| // RUN: %target-swift-frontend-typecheck -verify -disable-availability-checking %s |
| // RUN: %target-swift-frontend-typecheck -enable-testing -verify -disable-availability-checking %s |
| |
| // Swift.AdditiveArithmetic:3:17: note: cannot yet register derivative default implementation for protocol requirements |
| |
| import _Differentiation |
| |
| // Dummy `Differentiable`-conforming type. |
| struct DummyTangentVector: Differentiable & AdditiveArithmetic { |
| static var zero: Self { Self() } |
| static func + (_: Self, _: Self) -> Self { Self() } |
| static func - (_: Self, _: Self) -> Self { Self() } |
| typealias TangentVector = Self |
| } |
| |
| // Test top-level functions. |
| |
| func id(_ x: Float) -> Float { |
| return x |
| } |
| @derivative(of: id) |
| func jvpId(x: Float) -> (value: Float, differential: (Float) -> (Float)) { |
| return (x, { $0 }) |
| } |
| @derivative(of: id, wrt: x) |
| func vjpIdExplicitWrt(x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| return (x, { $0 }) |
| } |
| |
| func generic<T: Differentiable>(_ x: T, _ y: T) -> T { |
| return x |
| } |
| @derivative(of: generic) |
| func jvpGeneric<T: Differentiable>(x: T, y: T) -> ( |
| value: T, differential: (T.TangentVector, T.TangentVector) -> T.TangentVector |
| ) { |
| return (x, { $0 + $1 }) |
| } |
| @derivative(of: generic) |
| func vjpGenericExtraGenericRequirements<T: Differentiable & FloatingPoint>( |
| x: T, y: T |
| ) -> (value: T, pullback: (T) -> (T, T)) where T == T.TangentVector { |
| return (x, { ($0, $0) }) |
| } |
| |
| // Test `wrt` parameter clauses. |
| |
| func add(x: Float, y: Float) -> Float { |
| return x + y |
| } |
| @derivative(of: add, wrt: x) // ok |
| func vjpAddWrtX(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float)) { |
| return (x + y, { $0 }) |
| } |
| @derivative(of: add, wrt: (x, y)) // ok |
| func vjpAddWrtXY(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) { |
| return (x + y, { ($0, $0) }) |
| } |
| |
| // Test index-based `wrt` parameters. |
| |
| func subtract(x: Float, y: Float) -> Float { |
| return x - y |
| } |
| @derivative(of: subtract, wrt: (0, y)) // ok |
| func vjpSubtractWrt0Y(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) { |
| return (x - y, { ($0, $0) }) |
| } |
| @derivative(of: subtract, wrt: (1)) // ok |
| func vjpSubtractWrt1(x: Float, y: Float) -> (value: Float, pullback: (Float) -> Float) { |
| return (x - y, { $0 }) |
| } |
| |
| // Test invalid original function. |
| |
| // expected-error @+1 {{cannot find 'nonexistentFunction' in scope}} |
| @derivative(of: nonexistentFunction) |
| func vjpOriginalFunctionNotFound(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| // Test `@derivative` attribute where `value:` result does not conform to `Differentiable`. |
| // Invalid original function should be diagnosed first. |
| // expected-error @+1 {{cannot find 'nonexistentFunction' in scope}} |
| @derivative(of: nonexistentFunction) |
| func vjpOriginalFunctionNotFound2(_ x: Float) -> (value: Int, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| // Test incorrect `@derivative` declaration type. |
| |
| // expected-note @+2 {{'incorrectDerivativeType' defined here}} |
| // expected-note @+1 {{candidate global function does not have expected type '(Int) -> Int'}} |
| func incorrectDerivativeType(_ x: Float) -> Float { |
| return x |
| } |
| |
| // expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple; first element must have label 'value:' and second element must have label 'pullback:' or 'differential:'}} |
| @derivative(of: incorrectDerivativeType) |
| func jvpResultIncorrect(x: Float) -> Float { |
| return x |
| } |
| // expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple; first element must have label 'value:'}} |
| @derivative(of: incorrectDerivativeType) |
| func vjpResultIncorrectFirstLabel(x: Float) -> (Float, (Float) -> Float) { |
| return (x, { $0 }) |
| } |
| // expected-error @+1 {{'@derivative(of:)' attribute requires function to return a two-element tuple; second element must have label 'pullback:' or 'differential:'}} |
| @derivative(of: incorrectDerivativeType) |
| func vjpResultIncorrectSecondLabel(x: Float) -> (value: Float, (Float) -> Float) { |
| return (x, { $0 }) |
| } |
| // expected-error @+1 {{referenced declaration 'incorrectDerivativeType' could not be resolved}} |
| @derivative(of: incorrectDerivativeType) |
| func vjpResultNotDifferentiable(x: Int) -> ( |
| value: Int, pullback: (Int) -> Int |
| ) { |
| return (x, { $0 }) |
| } |
| // expected-error @+2 {{function result's 'pullback' type does not match 'incorrectDerivativeType'}} |
| // expected-note @+3 {{'pullback' does not have expected type '(Float.TangentVector) -> Float.TangentVector' (aka '(Float) -> Float')}} |
| @derivative(of: incorrectDerivativeType) |
| func vjpResultIncorrectPullbackType(x: Float) -> ( |
| value: Float, pullback: (Double) -> Double |
| ) { |
| return (x, { $0 }) |
| } |
| |
| // Test invalid `wrt:` differentiation parameters. |
| |
| func invalidWrtParam(_ x: Float, _ y: Float) -> Float { |
| return x |
| } |
| |
| // expected-error @+1 {{unknown parameter name 'z'}} |
| @derivative(of: add, wrt: z) |
| func vjpUnknownParam(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float)) { |
| return (x + y, { $0 }) |
| } |
| // expected-error @+1 {{parameters must be specified in original order}} |
| @derivative(of: invalidWrtParam, wrt: (y, x)) |
| func vjpParamOrderNotIncreasing(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) { |
| return (x + y, { ($0, $0) }) |
| } |
| // expected-error @+1 {{'self' parameter is only applicable to instance methods}} |
| @derivative(of: invalidWrtParam, wrt: self) |
| func vjpInvalidSelfParam(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) { |
| return (x + y, { ($0, $0) }) |
| } |
| // expected-error @+1 {{parameter index is larger than total number of parameters}} |
| @derivative(of: invalidWrtParam, wrt: 2) |
| func vjpSubtractWrt2(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) { |
| return (x - y, { ($0, $0) }) |
| } |
| // expected-error @+1 {{parameters must be specified in original order}} |
| @derivative(of: invalidWrtParam, wrt: (1, x)) |
| func vjpSubtractWrt1x(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) { |
| return (x - y, { ($0, $0) }) |
| } |
| // expected-error @+1 {{parameters must be specified in original order}} |
| @derivative(of: invalidWrtParam, wrt: (1, 0)) |
| func vjpSubtractWrt10(x: Float, y: Float) -> (value: Float, pullback: (Float) -> (Float, Float)) { |
| return (x - y, { ($0, $0) }) |
| } |
| |
| func noParameters() -> Float { |
| return 1 |
| } |
| // expected-error @+1 {{'vjpNoParameters()' has no parameters to differentiate with respect to}} |
| @derivative(of: noParameters) |
| func vjpNoParameters() -> (value: Float, pullback: (Float) -> Float) { |
| return (1, { $0 }) |
| } |
| |
| func noDifferentiableParameters(x: Int) -> Float { |
| return 1 |
| } |
| // expected-error @+1 {{no differentiation parameters could be inferred; must differentiate with respect to at least one parameter conforming to 'Differentiable'}} |
| @derivative(of: noDifferentiableParameters) |
| func vjpNoDifferentiableParameters(x: Int) -> ( |
| value: Float, pullback: (Float) -> Int |
| ) { |
| return (1, { _ in 0 }) |
| } |
| |
| func functionParameter(_ fn: (Float) -> Float) -> Float { |
| return fn(1) |
| } |
| // expected-error @+1 {{can only differentiate with respect to parameters that conform to 'Differentiable', but '(Float) -> Float' does not conform to 'Differentiable'}} |
| @derivative(of: functionParameter, wrt: fn) |
| func vjpFunctionParameter(_ fn: (Float) -> Float) -> ( |
| value: Float, pullback: (Float) -> Float |
| ) { |
| return (functionParameter(fn), { $0 }) |
| } |
| |
| // Test static methods. |
| |
| protocol StaticMethod: Differentiable { |
| static func foo(_ x: Float) -> Float |
| static func generic<T: Differentiable>(_ x: T) -> T |
| } |
| |
| extension StaticMethod { |
| static func foo(_ x: Float) -> Float { x } |
| static func generic<T: Differentiable>(_ x: T) -> T { x } |
| } |
| |
| extension StaticMethod { |
| @derivative(of: foo) |
| static func jvpFoo(x: Float) -> (value: Float, differential: (Float) -> Float) |
| { |
| return (x, { $0 }) |
| } |
| |
| // Test qualified declaration name. |
| @derivative(of: StaticMethod.foo) |
| static func vjpFoo(x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| return (x, { $0 }) |
| } |
| |
| @derivative(of: generic) |
| static func vjpGeneric<T: Differentiable>(_ x: T) -> ( |
| value: T, pullback: (T.TangentVector) -> (T.TangentVector) |
| ) { |
| return (x, { $0 }) |
| } |
| |
| // expected-error @+1 {{'self' parameter is only applicable to instance methods}} |
| @derivative(of: foo, wrt: (self, x)) |
| static func vjpFooWrtSelf(x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| return (x, { $0 }) |
| } |
| } |
| |
| // Test instance methods. |
| |
| protocol InstanceMethod: Differentiable { |
| func foo(_ x: Self) -> Self |
| func generic<T: Differentiable>(_ x: T) -> Self |
| } |
| |
| extension InstanceMethod { |
| // expected-note @+2 {{'foo' previously declared here}} |
| // expected-note @+1 {{'foo' defined here}} |
| func foo(_ x: Self) -> Self { x } |
| |
| // expected-note @+2 {{'generic' previously declared here}} |
| // expected-note @+1 {{'generic' defined here}} |
| func generic<T: Differentiable>(_ x: T) -> Self { self } |
| } |
| |
| extension InstanceMethod { |
| // expected-error @+1 {{invalid redeclaration of 'foo'}} |
| func foo(_ x: Self) -> Self { self } |
| |
| // expected-error @+1 {{invalid redeclaration of 'generic'}} |
| func generic<T: Differentiable>(_ x: T) -> Self { self } |
| } |
| |
| extension InstanceMethod { |
| @derivative(of: foo) |
| func jvpFoo(x: Self) -> ( |
| value: Self, differential: (TangentVector, TangentVector) -> (TangentVector) |
| ) { |
| return (x, { $0 + $1 }) |
| } |
| |
| // Test qualified declaration name. |
| @derivative(of: InstanceMethod.foo, wrt: x) |
| func jvpFooWrtX(x: Self) -> ( |
| value: Self, differential: (TangentVector) -> (TangentVector) |
| ) { |
| return (x, { $0 }) |
| } |
| |
| @derivative(of: generic) |
| func vjpGeneric<T: Differentiable>(_ x: T) -> ( |
| value: Self, pullback: (TangentVector) -> (TangentVector, T.TangentVector) |
| ) { |
| return (self, { ($0, .zero) }) |
| } |
| |
| @derivative(of: generic, wrt: (self, x)) |
| func jvpGenericWrt<T: Differentiable>(_ x: T) -> (value: Self, differential: (TangentVector, T.TangentVector) -> TangentVector) { |
| return (self, { dself, dx in dself }) |
| } |
| |
| // expected-error @+1 {{'self' parameter must come first in the parameter list}} |
| @derivative(of: generic, wrt: (x, self)) |
| func jvpGenericWrtSelf<T: Differentiable>(_ x: T) -> (value: Self, differential: (TangentVector, T.TangentVector) -> TangentVector) { |
| return (self, { dself, dx in dself }) |
| } |
| } |
| |
| extension InstanceMethod { |
| // If `Self` conforms to `Differentiable`, then `Self` is inferred to be a differentiation parameter. |
| // expected-error @+2 {{function result's 'pullback' type does not match 'foo'}} |
| // expected-note @+3 {{'pullback' does not have expected type '(Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)'}} |
| @derivative(of: foo) |
| func vjpFoo(x: Self) -> ( |
| value: Self, pullback: (TangentVector) -> TangentVector |
| ) { |
| return (x, { $0 }) |
| } |
| |
| // If `Self` conforms to `Differentiable`, then `Self` is inferred to be a differentiation parameter. |
| // expected-error @+2 {{function result's 'pullback' type does not match 'generic'}} |
| // expected-note @+3 {{'pullback' does not have expected type '(Self.TangentVector) -> (Self.TangentVector, T.TangentVector)'}} |
| @derivative(of: generic) |
| func vjpGeneric<T: Differentiable>(_ x: T) -> ( |
| value: Self, pullback: (TangentVector) -> T.TangentVector |
| ) { |
| return (self, { _ in .zero }) |
| } |
| } |
| |
| // Test `@derivative` declaration with more constrained generic signature. |
| |
| func req1<T>(_ x: T) -> T { |
| return x |
| } |
| @derivative(of: req1) |
| func vjpExtraConformanceConstraint<T: Differentiable>(_ x: T) -> ( |
| value: T, pullback: (T.TangentVector) -> T.TangentVector |
| ) { |
| return (x, { $0 }) |
| } |
| |
| func req2<T, U>(_ x: T, _ y: U) -> T { |
| return x |
| } |
| @derivative(of: req2) |
| func vjpExtraConformanceConstraints<T: Differentiable, U: Differentiable>( _ x: T, _ y: U) -> ( |
| value: T, pullback: (T) -> (T, U) |
| ) where T == T.TangentVector, U == U.TangentVector, T: CustomStringConvertible { |
| return (x, { ($0, .zero) }) |
| } |
| |
| // Test `@derivative` declaration with extra same-type requirements. |
| func req3<T>(_ x: T) -> T { |
| return x |
| } |
| @derivative(of: req3) |
| func vjpSameTypeRequirementsGenericParametersAllConcrete<T>(_ x: T) -> ( |
| value: T, pullback: (T.TangentVector) -> T.TangentVector |
| ) where T: Differentiable, T.TangentVector == Float { |
| return (x, { $0 }) |
| } |
| |
| struct Wrapper<T: Equatable>: Equatable { |
| var x: T |
| init(_ x: T) { self.x = x } |
| } |
| extension Wrapper: AdditiveArithmetic where T: AdditiveArithmetic { |
| static var zero: Self { .init(.zero) } |
| static func + (lhs: Self, rhs: Self) -> Self { .init(lhs.x + rhs.x) } |
| static func - (lhs: Self, rhs: Self) -> Self { .init(lhs.x - rhs.x) } |
| } |
| extension Wrapper: Differentiable where T: Differentiable, T == T.TangentVector { |
| typealias TangentVector = Wrapper<T.TangentVector> |
| } |
| extension Wrapper where T: Differentiable, T == T.TangentVector { |
| @derivative(of: init(_:)) |
| static func vjpInit(_ x: T) -> (value: Self, pullback: (Wrapper<T>.TangentVector) -> (T)) { |
| fatalError() |
| } |
| } |
| |
| // Test class methods. |
| |
| class Super { |
| @differentiable |
| // expected-note @+1 {{candidate instance method is not defined in the current type context}} |
| func foo(_ x: Float) -> Float { |
| return x |
| } |
| |
| @derivative(of: foo) |
| func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| return (foo(x), { v in v }) |
| } |
| } |
| |
| class Sub: Super { |
| // TODO(TF-649): Enable `@derivative` to override derivatives for original |
| // declaration defined in superclass. |
| // expected-error @+1 {{referenced declaration 'foo' could not be resolved}} |
| @derivative(of: foo) |
| override func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) |
| { |
| return (foo(x), { v in v }) |
| } |
| } |
| |
| // Test non-`func` original declarations. |
| |
| struct Struct<T> { |
| var x: T |
| } |
| extension Struct: Equatable where T: Equatable {} |
| extension Struct: Differentiable & AdditiveArithmetic |
| where T: Differentiable & AdditiveArithmetic { |
| static var zero: Self { |
| fatalError() |
| } |
| static func + (lhs: Self, rhs: Self) -> Self { |
| fatalError() |
| } |
| static func - (lhs: Self, rhs: Self) -> Self { |
| fatalError() |
| } |
| typealias TangentVector = Struct<T.TangentVector> |
| mutating func move(along direction: TangentVector) { |
| x.move(along: direction.x) |
| } |
| } |
| |
| class Class<T> { |
| var x: T |
| init(_ x: T) { |
| self.x = x |
| } |
| } |
| extension Class: Differentiable where T: Differentiable {} |
| |
| // Test computed properties. |
| |
| extension Struct { |
| var computedProperty: T { |
| get { x } |
| set { x = newValue } |
| _modify { yield &x } |
| } |
| } |
| extension Struct where T: Differentiable & AdditiveArithmetic { |
| @derivative(of: computedProperty) |
| func vjpProperty() -> (value: T, pullback: (T.TangentVector) -> TangentVector) { |
| return (x, { v in .init(x: v) }) |
| } |
| |
| @derivative(of: computedProperty.get) |
| func jvpProperty() -> (value: T, differential: (TangentVector) -> T.TangentVector) { |
| fatalError() |
| } |
| |
| @derivative(of: computedProperty.set) |
| mutating func vjpPropertySetter(_ newValue: T) -> ( |
| value: (), pullback: (inout TangentVector) -> T.TangentVector |
| ) { |
| fatalError() |
| } |
| |
| // expected-error @+1 {{cannot register derivative for _modify accessor}} |
| @derivative(of: computedProperty._modify) |
| mutating func vjpPropertyModify(_ newValue: T) -> ( |
| value: (), pullback: (inout TangentVector) -> T.TangentVector |
| ) { |
| fatalError() |
| } |
| } |
| |
| // Test initializers. |
| |
| extension Struct { |
| init(_ x: Float) {} |
| init(_ x: T, y: Float) {} |
| } |
| extension Struct where T: Differentiable & AdditiveArithmetic { |
| @derivative(of: init) |
| static func vjpInit(_ x: Float) -> ( |
| value: Struct, pullback: (TangentVector) -> Float |
| ) { |
| return (.init(x), { _ in .zero }) |
| } |
| |
| @derivative(of: init(_:y:)) |
| static func vjpInit2(_ x: T, _ y: Float) -> ( |
| value: Struct, pullback: (TangentVector) -> (T.TangentVector, Float) |
| ) { |
| return (.init(x, y: y), { _ in (.zero, .zero) }) |
| } |
| } |
| |
| // Test subscripts. |
| |
| extension Struct { |
| subscript() -> Float { |
| get { 1 } |
| set {} |
| } |
| |
| subscript(float float: Float) -> Float { |
| get { 1 } |
| set {} |
| } |
| |
| // expected-note @+1 {{candidate subscript does not have a setter}} |
| subscript<T: Differentiable>(x: T) -> T { x } |
| } |
| extension Struct where T: Differentiable & AdditiveArithmetic { |
| @derivative(of: subscript.get) |
| func vjpSubscriptGetter() -> (value: Float, pullback: (Float) -> TangentVector) { |
| return (1, { _ in .zero }) |
| } |
| |
| // expected-error @+2 {{a derivative already exists for '_'}} |
| // expected-note @-6 {{other attribute declared here}} |
| @derivative(of: subscript) |
| func vjpSubscript() -> (value: Float, pullback: (Float) -> TangentVector) { |
| return (1, { _ in .zero }) |
| } |
| |
| @derivative(of: subscript().get) |
| func jvpSubscriptGetter() -> (value: Float, differential: (TangentVector) -> Float) { |
| return (1, { _ in .zero }) |
| } |
| |
| @derivative(of: subscript(float:).get, wrt: self) |
| func vjpSubscriptLabeledGetter(float: Float) -> (value: Float, pullback: (Float) -> TangentVector) { |
| return (1, { _ in .zero }) |
| } |
| |
| // expected-error @+2 {{a derivative already exists for '_'}} |
| // expected-note @-6 {{other attribute declared here}} |
| @derivative(of: subscript(float:), wrt: self) |
| func vjpSubscriptLabeled(float: Float) -> (value: Float, pullback: (Float) -> TangentVector) { |
| return (1, { _ in .zero }) |
| } |
| |
| @derivative(of: subscript(float:).get) |
| func jvpSubscriptLabeledGetter(float: Float) -> (value: Float, differential: (TangentVector, Float) -> Float) { |
| return (1, { (_,_) in 1}) |
| } |
| |
| @derivative(of: subscript(_:).get, wrt: self) |
| func vjpSubscriptGenericGetter<T: Differentiable>(x: T) -> (value: T, pullback: (T.TangentVector) -> TangentVector) { |
| return (x, { _ in .zero }) |
| } |
| |
| // expected-error @+2 {{a derivative already exists for '_'}} |
| // expected-note @-6 {{other attribute declared here}} |
| @derivative(of: subscript(_:), wrt: self) |
| func vjpSubscriptGeneric<T: Differentiable>(x: T) -> (value: T, pullback: (T.TangentVector) -> TangentVector) { |
| return (x, { _ in .zero }) |
| } |
| |
| @derivative(of: subscript.set) |
| mutating func vjpSubscriptSetter(_ newValue: Float) -> ( |
| value: (), pullback: (inout TangentVector) -> Float |
| ) { |
| fatalError() |
| } |
| |
| @derivative(of: subscript().set) |
| mutating func jvpSubscriptSetter(_ newValue: Float) -> ( |
| value: (), differential: (inout TangentVector, Float) -> () |
| ) { |
| fatalError() |
| } |
| |
| @derivative(of: subscript(float:).set) |
| mutating func vjpSubscriptLabeledSetter(float: Float, newValue: Float) -> ( |
| value: (), pullback: (inout TangentVector) -> (Float, Float) |
| ) { |
| fatalError() |
| } |
| |
| @derivative(of: subscript(float:).set) |
| mutating func jvpSubscriptLabeledSetter(float: Float, _ newValue: Float) -> ( |
| value: (), differential: (inout TangentVector, Float, Float) -> Void |
| ) { |
| fatalError() |
| } |
| |
| // Error: original subscript has no setter. |
| // expected-error @+1 {{referenced declaration 'subscript(_:)' could not be resolved}} |
| @derivative(of: subscript(_:).set, wrt: self) |
| mutating func vjpSubscriptGeneric_NoSetter<T: Differentiable>(x: T) -> ( |
| value: T, pullback: (T.TangentVector) -> TangentVector |
| ) { |
| return (x, { _ in .zero }) |
| } |
| } |
| |
| extension Class { |
| subscript() -> Float { |
| get { 1 } |
| // expected-note @+1 {{'subscript()' declared here}} |
| set {} |
| } |
| } |
| extension Class where T: Differentiable { |
| @derivative(of: subscript.get) |
| func vjpSubscriptGetter() -> (value: Float, pullback: (Float) -> TangentVector) { |
| return (1, { _ in .zero }) |
| } |
| |
| // expected-error @+2 {{a derivative already exists for '_'}} |
| // expected-note @-6 {{other attribute declared here}} |
| @derivative(of: subscript) |
| func vjpSubscript() -> (value: Float, pullback: (Float) -> TangentVector) { |
| return (1, { _ in .zero }) |
| } |
| |
| // FIXME(SR-13096): Enable derivative registration for class property/subscript setters. |
| // This requires changing derivative type calculation rules for functions with |
| // class-typed parameters. We need to assume that all functions taking |
| // class-typed operands may mutate those operands. |
| // expected-error @+1 {{cannot yet register derivative for class property or subscript setters}} |
| @derivative(of: subscript.set) |
| func vjpSubscriptSetter(_ newValue: Float) -> ( |
| value: (), pullback: (inout TangentVector) -> Float |
| ) { |
| fatalError() |
| } |
| } |
| |
| // Test duplicate `@derivative` attribute. |
| |
| func duplicate(_ x: Float) -> Float { x } |
| // expected-note @+1 {{other attribute declared here}} |
| @derivative(of: duplicate) |
| func jvpDuplicate1(_ x: Float) -> (value: Float, differential: (Float) -> Float) { |
| return (duplicate(x), { $0 }) |
| } |
| // expected-error @+1 {{a derivative already exists for 'duplicate'}} |
| @derivative(of: duplicate) |
| func jvpDuplicate2(_ x: Float) -> (value: Float, differential: (Float) -> Float) { |
| return (duplicate(x), { $0 }) |
| } |
| |
| // Test invalid original declaration kind. |
| |
| // expected-note @+1 {{candidate var does not have a getter}} |
| var globalVariable: Float |
| |
| // expected-error @+1 {{referenced declaration 'globalVariable' could not be resolved}} |
| @derivative(of: globalVariable) |
| func invalidOriginalDeclaration(x: Float) -> ( |
| value: Float, differential: (Float) -> (Float) |
| ) { |
| return (x, { $0 }) |
| } |
| |
| // Test ambiguous original declaration. |
| |
| protocol P1 {} |
| protocol P2 {} |
| // expected-note @+1 {{candidate global function found here}} |
| func ambiguous<T: P1>(_ x: T) -> T { x } |
| // expected-note @+1 {{candidate global function found here}} |
| func ambiguous<T: P2>(_ x: T) -> T { x } |
| |
| // expected-error @+1 {{referenced declaration 'ambiguous' is ambiguous}} |
| @derivative(of: ambiguous) |
| func jvpAmbiguous<T: P1 & P2 & Differentiable>(x: T) |
| -> (value: T, differential: (T.TangentVector) -> (T.TangentVector)) |
| { |
| return (x, { $0 }) |
| } |
| |
| // Test no valid original declaration. |
| // Original declarations are invalid because they have extra generic |
| // requirements unsatisfied by the `@derivative` function. |
| |
| // expected-note @+1 {{candidate global function does not have type equal to or less constrained than '<T where T : Differentiable> (x: T) -> T'}} |
| func invalid<T: BinaryFloatingPoint>(x: T) -> T { x } |
| // expected-note @+1 {{candidate global function does not have type equal to or less constrained than '<T where T : Differentiable> (x: T) -> T'}} |
| func invalid<T: CustomStringConvertible>(x: T) -> T { x } |
| // expected-note @+1 {{candidate global function does not have type equal to or less constrained than '<T where T : Differentiable> (x: T) -> T'}} |
| func invalid<T: FloatingPoint>(x: T) -> T { x } |
| |
| // expected-error @+1 {{referenced declaration 'invalid' could not be resolved}} |
| @derivative(of: invalid) |
| func jvpInvalid<T: Differentiable>(x: T) -> ( |
| value: T, differential: (T.TangentVector) -> T.TangentVector |
| ) { |
| return (x, { $0 }) |
| } |
| |
| // Test invalid derivative type context: instance vs static method mismatch. |
| |
| struct InvalidTypeContext<T: Differentiable> { |
| // expected-note @+1 {{candidate static method does not have type equal to or less constrained than '<T where T : Differentiable> (InvalidTypeContext<T>) -> (T) -> T'}} |
| static func staticMethod(_ x: T) -> T { x } |
| |
| // expected-error @+1 {{referenced declaration 'staticMethod' could not be resolved}} |
| @derivative(of: staticMethod) |
| func jvpStatic(_ x: T) -> ( |
| value: T, differential: (T.TangentVector) -> (T.TangentVector) |
| ) { |
| return (x, { $0 }) |
| } |
| } |
| |
| // Test stored property original declaration. |
| |
| struct HasStoredProperty { |
| // expected-note @+1 {{'stored' declared here}} |
| var stored: Float |
| } |
| extension HasStoredProperty: Differentiable & AdditiveArithmetic { |
| static var zero: Self { |
| fatalError() |
| } |
| static func + (lhs: Self, rhs: Self) -> Self { |
| fatalError() |
| } |
| static func - (lhs: Self, rhs: Self) -> Self { |
| fatalError() |
| } |
| typealias TangentVector = Self |
| } |
| extension HasStoredProperty { |
| // expected-error @+1 {{cannot register derivative for stored property 'stored'}} |
| @derivative(of: stored) |
| func vjpStored() -> (value: Float, pullback: (Float) -> TangentVector) { |
| return (stored, { _ in .zero }) |
| } |
| } |
| |
| // Test derivative registration for protocol requirements. Currently unsupported. |
| // TODO(TF-982): Lift this restriction and add proper support. |
| |
| protocol ProtocolRequirementDerivative { |
| // expected-note @+1 {{cannot yet register derivative default implementation for protocol requirements}} |
| func requirement(_ x: Float) -> Float |
| } |
| extension ProtocolRequirementDerivative { |
| // expected-error @+1 {{referenced declaration 'requirement' could not be resolved}} |
| @derivative(of: requirement) |
| func vjpRequirement(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| } |
| |
| // Test `inout` parameters. |
| |
| func multipleSemanticResults(_ x: inout Float) -> Float { |
| return x |
| } |
| // expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} |
| @derivative(of: multipleSemanticResults) |
| func vjpMultipleSemanticResults(x: inout Float) -> ( |
| value: Float, pullback: (Float) -> Float |
| ) { |
| return (multipleSemanticResults(&x), { $0 }) |
| } |
| |
| struct InoutParameters: Differentiable { |
| typealias TangentVector = DummyTangentVector |
| mutating func move(along _: TangentVector) {} |
| } |
| |
| extension InoutParameters { |
| // expected-note @+1 4 {{'staticMethod(_:rhs:)' defined here}} |
| static func staticMethod(_ lhs: inout Self, rhs: Self) {} |
| |
| // Test wrt `inout` parameter. |
| |
| @derivative(of: staticMethod) |
| static func vjpWrtInout(_ lhs: inout Self, _ rhs: Self) -> ( |
| value: Void, pullback: (inout TangentVector) -> TangentVector |
| ) { fatalError() } |
| |
| // expected-error @+1 {{function result's 'pullback' type does not match 'staticMethod(_:rhs:)'}} |
| @derivative(of: staticMethod) |
| static func vjpWrtInoutMismatch(_ lhs: inout Self, _ rhs: Self) -> ( |
| // expected-note @+1 {{'pullback' does not have expected type '(inout InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(inout DummyTangentVector) -> DummyTangentVector')}} |
| value: Void, pullback: (TangentVector) -> TangentVector |
| ) { fatalError() } |
| |
| @derivative(of: staticMethod) |
| static func jvpWrtInout(_ lhs: inout Self, _ rhs: Self) -> ( |
| value: Void, differential: (inout TangentVector, TangentVector) -> Void |
| ) { fatalError() } |
| |
| // expected-error @+1 {{function result's 'differential' type does not match 'staticMethod(_:rhs:)'}} |
| @derivative(of: staticMethod) |
| static func jvpWrtInoutMismatch(_ lhs: inout Self, _ rhs: Self) -> ( |
| // expected-note @+1 {{'differential' does not have expected type '(inout InoutParameters.TangentVector, InoutParameters.TangentVector) -> ()' (aka '(inout DummyTangentVector, DummyTangentVector) -> ()')}} |
| value: Void, differential: (TangentVector, TangentVector) -> Void |
| ) { fatalError() } |
| |
| // Test non-wrt `inout` parameter. |
| |
| @derivative(of: staticMethod, wrt: rhs) |
| static func vjpNotWrtInout(_ lhs: inout Self, _ rhs: Self) -> ( |
| value: Void, pullback: (TangentVector) -> TangentVector |
| ) { fatalError() } |
| |
| // expected-error @+1 {{function result's 'pullback' type does not match 'staticMethod(_:rhs:)'}} |
| @derivative(of: staticMethod, wrt: rhs) |
| static func vjpNotWrtInoutMismatch(_ lhs: inout Self, _ rhs: Self) -> ( |
| // expected-note @+1 {{'pullback' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}} |
| value: Void, pullback: (inout TangentVector) -> TangentVector |
| ) { fatalError() } |
| |
| @derivative(of: staticMethod, wrt: rhs) |
| static func jvpNotWrtInout(_ lhs: inout Self, _ rhs: Self) -> ( |
| value: Void, differential: (TangentVector) -> TangentVector |
| ) { fatalError() } |
| |
| // expected-error @+1 {{function result's 'differential' type does not match 'staticMethod(_:rhs:)'}} |
| @derivative(of: staticMethod, wrt: rhs) |
| static func jvpNotWrtInout(_ lhs: inout Self, _ rhs: Self) -> ( |
| // expected-note @+1 {{'differential' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}} |
| value: Void, differential: (inout TangentVector) -> TangentVector |
| ) { fatalError() } |
| } |
| |
| extension InoutParameters { |
| // expected-note @+1 4 {{'mutatingMethod' defined here}} |
| mutating func mutatingMethod(_ other: Self) {} |
| |
| // Test wrt `inout` `self` parameter. |
| |
| @derivative(of: mutatingMethod) |
| mutating func vjpWrtInout(_ other: Self) -> ( |
| value: Void, pullback: (inout TangentVector) -> TangentVector |
| ) { fatalError() } |
| |
| // expected-error @+1 {{function result's 'pullback' type does not match 'mutatingMethod'}} |
| @derivative(of: mutatingMethod) |
| mutating func vjpWrtInoutMismatch(_ other: Self) -> ( |
| // expected-note @+1 {{'pullback' does not have expected type '(inout InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(inout DummyTangentVector) -> DummyTangentVector')}} |
| value: Void, pullback: (TangentVector) -> TangentVector |
| ) { fatalError() } |
| |
| @derivative(of: mutatingMethod) |
| mutating func jvpWrtInout(_ other: Self) -> ( |
| value: Void, differential: (inout TangentVector, TangentVector) -> Void |
| ) { fatalError() } |
| |
| // expected-error @+1 {{function result's 'differential' type does not match 'mutatingMethod'}} |
| @derivative(of: mutatingMethod) |
| mutating func jvpWrtInoutMismatch(_ other: Self) -> ( |
| // expected-note @+1 {{'differential' does not have expected type '(inout InoutParameters.TangentVector, InoutParameters.TangentVector) -> ()' (aka '(inout DummyTangentVector, DummyTangentVector) -> ()')}} |
| value: Void, differential: (TangentVector, TangentVector) -> Void |
| ) { fatalError() } |
| |
| // Test non-wrt `inout` `self` parameter. |
| |
| @derivative(of: mutatingMethod, wrt: other) |
| mutating func vjpNotWrtInout(_ other: Self) -> ( |
| value: Void, pullback: (TangentVector) -> TangentVector |
| ) { fatalError() } |
| |
| // expected-error @+1 {{function result's 'pullback' type does not match 'mutatingMethod'}} |
| @derivative(of: mutatingMethod, wrt: other) |
| mutating func vjpNotWrtInoutMismatch(_ other: Self) -> ( |
| // expected-note @+1 {{'pullback' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}} |
| value: Void, pullback: (inout TangentVector) -> TangentVector |
| ) { fatalError() } |
| |
| @derivative(of: mutatingMethod, wrt: other) |
| mutating func jvpNotWrtInout(_ other: Self) -> ( |
| value: Void, differential: (TangentVector) -> TangentVector |
| ) { fatalError() } |
| |
| // expected-error @+1 {{function result's 'differential' type does not match 'mutatingMethod'}} |
| @derivative(of: mutatingMethod, wrt: other) |
| mutating func jvpNotWrtInoutMismatch(_ other: Self) -> ( |
| // expected-note @+1 {{'differential' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}} |
| value: Void, differential: (TangentVector, TangentVector) -> Void |
| ) { fatalError() } |
| } |
| |
| // Test no semantic results. |
| |
| func noSemanticResults(_ x: Float) {} |
| |
| // expected-error @+1 {{cannot differentiate void function 'noSemanticResults'}} |
| @derivative(of: noSemanticResults) |
| func vjpNoSemanticResults(_ x: Float) -> (value: Void, pullback: Void) {} |
| |
| // Test multiple semantic results. |
| |
| extension InoutParameters { |
| func multipleSemanticResults(_ x: inout Float) -> Float { x } |
| // expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} |
| @derivative(of: multipleSemanticResults) |
| func vjpMultipleSemanticResults(_ x: inout Float) -> ( |
| value: Float, pullback: (inout Float) -> Void |
| ) { fatalError() } |
| |
| func inoutVoid(_ x: Float, _ void: inout Void) -> Float {} |
| // expected-error @+1 {{cannot differentiate functions with both an 'inout' parameter and a result}} |
| @derivative(of: inoutVoid) |
| func vjpInoutVoidParameter(_ x: Float, _ void: inout Void) -> ( |
| value: Float, pullback: (inout Float) -> Void |
| ) { fatalError() } |
| } |
| |
| // Test original/derivative function `inout` parameter mismatches. |
| |
| extension InoutParameters { |
| // expected-note @+1 {{candidate instance method does not have expected type '(InoutParameters) -> (inout Float) -> Void'}} |
| func inoutParameterMismatch(_ x: Float) {} |
| |
| // expected-error @+1 {{referenced declaration 'inoutParameterMismatch' could not be resolved}} |
| @derivative(of: inoutParameterMismatch) |
| func vjpInoutParameterMismatch(_ x: inout Float) -> (value: Void, pullback: (inout Float) -> Void) { |
| fatalError() |
| } |
| |
| // expected-note @+1 {{candidate instance method does not have expected type '(inout InoutParameters) -> (Float) -> Void'}} |
| func mutatingMismatch(_ x: Float) {} |
| |
| // expected-error @+1 {{referenced declaration 'mutatingMismatch' could not be resolved}} |
| @derivative(of: mutatingMismatch) |
| mutating func vjpMutatingMismatch(_ x: Float) -> (value: Void, pullback: (inout Float) -> Void) { |
| fatalError() |
| } |
| } |
| |
| // Test cross-file derivative registration. |
| |
| extension FloatingPoint where Self: Differentiable { |
| @usableFromInline |
| @derivative(of: rounded) |
| func vjpRounded() -> ( |
| value: Self, |
| pullback: (Self.TangentVector) -> (Self.TangentVector) |
| ) { |
| fatalError() |
| } |
| } |
| |
| extension Differentiable where Self: AdditiveArithmetic { |
| // expected-error @+1 {{referenced declaration '+' could not be resolved}} |
| @derivative(of: +) |
| static func vjpPlus(x: Self, y: Self) -> ( |
| value: Self, |
| pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector) |
| ) { |
| return (x + y, { v in (v, v) }) |
| } |
| } |
| |
| extension AdditiveArithmetic |
| where Self: Differentiable, Self == Self.TangentVector { |
| // expected-error @+1 {{referenced declaration '+' could not be resolved}} |
| @derivative(of: +) |
| func vjpPlusInstanceMethod(x: Self, y: Self) -> ( |
| value: Self, pullback: (Self) -> (Self, Self) |
| ) { |
| return (x + y, { v in (v, v) }) |
| } |
| } |
| |
| // Test derivatives of default implementations. |
| protocol HasADefaultImplementation { |
| func req(_ x: Float) -> Float |
| } |
| extension HasADefaultImplementation { |
| func req(_ x: Float) -> Float { x } |
| // ok |
| @derivative(of: req) |
| func req(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| (x, { 10 * $0 }) |
| } |
| } |
| |
| // Test default derivatives of requirements. |
| protocol HasADefaultDerivative { |
| // expected-note @+1 {{cannot yet register derivative default implementation for protocol requirements}} |
| func req(_ x: Float) -> Float |
| } |
| extension HasADefaultDerivative { |
| // TODO(TF-982): Support default derivatives for protocol requirements. |
| // expected-error @+1 {{referenced declaration 'req' could not be resolved}} |
| @derivative(of: req) |
| func vjpReq(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| (x, { 10 * $0 }) |
| } |
| } |
| |
| // MARK: - Original function visibility = derivative function visibility |
| |
| public func public_original_public_derivative(_ x: Float) -> Float { x } |
| @derivative(of: public_original_public_derivative) |
| public func _public_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| public func public_original_usablefrominline_derivative(_ x: Float) -> Float { x } |
| @usableFromInline |
| @derivative(of: public_original_usablefrominline_derivative) |
| func _public_original_usablefrominline_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| func internal_original_internal_derivative(_ x: Float) -> Float { x } |
| @derivative(of: internal_original_internal_derivative) |
| func _internal_original_internal_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| private func private_original_private_derivative(_ x: Float) -> Float { x } |
| @derivative(of: private_original_private_derivative) |
| private func _private_original_private_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| fileprivate func fileprivate_original_fileprivate_derivative(_ x: Float) -> Float { x } |
| @derivative(of: fileprivate_original_fileprivate_derivative) |
| fileprivate func _fileprivate_original_fileprivate_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| func internal_original_usablefrominline_derivative(_ x: Float) -> Float { x } |
| @usableFromInline |
| @derivative(of: internal_original_usablefrominline_derivative) |
| func _internal_original_usablefrominline_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| func internal_original_inlinable_derivative(_ x: Float) -> Float { x } |
| @inlinable |
| @derivative(of: internal_original_inlinable_derivative) |
| func _internal_original_inlinable_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| func internal_original_alwaysemitintoclient_derivative(_ x: Float) -> Float { x } |
| @_alwaysEmitIntoClient |
| @derivative(of: internal_original_alwaysemitintoclient_derivative) |
| func _internal_original_alwaysemitintoclient_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| // MARK: - Original function visibility < derivative function visibility |
| |
| @usableFromInline |
| func usablefrominline_original_public_derivative(_ x: Float) -> Float { x } |
| // expected-error @+1 {{derivative function must have same access level as original function; derivative function '_usablefrominline_original_public_derivative' is public, but original function 'usablefrominline_original_public_derivative' is internal}} |
| @derivative(of: usablefrominline_original_public_derivative) |
| // expected-note @+1 {{mark the derivative function as 'internal' to match the original function}} {{1-7=internal}} |
| public func _usablefrominline_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| func internal_original_public_derivative(_ x: Float) -> Float { x } |
| // expected-error @+1 {{derivative function must have same access level as original function; derivative function '_internal_original_public_derivative' is public, but original function 'internal_original_public_derivative' is internal}} |
| @derivative(of: internal_original_public_derivative) |
| // expected-note @+1 {{mark the derivative function as 'internal' to match the original function}} {{1-7=internal}} |
| public func _internal_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| private func private_original_usablefrominline_derivative(_ x: Float) -> Float { x } |
| // expected-error @+1 {{derivative function must have same access level as original function; derivative function '_private_original_usablefrominline_derivative' is internal, but original function 'private_original_usablefrominline_derivative' is private}} |
| @derivative(of: private_original_usablefrominline_derivative) |
| @usableFromInline |
| // expected-note @+1 {{mark the derivative function as 'private' to match the original function}} {{1-1=private }} |
| func _private_original_usablefrominline_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| private func private_original_public_derivative(_ x: Float) -> Float { x } |
| // expected-error @+1 {{derivative function must have same access level as original function; derivative function '_private_original_public_derivative' is public, but original function 'private_original_public_derivative' is private}} |
| @derivative(of: private_original_public_derivative) |
| // expected-note @+1 {{mark the derivative function as 'private' to match the original function}} {{1-7=private}} |
| public func _private_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| private func private_original_internal_derivative(_ x: Float) -> Float { x } |
| // expected-error @+1 {{derivative function must have same access level as original function; derivative function '_private_original_internal_derivative' is internal, but original function 'private_original_internal_derivative' is private}} |
| @derivative(of: private_original_internal_derivative) |
| // expected-note @+1 {{mark the derivative function as 'private' to match the original function}} |
| func _private_original_internal_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| fileprivate func fileprivate_original_private_derivative(_ x: Float) -> Float { x } |
| // expected-error @+1 {{derivative function must have same access level as original function; derivative function '_fileprivate_original_private_derivative' is private, but original function 'fileprivate_original_private_derivative' is fileprivate}} |
| @derivative(of: fileprivate_original_private_derivative) |
| // expected-note @+1 {{mark the derivative function as 'fileprivate' to match the original function}} {{1-8=fileprivate}} |
| private func _fileprivate_original_private_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| private func private_original_fileprivate_derivative(_ x: Float) -> Float { x } |
| // expected-error @+1 {{derivative function must have same access level as original function; derivative function '_private_original_fileprivate_derivative' is fileprivate, but original function 'private_original_fileprivate_derivative' is private}} |
| @derivative(of: private_original_fileprivate_derivative) |
| // expected-note @+1 {{mark the derivative function as 'private' to match the original function}} {{1-12=private}} |
| fileprivate func _private_original_fileprivate_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| // MARK: - Original function visibility > derivative function visibility |
| |
| public func public_original_private_derivative(_ x: Float) -> Float { x } |
| // expected-error @+1 {{derivative function must have same access level as original function; derivative function '_public_original_private_derivative' is fileprivate, but original function 'public_original_private_derivative' is public}} |
| @derivative(of: public_original_private_derivative) |
| // expected-note @+1 {{mark the derivative function as '@usableFromInline' to match the original function}} {{1-1=@usableFromInline }} |
| fileprivate func _public_original_private_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| public func public_original_internal_derivative(_ x: Float) -> Float { x } |
| // expected-error @+1 {{derivative function must have same access level as original function; derivative function '_public_original_internal_derivative' is internal, but original function 'public_original_internal_derivative' is public}} |
| @derivative(of: public_original_internal_derivative) |
| // expected-note @+1 {{mark the derivative function as '@usableFromInline' to match the original function}} {{1-1=@usableFromInline }} |
| func _public_original_internal_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| func internal_original_fileprivate_derivative(_ x: Float) -> Float { x } |
| // expected-error @+1 {{derivative function must have same access level as original function; derivative function '_internal_original_fileprivate_derivative' is fileprivate, but original function 'internal_original_fileprivate_derivative' is internal}} |
| @derivative(of: internal_original_fileprivate_derivative) |
| // expected-note @+1 {{mark the derivative function as 'internal' to match the original function}} {{1-12=internal}} |
| fileprivate func _internal_original_fileprivate_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| // Test invalid reference to an accessor of a non-storage declaration. |
| |
| // expected-note @+1 {{candidate global function does not have a getter}} |
| func function(_ x: Float) -> Float { |
| x |
| } |
| |
| // expected-error @+1 {{referenced declaration 'function' could not be resolved}} |
| @derivative(of: function(_:).get) |
| func vjpFunction(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |
| |
| // Test ambiguity that exists when Type function name is the same |
| // as an accessor label. |
| |
| extension Float { |
| // Original function name conflicts with an accessor name ("set"). |
| func set() -> Float { |
| self |
| } |
| |
| // Original function name does not conflict with an accessor name. |
| func method() -> Float { |
| self |
| } |
| |
| // Test ambiguous parse. |
| // Expected: |
| // - Base type: `Float` |
| // - Declaration name: `set` |
| // - Accessor kind: <none> |
| // Actual: |
| // - Base type: <none> |
| // - Declaration name: `Float` |
| // - Accessor kind: `set` |
| // expected-error @+1 {{cannot find 'Float' in scope}} |
| @derivative(of: Float.set) |
| func jvpSet() -> (value: Float, differential: (Float) -> Float) { |
| fatalError() |
| } |
| |
| @derivative(of: Float.method) |
| func jvpMethod() -> (value: Float, differential: (Float) -> Float) { |
| fatalError() |
| } |
| } |
| |
| // Test original function with opaque result type. |
| |
| // expected-note @+1 {{candidate global function does not have expected type '(Float) -> Float'}} |
| func opaqueResult(_ x: Float) -> some Differentiable { x } |
| |
| // expected-error @+1 {{referenced declaration 'opaqueResult' could not be resolved}} |
| @derivative(of: opaqueResult) |
| func vjpOpaqueResult(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { |
| fatalError() |
| } |