blob: 879ad6317e99d91056404a17d3a3a75f7d15c1fe [file] [log] [blame]
// RUN: %target-sil-opt %s | %FileCheck %s
// RUN: %empty-directory(%t)
// RUN: %target-sil-opt %s -emit-sib -o %t/tmp.sib -module-name differentiable_function
// RUN: %target-sil-opt %t/tmp.sib -o %t/tmp.2.sib -module-name differentiable_function
// RUN: %target-sil-opt %t/tmp.2.sib -module-name differentiable_function | %FileCheck %s
sil_stage raw
import Swift
import Builtin
sil @examplefunc : $@convention(thin) (Float, Float, Float) -> Float
sil @examplemethod : $@convention(method) (Float, Float, Float) -> Float
// CHECK-LABEL: sil @test
sil @test : $@convention(thin) () -> () {
bb0:
%0 = function_ref @examplefunc : $@convention(thin) (Float, Float, Float) -> Float
%1 = differentiable_function [parameters 0 1 2] %0 : $@convention(thin) (Float, Float, Float) -> Float
// CHECK: %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float
%2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (Float, Float, Float) -> Float
%3 = differentiable_function [parameters 0] %0 : $@convention(thin) (Float, Float, Float) -> Float
// CHECK: %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float
%4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (Float, @nondiff Float, @nondiff Float) -> Float
%5 = function_ref @examplemethod : $@convention(method) (Float, Float, Float) -> Float
%6 = differentiable_function [parameters 0 1 2] %5 : $@convention(method) (Float, Float, Float) -> Float
// CHECK: %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float
%7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (Float, Float, Float) -> Float
%8 = differentiable_function [parameters 0] %5 : $@convention(method) (Float, Float, Float) -> Float
// CHECK: %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float
%9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (Float, @nondiff Float, @nondiff Float) -> Float
%ret = tuple ()
return %ret : $()
}
// The adjoint function emitted by the compiler. Parameter are a vector, as in
// vector-Jacobian products, and pullback values. The function is partially
// applied to a pullback struct to form a pullback, which takes a vector and
// returns vector-Jacobian products evaluated at the original parameter.
sil hidden @foo_adj : $@convention(thin) (Float, Float, Float) -> Float {
bb0(%0 : $Float, %1 : $Float, %2 : $Float):
return %2 : $Float
}
// The original function with an attribute that specifies the compiler-emitted pullback.
sil hidden [differentiable source 0 wrt 0] @foo : $@convention(thin) (Float) -> Float {
bb0(%0 : $Float):
return %0 : $Float
}
// The vector-Jacobian product function, which returns the original result and a pullback.
sil hidden @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
%1 = function_ref @foo : $@convention(thin) (Float) -> Float
%2 = apply %1(%0) : $@convention(thin) (Float) -> Float
%3 = function_ref @foo_adj : $@convention(thin) (Float, Float, Float) -> Float
%4 = partial_apply [callee_guaranteed] %3(%0, %2) : $@convention(thin) (Float, Float, Float) -> Float
%5 = tuple (%2 : $Float, %4 : $@callee_guaranteed (Float) -> Float)
return %5 : $(Float, @callee_guaranteed (Float) -> Float)
}
sil @make_diff_func : $@convention(thin) () -> @differentiable @convention(thin) (Float) -> Float {
bb0:
%orig = function_ref @foo : $@convention(thin) (Float) -> Float
%undiffedFunc = differentiable_function [parameters 0] %orig : $@convention(thin) (Float) -> Float
%vjp = function_ref @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
%diffFunc = differentiable_function [parameters 0] %orig : $@convention(thin) (Float) -> Float with_derivative {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
%extractedVJP = differentiable_function_extract [vjp] %diffFunc : $@differentiable @convention(thin) (Float) -> Float
%extractedOriginal = differentiable_function_extract [original] %diffFunc : $@differentiable @convention(thin) (Float) -> Float
return %undiffedFunc : $@differentiable @convention(thin) (Float) -> Float
}
// CHECK-LABEL: @make_diff_func : $@convention(thin) () -> @differentiable @convention(thin) (Float) -> Float
// CHECK: [[FOO:%.*]] = function_ref @foo : $@convention(thin) (Float) -> Float
// CHECK: [[UNDIFFED_FOO:%.*]] = differentiable_function [parameters 0] [[FOO]] : $@convention(thin) (Float) -> Float
// CHECK: [[FOO_VJP:%.*]] = function_ref @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// CHECK: [[DIFFED_FOO:%.*]] = differentiable_function [parameters 0] [[FOO]] : $@convention(thin) (Float) -> Float with_derivative {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), [[FOO_VJP]] : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
// CHECK: [[EXTRACTED_VJP:%.*]] = differentiable_function_extract [vjp] [[DIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float
// CHECK: [[EXTRACTED_ORIG:%.*]] = differentiable_function_extract [original] [[DIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float
// CHECK: return [[UNDIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float