blob: 2b4dce4d9cca73a990aeca54f8a15ac088d6cf29 [file] [log] [blame]
// RUN: %target-run-simple-swift
// REQUIRES: executable_test
import StdlibUnittest
import _Differentiation
var SubsetParameterThunkTests = TestSuite("SubsetParameterThunks")
func inoutDirect(_ x: Float, _ y: inout Double, _ z: Float) {}
@derivative(of: inoutDirect)
func vjpInoutDirect(_ x: Float, _ y: inout Double, _ z: Float) -> (
value: Void, pullback: (inout Double) -> (Float, Float)
) {
return ((), { dy in
dy = 3
return (2, 4)
})
}
SubsetParameterThunkTests.test("InoutParametersDirect") {
@differentiable(wrt: x)
@differentiable(wrt: y)
@differentiable(wrt: z)
func inoutDirectCaller(_ x: Float, _ y: Double, _ z: Float) -> Double {
var result = y
inoutDirect(x, &result, z)
return result
}
let x: Float = 3
let y: Double = 4
let z: Float = 5
expectEqual((2, 3, 4), gradient(at: x, y, z, in: inoutDirectCaller))
expectEqual((3, 4), gradient(at: y, z, in: { y, z in inoutDirectCaller(x, y, z) }))
expectEqual((2, 4), gradient(at: x, z, in: { x, z in inoutDirectCaller(x, y, z) }))
expectEqual((2, 3), gradient(at: x, y, in: { x, y in inoutDirectCaller(x, y, z) }))
}
func inoutIndirect<T: Differentiable, U: Differentiable, V: Differentiable>(
_ x: T, _ y: inout U, _ z: V
) {}
@derivative(of: inoutIndirect)
func vjpInoutIndirect<T: Differentiable, U: Differentiable, V: Differentiable>(
_ x: T, _ y: inout U, _ z: V
) -> (
value: Void, pullback: (inout U.TangentVector) -> (T.TangentVector, V.TangentVector)
) {
return ((), { dy in
return (.zero, .zero)
})
}
SubsetParameterThunkTests.test("InoutParametersIndirect") {
@differentiable(wrt: x)
@differentiable(wrt: y)
@differentiable(wrt: z)
@differentiable
func inoutIndirectCaller<T: Differentiable, U: Differentiable, V: Differentiable>(
_ x: T, _ y: U, _ z: V
) -> U {
var result = y
inoutIndirect(x, &result, z)
return result
}
let x: Float = 3
let y: Double = 4
let z: Float = 5
expectEqual((0, 1, 0), gradient(at: x, y, z, in: inoutIndirectCaller))
expectEqual((1, 0), gradient(at: y, z, in: { y, z in inoutIndirectCaller(x, y, z) }))
expectEqual((0, 0), gradient(at: x, z, in: { x, z in inoutIndirectCaller(x, y, z) }))
expectEqual((0, 1), gradient(at: x, y, in: { x, y in inoutIndirectCaller(x, y, z) }))
}
runAllTests()