blob: f5394d6c33379932df55a7974962f77e3a6a4d01 [file] [log] [blame]
// RUN: %target-run-simple-swift
// REQUIRES: executable_test
import StdlibUnittest
#if os(macOS)
import Darwin.C
#else
import Glibc
#endif
import DifferentiationUnittest
var SeparateTangentTypeTests = TestSuite("SeparateTangentType")
struct DifferentiableSubset : Differentiable {
@differentiable(wrt: self)
var w: Tracked<Float>
@differentiable(wrt: self)
var b: Tracked<Float>
@noDerivative var flag: Bool
struct TangentVector : Differentiable, AdditiveArithmetic {
typealias TangentVector = DifferentiableSubset.TangentVector
var w: Tracked<Float>
var b: Tracked<Float>
}
mutating func move(along v: TangentVector) {
w.move(along: v.w)
b.move(along: v.b)
}
}
SeparateTangentTypeTests.testWithLeakChecking("Trivial") {
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
let pb = pullback(at: x) { x in x }
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
}
SeparateTangentTypeTests.testWithLeakChecking("Initialization") {
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
let pb = pullback(at: x) { x in DifferentiableSubset(w: 1, b: 2, flag: true) }
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
}
SeparateTangentTypeTests.testWithLeakChecking("SomeArithmetics") {
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
let pb = pullback(at: x) { x in DifferentiableSubset(w: x.w * x.w, b: x.b * x.b, flag: true) }
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
}
runAllTests()