| // RUN: %target-sil-opt %s | %FileCheck %s |
| |
| // RUN: %empty-directory(%t) |
| // RUN: %target-sil-opt %s -emit-sib -o %t/tmp.sib -module-name differentiable_function |
| // RUN: %target-sil-opt %t/tmp.sib -o %t/tmp.2.sib -module-name differentiable_function |
| // RUN: %target-sil-opt %t/tmp.2.sib -module-name differentiable_function | %FileCheck %s |
| |
| sil_stage raw |
| |
| import Swift |
| import Builtin |
| |
| sil @examplefunc : $@convention(thin) (Float, Float, Float) -> Float |
| sil @examplemethod : $@convention(method) (Float, Float, Float) -> Float |
| |
| // CHECK-LABEL: sil @test |
| sil @test : $@convention(thin) () -> () { |
| bb0: |
| %0 = function_ref @examplefunc : $@convention(thin) (Float, Float, Float) -> Float |
| %1 = differentiable_function [parameters 0 1 2] %0 : $@convention(thin) (Float, Float, Float) -> Float |
| |
| // CHECK: %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float |
| %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float |
| %3 = differentiable_function [parameters 0] %0 : $@convention(thin) (Float, Float, Float) -> Float |
| |
| // CHECK: %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float |
| %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float |
| %5 = function_ref @examplemethod : $@convention(method) (Float, Float, Float) -> Float |
| %6 = differentiable_function [parameters 0 1 2] %5 : $@convention(method) (Float, Float, Float) -> Float |
| |
| // CHECK: %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float |
| %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float |
| %8 = differentiable_function [parameters 0] %5 : $@convention(method) (Float, Float, Float) -> Float |
| |
| // CHECK: %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float |
| %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float |
| |
| %ret = tuple () |
| return %ret : $() |
| } |
| |
| // The adjoint function emitted by the compiler. Parameter are a vector, as in |
| // vector-Jacobian products, and pullback values. The function is partially |
| // applied to a pullback struct to form a pullback, which takes a vector and |
| // returns vector-Jacobian products evaluated at the original parameter. |
| sil hidden @foo_adj : $@convention(thin) (Float, Float, Float) -> Float { |
| bb0(%0 : $Float, %1 : $Float, %2 : $Float): |
| return %2 : $Float |
| } |
| |
| // The original function with an attribute that specifies the compiler-emitted pullback. |
| sil hidden [differentiable source 0 wrt 0] @foo : $@convention(thin) (Float) -> Float { |
| bb0(%0 : $Float): |
| return %0 : $Float |
| } |
| |
| // The vector-Jacobian product function, which returns the original result and a pullback. |
| sil hidden @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { |
| bb0(%0 : $Float): |
| %1 = function_ref @foo : $@convention(thin) (Float) -> Float |
| %2 = apply %1(%0) : $@convention(thin) (Float) -> Float |
| %3 = function_ref @foo_adj : $@convention(thin) (Float, Float, Float) -> Float |
| %4 = partial_apply [callee_guaranteed] %3(%0, %2) : $@convention(thin) (Float, Float, Float) -> Float |
| %5 = tuple (%2 : $Float, %4 : $@callee_guaranteed (Float) -> Float) |
| return %5 : $(Float, @callee_guaranteed (Float) -> Float) |
| } |
| |
| sil @make_diff_func : $@convention(thin) () -> @differentiable @convention(thin) (Float) -> Float { |
| bb0: |
| %orig = function_ref @foo : $@convention(thin) (Float) -> Float |
| %undiffedFunc = differentiable_function [parameters 0] %orig : $@convention(thin) (Float) -> Float |
| %vjp = function_ref @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) |
| %diffFunc = differentiable_function [parameters 0] %orig : $@convention(thin) (Float) -> Float with_derivative {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} |
| %extractedVJP = differentiable_function_extract [vjp] %diffFunc : $@differentiable @convention(thin) (Float) -> Float |
| %extractedOriginal = differentiable_function_extract [original] %diffFunc : $@differentiable @convention(thin) (Float) -> Float |
| return %undiffedFunc : $@differentiable @convention(thin) (Float) -> Float |
| } |
| |
| // CHECK-LABEL: @make_diff_func : $@convention(thin) () -> @differentiable @convention(thin) (Float) -> Float |
| // CHECK: [[FOO:%.*]] = function_ref @foo : $@convention(thin) (Float) -> Float |
| // CHECK: [[UNDIFFED_FOO:%.*]] = differentiable_function [parameters 0] [[FOO]] : $@convention(thin) (Float) -> Float |
| // CHECK: [[FOO_VJP:%.*]] = function_ref @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) |
| // CHECK: [[DIFFED_FOO:%.*]] = differentiable_function [parameters 0] [[FOO]] : $@convention(thin) (Float) -> Float with_derivative {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), [[FOO_VJP]] : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} |
| // CHECK: [[EXTRACTED_VJP:%.*]] = differentiable_function_extract [vjp] [[DIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float |
| // CHECK: [[EXTRACTED_ORIG:%.*]] = differentiable_function_extract [original] [[DIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float |
| // CHECK: return [[UNDIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float |