| //===--- LinearMapInfo.h --------------------------------------*- C++ -*---===// |
| // |
| // This source file is part of the Swift.org open source project |
| // |
| // Copyright (c) 2019 Apple Inc. and the Swift project authors |
| // Licensed under Apache License v2.0 with Runtime Library Exception |
| // |
| // See https://swift.org/LICENSE.txt for license information |
| // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Linear map struct and branching trace enum information for differentation. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_LINEARMAPINFO_H |
| #define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_LINEARMAPINFO_H |
| |
| #include "swift/AST/AutoDiff.h" |
| #include "swift/AST/SynthesizedFileUnit.h" |
| #include "swift/SIL/ApplySite.h" |
| #include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" |
| #include "llvm/ADT/DenseMap.h" |
| |
| namespace swift { |
| |
| class SILFunction; |
| class SILLoopInfo; |
| |
| namespace autodiff { |
| |
| class ADContext; |
| |
| /// Linear map struct and branching trace enum information for an original |
| /// function and and derivative function (JVP or VJP). |
| /// |
| /// Linear map structs contain all callee linear maps produced in a JVP/VJP |
| /// basic block. A linear map struct is created for each basic block in the |
| /// original function, and a linear map struct field is created for every active |
| /// `apply` in the original basic block. |
| /// |
| /// Branching trace enums model the control flow graph of the original function. |
| /// A branching trace enum is created for each basic block in the original |
| /// function, and a branching trace enum case is created for every basic block |
| /// predecessor/successor. This supports control flow differentiation: JVP/VJP |
| /// functions build branching trace enums to record an execution trace. Indirect |
| /// branching trace enums are created for basic blocks that are in loops. |
| /// |
| /// Linear map struct values and branching trace enum values are constructed in |
| /// JVP/VJP functions and consumed in pullback/differential functions. |
| class LinearMapInfo { |
| private: |
| /// The linear map kind. |
| AutoDiffLinearMapKind kind; |
| |
| /// The original function. |
| SILFunction *const original; |
| |
| /// The derivative function. |
| SILFunction *const derivative; |
| |
| /// Activity info of the original function. |
| const DifferentiableActivityInfo &activityInfo; |
| |
| /// The original function's loop info. |
| SILLoopInfo *loopInfo; |
| |
| /// Differentiation indices of the function. |
| const AutoDiffConfig config; |
| |
| /// Mapping from original basic blocks to linear map structs. |
| llvm::DenseMap<SILBasicBlock *, StructDecl *> linearMapStructs; |
| |
| /// Mapping from original basic blocks to branching trace enums. |
| /// For pullbacks: these are predecessor enums. |
| /// For differentials: these are successor enums. |
| llvm::DenseMap<SILBasicBlock *, EnumDecl *> branchingTraceDecls; |
| |
| /// Mapping from `apply` instructions in the original function to the |
| /// corresponding linear map field declaration in the linear map struct. |
| llvm::DenseMap<ApplyInst *, VarDecl *> linearMapFieldMap; |
| |
| /// Mapping from predecessor-succcessor basic block pairs in the original |
| /// function to the corresponding branching trace enum case. |
| llvm::DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, EnumElementDecl *> |
| branchingTraceEnumCases; |
| |
| /// Mapping from linear map structs to their branching trace enum fields. |
| llvm::DenseMap<StructDecl *, VarDecl *> linearMapStructEnumFields; |
| |
| /// Blocks in a loop. |
| llvm::SmallSetVector<SILBasicBlock *, 4> blocksInLoop; |
| |
| /// A synthesized file unit. |
| SynthesizedFileUnit &synthesizedFile; |
| |
| /// A type converter, used to compute struct/enum SIL types. |
| Lowering::TypeConverter &typeConverter; |
| |
| private: |
| /// Remaps the given type into the derivative function's context. |
| SILType remapTypeInDerivative(SILType ty); |
| |
| /// Adds a `VarDecl` member with the given name and type to the given nominal |
| /// declaration. |
| VarDecl *addVarDecl(NominalTypeDecl *nominal, StringRef name, Type type); |
| |
| /// Retrieves the file unit that contains implicit declarations in the |
| /// current Swift module. |
| SynthesizedFileUnit &getSynthesizedFile() { return synthesizedFile; } |
| |
| /// Computes and sets the access level for the given nominal type, given the |
| /// original function linkage. |
| void computeAccessLevel(NominalTypeDecl *nominal, SILLinkage originalLinkage); |
| |
| /// Creates an enum declaration with the given JVP/VJP generic signature, |
| /// whose cases represent the predecessors/successors of the given original |
| /// block. |
| EnumDecl *createBranchingTraceDecl(SILBasicBlock *originalBB, |
| CanGenericSignature genericSig, |
| SILLoopInfo *loopInfo); |
| |
| /// Creates a struct declaration with the given JVP/VJP generic signature, for |
| /// storing the linear map values and predecessor/successor basic block of the |
| /// given original block. |
| StructDecl *createLinearMapStruct(SILBasicBlock *originalBB, |
| CanGenericSignature genericSig); |
| |
| /// Adds a linear map field to the linear map struct. |
| VarDecl *addLinearMapDecl(ApplyInst *ai, SILType linearMapType); |
| |
| /// Given an `apply` instruction, conditionally adds a linear map struct field |
| /// for its linear map function if it is active. |
| void addLinearMapToStruct(ADContext &context, ApplyInst *ai); |
| |
| /// Generates linear map struct and branching enum declarations for the given |
| /// function. Linear map structs are populated with linear map fields and a |
| /// branching enum field. |
| void generateDifferentiationDataStructures(ADContext &context, |
| SILFunction *derivative); |
| |
| public: |
| bool shouldDifferentiateApplySite(FullApplySite applySite); |
| bool shouldDifferentiateInstruction(SILInstruction *inst); |
| |
| LinearMapInfo(const LinearMapInfo &) = delete; |
| LinearMapInfo &operator=(const LinearMapInfo &) = delete; |
| |
| explicit LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind, |
| SILFunction *original, SILFunction *derivative, |
| AutoDiffConfig config, |
| const DifferentiableActivityInfo &activityInfo, |
| SILLoopInfo *loopInfo); |
| |
| /// Returns the linear map struct associated with the given original block. |
| StructDecl *getLinearMapStruct(SILBasicBlock *origBB) const { |
| return linearMapStructs.lookup(origBB); |
| } |
| |
| /// Returns the lowered SIL type of the linear map struct associated with the |
| /// given original block. |
| SILType getLinearMapStructLoweredType(SILBasicBlock *origBB) const { |
| auto derivativeGenSig = |
| derivative->getLoweredFunctionType()->getSubstGenericSignature(); |
| auto *linMapStruct = getLinearMapStruct(origBB); |
| auto linMapStructType = |
| linMapStruct->getDeclaredInterfaceType()->getCanonicalType( |
| derivativeGenSig); |
| Lowering::AbstractionPattern pattern(derivativeGenSig, linMapStructType); |
| return typeConverter.getLoweredType(pattern, linMapStructType, |
| TypeExpansionContext::minimal()); |
| } |
| |
| /// Returns the branching trace enum associated with the given original block. |
| EnumDecl *getBranchingTraceDecl(SILBasicBlock *origBB) const { |
| return branchingTraceDecls.lookup(origBB); |
| } |
| |
| /// Returns the lowered SIL type of the branching trace enum associated with |
| /// the given original block. |
| SILType getBranchingTraceEnumLoweredType(SILBasicBlock *origBB) const { |
| auto *traceDecl = getBranchingTraceDecl(origBB); |
| auto traceDeclType = |
| traceDecl->getDeclaredInterfaceType()->getCanonicalType(); |
| Lowering::AbstractionPattern pattern( |
| derivative->getLoweredFunctionType()->getSubstGenericSignature(), |
| traceDeclType); |
| return typeConverter.getLoweredType(pattern, traceDeclType, |
| TypeExpansionContext::minimal()); |
| } |
| |
| /// Returns the enum element in the given successor block's branching trace |
| /// enum corresponding to the given predecessor block. |
| EnumElementDecl * |
| lookUpBranchingTraceEnumElement(SILBasicBlock *origPredBB, |
| SILBasicBlock *origSuccBB) const { |
| assert(origPredBB->getParent() == original); |
| return branchingTraceEnumCases.lookup({origPredBB, origSuccBB}); |
| } |
| |
| /// Returns the mapping from linear map structs to their branching trace enum |
| /// fields. |
| llvm::DenseMap<StructDecl *, VarDecl *> &getLinearMapStructEnumFields() { |
| return linearMapStructEnumFields; |
| } |
| |
| /// Returns the branching trace enum field for the linear map struct of the |
| /// given original block. |
| VarDecl *lookUpLinearMapStructEnumField(SILBasicBlock *origBB) const { |
| auto *linearMapStruct = getLinearMapStruct(origBB); |
| return linearMapStructEnumFields.lookup(linearMapStruct); |
| } |
| |
| /// Finds the linear map declaration in the pullback struct for the given |
| /// `apply` instruction in the original function. |
| VarDecl *lookUpLinearMapDecl(ApplyInst *ai) const { |
| assert(ai->getFunction() == original); |
| auto lookup = linearMapFieldMap.find(ai); |
| assert(lookup != linearMapFieldMap.end() && |
| "No linear map field corresponding to the given `apply`"); |
| return lookup->getSecond(); |
| } |
| |
| bool hasLoops() const { |
| return !blocksInLoop.empty(); |
| } |
| |
| ArrayRef<SILBasicBlock *> getBlocksInLoop() const { |
| return blocksInLoop.getArrayRef(); |
| } |
| }; |
| |
| } // end namespace autodiff |
| } // end namespace swift |
| |
| #endif // SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_LINEARMAPINFO_H |