| //===--- Common.h - Automatic differentiation common utils ----*- C++ -*---===// |
| // |
| // This source file is part of the Swift.org open source project |
| // |
| // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors |
| // Licensed under Apache License v2.0 with Runtime Library Exception |
| // |
| // See https://swift.org/LICENSE.txt for license information |
| // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Automatic differentiation common utilities. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_COMMON_H |
| #define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_COMMON_H |
| |
| #include "swift/AST/DiagnosticsSIL.h" |
| #include "swift/AST/Expr.h" |
| #include "swift/AST/SemanticAttrs.h" |
| #include "swift/SIL/SILDifferentiabilityWitness.h" |
| #include "swift/SIL/SILType.h" |
| #include "swift/SIL/SILFunction.h" |
| #include "swift/SIL/Projection.h" |
| #include "swift/SIL/SILModule.h" |
| #include "swift/SIL/TypeSubstCloner.h" |
| #include "swift/SILOptimizer/Analysis/ArraySemantic.h" |
| #include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" |
| #include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h" |
| |
| namespace swift { |
| |
| namespace autodiff { |
| |
| class ADContext; |
| |
| //===----------------------------------------------------------------------===// |
| // Helpers |
| //===----------------------------------------------------------------------===// |
| |
| /// Prints an "[AD] " prefix to `llvm::dbgs()` and returns the debug stream. |
| /// This is being used to print short debug messages within the AD pass. |
| raw_ostream &getADDebugStream(); |
| |
| /// Given an element address from an `array.uninitialized_intrinsic` `apply` |
| /// instruction, returns the `apply` instruction. The element address is either |
| /// a `pointer_to_address` or `index_addr` instruction to the `RawPointer` |
| /// result of the instrinsic: |
| /// |
| /// %result = apply %array.uninitialized_intrinsic : $(Array<T>, RawPointer) |
| /// (%array, %ptr) = destructure_tuple %result |
| /// %elt0 = pointer_to_address %ptr to $*T // element address |
| /// %index_1 = integer_literal $Builtin.Word, 1 |
| /// %elt1 = index_addr %elt0, %index_1 // element address |
| /// ... |
| // TODO(SR-12894): Find a better name and move this general utility to |
| // ArraySemantic.h. |
| ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v); |
| |
| /// Given a value, finds its single `destructure_tuple` user if the value is |
| /// tuple-typed and such a user exists. |
| DestructureTupleInst *getSingleDestructureTupleUser(SILValue value); |
| |
| /// Returns true if the given original function is a "semantic member accessor". |
| /// |
| /// "Semantic member accessors" are attached to member properties that have a |
| /// corresponding tangent stored property in the parent `TangentVector` type. |
| /// These accessors have special-case pullback generation based on their |
| /// semantic behavior. |
| /// |
| /// "Semantic member accessors" currently include: |
| /// - Stored property accessors. These are implicitly generated. |
| /// - Property wrapper wrapped value accessors. These are implicitly generated |
| /// and internally call `var wrappedValue`. |
| bool isSemanticMemberAccessor(SILFunction *original); |
| |
| /// Returns true if the given apply site has a "semantic member accessor" |
| /// callee. |
| bool hasSemanticMemberAccessorCallee(ApplySite applySite); |
| |
| /// Given a full apply site, apply the given callback to each of its |
| /// "direct results". |
| /// |
| /// - `apply` |
| /// Special case because `apply` returns a single (possibly tuple-typed) result |
| /// instead of multiple results. If the `apply` has a single |
| /// `destructure_tuple` user, treat the `destructure_tuple` results as the |
| /// `apply` direct results. |
| /// |
| /// - `begin_apply` |
| /// Apply callback to each `begin_apply` direct result. |
| /// |
| /// - `try_apply` |
| /// Apply callback to each `try_apply` successor basic block argument. |
| void forEachApplyDirectResult( |
| FullApplySite applySite, llvm::function_ref<void(SILValue)> resultCallback); |
| |
| /// Given a function, gathers all of its formal results (both direct and |
| /// indirect) in an order defined by its result type. Note that "formal results" |
| /// refer to result values in the body of the function, not at call sites. |
| void collectAllFormalResultsInTypeOrder(SILFunction &function, |
| SmallVectorImpl<SILValue> &results); |
| |
| /// Given a function, gathers all of its direct results in an order defined by |
| /// its result type. Note that "formal results" refer to result values in the |
| /// body of the function, not at call sites. |
| void collectAllDirectResultsInTypeOrder(SILFunction &function, |
| SmallVectorImpl<SILValue> &results); |
| |
| /// Given a function call site, gathers all of its actual results (both direct |
| /// and indirect) in an order defined by its result type. |
| void collectAllActualResultsInTypeOrder( |
| ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults, |
| SmallVectorImpl<SILValue> &results); |
| |
| /// For an `apply` instruction with active results, compute: |
| /// - The results of the `apply` instruction, in type order. |
| /// - The set of minimal parameter and result indices for differentiating the |
| /// `apply` instruction. |
| void collectMinimalIndicesForFunctionCall( |
| ApplyInst *ai, AutoDiffConfig parentConfig, |
| const DifferentiableActivityInfo &activityInfo, |
| SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> ¶mIndices, |
| SmallVectorImpl<unsigned> &resultIndices); |
| |
| /// Returns the underlying instruction for the given SILValue, if it exists, |
| /// peering through function conversion instructions. |
| template <class Inst> Inst *peerThroughFunctionConversions(SILValue value) { |
| if (auto *inst = dyn_cast<Inst>(value)) |
| return inst; |
| if (auto *cvi = dyn_cast<CopyValueInst>(value)) |
| return peerThroughFunctionConversions<Inst>(cvi->getOperand()); |
| if (auto *bbi = dyn_cast<BeginBorrowInst>(value)) |
| return peerThroughFunctionConversions<Inst>(bbi->getOperand()); |
| if (auto *tttfi = dyn_cast<ThinToThickFunctionInst>(value)) |
| return peerThroughFunctionConversions<Inst>(tttfi->getOperand()); |
| if (auto *cfi = dyn_cast<ConvertFunctionInst>(value)) |
| return peerThroughFunctionConversions<Inst>(cfi->getOperand()); |
| if (auto *pai = dyn_cast<PartialApplyInst>(value)) |
| return peerThroughFunctionConversions<Inst>(pai->getCallee()); |
| return nullptr; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Diagnostic utilities |
| //===----------------------------------------------------------------------===// |
| |
| // Returns `v`'s location if it is valid. Otherwise, returns `v`'s function's |
| // location as as a fallback. Used for diagnostics. |
| SILLocation getValidLocation(SILValue v); |
| |
| // Returns `inst`'s location if it is valid. Otherwise, returns `inst`'s |
| // function's location as as a fallback. Used for diagnostics. |
| SILLocation getValidLocation(SILInstruction *inst); |
| |
| //===----------------------------------------------------------------------===// |
| // Tangent property lookup utilities |
| //===----------------------------------------------------------------------===// |
| |
| /// Returns the tangent stored property of the given original stored property |
| /// and base type. On error, emits diagnostic and returns nullptr. |
| VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField, |
| CanType baseType, SILLocation loc, |
| DifferentiationInvoker invoker); |
| |
| /// Returns the tangent stored property of the original stored property |
| /// referenced by the given projection instruction with the given base type. |
| /// On error, emits diagnostic and returns nullptr. |
| /// |
| /// NOTE: Asserts if \p projectionInst is not one of: struct_extract, |
| /// struct_element_addr, or ref_element_addr. |
| VarDecl *getTangentStoredProperty(ADContext &context, |
| SingleValueInstruction *projectionInst, |
| CanType baseType, |
| DifferentiationInvoker invoker); |
| |
| //===----------------------------------------------------------------------===// |
| // Code emission utilities |
| //===----------------------------------------------------------------------===// |
| |
| /// Given a range of elements, joins these into a single value. If there's |
| /// exactly one element, returns that element. Otherwise, creates a tuple using |
| /// a `tuple` instruction. |
| SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder, |
| SILLocation loc); |
| |
| /// Given a value, extracts all elements to `results` from this value if it has |
| /// a tuple type. Otherwise, add this value directly to `results`. |
| void extractAllElements(SILValue value, SILBuilder &builder, |
| SmallVectorImpl<SILValue> &results); |
| |
| /// Emit a zero value into the given buffer access by calling |
| /// `AdditiveArithmetic.zero`. The given type must conform to |
| /// `AdditiveArithmetic`. |
| void emitZeroIntoBuffer(SILBuilder &builder, CanType type, |
| SILValue bufferAccess, SILLocation loc); |
| |
| /// Emit a `Builtin.Word` value that represents the given type's memory layout |
| /// size. |
| SILValue emitMemoryLayoutSize( |
| SILBuilder &builder, SILLocation loc, CanType type); |
| |
| /// Emit a projection of the top-level subcontext from the context object. |
| SILValue emitProjectTopLevelSubcontext( |
| SILBuilder &builder, SILLocation loc, SILValue context, |
| SILType subcontextType); |
| |
| //===----------------------------------------------------------------------===// |
| // Utilities for looking up derivatives of functions |
| //===----------------------------------------------------------------------===// |
| |
| /// Returns a differentiability witness (definition or declaration) exactly |
| /// matching the specified indices. If none are found in the given `module`, |
| /// returns `nullptr`. |
| /// |
| /// \param parameterIndices must be lowered to SIL. |
| /// \param resultIndices must be lowered to SIL. |
| SILDifferentiabilityWitness * |
| getExactDifferentiabilityWitness(SILModule &module, SILFunction *original, |
| IndexSubset *parameterIndices, |
| IndexSubset *resultIndices); |
| |
| /// Finds the derivative configuration (from `@differentiable` and |
| /// `@derivative` attributes) for `original` whose parameter indices are a |
| /// minimal superset of the specified AST parameter indices. Returns `None` if |
| /// no such configuration is found. |
| /// |
| /// \param parameterIndices must be lowered to SIL. |
| /// \param minimalASTParameterIndices is an output parameter that is set to the |
| /// AST indices of the minimal configuration, or to `nullptr` if no such |
| /// configuration exists. |
| Optional<AutoDiffConfig> |
| findMinimalDerivativeConfiguration(AbstractFunctionDecl *original, |
| IndexSubset *parameterIndices, |
| IndexSubset *&minimalASTParameterIndices); |
| |
| /// Returns a differentiability witness for `original` whose parameter indices |
| /// are a minimal superset of the specified parameter indices and whose result |
| /// indices match the given result indices, out of all |
| /// differentiability witnesses that come from AST "@differentiable" or |
| /// "@differentiating" attributes. |
| /// |
| /// This function never creates new differentiability witness definitions. |
| /// However, this function may create new differentiability witness declarations |
| /// referring to definitions in other modules when these witnesses have not yet |
| /// been declared in the current module. |
| /// |
| /// \param module is the SILModule in which to get or create the witnesses. |
| /// \param parameterIndices must be lowered to SIL. |
| /// \param resultIndices must be lowered to SIL. |
| SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness( |
| SILModule &module, SILFunction *original, IndexSubset *parameterIndices, |
| IndexSubset *resultIndices); |
| |
| } // end namespace autodiff |
| |
| /// Creates arguments in the entry block based on the function type. |
| inline void createEntryArguments(SILFunction *f) { |
| auto *entry = f->getEntryBlock(); |
| auto conv = f->getConventions(); |
| auto &ctx = f->getASTContext(); |
| auto moduleDecl = f->getModule().getSwiftModule(); |
| assert((entry->getNumArguments() == 0 || conv.getNumSILArguments() == 0) && |
| "Entry already has arguments?!"); |
| auto createFunctionArgument = [&](SILType type) { |
| // Create a dummy parameter declaration. |
| // Necessary to prevent crash during argument explosion optimization. |
| auto loc = f->getLocation().getSourceLoc(); |
| auto *decl = new (ctx) |
| ParamDecl(loc, loc, Identifier(), loc, Identifier(), moduleDecl); |
| decl->setSpecifier(ParamDecl::Specifier::Default); |
| entry->createFunctionArgument(type, decl); |
| }; |
| for (auto indResTy : |
| conv.getIndirectSILResultTypes(f->getTypeExpansionContext())) { |
| if (indResTy.hasArchetype()) |
| indResTy = indResTy.mapTypeOutOfContext(); |
| createFunctionArgument(f->mapTypeIntoContext(indResTy).getAddressType()); |
| } |
| for (auto paramTy : conv.getParameterSILTypes(f->getTypeExpansionContext())) { |
| if (paramTy.hasArchetype()) |
| paramTy = paramTy.mapTypeOutOfContext(); |
| createFunctionArgument(f->mapTypeIntoContext(paramTy)); |
| } |
| } |
| |
| /// Cloner that remaps types using the target function's generic environment. |
| class BasicTypeSubstCloner final |
| : public TypeSubstCloner<BasicTypeSubstCloner, SILOptFunctionBuilder> { |
| |
| static SubstitutionMap getSubstitutionMap(SILFunction *target) { |
| if (auto *targetGenEnv = target->getGenericEnvironment()) |
| return targetGenEnv->getForwardingSubstitutionMap(); |
| return SubstitutionMap(); |
| } |
| |
| public: |
| explicit BasicTypeSubstCloner(SILFunction *original, SILFunction *target) |
| : TypeSubstCloner(*target, *original, getSubstitutionMap(target)) {} |
| |
| void postProcess(SILInstruction *orig, SILInstruction *cloned) { |
| SILClonerWithScopes::postProcess(orig, cloned); |
| } |
| |
| void cloneFunction() { |
| auto &newFunction = Builder.getFunction(); |
| auto *entry = newFunction.createBasicBlock(); |
| createEntryArguments(&newFunction); |
| SmallVector<SILValue, 8> entryArguments(newFunction.getArguments().begin(), |
| newFunction.getArguments().end()); |
| cloneFunctionBody(&Original, entry, entryArguments); |
| } |
| }; |
| |
| } // end namespace swift |
| |
| #endif // SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_COMMON_H |