| //===--- Differentiation.cpp - SIL Automatic Differentiation --*- C++ -*---===// |
| // |
| // This source file is part of the Swift.org open source project |
| // |
| // Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors |
| // Licensed under Apache License v2.0 with Runtime Library Exception |
| // |
| // See https://swift.org/LICENSE.txt for license information |
| // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // SWIFT_ENABLE_TENSORFLOW |
| // |
| // This file implements automatic differentiation. |
| // |
| // NOTE: Although the AD feature is developed as part of the Swift for |
| // TensorFlow project, it is completely independent from TensorFlow support. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #define DEBUG_TYPE "differentiation" |
| |
| #include "Differentiation.h" |
| #include "swift/AST/ASTMangler.h" |
| #include "swift/AST/ASTPrinter.h" |
| #include "swift/AST/AnyFunctionRef.h" |
| #include "swift/AST/AutoDiff.h" |
| #include "swift/AST/Builtins.h" |
| #include "swift/AST/DeclContext.h" |
| #include "swift/AST/DiagnosticsSIL.h" |
| #include "swift/AST/Expr.h" |
| #include "swift/AST/GenericEnvironment.h" |
| #include "swift/AST/GenericSignatureBuilder.h" |
| #include "swift/AST/SourceFile.h" |
| #include "swift/AST/ParameterList.h" |
| #include "swift/AST/SubstitutionMap.h" |
| #include "swift/AST/TypeCheckRequests.h" |
| #include "swift/SIL/FormalLinkage.h" |
| #include "swift/SIL/LoopInfo.h" |
| #include "swift/SIL/Projection.h" |
| #include "swift/SIL/SILBuilder.h" |
| #include "swift/SIL/TypeSubstCloner.h" |
| #include "swift/SILOptimizer/Analysis/DominanceAnalysis.h" |
| #include "swift/SILOptimizer/Analysis/LoopAnalysis.h" |
| #include "swift/SILOptimizer/PassManager/Passes.h" |
| #include "swift/SILOptimizer/PassManager/Transforms.h" |
| #include "swift/SILOptimizer/Utils/LoopUtils.h" |
| #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" |
| #include "llvm/ADT/APSInt.h" |
| #include "llvm/ADT/BreadthFirstIterator.h" |
| #include "llvm/ADT/DenseSet.h" |
| |
| using namespace swift; |
| using llvm::DenseMap; |
| using llvm::SmallDenseMap; |
| using llvm::SmallDenseSet; |
| using llvm::SmallMapVector; |
| using llvm::SmallSet; |
| |
| /// This flag is used to disable `differentiable_function_extract` instruction |
| /// folding for SIL testing purposes. |
| static llvm::cl::opt<bool> SkipFoldingDifferentiableFunctionExtraction( |
| "differentiation-skip-folding-differentiable-function-extraction", |
| llvm::cl::init(true)); |
| |
| //===----------------------------------------------------------------------===// |
| // Helpers |
| //===----------------------------------------------------------------------===// |
| |
| /// Prints an "[AD] " prefix to `llvm::dbgs()` and returns the debug stream. |
| /// This is being used to print short debug messages within the AD pass. |
| static raw_ostream &getADDebugStream() { return llvm::dbgs() << "[AD] "; } |
| |
| /// Given a dumpable value, dumps it to `llvm::dbgs()`. |
| template <typename T> static inline void debugDump(T &v) { |
| LLVM_DEBUG(llvm::dbgs() << "\n==== BEGIN DEBUG DUMP ====\n" |
| << v << "\n==== END DEBUG DUMP ====\n"); |
| } |
| |
| static bool isWithoutDerivative(SILValue v) { |
| if (auto *fnRef = dyn_cast<FunctionRefInst>(v)) |
| return fnRef->getReferencedFunctionOrNull()->hasSemanticsAttr( |
| "autodiff.nonvarying"); |
| return false; |
| } |
| |
| static bool isArrayLiteralIntrinsic(ApplyInst *ai) { |
| return ai->hasSemantics("array.uninitialized_intrinsic"); |
| } |
| |
| static ApplyInst *getAllocateUninitializedArrayIntrinsic(SILValue v) { |
| if (auto *applyInst = dyn_cast<ApplyInst>(v)) |
| if (isArrayLiteralIntrinsic(applyInst)) |
| return applyInst; |
| return nullptr; |
| } |
| |
| /// Given a value, find its single `destructure_tuple` user if the value is |
| /// tuple-typed and such a user exists. |
| static DestructureTupleInst *getSingleDestructureTupleUser(SILValue value) { |
| bool foundDestructureTupleUser = false; |
| if (!value->getType().is<TupleType>()) |
| return nullptr; |
| DestructureTupleInst *result = nullptr; |
| for (auto *use : value->getUses()) { |
| if (auto *dti = dyn_cast<DestructureTupleInst>(use->getUser())) { |
| assert(!foundDestructureTupleUser && |
| "There should only be one `destructure_tuple` user of a tuple"); |
| foundDestructureTupleUser = true; |
| result = dti; |
| } |
| } |
| return result; |
| } |
| |
| /// Given an `apply` instruction, apply the given callback to each of its |
| /// direct results. If the `apply` instruction has a single `destructure_tuple` |
| /// user, apply the callback to the results of the `destructure_tuple` user. |
| static void forEachApplyDirectResult( |
| ApplyInst *ai, llvm::function_ref<void(SILValue)> resultCallback) { |
| if (!ai->getType().is<TupleType>()) { |
| resultCallback(ai); |
| return; |
| } |
| if (auto *dti = getSingleDestructureTupleUser(ai)) |
| for (auto result : dti->getResults()) |
| resultCallback(result); |
| } |
| |
| /// Given a function, gather all of its formal results (both direct and |
| /// indirect) in an order defined by its result type. Note that "formal results" |
| /// refer to result values in the body of the function, not at call sites. |
| static void |
| collectAllFormalResultsInTypeOrder(SILFunction &function, |
| SmallVectorImpl<SILValue> &results) { |
| SILFunctionConventions convs(function.getLoweredFunctionType(), |
| function.getModule()); |
| auto indResults = function.getIndirectResults(); |
| auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator()); |
| auto retVal = retInst->getOperand(); |
| SmallVector<SILValue, 8> dirResults; |
| if (auto *tupleInst = |
| dyn_cast_or_null<TupleInst>(retVal->getDefiningInstruction())) |
| dirResults.append(tupleInst->getElements().begin(), |
| tupleInst->getElements().end()); |
| else |
| dirResults.push_back(retVal); |
| unsigned indResIdx = 0, dirResIdx = 0; |
| for (auto &resInfo : convs.getResults()) |
| results.push_back(resInfo.isFormalDirect() ? dirResults[dirResIdx++] |
| : indResults[indResIdx++]); |
| } |
| |
| /// Given a function, gather all of its direct results in an order defined by |
| /// its result type. Note that "formal results" refer to result values in the |
| /// body of the function, not at call sites. |
| static void |
| collectAllDirectResultsInTypeOrder(SILFunction &function, |
| SmallVectorImpl<SILValue> &results) { |
| SILFunctionConventions convs(function.getLoweredFunctionType(), |
| function.getModule()); |
| auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator()); |
| auto retVal = retInst->getOperand(); |
| if (auto *tupleInst = dyn_cast<TupleInst>(retVal)) |
| results.append(tupleInst->getElements().begin(), |
| tupleInst->getElements().end()); |
| else |
| results.push_back(retVal); |
| } |
| |
| /// Given a function call site, gather all of its actual results (both direct |
| /// and indirect) in an order defined by its result type. |
| static void collectAllActualResultsInTypeOrder( |
| ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults, |
| SmallVectorImpl<SILValue> &results) { |
| auto calleeConvs = ai->getSubstCalleeConv(); |
| unsigned indResIdx = 0, dirResIdx = 0; |
| for (auto &resInfo : calleeConvs.getResults()) { |
| results.push_back(resInfo.isFormalDirect() |
| ? extractedDirectResults[dirResIdx++] |
| : ai->getIndirectSILResults()[indResIdx++]); |
| } |
| } |
| |
| /// Given a range of types, joins these into a single type. If there's exactly |
| /// one element type, returns that element type. Otherwise, creates a tuple type |
| /// of all element types. |
| template <typename TypeRange> |
| static CanType joinElementTypes(TypeRange &&range, const ASTContext &ctx) { |
| if (range.size() == 1) |
| return range.front(); |
| auto typeElts = |
| map<SmallVector<TupleTypeElt, 8>>(range, [&](Type type) { return type; }); |
| return TupleType::get(typeElts, ctx); |
| } |
| |
| /// Given a range of SIL values, retrieves the canonical types of these values, |
| /// and joins these types into a single type. |
| template <typename SILValueRange> |
| static CanType joinElementTypesFromValues(SILValueRange &&range, |
| const ASTContext &ctx) { |
| if (range.size() == 1) |
| return range.front()->getType().getASTType(); |
| SmallVector<TupleTypeElt, 8> elts; |
| transform(range, elts.begin(), |
| [&](SILValue val) { return val->getType().getASTType(); }); |
| return TupleType::get(elts, ctx)->getCanonicalType(); |
| } |
| |
| /// Given an operator name, such as '+', and a protocol, returns the '+' |
| /// operator. If the operator does not exist in the protocol, returns null. |
| static FuncDecl *findOperatorDeclInProtocol(DeclName operatorName, |
| ProtocolDecl *protocol) { |
| assert(operatorName.isOperator()); |
| // Find the operator requirement in the given protocol declaration. |
| auto opLookup = protocol->lookupDirect(operatorName); |
| for (auto *decl : opLookup) { |
| if (!decl->isProtocolRequirement()) |
| continue; |
| auto *fd = dyn_cast<FuncDecl>(decl); |
| if (!fd || !fd->isStatic() || !fd->isOperator()) |
| continue; |
| return fd; |
| } |
| // Not found. |
| return nullptr; |
| } |
| |
| /// Returns the "constrained" derivative generic signature given: |
| /// - An original SIL function type. |
| /// - A wrt parameter index subset. |
| /// - A possibly uncanonical derivative generic signature (optional). |
| /// - Additional derivative requirements (optional). |
| /// The constrained derivative generic signature constrains all wrt parameters |
| /// to conform to `Differentiable`. |
| static GenericSignature getConstrainedDerivativeGenericSignature( |
| CanSILFunctionType originalFnTy, IndexSubset *paramIndexSet, |
| GenericSignature derivativeGenSig) { |
| if (!derivativeGenSig) |
| derivativeGenSig = originalFnTy->getGenericSignature(); |
| if (!derivativeGenSig) |
| return nullptr; |
| // Constrain all wrt parameters to `Differentiable`. |
| auto &ctx = derivativeGenSig->getASTContext(); |
| auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable); |
| SmallVector<Requirement, 4> requirements; |
| for (unsigned paramIdx : paramIndexSet->getIndices()) { |
| auto paramType = originalFnTy->getParameters()[paramIdx].getType(); |
| Requirement req(RequirementKind::Conformance, paramType, |
| diffableProto->getDeclaredType()); |
| requirements.push_back(req); |
| } |
| return evaluateOrDefault( |
| ctx.evaluator, |
| AbstractGenericSignatureRequest{ |
| derivativeGenSig.getPointer(), |
| /*addedGenericParams*/ {}, |
| std::move(requirements)}, |
| nullptr); |
| } |
| |
| /// Returns the canonical derivative generic signature for the given |
| /// `[differentiable]` attribute and original function. |
| /// - Return the `[differentiable]` attribute derivative generic signature if |
| /// it exists. |
| /// - Otherwise, return the original function's generic signature. |
| static CanGenericSignature getDerivativeGenericSignature( |
| SILDifferentiableAttr *attr, SILFunction *original) { |
| if (auto attrDerivativeGenSig = attr->getDerivativeGenericSignature()) |
| return attrDerivativeGenSig->getCanonicalSignature(); |
| return original->getLoweredFunctionType()->getGenericSignature(); |
| } |
| |
| // Clone the generic parameters of the given generic signature and return a new |
| // `GenericParamList`. |
| static GenericParamList *cloneGenericParameters(ASTContext &ctx, |
| DeclContext *dc, |
| CanGenericSignature sig) { |
| SmallVector<GenericTypeParamDecl *, 2> clonedParams; |
| for (auto paramType : sig->getGenericParams()) { |
| auto clonedParam = new (ctx) GenericTypeParamDecl( |
| dc, paramType->getName(), SourceLoc(), paramType->getDepth(), |
| paramType->getIndex()); |
| clonedParam->setDeclContext(dc); |
| clonedParam->setImplicit(true); |
| clonedParams.push_back(clonedParam); |
| } |
| return GenericParamList::create(ctx, SourceLoc(), clonedParams, SourceLoc()); |
| } |
| |
| /// Given an `differentiable_function` instruction, find the corresponding |
| /// differential operator used in the AST. If no differential operator is found, |
| /// return nullptr. |
| static DifferentiableFunctionExpr * |
| findDifferentialOperator(DifferentiableFunctionInst *inst) { |
| return inst->getLoc().getAsASTNode<DifferentiableFunctionExpr>(); |
| } |
| |
| /// Returns the underlying instruction for the given SILValue, if it exists, |
| /// peering through function conversion instructions. |
| template<class Inst> |
| static Inst *peerThroughFunctionConversions(SILValue value) { |
| if (auto *inst = dyn_cast<Inst>(value)) |
| return inst; |
| if (auto *thinToThick = dyn_cast<ThinToThickFunctionInst>(value)) |
| return peerThroughFunctionConversions<Inst>(thinToThick->getOperand()); |
| if (auto *convertFn = dyn_cast<ConvertFunctionInst>(value)) |
| return peerThroughFunctionConversions<Inst>(convertFn->getOperand()); |
| if (auto *partialApply = dyn_cast<PartialApplyInst>(value)) |
| return peerThroughFunctionConversions<Inst>(partialApply->getCallee()); |
| return nullptr; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Auxiliary data structures |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| class ADContext; |
| |
| /// The invoker of a differentiation task. It can be some user syntax, e.g. |
| /// an `differentiable_function` instruction lowered from an |
| /// `DifferentiableFunctionExpr` expression, the differentiation pass, or |
| /// nothing at all. This will be used to emit informative diagnostics. |
| struct DifferentiationInvoker { |
| public: |
| /// The kind of the invoker of a differentiation task. |
| enum class Kind { |
| // Invoked by an `differentiable_function` instruction, which may or may not |
| // be linked to a Swift AST node (e.g. an `DifferentiableFunctionExpr` |
| // expression). |
| DifferentiableFunctionInst, |
| |
| // Invoked by the indirect application of differentiation. This case has an |
| // associated original `apply` instruction and `[differentiable]` attribute. |
| IndirectDifferentiation, |
| |
| // Invoker by a `[differentiable]` attribute in SIL **without** being linked |
| // to a Swift AST attribute. This case has an associated `[differentiable]` |
| // attribute. |
| SILDifferentiableAttribute |
| }; |
| |
| private: |
| Kind kind; |
| union Value { |
| /// The instruction associated with the `DifferentiableFunctionInst` case. |
| DifferentiableFunctionInst *diffFuncInst; |
| Value(DifferentiableFunctionInst *inst) : diffFuncInst(inst) {} |
| |
| /// The parent `apply` instruction and `[differentiable]` attribute |
| /// associated with the `IndirectDifferentiation` case. |
| std::pair<ApplyInst *, SILDifferentiableAttr *> |
| indirectDifferentiation; |
| Value(ApplyInst *applyInst, SILDifferentiableAttr *attr) |
| : indirectDifferentiation({applyInst, attr}) {} |
| |
| /// The `[differentiable]` attribute associated with the |
| /// `SILDifferentiableAttribute` case. |
| SILDifferentiableAttr *silDifferentiableAttribute; |
| Value(SILDifferentiableAttr *attr) : silDifferentiableAttribute(attr) {} |
| } value; |
| |
| /*implicit*/ |
| DifferentiationInvoker(Kind kind, Value value) : kind(kind), value(value) {} |
| |
| public: |
| DifferentiationInvoker(DifferentiableFunctionInst *inst) |
| : kind(Kind::DifferentiableFunctionInst), value(inst) {} |
| DifferentiationInvoker(ApplyInst *applyInst, SILDifferentiableAttr *attr) |
| : kind(Kind::IndirectDifferentiation), |
| value({applyInst, attr}) {} |
| DifferentiationInvoker(SILDifferentiableAttr *attr) |
| : kind(Kind::SILDifferentiableAttribute), value(attr) {} |
| |
| Kind getKind() const { return kind; } |
| |
| DifferentiableFunctionInst *getDifferentiableFunctionInst() const { |
| assert(kind == Kind::DifferentiableFunctionInst); |
| return value.diffFuncInst; |
| } |
| |
| std::pair<ApplyInst *, SILDifferentiableAttr *> |
| getIndirectDifferentiation() const { |
| assert(kind == Kind::IndirectDifferentiation); |
| return value.indirectDifferentiation; |
| } |
| |
| |
| SILDifferentiableAttr *getSILDifferentiableAttribute() const { |
| assert(kind == Kind::SILDifferentiableAttribute); |
| return value.silDifferentiableAttribute; |
| } |
| |
| SourceLoc getLocation() const { |
| switch (kind) { |
| case Kind::DifferentiableFunctionInst: |
| return getDifferentiableFunctionInst()->getLoc().getSourceLoc(); |
| case Kind::IndirectDifferentiation: |
| return getIndirectDifferentiation().first->getLoc().getSourceLoc(); |
| case Kind::SILDifferentiableAttribute: |
| return getSILDifferentiableAttribute()->getOriginal() |
| ->getLocation().getSourceLoc(); |
| } |
| } |
| |
| void print(llvm::raw_ostream &os) const; |
| }; |
| |
| class DifferentiableActivityInfo; |
| |
| /// Information about the JVP/VJP function produced during JVP/VJP generation, |
| /// e.g. mappings from original values to corresponding values in the |
| /// pullback/differential struct. |
| /// |
| /// A linear map struct is an aggregate value containing linear maps checkpointed |
| /// during the JVP/VJP computation. Linear map structs are generated for every |
| /// original function during JVP/VJP generation. Linear map struct values are |
| /// constructed by JVP/VJP functions and consumed by pullback/differential |
| /// functions. |
| class LinearMapInfo { |
| private: |
| /// The linear map kind. |
| AutoDiffLinearMapKind kind; |
| |
| /// The original function. |
| SILFunction *const original; |
| |
| /// The derivative function. |
| SILFunction *const derivative; |
| |
| /// Activity info of the original function. |
| const DifferentiableActivityInfo &activityInfo; |
| |
| /// Differentiation indices of the function. |
| const SILAutoDiffIndices &indices; |
| |
| /// Mapping from original basic blocks to linear map structs. |
| DenseMap<SILBasicBlock *, StructDecl *> linearMapStructs; |
| |
| /// Mapping from original basic blocks to branching trace enums. |
| /// For pullbacks: these are predecessor enums. |
| /// For differentials: these are successor enums. |
| DenseMap<SILBasicBlock *, EnumDecl *> branchingTraceDecls; |
| |
| /// Mapping from `apply` and `struct_extract` instructions in the original |
| /// function to the corresponding linear map declaration in the linear map |
| /// struct. |
| DenseMap<SILInstruction *, VarDecl *> linearMapValueMap; |
| |
| /// Mapping from predecessor+succcessor basic block pairs in original function |
| /// to the corresponding branching trace enum case. |
| DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, EnumElementDecl *> |
| branchingTraceEnumCases; |
| |
| /// Mapping from linear map structs to their branching trace enum fields. |
| DenseMap<StructDecl *, VarDecl *> linearMapStructEnumFields; |
| |
| /// A type converter, used to compute struct/enum SIL types. |
| Lowering::TypeConverter &typeConverter; |
| |
| private: |
| /// Remaps the given type into the derivative function's context. |
| SILType remapTypeInDerivative(SILType ty) { |
| if (ty.hasArchetype()) |
| return derivative->mapTypeIntoContext(ty.mapTypeOutOfContext()); |
| return derivative->mapTypeIntoContext(ty); |
| } |
| |
| /// Adds a `VarDecl` member with the given name and type to the given nominal |
| /// declaration. |
| VarDecl *addVarDecl(NominalTypeDecl *nominal, StringRef name, Type type) { |
| auto &astCtx = nominal->getASTContext(); |
| auto id = astCtx.getIdentifier(name); |
| auto *varDecl = new (astCtx) VarDecl( |
| /*IsStatic*/ false, VarDecl::Introducer::Var, /*IsCaptureList*/ false, |
| SourceLoc(), id, nominal); |
| varDecl->setAccess(nominal->getEffectiveAccess()); |
| if (type->hasArchetype()) |
| varDecl->setInterfaceType(type->mapTypeOutOfContext()); |
| else |
| varDecl->setInterfaceType(type); |
| nominal->addMember(varDecl); |
| return varDecl; |
| } |
| |
| /// Retrieves the file unit that contains implicit declarations in the |
| /// current Swift module. If it does not exist, create one. |
| /// |
| // FIXME: Currently it defaults to the file containing `original`, if it can |
| // be determined. Otherwise, it defaults to any file unit in the module. To |
| // handle this more properly, we could revive the DerivedFileUnit class to |
| // contain all synthesized implicit type declarations. |
| SourceFile &getDeclarationFileUnit() { |
| if (original->hasLocation()) |
| if (auto *declContext = original->getLocation().getAsDeclContext()) |
| if (auto *parentSourceFile = declContext->getParentSourceFile()) |
| return *parentSourceFile; |
| for (auto *file : original->getModule().getSwiftModule()->getFiles()) |
| if (auto *src = dyn_cast<SourceFile>(file)) |
| return *src; |
| llvm_unreachable("No files?"); |
| } |
| |
| /// Compute and set the access level for the given nominal type, given the |
| /// original function linkage. |
| void computeAccessLevel( |
| NominalTypeDecl *nominal, SILLinkage originalLinkage) { |
| auto &astCtx = nominal->getASTContext(); |
| switch (originalLinkage) { |
| case swift::SILLinkage::Public: |
| case swift::SILLinkage::PublicNonABI: |
| nominal->setAccess(AccessLevel::Internal); |
| nominal->getAttrs().add( |
| new (astCtx) UsableFromInlineAttr(/*Implicit*/ true)); |
| break; |
| case swift::SILLinkage::Hidden: |
| case swift::SILLinkage::Shared: |
| nominal->setAccess(AccessLevel::Internal); |
| break; |
| case swift::SILLinkage::Private: |
| nominal->setAccess(AccessLevel::FilePrivate); |
| break; |
| default: |
| // When the original function has external linkage, we create an internal |
| // struct for use by our own module. This is necessary for cross-cell |
| // differentiation in Jupyter. |
| // TODO: Add a test in the compiler that exercises a similar situation as |
| // cross-cell differentiation in Jupyter. |
| nominal->setAccess(AccessLevel::Internal); |
| } |
| } |
| |
| /// Creates an enum declaration with the given JVP/VJP generic signature, |
| /// whose cases represent the predecessors/successors of the given original |
| /// block. |
| EnumDecl *createBranchingTraceDecl(SILBasicBlock *originalBB, |
| SILAutoDiffIndices indices, |
| CanGenericSignature genericSig, |
| SILLoopInfo *loopInfo) { |
| assert(originalBB->getParent() == original); |
| auto &astCtx = original->getASTContext(); |
| auto *moduleDecl = original->getModule().getSwiftModule(); |
| auto &file = getDeclarationFileUnit(); |
| // Create a branching trace enum. |
| std::string enumName; |
| switch (kind) { |
| case AutoDiffLinearMapKind::Differential: |
| enumName = |
| "_AD__" + original->getName().str() + |
| "_bb" + std::to_string(originalBB->getDebugID()) + |
| "__Succ__" + indices.mangle(); |
| break; |
| case AutoDiffLinearMapKind::Pullback: |
| enumName = |
| "_AD__" + original->getName().str() + |
| "_bb" + std::to_string(originalBB->getDebugID()) + |
| "__Pred__" + indices.mangle(); |
| break; |
| } |
| auto enumId = astCtx.getIdentifier(enumName); |
| auto loc = original->getLocation().getSourceLoc(); |
| GenericParamList *genericParams = nullptr; |
| if (genericSig) |
| genericParams = cloneGenericParameters(astCtx, &file, genericSig); |
| auto *branchingTraceDecl = new (astCtx) EnumDecl( |
| /*EnumLoc*/ SourceLoc(), /*Name*/ enumId, /*NameLoc*/ loc, |
| /*Inherited*/ {}, /*GenericParams*/ genericParams, /*DC*/ &file); |
| // Note: must mark enum as implicit to satisfy assertion in |
| // `Parser::parseDeclListDelayed`. |
| branchingTraceDecl->setImplicit(); |
| if (genericSig) |
| branchingTraceDecl->setGenericSignature(genericSig); |
| computeAccessLevel(branchingTraceDecl, |
| original->getEffectiveSymbolLinkage()); |
| branchingTraceDecl->computeType(); |
| assert(branchingTraceDecl->hasInterfaceType()); |
| file.addVisibleDecl(branchingTraceDecl); |
| // Add basic block enum cases. |
| for (auto *predBB : originalBB->getPredecessorBlocks()) { |
| auto bbId = "bb" + std::to_string(predBB->getDebugID()); |
| auto *linearMapStruct = getLinearMapStruct(predBB); |
| assert(linearMapStruct); |
| auto linearMapStructTy = |
| linearMapStruct->getDeclaredInterfaceType()->getCanonicalType(); |
| // Create dummy declaration representing enum case parameter. |
| auto *decl = new (astCtx) |
| ParamDecl(ParamDecl::Specifier::Default, loc, loc, Identifier(), loc, |
| Identifier(), moduleDecl); |
| if (linearMapStructTy->hasArchetype()) |
| decl->setInterfaceType(linearMapStructTy->mapTypeOutOfContext()); |
| else |
| decl->setInterfaceType(linearMapStructTy); |
| // Create enum element and enum case declarations. |
| auto *paramList = ParameterList::create(astCtx, {decl}); |
| auto *enumEltDecl = new (astCtx) EnumElementDecl( |
| /*IdentifierLoc*/ loc, DeclName(astCtx.getIdentifier(bbId)), |
| paramList, loc, /*RawValueExpr*/ nullptr, branchingTraceDecl); |
| enumEltDecl->setImplicit(); |
| enumEltDecl->computeType(); |
| auto *enumCaseDecl = EnumCaseDecl::create( |
| /*CaseLoc*/ loc, {enumEltDecl}, branchingTraceDecl); |
| enumCaseDecl->setImplicit(); |
| branchingTraceDecl->addMember(enumEltDecl); |
| branchingTraceDecl->addMember(enumCaseDecl); |
| // Record enum element declaration. |
| branchingTraceEnumCases.insert({{predBB, originalBB}, enumEltDecl}); |
| } |
| // If original block is in a loop, mark branching trace enum as indirect. |
| if (loopInfo->getLoopFor(originalBB)) |
| branchingTraceDecl->getAttrs().add( |
| new (astCtx) IndirectAttr(/*Implicit*/ true)); |
| return branchingTraceDecl; |
| } |
| |
| /// Creates a struct declaration with the given JVP/VJP generic signature, for |
| /// storing the linear map values and predecessor/successor basic block of the |
| /// given original block. |
| StructDecl * |
| createLinearMapStruct(SILBasicBlock *originalBB, SILAutoDiffIndices indices, |
| CanGenericSignature genericSig) { |
| assert(originalBB->getParent() == original); |
| auto *original = originalBB->getParent(); |
| auto &astCtx = original->getASTContext(); |
| auto &file = getDeclarationFileUnit(); |
| std::string structName; |
| switch (kind) { |
| case swift::AutoDiffLinearMapKind::Differential: |
| structName = |
| "_AD__" + original->getName().str() + |
| "_bb" + std::to_string(originalBB->getDebugID()) + |
| "__DF__" + indices.mangle(); |
| break; |
| case swift::AutoDiffLinearMapKind::Pullback: |
| structName = |
| "_AD__" + original->getName().str() + |
| "_bb" + std::to_string(originalBB->getDebugID()) + |
| "__PB__" + indices.mangle(); |
| break; |
| } |
| auto structId = astCtx.getIdentifier(structName); |
| GenericParamList *genericParams = nullptr; |
| if (genericSig) |
| genericParams = cloneGenericParameters(astCtx, &file, genericSig); |
| auto *linearMapStruct = new (astCtx) StructDecl( |
| /*StructLoc*/ SourceLoc(), /*Name*/ structId, /*NameLoc*/ SourceLoc(), |
| /*Inherited*/ {}, /*GenericParams*/ genericParams, /*DC*/ &file); |
| // Note: must mark struct as implicit to satisfy assertion in |
| // `Parser::parseDeclListDelayed`. |
| linearMapStruct->setImplicit(); |
| if (genericSig) |
| linearMapStruct->setGenericSignature(genericSig); |
| computeAccessLevel( |
| linearMapStruct, original->getEffectiveSymbolLinkage()); |
| linearMapStruct->computeType(); |
| assert(linearMapStruct->hasInterfaceType()); |
| file.addVisibleDecl(linearMapStruct); |
| return linearMapStruct; |
| } |
| |
| /// Add a linear map to the linear map struct. |
| VarDecl *addLinearMapDecl(SILInstruction *inst, SILType linearMapType) { |
| // IRGen requires decls to have AST types (not `SILFunctionType`), so we |
| // convert the `SILFunctionType` of the linear map to a `FunctionType` with |
| // the same parameters and results. |
| auto silFnTy = linearMapType.castTo<SILFunctionType>(); |
| SmallVector<AnyFunctionType::Param, 8> params; |
| for (auto ¶m : silFnTy->getParameters()) |
| params.push_back(AnyFunctionType::Param(param.getType())); |
| AnyFunctionType *astFnTy; |
| if (auto genSig = silFnTy->getGenericSignature()) |
| astFnTy = GenericFunctionType::get( |
| genSig, params, silFnTy->getAllResultsType().getASTType()); |
| else |
| astFnTy = FunctionType::get( |
| params, silFnTy->getAllResultsType().getASTType()); |
| |
| auto *origBB = inst->getParent(); |
| auto *linMapStruct = getLinearMapStruct(origBB); |
| std::string linearMapName; |
| switch (kind) { |
| case AutoDiffLinearMapKind::Differential: |
| linearMapName = "differential_" + llvm::itostr(linearMapValueMap.size()); |
| break; |
| case AutoDiffLinearMapKind::Pullback: |
| linearMapName = "pullback_" + llvm::itostr(linearMapValueMap.size()); |
| break; |
| } |
| auto *linearMapDecl = addVarDecl(linMapStruct, linearMapName, astFnTy); |
| linearMapValueMap.insert({inst, linearMapDecl}); |
| return linearMapDecl; |
| } |
| |
| /// Given an `apply` instruction, conditionally adds its linear map function |
| /// to the linear map struct if it is active. |
| void addLinearMapToStruct(ADContext &context, ApplyInst *ai, |
| const SILAutoDiffIndices &indices); |
| |
| /// Generate linear map struct and branching enum declarations for the given |
| /// function. Linear map structs are populated with linear map fields and a |
| /// branching enum field. |
| void generateDifferentiationDataStructures( |
| ADContext &context, const SILAutoDiffIndices &indices, |
| SILFunction *derivative); |
| |
| public: |
| bool shouldDifferentiateApplyInst(ApplyInst *ai); |
| bool shouldDifferentiateInstruction(SILInstruction *inst); |
| |
| LinearMapInfo(const LinearMapInfo &) = delete; |
| LinearMapInfo &operator=(const LinearMapInfo &) = delete; |
| |
| explicit LinearMapInfo(ADContext &context, |
| AutoDiffLinearMapKind kind, |
| SILFunction *original, SILFunction *derivative, |
| const SILAutoDiffIndices &indices, |
| const DifferentiableActivityInfo &activityInfo); |
| |
| /// Returns the linear map struct associated with the given original block. |
| StructDecl *getLinearMapStruct(SILBasicBlock *origBB) const { |
| return linearMapStructs.lookup(origBB); |
| } |
| |
| /// Returns the lowered SIL type of the linear map struct associated with the |
| /// given original block. |
| SILType getLinearMapStructLoweredType(SILBasicBlock *origBB) const { |
| auto *linMapStruct = getLinearMapStruct(origBB); |
| auto linMapStructType = |
| linMapStruct->getDeclaredInterfaceType()->getCanonicalType(); |
| return typeConverter.getLoweredType(linMapStructType, |
| ResilienceExpansion::Minimal); |
| } |
| |
| /// Returns the branching trace enum associated with the given original block. |
| EnumDecl *getBranchingTraceDecl(SILBasicBlock *origBB) const { |
| return branchingTraceDecls.lookup(origBB); |
| } |
| |
| /// Returns the lowered SIL type of the branching trace enum associated with |
| /// the given original block. |
| SILType getBranchingTraceEnumLoweredType(SILBasicBlock *origBB) const { |
| auto *traceDecl = getBranchingTraceDecl(origBB); |
| auto traceDeclType = |
| traceDecl->getDeclaredInterfaceType()->getCanonicalType(); |
| return typeConverter.getLoweredType(traceDeclType, |
| ResilienceExpansion::Minimal); |
| } |
| |
| /// Returns the enum element in the given successor block's branching trace |
| /// enum corresponding to the given predecessor block. |
| EnumElementDecl * |
| lookUpBranchingTraceEnumElement(SILBasicBlock *origPredBB, |
| SILBasicBlock *origSuccBB) const { |
| assert(origPredBB->getParent() == original); |
| return branchingTraceEnumCases.lookup({origPredBB, origSuccBB}); |
| } |
| |
| /// Returns the mapping from linear map structs to their branching trace enum |
| /// fields. |
| DenseMap<StructDecl *, VarDecl *> &getLinearMapStructEnumFields() { |
| return linearMapStructEnumFields; |
| } |
| |
| /// Returns the branching trace enum field for the linear map struct of the |
| /// given original block. |
| VarDecl *lookUpLinearMapStructEnumField(SILBasicBlock *origBB) { |
| auto *linearMapStruct = getLinearMapStruct(origBB); |
| return linearMapStructEnumFields.lookup(linearMapStruct); |
| } |
| |
| /// Finds the linear map declaration in the pullback struct for an `apply` or |
| /// `struct_extract` in the original function. |
| VarDecl *lookUpLinearMapDecl(SILInstruction *inst) { |
| auto lookup = linearMapValueMap.find(inst); |
| assert(lookup != linearMapValueMap.end() && |
| "No linear map declaration corresponding to the given instruction"); |
| return lookup->getSecond(); |
| } |
| }; |
| |
| /// Stores `apply` instruction information calculated by VJP generation. |
| struct NestedApplyInfo { |
| /// The differentiation indices that are used to differentiate this `apply` |
| /// instruction. |
| SILAutoDiffIndices indices; |
| /// The original pullback type before reabstraction. `None` if the pullback |
| /// type is not reabstracted. |
| Optional<CanSILFunctionType> originalPullbackType; |
| }; |
| |
| static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, |
| DifferentiationInvoker invoker) { |
| invoker.print(os); |
| return os; |
| } |
| |
| void DifferentiationInvoker::print(llvm::raw_ostream &os) const { |
| os << "(differentiation_invoker "; |
| switch (kind) { |
| case Kind::DifferentiableFunctionInst: |
| os << "differentiable_function_inst=(" << *getDifferentiableFunctionInst() |
| << ")"; |
| break; |
| case Kind::IndirectDifferentiation: { |
| auto indDiff = getIndirectDifferentiation(); |
| os << "indirect_differentiation=(" << *std::get<0>(indDiff) << ')'; |
| // TODO: Enable printing parent invokers. |
| // May require storing a `DifferentiableInvoker *` in the |
| // `IndirectDifferentiation` case. |
| /* |
| SILInstruction *inst; |
| SILDifferentiableAttr *attr; |
| std::tie(inst, attr) = getIndirectDifferentiation(); |
| auto invokerLookup = invokers.find(attr); // No access to ADContext? |
| assert(invokerLookup != invokers.end() && "Expected parent invoker"); |
| */ |
| break; |
| } |
| case Kind::SILDifferentiableAttribute: { |
| auto diffAttr = getSILDifferentiableAttribute(); |
| os << "sil_differentiable_attribute=(attr=("; |
| diffAttr->print(os); |
| os << ") function=" << diffAttr->getOriginal()->getName(); |
| break; |
| } |
| } |
| os << ')'; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ADContext - Per-module contextual information for the Differentiation pass. |
| //===----------------------------------------------------------------------===// |
| |
| class ADContext { |
| private: |
| /// Reference to the main transform. |
| SILModuleTransform &transform; |
| |
| /// The module where Differentiation is performed on. |
| SILModule &module; |
| |
| /// AST context. |
| ASTContext &astCtx = module.getASTContext(); |
| |
| /// Shared pass manager. |
| SILPassManager &passManager; |
| |
| /// The worklist (stack) of `differentiable_function` instructions to be |
| /// processed. |
| SmallVector<DifferentiableFunctionInst *, 32> differentiableFunctionInsts; |
| |
| /// The set of `differentiable_function` instructions that have been |
| /// processed. Used to avoid reprocessing invalidated instructions. |
| SmallPtrSet<DifferentiableFunctionInst *, 32> |
| processedDifferentiableFunctionInsts; |
| |
| /// Mapping from `[differentiable]` attributes to invokers. |
| /// `SmallMapVector` is used for deterministic insertion order iteration. |
| SmallMapVector<SILDifferentiableAttr *, DifferentiationInvoker, 32> |
| invokers; |
| |
| /// Mapping from `differentiable_function` instructions to result indices. |
| DenseMap<DifferentiableFunctionInst *, unsigned> resultIndices; |
| |
| /// Mapping from original `apply` instructions to their corresponding |
| /// `NestedApplyInfo`s. |
| DenseMap<ApplyInst *, NestedApplyInfo> nestedApplyInfo; |
| |
| /// List of generated functions (JVPs, VJPs, pullbacks, and thunks). |
| /// Saved for deletion during cleanup. |
| SmallVector<SILFunction *, 32> generatedFunctions; |
| |
| /// List of references to generated functions. |
| /// Saved for deletion during cleanup. |
| SmallVector<SILValue, 32> generatedFunctionReferences; |
| |
| /// The AdditiveArithmetic protocol in the standard library. |
| ProtocolDecl *additiveArithmeticProtocol = |
| astCtx.getProtocol(KnownProtocolKind::AdditiveArithmetic); |
| |
| /// `AdditiveArithmetic.+` declaration. |
| mutable FuncDecl *cachedPlusFn = nullptr; |
| /// `AdditiveArithmetic.+=` declaration. |
| mutable FuncDecl *cachedPlusEqualFn = nullptr; |
| |
| public: |
| /// Construct an ADContext for the given module. |
| explicit ADContext(SILModuleTransform &transform); |
| |
| //--------------------------------------------------------------------------// |
| // General utilities |
| //--------------------------------------------------------------------------// |
| |
| SILModuleTransform &getTransform() const { return transform; } |
| SILModule &getModule() const { return module; } |
| ASTContext &getASTContext() const { return module.getASTContext(); } |
| SILPassManager &getPassManager() const { return passManager; } |
| Lowering::TypeConverter &getTypeConverter() { return module.Types; } |
| |
| SmallVectorImpl<DifferentiableFunctionInst *> & |
| getDifferentiableFunctionInsts() { |
| return differentiableFunctionInsts; |
| } |
| |
| SmallPtrSetImpl<DifferentiableFunctionInst *> & |
| getProcessedDifferentiableFunctionInsts() { |
| return processedDifferentiableFunctionInsts; |
| } |
| |
| llvm::SmallMapVector<SILDifferentiableAttr *, DifferentiationInvoker, 32> & |
| getInvokers() { |
| return invokers; |
| } |
| |
| DenseMap<DifferentiableFunctionInst *, unsigned> &getResultIndices() { |
| return resultIndices; |
| } |
| |
| DenseMap<ApplyInst *, NestedApplyInfo> &getNestedApplyInfo() { |
| return nestedApplyInfo; |
| } |
| |
| SmallVector<SILFunction *, 32> &getGeneratedFunctions() { |
| return generatedFunctions; |
| } |
| |
| SmallVector<SILValue, 32> &getGeneratedFunctionReferences() { |
| return generatedFunctionReferences; |
| } |
| |
| ProtocolDecl *getAdditiveArithmeticProtocol() const { |
| return additiveArithmeticProtocol; |
| } |
| |
| FuncDecl *getPlusDecl() const { |
| if (!cachedPlusFn) { |
| cachedPlusFn = findOperatorDeclInProtocol( |
| astCtx.getIdentifier("+"), additiveArithmeticProtocol); |
| assert(cachedPlusFn && "AdditiveArithmetic.+ not found"); |
| } |
| return cachedPlusFn; |
| } |
| |
| FuncDecl *getPlusEqualDecl() const { |
| if (!cachedPlusEqualFn) { |
| cachedPlusEqualFn = findOperatorDeclInProtocol( |
| astCtx.getIdentifier("+="), additiveArithmeticProtocol); |
| assert(cachedPlusEqualFn && "AdditiveArithmetic.+= not found"); |
| } |
| return cachedPlusEqualFn; |
| } |
| |
| void cleanUp() { |
| for (auto invokerPair : invokers) { |
| auto *attr = std::get<0>(invokerPair); |
| auto *original = attr->getOriginal(); |
| LLVM_DEBUG(getADDebugStream() |
| << "Removing [differentiable] attribute for " |
| << original->getName() << '\n'); |
| original->removeDifferentiableAttr(attr); |
| } |
| // Delete all references to generated functions. |
| for (auto fnRef : generatedFunctionReferences) { |
| if (auto *fnRefInst = |
| peerThroughFunctionConversions<FunctionRefInst>(fnRef)) { |
| fnRefInst->replaceAllUsesWithUndef(); |
| fnRefInst->eraseFromParent(); |
| } |
| } |
| // Delete all generated functions. |
| for (auto *generatedFunction : generatedFunctions) { |
| LLVM_DEBUG(getADDebugStream() |
| << "Deleting generated function " |
| << generatedFunction->getName() << '\n'); |
| generatedFunction->dropAllReferences(); |
| transform.notifyWillDeleteFunction(generatedFunction); |
| module.eraseFunction(generatedFunction); |
| } |
| } |
| |
| //--------------------------------------------------------------------------// |
| // `[differentiable]` attribute lookup and registration |
| //--------------------------------------------------------------------------// |
| |
| /// Finds the `[differentiable]` attribute on the specified original function |
| /// with the exact specified parameter indices. Returns nullptr if no such |
| /// attribute exists. |
| SILDifferentiableAttr *lookUpDifferentiableAttr( |
| SILFunction *original, const SILAutoDiffIndices &indices) const { |
| for (auto *attr : original->getDifferentiableAttrs()) |
| if (attr->getIndices() == indices) |
| return attr; |
| return nullptr; |
| } |
| |
| /// Finds the `[differentiable]` attribute on the specified original function |
| /// whose parameter indices are a minimal superset of the specified parameter |
| /// indices. Returns nullptr if no such attribute exists. |
| SILDifferentiableAttr *lookUpMinimalDifferentiableAttr( |
| SILFunction *original, const SILAutoDiffIndices &indices) const { |
| auto *minimalIndexSet = IndexSubset::getDefault( |
| getASTContext(), |
| original->getLoweredFunctionType()->getNumParameters(), false); |
| auto *indexSet = indices.parameters; |
| if (auto *exactAttr = lookUpDifferentiableAttr(original, indices)) |
| return exactAttr; |
| SILDifferentiableAttr *minimalAttr = nullptr; |
| for (auto *da : original->getDifferentiableAttrs()) { |
| if (da->getIndices().source != indices.source) |
| continue; |
| auto *daIndexSet = da->getIndices().parameters; |
| // If all indices in `indexSet` are in `daIndexSet`, and it has fewer |
| // indices than our current candidate and a primitive VJP, then `da` is |
| // our new candidate. |
| // |
| // NOTE(TF-642): `da` may come from a un-partial-applied function and |
| // have larger capacity than the desired indices. We expect this logic to |
| // go away when `partial_apply` supports `@differentiable` callees. |
| if (daIndexSet->isSupersetOf(indexSet->extendingCapacity( |
| getASTContext(), daIndexSet->getCapacity())) && |
| // fewer parameters than before |
| (minimalIndexSet->isEmpty() || |
| daIndexSet->getNumIndices() < minimalIndexSet->getNumIndices())) { |
| minimalAttr = da; |
| minimalIndexSet = daIndexSet; |
| } |
| } |
| return minimalAttr; |
| } |
| |
| /// Finds the `@differentiable` attribute (and its parameter indices) on the |
| /// specified original function whose parameter indices are a minimal |
| /// superset of the specified parameter indices. Returns nullptr if no such |
| /// attribute exists. |
| std::pair<const DifferentiableAttr *, IndexSubset *> |
| lookUpMinimalASTDifferentiableAttrAndIndexSubset( |
| SILDeclRef originalDeclRef, CanSILFunctionType originalFnType, |
| const SILAutoDiffIndices &indices) { |
| auto *original = originalDeclRef.getDecl(); |
| const DifferentiableAttr *minimalAttr = nullptr; |
| auto *minimalIndexSet = IndexSubset::getDefault( |
| getASTContext(), originalFnType->getNumParameters(), false); |
| auto *indexSet = indices.parameters; |
| for (auto *da : original->getAttrs().getAttributes<DifferentiableAttr>()) { |
| auto *daParamIndices = da->getParameterIndices(); |
| auto *daIndexSet = autodiff::getLoweredParameterIndices( |
| daParamIndices, original->getInterfaceType()->castTo<AnyFunctionType>()); |
| // If all indices in `indexSet` are in `daIndexSet`, and it has fewer |
| // indices than our current candidate and a primitive VJP, then `da` is |
| // our new candidate. |
| // |
| // NOTE(TF-642): `da` may come from a un-partial-applied function and |
| // have larger capacity than the desired indices. We expect this logic to |
| // go away when `partial_apply` supports `@differentiable` callees. |
| if (daIndexSet->isSupersetOf(indexSet->extendingCapacity(getASTContext(), |
| daIndexSet->getCapacity())) && |
| // fewer parameters than before |
| (minimalIndexSet->isEmpty() || |
| daIndexSet->getNumIndices() < minimalIndexSet->getNumIndices())) { |
| minimalAttr = da; |
| minimalIndexSet = daIndexSet; |
| } |
| } |
| return std::make_pair(minimalAttr, minimalIndexSet); |
| } |
| |
| /// Creates a `[differentiable]` attribute on the specified original function |
| /// with the specified parameter indices. |
| SILDifferentiableAttr *createDifferentiableAttr( |
| SILFunction *original, const SILAutoDiffIndices &indices, |
| GenericSignature derivativeGenericSignature) const { |
| assert(!lookUpDifferentiableAttr(original, indices)); |
| auto derivativeConstrainedGenSig = getConstrainedDerivativeGenericSignature( |
| original->getLoweredFunctionType(), indices.parameters, |
| derivativeGenericSignature); |
| auto *attr = SILDifferentiableAttr::create(getModule(), indices, |
| /*jvpName*/ StringRef(), |
| /*vjpName*/ StringRef(), |
| derivativeConstrainedGenSig); |
| original->addDifferentiableAttr(attr); |
| return attr; |
| } |
| |
| /// Finds or creates a `[differentiable]` attribute on the specified |
| /// original function corresponding to the specified parameter indices. |
| SILDifferentiableAttr *getOrCreateDifferentiableAttr( |
| SILFunction *original, const SILAutoDiffIndices &indices, |
| GenericSignature derivativeGenericSignature) { |
| if (auto *attr = lookUpDifferentiableAttr(original, indices)) |
| return attr; |
| assert(original->isDefinition()); |
| return createDifferentiableAttr(original, indices, |
| derivativeGenericSignature); |
| } |
| |
| /// Creates an `differentiable_function` instruction using the given builder |
| /// and arguments. Erase the newly created instruction from the processed set, |
| /// if it exists - it may exist in the processed set if it has the same |
| /// pointer value as a previously processed and deleted instruction. |
| DifferentiableFunctionInst *createDifferentiableFunction( |
| SILBuilder &builder, SILLocation loc, |
| IndexSubset *parameterIndices, SILValue original, |
| Optional<std::pair<SILValue, SILValue>> derivativeFunctions = None) { |
| auto *dfi = builder.createDifferentiableFunction( |
| loc, parameterIndices, original, derivativeFunctions); |
| processedDifferentiableFunctionInsts.erase(dfi); |
| return dfi; |
| } |
| |
| private: |
| /// Promotes the given `differentiable_function` instruction to a valid |
| /// `@differentiable` function-typed value. |
| SILValue promoteToDifferentiableFunction( |
| DifferentiableFunctionInst *inst, SILBuilder &builder, SILLocation loc, |
| DifferentiationInvoker invoker); |
| |
| public: |
| /// Process the given `[differentiable]` attribute, filling in JVP/VJPs if |
| /// missing. |
| bool processDifferentiableAttribute( |
| SILFunction *original, SILDifferentiableAttr *attr, |
| DifferentiationInvoker invoker); |
| |
| /// Process the given `differentiable_function` instruction, filling in |
| /// missing derivative functions if necessary. |
| bool processDifferentiableFunctionInst(DifferentiableFunctionInst *dfi); |
| |
| /// Fold `differentiable_function_extract` users of the given |
| /// `differentiable_function` instruction, directly replacing them with |
| /// `differentiable_function` instruction operands. If the |
| /// `differentiable_function` instruction has no remaining uses, delete the |
| /// instruction itself after folding. |
| /// |
| /// Folding can be disabled by the |
| /// `SkipFoldingDifferentiableFunctionExtraction` flag for SIL testing |
| /// purposes. |
| void foldDifferentiableFunctionExtraction(DifferentiableFunctionInst *source); |
| |
| /// Get or create a derivative function parameter index subset thunk from |
| /// `actualIndices` to `desiredIndices` for the given associated function |
| /// value and original function operand. Returns a pair of the parameter |
| /// index subset thunk and its interface substitution map (used to partially |
| /// apply the thunk). |
| /// Calls `getOrCreateSubsetParametersThunkForLinearMap` to thunk the linear |
| /// map returned by the derivative function. |
| std::pair<SILFunction *, SubstitutionMap> |
| getOrCreateSubsetParametersThunkForDerivativeFunction( |
| SILValue origFnOperand, SILValue derivativeFn, |
| AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices, |
| SILAutoDiffIndices actualIndices); |
| |
| /// Get or create a derivative function parameter index subset thunk from |
| /// `actualIndices` to `desiredIndices` for the given associated function |
| /// value and original function operand. Returns a pair of the parameter |
| /// index subset thunk and its interface substitution map (used to partially |
| /// apply the thunk). |
| std::pair<SILFunction *, SubstitutionMap> |
| getOrCreateSubsetParametersThunkForLinearMap( |
| SILFunction *assocFn, CanSILFunctionType linearMapType, |
| CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind, |
| SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices); |
| |
| public: |
| /// Declare an external reference to a derivative function of `original`, |
| /// given a `[differentiable]` attribute of `original` and the associated |
| /// function kind. |
| SILFunction * |
| declareExternalDerivativeFunction(SILFunction *original, |
| SILDifferentiableAttr *attr, StringRef name, |
| AutoDiffDerivativeFunctionKind kind); |
| |
| template <typename ...T, typename ...U> |
| InFlightDiagnostic diagnose(SourceLoc loc, Diag<T...> diag, |
| U &&...args) const { |
| return getASTContext().Diags.diagnose(loc, diag, std::forward<U>(args)...); |
| } |
| |
| /// Given an instruction and a differentiation task associated with the |
| /// parent function, emits a "not differentiable" error based on the task. If |
| /// the task is indirect, emits notes all the way up to the outermost task, |
| /// and emits an error at the outer task. Otherwise, emits an error directly. |
| template<typename ...T, typename ...U> |
| InFlightDiagnostic emitNondifferentiabilityError( |
| SILInstruction *inst, DifferentiationInvoker invoker, |
| Diag<T...> diag, U &&...args); |
| |
| /// Given a value and a differentiation task associated with the parent |
| /// function, emits a "not differentiable" error based on the task. If the |
| /// task is indirect, emits notes all the way up to the outermost task, and |
| /// emits an error at the outer task. Otherwise, emits an error directly. |
| template<typename ...T, typename ...U> |
| InFlightDiagnostic emitNondifferentiabilityError( |
| SILValue value, DifferentiationInvoker invoker, |
| Diag<T...> diag, U &&...args); |
| |
| /// Emit a "not differentiable" error based on the given differentiation task |
| /// and diagnostic. |
| template<typename ...T, typename ...U> |
| InFlightDiagnostic emitNondifferentiabilityError( |
| SourceLoc loc, DifferentiationInvoker invoker, |
| Diag<T...> diag, U &&...args); |
| }; |
| } // end anonymous namespace |
| |
| ADContext::ADContext(SILModuleTransform &transform) |
| : transform(transform), module(*transform.getModule()), |
| passManager(*transform.getPassManager()) {} |
| |
| template<typename ...T, typename ...U> |
| InFlightDiagnostic |
| ADContext::emitNondifferentiabilityError(SILValue value, |
| DifferentiationInvoker invoker, |
| Diag<T...> diag, U &&...args) { |
| LLVM_DEBUG({ |
| getADDebugStream() << "Diagnosing non-differentiability.\n"; |
| getADDebugStream() << "For value:\n" << value; |
| getADDebugStream() << "With invoker:\n" << invoker << '\n'; |
| }); |
| auto valueLoc = value.getLoc().getSourceLoc(); |
| // If instruction does not have a valid location, use the function location |
| // as a fallback. Improves diagnostics in some cases. |
| if (valueLoc.isInvalid()) |
| valueLoc = value->getFunction()->getLocation().getSourceLoc(); |
| return emitNondifferentiabilityError(valueLoc, invoker, diag, |
| std::forward<U>(args)...); |
| } |
| |
| template<typename ...T, typename ...U> |
| InFlightDiagnostic |
| ADContext::emitNondifferentiabilityError(SILInstruction *inst, |
| DifferentiationInvoker invoker, |
| Diag<T...> diag, U &&...args) { |
| LLVM_DEBUG({ |
| getADDebugStream() << "Diagnosing non-differentiability.\n"; |
| getADDebugStream() << "For instruction:\n" << *inst; |
| getADDebugStream() << "With invoker:\n" << invoker << '\n'; |
| }); |
| auto instLoc = inst->getLoc().getSourceLoc(); |
| // If instruction does not have a valid location, use the function location |
| // as a fallback. Improves diagnostics for `ref_element_addr` generated in |
| // synthesized stored property getters. |
| if (instLoc.isInvalid()) |
| instLoc = inst->getFunction()->getLocation().getSourceLoc(); |
| return emitNondifferentiabilityError(instLoc, invoker, diag, |
| std::forward<U>(args)...); |
| } |
| |
| template<typename ...T, typename ...U> |
| InFlightDiagnostic |
| ADContext::emitNondifferentiabilityError(SourceLoc loc, |
| DifferentiationInvoker invoker, |
| Diag<T...> diag, U &&...args) { |
| switch (invoker.getKind()) { |
| // For `differentiable_function` instructions: if the `differentiable_function` |
| // instruction comes from a differential operator, emit an error on the |
| // expression and a note on the non-differentiable operation. Otherwise, emit |
| // both an error and note on the non-differentiation operation. |
| case DifferentiationInvoker::Kind::DifferentiableFunctionInst: { |
| auto *inst = invoker.getDifferentiableFunctionInst(); |
| if (auto *expr = findDifferentialOperator(inst)) { |
| diagnose(expr->getLoc(), diag::autodiff_function_not_differentiable_error) |
| .highlight(expr->getSubExpr()->getSourceRange()); |
| return diagnose(loc, diag, std::forward<U>(args)...); |
| } |
| diagnose(loc, diag::autodiff_expression_not_differentiable_error); |
| return diagnose(loc, diag, std::forward<U>(args)...); |
| } |
| |
| // For `[differentiable]` attributes, try to find an AST function declaration |
| // and `@differentiable` attribute. If they are found, emit an error on the |
| // `@differentiable` attribute; otherwise, emit an error on the SIL function. |
| // Emit a note at the non-differentiable operation. |
| case DifferentiationInvoker::Kind::SILDifferentiableAttribute: { |
| auto *attr = invoker.getSILDifferentiableAttribute(); |
| auto *original = attr->getOriginal(); |
| bool foundAttr = false; |
| if (auto *declContext = original->getDeclContext()) { |
| if (auto *fnDecl = declContext->getInnermostDeclarationDeclContext()) { |
| if (auto *diffAttr = |
| fnDecl->getAttrs().getAttribute<DifferentiableAttr>()) { |
| diagnose(diffAttr->getLocation(), |
| diag::autodiff_function_not_differentiable_error) |
| .highlight(diffAttr->getRangeWithAt()); |
| diagnose(original->getLocation().getSourceLoc(), |
| diag::autodiff_when_differentiating_function_definition); |
| foundAttr = true; |
| } |
| } |
| } |
| // Fallback if we cannot find the expected attribute. |
| if (!foundAttr) |
| diagnose(original->getLocation().getSourceLoc(), |
| diag::autodiff_function_not_differentiable_error); |
| return diagnose(loc, diag, std::forward<U>(args)...); |
| } |
| |
| // For indirect differentiation, emit a "not differentiable" note on the |
| // expression first. Then emit an error at the source invoker of |
| // differentiation, and a "when differentiating this" note at each indirect |
| // invoker. |
| case DifferentiationInvoker::Kind::IndirectDifferentiation: { |
| SILInstruction *inst; |
| SILDifferentiableAttr *attr; |
| std::tie(inst, attr) = invoker.getIndirectDifferentiation(); |
| auto invokerLookup = invokers.find(attr); |
| assert(invokerLookup != invokers.end() && "Expected parent invoker"); |
| emitNondifferentiabilityError(inst, invokerLookup->second, |
| diag::autodiff_expression_not_differentiable_note); |
| return diagnose(loc, diag::autodiff_when_differentiating_function_call); |
| } |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Activity Analysis |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| class DifferentiableActivityCollection; |
| |
| /// In many real situations, the end-users of AD need only the derivatives of |
| /// some selected outputs of `P` with respect to some selected inputs of `P`. |
| /// Whatever the differentiation mode (tangent, reverse,...), these restrictions |
| /// allow the AD tool to produce a much more efficient differentiated program. |
| /// Essentially, fixing some inputs and neglecting some outputs allows AD to |
| /// just forget about several intermediate differentiated variables. |
| /// |
| /// Activity analysis is the specific analysis that detects these situations, |
| /// therefore allowing for a better differentiated code. Activity analysis is |
| /// present in all transformation-based AD tools. |
| /// |
| /// To begin with, the end-user specifies that only some output variables (the |
| /// “dependent”) must be differentiated with respect to only some input |
| /// variables (the “independent”). We say that variable `y` depends on `x` when |
| /// the derivative of `y` with respect to `x` is not trivially null. We say that |
| /// a variable is “varied” if it depends on at least one independent. Conversely |
| /// we say that a variable is “useful” if at least one dependent depends on it. |
| /// Finally, we say that a variable is “active” if it is at the same time varied |
| /// and useful. In the special case of the tangent mode, it is easy to check |
| /// that when variable `v` is not varied at some place in the program, then its |
| /// derivative `v̇` at this place is certainly null. Conversely when variable `v` |
| /// is not useful, then whatever the value of `v̇`, this value does not matter |
| /// for the final result. Symmetric reasoning applies for the reverse mode of |
| /// AD: observing that differentiated variables go upstream, we see that a |
| /// useless variable has a null derivative, in other words the partial |
| /// derivative of the output with respect to this variable is null. Conversely |
| /// when variable `v` is not varied, then whatever the value of `v`, this value |
| /// does not matter for the final result. |
| /// |
| /// Reference: |
| /// Laurent Hascoët. Automatic Differentiation by Program Transformation. 2007. |
| class DifferentiableActivityAnalysis |
| : public FunctionAnalysisBase<DifferentiableActivityCollection> { |
| private: |
| DominanceAnalysis *dominanceAnalysis = nullptr; |
| PostDominanceAnalysis *postDominanceAnalysis = nullptr; |
| |
| public: |
| explicit DifferentiableActivityAnalysis() |
| : FunctionAnalysisBase(SILAnalysisKind::DifferentiableActivity) {} |
| |
| static bool classof(const SILAnalysis *s) { |
| return s->getKind() == SILAnalysisKind::DifferentiableActivity; |
| } |
| |
| virtual bool shouldInvalidate(SILAnalysis::InvalidationKind k) override { |
| return k & InvalidationKind::Everything; |
| } |
| |
| virtual std::unique_ptr<DifferentiableActivityCollection> |
| newFunctionAnalysis(SILFunction *f) override; |
| |
| virtual void initialize(SILPassManager *pm) override; |
| }; |
| } // end anonymous namespace |
| |
| namespace { |
| /// Represents the differentiation activity associated with a SIL value. |
| enum class ActivityFlags : unsigned { |
| /// The value depends on a function parameter. |
| Varied = 1 << 1, |
| /// The value contributes to a result. |
| Useful = 1 << 2, |
| /// The value is both varied and useful. |
| Active = Varied | Useful, |
| }; |
| |
| using Activity = OptionSet<ActivityFlags>; |
| |
| /// Result of activity analysis on a function. Accepts queries for whether a |
| /// value is "varied", "useful" or "active" against certain differentiation |
| /// indices. |
| class DifferentiableActivityInfo { |
| private: |
| DifferentiableActivityCollection &parent; |
| |
| /// The derivative generic signature. |
| GenericSignature derivativeGenericSignature; |
| |
| /// Input values, i.e. parameters (both direct and indirect). |
| SmallVector<SILValue, 4> inputValues; |
| /// Output values, i.e. individual values (not the final tuple) being returned |
| /// by the `return` instruction. |
| SmallVector<SILValue, 4> outputValues; |
| |
| /// The set of useful variables, indexed by the corresponding dependent value |
| /// (output) index. |
| SmallVector<SmallDenseSet<SILValue>, 4> usefulValueSets; |
| /// The set of useful variables, indexed by the corresponding independent |
| /// value (input) index. |
| SmallVector<SmallDenseSet<SILValue>, 4> variedValueSets; |
| |
| /// The original function. |
| SILFunction &getFunction(); |
| |
| /// The conformance lookup function. |
| LookupConformanceFn getLookupConformanceFunction() { |
| // Look up in derivative generic signature, if defined. |
| if (derivativeGenericSignature) |
| return LookUpConformanceInSignature( |
| derivativeGenericSignature.getPointer()); |
| // Otherwise, look up in the module. |
| return LookUpConformanceInModule( |
| getFunction().getModule().getSwiftModule()); |
| } |
| |
| /// Perform analysis and populate sets. |
| void analyze(DominanceInfo *di, PostDominanceInfo *pdi); |
| |
| void setVaried(SILValue value, unsigned independentVariableIndex); |
| void setVariedAcrossArrayInitialization(SILValue value, |
| unsigned independentVariableIndex); |
| void setUseful(SILValue value, unsigned dependentVariableIndex); |
| void setUsefulAcrossArrayInitialization(SILValue value, |
| unsigned dependentVariableIndex); |
| /// Marks the given value as "varied" and recursively propagates "varied" |
| /// inwards (to operands) through projections. Skips any `@noDerivative` |
| /// struct field projections. |
| void propagateVariedInwardsThroughProjections( |
| SILValue value, unsigned independentVariableIndex); |
| void propagateUsefulThroughBuffer(SILValue value, |
| unsigned dependentVariableIndex); |
| |
| public: |
| explicit DifferentiableActivityInfo( |
| DifferentiableActivityCollection &parent, |
| GenericSignature derivativeGenericSignature); |
| |
| bool isVaried(SILValue value, unsigned independentVariableIndex) const; |
| bool isUseful(SILValue value, unsigned dependentVariableIndex) const; |
| bool isVaried(SILValue value, IndexSubset *parameterIndices) const; |
| bool isActive(SILValue value, const SILAutoDiffIndices &indices) const; |
| |
| Activity getActivity(SILValue value, |
| const SILAutoDiffIndices &indices) const; |
| Activity getActivity(SILInstruction *inst, |
| const SILAutoDiffIndices &indices) const; |
| }; |
| |
| /// Given a parameter argument (not indirect result) and some differentiation |
| /// indices, figure out whether the parent function is being differentiated with |
| /// respect to this parameter, according to the indices. |
| static bool isDifferentiationParameter(SILArgument *argument, |
| IndexSubset *indices) { |
| if (!argument) return false; |
| auto *function = argument->getFunction(); |
| auto paramArgs = function->getArgumentsWithoutIndirectResults(); |
| for (unsigned i : indices->getIndices()) |
| if (paramArgs[i] == argument) |
| return true; |
| return false; |
| } |
| |
| /// For an `apply` instruction with active results, compute: |
| /// - The results of the `apply` instruction, in type order. |
| /// - The set of minimal parameter and result indices for differentiating the |
| /// `apply` instruction. |
| static void collectMinimalIndicesForFunctionCall( |
| ApplyInst *ai, const SILAutoDiffIndices &parentIndices, |
| const DifferentiableActivityInfo &activityInfo, |
| SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> ¶mIndices, |
| SmallVectorImpl<unsigned> &resultIndices) { |
| auto calleeFnTy = ai->getSubstCalleeType(); |
| auto calleeConvs = ai->getSubstCalleeConv(); |
| // Parameter indices are indices (in the callee type signature) of parameter |
| // arguments that are varied or are arguments. |
| // Record all parameter indices in type order. |
| unsigned currentParamIdx = 0; |
| for (auto applyArg : ai->getArgumentsWithoutIndirectResults()) { |
| if (activityInfo.isVaried(applyArg, parentIndices.parameters) || |
| isDifferentiationParameter(dyn_cast<SILArgument>(applyArg), |
| parentIndices.parameters)) |
| paramIndices.push_back(currentParamIdx); |
| ++currentParamIdx; |
| } |
| // Result indices are indices (in the callee type signature) of results that |
| // are useful. |
| SmallVector<SILValue, 8> directResults; |
| forEachApplyDirectResult(ai, [&](SILValue directResult) { |
| directResults.push_back(directResult); |
| }); |
| auto indirectResults = ai->getIndirectSILResults(); |
| // Record all results and result indices in type order. |
| results.reserve(calleeFnTy->getNumResults()); |
| unsigned dirResIdx = 0; |
| unsigned indResIdx = calleeConvs.getSILArgIndexOfFirstIndirectResult(); |
| for (auto &resAndIdx : enumerate(calleeConvs.getResults())) { |
| auto &res = resAndIdx.value(); |
| unsigned idx = resAndIdx.index(); |
| if (res.isFormalDirect()) { |
| results.push_back(directResults[dirResIdx]); |
| if (auto dirRes = directResults[dirResIdx]) |
| if (dirRes && activityInfo.isUseful(dirRes, parentIndices.source)) |
| resultIndices.push_back(idx); |
| ++dirResIdx; |
| } else { |
| results.push_back(indirectResults[indResIdx]); |
| if (activityInfo.isUseful(indirectResults[indResIdx], |
| parentIndices.source)) |
| resultIndices.push_back(idx); |
| ++indResIdx; |
| } |
| } |
| // Make sure the function call has active results. |
| assert(results.size() == calleeFnTy->getNumResults()); |
| assert(llvm::any_of(results, [&](SILValue result) { |
| return activityInfo.isActive(result, parentIndices); |
| })); |
| } |
| |
| LinearMapInfo::LinearMapInfo(ADContext &context, |
| AutoDiffLinearMapKind kind, |
| SILFunction *original, SILFunction *derivative, |
| const SILAutoDiffIndices &indices, |
| const DifferentiableActivityInfo &activityInfo) |
| : kind(kind), original(original), derivative(derivative), |
| activityInfo(activityInfo), indices(indices), |
| typeConverter(context.getTypeConverter()) { |
| generateDifferentiationDataStructures(context, indices, derivative); |
| } |
| |
| /// Returns a flag that indicates whether the `apply` instruction should be |
| /// differentiated, given the differentiation indices of the instruction's |
| /// parent function. Whether the `apply` should be differentiated is determined |
| /// sequentially from the following conditions: |
| /// 1. The instruction has an active `inout` argument. |
| /// 2. The instruction is a call to the array literal initialization intrinsic |
| /// ("array.uninitialized_intrinsic"), where the result is active and where |
| /// there is a `store` of an active value into the array's buffer. |
| /// 3. The instruction has both an active result (direct or indirect) and an |
| /// active argument. |
| bool LinearMapInfo::shouldDifferentiateApplyInst(ApplyInst *ai) { |
| // Function applications with an inout argument should be differentiated. |
| auto paramInfos = ai->getSubstCalleeConv().getParameters(); |
| auto arguments = ai->getArgumentsWithoutIndirectResults(); |
| for (auto i : swift::indices(paramInfos)) |
| if (paramInfos[i].isIndirectInOut() && |
| activityInfo.isActive(arguments[i], indices)) |
| return true; |
| |
| bool hasActiveDirectResults = false; |
| forEachApplyDirectResult(ai, [&](SILValue directResult) { |
| hasActiveDirectResults |= activityInfo.isActive(directResult, indices); |
| }); |
| bool hasActiveIndirectResults = llvm::any_of(ai->getIndirectSILResults(), |
| [&](SILValue result) { return activityInfo.isActive(result, indices); }); |
| bool hasActiveResults = hasActiveDirectResults || hasActiveIndirectResults; |
| |
| // TODO: Pattern match to make sure there is at least one `store` to the |
| // array's active buffer. |
| if (isArrayLiteralIntrinsic(ai) && hasActiveResults) |
| return true; |
| |
| bool hasActiveArguments = llvm::any_of(arguments, |
| [&](SILValue arg) { return activityInfo.isActive(arg, indices); }); |
| return hasActiveResults && hasActiveArguments; |
| } |
| |
| /// Returns a flag indicating whether the instruction should be differentiated, |
| /// given the differentiation indices of the instruction's parent function. |
| /// Whether the instruction should be differentiated is determined sequentially |
| /// from any of the following conditions: |
| /// 1. The instruction is an `apply` and `shouldDifferentiateApplyInst` returns |
| /// true. |
| /// 2. The instruction has a source operand and a destination operand, both |
| /// being active. |
| /// 3. The instruction is an allocation instruction and has an active result. |
| /// 4. The instruction performs reference counting, lifetime ending, access |
| /// ending, or destroying on an active operand. |
| /// 5. The instruction creates an SSA copy of an active operand. |
| bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) { |
| // An `apply` with an active argument and an active result (direct or |
| // indirect) should be differentiated. |
| if (auto *ai = dyn_cast<ApplyInst>(inst)) |
| return shouldDifferentiateApplyInst(ai); |
| // Anything with an active result and an active operand should be |
| // differentiated. |
| auto hasActiveOperands = llvm::any_of(inst->getAllOperands(), |
| [&](Operand &op) { return activityInfo.isActive(op.get(), indices); }); |
| auto hasActiveResults = llvm::any_of(inst->getResults(), |
| [&](SILValue val) { return activityInfo.isActive(val, indices); }); |
| if (hasActiveOperands && hasActiveResults) |
| return true; |
| // A `store`-like instruction does not have an SSA result, but has two |
| // operands that represent the source and the destination. We treat them as |
| // the input and the output, respectively. |
| #define CHECK_INST_TYPE_ACTIVE_DEST(INST) \ |
| if (auto *castInst = dyn_cast<INST##Inst>(inst)) \ |
| return activityInfo.isActive(castInst->getDest(), indices); |
| CHECK_INST_TYPE_ACTIVE_DEST(Store) |
| CHECK_INST_TYPE_ACTIVE_DEST(StoreBorrow) |
| CHECK_INST_TYPE_ACTIVE_DEST(CopyAddr) |
| CHECK_INST_TYPE_ACTIVE_DEST(UnconditionalCheckedCastAddr) |
| #undef CHECK_INST_TYPE_ACTIVE_DEST |
| // Should differentiate any allocation instruction that has an active result. |
| if ((isa<AllocationInst>(inst) && hasActiveResults)) |
| return true; |
| if (hasActiveOperands) { |
| // Should differentiate any instruction that performs reference counting, |
| // lifetime ending, access ending, or destroying on an active operand. |
| if (isa<RefCountingInst>(inst) || isa<EndAccessInst>(inst) || |
| isa<EndBorrowInst>(inst) || isa<DeallocationInst>(inst) || |
| isa<DestroyValueInst>(inst) || isa<DestroyAddrInst>(inst)) |
| return true; |
| // Should differentiate any instruction that creates an SSA copy of an |
| // active operand. |
| if (isa<CopyValueInst>(inst)) |
| return true; |
| } |
| return false; |
| } |
| |
| /// Takes an `apply` instruction and adds its linear map function to the |
| /// linear map struct if it is active. |
| void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai, |
| const SILAutoDiffIndices &indices) { |
| SmallVector<SILValue, 4> allResults; |
| SmallVector<unsigned, 8> activeParamIndices; |
| SmallVector<unsigned, 8> activeResultIndices; |
| collectMinimalIndicesForFunctionCall( |
| ai, indices, activityInfo, allResults, activeParamIndices, |
| activeResultIndices); |
| |
| // Check if there are any active results or arguments. If not, skip |
| // this instruction. |
| auto hasActiveResults = llvm::any_of(allResults, [&](SILValue res) { |
| return activityInfo.isActive(res, indices); |
| }); |
| auto hasActiveArguments = llvm::any_of( |
| ai->getArgumentsWithoutIndirectResults(), [&](SILValue arg) { |
| return activityInfo.isActive(arg, indices); |
| }); |
| if (!hasActiveResults || !hasActiveArguments) |
| return; |
| |
| // Compute differentiation result index. |
| auto source = activeResultIndices.front(); |
| // Compute differentiation parameters. |
| // - If the callee has `@differentiable` function type, use differentiation |
| // parameters from the function type. |
| // - Otherwise, use the active parameters. |
| IndexSubset *parameters; |
| auto origFnSubstTy = ai->getSubstCalleeType(); |
| auto remappedOrigFnSubstTy = |
| remapTypeInDerivative(SILType::getPrimitiveObjectType(origFnSubstTy)) |
| .castTo<SILFunctionType>(); |
| if (remappedOrigFnSubstTy->isDifferentiable()) { |
| parameters = remappedOrigFnSubstTy->getDifferentiationParameterIndices(); |
| } else { |
| parameters = IndexSubset::get( |
| original->getASTContext(), |
| ai->getArgumentsWithoutIndirectResults().size(), |
| activeParamIndices); |
| } |
| // Create autodiff indices for the `apply` instruction. |
| SILAutoDiffIndices applyIndices(source, parameters); |
| |
| // Check for non-differentiable original function type. |
| auto checkNondifferentiableOriginalFunctionType = |
| [&](CanSILFunctionType origFnTy) { |
| // Check non-differentiable arguments. |
| for (unsigned paramIndex : range(origFnTy->getNumParameters())) { |
| auto remappedParamType = |
| origFnTy->getParameters()[paramIndex].getSILStorageType(); |
| if (applyIndices.isWrtParameter(paramIndex) && |
| !remappedParamType.isDifferentiable(derivative->getModule())) |
| return true; |
| } |
| // Check non-differentiable results. |
| auto remappedResultType = |
| origFnTy->getResults()[applyIndices.source].getSILStorageType(); |
| if (!remappedResultType.isDifferentiable(derivative->getModule())) |
| return true; |
| return false; |
| }; |
| if (checkNondifferentiableOriginalFunctionType(remappedOrigFnSubstTy)) |
| return; |
| |
| AutoDiffDerivativeFunctionKind derivativeFnKind(kind); |
| auto derivativeFnType = remappedOrigFnSubstTy->getAutoDiffDerivativeFunctionType( |
| parameters, source, derivativeFnKind, context.getTypeConverter(), |
| LookUpConformanceInModule(derivative->getModule().getSwiftModule())); |
| |
| auto derivativeFnResultTypes = |
| derivativeFnType->getAllResultsType().castTo<TupleType>(); |
| derivativeFnResultTypes->getElement(derivativeFnResultTypes->getElements().size() - 1); |
| auto linearMapSILType = SILType::getPrimitiveObjectType( |
| derivativeFnResultTypes |
| ->getElement(derivativeFnResultTypes->getElements().size() - 1) |
| .getType() |
| ->getCanonicalType()); |
| addLinearMapDecl(ai, linearMapSILType); |
| } |
| |
| void LinearMapInfo::generateDifferentiationDataStructures( |
| ADContext &context, const SILAutoDiffIndices &indices, |
| SILFunction *derivativeFn) { |
| auto &astCtx = original->getASTContext(); |
| auto *loopAnalysis = context.getPassManager().getAnalysis<SILLoopAnalysis>(); |
| auto *loopInfo = loopAnalysis->get(original); |
| |
| // Get the derivative function generic signature. |
| CanGenericSignature derivativeFnGenSig = nullptr; |
| if (auto *derivativeFnGenEnv = derivativeFn->getGenericEnvironment()) |
| derivativeFnGenSig = |
| derivativeFnGenEnv->getGenericSignature()->getCanonicalSignature(); |
| |
| // Create linear map struct for each original block. |
| for (auto &origBB : *original) { |
| auto *linearMapStruct = |
| createLinearMapStruct(&origBB, indices, derivativeFnGenSig); |
| linearMapStructs.insert({&origBB, linearMapStruct}); |
| } |
| |
| // Create branching trace enum for each original block and add it as a field |
| // in the corresponding struct. |
| StringRef traceEnumFieldName; |
| switch (kind) { |
| case AutoDiffLinearMapKind::Differential: |
| traceEnumFieldName = "successor"; |
| break; |
| case AutoDiffLinearMapKind::Pullback: |
| traceEnumFieldName = "predecessor"; |
| break; |
| } |
| for (auto &origBB : *original) { |
| auto *traceEnum = |
| createBranchingTraceDecl(&origBB, indices, derivativeFnGenSig, loopInfo); |
| branchingTraceDecls.insert({&origBB, traceEnum}); |
| if (origBB.isEntry()) |
| continue; |
| // Add branching trace enum field to corresponding linear map struct. |
| auto *linearMapStruct = getLinearMapStruct(&origBB); |
| auto *traceEnumField = |
| addVarDecl(linearMapStruct, |
| astCtx.getIdentifier(traceEnumFieldName).str(), |
| traceEnum->getDeclaredInterfaceType()); |
| linearMapStructEnumFields.insert({linearMapStruct, traceEnumField}); |
| } |
| |
| // Add linear map fields to the linear map structs. |
| for (auto &origBB : *original) { |
| for (auto &inst : origBB) { |
| if (auto *ai = dyn_cast<ApplyInst>(&inst)) { |
| // Check for active 'inout' arguments. |
| bool isInout = false; |
| auto paramInfos = ai->getSubstCalleeConv().getParameters(); |
| for (unsigned i : swift::indices(paramInfos)) { |
| if (paramInfos[i].isIndirectInOut() && |
| activityInfo.isActive(ai->getArgumentsWithoutIndirectResults()[i], |
| indices)) { |
| // Reject functions with active inout arguments. It's not yet |
| // supported. |
| isInout = true; |
| break; |
| } |
| } |
| if (isInout) |
| continue; |
| |
| // Add linear map field to struct for active `apply` instructions. |
| // Skip array literal intrinsic applications since array literal |
| // initialization is linear and handled separately. |
| if (!shouldDifferentiateApplyInst(ai) || isArrayLiteralIntrinsic(ai)) |
| continue; |
| |
| LLVM_DEBUG(getADDebugStream() << "Adding linear map struct field for " |
| << *ai); |
| addLinearMapToStruct(context, ai, indices); |
| } |
| } |
| } |
| |
| // Print generated linear map structs and branching trace enums. |
| // These declarations do not show up with `-emit-sil` because they are |
| // implicit. Instead, use `-Xllvm -debug-only=differentiation` to test |
| // declarations with FileCheck. |
| LLVM_DEBUG({ |
| auto &s = getADDebugStream(); |
| PrintOptions printOptions; |
| printOptions.TypeDefinitions = true; |
| printOptions.ExplodePatternBindingDecls = true; |
| printOptions.SkipImplicit = false; |
| s << "Generated linear map structs and branching trace enums for @" |
| << original->getName() << ":\n"; |
| for (auto &origBB : *original) { |
| auto *linearMapStruct = getLinearMapStruct(&origBB); |
| linearMapStruct->print(s, printOptions); s << '\n'; |
| } |
| for (auto &origBB : *original) { |
| auto *traceEnum = getBranchingTraceDecl(&origBB); |
| traceEnum->print(s, printOptions); s << '\n'; |
| } |
| }); |
| } |
| |
| class DifferentiableActivityCollection { |
| public: |
| SmallDenseMap<GenericSignature, DifferentiableActivityInfo> activityInfoMap; |
| SILFunction &function; |
| DominanceInfo *domInfo; |
| PostDominanceInfo *postDomInfo; |
| |
| DifferentiableActivityInfo &getActivityInfo( |
| GenericSignature assocGenSig, AutoDiffDerivativeFunctionKind kind) { |
| auto activityInfoLookup = activityInfoMap.find(assocGenSig); |
| if (activityInfoLookup != activityInfoMap.end()) |
| return activityInfoLookup->getSecond(); |
| auto insertion = activityInfoMap.insert( |
| {assocGenSig, DifferentiableActivityInfo(*this, assocGenSig)}); |
| return insertion.first->getSecond(); |
| } |
| |
| explicit DifferentiableActivityCollection(SILFunction &f, |
| DominanceInfo *di, |
| PostDominanceInfo *pdi); |
| }; |
| |
| } // end anonymous namespace |
| |
| std::unique_ptr<DifferentiableActivityCollection> |
| DifferentiableActivityAnalysis::newFunctionAnalysis(SILFunction *f) { |
| assert(dominanceAnalysis && "Expect a valid dominance anaysis"); |
| assert(postDominanceAnalysis && "Expect a valid post-dominance anaysis"); |
| return llvm::make_unique<DifferentiableActivityCollection>( |
| *f, dominanceAnalysis->get(f), postDominanceAnalysis->get(f)); |
| } |
| |
| void DifferentiableActivityAnalysis::initialize(SILPassManager *pm) { |
| dominanceAnalysis = pm->getAnalysis<DominanceAnalysis>(); |
| postDominanceAnalysis = pm->getAnalysis<PostDominanceAnalysis>(); |
| } |
| |
| SILAnalysis *swift::createDifferentiableActivityAnalysis(SILModule *m) { |
| return new DifferentiableActivityAnalysis(); |
| } |
| |
| DifferentiableActivityCollection::DifferentiableActivityCollection( |
| SILFunction &f, DominanceInfo *di, PostDominanceInfo *pdi) |
| : function(f), domInfo(di), postDomInfo(pdi) {} |
| |
| DifferentiableActivityInfo::DifferentiableActivityInfo( |
| DifferentiableActivityCollection &parent, GenericSignature derivGenSig) |
| : parent(parent), derivativeGenericSignature(derivGenSig) { |
| analyze(parent.domInfo, parent.postDomInfo); |
| } |
| |
| SILFunction &DifferentiableActivityInfo::getFunction() { |
| return parent.function; |
| } |
| |
| void DifferentiableActivityInfo::analyze(DominanceInfo *di, |
| PostDominanceInfo *pdi) { |
| auto &function = getFunction(); |
| LLVM_DEBUG(getADDebugStream() |
| << "Running activity analysis on @" << function.getName() << '\n'); |
| // Inputs are just function's arguments, count `n`. |
| auto paramArgs = function.getArgumentsWithoutIndirectResults(); |
| for (auto value : paramArgs) |
| inputValues.push_back(value); |
| LLVM_DEBUG({ |
| auto &s = getADDebugStream(); |
| s << "Inputs in @" << function.getName() << ":\n"; |
| for (auto val : inputValues) |
| s << val << '\n'; |
| }); |
| // Outputs are indirect result buffers and return values, count `m`. |
| collectAllFormalResultsInTypeOrder(function, outputValues); |
| LLVM_DEBUG({ |
| auto &s = getADDebugStream(); |
| s << "Outputs in @" << function.getName() << ":\n"; |
| for (auto val : outputValues) |
| s << val << '\n'; |
| }); |
| |
| // Mark inputs as varied. |
| assert(variedValueSets.empty()); |
| for (auto input : inputValues) |
| variedValueSets.push_back({input}); |
| // Propagate varied-ness through the function in dominance order. |
| DominanceOrder domOrder(function.getEntryBlock(), di); |
| while (auto *bb = domOrder.getNext()) { |
| for (auto &inst : *bb) { |
| for (auto i : indices(inputValues)) { |
| // Handle `apply`. |
| if (auto *ai = dyn_cast<ApplyInst>(&inst)) { |
| // If callee is non-varying, skip. |
| if (isWithoutDerivative(ai->getCallee())) |
| continue; |
| // If any argument is varied, set all direct and indirect results as |
| // varied. |
| for (auto arg : ai->getArgumentsWithoutIndirectResults()) { |
| if (isVaried(arg, i)) { |
| for (auto indRes : ai->getIndirectSILResults()) |
| setVaried(indRes, i); |
| forEachApplyDirectResult(ai, [&](SILValue directResult) { |
| setVaried(directResult, i); |
| }); |
| } |
| } |
| } |
| // Handle store-like instructions: |
| // `store`, `store_borrow`, `copy_addr`, `unconditional_checked_cast` |
| #define PROPAGATE_VARIED_THROUGH_STORE(INST) \ |
| else if (auto *si = dyn_cast<INST##Inst>(&inst)) { \ |
| if (isVaried(si->getSrc(), i)) \ |
| propagateVariedInwardsThroughProjections(si->getDest(), i); \ |
| } |
| PROPAGATE_VARIED_THROUGH_STORE(Store) |
| PROPAGATE_VARIED_THROUGH_STORE(StoreBorrow) |
| PROPAGATE_VARIED_THROUGH_STORE(CopyAddr) |
| PROPAGATE_VARIED_THROUGH_STORE(UnconditionalCheckedCastAddr) |
| #undef PROPAGATE_VARIED_THROUGH_STORE |
| // Handle `tuple_element_addr`. |
| else if (auto *teai = dyn_cast<TupleElementAddrInst>(&inst)) { |
| if (isVaried(teai->getOperand(), i)) { |
| auto projType = teai->getType().getASTType(); |
| if (derivativeGenericSignature && projType->hasArchetype()) |
| projType = derivativeGenericSignature->getCanonicalTypeInContext( |
| projType->mapTypeOutOfContext()); |
| if (projType->getAutoDiffAssociatedTangentSpace( |
| getLookupConformanceFunction())) |
| setVaried(teai, i); |
| } |
| } |
| // Handle `struct_extract` and `struct_element_addr` instructions. |
| // - If the field is marked `@noDerivative`, do not set the result as |
| // varied because it is not in the set of differentiable variables. |
| // - Otherwise, propagate variedness from operand to result as usual. |
| #define PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(INST) \ |
| else if (auto *sei = dyn_cast<INST##Inst>(&inst)) { \ |
| if (isVaried(sei->getOperand(), i) && \ |
| !sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \ |
| setVaried(sei, i); \ |
| } |
| PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(StructExtract) |
| PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(StructElementAddr) |
| #undef PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION |
| // Handle `br`. |
| else if (auto *bi = dyn_cast<BranchInst>(&inst)) { |
| for (auto &op : bi->getAllOperands()) |
| if (isVaried(op.get(), i)) |
| setVaried(bi->getArgForOperand(&op), i); |
| } |
| // Handle `cond_br`. |
| else if (auto *cbi = dyn_cast<CondBranchInst>(&inst)) { |
| for (unsigned opIdx : indices(cbi->getTrueOperands())) { |
| auto &op = cbi->getTrueOperands()[opIdx]; |
| if (isVaried(op.get(), i)) |
| setVaried(cbi->getTrueBB()->getArgument(opIdx), i); |
| } |
| for (unsigned opIdx : indices(cbi->getFalseOperands())) { |
| auto &op = cbi->getFalseOperands()[opIdx]; |
| if (isVaried(op.get(), i)) |
| setVaried(cbi->getFalseBB()->getArgument(opIdx), i); |
| } |
| } |
| // Handle `switch_enum`. |
| else if (auto *sei = dyn_cast<SwitchEnumInst>(&inst)) { |
| if (isVaried(sei->getOperand(), i)) |
| for (auto *succBB : sei->getSuccessorBlocks()) |
| for (auto *arg : succBB->getArguments()) |
| setVaried(arg, i); |
| } |
| // Handle everything else. |
| else { |
| for (auto &op : inst.getAllOperands()) |
| if (isVaried(op.get(), i)) |
| for (auto result : inst.getResults()) |
| setVaried(result, i); |
| } |
| } |
| } |
| domOrder.pushChildren(bb); |
| } |
| |
| // Mark differentiable outputs as useful. |
| assert(usefulValueSets.empty()); |
| for (auto output : outputValues) { |
| usefulValueSets.push_back({}); |
| // If the output has an address or class type, propagate usefulness |
| // recursively. |
| if (output->getType().isAddress() || |
| output->getType().isClassOrClassMetatype()) |
| propagateUsefulThroughBuffer(output, usefulValueSets.size() - 1); |
| // Otherwise, just mark the output as useful. |
| else |
| setUseful(output, usefulValueSets.size() - 1); |
| } |
| // Propagate usefulness through the function in post-dominance order. |
| PostDominanceOrder postDomOrder(&*function.findReturnBB(), pdi); |
| while (auto *bb = postDomOrder.getNext()) { |
| for (auto &inst : reversed(*bb)) { |
| for (auto i : indices(outputValues)) { |
| // Handle indirect results in `apply`. |
| if (auto *ai = dyn_cast<ApplyInst>(&inst)) { |
| if (isWithoutDerivative(ai->getCallee())) |
| continue; |
| auto checkAndSetUseful = [&](SILValue res) { |
| if (isUseful(res, i)) |
| for (auto arg : ai->getArgumentsWithoutIndirectResults()) |
| setUseful(arg, i); |
| }; |
| for (auto dirRes : ai->getResults()) |
| checkAndSetUseful(dirRes); |
| for (auto indRes : ai->getIndirectSILResults()) |
| checkAndSetUseful(indRes); |
| auto paramInfos = ai->getSubstCalleeConv().getParameters(); |
| for (auto i : indices(paramInfos)) |
| if (paramInfos[i].isIndirectInOut()) |
| checkAndSetUseful(ai->getArgumentsWithoutIndirectResults()[i]); |
| } |
| // Handle store-like instructions: |
| // `store`, `store_borrow`, `copy_addr`, `unconditional_checked_cast` |
| #define PROPAGATE_USEFUL_THROUGH_STORE(INST, PROPAGATE) \ |
| else if (auto *si = dyn_cast<INST##Inst>(&inst)) { \ |
| if (isUseful(si->getDest(), i)) \ |
| PROPAGATE(si->getSrc(), i); \ |
| } |
| PROPAGATE_USEFUL_THROUGH_STORE(Store, setUseful) |
| PROPAGATE_USEFUL_THROUGH_STORE(StoreBorrow, setUseful) |
| PROPAGATE_USEFUL_THROUGH_STORE(CopyAddr, propagateUsefulThroughBuffer) |
| PROPAGATE_USEFUL_THROUGH_STORE(UnconditionalCheckedCastAddr, |
| propagateUsefulThroughBuffer) |
| #undef PROPAGATE_USEFUL_THROUGH_STORE |
| // Handle struct element extraction, skipping `@noDerivative` fields: |
| // `struct_extract`, `struct_element_addr`. |
| #define PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(INST, PROPAGATE) \ |
| else if (auto *sei = dyn_cast<INST##Inst>(&inst)) { \ |
| if (isUseful(sei, i)) { \ |
| auto hasNoDeriv = sei->getField()->getAttrs() \ |
| .hasAttribute<NoDerivativeAttr>(); \ |
| if (!hasNoDeriv) \ |
| PROPAGATE(sei->getOperand(), i); \ |
| } \ |
| } |
| PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructExtract, setUseful) |
| PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructElementAddr, |
| propagateUsefulThroughBuffer) |
| #undef PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION |
| // Handle everything else. |
| else if (llvm::any_of(inst.getResults(), |
| [&](SILValue res) { return isUseful(res, i); })) { |
| for (auto &op : inst.getAllOperands()) { |
| auto value = op.get(); |
| if (value->getType().isAddress()) |
| propagateUsefulThroughBuffer(value, i); |
| setUseful(value, i); |
| } |
| } |
| } |
| } |
| // Propagate usefulness from basic block arguments to incoming phi values. |
| for (auto i : indices(outputValues)) { |
| for (auto *arg : bb->getArguments()) { |
| if (isUseful(arg, i)) { |
| SmallVector<SILValue, 4> incomingValues; |
| arg->getSingleTerminatorOperands(incomingValues); |
| for (auto incomingValue : incomingValues) |
| setUseful(incomingValue, i); |
| } |
| } |
| } |
| postDomOrder.pushChildren(bb); |
| } |
| } |
| |
| void DifferentiableActivityInfo::setVariedAcrossArrayInitialization( |
| SILValue value, unsigned independentVariableIndex) { |
| auto uai = getAllocateUninitializedArrayIntrinsic(value); |
| if (!uai) return; |
| for (auto use : value->getUses()) |
| if (auto *dti = dyn_cast<DestructureTupleInst>(use->getUser())) |
| // The first tuple field of the intrinsic's return value is the array. |
| setVaried(dti->getResult(0), independentVariableIndex); |
| } |
| |
| void DifferentiableActivityInfo::setUsefulAcrossArrayInitialization( |
| SILValue value, unsigned dependentVariableIndex) { |
| // Array initializer syntax is lowered to an intrinsic and one or more |
| // stores to a `RawPointer` returned by the intrinsic. |
| auto uai = getAllocateUninitializedArrayIntrinsic(value); |
| if (!uai) return; |
| for (auto use : value->getUses()) { |
| auto dti = dyn_cast<DestructureTupleInst>(use->getUser()); |
| if (!dti) continue; |
| // The second tuple field of the return value is the `RawPointer`. |
| for (auto use : dti->getResult(1)->getUses()) { |
| // The `RawPointer` passes through a `pointer_to_address`. That |
| // instruction's first use is a `store` whose src is useful; its |
| // subsequent uses are `index_addr`s whose only use is a useful `store`. |
| for (auto use : use->getUser()->getResult(0)->getUses()) { |
| auto inst = use->getUser(); |
| if (auto si = dyn_cast<StoreInst>(inst)) { |
| setUseful(si->getSrc(), dependentVariableIndex); |
| } else if (auto iai = dyn_cast<IndexAddrInst>(inst)) { |
| for (auto use : iai->getUses()) |
| if (auto si = dyn_cast<StoreInst>(use->getUser())) |
| setUseful(si->getSrc(), dependentVariableIndex); |
| } |
| } |
| } |
| } |
| } |
| |
| void DifferentiableActivityInfo::setVaried(SILValue value, |
| unsigned independentVariableIndex) { |
| variedValueSets[independentVariableIndex].insert(value); |
| setVariedAcrossArrayInitialization(value, independentVariableIndex); |
| } |
| |
| void DifferentiableActivityInfo::setUseful(SILValue value, |
| unsigned dependentVariableIndex) { |
| usefulValueSets[dependentVariableIndex].insert(value); |
| setUsefulAcrossArrayInitialization(value, dependentVariableIndex); |
| } |
| |
| void DifferentiableActivityInfo::propagateVariedInwardsThroughProjections( |
| SILValue value, unsigned independentVariableIndex) { |
| #define SKIP_NODERIVATIVE(INST) \ |
| if (auto *sei = dyn_cast<INST##Inst>(value)) \ |
| if (sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \ |
| return; |
| SKIP_NODERIVATIVE(StructExtract) |
| SKIP_NODERIVATIVE(StructElementAddr) |
| #undef SKIP_NODERIVATIVE |
| setVaried(value, independentVariableIndex); |
| auto *inst = value->getDefiningInstruction(); |
| if (!inst || isa<ApplyInst>(inst)) |
| return; |
| // Standard propagation. |
| for (auto &op : inst->getAllOperands()) |
| propagateVariedInwardsThroughProjections( |
| op.get(), independentVariableIndex); |
| } |
| |
| void DifferentiableActivityInfo::propagateUsefulThroughBuffer( |
| SILValue value, unsigned dependentVariableIndex) { |
| assert(value->getType().isAddress() || |
| value->getType().isClassOrClassMetatype()); |
| // Check whether value is already useful to prevent infinite recursion. |
| if (isUseful(value, dependentVariableIndex)) |
| return; |
| setUseful(value, dependentVariableIndex); |
| if (auto *inst = value->getDefiningInstruction()) |
| for (auto &operand : inst->getAllOperands()) |
| if (operand.get()->getType().isAddress()) |
| propagateUsefulThroughBuffer(operand.get(), dependentVariableIndex); |
| // Recursively propagate usefulness through users that are projections or |
| // `begin_access` instructions. |
| for (auto use : value->getUses()) { |
| for (auto res : use->getUser()->getResults()) { |
| #define SKIP_NODERIVATIVE(INST) \ |
| if (auto *sei = dyn_cast<INST##Inst>(res)) \ |
| if (sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \ |
| continue; |
| SKIP_NODERIVATIVE(StructExtract) |
| SKIP_NODERIVATIVE(StructElementAddr) |
| #undef SKIP_NODERIVATIVE |
| if (Projection::isAddressProjection(res) || isa<BeginAccessInst>(res)) |
| propagateUsefulThroughBuffer(res, dependentVariableIndex); |
| } |
| } |
| } |
| |
| bool DifferentiableActivityInfo::isVaried( |
| SILValue value, unsigned independentVariableIndex) const { |
| assert(independentVariableIndex < variedValueSets.size() && |
| "Independent variable index out of range"); |
| auto &set = variedValueSets[independentVariableIndex]; |
| return set.count(value); |
| } |
| |
| bool DifferentiableActivityInfo::isVaried( |
| SILValue value, IndexSubset *parameterIndices) const { |
| for (auto paramIdx : parameterIndices->getIndices()) |
| if (isVaried(value, paramIdx)) |
| return true; |
| return false; |
| } |
| |
| bool DifferentiableActivityInfo::isUseful( |
| SILValue value, unsigned dependentVariableIndex) const { |
| assert(dependentVariableIndex < usefulValueSets.size() && |
| "Dependent variable index out of range"); |
| auto &set = usefulValueSets[dependentVariableIndex]; |
| return set.count(value); |
| } |
| |
| bool DifferentiableActivityInfo::isActive( |
| SILValue value, const SILAutoDiffIndices &indices) const { |
| return isVaried(value, indices.parameters) && isUseful(value, indices.source); |
| } |
| |
| Activity DifferentiableActivityInfo::getActivity( |
| SILValue value, const SILAutoDiffIndices &indices) const { |
| Activity activity; |
| if (isVaried(value, indices.parameters)) |
| activity |= ActivityFlags::Varied; |
| if (isUseful(value, indices.source)) |
| activity |= ActivityFlags::Useful; |
| return activity; |
| } |
| |
| Activity DifferentiableActivityInfo::getActivity( |
| SILInstruction *inst, const SILAutoDiffIndices &indices) const { |
| Activity activity; |
| for (auto result : inst->getResults()) |
| activity |= getActivity(result, indices); |
| return activity; |
| } |
| |
| static void dumpActivityInfo(SILValue value, |
| const SILAutoDiffIndices &indices, |
| const DifferentiableActivityInfo &activityInfo, |
| llvm::raw_ostream &s = llvm::dbgs()) { |
| s << '['; |
| auto activity = activityInfo.getActivity(value, indices); |
| switch (activity.toRaw()) { |
| case 0: s << "NONE"; break; |
| case (unsigned)ActivityFlags::Varied: s << "VARIED"; break; |
| case (unsigned)ActivityFlags::Useful: s << "USEFUL"; break; |
| case (unsigned)ActivityFlags::Active: s << "ACTIVE"; break; |
| } |
| s << "] " << value; |
| } |
| |
| static void dumpActivityInfo(SILFunction &fn, |
| const SILAutoDiffIndices &indices, |
| const DifferentiableActivityInfo &activityInfo, |
| llvm::raw_ostream &s = llvm::dbgs()) { |
| s << "Activity info for " << fn.getName() << " at " << indices << '\n'; |
| for (auto &bb : fn) { |
| s << "bb" << bb.getDebugID() << ":\n"; |
| for (auto *arg : bb.getArguments()) |
| dumpActivityInfo(arg, indices, activityInfo, s); |
| for (auto &inst : bb) |
| for (auto res : inst.getResults()) |
| dumpActivityInfo(res, indices, activityInfo, s); |
| s << '\n'; |
| } |
| } |
| |
| /// If the original function doesn't have a return, it cannot be differentiated. |
| /// Returns true if error is emitted. |
| static bool diagnoseNoReturn(ADContext &context, SILFunction *original, |
| DifferentiationInvoker invoker) { |
| if (original->findReturnBB() != original->end()) |
| return false; |
| context.emitNondifferentiabilityError( |
| original->getLocation().getEndSourceLoc(), invoker, |
| diag::autodiff_missing_return); |
| return true; |
| } |
| |
| /// If the original function contains unsupported control flow, emit a "control |
| /// flow unsupported" error at appropriate source locations. Returns true if |
| /// error is emitted. |
| /// |
| /// Update as control flow support is added. Currently, branching terminators |
| /// other than `br`, `cond_br`, `switch_enum` are not supported. |
| static bool diagnoseUnsupportedControlFlow(ADContext &context, |
| SILFunction *original, |
| DifferentiationInvoker invoker) { |
| if (original->getBlocks().size() <= 1) |
| return false; |
| // Diagnose unsupported branching terminators. |
| for (auto &bb : *original) { |
| auto *term = bb.getTerminator(); |
| // Supported terminators are: `br`, `cond_br`, `switch_enum`. |
| if (isa<BranchInst>(term) || isa<CondBranchInst>(term) || |
| isa<SwitchEnumInst>(term)) |
| continue; |
| // If terminator is an unsupported branching terminator, emit an error. |
| if (term->isBranch()) { |
| context.emitNondifferentiabilityError( |
| term, invoker, diag::autodiff_control_flow_not_supported); |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| /// Check whether the given requirements are satisfied, with the given |
| /// derivative generic signature (containing requirements), original function, |
| /// and substitution map. Returns true if error is emitted. |
| static bool diagnoseUnsatisfiedRequirements(ADContext &context, |
| GenericSignature derivativeGenSig, |
| SILFunction *original, |
| SubstitutionMap substMap, |
| DifferentiationInvoker invoker, |
| SourceLoc loc) { |
| // If there are no derivative requirements, return false. |
| if (!derivativeGenSig) |
| return false; |
| auto requirements = derivativeGenSig->getRequirements(); |
| if (requirements.empty()) |
| return false; |
| // Iterate through all requirements and check whether they are satisfied. |
| auto *swiftModule = context.getModule().getSwiftModule(); |
| SmallVector<Requirement, 2> unsatisfiedRequirements; |
| for (auto req : requirements) { |
| auto firstType = req.getFirstType(); |
| Type secondType; |
| // Substitute first and second types using the given substitution map, |
| // looking up conformances in the current module, if possible. |
| if (auto substFirstType = |
| firstType.subst(QuerySubstitutionMap{substMap}, |
| LookUpConformanceInModule(swiftModule))) { |
| firstType = substFirstType; |
| } |
| if (req.getKind() != RequirementKind::Layout) { |
| secondType = req.getSecondType(); |
| if (auto substSecondType = |
| secondType.subst(QuerySubstitutionMap{substMap}, |
| LookUpConformanceInModule(swiftModule))) { |
| secondType = substSecondType; |
| } |
| } |
| switch (req.getKind()) { |
| // Check layout requirements. |
| case RequirementKind::Layout: { |
| auto layout = req.getLayoutConstraint(); |
| switch (layout->getKind()) { |
| case LayoutConstraintKind::Class: |
| if (!firstType->satisfiesClassConstraint()) |
| unsatisfiedRequirements.push_back(req); |
| continue; |
| default: |
| // TODO: Check other layout requirements. Note that `@differentiable` |
| // attribute type-checking does not yet support layout requirements in |
| // where clauses; layout requirements in derivative generic signatures |
| // can be formed only from `differentiable_function` instructions whose |
| // original function operand is generic with layout requirements. |
| break; |
| } |
| continue; |
| } |
| // Check same type requirements. |
| case RequirementKind::SameType: |
| // If the first type does not equal the second type, then record the |
| // unsatisfied requirement. |
| if (!firstType->isEqual(secondType)) |
| unsatisfiedRequirements.push_back(req); |
| continue; |
| // Check superclass requirements. |
| case RequirementKind::Superclass: { |
| // If the second type is not an exact superclass of second type, then |
| // record the unsatisfied requirement. |
| if (!secondType->isExactSuperclassOf(firstType)) |
| unsatisfiedRequirements.push_back(req); |
| continue; |
| } |
| // Check conformance requirements. |
| case RequirementKind::Conformance: { |
| auto protocolType = req.getSecondType()->castTo<ProtocolType>(); |
| auto protocol = protocolType->getDecl(); |
| assert(protocol && "Expected protocol in generic signature requirement"); |
| // If the first type does not conform to the second type in the current |
| // module, then record the unsatisfied requirement. |
| if (!swiftModule->lookupConformance(firstType, protocol)) |
| unsatisfiedRequirements.push_back(req); |
| continue; |
| } |
| } |
| } |
| if (unsatisfiedRequirements.empty()) |
| return false; |
| // Diagnose unsatisfied requirements. |
| std::string reqText; |
| llvm::raw_string_ostream stream(reqText); |
| interleave(unsatisfiedRequirements, |
| [&](Requirement req) { req.print(stream, PrintOptions()); }, |
| [&] { stream << ", "; }); |
| context.emitNondifferentiabilityError( |
| loc, invoker, diag::autodiff_function_assoc_func_unmet_requirements, |
| stream.str()); |
| return true; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Code emission utilities |
| //===----------------------------------------------------------------------===// |
| |
| /// Given a value, extracts all elements to `results` from this value if it has |
| /// a tuple type. Otherwise, add this value directly to `results`. |
| static void extractAllElements(SILValue value, SILBuilder &builder, |
| SmallVectorImpl<SILValue> &results) { |
| auto tupleType = value->getType().getAs<TupleType>(); |
| if (!tupleType) { |
| results.push_back(value); |
| return; |
| } |
| if (builder.hasOwnership()) { |
| auto *dti = builder.createDestructureTuple(value.getLoc(), value); |
| results.append(dti->getResults().begin(), dti->getResults().end()); |
| return; |
| } |
| for (auto i : range(tupleType->getNumElements())) |
| results.push_back(builder.createTupleExtract(value.getLoc(), value, i)); |
| } |
| |
| /// Given a range of elements, joins these into a single value. If there's |
| /// exactly one element, returns that element. Otherwise, creates a tuple using |
| /// a `tuple` instruction. |
| static SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder, |
| SILLocation loc) { |
| if (elements.size() == 1) |
| return elements.front(); |
| return builder.createTuple(loc, elements); |
| } |
| |
| /// Given an apply site, emit copies of all parameters and place them in |
| /// `copiedArgs`. Any buffers that need to be destroyed will be added to |
| /// `newArgsToDestroy`. Any new buffers that need to be deallocated will be |
| /// added to `newBuffersToDealloc`. This helper is used for duplicating an |
| /// apply site. |
| static void copyParameterArgumentsForApply( |
| ApplySite applySite, SmallVectorImpl<SILValue> &copiedArgs, |
| SmallVectorImpl<SILValue> &newArgsToDestroy, |
| SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) { |
| LLVM_DEBUG({ |
| auto &s = getADDebugStream() << "Copying arguments from apply site: "; |
| applySite.getInstruction()->print(s); |
| }); |
| auto loc = applySite.getLoc(); |
| copiedArgs.reserve(applySite.getNumArguments()); |
| SILBuilder copyBuilder(applySite.getInstruction()); |
| for (auto &argOperand : applySite.getArgumentOperands()) { |
| auto arg = argOperand.get(); |
| auto argConv = applySite.getArgumentConvention(argOperand); |
| auto collectNewArg = [&](SILValue newArg) { |
| copiedArgs.push_back(newArg); |
| if (argConv.isGuaranteedConvention() && |
| argConv != SILArgumentConvention::Indirect_InoutAliasable) |
| newArgsToDestroy.push_back(newArg); |
| }; |
| // Copy the argument if it's to be owned by the newly created closure. |
| // Objects are to be retained. |
| if (arg->getType().isObject()) { |
| auto newArg = copyBuilder.emitCopyValueOperation(loc, arg); |
| collectNewArg(newArg); |
| continue; |
| } |
| // Addresses depend on argument conventions. |
| // If the argument is an aliasable inout reference, do not copy the |
| // argument since it's a `@noescape` capture. |
| if (argConv == SILArgumentConvention::Indirect_InoutAliasable) { |
| collectNewArg(arg); |
| continue; |
| } |
| // Otherwise, it must be address-only. Create a new buffer and perform |
| // `copy_addr`. |
| auto *argCopy = copyBuilder.createAllocStack(loc, arg->getType()); |
| newBuffersToDealloc.push_back(argCopy); |
| copyBuilder.createCopyAddr(loc, arg, argCopy, IsNotTake, |
| IsInitialization); |
| collectNewArg(argCopy); |
| } |
| } |
| |
| /// When a function value is used in an instruction (usually `apply`), there's |
| /// some conversion instruction in between, e.g. `thin_to_thick_function`. Given |
| /// a new function value and an old function value, this helper function |
| /// recursively converts the new function just like how the old function is |
| /// converted. If the new function's generic signature is specified, it is used |
| /// to create substitution maps for reapplied `partial_apply` instructions. |
| static SILValue |
| reapplyFunctionConversion( |
| SILValue newFunc, SILValue oldFunc, SILValue oldConvertedFunc, |
| SILBuilder &builder, SILLocation loc, |
| SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc, |
| GenericSignature newFuncGenSig = GenericSignature()) { |
| // If the old func is the new func, then there's no conversion. |
| if (oldFunc == oldConvertedFunc) |
| return newFunc; |
| // Handle a few instruction cases. |
| // thin_to_thick_function |
| if (auto *tttfi = dyn_cast<ThinToThickFunctionInst>(oldConvertedFunc)) { |
| auto innerNewFunc = reapplyFunctionConversion( |
| newFunc, oldFunc, tttfi->getOperand(), builder, loc, |
| newBuffersToDealloc, newFuncGenSig); |
| auto operandFnTy = innerNewFunc->getType().castTo<SILFunctionType>(); |
| auto thickTy = operandFnTy->getWithRepresentation( |
| SILFunctionTypeRepresentation::Thick); |
| auto silTy = SILType::getPrimitiveObjectType(thickTy); |
| return builder.createThinToThickFunction(loc, innerNewFunc, silTy); |
| } |
| // partial_apply |
| if (auto *pai = dyn_cast<PartialApplyInst>(oldConvertedFunc)) { |
| SmallVector<SILValue, 8> newArgs; |
| newArgs.reserve(pai->getNumArguments()); |
| SmallVector<SILValue, 1> newArgsToDestroy; |
| copyParameterArgumentsForApply(pai, newArgs, newArgsToDestroy, |
| newBuffersToDealloc); |
| auto innerNewFunc = reapplyFunctionConversion( |
| newFunc, oldFunc, pai->getCallee(), builder, loc, newBuffersToDealloc, |
| newFuncGenSig); |
| // If new function's generic signature is specified, use it to create |
| // substitution map for reapplied `partial_apply` instruction. |
| auto substMap = !newFuncGenSig |
| ? pai->getSubstitutionMap() |
| : SubstitutionMap::get( |
| newFuncGenSig, QuerySubstitutionMap{pai->getSubstitutionMap()}, |
| LookUpConformanceInModule(builder.getModule().getSwiftModule())); |
| return builder.createPartialApply(loc, innerNewFunc, substMap, newArgs, |
| ParameterConvention::Direct_Guaranteed); |
| } |
| llvm_unreachable("Unhandled function conversion instruction"); |
| } |
| |
| /// Emits a reference to a derivative function of `original`, differentiated |
| /// with respect to a superset of `desiredIndices`. Returns the `SILValue` for |
| /// the derivative function and the actual indices that the derivative function |
| /// is with respect to. |
| /// |
| /// Returns `None` on failure, signifying that a diagnostic has been emitted. |
| /// |
| /// Creates new differentiation tasks, if necessary, using `invoker` as the |
| /// invoker. Calls `taskCallback` for all newly-created tasks (but may also call |
| /// `taskCallback` for already-existing tasks), so that the caller can make sure |
| /// that the task actually gets executed. |
| /// |
| /// FIXME: This is too complicated and needs to be rewritten. |
| static Optional<std::pair<SILValue, SILAutoDiffIndices>> |
| emitDerivativeFunctionReference( |
| ADContext &context, SILBuilder &builder, SILAutoDiffIndices desiredIndices, |
| AutoDiffDerivativeFunctionKind kind, SILValue original, |
| DifferentiationInvoker invoker, |
| SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) { |
| |
| SILValue functionSource = original; |
| |
| // If `original` is itself an `DifferentiableFunctionExtractInst` whose kind matches |
| // the given kind and desired differentiation parameter indices, simply |
| // extract the derivative function of its function operand, retain the |
| // derivative function, and return it. |
| if (auto *inst = original->getDefiningInstruction()) |
| if (auto *dfei = dyn_cast<DifferentiableFunctionExtractInst>(inst)) |
| if (dfei->getExtractee() == |
| NormalDifferentiableFunctionTypeComponent::Original) |
| functionSource = dfei->getFunctionOperand(); |
| |
| // If `functionSource` is a `@differentiable` function, just extract the |
| // derivative function. |
| if (auto diffableFnType = |
| functionSource->getType().castTo<SILFunctionType>()) { |
| if (diffableFnType->isDifferentiable()) { |
| auto paramIndices = diffableFnType->getDifferentiationParameterIndices(); |
| for (auto i : desiredIndices.parameters->getIndices()) { |
| if (!paramIndices->contains(i)) { |
| context.emitNondifferentiabilityError(functionSource, invoker, |
| diag::autodiff_function_nondiff_parameter_not_differentiable); |
| return None; |
| } |
| } |
| auto borrowedDiffFunc = builder.emitBeginBorrowOperation( |
| functionSource.getLoc(), functionSource); |
| SILValue derivativeFn = builder.createDifferentiableFunctionExtract( |
| borrowedDiffFunc.getLoc(), kind, borrowedDiffFunc); |
| derivativeFn = |
| builder.emitCopyValueOperation(functionSource.getLoc(), derivativeFn); |
| builder.emitEndBorrowOperation(functionSource.getLoc(), borrowedDiffFunc); |
| SILAutoDiffIndices indices(0, desiredIndices.parameters); |
| return std::make_pair(derivativeFn, indices); |
| } |
| } |
| |
| // Find local function reference. |
| if (auto *originalFRI = |
| peerThroughFunctionConversions<FunctionRefInst>(original)) { |
| auto loc = originalFRI->getLoc(); |
| auto *originalFn = originalFRI->getReferencedFunctionOrNull(); |
| // Attempt to look up a `[differentiable]` attribute that minimally |
| // satisfies the specified indices. |
| // TODO(TF-482): Change `lookUpMinimalDifferentiableAttr` to additionally |
| // check whether `[differentiable]` attribute generic requirements are |
| // satisfied. |
| auto *minimalAttr = |
| context.lookUpMinimalDifferentiableAttr(originalFn, desiredIndices); |
| if (!minimalAttr) { |
| // If the function is intentionally marked as being opaque to |
| // differentiation, then we should not create a task for it. |
| if (originalFn->hasSemanticsAttr("autodiff.opaque")) { |
| context.emitNondifferentiabilityError(original, invoker, |
| diag::autodiff_opaque_function_not_differentiable); |
| return None; |
| } |
| // Check and diagnose non-differentiable arguments. |
| auto originalFnTy = originalFn->getLoweredFunctionType(); |
| for (unsigned paramIndex : range(originalFnTy->getNumParameters())) { |
| if (desiredIndices.isWrtParameter(paramIndex) && |
| !originalFnTy->getParameters()[paramIndex] |
| .getSILStorageType() |
| .isDifferentiable(context.getModule())) { |
| auto diag = context.emitNondifferentiabilityError( |
| original, invoker, diag::autodiff_nondifferentiable_argument); |
| return None; |
| } |
| } |
| // Check and diagnose non-differentiable results. |
| if (!originalFnTy->getResults()[desiredIndices.source] |
| .getSILStorageType() |
| .isDifferentiable(context.getModule())) { |
| context.emitNondifferentiabilityError( |
| original, invoker, diag::autodiff_nondifferentiable_result); |
| return None; |
| } |
| // Check and diagnose external declarations. |
| if (originalFn->isExternalDeclaration()) { |
| context.emitNondifferentiabilityError( |
| original, invoker, |
| diag::autodiff_external_nondifferentiable_function); |
| return None; |
| } |
| // Sanity check passed. Create a new `[differentiable]` attribute and |
| // process it it. |
| GenericSignature contextualDerivativeGenSig = GenericSignature(); |
| if (invoker.getKind() == |
| DifferentiationInvoker::Kind::IndirectDifferentiation) |
| contextualDerivativeGenSig = invoker.getIndirectDifferentiation().second |
| ->getDerivativeGenericSignature(); |
| auto *newAttr = context.getOrCreateDifferentiableAttr( |
| originalFn, desiredIndices, contextualDerivativeGenSig); |
| if (context.processDifferentiableAttribute(originalFn, newAttr, invoker)) |
| return None; |
| minimalAttr = newAttr; |
| } |
| assert(minimalAttr); |
| // TODO(TF-482): Move generic requirement checking logic to |
| // `lookUpMinimalDifferentiableAttr`. |
| // Get the substitution map for checking unmet generic requirements. |
| // By default, use the forwarding substitution map of the original function. |
| // If the original callee is a `partial_apply` or `apply` instruction, use |
| // its substitution map instead. |
| auto substMap = original->getFunction()->getForwardingSubstitutionMap(); |
| if (auto *pai = dyn_cast<PartialApplyInst>(original)) { |
| substMap = pai->getSubstitutionMap(); |
| } else if (auto *ai = dyn_cast<ApplyInst>(original)) { |
| substMap = ai->getSubstitutionMap(); |
| } |
| if (diagnoseUnsatisfiedRequirements( |
| context, minimalAttr->getDerivativeGenericSignature(), originalFn, |
| substMap, invoker, original.getLoc().getSourceLoc())) |
| return None; |
| if (context.processDifferentiableAttribute( |
| originalFn, minimalAttr, invoker)) |
| return None; |
| SILFunction *derivativeFn = nullptr; |
| switch (kind) { |
| case AutoDiffDerivativeFunctionKind::JVP: |
| assert(!minimalAttr->getJVPName().empty() && "Expected JVP name"); |
| derivativeFn = context.getModule().lookUpFunction(minimalAttr->getJVPName()); |
| break; |
| case AutoDiffDerivativeFunctionKind::VJP: |
| assert(!minimalAttr->getVJPName().empty() && "Expected VJP name"); |
| derivativeFn = context.getModule().lookUpFunction(minimalAttr->getVJPName()); |
| break; |
| } |
| auto *derivativeFnRef = builder.createFunctionRef(loc, derivativeFn); |
| // FIXME(TF-201): Handle direct differentiation of reabstraction thunks. |
| // Tentative solution: clone a new reabstraction thunk where function |
| // argument has a `@differentiable` function type. |
| if (originalFn->isThunk() == IsReabstractionThunk) { |
| // Handle here. |
| } |
| auto convertedRef = reapplyFunctionConversion( |
| derivativeFnRef, originalFRI, original, builder, loc, |
| newBuffersToDealloc, |
| derivativeFn->getLoweredFunctionType()->getGenericSignature()); |
| return std::make_pair(convertedRef, minimalAttr->getIndices()); |
| } |
| |
| // Find witness method retrieval. |
| if (auto *witnessMethod = |
| peerThroughFunctionConversions<WitnessMethodInst>(original)) { |
| auto loc = witnessMethod->getLoc(); |
| auto requirementDeclRef = witnessMethod->getMember(); |
| auto *requirementDecl = requirementDeclRef.getDecl(); |
| auto witnessMethodType = witnessMethod->getType().castTo<SILFunctionType>(); |
| // If requirement declaration does not have any `@differentiable` |
| // attributes, produce an error. |
| if (!requirementDecl->getAttrs().hasAttribute<DifferentiableAttr>()) { |
| context.emitNondifferentiabilityError( |
| original, invoker, diag::autodiff_protocol_member_not_differentiable); |
| return None; |
| } |
| // Get the minimal `@differentiable` attribute and parameter index subset. |
| const DifferentiableAttr *minimalAttr; |
| IndexSubset *minimalParamIndexSet; |
| std::tie(minimalAttr, minimalParamIndexSet) = |
| context.lookUpMinimalASTDifferentiableAttrAndIndexSubset( |
| requirementDeclRef, witnessMethodType, desiredIndices); |
| SILAutoDiffIndices minimalIndices(/*source*/ 0, minimalParamIndexSet); |
| // If minimal `@differentiable` attribute does not exist, then no attribute |
| // exists with a superset of the desired indices. Produce an error. |
| if (!minimalAttr) { |
| context.emitNondifferentiabilityError( |
| original, invoker, |
| diag::autodiff_member_subset_indices_not_differentiable); |
| return None; |
| } |
| // Emit a `witness_method` instruction for the derivative function. |
| auto originalType = witnessMethod->getType().castTo<SILFunctionType>(); |
| auto assocType = originalType->getAutoDiffDerivativeFunctionType( |
| minimalIndices.parameters, minimalIndices.source, |
| kind, context.getTypeConverter(), |
| LookUpConformanceInModule(builder.getModule().getSwiftModule())); |
| auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get( |
| kind, minimalAttr->getParameterIndices(), context.getASTContext()); |
| auto *ref = builder.createWitnessMethod( |
| loc, witnessMethod->getLookupType(), witnessMethod->getConformance(), |
| requirementDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId), |
| SILType::getPrimitiveObjectType(assocType)); |
| auto convertedRef = |
| reapplyFunctionConversion(ref, witnessMethod, original, builder, loc, |
| newBuffersToDealloc); |
| return std::make_pair(convertedRef, minimalIndices); |
| } |
| |
| // Find class method. |
| if (auto *classMethodInst = |
| peerThroughFunctionConversions<ClassMethodInst>(original)) { |
| auto loc = classMethodInst->getLoc(); |
| auto methodDeclRef = classMethodInst->getMember(); |
| auto *methodDecl = methodDeclRef.getDecl(); |
| auto classMethodType = classMethodInst->getType().castTo<SILFunctionType>(); |
| // If method declaration does not have any `@differentiable` attributes, |
| // produce an error. |
| if (!methodDecl->getAttrs().hasAttribute<DifferentiableAttr>()) { |
| context.emitNondifferentiabilityError( |
| original, invoker, diag::autodiff_class_member_not_differentiable); |
| return None; |
| } |
| // Get the minimal `@differentiable` attribute and parameter index subset. |
| const DifferentiableAttr *minimalAttr; |
| IndexSubset *minimalParamIndexSet; |
| std::tie(minimalAttr, minimalParamIndexSet) = |
| context.lookUpMinimalASTDifferentiableAttrAndIndexSubset( |
| methodDeclRef, classMethodType, desiredIndices); |
| SILAutoDiffIndices minimalIndices(/*source*/ 0, minimalParamIndexSet); |
| // If minimal `@differentiable` attribute does not exist, then no attribute |
| // exists with a superset of the desired indices. Produce an error. |
| if (!minimalAttr) { |
| context.emitNondifferentiabilityError( |
| original, invoker, |
| diag::autodiff_member_subset_indices_not_differentiable); |
| return None; |
| } |
| // Emit a `class_method` instruction for the derivative function. |
| auto originalType = classMethodInst->getType().castTo<SILFunctionType>(); |
| auto assocType = originalType->getAutoDiffDerivativeFunctionType( |
| minimalIndices.parameters, minimalIndices.source, |
| kind, context.getTypeConverter(), |
| LookUpConformanceInModule(builder.getModule().getSwiftModule())); |
| auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get( |
| kind, minimalAttr->getParameterIndices(), |
| context.getASTContext()); |
| auto *ref = builder.createClassMethod( |
| loc, classMethodInst->getOperand(), |
| methodDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId), |
| SILType::getPrimitiveObjectType(assocType)); |
| auto convertedRef = |
| reapplyFunctionConversion(ref, classMethodInst, original, builder, loc, |
| newBuffersToDealloc); |
| return std::make_pair(convertedRef, minimalIndices); |
| } |
| |
| // Emit the general opaque function error. |
| context.emitNondifferentiabilityError(original, invoker, |
| diag::autodiff_opaque_function_not_differentiable); |
| return None; |
| } |
| |
| /// Emit a zero value into the given buffer access by calling |
| /// `AdditiveArithmetic.zero`. The given type must conform to |
| /// `AdditiveArithmetic`. |
| static void emitZeroIntoBuffer( |
| SILBuilder &builder, CanType type, SILValue bufferAccess, |
| SILLocation loc) { |
| auto &astCtx = builder.getASTContext(); |
| auto *swiftMod = builder.getModule().getSwiftModule(); |
| auto &typeConverter = builder.getModule().Types; |
| // Look up conformance to `AdditiveArithmetic`. |
| auto *additiveArithmeticProto = |
| astCtx.getProtocol(KnownProtocolKind::AdditiveArithmetic); |
| auto confRef = swiftMod->lookupConformance(type, additiveArithmeticProto); |
| assert(confRef.hasValue() && "Missing conformance to `AdditiveArithmetic`"); |
| // Look up `AdditiveArithmetic.zero.getter`. |
| auto zeroDeclLookup = additiveArithmeticProto->lookupDirect(astCtx.Id_zero); |
| auto *zeroDecl = cast<VarDecl>(zeroDeclLookup.front()); |
| assert(zeroDecl->isProtocolRequirement()); |
| auto *accessorDecl = zeroDecl->getAccessor(AccessorKind::Get); |
| SILDeclRef accessorDeclRef(accessorDecl, SILDeclRef::Kind::Func); |
| auto silFnType = typeConverter.getConstantType(accessorDeclRef); |
| // %wm = witness_method ... |
| auto *getter = builder.createWitnessMethod( |
| loc, type, *confRef, accessorDeclRef, silFnType); |
| // %metatype = metatype $T |
| auto metatypeType = CanMetatypeType::get( |
| type, MetatypeRepresentation::Thick); |
| auto metatype = builder.createMetatype( |
| loc, SILType::getPrimitiveObjectType(metatypeType)); |
| auto subMap = SubstitutionMap::getProtocolSubstitutions( |
| additiveArithmeticProto, type, *confRef); |
| builder.createApply(loc, getter, subMap, {bufferAccess, metatype}, |
| /*isNonThrowing*/ false); |
| builder.emitDestroyValueOperation(loc, getter); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Thunk helpers |
| //===----------------------------------------------------------------------===// |
| // These helpers are copied/adapted from SILGen. They should be refactored and |
| // moved to a shared location. |
| //===----------------------------------------------------------------------===// |
| |
| static CanGenericSignature |
| buildThunkSignature(SILFunction *fn, |
| bool inheritGenericSig, |
| OpenedArchetypeType *openedExistential, |
| GenericEnvironment *&genericEnv, |
| SubstitutionMap &contextSubs, |
| SubstitutionMap &interfaceSubs, |
| ArchetypeType *&newArchetype) { |
| // If there's no opened existential, we just inherit the generic environment |
| // from the parent function. |
| if (openedExistential == nullptr) { |
| auto genericSig = fn->getLoweredFunctionType()->getGenericSignature(); |
| genericEnv = fn->getGenericEnvironment(); |
| interfaceSubs = fn->getForwardingSubstitutionMap(); |
| contextSubs = interfaceSubs; |
| return genericSig; |
| } |
| |
| auto &ctx = fn->getASTContext(); |
| GenericSignatureBuilder builder(ctx); |
| |
| // Add the existing generic signature. |
| int depth = 0; |
| if (inheritGenericSig) { |
| if (auto genericSig = |
| fn->getLoweredFunctionType()->getGenericSignature()) { |
| builder.addGenericSignature(genericSig); |
| depth = genericSig->getGenericParams().back()->getDepth() + 1; |
| } |
| } |
| |
| // Add a new generic parameter to replace the opened existential. |
| auto *newGenericParam = GenericTypeParamType::get(depth, 0, ctx); |
| |
| builder.addGenericParameter(newGenericParam); |
| Requirement newRequirement(RequirementKind::Conformance, newGenericParam, |
| openedExistential->getOpenedExistentialType()); |
| auto source = |
| GenericSignatureBuilder::FloatingRequirementSource::forAbstract(); |
| builder.addRequirement(newRequirement, source, nullptr); |
| |
| auto genericSig = std::move(builder).computeGenericSignature( |
| SourceLoc(), /*allowConcreteGenericParams=*/true); |
| genericEnv = genericSig->getGenericEnvironment(); |
| |
| newArchetype = genericEnv->mapTypeIntoContext(newGenericParam) |
| ->castTo<ArchetypeType>(); |
| |
| // Calculate substitutions to map the caller's archetypes to the thunk's |
| // archetypes. |
| if (auto calleeGenericSig = |
| fn->getLoweredFunctionType()->getGenericSignature()) { |
| contextSubs = SubstitutionMap::get( |
| calleeGenericSig, |
| [&](SubstitutableType *type) -> Type { |
| return genericEnv->mapTypeIntoContext(type); |
| }, |
| MakeAbstractConformanceForGenericType()); |
| } |
| |
| // Calculate substitutions to map interface types to the caller's archetypes. |
| interfaceSubs = SubstitutionMap::get( |
| genericSig, |
| [&](SubstitutableType *type) -> Type { |
| if (type->isEqual(newGenericParam)) |
| return openedExistential; |
| return fn->mapTypeIntoContext(type); |
| }, |
| MakeAbstractConformanceForGenericType()); |
| |
| return genericSig->getCanonicalSignature(); |
| |
| } |
| |
| /// The thunk kinds used in the differentiation transform. |
| enum class DifferentiationThunkKind { |
| /// A reabstraction thunk. |
| /// |
| /// Reabstraction thunks transform a function-typed value to another one with |
| /// different parameter/result abstraction patterns. This is identical to the |
| /// thunks generated by SILGen. |
| Reabstraction, |
| |
| /// An index subset thunk. |
| /// |
| /// An index subset thunk is used transform JVP/VJPs into a version that is |
| /// "wrt" fewer differentiation parameters. |
| /// - Differentials of thunked JVPs use zero for non-requested differentiation |
| // parameters. |
| /// - Pullbacks of thunked VJPs discard results for non-requested |
| /// differentiation parameters. |
| IndexSubset |
| }; |
| |
| /// Build the type of a function transformation thunk. |
| static CanSILFunctionType buildThunkType(SILFunction *fn, |
| CanSILFunctionType &sourceType, |
| CanSILFunctionType &expectedType, |
| GenericEnvironment *&genericEnv, |
| SubstitutionMap &interfaceSubs, |
| bool withoutActuallyEscaping, |
| DifferentiationThunkKind thunkKind) { |
| assert(!expectedType->isPolymorphic()); |
| assert(!sourceType->isPolymorphic()); |
| |
| auto &module = fn->getModule(); |
| auto origType = sourceType; |
| |
| // Cannot build a reabstraction thunk without context. Ownership semantics |
| // on the result type are required. |
| if (thunkKind == DifferentiationThunkKind::Reabstraction) |
| assert(expectedType->getExtInfo().hasContext()); |
| |
| // This may inherit @noescape from the expected type. The `@noescape` |
| // attribute is only stripped when using this type to materialize a new decl. |
| // Use `@convention(thin)` if: |
| // - Building a reabstraction thunk type. |
| // - Building an index subset thunk type, where the expected type has context |
| // (i.e. is `@convention(thick)`). |
| auto extInfo = expectedType->getExtInfo(); |
| if (thunkKind == DifferentiationThunkKind::Reabstraction || |
| extInfo.hasContext()) { |
| extInfo = extInfo.withRepresentation( |
| SILFunctionType::Representation::Thin); |
| } |
| if (withoutActuallyEscaping) |
| extInfo = extInfo.withNoEscape(false); |
| |
| // Does the thunk type involve archetypes other than opened existentials? |
| bool hasArchetypes = false; |
| // Does the thunk type involve an open existential type? |
| CanOpenedArchetypeType openedExistential; |
| auto archetypeVisitor = [&](CanType t) { |
| if (auto archetypeTy = dyn_cast<OpenedArchetypeType>(t)) { |
| if (auto opened = dyn_cast<OpenedArchetypeType>(archetypeTy)) { |
| assert((openedExistential == CanArchetypeType() || |
| openedExistential == opened) && |
| "one too many open existentials"); |
| openedExistential = opened; |
| } else { |
| hasArchetypes = true; |
| } |
| } |
| }; |
| |
| // Use the generic signature from the context if the thunk involves |
| // generic parameters. |
| CanGenericSignature genericSig; |
| SubstitutionMap contextSubs; |
| ArchetypeType *newArchetype = nullptr; |
| |
| if (expectedType->hasArchetype() || sourceType->hasArchetype()) { |
| expectedType.visit(archetypeVisitor); |
| sourceType.visit(archetypeVisitor); |
| genericSig = buildThunkSignature( |
| fn, hasArchetypes, openedExistential, genericEnv, contextSubs, |
| interfaceSubs, newArchetype); |
| } |
| |
| // Utility function to apply contextSubs, and also replace the |
| // opened existential with the new archetype. |
| auto substIntoThunkContext = [&](CanType t) -> CanType { |
| return t.subst( |
| [&](SubstitutableType *type) -> Type { |
| if (CanType(type) == openedExistential) |
| return newArchetype; |
| return Type(type).subst(contextSubs); |
| }, |
| LookUpConformanceInSubstitutionMap(contextSubs), |
| SubstFlags::AllowLoweredTypes)->getCanonicalType(); |
| }; |
| |
| sourceType = cast<SILFunctionType>(substIntoThunkContext(sourceType)); |
| expectedType = cast<SILFunctionType>(substIntoThunkContext(expectedType)); |
| |
| // If our parent function was pseudogeneric, this thunk must also be |
| // pseudogeneric, since we have no way to pass generic parameters. |
| if (genericSig) |
| if (origType->isPseudogeneric()) |
| extInfo = extInfo.withIsPseudogeneric(); |
| |
| // Add the function type as the parameter. |
| auto contextConvention = |
| SILType::getPrimitiveObjectType(sourceType).isTrivial(*fn) |
| ? ParameterConvention::Direct_Unowned |
| : ParameterConvention::Direct_Guaranteed; |
| SmallVector<SILParameterInfo, 4> params; |
| params.append(expectedType->getParameters().begin(), |
| expectedType->getParameters().end()); |
| // Add reabstraction function parameter only if building a reabstraction thunk |
| // type. |
| if (thunkKind == DifferentiationThunkKind::Reabstraction) |
| params.push_back({sourceType, sourceType->getExtInfo().hasContext() |
| ? contextConvention |
| : ParameterConvention::Direct_Unowned}); |
| |
| // Map the parameter and expected types out of context to get the interface |
| // type of the thunk. |
| SmallVector<SILParameterInfo, 4> interfaceParams; |
| interfaceParams.reserve(params.size()); |
| for (auto ¶m : params) { |
| auto paramIfaceTy = param.getType()->mapTypeOutOfContext(); |
| interfaceParams.push_back(SILParameterInfo( |
| paramIfaceTy->getCanonicalType(genericSig), param.getConvention())); |
| } |
| |
| SmallVector<SILYieldInfo, 4> interfaceYields; |
| for (auto &yield : expectedType->getYields()) { |
| auto yieldIfaceTy = yield.getType()->mapTypeOutOfContext(); |
| auto interfaceYield = |
| yield.getWithType(yieldIfaceTy->getCanonicalType(genericSig)); |
| interfaceYields.push_back(interfaceYield); |
| } |
| |
| SmallVector<SILResultInfo, 4> interfaceResults; |
| for (auto &result : expectedType->getResults()) { |
| auto resultIfaceTy = result.getType()->mapTypeOutOfContext(); |
| auto interfaceResult = |
| result.getWithType(resultIfaceTy->getCanonicalType(genericSig)); |
| interfaceResults.push_back(interfaceResult); |
| } |
| |
| Optional<SILResultInfo> interfaceErrorResult; |
| if (expectedType->hasErrorResult()) { |
| auto errorResult = expectedType->getErrorResult(); |
| auto errorIfaceTy = errorResult.getType()->mapTypeOutOfContext(); |
| interfaceErrorResult = |
| SILResultInfo(errorIfaceTy->getCanonicalType(genericSig), |
| expectedType->getErrorResult().getConvention()); |
| } |
| |
| // The type of the thunk function. |
| return SILFunctionType::get( |
| genericSig, extInfo, expectedType->getCoroutineKind(), |
| ParameterConvention::Direct_Unowned, interfaceParams, interfaceYields, |
| interfaceResults, interfaceErrorResult, module.getASTContext()); |
| } |
| |
| /// Get or create a reabstraction thunk from `fromType` to `toType`, to be |
| /// called in `caller`. |
| static SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb, |
| SILModule &module, |
| SILLocation loc, |
| SILFunction *caller, |
| CanSILFunctionType fromType, |
| CanSILFunctionType toType) { |
| SubstitutionMap interfaceSubs; |
| GenericEnvironment *genericEnv = nullptr; |
| auto thunkType = buildThunkType( |
| caller, fromType, toType, genericEnv, interfaceSubs, |
| /*withoutActuallyEscaping*/ false, |
| DifferentiationThunkKind::Reabstraction); |
| auto thunkDeclType = |
| thunkType->getWithExtInfo(thunkType->getExtInfo().withNoEscape(false)); |
| |
| auto fromInterfaceType = fromType->mapTypeOutOfContext()->getCanonicalType(); |
| auto toInterfaceType = toType->mapTypeOutOfContext()->getCanonicalType(); |
| |
| Mangle::ASTMangler mangler; |
| std::string name = mangler.mangleReabstractionThunkHelper( |
| thunkType, fromInterfaceType, toInterfaceType, |
| Type(), module.getSwiftModule()); |
| |
| auto *thunk = fb.getOrCreateSharedFunction( |
| loc, name, thunkDeclType, IsBare, IsTransparent, IsSerialized, |
| ProfileCounter(), IsReabstractionThunk, IsNotDynamic); |
| if (!thunk->empty()) |
| return thunk; |
| |
| thunk->setGenericEnvironment(genericEnv); |
| thunk->setOwnershipEliminated(); |
| auto *entry = thunk->createBasicBlock(); |
| SILBuilder builder(entry); |
| createEntryArguments(thunk); |
| |
| SILFunctionConventions fromConv(fromType, module); |
| SILFunctionConventions toConv(toType, module); |
| assert(toConv.useLoweredAddresses()); |
| |
| auto *fnArg = thunk->getArgumentsWithoutIndirectResults().back(); |
| |
| SmallVector<SILValue, 4> arguments; |
| auto toArgIter = thunk->getArguments().begin(); |
| auto useNextArgument = [&]() { |
| arguments.push_back(*toArgIter++); |
| }; |
| |
| SmallVector<AllocStackInst *, 4> localAllocations; |
| auto createAllocStack = [&](SILType type) { |
| auto *alloc = builder.createAllocStack(loc, type); |
| localAllocations.push_back(alloc); |
| return alloc; |
| }; |
| |
| // Handle indirect results. |
| assert(fromType->getNumResults() == toType->getNumResults()); |
| for (unsigned resIdx : range(toType->getNumResults())) { |
| auto fromRes = fromConv.getResults()[resIdx]; |
| auto toRes = toConv.getResults()[resIdx]; |
| // No abstraction mismatch. |
| if (fromRes.isFormalIndirect() == toRes.isFormalIndirect()) { |
| // If result types are indirect, directly pass as next argument. |
| if (toRes.isFormalIndirect()) |
| useNextArgument(); |
| continue; |
| } |
| // Convert indirect result to direct result. |
| if (fromRes.isFormalIndirect()) { |
| SILType resultTy = fromConv.getSILType(fromRes); |
| assert(resultTy.isAddress()); |
| auto *indRes = createAllocStack(resultTy); |
| arguments.push_back(indRes); |
| continue; |
| } |
| // Convert direct result to indirect result. |
| // Increment thunk argument iterator; reabstraction handled later. |
| toArgIter++; |
| } |
| |
| // Reabstract parameters. |
| assert(toType->getNumParameters() == fromType->getNumParameters()); |
| for (unsigned paramIdx : range(toType->getNumParameters())) { |
| auto fromParam = fromConv.getParameters()[paramIdx]; |
| auto toParam = toConv.getParameters()[paramIdx]; |
| // No abstraction mismatch. Directly use next argument. |
| if (fromParam.isFormalIndirect() == toParam.isFormalIndirect()) { |
| useNextArgument(); |
| continue; |
| } |
| // Convert indirect parameter to direct parameter. |
| if (fromParam.isFormalIndirect()) { |
| auto paramTy = fromConv.getSILType(fromType->getParameters()[paramIdx]); |
| if (!paramTy.hasArchetype()) |
| paramTy = thunk->mapTypeIntoContext(paramTy); |
| assert(paramTy.isAddress()); |
| auto *toArg = *toArgIter++; |
| auto *buf = createAllocStack(toArg->getType()); |
| builder.createStore(loc, toArg, buf, |
| StoreOwnershipQualifier::Unqualified); |
| arguments.push_back(buf); |
| continue; |
| } |
| // Convert direct parameter to indirect parameter. |
| assert(toParam.isFormalIndirect()); |
| auto *toArg = *toArgIter++; |
| auto *load = builder.createLoad(loc, toArg, |
| LoadOwnershipQualifier::Unqualified); |
| arguments.push_back(load); |
| } |
| |
| auto *apply = builder.createApply( |
| loc, fnArg, SubstitutionMap(), arguments, /*isNonThrowing*/ false); |
| |
| // Get return elements. |
| SmallVector<SILValue, 4> results; |
| // Extract all direct results. |
| SmallVector<SILValue, 4> directResults; |
| extractAllElements(apply, builder, directResults); |
| |
| auto fromDirResultsIter = directResults.begin(); |
| auto fromIndResultsIter = apply->getIndirectSILResults().begin(); |
| auto toIndResultsIter = thunk->getIndirectResults().begin(); |
| // Reabstract results. |
| for (unsigned resIdx : range(toType->getNumResults())) { |
| auto fromRes = fromConv.getResults()[resIdx]; |
| auto toRes = toConv.getResults()[resIdx]; |
| // No abstraction mismatch. |
| if (fromRes.isFormalIndirect() == toRes.isFormalIndirect()) { |
| // If result types are direct, add call result as direct thunk result. |
| if (toRes.isFormalDirect()) |
| results.push_back(*fromDirResultsIter++); |
| // If result types are indirect, increment indirect result iterators. |
| else { |
| ++fromIndResultsIter; |
| ++toIndResultsIter; |
| } |
| continue; |
| } |
| // Load direct results from indirect results. |
| if (fromRes.isFormalIndirect()) { |
| auto indRes = *fromIndResultsIter++; |
| auto *load = builder.createLoad(loc, indRes, |
| LoadOwnershipQualifier::Unqualified); |
| results.push_back(load); |
| continue; |
| } |
| // Store direct results to indirect results. |
| assert(toRes.isFormalIndirect()); |
| SILType resultTy = toConv.getSILType(toRes); |
| assert(resultTy.isAddress()); |
| auto indRes = *toIndResultsIter++; |
| builder.createStore(loc, *fromDirResultsIter++, indRes, |
| StoreOwnershipQualifier::Unqualified); |
| } |
| auto retVal = joinElements(results, builder, loc); |
| |
| // Deallocate local allocations. |
| for (auto *alloc : reversed(localAllocations)) |
| builder.createDeallocStack(loc, alloc); |
| |
| // Create return. |
| builder.createReturn(loc, retVal); |
| |
| LLVM_DEBUG(auto &s = getADDebugStream() << "Created reabstraction thunk.\n"; |
| s << " From type: " << fromType << '\n'; |
| s << " To type: " << toType << '\n'; |
| s << '\n' << *thunk); |
| |
| return thunk; |
| } |
| |
| namespace { |
| class VJPEmitter final |
| : public TypeSubstCloner<VJPEmitter, SILOptFunctionBuilder> { |
| friend class PullbackEmitter; |
| |
| private: |
| /// The global context. |
| ADContext &context; |
| |
| /// The original function. |
| SILFunction *const original; |
| |
| /// The `[differentiable]` attribute. |
| SILDifferentiableAttr *const attr; |
| |
| /// The VJP function. |
| SILFunction *const vjp; |
| |
| /// The pullback function. |
| SILFunction *pullback; |
| |
| /// The differentiation invoker. |
| DifferentiationInvoker invoker; |
| |
| /// Info from activity analysis on the original function. |
| const DifferentiableActivityInfo &activityInfo; |
| |
| /// The linear map info. |
| LinearMapInfo pullbackInfo; |
| |
| /// Caches basic blocks whose phi arguments have been remapped (adding a |
| /// predecessor enum argument). |
| SmallPtrSet<SILBasicBlock *, 4> remappedBasicBlocks; |
| |
| /// A pair of a trampoline block phi argument and its corresponding |
| /// destination block phi argument. |
| struct TrampolinedArgumentPair { |
| SILPhiArgument *trampolineArgument; |
| SILPhiArgument *destinationArgument; |
| }; |
| /// An array that keeps track of all `@guaranteed` phi arguments in any |
| /// trampoline blocks we've added. Each of these arguments needs to have a |
| /// lifetime-ending use past its destination argument's lifetime-ending use, |
| /// so we keep track of these pairs of arguments and emit `end_borrow`s when |
| /// function cloning is finished. |
| SmallVector<TrampolinedArgumentPair, 8> trampolinedGuaranteedPhiArguments; |
| |
| bool errorOccurred = false; |
| |
| /// Mapping from original blocks to pullback values. Used to build pullback |
| /// struct instances. |
| DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> pullbackValues; |
| |
| ASTContext &getASTContext() const { return vjp->getASTContext(); } |
| SILModule &getModule() const { return vjp->getModule(); } |
| const SILAutoDiffIndices &getIndices() const { return attr->getIndices(); } |
| |
| static SubstitutionMap getSubstitutionMap(SILFunction *original, |
| SILFunction *vjp) { |
| auto substMap = original->getForwardingSubstitutionMap(); |
| if (auto *vjpGenEnv = vjp->getGenericEnvironment()) { |
| auto vjpSubstMap = vjpGenEnv->getForwardingSubstitutionMap(); |
| substMap = SubstitutionMap::get( |
| vjpGenEnv->getGenericSignature(), QuerySubstitutionMap{vjpSubstMap}, |
| LookUpConformanceInSubstitutionMap(vjpSubstMap)); |
| } |
| return substMap; |
| } |
| |
| static const DifferentiableActivityInfo &getActivityInfo( |
| ADContext &context, SILFunction *original, |
| const SILAutoDiffIndices &indices, SILFunction *vjp) { |
| // Get activity info of the original function. |
| auto &passManager = context.getPassManager(); |
| auto *activityAnalysis = |
| passManager.getAnalysis<DifferentiableActivityAnalysis>(); |
| auto &activityCollection = *activityAnalysis->get(original); |
| auto &activityInfo = activityCollection.getActivityInfo( |
| vjp->getLoweredFunctionType()->getGenericSignature(), |
| AutoDiffDerivativeFunctionKind::VJP); |
| LLVM_DEBUG( |
| dumpActivityInfo(*original, indices, activityInfo, getADDebugStream())); |
| return activityInfo; |
| } |
| |
| public: |
| explicit VJPEmitter(ADContext &context, SILFunction *original, |
| SILDifferentiableAttr *attr, SILFunction *vjp, |
| DifferentiationInvoker invoker) |
| : TypeSubstCloner(*vjp, *original, getSubstitutionMap(original, vjp)), |
| context(context), original(original), attr(attr), vjp(vjp), |
| invoker(invoker), activityInfo(getActivityInfo( |
| context, original, attr->getIndices(), vjp)), |
| pullbackInfo(context, AutoDiffLinearMapKind::Pullback, original, vjp, |
| attr->getIndices(), activityInfo) { |
| // Create empty pullback function. |
| pullback = createEmptyPullback(); |
| context.getGeneratedFunctions().push_back(pullback); |
| } |
| |
| SILFunction *createEmptyPullback() { |
| auto &module = context.getModule(); |
| auto origTy = original->getLoweredFunctionType(); |
| auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule()); |
| |
| // RAII that pushes the original function's generic signature to |
| // `module.Types` so that the calls to `module.Types.getTypeLowering()` |
| // below will know the original function's generic parameter types. |
| Lowering::GenericContextScope genericContextScope( |
| module.Types, origTy->getGenericSignature()); |
| |
| // Given a type, returns its formal SIL parameter info. |
| auto getTangentParameterInfoForOriginalResult = [&]( |
| CanType tanType, ResultConvention origResConv) -> SILParameterInfo { |
| auto &tl = context.getTypeConverter().getTypeLowering( |
| tanType, ResilienceExpansion::Minimal); |
| ParameterConvention conv; |
| switch (origResConv) { |
| case ResultConvention::Owned: |
| case ResultConvention::Autoreleased: |
| conv = tl.isTrivial() |
| ? ParameterConvention::Direct_Unowned |
| : ParameterConvention::Direct_Guaranteed; |
| break; |
| case ResultConvention::Unowned: |
| case ResultConvention::UnownedInnerPointer: |
| conv = ParameterConvention::Direct_Unowned; |
| break; |
| case ResultConvention::Indirect: |
| conv = ParameterConvention::Indirect_In_Guaranteed; |
| break; |
| } |
| return {tanType, conv}; |
| }; |
| |
| // Given a type, returns its formal SIL result info. |
| auto getTangentResultInfoForOriginalParameter = [&]( |
| CanType tanType, ParameterConvention origParamConv) -> SILResultInfo { |
| auto &tl = context.getTypeConverter().getTypeLowering( |
| tanType, ResilienceExpansion::Minimal); |
| ResultConvention conv; |
| switch (origParamConv) { |
| case ParameterConvention::Direct_Owned: |
| case ParameterConvention::Direct_Guaranteed: |
| case ParameterConvention::Direct_Unowned: |
| conv = tl.isTrivial() |
| ? ResultConvention::Unowned |
| : ResultConvention::Owned; |
| break; |
| case ParameterConvention::Indirect_In: |
| case ParameterConvention::Indirect_Inout: |
| case ParameterConvention::Indirect_In_Constant: |
| case ParameterConvention::Indirect_In_Guaranteed: |
| case ParameterConvention::Indirect_InoutAliasable: |
| conv = ResultConvention::Indirect; |
| break; |
| } |
| return {tanType, conv}; |
| }; |
| |
| // Parameters of the pullback are: |
| // - the tangent vectors of the original results, and |
| // - a pullback struct. |
| // Results of the pullback are in the tangent space of the original |
| // parameters. |
| SmallVector<SILParameterInfo, 8> pbParams; |
| SmallVector<SILResultInfo, 8> adjResults; |
| auto origParams = origTy->getParameters(); |
| auto indices = attr->getIndices(); |
| |
| // Add pullback parameter for the seed. |
| auto origResInfo = origTy->getResults()[indices.source]; |
| pbParams.push_back(getTangentParameterInfoForOriginalResult( |
| origResInfo.getType() |
| ->getAutoDiffAssociatedTangentSpace(lookupConformance) |
| ->getCanonicalType(), origResInfo.getConvention())); |
| |
| // Accept a pullback struct in the pullback parameter list. This is the |
| // returned pullback's closure context. |
| auto *origExit = &*original->findReturnBB(); |
| auto *pbStruct = pullbackInfo.getLinearMapStruct(origExit); |
| auto pbStructType = pbStruct->getDeclaredInterfaceType() |
| ->getCanonicalType(); |
| pbParams.push_back({pbStructType, ParameterConvention::Direct_Owned}); |
| |
| // Add pullback results for the requested wrt parameters. |
| for (auto i : indices.parameters->getIndices()) { |
| auto origParam = origParams[i]; |
| adjResults.push_back(getTangentResultInfoForOriginalParameter( |
| origParam.getType() |
| ->getAutoDiffAssociatedTangentSpace(lookupConformance) |
| ->getCanonicalType(), origParam.getConvention())); |
| } |
| |
| Mangle::ASTMangler mangler; |
| auto pbName = original->getASTContext().getIdentifier( |
| mangler.mangleAutoDiffLinearMapHelper( |
| original->getName(), AutoDiffLinearMapKind::Pullback, |
| indices)).str(); |
| auto pbGenericSig = getDerivativeGenericSignature(attr, original); |
| auto *pbGenericEnv = |
| pbGenericSig ? pbGenericSig->getGenericEnvironment() : nullptr; |
| auto pbType = SILFunctionType::get( |
| pbGenericSig, origTy->getExtInfo(), origTy->getCoroutineKind(), |
| origTy->getCalleeConvention(), pbParams, {}, adjResults, None, |
| original->getASTContext()); |
| |
| SILOptFunctionBuilder fb(context.getTransform()); |
| // The generated pullback linkage is set to Hidden because generated |
| // pullbacks are never called cross-module. |
| auto linkage = SILLinkage::Hidden; |
| auto *pullback = fb.createFunction( |
| linkage, pbName, pbType, pbGenericEnv, original->getLocation(), |
| original->isBare(), IsNotTransparent, original->isSerialized(), |
| original->isDynamicallyReplaceable()); |
| pullback->setDebugScope(new (module) |
| SILDebugScope(original->getLocation(), |
| pullback)); |
| return pullback; |
| } |
| |
| /// Run VJP generation. Returns true on error. |
| bool run(); |
| |
| void postProcess(SILInstruction *orig, SILInstruction *cloned) { |
| if (errorOccurred) |
| return; |
| SILClonerWithScopes::postProcess(orig, cloned); |
| } |
| |
| /// Remap original basic blocks, adding predecessor enum arguments. |
| SILBasicBlock *remapBasicBlock(SILBasicBlock *bb) { |
| auto *vjpBB = BBMap[bb]; |
| // If error has occurred, or if block has already been remapped, return |
| // remapped, return remapped block. |
| if (errorOccurred || remappedBasicBlocks.count(bb)) |
| return vjpBB; |
| // Add predecessor enum argument to the remapped block. |
| auto *predEnum = pullbackInfo.getBranchingTraceDecl(bb); |
| auto enumTy = getOpASTType(predEnum->getDeclaredInterfaceType() |
| ->getCanonicalType()); |
| auto enumLoweredTy = context.getTypeConverter().getLoweredType( |
| enumTy, ResilienceExpansion::Minimal); |
| vjpBB->createPhiArgument(enumLoweredTy, ValueOwnershipKind::Owned); |
| remappedBasicBlocks.insert(bb); |
| return vjpBB; |
| } |
| |
| /// General visitor for all instructions. If any error is emitted by previous |
| /// visits, bail out. |
| void visit(SILInstruction *inst) { |
| if (errorOccurred) |
| return; |
| TypeSubstCloner::visit(inst); |
| } |
| |
| void visitSILInstruction(SILInstruction *inst) { |
| context.emitNondifferentiabilityError(inst, invoker, |
| diag::autodiff_expression_not_differentiable_note); |
| errorOccurred = true; |
| } |
| |
| private: |
| /// Get the lowered SIL type of the given nominal type declaration. |
| SILType getNominalDeclLoweredType(NominalTypeDecl *nominal) { |
| auto nomType = getOpASTType( |
| nominal->getDeclaredInterfaceType()->getCanonicalType()); |
| auto nomSILType = context.getTypeConverter().getLoweredType( |
| nomType, ResilienceExpansion::Minimal); |
| return nomSILType; |
| } |
| |
| /// Build a pullback struct value for the original block corresponding to the |
| /// given terminator. |
| StructInst *buildPullbackValueStructValue(TermInst *termInst) { |
| assert(termInst->getFunction() == original); |
| auto loc = termInst->getFunction()->getLocation(); |
| auto *origBB = termInst->getParent(); |
| auto *vjpBB = BBMap[origBB]; |
| auto *pbStruct = pullbackInfo.getLinearMapStruct(origBB); |
| auto structLoweredTy = getNominalDeclLoweredType(pbStruct); |
| auto bbPullbackValues = pullbackValues[origBB]; |
| if (!origBB->isEntry()) { |
| auto *predEnumArg = vjpBB->getArguments().back(); |
| bbPullbackValues.insert(bbPullbackValues.begin(), predEnumArg); |
| } |
| return getBuilder().createStruct(loc, structLoweredTy, bbPullbackValues); |
| } |
| |
| /// Build a predecessor enum instance using the given builder for the given |
| /// original predecessor/successor blocks and pullback struct value. |
| EnumInst *buildPredecessorEnumValue(SILBuilder &builder, |
| SILBasicBlock *predBB, |
| SILBasicBlock *succBB, |
| SILValue pbStructVal) { |
| auto loc = pbStructVal.getLoc(); |
| auto *succEnum = pullbackInfo.getBranchingTraceDecl(succBB); |
| auto enumLoweredTy = getNominalDeclLoweredType(succEnum); |
| auto *enumEltDecl = |
| pullbackInfo.lookUpBranchingTraceEnumElement(predBB, succBB); |
| auto enumEltType = getOpType( |
| enumLoweredTy.getEnumElementType(enumEltDecl, getModule())); |
| // If the enum element type does not have a box type (i.e. the enum case is |
| // not indirect), then directly create an enum. |
| auto boxType = dyn_cast<SILBoxType>(enumEltType.getASTType()); |
| if (!boxType) |
| return builder.createEnum(loc, pbStructVal, enumEltDecl, enumLoweredTy); |
| // Otherwise, box the pullback struct value and create an enum. |
| auto *newBox = builder.createAllocBox(loc, boxType); |
| builder.emitScopedBorrowOperation( |
| loc, newBox, [&](SILValue borrowedBox) { |
| auto *projectBox = builder.createProjectBox(loc, newBox, /*index*/ 0); |
| builder.emitStoreValueOperation(loc, pbStructVal, projectBox, |
| StoreOwnershipQualifier::Init); |
| }); |
| return builder.createEnum(loc, newBox, enumEltDecl, enumLoweredTy); |
| } |
| |
| public: |
| void visitReturnInst(ReturnInst *ri) { |
| auto loc = ri->getOperand().getLoc(); |
| auto *origExit = ri->getParent(); |
| auto &builder = getBuilder(); |
| auto *pbStructVal = buildPullbackValueStructValue(ri); |
| |
| // Get the value in the VJP corresponding to the original result. |
| auto *origRetInst = cast<ReturnInst>(origExit->getTerminator()); |
| auto origResult = getOpValue(origRetInst->getOperand()); |
| SmallVector<SILValue, 8> origResults; |
| extractAllElements(origResult, builder, origResults); |
| |
| // Get and partially apply the pullback. |
| auto vjpGenericEnv = vjp->getGenericEnvironment(); |
| auto vjpSubstMap = vjpGenericEnv |
| ? vjpGenericEnv->getForwardingSubstitutionMap() |
| : vjp->getForwardingSubstitutionMap(); |
| auto *pullbackRef = builder.createFunctionRef(loc, pullback); |
| auto *pullbackPartialApply = builder.createPartialApply( |
| loc, pullbackRef, vjpSubstMap, {pbStructVal}, |
| ParameterConvention::Direct_Guaranteed); |
| |
| // Return a tuple of the original result and pullback. |
| SmallVector<SILValue, 8> directResults; |
| directResults.append(origResults.begin(), origResults.end()); |
| directResults.push_back(pullbackPartialApply); |
| builder.createReturn( |
| ri->getLoc(), joinElements(directResults, builder, loc)); |
| } |
| |
| void visitBranchInst(BranchInst *bi) { |
| // Build pullback struct value for original block. |
| // Build predecessor enum value for destination block. |
| auto *origBB = bi->getParent(); |
| auto *pbStructVal = buildPullbackValueStructValue(bi); |
| auto *enumVal = buildPredecessorEnumValue( |
| getBuilder(), origBB, bi->getDestBB(), pbStructVal); |
| |
| // Remap arguments, appending the new enum values. |
| SmallVector<SILValue, 8> args; |
| for (auto origArg : bi->getArgs()) |
| args.push_back(getOpValue(origArg)); |
| args.push_back(enumVal); |
| |
| // Create a new `br` instruction. |
| getBuilder().createBranch( |
| bi->getLoc(), getOpBasicBlock(bi->getDestBB()), args); |
| } |
| |
| void visitCondBranchInst(CondBranchInst *cbi) { |
| // Build pullback struct value for original block. |
| // Build predecessor enum values for true/false blocks. |
| auto *origBB = cbi->getParent(); |
| auto *pbStructVal = buildPullbackValueStructValue(cbi); |
| |
| // Creates a trampoline block for given original successor block. The |
| // trampoline block has the same arguments as the VJP successor block but |
| // drops the last predecessor enum argument. The generated `switch_enum` |
| // instruction branches to the trampoline block, and the trampoline block |
| // constructs a predecessor enum value and branches to the VJP successor |
| // block. |
| auto createTrampolineBasicBlock = |
| [&](SILBasicBlock *origSuccBB) -> SILBasicBlock * { |
| auto *vjpSuccBB = getOpBasicBlock(origSuccBB); |
| // Create the trampoline block. |
| auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB); |
| for (auto *arg : vjpSuccBB->getArguments().drop_back()) |
| trampolineBB->createPhiArgument(arg->getType(), |
| arg->getOwnershipKind()); |
| // Build predecessor enum value for successor block and branch to it. |
| SILBuilder trampolineBuilder(trampolineBB); |
| auto *succEnumVal = buildPredecessorEnumValue( |
| trampolineBuilder, origBB, origSuccBB, pbStructVal); |
| SmallVector<SILValue, 4> forwardedArguments( |
| trampolineBB->getArguments().begin(), |
| trampolineBB->getArguments().end()); |
| forwardedArguments.push_back(succEnumVal); |
| trampolineBuilder.createBranch(cbi->getLoc(), vjpSuccBB, |
| forwardedArguments); |
| return trampolineBB; |
| }; |
| |
| // Create a new `cond_br` instruction. |
| getBuilder().createCondBranch( |
| cbi->getLoc(), getOpValue(cbi->getCondition()), |
| createTrampolineBasicBlock(cbi->getTrueBB()), |
| createTrampolineBasicBlock(cbi->getFalseBB())); |
| } |
| |
| void visitSwitchEnumInst(SwitchEnumInst *sei) { |
| // Build pullback struct value for original block. |
| auto *origBB = sei->getParent(); |
| auto *pbStructVal = buildPullbackValueStructValue(sei); |
| |
| // Creates a trampoline block for given original successor block. The |
| // trampoline block has the same arguments as the VJP successor block but |
| // drops the last predecessor enum argument. The generated `switch_enum` |
| // instruction branches to the trampoline block, and the trampoline block |
| // constructs a predecessor enum value and branches to the VJP successor |
| // block. |
| auto createTrampolineBasicBlock = |
| [&](SILBasicBlock *origSuccBB) -> SILBasicBlock * { |
| auto *vjpSuccBB = getOpBasicBlock(origSuccBB); |
| // Create the trampoline block. |
| auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB); |
| for (auto *destArg : vjpSuccBB->getArguments().drop_back()) { |
| auto *trampolineArg = trampolineBB->createPhiArgument( |
| destArg->getType(), destArg->getOwnershipKind()); |
| // Each `@guaranteed` trampoline argument needs to have a |
| // lifetime-ending use past its destination argument's lifetime-ending |
| // uses, so we keep track of these pairs of arguments in |
| // `trampolinedGuaranteedPhiArguments` and emit `end_borrow`s when |
| // function cloning is finished. |
| if (trampolineArg->getOwnershipKind() == ValueOwnershipKind::Guaranteed) |
| trampolinedGuaranteedPhiArguments.push_back( |
| {trampolineArg, cast<SILPhiArgument>(destArg)}); |
| } |
| // Build predecessor enum value for successor block and branch to it. |
| SILBuilder trampolineBuilder(trampolineBB); |
| auto *succEnumVal = buildPredecessorEnumValue( |
| trampolineBuilder, origBB, origSuccBB, pbStructVal); |
| SmallVector<SILValue, 4> forwardedArguments( |
| trampolineBB->getArguments().begin(), |
| trampolineBB->getArguments().end()); |
| forwardedArguments.push_back(succEnumVal); |
| trampolineBuilder.createBranch(sei->getLoc(), vjpSuccBB, |
| forwardedArguments); |
| return trampolineBB; |
| }; |
| |
| // Create trampoline successor basic blocks. |
| SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4> caseBBs; |
| for (unsigned i : range(sei->getNumCases())) { |
| auto caseBB = sei->getCase(i); |
| auto *trampolineBB = createTrampolineBasicBlock(caseBB.second); |
| caseBBs.push_back({caseBB.first, trampolineBB}); |
| } |
| // Create trampoline default basic block. |
| SILBasicBlock *newDefaultBB = nullptr; |
| if (auto *defaultBB = sei->getDefaultBBOrNull().getPtrOrNull()) |
| newDefaultBB = createTrampolineBasicBlock(defaultBB); |
| |
| // Create a new `switch_enum` instruction. |
| getBuilder().createSwitchEnum( |
| sei->getLoc(), getOpValue(sei->getOperand()), newDefaultBB, caseBBs); |
| } |
| |
| // If an `apply` has active results or active inout parameters, replace it |
| // with an `apply` of its VJP. |
| void visitApplyInst(ApplyInst *ai) { |
| // If the function should not be differentiated or its the array literal |
| // initialization intrinsic, just do standard cloning. |
| if (!pullbackInfo.shouldDifferentiateApplyInst(ai) || |
| isArrayLiteralIntrinsic(ai)) { |
| LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n'); |
| TypeSubstCloner::visitApplyInst(ai); |
| return; |
| } |
| |
| // Check and reject functions with active inout arguments. It's not yet |
| // supported. |
| auto paramInfos = ai->getSubstCalleeConv().getParameters(); |
| auto paramArgs = ai->getArgumentsWithoutIndirectResults(); |
| for (unsigned i : swift::indices(paramInfos)) { |
| if (paramInfos[i].isIndirectInOut() && |
| activityInfo.isActive(paramArgs[i], getIndices())) { |
| context.emitNondifferentiabilityError(ai, invoker, |
| diag::autodiff_cannot_differentiate_through_inout_arguments); |
| errorOccurred = true; |
| return; |
| } |
| } |
| |
| LLVM_DEBUG(getADDebugStream() << "VJP-transforming:\n" << *ai << '\n'); |
| |
| // Get the minimal parameter and result indices required for differentiating |
| // this `apply`. |
| SmallVector<SILValue, 4> allResults; |
| SmallVector<unsigned, 8> activeParamIndices; |
| SmallVector<unsigned, 8> activeResultIndices; |
| collectMinimalIndicesForFunctionCall(ai, getIndices(), activityInfo, |
| allResults, activeParamIndices, |
| activeResultIndices); |
| assert(!activeParamIndices.empty() && "Parameter indices cannot be empty"); |
| assert(!activeResultIndices.empty() && "Result indices cannot be empty"); |
| LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params={"; |
| interleave(activeParamIndices.begin(), activeParamIndices.end(), |
| [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); |
| s << "}, results={"; interleave( |
| activeResultIndices.begin(), activeResultIndices.end(), |
| [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); |
| s << "}\n";); |
| // FIXME: We don't support multiple active results yet. |
| if (activeResultIndices.size() > 1) { |
| context.emitNondifferentiabilityError( |
| ai, invoker, diag::autodiff_expression_not_differentiable_note); |
| errorOccurred = true; |
| return; |
| } |
| |
| // Form expected indices, assuming there's only one result. |
| SILAutoDiffIndices indices( |
| activeResultIndices.front(), |
| IndexSubset::get( |
| getASTContext(), ai->getArgumentsWithoutIndirectResults().size(), |
| activeParamIndices)); |
| |
| // Emit the VJP. |
| auto loc = ai->getLoc(); |
| auto &builder = getBuilder(); |
| auto original = getOpValue(ai->getCallee()); |
| SILValue vjpValue; |
| // If functionSource is a `@differentiable` function, just extract it. |
| auto originalFnTy = original->getType().castTo<SILFunctionType>(); |
| if (originalFnTy->isDifferentiable()) { |
| auto paramIndices = originalFnTy->getDifferentiationParameterIndices(); |
| for (auto i : indices.parameters->getIndices()) { |
| if (!paramIndices->contains(i)) { |
| context.emitNondifferentiabilityError(original, invoker, |
| diag::autodiff_function_nondiff_parameter_not_differentiable); |
| errorOccurred = true; |
| return; |
| } |
| } |
| auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, original); |
| vjpValue = builder.createDifferentiableFunctionExtract( |
| loc, NormalDifferentiableFunctionTypeComponent::VJP, |
| borrowedDiffFunc); |
| vjpValue = builder.emitCopyValueOperation(loc, vjpValue); |
| } |
| |
| // Check and diagnose non-differentiable original function type. |
| auto diagnoseNondifferentiableOriginalFunctionType = |
| [&](CanSILFunctionType origFnTy) { |
| // Check and diagnose non-differentiable arguments. |
| for (unsigned paramIndex : range(originalFnTy->getNumParameters())) { |
| if (indices.isWrtParameter(paramIndex) && |
| !originalFnTy->getParameters()[paramIndex] |
| .getSILStorageType() |
| .isDifferentiable(getModule())) { |
| context.emitNondifferentiabilityError( |
| ai->getArgumentsWithoutIndirectResults()[paramIndex], invoker, |
| diag::autodiff_nondifferentiable_argument); |
| errorOccurred = true; |
| return true; |
| } |
| } |
| // Check and diagnose non-differentiable results. |
| if (!originalFnTy->getResults()[indices.source] |
| .getSILStorageType() |
| .isDifferentiable(getModule())) { |
| context.emitNondifferentiabilityError( |
| original, invoker, diag::autodiff_nondifferentiable_result); |
| errorOccurred = true; |
| return true; |
| } |
| return false; |
| }; |
| if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy)) |
| return; |
| |
| // If VJP has not yet been found, emit an `differentiable_function` |
| // instruction on the remapped original function operand and |
| // an `differentiable_function_extract` instruction to get the VJP. |
| // The `differentiable_function` instruction will be canonicalized during |
| // the transform main loop. |
| if (!vjpValue) { |
| // FIXME: Handle indirect differentiation invokers. This may require some |
| // redesign: currently, each original function + attribute pair is mapped |
| // only to one invoker. |
| /* |
| DifferentiationInvoker indirect(ai, attr); |
| auto insertion = |
| context.getInvokers().try_emplace({this->original, attr}, indirect); |
| auto &invoker = insertion.first->getSecond(); |
| invoker = indirect; |
| */ |
| |
| // If the original `apply` instruction has a substitution map, then the |
| // applied function is specialized. |
| // In the VJP, specialization is also necessary for parity. The original |
| // function operand is specialized with a remapped version of same |
| // substitution map using an argument-less `partial_apply`. |
| if (ai->getSubstitutionMap().empty()) { |
| original = builder.emitCopyValueOperation(loc, original); |
| } else { |
| auto substMap = getOpSubstitutionMap(ai->getSubstitutionMap()); |
| auto vjpPartialApply = getBuilder().createPartialApply( |
| ai->getLoc(), original, substMap, {}, |
| ParameterConvention::Direct_Guaranteed); |
| original = vjpPartialApply; |
| originalFnTy = original->getType().castTo<SILFunctionType>(); |
| // Diagnose if new original function type is non-differentiable. |
| if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy)) |
| return; |
| } |
| |
| auto *diffFuncInst = context.createDifferentiableFunction( |
| getBuilder(), loc, indices.parameters, original); |
| |
| // Record the `differentiable_function` instruction. |
| context.getDifferentiableFunctionInsts().push_back(diffFuncInst); |
| // TODO(TF-689): Make `differentiable_function` store result indices and |
| // remove `ADContext::resultIndices`. |
| context.getResultIndices()[diffFuncInst] = activeResultIndices.front(); |
| |
| auto borrowedADFunc = |
| builder.emitBeginBorrowOperation(loc, diffFuncInst); |
| auto extractedVJP = getBuilder().createDifferentiableFunctionExtract( |
| loc, NormalDifferentiableFunctionTypeComponent::VJP, |
| borrowedADFunc); |
| vjpValue = builder.emitCopyValueOperation(loc, extractedVJP); |
| builder.emitEndBorrowOperation(loc, borrowedADFunc); |
| builder.emitDestroyValueOperation(loc, diffFuncInst); |
| } |
| |
| // Record desired/actual VJP indices. |
| // Temporarily set original pullback type to `None`. |
| NestedApplyInfo info{indices, /*originalPullbackType*/ None}; |
| auto insertion = context.getNestedApplyInfo().try_emplace(ai, info); |
| auto &nestedApplyInfo = insertion.first->getSecond(); |
| nestedApplyInfo = info; |
| |
| // Call the VJP using the original parameters. |
| SmallVector<SILValue, 8> vjpArgs; |
| auto vjpFnTy = getOpType(vjpValue->getType()).castTo<SILFunctionType>(); |
| auto numVJPArgs = |
| vjpFnTy->getNumParameters() + vjpFnTy->getNumIndirectFormalResults(); |
| vjpArgs.reserve(numVJPArgs); |
| // Collect substituted arguments. |
| for (auto origArg : ai->getArguments()) |
| vjpArgs.push_back(getOpValue(origArg)); |
| assert(vjpArgs.size() == numVJPArgs); |
| // Apply the VJP. |
| // The VJP should be specialized, so no substitution map is necessary. |
| auto *vjpCall = getBuilder().createApply(loc, vjpValue, SubstitutionMap(), |
| vjpArgs, ai->isNonThrowing()); |
| LLVM_DEBUG(getADDebugStream() << "Applied vjp function\n" << *vjpCall); |
| builder.emitDestroyValueOperation(loc, vjpValue); |
| |
| // Get the VJP results (original results and pullback). |
| SmallVector<SILValue, 8> vjpDirectResults; |
| extractAllElements(vjpCall, getBuilder(), vjpDirectResults); |
| ArrayRef<SILValue> originalDirectResults = |
| ArrayRef<SILValue>(vjpDirectResults).drop_back(1); |
| SILValue originalDirectResult = joinElements(originalDirectResults, |
| getBuilder(), |
| vjpCall->getLoc()); |
| SILValue pullback = vjpDirectResults.back(); |
| |
| // Store the original result to the value map. |
| mapValue(ai, originalDirectResult); |
| |
| // Checkpoint the pullback. |
| auto *pullbackDecl = pullbackInfo.lookUpLinearMapDecl(ai); |
| |
| // If actual pullback type does not match lowered pullback type, reabstract |
| // the pullback using a thunk. |
| auto actualPullbackType = |
| getOpType(pullback->getType()).getAs<SILFunctionType>(); |
| auto vjpGenSig = SubsMap.getGenericSignature() |
| ? SubsMap.getGenericSignature()->getCanonicalSignature() |
| : nullptr; |
| Lowering::GenericContextScope genericContextScope( |
| context.getTypeConverter(), vjpGenSig); |
| auto loweredPullbackType = |
| getOpType(context.getTypeConverter().getLoweredType( |
| pullbackDecl->getInterfaceType()->getCanonicalType(), |
| ResilienceExpansion::Minimal)) |
| .castTo<SILFunctionType>(); |
| if (!loweredPullbackType->isEqual(actualPullbackType)) { |
| // Set non-reabstracted original pullback type in nested apply info. |
| nestedApplyInfo.originalPullbackType = actualPullbackType; |
| SILOptFunctionBuilder fb(context.getTransform()); |
| auto *thunk = getOrCreateReabstractionThunk( |
| fb, getModule(), loc, /*caller*/ vjp, actualPullbackType, |
| loweredPullbackType); |
| auto *thunkRef = getBuilder().createFunctionRef(loc, thunk); |
| pullback = getBuilder().createPartialApply( |
| ai->getLoc(), thunkRef, |
| getOpSubstitutionMap(thunk->getForwardingSubstitutionMap()), |
| {pullback}, actualPullbackType->getCalleeConvention()); |
| } |
| pullbackValues[ai->getParent()].push_back(pullback); |
| |
| // Some instructions that produce the callee may have been cloned. |
| // If the original callee did not have any users beyond this `apply`, |
| // recursively kill the cloned callee. |
| if (auto *origCallee = cast_or_null<SingleValueInstruction>( |
| ai->getCallee()->getDefiningInstruction())) |
| if (origCallee->hasOneUse()) |
| recursivelyDeleteTriviallyDeadInstructions( |
| getOpValue(origCallee)->getDefiningInstruction()); |
| } |
| |
| void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) { |
| // Clone `differentiable_function` from original to VJP, then add the cloned |
| // instruction to the `differentiable_function` worklist. |
| TypeSubstCloner::visitDifferentiableFunctionInst(dfi); |
| auto *newDFI = cast<DifferentiableFunctionInst>(getOpValue(dfi)); |
| context.getDifferentiableFunctionInsts().push_back(newDFI); |
| } |
| }; |
| } // end anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // AdjointValue - a symbolic representation for adjoint values that allows |
| // for efficient differentiation of aggregates. |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| class PullbackEmitter; |
| class AdjointValue; |
| |
| enum AdjointValueKind { |
| /// An empty adjoint, i.e. zero. This case exists due to its special |
| /// mathematical properties: `0 + x = x`. This is a guaranteed optimization |
| /// when we combine a zero adjoint with another (e.g. differentiating a |
| /// fanout). |
| Zero, |
| |
| /// An aggregate of adjoint values. |
| Aggregate, |
| |
| /// A concrete SIL value. |
| Concrete, |
| }; |
| |
| class AdjointValueBase { |
| friend class AdjointValue; |
| |
| /// The kind of this adjoint value. |
| AdjointValueKind kind; |
| |
| /// The type of this value as if it were materialized as a SIL value. |
| SILType type; |
| |
| /// The underlying value. |
| union Value { |
| ArrayRef<AdjointValue> aggregate; |
| SILValue concrete; |
| Value(ArrayRef<AdjointValue> v) : aggregate(v) {} |
| Value(SILValue v) : concrete(v) {} |
| Value() {} |
| } value; |
| |
| explicit AdjointValueBase(SILType type, |
| ArrayRef<AdjointValue> aggregate) |
| : kind(AdjointValueKind::Aggregate), type(type), value(aggregate) {} |
| |
| explicit AdjointValueBase(SILValue v) |
| : kind(AdjointValueKind::Concrete), type(v->getType()), value(v) {} |
| |
| explicit AdjointValueBase(SILType type) |
| : kind(AdjointValueKind::Zero), type(type) {} |
| }; |
| |
| /// A symbolic adjoint value that is capable of representing zero value 0 and |
| /// 1, in addition to a materialized SILValue. This is expected to be passed |
| /// around by value in most cases, as it's two words long. |
| class AdjointValue final { |
| friend class PullbackEmitter; |
| |
| private: |
| /// The kind of this adjoint value. |
| AdjointValueBase *base; |
| /*implicit*/ AdjointValue(AdjointValueBase *base = nullptr) : base(base) {} |
| |
| public: |
| AdjointValueBase *operator->() const { return base; } |
| AdjointValueBase &operator*() const { return *base; } |
| |
| static AdjointValue createConcrete(llvm::BumpPtrAllocator &allocator, |
| SILValue value) { |
| return new (allocator.Allocate<AdjointValueBase>()) AdjointValueBase(value); |
| } |
| |
| template<typename EltRange> |
| static AdjointValue createAggregate(llvm::BumpPtrAllocator &allocator, |
| SILType type, EltRange elements) { |
| AdjointValue *buf = reinterpret_cast<AdjointValue *>(allocator.Allocate( |
| elements.size() * sizeof(AdjointValue), alignof(AdjointValue))); |
| MutableArrayRef<AdjointValue> elementsCopy(buf, elements.size()); |
| std::uninitialized_copy(elements.begin(), elements.end(), |
| elementsCopy.begin()); |
| return new (allocator.Allocate<AdjointValueBase>()) |
| AdjointValueBase(type, elementsCopy); |
| } |
| |
| static AdjointValue createZero(llvm::BumpPtrAllocator &allocator, |
| SILType type) { |
| return new (allocator.Allocate<AdjointValueBase>()) AdjointValueBase(type); |
| } |
| |
| AdjointValueKind getKind() const { return base->kind; } |
| SILType getType() const { return base->type; } |
| CanType getSwiftType() const { return getType().getASTType(); } |
| |
| NominalTypeDecl *getAnyNominal() const { |
| return getSwiftType()->getAnyNominal(); |
| } |
| |
| bool isZero() const { return getKind() == AdjointValueKind::Zero; } |
| bool isAggregate() const { return getKind() == AdjointValueKind::Aggregate; } |
| bool isConcrete() const { return getKind() == AdjointValueKind::Concrete; } |
| |
| unsigned getNumAggregateElements() const { |
| assert(isAggregate()); |
| return base->value.aggregate.size(); |
| } |
| |
| AdjointValue getAggregateElement(unsigned i) const { |
| assert(isAggregate()); |
| return base->value.aggregate[i]; |
| } |
| |
| ArrayRef<AdjointValue> getAggregateElements() const { |
| return base->value.aggregate; |
| } |
| |
| SILValue getConcreteValue() const { |
| assert(isConcrete()); |
| return base->value.concrete; |
| } |
| |
| void print(llvm::raw_ostream &s) const { |
| switch (getKind()) { |
| case AdjointValueKind::Zero: |
| s << "Zero"; |
| break; |
| case AdjointValueKind::Aggregate: |
| s << "Aggregate<"; |
| if (auto *decl = |
| getType().getASTType()->getStructOrBoundGenericStruct()) { |
| s << "Struct>("; |
| interleave(llvm::zip(decl->getStoredProperties(), |
| base->value.aggregate), |
| [&s](std::tuple<VarDecl *, |
| const AdjointValue &> elt) { |
| s << std::get<0>(elt)->getName() << ": "; |
| std::get<1>(elt).print(s); |
| }, [&s] { s << ", "; }); |
| } else if (auto tupleType = getType().getAs<TupleType>()) { |
| s << "Tuple>("; |
| interleave(base->value.aggregate, |
| [&s](const AdjointValue &elt) { elt.print(s); }, |
| [&s] { s << ", "; }); |
| } else { |
| llvm_unreachable("Invalid aggregate"); |
| } |
| s << ')'; |
| break; |
| case AdjointValueKind::Concrete: |
| s << "Concrete(" << base->value.concrete << ')'; |
| break; |
| } |
| } |
| }; |
| |
| inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, |
| const AdjointValue &adjVal) { |
| adjVal.print(os); |
| return os; |
| } |
| |
| } // end anonymous namespace |
| |
| namespace { |
| |
| class JVPEmitter final |
| : public TypeSubstCloner<JVPEmitter, SILOptFunctionBuilder> { |
| private: |
| /// The global context. |
| ADContext &context; |
| |
| /// The original function. |
| SILFunction *const original; |
| |
| /// The `[differentiable]` attribute. |
| SILDifferentiableAttr *const attr; |
| |
| /// The JVP function. |
| SILFunction *const jvp; |
| |
| llvm::BumpPtrAllocator allocator; |
| |
| /// The differentiation invoker. |
| DifferentiationInvoker invoker; |
| |
| /// Info from activity analysis on the original function. |
| const DifferentiableActivityInfo &activityInfo; |
| |
| /// The differential info. |
| LinearMapInfo differentialInfo; |
| |
| bool errorOccurred = false; |
| |
| //--------------------------------------------------------------------------// |
| // Differential generation related fields |
| //--------------------------------------------------------------------------// |
| |
| /// The builder for the differential function. |
| SILBuilder differentialBuilder; |
| |
| /// Mapping from original basic blocks to corresponding differential basic |
| /// blocks. |
| DenseMap<SILBasicBlock *, SILBasicBlock *> diffBBMap; |
| |
| /// Mapping from original basic blocks and original values to corresponding |
| /// tangent values. |
| DenseMap<SILValue, AdjointValue> tangentValueMap; |
| |
| /// Mapping from original basic blocks and original buffers to corresponding |
| /// tangent buffers. |
| DenseMap<std::pair<SILBasicBlock *, SILValue>, SILValue> bufferMap; |
| |
| /// Mapping from differential basic blocks to differential struct arguments. |
| DenseMap<SILBasicBlock *, SILArgument *> differentialStructArguments; |
| |
| /// Mapping from differential struct field declarations to differential struct |
| /// elements destructured from the linear map basic block argument. In the |
| /// beginning of each differential basic block, the block's differential |
| /// struct is destructured into the individual elements stored here. |
| DenseMap<VarDecl *, SILValue> differentialStructElements; |
| |
| /// An auxiliary differential local allocation builder. |
| SILBuilder diffLocalAllocBuilder; |
| |
| /// Stack buffers allocated for storing local tangent values. |
| SmallVector<SILValue, 8> differentialLocalAllocations; |
| |
| /// Mapping from original blocks to differential values. Used to build |
| /// differential struct instances. |
| DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> differentialValues; |
| |
| //--------------------------------------------------------------------------// |
| // Getters |
| //--------------------------------------------------------------------------// |
| |
| ASTContext &getASTContext() const { return jvp->getASTContext(); } |
| SILModule &getModule() const { return jvp->getModule(); } |
| const SILAutoDiffIndices &getIndices() const { return attr->getIndices(); } |
| SILBuilder &getDifferentialBuilder() { return differentialBuilder; } |
| SILFunction &getDifferential() { |
| return differentialBuilder.getFunction(); |
| } |
| SILArgument *getDifferentialStructArgument(SILBasicBlock *origBB) { |
| #ifndef NDEBUG |
| auto *diffStruct = differentialStructArguments[origBB]->getType() |
| .getStructOrBoundGenericStruct(); |
| assert(diffStruct == differentialInfo.getLinearMapStruct(origBB)); |
| #endif |
| return differentialStructArguments[origBB]; |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Initialization helpers |
| //--------------------------------------------------------------------------// |
| |
| static SubstitutionMap getSubstitutionMap(SILFunction *original, |
| SILFunction *jvp) { |
| auto substMap = original->getForwardingSubstitutionMap(); |
| if (auto *jvpGenEnv = jvp->getGenericEnvironment()) { |
| auto jvpSubstMap = jvpGenEnv->getForwardingSubstitutionMap(); |
| substMap = SubstitutionMap::get( |
| jvpGenEnv->getGenericSignature(), QuerySubstitutionMap{jvpSubstMap}, |
| LookUpConformanceInSubstitutionMap(jvpSubstMap)); |
| } |
| return substMap; |
| } |
| |
| /// Returns the activity info about the SILValues in the original function. |
| static const DifferentiableActivityInfo &getActivityInfo( |
| ADContext &context, SILFunction *original, |
| const SILAutoDiffIndices &indices, SILFunction *jvp) { |
| // Get activity info of the original function. |
| auto &passManager = context.getPassManager(); |
| auto *activityAnalysis = |
| passManager.getAnalysis<DifferentiableActivityAnalysis>(); |
| auto &activityCollection = *activityAnalysis->get(original); |
| auto &activityInfo = activityCollection.getActivityInfo( |
| jvp->getLoweredFunctionType()->getGenericSignature(), |
| AutoDiffDerivativeFunctionKind::JVP); |
| LLVM_DEBUG( |
| dumpActivityInfo(*original, indices, activityInfo, getADDebugStream())); |
| return activityInfo; |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Differential struct mapping |
| //--------------------------------------------------------------------------// |
| |
| void initializeDifferentialStructElements(SILBasicBlock *origBB, |
| SILInstructionResultArray values) { |
| auto *diffStructDecl = differentialInfo.getLinearMapStruct(origBB); |
| assert(diffStructDecl->getStoredProperties().size() == values.size() && |
| "The number of differential struct fields must equal the number of " |
| "differential struct element values"); |
| for (auto pair : llvm::zip(diffStructDecl->getStoredProperties(), values)) { |
| assert( |
| std::get<1>(pair).getOwnershipKind() != ValueOwnershipKind::Guaranteed |
| && "Differential struct elements must be @owned"); |
| auto insertion = differentialStructElements.insert({std::get<0>(pair), |
| std::get<1>(pair)}); |
| (void)insertion; |
| assert(insertion.second && |
| "A differential struct element mapping already exists!"); |
| } |
| } |
| |
| SILValue getDifferentialStructElement(SILBasicBlock *origBB, VarDecl *field) { |
| assert(differentialInfo.getLinearMapStruct(origBB) == |
| cast<StructDecl>(field->getDeclContext())); |
| assert(differentialStructElements.count(field) && |
| "Differential struct element for this field does not exist!"); |
| return differentialStructElements.lookup(field); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // General utilities |
| //--------------------------------------------------------------------------// |
| |
| SILBasicBlock::iterator getNextDifferentialLocalAllocationInsertionPoint() { |
| // If there are no local allocations, insert at the beginning of the tangent |
| // entry. |
| if (differentialLocalAllocations.empty()) |
| return getDifferential().getEntryBlock()->begin(); |
| // Otherwise, insert before the last local allocation. Inserting before |
| // rather than after ensures that allocation and zero initialization |
| // instructions are grouped together. |
| auto lastLocalAlloc = differentialLocalAllocations.back(); |
| auto it = lastLocalAlloc->getDefiningInstruction()->getIterator(); |
| return it; |
| } |
| |
| /// Get the lowered SIL type of the given nominal type declaration. |
| SILType getNominalDeclLoweredType(NominalTypeDecl *nominal) { |
| auto nomType = |
| getOpASTType(nominal->getDeclaredInterfaceType()->getCanonicalType()); |
| auto nomSILType = context.getTypeConverter().getLoweredType( |
| nomType, ResilienceExpansion::Minimal); |
| return nomSILType; |
| } |
| |
| /// Build a differential struct value for the original block corresponding to |
| /// the given terminator. |
| StructInst *buildDifferentialValueStructValue(TermInst *termInst) { |
| assert(termInst->getFunction() == original); |
| auto loc = termInst->getFunction()->getLocation(); |
| auto *origBB = termInst->getParent(); |
| auto *jvpBB = BBMap[origBB]; |
| assert(jvpBB && "Basic block mapping should exist"); |
| auto *diffStruct = differentialInfo.getLinearMapStruct(origBB); |
| assert(diffStruct && "The differential struct should have been declared"); |
| auto structLoweredTy = getNominalDeclLoweredType(diffStruct); |
| auto bbDifferentialValues = differentialValues[origBB]; |
| if (!origBB->isEntry()) { |
| auto *enumArg = jvpBB->getArguments().back(); |
| bbDifferentialValues.insert(bbDifferentialValues.begin(), enumArg); |
| } |
| return getBuilder().createStruct(loc, structLoweredTy, |
| bbDifferentialValues); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Tangent value factory methods |
| //--------------------------------------------------------------------------// |
| |
| AdjointValue makeZeroTangentValue(SILType type) { |
| return AdjointValue::createZero( |
| allocator, remapSILTypeInDifferential(type)); |
| } |
| |
| AdjointValue makeConcreteTangentValue(SILValue value) { |
| return AdjointValue::createConcrete(allocator, value); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Tangent materialization |
| //--------------------------------------------------------------------------// |
| |
| void emitZeroIndirect(CanType type, SILValue bufferAccess, |
| SILLocation loc) { |
| auto builder = getDifferentialBuilder(); |
| auto tangentSpace = getTangentSpace(type); |
| assert(tangentSpace && "No tangent space for this type"); |
| switch (tangentSpace->getKind()) { |
| case VectorSpace::Kind::Vector: |
| emitZeroIntoBuffer(builder, type, bufferAccess, loc); |
| return; |
| case VectorSpace::Kind::Tuple: { |
| auto tupleType = tangentSpace->getTuple(); |
| SmallVector<SILValue, 8> zeroElements; |
| for (unsigned i : range(tupleType->getNumElements())) { |
| auto eltAddr = builder.createTupleElementAddr(loc, bufferAccess, i); |
| emitZeroIndirect(tupleType->getElementType(i)->getCanonicalType(), |
| eltAddr, loc); |
| } |
| return; |
| } |
| case VectorSpace::Kind::Function: { |
| llvm_unreachable( |
| "Unimplemented: Emit thunks for abstracting zero initialization"); |
| } |
| } |
| } |
| |
| SILValue emitZeroDirect(CanType type, SILLocation loc) { |
| auto diffBuilder = getDifferentialBuilder(); |
| auto silType = getModule().Types.getLoweredLoadableType( |
| type, ResilienceExpansion::Minimal, getModule()); |
| auto *buffer = diffBuilder.createAllocStack(loc, silType); |
| emitZeroIndirect(type, buffer, loc); |
| auto loaded = diffBuilder.emitLoadValueOperation( |
| loc, buffer, LoadOwnershipQualifier::Take); |
| diffBuilder.createDeallocStack(loc, buffer); |
| return loaded; |
| } |
| |
| SILValue materializeTangentDirect(AdjointValue val, SILLocation loc) { |
| assert(val.getType().isObject()); |
| LLVM_DEBUG(getADDebugStream() |
| << "Materializing tangents for " << val << '\n'); |
| switch (val.getKind()) { |
| case AdjointValueKind::Zero: { |
| auto zeroVal = emitZeroDirect(val.getSwiftType(), loc); |
| return zeroVal; |
| } |
| case AdjointValueKind::Aggregate: |
| llvm_unreachable( |
| "Tuples and structs are not supported in forward mode yet."); |
| case AdjointValueKind::Concrete: |
| return val.getConcreteValue(); |
| } |
| } |
| |
| SILValue materializeTangent(AdjointValue val, SILLocation loc) { |
| if (val.isConcrete()) { |
| LLVM_DEBUG(getADDebugStream() |
| << "Materializing tangent: Value is concrete.\n"); |
| return val.getConcreteValue(); |
| } |
| LLVM_DEBUG(getADDebugStream() << "Materializing tangent: Value is " |
| "non-concrete. Materializing directly.\n"); |
| return materializeTangentDirect(val, loc); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Tangent buffer mapping |
| //--------------------------------------------------------------------------// |
| |
| void setTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer, |
| SILValue tangentBuffer) { |
| assert(originalBuffer->getType().isAddress()); |
| auto insertion = |
| bufferMap.try_emplace({origBB, originalBuffer}, tangentBuffer); |
| assert(insertion.second && "tangent buffer already exists."); |
| (void)insertion; |
| } |
| |
| SILValue &getTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer) { |
| assert(originalBuffer->getType().isAddress()); |
| assert(originalBuffer->getFunction() == original); |
| auto insertion = bufferMap.try_emplace({origBB, originalBuffer}, |
| SILValue()); |
| assert(!insertion.second && "tangent buffer should already exist"); |
| return insertion.first->getSecond(); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Differential type calculations |
| //--------------------------------------------------------------------------// |
| |
| /// Substitutes all replacement types of the given substitution map using the |
| /// tangent function's substitution map. |
| SubstitutionMap remapSubstitutionMapInDifferential(SubstitutionMap substMap) { |
| return substMap.subst(getDifferential().getForwardingSubstitutionMap()); |
| } |
| |
| /// Remap any archetypes into the differential function's context. |
| Type remapTypeInDifferential(Type ty) { |
| if (ty->hasArchetype()) |
| return getDifferential().mapTypeIntoContext(ty->mapTypeOutOfContext()); |
| return getDifferential().mapTypeIntoContext(ty); |
| } |
| |
| /// Remap any archetypes into the differential function's context. |
| SILType remapSILTypeInDifferential(SILType ty) { |
| if (ty.hasArchetype()) |
| return getDifferential().mapTypeIntoContext(ty.mapTypeOutOfContext()); |
| return getDifferential().mapTypeIntoContext(ty); |
| } |
| |
| /// Find the tangent space of a given canonical type. |
| Optional<VectorSpace> getTangentSpace(CanType type) { |
| return type->getAutoDiffAssociatedTangentSpace( |
| LookUpConformanceInModule(getModule().getSwiftModule())); |
| } |
| |
| /// Assuming the given type conforms to `Differentiable` after remapping, |
| /// returns the associated tangent space SIL type. |
| SILType getRemappedTangentType(SILType type) { |
| return SILType::getPrimitiveType( |
| getTangentSpace(remapSILTypeInDifferential(type).getASTType()) |
| ->getCanonicalType(), |
| type.getCategory()); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Tangent value mapping |
| //--------------------------------------------------------------------------// |
| |
| /// Get the tangent for an original value. The given value must be in the |
| /// original function. |
| /// |
| /// This method first tries to find an entry in `tangentValueMap`. If an entry |
| /// doesn't exist, create a zero tangent. |
| AdjointValue getTangentValue(SILValue originalValue) { |
| assert(originalValue->getType().isObject()); |
| assert(originalValue->getFunction() == original); |
| auto insertion = tangentValueMap.try_emplace( |
| originalValue, makeZeroTangentValue( |
| getRemappedTangentType(originalValue->getType()))); |
| return insertion.first->getSecond(); |
| } |
| |
| /// Map the tangent value to the given original value. |
| void setTangentValue(SILBasicBlock *origBB, SILValue originalValue, |
| AdjointValue newTangentValue) { |
| if (auto *defInst = originalValue->getDefiningInstruction()) { |
| bool isTupleTypedApplyResult = |
| isa<ApplyInst>(defInst) && originalValue->getType().is<TupleType>(); |
| assert(!isTupleTypedApplyResult && |
| "Should not set tangent value for tuple-typed result from `apply` " |
| "instruction; use `destructure_tuple` on `apply` result and set " |
| "tangent value for `destructure_tuple` results instead."); |
| } |
| assert(originalValue->getType().isObject()); |
| assert(newTangentValue.getType().isObject()); |
| assert(originalValue->getFunction() == original); |
| LLVM_DEBUG(getADDebugStream() << "Adding tangent for " << originalValue); |
| // The tangent value must be in the tangent space. |
| assert(newTangentValue.getType() == |
| getRemappedTangentType(originalValue->getType())); |
| auto insertion = |
| tangentValueMap.try_emplace(originalValue, newTangentValue); |
| auto inserted = insertion.second; |
| assert(inserted && "The tangent value should not already exist."); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Tangent emission helpers |
| //--------------------------------------------------------------------------// |
| public: |
| #define CLONE_AND_EMIT_TANGENT(INST, ID) \ |
| void visit##INST##Inst(INST##Inst *inst) { \ |
| TypeSubstCloner::visit##INST##Inst(inst); \ |
| if (differentialInfo.shouldDifferentiateInstruction(inst)) \ |
| emitTangentFor##INST##Inst(inst); \ |
| } \ |
| void emitTangentFor##INST##Inst(INST##Inst *(ID)) |
| |
| CLONE_AND_EMIT_TANGENT(BeginBorrow, bbi) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto loc = bbi->getLoc(); |
| auto tanVal = materializeTangent(getTangentValue(bbi->getOperand()), loc); |
| auto tanValBorrow = diffBuilder.emitBeginBorrowOperation(loc, tanVal); |
| setTangentValue(bbi->getParent(), bbi, |
| makeConcreteTangentValue(tanValBorrow)); |
| } |
| |
| CLONE_AND_EMIT_TANGENT(EndBorrow, ebi) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto loc = ebi->getLoc(); |
| auto tanVal = materializeTangent(getTangentValue(ebi->getOperand()), loc); |
| diffBuilder.emitEndBorrowOperation(loc, tanVal); |
| } |
| |
| CLONE_AND_EMIT_TANGENT(DestroyValue, dvi) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto loc = dvi->getLoc(); |
| auto tanVal = materializeTangent(getTangentValue(dvi->getOperand()), loc); |
| diffBuilder.emitDestroyValue(loc, tanVal); |
| } |
| |
| CLONE_AND_EMIT_TANGENT(CopyValue, cvi) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto tan = getTangentValue(cvi->getOperand()); |
| auto tanVal = materializeTangent(tan, cvi->getLoc()); |
| auto tanValCopy = diffBuilder.emitCopyValueOperation(cvi->getLoc(), tanVal); |
| setTangentValue(cvi->getParent(), cvi, |
| makeConcreteTangentValue(tanValCopy)); |
| } |
| |
| /// Handle `load` instruction. |
| /// Original: y = load x |
| /// Tangent: tan[y] = load tan[x] |
| CLONE_AND_EMIT_TANGENT(Load, li) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto *bb = li->getParent(); |
| auto loc = li->getLoc(); |
| auto tanBuf = getTangentBuffer(bb, li->getOperand()); |
| auto tanVal = diffBuilder.emitLoadValueOperation( |
| loc, tanBuf, li->getOwnershipQualifier()); |
| setTangentValue(bb, li, makeConcreteTangentValue(tanVal)); |
| } |
| |
| /// Handle `load_borrow` instruction. |
| /// Original: y = load_borrow x |
| /// Tangent: tan[y] = load_borrow tan[x] |
| CLONE_AND_EMIT_TANGENT(LoadBorrow, lbi) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto *bb = lbi->getParent(); |
| auto loc = lbi->getLoc(); |
| auto tanBuf = getTangentBuffer(bb, lbi->getOperand()); |
| auto tanVal = diffBuilder.emitLoadBorrowOperation( |
| loc, tanBuf); |
| setTangentValue(bb, lbi, makeConcreteTangentValue(tanVal)); |
| } |
| |
| /// Handle `store` instruction in the differential. |
| /// Original: store x to y |
| /// Tangent: store tan[x] to tan[y] |
| CLONE_AND_EMIT_TANGENT(Store, si) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto loc = si->getLoc(); |
| auto tanValSrc = materializeTangent(getTangentValue(si->getSrc()), loc); |
| auto &tanValDest = getTangentBuffer(si->getParent(), si->getDest()); |
| diffBuilder.emitStoreValueOperation( |
| loc, tanValSrc, tanValDest, si->getOwnershipQualifier()); |
| } |
| |
| /// Handle `store_borrow` instruction in the differential. |
| /// Original: store_borrow x to y |
| /// Tangent: store_borrow tan[x] to tan[y] |
| CLONE_AND_EMIT_TANGENT(StoreBorrow, sbi) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto loc = sbi->getLoc(); |
| auto tanValSrc = materializeTangent(getTangentValue(sbi->getSrc()), loc); |
| auto &tanValDest = getTangentBuffer(sbi->getParent(), sbi->getDest()); |
| diffBuilder.createStoreBorrow(loc, tanValSrc, tanValDest); |
| } |
| |
| /// Handle `copy_addr` instruction. |
| /// Original: copy_addr x to y |
| /// Tangent: copy_addr tan[x] to tan[y] |
| CLONE_AND_EMIT_TANGENT(CopyAddr, cai) { |
| auto *diffGenEnv = getDifferential().getGenericEnvironment(); |
| auto diffGenSig = diffGenEnv |
| ? diffGenEnv->getGenericSignature()->getCanonicalSignature() |
| : nullptr; |
| Lowering::GenericContextScope genericContextScope( |
| context.getTypeConverter(), diffGenSig); |
| |
| auto diffBuilder = getDifferentialBuilder(); |
| auto loc = cai->getLoc(); |
| auto *bb = cai->getParent(); |
| auto &tanSrc = getTangentBuffer(bb, cai->getSrc()); |
| auto tanDest = getTangentBuffer(bb, cai->getDest()); |
| |
| diffBuilder.createCopyAddr(loc, tanSrc, tanDest, cai->isTakeOfSrc(), |
| cai->isInitializationOfDest()); |
| } |
| |
| /// Handle `unconditional_checked_cast_addr` instruction. |
| /// Original: unconditional_checked_cast_addr $X in x to $Y in y |
| /// Tangent: unconditional_checked_cast_addr $X.Tan in tan[x] |
| /// to $Y.Tan in tan[y] |
| CLONE_AND_EMIT_TANGENT(UnconditionalCheckedCastAddr, uccai) { |
| auto diffBuilder = getDifferentialBuilder(); |
| auto loc = uccai->getLoc(); |
| auto *bb = uccai->getParent(); |
| auto &tanSrc = getTangentBuffer(bb, uccai->getSrc()); |
| auto tanDest = getTangentBuffer(bb, uccai->getDest()); |
| |
| diffBuilder.createUnconditionalCheckedCastAddr( |
| loc, tanSrc, tanSrc->getType().getASTType(), tanDest, |
| tanDest->getType().getASTType()); |
| } |
| |
| /// Handle `begin_access` instruction (and do differentiability checks). |
| /// Original: y = begin_access x |
| /// Tangent: tan[y] = begin_access tan[x] |
| CLONE_AND_EMIT_TANGENT(BeginAccess, bai) { |
| // Check for non-differentiable writes. |
| if (bai->getAccessKind() == SILAccessKind::Modify) { |
| if (auto *gai = dyn_cast<GlobalAddrInst>(bai->getSource())) { |
| context.emitNondifferentiabilityError(bai, invoker, |
| diag::autodiff_cannot_differentiate_writes_to_global_variables); |
| errorOccurred = true; |
| return; |
| } |
| if (auto *pbi = dyn_cast<ProjectBoxInst>(bai->getSource())) { |
| context.emitNondifferentiabilityError(bai, invoker, |
| diag::autodiff_cannot_differentiate_writes_to_mutable_captures); |
| errorOccurred = true; |
| return; |
| } |
| } |
| |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto *bb = bai->getParent(); |
| |
| auto tanSrc = getTangentBuffer(bb, bai->getSource()); |
| auto *tanDest = diffBuilder.createBeginAccess( |
| bai->getLoc(), tanSrc, bai->getAccessKind(), bai->getEnforcement(), |
| bai->hasNoNestedConflict(), bai->isFromBuiltin()); |
| setTangentBuffer(bb, bai, tanDest); |
| } |
| |
| /// Handle `end_access` instruction. |
| /// Original: begin_access x |
| /// Tangent: end_access tan[x] |
| CLONE_AND_EMIT_TANGENT(EndAccess, eai) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto *bb = eai->getParent(); |
| auto loc = eai->getLoc(); |
| auto tanSrc = getTangentBuffer(bb, eai->getOperand()); |
| diffBuilder.createEndAccess(loc, tanSrc, eai->isAborting()); |
| } |
| |
| /// Handle `alloc_stack` instruction. |
| /// Original: y = alloc_stack $T |
| /// Tangent: tan[y] = alloc_stack $T.Tangent |
| CLONE_AND_EMIT_TANGENT(AllocStack, asi) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto *mappedAllocStackInst = diffBuilder.createAllocStack( |
| asi->getLoc(), getRemappedTangentType(asi->getElementType()), |
| asi->getVarInfo()); |
| bufferMap.try_emplace({asi->getParent(), asi}, |
| mappedAllocStackInst); |
| } |
| |
| /// Handle `dealloc_stack` instruction. |
| /// Original: dealloc_stack x |
| /// Tangent: dealloc_stack tan[x] |
| CLONE_AND_EMIT_TANGENT(DeallocStack, dsi) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto tanBuf = getTangentBuffer(dsi->getParent(), dsi->getOperand()); |
| diffBuilder.createDeallocStack(dsi->getLoc(), tanBuf); |
| } |
| |
| /// Handle `destroy_addr` instruction. |
| /// Original: destroy_addr x |
| /// Tangent: destroy_addr tan[x] |
| CLONE_AND_EMIT_TANGENT(DestroyAddr, dai) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto tanBuf = getTangentBuffer(dai->getParent(), dai->getOperand()); |
| diffBuilder.createDestroyAddr(dai->getLoc(), tanBuf); |
| } |
| |
| /// Handle `struct` instruction. |
| /// Original: y = struct $T (x0, x1, x2, ...) |
| /// Tangent: tan[y] = struct $T.Tangent (tan[x0], tan[x1], tan[x2], ...) |
| CLONE_AND_EMIT_TANGENT(Struct, si) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| SmallVector<SILValue, 4> tangentElements; |
| for (auto elem : si->getElements()) |
| tangentElements.push_back(getTangentValue(elem).getConcreteValue()); |
| auto tanExtract = diffBuilder.createStruct( |
| si->getLoc(), getRemappedTangentType(si->getType()), tangentElements); |
| setTangentValue(si->getParent(), si, makeConcreteTangentValue(tanExtract)); |
| } |
| |
| /// Handle `struct_extract` instruction. |
| /// Original: y = struct_extract x, #field |
| /// Tangent: tan[y] = struct_extract tan[x], #field' |
| /// ^~~~~~~ |
| /// field in tangent space corresponding to #field |
| CLONE_AND_EMIT_TANGENT(StructExtract, sei) { |
| assert(!sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() && |
| "`struct_extract` with `@noDerivative` field should not be " |
| "differentiated; activity analysis should not marked as varied."); |
| |
| auto diffBuilder = getDifferentialBuilder();; |
| auto tangentVectorTy = |
| getRemappedTangentType(sei->getOperand()->getType()); |
| auto *tangentVectorDecl = |
| tangentVectorTy.getStructOrBoundGenericStruct(); |
| |
| // Find the corresponding field in the tangent space. |
| VarDecl *tanField = nullptr; |
| // If the tangent space is the original struct, then field is the same. |
| if (tangentVectorDecl == sei->getStructDecl()) |
| tanField = sei->getField(); |
| // Otherwise, look up the field by name. |
| else { |
| auto tanFieldLookup = |
| tangentVectorDecl->lookupDirect(sei->getField()->getName()); |
| if (tanFieldLookup.empty()) { |
| context.emitNondifferentiabilityError( |
| sei, invoker, |
| diag::autodiff_stored_property_no_corresponding_tangent, |
| sei->getStructDecl()->getNameStr(), |
| sei->getField()->getNameStr()); |
| errorOccurred = true; |
| return; |
| } |
| tanField = cast<VarDecl>(tanFieldLookup.front()); |
| } |
| // Emit tangent `struct_extract`. |
| auto tanStruct = |
| materializeTangent(getTangentValue(sei->getOperand()), sei->getLoc()); |
| auto tangentInst = |
| diffBuilder.createStructExtract(sei->getLoc(), tanStruct, tanField); |
| // Update tangent value mapping for `struct_extract` result. |
| auto tangentResult = makeConcreteTangentValue(tangentInst); |
| setTangentValue(sei->getParent(), sei, tangentResult); |
| } |
| |
| /// Handle `struct_element_addr` instruction. |
| /// Original: y = struct_element_addr x, #field |
| /// Tangent: tan[y] = struct_element_addr tan[x], #field' |
| /// ^~~~~~~ |
| /// field in tangent space corresponding to #field |
| CLONE_AND_EMIT_TANGENT(StructElementAddr, seai) { |
| assert(!seai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() && |
| "`struct_element_addr` with `@noDerivative` field should not be " |
| "differentiated; activity analysis should not marked as varied."); |
| |
| auto diffBuilder = getDifferentialBuilder(); |
| auto *bb = seai->getParent(); |
| auto tangentVectorTy = |
| getRemappedTangentType(seai->getOperand()->getType()); |
| auto *tangentVectorDecl = |
| tangentVectorTy.getStructOrBoundGenericStruct(); |
| |
| // Find the corresponding field in the tangent space. |
| VarDecl *tanField = nullptr; |
| // If the tangent space is the original struct, then field is the same. |
| if (tangentVectorDecl == seai->getStructDecl()) |
| tanField = seai->getField(); |
| // Otherwise, look up the field by name. |
| else { |
| auto tanFieldLookup = |
| tangentVectorDecl->lookupDirect(seai->getField()->getName()); |
| if (tanFieldLookup.empty()) { |
| context.emitNondifferentiabilityError( |
| seai, invoker, |
| diag::autodiff_stored_property_no_corresponding_tangent, |
| seai->getStructDecl()->getNameStr(), |
| seai->getField()->getNameStr()); |
| errorOccurred = true; |
| return; |
| } |
| tanField = cast<VarDecl>(tanFieldLookup.front()); |
| } |
| |
| // Emit tangent `struct_element_addr`. |
| auto tanOperand = getTangentBuffer(bb, seai->getOperand()); |
| auto tangentInst = diffBuilder.createStructElementAddr( |
| seai->getLoc(), tanOperand, tanField); |
| // Update tangent buffer map for `struct_element_addr`. |
| setTangentBuffer(bb, seai, tangentInst); |
| } |
| |
| /// Handle `tuple` instruction. |
| /// Original: y = tuple (x0, x1, x2, ...) |
| /// Tangent: tan[y] = tuple (tan[x0], tan[x1], tan[x2], ...) |
| CLONE_AND_EMIT_TANGENT(Tuple, ti) { |
| auto diffBuilder = getDifferentialBuilder(); |
| |
| // Get the tangents of all the tuple elements. |
| SmallVector<SILValue, 8> tangentTupleElements; |
| for (auto elem : ti->getElements()) { |
| tangentTupleElements.push_back( |
| materializeTangent(getTangentValue(elem), ti->getLoc())); |
| } |
| |
| // Emit the instruction and add the tangent mapping. |
| auto tanTuple = diffBuilder.createTuple(ti->getLoc(), tangentTupleElements); |
| setTangentValue(ti->getParent(), ti, makeConcreteTangentValue(tanTuple)); |
| } |
| |
| /// Handle `tuple_extract` instruction. |
| /// Original: y = tuple_extract x, <n> |
| /// Tangent: tan[y] = tuple_extract tan[x], <n'> |
| /// ^~~~ |
| /// tuple tangent space index corresponding to n |
| CLONE_AND_EMIT_TANGENT(TupleExtract, tei) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto loc = tei->getLoc(); |
| auto origTupleTy = tei->getOperand()->getType().castTo<TupleType>(); |
| unsigned tanIndex = 0; |
| for (unsigned i : range(tei->getFieldNo())) { |
| if (getTangentSpace( |
| origTupleTy->getElement(i).getType()->getCanonicalType())) |
| ++tanIndex; |
| } |
| auto tanType = getRemappedTangentType(tei->getType()); |
| auto tanSource = materializeTangent( |
| getTangentValue(tei->getOperand()), loc); |
| SILValue tanBuf; |
| // If the tangent buffer of the source does not have a tuple type, then |
| // it must represent a "single element tuple type". Use it directly. |
| if (!tanSource->getType().is<TupleType>()) { |
| setTangentValue(tei->getParent(), tei, |
| makeConcreteTangentValue(tanSource)); |
| } else { |
| tanBuf = diffBuilder.createTupleExtract(loc, tanSource, tanIndex, tanType); |
| bufferMap.try_emplace({tei->getParent(), tei}, tanBuf); |
| } |
| } |
| |
| /// Handle `tuple_element_addr` instruction. |
| /// Original: y = tuple_element_addr x, <n> |
| /// Tangent: tan[y] = tuple_element_addr tan[x], <n'> |
| /// ^~~~ |
| /// tuple tangent space index corresponding to n |
| CLONE_AND_EMIT_TANGENT(TupleElementAddr, teai) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto origTupleTy = teai->getOperand()->getType().castTo<TupleType>(); |
| unsigned tanIndex = 0; |
| for (unsigned i : range(teai->getFieldNo())) { |
| if (getTangentSpace( |
| origTupleTy->getElement(i).getType()->getCanonicalType())) |
| ++tanIndex; |
| } |
| auto tanType = getRemappedTangentType(teai->getType()); |
| auto tanSource = getTangentBuffer(teai->getParent(), teai->getOperand()); |
| SILValue tanBuf; |
| // If the tangent buffer of the source does not have a tuple type, then |
| // it must represent a "single element tuple type". Use it directly. |
| if (!tanSource->getType().is<TupleType>()) { |
| tanBuf = tanSource; |
| } else { |
| tanBuf = diffBuilder.createTupleElementAddr( |
| teai->getLoc(), tanSource, tanIndex, tanType); |
| } |
| bufferMap.try_emplace({teai->getParent(), teai}, tanBuf); |
| } |
| |
| /// Handle `destructure_tuple` instruction. |
| /// Original: (y0, y1, ...) = destructure_tuple x, <n> |
| /// Tangent: (tan[y0], tan[y1], ...) = destructure_tuple tan[x], <n'> |
| /// ^~~~ |
| /// tuple tangent space index corresponding to n |
| CLONE_AND_EMIT_TANGENT(DestructureTuple, dti) { |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto *bb = dti->getParent(); |
| auto loc = dti->getLoc(); |
| |
| SmallVector<SILValue, 2> activeOrigResults; |
| bool hasActiveResult = false; |
| for (auto result : dti->getResults()) { |
| if (activityInfo.isActive(result, getIndices())) { |
| activeOrigResults.push_back(result); |
| hasActiveResult = true; |
| break; |
| } |
| } |
| assert(!activeOrigResults.empty() && |
| "original 'destructure_tuple' should have at least one active " |
| "result"); |
| |
| auto tanTuple = |
| materializeTangent(getTangentValue(dti->getOperand()), loc); |
| auto *tupleElements = diffBuilder.createDestructureTuple(loc, tanTuple); |
| for (auto i : range(tupleElements->getNumResults())) { |
| auto origElem = dti->getResult(i); |
| auto tanElem = tupleElements->getResult(i); |
| setTangentValue(bb, origElem, makeConcreteTangentValue(tanElem)); |
| } |
| } |
| |
| #undef CLONE_AND_EMIT_TANGENT |
| |
| /// Handle `apply` instruction. |
| /// Original: y = apply f(x0, x1, ...) |
| /// Tangent: tan[y] = apply diff_f(tan[x0], tan[x1], ...) |
| void emitTangentForApplyInst(ApplyInst *ai, |
| const SILAutoDiffIndices &actualIndices, |
| CanSILFunctionType originalDifferentialType) { |
| assert(differentialInfo.shouldDifferentiateApplyInst(ai)); |
| auto *bb = ai->getParent(); |
| auto loc = ai->getLoc(); |
| auto &diffBuilder = getDifferentialBuilder(); |
| |
| // Get the differential value. |
| auto *field = differentialInfo.lookUpLinearMapDecl(ai); |
| assert(field); |
| SILValue differential = getDifferentialStructElement(bb, field); |
| auto differentialType = remapSILTypeInDifferential(differential->getType()) |
| .castTo<SILFunctionType>(); |
| |
| // Get the differential arguments. |
| SmallVector<SILValue, 8> diffArgs; |
| |
| for (auto indRes : ai->getIndirectSILResults()) |
| diffArgs.push_back(getTangentBuffer(bb, indRes)); |
| |
| auto paramArgs = ai->getArgumentsWithoutIndirectResults(); |
| // Get the tangent value of the original arguments. |
| for (auto i : indices(paramArgs)) { |
| auto origArg = paramArgs[i]; |
| // If the argument is not active: |
| // - Skip the element, if it is not differentiable. |
| // - Otherwise, add a zero value to that location. |
| if (!activityInfo.isActive(origArg, getIndices())) { |
| auto origCalleeType = ai->getSubstCalleeType(); |
| if (!origCalleeType->isDifferentiable()) |
| continue; |
| auto actualOrigCalleeIndices = |
| origCalleeType->getDifferentiationParameterIndices(); |
| if (actualOrigCalleeIndices->contains(i)) { |
| SILValue tanParam; |
| if (origArg->getType().isObject()) { |
| tanParam = emitZeroDirect( |
| getRemappedTangentType(origArg->getType()).getASTType(), loc); |
| diffArgs.push_back(tanParam); |
| } else { |
| tanParam = diffBuilder.createAllocStack( |
| loc, getRemappedTangentType(origArg->getType())); |
| emitZeroIndirect( |
| getRemappedTangentType(origArg->getType()).getASTType(), tanParam, |
| loc); |
| } |
| } |
| } |
| // Otherwise, if the argument is active, handle the argument normally by |
| // getting its tangent value. |
| else { |
| SILValue tanParam; |
| if (origArg->getType().isObject()) { |
| tanParam = materializeTangent(getTangentValue(origArg), loc); |
| } else { |
| tanParam = getTangentBuffer(ai->getParent(), origArg); |
| } |
| diffArgs.push_back(tanParam); |
| if (errorOccurred) |
| return; |
| } |
| } |
| |
| // If callee differential was reabstracted in JVP, reabstract the callee |
| // differential. |
| if (!differentialType->isEqual(originalDifferentialType)) { |
| SILOptFunctionBuilder fb(context.getTransform()); |
| auto *thunk = getOrCreateReabstractionThunk( |
| fb, context.getModule(), loc, &getDifferential(), |
| differentialType, originalDifferentialType); |
| auto *thunkRef = diffBuilder.createFunctionRef(loc, thunk); |
| differential = diffBuilder.createPartialApply( |
| loc, thunkRef, |
| remapSubstitutionMapInDifferential( |
| thunk->getForwardingSubstitutionMap()), |
| {differential}, differentialType->getCalleeConvention()); |
| } |
| |
| // Call the differential. |
| auto *differentialCall = diffBuilder.createApply( |
| loc, differential, SubstitutionMap(), diffArgs, |
| /*isNonThrowing*/ false); |
| diffBuilder.emitDestroyValueOperation(loc, differential); |
| assert(differentialCall->getNumResults() == 1 && |
| "Expected differential to return one result"); |
| |
| // Get the original results of the `apply` instructions. |
| SmallVector<SILValue, 8> origDirectResults; |
| forEachApplyDirectResult(ai, [&](SILValue directResult) { |
| origDirectResults.push_back(directResult); |
| }); |
| SmallVector<SILValue, 8> origAllResults; |
| collectAllActualResultsInTypeOrder(ai, origDirectResults, origAllResults); |
| auto origResult = origAllResults[actualIndices.source]; |
| |
| // Get the differential results of the `apply` instructions. |
| SmallVector<SILValue, 8> differentialDirectResults; |
| forEachApplyDirectResult(differentialCall, [&](SILValue directResult) { |
| differentialDirectResults.push_back(directResult); |
| }); |
| SmallVector<SILValue, 8> differentialAllResults; |
| collectAllActualResultsInTypeOrder(differentialCall, |
| differentialDirectResults, |
| differentialAllResults); |
| auto differentialResult = differentialAllResults.front(); |
| |
| // Add tangent for original result. |
| if (origResult->getType().isObject()) { |
| if (!origResult->getType().is<TupleType>()) { |
| setTangentValue(bb, origResult, |
| makeConcreteTangentValue(differentialResult)); |
| } else if (auto *dti = getSingleDestructureTupleUser(ai)) { |
| bool notSetValue = true; |
| for (auto result : dti->getResults()) { |
| if (activityInfo.isActive(result, getIndices())) { |
| assert(notSetValue && |
| "This was incorrectly set, should only have one active " |
| "result from the tuple."); |
| notSetValue = false; |
| setTangentValue(bb, result, |
| makeConcreteTangentValue(differentialResult)); |
| } |
| } |
| } |
| } |
| } |
| |
| /// Generate a `return` instruction in the current differential basic block. |
| void emitReturnInstForDifferential() { |
| auto &differential = getDifferential(); |
| auto diffLoc = differential.getLocation(); |
| auto &diffBuilder = getDifferentialBuilder(); |
| |
| SmallVector<SILValue, 2> activeResults; |
| |
| // This vector will contain all the materialized return elements. |
| SmallVector<SILValue, 8> retElts; |
| SmallVector<SILValue, 2> originalResults; |
| collectAllDirectResultsInTypeOrder(*original, originalResults); |
| |
| // Materializes the return element corresponding to the result |
| // `resultIndex` into the `retElts` vector. |
| auto addActiveResult = [&](unsigned resultIndex) -> void { |
| auto origResult = originalResults[resultIndex]; |
| assert(origResult->getType().isObject() && |
| "Should only be handling direct results for 'return' " |
| "instruction."); |
| if (activityInfo.isActive(origResult, getIndices())) { |
| activeResults.push_back(origResult); |
| } |
| }; |
| // Create an array of the direct tangent values of the original results. |
| for (auto i : range(originalResults.size())) |
| addActiveResult(i); |
| assert(activeResults.size() <= 1); |
| |
| if (activeResults.empty() && !originalResults.empty()) { |
| // Create zero tangent value for direct result. |
| auto origResult = originalResults[getIndices().source]; |
| assert(origResult->getType().isObject() && |
| "Should only be handling direct results for 'return' " |
| "instruction."); |
| auto zeroType = origResult->getType().getASTType(); |
| auto zero = |
| emitZeroDirect(getTangentSpace(zeroType)->getCanonicalType(), |
| diffLoc); |
| retElts.push_back(zero); |
| } else if (!activeResults.empty()) { |
| auto diffVal = getTangentValue(activeResults.front()); |
| auto val = materializeTangent(diffVal, diffLoc); |
| retElts.push_back(val); |
| } |
| |
| diffBuilder.createReturn( |
| diffLoc, joinElements(retElts, diffBuilder, diffLoc)); |
| } |
| |
| private: |
| |
| /// Set up the differential function. This includes: |
| /// - Creating all differential blocks. |
| /// - Creating differential entry block arguments based on the function type. |
| /// - Creating tangent value mapping for original/differential parameters. |
| /// - Checking for unvaried result and emitting related warnings. |
| void prepareForDifferentialGeneration() { |
| // Create differential blocks and arguments. |
| auto *diffGenEnv = getDifferential().getGenericEnvironment(); |
| auto diffGenSig = diffGenEnv |
| ? diffGenEnv->getGenericSignature()->getCanonicalSignature() |
| : nullptr; |
| auto &differential = getDifferential(); |
| auto *origEntry = original->getEntryBlock(); |
| for (auto &origBB : *original) { |
| auto *diffBB = differential.createBasicBlock(); |
| diffBBMap.insert({&origBB, diffBB}); |
| { |
| Lowering::GenericContextScope genericContextScope( |
| context.getTypeConverter(), diffGenSig); |
| auto diffStructLoweredType = remapSILTypeInDifferential( |
| differentialInfo.getLinearMapStructLoweredType(&origBB)); |
| |
| // If the BB is the original entry, then the differential block that we |
| // just created must be the differential function's entry. Create |
| // differential entry arguments and continue. |
| if (&origBB == origEntry) { |
| assert(diffBB->isEntry()); |
| createEntryArguments(&differential); |
| auto *lastArg = diffBB->getArguments().back(); |
| assert(lastArg->getType() == diffStructLoweredType); |
| differentialStructArguments[&origBB] = lastArg; |
| } |
| } |
| |
| LLVM_DEBUG({ |
| auto &s = getADDebugStream() |
| << "Original bb" + std::to_string(origBB.getDebugID()) |
| << ": To differentiate or not to differentiate?\n"; |
| for (auto &inst : origBB) { |
| s << (differentialInfo.shouldDifferentiateInstruction(&inst) |
| ? "[∂] " : "[ ] ") |
| << inst; |
| } |
| }); |
| } |
| |
| assert(diffBBMap.size() == 1 && |
| "Can only currently handle single basic block functions"); |
| |
| // The differential function has type: |
| // (arg0', ..., argn', entry_df_struct) -> result'. |
| auto diffParamArgs = |
| differential.getArgumentsWithoutIndirectResults().drop_back(); |
| assert(diffParamArgs.size() == |
| attr->getIndices().parameters->getNumIndices()); |
| auto origParamArgs = original->getArgumentsWithoutIndirectResults(); |
| |
| // TODO(TF-788): Re-enable non-varied result warning. |
| /* |
| // Check if result is not varied. |
| SmallVector<SILValue, 8> origFormalResults; |
| collectAllFormalResultsInTypeOrder(*original, origFormalResults); |
| auto origResult = origFormalResults[getIndices().source]; |
| // Emit warning if original result is not varied, because it will always |
| // have a zero derivative. |
| if (!activityInfo.isVaried(origResult, getIndices().parameters)) { |
| // Emit fixit if original result has a valid source location. |
| auto startLoc = origResult.getLoc().getStartSourceLoc(); |
| auto endLoc = origResult.getLoc().getEndSourceLoc(); |
| if (startLoc.isValid() && endLoc.isValid()) { |
| context.diagnose(startLoc, diag::autodiff_nonvaried_result_fixit) |
| .fixItInsert(startLoc, "withoutDerivative(at:") |
| .fixItInsertAfter(endLoc, ")"); |
| } |
| } |
| */ |
| |
| // Initialize tangent mapping for parameters. |
| auto diffParamsIt = getIndices().parameters->begin(); |
| for (auto index : range(diffParamArgs.size())) { |
| auto *diffArg = diffParamArgs[index]; |
| auto *origArg = origParamArgs[*diffParamsIt]; |
| diffParamsIt++; |
| if (diffArg->getType().isAddress()) { |
| setTangentBuffer(origEntry, origArg, diffArg); |
| } else { |
| setTangentValue( |
| origEntry, origArg, makeConcreteTangentValue(diffArg)); |
| } |
| LLVM_DEBUG(getADDebugStream() |
| << "Assigned parameter " << *diffArg |
| << " as the tangent of original result " << *origArg); |
| } |
| |
| // Initialize tangent mapping for indirect results. |
| auto origIndResults = original->getIndirectResults(); |
| auto diffIndResults = differential.getIndirectResults(); |
| assert(origIndResults.size() == diffIndResults.size()); |
| |
| for (auto &origBB : *original) |
| for (auto i : indices(diffIndResults)) |
| setTangentBuffer(&origBB, origIndResults[i], diffIndResults[i]); |
| } |
| |
| public: |
| explicit JVPEmitter(ADContext &context, SILFunction *original, |
| SILDifferentiableAttr *attr, SILFunction *jvp, |
| DifferentiationInvoker invoker) |
| : TypeSubstCloner(*jvp, *original, getSubstitutionMap(original, jvp)), |
| context(context), original(original), attr(attr), jvp(jvp), |
| invoker(invoker), activityInfo(getActivityInfo( |
| context, original, attr->getIndices(), jvp)), |
| differentialInfo(context, AutoDiffLinearMapKind::Differential, original, |
| jvp, attr->getIndices(), activityInfo), |
| differentialBuilder(SILBuilder(*createEmptyDifferential( |
| context, original, attr, &differentialInfo))), |
| diffLocalAllocBuilder(getDifferential()) { |
| // Create empty differential function. |
| context.getGeneratedFunctions().push_back(&getDifferential()); |
| } |
| |
| static SILFunction *createEmptyDifferential(ADContext &context, |
| SILFunction *original, |
| SILDifferentiableAttr *attr, |
| LinearMapInfo *linearMapInfo) { |
| auto &module = context.getModule(); |
| auto origTy = original->getLoweredFunctionType(); |
| auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule()); |
| |
| // RAII that pushes the original function's generic signature to |
| // `module.Types` so that calls to `module.Types.getTypeLowering()` below |
| // will know the original function's generic parameter types. |
| Lowering::GenericContextScope genericContextScope( |
| module.Types, origTy->getGenericSignature()); |
| |
| // Parameters of the differential are: |
| // - the tangent values of the wrt parameters. |
| // - the differential struct for the original entry. |
| // Result of the differential is in the tangent space of the original |
| // result. |
| SmallVector<SILParameterInfo, 8> dfParams; |
| SmallVector<SILResultInfo, 8> dfResults; |
| auto origParams = origTy->getParameters(); |
| auto indices = attr->getIndices(); |
| |
| // Add differential results. |
| auto origResInfo = origTy->getResults()[indices.source]; |
| dfResults.push_back( |
| SILResultInfo(origResInfo.getType() |
| ->getAutoDiffAssociatedTangentSpace(lookupConformance) |
| ->getCanonicalType(), |
| origResInfo.getConvention())); |
| |
| // Add differential parameters for the requested wrt parameters. |
| for (auto i : indices.parameters->getIndices()) { |
| auto origParam = origParams[i]; |
| dfParams.push_back(SILParameterInfo( |
| origParam.getType() |
| ->getAutoDiffAssociatedTangentSpace(lookupConformance) |
| ->getCanonicalType(), |
| origParam.getConvention())); |
| } |
| |
| // Accept a differential struct in the differential parameter list. This is |
| // the returned differential's closure context. |
| auto *origEntry = original->getEntryBlock(); |
| auto *dfStruct = linearMapInfo->getLinearMapStruct(origEntry); |
| auto dfStructType = |
| dfStruct->getDeclaredInterfaceType()->getCanonicalType(); |
| dfParams.push_back({dfStructType, ParameterConvention::Direct_Owned}); |
| |
| Mangle::ASTMangler mangler; |
| auto diffName = original->getASTContext().getIdentifier( |
| mangler.mangleAutoDiffLinearMapHelper( |
| original->getName(), AutoDiffLinearMapKind::Differential, |
| indices)).str(); |
| auto diffGenericSig = getDerivativeGenericSignature(attr, original); |
| auto *diffGenericEnv = |
| diffGenericSig ? diffGenericSig->getGenericEnvironment() : nullptr; |
| auto diffType = SILFunctionType::get( |
| diffGenericSig, origTy->getExtInfo(), origTy->getCoroutineKind(), |
| origTy->getCalleeConvention(), dfParams, {}, dfResults, None, |
| original->getASTContext()); |
| |
| SILOptFunctionBuilder fb(context.getTransform()); |
| // The generated tangent linkage is set to Hidden because generated tangent |
| // are never called cross-module. |
| auto linkage = SILLinkage::Hidden; |
| auto *differential = fb.createFunction( |
| linkage, diffName, diffType, diffGenericEnv, original->getLocation(), |
| original->isBare(), IsNotTransparent, original->isSerialized(), |
| original->isDynamicallyReplaceable()); |
| differential->setDebugScope( |
| new (module) SILDebugScope(original->getLocation(), differential)); |
| |
| return differential; |
| } |
| |
| /// Run JVP generation. Returns true on error. |
| bool run() { |
| LLVM_DEBUG(getADDebugStream() |
| << "Cloning original @" << original->getName() |
| << " to jvp @" << jvp->getName() << '\n'); |
| // Create JVP and differential entry and arguments. |
| auto *entry = jvp->createBasicBlock(); |
| createEntryArguments(jvp); |
| prepareForDifferentialGeneration(); |
| // Clone. |
| SmallVector<SILValue, 4> entryArgs(entry->getArguments().begin(), |
| entry->getArguments().end()); |
| cloneFunctionBody(original, entry, entryArgs); |
| emitReturnInstForDifferential(); |
| // If errors occurred, back out. |
| if (errorOccurred) |
| return true; |
| LLVM_DEBUG(getADDebugStream() << "Generated JVP for " |
| << original->getName() << ":\n" << *jvp); |
| LLVM_DEBUG(getADDebugStream() << "Generated differential for " |
| << original->getName() << ":\n" << getDifferential()); |
| return errorOccurred; |
| } |
| |
| void postProcess(SILInstruction *orig, SILInstruction *cloned) { |
| if (errorOccurred) |
| return; |
| SILClonerWithScopes::postProcess(orig, cloned); |
| } |
| |
| /// Remap original basic blocks. |
| SILBasicBlock *remapBasicBlock(SILBasicBlock *bb) { |
| auto *jvpBB = BBMap[bb]; |
| return jvpBB; |
| } |
| |
| /// General visitor for all instructions. If any error is emitted by previous |
| /// visits, bail out. |
| void visit(SILInstruction *inst) { |
| auto diffBuilder = getDifferentialBuilder(); |
| if (errorOccurred) |
| return; |
| if (differentialInfo.shouldDifferentiateInstruction(inst)) { |
| LLVM_DEBUG(getADDebugStream() << "JVPEmitter visited:\n[ORIG]" << *inst); |
| #ifndef NDEBUG |
| auto beforeInsertion = std::prev(diffBuilder.getInsertionPoint()); |
| #endif |
| TypeSubstCloner::visit(inst); |
| LLVM_DEBUG({ |
| auto &s = llvm::dbgs() << "[TAN] Emitted in differential:\n"; |
| auto afterInsertion = diffBuilder.getInsertionPoint(); |
| for (auto it = ++beforeInsertion; it != afterInsertion; ++it) |
| s << *it; |
| }); |
| } else { |
| TypeSubstCloner::visit(inst); |
| } |
| } |
| |
| void visitSILInstruction(SILInstruction *inst) { |
| context.emitNondifferentiabilityError(inst, invoker, |
| diag::autodiff_expression_not_differentiable_note); |
| errorOccurred = true; |
| } |
| |
| void visitInstructionsInBlock(SILBasicBlock *bb) { |
| // Destructure the differential struct to get the elements. |
| auto &diffBuilder = getDifferentialBuilder(); |
| auto diffLoc = getDifferential().getLocation(); |
| auto *diffBB = diffBBMap.lookup(bb); |
| auto *mainDifferentialStruct = diffBB->getArguments().back(); |
| diffBuilder.setInsertionPoint(diffBB); |
| auto *dsi = diffBuilder.createDestructureStruct( |
| diffLoc, mainDifferentialStruct); |
| initializeDifferentialStructElements(bb, dsi->getResults()); |
| TypeSubstCloner::visitInstructionsInBlock(bb); |
| } |
| |
| // If an `apply` has active results or active inout parameters, replace it |
| // with an `apply` of its JVP. |
| void visitApplyInst(ApplyInst *ai) { |
| // If the function should not be differentiated or its the array literal |
| // initialization intrinsic, just do standard cloning. |
| if (!differentialInfo.shouldDifferentiateApplyInst(ai) || |
| isArrayLiteralIntrinsic(ai)) { |
| LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n'); |
| TypeSubstCloner::visitApplyInst(ai); |
| return; |
| } |
| |
| // Check and reject functions with active inout arguments. It's not yet |
| // supported. |
| auto paramInfos = ai->getSubstCalleeConv().getParameters(); |
| auto paramArgs = ai->getArgumentsWithoutIndirectResults(); |
| for (unsigned i : swift::indices(paramInfos)) { |
| if (paramInfos[i].isIndirectInOut() && |
| activityInfo.isActive(paramArgs[i], getIndices())) { |
| context.emitNondifferentiabilityError(ai, invoker, |
| diag::autodiff_cannot_differentiate_through_inout_arguments); |
| errorOccurred = true; |
| return; |
| } |
| } |
| |
| LLVM_DEBUG(getADDebugStream() << "JVP-transforming:\n" << *ai << '\n'); |
| |
| // Get the minimal parameter and result indices required for differentiating |
| // this `apply`. |
| SmallVector<SILValue, 4> allResults; |
| SmallVector<unsigned, 8> activeParamIndices; |
| SmallVector<unsigned, 8> activeResultIndices; |
| collectMinimalIndicesForFunctionCall(ai, getIndices(), activityInfo, |
| allResults, activeParamIndices, |
| activeResultIndices); |
| assert(!activeParamIndices.empty() && "Parameter indices cannot be empty"); |
| assert(!activeResultIndices.empty() && "Result indices cannot be empty"); |
| LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params={"; |
| interleave(activeParamIndices.begin(), activeParamIndices.end(), |
| [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); |
| s << "}, results={"; interleave( |
| activeResultIndices.begin(), activeResultIndices.end(), |
| [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); |
| s << "}\n";); |
| // FIXME: We don't support multiple active results yet. |
| if (activeResultIndices.size() > 1) { |
| context.emitNondifferentiabilityError( |
| ai, invoker, diag::autodiff_expression_not_differentiable_note); |
| errorOccurred = true; |
| return; |
| } |
| // Form expected indices, assuming there's only one result. |
| SILAutoDiffIndices indices( |
| activeResultIndices.front(), |
| IndexSubset::get( |
| getASTContext(), ai->getArgumentsWithoutIndirectResults().size(), |
| activeParamIndices)); |
| |
| // Emit the JVP. |
| auto loc = ai->getLoc(); |
| auto &builder = getBuilder(); |
| auto original = getOpValue(ai->getCallee()); |
| SILValue jvpValue; |
| // If functionSource is a `@differentiable` function, just extract it. |
| auto originalFnTy = original->getType().castTo<SILFunctionType>(); |
| if (originalFnTy->isDifferentiable()) { |
| auto paramIndices = originalFnTy->getDifferentiationParameterIndices(); |
| for (auto i : indices.parameters->getIndices()) { |
| if (!paramIndices->contains(i)) { |
| context.emitNondifferentiabilityError(original, invoker, |
| diag::autodiff_function_nondiff_parameter_not_differentiable); |
| errorOccurred = true; |
| return; |
| } |
| } |
| auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, original); |
| jvpValue = builder.createDifferentiableFunctionExtract( |
| loc, NormalDifferentiableFunctionTypeComponent::JVP, |
| borrowedDiffFunc); |
| jvpValue = builder.emitCopyValueOperation(loc, jvpValue); |
| } |
| |
| // If JVP has not yet been found, emit an `differentiable_function` |
| // instruction on the remapped original function operand and |
| // an `differentiable_function_extract` instruction to get the JVP. |
| // The `differentiable_function` instruction will be canonicalized during |
| // the transform main loop. |
| if (!jvpValue) { |
| // FIXME: Handle indirect differentiation invokers. This may require some |
| // redesign: currently, each original function + attribute pair is mapped |
| // only to one invoker. |
| /* |
| DifferentiationInvoker indirect(ai, attr); |
| auto insertion = |
| context.getInvokers().try_emplace({this->original, attr}, indirect); |
| auto &invoker = insertion.first->getSecond(); |
| invoker = indirect; |
| */ |
| |
| // If the original `apply` instruction has a substitution map, then the |
| // applied function is specialized. |
| // In the JVP, specialization is also necessary for parity. The original |
| // function operand is specialized with a remapped version of same |
| // substitution map using an argument-less `partial_apply`. |
| if (ai->getSubstitutionMap().empty()) { |
| original = builder.emitCopyValueOperation(loc, original); |
| } else { |
| auto substMap = getOpSubstitutionMap(ai->getSubstitutionMap()); |
| auto jvpPartialApply = getBuilder().createPartialApply( |
| ai->getLoc(), original, substMap, {}, |
| ParameterConvention::Direct_Guaranteed); |
| original = jvpPartialApply; |
| } |
| |
| // Check and diagnose non-differentiable original function type. |
| auto diagnoseNondifferentiableOriginalFunctionType = |
| [&](CanSILFunctionType origFnTy) { |
| // Check and diagnose non-differentiable arguments. |
| for (unsigned paramIndex : range(originalFnTy->getNumParameters())) { |
| if (indices.isWrtParameter(paramIndex) && |
| !originalFnTy->getParameters()[paramIndex] |
| .getSILStorageType() |
| .isDifferentiable(getModule())) { |
| context.emitNondifferentiabilityError( |
| ai->getArgumentsWithoutIndirectResults()[paramIndex], invoker, |
| diag::autodiff_nondifferentiable_argument); |
| errorOccurred = true; |
| return true; |
| } |
| } |
| // Check and diagnose non-differentiable results. |
| if (!originalFnTy->getResults()[indices.source] |
| .getSILStorageType() |
| .isDifferentiable(getModule())) { |
| context.emitNondifferentiabilityError( |
| original, invoker, diag::autodiff_nondifferentiable_result); |
| errorOccurred = true; |
| return true; |
| } |
| return false; |
| }; |
| if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy)) |
| return; |
| |
| auto *diffFuncInst = context.createDifferentiableFunction( |
| builder, loc, indices.parameters, original); |
| |
| // Record the `differentiable_function` instruction. |
| context.getDifferentiableFunctionInsts().push_back(diffFuncInst); |
| // TODO(TF-689): Make `differentiable_function` store result indices and |
| // remove `ADContext::resultIndices`. |
| context.getResultIndices()[diffFuncInst] = activeResultIndices.front(); |
| |
| auto borrowedADFunc = |
| builder.emitBeginBorrowOperation(loc, diffFuncInst); |
| auto extractedJVP = builder.createDifferentiableFunctionExtract( |
| loc, NormalDifferentiableFunctionTypeComponent::JVP, |
| borrowedADFunc); |
| jvpValue = builder.emitCopyValueOperation(loc, extractedJVP); |
| builder.emitEndBorrowOperation(loc, borrowedADFunc); |
| builder.emitDestroyValueOperation(loc, diffFuncInst); |
| } |
| |
| // Call the JVP using the original parameters. |
| SmallVector<SILValue, 8> jvpArgs; |
| auto jvpFnTy = getOpType(jvpValue->getType()).castTo<SILFunctionType>(); |
| auto numJVPArgs = |
| jvpFnTy->getNumParameters() + jvpFnTy->getNumIndirectFormalResults(); |
| jvpArgs.reserve(numJVPArgs); |
| // Collect substituted arguments. |
| for (auto origArg : ai->getArguments()) |
| jvpArgs.push_back(getOpValue(origArg)); |
| assert(jvpArgs.size() == numJVPArgs); |
| // Apply the JVP. |
| // The JVP should be specialized, so no substitution map is necessary. |
| auto *jvpCall = getBuilder().createApply(loc, jvpValue, SubstitutionMap(), |
| jvpArgs, ai->isNonThrowing()); |
| LLVM_DEBUG(getADDebugStream() << "Applied jvp function\n" << *jvpCall); |
| |
| // Release the differentiable function. |
| builder.emitDestroyValueOperation(loc, jvpValue); |
| |
| // Get the JVP results (original results and differential). |
| SmallVector<SILValue, 8> jvpDirectResults; |
| extractAllElements(jvpCall, builder, jvpDirectResults); |
| auto originalDirectResults = |
| ArrayRef<SILValue>(jvpDirectResults).drop_back(1); |
| auto originalDirectResult = |
| joinElements(originalDirectResults, getBuilder(), jvpCall->getLoc()); |
| |
| mapValue(ai, originalDirectResult); |
| |
| // Some instructions that produce the callee may have been cloned. |
| // If the original callee did not have any users beyond this `apply`, |
| // recursively kill the cloned callee. |
| if (auto *origCallee = cast_or_null<SingleValueInstruction>( |
| ai->getCallee()->getDefiningInstruction())) |
| if (origCallee->hasOneUse()) |
| recursivelyDeleteTriviallyDeadInstructions( |
| getOpValue(origCallee)->getDefiningInstruction()); |
| |
| // Add the differential function for when we create the struct we partially |
| // apply to the differential we are generating. |
| auto differential = jvpDirectResults.back(); |
| auto *differentialDecl = differentialInfo.lookUpLinearMapDecl(ai); |
| auto originalDifferentialType = |
| getOpType(differential->getType()).getAs<SILFunctionType>(); |
| auto differentialType = |
| remapType(differential->getType()) |
| .castTo<SILFunctionType>(); |
| auto jvpGenSig = SubsMap.getGenericSignature() |
| ? SubsMap.getGenericSignature()->getCanonicalSignature() |
| : nullptr; |
| Lowering::GenericContextScope genericContextScope( |
| context.getTypeConverter(), jvpGenSig); |
| auto loweredDifferentialType = |
| getOpType(context.getTypeConverter().getLoweredType( |
| differentialDecl->getInterfaceType()->getCanonicalType(), |
| ResilienceExpansion::Minimal)) |
| .castTo<SILFunctionType>(); |
| // If actual differential type does not match lowered differential type, |
| // reabstract the differential using a thunk. |
| if (!loweredDifferentialType->isEqual(originalDifferentialType)) { |
| SILOptFunctionBuilder fb(context.getTransform()); |
| auto *thunk = getOrCreateReabstractionThunk( |
| fb, context.getModule(), loc, &getDifferential(), |
| differentialType, loweredDifferentialType); |
| auto *thunkRef = builder.createFunctionRef(loc, thunk); |
| differential = builder.createPartialApply( |
| loc, thunkRef, |
| getOpSubstitutionMap(thunk->getForwardingSubstitutionMap()), |
| {differential}, differentialType->getCalleeConvention()); |
| } |
| differentialValues[ai->getParent()].push_back(differential); |
| |
| // Differential emission. |
| emitTangentForApplyInst(ai, indices, originalDifferentialType); |
| } |
| |
| void visitReturnInst(ReturnInst *ri) { |
| auto loc = ri->getOperand().getLoc(); |
| auto *origExit = ri->getParent(); |
| auto &builder = getBuilder(); |
| auto *diffStructVal = buildDifferentialValueStructValue(ri); |
| |
| // Get the JVP value corresponding to the original functions's return value. |
| auto *origRetInst = cast<ReturnInst>(origExit->getTerminator()); |
| auto origResult = getOpValue(origRetInst->getOperand()); |
| SmallVector<SILValue, 8> origResults; |
| extractAllElements(origResult, builder, origResults); |
| |
| // Get and partially apply the differential. |
| auto jvpGenericEnv = jvp->getGenericEnvironment(); |
| auto jvpSubstMap = jvpGenericEnv |
| ? jvpGenericEnv->getForwardingSubstitutionMap() |
| : jvp->getForwardingSubstitutionMap(); |
| auto *differentialRef = |
| builder.createFunctionRef(loc, &getDifferential()); |
| auto *differentialPartialApply = builder.createPartialApply( |
| loc, differentialRef, jvpSubstMap, {diffStructVal}, |
| ParameterConvention::Direct_Guaranteed); |
| |
| // Return a tuple of the original result and pullback. |
| SmallVector<SILValue, 8> directResults; |
| directResults.append(origResults.begin(), origResults.end()); |
| directResults.push_back(differentialPartialApply); |
| builder.createReturn( |
| ri->getLoc(), joinElements(directResults, builder, loc)); |
| } |
| |
| void visitBranchInst(BranchInst *bi) { |
| llvm_unreachable("Unsupported SIL instruction."); |
| } |
| |
| void visitCondBranchInst(CondBranchInst *cbi) { |
| llvm_unreachable("Unsupported SIL instruction."); |
| } |
| |
| void visitSwitchEnumInst(SwitchEnumInst *sei) { |
| llvm_unreachable("Unsupported SIL instruction."); |
| } |
| |
| void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) { |
| // Clone `differentiable_function` from original to JVP, then add the cloned |
| // instruction to the `differentiable_function` worklist. |
| TypeSubstCloner::visitDifferentiableFunctionInst(dfi); |
| auto *newDFI = cast<DifferentiableFunctionInst>(getOpValue(dfi)); |
| context.getDifferentiableFunctionInsts().push_back(newDFI); |
| } |
| }; |
| } // end anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // PullbackEmitter - visitors on the original function for pullback code |
| // generation |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> { |
| private: |
| /// The parent VJP emitter. |
| VJPEmitter &vjpEmitter; |
| |
| /// Dominance info for the original function. |
| DominanceInfo *domInfo = nullptr; |
| |
| /// Post-dominance info for the original function. |
| PostDominanceInfo *postDomInfo = nullptr; |
| |
| /// Post-order info for the original function. |
| PostOrderFunctionInfo *postOrderInfo = nullptr; |
| |
| /// Mapping from original basic blocks to corresponding pullback basic blocks. |
| /// Pullback basic blocks always have the predecessor as the single argument. |
| DenseMap<SILBasicBlock *, SILBasicBlock *> pullbackBBMap; |
| |
| /// Mapping from original basic blocks and original values to corresponding |
| /// adjoint values. |
| DenseMap<std::pair<SILBasicBlock *, SILValue>, AdjointValue> valueMap; |
| |
| /// Mapping from original basic blocks and original buffers to corresponding |
| /// adjoint buffers. |
| DenseMap<std::pair<SILBasicBlock *, SILValue>, SILValue> bufferMap; |
| |
| /// Mapping from pullback basic blocks to pullback struct arguments. |
| DenseMap<SILBasicBlock *, SILArgument *> pullbackStructArguments; |
| |
| /// Mapping from pullback struct field declarations to pullback struct |
| /// elements destructured from the linear map basic block argument. In the |
| /// beginning of each pullback basic block, the block's pullback struct is |
| /// destructured into individual elements stored here. |
| DenseMap<VarDecl *, SILValue> pullbackStructElements; |
| |
| /// Mapping from original basic blocks and successor basic blocks to |
| /// corresponding pullback trampoline basic blocks. Trampoline basic blocks |
| /// take additional arguments in addition to the predecessor enum argument. |
| DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, SILBasicBlock *> |
| pullbackTrampolineBBMap; |
| |
| /// Mapping from original basic blocks to dominated active values. |
| DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> activeValues; |
| |
| /// Mapping from original basic blocks and original active values to |
| /// corresponding pullback block arguments. |
| DenseMap<std::pair<SILBasicBlock *, SILValue>, SILArgument *> |
| activeValuePullbackBBArgumentMap; |
| |
| /// Mapping from original basic blocks to local temporary values to be cleaned |
| /// up. This is populated when pullback emission is run on one basic block and |
| /// cleaned before processing another basic block. |
| DenseMap<SILBasicBlock *, SmallVector<SILValue, 64>> |
| blockTemporaries; |
| |
| llvm::DenseSet<SILValue> blockTemporarySet; |
| |
| /// The main builder. |
| SILBuilder builder; |
| |
| /// An auxiliary local allocation builder. |
| SILBuilder localAllocBuilder; |
| |
| /// Stack buffers allocated for storing local adjoint values. |
| SmallVector<SILValue, 64> functionLocalAllocations; |
| |
| /// A set used to remember local allocations that were destroyed. |
| llvm::SmallDenseSet<SILValue> destroyedLocalAllocations; |
| |
| /// The seed argument in the pullback function. |
| SILArgument *seed = nullptr; |
| |
| llvm::BumpPtrAllocator allocator; |
| |
| bool errorOccurred = false; |
| |
| ADContext &getContext() const { return vjpEmitter.context; } |
| SILModule &getModule() const { return getContext().getModule(); } |
| ASTContext &getASTContext() const { return getPullback().getASTContext(); } |
| SILFunction &getOriginal() const { return *vjpEmitter.original; } |
| SILFunction &getPullback() const { return *vjpEmitter.pullback; } |
| SILDifferentiableAttr *getAttr() const { return vjpEmitter.attr; } |
| DifferentiationInvoker getInvoker() const { return vjpEmitter.invoker; } |
| LinearMapInfo &getPullbackInfo() { return vjpEmitter.pullbackInfo; } |
| const SILAutoDiffIndices &getIndices() const { |
| return vjpEmitter.getIndices(); |
| } |
| const DifferentiableActivityInfo &getActivityInfo() const { |
| return vjpEmitter.activityInfo; |
| } |
| |
| public: |
| explicit PullbackEmitter(VJPEmitter &vjpEmitter) |
| : vjpEmitter(vjpEmitter), builder(getPullback()), |
| localAllocBuilder(getPullback()) { |
| // Get dominance and post-order info for the original function. |
| auto &passManager = getContext().getPassManager(); |
| auto *domAnalysis = passManager.getAnalysis<DominanceAnalysis>(); |
| auto *postDomAnalysis = passManager.getAnalysis<PostDominanceAnalysis>(); |
| auto *postOrderAnalysis = passManager.getAnalysis<PostOrderAnalysis>(); |
| domInfo = domAnalysis->get(vjpEmitter.original); |
| postDomInfo = postDomAnalysis->get(vjpEmitter.original); |
| postOrderInfo = postOrderAnalysis->get(vjpEmitter.original); |
| } |
| |
| private: |
| //--------------------------------------------------------------------------// |
| // Pullback struct mapping |
| //--------------------------------------------------------------------------// |
| |
| void initializePullbackStructElements(SILBasicBlock *origBB, |
| SILInstructionResultArray values) { |
| auto *pbStructDecl = getPullbackInfo().getLinearMapStruct(origBB); |
| assert(pbStructDecl->getStoredProperties().size() == values.size() && |
| "The number of pullback struct fields must equal the number of " |
| "pullback struct element values"); |
| for (auto pair : llvm::zip(pbStructDecl->getStoredProperties(), values)) { |
| assert( |
| std::get<1>(pair).getOwnershipKind() != ValueOwnershipKind::Guaranteed |
| && "Pullback struct elements must be @owned"); |
| auto insertion = |
| pullbackStructElements.insert({std::get<0>(pair), std::get<1>(pair)}); |
| (void)insertion; |
| assert(insertion.second && "A pullback struct element already exists!"); |
| } |
| } |
| |
| SILValue getPullbackStructElement(SILBasicBlock *origBB, VarDecl *field) { |
| assert(getPullbackInfo().getLinearMapStruct(origBB) == |
| cast<StructDecl>(field->getDeclContext())); |
| assert(pullbackStructElements.count(field) && |
| "Pullback struct element for this field does not exist!"); |
| return pullbackStructElements.lookup(field); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Adjoint value factory methods |
| //--------------------------------------------------------------------------// |
| |
| AdjointValue makeZeroAdjointValue(SILType type); |
| |
| AdjointValue makeConcreteAdjointValue(SILValue value); |
| |
| template<typename EltRange> |
| AdjointValue makeAggregateAdjointValue(SILType type, EltRange elements); |
| |
| //--------------------------------------------------------------------------// |
| // Temporary value management |
| //--------------------------------------------------------------------------// |
| |
| /// Record a temporary value for cleanup before its block's terminator. |
| SILValue recordTemporary(SILValue value) { |
| assert(value->getType().isObject()); |
| blockTemporaries[value->getParentBlock()].push_back(value); |
| LLVM_DEBUG(getADDebugStream() << "Recorded temporary " << value); |
| auto insertion = blockTemporarySet.insert(value); (void)insertion; |
| assert(insertion.second && "Temporary already recorded?"); |
| return value; |
| } |
| |
| /// Clean up all temporary values for the given block. |
| void cleanUpTemporariesForBlock(SILBasicBlock *bb, SILLocation loc) { |
| LLVM_DEBUG(getADDebugStream() << "Cleaning up temporaries for bb" |
| << bb->getDebugID() << '\n'); |
| for (auto temp : blockTemporaries[bb]) { |
| builder.emitDestroyValueOperation(loc, temp); |
| blockTemporarySet.erase(temp); |
| } |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Symbolic value materializers |
| //--------------------------------------------------------------------------// |
| |
| /// Materialize an adjoint value. The type of the given adjoint value must be |
| /// loadable. |
| SILValue materializeAdjointDirect(AdjointValue val, SILLocation loc); |
| |
| /// Materialize an adjoint value indirectly to a SIL buffer. |
| void materializeAdjointIndirect(AdjointValue val, SILValue destBuffer, |
| SILLocation loc); |
| |
| //--------------------------------------------------------------------------// |
| // Helpers for symbolic value materializers |
| //--------------------------------------------------------------------------// |
| |
| /// Emit a zero value by calling `AdditiveArithmetic.zero`. The given type |
| /// must conform to `AdditiveArithmetic`. |
| void emitZeroIndirect(CanType type, SILValue bufferAccess, SILLocation loc); |
| |
| /// Emit a zero value by calling `AdditiveArithmetic.zero`. The given type |
| /// must conform to `AdditiveArithmetic` and be loadable in SIL. |
| SILValue emitZeroDirect(CanType type, SILLocation loc); |
| |
| //--------------------------------------------------------------------------// |
| // Accumulator |
| //--------------------------------------------------------------------------// |
| |
| /// Materialize an adjoint value in the most efficient way. |
| SILValue materializeAdjoint(AdjointValue val, SILLocation loc); |
| |
| /// Given two adjoint values, accumulate them. |
| AdjointValue accumulateAdjointsDirect(AdjointValue lhs, AdjointValue rhs, |
| SILLocation loc); |
| |
| /// Given two materialized adjoint values, accumulate them. These two |
| /// adjoints must be objects of loadable type. |
| SILValue accumulateDirect(SILValue lhs, SILValue rhs, SILLocation loc); |
| |
| /// Given two materialized adjoint values, accumulate them using |
| /// `AdditiveArithmetic.+`, depending on the differentiation mode. |
| void accumulateIndirect(SILValue resultBufAccess, |
| SILValue lhsBufAccess, SILValue rhsBufAccess, |
| SILLocation loc); |
| |
| /// Given two buffers of an `AdditiveArithmetic` type, accumulate the right |
| /// hand side into the left hand side using `+=`. |
| void accumulateIndirect(SILValue lhsDestAccess, SILValue rhsAccess, |
| SILLocation loc); |
| |
| //--------------------------------------------------------------------------// |
| // Type transformer |
| //--------------------------------------------------------------------------// |
| |
| /// Remap any archetypes into the current function's context. |
| SILType remapType(SILType ty) { |
| if (ty.hasArchetype()) |
| return getPullback().mapTypeIntoContext(ty.mapTypeOutOfContext()); |
| return getPullback().mapTypeIntoContext(ty); |
| } |
| |
| Optional<VectorSpace> getTangentSpace(CanType type) { |
| return type->getAutoDiffAssociatedTangentSpace( |
| LookUpConformanceInModule(getModule().getSwiftModule())); |
| } |
| |
| /// Assuming the given type conforms to `Differentiable` after remapping, |
| /// returns the associated tangent space type. |
| SILType getRemappedTangentType(SILType type) { |
| return SILType::getPrimitiveType( |
| getTangentSpace(remapType(type).getASTType())->getCanonicalType(), |
| type.getCategory()); |
| } |
| |
| /// Substitutes all replacement types of the given substitution map using the |
| /// pullback function's substitution map. |
| SubstitutionMap remapSubstitutionMap(SubstitutionMap substMap) { |
| return substMap.subst(getPullback().getForwardingSubstitutionMap()); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Managed value mapping |
| //--------------------------------------------------------------------------// |
| |
| /// Returns true if the original value has a corresponding adjoint value. |
| bool hasAdjointValue(SILBasicBlock *origBB, SILValue originalValue) const { |
| assert(origBB->getParent() == &getOriginal()); |
| assert(originalValue->getType().isObject()); |
| return valueMap.count({origBB, originalValue}); |
| } |
| |
| /// Initializes an original value's corresponding adjoint value. It must not |
| /// have an adjoint value before this function is called. |
| void setAdjointValue(SILBasicBlock *origBB, SILValue originalValue, |
| AdjointValue adjointValue) { |
| LLVM_DEBUG(getADDebugStream() << "Setting adjoint value for " |
| << originalValue); |
| assert(origBB->getParent() == &getOriginal()); |
| assert(originalValue->getType().isObject()); |
| assert(adjointValue.getType().isObject()); |
| assert(originalValue->getFunction() == &getOriginal()); |
| // The adjoint value must be in the tangent space. |
| assert(adjointValue.getType() == |
| getRemappedTangentType(originalValue->getType())); |
| auto insertion = valueMap.try_emplace({origBB, originalValue}, |
| adjointValue); |
| LLVM_DEBUG(getADDebugStream() |
| << "The existing adjoint value will be replaced: " |
| << insertion.first->getSecond()); |
| if (!insertion.second) |
| insertion.first->getSecond() = adjointValue; |
| } |
| |
| /// Get the adjoint for an original value. The given value must be in the |
| /// original function. |
| /// |
| /// This method first tries to find an entry in `adjointMap`. If an adjoint |
| /// doesn't exist, create a zero adjoint. |
| AdjointValue getAdjointValue(SILBasicBlock *origBB, SILValue originalValue) { |
| assert(origBB->getParent() == &getOriginal()); |
| assert(originalValue->getType().isObject()); |
| assert(originalValue->getFunction() == &getOriginal()); |
| auto insertion = valueMap.try_emplace( |
| {origBB, originalValue}, makeZeroAdjointValue( |
| getRemappedTangentType(originalValue->getType()))); |
| auto it = insertion.first; |
| return it->getSecond(); |
| } |
| |
| /// Add an adjoint value for the given original value. |
| void addAdjointValue(SILBasicBlock *origBB, SILValue originalValue, |
| AdjointValue newAdjointValue, SILLocation loc) { |
| assert(origBB->getParent() == &getOriginal()); |
| assert(originalValue->getType().isObject()); |
| assert(newAdjointValue.getType().isObject()); |
| assert(originalValue->getFunction() == &getOriginal()); |
| LLVM_DEBUG(getADDebugStream() << "Adding adjoint for " << originalValue); |
| // The adjoint value must be in the tangent space. |
| assert(newAdjointValue.getType() == |
| getRemappedTangentType(originalValue->getType())); |
| auto insertion = |
| valueMap.try_emplace({origBB, originalValue}, newAdjointValue); |
| auto inserted = insertion.second; |
| if (inserted) |
| return; |
| // If adjoint already exists, accumulate the adjoint onto the existing |
| // adjoint. |
| auto it = insertion.first; |
| auto existingValue = it->getSecond(); |
| valueMap.erase(it); |
| auto adjVal = accumulateAdjointsDirect(existingValue, newAdjointValue, loc); |
| setAdjointValue(origBB, originalValue, adjVal); |
| } |
| |
| /// Get the pullback block argument corresponding to the given original block |
| /// and active value. |
| SILArgument *getActiveValuePullbackBlockArgument(SILBasicBlock *origBB, |
| SILValue activeValue) { |
| assert(origBB->getParent() == &getOriginal()); |
| auto pullbackBBArg = |
| activeValuePullbackBBArgumentMap[{origBB, activeValue}]; |
| assert(pullbackBBArg); |
| assert(pullbackBBArg->getParent() == getPullbackBlock(origBB)); |
| return pullbackBBArg; |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Buffer mapping |
| //--------------------------------------------------------------------------// |
| |
| void setAdjointBuffer(SILBasicBlock *origBB, |
| SILValue originalBuffer, |
| SILValue adjointBuffer) { |
| assert(originalBuffer->getType().isAddress()); |
| auto insertion = |
| bufferMap.try_emplace({origBB, originalBuffer}, adjointBuffer); |
| assert(insertion.second); (void)insertion; |
| } |
| |
| SILValue getAdjointProjection(SILBasicBlock *origBB, |
| SILValue originalProjection) { |
| // Handle `struct_element_addr`. |
| if (auto *seai = dyn_cast<StructElementAddrInst>(originalProjection)) { |
| auto adjSource = getAdjointBuffer(origBB, seai->getOperand()); |
| auto *tangentVectorDecl = |
| adjSource->getType().getStructOrBoundGenericStruct(); |
| auto tanFieldLookup = |
| tangentVectorDecl->lookupDirect(seai->getField()->getName()); |
| assert(tanFieldLookup.size() == 1); |
| auto *tanField = cast<VarDecl>(tanFieldLookup.front()); |
| return builder.createStructElementAddr( |
| seai->getLoc(), adjSource, tanField); |
| } |
| // Handle `tuple_element_addr`. |
| if (auto *teai = dyn_cast<TupleElementAddrInst>(originalProjection)) { |
| auto source = teai->getOperand(); |
| auto adjSource = getAdjointBuffer(origBB, source); |
| if (!adjSource->getType().is<TupleType>()) |
| return adjSource; |
| auto origTupleTy = source->getType().castTo<TupleType>(); |
| unsigned adjIndex = 0; |
| for (unsigned i : range(teai->getFieldNo())) { |
| if (getTangentSpace( |
| origTupleTy->getElement(i).getType()->getCanonicalType())) |
| ++adjIndex; |
| } |
| return builder.createTupleElementAddr( |
| teai->getLoc(), adjSource, adjIndex); |
| } |
| // Handle `begin_access`. |
| if (auto *bai = dyn_cast<BeginAccessInst>(originalProjection)) { |
| auto adjBase = getAdjointBuffer(origBB, bai->getOperand()); |
| if (errorOccurred) |
| return (bufferMap[{origBB, originalProjection}] = SILValue()); |
| // Return the base buffer's adjoint buffer. |
| return adjBase; |
| } |
| return SILValue(); |
| } |
| |
| SILBasicBlock::iterator getNextFunctionLocalAllocationInsertionPoint() { |
| // If there are no local allocations, insert at the pullback entry start. |
| if (functionLocalAllocations.empty()) |
| return getPullback().getEntryBlock()->begin(); |
| // Otherwise, insert before the last local allocation. Inserting before |
| // rather than after ensures that allocation and zero initialization |
| // instructions are grouped together. |
| auto lastLocalAlloc = functionLocalAllocations.back(); |
| return lastLocalAlloc->getDefiningInstruction()->getIterator(); |
| } |
| |
| SILValue &getAdjointBuffer(SILBasicBlock *origBB, SILValue originalBuffer) { |
| assert(originalBuffer->getType().isAddress()); |
| assert(originalBuffer->getFunction() == &getOriginal()); |
| auto insertion = bufferMap.try_emplace({origBB, originalBuffer}, |
| SILValue()); |
| if (!insertion.second) // not inserted |
| return insertion.first->getSecond(); |
| |
| // If the original buffer is a projection, return a corresponding projection |
| // into the adjoint buffer. |
| if (auto adjProj = getAdjointProjection(origBB, originalBuffer)) |
| return (bufferMap[{origBB, originalBuffer}] = adjProj); |
| |
| // Set insertion point for local allocation builder: before the last local |
| // allocation, or at the start of the pullback function's entry if no local |
| // allocations exist yet. |
| localAllocBuilder.setInsertionPoint( |
| getPullback().getEntryBlock(), |
| getNextFunctionLocalAllocationInsertionPoint()); |
| // Allocate local buffer and initialize to zero. |
| auto bufObjectType = getRemappedTangentType(originalBuffer->getType()); |
| auto *newBuf = localAllocBuilder.createAllocStack( |
| RegularLocation::getAutoGeneratedLocation(), bufObjectType); |
| // Temporarily change global builder insertion point and emit zero into the |
| // local buffer. |
| auto insertionPoint = builder.getInsertionBB(); |
| builder.setInsertionPoint( |
| localAllocBuilder.getInsertionBB(), |
| localAllocBuilder.getInsertionPoint()); |
| emitZeroIndirect(bufObjectType.getASTType(), newBuf, newBuf->getLoc()); |
| builder.setInsertionPoint(insertionPoint); |
| // Register the local buffer. |
| functionLocalAllocations.push_back(newBuf); |
| return (insertion.first->getSecond() = newBuf); |
| } |
| |
| // Accumulates `rhsBufferAccess` into the adjoint buffer corresponding to |
| // `originalBuffer`. |
| void addToAdjointBuffer(SILBasicBlock *origBB, SILValue originalBuffer, |
| SILValue rhsBufferAccess, SILLocation loc) { |
| assert(originalBuffer->getType().isAddress() && |
| rhsBufferAccess->getType().isAddress()); |
| assert(originalBuffer->getFunction() == &getOriginal()); |
| assert(rhsBufferAccess->getFunction() == &getPullback()); |
| auto adjointBuffer = getAdjointBuffer(origBB, originalBuffer); |
| accumulateIndirect(adjointBuffer, rhsBufferAccess, loc); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // CFG mapping |
| //--------------------------------------------------------------------------// |
| |
| SILBasicBlock *getPullbackBlock(SILBasicBlock *originalBlock) { |
| return pullbackBBMap.lookup(originalBlock); |
| } |
| |
| SILBasicBlock *getPullbackTrampolineBlock( |
| SILBasicBlock *originalBlock, SILBasicBlock *successorBlock) { |
| return pullbackTrampolineBBMap.lookup({originalBlock, successorBlock}); |
| } |
| |
| public: |
| //--------------------------------------------------------------------------// |
| // Entry point |
| //--------------------------------------------------------------------------// |
| |
| /// Performs pullback generation on the empty pullback function. Returns true |
| /// if any error occurs. |
| bool run() { |
| auto &original = getOriginal(); |
| auto &pullback = getPullback(); |
| auto pbLoc = getPullback().getLocation(); |
| LLVM_DEBUG(getADDebugStream() << "Running PullbackEmitter on\n" |
| << original); |
| |
| auto *pbGenEnv = getPullback().getGenericEnvironment(); |
| auto pbGenSig = pbGenEnv |
| ? pbGenEnv->getGenericSignature()->getCanonicalSignature() |
| : nullptr; |
| Lowering::GenericContextScope genericContextScope( |
| getContext().getTypeConverter(), pbGenSig); |
| auto origExitIt = original.findReturnBB(); |
| assert(origExitIt != original.end() && |
| "Functions without returns must have been diagnosed"); |
| auto *origExit = &*origExitIt; |
| |
| SmallVector<SILValue, 8> origFormalResults; |
| collectAllFormalResultsInTypeOrder(original, origFormalResults); |
| auto origResult = origFormalResults[getIndices().source]; |
| |
| // If original result is non-varied, it will always have a zero derivative. |
| // Skip full pullback generation and simply emit zero derivatives for wrt |
| // parameters. |
| // |
| // NOTE(TF-876): This shortcut is currently necessary for functions |
| // returning non-varied result with >1 basic block where some basic blocks |
| // have no dominated active values; control flow differentiation does not |
| // handle this case. See TF-876 for context. |
| if (!getActivityInfo().isVaried(origResult, getIndices().parameters)) { |
| emitZeroDerivativesForNonvariedResult(origResult); |
| return false; |
| } |
| |
| // Get dominated active values in original blocks. |
| // Adjoint values of dominated active values are passed as pullback block |
| // arguments. |
| DominanceOrder domOrder(original.getEntryBlock(), domInfo); |
| while (auto *bb = domOrder.getNext()) { |
| auto &bbActiveValues = activeValues[bb]; |
| // If the current block has an immediate dominator, append the immediate |
| // dominator block's active values to the current block's active values. |
| if (auto *domNode = domInfo->getNode(bb)->getIDom()) { |
| auto &domBBActiveValues = activeValues[domNode->getBlock()]; |
| bbActiveValues.append(domBBActiveValues.begin(), |
| domBBActiveValues.end()); |
| } |
| SmallPtrSet<SILValue, 8> visited(bbActiveValues.begin(), |
| bbActiveValues.end()); |
| // Register a value as active if it has not yet been visited. |
| auto addActiveValue = [&](SILValue v) { |
| if (visited.count(v)) |
| return; |
| // Diagnose active enum values. Differentiation of enum values is not |
| // yet supported; requires special adjoint value handling. |
| if (v->getType().getEnumOrBoundGenericEnum()) { |
| getContext().emitNondifferentiabilityError( |
| v, getInvoker(), diag::autodiff_enums_unsupported); |
| errorOccurred = true; |
| } |
| // Skip address projections. |
| // Address projections do not need their own adjoint buffers; they |
| // become projections into their adjoint base buffer. |
| if (Projection::isAddressProjection(v)) |
| return; |
| visited.insert(v); |
| bbActiveValues.push_back(v); |
| }; |
| // Register bb arguments and all instruction operands/results. |
| for (auto *arg : bb->getArguments()) |
| if (getActivityInfo().isActive(arg, getIndices())) |
| addActiveValue(arg); |
| for (auto &inst : *bb) { |
| for (auto op : inst.getOperandValues()) |
| if (getActivityInfo().isActive(op, getIndices())) |
| addActiveValue(op); |
| for (auto result : inst.getResults()) |
| if (getActivityInfo().isActive(result, getIndices())) |
| addActiveValue(result); |
| } |
| domOrder.pushChildren(bb); |
| if (errorOccurred) |
| return true; |
| } |
| |
| // Create pullback blocks and arguments, visiting original blocks in |
| // post-order post-dominance order. |
| SmallVector<SILBasicBlock *, 8> postOrderPostDomOrder; |
| // Start from the root node, which may have a marker `nullptr` block if |
| // there are multiple roots. |
| PostOrderPostDominanceOrder postDomOrder(postDomInfo->getRootNode(), |
| postOrderInfo, original.size()); |
| while (auto *origNode = postDomOrder.getNext()) { |
| auto *origBB = origNode->getBlock(); |
| postDomOrder.pushChildren(origNode); |
| // If node is the `nullptr` marker basic block, do not push it. |
| if (!origBB) |
| continue; |
| postOrderPostDomOrder.push_back(origBB); |
| } |
| for (auto *origBB : postOrderPostDomOrder) { |
| auto *pullbackBB = pullback.createBasicBlock(); |
| pullbackBBMap.insert({origBB, pullbackBB}); |
| auto pbStructLoweredType = |
| remapType(getPullbackInfo().getLinearMapStructLoweredType(origBB)); |
| // If the BB is the original exit, then the pullback block that we just |
| // created must be the pullback function's entry. For the pullback entry, |
| // create entry arguments and continue to the next block. |
| if (origBB == origExit) { |
| assert(pullbackBB->isEntry()); |
| createEntryArguments(&pullback); |
| auto *mainPullbackStruct = pullbackBB->getArguments().back(); |
| assert(mainPullbackStruct->getType() == pbStructLoweredType); |
| pullbackStructArguments[origBB] = mainPullbackStruct; |
| // Destructure the pullback struct to get the elements. |
| builder.setInsertionPoint(pullbackBB); |
| auto *dsi = builder.createDestructureStruct(pbLoc, mainPullbackStruct); |
| initializePullbackStructElements(origBB, dsi->getResults()); |
| continue; |
| } |
| // Get all active values in the original block. |
| // If the original block has no active values, continue. |
| auto &bbActiveValues = activeValues[origBB]; |
| if (bbActiveValues.empty()) |
| continue; |
| // Otherwise, if the original block has active values: |
| // - For each active buffer in the original block, allocate a new local |
| // buffer in the pullback entry. (All adjoint buffers are allocated in |
| // the pullback entry and deallocated in the pullback exit.) |
| // - For each active value in the original block, add adjoint value |
| // arguments to the pullback block. |
| for (auto activeValue : bbActiveValues) { |
| if (activeValue->getType().isAddress()) { |
| // Allocate and zero initialize a new local buffer using |
| // `getAdjointBuffer`. |
| builder.setInsertionPoint(pullback.getEntryBlock()); |
| getAdjointBuffer(origBB, activeValue); |
| } else { |
| // Create and register pullback block argument for the active value. |
| auto *pullbackArg = pullbackBB->createPhiArgument( |
| getRemappedTangentType(activeValue->getType()), |
| ValueOwnershipKind::Owned); |
| activeValuePullbackBBArgumentMap[{origBB, activeValue}] = pullbackArg; |
| recordTemporary(pullbackArg); |
| } |
| } |
| // Add a pullback struct argument. |
| auto *pbStructArg = pullbackBB->createPhiArgument( |
| pbStructLoweredType, ValueOwnershipKind::Owned); |
| pullbackStructArguments[origBB] = pbStructArg; |
| // Destructure the pullback struct to get the elements. |
| builder.setInsertionPoint(pullbackBB); |
| auto *dsi = builder.createDestructureStruct(pbLoc, pbStructArg); |
| initializePullbackStructElements(origBB, dsi->getResults()); |
| |
| // - Create pullback trampoline blocks for each successor block of the |
| // original block. Pullback trampoline blocks only have a pullback |
| // struct argument. They branch from a pullback successor block to the |
| // pullback original block, passing adjoint values of active values. |
| for (auto *succBB : origBB->getSuccessorBlocks()) { |
| auto *pullbackTrampolineBB = |
| pullback.createBasicBlockBefore(pullbackBB); |
| pullbackTrampolineBBMap.insert({{origBB, succBB}, |
| pullbackTrampolineBB}); |
| // Get the enum element type (i.e. the pullback struct type). The enum |
| // element type may be boxed if the enum is indirect. |
| auto enumLoweredTy = |
| getPullbackInfo().getBranchingTraceEnumLoweredType(succBB); |
| auto *enumEltDecl = |
| getPullbackInfo().lookUpBranchingTraceEnumElement(origBB, succBB); |
| auto enumEltType = remapType( |
| enumLoweredTy.getEnumElementType(enumEltDecl, getModule())); |
| pullbackTrampolineBB->createPhiArgument(enumEltType, |
| ValueOwnershipKind::Owned); |
| } |
| } |
| |
| auto *pullbackEntry = pullback.getEntryBlock(); |
| // The pullback function has type (seed, exit_pbs) -> ([arg0], ..., [argn]). |
| auto pbParamArgs = pullback.getArgumentsWithoutIndirectResults(); |
| assert(pbParamArgs.size() == 2); |
| seed = pbParamArgs[0]; |
| |
| // Assign adjoint for original result. |
| builder.setInsertionPoint( |
| pullbackEntry, getNextFunctionLocalAllocationInsertionPoint()); |
| if (seed->getType().isAddress()) { |
| auto *seedBufCopy = builder.createAllocStack(pbLoc, seed->getType()); |
| builder.createCopyAddr(pbLoc, seed, seedBufCopy, IsNotTake, |
| IsInitialization); |
| setAdjointBuffer(origExit, origResult, seedBufCopy); |
| functionLocalAllocations.push_back(seedBufCopy); |
| LLVM_DEBUG(getADDebugStream() |
| << "Assigned seed buffer " << seedBufCopy |
| << " as the adjoint of original indirect result " |
| << origResult); |
| } else { |
| setAdjointValue(origExit, origResult, makeConcreteAdjointValue(seed)); |
| LLVM_DEBUG(getADDebugStream() |
| << "Assigned seed " << *seed |
| << " as the adjoint of original result " << origResult); |
| } |
| |
| // Visit original blocks blocks in post-order and perform differentiation |
| // in corresponding pullback blocks. If errors occurred, back out. |
| for (auto *bb : postOrderPostDomOrder) { |
| visitSILBasicBlock(bb); |
| if (errorOccurred) |
| return true; |
| } |
| |
| // Prepare and emit a `return` in the pullback exit block. |
| auto *origEntry = getOriginal().getEntryBlock(); |
| auto *pbExit = getPullbackBlock(origEntry); |
| builder.setInsertionPoint(pbExit); |
| |
| // This vector will contain all the materialized return elements. |
| SmallVector<SILValue, 8> retElts; |
| // This vector will contain all indirect parameter adjoint buffers. |
| SmallVector<SILValue, 4> indParamAdjoints; |
| |
| auto origParams = getOriginal().getArgumentsWithoutIndirectResults(); |
| |
| // Materializes the return element corresponding to the parameter |
| // `parameterIndex` into the `retElts` vector. |
| auto addRetElt = [&](unsigned parameterIndex) -> void { |
| auto origParam = origParams[parameterIndex]; |
| if (origParam->getType().isObject()) { |
| auto pbVal = getAdjointValue(origEntry, origParam); |
| auto val = materializeAdjointDirect(pbVal, pbLoc); |
| auto newVal = builder.emitCopyValueOperation(pbLoc, val); |
| retElts.push_back(newVal); |
| } else { |
| auto adjBuf = getAdjointBuffer(origEntry, origParam); |
| indParamAdjoints.push_back(adjBuf); |
| } |
| }; |
| // Collect differentiation parameter adjoints. |
| for (auto i : getIndices().parameters->getIndices()) |
| addRetElt(i); |
| |
| // Copy them to adjoint indirect results. |
| assert(indParamAdjoints.size() == |
| getPullback().getIndirectResults().size() && |
| "Indirect parameter adjoint count mismatch"); |
| for (auto pair : zip(indParamAdjoints, |
| getPullback().getIndirectResults())) { |
| auto source = std::get<0>(pair); |
| auto *dest = std::get<1>(pair); |
| builder.createCopyAddr(pbLoc, source, dest, IsTake, IsInitialization); |
| // Prevent source buffer from being deallocated, since the underlying |
| // value is moved. |
| destroyedLocalAllocations.insert(source); |
| } |
| |
| // Emit cleanups for all local values. |
| cleanUpTemporariesForBlock(pbExit, pbLoc); |
| // Deallocate local allocations. |
| for (auto alloc : functionLocalAllocations) { |
| // Assert that local allocations have at least one use. |
| // Buffers should not be allocated needlessly. |
| assert(!alloc->use_empty()); |
| if (!destroyedLocalAllocations.count(alloc)) { |
| builder.emitDestroyAddrAndFold(pbLoc, alloc); |
| destroyedLocalAllocations.insert(alloc); |
| } |
| builder.createDeallocStack(pbLoc, alloc); |
| } |
| builder.createReturn(pbLoc, joinElements(retElts, builder, pbLoc)); |
| |
| #ifndef NDEBUG |
| bool leakFound = false; |
| // Ensure all temporaries have been cleaned up. |
| for (auto &bb : pullback) { |
| for (auto temp : blockTemporaries[&bb]) { |
| if (blockTemporarySet.count(temp)) { |
| leakFound = true; |
| getADDebugStream() << "Found leaked temporary:\n" << temp; |
| } |
| } |
| } |
| // Ensure all local allocations have been cleaned up. |
| for (auto localAlloc : functionLocalAllocations) { |
| if (!destroyedLocalAllocations.count(localAlloc)) { |
| leakFound = true; |
| getADDebugStream() << "Found leaked local buffer:\n" << localAlloc; |
| } |
| } |
| assert(!leakFound && "Leaks found!"); |
| #endif |
| |
| LLVM_DEBUG(getADDebugStream() << "Generated pullback for " |
| << original.getName() << ":\n" << pullback); |
| return errorOccurred; |
| } |
| |
| /// If original result is non-varied, it will always have a zero derivative. |
| /// Skip full pullback generation and simply emit zero derivatives for wrt |
| /// parameters. |
| void emitZeroDerivativesForNonvariedResult(SILValue origNonvariedResult) { |
| auto &pullback = getPullback(); |
| auto pbLoc = getPullback().getLocation(); |
| /* |
| // TODO(TF-788): Re-enable non-varied result warning. |
| // Emit fixit if original non-varied result has a valid source location. |
| auto startLoc = origNonvariedResult.getLoc().getStartSourceLoc(); |
| auto endLoc = origNonvariedResult.getLoc().getEndSourceLoc(); |
| if (startLoc.isValid() && endLoc.isValid()) { |
| getContext().diagnose(startLoc, diag::autodiff_nonvaried_result_fixit) |
| .fixItInsert(startLoc, "withoutDerivative(at:") |
| .fixItInsertAfter(endLoc, ")"); |
| } |
| */ |
| LLVM_DEBUG(getADDebugStream() << getOriginal().getName() |
| << " has non-varied result, returning zero" |
| " for all pullback results\n"); |
| auto *pullbackEntry = pullback.createBasicBlock(); |
| createEntryArguments(&pullback); |
| builder.setInsertionPoint(pullbackEntry); |
| // Destroy all owned arguments. |
| for (auto *arg : pullbackEntry->getArguments()) |
| if (arg->getOwnershipKind() == ValueOwnershipKind::Owned) |
| builder.emitDestroyOperation(pbLoc, arg); |
| // Return zero for each result. |
| SmallVector<SILValue, 4> directResults; |
| auto indirectResultIt = pullback.getIndirectResults().begin(); |
| for (auto resultInfo : pullback.getLoweredFunctionType()->getResults()) { |
| auto resultType = |
| pullback.mapTypeIntoContext(resultInfo.getType())->getCanonicalType(); |
| if (resultInfo.isFormalDirect()) |
| directResults.push_back(emitZeroDirect(resultType, pbLoc)); |
| else |
| emitZeroIndirect(resultType, *indirectResultIt++, pbLoc); |
| } |
| builder.createReturn(pbLoc, joinElements(directResults, builder, pbLoc)); |
| LLVM_DEBUG(getADDebugStream() << "Generated pullback for " |
| << getOriginal().getName() << ":\n" |
| << pullback); |
| } |
| |
| using TrampolineBlockSet = SmallPtrSet<SILBasicBlock *, 4>; |
| |
| /// Determine the pullback successor block for a given original block and one |
| /// of its predecessors. When a trampoline block is necessary, emit code into |
| /// the trampoline block to trampoline the original block's active value's |
| /// adjoint values. A dense map `trampolineArgs` will be populated to keep |
| /// track of which pullback successor blocks each active value's adjoint value |
| /// is used, so that we can release those values in pullback successor blocks |
| /// that are not using them. |
| SILBasicBlock *buildPullbackSuccessor( |
| SILBasicBlock *origBB, SILBasicBlock *origPredBB, |
| SmallDenseMap<SILValue, TrampolineBlockSet> &pullbackTrampolineBlockMap) { |
| // Get the pullback block and optional pullback trampoline block of the |
| // predecessor block. |
| auto *pullbackBB = getPullbackBlock(origPredBB); |
| auto *pullbackTrampolineBB = getPullbackTrampolineBlock(origPredBB, origBB); |
| // If the predecessor block does not have a corresponding pullback |
| // trampoline block, then the pullback successor is the pullback block. |
| if (!pullbackTrampolineBB) |
| return pullbackBB; |
| |
| // Otherwise, the pullback successor is the pullback trampoline block, |
| // which branches to the pullback block and propagates adjoint values of |
| // active values. |
| assert(pullbackTrampolineBB->getNumArguments() == 1); |
| auto loc = origBB->getParent()->getLocation(); |
| SmallVector<SILValue, 8> trampolineArguments; |
| // Propagate adjoint values/buffers of active values/buffers to |
| // predecessor blocks. |
| auto &predBBActiveValues = activeValues[origPredBB]; |
| for (auto activeValue : predBBActiveValues) { |
| LLVM_DEBUG(getADDebugStream() |
| << "Propagating active adjoint " << activeValue |
| << " to predecessors' pullback blocks\n"); |
| if (activeValue->getType().isObject()) { |
| auto activeValueAdj = getAdjointValue(origBB, activeValue); |
| auto concreteActiveValueAdj = |
| materializeAdjointDirect(activeValueAdj, loc); |
| |
| if (!pullbackTrampolineBlockMap.count(concreteActiveValueAdj)) { |
| concreteActiveValueAdj = |
| builder.emitCopyValueOperation(loc, concreteActiveValueAdj); |
| setAdjointValue(origBB, activeValue, |
| makeConcreteAdjointValue(concreteActiveValueAdj)); |
| } |
| auto insertion = pullbackTrampolineBlockMap.try_emplace( |
| concreteActiveValueAdj, TrampolineBlockSet()); |
| auto &blockSet = insertion.first->getSecond(); |
| blockSet.insert(pullbackTrampolineBB); |
| trampolineArguments.push_back(concreteActiveValueAdj); |
| |
| // If the pullback block does not yet have a registered adjoint |
| // value for the active value, set the adjoint value to the |
| // forwarded adjoint value argument. |
| // TODO: Hoist this logic out of loop over predecessor blocks to |
| // remove the `hasAdjointValue` check. |
| if (!hasAdjointValue(origPredBB, activeValue)) { |
| auto *pullbackBBArg = |
| getActiveValuePullbackBlockArgument(origPredBB, activeValue); |
| auto forwardedArgAdj = makeConcreteAdjointValue(pullbackBBArg); |
| setAdjointValue(origPredBB, activeValue, forwardedArgAdj); |
| } |
| } else { |
| // Propagate adjoint buffers using `copy_addr`. |
| auto adjBuf = getAdjointBuffer(origBB, activeValue); |
| auto predAdjBuf = getAdjointBuffer(origPredBB, activeValue); |
| builder.createCopyAddr( |
| loc, adjBuf, predAdjBuf, IsNotTake, IsNotInitialization); |
| } |
| } |
| // Propagate pullback struct argument. |
| SILBuilder pullbackTrampolineBBBuilder(pullbackTrampolineBB); |
| auto *predPBStructVal = pullbackTrampolineBB->getArguments().front(); |
| auto boxType = |
| dyn_cast<SILBoxType>(predPBStructVal->getType().getASTType()); |
| if (!boxType) { |
| trampolineArguments.push_back(predPBStructVal); |
| } else { |
| auto *projectBox = pullbackTrampolineBBBuilder.createProjectBox( |
| loc, predPBStructVal, /*index*/ 0); |
| auto loaded = pullbackTrampolineBBBuilder.emitLoadValueOperation( |
| loc, projectBox, LoadOwnershipQualifier::Copy); |
| pullbackTrampolineBBBuilder.emitDestroyValueOperation(loc, |
| predPBStructVal); |
| trampolineArguments.push_back(loaded); |
| } |
| // Branch from pullback trampoline block to pullback block. |
| pullbackTrampolineBBBuilder.createBranch(loc, pullbackBB, |
| trampolineArguments); |
| return pullbackTrampolineBB; |
| } |
| |
| /// Emit pullback code in the corresponding pullback block. |
| void visitSILBasicBlock(SILBasicBlock *bb) { |
| auto pbLoc = getPullback().getLocation(); |
| // Get the corresponding pullback basic block. |
| auto *pbBB = getPullbackBlock(bb); |
| builder.setInsertionPoint(pbBB); |
| |
| LLVM_DEBUG({ |
| auto &s = getADDebugStream() |
| << "Original bb" + std::to_string(bb->getDebugID()) |
| << ": To differentiate or not to differentiate?\n"; |
| for (auto &inst : reversed(*bb)) { |
| s << (getPullbackInfo().shouldDifferentiateInstruction(&inst) |
| ? "[∂] " : "[ ] ") |
| << inst; |
| } |
| }); |
| |
| // Visit each instruction in reverse order. |
| for (auto &inst : reversed(*bb)) { |
| if (!getPullbackInfo().shouldDifferentiateInstruction(&inst)) |
| continue; |
| // Differentiate instruction. |
| visit(&inst); |
| if (errorOccurred) |
| return; |
| } |
| |
| // Emit a branching terminator for the block. |
| // If the original block is the original entry, then the pullback block is |
| // the pullback exit. This is handled specially in `PullbackEmitter::run()`, |
| // so we leave the block non-terminated. |
| if (bb->isEntry()) |
| return; |
| |
| // Otherwise, add a `switch_enum` terminator for non-exit |
| // pullback blocks. |
| // 1. Get the pullback struct pullback block argument. |
| // 2. Extract the predecessor enum value from the pullback struct value. |
| auto *predEnum = getPullbackInfo().getBranchingTraceDecl(bb); |
| auto *predEnumField = |
| getPullbackInfo().lookUpLinearMapStructEnumField(bb); |
| auto predEnumVal = getPullbackStructElement(bb, predEnumField); |
| |
| // Propagate adjoint values from active basic block arguments to |
| // predecessor terminator operands. |
| for (auto *bbArg : bb->getArguments()) { |
| if (!getActivityInfo().isActive(bbArg, getIndices())) |
| continue; |
| // Get predecessor terminator operands. |
| SmallVector<std::pair<SILBasicBlock *, SILValue>, 4> incomingValues; |
| bbArg->getSingleTerminatorOperands(incomingValues); |
| // Initialize adjoint value of predecessor terminator operands as |
| // adjoint value of current block arguments. |
| auto bbArgAdj = getAdjointValue(bb, bbArg); |
| for (auto pair : incomingValues) { |
| auto *predBB = std::get<0>(pair); |
| auto incomingValue = std::get<1>(pair); |
| setAdjointValue(predBB, incomingValue, bbArgAdj); |
| } |
| } |
| |
| // 3. Build the pullback successor cases for the `switch_enum` |
| // instruction. The pullback successors correspond to the predecessors |
| // of the current block. |
| SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4> |
| pullbackSuccessorCases; |
| // A map from active values' adjoint values to the trampoline blocks that |
| // are using them. |
| SmallDenseMap<SILValue, TrampolineBlockSet> pullbackTrampolineBlockMap; |
| SmallVector<SILBasicBlock *, 8> pullbackSuccBBs; |
| for (auto *predBB : bb->getPredecessorBlocks()) { |
| auto *pullbackSuccBB = buildPullbackSuccessor(bb, predBB, |
| pullbackTrampolineBlockMap); |
| pullbackSuccBBs.push_back(pullbackSuccBB); |
| auto *enumEltDecl = |
| getPullbackInfo().lookUpBranchingTraceEnumElement(predBB, bb); |
| pullbackSuccessorCases.push_back({enumEltDecl, pullbackSuccBB}); |
| } |
| // Values are trampolined by only a subset of pullback successor blocks. |
| // Other successors blocks should destroy the value to balance the reference |
| // count. |
| for (auto pair : pullbackTrampolineBlockMap) { |
| auto value = pair.getFirst(); |
| // The set of trampoline BBs that are users of `value`. |
| auto &userTrampolineBBSet = pair.getSecond(); |
| // For each pullback successor block that does not trampoline the value, |
| // release the value. |
| for (auto *pullbackSuccBB : pullbackSuccBBs) { |
| if (userTrampolineBBSet.count(pullbackSuccBB)) |
| continue; |
| SILBuilder builder(pullbackSuccBB->begin()); |
| builder.emitDestroyValueOperation(pbLoc, value); |
| } |
| } |
| // Emit cleanups for all block-local temporaries. |
| cleanUpTemporariesForBlock(pbBB, pbLoc); |
| // - If the original block has exactly one predecessor, then the pullback |
| // block has exactly one successor. Extract the pullback struct value |
| // from the predecessor enum value using `unchecked_take_enum_data_addr` |
| // and `load [take]`, and branch to the pullback successor block. |
| assert(pullbackSuccessorCases.size() == predEnum->getNumElements()); |
| builder.createSwitchEnum( |
| pbLoc, predEnumVal, /*DefaultBB*/ nullptr, pullbackSuccessorCases); |
| } |
| |
| void visit(SILInstruction *inst) { |
| if (errorOccurred) |
| return; |
| |
| LLVM_DEBUG(getADDebugStream() |
| << "PullbackEmitter visited:\n[ORIG]" << *inst); |
| #ifndef NDEBUG |
| auto beforeInsertion = std::prev(builder.getInsertionPoint()); |
| #endif |
| SILInstructionVisitor::visit(inst); |
| LLVM_DEBUG({ |
| auto &s = llvm::dbgs() << "[ADJ] Emitted in pullback:\n"; |
| auto afterInsertion = builder.getInsertionPoint(); |
| for (auto it = ++beforeInsertion; it != afterInsertion; ++it) |
| s << *it; |
| }); |
| } |
| |
| void visitSILInstruction(SILInstruction *inst) { |
| LLVM_DEBUG(getADDebugStream() |
| << "Unhandled instruction in PullbackEmitter: " << *inst); |
| getContext().emitNondifferentiabilityError(inst, getInvoker(), |
| diag::autodiff_expression_not_differentiable_note); |
| errorOccurred = true; |
| } |
| |
| AllocStackInst * |
| emitArrayTangentSubscript(ApplyInst *ai, SILType eltType, |
| SILValue adjointArray, SILValue fnRef, |
| CanGenericSignature genericSig, int index) { |
| auto &ctx = builder.getASTContext(); |
| auto astType = eltType.getASTType(); |
| auto literal = builder.createIntegerLiteral( |
| ai->getLoc(), SILType::getBuiltinIntegerType(64, ctx), index); |
| auto intType = SILType::getPrimitiveObjectType( |
| ctx.getIntDecl()->getDeclaredType()->getCanonicalType()); |
| auto intStruct = builder.createStruct(ai->getLoc(), intType, {literal}); |
| AllocStackInst *subscriptBuffer = |
| builder.createAllocStack(ai->getLoc(), eltType); |
| auto swiftModule = getModule().getSwiftModule(); |
| auto diffProto = ctx.getProtocol(KnownProtocolKind::Differentiable); |
| auto diffConf = swiftModule->lookupConformance(astType, diffProto); |
| assert(diffConf.hasValue() && "Missing conformance to `Differentiable`"); |
| auto addArithProto = ctx.getProtocol(KnownProtocolKind::AdditiveArithmetic); |
| auto addArithConf = swiftModule->lookupConformance(astType, addArithProto); |
| assert(addArithConf.hasValue() && |
| "Missing conformance to `AdditiveArithmetic`"); |
| auto subMap = |
| SubstitutionMap::get(genericSig, {astType}, {*addArithConf, *diffConf}); |
| builder.createApply(ai->getLoc(), fnRef, subMap, |
| {subscriptBuffer, intStruct, adjointArray}); |
| return subscriptBuffer; |
| } |
| |
| void accumulateArrayTangentSubscriptDirect(ApplyInst *ai, SILType eltType, |
| StoreInst *si, |
| AllocStackInst *subscriptBuffer) { |
| auto newAdjValue = builder.emitLoadValueOperation( |
| ai->getLoc(), subscriptBuffer, LoadOwnershipQualifier::Take); |
| recordTemporary(newAdjValue); |
| SILValue src = si->getSrc(); |
| // When the store's source is a `copy_value`, the `copy_value` is part of |
| // array literal initialization. In this case, add the adjoint to the source |
| // of the copy directly. |
| if (auto *cvi = dyn_cast<CopyValueInst>(src)) |
| src = cvi->getOperand(); |
| addAdjointValue(si->getParent(), src, |
| makeConcreteAdjointValue(newAdjValue), si->getLoc()); |
| blockTemporaries[ai->getParent()].push_back(newAdjValue); |
| builder.createDeallocStack(ai->getLoc(), subscriptBuffer); |
| } |
| |
| void accumulateArrayTangentSubscriptIndirect( |
| ApplyInst *ai, CopyAddrInst *cai, AllocStackInst *subscriptBuffer) { |
| addToAdjointBuffer(cai->getParent(), cai->getSrc(), subscriptBuffer, |
| cai->getLoc()); |
| builder.emitDestroyAddrAndFold(cai->getLoc(), subscriptBuffer); |
| builder.createDeallocStack(ai->getLoc(), subscriptBuffer); |
| } |
| |
| void visitArrayInitialization(ApplyInst *ai) { |
| LLVM_DEBUG(getADDebugStream() << "Visiting array initialization:\n" << *ai); |
| SILValue adjointArray; |
| SILValue fnRef; |
| CanGenericSignature genericSig; |
| for (auto use : ai->getUses()) { |
| auto *dti = dyn_cast<DestructureTupleInst>(use->getUser()); |
| if (!dti) continue; |
| // The first tuple field of the return value is the `Array`. |
| adjointArray = getAdjointValue(ai->getParent(), dti->getResult(0)) |
| .getConcreteValue(); |
| assert(adjointArray && "Array does not have adjoint value"); |
| auto astType = adjointArray->getType().getASTType(); |
| auto typeDecl = astType->getStructOrBoundGenericStruct(); |
| auto subscriptDecl = cast<SubscriptDecl>(typeDecl->lookupDirect( |
| DeclBaseName::createSubscript()).front()); |
| auto subscriptGet = subscriptDecl->getAccessor(AccessorKind::Get); |
| SILDeclRef subscriptRef(subscriptGet, SILDeclRef::Kind::Func); |
| auto fnBuilder = SILOptFunctionBuilder(getContext().getTransform()); |
| auto fn = fnBuilder.getOrCreateFunction( |
| ai->getLoc(), subscriptRef, NotForDefinition); |
| genericSig = fn->getLoweredFunctionType()->getGenericSignature(); |
| fnRef = builder.createFunctionRef(ai->getLoc(), fn); |
| } |
| assert(adjointArray && "Array does not have adjoint value"); |
| assert(genericSig && "No generic signature"); |
| assert(fnRef && "Could not create `function_ref`"); |
| // Two loops because the `tuple_extract` instructions can be reached in |
| // either order. |
| for (auto use : ai->getUses()) { |
| auto *dti = dyn_cast<DestructureTupleInst>(use->getUser()); |
| if (!dti) continue; |
| // The second tuple field is the `RawPointer`. |
| for (auto use : dti->getResult(1)->getUses()) { |
| // The `RawPointer` passes through a `pointer_to_address`. That |
| // instruction's first use is a `store` whose src is useful; its |
| // subsequent uses are `index_addr`s whose only use is a useful |
| // `store`. In the indirect case, each `store` is instead a |
| // `copy_addr`. |
| for (auto use : use->getUser()->getResult(0)->getUses()) { |
| auto inst = use->getUser(); |
| if (auto si = dyn_cast<StoreInst>(inst)) { |
| auto tanType = getRemappedTangentType(si->getSrc()->getType()); |
| auto subscriptBuffer = emitArrayTangentSubscript( |
| ai, tanType, adjointArray, fnRef, genericSig, 0); |
| accumulateArrayTangentSubscriptDirect( |
| ai, tanType, si, subscriptBuffer); |
| } else if (auto cai = dyn_cast<CopyAddrInst>(inst)) { |
| auto tanType = getRemappedTangentType(cai->getSrc()->getType()); |
| auto subscriptBuffer = emitArrayTangentSubscript( |
| ai, tanType, adjointArray, fnRef, genericSig, 0); |
| accumulateArrayTangentSubscriptIndirect( |
| ai, cai, subscriptBuffer); |
| } else if (auto iai = dyn_cast<IndexAddrInst>(inst)) { |
| for (auto use : iai->getUses()) { |
| if (auto si = dyn_cast<StoreInst>(use->getUser())) { |
| auto literal = dyn_cast<IntegerLiteralInst>(iai->getIndex()); |
| auto tanType = getRemappedTangentType( |
| si->getSrc()->getType()); |
| auto subscriptBuffer = emitArrayTangentSubscript( |
| ai, tanType, adjointArray, fnRef, |
| genericSig, literal->getValue().getLimitedValue()); |
| accumulateArrayTangentSubscriptDirect( |
| ai, tanType, si, subscriptBuffer); |
| } else if (auto cai = dyn_cast<CopyAddrInst>(use->getUser())) { |
| auto literal = dyn_cast<IntegerLiteralInst>(iai->getIndex()); |
| auto tanType = getRemappedTangentType( |
| cai->getSrc()->getType()); |
| auto subscriptBuffer = emitArrayTangentSubscript( |
| ai, tanType, adjointArray, fnRef, |
| genericSig, literal->getValue().getLimitedValue()); |
| accumulateArrayTangentSubscriptIndirect( |
| ai, cai, subscriptBuffer); |
| } |
| } |
| } |
| } |
| } |
| } |
| } |
| |
| void visitApplyInst(ApplyInst *ai) { |
| assert(getPullbackInfo().shouldDifferentiateApplyInst(ai)); |
| // Handle array uninitialized allocation intrinsic specially. |
| if (isArrayLiteralIntrinsic(ai)) |
| return visitArrayInitialization(ai); |
| // Replace a call to a function with a call to its pullback. |
| auto &nestedApplyInfo = getContext().getNestedApplyInfo(); |
| auto applyInfoLookup = nestedApplyInfo.find(ai); |
| // If no `NestedApplyInfo` was found, then this task doesn't need to be |
| // differentiated. |
| if (applyInfoLookup == nestedApplyInfo.end()) { |
| // Must not be active. |
| assert(!getActivityInfo().isActive(ai, getIndices())); |
| return; |
| } |
| auto applyInfo = applyInfoLookup->getSecond(); |
| |
| // Get the pullback. |
| auto *field = getPullbackInfo().lookUpLinearMapDecl(ai); |
| assert(field); |
| auto loc = ai->getLoc(); |
| auto pullback = getPullbackStructElement(ai->getParent(), field); |
| |
| // Get the original result of the `apply` instruction. |
| SmallVector<SILValue, 8> args; |
| SmallVector<SILValue, 8> origDirectResults; |
| forEachApplyDirectResult(ai, [&](SILValue directResult) { |
| origDirectResults.push_back(directResult); |
| }); |
| SmallVector<SILValue, 8> origAllResults; |
| collectAllActualResultsInTypeOrder(ai, origDirectResults, origAllResults); |
| assert(applyInfo.indices.source < origAllResults.size()); |
| auto origResult = origAllResults[applyInfo.indices.source]; |
| assert(origResult); |
| auto origNumIndRes = ai->getNumIndirectResults(); |
| |
| auto pullbackType = |
| remapType(pullback->getType()).castTo<SILFunctionType>(); |
| |
| // Get the seed (i.e. adjoint value of the original result). |
| SILValue seed; |
| auto *bb = ai->getParent(); |
| if (origResult->getType().isObject()) { |
| // Otherwise, materialize adjoint value of `ai`. |
| seed = materializeAdjoint(getAdjointValue(bb, origResult), loc); |
| } else { |
| seed = getAdjointBuffer(bb, origResult); |
| } |
| |
| // Create allocations for pullback indirect results. |
| SmallVector<AllocStackInst *, 4> pullbackIndirectResults; |
| auto actualPullbackType = applyInfo.originalPullbackType |
| ? *applyInfo.originalPullbackType |
| : pullbackType; |
| for (auto indRes : actualPullbackType->getIndirectFormalResults()) { |
| auto *alloc = |
| builder.createAllocStack(loc, remapType(indRes.getSILStorageType())); |
| pullbackIndirectResults.push_back(alloc); |
| args.push_back(alloc); |
| } |
| |
| // If callee pullback was reabstracted in VJP, reabstract callee pullback. |
| if (applyInfo.originalPullbackType) { |
| SILOptFunctionBuilder fb(getContext().getTransform()); |
| auto *thunk = getOrCreateReabstractionThunk( |
| fb, getContext().getModule(), loc, &getPullback(), |
| pullbackType, *applyInfo.originalPullbackType); |
| auto *thunkRef = builder.createFunctionRef(loc, thunk); |
| pullback = builder.createPartialApply( |
| loc, thunkRef, |
| remapSubstitutionMap(thunk->getForwardingSubstitutionMap()), |
| {pullback}, pullbackType->getCalleeConvention()); |
| } |
| args.push_back(seed); |
| |
| // Call the callee pullback. |
| auto *pullbackCall = builder.createApply( |
| loc, pullback, SubstitutionMap(), args, /*isNonThrowing*/ false); |
| builder.emitDestroyValueOperation(loc, pullback); |
| |
| // Extract all results from `pullbackCall`. |
| SmallVector<SILValue, 8> dirResults; |
| extractAllElements(pullbackCall, builder, dirResults); |
| // Get all results in type-defined order. |
| SmallVector<SILValue, 8> allResults; |
| collectAllActualResultsInTypeOrder(pullbackCall, dirResults, allResults); |
| LLVM_DEBUG({ |
| auto &s = getADDebugStream(); |
| s << "All results of the nested pullback call:\n"; |
| llvm::for_each(allResults, [&](SILValue v) { s << v; }); |
| }); |
| |
| // Accumulate adjoints for original differentiation parameters. |
| auto allResultsIt = allResults.begin(); |
| for (unsigned i : applyInfo.indices.parameters->getIndices()) { |
| auto origArg = ai->getArgument(origNumIndRes + i); |
| auto tan = *allResultsIt++; |
| if (tan->getType().isAddress()) { |
| addToAdjointBuffer(bb, origArg, tan, loc); |
| } else { |
| if (origArg->getType().isAddress()) { |
| auto *tmpBuf = builder.createAllocStack(loc, tan->getType()); |
| builder.emitStoreValueOperation(loc, tan, tmpBuf, |
| StoreOwnershipQualifier::Init); |
| addToAdjointBuffer(bb, origArg, tmpBuf, loc); |
| builder.emitDestroyAddrAndFold(loc, tmpBuf); |
| builder.createDeallocStack(loc, tmpBuf); |
| } |
| else { |
| recordTemporary(tan); |
| addAdjointValue(bb, origArg, makeConcreteAdjointValue(tan), loc); |
| } |
| } |
| } |
| // Destroy and deallocate pullback indirect results. |
| for (auto *alloc : reversed(pullbackIndirectResults)) { |
| builder.emitDestroyAddrAndFold(loc, alloc); |
| builder.createDeallocStack(loc, alloc); |
| } |
| } |
| |
| /// Handle `struct` instruction. |
| /// Original: y = struct (x0, x1, x2, ...) |
| /// Adjoint: adj[x0] += struct_extract adj[y], #x0 |
| /// adj[x1] += struct_extract adj[y], #x1 |
| /// adj[x2] += struct_extract adj[y], #x2 |
| /// ... |
| void visitStructInst(StructInst *si) { |
| auto *bb = si->getParent(); |
| auto loc = si->getLoc(); |
| auto *structDecl = si->getStructDecl(); |
| auto av = getAdjointValue(bb, si); |
| switch (av.getKind()) { |
| case AdjointValueKind::Zero: |
| for (auto *field : structDecl->getStoredProperties()) { |
| auto fv = si->getFieldValue(field); |
| addAdjointValue(bb, fv, |
| makeZeroAdjointValue(getRemappedTangentType(fv->getType())), loc); |
| } |
| break; |
| case AdjointValueKind::Concrete: { |
| auto adjStruct = materializeAdjointDirect(std::move(av), loc); |
| // Find the struct `TangentVector` type. |
| auto structTy = remapType(si->getType()).getASTType(); |
| auto tangentVectorTy = |
| getTangentSpace(structTy)->getType()->getCanonicalType(); |
| assert(!getModule().Types.getTypeLowering( |
| tangentVectorTy, ResilienceExpansion::Minimal) |
| .isAddressOnly()); |
| auto *tangentVectorDecl = |
| tangentVectorTy->getStructOrBoundGenericStruct(); |
| assert(tangentVectorDecl); |
| |
| auto *dti = builder.createDestructureStruct(si->getLoc(), adjStruct); |
| // Accumulate adjoints for the fields of the `struct` operand. |
| unsigned fieldIndex = 0; |
| for (auto it = structDecl->getStoredProperties().begin(); |
| it != structDecl->getStoredProperties().end(); ++it, ++fieldIndex) { |
| VarDecl *field = *it; |
| if (field->getAttrs().hasAttribute<NoDerivativeAttr>()) |
| continue; |
| // Find the corresponding field in the tangent space. |
| VarDecl *tanField = nullptr; |
| if (tangentVectorDecl == structDecl) |
| tanField = field; |
| // Otherwise, look up the field by name. |
| else { |
| auto tanFieldLookup = |
| tangentVectorDecl->lookupDirect(field->getName()); |
| if (tanFieldLookup.empty()) { |
| getContext().emitNondifferentiabilityError( |
| si, getInvoker(), |
| diag::autodiff_stored_property_no_corresponding_tangent, |
| tangentVectorDecl->getNameStr(), field->getNameStr()); |
| errorOccurred = true; |
| return; |
| } |
| tanField = cast<VarDecl>(tanFieldLookup.front()); |
| } |
| assert(tanField); |
| auto tanElt = dti->getResult(fieldIndex); |
| addAdjointValue( |
| bb, si->getFieldValue(field), |
| makeConcreteAdjointValue(tanElt), si->getLoc()); |
| } |
| break; |
| } |
| case AdjointValueKind::Aggregate: { |
| // Note: All user-called initializations go through the calls to the |
| // initializer, and synthesized initializers only have one level of struct |
| // formation which will not result into any aggregate adjoint valeus. |
| llvm_unreachable("Aggregate adjoint values should not occur for `struct` " |
| "instructions"); |
| } |
| } |
| } |
| |
| /// Handle `struct_extract` instruction. |
| /// Original: y = struct_extract x, #field |
| /// Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0) |
| /// ^~~~~~~ |
| /// field in tangent space corresponding to #field |
| void visitStructExtractInst(StructExtractInst *sei) { |
| assert(!sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() && |
| "`struct_extract` with `@noDerivative` field should not be " |
| "differentiated; activity analysis should not marked as varied"); |
| auto *bb = sei->getParent(); |
| auto structTy = remapType(sei->getOperand()->getType()).getASTType(); |
| auto tangentVectorTy = |
| getTangentSpace(structTy)->getType()->getCanonicalType(); |
| assert(!getModule().Types.getTypeLowering( |
| tangentVectorTy, ResilienceExpansion::Minimal) |
| .isAddressOnly()); |
| auto tangentVectorSILTy = |
| SILType::getPrimitiveObjectType(tangentVectorTy); |
| auto *tangentVectorDecl = |
| tangentVectorTy->getStructOrBoundGenericStruct(); |
| assert(tangentVectorDecl); |
| // Find the corresponding field in the tangent space. |
| VarDecl *tanField = nullptr; |
| // If the tangent space is the original struct, then field is the same. |
| if (tangentVectorDecl == sei->getStructDecl()) |
| tanField = sei->getField(); |
| // Otherwise, look up the field by name. |
| else { |
| auto tanFieldLookup = |
| tangentVectorDecl->lookupDirect(sei->getField()->getName()); |
| if (tanFieldLookup.empty()) { |
| getContext().emitNondifferentiabilityError( |
| sei, getInvoker(), |
| diag::autodiff_stored_property_no_corresponding_tangent, |
| sei->getStructDecl()->getNameStr(), |
| sei->getField()->getNameStr()); |
| errorOccurred = true; |
| return; |
| } |
| tanField = cast<VarDecl>(tanFieldLookup.front()); |
| } |
| // Accumulate adjoint for the `struct_extract` operand. |
| auto av = getAdjointValue(bb, sei); |
| switch (av.getKind()) { |
| case AdjointValueKind::Zero: |
| addAdjointValue(bb, sei->getOperand(), |
| makeZeroAdjointValue(tangentVectorSILTy), sei->getLoc()); |
| break; |
| case AdjointValueKind::Concrete: |
| case AdjointValueKind::Aggregate: { |
| SmallVector<AdjointValue, 8> eltVals; |
| for (auto *field : tangentVectorDecl->getStoredProperties()) { |
| if (field == tanField) { |
| eltVals.push_back(av); |
| } else { |
| auto substMap = tangentVectorTy->getMemberSubstitutionMap( |
| field->getModuleContext(), field); |
| auto fieldTy = field->getType().subst(substMap); |
| auto fieldSILTy = |
| getContext().getTypeConverter().getLoweredType( |
| fieldTy, ResilienceExpansion::Minimal); |
| assert(fieldSILTy.isObject()); |
| eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); |
| } |
| } |
| addAdjointValue(bb, sei->getOperand(), |
| makeAggregateAdjointValue(tangentVectorSILTy, eltVals), |
| sei->getLoc()); |
| } |
| } |
| } |
| |
| /// Handle `tuple` instruction. |
| /// Original: y = tuple (x0, x1, x2, ...) |
| /// Adjoint: adj[x0] += tuple_extract adj[y], 0 |
| /// ... |
| void visitTupleInst(TupleInst *ti) { |
| auto *bb = ti->getParent(); |
| auto av = getAdjointValue(bb, ti); |
| switch (av.getKind()) { |
| case AdjointValueKind::Zero: |
| for (auto eltVal : ti->getElements()) { |
| if (!getTangentSpace(eltVal->getType().getASTType())) |
| continue; |
| addAdjointValue(bb, eltVal, |
| makeZeroAdjointValue(getRemappedTangentType(eltVal->getType())), |
| ti->getLoc()); |
| } |
| break; |
| case AdjointValueKind::Concrete: { |
| auto val = av.getConcreteValue(); |
| unsigned adjIdx = 0; |
| auto elts = builder.createDestructureTuple(ti->getLoc(), val); |
| for (auto i : range(ti->getNumOperands())) { |
| if (!getTangentSpace(ti->getOperand(i)->getType().getASTType())) |
| continue; |
| auto adjElt = val; |
| if (val->getType().is<TupleType>()) |
| adjElt = elts->getResult(adjIdx++); |
| addAdjointValue(bb, ti->getOperand(i), |
| makeConcreteAdjointValue(adjElt), ti->getLoc()); |
| } |
| break; |
| } |
| case AdjointValueKind::Aggregate: |
| unsigned adjIdx = 0; |
| for (auto i : range(ti->getElements().size())) { |
| if (!getTangentSpace(ti->getElement(i)->getType().getASTType())) |
| continue; |
| addAdjointValue(bb, ti->getElement(i), av.getAggregateElement(adjIdx++), |
| ti->getLoc()); |
| } |
| break; |
| } |
| } |
| |
| /// Handle `tuple_extract` instruction. |
| /// Original: y = tuple_extract x, <n> |
| /// Adjoint: adj[x] += tuple (0, 0, ..., adj[y], ..., 0, 0) |
| /// ^~~~~~ |
| /// n'-th element, where n' is tuple tangent space |
| /// index corresponding to n |
| void visitTupleExtractInst(TupleExtractInst *tei) { |
| auto *bb = tei->getParent(); |
| auto tupleTanTy = getRemappedTangentType(tei->getOperand()->getType()); |
| auto av = getAdjointValue(bb, tei); |
| switch (av.getKind()) { |
| case AdjointValueKind::Zero: |
| addAdjointValue(bb, tei->getOperand(), makeZeroAdjointValue(tupleTanTy), |
| tei->getLoc()); |
| break; |
| case AdjointValueKind::Aggregate: |
| case AdjointValueKind::Concrete: { |
| auto tupleTy = tei->getTupleType(); |
| auto tupleTanTupleTy = tupleTanTy.getAs<TupleType>(); |
| if (!tupleTanTupleTy) { |
| addAdjointValue(bb, tei->getOperand(), av, tei->getLoc()); |
| break; |
| } |
| SmallVector<AdjointValue, 8> elements; |
| unsigned adjIdx = 0; |
| for (unsigned i : range(tupleTy->getNumElements())) { |
| if (!getTangentSpace( |
| tupleTy->getElement(i).getType()->getCanonicalType())) |
| continue; |
| if (tei->getFieldNo() == i) |
| elements.push_back(av); |
| else |
| elements.push_back(makeZeroAdjointValue( |
| getRemappedTangentType(SILType::getPrimitiveObjectType( |
| tupleTanTupleTy->getElementType(adjIdx++) |
| ->getCanonicalType())))); |
| } |
| if (elements.size() == 1) { |
| addAdjointValue(bb, tei->getOperand(), elements.front(), tei->getLoc()); |
| break; |
| } |
| addAdjointValue(bb, tei->getOperand(), |
| makeAggregateAdjointValue(tupleTanTy, elements), tei->getLoc()); |
| break; |
| } |
| } |
| } |
| |
| /// Handle `destructure_tuple` instruction. |
| /// Original: (y0, ..., yn) = destructure_tuple x |
| /// Adjoint: adj[x].0 += adj[y0] |
| /// ... |
| /// adj[x].n += adj[yn] |
| void visitDestructureTupleInst(DestructureTupleInst *dti) { |
| auto *bb = dti->getParent(); |
| auto tupleTanTy = getRemappedTangentType(dti->getOperand()->getType()); |
| SmallVector<AdjointValue, 8> adjValues; |
| for (auto origElt : dti->getResults()) { |
| if (!getTangentSpace(origElt->getType().getASTType())) |
| continue; |
| adjValues.push_back(getAdjointValue(bb, origElt)); |
| } |
| addAdjointValue(bb, dti->getOperand(), |
| makeAggregateAdjointValue(tupleTanTy, adjValues), |
| dti->getLoc()); |
| } |
| |
| /// Handle `load` or `load_borrow` instruction |
| /// Original: y = load/load_borrow x |
| /// Adjoint: adj[x] += adj[y] |
| void visitLoadOperation(SingleValueInstruction *inst) { |
| assert(isa<LoadInst>(inst) || isa<LoadBorrowInst>(inst)); |
| auto *bb = inst->getParent(); |
| auto adjVal = |
| materializeAdjointDirect(getAdjointValue(bb, inst), inst->getLoc()); |
| // Allocate a local buffer and store the adjoint value. This buffer will be |
| // used for accumulation into the adjoint buffer. |
| auto *localBuf = builder.createAllocStack(inst->getLoc(), adjVal->getType()); |
| auto copy = builder.emitCopyValueOperation(inst->getLoc(), adjVal); |
| builder.emitStoreValueOperation(inst->getLoc(), copy, localBuf, |
| StoreOwnershipQualifier::Init); |
| // Accumulate the adjoint value in the local buffer into the adjoint buffer. |
| addToAdjointBuffer(bb, inst->getOperand(0), localBuf, inst->getLoc()); |
| builder.emitDestroyAddr(inst->getLoc(), localBuf); |
| builder.createDeallocStack(inst->getLoc(), localBuf); |
| } |
| void visitLoadInst(LoadInst *li) { visitLoadOperation(li); } |
| void visitLoadBorrowInst(LoadBorrowInst *lbi) { visitLoadOperation(lbi); } |
| |
| /// Handle `store` or `store_borrow` instruction. |
| /// Original: store/store_borrow x to y |
| /// Adjoint: adj[x] += load adj[y]; adj[y] = 0 |
| void visitStoreOperation(SILBasicBlock *bb, SILLocation loc, |
| SILValue origSrc, SILValue origDest) { |
| auto &adjBuf = getAdjointBuffer(bb, origDest); |
| auto bufType = remapType(adjBuf->getType()); |
| auto adjVal = builder.emitLoadValueOperation( |
| loc, adjBuf, LoadOwnershipQualifier::Take); |
| recordTemporary(adjVal); |
| addAdjointValue(bb, origSrc, makeConcreteAdjointValue(adjVal), loc); |
| emitZeroIndirect(bufType.getASTType(), adjBuf, loc); |
| } |
| void visitStoreInst(StoreInst *si) { |
| visitStoreOperation( |
| si->getParent(), si->getLoc(), si->getSrc(), si->getDest()); |
| } |
| void visitStoreBorrowInst(StoreBorrowInst *sbi) { |
| visitStoreOperation( |
| sbi->getParent(), sbi->getLoc(), sbi->getSrc(), sbi->getDest()); |
| } |
| |
| /// Handle `copy_addr` instruction. |
| /// Original: copy_addr x to y |
| /// Adjoint: adj[x] += adj[y]; adj[y] = 0 |
| void visitCopyAddrInst(CopyAddrInst *cai) { |
| auto *bb = cai->getParent(); |
| auto &adjDest = getAdjointBuffer(bb, cai->getDest()); |
| auto destType = remapType(adjDest->getType()); |
| addToAdjointBuffer(bb, cai->getSrc(), adjDest, cai->getLoc()); |
| builder.emitDestroyAddrAndFold(cai->getLoc(), adjDest); |
| emitZeroIndirect(destType.getASTType(), adjDest, cai->getLoc()); |
| } |
| |
| /// Handle `copy_value` instruction. |
| /// Original: y = copy_value x |
| /// Adjoint: adj[x] += adj[y] |
| void visitCopyValueInst(CopyValueInst *cvi) { |
| auto *bb = cvi->getParent(); |
| auto adj = getAdjointValue(bb, cvi); |
| addAdjointValue(bb, cvi->getOperand(), adj, cvi->getLoc()); |
| } |
| |
| /// Handle `begin_borrow` instruction. |
| /// Original: y = begin_borrow x |
| /// Adjoint: adj[x] += adj[y] |
| void visitBeginBorrowInst(BeginBorrowInst *bbi) { |
| auto *bb = bbi->getParent(); |
| auto adj = getAdjointValue(bb, bbi); |
| addAdjointValue(bb, bbi->getOperand(), adj, bbi->getLoc()); |
| } |
| |
| /// Handle `begin_access` instruction. |
| /// Original: y = begin_access x |
| /// Adjoint: nothing |
| void visitBeginAccessInst(BeginAccessInst *bai) { |
| // Check for non-differentiable writes. |
| if (bai->getAccessKind() == SILAccessKind::Modify) { |
| if (auto *gai = dyn_cast<GlobalAddrInst>(bai->getSource())) { |
| getContext().emitNondifferentiabilityError(bai, getInvoker(), |
| diag::autodiff_cannot_differentiate_writes_to_global_variables); |
| errorOccurred = true; |
| return; |
| } |
| if (auto *pbi = dyn_cast<ProjectBoxInst>(bai->getSource())) { |
| getContext().emitNondifferentiabilityError(bai, getInvoker(), |
| diag::autodiff_cannot_differentiate_writes_to_mutable_captures); |
| errorOccurred = true; |
| return; |
| } |
| } |
| } |
| |
| /// Handle `unconditional_checked_cast_addr` instruction. |
| /// Original: y = unconditional_checked_cast_addr x |
| /// Adjoint: adj[x] += unconditional_checked_cast_addr adj[y] |
| void visitUnconditionalCheckedCastAddrInst( |
| UnconditionalCheckedCastAddrInst *uccai) { |
| auto *bb = uccai->getParent(); |
| auto &adjDest = getAdjointBuffer(bb, uccai->getDest()); |
| auto &adjSrc = getAdjointBuffer(bb, uccai->getSrc()); |
| auto destType = remapType(adjDest->getType()); |
| auto castBuf = builder.createAllocStack(uccai->getLoc(), adjSrc->getType()); |
| builder.createUnconditionalCheckedCastAddr( |
| uccai->getLoc(), adjDest, adjDest->getType().getASTType(), castBuf, |
| adjSrc->getType().getASTType()); |
| addToAdjointBuffer(bb, uccai->getSrc(), castBuf, uccai->getLoc()); |
| builder.emitDestroyAddrAndFold(uccai->getLoc(), castBuf); |
| builder.createDeallocStack(uccai->getLoc(), castBuf); |
| emitZeroIndirect(destType.getASTType(), adjDest, uccai->getLoc()); |
| } |
| |
| #define NOT_DIFFERENTIABLE(INST, DIAG) \ |
| void visit##INST##Inst(INST##Inst *inst) { \ |
| getContext().emitNondifferentiabilityError( \ |
| inst, getInvoker(), diag::DIAG); \ |
| errorOccurred = true; \ |
| return; \ |
| } |
| NOT_DIFFERENTIABLE(RefElementAddr, autodiff_class_property_not_supported) |
| #undef NOT_DIFFERENTIABLE |
| |
| #define NO_ADJOINT(INST) \ |
| void visit##INST##Inst(INST##Inst *inst) {} |
| // Terminators. |
| NO_ADJOINT(Return) |
| NO_ADJOINT(Branch) |
| NO_ADJOINT(CondBranch) |
| |
| // Buffer projection. |
| NO_ADJOINT(StructElementAddr) |
| NO_ADJOINT(TupleElementAddr) |
| |
| // Memory allocation/access. |
| NO_ADJOINT(AllocStack) |
| NO_ADJOINT(DeallocStack) |
| NO_ADJOINT(EndAccess) |
| |
| // Debugging/reference counting instructions. |
| NO_ADJOINT(DebugValue) |
| NO_ADJOINT(DebugValueAddr) |
| NO_ADJOINT(RetainValue) |
| NO_ADJOINT(RetainValueAddr) |
| NO_ADJOINT(ReleaseValue) |
| NO_ADJOINT(ReleaseValueAddr) |
| NO_ADJOINT(StrongRetain) |
| NO_ADJOINT(StrongRelease) |
| NO_ADJOINT(UnownedRetain) |
| NO_ADJOINT(UnownedRelease) |
| NO_ADJOINT(StrongRetainUnowned) |
| NO_ADJOINT(DestroyValue) |
| NO_ADJOINT(DestroyAddr) |
| |
| // Value ownership. |
| NO_ADJOINT(EndBorrow) |
| #undef NO_DERIVATIVE |
| }; |
| } // end anonymous namespace |
| |
| AdjointValue PullbackEmitter::makeZeroAdjointValue(SILType type) { |
| return AdjointValue::createZero(allocator, remapType(type)); |
| } |
| |
| AdjointValue |
| PullbackEmitter::makeConcreteAdjointValue(SILValue value) { |
| return AdjointValue::createConcrete(allocator, value); |
| } |
| |
| template<typename EltRange> |
| AdjointValue PullbackEmitter::makeAggregateAdjointValue( |
| SILType type, EltRange elements) { |
| return AdjointValue::createAggregate(allocator, remapType(type), elements); |
| } |
| |
| SILValue PullbackEmitter::materializeAdjointDirect( |
| AdjointValue val, SILLocation loc) { |
| assert(val.getType().isObject()); |
| LLVM_DEBUG(getADDebugStream() << |
| "Materializing adjoints for " << val << '\n'); |
| switch (val.getKind()) { |
| case AdjointValueKind::Zero: |
| return recordTemporary(emitZeroDirect(val.getType().getASTType(), loc)); |
| case AdjointValueKind::Aggregate: { |
| SmallVector<SILValue, 8> elements; |
| for (auto i : range(val.getNumAggregateElements())) { |
| auto eltVal = materializeAdjointDirect(val.getAggregateElement(i), loc); |
| elements.push_back(builder.emitCopyValueOperation(loc, eltVal)); |
| } |
| if (val.getType().is<TupleType>()) |
| return recordTemporary( |
| builder.createTuple(loc, val.getType(), elements)); |
| else |
| return recordTemporary( |
| builder.createStruct(loc, val.getType(), elements)); |
| } |
| case AdjointValueKind::Concrete: |
| return val.getConcreteValue(); |
| } |
| } |
| |
| SILValue PullbackEmitter::materializeAdjoint(AdjointValue val, |
| SILLocation loc) { |
| if (val.isConcrete()) { |
| LLVM_DEBUG(getADDebugStream() |
| << "Materializing adjoint: Value is concrete.\n"); |
| return val.getConcreteValue(); |
| } |
| LLVM_DEBUG(getADDebugStream() << "Materializing adjoint: Value is " |
| "non-concrete. Materializing directly.\n"); |
| return materializeAdjointDirect(val, loc); |
| } |
| |
| void PullbackEmitter::materializeAdjointIndirect( |
| AdjointValue val, SILValue destBufferAccess, SILLocation loc) { |
| switch (val.getKind()) { |
| /// Given a `%buf : *T, emit instructions that produce a zero or an aggregate |
| /// of zeros of the expected type. When `T` conforms to |
| /// `AdditiveArithmetic`, we emit a call to `AdditiveArithmetic.zero`. When |
| /// `T` is a builtin float, we emit a `float_literal` instruction. |
| /// Otherwise, we assert that `T` must be an aggregate where each element |
| /// conforms to `AdditiveArithmetic` or is a builtin float. We expect to emit |
| /// a zero for each element and use the appropriate aggregate constructor |
| /// instruction (in this case, `tuple`) to produce a tuple. But currently, |
| /// since we need indirect passing for aggregate instruction, we just use |
| /// `tuple_element_addr` to get element buffers and write elements to them. |
| case AdjointValueKind::Zero: |
| emitZeroIndirect(val.getSwiftType(), destBufferAccess, loc); |
| break; |
| /// Given a `%buf : *(T0, T1, T2, ...)` or `%buf : *Struct` recursively emit |
| /// instructions to materialize the symbolic tuple or struct, filling the |
| /// buffer. |
| case AdjointValueKind::Aggregate: { |
| if (auto *tupTy = val.getSwiftType()->getAs<TupleType>()) { |
| for (auto idx : range(val.getNumAggregateElements())) { |
| auto eltTy = SILType::getPrimitiveAddressType( |
| tupTy->getElementType(idx)->getCanonicalType()); |
| auto *eltBuf = |
| builder.createTupleElementAddr(loc, destBufferAccess, idx, eltTy); |
| materializeAdjointIndirect( |
| val.getAggregateElement(idx), eltBuf, loc); |
| } |
| } else if (auto *structDecl = |
| val.getSwiftType()->getStructOrBoundGenericStruct()) { |
| auto fieldIt = structDecl->getStoredProperties().begin(); |
| for (unsigned i = 0; fieldIt != structDecl->getStoredProperties().end(); |
| ++fieldIt, ++i) { |
| auto eltBuf = |
| builder.createStructElementAddr(loc, destBufferAccess, *fieldIt); |
| materializeAdjointIndirect( |
| val.getAggregateElement(i), eltBuf, loc); |
| } |
| } else { |
| llvm_unreachable("Not an aggregate type"); |
| } |
| break; |
| } |
| /// Value is already materialized! |
| case AdjointValueKind::Concrete: |
| auto concreteVal = val.getConcreteValue(); |
| builder.emitStoreValueOperation(loc, concreteVal, destBufferAccess, |
| StoreOwnershipQualifier::Init); |
| break; |
| } |
| } |
| |
| void PullbackEmitter::emitZeroIndirect(CanType type, SILValue bufferAccess, |
| SILLocation loc) { |
| auto tangentSpace = getTangentSpace(type); |
| assert(tangentSpace && "No tangent space for this type"); |
| switch (tangentSpace->getKind()) { |
| case VectorSpace::Kind::Vector: |
| emitZeroIntoBuffer(builder, type, bufferAccess, loc); |
| return; |
| case VectorSpace::Kind::Tuple: { |
| auto tupleType = tangentSpace->getTuple(); |
| SmallVector<SILValue, 8> zeroElements; |
| for (unsigned i : range(tupleType->getNumElements())) { |
| auto eltAddr = builder.createTupleElementAddr(loc, bufferAccess, i); |
| emitZeroIndirect(tupleType->getElementType(i)->getCanonicalType(), |
| eltAddr, loc); |
| } |
| return; |
| } |
| case VectorSpace::Kind::Function: { |
| llvm_unreachable( |
| "Unimplemented: Emit thunks for abstracting zero initialization"); |
| } |
| } |
| } |
| |
| SILValue PullbackEmitter::emitZeroDirect(CanType type, SILLocation loc) { |
| auto silType = getModule().Types.getLoweredLoadableType( |
| type, ResilienceExpansion::Minimal, getModule()); |
| auto *buffer = builder.createAllocStack(loc, silType); |
| emitZeroIndirect(type, buffer, loc); |
| auto loaded = builder.emitLoadValueOperation( |
| loc, buffer, LoadOwnershipQualifier::Take); |
| builder.createDeallocStack(loc, buffer); |
| return loaded; |
| } |
| |
| AdjointValue |
| PullbackEmitter::accumulateAdjointsDirect(AdjointValue lhs, AdjointValue rhs, |
| SILLocation loc) { |
| LLVM_DEBUG(getADDebugStream() |
| << "Materializing adjoint directly.\nLHS: " << lhs |
| << "\nRHS: " << rhs << '\n'); |
| |
| switch (lhs.getKind()) { |
| // x |
| case AdjointValueKind::Concrete: { |
| auto lhsVal = lhs.getConcreteValue(); |
| switch (rhs.getKind()) { |
| // x + y |
| case AdjointValueKind::Concrete: { |
| auto rhsVal = rhs.getConcreteValue(); |
| auto sum = recordTemporary(accumulateDirect(lhsVal, rhsVal, loc)); |
| return makeConcreteAdjointValue(sum); |
| } |
| // x + 0 => x |
| case AdjointValueKind::Zero: |
| return lhs; |
| // x + (y, z) => (x.0 + y, x.1 + z) |
| case AdjointValueKind::Aggregate: |
| SmallVector<AdjointValue, 8> newElements; |
| auto lhsTy = lhsVal->getType().getASTType(); |
| auto lhsValCopy = builder.emitCopyValueOperation(loc, lhsVal); |
| if (auto *tupTy = lhsTy->getAs<TupleType>()) { |
| auto elts = builder.createDestructureTuple(loc, lhsValCopy); |
| llvm::for_each(elts->getResults(), |
| [this](SILValue result) { recordTemporary(result); }); |
| for (auto i : indices(elts->getResults())) { |
| auto rhsElt = rhs.getAggregateElement(i); |
| newElements.push_back(accumulateAdjointsDirect( |
| makeConcreteAdjointValue(elts->getResult(i)), rhsElt, loc)); |
| } |
| } else if (auto *structDecl = lhsTy->getStructOrBoundGenericStruct()) { |
| auto elts = |
| builder.createDestructureStruct(lhsVal.getLoc(), lhsValCopy); |
| llvm::for_each(elts->getResults(), |
| [this](SILValue result) { recordTemporary(result); }); |
| for (unsigned i : indices(elts->getResults())) { |
| auto rhsElt = rhs.getAggregateElement(i); |
| newElements.push_back( |
| accumulateAdjointsDirect( |
| makeConcreteAdjointValue(elts->getResult(i)), rhsElt, loc)); |
| } |
| } else { |
| llvm_unreachable("Not an aggregate type"); |
| } |
| return makeAggregateAdjointValue(lhsVal->getType(), newElements); |
| } |
| } |
| // 0 |
| case AdjointValueKind::Zero: |
| // 0 + x => x |
| return rhs; |
| // (x, y) |
| case AdjointValueKind::Aggregate: |
| switch (rhs.getKind()) { |
| // (x, y) + z => (x + z.0, y + z.1) |
| case AdjointValueKind::Concrete: |
| // x + 0 => x |
| case AdjointValueKind::Zero: |
| return lhs; |
| // (x, y) + (z, w) => (x + z, y + w) |
| case AdjointValueKind::Aggregate: { |
| SmallVector<AdjointValue, 8> newElements; |
| for (auto i : range(lhs.getNumAggregateElements())) |
| newElements.push_back( |
| accumulateAdjointsDirect(lhs.getAggregateElement(i), |
| rhs.getAggregateElement(i), |
| loc)); |
| return makeAggregateAdjointValue(lhs.getType(), newElements); |
| } |
| } |
| } |
| } |
| |
| SILValue PullbackEmitter::accumulateDirect(SILValue lhs, SILValue rhs, |
| SILLocation loc) { |
| // TODO: Optimize for the case when lhs == rhs. |
| LLVM_DEBUG(getADDebugStream() << |
| "Emitting adjoint accumulation for lhs: " << lhs << |
| " and rhs: " << rhs << "\n"); |
| assert(lhs->getType() == rhs->getType() && "Adjoints must have equal types!"); |
| assert(lhs->getType().isObject() && rhs->getType().isObject() && |
| "Adjoint types must be both object types!"); |
| auto adjointTy = lhs->getType(); |
| auto adjointASTTy = adjointTy.getASTType(); |
| auto tangentSpace = getTangentSpace(adjointASTTy); |
| auto lhsCopy = builder.emitCopyValueOperation(loc, lhs); |
| auto rhsCopy = builder.emitCopyValueOperation(loc, rhs); |
| assert(tangentSpace && "No tangent space for this type"); |
| switch (tangentSpace->getKind()) { |
| case VectorSpace::Kind::Vector: { |
| // Allocate buffers for inputs and output. |
| auto *resultBuf = builder.createAllocStack(loc, adjointTy); |
| auto *lhsBuf = builder.createAllocStack(loc, adjointTy); |
| auto *rhsBuf = builder.createAllocStack(loc, adjointTy); |
| // Initialize input buffers. |
| builder.emitStoreValueOperation(loc, lhsCopy, lhsBuf, |
| StoreOwnershipQualifier::Init); |
| builder.emitStoreValueOperation(loc, rhsCopy, rhsBuf, |
| StoreOwnershipQualifier::Init); |
| accumulateIndirect(resultBuf, lhsBuf, rhsBuf, loc); |
| builder.emitDestroyAddr(loc, lhsBuf); |
| builder.emitDestroyAddr(loc, rhsBuf); |
| // Deallocate input buffers. |
| builder.createDeallocStack(loc, rhsBuf); |
| builder.createDeallocStack(loc, lhsBuf); |
| auto val = builder.emitLoadValueOperation( |
| loc, resultBuf, LoadOwnershipQualifier::Take); |
| // Deallocate result buffer. |
| builder.createDeallocStack(loc, resultBuf); |
| return val; |
| } |
| case VectorSpace::Kind::Tuple: { |
| SmallVector<SILValue, 8> adjElements; |
| auto lhsElts = builder.createDestructureTuple(loc, lhsCopy)->getResults(); |
| auto rhsElts = builder.createDestructureTuple(loc, rhsCopy)->getResults(); |
| for (auto zipped : llvm::zip(lhsElts, rhsElts)) |
| adjElements.push_back( |
| accumulateDirect(std::get<0>(zipped), std::get<1>(zipped), loc)); |
| return builder.createTuple(loc, adjointTy, adjElements); |
| } |
| case VectorSpace::Kind::Function: { |
| llvm_unreachable( |
| "Unimplemented: Emit thunks for abstracting adjoint accumulation"); |
| } |
| } |
| } |
| |
| void PullbackEmitter::accumulateIndirect( |
| SILValue resultBufAccess, SILValue lhsBufAccess, SILValue rhsBufAccess, |
| SILLocation loc) { |
| // TODO: Optimize for the case when lhs == rhs. |
| assert(lhsBufAccess->getType() == rhsBufAccess->getType() && |
| "Adjoint values must have same type!"); |
| assert(lhsBufAccess->getType().isAddress() && |
| rhsBufAccess->getType().isAddress() && |
| "Adjoint values must both have address types!"); |
| auto adjointTy = lhsBufAccess->getType(); |
| auto adjointASTTy = adjointTy.getASTType(); |
| auto *swiftMod = getModule().getSwiftModule(); |
| auto tangentSpace = adjointASTTy->getAutoDiffAssociatedTangentSpace( |
| LookUpConformanceInModule(swiftMod)); |
| assert(tangentSpace && "No tangent space for this type"); |
| switch (tangentSpace->getKind()) { |
| case VectorSpace::Kind::Vector: { |
| auto *proto = getContext().getAdditiveArithmeticProtocol(); |
| auto *combinerFuncDecl = getContext().getPlusDecl(); |
| // Call the combiner function and return. |
| auto adjointParentModule = tangentSpace->getNominal() |
| ? tangentSpace->getNominal()->getModuleContext() |
| : getModule().getSwiftModule(); |
| auto confRef = adjointParentModule->lookupConformance(adjointASTTy, |
| proto); |
| assert(confRef.hasValue() && "Missing conformance to `AdditiveArithmetic`"); |
| SILDeclRef declRef(combinerFuncDecl, SILDeclRef::Kind::Func); |
| auto silFnTy = getContext().getTypeConverter().getConstantType(declRef); |
| // %0 = witness_method @+ |
| auto witnessMethod = builder.createWitnessMethod(loc, adjointASTTy, |
| *confRef, declRef, |
| silFnTy); |
| auto subMap = SubstitutionMap::getProtocolSubstitutions( |
| proto, adjointASTTy, *confRef); |
| // %1 = metatype $T.Type |
| auto metatypeType = |
| CanMetatypeType::get(adjointASTTy, MetatypeRepresentation::Thick); |
| auto metatypeSILType = SILType::getPrimitiveObjectType(metatypeType); |
| auto metatype = builder.createMetatype(loc, metatypeSILType); |
| // %2 = apply $0(%result, %new, %old, %1) |
| builder.createApply(loc, witnessMethod, subMap, |
| {resultBufAccess, rhsBufAccess, lhsBufAccess, metatype}, |
| /*isNonThrowing*/ false); |
| builder.emitDestroyValueOperation(loc, witnessMethod); |
| return; |
| } |
| case VectorSpace::Kind::Tuple: { |
| auto tupleType = tangentSpace->getTuple(); |
| for (unsigned i : range(tupleType->getNumElements())) { |
| auto *destAddr = builder.createTupleElementAddr(loc, resultBufAccess, i); |
| auto *eltAddrLHS = builder.createTupleElementAddr(loc, lhsBufAccess, i); |
| auto *eltAddrRHS = builder.createTupleElementAddr(loc, rhsBufAccess, i); |
| accumulateIndirect(destAddr, eltAddrLHS, eltAddrRHS, loc); |
| } |
| return; |
| } |
| case VectorSpace::Kind::Function: { |
| llvm_unreachable( |
| "Unimplemented: Emit thunks for abstracting adjoint value " |
| "accumulation"); |
| } |
| } |
| } |
| |
| void PullbackEmitter::accumulateIndirect(SILValue lhsDestAccess, |
| SILValue rhsAccess, SILLocation loc) { |
| assert(lhsDestAccess->getType().isAddress() && |
| rhsAccess->getType().isAddress()); |
| assert(lhsDestAccess->getFunction() == &getPullback()); |
| assert(rhsAccess->getFunction() == &getPullback()); |
| auto type = lhsDestAccess->getType(); |
| auto astType = type.getASTType(); |
| auto *swiftMod = getModule().getSwiftModule(); |
| auto tangentSpace = astType->getAutoDiffAssociatedTangentSpace( |
| LookUpConformanceInModule(swiftMod)); |
| assert(tangentSpace && "No tangent space for this type"); |
| switch (tangentSpace->getKind()) { |
| case VectorSpace::Kind::Vector: { |
| auto *proto = getContext().getAdditiveArithmeticProtocol(); |
| auto *accumulatorFuncDecl = getContext().getPlusEqualDecl(); |
| // Call the combiner function and return. |
| auto confRef = swiftMod->lookupConformance(astType, proto); |
| assert(confRef.hasValue() && "Missing conformance to `AdditiveArithmetic`"); |
| SILDeclRef declRef(accumulatorFuncDecl, SILDeclRef::Kind::Func); |
| auto silFnTy = getContext().getTypeConverter().getConstantType(declRef); |
| // %0 = witness_method @+= |
| auto witnessMethod = |
| builder.createWitnessMethod(loc, astType, *confRef, declRef, silFnTy); |
| auto subMap = |
| SubstitutionMap::getProtocolSubstitutions(proto, astType, *confRef); |
| // %1 = metatype $T.Type |
| auto metatypeType = |
| CanMetatypeType::get(astType, MetatypeRepresentation::Thick); |
| auto metatypeSILType = SILType::getPrimitiveObjectType(metatypeType); |
| auto metatype = builder.createMetatype(loc, metatypeSILType); |
| // %2 = apply $0(%lhs, %rhs, %1) |
| builder.createApply(loc, witnessMethod, subMap, |
| {lhsDestAccess, rhsAccess, metatype}, |
| /*isNonThrowing*/ false); |
| builder.emitDestroyValueOperation(loc, witnessMethod); |
| return; |
| } |
| case VectorSpace::Kind::Tuple: { |
| auto tupleType = tangentSpace->getTuple(); |
| for (unsigned i : range(tupleType->getNumElements())) { |
| auto *destAddr = builder.createTupleElementAddr(loc, lhsDestAccess, i); |
| auto *eltAddrRHS = builder.createTupleElementAddr(loc, rhsAccess, i); |
| accumulateIndirect(destAddr, eltAddrRHS, loc); |
| } |
| return; |
| } |
| case VectorSpace::Kind::Function: { |
| llvm_unreachable( |
| "Unimplemented: Emit thunks for abstracting adjoint value " |
| "accumulation"); |
| } |
| } |
| } |
| |
| bool VJPEmitter::run() { |
| LLVM_DEBUG(getADDebugStream() |
| << "Cloning original @" << original->getName() |
| << " to vjp @" << vjp->getName() << '\n'); |
| // Create entry BB and arguments. |
| auto *entry = vjp->createBasicBlock(); |
| createEntryArguments(vjp); |
| |
| // Clone. |
| SmallVector<SILValue, 4> entryArgs(entry->getArguments().begin(), |
| entry->getArguments().end()); |
| cloneFunctionBody(original, entry, entryArgs); |
| // If errors occurred, back out. |
| if (errorOccurred) |
| return true; |
| |
| // Each `@guaranteed` trampoline argument needs to have a lifetime-ending use |
| // past its destination argument's lifetime-ending uses (aka. `end_borrow`). |
| // `trampolinedGuaranteedPhiArguments` tracks all `@guaranteed` trampoline |
| // arguments. We emit an `end_borrow` immediately past each destination |
| // argument's lifetime-ending uses. |
| for (auto &trampolinedArgPair : trampolinedGuaranteedPhiArguments) { |
| for (auto *destArgUse : trampolinedArgPair.destinationArgument->getUses()) { |
| if (auto *lifetimeEnd = dyn_cast<EndBorrowInst>(destArgUse->getUser())) { |
| getBuilder().setInsertionPoint(lifetimeEnd->getParentBlock(), |
| std::next(lifetimeEnd->getIterator())); |
| getBuilder().emitEndBorrowOperation( |
| lifetimeEnd->getLoc(), trampolinedArgPair.trampolineArgument); |
| } |
| } |
| } |
| |
| // Generate pullback code. |
| PullbackEmitter PullbackEmitter(*this); |
| if (PullbackEmitter.run()) { |
| errorOccurred = true; |
| return true; |
| } |
| LLVM_DEBUG(getADDebugStream() << "Generated VJP for " |
| << original->getName() << ":\n" << *vjp); |
| return errorOccurred; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // `[differentiable]` attribute processing |
| //===----------------------------------------------------------------------===// |
| |
| SILFunction * |
| ADContext::declareExternalDerivativeFunction( |
| SILFunction *original, SILDifferentiableAttr *attr, StringRef name, |
| AutoDiffDerivativeFunctionKind kind) { |
| auto &module = getModule(); |
| auto &indices = attr->getIndices(); |
| auto originalTy = original->getLoweredFunctionType(); |
| auto originalLoc = original->getLocation(); |
| auto assocGenSig = getDerivativeGenericSignature(attr, original); |
| auto derivativeFnTy = originalTy->getAutoDiffDerivativeFunctionType( |
| indices.parameters, indices.source, kind, module.Types, |
| LookUpConformanceInModule(module.getSwiftModule()), assocGenSig); |
| SILOptFunctionBuilder fb(getTransform()); |
| // Create external function declaration. |
| auto *derivativeFn = fb.createFunction( |
| SILLinkage::PublicExternal, name, derivativeFnTy, |
| /*genericEnv*/ nullptr, originalLoc, original->isBare(), IsNotTransparent, |
| original->isSerialized(), original->isDynamicallyReplaceable()); |
| // Note: Setting debug scope prevents crashes during later transforms. |
| derivativeFn->setDebugScope(new (module) SILDebugScope(originalLoc, derivativeFn)); |
| return derivativeFn; |
| } |
| |
| static SILFunction *createEmptyVJP( |
| ADContext &context, SILFunction *original, SILDifferentiableAttr *attr, |
| bool isExported) { |
| LLVM_DEBUG({ |
| auto &s = getADDebugStream(); |
| s << "Creating VJP:\n\t"; |
| s << "Original type: " << original->getLoweredFunctionType() << "\n\t"; |
| }); |
| |
| auto &module = context.getModule(); |
| auto originalTy = original->getLoweredFunctionType(); |
| auto indices = attr->getIndices(); |
| |
| // === Create an empty VJP. === |
| Mangle::ASTMangler mangler; |
| auto vjpName = original->getASTContext().getIdentifier( |
| mangler.mangleAutoDiffDerivativeFunctionHelper( |
| original->getName(), AutoDiffDerivativeFunctionKind::VJP, indices)) |
| .str(); |
| auto vjpGenericSig = getDerivativeGenericSignature(attr, original); |
| |
| // RAII that pushes the original function's generic signature to |
| // `module.Types` so that calls to `module.Types.getTypeLowering()` below |
| // will know the VJP's generic parameter types. |
| Lowering::GenericContextScope genericContextScope( |
| module.Types, vjpGenericSig); |
| |
| auto *vjpGenericEnv = vjpGenericSig |
| ? vjpGenericSig->getGenericEnvironment() |
| : nullptr; |
| auto vjpType = originalTy->getAutoDiffDerivativeFunctionType( |
| indices.parameters, indices.source, AutoDiffDerivativeFunctionKind::VJP, |
| module.Types, LookUpConformanceInModule(module.getSwiftModule()), |
| vjpGenericSig); |
| |
| SILOptFunctionBuilder fb(context.getTransform()); |
| auto linkage = autodiff::getAutoDiffDerivativeFunctionLinkage( |
| original->getLinkage(), isExported); |
| auto *vjp = fb.createFunction(linkage, vjpName, vjpType, vjpGenericEnv, |
| original->getLocation(), original->isBare(), |
| IsNotTransparent, original->isSerialized(), |
| original->isDynamicallyReplaceable()); |
| vjp->setDebugScope(new (module) SILDebugScope(original->getLocation(), vjp)); |
| attr->setVJPName(vjpName); |
| |
| LLVM_DEBUG(llvm::dbgs() << "VJP type: " << vjp->getLoweredFunctionType() |
| << "\n"); |
| return vjp; |
| } |
| |
| static SILFunction *createEmptyJVP( |
| ADContext &context, SILFunction *original, SILDifferentiableAttr *attr, |
| bool isExported) { |
| LLVM_DEBUG({ |
| auto &s = getADDebugStream(); |
| s << "Creating JVP:\n\t"; |
| s << "Original type: " << original->getLoweredFunctionType() << "\n\t"; |
| }); |
| |
| auto &module = context.getModule(); |
| auto originalTy = original->getLoweredFunctionType(); |
| auto indices = attr->getIndices(); |
| |
| // === Create an empty JVP. === |
| Mangle::ASTMangler mangler; |
| auto jvpName = original->getASTContext().getIdentifier( |
| mangler.mangleAutoDiffDerivativeFunctionHelper( |
| original->getName(), AutoDiffDerivativeFunctionKind::JVP, indices)) |
| .str(); |
| auto jvpGenericSig = getDerivativeGenericSignature(attr, original); |
| |
| // RAII that pushes the original function's generic signature to |
| // `module.Types` so that calls to `module.Types.getTypeLowering()` below |
| // will know the VJP's generic parameter types. |
| Lowering::GenericContextScope genericContextScope( |
| module.Types, jvpGenericSig); |
| |
| auto *jvpGenericEnv = jvpGenericSig |
| ? jvpGenericSig->getGenericEnvironment() |
| : nullptr; |
| auto jvpType = originalTy->getAutoDiffDerivativeFunctionType( |
| indices.parameters, indices.source, |
| AutoDiffDerivativeFunctionKind::JVP, module.Types, |
| LookUpConformanceInModule(module.getSwiftModule()), jvpGenericSig); |
| |
| SILOptFunctionBuilder fb(context.getTransform()); |
| auto linkage = autodiff::getAutoDiffDerivativeFunctionLinkage( |
| original->getLinkage(), isExported); |
| auto *jvp = fb.createFunction(linkage, jvpName, jvpType, jvpGenericEnv, |
| original->getLocation(), original->isBare(), |
| IsNotTransparent, original->isSerialized(), |
| original->isDynamicallyReplaceable()); |
| jvp->setDebugScope(new (module) SILDebugScope(original->getLocation(), jvp)); |
| attr->setJVPName(jvpName); |
| |
| LLVM_DEBUG(llvm::dbgs() << "JVP type: " << jvp->getLoweredFunctionType() |
| << "\n"); |
| return jvp; |
| } |
| |
| /// Returns true on error. |
| bool ADContext::processDifferentiableAttribute( |
| SILFunction *original, SILDifferentiableAttr *attr, |
| DifferentiationInvoker invoker) { |
| auto &module = getModule(); |
| // Try to look up JVP only if attribute specifies JVP name or if original |
| // function is an external declaration. If JVP function cannot be found, |
| // create an external JVP reference. |
| StringRef jvpName; |
| SILFunction *jvp = nullptr; |
| if (attr->hasJVP()) { |
| jvpName = attr->getJVPName(); |
| } else if (original->isExternalDeclaration()) { |
| Mangle::ASTMangler mangler; |
| jvpName = original->getASTContext().getIdentifier( |
| mangler.mangleAutoDiffDerivativeFunctionHelper( |
| original->getName(), AutoDiffDerivativeFunctionKind::JVP, |
| attr->getIndices())).str(); |
| } |
| if (!jvpName.empty()) { |
| jvp = module.lookUpFunction(jvpName); |
| if (!jvp) |
| jvp = declareExternalDerivativeFunction( |
| original, attr, jvpName, AutoDiffDerivativeFunctionKind::JVP); |
| attr->setJVPName(jvpName); |
| } |
| |
| // If differentiation is triggered by `[differentiable]`, derivative function |
| // should share linkage of original function. |
| auto isDerivativeFnExported = |
| invoker.getKind() == |
| DifferentiationInvoker::Kind::SILDifferentiableAttribute; |
| |
| // Try to look up VJP only if attribute specifies VJP name or if original |
| // function is an external declaration. If VJP function cannot be found, |
| // create an external VJP reference. |
| StringRef vjpName; |
| SILFunction *vjp = nullptr; |
| if (attr->hasVJP()) { |
| vjpName = attr->getVJPName(); |
| } else if (original->isExternalDeclaration()) { |
| Mangle::ASTMangler mangler; |
| vjpName = original->getASTContext().getIdentifier( |
| mangler.mangleAutoDiffDerivativeFunctionHelper( |
| original->getName(), AutoDiffDerivativeFunctionKind::VJP, |
| attr->getIndices())).str(); |
| } |
| if (!vjpName.empty()) { |
| vjp = module.lookUpFunction(vjpName); |
| if (!vjp) |
| vjp = declareExternalDerivativeFunction( |
| original, attr, vjpName, AutoDiffDerivativeFunctionKind::VJP); |
| attr->setVJPName(vjpName); |
| } |
| |
| // If the JVP doesn't exist, need to synthesize it. |
| if (!jvp) { |
| // Diagnose: |
| // - Functions with no return. |
| // - Functions with unsupported control flow. |
| if (getASTContext().LangOpts.EnableExperimentalForwardModeDifferentiation && |
| (diagnoseNoReturn(*this, original, invoker) || |
| diagnoseUnsupportedControlFlow(*this, original, invoker))) |
| return true; |
| |
| jvp = createEmptyJVP(*this, original, attr, isDerivativeFnExported); |
| getGeneratedFunctions().push_back(jvp); |
| |
| // For now, only do JVP generation if the flag is enabled and if custom VJP |
| // does not exist. If custom VJP exists but custom JVP does not, skip JVP |
| // generation because generated JVP may not match semantics of custom VJP. |
| // Instead, create an empty JVP. |
| if (getASTContext().LangOpts.EnableExperimentalForwardModeDifferentiation && |
| !vjp) { |
| // JVP and differential generation do not currently support functions with |
| // multiple basic blocks. |
| if (original->getBlocks().size() > 1) { |
| emitNondifferentiabilityError( |
| original->getLocation().getSourceLoc(), invoker, |
| diag::autodiff_jvp_control_flow_not_supported); |
| return true; |
| } |
| |
| JVPEmitter emitter(*this, original, attr, jvp, invoker); |
| if (emitter.run()) |
| return true; |
| } else { |
| LLVM_DEBUG(getADDebugStream() |
| << "Generating empty JVP for original @" |
| << original->getName() << '\n'); |
| // Create empty JVP body since custom VJP exists. |
| auto *entry = jvp->createBasicBlock(); |
| createEntryArguments(jvp); |
| SILBuilder builder(entry); |
| auto loc = jvp->getLocation(); |
| |
| // Destroy all owned arguments. |
| for (auto *arg : entry->getArguments()) |
| if (arg->getOwnershipKind() == ValueOwnershipKind::Owned) |
| builder.emitDestroyOperation(loc, arg); |
| |
| // Fatal error in case this JVP is called by the user. |
| auto neverResultInfo = SILResultInfo( |
| module.getASTContext().getNeverType(), ResultConvention::Unowned); |
| auto fatalErrorJVPType = SILFunctionType::get( |
| /*genericSig*/ nullptr, |
| SILFunctionType::ExtInfo().withRepresentation( |
| SILFunctionTypeRepresentation::Thin), |
| SILCoroutineKind::None, ParameterConvention::Direct_Unowned, {}, |
| /*interfaceYields*/ {}, neverResultInfo, |
| /*interfaceErrorResults*/ None, getASTContext()); |
| auto fnBuilder = SILOptFunctionBuilder(getTransform()); |
| auto *fatalErrrorJvpFunc = fnBuilder.getOrCreateFunction( |
| loc, "_printJVPErrorAndExit", SILLinkage::PublicExternal, |
| fatalErrorJVPType, IsNotBare, IsNotTransparent, IsNotSerialized, |
| IsNotDynamic, ProfileCounter(), IsNotThunk); |
| auto *jvpErrorFuncRef = |
| builder.createFunctionRef(loc, fatalErrrorJvpFunc); |
| builder.createApply(loc, jvpErrorFuncRef, SubstitutionMap(), {}); |
| builder.createUnreachable(loc); |
| LLVM_DEBUG(getADDebugStream() << "Generated empty JVP for " |
| << original->getName() << ":\n" << *jvp); |
| } |
| } |
| |
| // If the VJP doesn't exist, need to synthesize it. |
| if (!vjp) { |
| // Diagnose: |
| // - Functions with no return. |
| // - Functions with unsupported control flow. |
| if (diagnoseNoReturn(*this, original, invoker) || |
| diagnoseUnsupportedControlFlow(*this, original, invoker)) |
| return true; |
| |
| vjp = createEmptyVJP(*this, original, attr, isDerivativeFnExported); |
| getGeneratedFunctions().push_back(vjp); |
| VJPEmitter emitter(*this, original, attr, vjp, invoker); |
| return emitter.run(); |
| } |
| |
| return false; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Differentiation pass implementation |
| //===----------------------------------------------------------------------===// |
| |
| /// The automatic differentiation pass. |
| namespace { |
| class Differentiation : public SILModuleTransform { |
| public: |
| Differentiation() : SILModuleTransform() {} |
| void run() override; |
| }; |
| } // end anonymous namespace |
| |
| std::pair<SILFunction *, SubstitutionMap> |
| ADContext::getOrCreateSubsetParametersThunkForLinearMap( |
| SILFunction *parentThunk, CanSILFunctionType linearMapType, |
| CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind, |
| SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices) { |
| LLVM_DEBUG(getADDebugStream() |
| << "Getting a subset parameters thunk for " << linearMapType |
| << " from " << actualIndices << " to " << desiredIndices << '\n'); |
| |
| SubstitutionMap interfaceSubs; |
| GenericEnvironment *genericEnv = nullptr; |
| auto thunkType = buildThunkType( |
| parentThunk, linearMapType, targetType, genericEnv, interfaceSubs, |
| /*withoutActuallyEscaping*/ true, |
| DifferentiationThunkKind::Reabstraction); |
| |
| // TODO(TF-685): Use more principled mangling for thunks. |
| std::string thunkName; |
| switch (kind) { |
| case AutoDiffDerivativeFunctionKind::JVP: |
| thunkName = "differential"; |
| break; |
| case AutoDiffDerivativeFunctionKind::VJP: |
| thunkName = "pullback"; |
| } |
| Mangle::ASTMangler mangler; |
| auto fromInterfaceType = |
| linearMapType->mapTypeOutOfContext()->getCanonicalType(); |
| auto toInterfaceType = targetType->mapTypeOutOfContext()->getCanonicalType(); |
| CanType dynamicSelfType; |
| thunkName = "AD__" + mangler.mangleReabstractionThunkHelper( |
| thunkType, fromInterfaceType, toInterfaceType, dynamicSelfType, |
| module.getSwiftModule()) + "_" + desiredIndices.mangle() + "_" + |
| thunkName; |
| thunkName += "_index_subset_thunk"; |
| |
| auto loc = parentThunk->getLocation(); |
| SILOptFunctionBuilder fb(getTransform()); |
| auto *thunk = fb.getOrCreateSharedFunction( |
| loc, thunkName, thunkType, IsBare, IsTransparent, IsSerialized, |
| ProfileCounter(), IsThunk, IsNotDynamic); |
| |
| if (!thunk->empty()) |
| return {thunk, interfaceSubs}; |
| |
| thunk->setGenericEnvironment(genericEnv); |
| thunk->setOwnershipEliminated(); |
| auto *entry = thunk->createBasicBlock(); |
| SILBuilder builder(entry); |
| createEntryArguments(thunk); |
| |
| // Get arguments. |
| SmallVector<SILValue, 4> arguments; |
| SmallVector<AllocStackInst *, 4> localAllocations; |
| |
| // Build a `.zero` argument for the given `Differentiable`-conforming type. |
| auto buildZeroArgument = [&](SILType zeroSILType) { |
| auto zeroSILObjType = zeroSILType.getObjectType(); |
| auto zeroType = zeroSILType.getASTType(); |
| auto *swiftMod = getModule().getSwiftModule(); |
| auto tangentSpace = zeroType->getAutoDiffAssociatedTangentSpace( |
| LookUpConformanceInModule(swiftMod)); |
| assert(tangentSpace && "No tangent space for this type"); |
| switch (tangentSpace->getKind()) { |
| case VectorSpace::Kind::Vector: { |
| auto *buf = builder.createAllocStack(loc, zeroSILObjType); |
| localAllocations.push_back(buf); |
| emitZeroIntoBuffer(builder, zeroType, buf, loc); |
| if (zeroSILType.isAddress()) |
| arguments.push_back(buf); |
| else { |
| auto *arg = builder.createLoad(loc, buf, |
| LoadOwnershipQualifier::Unqualified); |
| arguments.push_back(arg); |
| } |
| break; |
| } |
| case VectorSpace::Kind::Tuple: { |
| llvm_unreachable( |
| "Unimplemented: Handle zero initialization for tuples"); |
| } |
| case VectorSpace::Kind::Function: |
| llvm_unreachable( |
| "Unimplemented: Emit thunks for abstracting zero initialization"); |
| } |
| }; |
| |
| // `actualIndices` and `desiredIndices` are with respect to the original |
| // function. However, the differential parameters and pullback results may |
| // already be w.r.t. a subset. We create a map between the original function's |
| // actual parameter indices and the linear map's actual indices. |
| // Example: |
| // Original: (T0, T1, T2) -> R |
| // Actual indices: 0, 2 |
| // Original differential: (T0, T2) -> R |
| // Original pullback: R -> (T0, T2) |
| // Desired indices w.r.t. original: 2 |
| // Desired indices w.r.t. linear map: 1 |
| SmallVector<unsigned, 4> actualParamIndicesMap( |
| actualIndices.parameters->getCapacity(), UINT_MAX); |
| { |
| unsigned indexInBitVec = 0; |
| for (auto index : actualIndices.parameters->getIndices()) { |
| actualParamIndicesMap[index] = indexInBitVec; |
| indexInBitVec++; |
| } |
| } |
| auto mapOriginalParameterIndex = [&](unsigned index) -> unsigned { |
| auto mappedIndex = actualParamIndicesMap[index]; |
| assert(mappedIndex < actualIndices.parameters->getCapacity()); |
| return mappedIndex; |
| }; |
| |
| switch (kind) { |
| // Differential arguments are: |
| // - All indirect results, followed by: |
| // - An interleaving of: |
| // - Thunk arguments (when parameter index is in both desired and actual |
| // indices). |
| // - Zeros (when parameter is not in desired indices). |
| case AutoDiffDerivativeFunctionKind::JVP: { |
| // Forward all indirect results. |
| arguments.append(thunk->getIndirectResults().begin(), |
| thunk->getIndirectResults().end()); |
| auto toArgIter = thunk->getArgumentsWithoutIndirectResults().begin(); |
| auto useNextArgument = [&]() { |
| arguments.push_back(*toArgIter++); |
| }; |
| // Iterate over actual indices. |
| for (unsigned i : actualIndices.parameters->getIndices()) { |
| // If index is desired, use next argument. |
| if (desiredIndices.isWrtParameter(i)) { |
| useNextArgument(); |
| } |
| // Otherwise, construct and use a zero argument. |
| else { |
| auto zeroSILType = |
| linearMapType->getParameters()[mapOriginalParameterIndex(i)] |
| .getSILStorageType(); |
| buildZeroArgument(zeroSILType); |
| } |
| } |
| break; |
| } |
| // Pullback arguments are: |
| // - An interleaving of: |
| // - Thunk indirect results (when parameter index is in both desired and |
| // actual indices). |
| // - Zeros (when parameter is not in desired indices). |
| // - All actual arguments. |
| case AutoDiffDerivativeFunctionKind::VJP: { |
| auto toIndirectResultsIter = thunk->getIndirectResults().begin(); |
| auto useNextResult = [&]() { |
| arguments.push_back(*toIndirectResultsIter++); |
| }; |
| // Iterate over actual indices. |
| for (unsigned i : actualIndices.parameters->getIndices()) { |
| auto resultInfo = |
| linearMapType->getResults()[mapOriginalParameterIndex(i)]; |
| // Skip direct results. Only indirect results are relevant as arguments. |
| if (resultInfo.isFormalDirect()) |
| continue; |
| // If index is desired, use next indirect result. |
| if (desiredIndices.isWrtParameter(i)) { |
| useNextResult(); |
| continue; |
| } |
| // Otherwise, construct and use an uninitialized indirect result. |
| auto *indirectResult = |
| builder.createAllocStack(loc, resultInfo.getSILStorageType()); |
| localAllocations.push_back(indirectResult); |
| arguments.push_back(indirectResult); |
| } |
| // Foward all actual non-indirect-result arguments. |
| arguments.append(thunk->getArgumentsWithoutIndirectResults().begin(), |
| thunk->getArgumentsWithoutIndirectResults().end() - 1); |
| break; |
| } |
| } |
| |
| // Get the linear map thunk argument and apply it. |
| auto *linearMap = thunk->getArguments().back(); |
| auto *ai = builder.createApply( |
| loc, linearMap, SubstitutionMap(), arguments, /*isNonThrowing*/ false); |
| |
| // If differential thunk, deallocate local allocations and directly return |
| // `apply` result. |
| if (kind == AutoDiffDerivativeFunctionKind::JVP) { |
| for (auto *alloc : reversed(localAllocations)) |
| builder.createDeallocStack(loc, alloc); |
| builder.createReturn(loc, ai); |
| return {thunk, interfaceSubs}; |
| } |
| |
| // If pullback thunk, return only the desired results and clean up the |
| // undesired results. |
| SmallVector<SILValue, 8> pullbackDirectResults; |
| extractAllElements(ai, builder, pullbackDirectResults); |
| SmallVector<SILValue, 8> allResults; |
| collectAllActualResultsInTypeOrder(ai, pullbackDirectResults, allResults); |
| |
| SmallVector<SILValue, 8> results; |
| for (unsigned i : actualIndices.parameters->getIndices()) { |
| // If result is desired: |
| // - Do nothing if result is indirect. |
| // (It was already forwarded to the `apply` instruction). |
| // - Push it to `results` if result is direct. |
| auto result = allResults[mapOriginalParameterIndex(i)]; |
| if (desiredIndices.isWrtParameter(i)) { |
| if (result->getType().isObject()) |
| results.push_back(result); |
| } |
| // Otherwise, cleanup the unused results. |
| else { |
| if (result->getType().isAddress()) |
| builder.emitDestroyAddrAndFold(loc, result); |
| else |
| builder.emitDestroyValueOperation(loc, result); |
| } |
| } |
| // Deallocate local allocations and return final direct result. |
| for (auto *alloc : reversed(localAllocations)) |
| builder.createDeallocStack(loc, alloc); |
| auto result = joinElements(results, builder, loc); |
| builder.createReturn(loc, result); |
| |
| getGeneratedFunctions().push_back(thunk); |
| return {thunk, interfaceSubs}; |
| } |
| |
| std::pair<SILFunction *, SubstitutionMap> |
| ADContext::getOrCreateSubsetParametersThunkForDerivativeFunction( |
| SILValue origFnOperand, SILValue derivativeFn, |
| AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices, |
| SILAutoDiffIndices actualIndices) { |
| LLVM_DEBUG(getADDebugStream() |
| << "Getting a subset parameters thunk for derivative function " |
| << derivativeFn << " of the original function " << origFnOperand |
| << " from " << actualIndices << " to " << desiredIndices << '\n'); |
| |
| auto origFnType = origFnOperand->getType().castTo<SILFunctionType>(); |
| auto &module = getModule(); |
| auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule()); |
| |
| // Compute target type for thunking. |
| auto derivativeFnType = derivativeFn->getType().castTo<SILFunctionType>(); |
| auto targetType = origFnType->getAutoDiffDerivativeFunctionType( |
| desiredIndices.parameters, desiredIndices.source, kind, module.Types, |
| lookupConformance); |
| auto *caller = derivativeFn->getFunction(); |
| if (targetType->hasArchetype()) { |
| auto substTargetType = caller->mapTypeIntoContext( |
| targetType->mapTypeOutOfContext())->getCanonicalType(); |
| targetType = SILType::getPrimitiveObjectType(substTargetType) |
| .castTo<SILFunctionType>(); |
| } |
| assert(derivativeFnType->getNumParameters() == targetType->getNumParameters()); |
| assert(derivativeFnType->getNumResults() == targetType->getNumResults()); |
| |
| // Build thunk type. |
| SubstitutionMap interfaceSubs; |
| GenericEnvironment *genericEnv = nullptr; |
| auto thunkType = buildThunkType( |
| derivativeFn->getFunction(), derivativeFnType, targetType, genericEnv, |
| interfaceSubs, /*withoutActuallyEscaping*/ false, |
| DifferentiationThunkKind::IndexSubset); |
| |
| // FIXME: The logic for resolving `assocRef` does not reapply function |
| // conversions, which is problematic if `derivativeFn` is a `partial_apply` |
| // instruction. |
| StringRef origName; |
| if (auto *origFnRef = |
| peerThroughFunctionConversions<FunctionRefInst>(origFnOperand)) { |
| origName = origFnRef->getInitiallyReferencedFunction()->getName(); |
| } else if (auto *origMethodInst = |
| peerThroughFunctionConversions<MethodInst>(origFnOperand)) { |
| origName = origMethodInst->getMember().getAnyFunctionRef() |
| ->getAbstractFunctionDecl()->getNameStr(); |
| } |
| assert(!origName.empty() && "Original function name could not be resolved"); |
| // TODO(TF-685): Use more principled mangling for thunks. |
| std::string thunkName; |
| switch (kind) { |
| case AutoDiffDerivativeFunctionKind::JVP: |
| thunkName = "jvp"; |
| break; |
| case AutoDiffDerivativeFunctionKind::VJP: |
| thunkName = "vjp"; |
| } |
| Mangle::ASTMangler mangler; |
| auto fromInterfaceType = |
| derivativeFnType->mapTypeOutOfContext()->getCanonicalType(); |
| auto toInterfaceType = targetType->mapTypeOutOfContext()->getCanonicalType(); |
| CanType dynamicSelfType; |
| thunkName = "AD__orig_" + origName.str() + "_" + |
| mangler.mangleReabstractionThunkHelper( |
| thunkType, fromInterfaceType, toInterfaceType, dynamicSelfType, |
| module.getSwiftModule()) + "_" + desiredIndices.mangle() + "_" + |
| thunkName; |
| thunkName += "_subset_parameters_thunk"; |
| |
| auto loc = origFnOperand.getLoc(); |
| SILOptFunctionBuilder fb(getTransform()); |
| auto *thunk = fb.getOrCreateSharedFunction( |
| loc, thunkName, thunkType, IsBare, IsTransparent, caller->isSerialized(), |
| ProfileCounter(), IsThunk, IsNotDynamic); |
| |
| if (!thunk->empty()) |
| return {thunk, interfaceSubs}; |
| |
| thunk->setOwnershipEliminated(); |
| thunk->setGenericEnvironment(genericEnv); |
| auto *entry = thunk->createBasicBlock(); |
| SILBuilder builder(entry); |
| createEntryArguments(thunk); |
| |
| SubstitutionMap assocSubstMap; |
| if (auto *partialApply = dyn_cast<PartialApplyInst>(derivativeFn)) |
| assocSubstMap = partialApply->getSubstitutionMap(); |
| |
| // FIXME: The logic for resolving `assocRef` does not reapply function |
| // conversions, which is problematic if `derivativeFn` is a `partial_apply` |
| // instruction. |
| SILValue assocRef; |
| if (auto *derivativeFnRef = |
| peerThroughFunctionConversions<FunctionRefInst>(derivativeFn)) { |
| auto *assoc = derivativeFnRef->getReferencedFunctionOrNull(); |
| assocRef = builder.createFunctionRef(loc, assoc); |
| } else if (auto *assocMethodInst = |
| peerThroughFunctionConversions<WitnessMethodInst>(derivativeFn)) { |
| assocRef = builder.createWitnessMethod( |
| loc, assocMethodInst->getLookupType(), |
| assocMethodInst->getConformance(), assocMethodInst->getMember(), |
| thunk->mapTypeIntoContext(assocMethodInst->getType())); |
| } else if (auto *assocMethodInst = |
| peerThroughFunctionConversions<ClassMethodInst>(derivativeFn)) { |
| auto classOperand = thunk->getArgumentsWithoutIndirectResults().back(); |
| auto classOperandType = assocMethodInst->getOperand()->getType(); |
| assert(classOperand->getType() == classOperandType); |
| assocRef = builder.createClassMethod( |
| loc, classOperand, assocMethodInst->getMember(), |
| thunk->mapTypeIntoContext(assocMethodInst->getType())); |
| } |
| assert(assocRef && "Expected derivative function to be resolved"); |
| |
| assocSubstMap = assocSubstMap.subst(thunk->getForwardingSubstitutionMap()); |
| derivativeFnType = assocRef->getType().castTo<SILFunctionType>(); |
| |
| SmallVector<SILValue, 4> arguments; |
| arguments.append(thunk->getArguments().begin(), thunk->getArguments().end()); |
| assert(arguments.size() == derivativeFnType->getNumParameters() + |
| derivativeFnType->getNumIndirectFormalResults()); |
| auto *apply = builder.createApply( |
| loc, assocRef, assocSubstMap, arguments, /*isNonThrowing*/ false); |
| |
| // Extract all direct results. |
| SmallVector<SILValue, 8> directResults; |
| extractAllElements(apply, builder, directResults); |
| auto originalDirectResults = ArrayRef<SILValue>(directResults).drop_back(1); |
| auto originalDirectResult = |
| joinElements(originalDirectResults, builder, apply->getLoc()); |
| auto linearMap = directResults.back(); |
| |
| auto linearMapType = linearMap->getType().castTo<SILFunctionType>(); |
| auto linearMapTargetType = targetType->getResults().back().getSILStorageType() |
| .castTo<SILFunctionType>(); |
| |
| SILFunction *linearMapThunk; |
| SubstitutionMap linearMapSubs; |
| std::tie(linearMapThunk, linearMapSubs) = |
| getOrCreateSubsetParametersThunkForLinearMap( |
| thunk, linearMapType, linearMapTargetType, kind, |
| desiredIndices, actualIndices); |
| |
| auto *linearMapThunkFRI = builder.createFunctionRef(loc, linearMapThunk); |
| auto *thunkedLinearMap = builder.createPartialApply( |
| loc, linearMapThunkFRI, linearMapSubs, {linearMap}, |
| ParameterConvention::Direct_Guaranteed); |
| |
| assert(origFnType->getResults().size() == 1); |
| if (origFnType->getResults().front().isFormalDirect()) { |
| auto result = joinElements( |
| {originalDirectResult, thunkedLinearMap}, builder, loc); |
| builder.createReturn(loc, result); |
| } else { |
| builder.createReturn(loc, thunkedLinearMap); |
| } |
| |
| getGeneratedFunctions().push_back(thunk); |
| return {thunk, interfaceSubs}; |
| } |
| |
| SILValue ADContext::promoteToDifferentiableFunction( |
| DifferentiableFunctionInst *dfi, SILBuilder &builder, SILLocation loc, |
| DifferentiationInvoker invoker) { |
| auto origFnOperand = dfi->getOriginalFunction(); |
| auto origFnTy = origFnOperand->getType().castTo<SILFunctionType>(); |
| auto parameterIndices = dfi->getParameterIndices(); |
| unsigned resultIndex = resultIndices[dfi]; |
| |
| // Handle curry thunk applications specially. |
| if (auto *ai = dyn_cast<ApplyInst>(origFnOperand)) { |
| if (auto *thunkRef = dyn_cast<FunctionRefInst>(ai->getCallee())) { |
| // Create a new curry thunk. |
| SILAutoDiffIndices desiredIndices(resultIndex, parameterIndices); |
| auto *thunk = thunkRef->getReferencedFunctionOrNull(); |
| // TODO(TF-685): Use more principled mangling for thunks. |
| auto newThunkName = "AD__" + thunk->getName().str() + |
| "__differentiable_curry_thunk_" + desiredIndices.mangle(); |
| |
| auto thunkTy = thunk->getLoweredFunctionType(); |
| auto thunkResult = thunkTy->getSingleResult(); |
| if (auto resultFnTy = thunkResult.getType()->getAs<SILFunctionType>()) { |
| // Construct new curry thunk type with `@differentiable` result. |
| auto diffableResultFnTy = resultFnTy->getWithExtInfo( |
| resultFnTy->getExtInfo() |
| .withDifferentiabilityKind(DifferentiabilityKind::Normal)); |
| auto newThunkResult = thunkResult.getWithType(diffableResultFnTy); |
| auto thunkType = SILFunctionType::get( |
| thunkTy->getGenericSignature(), thunkTy->getExtInfo(), |
| thunkTy->getCoroutineKind(), thunkTy->getCalleeConvention(), |
| thunkTy->getParameters(), {}, {newThunkResult}, {}, |
| thunkTy->getASTContext()); |
| |
| // Construct new curry thunk, returning a `@differentiable` function. |
| SILOptFunctionBuilder fb(transform); |
| auto *newThunk = fb.getOrCreateFunction( |
| loc, newThunkName, |
| getSpecializedLinkage(thunk, thunk->getLinkage()), thunkType, |
| thunk->isBare(), thunk->isTransparent(), thunk->isSerialized(), |
| thunk->isDynamicallyReplaceable(), ProfileCounter(), |
| thunk->isThunk()); |
| // If new thunk is newly created: clone the old thunk body, wrap the |
| // returned function value with an `differentiable_function` |
| // instruction, and process the `differentiable_function` instruction. |
| if (newThunk->empty()) { |
| if (auto newThunkGenSig = thunkType->getGenericSignature()) |
| newThunk->setGenericEnvironment( |
| newThunkGenSig->getGenericEnvironment()); |
| newThunk->setOwnershipEliminated(); |
| BasicTypeSubstCloner cloner(thunk, newThunk); |
| cloner.run(); |
| auto *retInst = |
| cast<ReturnInst>(newThunk->findReturnBB()->getTerminator()); |
| SILBuilder thunkBuilder(retInst); |
| auto *dfi = createDifferentiableFunction(thunkBuilder, loc, |
| parameterIndices, |
| retInst->getOperand()); |
| resultIndices[dfi] = resultIndex; |
| thunkBuilder.createReturn(loc, dfi); |
| retInst->eraseFromParent(); |
| |
| getGeneratedFunctions().push_back(newThunk); |
| getDifferentiableFunctionInsts().push_back(dfi); |
| if (processDifferentiableFunctionInst(dfi)) |
| return nullptr; |
| } |
| |
| // Apply the new curry thunk. |
| auto *newThunkRef = builder.createFunctionRef(loc, newThunk); |
| getGeneratedFunctionReferences().push_back(newThunkRef); |
| SmallVector<SILValue, 8> newArgs; |
| SmallVector<SILValue, 8> newArgsToDestroy; |
| SmallVector<AllocStackInst *, 1> newBuffersToDealloc; |
| copyParameterArgumentsForApply(ai, newArgs, newArgsToDestroy, |
| newBuffersToDealloc); |
| auto *newApply = builder.createApply( |
| ai->getLoc(), newThunkRef, ai->getSubstitutionMap(), newArgs, |
| ai->isNonThrowing()); |
| for (auto arg : newArgsToDestroy) { |
| if (arg->getType().isObject()) |
| builder.emitDestroyValueOperation(loc, arg); |
| else |
| builder.emitDestroyAddr(loc, arg); |
| } |
| for (auto *alloc : newBuffersToDealloc) |
| builder.createDeallocStack(loc, alloc); |
| return newApply; |
| } |
| } |
| } |
| |
| SILAutoDiffIndices desiredIndices(resultIndex, parameterIndices); |
| SmallVector<SILValue, 2> derivativeFns; |
| SmallVector<AllocStackInst *, 2> newBuffersToDealloc; |
| for (auto derivativeFnKind : {AutoDiffDerivativeFunctionKind::JVP, |
| AutoDiffDerivativeFunctionKind::VJP}) { |
| auto derivativeFnAndIndices = emitDerivativeFunctionReference( |
| *this, builder, desiredIndices, derivativeFnKind, origFnOperand, invoker, |
| newBuffersToDealloc); |
| // Show an error at the operator, highlight the argument, and show a note |
| // at the definition site of the argument. |
| if (!derivativeFnAndIndices) |
| return nullptr; |
| |
| auto derivativeFn = derivativeFnAndIndices->first; |
| getGeneratedFunctionReferences().push_back(derivativeFn); |
| |
| // If desired indices are a subset of actual indices, create a "subset |
| // indices thunk" and destroy the emitted derivative function reference. |
| // - For JVPs: the thunked JVP returns a differential taking fewer |
| // parameters (using `.zero` for the dropped parameters). |
| // - For VJPs: the thunked VJP returns a pullback that drops the unused |
| // tangent values. |
| auto actualIndices = derivativeFnAndIndices->second; |
| // NOTE: `desiredIndices` may come from a partially-applied function and |
| // have smaller capacity than `actualIndices`. We expect this logic to go |
| // away when we support `@differentiable` partial apply. |
| // if (actualIndices != desiredIndices) { // TODO: Re-enable. |
| auto extendedDesiredIndices = desiredIndices.parameters->extendingCapacity( |
| getASTContext(), actualIndices.parameters->getCapacity()); |
| if (actualIndices.source != desiredIndices.source || |
| !actualIndices.parameters->equals(extendedDesiredIndices)) { |
| // Destroy the already emitted derivative function reference because it |
| // is no longer used. |
| builder.emitDestroyValueOperation(loc, derivativeFn); |
| // Check if underlying original function reference has been partially |
| // applied with arguments. If so, produce an error: parameter subset |
| // thunks do not yet support this case because partially applied arguments |
| // cannot be propagated to parameter subset thunks. |
| auto didPartiallyApplyArguments = [](SILValue original) { |
| while (auto *pai = |
| peerThroughFunctionConversions<PartialApplyInst>(original)) { |
| if (pai->getNumArguments() > 0) |
| return true; |
| original = pai->getCallee(); |
| } |
| return false; |
| }; |
| if (didPartiallyApplyArguments(origFnOperand)) { |
| emitNondifferentiabilityError( |
| origFnOperand, invoker, |
| diag::autodiff_cannot_param_subset_thunk_partially_applied_orig_fn); |
| return nullptr; |
| } |
| // Create the parameter subset thunk. |
| assert(actualIndices.parameters->isSupersetOf(extendedDesiredIndices)); |
| SILFunction *thunk; |
| SubstitutionMap interfaceSubs; |
| std::tie(thunk, interfaceSubs) = |
| getOrCreateSubsetParametersThunkForDerivativeFunction( |
| origFnOperand, derivativeFn, derivativeFnKind, desiredIndices, |
| actualIndices); |
| auto *thunkFRI = builder.createFunctionRef(loc, thunk); |
| if (auto genSig = |
| thunk->getLoweredFunctionType()->getGenericSignature()) { |
| derivativeFn = builder.createPartialApply( |
| loc, thunkFRI, interfaceSubs, {}, |
| ParameterConvention::Direct_Guaranteed); |
| } else { |
| derivativeFn = thunkFRI; |
| } |
| } |
| auto expectedDerivativeFnTy = origFnTy->getAutoDiffDerivativeFunctionType( |
| parameterIndices, resultIndex, derivativeFnKind, getTypeConverter(), |
| LookUpConformanceInModule(getModule().getSwiftModule())); |
| // If `derivativeFn` is `@convention(thin)` but is expected to be |
| // `@convention(thick)`, emit a `thin_to_thick` instruction. |
| if (expectedDerivativeFnTy->getRepresentation() |
| == SILFunctionTypeRepresentation::Thick && |
| derivativeFn->getType().castTo<SILFunctionType>()->getRepresentation() |
| == SILFunctionTypeRepresentation::Thin) { |
| derivativeFn = builder.createThinToThickFunction( |
| loc, derivativeFn, SILType::getPrimitiveObjectType(expectedDerivativeFnTy)); |
| } |
| |
| derivativeFns.push_back(derivativeFn); |
| } |
| // Deallocate temporary buffers used for creating derivative functions. |
| for (auto *buf : reversed(newBuffersToDealloc)) |
| builder.createDeallocStack(loc, buf); |
| |
| auto origFnCopy = builder.emitCopyValueOperation(loc, origFnOperand); |
| auto *newDFI = createDifferentiableFunction( |
| builder, loc, parameterIndices, origFnCopy, |
| std::make_pair(derivativeFns[0], derivativeFns[1])); |
| resultIndices[dfi] = resultIndex; |
| getDifferentiableFunctionInsts().push_back(dfi); |
| |
| return newDFI; |
| } |
| |
| /// Fold `differentiable_function_extract` users of the given |
| /// `differentiable_function` instruction, directly replacing them with |
| /// `differentiable_function` instruction operands. If the |
| /// `differentiable_function` instruction has no remaining uses, delete the |
| /// instruction itself after folding. |
| /// |
| /// Folding can be disabled by the `SkipFoldingDifferentiableFunctionExtraction` |
| /// flag for SIL testing purposes. |
| // FIXME: This function is not correctly detecting the foldable pattern and |
| // needs to be rewritten. |
| void ADContext::foldDifferentiableFunctionExtraction( |
| DifferentiableFunctionInst *source) { |
| // Iterate through all `differentiable_function` instruction uses. |
| for (auto use : source->getUses()) { |
| auto *dfei = dyn_cast<DifferentiableFunctionExtractInst>(use->getUser()); |
| // If user is not an `differentiable_function_extract` instruction, set flag |
| // to false. |
| if (!dfei) |
| continue; |
| // Fold original function extractors. |
| if (dfei->getExtractee() == |
| NormalDifferentiableFunctionTypeComponent::Original) { |
| auto originalFnValue = source->getOriginalFunction(); |
| dfei->replaceAllUsesWith(originalFnValue); |
| dfei->eraseFromParent(); |
| continue; |
| } |
| // Fold derivative function extractors. |
| auto derivativeFnValue = |
| source->getDerivativeFunction(dfei->getDerivativeFunctionKind()); |
| dfei->replaceAllUsesWith(derivativeFnValue); |
| dfei->eraseFromParent(); |
| } |
| // If the `differentiable_function` instruction has no remaining uses, erase |
| // it. |
| if (isInstructionTriviallyDead(source)) { |
| SILBuilder builder(source); |
| builder.emitDestroyAddrAndFold(source->getLoc(), source->getJVPFunction()); |
| builder.emitDestroyAddrAndFold(source->getLoc(), source->getVJPFunction()); |
| source->eraseFromParent(); |
| } |
| // Mark `source` as processed so that it won't be reprocessed after deletion. |
| processedDifferentiableFunctionInsts.insert(source); |
| } |
| |
| bool ADContext::processDifferentiableFunctionInst( |
| DifferentiableFunctionInst *dfi) { |
| LLVM_DEBUG({ |
| auto &s = getADDebugStream() << "Processing DifferentiableFunctionInst:\n"; |
| dfi->printInContext(s); |
| }); |
| if (dfi->hasDerivativeFunctions()) |
| return false; |
| |
| SILFunction *parent = dfi->getFunction(); |
| auto loc = dfi->getLoc(); |
| SILBuilder builder(dfi); |
| |
| auto differentiableFnValue = |
| promoteToDifferentiableFunction(dfi, builder, loc, dfi); |
| // Mark `dfi` as processed so that it won't be reprocessed after deletion. |
| processedDifferentiableFunctionInsts.insert(dfi); |
| if (!differentiableFnValue) |
| return true; |
| // Replace all uses of `dfi`. |
| dfi->replaceAllUsesWith(differentiableFnValue); |
| // Destroy the original operand. |
| builder.emitDestroyValueOperation(loc, dfi->getOriginalFunction()); |
| dfi->eraseFromParent(); |
| // If the promoted `@differentiable` function-typed value is an |
| // `differentiable_function` instruction, fold |
| // `differentiable_function_extract` instructions. If |
| // `differentiable_function_extract` folding is disabled, return. |
| if (!SkipFoldingDifferentiableFunctionExtraction) |
| if (auto *newDFI = |
| dyn_cast<DifferentiableFunctionInst>(differentiableFnValue)) |
| foldDifferentiableFunctionExtraction(newDFI); |
| transform.invalidateAnalysis( |
| parent, SILAnalysis::InvalidationKind::FunctionBody); |
| return false; |
| } |
| |
| /// AD pass entry. |
| void Differentiation::run() { |
| auto &module = *getModule(); |
| auto &astCtx = module.getASTContext(); |
| debugDump(module); |
| |
| // A global differentiation context. |
| ADContext context(*this); |
| |
| bool errorOccurred = false; |
| |
| // Register all `@differentiable` attributes and `differentiable_function` |
| // instructions in the module that trigger differentiation. |
| for (SILFunction &f : module) { |
| for (auto *diffAttr : f.getDifferentiableAttrs()) { |
| DifferentiationInvoker invoker(diffAttr); |
| assert(!context.getInvokers().count(diffAttr) && |
| "[differentiable] attribute already has an invoker"); |
| context.getInvokers().insert({diffAttr, invoker}); |
| continue; |
| } |
| for (SILBasicBlock &bb : f) { |
| for (SILInstruction &i : bb) { |
| if (auto *dfi = dyn_cast<DifferentiableFunctionInst>(&i)) |
| context.getDifferentiableFunctionInsts().push_back(dfi); |
| else if (auto *lfi = dyn_cast<LinearFunctionInst>(&i)) { |
| astCtx.Diags.diagnose( |
| lfi->getLoc().getSourceLoc(), |
| diag::autodiff_conversion_to_linear_function_not_supported); |
| errorOccurred = true; |
| } |
| } |
| } |
| } |
| |
| // If nothing has triggered differentiation, there's nothing to do. |
| if (context.getInvokers().empty() && |
| context.getDifferentiableFunctionInsts().empty()) |
| return; |
| |
| // AD relies on stdlib (the Swift module). If it's not imported, it's an |
| // internal error. |
| if (!astCtx.getStdlibModule()) { |
| astCtx.Diags.diagnose(SourceLoc(), |
| diag::autodiff_internal_swift_not_imported); |
| return; |
| } |
| |
| // Process all `[differentiable]` attributes. |
| for (auto invokerPair : context.getInvokers()) { |
| auto *attr = invokerPair.first; |
| auto *original = attr->getOriginal(); |
| auto invoker = invokerPair.second; |
| errorOccurred |= |
| context.processDifferentiableAttribute(original, attr, invoker); |
| } |
| |
| // Iteratively process `differentiable_function` instruction worklist. |
| while (!context.getDifferentiableFunctionInsts().empty()) { |
| auto *dfi = context.getDifferentiableFunctionInsts().back(); |
| context.getDifferentiableFunctionInsts().pop_back(); |
| // Skip instructions that have been already been processed. |
| if (context.getProcessedDifferentiableFunctionInsts().count(dfi)) continue; |
| errorOccurred |= context.processDifferentiableFunctionInst(dfi); |
| } |
| |
| // If any error occurred while processing `[differentiable]` attributes or |
| // `differentiable_function` instructions, clean up. |
| if (errorOccurred) { |
| context.cleanUp(); |
| return; |
| } |
| |
| LLVM_DEBUG(getADDebugStream() << "All differentiation finished\n"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Pass creation |
| //===----------------------------------------------------------------------===// |
| |
| SILTransform *swift::createDifferentiation() { |
| return new Differentiation; |
| } |