| // RUN: %target-swift-frontend -dump-ast %s | %FileCheck %s -check-prefix=CHECK-AST |
| // RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s -check-prefix=CHECK-SILGEN |
| |
| //===----------------------------------------------------------------------===// |
| // Closure conversion |
| //===----------------------------------------------------------------------===// |
| |
| func thin(x: Float) -> Float { return x } |
| |
| func myfunction(_ f: @escaping @differentiable (Float) -> (Float)) -> (Float) -> Float { |
| // @differentiable functions should be callable. |
| _ = f(.zero) |
| return f |
| } |
| |
| func myfunction2(_ f: @escaping @differentiable(linear) (Float) -> (Float)) -> (Float) -> Float { |
| // @differentiable(linear) functions should be callable. |
| _ = f(.zero) |
| return f |
| } |
| |
| var global_f: @differentiable (Float) -> Float = {$0} |
| var global_f_linear: @differentiable(linear) (Float) -> Float = {$0} |
| |
| func calls_global_f() { |
| _ = global_f(10) |
| // TODO(TF-900, TF-902): Uncomment the following line to test loading a linear function from memory and direct calls to a linear function. |
| // _ = global_f_linear(10) |
| } |
| |
| func apply() { |
| _ = myfunction(thin) |
| _ = myfunction2(thin) |
| } |
| |
| // CHECK-AST-LABEL: (func_decl {{.*}} "myfunction(_:)" |
| // CHECK-AST: (call_expr type='(Float)' |
| // CHECK-AST: (declref_expr type='@differentiable (Float) -> (Float)' |
| // CHECK-AST: (return_stmt |
| // CHECK-AST: (function_conversion_expr implicit type='(Float) -> Float' |
| // CHECK-AST: (differentiable_function_extract_original implicit type='(Float) -> (Float)' |
| // CHECK-AST: (declref_expr type='@differentiable (Float) -> (Float)' |
| // CHECK-AST-LABEL: (func_decl {{.*}} "apply()" |
| // CHECK-AST: (function_conversion_expr implicit type='@differentiable (Float) -> (Float)' |
| // CHECK-AST: (differentiable_function implicit type='@differentiable (Float) -> Float' |
| // CHECK-AST: (declref_expr type='(Float) -> Float' |
| |
| // CHECK-SILGEN-LABEL: @{{.*}}myfunction{{.*}} |
| // CHECK-SILGEN: bb0([[DIFF:%.*]] : @guaranteed $@differentiable @callee_guaranteed (Float) -> Float): |
| // CHECK-SILGEN: [[COPIED_DIFF:%.*]] = copy_value [[DIFF]] : $@differentiable @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: [[BORROWED_DIFF:%.*]] = begin_borrow [[COPIED_DIFF]] : $@differentiable @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: apply [[BORROWED_DIFF]]({{%.*}}) : $@differentiable @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: end_borrow [[BORROWED_DIFF]] : $@differentiable @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: destroy_value [[COPIED_DIFF]] : $@differentiable @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: [[COPIED_DIFF:%.*]] = copy_value [[DIFF]] : $@differentiable @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: [[BORROWED_DIFF:%.*]] = begin_borrow [[COPIED_DIFF]] : $@differentiable @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: [[BORROWED_ORIG:%.*]] = differentiable_function_extract [original] [[BORROWED_DIFF]] : $@differentiable @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: [[COPIED_ORIG:%.*]] = copy_value [[BORROWED_ORIG]] : $@callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: return [[COPIED_ORIG]] : $@callee_guaranteed (Float) -> Float |
| |
| // CHECK-SILGEN-LABEL: @{{.*}}myfunction2{{.*}} |
| // CHECK-SILGEN: bb0([[LIN:%.*]] : @guaranteed $@differentiable(linear) @callee_guaranteed (Float) -> Float): |
| // CHECK-SILGEN: [[COPIED_LIN:%.*]] = copy_value [[LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: [[BORROWED_LIN:%.*]] = begin_borrow [[COPIED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: apply [[BORROWED_LIN]]({{%.*}}) : $@differentiable(linear) @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: end_borrow [[BORROWED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: [[COPIED_LIN:%.*]] = copy_value [[LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: [[BORROWED_LIN:%.*]] = begin_borrow [[COPIED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: [[BORROWED_ORIG:%.*]] = linear_function_extract [original] [[BORROWED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: [[COPIED_ORIG:%.*]] = copy_value [[BORROWED_ORIG]] : $@callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: end_borrow [[BORROWED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: destroy_value [[COPIED_LIN]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: return [[COPIED_ORIG]] : $@callee_guaranteed (Float) -> Float |
| |
| // CHECK-SILGEN-LABEL: @{{.*}}apply{{.*}} |
| // CHECK-SILGEN: [[ORIG:%.*]] = function_ref @{{.*}}thin{{.*}} : $@convention(thin) (Float) -> Float |
| // CHECK-SILGEN-NEXT: [[ORIG_THICK:%.*]] = thin_to_thick_function [[ORIG]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN-NEXT: [[DIFFED:%.*]] = differentiable_function [parameters 0] [results 0] [[ORIG_THICK]] : $@callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: [[ORIG:%.*]] = function_ref @{{.*}}thin{{.*}} : $@convention(thin) (Float) -> Float |
| // CHECK-SILGEN-NEXT: [[ORIG_THICK:%.*]] = thin_to_thick_function [[ORIG]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN-NEXT: [[LIN:%.*]] = linear_function [parameters 0] [[ORIG_THICK]] : $@callee_guaranteed (Float) -> Float |
| |
| //===----------------------------------------------------------------------===// |
| // Reabstraction |
| //===----------------------------------------------------------------------===// |
| |
| func pullback<T, R>( |
| at x: T, in f: @escaping @differentiable (T) -> R |
| ) -> (R.TangentVector) -> T.TangentVector { |
| fatalError() |
| } |
| |
| func appliesReabstraction(_ f: @escaping @differentiable (Float) -> Float) { |
| _ = pullback(at: .zero, in: f) |
| } |
| |
| // CHECK-SILGEN-LABEL: @{{.*}}appliesReabstraction{{.*}} |
| // CHECK-SILGEN: bb0([[DIFF_FUNC_ARG:%.*]] : @guaranteed $@differentiable @callee_guaranteed (Float) -> Float): |
| // CHECK-SILGEN: [[DIFF_FUNC:%.*]] = copy_value [[DIFF_FUNC_ARG]] : $@differentiable @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: [[DIFF_FUNC_BORROWED:%.*]] = begin_borrow [[DIFF_FUNC]] : $@differentiable @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: [[ORIG:%.*]] = differentiable_function_extract [original] [[DIFF_FUNC_BORROWED]] : $@differentiable @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: [[ORIG_COPY:%.*]] = copy_value [[ORIG]] : $@callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: [[REABS_ORIG:%.*]] = function_ref @$sS2fIegyd_S2fIegnr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> @out Float |
| // CHECK-SILGEN: [[NEW_ORIG:%.*]] = partial_apply [callee_guaranteed] [[REABS_ORIG]]([[ORIG_COPY]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> @out Float |
| // CHECK-SILGEN: [[NEW_ORIG_CONVERTED:%.*]] = convert_function [[NEW_ORIG]] : $@callee_guaranteed (@in_guaranteed Float) -> @out Float to $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float> |
| // CHECK-SILGEN: [[JVP:%.*]] = differentiable_function_extract [jvp] [[DIFF_FUNC_BORROWED]] : $@differentiable @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: [[JVP_COPY:%.*]] = copy_value [[JVP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) |
| // CHECK-SILGEN: [[REABS_JVP:%.*]] = function_ref @$sS4fIegyd_Iegydo_S2fxq_r0_lyS2fIsegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>) |
| // CHECK-SILGEN: [[NEW_JVP:%.*]] = partial_apply [callee_guaranteed] %19(%18) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>) |
| // CHECK-SILGEN: [[NEW_JVP_CONVERTED:%.*]] = convert_function [[NEW_JVP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>) to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float> |
| // CHECK-SILGEN: [[VJP:%.*]] = differentiable_function_extract [vjp] [[DIFF_FUNC_BORROWED]] : $@differentiable @callee_guaranteed (Float) -> Float |
| // CHECK-SILGEN: [[VJP_COPY:%.*]] = copy_value [[VJP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) |
| // CHECK-SILGEN: [[REABS_VJP:%.*]] = function_ref @$sS4fIegyd_Iegydo_S2fxq_r0_lyS2fIsegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>) |
| // CHECK-SILGEN: [[NEW_VJP:%.*]] = partial_apply [callee_guaranteed] [[REABS_VJP]]([[VJP_COPY]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>) |
| // CHECK-SILGEN: [[NEW_VJP_CONVERTED:%.*]] = convert_function [[NEW_VJP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float>) to $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float> |
| // CHECK-SILGEN: [[NEW_DIFF_FUNC:%.*]] = differentiable_function [parameters 0] [results 0] [[NEW_ORIG_CONVERTED]] : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <Float, Float> with_derivative {[[NEW_JVP_CONVERTED]] : $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>, [[NEW_VJP_CONVERTED]] : $@callee_guaranteed @substituted <τ_0_0, τ_0_1, τ_0_2, τ_0_3> (@in_guaranteed τ_0_0) -> (@out τ_0_1, @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_2, τ_0_3>) for <Float, Float, Float, Float>} |
| // CHECK-SILGEN: [[DIFF_API:%.*]] = function_ref @${{.*}}pullback{{.*}}at{{.*}} : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_0, τ_0_1>) -> @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_1.TangentVector, τ_0_0.TangentVector> |
| // CHECK-SILGEN: apply [[DIFF_API]]<Float, Float>({{.*}}, [[NEW_DIFF_FUNC]]) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_0, τ_0_1>) -> @owned @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> @out τ_0_1 for <τ_0_1.TangentVector, τ_0_0.TangentVector> |