blob: a9c15adff158b4fb8ab8d8407143e02c13118091 [file] [log] [blame]
// RUN: %target-swift-frontend -typecheck -verify %s
// Test top-level functions.
// expected-note @+1 {{'sin' defined here}}
func sin(_ x: Float) -> Float {
return x // dummy implementation
}
@differentiating(sin) // ok
func jvpSin(x: @nondiff Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
@differentiating(sin, wrt: x) // ok
func vjpSinExplicitWrt(x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (x, { $0 })
}
// expected-error @+1 {{a derivative already exists for 'sin'}}
@differentiating(sin)
func vjpDuplicate(x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (x, { $0 })
}
// expected-error @+1 {{'@differentiating' attribute requires function to return a two-element tuple of type '(value: T..., pullback: (U.TangentVector) -> T.TangentVector...)' or '(value: T..., differential: (T.TangentVector...) -> U.TangentVector)'}}
@differentiating(sin)
func jvpSinResultInvalid(x: @nondiff Float) -> Float {
return x
}
// expected-error @+1 {{'@differentiating' attribute requires function to return a two-element tuple (second element must have label 'pullback:' or 'differential:')}}
@differentiating(sin)
func vjpSinResultWrongLabel(x: Float) -> (value: Float, (Float) -> Float) {
return (x, { $0 })
}
// expected-error @+1 {{'@differentiating' attribute requires function to return a two-element tuple (first element type 'Int' must conform to 'Differentiable')}}
@differentiating(sin)
func vjpSinResultNotDifferentiable(x: Int) -> (value: Int, pullback: (Int) -> Int) {
return (x, { $0 })
}
// expected-error @+2 {{function result's 'pullback' type does not match 'sin'}}
// expected-note @+2 {{'pullback' does not have expected type '(Float.TangentVector) -> (Float.TangentVector)' (aka '(Float) -> Float')}}
@differentiating(sin)
func vjpSinResultInvalidSeedType(x: Float) -> (value: Float, pullback: (Double) -> Double) {
return (x, { $0 })
}
func generic<T : Differentiable>(_ x: T, _ y: T) -> T {
return x
}
@differentiating(generic) // ok
func jvpGeneric<T : Differentiable>(x: T, y: T) -> (value: T, differential: (T.TangentVector, T.TangentVector) -> T.TangentVector) {
return (x, { $0 + $1 })
}
// expected-error @+1 {{'@differentiating' attribute requires function to return a two-element tuple (second element must have label 'pullback:' or 'differential:')}}
@differentiating(generic)
func vjpGenericWrongLabel<T : Differentiable>(x: T, y: T) -> (value: T, (T) -> (T, T)) {
return (x, { ($0, $0) })
}
// expected-error @+1 {{could not find function 'generic' with expected type '<T where T : _Differentiable, T == T.TangentVector> (x: T) -> T'}}
@differentiating(generic)
func vjpGenericDiffParamMismatch<T : Differentiable>(x: T) -> (value: T, pullback: (T) -> (T, T)) where T == T.TangentVector {
return (x, { ($0, $0) })
}
@differentiating(generic) // ok
func vjpGenericExtraGenericRequirements<T : Differentiable & FloatingPoint>(x: T, y: T) -> (value: T, pullback: (T) -> (T, T)) where T == T.TangentVector {
return (x, { ($0, $0) })
}
func foo<T : FloatingPoint & Differentiable>(_ x: T) -> T { return x }
// Test `wrt` clauses.
func add(x: Float, y: Float) -> Float {
return x + y
}
@differentiating(add, wrt: x) // ok
func vjpAddWrtX(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float)) {
return (x + y, { $0 })
}
@differentiating(add, wrt: (x, y)) // ok
func vjpAddWrtXY(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
return (x + y, { ($0, $0) })
}
// expected-error @+1 {{unknown parameter name 'z'}}
@differentiating(add, wrt: z)
func vjpUnknownParam(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float)) {
return (x + y, { $0 })
}
// expected-error @+1 {{parameters must be specified in original order}}
@differentiating(add, wrt: (y, x))
func vjpParamOrderNotIncreasing(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
return (x + y, { ($0, $0) })
}
// expected-error @+1 {{'self' parameter is only applicable to instance methods}}
@differentiating(add, wrt: self)
func vjpInvalidSelfParam(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
return (x + y, { ($0, $0) })
}
func noParams() -> Float {
return 1
}
// expected-error @+1 {{'vjpNoParams()' has no parameters to differentiate with respect to}}
@differentiating(noParams)
func vjpNoParams() -> (value: Float, pullback: (Float) -> Float) {
return (1, { $0 })
}
func noDiffParams(x: Int) -> Float {
return 1
}
// expected-error @+1 {{no differentiation parameters could be inferred; must differentiate with respect to at least one parameter conforming to 'Differentiable'}}
@differentiating(noDiffParams)
func vjpNoDiffParams(x: Int) -> (value: Float, pullback: (Float) -> Int) {
return (1, { _ in 0 })
}
// expected-error @+1 {{functions ('@differentiable (Float) -> Float') cannot be differentiated with respect to}}
@differentiable(wrt: fn)
func invalidDiffWrtFunction(_ fn: @differentiable(Float) -> Float) -> Float {
return fn(.pi)
}
// expected-error @+2 {{type 'T' does not conform to protocol 'FloatingPoint'}}
// expected-error @+1 {{could not find function 'foo' with expected type '<T where T : AdditiveArithmetic, T : _Differentiable> (T) -> T'}}
@differentiating(foo)
func vjpFoo<T : AdditiveArithmetic & Differentiable>(_ x: T) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector)) {
return (x, { $0 })
}
@differentiating(foo)
func vjpFooExtraGenericRequirements<T : FloatingPoint & Differentiable & BinaryInteger>(_ x: T) -> (value: T, pullback: (T) -> (T)) where T == T.TangentVector {
return (x, { $0 })
}
// Test static methods.
extension AdditiveArithmetic where Self : Differentiable {
// expected-error @+1 {{derivative not in the same file as the original function}}
@differentiating(+)
static func vjpPlus(x: Self, y: Self) -> (value: Self, pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)) {
return (x + y, { v in (v, v) })
}
}
extension FloatingPoint where Self : Differentiable, Self == Self.TangentVector {
// expected-error @+1 {{derivative not in the same file as the original function}}
@differentiating(+)
static func vjpPlus(x: Self, y: Self) -> (value: Self, pullback: (Self) -> (Self, Self)) {
return (x + y, { v in (v, v) })
}
}
extension Differentiable where Self : AdditiveArithmetic {
// expected-error @+1 {{'+' is not defined in the current type context}}
@differentiating(+)
static func vjpPlus(x: Self, y: Self) -> (value: Self, pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)) {
return (x + y, { v in (v, v) })
}
}
extension AdditiveArithmetic where Self : Differentiable, Self == Self.TangentVector {
// expected-error @+1 {{could not find function '+' with expected type '<Self where Self : _Differentiable, Self == Self.TangentVector> (Self) -> (Self, Self) -> Self'}}
@differentiating(+)
func vjpPlusInstanceMethod(x: Self, y: Self) -> (value: Self, pullback: (Self) -> (Self, Self)) {
return (x + y, { v in (v, v) })
}
}
// Test instance methods.
protocol InstanceMethod : Differentiable {
// expected-note @+1 {{'foo' defined here}}
func foo(_ x: Self) -> Self
func foo2(_ x: Self) -> Self
// expected-note @+1 {{'bar' defined here}}
func bar<T : Differentiable>(_ x: T) -> Self
func bar2<T : Differentiable>(_ x: T) -> Self
}
extension InstanceMethod {
// If `Self` conforms to `Differentiable`, then `Self` is currently always inferred to be a differentiation parameter.
// expected-error @+2 {{function result's 'pullback' type does not match 'foo'}}
// expected-note @+2 {{'pullback' does not have expected type '(Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)'}}
@differentiating(foo)
func vjpFoo(x: Self) -> (value: Self, pullback: (TangentVector) -> TangentVector) {
return (x, { $0 })
}
@differentiating(foo)
func jvpFoo(x: Self) -> (value: Self, differential: (TangentVector, TangentVector) -> (TangentVector)) {
return (x, { $0 + $1 })
}
@differentiating(foo, wrt: (self, x))
func vjpFooWrt(x: Self) -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (x, { ($0, $0) })
}
}
extension InstanceMethod {
// expected-error @+2 {{function result's 'pullback' type does not match 'bar'}}
// expected-note @+2 {{'pullback' does not have expected type '(Self.TangentVector) -> (Self.TangentVector, T.TangentVector)'}}
@differentiating(bar)
func vjpBar<T : Differentiable>(_ x: T) -> (value: Self, pullback: (TangentVector) -> T.TangentVector) {
return (self, { _ in .zero })
}
@differentiating(bar)
func vjpBar<T : Differentiable>(_ x: T) -> (value: Self, pullback: (TangentVector) -> (TangentVector, T.TangentVector)) {
return (self, { ($0, .zero) })
}
@differentiating(bar, wrt: (self, x))
func jvpBarWrt<T : Differentiable>(_ x: T) -> (value: Self, differential: (TangentVector, T.TangentVector) -> TangentVector) {
return (self, { dself, dx in dself })
}
}
extension InstanceMethod where Self == Self.TangentVector {
@differentiating(foo2)
func vjpFooExtraRequirements(x: Self) -> (value: Self, pullback: (Self) -> (Self, Self)) {
return (x, { ($0, $0) })
}
@differentiating(foo2)
func jvpFooExtraRequirements(x: Self) -> (value: Self, differential: (Self, Self) -> (Self)) {
return (x, { $0 + $1 })
}
@differentiating(bar2)
func vjpBarExtraRequirements<T : Differentiable>(x: T) -> (value: Self, pullback: (Self) -> (Self, T.TangentVector)) {
return (self, { ($0, .zero) })
}
@differentiating(bar2)
func jvpBarExtraRequirements<T : Differentiable>(_ x: T) -> (value: Self, differential: (Self, T.TangentVector) -> Self) {
return (self, { dself, dx in dself })
}
}
protocol GenericInstanceMethod : Differentiable where Self == Self.TangentVector {
func instanceMethod<T : Differentiable>(_ x: T) -> T
}
extension GenericInstanceMethod {
func jvpInstanceMethod<T : Differentiable>(_ x: T) -> (T, (T.TangentVector) -> (TangentVector, T.TangentVector)) {
return (x, { v in (self, v) })
}
func vjpInstanceMethod<T : Differentiable>(_ x: T) -> (T, (T.TangentVector) -> (TangentVector, T.TangentVector)) {
return (x, { v in (self, v) })
}
}
// Test extra generic constraints.
func bar<T>(_ x: T) -> T {
return x
}
@differentiating(bar)
func vjpBar<T : Differentiable>(_ x: T) -> (value: T, pullback: (T.TangentVector) -> T.TangentVector) {
return (x, { $0 })
}
func baz<T, U>(_ x: T, _ y: U) -> T {
return x
}
@differentiating(baz)
func vjpBaz<T : Differentiable, U : Differentiable>(_ x: T, _ y: U)
-> (value: T, pullback: (T) -> (T, U))
where T == T.TangentVector, U == U.TangentVector
{
return (x, { ($0, .zero) })
}
protocol InstanceMethodProto {
func bar() -> Float
}
extension InstanceMethodProto where Self : Differentiable {
@differentiating(bar)
func vjpBar() -> (value: Float, pullback: (Float) -> TangentVector) {
return (bar(), { _ in .zero })
}
}
// Test consistent usages of `@differentiable` and `@differentiating` where
// derivative functions are specified in both attributes.
@differentiable(jvp: jvpConsistent, vjp: vjpConsistent)
func consistentSpecifiedDerivatives(_ x: Float) -> Float {
return x
}
@differentiating(consistentSpecifiedDerivatives)
func jvpConsistent(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
return (x, { $0 })
}
@differentiating(consistentSpecifiedDerivatives(_:))
func vjpConsistent(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (x, { $0 })
}
// Test usage of `@differentiable` on a stored property
struct PropertyDiff : Differentiable & AdditiveArithmetic {
// expected-error @+1 {{'@differentiable' attribute on stored property cannot specify 'jvp:' or 'vjp:'}}
@differentiable(vjp: vjpPropertyA)
var a: Float = 1
typealias TangentVector = PropertyDiff
typealias AllDifferentiableVariables = PropertyDiff
func vjpPropertyA() -> (Float, (Float) -> PropertyDiff) {
(.zero, { _ in .zero })
}
}
@differentiable
func f(_ x: PropertyDiff) -> Float {
return x.a
}
let a = gradient(at: PropertyDiff(), in: f)
print(a)
// Index based 'wrt:'
func add2(x: Float, y: Float) -> Float {
return x + y
}
@differentiating(add2, wrt: (0, y)) // ok
func two3(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
return (x + y, { ($0, $0) })
}
@differentiating(add2, wrt: (1)) // ok
func two4(x: Float, y: Float) -> (value: Float, pullback: (Float) -> Float) {
return (x + y, { $0 })
}
@differentiating(add2, wrt: 2) // expected-error {{parameter index is larger than total number of parameters}}
func two5(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
return (x + y, { ($0, $0) })
}
@differentiating(add2, wrt: (1, x)) // expected-error {{parameters must be specified in original order}}
func two6(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
return (x + y, { ($0, $0) })
}
@differentiating(add2, wrt: (1, 0)) // expected-error {{parameters must be specified in original order}}
func two7(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) {
return (x + y, { ($0, $0) })
}
// Test class methods.
class Super {
@differentiable
func foo(_ x: Float) -> Float {
return x
}
@differentiating(foo)
func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (foo(x), { v in v })
}
}
class Sub : Super {
// TODO(TF-649): Enable `@differentiating` to override original functions from superclass.
// expected-error @+1 {{'foo' is not defined in the current type context}}
@differentiating(foo)
override func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (foo(x), { v in v })
}
}