blob: 124d428b5cd299b50d36be14136bac90eb3f05d0 [file] [log] [blame]
// RUN: %target-swift-emit-sil -verify %s | %FileCheck %s -check-prefix=CHECK-SIL
@_silgen_name("identity")
func identity<T : Differentiable>(_ x: T) -> T {
return x
}
_ = gradient(at: Float(1), in: { x in identity(x) })
// Test AdjointEmitter local buffer allocation.
// Verify that local buffers are immediately set to zero.
// CHECK-SIL-LABEL: sil private @AD__identity__pullback_src_0_wrt_0_s14DifferentiableRzl
// CHECK-SIL: [[ORIG_COTAN:%.*]] = alloc_stack $τ_0_0.TangentVector
// CHECK-SIL-NEXT: [[ZERO_WITNESS:%.*]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic.zero!getter
// CHECK-SIL-NEXT: [[ORIG_COTAN_METATYPE:%.*]] = metatype $@thick τ_0_0.TangentVector.Type
// CHECK-SIL-NEXT: [[EMIT_ZERO_INDIRECT:%.*]] = apply [[ZERO_WITNESS]]<τ_0_0.TangentVector>([[ORIG_COTAN]], [[ORIG_COTAN_METATYPE]])
// CHECK-SIL: }
// Test TF-201: differentiate direct references to generic function.
// This involves reabstraction thunk differentiation.
_ = gradient(at: Float(1), in: identity)
protocol DifferentiableAdditiveArithmetic: Differentiable & AdditiveArithmetic {
@differentiable
static func + (lhs: Self, rhs: Self) -> Self
}
extension Float: DifferentiableAdditiveArithmetic {}
func generic<T: DifferentiableAdditiveArithmetic>(_ x: T) -> T {
x + x + x
}
_ = gradient(at: Float(10), in: generic)
struct Wrapper<Scalar : Differentiable> : Differentiable {
var value: Scalar
init(_ value: Scalar) { self.value = value }
}
func generic<T>(_ x: Wrapper<T>) -> T {
return x.value
}
_ = gradient(at: Wrapper<Float>(1), in: generic)
func generic2<T: Differentiable, U: Differentiable>(_ x: T, _ y: Float, _ z: U) -> T {
return x
}
func foo<T>(_ x: Wrapper<T>) {
_ = gradient(at: Float(1), 2, x, in: generic2)
}
// Test case where associated derivative function's requirements are met.
extension Wrapper where Scalar : Numeric {
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
func mean() -> Wrapper {
return self
}
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
func variance() -> Wrapper {
return mean() // ok
}
}
_ = pullback(at: Wrapper<Float>(1), in: { $0.variance() })
// Tests TF-277.
// FIXME(SR-13933): Temporarily disabled due to VJPCloner ownership verification
// failure.
/*
protocol Layer : Differentiable {
associatedtype Output : Differentiable
}
struct SupervisedTrainer<Model : Layer> {
var model: Model
var lossFunction: @differentiable (Model.Output, Model.Output) -> Float
func fit(y: Model.Output) {
_ = gradient(at: y) { y in return self.lossFunction(y, y) }
}
}
*/
// Tests TF-440.
struct TF_440_Input<Input: Differentiable, State: Differentiable>
: Differentiable {
var input: Input
var state: State
}
struct TF_440<T : Differentiable> {
@differentiable
func applied(to input: TF_440_Input<Float, Float>) -> Float {
return input.state
}
@differentiable
func applied(to input: TF_440_Input<T, Float>) -> Float {
return input.state
}
@differentiable
func applied(to input: TF_440_Input<T, Float>) -> T {
return input.input
}
}
// Tests TF-508: differentiation requirements with dependent member types.
protocol TF_508_Proto {
associatedtype Scalar
}
extension TF_508_Proto where Scalar : FloatingPoint {
@differentiable(
where Self : Differentiable, Scalar : Differentiable,
// Conformance requirement with dependent member type.
Self.TangentVector : TF_508_Proto
)
static func +(lhs: Self, rhs: Self) -> Self {
return lhs
}
@differentiable(
where Self : Differentiable, Scalar : Differentiable,
// Same-type requirement with dependent member type.
Self.TangentVector == Float
)
static func -(lhs: Self, rhs: Self) -> Self {
return lhs
}
}
extension TF_508_Proto where Self : Differentiable,
Scalar : FloatingPoint & Differentiable,
Self.TangentVector : TF_508_Proto {
@derivative(of: +)
static func vjpAdd(lhs: Self, rhs: Self)
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (lhs, { v in (v, v) })
}
}
extension TF_508_Proto where Self : Differentiable,
Scalar : FloatingPoint & Differentiable,
Self.TangentVector == Float {
@derivative(of: -)
static func vjpSubtract(lhs: Self, rhs: Self)
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (lhs, { v in (v, v) })
}
}
struct TF_508_Struct<Scalar : AdditiveArithmetic>
: TF_508_Proto, AdditiveArithmetic {}
extension TF_508_Struct : Differentiable where Scalar : Differentiable {
typealias TangentVector = TF_508_Struct
}
func TF_508() {
let x = TF_508_Struct<Float>()
// Test conformance requirement with dependent member type.
_ = pullback(at: x, in: { (x: TF_508_Struct<Float>) -> TF_508_Struct<Float> in
return x + x
})
// Test same-type requirement with dependent member type.
_ = pullback(at: x, in: { (x: TF_508_Struct<Float>) -> TF_508_Struct<Float> in
return x - x
})
}
// TF-523
struct TF_523_Struct : Differentiable & AdditiveArithmetic {
var a: Float = 1
typealias TangentVector = TF_523_Struct
}
@differentiable
func TF_523_f(_ x: TF_523_Struct) -> Float {
return x.a * 2
}
// TF-534: Thunk substitution map remapping.
protocol TF_534_Layer : Differentiable {
associatedtype Input : Differentiable
associatedtype Output : Differentiable
@differentiable
func callAsFunction(_ input: Input) -> Output
}
struct TF_534_Tensor<Scalar> : Differentiable {}
func TF_534<Model: TF_534_Layer>(
_ model: inout Model, inputs: Model.Input
) -> TF_534_Tensor<Float> where Model.Output == TF_534_Tensor<Float> {
return valueWithPullback(at: model) { model -> Model.Output in
return model(inputs)
}.0
}
// TF-546: Test that SILGen linear map thunk performs correct reabstraction.
struct TF_546<T: FloatingPoint>: AdditiveArithmetic {
var real: T
var imaginary: T
@differentiable(where T: Differentiable, T == T.TangentVector)
init(real: T = 0, imaginary: T = 0) {
self.real = real
self.imaginary = imaginary
}
}
extension TF_546: Differentiable where T: Differentiable {
typealias TangentVector = TF_546
}
extension TF_546 where T: Differentiable, T == T.TangentVector {
@derivative(of: init)
static func _vjpInit(real: T, imaginary: T) -> (value: TF_546, pullback: (TF_546) -> (T, T)) {
return (TF_546(real: real, imaginary: imaginary), { ($0.real, $0.imaginary) })
}
}
let _: @differentiable(Float, Float) -> TF_546<Float> = { r, i in
TF_546(real: r, imaginary: i)
}
// TF-652: Test VJPEmitter substitution map generic signature.
// The substitution map should have the VJP's generic signature, not the
// original function's.
struct TF_652<Scalar> {}
extension TF_652 : Differentiable where Scalar : FloatingPoint {}
@differentiable(wrt: x where Scalar: FloatingPoint)
func test<Scalar: Numeric>(x: TF_652<Scalar>) -> TF_652<Scalar> {
for _ in 0..<10 {
let _ = x
}
return x
}
// TF-682: Test that SILGen linear map thunk performs correct reabstraction.
protocol TF_682_Proto {
associatedtype Scalar
}
extension TF_682_Proto where Scalar : FloatingPoint {
@differentiable(
where Self : Differentiable, Scalar : Differentiable,
// Same-type requirement with dependent member type.
Self.TangentVector == Float
)
func foo(lhs: Self) -> Self {
return lhs
}
}
extension TF_682_Proto where Self : Differentiable,
Scalar : FloatingPoint & Differentiable,
Self.TangentVector == Float {
@derivative(of: foo)
func vjpFoo(lhs: Self) -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (lhs, { v in (v, v) })
}
}
// NOTE(TF-1208): Differentiation regression due to changes in curry thunk generation.
/*
// TF-688: Test generic curry thunk cloning.
public struct TF_688_Struct<Scalar> {
var x: Scalar
}
extension TF_688_Struct: Differentiable where Scalar: Differentiable {
@differentiable
public static func id(x: Self) -> Self {
return x
}
}
@differentiable(wrt: x)
public func TF_688<Scalar: Differentiable>(
_ x: TF_688_Struct<Scalar>,
reduction: @differentiable (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id
) -> TF_688_Struct<Scalar> {
reduction(x)
}
*/
// TF-697: Test generic requirements of generated derivative function.
protocol TF_697_Module: Differentiable {
associatedtype Input
associatedtype Output: Differentiable
@differentiable(wrt: self)
func callModule(_ input: Input) -> Output
}
protocol TF_697_Layer: TF_697_Module where Input: Differentiable {
@differentiable
func callLayer(_ input: Input) -> Output
}
struct TF_697_Sequential<Layer1: TF_697_Module, Layer2: TF_697_Layer>: TF_697_Module
where Layer1.Output == Layer2.Input {
var layer1: Layer1
var layer2: Layer2
@differentiable(wrt: self)
func callModule(_ input: Layer1.Input) -> Layer2.Output {
layer2.callLayer(layer1.callModule(input))
}
}
extension TF_697_Sequential: TF_697_Layer where Layer1: TF_697_Layer {
@differentiable
func callLayer(_ input: Layer1.Input) -> Layer2.Output {
layer2.callLayer(layer1.callLayer(input))
}
}
// TF-817: Test remapping `apply` callee types in derivative function context.
struct TF_817<T> {
func foo(_ index: Int) -> T {
fatalError()
}
}
extension TF_817: Differentiable where T: Differentiable {
@derivative(of: foo)
func vjpFoo(index: Int) -> (value: T, pullback: (T.TangentVector) -> (TangentVector)) {
fatalError()
}
}
extension TF_817 {
@differentiable(wrt: self where T: Differentiable)
public func test(index: Int) -> T {
return self.foo(0) // crash happened here
}
}
// TF-886: Test `partial_apply` of linear map subset parameters thunk.
@differentiable
func TF_886_foo<T, U: Differentiable>(_: Float, _: T, _: U) -> Float {
return 0
}
@differentiable
func TF_886_bar<T>(x: Float, y: T) -> Float {
return TF_886_foo(x, y, 0)
}
// Test layout requirements.
// The layout requirement is "contextual": the requirement is not on `T`, the
// differentiable function parameter/result type.
struct ContextualLayoutRequirement<T: Differentiable, U: AnyObject> {
var stored: T
}
extension ContextualLayoutRequirement {
func test(_ x: T) {
let _: @differentiable (T) -> T = { _ in self.stored }
let _: @differentiable (T) -> T = { $0 }
}
}
// The layout requirement directly involves `T`, the differentiable function
// parameter/result type.
// TODO(TF-851): Uncomment the tests below after `@differentiable` function
// SILGen thunking is fixed.
/*
struct LayoutRequirement<T: AnyObject & Differentiable> {
var stored: T
}
extension LayoutRequirement {
func test(_ x: T) {
let _: @differentiable (T) -> T = { _ in self.stored }
let _: @differentiable (T) -> T = { $0 }
}
}
*/
// Test superclass requirements.
class Super: Differentiable {}
// The superclass requirement is "contextual": the requirement is not on `T`,
// the differentiable function parameter/result type.
struct ContextualSuperclassRequirement<T: Differentiable, U: Super> {
var stored: T
}
extension ContextualSuperclassRequirement {
func test(_ x: T) {
let _: @differentiable (T) -> T = { _ in self.stored }
let _: @differentiable (T) -> T = { $0 }
}
}
// The superclass requirement directly involves `T`, the differentiable
// function parameter/result type.
// TODO(TF-851): Uncomment the tests below after `@differentiable` function
// SILGen thunking is fixed.
/*
struct SuperclassRequirement<T: Super & Differentiable> {
var stored: T
}
extension SuperclassRequirement {
func test(_ x: T) {
let _: @differentiable (T) -> T = { _ in self.stored }
let _: @differentiable (T) -> T = { $0 }
}
}
*/
// TODO: add more tests.