blob: 84399833cbd5b582c2b6076e3e79810e5545c5ee [file] [log] [blame]
//===--- Thunk.cpp - Automatic differentiation thunks ---------*- C++ -*---===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// Automatic differentiation thunk generation utilities.
//
//===----------------------------------------------------------------------===//
#define DEBUG_TYPE "differentiation"
#include "swift/SILOptimizer/Differentiation/Thunk.h"
#include "swift/SILOptimizer/Differentiation/Common.h"
#include "swift/AST/AnyFunctionRef.h"
#include "swift/AST/GenericSignatureBuilder.h"
#include "swift/AST/Requirement.h"
#include "swift/AST/SubstitutionMap.h"
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
namespace swift {
namespace autodiff {
//===----------------------------------------------------------------------===//
// Thunk helpers
//===----------------------------------------------------------------------===//
// These helpers are copied/adapted from SILGen. They should be refactored and
// moved to a shared location.
//===----------------------------------------------------------------------===//
CanGenericSignature buildThunkSignature(SILFunction *fn, bool inheritGenericSig,
OpenedArchetypeType *openedExistential,
GenericEnvironment *&genericEnv,
SubstitutionMap &contextSubs,
SubstitutionMap &interfaceSubs,
ArchetypeType *&newArchetype) {
// If there's no opened existential, we just inherit the generic environment
// from the parent function.
if (openedExistential == nullptr) {
auto genericSig = fn->getLoweredFunctionType()->getSubstGenericSignature();
genericEnv = fn->getGenericEnvironment();
interfaceSubs = fn->getForwardingSubstitutionMap();
contextSubs = interfaceSubs;
return genericSig;
}
auto &ctx = fn->getASTContext();
GenericSignatureBuilder builder(ctx);
// Add the existing generic signature.
int depth = 0;
if (inheritGenericSig) {
if (auto genericSig =
fn->getLoweredFunctionType()->getSubstGenericSignature()) {
builder.addGenericSignature(genericSig);
depth = genericSig->getGenericParams().back()->getDepth() + 1;
}
}
// Add a new generic parameter to replace the opened existential.
auto *newGenericParam = GenericTypeParamType::get(depth, 0, ctx);
builder.addGenericParameter(newGenericParam);
Requirement newRequirement(RequirementKind::Conformance, newGenericParam,
openedExistential->getOpenedExistentialType());
auto source =
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
builder.addRequirement(newRequirement, source, nullptr);
auto genericSig = std::move(builder).computeGenericSignature(
SourceLoc(), /*allowConcreteGenericParams=*/true);
genericEnv = genericSig->getGenericEnvironment();
newArchetype =
genericEnv->mapTypeIntoContext(newGenericParam)->castTo<ArchetypeType>();
// Calculate substitutions to map the caller's archetypes to the thunk's
// archetypes.
if (auto calleeGenericSig =
fn->getLoweredFunctionType()->getSubstGenericSignature()) {
contextSubs = SubstitutionMap::get(
calleeGenericSig,
[&](SubstitutableType *type) -> Type {
return genericEnv->mapTypeIntoContext(type);
},
MakeAbstractConformanceForGenericType());
}
// Calculate substitutions to map interface types to the caller's archetypes.
interfaceSubs = SubstitutionMap::get(
genericSig,
[&](SubstitutableType *type) -> Type {
if (type->isEqual(newGenericParam))
return openedExistential;
return fn->mapTypeIntoContext(type);
},
MakeAbstractConformanceForGenericType());
return genericSig->getCanonicalSignature();
}
CanSILFunctionType buildThunkType(SILFunction *fn,
CanSILFunctionType &sourceType,
CanSILFunctionType &expectedType,
GenericEnvironment *&genericEnv,
SubstitutionMap &interfaceSubs,
bool withoutActuallyEscaping,
DifferentiationThunkKind thunkKind) {
assert(!expectedType->isPolymorphic() &&
!expectedType->getCombinedSubstitutions());
assert(!sourceType->isPolymorphic() &&
!sourceType->getCombinedSubstitutions());
// Cannot build a reabstraction thunk without context. Ownership semantics
// on the result type are required.
if (thunkKind == DifferentiationThunkKind::Reabstraction)
assert(expectedType->getExtInfo().hasContext());
// This may inherit @noescape from the expected type. The `@noescape`
// attribute is only stripped when using this type to materialize a new decl.
// Use `@convention(thin)` if:
// - Building a reabstraction thunk type.
// - Building an index subset thunk type, where the expected type has context
// (i.e. is `@convention(thick)`).
auto extInfoBuilder = expectedType->getExtInfo().intoBuilder();
if (thunkKind == DifferentiationThunkKind::Reabstraction ||
extInfoBuilder.hasContext()) {
extInfoBuilder = extInfoBuilder.withRepresentation(
SILFunctionType::Representation::Thin);
}
if (withoutActuallyEscaping)
extInfoBuilder = extInfoBuilder.withNoEscape(false);
// Does the thunk type involve archetypes other than opened existentials?
bool hasArchetypes = false;
// Does the thunk type involve an open existential type?
CanOpenedArchetypeType openedExistential;
auto archetypeVisitor = [&](CanType t) {
if (auto archetypeTy = dyn_cast<OpenedArchetypeType>(t)) {
if (auto opened = dyn_cast<OpenedArchetypeType>(archetypeTy)) {
assert((openedExistential == CanArchetypeType() ||
openedExistential == opened) &&
"one too many open existentials");
openedExistential = opened;
} else {
hasArchetypes = true;
}
}
};
// Use the generic signature from the context if the thunk involves
// generic parameters.
CanGenericSignature genericSig;
SubstitutionMap contextSubs;
ArchetypeType *newArchetype = nullptr;
if (expectedType->hasArchetype() || sourceType->hasArchetype()) {
expectedType.visit(archetypeVisitor);
sourceType.visit(archetypeVisitor);
genericSig =
buildThunkSignature(fn, hasArchetypes, openedExistential, genericEnv,
contextSubs, interfaceSubs, newArchetype);
}
auto substTypeHelper = [&](SubstitutableType *type) -> Type {
if (CanType(type) == openedExistential)
return newArchetype;
return Type(type).subst(contextSubs);
};
auto substConformanceHelper = LookUpConformanceInSubstitutionMap(contextSubs);
// Utility function to apply contextSubs, and also replace the
// opened existential with the new archetype.
auto substLoweredTypeIntoThunkContext =
[&](CanSILFunctionType t) -> CanSILFunctionType {
return SILType::getPrimitiveObjectType(t)
.subst(fn->getModule(), substTypeHelper, substConformanceHelper)
.castTo<SILFunctionType>();
};
sourceType = substLoweredTypeIntoThunkContext(sourceType);
expectedType = substLoweredTypeIntoThunkContext(expectedType);
// If our parent function was pseudogeneric, this thunk must also be
// pseudogeneric, since we have no way to pass generic parameters.
if (genericSig)
if (fn->getLoweredFunctionType()->isPseudogeneric())
extInfoBuilder = extInfoBuilder.withIsPseudogeneric();
// Add the function type as the parameter.
auto contextConvention =
SILType::getPrimitiveObjectType(sourceType).isTrivial(*fn)
? ParameterConvention::Direct_Unowned
: ParameterConvention::Direct_Guaranteed;
SmallVector<SILParameterInfo, 4> params;
params.append(expectedType->getParameters().begin(),
expectedType->getParameters().end());
// Add reabstraction function parameter only if building a reabstraction thunk
// type.
if (thunkKind == DifferentiationThunkKind::Reabstraction)
params.push_back({sourceType, sourceType->getExtInfo().hasContext()
? contextConvention
: ParameterConvention::Direct_Unowned});
auto mapTypeOutOfContext = [&](CanType type) -> CanType {
return type->mapTypeOutOfContext()->getCanonicalType(genericSig);
};
// Map the parameter and expected types out of context to get the interface
// type of the thunk.
SmallVector<SILParameterInfo, 4> interfaceParams;
interfaceParams.reserve(params.size());
for (auto &param : params) {
auto interfaceParam = param.map(mapTypeOutOfContext);
interfaceParams.push_back(interfaceParam);
}
SmallVector<SILYieldInfo, 4> interfaceYields;
for (auto &yield : expectedType->getYields()) {
auto interfaceYield = yield.map(mapTypeOutOfContext);
interfaceYields.push_back(interfaceYield);
}
SmallVector<SILResultInfo, 4> interfaceResults;
for (auto &result : expectedType->getResults()) {
auto interfaceResult = result.map(mapTypeOutOfContext);
interfaceResults.push_back(interfaceResult);
}
Optional<SILResultInfo> interfaceErrorResult;
if (expectedType->hasErrorResult()) {
auto errorResult = expectedType->getErrorResult();
interfaceErrorResult = errorResult.map(mapTypeOutOfContext);
}
// The type of the thunk function.
return SILFunctionType::get(
genericSig, extInfoBuilder.build(), expectedType->getCoroutineKind(),
ParameterConvention::Direct_Unowned, interfaceParams, interfaceYields,
interfaceResults, interfaceErrorResult,
expectedType->getPatternSubstitutions(), SubstitutionMap(),
fn->getASTContext());
}
/// Forward function arguments, handling ownership convention mismatches.
/// Adapted from `forwardFunctionArguments` in SILGenPoly.cpp.
///
/// Forwarded arguments are appended to `forwardedArgs`.
///
/// Local allocations are appended to `localAllocations`. They need to be
/// deallocated via `dealloc_stack`.
///
/// Local values requiring cleanup are appended to `valuesToCleanup`.
static void forwardFunctionArgumentsConvertingOwnership(
SILBuilder &builder, SILLocation loc, CanSILFunctionType fromTy,
CanSILFunctionType toTy, ArrayRef<SILArgument *> originalArgs,
SmallVectorImpl<SILValue> &forwardedArgs,
SmallVectorImpl<AllocStackInst *> &localAllocations,
SmallVectorImpl<SILValue> &valuesToCleanup) {
auto fromParameters = fromTy->getParameters();
auto toParameters = toTy->getParameters();
assert(fromParameters.size() == toParameters.size());
assert(fromParameters.size() == originalArgs.size());
for (auto index : indices(originalArgs)) {
auto &arg = originalArgs[index];
auto fromParam = fromParameters[index];
auto toParam = toParameters[index];
// To convert guaranteed argument to be owned, create a copy.
if (fromParam.isConsumed() && !toParam.isConsumed()) {
// If the argument has an object type, create a `copy_value`.
if (arg->getType().isObject()) {
auto argCopy = builder.emitCopyValueOperation(loc, arg);
forwardedArgs.push_back(argCopy);
continue;
}
// If the argument has an address type, create a local allocation and
// `copy_addr` its contents to the local allocation.
auto *alloc = builder.createAllocStack(loc, arg->getType());
builder.createCopyAddr(loc, arg, alloc, IsNotTake, IsInitialization);
localAllocations.push_back(alloc);
forwardedArgs.push_back(alloc);
continue;
}
// To convert owned argument to be guaranteed, borrow the argument.
if (fromParam.isGuaranteed() && !toParam.isGuaranteed()) {
auto bbi = builder.emitBeginBorrowOperation(loc, arg);
forwardedArgs.push_back(bbi);
valuesToCleanup.push_back(bbi);
valuesToCleanup.push_back(arg);
continue;
}
// Otherwise, simply forward the argument.
forwardedArgs.push_back(arg);
}
}
SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
SILModule &module, SILLocation loc,
SILFunction *caller,
CanSILFunctionType fromType,
CanSILFunctionType toType) {
assert(!fromType->getCombinedSubstitutions());
assert(!toType->getCombinedSubstitutions());
SubstitutionMap interfaceSubs;
GenericEnvironment *genericEnv = nullptr;
auto thunkType =
buildThunkType(caller, fromType, toType, genericEnv, interfaceSubs,
/*withoutActuallyEscaping*/ false,
DifferentiationThunkKind::Reabstraction);
auto thunkDeclType =
thunkType->getWithExtInfo(thunkType->getExtInfo().withNoEscape(false));
auto fromInterfaceType = fromType->mapTypeOutOfContext()->getCanonicalType();
auto toInterfaceType = toType->mapTypeOutOfContext()->getCanonicalType();
Mangle::ASTMangler mangler;
std::string name = mangler.mangleReabstractionThunkHelper(
thunkType, fromInterfaceType, toInterfaceType, Type(),
module.getSwiftModule());
auto *thunk = fb.getOrCreateSharedFunction(
loc, name, thunkDeclType, IsBare, IsTransparent, IsSerialized,
ProfileCounter(), IsReabstractionThunk, IsNotDynamic);
if (!thunk->empty())
return thunk;
thunk->setGenericEnvironment(genericEnv);
auto *entry = thunk->createBasicBlock();
SILBuilder builder(entry);
createEntryArguments(thunk);
SILFunctionConventions fromConv(fromType, module);
SILFunctionConventions toConv(toType, module);
assert(toConv.useLoweredAddresses());
// Forward thunk arguments, handling ownership convention mismatches.
SmallVector<SILValue, 4> forwardedArgs;
for (auto indRes : thunk->getIndirectResults())
forwardedArgs.push_back(indRes);
SmallVector<AllocStackInst *, 4> localAllocations;
SmallVector<SILValue, 4> valuesToCleanup;
forwardFunctionArgumentsConvertingOwnership(
builder, loc, fromType, toType,
thunk->getArgumentsWithoutIndirectResults().drop_back(), forwardedArgs,
localAllocations, valuesToCleanup);
SmallVector<SILValue, 4> arguments;
auto toArgIter = forwardedArgs.begin();
auto useNextArgument = [&]() { arguments.push_back(*toArgIter++); };
auto createAllocStack = [&](SILType type) {
auto *alloc = builder.createAllocStack(loc, type);
localAllocations.push_back(alloc);
return alloc;
};
// Handle indirect results.
assert(fromType->getNumResults() == toType->getNumResults());
for (unsigned resIdx : range(toType->getNumResults())) {
auto fromRes = fromConv.getResults()[resIdx];
auto toRes = toConv.getResults()[resIdx];
// No abstraction mismatch.
if (fromRes.isFormalIndirect() == toRes.isFormalIndirect()) {
// If result types are indirect, directly pass as next argument.
if (toRes.isFormalIndirect())
useNextArgument();
continue;
}
// Convert indirect result to direct result.
if (fromRes.isFormalIndirect()) {
SILType resultTy =
fromConv.getSILType(fromRes, builder.getTypeExpansionContext());
assert(resultTy.isAddress());
auto *indRes = createAllocStack(resultTy);
arguments.push_back(indRes);
continue;
}
// Convert direct result to indirect result.
// Increment thunk argument iterator; reabstraction handled later.
++toArgIter;
}
// Reabstract parameters.
assert(toType->getNumParameters() == fromType->getNumParameters());
for (unsigned paramIdx : range(toType->getNumParameters())) {
auto fromParam = fromConv.getParameters()[paramIdx];
auto toParam = toConv.getParameters()[paramIdx];
// No abstraction mismatch. Directly use next argument.
if (fromParam.isFormalIndirect() == toParam.isFormalIndirect()) {
useNextArgument();
continue;
}
// Convert indirect parameter to direct parameter.
if (fromParam.isFormalIndirect()) {
auto paramTy = fromConv.getSILType(fromType->getParameters()[paramIdx],
builder.getTypeExpansionContext());
if (!paramTy.hasArchetype())
paramTy = thunk->mapTypeIntoContext(paramTy);
assert(paramTy.isAddress());
auto toArg = *toArgIter++;
auto *buf = createAllocStack(toArg->getType());
toArg = builder.emitCopyValueOperation(loc, toArg);
builder.emitStoreValueOperation(loc, toArg, buf,
StoreOwnershipQualifier::Init);
valuesToCleanup.push_back(buf);
arguments.push_back(buf);
continue;
}
// Convert direct parameter to indirect parameter.
assert(toParam.isFormalIndirect());
auto toArg = *toArgIter++;
auto load = builder.emitLoadBorrowOperation(loc, toArg);
if (isa<LoadBorrowInst>(load))
valuesToCleanup.push_back(load);
arguments.push_back(load);
}
auto *fnArg = thunk->getArgumentsWithoutIndirectResults().back();
auto *apply = builder.createApply(loc, fnArg, SubstitutionMap(), arguments,
/*isNonThrowing*/ false);
// Get return elements.
SmallVector<SILValue, 4> results;
// Extract all direct results.
SmallVector<SILValue, 4> directResults;
extractAllElements(apply, builder, directResults);
auto fromDirResultsIter = directResults.begin();
auto fromIndResultsIter = apply->getIndirectSILResults().begin();
auto toIndResultsIter = thunk->getIndirectResults().begin();
// Reabstract results.
for (unsigned resIdx : range(toType->getNumResults())) {
auto fromRes = fromConv.getResults()[resIdx];
auto toRes = toConv.getResults()[resIdx];
// Check function-typed results.
if (isa<SILFunctionType>(fromRes.getInterfaceType()) &&
isa<SILFunctionType>(toRes.getInterfaceType())) {
auto fromFnType = cast<SILFunctionType>(fromRes.getInterfaceType());
auto toFnType = cast<SILFunctionType>(toRes.getInterfaceType());
auto fromUnsubstFnType = fromFnType->getUnsubstitutedType(module);
auto toUnsubstFnType = toFnType->getUnsubstitutedType(module);
// If unsubstituted function types are not equal, perform reabstraction.
if (fromUnsubstFnType != toUnsubstFnType) {
auto fromFn = *fromDirResultsIter++;
auto newFromFn = reabstractFunction(
builder, fb, loc, fromFn, toFnType,
[](SubstitutionMap substMap) { return substMap; });
results.push_back(newFromFn);
continue;
}
}
// No abstraction mismatch.
if (fromRes.isFormalIndirect() == toRes.isFormalIndirect()) {
// If result types are direct, add call result as direct thunk result.
if (toRes.isFormalDirect())
results.push_back(*fromDirResultsIter++);
// If result types are indirect, increment indirect result iterators.
else {
++fromIndResultsIter;
++toIndResultsIter;
}
continue;
}
// Load direct results from indirect results.
if (fromRes.isFormalIndirect()) {
auto indRes = *fromIndResultsIter++;
auto load = builder.emitLoadValueOperation(loc, indRes,
LoadOwnershipQualifier::Take);
results.push_back(load);
continue;
}
// Store direct results to indirect results.
assert(toRes.isFormalIndirect());
#ifndef NDEBUG
SILType resultTy =
toConv.getSILType(toRes, builder.getTypeExpansionContext());
assert(resultTy.isAddress());
#endif
auto indRes = *toIndResultsIter++;
auto dirRes = *fromDirResultsIter++;
builder.emitStoreValueOperation(loc, dirRes, indRes,
StoreOwnershipQualifier::Init);
}
auto retVal = joinElements(results, builder, loc);
// Clean up local values.
// Guaranteed values need an `end_borrow`.
// Owned values need to be destroyed.
for (auto arg : valuesToCleanup) {
switch (arg.getOwnershipKind()) {
case OwnershipKind::Any:
llvm_unreachable("value with any ownership kind?!");
case OwnershipKind::Guaranteed:
builder.emitEndBorrowOperation(loc, arg);
break;
case OwnershipKind::Owned:
case OwnershipKind::Unowned:
case OwnershipKind::None:
builder.emitDestroyOperation(loc, arg);
break;
}
}
// Deallocate local allocations.
for (auto *alloc : llvm::reverse(localAllocations))
builder.createDeallocStack(loc, alloc);
// Create return.
builder.createReturn(loc, retVal);
LLVM_DEBUG(auto &s = getADDebugStream() << "Created reabstraction thunk.\n";
s << " From type: " << fromType << '\n';
s << " To type: " << toType << '\n'; s << '\n'
<< *thunk);
return thunk;
}
SILValue reabstractFunction(
SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc,
SILValue fn, CanSILFunctionType toType,
std::function<SubstitutionMap(SubstitutionMap)> remapSubstitutions) {
auto &module = *fn->getModule();
auto fromType = fn->getType().getAs<SILFunctionType>();
auto unsubstFromType = fromType->getUnsubstitutedType(module);
auto unsubstToType = toType->getUnsubstitutedType(module);
auto *thunk = getOrCreateReabstractionThunk(fb, module, loc,
/*caller*/ fn->getFunction(),
unsubstFromType, unsubstToType);
auto *thunkRef = builder.createFunctionRef(loc, thunk);
if (fromType != unsubstFromType)
fn = builder.createConvertFunction(
loc, fn, SILType::getPrimitiveObjectType(unsubstFromType),
/*withoutActuallyEscaping*/ false);
fn = builder.createPartialApply(
loc, thunkRef, remapSubstitutions(thunk->getForwardingSubstitutionMap()),
{fn}, fromType->getCalleeConvention());
if (toType != unsubstToType)
fn = builder.createConvertFunction(loc, fn,
SILType::getPrimitiveObjectType(toType),
/*withoutActuallyEscaping*/ false);
return fn;
}
std::pair<SILFunction *, SubstitutionMap>
getOrCreateSubsetParametersThunkForLinearMap(
SILOptFunctionBuilder &fb, SILFunction *parentThunk,
CanSILFunctionType origFnType, CanSILFunctionType linearMapType,
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
AutoDiffConfig desiredConfig, AutoDiffConfig actualConfig) {
LLVM_DEBUG(getADDebugStream()
<< "Getting a subset parameters thunk for " << linearMapType
<< " from " << actualConfig << " to " << desiredConfig << '\n');
assert(!linearMapType->getCombinedSubstitutions());
assert(!targetType->getCombinedSubstitutions());
SubstitutionMap interfaceSubs;
GenericEnvironment *genericEnv = nullptr;
auto thunkType = buildThunkType(parentThunk, linearMapType, targetType,
genericEnv, interfaceSubs,
/*withoutActuallyEscaping*/ true,
DifferentiationThunkKind::Reabstraction);
// TODO(TF-685): Use more principled mangling for thunks.
std::string thunkName;
switch (kind) {
case AutoDiffDerivativeFunctionKind::JVP:
thunkName = "differential";
break;
case AutoDiffDerivativeFunctionKind::VJP:
thunkName = "pullback";
}
Mangle::ASTMangler mangler;
auto fromInterfaceType =
linearMapType->mapTypeOutOfContext()->getCanonicalType();
auto toInterfaceType = targetType->mapTypeOutOfContext()->getCanonicalType();
CanType dynamicSelfType;
thunkName = "AD__" +
mangler.mangleReabstractionThunkHelper(
thunkType, fromInterfaceType, toInterfaceType,
dynamicSelfType, parentThunk->getModule().getSwiftModule()) +
"_" + desiredConfig.mangle() + "_" + thunkName;
thunkName += "_index_subset_thunk";
auto loc = parentThunk->getLocation();
auto *thunk = fb.getOrCreateSharedFunction(
loc, thunkName, thunkType, IsBare, IsTransparent, IsSerialized,
ProfileCounter(), IsThunk, IsNotDynamic);
if (!thunk->empty())
return {thunk, interfaceSubs};
// TODO(TF-1206): Enable ownership in all differentiation thunks.
thunk->setOwnershipEliminated();
thunk->setGenericEnvironment(genericEnv);
auto *entry = thunk->createBasicBlock();
SILBuilder builder(entry);
createEntryArguments(thunk);
// Get arguments.
SmallVector<SILValue, 4> arguments;
SmallVector<AllocStackInst *, 4> localAllocations;
// Build a `.zero` argument for the given `Differentiable`-conforming type.
auto buildZeroArgument = [&](SILType zeroSILType) {
auto zeroSILObjType = zeroSILType.getObjectType();
auto zeroType = zeroSILType.getASTType();
auto *swiftMod = parentThunk->getModule().getSwiftModule();
auto tangentSpace =
zeroType->getAutoDiffTangentSpace(LookUpConformanceInModule(swiftMod));
assert(tangentSpace && "No tangent space for this type");
switch (tangentSpace->getKind()) {
case TangentSpace::Kind::TangentVector: {
auto *buf = builder.createAllocStack(loc, zeroSILObjType);
localAllocations.push_back(buf);
emitZeroIntoBuffer(builder, zeroType, buf, loc);
if (zeroSILType.isAddress()) {
arguments.push_back(buf);
} else {
auto arg = builder.emitLoadValueOperation(loc, buf,
LoadOwnershipQualifier::Take);
arguments.push_back(arg);
}
break;
}
case TangentSpace::Kind::Tuple: {
llvm_unreachable("Unimplemented: Handle zero initialization for tuples");
}
}
};
// The indices in `actualConfig` and `desiredConfig` are with respect to the
// original function. However, the differential parameters and pullback
// results may already be w.r.t. a subset. We create a map between the
// original function's actual parameter indices and the linear map's actual
// indices.
// Example:
// Original: (T0, T1, T2) -> R
// Actual indices: 0, 2
// Original differential: (T0, T2) -> R
// Original pullback: R -> (T0, T2)
// Desired indices w.r.t. original: 2
// Desired indices w.r.t. linear map: 1
SmallVector<unsigned, 4> actualParamIndicesMap(
actualConfig.parameterIndices->getCapacity(), UINT_MAX);
{
unsigned indexInBitVec = 0;
for (auto index : actualConfig.parameterIndices->getIndices()) {
actualParamIndicesMap[index] = indexInBitVec;
++indexInBitVec;
}
}
auto mapOriginalParameterIndex = [&](unsigned index) -> unsigned {
auto mappedIndex = actualParamIndicesMap[index];
assert(mappedIndex < actualConfig.parameterIndices->getCapacity());
return mappedIndex;
};
switch (kind) {
// Differential arguments are:
// - All indirect results, followed by:
// - An interleaving of:
// - Thunk arguments (when parameter index is in both desired and actual
// indices).
// - Zeros (when parameter is not in desired indices).
case AutoDiffDerivativeFunctionKind::JVP: {
// Forward all indirect results.
arguments.append(thunk->getIndirectResults().begin(),
thunk->getIndirectResults().end());
auto toArgIter = thunk->getArgumentsWithoutIndirectResults().begin();
auto useNextArgument = [&]() { arguments.push_back(*toArgIter++); };
// Iterate over actual indices.
for (unsigned i : actualConfig.parameterIndices->getIndices()) {
// If index is desired, use next argument.
if (desiredConfig.isWrtParameter(i)) {
useNextArgument();
}
// Otherwise, construct and use a zero argument.
else {
auto zeroSILType =
linearMapType->getParameters()[mapOriginalParameterIndex(i)]
.getSILStorageInterfaceType();
buildZeroArgument(zeroSILType);
}
}
break;
}
// Pullback arguments are:
// - An interleaving of:
// - Thunk indirect results (when parameter index is in both desired and
// actual indices).
// - Zeros (when parameter is not in desired indices).
// - All actual arguments.
case AutoDiffDerivativeFunctionKind::VJP: {
auto toIndirectResultsIter = thunk->getIndirectResults().begin();
auto useNextIndirectResult = [&]() {
arguments.push_back(*toIndirectResultsIter++);
};
// Collect pullback arguments.
unsigned pullbackResultIndex = 0;
for (unsigned i : actualConfig.parameterIndices->getIndices()) {
auto origParamInfo = origFnType->getParameters()[i];
// Skip original `inout` parameters. All non-indirect-result pullback
// arguments (including `inout` arguments) are appended to `arguments`
// later.
if (origParamInfo.isIndirectMutating())
continue;
auto resultInfo = linearMapType->getResults()[pullbackResultIndex];
assert(pullbackResultIndex < linearMapType->getNumResults());
++pullbackResultIndex;
// Skip pullback direct results. Only indirect results are relevant as
// arguments.
if (resultInfo.isFormalDirect())
continue;
// If index is desired, use next pullback indirect result.
if (desiredConfig.isWrtParameter(i)) {
useNextIndirectResult();
continue;
}
// Otherwise, allocate and use an uninitialized pullback indirect result.
auto *indirectResult = builder.createAllocStack(
loc, resultInfo.getSILStorageInterfaceType());
localAllocations.push_back(indirectResult);
arguments.push_back(indirectResult);
}
// Forward all actual non-indirect-result arguments.
arguments.append(thunk->getArgumentsWithoutIndirectResults().begin(),
thunk->getArgumentsWithoutIndirectResults().end() - 1);
break;
}
}
// Get the linear map thunk argument and apply it.
auto *linearMap = thunk->getArguments().back();
auto *ai = builder.createApply(loc, linearMap, SubstitutionMap(), arguments,
/*isNonThrowing*/ false);
// If differential thunk, deallocate local allocations and directly return
// `apply` result.
if (kind == AutoDiffDerivativeFunctionKind::JVP) {
for (auto *alloc : llvm::reverse(localAllocations))
builder.createDeallocStack(loc, alloc);
builder.createReturn(loc, ai);
return {thunk, interfaceSubs};
}
// If pullback thunk, return only the desired results and clean up the
// undesired results.
SmallVector<SILValue, 8> pullbackDirectResults;
extractAllElements(ai, builder, pullbackDirectResults);
SmallVector<SILValue, 8> allResults;
collectAllActualResultsInTypeOrder(ai, pullbackDirectResults, allResults);
// Collect pullback `inout` arguments in type order.
unsigned inoutArgIdx = 0;
SILFunctionConventions origConv(origFnType, thunk->getModule());
for (auto paramIdx : actualConfig.parameterIndices->getIndices()) {
auto paramInfo = origConv.getParameters()[paramIdx];
if (!paramInfo.isIndirectMutating())
continue;
auto inoutArg = *std::next(ai->getInoutArguments().begin(), inoutArgIdx++);
unsigned mappedParamIdx = mapOriginalParameterIndex(paramIdx);
allResults.insert(allResults.begin() + mappedParamIdx, inoutArg);
}
assert(allResults.size() == actualConfig.parameterIndices->getNumIndices() &&
"Number of pullback results should match number of differentiability "
"parameters");
SmallVector<SILValue, 8> results;
for (unsigned i : actualConfig.parameterIndices->getIndices()) {
unsigned mappedIndex = mapOriginalParameterIndex(i);
// If result is desired:
// - Do nothing if result is indirect.
// (It was already forwarded to the `apply` instruction).
// - Push it to `results` if result is direct.
auto result = allResults[mappedIndex];
if (desiredConfig.isWrtParameter(i)) {
if (result->getType().isObject())
results.push_back(result);
}
// Otherwise, cleanup the unused results.
else {
if (result->getType().isAddress())
builder.emitDestroyAddrAndFold(loc, result);
else
builder.emitDestroyValueOperation(loc, result);
}
}
// Deallocate local allocations and return final direct result.
for (auto *alloc : llvm::reverse(localAllocations))
builder.createDeallocStack(loc, alloc);
auto result = joinElements(results, builder, loc);
builder.createReturn(loc, result);
return {thunk, interfaceSubs};
}
std::pair<SILFunction *, SubstitutionMap>
getOrCreateSubsetParametersThunkForDerivativeFunction(
SILOptFunctionBuilder &fb, SILValue origFnOperand, SILValue derivativeFn,
AutoDiffDerivativeFunctionKind kind, AutoDiffConfig desiredConfig,
AutoDiffConfig actualConfig) {
LLVM_DEBUG(getADDebugStream()
<< "Getting a subset parameters thunk for derivative function "
<< derivativeFn << " of the original function " << origFnOperand
<< " from " << actualConfig << " to " << desiredConfig << '\n');
auto origFnType = origFnOperand->getType().castTo<SILFunctionType>();
auto &module = fb.getModule();
auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule());
// Compute target type for thunking.
auto derivativeFnType = derivativeFn->getType().castTo<SILFunctionType>();
auto targetType = origFnType->getAutoDiffDerivativeFunctionType(
desiredConfig.parameterIndices, desiredConfig.resultIndices, kind,
module.Types, lookupConformance);
auto *caller = derivativeFn->getFunction();
if (targetType->hasArchetype()) {
auto substTargetType =
caller->mapTypeIntoContext(targetType->mapTypeOutOfContext())
->getCanonicalType();
targetType = SILType::getPrimitiveObjectType(substTargetType)
.castTo<SILFunctionType>();
}
assert(derivativeFnType->getNumParameters() ==
targetType->getNumParameters());
assert(derivativeFnType->getNumResults() == targetType->getNumResults());
// Build thunk type.
SubstitutionMap interfaceSubs;
GenericEnvironment *genericEnv = nullptr;
auto thunkType = buildThunkType(derivativeFn->getFunction(), derivativeFnType,
targetType, genericEnv, interfaceSubs,
/*withoutActuallyEscaping*/ false,
DifferentiationThunkKind::IndexSubset);
// FIXME: The logic for resolving `assocRef` does not reapply function
// conversions, which is problematic if `derivativeFn` is a `partial_apply`
// instruction.
StringRef origName;
if (auto *origFnRef =
peerThroughFunctionConversions<FunctionRefInst>(origFnOperand)) {
origName = origFnRef->getInitiallyReferencedFunction()->getName();
} else if (auto *origMethodInst =
peerThroughFunctionConversions<MethodInst>(origFnOperand)) {
origName = origMethodInst->getMember()
.getAnyFunctionRef()
->getAbstractFunctionDecl()
->getNameStr();
}
assert(!origName.empty() && "Original function name could not be resolved");
// TODO(TF-685): Use more principled mangling for thunks.
std::string thunkName;
switch (kind) {
case AutoDiffDerivativeFunctionKind::JVP:
thunkName = "jvp";
break;
case AutoDiffDerivativeFunctionKind::VJP:
thunkName = "vjp";
}
Mangle::ASTMangler mangler;
auto fromInterfaceType =
derivativeFnType->mapTypeOutOfContext()->getCanonicalType();
auto toInterfaceType = targetType->mapTypeOutOfContext()->getCanonicalType();
CanType dynamicSelfType;
thunkName = "AD__orig_" + origName.str() + "_" +
mangler.mangleReabstractionThunkHelper(
thunkType, fromInterfaceType, toInterfaceType,
dynamicSelfType, module.getSwiftModule()) +
"_" + desiredConfig.mangle() + "_" + thunkName;
thunkName += "_subset_parameters_thunk";
auto loc = origFnOperand.getLoc();
auto *thunk = fb.getOrCreateSharedFunction(
loc, thunkName, thunkType, IsBare, IsTransparent, caller->isSerialized(),
ProfileCounter(), IsThunk, IsNotDynamic);
if (!thunk->empty())
return {thunk, interfaceSubs};
thunk->setGenericEnvironment(genericEnv);
auto *entry = thunk->createBasicBlock();
SILBuilder builder(entry);
createEntryArguments(thunk);
SubstitutionMap assocSubstMap;
if (auto *partialApply = dyn_cast<PartialApplyInst>(derivativeFn))
assocSubstMap = partialApply->getSubstitutionMap();
// FIXME: The logic for resolving `assocRef` does not reapply function
// conversions, which is problematic if `derivativeFn` is a `partial_apply`
// instruction.
SILValue assocRef;
if (auto *derivativeFnRef =
peerThroughFunctionConversions<FunctionRefInst>(derivativeFn)) {
auto *assoc = derivativeFnRef->getReferencedFunctionOrNull();
assocRef = builder.createFunctionRef(loc, assoc);
} else if (auto *assocMethodInst =
peerThroughFunctionConversions<WitnessMethodInst>(
derivativeFn)) {
assocRef = builder.createWitnessMethod(
loc, assocMethodInst->getLookupType(),
assocMethodInst->getConformance(), assocMethodInst->getMember(),
thunk->mapTypeIntoContext(assocMethodInst->getType()));
} else if (auto *assocMethodInst =
peerThroughFunctionConversions<ClassMethodInst>(
derivativeFn)) {
auto classOperand = thunk->getArgumentsWithoutIndirectResults().back();
#ifndef NDEBUG
auto classOperandType = assocMethodInst->getOperand()->getType();
assert(classOperand->getType() == classOperandType);
#endif
assocRef = builder.createClassMethod(
loc, classOperand, assocMethodInst->getMember(),
thunk->mapTypeIntoContext(assocMethodInst->getType()));
} else if (auto *diffWitFn = peerThroughFunctionConversions<
DifferentiabilityWitnessFunctionInst>(derivativeFn)) {
assocRef = builder.createDifferentiabilityWitnessFunction(
loc, diffWitFn->getWitnessKind(), diffWitFn->getWitness());
}
assert(assocRef && "Expected derivative function to be resolved");
assocSubstMap = assocSubstMap.subst(thunk->getForwardingSubstitutionMap());
derivativeFnType = assocRef->getType().castTo<SILFunctionType>();
SmallVector<SILValue, 4> arguments;
arguments.append(thunk->getArguments().begin(), thunk->getArguments().end());
assert(arguments.size() ==
derivativeFnType->getNumParameters() +
derivativeFnType->getNumIndirectFormalResults());
auto *apply = builder.createApply(loc, assocRef, assocSubstMap, arguments,
/*isNonThrowing*/ false);
// Extract all direct results.
SmallVector<SILValue, 8> directResults;
extractAllElements(apply, builder, directResults);
auto originalDirectResults = ArrayRef<SILValue>(directResults).drop_back(1);
auto originalDirectResult =
joinElements(originalDirectResults, builder, apply->getLoc());
auto linearMap = directResults.back();
auto linearMapType = linearMap->getType().castTo<SILFunctionType>();
auto linearMapTargetType = targetType->getResults()
.back()
.getSILStorageInterfaceType()
.castTo<SILFunctionType>();
auto unsubstLinearMapType = linearMapType->getUnsubstitutedType(module);
auto unsubstLinearMapTargetType =
linearMapTargetType->getUnsubstitutedType(module);
SILFunction *linearMapThunk;
SubstitutionMap linearMapSubs;
std::tie(linearMapThunk, linearMapSubs) =
getOrCreateSubsetParametersThunkForLinearMap(
fb, thunk, origFnType, unsubstLinearMapType,
unsubstLinearMapTargetType, kind, desiredConfig, actualConfig);
auto *linearMapThunkFRI = builder.createFunctionRef(loc, linearMapThunk);
SILValue thunkedLinearMap = linearMap;
if (linearMapType != unsubstLinearMapType) {
thunkedLinearMap = builder.createConvertFunction(
loc, thunkedLinearMap,
SILType::getPrimitiveObjectType(unsubstLinearMapType),
/*withoutActuallyEscaping*/ false);
}
thunkedLinearMap = builder.createPartialApply(
loc, linearMapThunkFRI, linearMapSubs, {thunkedLinearMap},
ParameterConvention::Direct_Guaranteed);
if (linearMapTargetType != unsubstLinearMapTargetType) {
thunkedLinearMap = builder.createConvertFunction(
loc, thunkedLinearMap,
SILType::getPrimitiveObjectType(linearMapTargetType),
/*withoutActuallyEscaping*/ false);
}
assert(origFnType->getNumResults() +
origFnType->getNumIndirectMutatingParameters() ==
1);
if (origFnType->getNumResults() > 0 &&
origFnType->getResults().front().isFormalDirect()) {
auto result =
joinElements({originalDirectResult, thunkedLinearMap}, builder, loc);
builder.createReturn(loc, result);
} else {
builder.createReturn(loc, thunkedLinearMap);
}
return {thunk, interfaceSubs};
}
} // end namespace autodiff
} // end namespace swift