blob: b509583e6e4bef20e26262916e70313d18a61a08 [file] [log] [blame]
// RUN: %target-swift-frontend -emit-sil -verify -verify-ignore-unknown %s
// FIXME(TF-201): Remove `-verify-ignore-unknown`. This is currently necessary
// due to direct differentiation of reabstraction thunks, which emits errors
// with unknown location.
@differentiable
func generic<T: Differentiable & FloatingPoint>(_ x: T) -> T {
// expected-error @+2 {{expression is not differentiable}}
// expected-note @+1 {{member is not differentiable because the corresponding protocol requirement is not '@differentiable'}}
return x + 1
}
_ = gradient(at: 1.0, in: generic)
// Test unmet generic requirements.
@differentiable(
vjp: vjpWeirdExtraRequirements
where T : Differentiable & CaseIterable, T.AllCases : ExpressibleByStringLiteral
)
func weird<T>(_ x: T) -> T {
return x
}
func vjpWeirdExtraRequirements<
T : Differentiable & CaseIterable
>(_ x: T) -> (T, (T.TangentVector) -> T.TangentVector)
where T.AllCases : ExpressibleByStringLiteral
{
return (x, { $0 })
}
func weirdWrapper<T : Differentiable>(_ x: T) -> T {
// expected-error @+2 {{expression is not differentiable}}
// expected-note @+1 {{function call is not differentiable because generic requirements are not met: 'T : CaseIterable, T.AllCases : ExpressibleByStringLiteral'}}
return weird(x)
}
_ = gradient(at: Float(1), in: { x in weirdWrapper(x) })
@differentiable
func direct<T : Differentiable>(_ x: T) -> T {
return x
}
struct Tensor<Scalar> {
static func + (_ lhs: Tensor, rhs: Scalar) -> Tensor { return lhs }
}
extension Tensor : Differentiable where Scalar : Differentiable & FloatingPoint {}
extension Tensor where Scalar : BinaryFloatingPoint {
@differentiable(wrt: (self) where Scalar : Differentiable)
func TF_6(_ x: Float) -> Tensor {
return self + Scalar(x)
}
}
protocol TF8Proto : Differentiable {
associatedtype Scalar
@differentiable(wrt: (self, input))
func applied(to input: Float) -> Float
}
struct TF8Struct<Scalar> : TF8Proto where Scalar : FloatingPoint & Differentiable {
@noDerivative let bar: Scalar
@differentiable(wrt: (self, input))
func applied(to input: Float) -> Float {
return input
}
}
_ = gradient(at: 1.0, in: { x in x.squareRoot() })
//===----------------------------------------------------------------------===//
// Non-differentiable arguments and results
//===----------------------------------------------------------------------===//
struct TF_687<T> : Differentiable {
@noDerivative var indirectDummy: T
var base: Float
init(_ base: Float, dummy: T) {
self.base = base
self.indirectDummy = dummy
}
}
// expected-error @+2 {{function is not differentiable}}
// expected-note @+1 {{cannot differentiate through a non-differentiable argument; do you want to use 'withoutDerivative(at:)'?}}
let _: @differentiable (Float) -> TF_687<Any> = { x in TF_687<Any>(x, dummy: x) }
//===----------------------------------------------------------------------===//
// Add `Differentiable` conformance for generic wrt parameters
//===----------------------------------------------------------------------===//
func id<T>(_ x: T) -> T { x }
let _: @differentiable (Float) -> Float = { x in id(x) }
struct TF_691<Scalar> {
var x: Scalar
init(_ x: Scalar) {
self.x = x
}
}
extension TF_691: Differentiable where Scalar: Differentiable {}
func identity<T>(_ x: TF_691<T>) -> TF_691<T> { x }
let _: @differentiable (Float) -> TF_691<Float> = { x in identity(TF_691(x)) }
let _: @differentiable (Float) -> TF_691<Float> = { x in id(TF_691(x)) }