| // RUN: %target-swift-frontend -emit-ir %s | %FileCheck %s |
| |
| sil_stage raw |
| |
| import Swift |
| import Builtin |
| |
| // The adjoint function emitted by the compiler. Parameters are a vector, as in |
| // vector-Jacobian products, and pullback struct value. The function is not |
| // itself a pullback, but to be partially applied to form a pullback, which |
| // takes a vector and returns vector-Jacobian products evaluated at the |
| // original parameter. |
| sil hidden @foo_adj : $@convention(thin) (Float, Float, Float) -> Float { |
| bb0(%0 : $Float, %1 : $Float, %2 : $Float): |
| return %0 : $Float |
| } |
| |
| // The original function with an attribute that specifies the compiler-emitted pullback. |
| sil hidden [differentiable source 0 wrt 0 vjp @foo_vjp] @foo : $@convention(thin) (Float) -> Float { |
| bb0(%0 : $Float): |
| return %0 : $Float |
| } |
| |
| // The vector-Jacobian product function, which returns the original result and a pullback. |
| sil hidden @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { |
| bb0(%0 : $Float): |
| %1 = function_ref @foo : $@convention(thin) (Float) -> Float |
| %2 = apply %1(%0) : $@convention(thin) (Float) -> Float |
| %3 = function_ref @foo_adj : $@convention(thin) (Float, Float, Float) -> Float |
| %4 = partial_apply [callee_guaranteed] %3(%0, %2) : $@convention(thin) (Float, Float, Float) -> Float |
| %5 = tuple (%2 : $Float, %4 : $@callee_guaranteed (Float) -> Float) |
| return %5 : $(Float, @callee_guaranteed (Float) -> Float) |
| } |
| |
| sil @make_diff_func : $@convention(thin) () -> (@convention(thin) (Float) -> Float, @convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) { |
| bb0: |
| %orig = function_ref @foo : $@convention(thin) (Float) -> Float |
| %vjp = function_ref @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) |
| %diffFunc = differentiable_function [parameters 0] %orig : $@convention(thin) (Float) -> Float with_derivative {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)} |
| %extractedOrig = differentiable_function_extract [original] %diffFunc : $@differentiable @convention(thin) (Float) -> Float |
| %extractedVJP = differentiable_function_extract [vjp] %diffFunc : $@differentiable @convention(thin) (Float) -> Float |
| %tuple = tuple (%extractedOrig : $@convention(thin) (Float) -> Float, %extractedVJP : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) |
| return %tuple : $(@convention(thin) (Float) -> Float, @convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) |
| } |
| |
| sil @caller : $@convention(thin) () -> () { |
| bb0: |
| %f = function_ref @make_diff_func : $@convention(thin) () -> (@convention(thin) (Float) -> Float, @convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) |
| %tuple = apply %f() : $@convention(thin) () -> (@convention(thin) (Float) -> Float, @convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) |
| %vjp = tuple_extract %tuple : $(@convention(thin) (Float) -> Float, @convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)), 1 |
| %res = apply %vjp(undef) : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) |
| %ret = tuple () |
| return %ret : $() |
| } |
| |
| // CHECK-LABEL: swiftcc { i8*, i8* } @make_diff_func() |
| // CHECK-NEXT: entry: |
| // CHECK-NEXT: ret { i8*, i8* } { i8* bitcast (float (float)* @foo to i8*), i8* bitcast ({ float, i8*, %swift.refcounted* } (float)* @foo_vjp to i8*) } |
| |
| // CHECK-LABEL: swiftcc void @caller() |
| // CHECK-NEXT: entry: |
| // CHECK-NEXT: [[RESULT_TUPLE:%.*]] = call swiftcc { i8*, i8* } @make_diff_func() |
| // CHECK: [[VJP:%.*]] = extractvalue { i8*, i8* } [[RESULT_TUPLE]], 1 |
| // CHECK: [[VJP_TYPED:%.*]] = bitcast i8* [[VJP]] to { float, i8*, %swift.refcounted* } (float)* |
| // CHECK: call swiftcc { float, i8*, %swift.refcounted* } [[VJP_TYPED]](float undef) |