blob: 0cfdb1a8b52bfa32f2224a3c4abcfc1244d8ba64 [file] [log] [blame]
//===--- Differentiation.cpp - SIL Automatic Differentiation --*- C++ -*---===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2018 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file implements automatic differentiation.
//
//===----------------------------------------------------------------------===//
#define DEBUG_TYPE "differentiation"
#include "swift/AST/ASTMangler.h"
#include "swift/AST/ASTPrinter.h"
#include "swift/AST/AnyFunctionRef.h"
#include "swift/AST/AutoDiff.h"
#include "swift/AST/Builtins.h"
#include "swift/AST/DeclContext.h"
#include "swift/AST/DiagnosticsSIL.h"
#include "swift/AST/Expr.h"
#include "swift/AST/GenericEnvironment.h"
#include "swift/AST/GenericSignatureBuilder.h"
#include "swift/AST/LazyResolver.h"
#include "swift/AST/ParameterList.h"
#include "swift/AST/SourceFile.h"
#include "swift/AST/SubstitutionMap.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/SIL/FormalLinkage.h"
#include "swift/SIL/PrettyStackTrace.h"
#include "swift/SIL/SILBuilder.h"
#include "swift/SIL/TypeSubstCloner.h"
#include "swift/SILOptimizer/Analysis/DominanceAnalysis.h"
#include "swift/SILOptimizer/Differentiation/ADContext.h"
#include "swift/SILOptimizer/Differentiation/JVPCloner.h"
#include "swift/SILOptimizer/Differentiation/Thunk.h"
#include "swift/SILOptimizer/Differentiation/VJPCloner.h"
#include "swift/SILOptimizer/PassManager/Passes.h"
#include "swift/SILOptimizer/PassManager/Transforms.h"
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/BreadthFirstIterator.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/CommandLine.h"
using namespace swift;
using namespace swift::autodiff;
using llvm::DenseMap;
using llvm::SmallDenseMap;
using llvm::SmallDenseSet;
using llvm::SmallMapVector;
using llvm::SmallSet;
/// This flag enables experimental `@differentiable(linear)` function
/// transposition.
static llvm::cl::opt<bool> EnableExperimentalLinearMapTransposition(
"enable-experimental-linear-map-transposition", llvm::cl::init(false));
/// This flag is used to disable `differentiable_function_extract` instruction
/// folding for SIL testing purposes.
static llvm::cl::opt<bool> SkipFoldingDifferentiableFunctionExtraction(
"differentiation-skip-folding-differentiable-function-extraction",
llvm::cl::init(true));
//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//
/// Given a dumpable value, dumps it to `llvm::dbgs()`.
template <typename T> static inline void debugDump(T &v) {
LLVM_DEBUG(llvm::dbgs() << "\n==== BEGIN DEBUG DUMP ====\n"
<< v << "\n==== END DEBUG DUMP ====\n");
}
namespace {
class DifferentiationTransformer {
private:
/// Reference to the main transform.
SILModuleTransform &transform;
/// Context necessary for performing the transformations.
ADContext context;
/// Promotes the given `differentiable_function` instruction to a valid
/// `@differentiable` function-typed value.
SILValue promoteToDifferentiableFunction(DifferentiableFunctionInst *inst,
SILBuilder &builder, SILLocation loc,
DifferentiationInvoker invoker);
/// Given a `linear_function` instruction that is missing a transpose operand,
/// return a new `linear_function` instruction with the transpose filled in.
SILValue promoteToLinearFunction(LinearFunctionInst *inst,
SILBuilder &builder, SILLocation loc,
DifferentiationInvoker invoker);
public:
/// Construct an `DifferentiationTransformer` for the given module.
explicit DifferentiationTransformer(SILModuleTransform &transform)
: transform(transform), context(transform) {}
SILModuleTransform &getTransform() { return transform; }
ADContext &getContext() { return context; }
/// Canonicalize the given witness, filling in derivative functions if
/// missing.
///
/// Generated derivative functions have the same linkage as the witness.
///
/// \param serializeFunctions specifies whether generated functions should be
/// serialized.
bool canonicalizeDifferentiabilityWitness(
SILFunction *original, SILDifferentiabilityWitness *witness,
DifferentiationInvoker invoker, IsSerialized_t serializeFunctions);
/// Process the given `differentiable_function` instruction, filling in
/// missing derivative functions if necessary.
bool processDifferentiableFunctionInst(DifferentiableFunctionInst *dfi);
/// Process the given `linear_function` instruction, filling in the missing
/// transpose function if necessary.
bool processLinearFunctionInst(LinearFunctionInst *lfi);
/// Fold `differentiable_function_extract` users of the given
/// `differentiable_function` instruction, directly replacing them with
/// `differentiable_function` instruction operands. If the
/// `differentiable_function` instruction has no remaining uses, delete the
/// instruction itself after folding.
///
/// Folding can be disabled by the
/// `SkipFoldingDifferentiableFunctionExtraction` flag for SIL testing
/// purposes.
void foldDifferentiableFunctionExtraction(DifferentiableFunctionInst *source);
};
} // end anonymous namespace
/// If the original function doesn't have a return, it cannot be differentiated.
/// Returns true if error is emitted.
static bool diagnoseNoReturn(ADContext &context, SILFunction *original,
DifferentiationInvoker invoker) {
if (original->findReturnBB() != original->end())
return false;
context.emitNondifferentiabilityError(
original->getLocation().getEndSourceLoc(), invoker,
diag::autodiff_missing_return);
return true;
}
/// If the original function contains unsupported control flow, emit a "control
/// flow unsupported" error at appropriate source locations. Returns true if
/// error is emitted.
///
/// Update as control flow support is added.
static bool diagnoseUnsupportedControlFlow(ADContext &context,
SILFunction *original,
DifferentiationInvoker invoker) {
if (original->getBlocks().size() <= 1)
return false;
// Diagnose unsupported branching terminators.
for (auto &bb : *original) {
auto *term = bb.getTerminator();
// Check supported branching terminators.
if (isa<BranchInst>(term) || isa<CondBranchInst>(term) ||
isa<SwitchEnumInst>(term) || isa<SwitchEnumAddrInst>(term) ||
isa<CheckedCastBranchInst>(term) ||
isa<CheckedCastValueBranchInst>(term) ||
isa<CheckedCastAddrBranchInst>(term) || isa<TryApplyInst>(term))
continue;
// If terminator is an unsupported branching terminator, emit an error.
if (term->isBranch()) {
context.emitNondifferentiabilityError(
term, invoker, diag::autodiff_control_flow_not_supported);
return true;
}
}
return false;
}
/// Check whether the given requirements are satisfied, with the given
/// derivative generic signature (containing requirements), and substitution
/// map. Returns true if error is emitted.
static bool diagnoseUnsatisfiedRequirements(ADContext &context,
CanSILFunctionType origFnTy,
GenericSignature derivativeGenSig,
SubstitutionMap substMap,
DifferentiationInvoker invoker,
SourceLoc loc) {
// If the original function is polymorphic and its generic signature is the
// same as the derivative generic signature, then the requirements are
// satisfied. This check is necessary because the subsequent logic does not
// correctly handle polymorphic original functions.
// TODO(TF-1055): Can be removed after we have a robust solution for TF-1055.
if (origFnTy->getInvocationGenericSignature() && derivativeGenSig &&
origFnTy->getInvocationGenericSignature()->isEqual(derivativeGenSig))
return false;
// If there are no derivative requirements, return false.
if (!derivativeGenSig)
return false;
auto requirements = derivativeGenSig->getRequirements();
if (requirements.empty())
return false;
// Iterate through all requirements and check whether they are satisfied.
auto *swiftModule = context.getModule().getSwiftModule();
SmallVector<Requirement, 2> unsatisfiedRequirements;
for (auto req : requirements) {
auto firstType = req.getFirstType();
Type secondType;
// Substitute first and second types using the given substitution map,
// looking up conformances in the current module, if possible.
if (auto substFirstType =
firstType.subst(QuerySubstitutionMap{substMap},
LookUpConformanceInModule(swiftModule))) {
firstType = substFirstType;
}
if (req.getKind() != RequirementKind::Layout) {
secondType = req.getSecondType();
if (auto substSecondType =
secondType.subst(QuerySubstitutionMap{substMap},
LookUpConformanceInModule(swiftModule))) {
secondType = substSecondType;
}
}
switch (req.getKind()) {
// Check layout requirements.
case RequirementKind::Layout: {
auto layout = req.getLayoutConstraint();
switch (layout->getKind()) {
case LayoutConstraintKind::Class:
if (!firstType->satisfiesClassConstraint())
unsatisfiedRequirements.push_back(req);
continue;
default:
// TODO: Check other layout requirements. Note that `@differentiable`
// attribute type-checking does not yet support layout requirements in
// where clauses; layout requirements in derivative generic signatures
// can be formed only from `differentiable_function` instructions whose
// original function operand is generic with layout requirements.
break;
}
continue;
}
// Check same type requirements.
case RequirementKind::SameType:
// If the first type does not equal the second type, then record the
// unsatisfied requirement.
if (!firstType->isEqual(secondType))
unsatisfiedRequirements.push_back(req);
continue;
// Check superclass requirements.
case RequirementKind::Superclass: {
// If the second type is not an exact superclass of second type, then
// record the unsatisfied requirement.
if (!secondType->isExactSuperclassOf(firstType))
unsatisfiedRequirements.push_back(req);
continue;
}
// Check conformance requirements.
case RequirementKind::Conformance: {
auto protocolType = req.getSecondType()->castTo<ProtocolType>();
auto protocol = protocolType->getDecl();
assert(protocol && "Expected protocol in generic signature requirement");
// If the first type does not conform to the second type in the current
// module, then record the unsatisfied requirement.
if (!swiftModule->lookupConformance(firstType, protocol))
unsatisfiedRequirements.push_back(req);
continue;
}
}
}
if (unsatisfiedRequirements.empty())
return false;
// Diagnose unsatisfied requirements.
std::string reqText;
llvm::raw_string_ostream stream(reqText);
interleave(
unsatisfiedRequirements,
[&](Requirement req) { req.print(stream, PrintOptions()); },
[&] { stream << ", "; });
context.emitNondifferentiabilityError(
loc, invoker, diag::autodiff_function_assoc_func_unmet_requirements,
stream.str());
return true;
}
//===----------------------------------------------------------------------===//
// Code emission utilities
//===----------------------------------------------------------------------===//
/// Given an apply site, emit copies of all parameters and place them in
/// `copiedArgs`. Any buffers that need to be destroyed will be added to
/// `newArgsToDestroy`. Any new buffers that need to be deallocated will be
/// added to `newBuffersToDealloc`. This helper is used for duplicating an
/// apply site.
static void copyParameterArgumentsForApply(
ApplySite applySite, SmallVectorImpl<SILValue> &copiedArgs,
SmallVectorImpl<SILValue> &newArgsToDestroy,
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
LLVM_DEBUG({
auto &s = getADDebugStream() << "Copying arguments from apply site: ";
applySite.getInstruction()->print(s);
});
auto loc = applySite.getLoc();
copiedArgs.reserve(applySite.getNumArguments());
SILBuilderWithScope copyBuilder(applySite.getInstruction());
for (auto &argOperand : applySite.getArgumentOperands()) {
auto arg = argOperand.get();
auto argConv = applySite.getArgumentConvention(argOperand);
auto collectNewArg = [&](SILValue newArg) {
copiedArgs.push_back(newArg);
if (argConv.isGuaranteedConvention() &&
argConv != SILArgumentConvention::Indirect_InoutAliasable)
newArgsToDestroy.push_back(newArg);
};
// Copy the argument if it's to be owned by the newly created closure.
// Objects are to be retained.
if (arg->getType().isObject()) {
auto newArg = arg;
if (!copyBuilder.hasOwnership() ||
newArg.getOwnershipKind() != OwnershipKind::None)
newArg = copyBuilder.emitCopyValueOperation(loc, arg);
collectNewArg(newArg);
continue;
}
// Addresses depend on argument conventions.
// If the argument is an aliasable inout reference, do not copy the
// argument since it's a `@noescape` capture.
if (argConv == SILArgumentConvention::Indirect_InoutAliasable) {
collectNewArg(arg);
continue;
}
// Otherwise, it must be address-only. Create a new buffer and perform
// `copy_addr`.
auto *argCopy = copyBuilder.createAllocStack(loc, arg->getType());
newBuffersToDealloc.push_back(argCopy);
copyBuilder.createCopyAddr(loc, arg, argCopy, IsNotTake, IsInitialization);
collectNewArg(argCopy);
}
}
/// When a function value is used in an instruction (usually `apply`), there may
/// be conversion instructions in between, e.g. `thin_to_thick_function`. Given
/// a new function value and an old function value, this helper function
/// recursively converts the new function just like how the old function is
/// converted.
///
/// If the new function's generic signature is specified, it is used
/// to create substitution maps for reapplied `partial_apply` instructions.
static SILValue reapplyFunctionConversion(
ADContext &context, SILValue newFunc, SILValue oldFunc,
SILValue oldConvertedFunc, SILBuilder &builder, SILLocation loc,
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc,
IndexSubset *parameterIndices, IndexSubset *resultIndices,
GenericSignature newFuncGenSig = GenericSignature()) {
// If the old func is the new func, then there's no conversion.
if (oldFunc == oldConvertedFunc)
return newFunc;
// Handle a few instruction cases.
// copy_value
if (auto *cvi = dyn_cast<CopyValueInst>(oldConvertedFunc)) {
// Note: no `copy_value` is needed for the re-converted function because the
// caller of `reapplyFunctionConversion` should consume the re-converted
// function.
return reapplyFunctionConversion(
context, newFunc, oldFunc, cvi->getOperand(), builder, loc,
newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig);
}
// begin_borrow
if (auto *bbi = dyn_cast<BeginBorrowInst>(oldConvertedFunc)) {
// Note: no `begin_borrow` is needed for the re-converted function because
// the caller of `reapplyFunctionConversion` should consume the re-converted
// function.
return reapplyFunctionConversion(
context, newFunc, oldFunc, bbi->getOperand(), builder, loc,
newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig);
}
// convert_function
if (auto *cfi = dyn_cast<ConvertFunctionInst>(oldConvertedFunc)) {
return reapplyFunctionConversion(
context, newFunc, oldFunc, cfi->getOperand(), builder, loc,
newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig);
}
// thin_to_thick_function
if (auto *tttfi = dyn_cast<ThinToThickFunctionInst>(oldConvertedFunc)) {
auto innerNewFunc = reapplyFunctionConversion(
context, newFunc, oldFunc, tttfi->getOperand(), builder, loc,
newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig);
auto operandFnTy = innerNewFunc->getType().castTo<SILFunctionType>();
auto thickTy = operandFnTy->getWithRepresentation(
SILFunctionTypeRepresentation::Thick);
auto silTy = SILType::getPrimitiveObjectType(thickTy);
return builder.createThinToThickFunction(loc, innerNewFunc, silTy);
}
// partial_apply
if (auto *pai = dyn_cast<PartialApplyInst>(oldConvertedFunc)) {
SmallVector<SILValue, 8> newArgs;
newArgs.reserve(pai->getNumArguments());
SmallVector<SILValue, 1> newArgsToDestroy;
copyParameterArgumentsForApply(pai, newArgs, newArgsToDestroy,
newBuffersToDealloc);
auto innerNewFunc = reapplyFunctionConversion(
context, newFunc, oldFunc, pai->getCallee(), builder, loc,
newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig);
// Reabstraction thunk `partial_apply` reapplications require special
// support. Reabstraction thunk JVP/VJP expects a `@differentiable`
// function-typed argument to avoid opaque function non-differentiability
// errors. Thus, `partial_apply` reapplications must first form a
// `differentiable_function` of the function-typed thunk argument.
auto isReabstractionThunkCallee = [&]() -> bool {
auto *fri = dyn_cast<FunctionRefInst>(oldFunc);
return fri && fri->getReferencedFunctionOrNull()->isThunk() ==
IsReabstractionThunk;
};
if (isReabstractionThunkCallee()) {
assert(newArgs.size() == 1 &&
"Expected reabstraction thunk to be partially applied with only "
"one argument");
auto *dfi = context.createDifferentiableFunction(
builder, loc, parameterIndices, resultIndices, newArgs.back());
context.getDifferentiableFunctionInstWorklist().push_back(dfi);
newArgs.back() = dfi;
}
// Compute substitution map for reapplying `partial_apply`.
// - If reapplied functoin is not polymorphic, use empty substitution map
// regardless of the original `partial_apply`'s substitution map.
// - This case is triggered for reapplying `partial_apply` where `newFunc`
// is a `differentiability_witness_function` where the witness generic
// signature has all concrete parameters while the original function's
// generic signature does not. In this case, the original function type
// is polymorphic while derivative function types are not (specialized
// with concrete types from same-type requirements).
// - Otherwise, if `newFuncGenSig` is not specified, use the original
// `partial_apply`'s substitution map.
// - Otherwise, if `newFuncGenSig` is specified, combine it with the
// original `partial_apply`'s substitution map.
SubstitutionMap substMap;
if (innerNewFunc->getType().castTo<SILFunctionType>()->isPolymorphic()) {
if (!newFuncGenSig) {
substMap = pai->getSubstitutionMap();
} else {
substMap = SubstitutionMap::get(
newFuncGenSig, QuerySubstitutionMap{pai->getSubstitutionMap()},
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
}
}
return builder.createPartialApply(loc, innerNewFunc, substMap, newArgs,
ParameterConvention::Direct_Guaranteed);
}
llvm_unreachable("Unhandled function conversion instruction");
}
/// Emits a reference to a derivative function of `original`, differentiated
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
/// the derivative function and the actual indices that the derivative function
/// is with respect to.
///
/// Returns `None` on failure, signifying that a diagnostic has been emitted
/// using `invoker`.
static Optional<std::pair<SILValue, AutoDiffConfig>>
emitDerivativeFunctionReference(
DifferentiationTransformer &transformer, SILBuilder &builder,
AutoDiffConfig desiredConfig, AutoDiffDerivativeFunctionKind kind,
SILValue original, DifferentiationInvoker invoker,
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
ADContext &context = transformer.getContext();
// If `original` is itself an `DifferentiableFunctionExtractInst` whose kind
// matches the given kind and desired differentiation parameter indices,
// simply extract the derivative function of its function operand, retain the
// derivative function, and return it.
if (auto *inst = original->getDefiningInstruction())
if (auto *dfei = dyn_cast<DifferentiableFunctionExtractInst>(inst))
if (dfei->getExtractee() ==
NormalDifferentiableFunctionTypeComponent::Original)
original = dfei->getOperand();
// If `original` is a `@differentiable` function, just extract the
// derivative function.
if (auto diffableFnType = original->getType().castTo<SILFunctionType>()) {
if (diffableFnType->isDifferentiable()) {
auto paramIndices =
diffableFnType->getDifferentiabilityParameterIndices();
for (auto i : desiredConfig.parameterIndices->getIndices()) {
if (!paramIndices->contains(i)) {
context.emitNondifferentiabilityError(
original, invoker,
diag::
autodiff_function_noderivative_parameter_not_differentiable);
return None;
}
}
auto borrowedDiffFunc =
builder.emitBeginBorrowOperation(original.getLoc(), original);
SILValue derivativeFn = builder.createDifferentiableFunctionExtract(
borrowedDiffFunc.getLoc(), kind, borrowedDiffFunc);
if (!builder.hasOwnership() ||
derivativeFn.getOwnershipKind() != OwnershipKind::None)
derivativeFn =
builder.emitCopyValueOperation(original.getLoc(), derivativeFn);
builder.emitEndBorrowOperation(original.getLoc(), borrowedDiffFunc);
return std::make_pair(derivativeFn, desiredConfig);
}
}
// Handle `function_ref` original function.
if (auto *originalFRI =
peerThroughFunctionConversions<FunctionRefInst>(original)) {
auto loc = originalFRI->getLoc();
auto *originalFn = originalFRI->getReferencedFunctionOrNull();
assert(originalFn);
auto originalFnTy = originalFn->getLoweredFunctionType();
auto *desiredParameterIndices = desiredConfig.parameterIndices;
auto *desiredResultIndices = desiredConfig.resultIndices;
// NOTE(TF-893): Extending capacity is necessary when `originalFnTy` has
// parameters corresponding to captured variables.
// TODO: If posssible, change `autodiff::getLoweredParameterIndices` to
// take `CaptureInfo` into account.
if (originalFnTy->getNumParameters() >
desiredParameterIndices->getCapacity()) {
desiredParameterIndices = desiredParameterIndices->extendingCapacity(
context.getASTContext(), originalFnTy->getNumParameters());
}
// Look up a differentiability witness with the exact configuration.
auto *minimalWitness = getExactDifferentiabilityWitness(
context.getModule(), originalFn, desiredParameterIndices,
desiredResultIndices);
// Otherwise, look up a differentiability witness with a minimal superset
// configuration.
if (!minimalWitness)
minimalWitness = getOrCreateMinimalASTDifferentiabilityWitness(
context.getModule(), originalFn, desiredParameterIndices,
desiredResultIndices);
// If no minimal witness exists, check non-differentiable cases before
// creating a new private differentiability witness.
if (!minimalWitness) {
// If the function is intentionally marked as being opaque to
// differentiation, then we should not create a task for it.
if (originalFn->hasSemanticsAttr("autodiff.opaque")) {
context.emitNondifferentiabilityError(
original, invoker,
diag::autodiff_opaque_function_not_differentiable);
return None;
}
// Check and diagnose non-differentiable arguments.
auto originalFnTy = originalFn->getLoweredFunctionType();
for (unsigned paramIndex : range(originalFnTy->getNumParameters())) {
if (desiredConfig.isWrtParameter(paramIndex) &&
!originalFnTy->getParameters()[paramIndex]
.getSILStorageInterfaceType()
.isDifferentiable(context.getModule())) {
auto diag = context.emitNondifferentiabilityError(
original, invoker, diag::autodiff_nondifferentiable_argument);
return None;
}
}
// Check and diagnose non-differentiable results.
for (auto resultIndex : desiredResultIndices->getIndices()) {
SILType resultType;
if (resultIndex >= originalFnTy->getNumResults()) {
auto inoutParamIdx = resultIndex - originalFnTy->getNumResults();
auto inoutParam =
*std::next(originalFnTy->getIndirectMutatingParameters().begin(),
inoutParamIdx);
resultType = inoutParam.getSILStorageInterfaceType();
} else {
resultType = originalFnTy->getResults()[resultIndex]
.getSILStorageInterfaceType();
}
if (!resultType.isDifferentiable(context.getModule())) {
context.emitNondifferentiabilityError(
original, invoker, diag::autodiff_nondifferentiable_result);
return None;
}
}
// Check and diagnose external declarations.
if (originalFn->isExternalDeclaration()) {
context.emitNondifferentiabilityError(
original, invoker,
diag::autodiff_external_nondifferentiable_function);
return None;
}
// Sanity check passed. Create a new differentiability witness and
// canonicalize it.
GenericSignature contextualDerivativeGenSig = GenericSignature();
if (invoker.getKind() ==
DifferentiationInvoker::Kind::IndirectDifferentiation)
contextualDerivativeGenSig =
invoker.getIndirectDifferentiation()
.second->getDerivativeGenericSignature();
auto derivativeConstrainedGenSig =
autodiff::getConstrainedDerivativeGenericSignature(
originalFn->getLoweredFunctionType(), desiredParameterIndices,
contextualDerivativeGenSig,
LookUpConformanceInModule(context.getModule().getSwiftModule()));
minimalWitness = SILDifferentiabilityWitness::createDefinition(
context.getModule(), SILLinkage::Private, originalFn,
desiredParameterIndices, desiredResultIndices,
derivativeConstrainedGenSig, /*jvp*/ nullptr,
/*vjp*/ nullptr, /*isSerialized*/ false);
if (transformer.canonicalizeDifferentiabilityWitness(
originalFn, minimalWitness, invoker, IsNotSerialized))
return None;
}
assert(minimalWitness);
if (original->getFunction()->isSerialized() &&
!hasPublicVisibility(minimalWitness->getLinkage())) {
enum { Inlinable = 0, DefaultArgument = 1 };
unsigned fragileKind = Inlinable;
// FIXME: This is not a very robust way of determining if the function is
// a default argument. Also, we have not exhaustively listed all the kinds
// of fragility.
if (original->getFunction()->getLinkage() == SILLinkage::PublicNonABI)
fragileKind = DefaultArgument;
context.emitNondifferentiabilityError(
original, invoker, diag::autodiff_private_derivative_from_fragile,
fragileKind,
llvm::isa_and_nonnull<AbstractClosureExpr>(
originalFRI->getLoc().getAsASTNode<Expr>()));
return None;
}
// TODO(TF-482): Move generic requirement checking logic to
// `getExactDifferentiabilityWitness` and
// `getOrCreateMinimalASTDifferentiabilityWitness`.
// Get the substitution map for checking unmet generic requirements.
// By default, use the forwarding substitution map of the original function.
// If the original callee is a `partial_apply` or `apply` instruction, use
// its substitution map instead.
auto substMap = original->getFunction()->getForwardingSubstitutionMap();
if (auto *pai =
peerThroughFunctionConversions<PartialApplyInst>(original)) {
substMap = pai->getSubstitutionMap();
} else if (auto *ai = peerThroughFunctionConversions<ApplyInst>(original)) {
substMap = ai->getSubstitutionMap();
}
if (diagnoseUnsatisfiedRequirements(
context, original->getType().castTo<SILFunctionType>(),
minimalWitness->getDerivativeGenericSignature(), substMap, invoker,
original.getLoc().getSourceLoc()))
return None;
DifferentiabilityWitnessFunctionKind witnessKind;
switch (kind) {
case AutoDiffDerivativeFunctionKind::JVP:
witnessKind = DifferentiabilityWitnessFunctionKind::JVP;
break;
case AutoDiffDerivativeFunctionKind::VJP:
witnessKind = DifferentiabilityWitnessFunctionKind::VJP;
break;
}
auto *derivativeFnRef = builder.createDifferentiabilityWitnessFunction(
loc, witnessKind, minimalWitness);
auto convertedRef = reapplyFunctionConversion(
context, derivativeFnRef, originalFRI, original, builder, loc,
newBuffersToDealloc, desiredConfig.parameterIndices,
desiredConfig.resultIndices,
derivativeFnRef->getType()
.getASTType()
->castTo<SILFunctionType>()
->getSubstGenericSignature());
return std::make_pair(convertedRef, minimalWitness->getConfig());
}
// Handle `witness_method`.
if (auto *witnessMethod =
peerThroughFunctionConversions<WitnessMethodInst>(original)) {
auto loc = witnessMethod->getLoc();
auto requirementDeclRef = witnessMethod->getMember();
auto *requirementDecl = requirementDeclRef.getAbstractFunctionDecl();
// If requirement declaration does not have any derivative function
// configurations, produce an error.
if (requirementDecl->getDerivativeFunctionConfigurations().empty()) {
context.emitNondifferentiabilityError(
original, invoker, diag::autodiff_protocol_member_not_differentiable);
return None;
}
// Find the minimal derivative configuration: minimal parameter indices and
// corresponding derivative generic signature. If it does not exist, produce
// an error.
IndexSubset *minimalASTParamIndices = nullptr;
auto minimalConfig = findMinimalDerivativeConfiguration(
requirementDecl, desiredConfig.parameterIndices,
minimalASTParamIndices);
if (!minimalConfig) {
context.emitNondifferentiabilityError(
original, invoker,
diag::autodiff_member_subset_indices_not_differentiable);
return None;
}
// Emit a `witness_method` instruction for the derivative function.
auto originalType = witnessMethod->getType().castTo<SILFunctionType>();
auto assocType = originalType->getAutoDiffDerivativeFunctionType(
minimalConfig->parameterIndices, minimalConfig->resultIndices, kind,
context.getTypeConverter(),
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get(
kind, minimalASTParamIndices, minimalConfig->derivativeGenericSignature,
context.getASTContext());
auto *ref = builder.createWitnessMethod(
loc, witnessMethod->getLookupType(), witnessMethod->getConformance(),
requirementDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId),
SILType::getPrimitiveObjectType(assocType));
auto convertedRef = reapplyFunctionConversion(
context, ref, witnessMethod, original, builder, loc,
newBuffersToDealloc, desiredConfig.parameterIndices,
desiredConfig.resultIndices);
return std::make_pair(convertedRef, *minimalConfig);
}
// Handle `class_method`.
if (auto *classMethod =
peerThroughFunctionConversions<ClassMethodInst>(original)) {
auto loc = classMethod->getLoc();
auto methodDeclRef = classMethod->getMember();
auto *methodDecl = methodDeclRef.getAbstractFunctionDecl();
// If method declaration does not have any derivative function
// configurations, produce an error.
if (methodDecl->getDerivativeFunctionConfigurations().empty()) {
context.emitNondifferentiabilityError(
original, invoker, diag::autodiff_class_member_not_differentiable);
return None;
}
// Find the minimal derivative configuration: minimal parameter indices and
// corresponding derivative generic signature. If it does not exist, produce
// an error.
IndexSubset *minimalASTParamIndices = nullptr;
auto minimalConfig = findMinimalDerivativeConfiguration(
methodDecl, desiredConfig.parameterIndices, minimalASTParamIndices);
if (!minimalConfig) {
context.emitNondifferentiabilityError(
original, invoker,
diag::autodiff_member_subset_indices_not_differentiable);
return None;
}
// Emit a `class_method` instruction for the derivative function.
auto originalType = classMethod->getType().castTo<SILFunctionType>();
auto assocType = originalType->getAutoDiffDerivativeFunctionType(
minimalConfig->parameterIndices, minimalConfig->resultIndices, kind,
context.getTypeConverter(),
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get(
kind, minimalASTParamIndices, minimalConfig->derivativeGenericSignature,
context.getASTContext());
auto *ref = builder.createClassMethod(
loc, classMethod->getOperand(),
methodDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId),
SILType::getPrimitiveObjectType(assocType));
auto convertedRef = reapplyFunctionConversion(
context, ref, classMethod, original, builder, loc, newBuffersToDealloc,
desiredConfig.parameterIndices, desiredConfig.resultIndices);
return std::make_pair(convertedRef, *minimalConfig);
}
// Emit the general opaque function error.
context.emitNondifferentiabilityError(
original, invoker, diag::autodiff_opaque_function_not_differentiable);
return None;
}
//===----------------------------------------------------------------------===//
// `SILDifferentiabilityWitness` processing
//===----------------------------------------------------------------------===//
static SILFunction *createEmptyVJP(ADContext &context, SILFunction *original,
SILDifferentiabilityWitness *witness,
IsSerialized_t isSerialized) {
LLVM_DEBUG({
auto &s = getADDebugStream();
s << "Creating VJP:\n\t";
s << "Original type: " << original->getLoweredFunctionType() << "\n\t";
});
auto &module = context.getModule();
auto originalTy = original->getLoweredFunctionType();
auto config = witness->getConfig();
// === Create an empty VJP. ===
Mangle::ASTMangler mangler;
auto vjpName =
original->getASTContext()
.getIdentifier(mangler.mangleAutoDiffDerivativeFunctionHelper(
original->getName(), AutoDiffDerivativeFunctionKind::VJP,
witness->getConfig()))
.str();
CanGenericSignature vjpCanGenSig;
if (auto vjpGenSig = witness->getDerivativeGenericSignature())
vjpCanGenSig = vjpGenSig->getCanonicalSignature();
GenericEnvironment *vjpGenericEnv = nullptr;
if (vjpCanGenSig && !vjpCanGenSig->areAllParamsConcrete())
vjpGenericEnv = vjpCanGenSig->getGenericEnvironment();
auto vjpType = originalTy->getAutoDiffDerivativeFunctionType(
config.parameterIndices, config.resultIndices,
AutoDiffDerivativeFunctionKind::VJP,
module.Types, LookUpConformanceInModule(module.getSwiftModule()),
vjpCanGenSig,
/*isReabstractionThunk*/ original->isThunk() == IsReabstractionThunk);
SILOptFunctionBuilder fb(context.getTransform());
auto *vjp = fb.createFunction(
witness->getLinkage(), vjpName, vjpType, vjpGenericEnv,
original->getLocation(), original->isBare(), IsNotTransparent,
isSerialized, original->isDynamicallyReplaceable());
vjp->setDebugScope(new (module) SILDebugScope(original->getLocation(), vjp));
LLVM_DEBUG(llvm::dbgs() << "VJP type: " << vjp->getLoweredFunctionType()
<< "\n");
return vjp;
}
static SILFunction *createEmptyJVP(ADContext &context, SILFunction *original,
SILDifferentiabilityWitness *witness,
IsSerialized_t isSerialized) {
LLVM_DEBUG({
auto &s = getADDebugStream();
s << "Creating JVP:\n\t";
s << "Original type: " << original->getLoweredFunctionType() << "\n\t";
});
auto &module = context.getModule();
auto originalTy = original->getLoweredFunctionType();
auto config = witness->getConfig();
// === Create an empty JVP. ===
Mangle::ASTMangler mangler;
auto jvpName =
original->getASTContext()
.getIdentifier(mangler.mangleAutoDiffDerivativeFunctionHelper(
original->getName(), AutoDiffDerivativeFunctionKind::JVP,
witness->getConfig()))
.str();
CanGenericSignature jvpCanGenSig;
if (auto jvpGenSig = witness->getDerivativeGenericSignature())
jvpCanGenSig = jvpGenSig->getCanonicalSignature();
GenericEnvironment *jvpGenericEnv = nullptr;
if (jvpCanGenSig && !jvpCanGenSig->areAllParamsConcrete())
jvpGenericEnv = jvpCanGenSig->getGenericEnvironment();
auto jvpType = originalTy->getAutoDiffDerivativeFunctionType(
config.parameterIndices, config.resultIndices,
AutoDiffDerivativeFunctionKind::JVP,
module.Types, LookUpConformanceInModule(module.getSwiftModule()),
jvpCanGenSig,
/*isReabstractionThunk*/ original->isThunk() == IsReabstractionThunk);
SILOptFunctionBuilder fb(context.getTransform());
auto *jvp = fb.createFunction(
witness->getLinkage(), jvpName, jvpType, jvpGenericEnv,
original->getLocation(), original->isBare(), IsNotTransparent,
isSerialized, original->isDynamicallyReplaceable());
jvp->setDebugScope(new (module) SILDebugScope(original->getLocation(), jvp));
LLVM_DEBUG(llvm::dbgs() << "JVP type: " << jvp->getLoweredFunctionType()
<< "\n");
return jvp;
}
/// Apply the fatal error function with the given name of type
/// `@convention(thin) () -> Never` in `f`.
static void emitFatalError(ADContext &context, SILFunction *f,
StringRef fatalErrorFuncName) {
auto *entry = f->createBasicBlock();
createEntryArguments(f);
SILBuilder builder(entry);
auto loc = f->getLocation();
// Destroy all owned arguments to pass ownership verification.
for (auto *arg : entry->getArguments())
if (!builder.hasOwnership() ||
arg->getOwnershipKind() == OwnershipKind::Owned)
builder.emitDestroyOperation(loc, arg);
// Fatal error with a nice message.
auto neverResultInfo =
SILResultInfo(context.getModule().getASTContext().getNeverType(),
ResultConvention::Unowned);
// Fatal error function must have type `@convention(thin) () -> Never`.
auto fatalErrorFnType = SILFunctionType::get(
/*genericSig*/ nullptr, SILFunctionType::ExtInfo::getThin(),
SILCoroutineKind::None, ParameterConvention::Direct_Unowned, {},
/*interfaceYields*/ {}, neverResultInfo,
/*interfaceErrorResults*/ None, {}, {}, context.getASTContext());
auto fnBuilder = SILOptFunctionBuilder(context.getTransform());
auto *fatalErrorFn = fnBuilder.getOrCreateFunction(
loc, fatalErrorFuncName, SILLinkage::PublicExternal, fatalErrorFnType,
IsNotBare, IsNotTransparent, IsNotSerialized, IsNotDynamic,
ProfileCounter(), IsNotThunk);
auto *fatalErrorFnRef = builder.createFunctionRef(loc, fatalErrorFn);
builder.createApply(loc, fatalErrorFnRef, SubstitutionMap(), {});
builder.createUnreachable(loc);
}
/// Returns true on error.
bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
SILFunction *original, SILDifferentiabilityWitness *witness,
DifferentiationInvoker invoker, IsSerialized_t serializeFunctions) {
std::string traceMessage;
llvm::raw_string_ostream OS(traceMessage);
OS << "processing ";
witness->print(OS);
OS << " on";
OS.flush();
PrettyStackTraceSILFunction trace(traceMessage.c_str(), original);
assert(witness->isDefinition());
// If the JVP doesn't exist, need to synthesize it.
if (!witness->getJVP()) {
// Diagnose:
// - Functions with no return.
// - Functions with unsupported control flow.
if (context.getASTContext()
.LangOpts.EnableExperimentalForwardModeDifferentiation &&
(diagnoseNoReturn(context, original, invoker) ||
diagnoseUnsupportedControlFlow(context, original, invoker)))
return true;
// Create empty JVP.
auto *jvp = createEmptyJVP(context, original, witness, serializeFunctions);
witness->setJVP(jvp);
context.recordGeneratedFunction(jvp);
// For now, only do JVP generation if the flag is enabled and if custom VJP
// does not exist. If custom VJP exists but custom JVP does not, skip JVP
// generation because generated JVP may not match semantics of custom VJP.
// Instead, create an empty JVP.
if (context.getASTContext()
.LangOpts.EnableExperimentalForwardModeDifferentiation &&
!witness->getVJP()) {
// JVP and differential generation do not currently support functions with
// multiple basic blocks.
if (original->getBlocks().size() > 1) {
context.emitNondifferentiabilityError(
original->getLocation().getSourceLoc(), invoker,
diag::autodiff_jvp_control_flow_not_supported);
return true;
}
// Emit JVP function.
JVPCloner cloner(context, original, witness, jvp, invoker);
if (cloner.run())
return true;
} else {
// If JVP generation is disabled or a user-defined custom VJP function
// exists, fatal error with a nice message.
emitFatalError(context, jvp,
"_fatalErrorForwardModeDifferentiationDisabled");
LLVM_DEBUG(getADDebugStream()
<< "Generated empty JVP for " << original->getName() << ":\n"
<< *jvp);
}
}
// If the VJP doesn't exist, need to synthesize it.
if (!witness->getVJP()) {
// Diagnose:
// - Functions with no return.
// - Functions with unsupported control flow.
if (diagnoseNoReturn(context, original, invoker) ||
diagnoseUnsupportedControlFlow(context, original, invoker))
return true;
// Create empty VJP.
auto *vjp = createEmptyVJP(context, original, witness, serializeFunctions);
witness->setVJP(vjp);
context.recordGeneratedFunction(vjp);
// Emit VJP function.
VJPCloner cloner(context, original, witness, vjp, invoker);
return cloner.run();
}
return false;
}
//===----------------------------------------------------------------------===//
// Differentiation pass implementation
//===----------------------------------------------------------------------===//
/// The automatic differentiation pass.
namespace {
class Differentiation : public SILModuleTransform {
public:
Differentiation() : SILModuleTransform() {}
void run() override;
};
} // end anonymous namespace
/// Given a curry thunk application, clone the thunk to return a
/// `@differentiable` function-typed value and apply the cloned thunk.
///
/// Curry thunk type: `(Self) -> (T, ...) -> U`.
/// Cloned thunk type: `(Self) -> @differentiable (T, ...) -> U`.
static SILValue promoteCurryThunkApplicationToDifferentiableFunction(
DifferentiationTransformer &dt, DifferentiableFunctionInst *dfi,
SILBuilder &builder, SILLocation loc, DifferentiationInvoker invoker) {
auto origFnOperand = dfi->getOriginalFunction();
auto *parameterIndices = dfi->getParameterIndices();
auto *resultIndices = dfi->getResultIndices();
auto &context = dt.getContext();
// Check for curry thunk application:
// - The original function operand must be an `apply` instruction.
// - The `apply` callee must be a `function_ref` instruction.
// - The callee must return a function-typed value.
auto *ai = dyn_cast<ApplyInst>(origFnOperand);
if (!ai)
return nullptr;
auto *thunkRef = dyn_cast<FunctionRefInst>(ai->getCallee());
if (!thunkRef)
return nullptr;
auto *thunk = thunkRef->getReferencedFunctionOrNull();
auto thunkTy = thunk->getLoweredFunctionType();
auto thunkResult = thunkTy->getSingleResult();
auto resultFnTy = thunkResult.getInterfaceType()->getAs<SILFunctionType>();
if (!resultFnTy)
return nullptr;
// Create a new curry thunk.
AutoDiffConfig desiredConfig(parameterIndices, resultIndices);
// TODO(TF-685): Use more principled mangling for thunks.
auto newThunkName = "AD__" + thunk->getName().str() +
"__differentiable_curry_thunk_" + desiredConfig.mangle();
// Construct new curry thunk type with `@differentiable` function
// result.
auto diffResultFnTy = resultFnTy->getWithExtInfo(
resultFnTy->getExtInfo()
.intoBuilder()
.withDifferentiabilityKind(DifferentiabilityKind::Normal)
.build());
auto newThunkResult = thunkResult.getWithInterfaceType(diffResultFnTy);
auto thunkType = SILFunctionType::get(
thunkTy->getSubstGenericSignature(), thunkTy->getExtInfo(),
thunkTy->getCoroutineKind(), thunkTy->getCalleeConvention(),
thunkTy->getParameters(), {}, {newThunkResult}, {},
thunkTy->getPatternSubstitutions(), thunkTy->getInvocationSubstitutions(),
thunkTy->getASTContext());
// Construct new curry thunk, returning a `@differentiable` function.
SILOptFunctionBuilder fb(dt.getTransform());
auto *newThunk = fb.getOrCreateFunction(
loc, newThunkName, getSpecializedLinkage(thunk, thunk->getLinkage()),
thunkType, thunk->isBare(), thunk->isTransparent(), thunk->isSerialized(),
thunk->isDynamicallyReplaceable(), ProfileCounter(), thunk->isThunk());
// If new thunk is newly created: clone the old thunk body, wrap the
// returned function value with an `differentiable_function`
// instruction, and process the `differentiable_function` instruction.
if (newThunk->empty()) {
if (auto newThunkGenSig = thunkType->getSubstGenericSignature())
newThunk->setGenericEnvironment(newThunkGenSig->getGenericEnvironment());
// TODO(TF-1206): Enable ownership in all differentiation thunks.
newThunk->setOwnershipEliminated();
BasicTypeSubstCloner cloner(thunk, newThunk);
cloner.cloneFunction();
auto *retInst = cast<ReturnInst>(newThunk->findReturnBB()->getTerminator());
auto returnValue = retInst->getOperand();
// Create `differentiable_function` instruction directly after the
// defining instruction (e.g. `partial_apply`) of the returned value.
// Note: `differentiable_function` is not created at the end of the
// new thunk to avoid `alloc_stack`/`dealloc_stack` ordering issues.
SILBuilderWithScope dfiBuilder(
std::next(returnValue->getDefiningInstruction()->getIterator()));
auto *dfi = context.createDifferentiableFunction(
dfiBuilder, loc, parameterIndices, resultIndices, returnValue);
dfiBuilder.setInsertionPoint(newThunk->findReturnBB());
dfiBuilder.createReturn(loc, dfi);
retInst->eraseFromParent();
context.recordGeneratedFunction(newThunk);
context.getDifferentiableFunctionInstWorklist().push_back(dfi);
if (dt.processDifferentiableFunctionInst(dfi))
return nullptr;
}
// Apply the new curry thunk.
auto *newThunkRef = builder.createFunctionRef(loc, newThunk);
context.recordGeneratedFunctionReference(newThunkRef);
SmallVector<SILValue, 8> newArgs;
SmallVector<SILValue, 8> newArgsToDestroy;
SmallVector<AllocStackInst *, 1> newBuffersToDealloc;
copyParameterArgumentsForApply(ai, newArgs, newArgsToDestroy,
newBuffersToDealloc);
auto *newApply = builder.createApply(
loc, newThunkRef, ai->getSubstitutionMap(), newArgs, ai->isNonThrowing());
for (auto arg : newArgsToDestroy)
builder.emitDestroyOperation(loc, arg);
for (auto *alloc : newBuffersToDealloc)
builder.createDeallocStack(loc, alloc);
return newApply;
}
SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
DifferentiableFunctionInst *dfi, SILBuilder &builder, SILLocation loc,
DifferentiationInvoker invoker) {
auto &astCtx = context.getASTContext();
auto origFnOperand = dfi->getOriginalFunction();
auto origFnTy = origFnOperand->getType().castTo<SILFunctionType>();
auto *parameterIndices = dfi->getParameterIndices();
auto *resultIndices = dfi->getResultIndices();
if (auto diffFn = promoteCurryThunkApplicationToDifferentiableFunction(
*this, dfi, builder, loc, invoker))
return diffFn;
AutoDiffConfig desiredConfig(parameterIndices, resultIndices);
SmallVector<SILValue, 2> derivativeFns;
SmallVector<AllocStackInst *, 2> newBuffersToDealloc;
for (auto derivativeFnKind : {AutoDiffDerivativeFunctionKind::JVP,
AutoDiffDerivativeFunctionKind::VJP}) {
auto derivativeFnAndIndices = emitDerivativeFunctionReference(
*this, builder, desiredConfig, derivativeFnKind, origFnOperand,
invoker, newBuffersToDealloc);
// Show an error at the operator, highlight the argument, and show a note
// at the definition site of the argument.
if (!derivativeFnAndIndices)
return nullptr;
auto derivativeFn = derivativeFnAndIndices->first;
context.recordGeneratedFunctionReference(derivativeFn);
// If desired indices are a subset of actual indices, create a "subset
// indices thunk" and destroy the emitted derivative function reference.
// - For JVPs: the thunked JVP returns a differential taking fewer
// parameters (using `.zero` for the dropped parameters).
// - For VJPs: the thunked VJP returns a pullback that drops the unused
// tangent values.
auto actualConfig = derivativeFnAndIndices->second;
// NOTE: `desiredIndices` may come from a partially-applied function and
// have smaller capacity than `actualIndices`. We expect this logic to go
// away when we support `@differentiable` partial apply.
// if (actualIndices != desiredIndices) { // TODO: Re-enable.
auto extendedDesiredParameterIndices =
desiredConfig.parameterIndices->extendingCapacity(
astCtx, actualConfig.parameterIndices->getCapacity());
if (!actualConfig.parameterIndices->equals(extendedDesiredParameterIndices)
|| !actualConfig.resultIndices->equals(desiredConfig.resultIndices)) {
// Destroy the already emitted derivative function reference because it
// is no longer used.
builder.emitDestroyValueOperation(loc, derivativeFn);
// Check if underlying original function reference has been partially
// applied with arguments. If so, produce an error: parameter subset
// thunks do not yet support this case because partially applied arguments
// cannot be propagated to parameter subset thunks.
auto didPartiallyApplyArguments = [](SILValue original) {
while (auto *pai =
peerThroughFunctionConversions<PartialApplyInst>(original)) {
if (pai->getNumArguments() > 0)
return true;
original = pai->getCallee();
}
return false;
};
if (didPartiallyApplyArguments(origFnOperand)) {
context.emitNondifferentiabilityError(
origFnOperand, invoker,
diag::autodiff_cannot_param_subset_thunk_partially_applied_orig_fn);
return nullptr;
}
// Create the parameter subset thunk.
assert(actualConfig.parameterIndices->isSupersetOf(
extendedDesiredParameterIndices));
SILFunction *thunk;
SubstitutionMap interfaceSubs;
SILOptFunctionBuilder fb(transform);
std::tie(thunk, interfaceSubs) =
getOrCreateSubsetParametersThunkForDerivativeFunction(
fb, origFnOperand, derivativeFn, derivativeFnKind, desiredConfig,
actualConfig);
auto *thunkFRI = builder.createFunctionRef(loc, thunk);
if (auto genSig =
thunk->getLoweredFunctionType()->getSubstGenericSignature()) {
derivativeFn =
builder.createPartialApply(loc, thunkFRI, interfaceSubs, {},
ParameterConvention::Direct_Guaranteed);
} else {
derivativeFn = thunkFRI;
}
}
auto expectedDerivativeFnTy = origFnTy->getAutoDiffDerivativeFunctionType(
parameterIndices, resultIndices, derivativeFnKind,
context.getTypeConverter(),
LookUpConformanceInModule(context.getModule().getSwiftModule()));
// If `derivativeFn` is `@convention(thin)` but is expected to be
// `@convention(thick)`, emit a `thin_to_thick` instruction.
if (expectedDerivativeFnTy->getRepresentation() ==
SILFunctionTypeRepresentation::Thick &&
derivativeFn->getType()
.castTo<SILFunctionType>()
->getRepresentation() == SILFunctionTypeRepresentation::Thin) {
derivativeFn = builder.createThinToThickFunction(
loc, derivativeFn,
SILType::getPrimitiveObjectType(expectedDerivativeFnTy));
}
// If derivative function value's type is not ABI-compatible with the
// expected derivative function type (i.e. parameter and result conventions
// do not match), perform reabstraction.
auto abiCompatibility = expectedDerivativeFnTy->isABICompatibleWith(
derivativeFn->getType().castTo<SILFunctionType>(), *dfi->getFunction());
if (!abiCompatibility.isCompatible()) {
SILOptFunctionBuilder fb(context.getTransform());
auto newDerivativeFn = reabstractFunction(
builder, fb, loc, derivativeFn, expectedDerivativeFnTy,
[](SubstitutionMap substMap) { return substMap; });
derivativeFn = newDerivativeFn;
assert(expectedDerivativeFnTy
->isABICompatibleWith(
derivativeFn->getType().castTo<SILFunctionType>(),
*dfi->getFunction())
.isCompatible());
}
derivativeFns.push_back(derivativeFn);
}
// Deallocate temporary buffers used for creating derivative functions.
for (auto *buf : llvm::reverse(newBuffersToDealloc))
builder.createDeallocStack(loc, buf);
// If our original copy does not have none ownership, copy it.
if (!builder.hasOwnership() ||
origFnOperand.getOwnershipKind() != OwnershipKind::None)
origFnOperand = builder.emitCopyValueOperation(loc, origFnOperand);
auto *newDiffFn = context.createDifferentiableFunction(
builder, loc, parameterIndices, resultIndices, origFnOperand,
std::make_pair(derivativeFns[0], derivativeFns[1]));
context.getDifferentiableFunctionInstWorklist().push_back(dfi);
return newDiffFn;
}
SILValue DifferentiationTransformer::promoteToLinearFunction(
LinearFunctionInst *lfi, SILBuilder &builder, SILLocation loc,
DifferentiationInvoker invoker) {
// Note: for now, this function creates a new `linear_function` instruction
// with an undef transpose function operand. Eventually, a legitimate
// transpose function operand should be created and used.
auto origFnOperand = lfi->getOriginalFunction();
if (!builder.hasOwnership() ||
origFnOperand.getOwnershipKind() != OwnershipKind::None)
origFnOperand = builder.emitCopyValueOperation(loc, origFnOperand);
auto *parameterIndices = lfi->getParameterIndices();
auto originalType = origFnOperand->getType().castTo<SILFunctionType>();
auto transposeFnType = originalType->getAutoDiffTransposeFunctionType(
parameterIndices, context.getTypeConverter(),
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
auto transposeType = SILType::getPrimitiveObjectType(transposeFnType);
auto transposeFn = SILUndef::get(transposeType, builder.getFunction());
auto *newLinearFn = context.createLinearFunction(
builder, loc, parameterIndices, origFnOperand, SILValue(transposeFn));
context.getLinearFunctionInstWorklist().push_back(lfi);
return newLinearFn;
}
/// Fold `differentiable_function_extract` users of the given
/// `differentiable_function` instruction, directly replacing them with
/// `differentiable_function` instruction operands. If the
/// `differentiable_function` instruction has no remaining uses, delete the
/// instruction itself after folding.
///
/// Folding can be disabled by the `SkipFoldingDifferentiableFunctionExtraction`
/// flag for SIL testing purposes.
// FIXME: This function is not correctly detecting the foldable pattern and
// needs to be rewritten.
void DifferentiationTransformer::foldDifferentiableFunctionExtraction(
DifferentiableFunctionInst *source) {
// Iterate through all `differentiable_function` instruction uses.
for (auto use : source->getUses()) {
auto *dfei = dyn_cast<DifferentiableFunctionExtractInst>(use->getUser());
// If user is not an `differentiable_function_extract` instruction, set flag
// to false.
if (!dfei)
continue;
// Fold original function extractors.
if (dfei->getExtractee() ==
NormalDifferentiableFunctionTypeComponent::Original) {
auto originalFnValue = source->getOriginalFunction();
dfei->replaceAllUsesWith(originalFnValue);
dfei->eraseFromParent();
continue;
}
// Fold derivative function extractors.
auto derivativeFnValue =
source->getDerivativeFunction(dfei->getDerivativeFunctionKind());
dfei->replaceAllUsesWith(derivativeFnValue);
dfei->eraseFromParent();
}
// If the `differentiable_function` instruction has no remaining uses, erase
// it.
if (isInstructionTriviallyDead(source)) {
SILBuilder builder(source);
builder.emitDestroyAddrAndFold(source->getLoc(), source->getJVPFunction());
builder.emitDestroyAddrAndFold(source->getLoc(), source->getVJPFunction());
source->eraseFromParent();
}
// Mark `source` as processed so that it won't be reprocessed after deletion.
context.markDifferentiableFunctionInstAsProcessed(source);
}
bool DifferentiationTransformer::processDifferentiableFunctionInst(
DifferentiableFunctionInst *dfi) {
PrettyStackTraceSILNode dfiTrace("canonicalizing `differentiable_function`",
cast<SILInstruction>(dfi));
PrettyStackTraceSILFunction fnTrace("...in", dfi->getFunction());
LLVM_DEBUG({
auto &s = getADDebugStream() << "Processing DifferentiableFunctionInst:\n";
dfi->printInContext(s);
});
// If `dfi` already has derivative functions, do not process.
if (dfi->hasDerivativeFunctions())
return false;
SILFunction *parent = dfi->getFunction();
auto loc = dfi->getLoc();
SILBuilderWithScope builder(dfi);
auto differentiableFnValue =
promoteToDifferentiableFunction(dfi, builder, loc, dfi);
// Mark `dfi` as processed so that it won't be reprocessed after deletion.
context.markDifferentiableFunctionInstAsProcessed(dfi);
if (!differentiableFnValue)
return true;
// Replace all uses of `dfi`.
dfi->replaceAllUsesWith(differentiableFnValue);
// Destroy the original operand.
builder.emitDestroyValueOperation(loc, dfi->getOriginalFunction());
dfi->eraseFromParent();
// If the promoted `@differentiable` function-typed value is an
// `differentiable_function` instruction, fold
// `differentiable_function_extract` instructions. If
// `differentiable_function_extract` folding is disabled, return.
if (!SkipFoldingDifferentiableFunctionExtraction)
if (auto *newDFI =
dyn_cast<DifferentiableFunctionInst>(differentiableFnValue))
foldDifferentiableFunctionExtraction(newDFI);
transform.invalidateAnalysis(parent,
SILAnalysis::InvalidationKind::FunctionBody);
return false;
}
bool DifferentiationTransformer::processLinearFunctionInst(
LinearFunctionInst *lfi) {
PrettyStackTraceSILNode dfiTrace("canonicalizing `linear_function`",
cast<SILInstruction>(lfi));
PrettyStackTraceSILFunction fnTrace("...in", lfi->getFunction());
LLVM_DEBUG({
auto &s = getADDebugStream() << "Processing LinearFunctionInst:\n";
lfi->printInContext(s);
});
// If `lfi` already has a transpose function, do not process.
if (lfi->hasTransposeFunction())
return false;
SILFunction *parent = lfi->getFunction();
auto loc = lfi->getLoc();
SILBuilderWithScope builder(lfi);
auto linearFnValue = promoteToLinearFunction(lfi, builder, loc, lfi);
// Mark `lfi` as processed so that it won't be reprocessed after deletion.
context.markLinearFunctionInstAsProcessed(lfi);
if (!linearFnValue)
return true;
// Replace all uses of `lfi`.
lfi->replaceAllUsesWith(linearFnValue);
// Destroy the original operand.
builder.emitDestroyValueOperation(loc, lfi->getOriginalFunction());
lfi->eraseFromParent();
transform.invalidateAnalysis(parent,
SILAnalysis::InvalidationKind::FunctionBody);
return false;
}
/// Automatic differentiation transform entry.
void Differentiation::run() {
auto &module = *getModule();
auto &astCtx = module.getASTContext();
debugDump(module);
// A transformation helper.
DifferentiationTransformer transformer(*this);
ADContext &context = transformer.getContext();
bool errorOccurred = false;
// Register all the SIL differentiability witnesses in the module that trigger
// differentiation.
for (auto &witness : module.getDifferentiabilityWitnesses()) {
if (witness.isDeclaration())
continue;
context.addInvoker(&witness);
}
// Register all the `differentiable_function` and `linear_function`
// instructions in the module that trigger differentiation.
for (SILFunction &f : module) {
for (SILBasicBlock &bb : f) {
for (SILInstruction &i : bb) {
if (auto *dfi = dyn_cast<DifferentiableFunctionInst>(&i)) {
context.getDifferentiableFunctionInstWorklist().push_back(dfi);
} else if (auto *lfi = dyn_cast<LinearFunctionInst>(&i)) {
// If linear map transposition is not enabled and an uncanonical
// `linear_function` instruction is encountered, emit a diagnostic.
// FIXME(SR-11850): Finish support for linear map transposition.
if (!EnableExperimentalLinearMapTransposition) {
if (!lfi->hasTransposeFunction()) {
astCtx.Diags.diagnose(
lfi->getLoc().getSourceLoc(),
diag::autodiff_conversion_to_linear_function_not_supported);
errorOccurred = true;
}
}
context.getLinearFunctionInstWorklist().push_back(lfi);
}
}
}
}
// If nothing has triggered differentiation, there's nothing to do.
if (context.getInvokers().empty() &&
context.getDifferentiableFunctionInstWorklist().empty() &&
context.getLinearFunctionInstWorklist().empty())
return;
// Differentiation relies on the stdlib (the Swift module).
// If it's not imported, it's an internal error.
if (!astCtx.getStdlibModule()) {
astCtx.Diags.diagnose(SourceLoc(),
diag::autodiff_internal_swift_not_imported);
return;
}
if (!astCtx.getLoadedModule(astCtx.Id_Differentiation)) {
SourceLoc loc;
if (!context.getInvokers().empty()) {
loc = context.getInvokers().front().second.getLocation();
} else {
assert(!context.getDifferentiableFunctionInstWorklist().empty());
loc = context.getDifferentiableFunctionInstWorklist()
.pop_back_val()
->getLoc()
.getSourceLoc();
}
astCtx.Diags.diagnose(loc,
diag::autodiff_differentiation_module_not_imported);
return;
}
// Process all invokers.
for (auto invokerPair : context.getInvokers()) {
auto *witness = invokerPair.first;
auto *original = witness->getOriginalFunction();
auto invoker = invokerPair.second;
if (transformer.canonicalizeDifferentiabilityWitness(
original, witness, invoker, original->isSerialized()))
errorOccurred = true;
}
// Iteratively process `differentiable_function` instruction worklist.
while (!context.getDifferentiableFunctionInstWorklist().empty()) {
auto *dfi = context.getDifferentiableFunctionInstWorklist().pop_back_val();
// Skip instructions that have been already been processed.
if (context.isDifferentiableFunctionInstProcessed(dfi))
continue;
errorOccurred |= transformer.processDifferentiableFunctionInst(dfi);
}
// Iteratively process `linear_function` instruction worklist.
while (!context.getLinearFunctionInstWorklist().empty()) {
auto *lfi = context.getLinearFunctionInstWorklist().pop_back_val();
// Skip instructions that have been already been processed.
if (context.isLinearFunctionInstProcessed(lfi))
continue;
errorOccurred |= transformer.processLinearFunctionInst(lfi);
}
// If any error occurred while processing witnesses or
// `differentiable_function` instructions, clean up.
if (errorOccurred) {
context.cleanUp();
return;
}
LLVM_DEBUG(getADDebugStream() << "All differentiation finished\n");
}
//===----------------------------------------------------------------------===//
// Pass creation
//===----------------------------------------------------------------------===//
SILTransform *swift::createDifferentiation() { return new Differentiation; }