blob: 3140fcbcca1a23f5e4aaacf16c0482dce710fcb2 [file] [log] [blame]
// RUN: %target-swift-frontend -typecheck -verify %s
@differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}}
let globalConst: Float = 1
@differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}}
var globalVar: Float = 1
func testLocalVariables() {
// expected-error @+1 {{'_' has no parameters to differentiate with respect to}}
@differentiable
var getter: Float {
return 1
}
// expected-error @+1 {{'_' has no parameters to differentiate with respect to}}
@differentiable
var getterSetter: Float {
get { return 1 }
set {}
}
}
@differentiable(vjp: dfoo) // expected-error {{'@differentiable' attribute cannot be applied to this declaration}}
protocol P {}
@differentiable() // ok!
func no_jvp_or_vjp(_ x: Float) -> Float {
return x * x
}
// Test duplicated `@differentiable` attributes.
@differentiable // expected-error {{duplicate '@differentiable' attribute with same parameters}}
@differentiable // expected-note {{other attribute declared here}}
func dupe_attributes(arg: Float) -> Float { return arg }
@differentiable(wrt: arg1)
@differentiable(wrt: arg2) // expected-error {{duplicate '@differentiable' attribute with same parameters}}
@differentiable(wrt: arg2) // expected-note {{other attribute declared here}}
func dupe_attributes(arg1: Float, arg2: Float) -> Float { return arg1 }
struct ComputedPropertyDupeAttributes<T : Differentiable> : Differentiable {
var value: T
@differentiable // expected-error {{duplicate '@differentiable' attribute with same parameters}}
var computed1: T {
@differentiable // expected-note {{other attribute declared here}}
get { value }
set { value = newValue }
}
// TODO(TF-482): Remove diagnostics when `@differentiable` attributes are
// also uniqued based on generic requirements.
@differentiable(where T == Float) // expected-error {{duplicate '@differentiable' attribute with same parameters}}
@differentiable(where T == Double) // expected-note {{other attribute declared here}}
var computed2: T {
get { value }
set { value = newValue }
}
}
// Test TF-568.
protocol WrtOnlySelfProtocol : Differentiable {
@differentiable
var computedProperty: Float { get }
@differentiable
func method() -> Float
}
class Class : Differentiable {}
@differentiable(wrt: x)
func invalidDiffWrtClass(_ x: Class) -> Class {
return x
}
protocol Proto {}
// expected-error @+1 {{cannot differentiate with respect to protocol existential ('Proto')}}
@differentiable(wrt: x)
func invalidDiffWrtExistential(_ x: Proto) -> Proto {
return x
}
// expected-error @+1 {{functions ('@differentiable (Float) -> Float') cannot be differentiated with respect to}}
@differentiable(wrt: fn)
func invalidDiffWrtFunction(_ fn: @differentiable(Float) -> Float) -> Float {
return fn(.pi)
}
// expected-error @+1 {{'invalidDiffNoParams()' has no parameters to differentiate with respect to}}
@differentiable
func invalidDiffNoParams() -> Float {
return 1
}
// expected-error @+1 {{cannot differentiate void function 'invalidDiffVoidResult(x:)'}}
@differentiable
func invalidDiffVoidResult(x: Float) {}
// Test static methods.
struct StaticMethod {
// expected-error @+1 {{'invalidDiffNoParams()' has no parameters to differentiate with respect to}}
@differentiable
static func invalidDiffNoParams() -> Float {
return 1
}
// expected-error @+1 {{cannot differentiate void function 'invalidDiffVoidResult(x:)'}}
@differentiable
static func invalidDiffVoidResult(x: Float) {}
}
// Test instance methods.
struct InstanceMethod {
// expected-error @+1 {{'invalidDiffNoParams()' has no parameters to differentiate with respect to}}
@differentiable
func invalidDiffNoParams() -> Float {
return 1
}
// expected-error @+1 {{cannot differentiate void function 'invalidDiffVoidResult(x:)'}}
@differentiable
func invalidDiffVoidResult(x: Float) {}
}
// Test instance methods for a `Differentiable` type.
struct DifferentiableInstanceMethod : Differentiable {
@differentiable // ok
func noParams() -> Float {
return 1
}
}
// Test subscript methods.
struct SubscriptMethod {
@differentiable // ok
subscript(implicitGetter x: Float) -> Float {
return x
}
@differentiable // ok
subscript(implicitGetterSetter x: Float) -> Float {
get { return x }
set {}
}
subscript(explicit x: Float) -> Float {
@differentiable // ok
get { return x }
@differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}}
set {}
}
subscript(x: Float, y: Float) -> Float {
@differentiable // ok
get { return x + y }
@differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}}
set {}
}
}
// JVP
@differentiable(jvp: jvpSimpleJVP)
func jvpSimple(x: Float) -> Float {
return x
}
func jvpSimpleJVP(x: Float) -> (Float, ((Float) -> Float)) {
return (x, { v in v })
}
@differentiable(wrt: y, jvp: jvpWrtSubsetJVP)
func jvpWrtSubset1(x: Float, y: Float) -> Float {
return x + y
}
@differentiable(wrt: (y), jvp: jvpWrtSubsetJVP)
func jvpWrtSubset2(x: Float, y: Float) -> Float {
return x + y
}
func jvpWrtSubsetJVP(x: Float, y: Float) -> (Float, (Float) -> Float) {
return (x + y, { v in v })
}
@differentiable(jvp: jvp2ParamsJVP)
func jvp2Params(x: Float, y: Float) -> Float {
return x + y
}
func jvp2ParamsJVP(x: Float, y: Float) -> (Float, (Float, Float) -> Float) {
return (x + y, { (a, b) in a + b })
}
// expected-error @+1 {{unknown parameter name 'y'}}
@differentiable(wrt: (y))
func jvpUnknownParam(x: Float) -> Float {
return x
}
// expected-error @+1 {{parameters must be specified in original order}}
@differentiable(wrt: (y, x))
func jvpParamOrderNotIncreasing(x: Float, y: Float) -> Float {
return x * y
}
// expected-error @+1 {{'jvpWrongTypeJVP' does not have expected type '(Float) -> (Float, (Float.TangentVector) -> Float.TangentVector)' (aka '(Float) -> (Float, (Float) -> Float)'}}
@differentiable(jvp: jvpWrongTypeJVP)
func jvpWrongType(x: Float) -> Float {
return x
}
func jvpWrongTypeJVP(x: Float) -> (Float, (Float) -> Int) {
return (x, { v in Int(v) })
}
// expected-error @+1 {{no differentiation parameters could be inferred; must differentiate with respect to at least one parameter conforming to 'Differentiable'}}
@differentiable(jvp: jvpSimpleJVP)
func jvpNonDiffParam(x: Int) -> Float {
return Float(x)
}
// expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but 'Int' does not conform to 'Differentiable'}}
@differentiable(jvp: jvpSimpleJVP)
func jvpNonDiffResult(x: Float) -> Int {
return Int(x)
}
// expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but '(Float, Int)' does not conform to 'Differentiable'}}
@differentiable(jvp: jvpSimpleJVP)
func jvpNonDiffResult2(x: Float) -> (Float, Int) {
return (x, Int(x))
}
// expected-error @+1 {{ambiguous or overloaded identifier 'jvpAmbiguousVJP' cannot be used in '@differentiable' attribute}}
@differentiable(jvp: jvpAmbiguousVJP)
func jvpAmbiguous(x: Float) -> Float {
return x
}
func jvpAmbiguousVJP(_ x: Float) -> (Float, (Float) -> Float) {
return (x, { $0 })
}
func jvpAmbiguousVJP(x: Float) -> (Float, (Float) -> Float) {
return (x, { $0 })
}
class DifferentiableClassMethod {
// Direct differentiation case.
@differentiable
func foo(_ x: Float) -> Float {
return x
}
}
struct JVPStruct {
@differentiable
let p: Float
// expected-error @+1 {{'funcJVP' does not have expected type '(JVPStruct) -> () -> (Double, (JVPStruct.TangentVector) -> Double.TangentVector)' (aka '(JVPStruct) -> () -> (Double, (JVPStruct) -> Double)'}}
@differentiable(wrt: (self), jvp: funcJVP)
func funcWrongType() -> Double {
fatalError("unimplemented")
}
}
extension JVPStruct {
func funcJVP() -> (Float, (JVPStruct) -> Float) {
fatalError("unimplemented")
}
}
extension JVPStruct : AdditiveArithmetic {
static var zero: JVPStruct { fatalError("unimplemented") }
static func + (lhs: JVPStruct, rhs: JVPStruct) -> JVPStruct {
fatalError("unimplemented")
}
static func - (lhs: JVPStruct, rhs: JVPStruct) -> JVPStruct {
fatalError("unimplemented")
}
typealias Scalar = Float
static func * (lhs: Float, rhs: JVPStruct) -> JVPStruct {
fatalError("unimplemented")
}
}
extension JVPStruct : Differentiable {
typealias TangentVector = JVPStruct
}
extension JVPStruct {
@differentiable(wrt: x, jvp: wrtAllNonSelfJVP)
func wrtAllNonSelf(x: Float) -> Float {
return x + p
}
func wrtAllNonSelfJVP(x: Float) -> (Float, (Float) -> Float) {
return (x + p, { v in v })
}
}
extension JVPStruct {
@differentiable(wrt: (self, x), jvp: wrtAllJVP)
func wrtAll(x: Float) -> Float {
return x + p
}
func wrtAllJVP(x: Float) -> (Float, (JVPStruct, Float) -> Float) {
return (x + p, { (a, b) in a.p + b })
}
}
extension JVPStruct {
@differentiable(jvp: computedPropJVP)
var computedPropOk1: Float {
return 0
}
var computedPropOk2: Float {
@differentiable(jvp: computedPropJVP)
get {
return 0
}
}
// expected-error @+1 {{'computedPropJVP' does not have expected type '(JVPStruct) -> () -> (Double, (JVPStruct.TangentVector) -> Double.TangentVector)' (aka '(JVPStruct) -> () -> (Double, (JVPStruct) -> Double)'}}
@differentiable(jvp: computedPropJVP)
var computedPropWrongType: Double {
return 0
}
var computedPropWrongAccessor: Float {
get {
return 0
}
// expected-error @+1 {{'@differentiable' attribute cannot be applied to this declaration}}
@differentiable(jvp: computedPropJVP)
set {
fatalError("unimplemented")
}
}
func computedPropJVP() -> (Float, (JVPStruct) -> Float) {
fatalError("unimplemented")
}
}
// VJP
@differentiable(vjp: vjpSimpleVJP)
func vjpSimple(x: Float) -> Float {
return x
}
func vjpSimpleVJP(x: Float) -> (Float, ((Float) -> Float)) {
return (x, { v in v })
}
@differentiable(wrt: (y), vjp: vjpWrtSubsetVJP)
func vjpWrtSubset(x: Float, y: Float) -> Float {
return x + y
}
func vjpWrtSubsetVJP(x: Float, y: Float) -> (Float, (Float) -> Float) {
return (x + y, { v in v })
}
@differentiable(vjp: vjp2ParamsVJP)
func vjp2Params(x: Float, y: Float) -> Float {
return x + y
}
func vjp2ParamsVJP(x: Float, y: Float) -> (Float, (Float) -> (Float, Float)) {
return (x + y, { v in (v, v) })
}
// expected-error @+1 {{'vjpWrongTypeVJP' does not have expected type '(Float) -> (Float, (Float.TangentVector) -> Float.TangentVector)' (aka '(Float) -> (Float, (Float) -> Float)'}}
@differentiable(vjp: vjpWrongTypeVJP)
func vjpWrongType(x: Float) -> Float {
return x
}
func vjpWrongTypeVJP(x: Float) -> (Float, (Float) -> Int) {
return (x, { v in Int(v) })
}
// expected-error @+1 {{no differentiation parameters could be inferred; must differentiate with respect to at least one parameter conforming to 'Differentiable'}}
@differentiable(vjp: vjpSimpleVJP)
func vjpNonDiffParam(x: Int) -> Float {
return Float(x)
}
// expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but 'Int' does not conform to 'Differentiable'}}
@differentiable(vjp: vjpSimpleVJP)
func vjpNonDiffResult(x: Float) -> Int {
return Int(x)
}
// expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but '(Float, Int)' does not conform to 'Differentiable'}}
@differentiable(vjp: vjpSimpleVJP)
func vjpNonDiffResult2(x: Float) -> (Float, Int) {
return (x, Int(x))
}
struct VJPStruct {
let p: Float
// expected-error @+1 {{'funcVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}}
@differentiable(vjp: funcVJP)
func funcWrongType() -> Double {
fatalError("unimplemented")
}
}
extension VJPStruct {
func funcVJP() -> (Float, (Float) -> VJPStruct) {
fatalError("unimplemented")
}
}
extension VJPStruct : AdditiveArithmetic {
static var zero: VJPStruct { fatalError("unimplemented") }
static func + (lhs: VJPStruct, rhs: VJPStruct) -> VJPStruct {
fatalError("unimplemented")
}
static func - (lhs: VJPStruct, rhs: VJPStruct) -> VJPStruct {
fatalError("unimplemented")
}
typealias Scalar = Float
static func * (lhs: Float, rhs: VJPStruct) -> VJPStruct {
fatalError("unimplemented")
}
}
extension VJPStruct : Differentiable {
typealias TangentVector = VJPStruct
}
extension VJPStruct {
@differentiable(wrt: x, vjp: wrtAllNonSelfVJP)
func wrtAllNonSelf(x: Float) -> Float {
return x + p
}
func wrtAllNonSelfVJP(x: Float) -> (Float, (Float) -> Float) {
return (x + p, { v in v })
}
}
extension VJPStruct {
@differentiable(wrt: (self, x), vjp: wrtAllVJP)
func wrtAll(x: Float) -> Float {
return x + p
}
func wrtAllVJP(x: Float) -> (Float, (Float) -> (VJPStruct, Float)) {
fatalError("unimplemented")
}
}
extension VJPStruct {
@differentiable(vjp: computedPropVJP)
var computedPropOk1: Float {
return 0
}
var computedPropOk2: Float {
@differentiable(vjp: computedPropVJP)
get {
return 0
}
}
// expected-error @+1 {{'computedPropVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}}
@differentiable(vjp: computedPropVJP)
var computedPropWrongType: Double {
return 0
}
var computedPropWrongAccessor: Float {
get {
return 0
}
// expected-error @+1 {{'@differentiable' attribute cannot be applied to this declaration}}
@differentiable(vjp: computedPropVJP)
set {
fatalError("unimplemented")
}
}
func computedPropVJP() -> (Float, (Float) -> VJPStruct) {
fatalError("unimplemented")
}
}
// expected-error @+2 {{empty 'where' clause in '@differentiable' attribute}}
// expected-error @+1 {{expected type}}
@differentiable(where)
func emptyWhereClause<T>(x: T) -> T {
return x
}
// expected-error @+1 {{trailing 'where' clause in '@differentiable' attribute of non-generic function 'nongenericWhereClause(x:)'}}
@differentiable(where T : Differentiable)
func nongenericWhereClause(x: Float) -> Float {
return x
}
@differentiable(jvp: jvpWhere1, vjp: vjpWhere1 where T : Differentiable)
func where1<T>(x: T) -> T {
return x
}
func jvpWhere1<T : Differentiable>(x: T) -> (T, (T.TangentVector) -> T.TangentVector) {
return (x, { v in v })
}
func vjpWhere1<T : Differentiable>(x: T) -> (T, (T.TangentVector) -> T.TangentVector) {
return (x, { v in v })
}
// Test derivative functions with result tuple type labels.
@differentiable(jvp: jvpResultLabels, vjp: vjpResultLabels)
func derivativeResultLabels(_ x: Float) -> Float {
return x
}
func jvpResultLabels(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
return (x, { $0 })
}
func vjpResultLabels(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (x, { $0 })
}
struct ResultLabelTest {
@differentiable(jvp: jvpResultLabels, vjp: vjpResultLabels)
static func derivativeResultLabels(_ x: Float) -> Float {
return x
}
static func jvpResultLabels(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
return (x, { $0 })
}
static func vjpResultLabels(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (x, { $0 })
}
@differentiable(jvp: jvpResultLabels, vjp: vjpResultLabels)
func derivativeResultLabels(_ x: Float) -> Float {
return x
}
func jvpResultLabels(_ x: Float) -> (value: Float, differential: (Float) -> Float) {
return (x, { $0 })
}
func vjpResultLabels(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (x, { $0 })
}
}
struct Tensor<Scalar> : AdditiveArithmetic {}
extension Tensor : Differentiable where Scalar : Differentiable {}
@differentiable(where Scalar : Differentiable)
func where2<Scalar : Numeric>(x: Tensor<Scalar>) -> Tensor<Scalar> {
return x
}
func adjWhere2<Scalar : Numeric & Differentiable>(seed: Tensor<Scalar>, originalResult: Tensor<Scalar>, x: Tensor<Scalar>) -> Tensor<Scalar> {
return seed
}
func jvpWhere2<Scalar : Numeric & Differentiable>(x: Tensor<Scalar>) -> (Tensor<Scalar>, (Tensor<Scalar>) -> Tensor<Scalar>) {
return (x, { v in v })
}
func vjpWhere2<Scalar : Numeric & Differentiable>(x: Tensor<Scalar>) -> (Tensor<Scalar>, (Tensor<Scalar>) -> Tensor<Scalar>) {
return (x, { v in v })
}
struct A<T> {
struct B<U, V> {
@differentiable(wrt: x where T : Differentiable, V : Differentiable, V.TangentVector == V)
func whereInGenericContext<T>(x: T) -> T {
return x
}
}
}
extension FloatingPoint {
@differentiable(wrt: (self) where Self : Differentiable)
func whereClauseExtension() -> Self {
return self
}
}
// expected-error @+1 {{'vjpNonvariadic' does not have expected type '(Float, Int32...) -> (Float, (Float.TangentVector) -> Float.TangentVector)' (aka '(Float, Int32...) -> (Float, (Float) -> Float)')}}
@differentiable(wrt: x, vjp: vjpNonvariadic)
func variadic(_ x: Float, indices: Int32...) -> Float {
return x
}
func vjpNonvariadic(_ x: Float, indices: [Int32]) -> (Float, (Float) -> Float) {
return (x, { $0 })
}
// expected-error @+3 {{type 'Scalar' constrained to non-protocol, non-class type 'Float'}}
// expected-error @+2 {{no differentiation parameters could be inferred; must differentiate with respect to at least one parameter conforming to 'Differentiable'}}
// expected-note @+1 {{use 'Scalar == Float' to require 'Scalar' to be 'Float'}}
@differentiable(where Scalar : Float)
func invalidRequirementConformance<Scalar>(x: Scalar) -> Scalar {
return x
}
@differentiable(where T : AnyObject)
func invalidAnyObjectRequirement<T : Differentiable>(x: T) -> T {
return x
}
// expected-error @+1 {{'@differentiable' attribute does not support layout requirements}}
@differentiable(where Scalar : _Trivial)
func invalidRequirementLayout<Scalar>(x: Scalar) -> Scalar {
return x
}
// expected-error @+1 {{no differentiation parameters could be inferred; must differentiate with respect to at least one parameter conforming to 'Differentiable'}}
@differentiable
func missingConformance<T>(_ x: T) -> T {
return x
}
protocol ProtocolRequirements : Differentiable {
// expected-note @+2 {{protocol requires initializer 'init(x:y:)' with type '(x: Float, y: Float)'}}
@differentiable
init(x: Float, y: Float)
// expected-note @+2 {{protocol requires initializer 'init(x:y:)' with type '(x: Float, y: Int)'}}
@differentiable(wrt: x)
init(x: Float, y: Int)
// expected-note @+2 {{protocol requires function 'amb(x:y:)' with type '(Float, Float) -> Float';}}
@differentiable
func amb(x: Float, y: Float) -> Float
// expected-note @+2 {{protocol requires function 'amb(x:y:)' with type '(Float, Int) -> Float';}}
@differentiable(wrt: x)
func amb(x: Float, y: Int) -> Float
// expected-note @+3 {{protocol requires function 'f1'}}
// expected-note @+2 {{overridden declaration is here}}
@differentiable(wrt: (self, x))
func f1(_ x: Float) -> Float
// expected-note @+2 {{protocol requires function 'f2'}}
@differentiable(wrt: (self, x, y))
func f2(_ x: Float, _ y: Float) -> Float
}
protocol ProtocolRequirementsRefined : ProtocolRequirements {
// expected-error @+1 {{overriding declaration is missing attribute '@differentiable'}}
func f1(_ x: Float) -> Float
}
// expected-error @+1 {{does not conform to protocol 'ProtocolRequirements'}}
struct DiffAttrConformanceErrors : ProtocolRequirements {
var x: Float
var y: Float
// FIXME(TF-284): Fix unexpected diagnostic.
// expected-note @+2 {{candidate is missing attribute '@differentiable'}}
// expected-note @+1 {{candidate has non-matching type '(x: Float, y: Float)'}}
init(x: Float, y: Float) {
self.x = x
self.y = y
}
// FIXME(TF-284): Fix unexpected diagnostic.
// expected-note @+2 {{candidate is missing attribute '@differentiable'}}
// expected-note @+1 {{candidate has non-matching type '(x: Float, y: Int)'}}
init(x: Float, y: Int) {
self.x = x
self.y = Float(y)
}
// expected-note @+2 {{candidate is missing attribute '@differentiable'}}
// expected-note @+1 {{candidate has non-matching type '(Float, Float) -> Float'}}
func amb(x: Float, y: Float) -> Float {
return x
}
// expected-note @+2 {{candidate is missing attribute '@differentiable(wrt: x)'}}
// expected-note @+1 {{candidate has non-matching type '(Float, Int) -> Float'}}
func amb(x: Float, y: Int) -> Float {
return x
}
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
func f1(_ x: Float) -> Float {
return x
}
// expected-note @+2 {{candidate is missing attribute '@differentiable'}}
@differentiable(wrt: (self, x))
func f2(_ x: Float, _ y: Float) -> Float {
return x + y
}
}
protocol ProtocolRequirementsWithDefault_NoConformingTypes {
@differentiable
func f1(_ x: Float) -> Float
}
extension ProtocolRequirementsWithDefault_NoConformingTypes {
// TODO(TF-650): It would be nice to diagnose protocol default implementation
// with missing `@differentiable` attribute.
func f1(_ x: Float) -> Float { x }
}
protocol ProtocolRequirementsWithDefault {
// expected-note @+2 {{protocol requires function 'f1'}}
@differentiable
func f1(_ x: Float) -> Float
}
extension ProtocolRequirementsWithDefault {
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
func f1(_ x: Float) -> Float { x }
}
// expected-error @+1 {{type 'DiffAttrConformanceErrors2' does not conform to protocol 'ProtocolRequirementsWithDefault'}}
struct DiffAttrConformanceErrors2 : ProtocolRequirementsWithDefault {
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
func f1(_ x: Float) -> Float { x }
}
protocol NotRefiningDiffable {
@differentiable(wrt: x)
// expected-note @+1 {{protocol requires function 'a' with type '(Float) -> Float'; do you want to add a stub?}}
func a(_ x: Float) -> Float
}
// expected-error @+1 {{type 'CertainlyNotDiffableWrtSelf' does not conform to protocol 'NotRefiningDiffable'}}
struct CertainlyNotDiffableWrtSelf : NotRefiningDiffable {
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
func a(_ x: Float) -> Float { return x * 5.0 }
}
protocol TF285 : Differentiable {
@differentiable(wrt: (x, y))
@differentiable(wrt: x)
// expected-note @+1 {{protocol requires function 'foo(x:y:)' with type '(Float, Float) -> Float'; do you want to add a stub?}}
func foo(x: Float, y: Float) -> Float
}
// expected-error @+1 {{type 'TF285MissingOneDiffAttr' does not conform to protocol 'TF285'}}
struct TF285MissingOneDiffAttr : TF285 {
// Requirement is missing an attribute.
@differentiable(wrt: x)
// expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: (x, y))}}
func foo(x: Float, y: Float) -> Float {
return x
}
}
// TF-521: Test invalid `@differentiable` attribute due to invalid
// `Differentiable` conformance (`TangentVector` does not conform to
// `AdditiveArithmetic`).
struct TF_521<T: FloatingPoint> {
var real: T
var imaginary: T
// expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but 'TF_521<T>' does not conform to 'Differentiable'}}
@differentiable(vjp: _vjpInit where T: Differentiable, T == T.TangentVector)
init(real: T = 0, imaginary: T = 0) {
self.real = real
self.imaginary = imaginary
}
}
// expected-error @+2 {{type 'TF_521<T>' does not conform to protocol '_Differentiable'}}
// expected-note @+1 {{do you want to add protocol stubs}}
extension TF_521: Differentiable where T: Differentiable {
// expected-note @+1 {{possibly intended match 'TF_521<T>.TangentVector' does not conform to 'AdditiveArithmetic'}}
typealias TangentVector = TF_521
typealias AllDifferentiableVariables = TF_521
}
extension TF_521 where T: Differentiable, T == T.TangentVector {
static func _vjpInit(real: T, imaginary: T) -> (TF_521, (TF_521) -> (T, T)) {
return (TF_521(real: real, imaginary: imaginary), { ($0.real, $0.imaginary) })
}
}
// expected-error @+1 {{result is not differentiable, but the function type is marked '@differentiable'}}
let _: @differentiable(Float, Float) -> TF_521<Float> = { r, i in
TF_521(real: r, imaginary: i)
}
// TF-296: Infer `@differentiable` wrt parameters to be to all parameters that conform to `Differentiable`.
@differentiable
func infer1(_ a: Float, _ b: Int) -> Float {
return a + Float(b)
}
@differentiable
func infer2(_ fn: @differentiable(Float) -> Float, x: Float) -> Float {
return fn(x)
}
struct DiffableStruct : Differentiable {
var a: Float
@differentiable
func fn(_ b: Float, _ c: Int) -> Float {
return a + b + Float(c)
}
}
struct NonDiffableStruct {
var a: Float
@differentiable
func fn(_ b: Float) -> Float {
return a + b
}
}
@differentiable(linear, wrt: x, vjp: const3) // expected-error {{cannot specify 'vjp:' or 'jvp:' for linear functions; use 'transpose:' instead}}
func slope1(_ x: Float) -> Float {
return 3 * x
}
@differentiable(linear, wrt: x, jvp: const3) // expected-error {{cannot specify 'vjp:' or 'jvp:' for linear functions; use 'transpose:' instead}}
func slope2(_ x: Float) -> Float {
return 3 * x
}
@differentiable(linear, jvp: const3, vjp: const3) // expected-error {{cannot specify 'vjp:' or 'jvp:' for linear functions; use 'transpose:' instead}}
func slope3(_ x: Float) -> Float {
return 3 * x
}
// Index based 'wrt:'
struct NumberWrtStruct: Differentiable {
var a, b: Float
@differentiable(wrt: 0) // ok
@differentiable(wrt: 1) // ok
func foo1(_ x: Float, _ y: Float) -> Float {
return a*x + b*y
}
@differentiable(wrt: -1) // expected-error {{expected a parameter, which can be a function parameter name, parameter index, or 'self'}}
@differentiable(wrt: (1, x)) // expected-error {{parameters must be specified in original order}}
func foo2(_ x: Float, _ y: Float) -> Float {
return a*x + b*y
}
@differentiable(wrt: (x, 1)) // ok
@differentiable(wrt: (0)) // ok
static func staticFoo1(_ x: Float, _ y: Float) -> Float {
return x + y
}
@differentiable(wrt: (1, 1)) // expected-error {{parameters must be specified in original order}}
@differentiable(wrt: (2)) // expected-error {{parameter index is larger than total number of parameters}}
static func staticFoo2(_ x: Float, _ y: Float) -> Float {
return x + y
}
}
@differentiable(wrt: y) // ok
func two1(x: Float, y: Float) -> Float {
return x + y
}
@differentiable(wrt: (x, y)) // ok
func two2(x: Float, y: Float) -> Float {
return x + y
}
@differentiable(wrt: (0, y)) // ok
func two3(x: Float, y: Float) -> Float {
return x + y
}
@differentiable(wrt: (x, 1)) // ok
func two4(x: Float, y: Float) -> Float {
return x + y
}
@differentiable(wrt: (0, 1)) // ok
func two5(x: Float, y: Float) -> Float {
return x + y
}
@differentiable(wrt: 2) // expected-error {{parameter index is larger than total number of parameters}}
func two6(x: Float, y: Float) -> Float {
return x + y
}
@differentiable(wrt: (1, 0)) // expected-error {{parameters must be specified in original order}}
func two7(x: Float, y: Float) -> Float {
return x + y
}
@differentiable(wrt: (1, x)) // expected-error {{parameters must be specified in original order}}
func two8(x: Float, y: Float) -> Float {
return x + y
}
@differentiable(wrt: (y, 0)) // expected-error {{parameters must be specified in original order}}
func two9(x: Float, y: Float) -> Float {
return x + y
}
// Inout 'wrt:' arguments.
@differentiable(wrt: y) // expected-error {{cannot differentiate void function 'inout1(x:y:)'}}
func inout1(x: Float, y: inout Float) -> Void {
let _ = x + y
}
@differentiable(wrt: y) // expected-error {{'inout' parameters ('inout Float') cannot be differentiated with respect to}}
func inout2(x: Float, y: inout Float) -> Float {
let _ = x + y
}
// Test refining protocol requirements with `@differentiable` attribute.
public protocol Distribution {
associatedtype Value
func logProbability(of value: Value) -> Float
}
public protocol DifferentiableDistribution: Differentiable, Distribution {
// expected-note @+2 {{overridden declaration is here}}
@differentiable(wrt: self)
func logProbability(of value: Value) -> Float
}
// Adding a more general `@differentiable` attribute.
public protocol DoubleDifferentiableDistribution: DifferentiableDistribution
where Value: Differentiable {
// expected-error @+1 {{overriding declaration is missing attribute '@differentiable(wrt: self)'}}
func logProbability(of value: Value) -> Float
}
// Test protocol requirement `@differentiable` attribute unsupported features.
protocol ProtocolRequirementUnsupported : Differentiable {
associatedtype Scalar
// expected-error @+1 {{'@differentiable' attribute on protocol requirement cannot specify 'where' clause}}
@differentiable(where Scalar: Differentiable)
func unsupportedWhereClause(value: Scalar) -> Float
// expected-error @+1 {{'@differentiable' attribute on protocol requirement cannot specify 'jvp:' or 'vjp:'}}
@differentiable(wrt: x, jvp: dfoo, vjp: dfoo)
func unsupportedDerivatives(_ x: Float) -> Float
}
extension ProtocolRequirementUnsupported {
func dfoo(_ x: Float) -> (Float, (Float) -> Float) {
(x, { $0 })
}
}
// Classes.
class Super : Differentiable {
var base: Float
// NOTE(TF-654): Class initializers are not yet supported.
// expected-error @+1 {{'@differentiable' attribute does not yet support class initializers}}
@differentiable
init(base: Float) {
self.base = base
}
@differentiable(wrt: (self, x))
@differentiable(wrt: x, vjp: vjp)
// expected-note @+1 2 {{overridden declaration is here}}
func testMissingAttributes(_ x: Float) -> Float { x }
@differentiable(wrt: x, vjp: vjp)
func testSuperclassDerivatives(_ x: Float) -> Float { x }
final func vjp(_ x: Float) -> (Float, (Float) -> Float) {
fatalError()
}
// expected-error @+1 {{'@differentiable' attribute cannot be declared on class methods returning 'Self'}}
@differentiable(vjp: vjpDynamicSelfResult)
func dynamicSelfResult() -> Self { self }
// TODO(TF-632): Fix "'TangentVector' is not a member type of 'Self'" diagnostic.
// The underlying error should appear instead:
// "covariant 'Self' can only appear at the top level of method result type".
// expected-error @+1 2 {{'TangentVector' is not a member type of 'Self'}}
func vjpDynamicSelfResult() -> (Self, (Self.TangentVector) -> Self.TangentVector) {
return (self, { $0 })
}
}
class Sub : Super {
// expected-error @+2 {{overriding declaration is missing attribute '@differentiable(wrt: x)'}}
// expected-error @+1 {{overriding declaration is missing attribute '@differentiable'}}
override func testMissingAttributes(_ x: Float) -> Float { x }
// expected-error @+1 {{'vjp' is not defined in the current type context}}
@differentiable(wrt: x, vjp: vjp)
override func testSuperclassDerivatives(_ x: Float) -> Float { x }
}