blob: 601f1907ed26527bdcc889edb1e69a787d26457e [file] [log] [blame]
// RUN: %target-swift-frontend -emit-ir %s | %FileCheck %s
sil_stage raw
import Swift
import Builtin
// The adjoint function emitted by the compiler. Parameters are a vector, as in
// vector-Jacobian products, and pullback struct value. The function is not
// itself a pullback, but to be partially applied to form a pullback, which
// takes a vector and returns vector-Jacobian products evaluated at the
// original parameter.
sil hidden @foo_adj : $@convention(thin) (Float, Float, Float) -> Float {
bb0(%0 : $Float, %1 : $Float, %2 : $Float):
return %0 : $Float
}
// The original function with an attribute that specifies the compiler-emitted pullback.
sil hidden [differentiable source 0 wrt 0 vjp @foo_vjp] @foo : $@convention(thin) (Float) -> Float {
bb0(%0 : $Float):
return %0 : $Float
}
// The vector-Jacobian product function, which returns the original result and a pullback.
sil hidden @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
%1 = function_ref @foo : $@convention(thin) (Float) -> Float
%2 = apply %1(%0) : $@convention(thin) (Float) -> Float
%3 = function_ref @foo_adj : $@convention(thin) (Float, Float, Float) -> Float
%4 = partial_apply [callee_guaranteed] %3(%0, %2) : $@convention(thin) (Float, Float, Float) -> Float
%5 = tuple (%2 : $Float, %4 : $@callee_guaranteed (Float) -> Float)
return %5 : $(Float, @callee_guaranteed (Float) -> Float)
}
sil @make_diff_func : $@convention(thin) () -> (@convention(thin) (Float) -> Float, @convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) {
bb0:
%orig = function_ref @foo : $@convention(thin) (Float) -> Float
%vjp = function_ref @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
%diffFunc = differentiable_function [parameters 0] %orig : $@convention(thin) (Float) -> Float with_derivative {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
%extractedOrig = differentiable_function_extract [original] %diffFunc : $@differentiable @convention(thin) (Float) -> Float
%extractedVJP = differentiable_function_extract [vjp] %diffFunc : $@differentiable @convention(thin) (Float) -> Float
%tuple = tuple (%extractedOrig : $@convention(thin) (Float) -> Float, %extractedVJP : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float))
return %tuple : $(@convention(thin) (Float) -> Float, @convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float))
}
sil @caller : $@convention(thin) () -> () {
bb0:
%f = function_ref @make_diff_func : $@convention(thin) () -> (@convention(thin) (Float) -> Float, @convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float))
%tuple = apply %f() : $@convention(thin) () -> (@convention(thin) (Float) -> Float, @convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float))
%vjp = tuple_extract %tuple : $(@convention(thin) (Float) -> Float, @convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)), 1
%res = apply %vjp(undef) : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
%ret = tuple ()
return %ret : $()
}
// CHECK-LABEL: swiftcc { i8*, i8* } @make_diff_func()
// CHECK-NEXT: entry:
// CHECK-NEXT: ret { i8*, i8* } { i8* bitcast (float (float)* @foo to i8*), i8* bitcast ({ float, i8*, %swift.refcounted* } (float)* @foo_vjp to i8*) }
// CHECK-LABEL: swiftcc void @caller()
// CHECK-NEXT: entry:
// CHECK-NEXT: [[RESULT_TUPLE:%.*]] = call swiftcc { i8*, i8* } @make_diff_func()
// CHECK: [[VJP:%.*]] = extractvalue { i8*, i8* } [[RESULT_TUPLE]], 1
// CHECK: [[VJP_TYPED:%.*]] = bitcast i8* [[VJP]] to { float, i8*, %swift.refcounted* } (float)*
// CHECK: call swiftcc { float, i8*, %swift.refcounted* } [[VJP_TYPED]](float undef)