blob: fcd72f915530492fb9881d1535d4bf00026f7367 [file] [log] [blame]
// RUN: %target-run-simple-swift
// REQUIRES: executable_test
// rdar://71642726 this test is crashing with optimizations.
// REQUIRES: swift_test_mode_optimize_none
import _Differentiation
import StdlibUnittest
var ZeroTangentVectorTests = TestSuite("zeroTangentVectorInitializer")
struct Generic<T: Differentiable, U: Differentiable>: Differentiable {
var x: T
var y: U
}
struct Nested<T: Differentiable, U: Differentiable>: Differentiable {
var generic: Generic<T, U>
}
ZeroTangentVectorTests.test("Derivation") {
typealias G = Generic<[Float], [[Float]]>
let generic = G(x: [1, 2, 3], y: [[4, 5, 6], [], [2]])
let genericZero = G.TangentVector(x: [0, 0, 0], y: [[0, 0, 0], [], [0]])
expectEqual(generic.zeroTangentVector, genericZero)
let nested = Nested(generic: generic)
let nestedZero = Nested.TangentVector(generic: genericZero)
expectEqual(nested.zeroTangentVector, nestedZero)
}
// Test differentiation correctness involving projection operations and
// per-instance zeros.
ZeroTangentVectorTests.test("DifferentiationCorrectness") {
struct Struct: Differentiable {
var x, y: [Float]
}
func concatenated(_ lhs: Struct, _ rhs: Struct) -> Struct {
return Struct(x: lhs.x + rhs.x, y: lhs.y + rhs.y)
}
func test(_ s: Struct) -> [Float] {
let result = concatenated(s, s).withDerivative { dresult in
// FIXME(TF-1008): Fix incorrect derivative values for
// "projection operation" operands when differentiation transform uses
// `Differentiable.zeroTangentVectorInitializer`.
// Actual: TangentVector(x: [1.0, 1.0, 1.0], y: [])
// Expected: TangentVector(x: [1.0, 1.0, 1.0], y: [1.0, 1.0, 1.0])
expectEqual(dresult, Struct.TangentVector(x: [1, 1, 1], y: [1, 1, 1]))
}
return result.x
}
let s = Struct(x: [1, 2, 3], y: [1, 2, 3])
let pb = pullback(at: s, in: test)
// FIXME(TF-1008): Remove `expectCrash` when differentiation transform uses
// `Differentiable.zeroTangentVectorInitializer`.
expectCrash {
_ = pb([1, 1, 1])
}
}
runAllTests()