blob: d617609ded57e4282330e60b8b0cfedc4efd3ddb [file] [log] [blame]
// RUN: %target-swift-frontend -emit-silgen -enable-testing -verify %s | %FileCheck %s -check-prefix=CHECK-AST
// RUN: %target-swift-frontend -emit-silgen -enable-testing -verify %s | %FileCheck %s -check-prefix=CHECK-SIL
//===----------------------------------------------------------------------===//
// Normal types
//===----------------------------------------------------------------------===//
@_silgen_name("foo")
@differentiable(vjp: dfoo)
public func foo(_ x: Float, _ y: Float) -> Float {
return 1
}
// CHECK-SIL-LABEL: sil [differentiable source 0 wrt 0, 1 vjp @AD__foo__vjp_src_0_wrt_0_1] [ossa] @foo
@_silgen_name("dfoo")
public func dfoo(_ x: Float, _ y: Float) -> (Float, (Float) -> (Float, Float)) {
return (foo(x, y), { _ in (1, 1) })
}
// CHECK-SIL-LABEL: sil [ossa] @dfoo
//===----------------------------------------------------------------------===//
// Indirect returns
//===----------------------------------------------------------------------===//
@_silgen_name("foo_indir_ret")
@differentiable(vjp: dfoo_indir_ret)
public func foo_indir_ret<T: Differentiable>(_ x: Float, _ y: T) -> T {
return y
}
// CHECK-SIL-LABEL: sil [differentiable source 0 wrt 0, 1 vjp @AD__foo_indir_ret__vjp_src_0_wrt_0_1] [ossa] @foo_indir_ret : $@convention(thin) <T where T : _Differentiable> (Float, @in_guaranteed T) -> @out T {
// CHECK-SIL: bb0(%0 : $*T, %1 : $Float, %2 : $*T):
@_silgen_name("dfoo_indir_ret")
public func dfoo_indir_ret<T: Differentiable>(_ x: Float, _ y: T) -> (T, (T.TangentVector) -> (Float, T.TangentVector)) {
return (y, { v in (x, v) })
}
//===----------------------------------------------------------------------===//
// JVP
//===----------------------------------------------------------------------===//
@_silgen_name("hasjvp")
@differentiable(jvp: dhasjvp)
public func hasjvp(_ x: Float, _ y: Float) -> Float {
return 1
}
// CHECK-SIL-LABEL: sil [differentiable source 0 wrt 0, 1 jvp @AD__hasjvp__jvp_src_0_wrt_0_1] [ossa] @hasjvp
@_silgen_name("dhasjvp")
public func dhasjvp(_ x: Float, _ y: Float) -> (Float, (Float, Float) -> Float) {
return (1, { _, _ in 1 })
}
// CHECK-SIL-LABEL: sil [ossa] @dhasjvp
//===----------------------------------------------------------------------===//
// VJP
//===----------------------------------------------------------------------===//
@inlinable
@_silgen_name("hasvjp")
@differentiable(vjp: dhasvjp)
public func hasvjp(_ x: Float, _ y: Float) -> Float {
return 1
}
// CHECK-SIL-LABEL: sil [serialized] [differentiable source 0 wrt 0, 1 vjp @AD__hasvjp__vjp_src_0_wrt_0_1] [ossa] @hasvjp
@_silgen_name("dhasvjp")
public func dhasvjp(_ x: Float, _ y: Float) -> (Float, (Float) -> (Float, Float)) {
return (1, { _ in (1, 1) })
}
// CHECK-SIL-LABEL: sil [ossa] @dhasvjp
//===----------------------------------------------------------------------===//
// Computed property
//===----------------------------------------------------------------------===//
struct DiffComputedProp : Differentiable & AdditiveArithmetic {
@differentiable(wrt: (self), jvp: computedPropJVP, vjp: computedPropVJP)
var computedProp: Float {
return 0
}
@_silgen_name("computedPropJVP")
func computedPropJVP() -> (Float, (DiffComputedProp) -> Float) {
fatalError("unimplemented")
}
@_silgen_name("computedPropVJP")
func computedPropVJP() -> (Float, (Float) -> DiffComputedProp) {
fatalError("unimplemented")
}
}
// Check that `@differentiable` attribute is transferred from computed property
// storage declaration to getter accessor.
// CHECK-AST: struct DiffComputedProp : _Differentiable & AdditiveArithmetic {
// CHECK-AST-NEXT: var computedProp: Float { get }
// CHECK-AST: }
// CHECK-SIL-LABEL: DiffComputedProp.computedProp.getter
// CHECK-SIL-NEXT: [differentiable source 0 wrt 0 jvp @AD__$s26differentiable_attr_silgen16DiffComputedPropV08computedF0Sfvg__jvp_src_0_wrt_0 vjp @AD__$s26differentiable_attr_silgen16DiffComputedPropV08computedF0Sfvg__vjp_src_0_wrt_0]
public struct MyLayer: Differentiable {
@differentiable
var x: Float = 10
}
// CHECK-SIL-LABEL: initialization expression of MyLayer.x
// CHECK-SIL-NEXT: sil [transparent] [ossa] @$s26differentiable_attr_silgen7MyLayerV1xSfvpfi : $@convention(thin) () -> Float