blob: 97be01190465357368ca6638342e26968b240374 [file] [log] [blame]
// SWIFT_ENABLE_TENSORFLOW
// RUN: %empty-directory(%t)
// RUN: %target-swift-frontend %s -emit-module -parse-as-library -o %t
// RUN: llvm-bcanalyzer %t/differentiable_attr.swiftmodule | %FileCheck %s -check-prefix=BCANALYZER
// RUN: %target-sil-opt -enable-sil-verify-all %t/differentiable_attr.swiftmodule -o - | %FileCheck %s
// BCANALYZER-NOT: UnknownCode
// CHECK: @differentiable(wrt: x)
// CHECK-NEXT: func simple(x: Float) -> Float
@differentiable
func simple(x: Float) -> Float {
return x
}
// CHECK: @differentiable(linear, wrt: x)
// CHECK-NEXT: func simple2(x: Float) -> Float
@differentiable(linear)
func simple2(x: Float) -> Float {
return x
}
// CHECK: @differentiable(linear, wrt: x)
// CHECK-NEXT: func simple4(x: Float) -> Float
@differentiable(linear, wrt: x)
func simple4(x: Float) -> Float {
return x
}
func jvpSimple(x: Float) -> (Float, (Float) -> Float) {
return (x, { v in v })
}
func vjpSimple(x: Float) -> (Float, (Float) -> Float) {
return (x, { v in v })
}
// CHECK: @differentiable(wrt: x)
// CHECK-NEXT: func testWrtClause(x: Float, y: Float) -> Float
@differentiable(wrt: x)
func testWrtClause(x: Float, y: Float) -> Float {
return x + y
}
struct InstanceMethod : Differentiable {
// CHECK: @differentiable(wrt: (self, y))
// CHECK-NEXT: func testWrtClause(x: Float, y: Float) -> Float
@differentiable(wrt: (self, y))
func testWrtClause(x: Float, y: Float) -> Float {
return x + y
}
}
// CHECK: @differentiable(wrt: x where T : Differentiable)
// CHECK-NEXT: func testOnlyWhereClause<T>(x: T) -> T where T : Numeric
@differentiable(where T : Differentiable)
func testOnlyWhereClause<T : Numeric>(x: T) -> T {
return x
}
// CHECK: @differentiable(wrt: x where T : Differentiable)
// CHECK-NEXT: func testWhereClause<T>(x: T) -> T where T : Numeric
@differentiable(where T : Differentiable)
func testWhereClause<T : Numeric>(x: T) -> T {
return x
}
protocol P {}
extension P {
// CHECK: @differentiable(wrt: self where Self : Differentiable)
// CHECK-NEXT: func testWhereClauseMethod() -> Self
@differentiable(wrt: self where Self : Differentiable)
func testWhereClauseMethod() -> Self {
return self
}
}
// CHECK: @differentiable(wrt: x where T : Differentiable, T == T.TangentVector)
// CHECK-NEXT: func testWhereClauseMethodTypeConstraint<T>(x: T) -> T where T : Numeric
@differentiable(where T : Differentiable, T == T.TangentVector)
func testWhereClauseMethodTypeConstraint<T : Numeric>(x: T) -> T {
return x
}
extension P {
// CHECK: @differentiable(wrt: self where Self : Differentiable, Self == Self.TangentVector)
// CHECK-NEXT: func testWhereClauseMethodTypeConstraint() -> Self
@differentiable(wrt: self where Self.TangentVector == Self, Self : Differentiable)
func testWhereClauseMethodTypeConstraint() -> Self {
return self
}
}
// CHECK: func testDifferentiableParam(f: @differentiable (Float) -> Float)
func testDifferentiableParam(f: @differentiable (Float) -> Float) {}