| //===--- JVPCloner.cpp - JVP function generation --------------*- C++ -*---===// |
| // |
| // This source file is part of the Swift.org open source project |
| // |
| // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors |
| // Licensed under Apache License v2.0 with Runtime Library Exception |
| // |
| // See https://swift.org/LICENSE.txt for license information |
| // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file defines a helper class for generating JVP functions for automatic |
| // differentiation. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #define DEBUG_TYPE "differentiation" |
| |
| #include "swift/SILOptimizer/Differentiation/JVPCloner.h" |
| #include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" |
| #include "swift/SILOptimizer/Differentiation/ADContext.h" |
| #include "swift/SILOptimizer/Differentiation/AdjointValue.h" |
| #include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h" |
| #include "swift/SILOptimizer/Differentiation/LinearMapInfo.h" |
| #include "swift/SILOptimizer/Differentiation/PullbackCloner.h" |
| #include "swift/SILOptimizer/Differentiation/Thunk.h" |
| |
| #include "swift/SIL/LoopInfo.h" |
| #include "swift/SIL/TypeSubstCloner.h" |
| #include "swift/SILOptimizer/Analysis/LoopAnalysis.h" |
| #include "swift/SILOptimizer/PassManager/PrettyStackTrace.h" |
| #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" |
| #include "llvm/ADT/DenseMap.h" |
| |
| using namespace swift; |
| using namespace autodiff; |
| |
| namespace swift { |
| namespace autodiff { |
| |
| class JVPCloner::Implementation final |
| : public TypeSubstCloner<JVPCloner::Implementation, SILOptFunctionBuilder> { |
| private: |
| /// The global context. |
| ADContext &context; |
| |
| /// The original function. |
| SILFunction *const original; |
| |
| /// The witness. |
| SILDifferentiabilityWitness *const witness; |
| |
| /// The JVP function. |
| SILFunction *const jvp; |
| |
| llvm::BumpPtrAllocator allocator; |
| |
| /// The differentiation invoker. |
| DifferentiationInvoker invoker; |
| |
| /// Info from activity analysis on the original function. |
| const DifferentiableActivityInfo &activityInfo; |
| |
| /// The loop info. |
| SILLoopInfo *loopInfo; |
| |
| /// The differential info. |
| LinearMapInfo differentialInfo; |
| |
| bool errorOccurred = false; |
| |
| //--------------------------------------------------------------------------// |
| // Differential generation related fields |
| //--------------------------------------------------------------------------// |
| |
| /// The builder for the differential function. |
| SILBuilder differentialBuilder; |
| |
| /// Mapping from original basic blocks to corresponding differential basic |
| /// blocks. |
| llvm::DenseMap<SILBasicBlock *, SILBasicBlock *> diffBBMap; |
| |
| /// Mapping from original basic blocks and original values to corresponding |
| /// tangent values. |
| llvm::DenseMap<SILValue, AdjointValue> tangentValueMap; |
| |
| /// Mapping from original basic blocks and original buffers to corresponding |
| /// tangent buffers. |
| llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, SILValue> bufferMap; |
| |
| /// Mapping from differential basic blocks to differential struct arguments. |
| llvm::DenseMap<SILBasicBlock *, SILArgument *> differentialStructArguments; |
| |
| /// Mapping from differential struct field declarations to differential struct |
| /// elements destructured from the linear map basic block argument. In the |
| /// beginning of each differential basic block, the block's differential |
| /// struct is destructured into the individual elements stored here. |
| llvm::DenseMap<VarDecl *, SILValue> differentialStructElements; |
| |
| /// An auxiliary differential local allocation builder. |
| SILBuilder diffLocalAllocBuilder; |
| |
| /// Stack buffers allocated for storing local tangent values. |
| SmallVector<SILValue, 8> differentialLocalAllocations; |
| |
| /// Mapping from original blocks to differential values. Used to build |
| /// differential struct instances. |
| llvm::DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> differentialValues; |
| |
| //--------------------------------------------------------------------------// |
| // Getters |
| //--------------------------------------------------------------------------// |
| |
| ASTContext &getASTContext() const { return jvp->getASTContext(); } |
| SILModule &getModule() const { return jvp->getModule(); } |
| const AutoDiffConfig getConfig() const { return witness->getConfig(); } |
| SILBuilder &getDifferentialBuilder() { return differentialBuilder; } |
| SILFunction &getDifferential() { return differentialBuilder.getFunction(); } |
| SILArgument *getDifferentialStructArgument(SILBasicBlock *origBB) { |
| #ifndef NDEBUG |
| auto *diffStruct = differentialStructArguments[origBB] |
| ->getType() |
| .getStructOrBoundGenericStruct(); |
| assert(diffStruct == differentialInfo.getLinearMapStruct(origBB)); |
| #endif |
| return differentialStructArguments[origBB]; |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Differential struct mapping |
| //--------------------------------------------------------------------------// |
| |
| void initializeDifferentialStructElements(SILBasicBlock *origBB, |
| SILInstructionResultArray values); |
| |
| SILValue getDifferentialStructElement(SILBasicBlock *origBB, VarDecl *field); |
| |
| //--------------------------------------------------------------------------// |
| // General utilities |
| //--------------------------------------------------------------------------// |
| |
| /// Get the lowered SIL type of the given AST type. |
| SILType getLoweredType(Type type) { |
| auto jvpGenSig = jvp->getLoweredFunctionType()->getSubstGenericSignature(); |
| Lowering::AbstractionPattern pattern(jvpGenSig, |
| type->getCanonicalType(jvpGenSig)); |
| return jvp->getLoweredType(pattern, type); |
| } |
| |
| /// Get the lowered SIL type of the given nominal type declaration. |
| SILType getNominalDeclLoweredType(NominalTypeDecl *nominal) { |
| auto nominalType = |
| getOpASTType(nominal->getDeclaredInterfaceType()->getCanonicalType()); |
| return getLoweredType(nominalType); |
| } |
| |
| /// Build a differential struct value for the original block corresponding to |
| /// the given terminator. |
| StructInst *buildDifferentialValueStructValue(TermInst *termInst) { |
| assert(termInst->getFunction() == original); |
| auto loc = termInst->getFunction()->getLocation(); |
| auto *origBB = termInst->getParent(); |
| auto *jvpBB = BBMap[origBB]; |
| assert(jvpBB && "Basic block mapping should exist"); |
| auto *diffStruct = differentialInfo.getLinearMapStruct(origBB); |
| assert(diffStruct && "The differential struct should have been declared"); |
| auto structLoweredTy = getNominalDeclLoweredType(diffStruct); |
| auto bbDifferentialValues = differentialValues[origBB]; |
| if (!origBB->isEntry()) { |
| auto *enumArg = jvpBB->getArguments().back(); |
| bbDifferentialValues.insert(bbDifferentialValues.begin(), enumArg); |
| } |
| return getBuilder().createStruct(loc, structLoweredTy, |
| bbDifferentialValues); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Tangent value factory methods |
| //--------------------------------------------------------------------------// |
| |
| AdjointValue makeZeroTangentValue(SILType type) { |
| return AdjointValue::createZero(allocator, |
| remapSILTypeInDifferential(type)); |
| } |
| |
| AdjointValue makeConcreteTangentValue(SILValue value) { |
| return AdjointValue::createConcrete(allocator, value); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Tangent materialization |
| //--------------------------------------------------------------------------// |
| |
| void emitZeroIndirect(CanType type, SILValue bufferAccess, SILLocation loc) { |
| auto builder = getDifferentialBuilder(); |
| auto tangentSpace = getTangentSpace(type); |
| assert(tangentSpace && "No tangent space for this type"); |
| switch (tangentSpace->getKind()) { |
| case TangentSpace::Kind::TangentVector: |
| emitZeroIntoBuffer(builder, type, bufferAccess, loc); |
| return; |
| case TangentSpace::Kind::Tuple: { |
| auto tupleType = tangentSpace->getTuple(); |
| SmallVector<SILValue, 8> zeroElements; |
| for (unsigned i : range(tupleType->getNumElements())) { |
| auto eltAddr = builder.createTupleElementAddr(loc, bufferAccess, i); |
| emitZeroIndirect(tupleType->getElementType(i)->getCanonicalType(), |
| eltAddr, loc); |
| } |
| return; |
| } |
| } |
| } |
| |
| SILValue emitZeroDirect(CanType type, SILLocation loc) { |
| auto diffBuilder = getDifferentialBuilder(); |
| auto silType = getModule().Types.getLoweredLoadableType( |
| type, TypeExpansionContext::minimal(), getModule()); |
| auto *buffer = diffBuilder.createAllocStack(loc, silType); |
| emitZeroIndirect(type, buffer, loc); |
| auto loaded = diffBuilder.emitLoadValueOperation( |
| loc, buffer, LoadOwnershipQualifier::Take); |
| diffBuilder.createDeallocStack(loc, buffer); |
| return loaded; |
| } |
| |
| SILValue materializeTangentDirect(AdjointValue val, SILLocation loc) { |
| assert(val.getType().isObject()); |
| LLVM_DEBUG(getADDebugStream() |
| << "Materializing tangents for " << val << '\n'); |
| switch (val.getKind()) { |
| case AdjointValueKind::Zero: { |
| auto zeroVal = emitZeroDirect(val.getSwiftType(), loc); |
| return zeroVal; |
| } |
| case AdjointValueKind::Aggregate: |
| llvm_unreachable( |
| "Tuples and structs are not supported in forward mode yet."); |
| case AdjointValueKind::Concrete: |
| return val.getConcreteValue(); |
| } |
| llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715 |
| } |
| |
| SILValue materializeTangent(AdjointValue val, SILLocation loc) { |
| if (val.isConcrete()) { |
| LLVM_DEBUG(getADDebugStream() |
| << "Materializing tangent: Value is concrete.\n"); |
| return val.getConcreteValue(); |
| } |
| LLVM_DEBUG(getADDebugStream() << "Materializing tangent: Value is " |
| "non-concrete. Materializing directly.\n"); |
| return materializeTangentDirect(val, loc); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Tangent value mapping |
| //--------------------------------------------------------------------------// |
| |
| /// Get the tangent for an original value. The given value must be in the |
| /// original function. |
| /// |
| /// This method first tries to find an entry in `tangentValueMap`. If an entry |
| /// doesn't exist, create a zero tangent. |
| AdjointValue getTangentValue(SILValue originalValue) { |
| assert(originalValue->getType().isObject()); |
| assert(originalValue->getFunction() == original); |
| auto insertion = tangentValueMap.try_emplace( |
| originalValue, |
| makeZeroTangentValue(getRemappedTangentType(originalValue->getType()))); |
| return insertion.first->getSecond(); |
| } |
| |
| /// Map the tangent value to the given original value. |
| void setTangentValue(SILBasicBlock *origBB, SILValue originalValue, |
| AdjointValue newTangentValue) { |
| #ifndef NDEBUG |
| if (auto *defInst = originalValue->getDefiningInstruction()) { |
| bool isTupleTypedApplyResult = |
| isa<ApplyInst>(defInst) && originalValue->getType().is<TupleType>(); |
| assert(!isTupleTypedApplyResult && |
| "Should not set tangent value for tuple-typed result from `apply` " |
| "instruction; use `destructure_tuple` on `apply` result and set " |
| "tangent value for `destructure_tuple` results instead."); |
| } |
| #endif |
| assert(originalValue->getType().isObject()); |
| assert(newTangentValue.getType().isObject()); |
| assert(originalValue->getFunction() == original); |
| LLVM_DEBUG(getADDebugStream() |
| << "Setting tangent value for " << originalValue); |
| // The tangent value must be in the tangent space. |
| assert(newTangentValue.getType() == |
| getRemappedTangentType(originalValue->getType())); |
| auto insertion = |
| tangentValueMap.try_emplace(originalValue, newTangentValue); |
| (void)insertion; |
| assert(insertion.second && "The tangent value should not already exist."); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Tangent buffer mapping |
| //--------------------------------------------------------------------------// |
| |
| /// Sets the tangent buffer for the original buffer. Asserts that the |
| /// original buffer does not already have a tangent buffer. |
| void setTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer, |
| SILValue tangentBuffer) { |
| assert(originalBuffer->getType().isAddress()); |
| auto insertion = |
| bufferMap.try_emplace({origBB, originalBuffer}, tangentBuffer); |
| assert(insertion.second && "Tangent buffer already exists"); |
| (void)insertion; |
| } |
| |
| /// Returns the tangent buffer for the original buffer. Asserts that the |
| /// original buffer has a tangent buffer. |
| SILValue &getTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer) { |
| assert(originalBuffer->getType().isAddress()); |
| assert(originalBuffer->getFunction() == original); |
| auto it = bufferMap.find({origBB, originalBuffer}); |
| assert(it != bufferMap.end() && "Tangent buffer should already exist"); |
| return it->getSecond(); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Differential type calculations |
| //--------------------------------------------------------------------------// |
| |
| /// Substitutes all replacement types of the given substitution map using the |
| /// tangent function's substitution map. |
| SubstitutionMap remapSubstitutionMapInDifferential(SubstitutionMap substMap) { |
| return substMap.subst(getDifferential().getForwardingSubstitutionMap()); |
| } |
| |
| /// Remap any archetypes into the differential function's context. |
| Type remapTypeInDifferential(Type ty) { |
| if (ty->hasArchetype()) |
| return getDifferential().mapTypeIntoContext(ty->mapTypeOutOfContext()); |
| return getDifferential().mapTypeIntoContext(ty); |
| } |
| |
| /// Remap any archetypes into the differential function's context. |
| SILType remapSILTypeInDifferential(SILType ty) { |
| if (ty.hasArchetype()) |
| return getDifferential().mapTypeIntoContext(ty.mapTypeOutOfContext()); |
| return getDifferential().mapTypeIntoContext(ty); |
| } |
| |
| /// Find the tangent space of a given canonical type. |
| Optional<TangentSpace> getTangentSpace(CanType type) { |
| // Use witness generic signature to remap types. |
| if (auto witnessGenSig = witness->getDerivativeGenericSignature()) |
| type = witnessGenSig->getCanonicalTypeInContext(type); |
| return type->getAutoDiffTangentSpace( |
| LookUpConformanceInModule(getModule().getSwiftModule())); |
| } |
| |
| /// Assuming the given type conforms to `Differentiable` after remapping, |
| /// returns the associated tangent space SIL type. |
| SILType getRemappedTangentType(SILType type) { |
| return SILType::getPrimitiveType( |
| getTangentSpace(remapSILTypeInDifferential(type).getASTType()) |
| ->getCanonicalType(), |
| type.getCategory()); |
| } |
| |
| /// Set up the differential function. This includes: |
| /// - Creating all differential blocks. |
| /// - Creating differential entry block arguments based on the function type. |
| /// - Creating tangent value mapping for original/differential parameters. |
| /// - Checking for unvaried result and emitting related warnings. |
| void prepareForDifferentialGeneration(); |
| |
| public: |
| explicit Implementation(ADContext &context, SILFunction *original, |
| SILDifferentiabilityWitness *witness, |
| SILFunction *jvp, DifferentiationInvoker invoker); |
| |
| static SILFunction * |
| createEmptyDifferential(ADContext &context, |
| SILDifferentiabilityWitness *witness, |
| LinearMapInfo *linearMapInfo); |
| |
| /// Run JVP generation. Returns true on error. |
| bool run(); |
| |
| SILFunction &getJVP() const { return *jvp; } |
| |
| void postProcess(SILInstruction *orig, SILInstruction *cloned) { |
| if (errorOccurred) |
| return; |
| SILClonerWithScopes::postProcess(orig, cloned); |
| } |
| |
| /// Remap original basic blocks. |
| SILBasicBlock *remapBasicBlock(SILBasicBlock *bb) { |
| auto *jvpBB = BBMap[bb]; |
| return jvpBB; |
| } |
| |
| /// General visitor for all instructions. If any error is emitted by previous |
| /// visits, bail out. |
| void visit(SILInstruction *inst) { |
| if (errorOccurred) |
| return; |
| if (differentialInfo.shouldDifferentiateInstruction(inst)) { |
| LLVM_DEBUG(getADDebugStream() << "JVPCloner visited:\n[ORIG]" << *inst); |
| #ifndef NDEBUG |
| auto diffBuilder = getDifferentialBuilder(); |
| auto beforeInsertion = std::prev(diffBuilder.getInsertionPoint()); |
| #endif |
| TypeSubstCloner::visit(inst); |
| LLVM_DEBUG({ |
| auto &s = llvm::dbgs() << "[TAN] Emitted in differential:\n"; |
| auto afterInsertion = diffBuilder.getInsertionPoint(); |
| for (auto it = ++beforeInsertion; it != afterInsertion; ++it) |
| s << *it; |
| }); |
| } else { |
| TypeSubstCloner::visit(inst); |
| } |
| } |
| |
| void visitSILInstruction(SILInstruction *inst) { |
| context.emitNondifferentiabilityError( |
| inst, invoker, diag::autodiff_expression_not_differentiable_note); |
| errorOccurred = true; |
| } |
| |
| void visitInstructionsInBlock(SILBasicBlock *bb) { |
| // Destructure the differential struct to get the elements. |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto diffLoc = getDifferential().getLocation(); |
| auto *diffBB = diffBBMap.lookup(bb); |
| auto *mainDifferentialStruct = diffBB->getArguments().back(); |
| diffBuilder.setInsertionPoint(diffBB); |
| auto *dsi = |
| diffBuilder.createDestructureStruct(diffLoc, mainDifferentialStruct); |
| initializeDifferentialStructElements(bb, dsi->getResults()); |
| TypeSubstCloner::visitInstructionsInBlock(bb); |
| } |
| |
| // If an `apply` has active results or active inout parameters, replace it |
| // with an `apply` of its JVP. |
| void visitApplyInst(ApplyInst *ai) { |
| bool shouldDifferentiate = |
| differentialInfo.shouldDifferentiateApplySite(ai); |
| // If the function has no active arguments or results, zero-initialize the |
| // tangent buffers of the active indirect results. |
| if (!shouldDifferentiate) { |
| for (auto indResult : ai->getIndirectSILResults()) |
| if (activityInfo.isActive(indResult, getConfig())) { |
| auto &tanBuf = getTangentBuffer(ai->getParent(), indResult); |
| emitZeroIndirect(tanBuf->getType().getASTType(), tanBuf, |
| tanBuf.getLoc()); |
| } |
| } |
| // If the function should not be differentiated or its the array literal |
| // initialization intrinsic, just do standard cloning. |
| if (!shouldDifferentiate || |
| ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) { |
| LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n'); |
| TypeSubstCloner::visitApplyInst(ai); |
| return; |
| } |
| |
| auto loc = ai->getLoc(); |
| auto &builder = getBuilder(); |
| auto origCallee = getOpValue(ai->getCallee()); |
| auto originalFnTy = origCallee->getType().castTo<SILFunctionType>(); |
| |
| LLVM_DEBUG(getADDebugStream() << "JVP-transforming:\n" << *ai << '\n'); |
| |
| // Get the minimal parameter and result indices required for differentiating |
| // this `apply`. |
| SmallVector<SILValue, 4> allResults; |
| SmallVector<unsigned, 8> activeParamIndices; |
| SmallVector<unsigned, 8> activeResultIndices; |
| collectMinimalIndicesForFunctionCall(ai, getConfig(), activityInfo, |
| allResults, activeParamIndices, |
| activeResultIndices); |
| assert(!activeParamIndices.empty() && "Parameter indices cannot be empty"); |
| assert(!activeResultIndices.empty() && "Result indices cannot be empty"); |
| LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params={"; |
| llvm::interleave( |
| activeParamIndices.begin(), activeParamIndices.end(), |
| [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); |
| s << "}, results={"; llvm::interleave( |
| activeResultIndices.begin(), activeResultIndices.end(), |
| [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); |
| s << "}\n";); |
| |
| // Form expected indices. |
| auto numResults = |
| ai->getSubstCalleeType()->getNumResults() + |
| ai->getSubstCalleeType()->getNumIndirectMutatingParameters(); |
| AutoDiffConfig config( |
| IndexSubset::get(getASTContext(), |
| ai->getArgumentsWithoutIndirectResults().size(), |
| activeParamIndices), |
| IndexSubset::get(getASTContext(), numResults, activeResultIndices)); |
| |
| // Emit the JVP. |
| SILValue jvpValue; |
| // If functionSource is a `@differentiable` function, just extract it. |
| if (originalFnTy->isDifferentiable()) { |
| auto paramIndices = originalFnTy->getDifferentiabilityParameterIndices(); |
| for (auto i : config.parameterIndices->getIndices()) { |
| if (!paramIndices->contains(i)) { |
| context.emitNondifferentiabilityError( |
| origCallee, invoker, |
| diag:: |
| autodiff_function_noderivative_parameter_not_differentiable); |
| errorOccurred = true; |
| return; |
| } |
| } |
| builder.emitScopedBorrowOperation( |
| loc, origCallee, [&](SILValue borrowedDiffFunc) { |
| jvpValue = builder.createDifferentiableFunctionExtract( |
| loc, NormalDifferentiableFunctionTypeComponent::JVP, |
| borrowedDiffFunc); |
| jvpValue = builder.emitCopyValueOperation(loc, jvpValue); |
| }); |
| } |
| |
| // If JVP has not yet been found, emit an `differentiable_function` |
| // instruction on the remapped function operand and |
| // an `differentiable_function_extract` instruction to get the JVP. |
| // The `differentiable_function` instruction will be canonicalized during |
| // the transform main loop. |
| if (!jvpValue) { |
| // FIXME: Handle indirect differentiation invokers. This may require some |
| // redesign: currently, each original function + witness pair is mapped |
| // only to one invoker. |
| /* |
| DifferentiationInvoker indirect(ai, attr); |
| auto insertion = |
| context.getInvokers().try_emplace({original, attr}, indirect); |
| auto &invoker = insertion.first->getSecond(); |
| invoker = indirect; |
| */ |
| |
| // If the original `apply` instruction has a substitution map, then the |
| // applied function is specialized. |
| // In the JVP, specialization is also necessary for parity. The original |
| // function operand is specialized with a remapped version of same |
| // substitution map using an argument-less `partial_apply`. |
| if (ai->getSubstitutionMap().empty()) { |
| origCallee = builder.emitCopyValueOperation(loc, origCallee); |
| } else { |
| auto substMap = getOpSubstitutionMap(ai->getSubstitutionMap()); |
| auto jvpPartialApply = getBuilder().createPartialApply( |
| ai->getLoc(), origCallee, substMap, {}, |
| ParameterConvention::Direct_Guaranteed); |
| origCallee = jvpPartialApply; |
| } |
| |
| // Check and diagnose non-differentiable original function type. |
| auto diagnoseNondifferentiableOriginalFunctionType = |
| [&](CanSILFunctionType origFnTy) { |
| // Check and diagnose non-differentiable arguments. |
| for (auto paramIndex : config.parameterIndices->getIndices()) { |
| if (!originalFnTy->getParameters()[paramIndex] |
| .getSILStorageInterfaceType() |
| .isDifferentiable(getModule())) { |
| auto arg = ai->getArgumentsWithoutIndirectResults()[paramIndex]; |
| auto startLoc = arg.getLoc().getStartSourceLoc(); |
| auto endLoc = arg.getLoc().getEndSourceLoc(); |
| context |
| .emitNondifferentiabilityError( |
| arg, invoker, diag::autodiff_nondifferentiable_argument) |
| .fixItInsert(startLoc, "withoutDerivative(at: ") |
| .fixItInsertAfter(endLoc, ")"); |
| errorOccurred = true; |
| return true; |
| } |
| } |
| // Check and diagnose non-differentiable results. |
| for (auto resultIndex : config.resultIndices->getIndices()) { |
| SILType remappedResultType; |
| if (resultIndex >= originalFnTy->getNumResults()) { |
| auto inoutArgIdx = resultIndex - originalFnTy->getNumResults(); |
| auto inoutArg = |
| *std::next(ai->getInoutArguments().begin(), inoutArgIdx); |
| remappedResultType = inoutArg->getType(); |
| } else { |
| remappedResultType = originalFnTy->getResults()[resultIndex] |
| .getSILStorageInterfaceType(); |
| } |
| if (!remappedResultType.isDifferentiable(getModule())) { |
| auto startLoc = ai->getLoc().getStartSourceLoc(); |
| auto endLoc = ai->getLoc().getEndSourceLoc(); |
| context |
| .emitNondifferentiabilityError( |
| origCallee, invoker, |
| diag::autodiff_nondifferentiable_result) |
| .fixItInsert(startLoc, "withoutDerivative(at: ") |
| .fixItInsertAfter(endLoc, ")"); |
| errorOccurred = true; |
| return true; |
| } |
| } |
| return false; |
| }; |
| if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy)) |
| return; |
| |
| auto *diffFuncInst = context.createDifferentiableFunction( |
| builder, loc, config.parameterIndices, config.resultIndices, |
| origCallee); |
| |
| // Record the `differentiable_function` instruction. |
| context.getDifferentiableFunctionInstWorklist().push_back(diffFuncInst); |
| |
| builder.emitScopedBorrowOperation( |
| loc, diffFuncInst, [&](SILValue borrowedADFunc) { |
| auto extractedJVP = builder.createDifferentiableFunctionExtract( |
| loc, NormalDifferentiableFunctionTypeComponent::JVP, |
| borrowedADFunc); |
| jvpValue = builder.emitCopyValueOperation(loc, extractedJVP); |
| }); |
| builder.emitDestroyValueOperation(loc, diffFuncInst); |
| } |
| |
| // Call the JVP using the original parameters. |
| SmallVector<SILValue, 8> jvpArgs; |
| auto jvpFnTy = getOpType(jvpValue->getType()).castTo<SILFunctionType>(); |
| auto numJVPArgs = |
| jvpFnTy->getNumParameters() + jvpFnTy->getNumIndirectFormalResults(); |
| jvpArgs.reserve(numJVPArgs); |
| // Collect substituted arguments. |
| for (auto origArg : ai->getArguments()) |
| jvpArgs.push_back(getOpValue(origArg)); |
| assert(jvpArgs.size() == numJVPArgs); |
| // Apply the JVP. |
| // The JVP should be specialized, so no substitution map is necessary. |
| auto *jvpCall = getBuilder().createApply(loc, jvpValue, SubstitutionMap(), |
| jvpArgs, ai->isNonThrowing()); |
| LLVM_DEBUG(getADDebugStream() << "Applied jvp function\n" << *jvpCall); |
| |
| // Release the differentiable function. |
| builder.emitDestroyValueOperation(loc, jvpValue); |
| |
| // Get the JVP results (original results and differential). |
| SmallVector<SILValue, 8> jvpDirectResults; |
| extractAllElements(jvpCall, builder, jvpDirectResults); |
| auto originalDirectResults = |
| ArrayRef<SILValue>(jvpDirectResults).drop_back(1); |
| auto originalDirectResult = |
| joinElements(originalDirectResults, getBuilder(), jvpCall->getLoc()); |
| |
| mapValue(ai, originalDirectResult); |
| |
| // Some instructions that produce the callee may have been cloned. |
| // If the original callee did not have any users beyond this `apply`, |
| // recursively kill the cloned callee. |
| if (auto *origCallee = cast_or_null<SingleValueInstruction>( |
| ai->getCallee()->getDefiningInstruction())) |
| if (origCallee->hasOneUse()) |
| recursivelyDeleteTriviallyDeadInstructions( |
| getOpValue(origCallee)->getDefiningInstruction()); |
| |
| // Add the differential function for when we create the struct we partially |
| // apply to the differential we are generating. |
| auto differential = jvpDirectResults.back(); |
| auto *differentialDecl = differentialInfo.lookUpLinearMapDecl(ai); |
| auto originalDifferentialType = |
| getOpType(differential->getType()).getAs<SILFunctionType>(); |
| auto loweredDifferentialType = |
| getOpType(getLoweredType(differentialDecl->getInterfaceType())) |
| .castTo<SILFunctionType>(); |
| // If actual differential type does not match lowered differential type, |
| // reabstract the differential using a thunk. |
| if (!loweredDifferentialType->isEqual(originalDifferentialType)) { |
| SILOptFunctionBuilder fb(context.getTransform()); |
| differential = reabstractFunction( |
| builder, fb, loc, differential, loweredDifferentialType, |
| [this](SubstitutionMap subs) -> SubstitutionMap { |
| return this->getOpSubstitutionMap(subs); |
| }); |
| } |
| differentialValues[ai->getParent()].push_back(differential); |
| |
| // Differential emission. |
| emitTangentForApplyInst(ai, config, originalDifferentialType); |
| } |
| |
| void visitReturnInst(ReturnInst *ri) { |
| auto loc = ri->getOperand().getLoc(); |
| auto *origExit = ri->getParent(); |
| auto &builder = getBuilder(); |
| auto *diffStructVal = buildDifferentialValueStructValue(ri); |
| |
| // Get the JVP value corresponding to the original functions's return value. |
| auto *origRetInst = cast<ReturnInst>(origExit->getTerminator()); |
| auto origResult = getOpValue(origRetInst->getOperand()); |
| SmallVector<SILValue, 8> origResults; |
| extractAllElements(origResult, builder, origResults); |
| |
| // Get and partially apply the differential. |
| auto jvpGenericEnv = jvp->getGenericEnvironment(); |
| auto jvpSubstMap = jvpGenericEnv |
| ? jvpGenericEnv->getForwardingSubstitutionMap() |
| : jvp->getForwardingSubstitutionMap(); |
| auto *differentialRef = builder.createFunctionRef(loc, &getDifferential()); |
| auto *differentialPartialApply = builder.createPartialApply( |
| loc, differentialRef, jvpSubstMap, {diffStructVal}, |
| ParameterConvention::Direct_Guaranteed); |
| |
| auto differentialType = jvp->getLoweredFunctionType() |
| ->getResults() |
| .back() |
| .getSILStorageInterfaceType(); |
| differentialType = differentialType.substGenericArgs( |
| getModule(), jvpSubstMap, TypeExpansionContext::minimal()); |
| differentialType = differentialType.subst(getModule(), jvpSubstMap); |
| auto differentialFnType = differentialType.castTo<SILFunctionType>(); |
| auto differentialSubstType = |
| differentialPartialApply->getType().castTo<SILFunctionType>(); |
| |
| // If necessary, convert the differential value to the returned differential |
| // function type. |
| SILValue differentialValue; |
| if (differentialSubstType == differentialFnType) { |
| differentialValue = differentialPartialApply; |
| } else if (differentialSubstType |
| ->isABICompatibleWith(differentialFnType, *jvp) |
| .isCompatible()) { |
| differentialValue = builder.createConvertFunction( |
| loc, differentialPartialApply, differentialType, |
| /*withoutActuallyEscaping*/ false); |
| } else { |
| llvm::report_fatal_error("Differential value type is not ABI-compatible " |
| "with the returned differential type"); |
| } |
| |
| // Return a tuple of the original result and differential. |
| SmallVector<SILValue, 8> directResults; |
| directResults.append(origResults.begin(), origResults.end()); |
| directResults.push_back(differentialValue); |
| builder.createReturn(ri->getLoc(), |
| joinElements(directResults, builder, loc)); |
| } |
| |
| void visitBranchInst(BranchInst *bi) { |
| llvm_unreachable("Unsupported SIL instruction."); |
| } |
| |
| void visitCondBranchInst(CondBranchInst *cbi) { |
| llvm_unreachable("Unsupported SIL instruction."); |
| } |
| |
| void visitSwitchEnumInst(SwitchEnumInst *sei) { |
| llvm_unreachable("Unsupported SIL instruction."); |
| } |
| |
| void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) { |
| // Clone `differentiable_function` from original to JVP, then add the cloned |
| // instruction to the `differentiable_function` worklist. |
| TypeSubstCloner::visitDifferentiableFunctionInst(dfi); |
| auto *newDFI = cast<DifferentiableFunctionInst>(getOpValue(dfi)); |
| context.getDifferentiableFunctionInstWorklist().push_back(newDFI); |
| } |
| |
| void visitLinearFunctionInst(LinearFunctionInst *lfi) { |
| // Clone `linear_function` from original to JVP, then add the cloned |
| // instruction to the `linear_function` worklist. |
| TypeSubstCloner::visitLinearFunctionInst(lfi); |
| auto *newLFI = cast<LinearFunctionInst>(getOpValue(lfi)); |
| context.getLinearFunctionInstWorklist().push_back(newLFI); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Tangent emission helpers |
| //--------------------------------------------------------------------------// |
| |
| #define CLONE_AND_EMIT_TANGENT(INST, ID) \ |
| void visit##INST##Inst(INST##Inst *inst) { \ |
| TypeSubstCloner::visit##INST##Inst(inst); \ |
| if (differentialInfo.shouldDifferentiateInstruction(inst)) \ |
| emitTangentFor##INST##Inst(inst); \ |
| } \ |
| void emitTangentFor##INST##Inst(INST##Inst *(ID)) |
| |
| CLONE_AND_EMIT_TANGENT(BeginBorrow, bbi) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto loc = bbi->getLoc(); |
| auto tanVal = materializeTangent(getTangentValue(bbi->getOperand()), loc); |
| auto tanValBorrow = diffBuilder.emitBeginBorrowOperation(loc, tanVal); |
| setTangentValue(bbi->getParent(), bbi, |
| makeConcreteTangentValue(tanValBorrow)); |
| } |
| |
| CLONE_AND_EMIT_TANGENT(EndBorrow, ebi) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto loc = ebi->getLoc(); |
| auto tanVal = materializeTangent(getTangentValue(ebi->getOperand()), loc); |
| diffBuilder.emitEndBorrowOperation(loc, tanVal); |
| } |
| |
| CLONE_AND_EMIT_TANGENT(DestroyValue, dvi) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto loc = dvi->getLoc(); |
| auto tanVal = materializeTangent(getTangentValue(dvi->getOperand()), loc); |
| diffBuilder.emitDestroyValueOperation(loc, tanVal); |
| } |
| |
| CLONE_AND_EMIT_TANGENT(CopyValue, cvi) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto tan = getTangentValue(cvi->getOperand()); |
| auto tanVal = materializeTangent(tan, cvi->getLoc()); |
| auto tanValCopy = diffBuilder.emitCopyValueOperation(cvi->getLoc(), tanVal); |
| setTangentValue(cvi->getParent(), cvi, |
| makeConcreteTangentValue(tanValCopy)); |
| } |
| |
| /// Handle `load` instruction. |
| /// Original: y = load x |
| /// Tangent: tan[y] = load tan[x] |
| void visitLoadInst(LoadInst *li) { |
| TypeSubstCloner::visitLoadInst(li); |
| // If an active buffer is loaded with take to a non-active value, destroy |
| // the active buffer's tangent buffer. |
| if (!differentialInfo.shouldDifferentiateInstruction(li)) { |
| auto isTake = |
| (li->getOwnershipQualifier() == LoadOwnershipQualifier::Take); |
| if (isTake && activityInfo.isActive(li->getOperand(), getConfig())) { |
| auto &tanBuf = getTangentBuffer(li->getParent(), li->getOperand()); |
| getDifferentialBuilder().emitDestroyOperation(tanBuf.getLoc(), tanBuf); |
| } |
| return; |
| } |
| // Otherwise, do standard differential cloning. |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto *bb = li->getParent(); |
| auto loc = li->getLoc(); |
| auto tanBuf = getTangentBuffer(bb, li->getOperand()); |
| auto tanVal = diffBuilder.emitLoadValueOperation( |
| loc, tanBuf, li->getOwnershipQualifier()); |
| setTangentValue(bb, li, makeConcreteTangentValue(tanVal)); |
| } |
| |
| /// Handle `load_borrow` instruction. |
| /// Original: y = load_borrow x |
| /// Tangent: tan[y] = load_borrow tan[x] |
| CLONE_AND_EMIT_TANGENT(LoadBorrow, lbi) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto *bb = lbi->getParent(); |
| auto loc = lbi->getLoc(); |
| auto tanBuf = getTangentBuffer(bb, lbi->getOperand()); |
| auto tanVal = diffBuilder.emitLoadBorrowOperation(loc, tanBuf); |
| setTangentValue(bb, lbi, makeConcreteTangentValue(tanVal)); |
| } |
| |
| /// Handle `store` instruction in the differential. |
| /// Original: store x to y |
| /// Tangent: store tan[x] to tan[y] |
| void visitStoreInst(StoreInst *si) { |
| TypeSubstCloner::visitStoreInst(si); |
| // If a non-active value is stored into an active buffer, zero-initialize |
| // the active buffer's tangent buffer. |
| if (!differentialInfo.shouldDifferentiateInstruction(si)) { |
| if (activityInfo.isActive(si->getDest(), getConfig())) { |
| auto &tanBufDest = getTangentBuffer(si->getParent(), si->getDest()); |
| emitZeroIndirect(tanBufDest->getType().getASTType(), tanBufDest, |
| tanBufDest.getLoc()); |
| } |
| return; |
| } |
| // Otherwise, do standard differential cloning. |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto loc = si->getLoc(); |
| auto tanValSrc = materializeTangent(getTangentValue(si->getSrc()), loc); |
| auto &tanValDest = getTangentBuffer(si->getParent(), si->getDest()); |
| diffBuilder.emitStoreValueOperation(loc, tanValSrc, tanValDest, |
| si->getOwnershipQualifier()); |
| } |
| |
| /// Handle `store_borrow` instruction in the differential. |
| /// Original: store_borrow x to y |
| /// Tangent: store_borrow tan[x] to tan[y] |
| void visitStoreBorrowInst(StoreBorrowInst *sbi) { |
| TypeSubstCloner::visitStoreBorrowInst(sbi); |
| // If a non-active value is stored into an active buffer, zero-initialize |
| // the active buffer's tangent buffer. |
| if (!differentialInfo.shouldDifferentiateInstruction(sbi)) { |
| if (activityInfo.isActive(sbi->getDest(), getConfig())) { |
| auto &tanBufDest = getTangentBuffer(sbi->getParent(), sbi->getDest()); |
| emitZeroIndirect(tanBufDest->getType().getASTType(), tanBufDest, |
| tanBufDest.getLoc()); |
| } |
| return; |
| } |
| // Otherwise, do standard differential cloning. |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto loc = sbi->getLoc(); |
| auto tanValSrc = materializeTangent(getTangentValue(sbi->getSrc()), loc); |
| auto &tanValDest = getTangentBuffer(sbi->getParent(), sbi->getDest()); |
| diffBuilder.createStoreBorrow(loc, tanValSrc, tanValDest); |
| } |
| |
| /// Handle `copy_addr` instruction. |
| /// Original: copy_addr x to y |
| /// Tangent: copy_addr tan[x] to tan[y] |
| void visitCopyAddrInst(CopyAddrInst *cai) { |
| TypeSubstCloner::visitCopyAddrInst(cai); |
| // If a non-active buffer is copied into an active buffer, zero-initialize |
| // the destination buffer's tangent buffer. |
| // If an active buffer is copied with take into a non-active buffer, destroy |
| // the source buffer's tangent buffer. |
| if (!differentialInfo.shouldDifferentiateInstruction(cai)) { |
| if (activityInfo.isActive(cai->getDest(), getConfig())) { |
| auto &tanBufDest = getTangentBuffer(cai->getParent(), cai->getDest()); |
| emitZeroIndirect(tanBufDest->getType().getASTType(), tanBufDest, |
| tanBufDest.getLoc()); |
| } |
| if (cai->isTakeOfSrc() && |
| activityInfo.isActive(cai->getSrc(), getConfig())) { |
| auto &tanBufSrc = getTangentBuffer(cai->getParent(), cai->getSrc()); |
| getDifferentialBuilder().emitDestroyOperation(tanBufSrc.getLoc(), |
| tanBufSrc); |
| } |
| return; |
| } |
| // Otherwise, do standard differential cloning. |
| auto diffBuilder = getDifferentialBuilder(); |
| auto loc = cai->getLoc(); |
| auto *bb = cai->getParent(); |
| auto &tanSrc = getTangentBuffer(bb, cai->getSrc()); |
| auto tanDest = getTangentBuffer(bb, cai->getDest()); |
| diffBuilder.createCopyAddr(loc, tanSrc, tanDest, cai->isTakeOfSrc(), |
| cai->isInitializationOfDest()); |
| } |
| |
| /// Handle `unconditional_checked_cast_addr` instruction. |
| /// Original: unconditional_checked_cast_addr $X in x to $Y in y |
| /// Tangent: unconditional_checked_cast_addr $X.Tan in tan[x] |
| /// to $Y.Tan in tan[y] |
| CLONE_AND_EMIT_TANGENT(UnconditionalCheckedCastAddr, uccai) { |
| auto diffBuilder = getDifferentialBuilder(); |
| auto loc = uccai->getLoc(); |
| auto *bb = uccai->getParent(); |
| auto &tanSrc = getTangentBuffer(bb, uccai->getSrc()); |
| auto tanDest = getTangentBuffer(bb, uccai->getDest()); |
| |
| diffBuilder.createUnconditionalCheckedCastAddr( |
| loc, tanSrc, tanSrc->getType().getASTType(), tanDest, |
| tanDest->getType().getASTType()); |
| } |
| |
| /// Handle `begin_access` instruction (and do differentiability checks). |
| /// Original: y = begin_access x |
| /// Tangent: tan[y] = begin_access tan[x] |
| CLONE_AND_EMIT_TANGENT(BeginAccess, bai) { |
| // Check for non-differentiable writes. |
| if (bai->getAccessKind() == SILAccessKind::Modify) { |
| if (auto *gai = dyn_cast<GlobalAddrInst>(bai->getSource())) { |
| context.emitNondifferentiabilityError( |
| bai, invoker, |
| diag::autodiff_cannot_differentiate_writes_to_global_variables); |
| errorOccurred = true; |
| return; |
| } |
| if (auto *pbi = dyn_cast<ProjectBoxInst>(bai->getSource())) { |
| context.emitNondifferentiabilityError( |
| bai, invoker, |
| diag::autodiff_cannot_differentiate_writes_to_mutable_captures); |
| errorOccurred = true; |
| return; |
| } |
| } |
| |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto *bb = bai->getParent(); |
| |
| auto tanSrc = getTangentBuffer(bb, bai->getSource()); |
| auto *tanDest = diffBuilder.createBeginAccess( |
| bai->getLoc(), tanSrc, bai->getAccessKind(), bai->getEnforcement(), |
| bai->hasNoNestedConflict(), bai->isFromBuiltin()); |
| setTangentBuffer(bb, bai, tanDest); |
| } |
| |
| /// Handle `end_access` instruction. |
| /// Original: begin_access x |
| /// Tangent: end_access tan[x] |
| CLONE_AND_EMIT_TANGENT(EndAccess, eai) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto *bb = eai->getParent(); |
| auto loc = eai->getLoc(); |
| auto tanOperand = getTangentBuffer(bb, eai->getOperand()); |
| diffBuilder.createEndAccess(loc, tanOperand, eai->isAborting()); |
| } |
| |
| /// Handle `alloc_stack` instruction. |
| /// Original: y = alloc_stack $T |
| /// Tangent: tan[y] = alloc_stack $T.Tangent |
| CLONE_AND_EMIT_TANGENT(AllocStack, asi) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto *mappedAllocStackInst = diffBuilder.createAllocStack( |
| asi->getLoc(), getRemappedTangentType(asi->getElementType()), |
| asi->getVarInfo()); |
| setTangentBuffer(asi->getParent(), asi, mappedAllocStackInst); |
| } |
| |
| /// Handle `dealloc_stack` instruction. |
| /// Original: dealloc_stack x |
| /// Tangent: dealloc_stack tan[x] |
| CLONE_AND_EMIT_TANGENT(DeallocStack, dsi) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto tanBuf = getTangentBuffer(dsi->getParent(), dsi->getOperand()); |
| diffBuilder.createDeallocStack(dsi->getLoc(), tanBuf); |
| } |
| |
| /// Handle `destroy_addr` instruction. |
| /// Original: destroy_addr x |
| /// Tangent: destroy_addr tan[x] |
| CLONE_AND_EMIT_TANGENT(DestroyAddr, dai) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto tanBuf = getTangentBuffer(dai->getParent(), dai->getOperand()); |
| diffBuilder.createDestroyAddr(dai->getLoc(), tanBuf); |
| } |
| |
| /// Handle `struct` instruction. |
| /// Original: y = struct $T (x0, x1, x2, ...) |
| /// Tangent: tan[y] = struct $T.Tangent (tan[x0], tan[x1], tan[x2], ...) |
| CLONE_AND_EMIT_TANGENT(Struct, si) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| SmallVector<SILValue, 4> tangentElements; |
| for (auto elem : si->getElements()) |
| tangentElements.push_back(getTangentValue(elem).getConcreteValue()); |
| auto tanExtract = diffBuilder.createStruct( |
| si->getLoc(), getRemappedTangentType(si->getType()), tangentElements); |
| setTangentValue(si->getParent(), si, makeConcreteTangentValue(tanExtract)); |
| } |
| |
| /// Handle `struct_extract` instruction. |
| /// Original: y = struct_extract x, #field |
| /// Tangent: tan[y] = struct_extract tan[x], #field' |
| /// ^~~~~~~ |
| /// field in tangent space corresponding to #field |
| CLONE_AND_EMIT_TANGENT(StructExtract, sei) { |
| assert(!sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() && |
| "`struct_extract` with `@noDerivative` field should not be " |
| "differentiated; activity analysis should not marked as varied."); |
| auto diffBuilder = getDifferentialBuilder(); |
| auto loc = getValidLocation(sei); |
| // Find the corresponding field in the tangent space. |
| auto structType = |
| remapSILTypeInDifferential(sei->getOperand()->getType()).getASTType(); |
| auto *tanField = |
| getTangentStoredProperty(context, sei, structType, invoker); |
| if (!tanField) { |
| errorOccurred = true; |
| return; |
| } |
| // Emit tangent `struct_extract`. |
| auto tanStruct = |
| materializeTangent(getTangentValue(sei->getOperand()), loc); |
| auto tangentInst = |
| diffBuilder.createStructExtract(loc, tanStruct, tanField); |
| // Update tangent value mapping for `struct_extract` result. |
| auto tangentResult = makeConcreteTangentValue(tangentInst); |
| setTangentValue(sei->getParent(), sei, tangentResult); |
| } |
| |
| /// Handle `struct_element_addr` instruction. |
| /// Original: y = struct_element_addr x, #field |
| /// Tangent: tan[y] = struct_element_addr tan[x], #field' |
| /// ^~~~~~~ |
| /// field in tangent space corresponding to #field |
| CLONE_AND_EMIT_TANGENT(StructElementAddr, seai) { |
| assert(!seai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() && |
| "`struct_element_addr` with `@noDerivative` field should not be " |
| "differentiated; activity analysis should not marked as varied."); |
| auto diffBuilder = getDifferentialBuilder(); |
| auto *bb = seai->getParent(); |
| auto loc = getValidLocation(seai); |
| // Find the corresponding field in the tangent space. |
| auto structType = |
| remapSILTypeInDifferential(seai->getOperand()->getType()).getASTType(); |
| auto *tanField = |
| getTangentStoredProperty(context, seai, structType, invoker); |
| if (!tanField) { |
| errorOccurred = true; |
| return; |
| } |
| // Emit tangent `struct_element_addr`. |
| auto tanOperand = getTangentBuffer(bb, seai->getOperand()); |
| auto tangentInst = |
| diffBuilder.createStructElementAddr(loc, tanOperand, tanField); |
| // Update tangent buffer map for `struct_element_addr`. |
| setTangentBuffer(bb, seai, tangentInst); |
| } |
| |
| /// Handle `tuple` instruction. |
| /// Original: y = tuple (x0, x1, x2, ...) |
| /// Tangent: tan[y] = tuple (tan[x0], tan[x1], tan[x2], ...) |
| /// ^~~ |
| /// excluding non-differentiable elements |
| CLONE_AND_EMIT_TANGENT(Tuple, ti) { |
| auto diffBuilder = getDifferentialBuilder(); |
| // Get the tangents of all the tuple elements. |
| SmallVector<SILValue, 8> tangentTupleElements; |
| for (auto elem : ti->getElements()) { |
| if (!getTangentSpace(elem->getType().getASTType())) |
| continue; |
| tangentTupleElements.push_back( |
| materializeTangent(getTangentValue(elem), ti->getLoc())); |
| } |
| // Emit the instruction and add the tangent mapping. |
| auto tanTuple = |
| joinElements(tangentTupleElements, diffBuilder, ti->getLoc()); |
| setTangentValue(ti->getParent(), ti, makeConcreteTangentValue(tanTuple)); |
| } |
| |
| /// Handle `tuple_extract` instruction. |
| /// Original: y = tuple_extract x, <n> |
| /// Tangent: tan[y] = tuple_extract tan[x], <n'> |
| /// ^~~~ |
| /// tuple tangent space index corresponding to n |
| CLONE_AND_EMIT_TANGENT(TupleExtract, tei) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto loc = tei->getLoc(); |
| auto origTupleTy = tei->getOperand()->getType().castTo<TupleType>(); |
| unsigned tanIndex = 0; |
| for (unsigned i : range(tei->getFieldIndex())) { |
| if (getTangentSpace( |
| origTupleTy->getElement(i).getType()->getCanonicalType())) |
| ++tanIndex; |
| } |
| auto tanType = getRemappedTangentType(tei->getType()); |
| auto tanSource = |
| materializeTangent(getTangentValue(tei->getOperand()), loc); |
| // If the tangent value of the source does not have a tuple type, then |
| // it must represent a "single element tuple type". Use it directly. |
| if (!tanSource->getType().is<TupleType>()) { |
| setTangentValue(tei->getParent(), tei, |
| makeConcreteTangentValue(tanSource)); |
| } else { |
| auto tanElt = |
| diffBuilder.createTupleExtract(loc, tanSource, tanIndex, tanType); |
| setTangentValue(tei->getParent(), tei, makeConcreteTangentValue(tanElt)); |
| } |
| } |
| |
| /// Handle `tuple_element_addr` instruction. |
| /// Original: y = tuple_element_addr x, <n> |
| /// Tangent: tan[y] = tuple_element_addr tan[x], <n'> |
| /// ^~~~ |
| /// tuple tangent space index corresponding to n |
| CLONE_AND_EMIT_TANGENT(TupleElementAddr, teai) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto origTupleTy = teai->getOperand()->getType().castTo<TupleType>(); |
| unsigned tanIndex = 0; |
| for (unsigned i : range(teai->getFieldIndex())) { |
| if (getTangentSpace( |
| origTupleTy->getElement(i).getType()->getCanonicalType())) |
| ++tanIndex; |
| } |
| auto tanType = getRemappedTangentType(teai->getType()); |
| auto tanSource = getTangentBuffer(teai->getParent(), teai->getOperand()); |
| SILValue tanBuf; |
| // If the tangent buffer of the source does not have a tuple type, then |
| // it must represent a "single element tuple type". Use it directly. |
| if (!tanSource->getType().is<TupleType>()) { |
| tanBuf = tanSource; |
| } else { |
| tanBuf = diffBuilder.createTupleElementAddr(teai->getLoc(), tanSource, |
| tanIndex, tanType); |
| } |
| setTangentBuffer(teai->getParent(), teai, tanBuf); |
| } |
| |
| /// Handle `destructure_tuple` instruction. |
| /// Original: (y0, y1, ...) = destructure_tuple x, <n> |
| /// Tangent: (tan[y0], tan[y1], ...) = destructure_tuple tan[x], <n'> |
| /// ^~~~ |
| /// tuple tangent space index corresponding to n |
| CLONE_AND_EMIT_TANGENT(DestructureTuple, dti) { |
| assert(llvm::any_of(dti->getResults(), |
| [&](SILValue elt) { |
| return activityInfo.isActive(elt, getConfig()); |
| }) && |
| "`destructure_tuple` should have at least one active result"); |
| |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto *bb = dti->getParent(); |
| auto loc = dti->getLoc(); |
| |
| auto tanTuple = materializeTangent(getTangentValue(dti->getOperand()), loc); |
| SmallVector<SILValue, 4> tanElts; |
| if (tanTuple->getType().is<TupleType>()) { |
| auto *tanDti = diffBuilder.createDestructureTuple(loc, tanTuple); |
| tanElts.append(tanDti->getResults().begin(), tanDti->getResults().end()); |
| } else { |
| tanElts.push_back(tanTuple); |
| } |
| unsigned tanIdx = 0; |
| for (auto i : range(dti->getNumResults())) { |
| auto origElt = dti->getResult(i); |
| if (!getTangentSpace(origElt->getType().getASTType())) |
| continue; |
| setTangentValue(bb, origElt, makeConcreteTangentValue(tanElts[tanIdx++])); |
| } |
| } |
| |
| #undef CLONE_AND_EMIT_TANGENT |
| |
| /// Handle `apply` instruction, given: |
| /// - The minimal indices for differentiating the `apply`. |
| /// - The original non-reabstracted differential type. |
| /// |
| /// Original: y = apply f(x0, x1, ...) |
| /// Tangent: tan[y] = apply diff_f(tan[x0], tan[x1], ...) |
| void emitTangentForApplyInst(ApplyInst *ai, AutoDiffConfig applyConfig, |
| CanSILFunctionType originalDifferentialType) { |
| assert(differentialInfo.shouldDifferentiateApplySite(ai)); |
| auto *bb = ai->getParent(); |
| auto loc = ai->getLoc(); |
| auto &diffBuilder = getDifferentialBuilder(); |
| |
| // Get the differential value. |
| auto *field = differentialInfo.lookUpLinearMapDecl(ai); |
| assert(field); |
| SILValue differential = getDifferentialStructElement(bb, field); |
| auto differentialType = remapSILTypeInDifferential(differential->getType()) |
| .castTo<SILFunctionType>(); |
| |
| // Get the differential arguments. |
| SmallVector<SILValue, 8> diffArgs; |
| |
| for (auto indRes : ai->getIndirectSILResults()) |
| diffArgs.push_back(getTangentBuffer(bb, indRes)); |
| |
| auto origArgs = ai->getArgumentsWithoutIndirectResults(); |
| // Get the tangent value of the original arguments. |
| for (auto i : indices(origArgs)) { |
| auto origArg = origArgs[i]; |
| // If the argument is not active: |
| // - Skip the element, if it is not differentiable. |
| // - Otherwise, add a zero value to that location. |
| if (!activityInfo.isActive(origArg, getConfig())) { |
| auto origCalleeType = ai->getSubstCalleeType(); |
| if (!origCalleeType->isDifferentiable()) |
| continue; |
| auto actualOrigCalleeIndices = |
| origCalleeType->getDifferentiabilityParameterIndices(); |
| if (actualOrigCalleeIndices->contains(i)) { |
| SILValue tanParam; |
| if (origArg->getType().isObject()) { |
| tanParam = emitZeroDirect( |
| getRemappedTangentType(origArg->getType()).getASTType(), loc); |
| diffArgs.push_back(tanParam); |
| } else { |
| tanParam = diffBuilder.createAllocStack( |
| loc, getRemappedTangentType(origArg->getType())); |
| emitZeroIndirect( |
| getRemappedTangentType(origArg->getType()).getASTType(), |
| tanParam, loc); |
| } |
| } |
| } |
| // Otherwise, if the argument is active, handle the argument normally by |
| // getting its tangent value. |
| else { |
| SILValue tanParam; |
| if (origArg->getType().isObject()) { |
| tanParam = materializeTangent(getTangentValue(origArg), loc); |
| } else { |
| tanParam = getTangentBuffer(ai->getParent(), origArg); |
| } |
| diffArgs.push_back(tanParam); |
| if (errorOccurred) |
| return; |
| } |
| } |
| |
| // If callee differential was reabstracted in JVP, reabstract the callee |
| // differential. |
| if (!differentialType->isEqual(originalDifferentialType)) { |
| SILOptFunctionBuilder fb(context.getTransform()); |
| differential = reabstractFunction( |
| diffBuilder, fb, loc, differential, originalDifferentialType, |
| [this](SubstitutionMap subs) -> SubstitutionMap { |
| return this->getOpSubstitutionMap(subs); |
| }); |
| } |
| |
| // Call the differential. |
| auto *differentialCall = |
| diffBuilder.createApply(loc, differential, SubstitutionMap(), diffArgs, |
| /*isNonThrowing*/ false); |
| diffBuilder.emitDestroyValueOperation(loc, differential); |
| |
| // Get the original `apply` results. |
| SmallVector<SILValue, 8> origDirectResults; |
| forEachApplyDirectResult(ai, [&](SILValue directResult) { |
| origDirectResults.push_back(directResult); |
| }); |
| SmallVector<SILValue, 8> origAllResults; |
| collectAllActualResultsInTypeOrder(ai, origDirectResults, origAllResults); |
| |
| // Get the callee differential `apply` results. |
| SmallVector<SILValue, 8> differentialDirectResults; |
| extractAllElements(differentialCall, getDifferentialBuilder(), |
| differentialDirectResults); |
| SmallVector<SILValue, 8> differentialAllResults; |
| collectAllActualResultsInTypeOrder( |
| differentialCall, differentialDirectResults, differentialAllResults); |
| for (auto inoutArg : ai->getInoutArguments()) |
| origAllResults.push_back(inoutArg); |
| for (auto inoutArg : differentialCall->getInoutArguments()) |
| differentialAllResults.push_back(inoutArg); |
| assert(applyConfig.resultIndices->getNumIndices() == |
| differentialAllResults.size()); |
| |
| // Set tangent values for original `apply` results. |
| unsigned differentialResultIndex = 0; |
| for (auto resultIndex : applyConfig.resultIndices->getIndices()) { |
| auto origResult = origAllResults[resultIndex]; |
| auto differentialResult = |
| differentialAllResults[differentialResultIndex++]; |
| if (origResult->getType().isObject()) { |
| if (!origResult->getType().is<TupleType>()) { |
| setTangentValue(bb, origResult, |
| makeConcreteTangentValue(differentialResult)); |
| } else if (auto *dti = getSingleDestructureTupleUser(ai)) { |
| bool notSetValue = true; |
| for (auto result : dti->getResults()) { |
| if (activityInfo.isActive(result, getConfig())) { |
| assert(notSetValue && |
| "This was incorrectly set, should only have one active " |
| "result from the tuple."); |
| notSetValue = false; |
| setTangentValue(bb, result, |
| makeConcreteTangentValue(differentialResult)); |
| } |
| } |
| } |
| } |
| } |
| } |
| |
| /// Generate a `return` instruction in the current differential basic block. |
| void emitReturnInstForDifferential() { |
| auto &differential = getDifferential(); |
| auto diffLoc = differential.getLocation(); |
| auto &diffBuilder = getDifferentialBuilder(); |
| |
| // Collect original results. |
| SmallVector<SILValue, 2> originalResults; |
| collectAllDirectResultsInTypeOrder(*original, originalResults); |
| // Collect differential direct results. |
| SmallVector<SILValue, 8> retElts; |
| for (auto i : range(originalResults.size())) { |
| auto origResult = originalResults[i]; |
| if (!getConfig().resultIndices->contains(i)) |
| continue; |
| auto tanVal = materializeTangent(getTangentValue(origResult), diffLoc); |
| retElts.push_back(tanVal); |
| } |
| |
| diffBuilder.createReturn(diffLoc, |
| joinElements(retElts, diffBuilder, diffLoc)); |
| } |
| }; |
| |
| //--------------------------------------------------------------------------// |
| // Initialization |
| //--------------------------------------------------------------------------// |
| |
| /// Initialization helper function. |
| /// |
| /// Returns the substitution map used for type remapping. |
| static SubstitutionMap getSubstitutionMap(SILFunction *original, |
| SILFunction *jvp) { |
| auto substMap = original->getForwardingSubstitutionMap(); |
| if (auto *jvpGenEnv = jvp->getGenericEnvironment()) { |
| auto jvpSubstMap = jvpGenEnv->getForwardingSubstitutionMap(); |
| substMap = SubstitutionMap::get( |
| jvpGenEnv->getGenericSignature(), QuerySubstitutionMap{jvpSubstMap}, |
| LookUpConformanceInSubstitutionMap(jvpSubstMap)); |
| } |
| return substMap; |
| } |
| |
| /// Initialization helper function. |
| /// |
| /// Returns the activity info for the given original function, autodiff indices, |
| /// and JVP generic signature. |
| static const DifferentiableActivityInfo & |
| getActivityInfo(ADContext &context, SILFunction *original, |
| AutoDiffConfig config, SILFunction *jvp) { |
| // Get activity info of the original function. |
| auto &passManager = context.getPassManager(); |
| auto *activityAnalysis = |
| passManager.getAnalysis<DifferentiableActivityAnalysis>(); |
| auto &activityCollection = *activityAnalysis->get(original); |
| auto &activityInfo = activityCollection.getActivityInfo( |
| jvp->getLoweredFunctionType()->getSubstGenericSignature(), |
| AutoDiffDerivativeFunctionKind::JVP); |
| LLVM_DEBUG(activityInfo.dump(config, getADDebugStream())); |
| return activityInfo; |
| } |
| |
| JVPCloner::Implementation::Implementation(ADContext &context, |
| SILFunction *original, |
| SILDifferentiabilityWitness *witness, |
| SILFunction *jvp, |
| DifferentiationInvoker invoker) |
| : TypeSubstCloner(*jvp, *original, getSubstitutionMap(original, jvp)), |
| context(context), original(original), witness(witness), jvp(jvp), |
| invoker(invoker), |
| activityInfo( |
| getActivityInfo(context, original, witness->getConfig(), jvp)), |
| loopInfo(context.getPassManager().getAnalysis<SILLoopAnalysis>() |
| ->get(original)), |
| differentialInfo(context, AutoDiffLinearMapKind::Differential, original, |
| jvp, witness->getConfig(), activityInfo, loopInfo), |
| differentialBuilder(SILBuilder( |
| *createEmptyDifferential(context, witness, &differentialInfo))), |
| diffLocalAllocBuilder(getDifferential()) { |
| // Create empty differential function. |
| context.recordGeneratedFunction(&getDifferential()); |
| } |
| |
| JVPCloner::JVPCloner(ADContext &context, SILFunction *original, |
| SILDifferentiabilityWitness *witness, SILFunction *jvp, |
| DifferentiationInvoker invoker) |
| : impl(*new Implementation(context, original, witness, jvp, invoker)) {} |
| |
| JVPCloner::~JVPCloner() { delete &impl; } |
| |
| //--------------------------------------------------------------------------// |
| // Differential struct mapping |
| //--------------------------------------------------------------------------// |
| |
| void JVPCloner::Implementation::initializeDifferentialStructElements( |
| SILBasicBlock *origBB, SILInstructionResultArray values) { |
| auto *diffStructDecl = differentialInfo.getLinearMapStruct(origBB); |
| assert(diffStructDecl->getStoredProperties().size() == values.size() && |
| "The number of differential struct fields must equal the number of " |
| "differential struct element values"); |
| for (auto pair : llvm::zip(diffStructDecl->getStoredProperties(), values)) { |
| assert(std::get<1>(pair).getOwnershipKind() != OwnershipKind::Guaranteed && |
| "Differential struct elements must be @owned"); |
| auto insertion = differentialStructElements.insert( |
| {std::get<0>(pair), std::get<1>(pair)}); |
| (void)insertion; |
| assert(insertion.second && |
| "A differential struct element mapping already exists!"); |
| } |
| } |
| |
| SILValue |
| JVPCloner::Implementation::getDifferentialStructElement(SILBasicBlock *origBB, |
| VarDecl *field) { |
| assert(differentialInfo.getLinearMapStruct(origBB) == |
| cast<StructDecl>(field->getDeclContext())); |
| assert(differentialStructElements.count(field) && |
| "Differential struct element for this field does not exist!"); |
| return differentialStructElements.lookup(field); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Tangent emission helpers |
| //--------------------------------------------------------------------------// |
| |
| void JVPCloner::Implementation::prepareForDifferentialGeneration() { |
| // Create differential blocks and arguments. |
| auto &differential = getDifferential(); |
| auto diffLoc = differential.getLocation(); |
| auto *origEntry = original->getEntryBlock(); |
| auto origFnTy = original->getLoweredFunctionType(); |
| |
| for (auto &origBB : *original) { |
| auto *diffBB = differential.createBasicBlock(); |
| diffBBMap.insert({&origBB, diffBB}); |
| // If the BB is the original entry, then the differential block that we |
| // just created must be the differential function's entry. Create |
| // differential entry arguments and continue. |
| if (&origBB == origEntry) { |
| assert(diffBB->isEntry()); |
| createEntryArguments(&differential); |
| auto *lastArg = diffBB->getArguments().back(); |
| #ifndef NDEBUG |
| auto diffStructLoweredType = remapSILTypeInDifferential( |
| differentialInfo.getLinearMapStructLoweredType(&origBB)); |
| assert(lastArg->getType() == diffStructLoweredType); |
| #endif |
| differentialStructArguments[&origBB] = lastArg; |
| } |
| |
| LLVM_DEBUG({ |
| auto &s = getADDebugStream() |
| << "Original bb" + std::to_string(origBB.getDebugID()) |
| << ": To differentiate or not to differentiate?\n"; |
| for (auto &inst : origBB) { |
| s << (differentialInfo.shouldDifferentiateInstruction(&inst) ? "[x] " |
| : "[ ] ") |
| << inst; |
| } |
| }); |
| } |
| |
| assert(diffBBMap.size() == 1 && |
| "Can only currently handle single basic block functions"); |
| |
| // The differential function has type: |
| // (arg0', ..., argn', entry_df_struct) -> result'. |
| auto diffParamArgs = |
| differential.getArgumentsWithoutIndirectResults().drop_back(); |
| assert(diffParamArgs.size() == |
| witness->getConfig().parameterIndices->getNumIndices()); |
| auto origParamArgs = original->getArgumentsWithoutIndirectResults(); |
| |
| // TODO(TF-788): Re-enable non-varied result warning. |
| /* |
| // Check if result is not varied. |
| SmallVector<SILValue, 8> origFormalResults; |
| collectAllFormalResultsInTypeOrder(*original, origFormalResults); |
| std::get<0>(pair); |
| for (auto resultIndex : getConfig().results->getIndices()) { |
| auto origResult = origFormalResults[resultIndex]; |
| // Emit warning if original result is not varied, because it will always |
| // have a zero derivative. |
| if (!activityInfo.isVaried(origResult, getConfig().parameters)) { |
| // Emit fixit if original result has a valid source location. |
| auto startLoc = origResult.getLoc().getStartSourceLoc(); |
| auto endLoc = origResult.getLoc().getEndSourceLoc(); |
| if (startLoc.isValid() && endLoc.isValid()) { |
| context.diagnose(startLoc, diag::autodiff_nonvaried_result_fixit) |
| .fixItInsert(startLoc, "withoutDerivative(at:") |
| .fixItInsertAfter(endLoc, ")"); |
| } |
| } |
| } |
| */ |
| |
| // Initialize tangent mapping for parameters. |
| auto diffParamsIt = getConfig().parameterIndices->begin(); |
| for (auto index : range(diffParamArgs.size())) { |
| auto *diffArg = diffParamArgs[index]; |
| auto *origArg = origParamArgs[*diffParamsIt]; |
| ++diffParamsIt; |
| if (diffArg->getType().isAddress()) { |
| setTangentBuffer(origEntry, origArg, diffArg); |
| } else { |
| setTangentValue(origEntry, origArg, makeConcreteTangentValue(diffArg)); |
| } |
| LLVM_DEBUG(getADDebugStream() |
| << "Assigned parameter " << *diffArg |
| << " as the tangent of original result " << *origArg); |
| } |
| |
| // Initialize tangent mapping for original indirect results and non-wrt |
| // `inout` parameters. The tangent buffers of these address values are |
| // differential indirect results. |
| |
| // Collect original results. |
| SmallVector<SILValue, 2> originalResults; |
| collectAllFormalResultsInTypeOrder(*original, originalResults); |
| |
| // Iterate over differentiability results. |
| differentialBuilder.setInsertionPoint(differential.getEntryBlock()); |
| auto diffIndResults = differential.getIndirectResults(); |
| unsigned differentialIndirectResultIndex = 0; |
| for (auto resultIndex : getConfig().resultIndices->getIndices()) { |
| auto origResult = originalResults[resultIndex]; |
| // Handle original formal indirect result. |
| if (resultIndex < origFnTy->getNumResults()) { |
| // Skip original direct results. |
| if (origResult->getType().isObject()) |
| continue; |
| auto diffIndResult = diffIndResults[differentialIndirectResultIndex++]; |
| setTangentBuffer(origEntry, origResult, diffIndResult); |
| // If original indirect result is non-varied, zero-initialize its tangent |
| // buffer. |
| if (!activityInfo.isVaried(origResult, getConfig().parameterIndices)) |
| emitZeroIndirect(diffIndResult->getType().getASTType(), diffIndResult, |
| diffLoc); |
| continue; |
| } |
| // Handle original non-wrt `inout` parameter. |
| // Only original *non-wrt* `inout` parameters have corresponding |
| // differential indirect results. |
| auto inoutParamIndex = resultIndex - origFnTy->getNumResults(); |
| auto inoutParamIt = std::next( |
| origFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex); |
| auto paramIndex = |
| std::distance(origFnTy->getParameters().begin(), &*inoutParamIt); |
| if (getConfig().parameterIndices->contains(paramIndex)) |
| continue; |
| auto diffIndResult = diffIndResults[differentialIndirectResultIndex++]; |
| setTangentBuffer(origEntry, origResult, diffIndResult); |
| // Original `inout` parameters are initialized, so their tangent buffers |
| // must also be initialized. |
| emitZeroIndirect(diffIndResult->getType().getASTType(), diffIndResult, |
| diffLoc); |
| } |
| } |
| |
| /*static*/ SILFunction *JVPCloner::Implementation::createEmptyDifferential( |
| ADContext &context, SILDifferentiabilityWitness *witness, |
| LinearMapInfo *linearMapInfo) { |
| auto &module = context.getModule(); |
| auto *original = witness->getOriginalFunction(); |
| auto *jvp = witness->getJVP(); |
| auto origTy = original->getLoweredFunctionType(); |
| // Get witness generic signature for remapping types. |
| // Witness generic signature may have more requirements than JVP generic |
| // signature: when witness generic signature has same-type requirements |
| // binding all generic parameters to concrete types, JVP function type uses |
| // all the concrete types and JVP generic signature is null. |
| CanGenericSignature witnessCanGenSig; |
| if (auto witnessGenSig = witness->getDerivativeGenericSignature()) |
| witnessCanGenSig = witnessGenSig->getCanonicalSignature(); |
| auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule()); |
| |
| // Parameters of the differential are: |
| // - the tangent values of the wrt parameters. |
| // - the differential struct for the original entry. |
| // Result of the differential is in the tangent space of the original |
| // result. |
| SmallVector<SILParameterInfo, 8> dfParams; |
| SmallVector<SILResultInfo, 8> dfResults; |
| auto origParams = origTy->getParameters(); |
| auto config = witness->getConfig(); |
| |
| for (auto resultIndex : config.resultIndices->getIndices()) { |
| if (resultIndex < origTy->getNumResults()) { |
| // Handle formal original result. |
| auto origResult = origTy->getResults()[resultIndex]; |
| origResult = origResult.getWithInterfaceType( |
| origResult.getInterfaceType()->getCanonicalType(witnessCanGenSig)); |
| dfResults.push_back( |
| SILResultInfo(origResult.getInterfaceType() |
| ->getAutoDiffTangentSpace(lookupConformance) |
| ->getType() |
| ->getCanonicalType(witnessCanGenSig), |
| origResult.getConvention())); |
| } else { |
| // Handle original `inout` parameter. |
| auto inoutParamIndex = resultIndex - origTy->getNumResults(); |
| auto inoutParamIt = std::next( |
| origTy->getIndirectMutatingParameters().begin(), inoutParamIndex); |
| auto paramIndex = |
| std::distance(origTy->getParameters().begin(), &*inoutParamIt); |
| // If the original `inout` parameter is a differentiability parameter, |
| // then it already has a corresponding differential parameter. Do not add |
| // a corresponding differential result. |
| if (config.parameterIndices->contains(paramIndex)) |
| continue; |
| auto inoutParam = origTy->getParameters()[paramIndex]; |
| auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace( |
| lookupConformance); |
| assert(paramTan && "Parameter type does not have a tangent space?"); |
| dfResults.push_back( |
| {paramTan->getCanonicalType(), ResultConvention::Indirect}); |
| } |
| } |
| |
| // Add differential parameters for the requested wrt parameters. |
| for (auto i : config.parameterIndices->getIndices()) { |
| auto origParam = origParams[i]; |
| origParam = origParam.getWithInterfaceType( |
| origParam.getInterfaceType()->getCanonicalType(witnessCanGenSig)); |
| dfParams.push_back( |
| SILParameterInfo(origParam.getInterfaceType() |
| ->getAutoDiffTangentSpace(lookupConformance) |
| ->getType() |
| ->getCanonicalType(witnessCanGenSig), |
| origParam.getConvention())); |
| } |
| |
| // Accept a differential struct in the differential parameter list. This is |
| // the returned differential's closure context. |
| auto *origEntry = original->getEntryBlock(); |
| auto *dfStruct = linearMapInfo->getLinearMapStruct(origEntry); |
| auto dfStructType = |
| dfStruct->getDeclaredInterfaceType()->getCanonicalType(witnessCanGenSig); |
| dfParams.push_back({dfStructType, ParameterConvention::Direct_Owned}); |
| |
| Mangle::ASTMangler mangler; |
| auto diffName = |
| original->getASTContext() |
| .getIdentifier(mangler.mangleAutoDiffLinearMapHelper( |
| original->getName(), AutoDiffLinearMapKind::Differential, |
| witness->getConfig())) |
| .str(); |
| // Set differential generic signature equal to JVP generic signature. |
| // Do not use witness generic signature, which may have same-type requirements |
| // binding all generic parameters to concrete types. |
| auto diffGenericSig = |
| jvp->getLoweredFunctionType()->getSubstGenericSignature(); |
| auto *diffGenericEnv = |
| diffGenericSig ? diffGenericSig->getGenericEnvironment() : nullptr; |
| auto diffType = SILFunctionType::get( |
| diffGenericSig, SILExtInfo::getThin(), origTy->getCoroutineKind(), |
| origTy->getCalleeConvention(), dfParams, {}, dfResults, None, |
| origTy->getPatternSubstitutions(), origTy->getInvocationSubstitutions(), |
| original->getASTContext()); |
| |
| SILOptFunctionBuilder fb(context.getTransform()); |
| auto linkage = jvp->isSerialized() ? SILLinkage::Public : SILLinkage::Hidden; |
| auto *differential = fb.createFunction( |
| linkage, diffName, diffType, diffGenericEnv, original->getLocation(), |
| original->isBare(), IsNotTransparent, jvp->isSerialized(), |
| original->isDynamicallyReplaceable()); |
| differential->setDebugScope( |
| new (module) SILDebugScope(original->getLocation(), differential)); |
| |
| return differential; |
| } |
| |
| bool JVPCloner::Implementation::run() { |
| PrettyStackTraceSILFunction trace("generating JVP and differential for", |
| original); |
| LLVM_DEBUG(getADDebugStream() << "Cloning original @" << original->getName() |
| << " to jvp @" << jvp->getName() << '\n'); |
| // Create JVP and differential entry and arguments. |
| auto *entry = jvp->createBasicBlock(); |
| createEntryArguments(jvp); |
| prepareForDifferentialGeneration(); |
| // Clone. |
| SmallVector<SILValue, 4> entryArgs(entry->getArguments().begin(), |
| entry->getArguments().end()); |
| cloneFunctionBody(original, entry, entryArgs); |
| emitReturnInstForDifferential(); |
| // If errors occurred, back out. |
| if (errorOccurred) |
| return true; |
| LLVM_DEBUG(getADDebugStream() |
| << "Generated JVP for " << original->getName() << ":\n" |
| << *jvp); |
| LLVM_DEBUG(getADDebugStream() |
| << "Generated differential for " << original->getName() << ":\n" |
| << getDifferential()); |
| return errorOccurred; |
| } |
| |
| } // end namespace autodiff |
| } // end namespace swift |
| |
| bool JVPCloner::run() { |
| bool foundError = impl.run(); |
| #ifndef NDEBUG |
| if (!foundError) |
| getJVP().verify(); |
| #endif |
| return foundError; |
| } |
| |
| SILFunction &JVPCloner::getJVP() const { return impl.getJVP(); } |