blob: 78484a0fc3f837d77830ebdf25248602fa87a9a5 [file] [log] [blame]
// RUN: %target-run-simple-swift
// REQUIRES: executable_test
import StdlibUnittest
import DifferentiationUnittest
var ClassMethodTests = TestSuite("ClassMethods")
ClassMethodTests.test("Final") {
final class Final : Differentiable {
func method(_ x: Tracked<Float>) -> Tracked<Float> {
return x * x
}
}
for i in -5...5 {
expectEqual(Tracked<Float>(Float(i * 2)),
gradient(at: Tracked<Float>(Float(i))) {
x in Final().method(x)
})
}
}
ClassMethodTests.test("Simple") {
class Super {
@differentiable(wrt: x, jvp: jvpf, vjp: vjpf)
func f(_ x: Tracked<Float>) -> Tracked<Float> {
return 2 * x
}
final func jvpf(_ x: Tracked<Float>) -> (Tracked<Float>, (Tracked<Float>) -> Tracked<Float>) {
return (f(x), { v in 2 * v })
}
final func vjpf(_ x: Tracked<Float>) -> (Tracked<Float>, (Tracked<Float>) -> Tracked<Float>) {
return (f(x), { v in 2 * v })
}
}
class SubOverride : Super {
@differentiable(wrt: x)
override func f(_ x: Tracked<Float>) -> Tracked<Float> {
return 3 * x
}
}
class SubOverrideCustomDerivatives : Super {
@differentiable(wrt: x, jvp: jvpf2, vjp: vjpf2)
override func f(_ x: Tracked<Float>) -> Tracked<Float> {
return 3 * x
}
final func jvpf2(_ x: Tracked<Float>) -> (Tracked<Float>, (Tracked<Float>) -> Tracked<Float>) {
return (f(x), { v in 3 * v })
}
final func vjpf2(_ x: Tracked<Float>) -> (Tracked<Float>, (Tracked<Float>) -> Tracked<Float>) {
return (f(x), { v in 3 * v })
}
}
func classValueWithGradient(_ c: Super) -> (Tracked<Float>, Tracked<Float>) {
return valueWithGradient(at: 1) { c.f($0) }
}
expectEqual((2, 2), classValueWithGradient(Super()))
expectEqual((3, 3), classValueWithGradient(SubOverride()))
expectEqual((3, 3), classValueWithGradient(SubOverrideCustomDerivatives()))
}
ClassMethodTests.test("SimpleWrtSelf") {
class Super : Differentiable {
var base: Tracked<Float>
// FIXME(TF-648): Dummy to make `Super.AllDifferentiableVariables` be nontrivial.
var _nontrivial: [Tracked<Float>] = []
// TODO(TF-654): Uncomment attribute when differentiation supports class initializers.
// TODO(TF-645): Remove `vjpInit` when differentiation supports `ref_element_addr`.
// @differentiable(vjp: vjpInit)
required init(base: Tracked<Float>) {
self.base = base
}
static func vjpInit(base: Tracked<Float>) -> (Super, (TangentVector) -> Tracked<Float>) {
return (Super(base: base), { x in x.base })
}
@differentiable(wrt: (self, x), jvp: jvpf, vjp: vjpf)
func f(_ x: Tracked<Float>) -> Tracked<Float> {
return base * x
}
final func jvpf(
_ x: Tracked<Float>
) -> (Tracked<Float>, (TangentVector, Tracked<Float>) -> Tracked<Float>) {
return (f(x), { (dself, dx) in dself.base * dx })
}
final func vjpf(
_ x: Tracked<Float>
) -> (Tracked<Float>, (Tracked<Float>) -> (TangentVector, Tracked<Float>)) {
let base = self.base
return (f(x), { v in
(TangentVector(base: v * x, _nontrivial: []), base * v)
})
}
}
class SubOverride : Super {
@differentiable(wrt: (self, x))
override func f(_ x: Tracked<Float>) -> Tracked<Float> {
return 3 * x
}
}
class SubOverrideCustomDerivatives : Super {
@differentiable(wrt: (self, x))
@differentiable(wrt: x, jvp: jvpf2, vjp: vjpf2)
override func f(_ x: Tracked<Float>) -> Tracked<Float> {
return 3 * x
}
final func jvpf2(_ x: Tracked<Float>) -> (Tracked<Float>, (Tracked<Float>) -> Tracked<Float>) {
return (f(x), { v in 3 * v })
}
final func vjpf2(_ x: Tracked<Float>) -> (Tracked<Float>, (Tracked<Float>) -> Tracked<Float>) {
return (f(x), { v in 3 * v })
}
}
// TODO(TF-654): Uncomment when differentiation supports class initializers.
/*
let v = Super.TangentVector(base: 100, _nontrivial: [])
expectEqual(100, pullback(at: 1337) { x in Super(base: x) }(v))
expectEqual(100, pullback(at: 1337) { x in SubOverride(base: x) }(v))
expectEqual(100, pullback(at: 1337) { x in SubOverrideCustomDerivatives(base: x) }(v))
*/
// `valueWithGradient` is not used because nested tuples cannot be compared
// with `expectEqual`.
func classGradient(_ c: Super) -> (Super.TangentVector, Tracked<Float>) {
return gradient(at: c, 10) { c, x in c.f(x) }
}
expectEqual((Super.TangentVector(base: 10, _nontrivial: []), 2),
classGradient(Super(base: 2)))
expectEqual((Super.TangentVector(base: 0, _nontrivial: []), 3),
classGradient(SubOverride(base: 2)))
expectEqual((Super.TangentVector(base: 0, _nontrivial: []), 3),
classGradient(SubOverrideCustomDerivatives(base: 2)))
}
ClassMethodTests.test("Generics") {
class Super<T : Differentiable & FloatingPoint> where T == T.TangentVector {
@differentiable(wrt: x, jvp: jvpf, vjp: vjpf)
func f(_ x: Tracked<T>) -> Tracked<T> {
return Tracked<T>(2) * x
}
final func jvpf(
_ x: Tracked<T>
) -> (Tracked<T>, (Tracked<T>.TangentVector) -> Tracked<T>.TangentVector) {
return (f(x), { v in Tracked<T>(2) * v })
}
final func vjpf(
_ x: Tracked<T>
) -> (Tracked<T>, (Tracked<T>.TangentVector) -> Tracked<T>.TangentVector) {
return (f(x), { v in Tracked<T>(2) * v })
}
}
class SubOverride<T : Differentiable & FloatingPoint> : Super<T> where T == T.TangentVector {
@differentiable(wrt: x)
override func f(_ x: Tracked<T>) -> Tracked<T> {
return x
}
}
class SubSpecializeOverride : Super<Float> {
@differentiable(wrt: x)
override func f(_ x: Tracked<Float>) -> Tracked<Float> {
return 3 * x
}
}
class SubOverrideCustomDerivatives<T : Differentiable & FloatingPoint> : Super<T>
where T == T.TangentVector {
@differentiable(wrt: x, jvp: jvpf2, vjp: vjpf2)
override func f(_ x: Tracked<T>) -> Tracked<T> {
return Tracked<T>(3) * x
}
final func jvpf2(
_ x: Tracked<T>
) -> (Tracked<T>, (Tracked<T>.TangentVector) -> Tracked<T>.TangentVector) {
return (f(x), { v in Tracked<T>(3) * v })
}
final func vjpf2(
_ x: Tracked<T>
) -> (Tracked<T>, (Tracked<T>.TangentVector) -> Tracked<T>.TangentVector) {
return (f(x), { v in Tracked<T>(3) * v })
}
}
class SubSpecializeOverrideCustomDerivatives : Super<Float80> {
@differentiable(wrt: x, jvp: jvpf2, vjp: vjpf2)
override func f(_ x: Tracked<Float80>) -> Tracked<Float80> {
return 3 * x
}
final func jvpf2(
_ x: Tracked<Float80>
) -> (Tracked<Float80>, (Tracked<Float80>) -> Tracked<Float80>) {
return (f(x), { v in 3 * v })
}
final func vjpf2(
_ x: Tracked<Float80>
) -> (Tracked<Float80>, (Tracked<Float80>) -> Tracked<Float80>) {
return (f(x), { v in 3 * v })
}
}
func classValueWithGradient<T : Differentiable & FloatingPoint>(
_ c: Super<T>
) -> (T, T) where T == T.TangentVector {
let (x,y) = valueWithGradient(at: Tracked<T>(1), in: {
c.f($0) })
return (x.value, y.value)
}
expectEqual((2, 2), classValueWithGradient(Super<Float>()))
expectEqual((1, 1), classValueWithGradient(SubOverride<Float>()))
expectEqual((3, 3), classValueWithGradient(SubSpecializeOverride()))
expectEqual((3, 3), classValueWithGradient(SubOverrideCustomDerivatives<Float>()))
expectEqual((3, 3), classValueWithGradient(SubSpecializeOverrideCustomDerivatives()))
}
ClassMethodTests.test("Methods") {
class Super : Differentiable {
var base: Tracked<Float>
// Dummy to make `Super.AllDifferentiableVariables` be nontrivial.
var _nontrivial: [Tracked<Float>] = []
// TODO(TF-654): Uncomment attribute when differentiation supports class initializers.
// TODO(TF-645): Remove `vjpInit` when differentiation supports `ref_element_addr`.
// @differentiable(vjp: vjpInit)
init(base: Tracked<Float>) {
self.base = base
}
static func vjpInit(base: Tracked<Float>) -> (Super, (TangentVector) -> Tracked<Float>) {
return (Super(base: base), { x in x.base })
}
@differentiable(vjp: vjpSquared)
func squared() -> Tracked<Float> { base * base }
final func vjpSquared() -> (Tracked<Float>, (Tracked<Float>) -> TangentVector) {
let base = self.base
return (base * base, { v in
TangentVector(base: 2 * base * v, _nontrivial: [])
})
}
}
class Sub1 : Super {
@differentiable(vjp: vjpSquared2)
override func squared() -> Tracked<Float> { base * base }
final func vjpSquared2() -> (Tracked<Float>, (Tracked<Float>) -> TangentVector) {
let base = self.base
return (base * base, { v in
TangentVector(base: 2 * base * v, _nontrivial: [])
})
}
}
func classValueWithGradient(_ c: Super) -> (Tracked<Float>, Super.TangentVector) {
return valueWithGradient(at: c) { c in c.squared() }
}
// TODO(TF-654, TF-645): Uncomment when differentiation supports class initializers or `ref_element_addr`.
// expectEqual(4, gradient(at: 2) { x in Super(base: x).squared() })
// TODO(TF-647): Handle `unchecked_ref_cast` in `Sub1.init` during pullback generation.
// FIXME: `Super.init` VJP type mismatch for empty `Super.AllDifferentiableVariables`:
// SIL verification failed: VJP type does not match expected VJP type
// $@convention(method) (Tracked<Float>, @thick Super.Type) -> (@owned Super, @owned @callee_guaranteed (@guaranteed Super.AllDifferentiableVariables) -> Tracked<Float>)
// $@convention(method) (Tracked<Float>, @owned Super) -> (@owned Super, @owned @callee_guaranteed (@guaranteed Super.AllDifferentiableVariables) -> Tracked<Float>)
// expectEqual(4, gradient(at: 2) { x in Sub1(base: x).squared() })
expectEqual(Super.TangentVector(base: 4, _nontrivial: []),
gradient(at: Super(base: 2)) { foo in foo.squared() })
expectEqual(Sub1.TangentVector(base: 4, _nontrivial: []),
gradient(at: Sub1(base: 2)) { foo in foo.squared() })
}
ClassMethodTests.test("Properties") {
class Super : Differentiable {
var base: Tracked<Float>
// TODO(TF-654): Uncomment attribute when differentiation supports class initializers.
// TODO(TF-645): Remove `vjpInit` when differentiation supports `ref_element_addr`.
// @differentiable(vjp: vjpInit)
init(base: Tracked<Float>) { self.base = base }
static func vjpInit(base: Tracked<Float>) -> (Super, (TangentVector) -> Tracked<Float>) {
return (Super(base: base), { x in x.base })
}
@differentiable(vjp: vjpSquared)
var squared: Tracked<Float> { base * base }
final func vjpSquared() -> (Tracked<Float>, (Tracked<Float>) -> TangentVector) {
let base = self.base
return (base * base, { v in TangentVector(base: 2 * base * v) })
}
}
class Sub1 : Super {
// FIXME(TF-625): Crash due to `Super.AllDifferentiableVariables` abstraction pattern mismatch.
// SIL verification failed: vtable entry for #<anonymous function>Super.squared!getter.1.jvp.1.S must be ABI-compatible
// ABI incompatible return values
// @convention(method) (@guaranteed Super) -> (Tracked<Float>, @owned @callee_guaranteed (@guaranteed Super.AllDifferentiableVariables) -> Tracked<Float>)
// @convention(method) (@guaranteed Sub1) -> (Tracked<Float>, @owned @callee_guaranteed (Super.AllDifferentiableVariables) -> Tracked<Float>)
// @differentiable
// override var squared: Tracked<Float> { base * base }
}
func classValueWithGradient(_ c: Super) -> (Tracked<Float>, Super.TangentVector) {
return valueWithGradient(at: c) { c in c.squared }
}
// TODO(TF-654, TF-645): Uncomment when differentiation supports class initializers or `ref_element_addr`.
// expectEqual(4, gradient(at: 2) { x in Super(base: x).squared })
expectEqual(Super.TangentVector(base: 4),
gradient(at: Super(base: 2)) { foo in foo.squared })
}
ClassMethodTests.test("Capturing") {
class Multiplier {
var coefficient: Tracked<Float>
init(_ coefficient: Tracked<Float>) {
self.coefficient = coefficient
}
// Case 1: generated VJP.
@differentiable
func apply(to x: Tracked<Float>) -> Tracked<Float> {
return coefficient * x
}
// Case 2: custom VJP capturing `self`.
@differentiable(wrt: (x), vjp: vjpApply2)
func apply2(to x: Tracked<Float>) -> Tracked<Float> {
return coefficient * x
}
final func vjpApply2(
to x: Tracked<Float>
) -> (Tracked<Float>, (Tracked<Float>) -> Tracked<Float>) {
return (coefficient * x, { v in self.coefficient * v })
}
// Case 3: custom VJP capturing `self.coefficient`.
@differentiable(wrt: x, vjp: vjpApply3)
func apply3(to x: Tracked<Float>) -> Tracked<Float> {
return coefficient * x
}
final func vjpApply3(
to x: Tracked<Float>
) -> (Tracked<Float>, (Tracked<Float>) -> Tracked<Float>) {
let coefficient = self.coefficient
return (coefficient * x, { v in coefficient * v })
}
}
func f(_ x: Tracked<Float>) -> Tracked<Float> {
let m = Multiplier(10)
let result = m.apply(to: x)
m.coefficient += 1
return result
}
expectEqual(10, gradient(at: 1, in: f))
func f2(_ x: Tracked<Float>) -> Tracked<Float> {
let m = Multiplier(10)
let result = m.apply2(to: x)
m.coefficient += 1
return result
}
expectEqual(11, gradient(at: 1, in: f2))
func f3(_ x: Tracked<Float>) -> Tracked<Float> {
let m = Multiplier(10)
let result = m.apply3(to: x)
m.coefficient += 1
return result
}
expectEqual(10, gradient(at: 1, in: f3))
}
runAllTests()