blob: a69ea47663464989f4c9b6a8400459d4b07ecb9f [file] [log] [blame]
// RUN: %target-run-simple-swift
import StdlibUnittest
import DifferentiationUnittest
var CurryingAutodiffTests = TestSuite("CurryingAutodiff")
CurryingAutodiffTests.testWithLeakChecking("StructMember") {
struct A {
@differentiable(wrt: (value))
func v(_ value: Tracked<Float>) -> Tracked<Float> { return value * value }
}
let a = A()
// This implicitly constructs a function (A) -> (Tracked<Float>) -> Tracked<Float>
// which gets called with a:
let g: @differentiable (Tracked<Float>) -> Tracked<Float> = a.v
expectEqual(6.0, Tracked<Float>(3.0).gradient(in: g))
}
runAllTests()