[AutoDiff] Improve documentation for derivative function SILGen. (#27753)
Improve documentation and naming for derivative function SILGen.
SILGen generates thunks for derivative functions registered via
`@differentiable` and `@differentiating` attributes.
Currently, two SILGen derivative thunk kinds exist:
- `SILGenModule::getOrCreateAutoDiffDerivativeForwardingThunk`
- This creates a simple thunk that forwards arguments and returns
results. This is generated when no reabstraction or self reordering
is necessary.
- `SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk`
- This creates a thunk that performs reabstraction and/or self reordering.
diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp
index 93e346a..15ec27f 100644
--- a/lib/SILGen/SILGen.cpp
+++ b/lib/SILGen/SILGen.cpp
@@ -848,14 +848,14 @@
SILFunction *derivativeThunk;
if (reorderSelf ||
derivative->getLoweredFunctionType() != expectedDerivativeType) {
- derivativeThunk = getOrCreateAutoDiffDerivativeFunctionThunk(
+ derivativeThunk = getOrCreateAutoDiffDerivativeReabstractionThunk(
originalFunction, indices, derivative, kind, reorderSelf);
} else {
// Note: `AutoDiffDerivativeFunctionIdentifier` must be constructed with
// the AST-level parameter indices, not the SIL-level ones.
auto *id = AutoDiffDerivativeFunctionIdentifier::get(
kind, config.parameterIndices, getASTContext());
- derivativeThunk = getOrCreateAutoDiffThunk(
+ derivativeThunk = getOrCreateAutoDiffDerivativeForwardingThunk(
SILDeclRef(originalAFD).asAutoDiffDerivativeFunction(id), derivative,
expectedDerivativeType);
}
diff --git a/lib/SILGen/SILGen.h b/lib/SILGen/SILGen.h
index 879fa64..0d3291d 100644
--- a/lib/SILGen/SILGen.h
+++ b/lib/SILGen/SILGen.h
@@ -148,11 +148,13 @@
CanSILFunctionType constantTy);
// SWIFT_ENABLE_TENSORFLOW
- /// Get or create an autodiff derivative function thunk for the given
- /// SILDeclRef, SILFunction, and derivative function type.
- SILFunction *getOrCreateAutoDiffThunk(SILDeclRef derivativeFnRef,
- SILFunction *derivativeFn,
- CanSILFunctionType derivativeFnTy);
+ /// Get or create an autodiff derivative function forwarding thunk for the
+ /// given derivative SILDeclRef, SILFunction, and function type.
+ /// The thunk simply forwards arguments and returns results: use this when no
+ /// reabstraction or self reordering is necessary.
+ SILFunction *getOrCreateAutoDiffDerivativeForwardingThunk(
+ SILDeclRef derivativeFnRef, SILFunction *derivativeFn,
+ CanSILFunctionType derivativeFnTy);
// SWIFT_ENABLE_TENSORFLOW
/// Get or create an autodiff derivative function vtable entry thunk for the
@@ -180,12 +182,51 @@
CanType dynamicSelfType);
// SWIFT_ENABLE_TENSORFLOW
- /// Get or create a thunk for reabstracting user-defined JVP/VJP functions.
+ /// Get or create an autodiff derivative function thunk performing
+ /// reabstraction and/or self-reordering.
+ ///
+ /// Self-reordering is done for canonicalizing the types of derivative
+ /// functions for instance methods wrt self. We want users to define
+ /// derivatives with the following AST function types:
+ ///
+ /// JVP:
+ /// - Takes `Self` as first parameter.
+ /// - Returns differential taking `Self.Tan` as first parameter.
+ ///
+ /// (Self) -> (T, ...) -> (R, (Self.Tan, T.Tan, ...) -> R.Tan)
+ ///
+ /// VJP:
+ /// - Takes `Self` as first parameter.
+ /// - Returns pullback returning `Self.Tan` as first result.
+ ///
+ /// (Self) -> (T, ...) -> (R, (R.Tan) -> (Self.Tan, T.Tan, ...))
+ ///
+ /// However, the curried `Self` parameter in the AST JVP/VJP function types
+ /// becomes the *last* parameter in the flattened parameter list of their
+ /// lowered SIL function types.
+ ///
+ /// JVP:
+ /// - Takes `Self` as *last* parameter.
+ /// - Returns differential taking `Self.Tan` as *first* parameter.
+ ///
+ /// $(T, ..., Self) -> (R, (Self.Tan, T.Tan, ...) -> R.Tan)
+ ///
+ /// VJP:
+ /// - Takes `Self` as *last* parameter.
+ /// - Returns pullback returning `Self.Tan` as *first* result.
+ ///
+ /// $(T, ..., Self) -> (R, (R.Tan) -> (Self.Tan, T.Tan, ...))
+ ///
+ /// This leads to a parameter ordering inconsistency, and would require the
+ /// Differentiation transform to handle "wrt self instance method derivatives"
+ /// specially. However, canonicalization during SILGen makes the parameter
+ /// ordering uniform for "wrt self instance method derivatives" and simplifies
+ /// the transform rules.
///
/// If `reorderSelf` is true, reorder self so that it appears as:
/// - The last parameter in the returned differential.
/// - The last result in the returned pullback.
- SILFunction *getOrCreateAutoDiffDerivativeFunctionThunk(
+ SILFunction *getOrCreateAutoDiffDerivativeReabstractionThunk(
SILFunction *original, SILAutoDiffIndices &indices,
SILFunction *derivativeFn,
AutoDiffDerivativeFunctionKind derivativeFnKind, bool reorderSelf);
diff --git a/lib/SILGen/SILGenPoly.cpp b/lib/SILGen/SILGenPoly.cpp
index da29d3f..148722e 100644
--- a/lib/SILGen/SILGenPoly.cpp
+++ b/lib/SILGen/SILGenPoly.cpp
@@ -3665,7 +3665,7 @@
// SWIFT_ENABLE_TENSORFLOW
SILFunction *
-SILGenModule::getOrCreateAutoDiffDerivativeFunctionThunk(
+SILGenModule::getOrCreateAutoDiffDerivativeReabstractionThunk(
SILFunction *original, SILAutoDiffIndices &indices,
SILFunction *derivativeFn, AutoDiffDerivativeFunctionKind derivativeFnKind,
bool reorderSelf) {
diff --git a/lib/SILGen/SILGenThunk.cpp b/lib/SILGen/SILGenThunk.cpp
index 1060e13..5bb9e18 100644
--- a/lib/SILGen/SILGenThunk.cpp
+++ b/lib/SILGen/SILGenThunk.cpp
@@ -72,9 +72,9 @@
// SWIFT_ENABLE_TENSORFLOW
SILFunction *
-SILGenModule::getOrCreateAutoDiffThunk(SILDeclRef derivativeFnDeclRef,
- SILFunction *derivativeFn,
- CanSILFunctionType derivativeFnTy) {
+SILGenModule::getOrCreateAutoDiffDerivativeForwardingThunk(
+ SILDeclRef derivativeFnDeclRef, SILFunction *derivativeFn,
+ CanSILFunctionType derivativeFnTy) {
auto *autoDiffFuncId =
derivativeFnDeclRef.autoDiffDerivativeFunctionIdentifier;
assert(autoDiffFuncId);