blob: 6c475d4d11c665b4dfd72a828af2adb58dfb2d05 [file] [log] [blame]
// RUN: %target-run-simple-swift
// REQUIRES: executable_test
import StdlibUnittest
import DifferentiationUnittest
var SupersetVJPTests = TestSuite("SupersetVJP")
@differentiable(wrt: (x, y), vjp: dmulxy)
func mulxy(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
// use control flow to prevent AD; NB fix when control flow is supported
if x > 1000 {
return y
}
return x * y
}
func dmulxy(
_ x: Tracked<Float>,
_ y: Tracked<Float>
) -> (Tracked<Float>, (Tracked<Float>) -> (Tracked<Float>, Tracked<Float>)) {
return (mulxy(x, y), { v in (y * v, x * v) })
}
func calls_mulxy(_ x: Tracked<Float>, _ y: Tracked<Float>) -> Tracked<Float> {
return mulxy(x, y)
}
SupersetVJPTests.testWithLeakChecking("Superset") {
expectEqual(3, gradient(at: 2) { x in mulxy(x, 3) })
}
SupersetVJPTests.testWithLeakChecking("SupersetNested") {
expectEqual(2, gradient(at: 3) { y in calls_mulxy(2, y) })
}
SupersetVJPTests.testWithLeakChecking("CrossModuleClosure") {
expectEqual(1, gradient(at: Tracked<Float>(1)) { x in x + 2 })
}
SupersetVJPTests.testWithLeakChecking("SubsetOfSubset") {
@differentiable(wrt: (x, z))
func foo(_ x: Tracked<Float>, _ y: Tracked<Float>, _ z: Tracked<Float>) -> Tracked<Float> {
withoutDerivative(at: 0)
}
expectEqual(0, gradient(at: 0, in: { x in foo(x, 0, 0) }))
}
SupersetVJPTests.test("ApplySubset") {
// TF-914
@differentiable(wrt: x)
func foo<T: Differentiable>(_ x: T, _ y: T, apply: @differentiable (T, T) -> T) -> T {
return apply(x, y)
}
expectEqual(1, gradient(at: Tracked<Float>(0)) { x in foo(x, 0) { $0 + $1 } })
}
// FIXME: The expression `(+) as @differentiable (Float, @nondiff Float) -> Float)`
// forms a curry thunk of `Float.+` before conversion to @differentiable, and AD
// doesn't know how to differentiate the curry thunk, so it produces a
// "function is not differentiable" error.
// SupersetVJPTests.test("CrossModule") {
// let grad = gradient(at: Float(1), Float(2), in: (+) as @differentiable (Float, @nondiff Float) -> Float)
// expectEqual(Float(1), grad)
// }
SupersetVJPTests.testWithLeakChecking("IndirectResults") {
@differentiable(wrt: (x, y), vjp: dx_T)
func x_T<T : Differentiable>(_ x: Tracked<Float>, _ y: T) -> Tracked<Float> {
if x > 1000 { return x }
return x
}
func dx_T<T : Differentiable>(
_ x: Tracked<Float>, _ y: T
) -> (Tracked<Float>, (Tracked<Float>) -> (Tracked<Float>, T.TangentVector)) {
return (x_T(x, y), { v in (x * v, .zero) })
}
expectEqual(2, gradient(at: 2) { x in x_T(x, Tracked<Float>(3)) })
}
runAllTests()