blob: 714487f378608c389869f4cae54e591e64080d57 [file] [log] [blame]
// RUN: %target-swift-frontend -typecheck -verify %s
import _Differentiation
//===----------------------------------------------------------------------===//
// Basic @differentiable function types.
//===----------------------------------------------------------------------===//
// expected-error @+1 {{@differentiable attribute only applies to function types}}
let _: @differentiable Float
let _: @differentiable (Float) -> Float
let _: @differentiable (Float) throws -> Float
//===----------------------------------------------------------------------===//
// Type differentiability
//===----------------------------------------------------------------------===//
struct NonDiffType { var x: Int }
// FIXME: Properly type-check parameters and the result's differentiability
// expected-error @+1 {{parameter type 'NonDiffType' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
let _: @differentiable (NonDiffType) -> Float
// Emit `@noDerivative` fixit iff there is at least one valid differentiability parameter.
// expected-error @+1 {{parameter type 'NonDiffType' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'; did you want to add '@noDerivative' to this parameter?}} {{32-32=@noDerivative }}
let _: @differentiable (Float, NonDiffType) -> Float
// expected-error @+1 {{result type 'NonDiffType' does not conform to 'Differentiable' and satisfy 'NonDiffType == NonDiffType.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
let _: @differentiable(linear) (Float) -> NonDiffType
// Emit `@noDerivative` fixit iff there is at least one valid linearity parameter.
// expected-error @+1 {{parameter type 'NonDiffType' does not conform to 'Differentiable' and satisfy 'NonDiffType == NonDiffType.TangentVector', but the enclosing function type is '@differentiable(linear)'; did you want to add '@noDerivative' to this parameter?}} {{40-40=@noDerivative }}
let _: @differentiable(linear) (Float, NonDiffType) -> Float
// expected-error @+1 {{result type 'NonDiffType' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
let _: @differentiable (Float) -> NonDiffType
// expected-error @+1 {{result type 'NonDiffType' does not conform to 'Differentiable' and satisfy 'NonDiffType == NonDiffType.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
let _: @differentiable(linear) (Float) -> NonDiffType
let _: @differentiable(linear) (Float) -> Float
// expected-error @+1 {{result type '@differentiable (U) -> Float' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
func test1<T: Differentiable, U: Differentiable>(_: @differentiable (T) -> @differentiable (U) -> Float) {}
// expected-error @+1 {{result type '(U) -> Float' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
func test2<T: Differentiable, U: Differentiable>(_: @differentiable (T) -> (U) -> Float) {}
// expected-error @+2 {{result type 'Int' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
// expected-error @+1 {{result type '@differentiable (U) -> Int' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
func test3<T: Differentiable, U: Differentiable>(_: @differentiable (T) -> @differentiable (U) -> Int) {}
// expected-error @+1 {{result type '(U) -> Int' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
func test4<T: Differentiable, U: Differentiable>(_: @differentiable (T) -> (U) -> Int) {}
//===----------------------------------------------------------------------===//
// Function conversion
//===----------------------------------------------------------------------===//
/// Function with similar signature as `gradient`, for testing purposes.
func fakeGradient<T, U: FloatingPoint>(of f: @differentiable (T) -> U) {}
func takesOpaqueClosure(f: @escaping (Float) -> Float) {
// expected-note @-1 {{did you mean to take a '@differentiable' closure?}} {{38-38=@differentiable }}
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or 'init' or a literal closure}}
fakeGradient(of: f)
}
let globalAddOne: (Float) -> Float = { $0 + 1 }
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or 'init' or a literal closure}}
fakeGradient(of: globalAddOne)
func someScope() {
let localAddOne: (Float) -> Float = { $0 + 1 }
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or 'init' or a literal closure}}
fakeGradient(of: globalAddOne)
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or 'init' or a literal closure}}
fakeGradient(of: localAddOne)
// The following case is okay during type checking, but will fail in the AD transform.
fakeGradient { localAddOne($0) }
}
func addOne(x: Float) -> Float { x + 1 }
fakeGradient(of: addOne) // okay
extension Float {
static func addOne(x: Float) -> Float { x + 1 }
func addOne(x: Float) -> Float { x + 1 }
}
fakeGradient(of: Float.addOne) // okay
fakeGradient(of: Float(1.0).addOne) // okay
// TODO(TF-908): Remove this test once linear-to-differentiable conversion is supported.
func linearToDifferentiable(_ f: @escaping @differentiable(linear) (Float) -> Float) {
// expected-error @+1 {{conversion from '@differentiable(linear)' to '@differentiable' is not yet supported}}
_ = f as @differentiable (Float) -> Float
}
func differentiableToLinear(_ f: @escaping @differentiable (Float) -> Float) {
// expected-error @+1 {{a '@differentiable(linear)' function can only be formed from a reference to a 'func' or 'init' or a literal closure}}
_ = f as @differentiable(linear) (Float) -> Float
}
struct Struct: Differentiable {
var x: Float
}
let _: @differentiable (Float) -> Struct = Struct.init
//===----------------------------------------------------------------------===//
// Parameter selection (@noDerivative)
//===----------------------------------------------------------------------===//
// expected-error @+1 {{'@noDerivative' may only be used on parameters of '@differentiable' function types}}
let _: @noDerivative Float
// expected-error @+1 {{'@noDerivative' may only be used on parameters of '@differentiable' function types}}
let _: (@noDerivative Float) -> Float
// expected-error @+1 {{'@noDerivative' may only be used on parameters of '@differentiable' function types}}
let _: (@noDerivative Float, Float) -> Float
let _: @differentiable (Float, @noDerivative Float) -> Float // okay
// expected-error @+1 {{'@differentiable' function type requires at least one differentiability parameter, i.e. a non-'@noDerivative' parameter whose type conforms to 'Differentiable'}}
let _: @differentiable (@noDerivative Float) -> Float
// expected-error @+1 {{'@differentiable' function type requires at least one differentiability parameter, i.e. a non-'@noDerivative' parameter whose type conforms to 'Differentiable'}}
let _: @differentiable (@noDerivative Float, @noDerivative Int) -> Float
// expected-error @+1 {{'@differentiable' function type requires at least one differentiability parameter, i.e. a non-'@noDerivative' parameter whose type conforms to 'Differentiable'}}
let _: @differentiable (@noDerivative Float, @noDerivative Float) -> Float
// expected-error @+1 {{parameter type 'Int' does not conform to 'Differentiable' and satisfy 'Int == Int.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
let _: @differentiable(linear) (@noDerivative Float, Int) -> Float
// expected-error @+1 {{'@noDerivative' may only be used on parameters of '@differentiable' function types}}
let _: (Float) -> @noDerivative Float
// expected-error @+1 {{'@noDerivative' may only be used on parameters of '@differentiable' function types}}
let _: @differentiable (Float) -> @noDerivative Float
// expected-error @+2 {{'@noDerivative' may only be used on parameters of '@differentiable' function types}}
// expected-error @+1 {{'@noDerivative' must not be used on variadic parameters}}
let _: (Float, @noDerivative Float...) -> Float
let _: @differentiable (@noDerivative Float, Float) -> Float
// expected-error @+1 {{'@noDerivative' must not be used on variadic parameters}}
let _: @differentiable (Float, @noDerivative Float...) -> Float
//===----------------------------------------------------------------------===//
// Inferred conformances
//===----------------------------------------------------------------------===//
let diffFunc: @differentiable (Float) -> Float
let linearFunc: @differentiable(linear) (Float) -> Float
func inferredConformances<T, U>(_: @differentiable (T) -> U) {}
func inferredConformancesLinear<T, U>(_: @differentiable(linear) (T) -> U) {}
inferredConformances(diffFunc)
inferredConformancesLinear(linearFunc)
func inferredConformancesResult<T, U>() -> @differentiable (T) -> U {}
func inferredConformancesResultLinear<T, U>() -> @differentiable(linear) (T) -> U {}
let diffFuncWithNondiff: @differentiable (Float, @noDerivative Int) -> Float
let linearFuncWithNondiff: @differentiable(linear) (Float, @noDerivative Int) -> Float
func inferredConformances<T, U, V>(_: @differentiable (T, @noDerivative U) -> V) {}
func inferredConformancesLinear<T, U, V>(_: @differentiable(linear) (T, @noDerivative U) -> V) {}
inferredConformances(diffFuncWithNondiff)
inferredConformancesLinear(linearFuncWithNondiff)
struct Vector<T> {
var x, y: T
}
extension Vector: Equatable where T: Equatable {}
extension Vector: AdditiveArithmetic where T: AdditiveArithmetic {
static var zero: Self { fatalError() }
static func + (lhs: Self, rhs: Self) -> Self { fatalError() }
static func - (lhs: Self, rhs: Self) -> Self { fatalError() }
}
extension Vector: Differentiable where T: Differentiable {
struct TangentVector: Equatable, AdditiveArithmetic, Differentiable {
var x, y: T.TangentVector
static var zero: Self { fatalError() }
static func + (lhs: Self, rhs: Self) -> Self { fatalError() }
static func - (lhs: Self, rhs: Self) -> Self { fatalError() }
typealias TangentVector = Self
}
mutating func move(along direction: TangentVector) { fatalError() }
}
// expected-note@+1 2 {{found this candidate}}
func inferredConformancesGeneric<T, U>(_: @differentiable (Vector<T>) -> Vector<U>) {}
// expected-note @+5 2 {{found this candidate}}
// expected-error @+4 {{generic signature requires types 'Vector<T>' and 'Vector<T>.TangentVector' to be the same}}
// expected-error @+3 {{generic signature requires types 'Vector<U>' and 'Vector<U>.TangentVector' to be the same}}
// expected-error @+2 {{parameter type 'Vector<T>' does not conform to 'Differentiable' and satisfy 'Vector<T> == Vector<T>.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
// expected-error @+1 {{result type 'Vector<U>' does not conform to 'Differentiable' and satisfy 'Vector<U> == Vector<U>.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
func inferredConformancesGenericLinear<T, U>(_: @differentiable(linear) (Vector<T>) -> Vector<U>) {}
func nondiff(x: Vector<Int>) -> Vector<Int> {}
// TODO(diagnostics): Ambiguity notes for two following calls should talk about `T` and `U` both not conforming to `Differentiable`
// but we currently have to way to coalesce notes multiple fixes in to a single note.
// expected-error @+1 {{no exact matches in call to global function 'inferredConformancesGeneric'}}
inferredConformancesGeneric(nondiff)
// expected-error @+1 {{no exact matches in call to global function 'inferredConformancesGenericLinear'}}
inferredConformancesGenericLinear(nondiff)
func diff(x: Vector<Float>) -> Vector<Float> {}
inferredConformancesGeneric(diff) // okay!
func inferredConformancesGenericResult<T, U>() -> @differentiable (Vector<T>) -> Vector<U> {}
// expected-error @+4 {{generic signature requires types 'Vector<T>' and 'Vector<T>.TangentVector' to be the same}}
// expected-error @+3 {{generic signature requires types 'Vector<U>' and 'Vector<U>.TangentVector' to be the same}}
// expected-error @+2 {{parameter type 'Vector<T>' does not conform to 'Differentiable' and satisfy 'Vector<T> == Vector<T>.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
// expected-error @+1 {{result type 'Vector<U>' does not conform to 'Differentiable' and satisfy 'Vector<U> == Vector<U>.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
func inferredConformancesGenericResultLinear<T, U>() -> @differentiable(linear) (Vector<T>) -> Vector<U> {}
struct Linear<T> {
var x, y: T
}
extension Linear: Equatable where T: Equatable {}
extension Linear: AdditiveArithmetic where T: AdditiveArithmetic {}
extension Linear: Differentiable where T: Differentiable, T == T.TangentVector {
typealias TangentVector = Self
}
// expected-note @+1 2 {{found this candidate}}
func inferredConformancesGeneric<T, U>(_: @differentiable (Linear<T>) -> Linear<U>) {}
// expected-note @+1 2 {{found this candidate}}
func inferredConformancesGenericLinear<T, U>(_: @differentiable(linear) (Linear<T>) -> Linear<U>) {}
func nondiff(x: Linear<Int>) -> Linear<Int> {}
// expected-error @+1 {{no exact matches in call to global function 'inferredConformancesGeneric'}}
inferredConformancesGeneric(nondiff)
// expected-error @+1 {{no exact matches in call to global function 'inferredConformancesGenericLinear'}}
inferredConformancesGenericLinear(nondiff)
func diff(x: Linear<Float>) -> Linear<Float> {}
inferredConformancesGeneric(diff) // okay!
func inferredConformancesGenericResult<T, U>() -> @differentiable (Linear<T>) -> Linear<U> {}
func inferredConformancesGenericResultLinear<T, U>() -> @differentiable(linear) (Linear<T>) -> Linear<U> {}