// RUN: %target-swift-emit-sil -verify %s | %FileCheck %s -check-prefix=CHECK-SIL
import _Differentiation
func identity<T : Differentiable>(_ x: T) -> T {
return x
_ = gradient(at: Float(1), in: { x in identity(x) })
// Test PullbackCloner local buffer allocation.
// Verify that local buffers are immediately set to zero.
// CHECK-SIL-LABEL: sil private @AD__identity__pullback_src_0_wrt_0_{{16_Differentiation|s}}14DifferentiableRzl
// CHECK-SIL: [[ORIG_COTAN:%.*]] = alloc_stack $τ_0_0.TangentVector
// CHECK-SIL-NEXT: [[ZERO_WITNESS:%.*]] = witness_method $τ_0_0.TangentVector,!getter
// CHECK-SIL-NEXT: [[ORIG_COTAN_METATYPE:%.*]] = metatype $@thick τ_0_0.TangentVector.Type
// Test TF-201: differentiate direct references to generic function.
// This involves reabstraction thunk differentiation.
_ = gradient(at: Float(1), in: identity)
protocol DifferentiableAdditiveArithmetic: Differentiable & AdditiveArithmetic {
static func + (lhs: Self, rhs: Self) -> Self
extension Float: DifferentiableAdditiveArithmetic {}
func generic<T: DifferentiableAdditiveArithmetic>(_ x: T) -> T {
x + x + x
_ = gradient(at: Float(10), in: generic)
struct Wrapper<Scalar : Differentiable> : Differentiable {
var value: Scalar
init(_ value: Scalar) { self.value = value }
func generic<T>(_ x: Wrapper<T>) -> T {
return x.value
_ = gradient(at: Wrapper<Float>(1), in: generic)
func generic2<T: Differentiable, U: Differentiable>(_ x: T, _ y: Float, _ z: U) -> T {
return x
func foo<T>(_ x: Wrapper<T>) {
_ = gradient(at: Float(1), 2, x, in: generic2)
// Test case where associated derivative function's requirements are met.
extension Wrapper where Scalar : Numeric {
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
func mean() -> Wrapper {
return self
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
func variance() -> Wrapper {
return mean() // ok
_ = pullback(at: Wrapper<Float>(1), in: { $0.variance() })
// Tests TF-277.
protocol Layer : Differentiable {
associatedtype Output : Differentiable
struct SupervisedTrainer<Model : Layer> {
var model: Model
var lossFunction: @differentiable (Model.Output, Model.Output) -> Float
func fit(y: Model.Output) {
_ = gradient(at: y) { y in return self.lossFunction(y, y) }
// Tests TF-440.
struct TF_440_Input<Input: Differentiable, State: Differentiable>
: Differentiable {
var input: Input
var state: State
struct TF_440<T : Differentiable> {
func applied(to input: TF_440_Input<Float, Float>) -> Float {
return input.state
func applied(to input: TF_440_Input<T, Float>) -> Float {
return input.state
func applied(to input: TF_440_Input<T, Float>) -> T {
return input.input
// Tests TF-508: differentiation requirements with dependent member types.
protocol TF_508_Proto {
associatedtype Scalar
extension TF_508_Proto where Scalar : FloatingPoint {
where Self : Differentiable, Scalar : Differentiable,
// Conformance requirement with dependent member type.
Self.TangentVector : TF_508_Proto
static func +(lhs: Self, rhs: Self) -> Self {
return lhs
where Self : Differentiable, Scalar : Differentiable,
// Same-type requirement with dependent member type.
Self.TangentVector == Float
static func -(lhs: Self, rhs: Self) -> Self {
return lhs
extension TF_508_Proto where Self : Differentiable,
Scalar : FloatingPoint & Differentiable,
Self.TangentVector : TF_508_Proto {
@derivative(of: +)
static func vjpAdd(lhs: Self, rhs: Self)
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (lhs, { v in (v, v) })
extension TF_508_Proto where Self : Differentiable,
Scalar : FloatingPoint & Differentiable,
Self.TangentVector == Float {
@derivative(of: -)
static func vjpSubtract(lhs: Self, rhs: Self)
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (lhs, { v in (v, v) })
struct TF_508_Struct<Scalar : AdditiveArithmetic>
: TF_508_Proto, AdditiveArithmetic {}
extension TF_508_Struct : Differentiable where Scalar : Differentiable {
typealias TangentVector = TF_508_Struct
func TF_508() {
let x = TF_508_Struct<Float>()
// Test conformance requirement with dependent member type.
_ = pullback(at: x, in: { (x: TF_508_Struct<Float>) -> TF_508_Struct<Float> in
return x + x
// Test same-type requirement with dependent member type.
_ = pullback(at: x, in: { (x: TF_508_Struct<Float>) -> TF_508_Struct<Float> in
return x - x
// TF-523
struct TF_523_Struct : Differentiable & AdditiveArithmetic {
var a: Float = 1
typealias TangentVector = TF_523_Struct
func TF_523_f(_ x: TF_523_Struct) -> Float {
return x.a * 2
// TF-534: Thunk substitution map remapping.
protocol TF_534_Layer : Differentiable {
associatedtype Input : Differentiable
associatedtype Output : Differentiable
func callAsFunction(_ input: Input) -> Output
struct TF_534_Tensor<Scalar> : Differentiable {}
func TF_534<Model: TF_534_Layer>(
_ model: inout Model, inputs: Model.Input
) -> TF_534_Tensor<Float> where Model.Output == TF_534_Tensor<Float> {
return valueWithPullback(at: model) { model -> Model.Output in
return model(inputs)
// TF-546: Test that SILGen linear map thunk performs correct reabstraction.
struct TF_546<T: FloatingPoint>: AdditiveArithmetic {
var real: T
var imaginary: T
@differentiable(where T: Differentiable, T == T.TangentVector)
init(real: T = 0, imaginary: T = 0) {
self.real = real
self.imaginary = imaginary
extension TF_546: Differentiable where T: Differentiable {
typealias TangentVector = TF_546
extension TF_546 where T: Differentiable, T == T.TangentVector {
@derivative(of: init)
static func _vjpInit(real: T, imaginary: T) -> (value: TF_546, pullback: (TF_546) -> (T, T)) {
return (TF_546(real: real, imaginary: imaginary), { ($0.real, $0.imaginary) })
let _: @differentiable(Float, Float) -> TF_546<Float> = { r, i in
TF_546(real: r, imaginary: i)
// TF-652: Test VJPCloner substitution map generic signature.
// The substitution map should have the VJP's generic signature, not the
// original function's.
struct TF_652<Scalar> {}
extension TF_652 : Differentiable where Scalar : FloatingPoint {}
@differentiable(wrt: x where Scalar: FloatingPoint)
func test<Scalar: Numeric>(x: TF_652<Scalar>) -> TF_652<Scalar> {
for _ in 0..<10 {
let _ = x
return x
// TF-682: Test that SILGen linear map thunk performs correct reabstraction.
protocol TF_682_Proto {
associatedtype Scalar
extension TF_682_Proto where Scalar : FloatingPoint {
where Self : Differentiable, Scalar : Differentiable,
// Same-type requirement with dependent member type.
Self.TangentVector == Float
func foo(lhs: Self) -> Self {
return lhs
extension TF_682_Proto where Self : Differentiable,
Scalar : FloatingPoint & Differentiable,
Self.TangentVector == Float {
@derivative(of: foo)
func vjpFoo(lhs: Self) -> (
value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)
) {
return (lhs, { v in (v, v) })
// NOTE(TF-1208): Differentiation regression due to changes in curry thunk generation.
// TF-688: Test generic curry thunk cloning.
public struct TF_688_Struct<Scalar> {
var x: Scalar
extension TF_688_Struct: Differentiable where Scalar: Differentiable {
public static func id(x: Self) -> Self {
return x
@differentiable(wrt: x)
public func TF_688<Scalar: Differentiable>(
_ x: TF_688_Struct<Scalar>,
reduction: @differentiable (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> =
) -> TF_688_Struct<Scalar> {
// TF-697: Test generic requirements of generated derivative function.
protocol TF_697_Module: Differentiable {
associatedtype Input
associatedtype Output: Differentiable
@differentiable(wrt: self)
func callModule(_ input: Input) -> Output
protocol TF_697_Layer: TF_697_Module where Input: Differentiable {
func callLayer(_ input: Input) -> Output
struct TF_697_Sequential<Layer1: TF_697_Module, Layer2: TF_697_Layer>: TF_697_Module
where Layer1.Output == Layer2.Input {
var layer1: Layer1
var layer2: Layer2
@differentiable(wrt: self)
func callModule(_ input: Layer1.Input) -> Layer2.Output {
extension TF_697_Sequential: TF_697_Layer where Layer1: TF_697_Layer {
func callLayer(_ input: Layer1.Input) -> Layer2.Output {
// TF-817: Test remapping `apply` callee types in derivative function context.
struct TF_817<T> {
func foo(_ index: Int) -> T {
extension TF_817: Differentiable where T: Differentiable {
@derivative(of: foo)
func vjpFoo(index: Int) -> (value: T, pullback: (T.TangentVector) -> (TangentVector)) {
extension TF_817 {
@differentiable(wrt: self where T: Differentiable)
public func test(index: Int) -> T {
return // crash happened here
// TF-886: Test `partial_apply` of linear map subset parameters thunk.
func TF_886_foo<T, U: Differentiable>(_: Float, _: T, _: U) -> Float {
return 0
func TF_886_bar<T>(x: Float, y: T) -> Float {
return TF_886_foo(x, y, 0)
// Test layout requirements.
// The layout requirement is "contextual": the requirement is not on `T`, the
// differentiable function parameter/result type.
struct ContextualLayoutRequirement<T: Differentiable, U: AnyObject> {
var stored: T
extension ContextualLayoutRequirement {
func test(_ x: T) {
let _: @differentiable (T) -> T = { _ in self.stored }
let _: @differentiable (T) -> T = { $0 }
// The layout requirement directly involves `T`, the differentiable function
// parameter/result type.
// TODO(TF-851): Uncomment the tests below after `@differentiable` function
// SILGen thunking is fixed.
struct LayoutRequirement<T: AnyObject & Differentiable> {
var stored: T
extension LayoutRequirement {
func test(_ x: T) {
let _: @differentiable (T) -> T = { _ in self.stored }
let _: @differentiable (T) -> T = { $0 }
// Test superclass requirements.
class Super: Differentiable {}
// The superclass requirement is "contextual": the requirement is not on `T`,
// the differentiable function parameter/result type.
struct ContextualSuperclassRequirement<T: Differentiable, U: Super> {
var stored: T
extension ContextualSuperclassRequirement {
func test(_ x: T) {
let _: @differentiable (T) -> T = { _ in self.stored }
let _: @differentiable (T) -> T = { $0 }
// The superclass requirement directly involves `T`, the differentiable
// function parameter/result type.
// TODO(TF-851): Uncomment the tests below after `@differentiable` function
// SILGen thunking is fixed.
struct SuperclassRequirement<T: Super & Differentiable> {
var stored: T
extension SuperclassRequirement {
func test(_ x: T) {
let _: @differentiable (T) -> T = { _ in self.stored }
let _: @differentiable (T) -> T = { $0 }