| //===--- PullbackCloner.cpp - Pullback function generation ---*- C++ -*----===// |
| // |
| // This source file is part of the Swift.org open source project |
| // |
| // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors |
| // Licensed under Apache License v2.0 with Runtime Library Exception |
| // |
| // See https://swift.org/LICENSE.txt for license information |
| // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file defines a helper class for generating pullback functions for |
| // automatic differentiation. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #define DEBUG_TYPE "differentiation" |
| |
| #include "swift/SILOptimizer/Differentiation/PullbackCloner.h" |
| #include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h" |
| #include "swift/SILOptimizer/Differentiation/ADContext.h" |
| #include "swift/SILOptimizer/Differentiation/AdjointValue.h" |
| #include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h" |
| #include "swift/SILOptimizer/Differentiation/LinearMapInfo.h" |
| #include "swift/SILOptimizer/Differentiation/Thunk.h" |
| #include "swift/SILOptimizer/Differentiation/VJPCloner.h" |
| |
| #include "swift/AST/Expr.h" |
| #include "swift/AST/PropertyWrappers.h" |
| #include "swift/AST/TypeCheckRequests.h" |
| #include "swift/SIL/InstructionUtils.h" |
| #include "swift/SIL/Projection.h" |
| #include "swift/SIL/TypeSubstCloner.h" |
| #include "swift/SILOptimizer/PassManager/PrettyStackTrace.h" |
| #include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h" |
| #include "llvm/ADT/DenseMap.h" |
| |
| namespace swift { |
| |
| class SILDifferentiabilityWitness; |
| class SILBasicBlock; |
| class SILFunction; |
| class SILInstruction; |
| |
| namespace autodiff { |
| |
| class ADContext; |
| class VJPCloner; |
| |
| /// The implementation class for `PullbackCloner`. |
| /// |
| /// The implementation class is a `SILInstructionVisitor`. Effectively, it acts |
| /// as a `SILCloner` that visits basic blocks in post-order and that visits |
| /// instructions per basic block in reverse order. This visitation order is |
| /// necessary for generating pullback functions, whose control flow graph is |
| /// ~a transposed version of the original function's control flow graph. |
| class PullbackCloner::Implementation final |
| : public SILInstructionVisitor<PullbackCloner::Implementation> { |
| |
| public: |
| explicit Implementation(VJPCloner &vjpCloner); |
| |
| private: |
| /// The parent VJP cloner. |
| VJPCloner &vjpCloner; |
| |
| /// Dominance info for the original function. |
| DominanceInfo *domInfo = nullptr; |
| |
| /// Post-dominance info for the original function. |
| PostDominanceInfo *postDomInfo = nullptr; |
| |
| /// Post-order info for the original function. |
| PostOrderFunctionInfo *postOrderInfo = nullptr; |
| |
| /// Mapping from original basic blocks to corresponding pullback basic blocks. |
| /// Pullback basic blocks always have the predecessor as the single argument. |
| llvm::DenseMap<SILBasicBlock *, SILBasicBlock *> pullbackBBMap; |
| |
| /// Mapping from original basic blocks and original values to corresponding |
| /// adjoint values. |
| llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, AdjointValue> valueMap; |
| |
| /// Mapping from original basic blocks and original values to corresponding |
| /// adjoint buffers. |
| llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, SILValue> bufferMap; |
| |
| /// Mapping from pullback struct field declarations to pullback struct |
| /// elements destructured from the linear map basic block argument. In the |
| /// beginning of each pullback basic block, the block's pullback struct is |
| /// destructured into individual elements stored here. |
| llvm::DenseMap<VarDecl *, SILValue> pullbackStructElements; |
| |
| /// Mapping from original basic blocks and successor basic blocks to |
| /// corresponding pullback trampoline basic blocks. Trampoline basic blocks |
| /// take additional arguments in addition to the predecessor enum argument. |
| llvm::DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, SILBasicBlock *> |
| pullbackTrampolineBBMap; |
| |
| /// Mapping from original basic blocks to dominated active values. |
| llvm::DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> activeValues; |
| |
| /// Mapping from original basic blocks and original active values to |
| /// corresponding pullback block arguments. |
| llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, SILArgument *> |
| activeValuePullbackBBArgumentMap; |
| |
| /// Mapping from original basic blocks to local temporary values to be cleaned |
| /// up. This is populated when pullback emission is run on one basic block and |
| /// cleaned before processing another basic block. |
| llvm::DenseMap<SILBasicBlock *, SmallSetVector<SILValue, 64>> |
| blockTemporaries; |
| |
| /// The main builder. |
| SILBuilder builder; |
| |
| /// An auxiliary local allocation builder. |
| SILBuilder localAllocBuilder; |
| |
| /// Stack buffers allocated for storing local adjoint values. |
| SmallVector<AllocStackInst *, 64> functionLocalAllocations; |
| |
| /// A set used to remember local allocations that were destroyed. |
| llvm::SmallDenseSet<SILValue> destroyedLocalAllocations; |
| |
| /// The seed arguments of the pullback function. |
| SmallVector<SILArgument *, 4> seeds; |
| |
| /// The `AutoDiffLinearMapContext` object, if any. |
| SILValue contextValue = nullptr; |
| |
| llvm::BumpPtrAllocator allocator; |
| |
| bool errorOccurred = false; |
| |
| ADContext &getContext() const { return vjpCloner.getContext(); } |
| SILModule &getModule() const { return getContext().getModule(); } |
| ASTContext &getASTContext() const { return getPullback().getASTContext(); } |
| SILFunction &getOriginal() const { return vjpCloner.getOriginal(); } |
| SILDifferentiabilityWitness *getWitness() const { |
| return vjpCloner.getWitness(); |
| } |
| DifferentiationInvoker getInvoker() const { return vjpCloner.getInvoker(); } |
| LinearMapInfo &getPullbackInfo() const { return vjpCloner.getPullbackInfo(); } |
| AutoDiffConfig getConfig() const { return vjpCloner.getConfig(); } |
| const DifferentiableActivityInfo &getActivityInfo() const { |
| return vjpCloner.getActivityInfo(); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Pullback struct mapping |
| //--------------------------------------------------------------------------// |
| |
| void initializePullbackStructElements(SILBasicBlock *origBB, |
| SILInstructionResultArray values) { |
| auto *pbStructDecl = getPullbackInfo().getLinearMapStruct(origBB); |
| assert(pbStructDecl->getStoredProperties().size() == values.size() && |
| "The number of pullback struct fields must equal the number of " |
| "pullback struct element values"); |
| for (auto pair : llvm::zip(pbStructDecl->getStoredProperties(), values)) { |
| assert(std::get<1>(pair).getOwnershipKind() != |
| OwnershipKind::Guaranteed && |
| "Pullback struct elements must be @owned"); |
| auto insertion = |
| pullbackStructElements.insert({std::get<0>(pair), std::get<1>(pair)}); |
| (void)insertion; |
| assert(insertion.second && "A pullback struct element already exists!"); |
| } |
| } |
| |
| /// Returns the pullback struct element value corresponding to the given |
| /// original block and pullback struct field. |
| SILValue getPullbackStructElement(SILBasicBlock *origBB, VarDecl *field) { |
| assert(getPullbackInfo().getLinearMapStruct(origBB) == |
| cast<StructDecl>(field->getDeclContext())); |
| assert(pullbackStructElements.count(field) && |
| "Pullback struct element for this field does not exist!"); |
| return pullbackStructElements.lookup(field); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Type transformer |
| //--------------------------------------------------------------------------// |
| |
| /// Get the type lowering for the given AST type. |
| const Lowering::TypeLowering &getTypeLowering(Type type) { |
| auto pbGenSig = |
| getPullback().getLoweredFunctionType()->getSubstGenericSignature(); |
| Lowering::AbstractionPattern pattern(pbGenSig, |
| type->getCanonicalType(pbGenSig)); |
| return getPullback().getTypeLowering(pattern, type); |
| } |
| |
| /// Remap any archetypes into the current function's context. |
| SILType remapType(SILType ty) { |
| if (ty.hasArchetype()) |
| ty = ty.mapTypeOutOfContext(); |
| auto remappedType = ty.getASTType()->getCanonicalType( |
| getPullback().getLoweredFunctionType()->getSubstGenericSignature()); |
| auto remappedSILType = |
| SILType::getPrimitiveType(remappedType, ty.getCategory()); |
| return getPullback().mapTypeIntoContext(remappedSILType); |
| } |
| |
| Optional<TangentSpace> getTangentSpace(CanType type) { |
| // Use witness generic signature to remap types. |
| if (auto witnessGenSig = getWitness()->getDerivativeGenericSignature()) |
| type = witnessGenSig->getCanonicalTypeInContext(type); |
| return type->getAutoDiffTangentSpace( |
| LookUpConformanceInModule(getModule().getSwiftModule())); |
| } |
| |
| /// Returns the tangent value category of the given value. |
| SILValueCategory getTangentValueCategory(SILValue v) { |
| // Tangent value category table: |
| // |
| // Let $L be a loadable type and $*A be an address-only type. |
| // |
| // Original type | Tangent type loadable? | Tangent value category and type |
| // --------------|------------------------|-------------------------------- |
| // $L | loadable | object, $L' (no mismatch) |
| // $*A | loadable | address, $*L' (create a buffer) |
| // $L | address-only | address, $*A' (no alternative) |
| // $*A | address-only | address, $*A' (no alternative) |
| |
| // TODO(SR-13077): Make "tangent value category" depend solely on whether |
| // the tangent type is loadable or address-only. |
| // |
| // For loadable tangent types, using symbolic adjoint values instead of |
| // concrete adjoint buffers is more efficient. |
| |
| // Quick check: if the value has an address type, the tangent value category |
| // is currently always "address". |
| if (v->getType().isAddress()) |
| return SILValueCategory::Address; |
| // If the value has an object type and the tangent type is not address-only, |
| // then the tangent value category is "object". |
| auto tanSpace = getTangentSpace(remapType(v->getType()).getASTType()); |
| auto tanASTType = tanSpace->getCanonicalType(); |
| if (v->getType().isObject() && getTypeLowering(tanASTType).isLoadable()) |
| return SILValueCategory::Object; |
| // Otherwise, the tangent value category is "address". |
| return SILValueCategory::Address; |
| } |
| |
| /// Assuming the given type conforms to `Differentiable` after remapping, |
| /// returns the associated tangent space type. |
| SILType getRemappedTangentType(SILType type) { |
| return SILType::getPrimitiveType( |
| getTangentSpace(remapType(type).getASTType())->getCanonicalType(), |
| type.getCategory()); |
| } |
| |
| /// Substitutes all replacement types of the given substitution map using the |
| /// pullback function's substitution map. |
| SubstitutionMap remapSubstitutionMap(SubstitutionMap substMap) { |
| return substMap.subst(getPullback().getForwardingSubstitutionMap()); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Temporary value management |
| //--------------------------------------------------------------------------// |
| |
| /// Record a temporary value for cleanup before its block's terminator. |
| SILValue recordTemporary(SILValue value) { |
| assert(value->getType().isObject()); |
| assert(value->getFunction() == &getPullback()); |
| auto inserted = blockTemporaries[value->getParentBlock()].insert(value); |
| (void)inserted; |
| LLVM_DEBUG(getADDebugStream() << "Recorded temporary " << value); |
| assert(inserted && "Temporary already recorded?"); |
| return value; |
| } |
| |
| /// Clean up all temporary values for the given pullback block. |
| void cleanUpTemporariesForBlock(SILBasicBlock *bb, SILLocation loc) { |
| assert(bb->getParent() == &getPullback()); |
| LLVM_DEBUG(getADDebugStream() << "Cleaning up temporaries for pullback bb" |
| << bb->getDebugID() << '\n'); |
| for (auto temp : blockTemporaries[bb]) |
| builder.emitDestroyValueOperation(loc, temp); |
| blockTemporaries[bb].clear(); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Adjoint value factory methods |
| //--------------------------------------------------------------------------// |
| |
| AdjointValue makeZeroAdjointValue(SILType type) { |
| return AdjointValue::createZero(allocator, remapType(type)); |
| } |
| |
| AdjointValue makeConcreteAdjointValue(SILValue value) { |
| return AdjointValue::createConcrete(allocator, value); |
| } |
| |
| template <typename EltRange> |
| AdjointValue makeAggregateAdjointValue(SILType type, EltRange elements) { |
| AdjointValue *buf = reinterpret_cast<AdjointValue *>(allocator.Allocate( |
| elements.size() * sizeof(AdjointValue), alignof(AdjointValue))); |
| MutableArrayRef<AdjointValue> elementsCopy(buf, elements.size()); |
| std::uninitialized_copy(elements.begin(), elements.end(), |
| elementsCopy.begin()); |
| return AdjointValue::createAggregate(allocator, remapType(type), |
| elementsCopy); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Adjoint value materialization |
| //--------------------------------------------------------------------------// |
| |
| /// Materializes an adjoint value. The type of the given adjoint value must be |
| /// loadable. |
| SILValue materializeAdjointDirect(AdjointValue val, SILLocation loc) { |
| assert(val.getType().isObject()); |
| LLVM_DEBUG(getADDebugStream() |
| << "Materializing adjoint for " << val << '\n'); |
| switch (val.getKind()) { |
| case AdjointValueKind::Zero: |
| return recordTemporary(emitZeroDirect(val.getType().getASTType(), loc)); |
| case AdjointValueKind::Aggregate: { |
| SmallVector<SILValue, 8> elements; |
| for (auto i : range(val.getNumAggregateElements())) { |
| auto eltVal = materializeAdjointDirect(val.getAggregateElement(i), loc); |
| elements.push_back(builder.emitCopyValueOperation(loc, eltVal)); |
| } |
| if (val.getType().is<TupleType>()) |
| return recordTemporary( |
| builder.createTuple(loc, val.getType(), elements)); |
| else |
| return recordTemporary( |
| builder.createStruct(loc, val.getType(), elements)); |
| } |
| case AdjointValueKind::Concrete: |
| return val.getConcreteValue(); |
| } |
| llvm_unreachable("unhandled adjoint value kind!"); |
| } |
| |
| /// Materializes an adjoint value indirectly to a SIL buffer. |
| void materializeAdjointIndirect(AdjointValue val, SILValue destAddress, |
| SILLocation loc) { |
| assert(destAddress->getType().isAddress()); |
| switch (val.getKind()) { |
| /// If adjoint value is a symbolic zero, emit a call to |
| /// `AdditiveArithmetic.zero`. |
| case AdjointValueKind::Zero: |
| emitZeroIndirect(val.getSwiftType(), destAddress, loc); |
| break; |
| /// If adjoint value is a symbolic aggregate (tuple or struct), recursively |
| /// materialize materialize the symbolic tuple or struct, filling the |
| /// buffer. |
| case AdjointValueKind::Aggregate: { |
| if (auto *tupTy = val.getSwiftType()->getAs<TupleType>()) { |
| for (auto idx : range(val.getNumAggregateElements())) { |
| auto eltTy = SILType::getPrimitiveAddressType( |
| tupTy->getElementType(idx)->getCanonicalType()); |
| auto *eltBuf = |
| builder.createTupleElementAddr(loc, destAddress, idx, eltTy); |
| materializeAdjointIndirect(val.getAggregateElement(idx), eltBuf, loc); |
| } |
| } else if (auto *structDecl = |
| val.getSwiftType()->getStructOrBoundGenericStruct()) { |
| auto fieldIt = structDecl->getStoredProperties().begin(); |
| for (unsigned i = 0; fieldIt != structDecl->getStoredProperties().end(); |
| ++fieldIt, ++i) { |
| auto eltBuf = |
| builder.createStructElementAddr(loc, destAddress, *fieldIt); |
| materializeAdjointIndirect(val.getAggregateElement(i), eltBuf, loc); |
| } |
| } else { |
| llvm_unreachable("Not an aggregate type"); |
| } |
| break; |
| } |
| /// If adjoint value is concrete, it is already materialized. Store it in |
| /// the destination address. |
| case AdjointValueKind::Concrete: |
| auto concreteVal = val.getConcreteValue(); |
| builder.emitStoreValueOperation(loc, concreteVal, destAddress, |
| StoreOwnershipQualifier::Init); |
| break; |
| } |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Helpers for adjoint value materialization |
| //--------------------------------------------------------------------------// |
| |
| /// Emits a zero value into the given address by calling |
| /// `AdditiveArithmetic.zero`. The given type must conform to |
| /// `AdditiveArithmetic`. |
| void emitZeroIndirect(CanType type, SILValue address, SILLocation loc) { |
| auto tangentSpace = getTangentSpace(type); |
| assert(tangentSpace && "No tangent space for this type"); |
| switch (tangentSpace->getKind()) { |
| case TangentSpace::Kind::TangentVector: |
| emitZeroIntoBuffer(builder, type, address, loc); |
| return; |
| case TangentSpace::Kind::Tuple: { |
| auto tupleType = tangentSpace->getTuple(); |
| SmallVector<SILValue, 8> zeroElements; |
| for (unsigned i : range(tupleType->getNumElements())) { |
| auto eltAddr = builder.createTupleElementAddr(loc, address, i); |
| emitZeroIndirect(tupleType->getElementType(i)->getCanonicalType(), |
| eltAddr, loc); |
| } |
| return; |
| } |
| } |
| } |
| |
| /// Emits a zero value by calling `AdditiveArithmetic.zero`. The given type |
| /// must conform to `AdditiveArithmetic` and be loadable in SIL. |
| SILValue emitZeroDirect(CanType type, SILLocation loc) { |
| auto silType = getModule().Types.getLoweredLoadableType( |
| type, TypeExpansionContext::minimal(), getModule()); |
| auto *alloc = builder.createAllocStack(loc, silType); |
| emitZeroIndirect(type, alloc, loc); |
| auto zeroValue = builder.emitLoadValueOperation( |
| loc, alloc, LoadOwnershipQualifier::Take); |
| builder.createDeallocStack(loc, alloc); |
| return zeroValue; |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Adjoint value mapping |
| //--------------------------------------------------------------------------// |
| |
| /// Returns true if the given value in the original function has a |
| /// corresponding adjoint value. |
| bool hasAdjointValue(SILBasicBlock *origBB, SILValue originalValue) const { |
| assert(origBB->getParent() == &getOriginal()); |
| assert(originalValue->getType().isObject()); |
| return valueMap.count({origBB, originalValue}); |
| } |
| |
| /// Initializes the adjoint value for the original value. Asserts that the |
| /// original value does not already have an adjoint value. |
| void setAdjointValue(SILBasicBlock *origBB, SILValue originalValue, |
| AdjointValue adjointValue) { |
| LLVM_DEBUG(getADDebugStream() |
| << "Setting adjoint value for " << originalValue); |
| assert(origBB->getParent() == &getOriginal()); |
| assert(originalValue->getType().isObject()); |
| assert(getTangentValueCategory(originalValue) == SILValueCategory::Object); |
| assert(adjointValue.getType().isObject()); |
| assert(originalValue->getFunction() == &getOriginal()); |
| // The adjoint value must be in the tangent space. |
| assert(adjointValue.getType() == |
| getRemappedTangentType(originalValue->getType())); |
| auto insertion = |
| valueMap.try_emplace({origBB, originalValue}, adjointValue); |
| LLVM_DEBUG(getADDebugStream() |
| << "The new adjoint value, replacing the existing one, is: " |
| << insertion.first->getSecond()); |
| if (!insertion.second) |
| insertion.first->getSecond() = adjointValue; |
| } |
| |
| /// Returns the adjoint value for a value in the original function. |
| /// |
| /// This method first tries to find an existing entry in the adjoint value |
| /// mapping. If no entry exists, creates a zero adjoint value. |
| AdjointValue getAdjointValue(SILBasicBlock *origBB, SILValue originalValue) { |
| assert(origBB->getParent() == &getOriginal()); |
| assert(originalValue->getType().isObject()); |
| assert(getTangentValueCategory(originalValue) == SILValueCategory::Object); |
| assert(originalValue->getFunction() == &getOriginal()); |
| auto insertion = valueMap.try_emplace( |
| {origBB, originalValue}, |
| makeZeroAdjointValue(getRemappedTangentType(originalValue->getType()))); |
| auto it = insertion.first; |
| return it->getSecond(); |
| } |
| |
| /// Adds `newAdjointValue` to the adjoint value for `originalValue` and sets |
| /// the sum as the new adjoint value. |
| void addAdjointValue(SILBasicBlock *origBB, SILValue originalValue, |
| AdjointValue newAdjointValue, SILLocation loc) { |
| assert(origBB->getParent() == &getOriginal()); |
| assert(originalValue->getType().isObject()); |
| assert(newAdjointValue.getType().isObject()); |
| assert(originalValue->getFunction() == &getOriginal()); |
| LLVM_DEBUG(getADDebugStream() |
| << "Adding adjoint value for " << originalValue); |
| // The adjoint value must be in the tangent space. |
| assert(newAdjointValue.getType() == |
| getRemappedTangentType(originalValue->getType())); |
| auto insertion = |
| valueMap.try_emplace({origBB, originalValue}, newAdjointValue); |
| auto inserted = insertion.second; |
| if (inserted) |
| return; |
| // If adjoint already exists, accumulate the adjoint onto the existing |
| // adjoint. |
| auto it = insertion.first; |
| auto existingValue = it->getSecond(); |
| valueMap.erase(it); |
| auto adjVal = accumulateAdjointsDirect(existingValue, newAdjointValue, loc); |
| // If the original value is the `Array` result of an |
| // `array.uninitialized_intrinsic` application, accumulate adjoint buffers |
| // for the array element addresses. |
| accumulateArrayLiteralElementAddressAdjoints(origBB, originalValue, adjVal, |
| loc); |
| setAdjointValue(origBB, originalValue, adjVal); |
| } |
| |
| /// Get the pullback block argument corresponding to the given original block |
| /// and active value. |
| SILArgument *getActiveValuePullbackBlockArgument(SILBasicBlock *origBB, |
| SILValue activeValue) { |
| assert(getTangentValueCategory(activeValue) == SILValueCategory::Object); |
| assert(origBB->getParent() == &getOriginal()); |
| auto pullbackBBArg = |
| activeValuePullbackBBArgumentMap[{origBB, activeValue}]; |
| assert(pullbackBBArg); |
| assert(pullbackBBArg->getParent() == getPullbackBlock(origBB)); |
| return pullbackBBArg; |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Adjoint value accumulation |
| //--------------------------------------------------------------------------// |
| |
| /// Given two adjoint values, accumulates them and returns their sum. |
| AdjointValue accumulateAdjointsDirect(AdjointValue lhs, AdjointValue rhs, |
| SILLocation loc); |
| |
| /// Generates code returning `result = lhs + rhs`. |
| /// |
| /// Given two materialized adjoint values, accumulates them and returns their |
| /// sum. The adjoint values must have a loadable type. |
| SILValue accumulateDirect(SILValue lhs, SILValue rhs, SILLocation loc); |
| |
| /// Generates code for `resultAddress = lhsAddress + rhsAddress`. |
| /// |
| /// Given two addresses with the same `AdditiveArithmetic`-conforming type, |
| /// accumulates them into a result address using `AdditiveArithmetic.+`. |
| void accumulateIndirect(SILValue resultAddress, SILValue lhsAddress, |
| SILValue rhsAddress, SILLocation loc); |
| |
| /// Generates code for `lhsDestAddress += rhsAddress`. |
| /// |
| /// Given two addresses with the same `AdditiveArithmetic`-conforming type, |
| /// accumulates the rhs into the lhs using `AdditiveArithmetic.+=`. |
| void accumulateIndirect(SILValue lhsDestAddress, SILValue rhsAddress, |
| SILLocation loc); |
| |
| //--------------------------------------------------------------------------// |
| // Adjoint buffer mapping |
| //--------------------------------------------------------------------------// |
| |
| /// If the given original value is an address projection, returns a |
| /// corresponding adjoint projection to be used as its adjoint buffer. |
| /// |
| /// Helper function for `getAdjointBuffer`. |
| SILValue getAdjointProjection(SILBasicBlock *origBB, SILValue originalValue); |
| |
| /// Returns the adjoint buffer for the original value. |
| /// |
| /// This method first tries to find an existing entry in the adjoint buffer |
| /// mapping. If no entry exists, creates a zero adjoint buffer. |
| SILValue getAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue) { |
| assert(getTangentValueCategory(originalValue) == SILValueCategory::Address); |
| assert(originalValue->getFunction() == &getOriginal()); |
| auto insertion = bufferMap.try_emplace({origBB, originalValue}, SILValue()); |
| if (!insertion.second) // not inserted |
| return insertion.first->getSecond(); |
| |
| // If the original buffer is a projection, return a corresponding projection |
| // into the adjoint buffer. |
| if (auto adjProj = getAdjointProjection(origBB, originalValue)) |
| return (bufferMap[{origBB, originalValue}] = adjProj); |
| |
| auto bufType = getRemappedTangentType(originalValue->getType()); |
| // Set insertion point for local allocation builder: before the last local |
| // allocation, or at the start of the pullback function's entry if no local |
| // allocations exist yet. |
| auto *newBuf = createFunctionLocalAllocation( |
| bufType, RegularLocation::getAutoGeneratedLocation()); |
| // Temporarily change global builder insertion point and emit zero into the |
| // local allocation. |
| auto insertionPoint = builder.getInsertionBB(); |
| builder.setInsertionPoint(localAllocBuilder.getInsertionBB(), |
| localAllocBuilder.getInsertionPoint()); |
| emitZeroIndirect(bufType.getASTType(), newBuf, newBuf->getLoc()); |
| builder.setInsertionPoint(insertionPoint); |
| return (insertion.first->getSecond() = newBuf); |
| } |
| |
| /// Initializes the adjoint buffer for the original value. Asserts that the |
| /// original value does not already have an adjoint buffer. |
| void setAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue, |
| SILValue adjointBuffer) { |
| assert(getTangentValueCategory(originalValue) == SILValueCategory::Address); |
| auto insertion = |
| bufferMap.try_emplace({origBB, originalValue}, adjointBuffer); |
| assert(insertion.second && "Adjoint buffer already exists"); |
| (void)insertion; |
| } |
| |
| /// Accumulates `rhsAddress` into the adjoint buffer corresponding to the |
| /// original value. |
| void addToAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue, |
| SILValue rhsAddress, SILLocation loc) { |
| assert(getTangentValueCategory(originalValue) == |
| SILValueCategory::Address && |
| rhsAddress->getType().isAddress()); |
| assert(originalValue->getFunction() == &getOriginal()); |
| assert(rhsAddress->getFunction() == &getPullback()); |
| auto adjointBuffer = getAdjointBuffer(origBB, originalValue); |
| accumulateIndirect(adjointBuffer, rhsAddress, loc); |
| } |
| |
| /// Returns a next insertion point for creating a local allocation: either |
| /// before the previous local allocation, or at the start of the pullback |
| /// entry if no local allocations exist. |
| /// |
| /// Helper for `createFunctionLocalAllocation`. |
| SILBasicBlock::iterator getNextFunctionLocalAllocationInsertionPoint() { |
| // If there are no local allocations, insert at the pullback entry start. |
| if (functionLocalAllocations.empty()) |
| return getPullback().getEntryBlock()->begin(); |
| // Otherwise, insert before the last local allocation. Inserting before |
| // rather than after ensures that allocation and zero initialization |
| // instructions are grouped together. |
| auto lastLocalAlloc = functionLocalAllocations.back(); |
| return lastLocalAlloc->getDefiningInstruction()->getIterator(); |
| } |
| |
| /// Creates and returns a local allocation with the given type. |
| /// |
| /// Local allocations are created uninitialized in the pullback entry and |
| /// deallocated in the pullback exit. All local allocations not in |
| /// `destroyedLocalAllocations` are also destroyed in the pullback exit. |
| /// |
| /// Helper for `getAdjointBuffer`. |
| AllocStackInst *createFunctionLocalAllocation(SILType type, SILLocation loc) { |
| // Set insertion point for local allocation builder: before the last local |
| // allocation, or at the start of the pullback function's entry if no local |
| // allocations exist yet. |
| localAllocBuilder.setInsertionPoint( |
| getPullback().getEntryBlock(), |
| getNextFunctionLocalAllocationInsertionPoint()); |
| // Create and return local allocation. |
| auto *alloc = localAllocBuilder.createAllocStack(loc, type); |
| functionLocalAllocations.push_back(alloc); |
| return alloc; |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Optional differentiation |
| //--------------------------------------------------------------------------// |
| |
| /// Given a `wrappedAdjoint` value of type `T.TangentVector`, creates an |
| /// `Optional<T>.TangentVector` value from it and adds it to the adjoint value |
| /// of `optionalValue`. |
| /// |
| /// `wrappedAdjoint` may be an object or address value, both cases are |
| /// handled. |
| void accumulateAdjointForOptional(SILBasicBlock *bb, SILValue optionalValue, |
| SILValue wrappedAdjoint); |
| |
| //--------------------------------------------------------------------------// |
| // Array literal initialization differentiation |
| //--------------------------------------------------------------------------// |
| |
| /// Given the adjoint value of an array initialized from an |
| /// `array.uninitialized_intrinsic` application and an array element index, |
| /// returns an `alloc_stack` containing the adjoint value of the array element |
| /// at the given index by applying `Array.TangentVector.subscript`. |
| AllocStackInst *getArrayAdjointElementBuffer(SILValue arrayAdjoint, |
| int eltIndex, SILLocation loc); |
| |
| /// Given the adjoint value of an array initialized from an |
| /// `array.uninitialized_intrinsic` application, accumulates the adjoint |
| /// value's elements into the adjoint buffers of its element addresses. |
| void accumulateArrayLiteralElementAddressAdjoints( |
| SILBasicBlock *origBB, SILValue originalValue, |
| AdjointValue arrayAdjointValue, SILLocation loc); |
| |
| //--------------------------------------------------------------------------// |
| // CFG mapping |
| //--------------------------------------------------------------------------// |
| |
| SILBasicBlock *getPullbackBlock(SILBasicBlock *originalBlock) { |
| return pullbackBBMap.lookup(originalBlock); |
| } |
| |
| SILBasicBlock *getPullbackTrampolineBlock(SILBasicBlock *originalBlock, |
| SILBasicBlock *successorBlock) { |
| return pullbackTrampolineBBMap.lookup({originalBlock, successorBlock}); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Debugging utilities |
| //--------------------------------------------------------------------------// |
| |
| void printAdjointValueMapping() { |
| // Group original/adjoint values by basic block. |
| llvm::DenseMap<SILBasicBlock *, llvm::DenseMap<SILValue, AdjointValue>> tmp; |
| for (auto pair : valueMap) { |
| auto origPair = pair.first; |
| auto *origBB = origPair.first; |
| auto origValue = origPair.second; |
| auto adjValue = pair.second; |
| tmp[origBB].insert({origValue, adjValue}); |
| } |
| // Print original/adjoint values per basic block. |
| auto &s = getADDebugStream() << "Adjoint value mapping:\n"; |
| for (auto &origBB : getOriginal()) { |
| if (!pullbackBBMap.count(&origBB)) |
| continue; |
| auto bbValueMap = tmp[&origBB]; |
| s << "bb" << origBB.getDebugID(); |
| s << " (size " << bbValueMap.size() << "):\n"; |
| for (auto valuePair : bbValueMap) { |
| auto origValue = valuePair.first; |
| auto adjValue = valuePair.second; |
| s << "ORIG: " << origValue; |
| s << "ADJ: " << adjValue << '\n'; |
| } |
| s << '\n'; |
| } |
| } |
| |
| void printAdjointBufferMapping() { |
| // Group original/adjoint buffers by basic block. |
| llvm::DenseMap<SILBasicBlock *, llvm::DenseMap<SILValue, SILValue>> tmp; |
| for (auto pair : bufferMap) { |
| auto origPair = pair.first; |
| auto *origBB = origPair.first; |
| auto origBuf = origPair.second; |
| auto adjBuf = pair.second; |
| tmp[origBB][origBuf] = adjBuf; |
| } |
| // Print original/adjoint buffers per basic block. |
| auto &s = getADDebugStream() << "Adjoint buffer mapping:\n"; |
| for (auto &origBB : getOriginal()) { |
| if (!pullbackBBMap.count(&origBB)) |
| continue; |
| auto bbBufferMap = tmp[&origBB]; |
| s << "bb" << origBB.getDebugID(); |
| s << " (size " << bbBufferMap.size() << "):\n"; |
| for (auto valuePair : bbBufferMap) { |
| auto origBuf = valuePair.first; |
| auto adjBuf = valuePair.second; |
| s << "ORIG: " << origBuf; |
| s << "ADJ: " << adjBuf << '\n'; |
| } |
| s << '\n'; |
| } |
| } |
| |
| public: |
| //--------------------------------------------------------------------------// |
| // Entry point |
| //--------------------------------------------------------------------------// |
| |
| /// Performs pullback generation on the empty pullback function. Returns true |
| /// if any error occurs. |
| bool run(); |
| |
| /// Performs pullback generation on the empty pullback function, given that |
| /// the original function is a "semantic member accessor". |
| /// |
| /// "Semantic member accessors" are attached to member properties that have a |
| /// corresponding tangent stored property in the parent `TangentVector` type. |
| /// These accessors have special-case pullback generation based on their |
| /// semantic behavior. |
| /// |
| /// Returns true if any error occurs. |
| bool runForSemanticMemberAccessor(); |
| bool runForSemanticMemberGetter(); |
| bool runForSemanticMemberSetter(); |
| |
| /// If original result is non-varied, it will always have a zero derivative. |
| /// Skip full pullback generation and simply emit zero derivatives for wrt |
| /// parameters. |
| void emitZeroDerivativesForNonvariedResult(SILValue origNonvariedResult); |
| |
| /// Public helper so that our users can get the underlying newly created |
| /// function. |
| SILFunction &getPullback() const { return vjpCloner.getPullback(); } |
| |
| using TrampolineBlockSet = SmallPtrSet<SILBasicBlock *, 4>; |
| |
| /// Determines the pullback successor block for a given original block and one |
| /// of its predecessors. When a trampoline block is necessary, emits code into |
| /// the trampoline block to trampoline the original block's active value's |
| /// adjoint values. |
| /// |
| /// Populates `pullbackTrampolineBlockMap`, which maps active values' adjoint |
| /// values to the pullback successor blocks in which they are used. This |
| /// allows us to release those values in pullback successor blocks that do not |
| /// use them. |
| SILBasicBlock * |
| buildPullbackSuccessor(SILBasicBlock *origBB, SILBasicBlock *origPredBB, |
| llvm::SmallDenseMap<SILValue, TrampolineBlockSet> |
| &pullbackTrampolineBlockMap); |
| |
| /// Emits pullback code in the corresponding pullback block. |
| void visitSILBasicBlock(SILBasicBlock *bb); |
| |
| void visit(SILInstruction *inst) { |
| if (errorOccurred) |
| return; |
| |
| LLVM_DEBUG(getADDebugStream() |
| << "PullbackCloner visited:\n[ORIG]" << *inst); |
| #ifndef NDEBUG |
| auto beforeInsertion = std::prev(builder.getInsertionPoint()); |
| #endif |
| SILInstructionVisitor::visit(inst); |
| LLVM_DEBUG({ |
| auto &s = llvm::dbgs() << "[ADJ] Emitted in pullback:\n"; |
| auto afterInsertion = builder.getInsertionPoint(); |
| for (auto it = ++beforeInsertion; it != afterInsertion; ++it) |
| s << *it; |
| }); |
| } |
| |
| /// Fallback instruction visitor for unhandled instructions. |
| /// Emit a general non-differentiability diagnostic. |
| void visitSILInstruction(SILInstruction *inst) { |
| LLVM_DEBUG(getADDebugStream() |
| << "Unhandled instruction in PullbackCloner: " << *inst); |
| getContext().emitNondifferentiabilityError( |
| inst, getInvoker(), diag::autodiff_expression_not_differentiable_note); |
| errorOccurred = true; |
| } |
| |
| /// Handle `apply` instruction. |
| /// Original: (y0, y1, ...) = apply @fn (x0, x1, ...) |
| /// Adjoint: (adj[x0], adj[x1], ...) += apply @fn_pullback (adj[y0], ...) |
| void visitApplyInst(ApplyInst *ai) { |
| assert(getPullbackInfo().shouldDifferentiateApplySite(ai)); |
| // Skip `array.uninitialized_intrinsic` applications, which have special |
| // `store` and `copy_addr` support. |
| if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) |
| return; |
| auto loc = ai->getLoc(); |
| auto *bb = ai->getParent(); |
| // Handle `array.finalize_intrinsic` applications. |
| // `array.finalize_intrinsic` semantically behaves like an identity |
| // function. |
| if (ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC)) { |
| assert(ai->getNumArguments() == 1 && |
| "Expected intrinsic to have one operand"); |
| // Accumulate result's adjoint into argument's adjoint. |
| auto adjResult = getAdjointValue(bb, ai); |
| auto origArg = ai->getArgumentsWithoutIndirectResults().front(); |
| addAdjointValue(bb, origArg, adjResult, loc); |
| return; |
| } |
| // Replace a call to a function with a call to its pullback. |
| auto &nestedApplyInfo = getContext().getNestedApplyInfo(); |
| auto applyInfoLookup = nestedApplyInfo.find(ai); |
| // If no `NestedApplyInfo` was found, then this task doesn't need to be |
| // differentiated. |
| if (applyInfoLookup == nestedApplyInfo.end()) { |
| // Must not be active. |
| assert(!getActivityInfo().isActive(ai, getConfig())); |
| return; |
| } |
| auto applyInfo = applyInfoLookup->getSecond(); |
| |
| // Get the pullback. |
| auto *field = getPullbackInfo().lookUpLinearMapDecl(ai); |
| assert(field); |
| auto pullback = getPullbackStructElement(ai->getParent(), field); |
| |
| // Get the original result of the `apply` instruction. |
| SmallVector<SILValue, 8> origDirectResults; |
| forEachApplyDirectResult(ai, [&](SILValue directResult) { |
| origDirectResults.push_back(directResult); |
| }); |
| SmallVector<SILValue, 8> origAllResults; |
| collectAllActualResultsInTypeOrder(ai, origDirectResults, origAllResults); |
| // Append `inout` arguments after original results. |
| for (auto paramIdx : applyInfo.config.parameterIndices->getIndices()) { |
| auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( |
| ai->getNumIndirectResults() + paramIdx); |
| if (!paramInfo.isIndirectMutating()) |
| continue; |
| origAllResults.push_back( |
| ai->getArgumentsWithoutIndirectResults()[paramIdx]); |
| } |
| |
| // Get callee pullback arguments. |
| SmallVector<SILValue, 8> args; |
| |
| // Handle callee pullback indirect results. |
| // Create local allocations for these and destroy them after the call. |
| auto pullbackType = |
| remapType(pullback->getType()).castTo<SILFunctionType>(); |
| auto actualPullbackType = applyInfo.originalPullbackType |
| ? *applyInfo.originalPullbackType |
| : pullbackType; |
| actualPullbackType = actualPullbackType->getUnsubstitutedType(getModule()); |
| SmallVector<AllocStackInst *, 4> pullbackIndirectResults; |
| for (auto indRes : actualPullbackType->getIndirectFormalResults()) { |
| auto *alloc = builder.createAllocStack( |
| loc, remapType(indRes.getSILStorageInterfaceType())); |
| pullbackIndirectResults.push_back(alloc); |
| args.push_back(alloc); |
| } |
| |
| // Collect callee pullback formal arguments. |
| for (auto resultIndex : applyInfo.config.resultIndices->getIndices()) { |
| assert(resultIndex < origAllResults.size()); |
| auto origResult = origAllResults[resultIndex]; |
| // Get the seed (i.e. adjoint value of the original result). |
| SILValue seed; |
| switch (getTangentValueCategory(origResult)) { |
| case SILValueCategory::Object: |
| seed = materializeAdjointDirect(getAdjointValue(bb, origResult), loc); |
| break; |
| case SILValueCategory::Address: |
| seed = getAdjointBuffer(bb, origResult); |
| break; |
| } |
| args.push_back(seed); |
| } |
| |
| // If callee pullback was reabstracted in VJP, reabstract callee pullback. |
| if (applyInfo.originalPullbackType) { |
| SILOptFunctionBuilder fb(getContext().getTransform()); |
| pullback = reabstractFunction( |
| builder, fb, loc, pullback, *applyInfo.originalPullbackType, |
| [this](SubstitutionMap subs) -> SubstitutionMap { |
| return this->remapSubstitutionMap(subs); |
| }); |
| } |
| |
| // Call the callee pullback. |
| auto *pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(), |
| args, /*isNonThrowing*/ false); |
| builder.emitDestroyValueOperation(loc, pullback); |
| |
| // Extract all results from `pullbackCall`. |
| SmallVector<SILValue, 8> dirResults; |
| extractAllElements(pullbackCall, builder, dirResults); |
| // Get all results in type-defined order. |
| SmallVector<SILValue, 8> allResults; |
| collectAllActualResultsInTypeOrder(pullbackCall, dirResults, allResults); |
| |
| LLVM_DEBUG({ |
| auto &s = getADDebugStream(); |
| s << "All results of the nested pullback call:\n"; |
| llvm::for_each(allResults, [&](SILValue v) { s << v; }); |
| }); |
| |
| // Accumulate adjoints for original differentiation parameters. |
| auto allResultsIt = allResults.begin(); |
| for (unsigned i : applyInfo.config.parameterIndices->getIndices()) { |
| auto origArg = ai->getArgument(ai->getNumIndirectResults() + i); |
| // Skip adjoint accumulation for `inout` arguments. |
| auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( |
| ai->getNumIndirectResults() + i); |
| if (paramInfo.isIndirectMutating()) |
| continue; |
| auto tan = *allResultsIt++; |
| if (tan->getType().isAddress()) { |
| addToAdjointBuffer(bb, origArg, tan, loc); |
| } else { |
| if (origArg->getType().isAddress()) { |
| auto *tmpBuf = builder.createAllocStack(loc, tan->getType()); |
| builder.emitStoreValueOperation(loc, tan, tmpBuf, |
| StoreOwnershipQualifier::Init); |
| addToAdjointBuffer(bb, origArg, tmpBuf, loc); |
| builder.emitDestroyAddrAndFold(loc, tmpBuf); |
| builder.createDeallocStack(loc, tmpBuf); |
| } else { |
| recordTemporary(tan); |
| addAdjointValue(bb, origArg, makeConcreteAdjointValue(tan), loc); |
| } |
| } |
| } |
| // Destroy unused pullback direct results. Needed for pullback results from |
| // VJPs extracted from `@differentiable` function callees, where the |
| // `@differentiable` function's differentiation parameter indices are a |
| // superset of the active `apply` parameter indices. |
| while (allResultsIt != allResults.end()) { |
| auto unusedPullbackDirectResult = *allResultsIt++; |
| if (unusedPullbackDirectResult->getType().isAddress()) |
| continue; |
| builder.emitDestroyValueOperation(loc, unusedPullbackDirectResult); |
| } |
| // Destroy and deallocate pullback indirect results. |
| for (auto *alloc : llvm::reverse(pullbackIndirectResults)) { |
| builder.emitDestroyAddrAndFold(loc, alloc); |
| builder.createDeallocStack(loc, alloc); |
| } |
| } |
| |
| void visitBeginApplyInst(BeginApplyInst *bai) { |
| // Diagnose `begin_apply` instructions. |
| // Coroutine differentiation is not yet supported. |
| getContext().emitNondifferentiabilityError( |
| bai, getInvoker(), diag::autodiff_coroutines_not_supported); |
| errorOccurred = true; |
| return; |
| } |
| |
| /// Handle `struct` instruction. |
| /// Original: y = struct (x0, x1, x2, ...) |
| /// Adjoint: adj[x0] += struct_extract adj[y], #x0 |
| /// adj[x1] += struct_extract adj[y], #x1 |
| /// adj[x2] += struct_extract adj[y], #x2 |
| /// ... |
| void visitStructInst(StructInst *si) { |
| auto *bb = si->getParent(); |
| auto loc = si->getLoc(); |
| auto *structDecl = si->getStructDecl(); |
| switch (getTangentValueCategory(si)) { |
| case SILValueCategory::Object: { |
| auto av = getAdjointValue(bb, si); |
| switch (av.getKind()) { |
| case AdjointValueKind::Zero: { |
| for (auto *field : structDecl->getStoredProperties()) { |
| auto fv = si->getFieldValue(field); |
| addAdjointValue( |
| bb, fv, |
| makeZeroAdjointValue(getRemappedTangentType(fv->getType())), loc); |
| } |
| break; |
| } |
| case AdjointValueKind::Concrete: { |
| auto adjStruct = materializeAdjointDirect(std::move(av), loc); |
| auto *dti = builder.createDestructureStruct(si->getLoc(), adjStruct); |
| |
| // Find the struct `TangentVector` type. |
| auto structTy = remapType(si->getType()).getASTType(); |
| #ifndef NDEBUG |
| auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType(); |
| assert(!getTypeLowering(tangentVectorTy).isAddressOnly()); |
| assert(tangentVectorTy->getStructOrBoundGenericStruct()); |
| #endif |
| |
| // Accumulate adjoints for the fields of the `struct` operand. |
| unsigned fieldIndex = 0; |
| for (auto it = structDecl->getStoredProperties().begin(); |
| it != structDecl->getStoredProperties().end(); |
| ++it, ++fieldIndex) { |
| VarDecl *field = *it; |
| if (field->getAttrs().hasAttribute<NoDerivativeAttr>()) |
| continue; |
| // Find the corresponding field in the tangent space. |
| auto *tanField = getTangentStoredProperty( |
| getContext(), field, structTy, loc, getInvoker()); |
| if (!tanField) { |
| errorOccurred = true; |
| return; |
| } |
| auto tanElt = dti->getResult(fieldIndex); |
| addAdjointValue(bb, si->getFieldValue(field), |
| makeConcreteAdjointValue(tanElt), si->getLoc()); |
| } |
| break; |
| } |
| case AdjointValueKind::Aggregate: { |
| // Note: All user-called initializations go through the calls to the |
| // initializer, and synthesized initializers only have one level of |
| // struct formation which will not result into any aggregate adjoint |
| // valeus. |
| llvm_unreachable( |
| "Aggregate adjoint values should not occur for `struct` " |
| "instructions"); |
| } |
| } |
| break; |
| } |
| case SILValueCategory::Address: { |
| auto adjBuf = getAdjointBuffer(bb, si); |
| // Find the struct `TangentVector` type. |
| auto structTy = remapType(si->getType()).getASTType(); |
| // Accumulate adjoints for the fields of the `struct` operand. |
| unsigned fieldIndex = 0; |
| for (auto it = structDecl->getStoredProperties().begin(); |
| it != structDecl->getStoredProperties().end(); ++it, ++fieldIndex) { |
| VarDecl *field = *it; |
| if (field->getAttrs().hasAttribute<NoDerivativeAttr>()) |
| continue; |
| // Find the corresponding field in the tangent space. |
| auto *tanField = getTangentStoredProperty(getContext(), field, structTy, |
| loc, getInvoker()); |
| if (!tanField) { |
| errorOccurred = true; |
| return; |
| } |
| auto *adjFieldBuf = |
| builder.createStructElementAddr(loc, adjBuf, tanField); |
| auto fieldValue = si->getFieldValue(field); |
| switch (getTangentValueCategory(fieldValue)) { |
| case SILValueCategory::Object: { |
| auto adjField = builder.emitLoadValueOperation( |
| loc, adjFieldBuf, LoadOwnershipQualifier::Copy); |
| recordTemporary(adjField); |
| addAdjointValue(bb, fieldValue, makeConcreteAdjointValue(adjField), |
| loc); |
| break; |
| } |
| case SILValueCategory::Address: { |
| addToAdjointBuffer(bb, fieldValue, adjFieldBuf, loc); |
| break; |
| } |
| } |
| } |
| } break; |
| } |
| } |
| |
| /// Handle `struct_extract` instruction. |
| /// Original: y = struct_extract x, #field |
| /// Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0) |
| /// ^~~~~~~ |
| /// field in tangent space corresponding to #field |
| void visitStructExtractInst(StructExtractInst *sei) { |
| auto *bb = sei->getParent(); |
| auto loc = getValidLocation(sei); |
| // Find the corresponding field in the tangent space. |
| auto structTy = remapType(sei->getOperand()->getType()).getASTType(); |
| auto *tanField = |
| getTangentStoredProperty(getContext(), sei, structTy, getInvoker()); |
| // Check the `struct_extract` operand's value tangent category. |
| switch (getTangentValueCategory(sei->getOperand())) { |
| case SILValueCategory::Object: { |
| auto tangentVectorTy = getTangentSpace(structTy)->getCanonicalType(); |
| auto *tangentVectorDecl = |
| tangentVectorTy->getStructOrBoundGenericStruct(); |
| assert(tangentVectorDecl); |
| auto tangentVectorSILTy = |
| SILType::getPrimitiveObjectType(tangentVectorTy); |
| assert(tanField && "Invalid projections should have been diagnosed"); |
| // Accumulate adjoint for the `struct_extract` operand. |
| auto av = getAdjointValue(bb, sei); |
| switch (av.getKind()) { |
| case AdjointValueKind::Zero: |
| addAdjointValue(bb, sei->getOperand(), |
| makeZeroAdjointValue(tangentVectorSILTy), loc); |
| break; |
| case AdjointValueKind::Concrete: |
| case AdjointValueKind::Aggregate: { |
| SmallVector<AdjointValue, 8> eltVals; |
| for (auto *field : tangentVectorDecl->getStoredProperties()) { |
| if (field == tanField) { |
| eltVals.push_back(av); |
| } else { |
| auto substMap = tangentVectorTy->getMemberSubstitutionMap( |
| field->getModuleContext(), field); |
| auto fieldTy = field->getType().subst(substMap); |
| auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType(); |
| assert(fieldSILTy.isObject()); |
| eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); |
| } |
| } |
| addAdjointValue(bb, sei->getOperand(), |
| makeAggregateAdjointValue(tangentVectorSILTy, eltVals), |
| loc); |
| } |
| } |
| break; |
| } |
| case SILValueCategory::Address: { |
| auto adjBase = getAdjointBuffer(bb, sei->getOperand()); |
| auto *adjBaseElt = |
| builder.createStructElementAddr(loc, adjBase, tanField); |
| // Check the `struct_extract`'s value tangent category. |
| switch (getTangentValueCategory(sei)) { |
| case SILValueCategory::Object: { |
| auto adjElt = getAdjointValue(bb, sei); |
| auto concreteAdjElt = materializeAdjointDirect(adjElt, loc); |
| auto concreteAdjEltCopy = |
| builder.emitCopyValueOperation(loc, concreteAdjElt); |
| auto *alloc = builder.createAllocStack(loc, adjElt.getType()); |
| builder.emitStoreValueOperation(loc, concreteAdjEltCopy, alloc, |
| StoreOwnershipQualifier::Init); |
| accumulateIndirect(adjBaseElt, alloc, loc); |
| builder.createDestroyAddr(loc, alloc); |
| builder.createDeallocStack(loc, alloc); |
| break; |
| } |
| case SILValueCategory::Address: { |
| auto adjElt = getAdjointBuffer(bb, sei); |
| accumulateIndirect(adjBaseElt, adjElt, loc); |
| break; |
| } |
| } |
| break; |
| } |
| } |
| } |
| |
| /// Handle `ref_element_addr` instruction. |
| /// Original: y = ref_element_addr x, <n> |
| /// Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0) |
| /// ^~~~~~~ |
| /// field in tangent space corresponding to #field |
| void visitRefElementAddrInst(RefElementAddrInst *reai) { |
| auto *bb = reai->getParent(); |
| auto loc = reai->getLoc(); |
| auto adjBuf = getAdjointBuffer(bb, reai); |
| auto classOperand = reai->getOperand(); |
| auto classType = remapType(reai->getOperand()->getType()).getASTType(); |
| auto *tanField = |
| getTangentStoredProperty(getContext(), reai, classType, getInvoker()); |
| assert(tanField && "Invalid projections should have been diagnosed"); |
| switch (getTangentValueCategory(classOperand)) { |
| case SILValueCategory::Object: { |
| auto classTy = remapType(classOperand->getType()).getASTType(); |
| auto tangentVectorTy = getTangentSpace(classTy)->getCanonicalType(); |
| auto tangentVectorSILTy = |
| SILType::getPrimitiveObjectType(tangentVectorTy); |
| auto *tangentVectorDecl = |
| tangentVectorTy->getStructOrBoundGenericStruct(); |
| // Accumulate adjoint for the `ref_element_addr` operand. |
| SmallVector<AdjointValue, 8> eltVals; |
| for (auto *field : tangentVectorDecl->getStoredProperties()) { |
| if (field == tanField) { |
| auto adjElt = builder.emitLoadValueOperation( |
| reai->getLoc(), adjBuf, LoadOwnershipQualifier::Copy); |
| eltVals.push_back(makeConcreteAdjointValue(adjElt)); |
| recordTemporary(adjElt); |
| } else { |
| auto substMap = tangentVectorTy->getMemberSubstitutionMap( |
| field->getModuleContext(), field); |
| auto fieldTy = field->getType().subst(substMap); |
| auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType(); |
| assert(fieldSILTy.isObject()); |
| eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); |
| } |
| } |
| addAdjointValue(bb, classOperand, |
| makeAggregateAdjointValue(tangentVectorSILTy, eltVals), |
| loc); |
| break; |
| } |
| case SILValueCategory::Address: { |
| auto adjBufClass = getAdjointBuffer(bb, classOperand); |
| auto adjBufElt = |
| builder.createStructElementAddr(loc, adjBufClass, tanField); |
| accumulateIndirect(adjBufElt, adjBuf, loc); |
| break; |
| } |
| } |
| } |
| |
| /// Handle `tuple` instruction. |
| /// Original: y = tuple (x0, x1, x2, ...) |
| /// Adjoint: (adj[x0], adj[x1], adj[x2], ...) += destructure_tuple adj[y] |
| /// ^~~ |
| /// excluding non-differentiable elements |
| void visitTupleInst(TupleInst *ti) { |
| auto *bb = ti->getParent(); |
| auto loc = ti->getLoc(); |
| switch (getTangentValueCategory(ti)) { |
| case SILValueCategory::Object: { |
| auto av = getAdjointValue(bb, ti); |
| switch (av.getKind()) { |
| case AdjointValueKind::Zero: |
| for (auto elt : ti->getElements()) { |
| if (!getTangentSpace(elt->getType().getASTType())) |
| continue; |
| addAdjointValue( |
| bb, elt, |
| makeZeroAdjointValue(getRemappedTangentType(elt->getType())), |
| loc); |
| } |
| break; |
| case AdjointValueKind::Concrete: { |
| auto adjVal = av.getConcreteValue(); |
| auto adjValCopy = builder.emitCopyValueOperation(loc, adjVal); |
| SmallVector<SILValue, 4> adjElts; |
| if (!adjVal->getType().getAs<TupleType>()) { |
| recordTemporary(adjValCopy); |
| adjElts.push_back(adjValCopy); |
| } else { |
| auto *dti = builder.createDestructureTuple(loc, adjValCopy); |
| for (auto adjElt : dti->getResults()) |
| recordTemporary(adjElt); |
| adjElts.append(dti->getResults().begin(), dti->getResults().end()); |
| } |
| // Accumulate adjoints for `tuple` operands, skipping the |
| // non-`Differentiable` ones. |
| unsigned adjIndex = 0; |
| for (auto i : range(ti->getNumOperands())) { |
| if (!getTangentSpace(ti->getOperand(i)->getType().getASTType())) |
| continue; |
| auto adjElt = adjElts[adjIndex++]; |
| addAdjointValue(bb, ti->getOperand(i), |
| makeConcreteAdjointValue(adjElt), loc); |
| } |
| break; |
| } |
| case AdjointValueKind::Aggregate: |
| unsigned adjIndex = 0; |
| for (auto i : range(ti->getElements().size())) { |
| if (!getTangentSpace(ti->getElement(i)->getType().getASTType())) |
| continue; |
| addAdjointValue(bb, ti->getElement(i), |
| av.getAggregateElement(adjIndex++), loc); |
| } |
| break; |
| } |
| break; |
| } |
| case SILValueCategory::Address: { |
| auto adjBuf = getAdjointBuffer(bb, ti); |
| // Accumulate adjoints for `tuple` operands, skipping the |
| // non-`Differentiable` ones. |
| unsigned adjIndex = 0; |
| for (auto i : range(ti->getNumOperands())) { |
| if (!getTangentSpace(ti->getOperand(i)->getType().getASTType())) |
| continue; |
| auto adjBufElt = |
| builder.createTupleElementAddr(loc, adjBuf, adjIndex++); |
| auto adjElt = getAdjointBuffer(bb, ti->getOperand(i)); |
| accumulateIndirect(adjElt, adjBufElt, loc); |
| } |
| break; |
| } |
| } |
| } |
| |
| /// Handle `tuple_extract` instruction. |
| /// Original: y = tuple_extract x, <n> |
| /// Adjoint: adj[x] += tuple (0, 0, ..., adj[y], ..., 0, 0) |
| /// ^~~~~~ |
| /// n'-th element, where n' is tuple tangent space |
| /// index corresponding to n |
| void visitTupleExtractInst(TupleExtractInst *tei) { |
| auto *bb = tei->getParent(); |
| auto tupleTanTy = getRemappedTangentType(tei->getOperand()->getType()); |
| auto av = getAdjointValue(bb, tei); |
| switch (av.getKind()) { |
| case AdjointValueKind::Zero: |
| addAdjointValue(bb, tei->getOperand(), makeZeroAdjointValue(tupleTanTy), |
| tei->getLoc()); |
| break; |
| case AdjointValueKind::Aggregate: |
| case AdjointValueKind::Concrete: { |
| auto tupleTy = tei->getTupleType(); |
| auto tupleTanTupleTy = tupleTanTy.getAs<TupleType>(); |
| if (!tupleTanTupleTy) { |
| addAdjointValue(bb, tei->getOperand(), av, tei->getLoc()); |
| break; |
| } |
| SmallVector<AdjointValue, 8> elements; |
| unsigned adjIdx = 0; |
| for (unsigned i : range(tupleTy->getNumElements())) { |
| if (!getTangentSpace( |
| tupleTy->getElement(i).getType()->getCanonicalType())) |
| continue; |
| if (tei->getFieldIndex() == i) |
| elements.push_back(av); |
| else |
| elements.push_back(makeZeroAdjointValue( |
| getRemappedTangentType(SILType::getPrimitiveObjectType( |
| tupleTanTupleTy->getElementType(adjIdx++) |
| ->getCanonicalType())))); |
| } |
| if (elements.size() == 1) { |
| addAdjointValue(bb, tei->getOperand(), elements.front(), tei->getLoc()); |
| break; |
| } |
| addAdjointValue(bb, tei->getOperand(), |
| makeAggregateAdjointValue(tupleTanTy, elements), |
| tei->getLoc()); |
| break; |
| } |
| } |
| } |
| |
| /// Handle `destructure_tuple` instruction. |
| /// Original: (y0, ..., yn) = destructure_tuple x |
| /// Adjoint: adj[x].0 += adj[y0] |
| /// ... |
| /// adj[x].n += adj[yn] |
| void visitDestructureTupleInst(DestructureTupleInst *dti) { |
| auto *bb = dti->getParent(); |
| auto loc = dti->getLoc(); |
| auto tupleTanTy = getRemappedTangentType(dti->getOperand()->getType()); |
| // Check the `destructure_tuple` operand's value tangent category. |
| switch (getTangentValueCategory(dti->getOperand())) { |
| case SILValueCategory::Object: { |
| SmallVector<AdjointValue, 8> adjValues; |
| for (auto origElt : dti->getResults()) { |
| // Skip non-`Differentiable` tuple elements. |
| if (!getTangentSpace(remapType(origElt->getType()).getASTType())) |
| continue; |
| adjValues.push_back(getAdjointValue(bb, origElt)); |
| } |
| // Handle tuple tangent type. |
| // Add adjoints for every tuple element that has a tangent space. |
| if (tupleTanTy.is<TupleType>()) { |
| assert(adjValues.size() > 1); |
| addAdjointValue(bb, dti->getOperand(), |
| makeAggregateAdjointValue(tupleTanTy, adjValues), loc); |
| } |
| // Handle non-tuple tangent type. |
| // Add adjoint for the single tuple element that has a tangent space. |
| else { |
| assert(adjValues.size() == 1); |
| addAdjointValue(bb, dti->getOperand(), adjValues.front(), loc); |
| } |
| break; |
| } |
| case SILValueCategory::Address: { |
| auto adjBuf = getAdjointBuffer(bb, dti->getOperand()); |
| unsigned adjIndex = 0; |
| for (auto origElt : dti->getResults()) { |
| // Skip non-`Differentiable` tuple elements. |
| if (!getTangentSpace(remapType(origElt->getType()).getASTType())) |
| continue; |
| // Handle tuple tangent type. |
| // Add adjoints for every tuple element that has a tangent space. |
| if (tupleTanTy.is<TupleType>()) { |
| auto adjEltBuf = getAdjointBuffer(bb, origElt); |
| auto adjBufElt = |
| builder.createTupleElementAddr(loc, adjBuf, adjIndex); |
| accumulateIndirect(adjBufElt, adjEltBuf, loc); |
| } |
| // Handle non-tuple tangent type. |
| // Add adjoint for the single tuple element that has a tangent space. |
| else { |
| auto adjEltBuf = getAdjointBuffer(bb, origElt); |
| addToAdjointBuffer(bb, dti->getOperand(), adjEltBuf, loc); |
| } |
| ++adjIndex; |
| } |
| break; |
| } |
| } |
| } |
| |
| /// Handle `load` or `load_borrow` instruction |
| /// Original: y = load/load_borrow x |
| /// Adjoint: adj[x] += adj[y] |
| void visitLoadOperation(SingleValueInstruction *inst) { |
| assert(isa<LoadInst>(inst) || isa<LoadBorrowInst>(inst)); |
| auto *bb = inst->getParent(); |
| auto loc = inst->getLoc(); |
| switch (getTangentValueCategory(inst)) { |
| case SILValueCategory::Object: { |
| auto adjVal = materializeAdjointDirect(getAdjointValue(bb, inst), loc); |
| // Allocate a local buffer and store the adjoint value. This buffer will |
| // be used for accumulation into the adjoint buffer. |
| auto adjBuf = builder.createAllocStack(loc, adjVal->getType(), SILDebugVariable()); |
| auto copy = builder.emitCopyValueOperation(loc, adjVal); |
| builder.emitStoreValueOperation(loc, copy, adjBuf, |
| StoreOwnershipQualifier::Init); |
| // Accumulate the adjoint value in the local buffer into the adjoint |
| // buffer. |
| addToAdjointBuffer(bb, inst->getOperand(0), adjBuf, loc); |
| builder.emitDestroyAddr(loc, adjBuf); |
| builder.createDeallocStack(loc, adjBuf); |
| break; |
| } |
| case SILValueCategory::Address: { |
| auto adjBuf = getAdjointBuffer(bb, inst); |
| addToAdjointBuffer(bb, inst->getOperand(0), adjBuf, loc); |
| break; |
| } |
| } |
| } |
| void visitLoadInst(LoadInst *li) { visitLoadOperation(li); } |
| void visitLoadBorrowInst(LoadBorrowInst *lbi) { visitLoadOperation(lbi); } |
| |
| /// Handle `store` or `store_borrow` instruction. |
| /// Original: store/store_borrow x to y |
| /// Adjoint: adj[x] += load adj[y]; adj[y] = 0 |
| void visitStoreOperation(SILBasicBlock *bb, SILLocation loc, SILValue origSrc, |
| SILValue origDest) { |
| auto adjBuf = getAdjointBuffer(bb, origDest); |
| switch (getTangentValueCategory(origSrc)) { |
| case SILValueCategory::Object: { |
| auto adjVal = builder.emitLoadValueOperation( |
| loc, adjBuf, LoadOwnershipQualifier::Take); |
| recordTemporary(adjVal); |
| addAdjointValue(bb, origSrc, makeConcreteAdjointValue(adjVal), loc); |
| emitZeroIndirect(adjBuf->getType().getASTType(), adjBuf, loc); |
| break; |
| } |
| case SILValueCategory::Address: { |
| addToAdjointBuffer(bb, origSrc, adjBuf, loc); |
| builder.emitDestroyAddr(loc, adjBuf); |
| emitZeroIndirect(adjBuf->getType().getASTType(), adjBuf, loc); |
| break; |
| } |
| } |
| } |
| void visitStoreInst(StoreInst *si) { |
| visitStoreOperation(si->getParent(), si->getLoc(), si->getSrc(), |
| si->getDest()); |
| } |
| void visitStoreBorrowInst(StoreBorrowInst *sbi) { |
| visitStoreOperation(sbi->getParent(), sbi->getLoc(), sbi->getSrc(), |
| sbi->getDest()); |
| } |
| |
| /// Handle `copy_addr` instruction. |
| /// Original: copy_addr x to y |
| /// Adjoint: adj[x] += adj[y]; adj[y] = 0 |
| void visitCopyAddrInst(CopyAddrInst *cai) { |
| auto *bb = cai->getParent(); |
| auto adjDest = getAdjointBuffer(bb, cai->getDest()); |
| auto destType = remapType(adjDest->getType()); |
| addToAdjointBuffer(bb, cai->getSrc(), adjDest, cai->getLoc()); |
| builder.emitDestroyAddrAndFold(cai->getLoc(), adjDest); |
| emitZeroIndirect(destType.getASTType(), adjDest, cai->getLoc()); |
| } |
| |
| /// Handle `copy_value` instruction. |
| /// Original: y = copy_value x |
| /// Adjoint: adj[x] += adj[y] |
| void visitCopyValueInst(CopyValueInst *cvi) { |
| auto *bb = cvi->getParent(); |
| switch (getTangentValueCategory(cvi)) { |
| case SILValueCategory::Object: { |
| auto adj = getAdjointValue(bb, cvi); |
| addAdjointValue(bb, cvi->getOperand(), adj, cvi->getLoc()); |
| break; |
| } |
| case SILValueCategory::Address: { |
| auto adjDest = getAdjointBuffer(bb, cvi); |
| auto destType = remapType(adjDest->getType()); |
| addToAdjointBuffer(bb, cvi->getOperand(), adjDest, cvi->getLoc()); |
| builder.emitDestroyAddrAndFold(cvi->getLoc(), adjDest); |
| emitZeroIndirect(destType.getASTType(), adjDest, cvi->getLoc()); |
| break; |
| } |
| } |
| } |
| |
| /// Handle `begin_borrow` instruction. |
| /// Original: y = begin_borrow x |
| /// Adjoint: adj[x] += adj[y] |
| void visitBeginBorrowInst(BeginBorrowInst *bbi) { |
| auto *bb = bbi->getParent(); |
| switch (getTangentValueCategory(bbi)) { |
| case SILValueCategory::Object: { |
| auto adj = getAdjointValue(bb, bbi); |
| addAdjointValue(bb, bbi->getOperand(), adj, bbi->getLoc()); |
| break; |
| } |
| case SILValueCategory::Address: { |
| auto adjDest = getAdjointBuffer(bb, bbi); |
| auto destType = remapType(adjDest->getType()); |
| addToAdjointBuffer(bb, bbi->getOperand(), adjDest, bbi->getLoc()); |
| builder.emitDestroyAddrAndFold(bbi->getLoc(), adjDest); |
| emitZeroIndirect(destType.getASTType(), adjDest, bbi->getLoc()); |
| break; |
| } |
| } |
| } |
| |
| /// Handle `begin_access` instruction. |
| /// Original: y = begin_access x |
| /// Adjoint: nothing |
| void visitBeginAccessInst(BeginAccessInst *bai) { |
| // Check for non-differentiable writes. |
| if (bai->getAccessKind() == SILAccessKind::Modify) { |
| if (isa<GlobalAddrInst>(bai->getSource())) { |
| getContext().emitNondifferentiabilityError( |
| bai, getInvoker(), |
| diag::autodiff_cannot_differentiate_writes_to_global_variables); |
| errorOccurred = true; |
| return; |
| } |
| if (isa<ProjectBoxInst>(bai->getSource())) { |
| getContext().emitNondifferentiabilityError( |
| bai, getInvoker(), |
| diag::autodiff_cannot_differentiate_writes_to_mutable_captures); |
| errorOccurred = true; |
| return; |
| } |
| } |
| } |
| |
| /// Handle `unconditional_checked_cast_addr` instruction. |
| /// Original: y = unconditional_checked_cast_addr x |
| /// Adjoint: adj[x] += unconditional_checked_cast_addr adj[y] |
| void visitUnconditionalCheckedCastAddrInst( |
| UnconditionalCheckedCastAddrInst *uccai) { |
| auto *bb = uccai->getParent(); |
| auto adjDest = getAdjointBuffer(bb, uccai->getDest()); |
| auto adjSrc = getAdjointBuffer(bb, uccai->getSrc()); |
| auto destType = remapType(adjDest->getType()); |
| auto castBuf = builder.createAllocStack(uccai->getLoc(), adjSrc->getType()); |
| builder.createUnconditionalCheckedCastAddr( |
| uccai->getLoc(), adjDest, adjDest->getType().getASTType(), castBuf, |
| adjSrc->getType().getASTType()); |
| addToAdjointBuffer(bb, uccai->getSrc(), castBuf, uccai->getLoc()); |
| builder.emitDestroyAddrAndFold(uccai->getLoc(), castBuf); |
| builder.createDeallocStack(uccai->getLoc(), castBuf); |
| emitZeroIndirect(destType.getASTType(), adjDest, uccai->getLoc()); |
| } |
| |
| /// Handle `unchecked_ref_cast` instruction. |
| /// Original: y = unchecked_ref_cast x |
| /// Adjoint: adj[x] += adj[y] |
| /// (assuming adj[x] and adj[y] have the same type) |
| void visitUncheckedRefCastInst(UncheckedRefCastInst *urci) { |
| auto *bb = urci->getParent(); |
| assert(urci->getOperand()->getType().isObject()); |
| assert(getRemappedTangentType(urci->getOperand()->getType()) == |
| getRemappedTangentType(urci->getType()) && |
| "Operand/result must have the same `TangentVector` type"); |
| switch (getTangentValueCategory(urci)) { |
| case SILValueCategory::Object: { |
| auto adj = getAdjointValue(bb, urci); |
| addAdjointValue(bb, urci->getOperand(), adj, urci->getLoc()); |
| break; |
| } |
| case SILValueCategory::Address: { |
| auto adjDest = getAdjointBuffer(bb, urci); |
| auto destType = remapType(adjDest->getType()); |
| addToAdjointBuffer(bb, urci->getOperand(), adjDest, urci->getLoc()); |
| builder.emitDestroyAddrAndFold(urci->getLoc(), adjDest); |
| emitZeroIndirect(destType.getASTType(), adjDest, urci->getLoc()); |
| break; |
| } |
| } |
| } |
| |
| /// Handle `upcast` instruction. |
| /// Original: y = upcast x |
| /// Adjoint: adj[x] += adj[y] |
| /// (assuming adj[x] and adj[y] have the same type) |
| void visitUpcastInst(UpcastInst *ui) { |
| auto *bb = ui->getParent(); |
| assert(ui->getOperand()->getType().isObject()); |
| assert(getRemappedTangentType(ui->getOperand()->getType()) == |
| getRemappedTangentType(ui->getType()) && |
| "Operand/result must have the same `TangentVector` type"); |
| switch (getTangentValueCategory(ui)) { |
| case SILValueCategory::Object: { |
| auto adj = getAdjointValue(bb, ui); |
| addAdjointValue(bb, ui->getOperand(), adj, ui->getLoc()); |
| break; |
| } |
| case SILValueCategory::Address: { |
| auto adjDest = getAdjointBuffer(bb, ui); |
| auto destType = remapType(adjDest->getType()); |
| addToAdjointBuffer(bb, ui->getOperand(), adjDest, ui->getLoc()); |
| builder.emitDestroyAddrAndFold(ui->getLoc(), adjDest); |
| emitZeroIndirect(destType.getASTType(), adjDest, ui->getLoc()); |
| break; |
| } |
| } |
| } |
| |
| /// Handle `unchecked_take_enum_data_addr` instruction. |
| /// Currently, only `Optional`-typed operands are supported. |
| /// Original: y = unchecked_take_enum_data_addr x : $*Enum, #Enum.Case |
| /// Adjoint: adj[x] += $Enum.TangentVector(adj[y]) |
| void |
| visitUncheckedTakeEnumDataAddrInst(UncheckedTakeEnumDataAddrInst *utedai) { |
| auto *bb = utedai->getParent(); |
| auto adjBuf = getAdjointBuffer(bb, utedai); |
| auto enumTy = utedai->getOperand()->getType(); |
| auto *optionalEnumDecl = getASTContext().getOptionalDecl(); |
| // Only `Optional`-typed operands are supported for now. Diagnose all other |
| // enum operand types. |
| if (enumTy.getASTType().getEnumOrBoundGenericEnum() != optionalEnumDecl) { |
| LLVM_DEBUG(getADDebugStream() |
| << "Unhandled instruction in PullbackCloner: " << *utedai); |
| getContext().emitNondifferentiabilityError( |
| utedai, getInvoker(), |
| diag::autodiff_expression_not_differentiable_note); |
| errorOccurred = true; |
| return; |
| } |
| accumulateAdjointForOptional(bb, utedai->getOperand(), adjBuf); |
| } |
| |
| #define NOT_DIFFERENTIABLE(INST, DIAG) void visit##INST##Inst(INST##Inst *inst); |
| #undef NOT_DIFFERENTIABLE |
| |
| #define NO_ADJOINT(INST) \ |
| void visit##INST##Inst(INST##Inst *inst) {} |
| // Terminators. |
| NO_ADJOINT(Return) |
| NO_ADJOINT(Branch) |
| NO_ADJOINT(CondBranch) |
| |
| // Address projections. |
| NO_ADJOINT(StructElementAddr) |
| NO_ADJOINT(TupleElementAddr) |
| |
| // Array literal initialization address projections. |
| NO_ADJOINT(PointerToAddress) |
| NO_ADJOINT(IndexAddr) |
| |
| // Memory allocation/access. |
| NO_ADJOINT(AllocStack) |
| NO_ADJOINT(DeallocStack) |
| NO_ADJOINT(EndAccess) |
| |
| // Debugging/reference counting instructions. |
| NO_ADJOINT(DebugValue) |
| NO_ADJOINT(DebugValueAddr) |
| NO_ADJOINT(RetainValue) |
| NO_ADJOINT(RetainValueAddr) |
| NO_ADJOINT(ReleaseValue) |
| NO_ADJOINT(ReleaseValueAddr) |
| NO_ADJOINT(StrongRetain) |
| NO_ADJOINT(StrongRelease) |
| NO_ADJOINT(UnownedRetain) |
| NO_ADJOINT(UnownedRelease) |
| NO_ADJOINT(StrongRetainUnowned) |
| NO_ADJOINT(DestroyValue) |
| NO_ADJOINT(DestroyAddr) |
| |
| // Value ownership. |
| NO_ADJOINT(EndBorrow) |
| #undef NO_ADJOINT |
| }; |
| |
| PullbackCloner::Implementation::Implementation(VJPCloner &vjpCloner) |
| : vjpCloner(vjpCloner), builder(getPullback()), |
| localAllocBuilder(getPullback()) { |
| // Get dominance and post-order info for the original function. |
| auto &passManager = getContext().getPassManager(); |
| auto *domAnalysis = passManager.getAnalysis<DominanceAnalysis>(); |
| auto *postDomAnalysis = passManager.getAnalysis<PostDominanceAnalysis>(); |
| auto *postOrderAnalysis = passManager.getAnalysis<PostOrderAnalysis>(); |
| auto *original = &vjpCloner.getOriginal(); |
| domInfo = domAnalysis->get(original); |
| postDomInfo = postDomAnalysis->get(original); |
| postOrderInfo = postOrderAnalysis->get(original); |
| } |
| |
| PullbackCloner::PullbackCloner(VJPCloner &vjpCloner) |
| : impl(*new Implementation(vjpCloner)) {} |
| |
| PullbackCloner::~PullbackCloner() { delete &impl; } |
| |
| //--------------------------------------------------------------------------// |
| // Entry point |
| //--------------------------------------------------------------------------// |
| |
| bool PullbackCloner::run() { |
| bool foundError = impl.run(); |
| #ifndef NDEBUG |
| if (!foundError) |
| impl.getPullback().verify(); |
| #endif |
| return foundError; |
| } |
| |
| bool PullbackCloner::Implementation::run() { |
| PrettyStackTraceSILFunction trace("generating pullback for", &getOriginal()); |
| auto &original = getOriginal(); |
| auto &pullback = getPullback(); |
| auto pbLoc = getPullback().getLocation(); |
| LLVM_DEBUG(getADDebugStream() << "Running PullbackCloner on\n" << original); |
| |
| auto origExitIt = original.findReturnBB(); |
| assert(origExitIt != original.end() && |
| "Functions without returns must have been diagnosed"); |
| auto *origExit = &*origExitIt; |
| |
| // Collect original formal results. |
| SmallVector<SILValue, 8> origFormalResults; |
| collectAllFormalResultsInTypeOrder(original, origFormalResults); |
| for (auto resultIndex : getConfig().resultIndices->getIndices()) { |
| auto origResult = origFormalResults[resultIndex]; |
| // If original result is non-varied, it will always have a zero derivative. |
| // Skip full pullback generation and simply emit zero derivatives for wrt |
| // parameters. |
| // |
| // NOTE(TF-876): This shortcut is currently necessary for functions |
| // returning non-varied result with >1 basic block where some basic blocks |
| // have no dominated active values; control flow differentiation does not |
| // handle this case. See TF-876 for context. |
| if (!getActivityInfo().isVaried(origResult, getConfig().parameterIndices)) { |
| emitZeroDerivativesForNonvariedResult(origResult); |
| return false; |
| } |
| } |
| |
| // Collect dominated active values in original basic blocks. |
| // Adjoint values of dominated active values are passed as pullback block |
| // arguments. |
| DominanceOrder domOrder(original.getEntryBlock(), domInfo); |
| // Keep track of visited values. |
| SmallPtrSet<SILValue, 8> visited; |
| while (auto *bb = domOrder.getNext()) { |
| auto &bbActiveValues = activeValues[bb]; |
| // If the current block has an immediate dominator, append the immediate |
| // dominator block's active values to the current block's active values. |
| if (auto *domNode = domInfo->getNode(bb)->getIDom()) { |
| auto &domBBActiveValues = activeValues[domNode->getBlock()]; |
| bbActiveValues.append(domBBActiveValues.begin(), domBBActiveValues.end()); |
| } |
| // If `v` is active and has not been visited, records it as an active value |
| // in the original basic block. |
| // For active values unsupported by differentiation, emits a diagnostic and |
| // returns true. Otherwise, returns false. |
| auto recordValueIfActive = [&](SILValue v) -> bool { |
| // If value is not active, skip. |
| if (!getActivityInfo().isActive(v, getConfig())) |
| return false; |
| // If active value has already been visited, skip. |
| if (visited.count(v)) |
| return false; |
| // Mark active value as visited. |
| visited.insert(v); |
| |
| // Diagnose unsupported active values. |
| auto type = v->getType(); |
| // Do not emit remaining activity-related diagnostics for semantic member |
| // accessors, which have special-case pullback generation. |
| if (isSemanticMemberAccessor(&original)) |
| return false; |
| // Diagnose active enum values. Differentiation of enum values requires |
| // special adjoint value handling and is not yet supported. Diagnose |
| // only the first active enum value to prevent too many diagnostics. |
| // |
| // Do not diagnose `Optional`-typed values, which will have special-case |
| // differentiation support. |
| if (auto *enumDecl = type.getEnumOrBoundGenericEnum()) { |
| if (enumDecl != getContext().getASTContext().getOptionalDecl()) { |
| getContext().emitNondifferentiabilityError( |
| v, getInvoker(), diag::autodiff_enums_unsupported); |
| errorOccurred = true; |
| return true; |
| } |
| } |
| // Diagnose unsupported stored property projections. |
| if (isa<StructExtractInst>(v) || isa<RefElementAddrInst>(v) || |
| isa<StructElementAddrInst>(v)) { |
| auto *inst = cast<SingleValueInstruction>(v); |
| assert(inst->getNumOperands() == 1); |
| auto baseType = remapType(inst->getOperand(0)->getType()).getASTType(); |
| if (!getTangentStoredProperty(getContext(), inst, baseType, |
| getInvoker())) { |
| errorOccurred = true; |
| return true; |
| } |
| } |
| // Skip address projections. |
| // Address projections do not need their own adjoint buffers; they |
| // become projections into their adjoint base buffer. |
| if (Projection::isAddressProjection(v)) |
| return false; |
| // Record active value. |
| bbActiveValues.push_back(v); |
| return false; |
| }; |
| // Record all active values in the basic block. |
| for (auto *arg : bb->getArguments()) |
| if (recordValueIfActive(arg)) |
| return true; |
| for (auto &inst : *bb) { |
| for (auto op : inst.getOperandValues()) |
| if (recordValueIfActive(op)) |
| return true; |
| for (auto result : inst.getResults()) |
| if (recordValueIfActive(result)) |
| return true; |
| } |
| domOrder.pushChildren(bb); |
| } |
| |
| // Create pullback blocks and arguments, visiting original blocks using BFS |
| // starting from the original exit block. Unvisited original basic blocks |
| // (e.g unreachable blocks) are not relevant for pullback generation and thus |
| // ignored. |
| // The original blocks in traversal order for pullback generation. |
| SmallVector<SILBasicBlock *, 8> originalBlocks; |
| // The set of visited original blocks. |
| SmallDenseSet<SILBasicBlock *, 8> visitedBlocks; |
| |
| // Perform BFS from the original exit block. |
| { |
| std::deque<SILBasicBlock *> worklist = {}; |
| worklist.push_back(origExit); |
| visitedBlocks.insert(origExit); |
| while (!worklist.empty()) { |
| auto *BB = worklist.front(); |
| worklist.pop_front(); |
| |
| originalBlocks.push_back(BB); |
| |
| for (auto *nextBB : BB->getPredecessorBlocks()) { |
| if (!visitedBlocks.count(nextBB)) { |
| worklist.push_back(nextBB); |
| visitedBlocks.insert(nextBB); |
| } |
| } |
| } |
| } |
| |
| for (auto *origBB : originalBlocks) { |
| auto *pullbackBB = pullback.createBasicBlock(); |
| pullbackBBMap.insert({origBB, pullbackBB}); |
| auto pbStructLoweredType = |
| remapType(getPullbackInfo().getLinearMapStructLoweredType(origBB)); |
| // If the BB is the original exit, then the pullback block that we just |
| // created must be the pullback function's entry. For the pullback entry, |
| // create entry arguments and continue to the next block. |
| if (origBB == origExit) { |
| assert(pullbackBB->isEntry()); |
| createEntryArguments(&pullback); |
| builder.setInsertionPoint(pullbackBB); |
| // Obtain the context object, if any, and the top-level subcontext, i.e. |
| // the main pullback struct. |
| SILValue mainPullbackStruct; |
| if (getPullbackInfo().hasLoops()) { |
| // The last argument is the context object (`Builtin.NativeObject`). |
| contextValue = pullbackBB->getArguments().back(); |
| assert(contextValue->getType() == |
| SILType::getNativeObjectType(getASTContext())); |
| // Load the pullback struct. |
| auto subcontextAddr = emitProjectTopLevelSubcontext( |
| builder, pbLoc, contextValue, pbStructLoweredType); |
| mainPullbackStruct = builder.createLoad( |
| pbLoc, subcontextAddr, |
| pbStructLoweredType.isTrivial(getPullback()) ? |
| LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Take); |
| } else { |
| // Obtain and destructure pullback struct elements. |
| mainPullbackStruct = pullbackBB->getArguments().back(); |
| assert(mainPullbackStruct->getType() == pbStructLoweredType); |
| } |
| |
| auto *dsi = builder.createDestructureStruct(pbLoc, mainPullbackStruct); |
| initializePullbackStructElements(origBB, dsi->getResults()); |
| continue; |
| } |
| // Get all active values in the original block. |
| // If the original block has no active values, continue. |
| auto &bbActiveValues = activeValues[origBB]; |
| if (bbActiveValues.empty()) |
| continue; |
| // Otherwise, if the original block has active values: |
| // - For each active buffer in the original block, allocate a new local |
| // buffer in the pullback entry. (All adjoint buffers are allocated in |
| // the pullback entry and deallocated in the pullback exit.) |
| // - For each active value in the original block, add adjoint value |
| // arguments to the pullback block. |
| for (auto activeValue : bbActiveValues) { |
| switch (getTangentValueCategory(activeValue)) { |
| case SILValueCategory::Address: { |
| // Allocate and zero initialize a new local buffer using |
| // `getAdjointBuffer`. |
| builder.setInsertionPoint(pullback.getEntryBlock()); |
| getAdjointBuffer(origBB, activeValue); |
| break; |
| } |
| case SILValueCategory::Object: { |
| // Create and register pullback block argument for the active value. |
| auto *pullbackArg = pullbackBB->createPhiArgument( |
| getRemappedTangentType(activeValue->getType()), |
| OwnershipKind::Owned); |
| activeValuePullbackBBArgumentMap[{origBB, activeValue}] = pullbackArg; |
| recordTemporary(pullbackArg); |
| break; |
| } |
| } |
| } |
| // Add a pullback struct argument. |
| auto *pbStructArg = pullbackBB->createPhiArgument(pbStructLoweredType, |
| OwnershipKind::Owned); |
| // Destructure the pullback struct to get the elements. |
| builder.setInsertionPoint(pullbackBB); |
| auto *dsi = builder.createDestructureStruct(pbLoc, pbStructArg); |
| initializePullbackStructElements(origBB, dsi->getResults()); |
| |
| // - Create pullback trampoline blocks for each successor block of the |
| // original block. Pullback trampoline blocks only have a pullback |
| // struct argument. They branch from a pullback successor block to the |
| // pullback original block, passing adjoint values of active values. |
| for (auto *succBB : origBB->getSuccessorBlocks()) { |
| // Skip generating pullback block for original unreachable blocks. |
| if (!visitedBlocks.count(succBB)) |
| continue; |
| auto *pullbackTrampolineBB = pullback.createBasicBlockBefore(pullbackBB); |
| pullbackTrampolineBBMap.insert({{origBB, succBB}, pullbackTrampolineBB}); |
| // Get the enum element type (i.e. the pullback struct type). The enum |
| // element type may be boxed if the enum is indirect. |
| auto enumLoweredTy = |
| getPullbackInfo().getBranchingTraceEnumLoweredType(succBB); |
| auto *enumEltDecl = |
| getPullbackInfo().lookUpBranchingTraceEnumElement(origBB, succBB); |
| auto enumEltType = remapType(enumLoweredTy.getEnumElementType( |
| enumEltDecl, getModule(), TypeExpansionContext::minimal())); |
| pullbackTrampolineBB->createPhiArgument(enumEltType, |
| OwnershipKind::Owned); |
| } |
| } |
| |
| auto *pullbackEntry = pullback.getEntryBlock(); |
| // The pullback function has type: |
| // `(seed0, seed1, ..., exit_pb_struct|context_obj) -> (d_arg0, ..., d_argn)`. |
| auto pbParamArgs = pullback.getArgumentsWithoutIndirectResults(); |
| assert(getConfig().resultIndices->getNumIndices() == pbParamArgs.size() - 1 && |
| pbParamArgs.size() >= 2); |
| // Assign adjoints for original result. |
| builder.setInsertionPoint(pullbackEntry, |
| getNextFunctionLocalAllocationInsertionPoint()); |
| unsigned seedIndex = 0; |
| for (auto resultIndex : getConfig().resultIndices->getIndices()) { |
| auto origResult = origFormalResults[resultIndex]; |
| auto *seed = pbParamArgs[seedIndex]; |
| if (seed->getType().isAddress()) { |
| // If the seed argument is an `inout` parameter, assign it directly as |
| // the adjoint buffer of the original result. |
| auto seedParamInfo = |
| pullback.getLoweredFunctionType()->getParameters()[seedIndex]; |
| if (seedParamInfo.isIndirectInOut()) { |
| setAdjointBuffer(origExit, origResult, seed); |
| } |
| // Otherwise, assign a copy of the seed argument as the adjoint buffer of |
| // the original result. |
| else { |
| auto *seedBufCopy = |
| createFunctionLocalAllocation(seed->getType(), pbLoc); |
| builder.createCopyAddr(pbLoc, seed, seedBufCopy, IsNotTake, |
| IsInitialization); |
| setAdjointBuffer(origExit, origResult, seedBufCopy); |
| LLVM_DEBUG(getADDebugStream() |
| << "Assigned seed buffer " << *seedBufCopy |
| << " as the adjoint of original indirect result " |
| << origResult); |
| } |
| } else { |
| addAdjointValue(origExit, origResult, makeConcreteAdjointValue(seed), |
| pbLoc); |
| LLVM_DEBUG(getADDebugStream() |
| << "Assigned seed " << *seed |
| << " as the adjoint of original result " << origResult); |
| } |
| ++seedIndex; |
| } |
| |
| // If the original function is an accessor with special-case pullback |
| // generation logic, do special-case generation. |
| if (isSemanticMemberAccessor(&original)) { |
| if (runForSemanticMemberAccessor()) |
| return true; |
| } |
| // Otherwise, perform standard pullback generation. |
| // Visit original blocks blocks in post-order and perform differentiation |
| // in corresponding pullback blocks. If errors occurred, back out. |
| else { |
| for (auto *bb : originalBlocks) { |
| visitSILBasicBlock(bb); |
| if (errorOccurred) |
| return true; |
| } |
| } |
| |
| // Prepare and emit a `return` in the pullback exit block. |
| auto *origEntry = getOriginal().getEntryBlock(); |
| auto *pbExit = getPullbackBlock(origEntry); |
| builder.setInsertionPoint(pbExit); |
| |
| // This vector will contain all the materialized return elements. |
| SmallVector<SILValue, 8> retElts; |
| // This vector will contain all indirect parameter adjoint buffers. |
| SmallVector<SILValue, 4> indParamAdjoints; |
| |
| auto conv = getOriginal().getConventions(); |
| auto origParams = getOriginal().getArgumentsWithoutIndirectResults(); |
| |
| // Materializes the return element corresponding to the parameter |
| // `parameterIndex` into the `retElts` vector. |
| auto addRetElt = [&](unsigned parameterIndex) -> void { |
| auto origParam = origParams[parameterIndex]; |
| switch (getTangentValueCategory(origParam)) { |
| case SILValueCategory::Object: { |
| auto pbVal = getAdjointValue(origEntry, origParam); |
| auto val = materializeAdjointDirect(pbVal, pbLoc); |
| auto newVal = builder.emitCopyValueOperation(pbLoc, val); |
| retElts.push_back(newVal); |
| break; |
| } |
| case SILValueCategory::Address: { |
| auto adjBuf = getAdjointBuffer(origEntry, origParam); |
| indParamAdjoints.push_back(adjBuf); |
| break; |
| } |
| } |
| }; |
| // Collect differentiation parameter adjoints. |
| for (auto i : getConfig().parameterIndices->getIndices()) { |
| // Skip `inout` parameters. |
| if (conv.getParameters()[i].isIndirectMutating()) |
| continue; |
| addRetElt(i); |
| } |
| |
| // Copy them to adjoint indirect results. |
| assert(indParamAdjoints.size() == getPullback().getIndirectResults().size() && |
| "Indirect parameter adjoint count mismatch"); |
| for (auto pair : zip(indParamAdjoints, getPullback().getIndirectResults())) { |
| auto source = std::get<0>(pair); |
| auto *dest = std::get<1>(pair); |
| builder.createCopyAddr(pbLoc, source, dest, IsTake, IsInitialization); |
| // Prevent source buffer from being deallocated, since the underlying |
| // value is moved. |
| destroyedLocalAllocations.insert(source); |
| } |
| |
| // Emit cleanups for all local values. |
| cleanUpTemporariesForBlock(pbExit, pbLoc); |
| // Deallocate local allocations. |
| for (auto alloc : functionLocalAllocations) { |
| // Assert that local allocations have at least one use. |
| // Buffers should not be allocated needlessly. |
| assert(!alloc->use_empty()); |
| if (!destroyedLocalAllocations.count(alloc)) { |
| builder.emitDestroyAddrAndFold(pbLoc, alloc); |
| destroyedLocalAllocations.insert(alloc); |
| } |
| builder.createDeallocStack(pbLoc, alloc); |
| } |
| builder.createReturn(pbLoc, joinElements(retElts, builder, pbLoc)); |
| |
| #ifndef NDEBUG |
| bool leakFound = false; |
| // Ensure all temporaries have been cleaned up. |
| for (auto &bb : pullback) { |
| for (auto temp : blockTemporaries[&bb]) { |
| if (blockTemporaries[&bb].count(temp)) { |
| leakFound = true; |
| getADDebugStream() << "Found leaked temporary:\n" << temp; |
| } |
| } |
| } |
| // Ensure all local allocations have been cleaned up. |
| for (auto localAlloc : functionLocalAllocations) { |
| if (!destroyedLocalAllocations.count(localAlloc)) { |
| leakFound = true; |
| getADDebugStream() << "Found leaked local buffer:\n" << localAlloc; |
| } |
| } |
| assert(!leakFound && "Leaks found!"); |
| #endif |
| |
| LLVM_DEBUG(getADDebugStream() |
| << "Generated pullback for " << original.getName() << ":\n" |
| << pullback); |
| return errorOccurred; |
| } |
| |
| void PullbackCloner::Implementation::emitZeroDerivativesForNonvariedResult( |
| SILValue origNonvariedResult) { |
| auto &pullback = getPullback(); |
| auto pbLoc = getPullback().getLocation(); |
| /* |
| // TODO(TF-788): Re-enable non-varied result warning. |
| // Emit fixit if original non-varied result has a valid source location. |
| auto startLoc = origNonvariedResult.getLoc().getStartSourceLoc(); |
| auto endLoc = origNonvariedResult.getLoc().getEndSourceLoc(); |
| if (startLoc.isValid() && endLoc.isValid()) { |
| getContext().diagnose(startLoc, diag::autodiff_nonvaried_result_fixit) |
| .fixItInsert(startLoc, "withoutDerivative(at:") |
| .fixItInsertAfter(endLoc, ")"); |
| } |
| */ |
| LLVM_DEBUG(getADDebugStream() << getOriginal().getName() |
| << " has non-varied result, returning zero" |
| " for all pullback results\n"); |
| auto *pullbackEntry = pullback.createBasicBlock(); |
| createEntryArguments(&pullback); |
| builder.setInsertionPoint(pullbackEntry); |
| // Destroy all owned arguments. |
| for (auto *arg : pullbackEntry->getArguments()) |
| if (arg->getOwnershipKind() == OwnershipKind::Owned) |
| builder.emitDestroyOperation(pbLoc, arg); |
| // Return zero for each result. |
| SmallVector<SILValue, 4> directResults; |
| auto indirectResultIt = pullback.getIndirectResults().begin(); |
| for (auto resultInfo : pullback.getLoweredFunctionType()->getResults()) { |
| auto resultType = |
| pullback.mapTypeIntoContext(resultInfo.getInterfaceType()) |
| ->getCanonicalType(); |
| if (resultInfo.isFormalDirect()) |
| directResults.push_back(emitZeroDirect(resultType, pbLoc)); |
| else |
| emitZeroIndirect(resultType, *indirectResultIt++, pbLoc); |
| } |
| builder.createReturn(pbLoc, joinElements(directResults, builder, pbLoc)); |
| LLVM_DEBUG(getADDebugStream() |
| << "Generated pullback for " << getOriginal().getName() << ":\n" |
| << pullback); |
| } |
| |
| void PullbackCloner::Implementation::accumulateAdjointForOptional( |
| SILBasicBlock *bb, SILValue optionalValue, SILValue wrappedAdjoint) { |
| auto pbLoc = getPullback().getLocation(); |
| // Handle `switch_enum` on `Optional`. |
| // `Optional<T>` |
| auto optionalTy = remapType(optionalValue->getType()); |
| assert(optionalTy.getASTType().getEnumOrBoundGenericEnum() == |
| getASTContext().getOptionalDecl()); |
| // `T` |
| auto wrappedType = optionalTy.getOptionalObjectType(); |
| // `T.TangentVector` |
| auto wrappedTanType = remapType(wrappedAdjoint->getType()); |
| // `Optional<T.TangentVector>` |
| auto optionalOfWrappedTanType = SILType::getOptionalType(wrappedTanType); |
| // `Optional<T>.TangentVector` |
| auto optionalTanTy = getRemappedTangentType(optionalTy); |
| auto *optionalTanDecl = optionalTanTy.getNominalOrBoundGenericNominal(); |
| // Look up the `Optional<T>.TangentVector.init` declaration. |
| auto initLookup = |
| optionalTanDecl->lookupDirect(DeclBaseName::createConstructor()); |
| ConstructorDecl *constructorDecl = nullptr; |
| for (auto *candidate : initLookup) { |
| auto candidateModule = candidate->getModuleContext(); |
| if (candidateModule->getName() == |
| builder.getASTContext().Id_Differentiation || |
| candidateModule->isStdlibModule()) { |
| assert(!constructorDecl && "Multiple `Optional.TangentVector.init`s"); |
| constructorDecl = cast<ConstructorDecl>(candidate); |
| #ifdef NDEBUG |
| break; |
| #endif |
| } |
| } |
| assert(constructorDecl && "No `Optional.TangentVector.init`"); |
| |
| // Allocate a local buffer for the `Optional` adjoint value. |
| auto *optTanAdjBuf = builder.createAllocStack(pbLoc, optionalTanTy); |
| // Find `Optional<T.TangentVector>.some` EnumElementDecl. |
| auto someEltDecl = builder.getASTContext().getOptionalSomeDecl(); |
| |
| // Initialize an `Optional<T.TangentVector>` buffer from `wrappedAdjoint` as |
| // the input for `Optional<T>.TangentVector.init`. |
| auto *optArgBuf = builder.createAllocStack(pbLoc, optionalOfWrappedTanType); |
| if (optionalOfWrappedTanType.isLoadableOrOpaque(builder.getFunction())) { |
| // %enum = enum $Optional<T.TangentVector>, #Optional.some!enumelt, |
| // %wrappedAdjoint : $T |
| auto *enumInst = builder.createEnum(pbLoc, wrappedAdjoint, someEltDecl, |
| optionalOfWrappedTanType); |
| // store %enum to %optArgBuf |
| builder.emitStoreValueOperation(pbLoc, enumInst, optArgBuf, |
| StoreOwnershipQualifier::Init); |
| } else { |
| // %enumAddr = init_enum_data_addr %optArgBuf $Optional<T.TangentVector>, |
| // #Optional.some!enumelt |
| auto *enumAddr = builder.createInitEnumDataAddr( |
| pbLoc, optArgBuf, someEltDecl, wrappedTanType.getAddressType()); |
| // copy_addr %wrappedAdjoint to [initialization] %enumAddr |
| builder.createCopyAddr(pbLoc, wrappedAdjoint, enumAddr, IsNotTake, |
| IsInitialization); |
| // inject_enum_addr %optArgBuf : $*Optional<T.TangentVector>, |
| // #Optional.some!enumelt |
| builder.createInjectEnumAddr(pbLoc, optArgBuf, someEltDecl); |
| } |
| |
| // Apply `Optional<T>.TangentVector.init`. |
| SILOptFunctionBuilder fb(getContext().getTransform()); |
| // %init_fn = function_ref @Optional<T>.TangentVector.init |
| auto *initFn = fb.getOrCreateFunction(pbLoc, SILDeclRef(constructorDecl), |
| NotForDefinition); |
| auto *initFnRef = builder.createFunctionRef(pbLoc, initFn); |
| auto *diffProto = |
| builder.getASTContext().getProtocol(KnownProtocolKind::Differentiable); |
| auto *swiftModule = getModule().getSwiftModule(); |
| auto diffConf = |
| swiftModule->lookupConformance(wrappedType.getASTType(), diffProto); |
| assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`"); |
| auto subMap = SubstitutionMap::get( |
| initFn->getLoweredFunctionType()->getSubstGenericSignature(), |
| ArrayRef<Type>(wrappedType.getASTType()), {diffConf}); |
| // %metatype = metatype $Optional<T>.TangentVector.Type |
| auto metatypeType = CanMetatypeType::get(optionalTanTy.getASTType(), |
| MetatypeRepresentation::Thin); |
| auto metatypeSILType = SILType::getPrimitiveObjectType(metatypeType); |
| auto metatype = builder.createMetatype(pbLoc, metatypeSILType); |
| // apply %init_fn(%optTanAdjBuf, %optArgBuf, %metatype) |
| builder.createApply(pbLoc, initFnRef, subMap, |
| {optTanAdjBuf, optArgBuf, metatype}); |
| builder.createDeallocStack(pbLoc, optArgBuf); |
| |
| // Accumulate adjoint for the incoming `Optional` value. |
| addToAdjointBuffer(bb, optionalValue, optTanAdjBuf, pbLoc); |
| builder.emitDestroyAddr(pbLoc, optTanAdjBuf); |
| builder.createDeallocStack(pbLoc, optTanAdjBuf); |
| } |
| |
| SILBasicBlock *PullbackCloner::Implementation::buildPullbackSuccessor( |
| SILBasicBlock *origBB, SILBasicBlock *origPredBB, |
| SmallDenseMap<SILValue, TrampolineBlockSet> &pullbackTrampolineBlockMap) { |
| // Get the pullback block and optional pullback trampoline block of the |
| // predecessor block. |
| auto *pullbackBB = getPullbackBlock(origPredBB); |
| auto *pullbackTrampolineBB = getPullbackTrampolineBlock(origPredBB, origBB); |
| // If the predecessor block does not have a corresponding pullback |
| // trampoline block, then the pullback successor is the pullback block. |
| if (!pullbackTrampolineBB) |
| return pullbackBB; |
| |
| // Otherwise, the pullback successor is the pullback trampoline block, |
| // which branches to the pullback block and propagates adjoint values of |
| // active values. |
| assert(pullbackTrampolineBB->getNumArguments() == 1); |
| auto loc = origBB->getParent()->getLocation(); |
| SmallVector<SILValue, 8> trampolineArguments; |
| // Propagate adjoint values/buffers of active values/buffers to |
| // predecessor blocks. |
| auto &predBBActiveValues = activeValues[origPredBB]; |
| for (auto activeValue : predBBActiveValues) { |
| LLVM_DEBUG(getADDebugStream() |
| << "Propagating adjoint of active value " << activeValue |
| << " to predecessors' pullback blocks\n"); |
| switch (getTangentValueCategory(activeValue)) { |
| case SILValueCategory::Object: { |
| auto activeValueAdj = getAdjointValue(origBB, activeValue); |
| auto concreteActiveValueAdj = |
| materializeAdjointDirect(activeValueAdj, loc); |
| |
| if (!pullbackTrampolineBlockMap.count(concreteActiveValueAdj)) { |
| concreteActiveValueAdj = |
| builder.emitCopyValueOperation(loc, concreteActiveValueAdj); |
| setAdjointValue(origBB, activeValue, |
| makeConcreteAdjointValue(concreteActiveValueAdj)); |
| } |
| auto insertion = pullbackTrampolineBlockMap.try_emplace( |
| concreteActiveValueAdj, TrampolineBlockSet()); |
| auto &blockSet = insertion.first->getSecond(); |
| blockSet.insert(pullbackTrampolineBB); |
| trampolineArguments.push_back(concreteActiveValueAdj); |
| |
| // If the pullback block does not yet have a registered adjoint |
| // value for the active value, set the adjoint value to the |
| // forwarded adjoint value argument. |
| // TODO: Hoist this logic out of loop over predecessor blocks to |
| // remove the `hasAdjointValue` check. |
| if (!hasAdjointValue(origPredBB, activeValue)) { |
| auto *pullbackBBArg = |
| getActiveValuePullbackBlockArgument(origPredBB, activeValue); |
| auto forwardedArgAdj = makeConcreteAdjointValue(pullbackBBArg); |
| setAdjointValue(origPredBB, activeValue, forwardedArgAdj); |
| } |
| break; |
| } |
| case SILValueCategory::Address: { |
| // Propagate adjoint buffers using `copy_addr`. |
| auto adjBuf = getAdjointBuffer(origBB, activeValue); |
| auto predAdjBuf = getAdjointBuffer(origPredBB, activeValue); |
| builder.createCopyAddr(loc, adjBuf, predAdjBuf, IsNotTake, |
| IsNotInitialization); |
| break; |
| } |
| } |
| } |
| // Propagate pullback struct argument. |
| SILBuilder pullbackTrampolineBBBuilder(pullbackTrampolineBB); |
| auto *pullbackTrampolineBBArg = pullbackTrampolineBB->getArguments().front(); |
| if (vjpCloner.getLoopInfo()->getLoopFor(origPredBB)) { |
| assert(pullbackTrampolineBBArg->getType() == |
| SILType::getRawPointerType(getASTContext())); |
| auto pbStructType = |
| remapType(getPullbackInfo().getLinearMapStructLoweredType(origPredBB)); |
| auto predPbStructAddr = pullbackTrampolineBBBuilder.createPointerToAddress( |
| loc, pullbackTrampolineBBArg, pbStructType.getAddressType(), |
| /*isStrict*/ true); |
| auto predPbStructVal = pullbackTrampolineBBBuilder.createLoad( |
| loc, predPbStructAddr, |
| pbStructType.isTrivial(getPullback()) ? |
| LoadOwnershipQualifier::Trivial : LoadOwnershipQualifier::Take); |
| trampolineArguments.push_back(predPbStructVal); |
| } else { |
| trampolineArguments.push_back(pullbackTrampolineBBArg); |
| } |
| // Branch from pullback trampoline block to pullback block. |
| pullbackTrampolineBBBuilder.createBranch(loc, pullbackBB, |
| trampolineArguments); |
| return pullbackTrampolineBB; |
| } |
| |
| void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) { |
| auto pbLoc = getPullback().getLocation(); |
| // Get the corresponding pullback basic block. |
| auto *pbBB = getPullbackBlock(bb); |
| builder.setInsertionPoint(pbBB); |
| |
| LLVM_DEBUG({ |
| auto &s = getADDebugStream() |
| << "Original bb" + std::to_string(bb->getDebugID()) |
| << ": To differentiate or not to differentiate?\n"; |
| for (auto &inst : llvm::reverse(*bb)) { |
| s << (getPullbackInfo().shouldDifferentiateInstruction(&inst) ? "[x] " |
| : "[ ] ") |
| << inst; |
| } |
| }); |
| |
| // Visit each instruction in reverse order. |
| for (auto &inst : llvm::reverse(*bb)) { |
| if (!getPullbackInfo().shouldDifferentiateInstruction(&inst)) |
| continue; |
| // Differentiate instruction. |
| visit(&inst); |
| if (errorOccurred) |
| return; |
| } |
| |
| // Emit a branching terminator for the block. |
| // If the original block is the original entry, then the pullback block is |
| // the pullback exit. This is handled specially in |
| // `PullbackCloner::Implementation::run()`, so we leave the block |
| // non-terminated. |
| if (bb->isEntry()) |
| return; |
| |
| // Otherwise, add a `switch_enum` terminator for non-exit |
| // pullback blocks. |
| // 1. Get the pullback struct pullback block argument. |
| // 2. Extract the predecessor enum value from the pullback struct value. |
| auto *predEnum = getPullbackInfo().getBranchingTraceDecl(bb); |
| (void)predEnum; |
| auto *predEnumField = getPullbackInfo().lookUpLinearMapStructEnumField(bb); |
| auto predEnumVal = getPullbackStructElement(bb, predEnumField); |
| |
| // Propagate adjoint values from active basic block arguments to |
| // incoming values (predecessor terminator operands). |
| for (auto *bbArg : bb->getArguments()) { |
| if (!getActivityInfo().isActive(bbArg, getConfig())) |
| continue; |
| // Get predecessor terminator operands. |
| SmallVector<std::pair<SILBasicBlock *, SILValue>, 4> incomingValues; |
| bbArg->getSingleTerminatorOperands(incomingValues); |
| |
| // Returns true if the given terminator instruction is a `switch_enum` on |
| // an `Optional`-typed value. `switch_enum` instructions require |
| // special-case adjoint value propagation for the operand. |
| auto isSwitchEnumInstOnOptional = |
| [&ctx = getASTContext()](TermInst *termInst) { |
| if (!termInst) |
| return false; |
| if (auto *sei = dyn_cast<SwitchEnumInst>(termInst)) { |
| auto *optionalEnumDecl = ctx.getOptionalDecl(); |
| auto operandTy = sei->getOperand()->getType(); |
| return operandTy.getASTType().getEnumOrBoundGenericEnum() == |
| optionalEnumDecl; |
| } |
| return false; |
| }; |
| |
| // Check the tangent value category of the active basic block argument. |
| switch (getTangentValueCategory(bbArg)) { |
| // If argument has a loadable tangent value category: materialize adjoint |
| // value of the argument, create a copy, and set the copy as the adjoint |
| // value of incoming values. |
| case SILValueCategory::Object: { |
| auto bbArgAdj = getAdjointValue(bb, bbArg); |
| auto concreteBBArgAdj = materializeAdjointDirect(bbArgAdj, pbLoc); |
| auto concreteBBArgAdjCopy = |
| builder.emitCopyValueOperation(pbLoc, concreteBBArgAdj); |
| for (auto pair : incomingValues) { |
| auto *predBB = std::get<0>(pair); |
| auto incomingValue = std::get<1>(pair); |
| // Handle `switch_enum` on `Optional`. |
| auto termInst = bbArg->getSingleTerminator(); |
| if (isSwitchEnumInstOnOptional(termInst)) { |
| accumulateAdjointForOptional(bb, incomingValue, concreteBBArgAdjCopy); |
| } else { |
| blockTemporaries[getPullbackBlock(predBB)].insert( |
| concreteBBArgAdjCopy); |
| setAdjointValue(predBB, incomingValue, |
| makeConcreteAdjointValue(concreteBBArgAdjCopy)); |
| } |
| } |
| break; |
| } |
| // If argument has an address tangent value category: materialize adjoint |
| // value of the argument, create a copy, and set the copy as the adjoint |
| // value of incoming values. |
| case SILValueCategory::Address: { |
| auto bbArgAdjBuf = getAdjointBuffer(bb, bbArg); |
| for (auto pair : incomingValues) { |
| auto *predBB = std::get<0>(pair); |
| auto incomingValue = std::get<1>(pair); |
| // Handle `switch_enum` on `Optional`. |
| auto termInst = bbArg->getSingleTerminator(); |
| if (isSwitchEnumInstOnOptional(termInst)) |
| accumulateAdjointForOptional(bb, incomingValue, bbArgAdjBuf); |
| else |
| addToAdjointBuffer(predBB, incomingValue, bbArgAdjBuf, pbLoc); |
| } |
| break; |
| } |
| } |
| } |
| |
| // 3. Build the pullback successor cases for the `switch_enum` |
| // instruction. The pullback successors correspond to the predecessors |
| // of the current block. |
| SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4> |
| pullbackSuccessorCases; |
| // A map from active values' adjoint values to the trampoline blocks that |
| // are using them. |
| SmallDenseMap<SILValue, TrampolineBlockSet> pullbackTrampolineBlockMap; |
| SmallVector<SILBasicBlock *, 8> pullbackSuccBBs; |
| for (auto *predBB : bb->getPredecessorBlocks()) { |
| auto *pullbackSuccBB = |
| buildPullbackSuccessor(bb, predBB, pullbackTrampolineBlockMap); |
| pullbackSuccBBs.push_back(pullbackSuccBB); |
| auto *enumEltDecl = |
| getPullbackInfo().lookUpBranchingTraceEnumElement(predBB, bb); |
| pullbackSuccessorCases.push_back({enumEltDecl, pullbackSuccBB}); |
| } |
| // Values are trampolined by only a subset of pullback successor blocks. |
| // Other successors blocks should destroy the value to balance the reference |
| // count. |
| for (auto pair : pullbackTrampolineBlockMap) { |
| auto value = pair.getFirst(); |
| // The set of trampoline BBs that are users of `value`. |
| auto &userTrampolineBBSet = pair.getSecond(); |
| // For each pullback successor block that does not trampoline the value, |
| // release the value. |
| for (auto *pullbackSuccBB : pullbackSuccBBs) { |
| if (userTrampolineBBSet.count(pullbackSuccBB)) |
| continue; |
| SILBuilder builder(pullbackSuccBB->begin()); |
| builder.emitDestroyValueOperation(pbLoc, value); |
| } |
| } |
| // Emit cleanups for all block-local temporaries. |
| cleanUpTemporariesForBlock(pbBB, pbLoc); |
| // Branch to pullback successor blocks. |
| assert(pullbackSuccessorCases.size() == predEnum->getNumElements()); |
| builder.createSwitchEnum(pbLoc, predEnumVal, /*DefaultBB*/ nullptr, |
| pullbackSuccessorCases); |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Member accessor pullback generation |
| //--------------------------------------------------------------------------// |
| |
| bool PullbackCloner::Implementation::runForSemanticMemberAccessor() { |
| auto &original = getOriginal(); |
| auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl()); |
| switch (accessor->getAccessorKind()) { |
| case AccessorKind::Get: |
| return runForSemanticMemberGetter(); |
| case AccessorKind::Set: |
| return runForSemanticMemberSetter(); |
| // TODO(SR-12640): Support `modify` accessors. |
| default: |
| llvm_unreachable("Unsupported accessor kind; inconsistent with " |
| "`isSemanticMemberAccessor`?"); |
| } |
| } |
| |
| bool PullbackCloner::Implementation::runForSemanticMemberGetter() { |
| auto &original = getOriginal(); |
| auto &pullback = getPullback(); |
| auto pbLoc = getPullback().getLocation(); |
| |
| auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl()); |
| assert(accessor->getAccessorKind() == AccessorKind::Get); |
| |
| auto *origEntry = original.getEntryBlock(); |
| auto *pbEntry = pullback.getEntryBlock(); |
| builder.setInsertionPoint(pbEntry); |
| |
| // Get getter argument and result values. |
| // Getter type: $(Self) -> Result |
| // Pullback type: $(Result', PB_Struct|Context) -> Self' |
| assert(original.getLoweredFunctionType()->getNumParameters() == 1); |
| assert(pullback.getLoweredFunctionType()->getNumParameters() == 2); |
| assert(pullback.getLoweredFunctionType()->getNumResults() == 1); |
| SILValue origSelf = original.getArgumentsWithoutIndirectResults().front(); |
| |
| SmallVector<SILValue, 8> origFormalResults; |
| collectAllFormalResultsInTypeOrder(original, origFormalResults); |
| assert(getConfig().resultIndices->getNumIndices() == 1 && |
| "Getter should have one semantic result"); |
| auto origResult = origFormalResults[*getConfig().resultIndices->begin()]; |
| |
| auto tangentVectorSILTy = pullback.getConventions().getResults().front() |
| .getSILStorageType(getModule(), |
| pullback.getLoweredFunctionType(), |
| TypeExpansionContext::minimal()); |
| auto tangentVectorTy = tangentVectorSILTy.getASTType(); |
| auto *tangentVectorDecl = tangentVectorTy->getStructOrBoundGenericStruct(); |
| |
| // Look up the corresponding field in the tangent space. |
| auto *origField = cast<VarDecl>(accessor->getStorage()); |
| auto baseType = remapType(origSelf->getType()).getASTType(); |
| auto *tanField = getTangentStoredProperty(getContext(), origField, baseType, |
| pbLoc, getInvoker()); |
| if (!tanField) { |
| errorOccurred = true; |
| return true; |
| } |
| |
| // Switch based on the base tangent struct's value category. |
| // TODO(TF-1255): Simplify using unified adjoint value data structure. |
| switch (getTangentValueCategory(origSelf)) { |
| case SILValueCategory::Object: { |
| auto adjResult = getAdjointValue(origEntry, origResult); |
| switch (adjResult.getKind()) { |
| case AdjointValueKind::Zero: |
| addAdjointValue(origEntry, origSelf, |
| makeZeroAdjointValue(tangentVectorSILTy), pbLoc); |
| break; |
| case AdjointValueKind::Concrete: |
| case AdjointValueKind::Aggregate: { |
| SmallVector<AdjointValue, 8> eltVals; |
| for (auto *field : tangentVectorDecl->getStoredProperties()) { |
| if (field == tanField) { |
| eltVals.push_back(adjResult); |
| } else { |
| auto substMap = tangentVectorTy->getMemberSubstitutionMap( |
| field->getModuleContext(), field); |
| auto fieldTy = field->getType().subst(substMap); |
| auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType(); |
| assert(fieldSILTy.isObject()); |
| eltVals.push_back(makeZeroAdjointValue(fieldSILTy)); |
| } |
| } |
| addAdjointValue(origEntry, origSelf, |
| makeAggregateAdjointValue(tangentVectorSILTy, eltVals), |
| pbLoc); |
| } |
| } |
| break; |
| } |
| case SILValueCategory::Address: { |
| assert(pullback.getIndirectResults().size() == 1); |
| auto pbIndRes = pullback.getIndirectResults().front(); |
| auto *adjSelf = createFunctionLocalAllocation( |
| pbIndRes->getType().getObjectType(), pbLoc); |
| setAdjointBuffer(origEntry, origSelf, adjSelf); |
| for (auto *field : tangentVectorDecl->getStoredProperties()) { |
| auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, field); |
| if (field == tanField) { |
| // Switch based on the property's value category. |
| // TODO(TF-1255): Simplify using unified adjoint value data structure. |
| switch (getTangentValueCategory(origResult)) { |
| case SILValueCategory::Object: { |
| auto adjResult = getAdjointValue(origEntry, origResult); |
| auto adjResultValue = materializeAdjointDirect(adjResult, pbLoc); |
| auto adjResultValueCopy = |
| builder.emitCopyValueOperation(pbLoc, adjResultValue); |
| builder.emitStoreValueOperation(pbLoc, adjResultValueCopy, adjSelfElt, |
| StoreOwnershipQualifier::Init); |
| break; |
| } |
| case SILValueCategory::Address: { |
| auto adjResult = getAdjointBuffer(origEntry, origResult); |
| builder.createCopyAddr(pbLoc, adjResult, adjSelfElt, IsTake, |
| IsInitialization); |
| destroyedLocalAllocations.insert(adjResult); |
| break; |
| } |
| } |
| } else { |
| auto fieldType = pullback.mapTypeIntoContext(field->getInterfaceType()) |
| ->getCanonicalType(); |
| emitZeroIndirect(fieldType, adjSelfElt, pbLoc); |
| } |
| } |
| break; |
| } |
| } |
| return false; |
| } |
| |
| bool PullbackCloner::Implementation::runForSemanticMemberSetter() { |
| auto &original = getOriginal(); |
| auto &pullback = getPullback(); |
| auto pbLoc = getPullback().getLocation(); |
| |
| auto *accessor = cast<AccessorDecl>(original.getDeclContext()->getAsDecl()); |
| assert(accessor->getAccessorKind() == AccessorKind::Set); |
| |
| auto *origEntry = original.getEntryBlock(); |
| auto *pbEntry = pullback.getEntryBlock(); |
| builder.setInsertionPoint(pbEntry); |
| |
| // Get setter argument values. |
| // Setter type: $(inout Self, Argument) -> () |
| // Pullback type (wrt self): $(inout Self', PB_Struct) -> () |
| // Pullback type (wrt both): $(inout Self', PB_Struct) -> Argument' |
| assert(original.getLoweredFunctionType()->getNumParameters() == 2); |
| assert(pullback.getLoweredFunctionType()->getNumParameters() == 2); |
| assert(pullback.getLoweredFunctionType()->getNumResults() == 0 || |
| pullback.getLoweredFunctionType()->getNumResults() == 1); |
| |
| SILValue origArg = original.getArgumentsWithoutIndirectResults()[0]; |
| SILValue origSelf = original.getArgumentsWithoutIndirectResults()[1]; |
| |
| // Look up the corresponding field in the tangent space. |
| auto *origField = cast<VarDecl>(accessor->getStorage()); |
| auto baseType = remapType(origSelf->getType()).getASTType(); |
| auto *tanField = getTangentStoredProperty(getContext(), origField, baseType, |
| pbLoc, getInvoker()); |
| if (!tanField) { |
| errorOccurred = true; |
| return true; |
| } |
| |
| auto adjSelf = getAdjointBuffer(origEntry, origSelf); |
| auto *adjSelfElt = builder.createStructElementAddr(pbLoc, adjSelf, tanField); |
| // Switch based on the property's value category. |
| // TODO(TF-1255): Simplify using unified adjoint value data structure. |
| switch (origArg->getType().getCategory()) { |
| case SILValueCategory::Object: { |
| auto adjArg = builder.emitLoadValueOperation(pbLoc, adjSelfElt, |
| LoadOwnershipQualifier::Take); |
| setAdjointValue(origEntry, origArg, makeConcreteAdjointValue(adjArg)); |
| blockTemporaries[pbEntry].insert(adjArg); |
| break; |
| } |
| case SILValueCategory::Address: { |
| addToAdjointBuffer(origEntry, origArg, adjSelfElt, pbLoc); |
| builder.emitDestroyOperation(pbLoc, adjSelfElt); |
| break; |
| } |
| } |
| emitZeroIndirect(adjSelfElt->getType().getASTType(), adjSelfElt, pbLoc); |
| |
| return false; |
| } |
| |
| //--------------------------------------------------------------------------// |
| // Adjoint buffer mapping |
| //--------------------------------------------------------------------------// |
| |
| SILValue PullbackCloner::Implementation::getAdjointProjection( |
| SILBasicBlock *origBB, SILValue originalProjection) { |
| // Handle `struct_element_addr`. |
| // Adjoint projection: a `struct_element_addr` into the base adjoint buffer. |
| if (auto *seai = dyn_cast<StructElementAddrInst>(originalProjection)) { |
| assert(!seai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() && |
| "`@noDerivative` struct projections should never be active"); |
| auto adjSource = getAdjointBuffer(origBB, seai->getOperand()); |
| auto structType = remapType(seai->getOperand()->getType()).getASTType(); |
| auto *tanField = |
| getTangentStoredProperty(getContext(), seai, structType, getInvoker()); |
| assert(tanField && "Invalid projections should have been diagnosed"); |
| return builder.createStructElementAddr(seai->getLoc(), adjSource, tanField); |
| } |
| // Handle `tuple_element_addr`. |
| // Adjoint projection: a `tuple_element_addr` into the base adjoint buffer. |
| if (auto *teai = dyn_cast<TupleElementAddrInst>(originalProjection)) { |
| auto source = teai->getOperand(); |
| auto adjSource = getAdjointBuffer(origBB, source); |
| if (!adjSource->getType().is<TupleType>()) |
| return adjSource; |
| auto origTupleTy = source->getType().castTo<TupleType>(); |
| unsigned adjIndex = 0; |
| for (unsigned i : range(teai->getFieldIndex())) { |
| if (getTangentSpace( |
| origTupleTy->getElement(i).getType()->getCanonicalType())) |
| ++adjIndex; |
| } |
| return builder.createTupleElementAddr(teai->getLoc(), adjSource, adjIndex); |
| } |
| // Handle `ref_element_addr`. |
| // Adjoint projection: a local allocation initialized with the corresponding |
| // field value from the class's base adjoint value. |
| if (auto *reai = dyn_cast<RefElementAddrInst>(originalProjection)) { |
| assert(!reai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() && |
| "`@noDerivative` class projections should never be active"); |
| auto loc = reai->getLoc(); |
| // Get the class operand, stripping `begin_borrow`. |
| auto classOperand = stripBorrow(reai->getOperand()); |
| auto classType = remapType(reai->getOperand()->getType()).getASTType(); |
| auto *tanField = |
| getTangentStoredProperty(getContext(), reai->getField(), classType, |
| reai->getLoc(), getInvoker()); |
| assert(tanField && "Invalid projections should have been diagnosed"); |
| // Create a local allocation for the element adjoint buffer. |
| auto eltTanType = tanField->getValueInterfaceType()->getCanonicalType(); |
| auto eltTanSILType = |
| remapType(SILType::getPrimitiveAddressType(eltTanType)); |
| auto *eltAdjBuffer = createFunctionLocalAllocation(eltTanSILType, loc); |
| // Check the class operand's `TangentVector` value category. |
| switch (getTangentValueCategory(classOperand)) { |
| case SILValueCategory::Object: { |
| // Get the class operand's adjoint value. Currently, it must be a |
| // `TangentVector` struct. |
| auto adjClass = |
| materializeAdjointDirect(getAdjointValue(origBB, classOperand), loc); |
| builder.emitScopedBorrowOperation( |
| loc, adjClass, [&](SILValue borrowedAdjClass) { |
| // Initialize the element adjoint buffer with the base adjoint |
| // value. |
| auto *adjElt = |
| builder.createStructExtract(loc, borrowedAdjClass, tanField); |
| auto adjEltCopy = builder.emitCopyValueOperation(loc, adjElt); |
| builder.emitStoreValueOperation(loc, adjEltCopy, eltAdjBuffer, |
| StoreOwnershipQualifier::Init); |
| }); |
| return eltAdjBuffer; |
| } |
| case SILValueCategory::Address: { |
| // Get the class operand's adjoint buffer. Currently, it must be a |
| // `TangentVector` struct. |
| auto adjClass = getAdjointBuffer(origBB, classOperand); |
| // Initialize the element adjoint buffer with the base adjoint buffer. |
| auto *adjElt = builder.createStructElementAddr(loc, adjClass, tanField); |
| builder.createCopyAddr(loc, adjElt, eltAdjBuffer, IsNotTake, |
| IsInitialization); |
| return eltAdjBuffer; |
| } |
| } |
| } |
| // Handle `begin_access`. |
| // Adjoint projection: the base adjoint buffer itself. |
| if (auto *bai = dyn_cast<BeginAccessInst>(originalProjection)) { |
| auto adjBase = getAdjointBuffer(origBB, bai->getOperand()); |
| if (errorOccurred) |
| return (bufferMap[{origBB, originalProjection}] = SILValue()); |
| // Return the base buffer's adjoint buffer. |
| return adjBase; |
| } |
| // Handle `array.uninitialized_intrinsic` application element addresses. |
| // Adjoint projection: a local allocation initialized by applying |
| // `Array.TangentVector.subscript` to the base array's adjoint value. |
| auto *ai = |
| getAllocateUninitializedArrayIntrinsicElementAddress(originalProjection); |
| auto *definingInst = dyn_cast_or_null<SingleValueInstruction>( |
| originalProjection->getDefiningInstruction()); |
| bool isAllocateUninitializedArrayIntrinsicElementAddress = |
| ai && definingInst && |
| (isa<PointerToAddressInst>(definingInst) || |
| isa<IndexAddrInst>(definingInst)); |
| if (isAllocateUninitializedArrayIntrinsicElementAddress) { |
| // Get the array element index of the result address. |
| int eltIndex = 0; |
| if (auto *iai = dyn_cast<IndexAddrInst>(definingInst)) { |
| auto *ili = cast<IntegerLiteralInst>(iai->getIndex()); |
| eltIndex = ili->getValue().getLimitedValue(); |
| } |
| // Get the array adjoint value. |
| SILValue arrayAdjoint; |
| assert(ai && "Expected `array.uninitialized_intrinsic` application"); |
| for (auto use : ai->getUses()) { |
| auto *dti = dyn_cast<DestructureTupleInst>(use->getUser()); |
| if (!dti) |
| continue; |
| assert(!arrayAdjoint && "Array adjoint already found"); |
| // The first `destructure_tuple` result is the `Array` value. |
| auto arrayValue = dti->getResult(0); |
| arrayAdjoint = materializeAdjointDirect( |
| getAdjointValue(origBB, arrayValue), definingInst->getLoc()); |
| } |
| assert(arrayAdjoint && "Array does not have adjoint value"); |
| // Apply `Array.TangentVector.subscript` to get array element adjoint value. |
| auto *eltAdjBuffer = |
| getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, ai->getLoc()); |
| return eltAdjBuffer; |
| } |
| return SILValue(); |
| } |
| |
| //----------------------------------------------------------------------------// |
| // Adjoint value accumulation |
| //----------------------------------------------------------------------------// |
| |
| AdjointValue PullbackCloner::Implementation::accumulateAdjointsDirect( |
| AdjointValue lhs, AdjointValue rhs, SILLocation loc) { |
| LLVM_DEBUG(getADDebugStream() << "Materializing adjoint directly.\nLHS: " |
| << lhs << "\nRHS: " << rhs << '\n'); |
| switch (lhs.getKind()) { |
| // x |
| case AdjointValueKind::Concrete: { |
| auto lhsVal = lhs.getConcreteValue(); |
| switch (rhs.getKind()) { |
| // x + y |
| case AdjointValueKind::Concrete: { |
| auto rhsVal = rhs.getConcreteValue(); |
| auto sum = recordTemporary(accumulateDirect(lhsVal, rhsVal, loc)); |
| return makeConcreteAdjointValue(sum); |
| } |
| // x + 0 => x |
| case AdjointValueKind::Zero: |
| return lhs; |
| // x + (y, z) => (x.0 + y, x.1 + z) |
| case AdjointValueKind::Aggregate: |
| SmallVector<AdjointValue, 8> newElements; |
| auto lhsTy = lhsVal->getType().getASTType(); |
| auto lhsValCopy = builder.emitCopyValueOperation(loc, lhsVal); |
| if (lhsTy->is<TupleType>()) { |
| auto elts = builder.createDestructureTuple(loc, lhsValCopy); |
| llvm::for_each(elts->getResults(), |
| [this](SILValue result) { recordTemporary(result); }); |
| for (auto i : indices(elts->getResults())) { |
| auto rhsElt = rhs.getAggregateElement(i); |
| newElements.push_back(accumulateAdjointsDirect( |
| makeConcreteAdjointValue(elts->getResult(i)), rhsElt, loc)); |
| } |
| } else if (lhsTy->getStructOrBoundGenericStruct()) { |
| auto elts = |
| builder.createDestructureStruct(lhsVal.getLoc(), lhsValCopy); |
| llvm::for_each(elts->getResults(), |
| [this](SILValue result) { recordTemporary(result); }); |
| for (unsigned i : indices(elts->getResults())) { |
| auto rhsElt = rhs.getAggregateElement(i); |
| newElements.push_back(accumulateAdjointsDirect( |
| makeConcreteAdjointValue(elts->getResult(i)), rhsElt, loc)); |
| } |
| } else { |
| llvm_unreachable("Not an aggregate type"); |
| } |
| return makeAggregateAdjointValue(lhsVal->getType(), newElements); |
| } |
| } |
| // 0 |
| case AdjointValueKind::Zero: |
| // 0 + x => x |
| return rhs; |
| // (x, y) |
| case AdjointValueKind::Aggregate: |
| switch (rhs.getKind()) { |
| // (x, y) + z => (z.0 + x, z.1 + y) |
| case AdjointValueKind::Concrete: |
| return accumulateAdjointsDirect(rhs, lhs, loc); |
| // x + 0 => x |
| case AdjointValueKind::Zero: |
| return lhs; |
| // (x, y) + (z, w) => (x + z, y + w) |
| case AdjointValueKind::Aggregate: { |
| SmallVector<AdjointValue, 8> newElements; |
| for (auto i : range(lhs.getNumAggregateElements())) |
| newElements.push_back(accumulateAdjointsDirect( |
| lhs.getAggregateElement(i), rhs.getAggregateElement(i), loc)); |
| return makeAggregateAdjointValue(lhs.getType(), newElements); |
| } |
| } |
| } |
| llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715 |
| } |
| |
| SILValue PullbackCloner::Implementation::accumulateDirect(SILValue lhs, |
| SILValue rhs, |
| SILLocation loc) { |
| LLVM_DEBUG(getADDebugStream() << "Emitting adjoint accumulation for lhs: " |
| << lhs << " and rhs: " << rhs); |
| assert(lhs->getType() == rhs->getType() && "Adjoints must have equal types!"); |
| assert(lhs->getType().isObject() && rhs->getType().isObject() && |
| "Adjoint types must be both object types!"); |
| auto adjointTy = lhs->getType(); |
| auto adjointASTTy = adjointTy.getASTType(); |
| auto tangentSpace = getTangentSpace(adjointASTTy); |
| auto lhsCopy = builder.emitCopyValueOperation(loc, lhs); |
| auto rhsCopy = builder.emitCopyValueOperation(loc, rhs); |
| assert(tangentSpace && "No tangent space for this type"); |
| switch (tangentSpace->getKind()) { |
| case TangentSpace::Kind::TangentVector: { |
| // Allocate buffers for inputs and output. |
| auto *resultBuf = builder.createAllocStack(loc, adjointTy); |
| auto *lhsBuf = builder.createAllocStack(loc, adjointTy); |
| auto *rhsBuf = builder.createAllocStack(loc, adjointTy); |
| // Initialize input buffers. |
| builder.emitStoreValueOperation(loc, lhsCopy, lhsBuf, |
| StoreOwnershipQualifier::Init); |
| builder.emitStoreValueOperation(loc, rhsCopy, rhsBuf, |
| StoreOwnershipQualifier::Init); |
| accumulateIndirect(resultBuf, lhsBuf, rhsBuf, loc); |
| builder.emitDestroyAddr(loc, lhsBuf); |
| builder.emitDestroyAddr(loc, rhsBuf); |
| // Deallocate input buffers. |
| builder.createDeallocStack(loc, rhsBuf); |
| builder.createDeallocStack(loc, lhsBuf); |
| auto val = builder.emitLoadValueOperation(loc, resultBuf, |
| LoadOwnershipQualifier::Take); |
| // Deallocate result buffer. |
| builder.createDeallocStack(loc, resultBuf); |
| return val; |
| } |
| case TangentSpace::Kind::Tuple: { |
| SmallVector<SILValue, 8> adjElements; |
| auto lhsElts = builder.createDestructureTuple(loc, lhsCopy)->getResults(); |
| auto rhsElts = builder.createDestructureTuple(loc, rhsCopy)->getResults(); |
| for (auto zipped : llvm::zip(lhsElts, rhsElts)) |
| adjElements.push_back( |
| accumulateDirect(std::get<0>(zipped), std::get<1>(zipped), loc)); |
| return builder.createTuple(loc, adjointTy, adjElements); |
| } |
| } |
| llvm_unreachable("Invalid tangent space"); // silences MSVC C4715 |
| } |
| |
| void PullbackCloner::Implementation::accumulateIndirect(SILValue resultAddress, |
| SILValue lhsAddress, |
| SILValue rhsAddress, |
| SILLocation loc) { |
| assert(lhsAddress->getType() == rhsAddress->getType() && |
| "Adjoint values must have same type!"); |
| assert(lhsAddress->getType().isAddress() && |
| rhsAddress->getType().isAddress() && |
| "Adjoint values must both have address types!"); |
| auto adjointTy = lhsAddress->getType(); |
| auto adjointASTTy = adjointTy.getASTType(); |
| auto *swiftMod = getModule().getSwiftModule(); |
| auto tangentSpace = adjointASTTy->getAutoDiffTangentSpace( |
| LookUpConformanceInModule(swiftMod)); |
| assert(tangentSpace && "No tangent space for this type"); |
| switch (tangentSpace->getKind()) { |
| case TangentSpace::Kind::TangentVector: { |
| auto *proto = getContext().getAdditiveArithmeticProtocol(); |
| auto *combinerFuncDecl = getContext().getPlusDecl(); |
| // Call the combiner function and return. |
| auto adjointParentModule = |
| tangentSpace->getNominal() |
| ? tangentSpace->getNominal()->getModuleContext() |
| : getModule().getSwiftModule(); |
| auto confRef = adjointParentModule->lookupConformance(adjointASTTy, proto); |
| assert(!confRef.isInvalid() && |
| "Missing conformance to `AdditiveArithmetic`"); |
| SILDeclRef declRef(combinerFuncDecl, SILDeclRef::Kind::Func); |
| auto silFnTy = getContext().getTypeConverter().getConstantType( |
| TypeExpansionContext::minimal(), declRef); |
| // %0 = witness_method @+ |
| auto witnessMethod = builder.createWitnessMethod(loc, adjointASTTy, confRef, |
| declRef, silFnTy); |
| auto subMap = |
| SubstitutionMap::getProtocolSubstitutions(proto, adjointASTTy, confRef); |
| // %1 = metatype $T.Type |
| auto metatypeType = |
| CanMetatypeType::get(adjointASTTy, MetatypeRepresentation::Thick); |
| auto metatypeSILType = SILType::getPrimitiveObjectType(metatypeType); |
| auto metatype = builder.createMetatype(loc, metatypeSILType); |
| // %2 = apply %0(%result, %new, %old, %1) |
| builder.createApply(loc, witnessMethod, subMap, |
| {resultAddress, rhsAddress, lhsAddress, metatype}, |
| /*isNonThrowing*/ false); |
| builder.emitDestroyValueOperation(loc, witnessMethod); |
| return; |
| } |
| case TangentSpace::Kind::Tuple: { |
| auto tupleType = tangentSpace->getTuple(); |
| for (unsigned i : range(tupleType->getNumElements())) { |
| auto *destAddr = builder.createTupleElementAddr(loc, resultAddress, i); |
| auto *eltAddrLHS = builder.createTupleElementAddr(loc, lhsAddress, i); |
| auto *eltAddrRHS = builder.createTupleElementAddr(loc, rhsAddress, i); |
| accumulateIndirect(destAddr, eltAddrLHS, eltAddrRHS, loc); |
| } |
| return; |
| } |
| } |
| } |
| |
| void PullbackCloner::Implementation::accumulateIndirect(SILValue lhsDestAddress, |
| SILValue rhsAddress, |
| SILLocation loc) { |
| assert(lhsDestAddress->getType().isAddress() && |
| rhsAddress->getType().isAddress()); |
| assert(lhsDestAddress->getFunction() == &getPullback()); |
| assert(rhsAddress->getFunction() == &getPullback()); |
| auto type = lhsDestAddress->getType(); |
| auto astType = type.getASTType(); |
| auto *swiftMod = getModule().getSwiftModule(); |
| auto tangentSpace = |
| astType->getAutoDiffTangentSpace(LookUpConformanceInModule(swiftMod)); |
| assert(tangentSpace && "No tangent space for this type"); |
| switch (tangentSpace->getKind()) { |
| case TangentSpace::Kind::TangentVector: { |
| auto *proto = getContext().getAdditiveArithmeticProtocol(); |
| auto *accumulatorFuncDecl = getContext().getPlusEqualDecl(); |
| // Call the combiner function and return. |
| auto confRef = swiftMod->lookupConformance(astType, proto); |
| assert(!confRef.isInvalid() && |
| "Missing conformance to `AdditiveArithmetic`"); |
| SILDeclRef declRef(accumulatorFuncDecl, SILDeclRef::Kind::Func); |
| auto silFnTy = getContext().getTypeConverter().getConstantType( |
| TypeExpansionContext::minimal(), declRef); |
| // %0 = witness_method @+= |
| auto witnessMethod = |
| builder.createWitnessMethod(loc, astType, confRef, declRef, silFnTy); |
| auto subMap = |
| SubstitutionMap::getProtocolSubstitutions(proto, astType, confRef); |
| // %1 = metatype $T.Type |
| auto metatypeType = |
| CanMetatypeType::get(astType, MetatypeRepresentation::Thick); |
| auto metatypeSILType = SILType::getPrimitiveObjectType(metatypeType); |
| auto metatype = builder.createMetatype(loc, metatypeSILType); |
| // %2 = apply $0(%lhs, %rhs, %1) |
| builder.createApply(loc, witnessMethod, subMap, |
| {lhsDestAddress, rhsAddress, metatype}, |
| /*isNonThrowing*/ false); |
| builder.emitDestroyValueOperation(loc, witnessMethod); |
| return; |
| } |
| case TangentSpace::Kind::Tuple: { |
| auto tupleType = tangentSpace->getTuple(); |
| for (unsigned i : range(tupleType->getNumElements())) { |
| auto *destAddr = builder.createTupleElementAddr(loc, lhsDestAddress, i); |
| auto *eltAddrRHS = builder.createTupleElementAddr(loc, rhsAddress, i); |
| accumulateIndirect(destAddr, eltAddrRHS, loc); |
| } |
| return; |
| } |
| } |
| } |
| |
| //----------------------------------------------------------------------------// |
| // Array literal initialization differentiation |
| //----------------------------------------------------------------------------// |
| |
| void PullbackCloner::Implementation:: |
| accumulateArrayLiteralElementAddressAdjoints(SILBasicBlock *origBB, |
| SILValue originalValue, |
| AdjointValue arrayAdjointValue, |
| SILLocation loc) { |
| // Return if the original value is not the `Array` result of an |
| // `array.uninitialized_intrinsic` application. |
| auto *dti = dyn_cast_or_null<DestructureTupleInst>( |
| originalValue->getDefiningInstruction()); |
| if (!dti) |
| return; |
| if (!ArraySemanticsCall(dti->getOperand(), |
| semantics::ARRAY_UNINITIALIZED_INTRINSIC)) |
| return; |
| if (originalValue != dti->getResult(0)) |
| return; |
| // Accumulate the array's adjoint value into the adjoint buffers of its |
| // element addresses: `pointer_to_address` and `index_addr` instructions. |
| LLVM_DEBUG(getADDebugStream() |
| << "Accumulating adjoint value for array literal into element " |
| "address adjoint buffers" |
| << originalValue); |
| auto arrayAdjoint = materializeAdjointDirect(arrayAdjointValue, loc); |
| builder.setInsertionPoint(arrayAdjoint->getParentBlock()); |
| for (auto use : dti->getResult(1)->getUses()) { |
| auto *ptai = dyn_cast<PointerToAddressInst>(use->getUser()); |
| auto adjBuf = getAdjointBuffer(origBB, ptai); |
| auto *eltAdjBuf = getArrayAdjointElementBuffer(arrayAdjoint, 0, loc); |
| accumulateIndirect(adjBuf, eltAdjBuf, loc); |
| for (auto use : ptai->getUses()) { |
| if (auto *iai = dyn_cast<IndexAddrInst>(use->getUser())) { |
| auto *ili = cast<IntegerLiteralInst>(iai->getIndex()); |
| auto eltIndex = ili->getValue().getLimitedValue(); |
| auto adjBuf = getAdjointBuffer(origBB, iai); |
| auto *eltAdjBuf = |
| getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, loc); |
| accumulateIndirect(adjBuf, eltAdjBuf, loc); |
| } |
| } |
| } |
| } |
| |
| AllocStackInst *PullbackCloner::Implementation::getArrayAdjointElementBuffer( |
| SILValue arrayAdjoint, int eltIndex, SILLocation loc) { |
| auto &ctx = builder.getASTContext(); |
| auto arrayTanType = cast<StructType>(arrayAdjoint->getType().getASTType()); |
| auto arrayType = arrayTanType->getParent()->castTo<BoundGenericStructType>(); |
| auto eltTanType = arrayType->getGenericArgs().front()->getCanonicalType(); |
| auto eltTanSILType = remapType(SILType::getPrimitiveAddressType(eltTanType)); |
| // Get `function_ref` and generic signature of |
| // `Array.TangentVector.subscript.getter`. |
| auto *arrayTanStructDecl = arrayTanType->getStructOrBoundGenericStruct(); |
| auto subscriptLookup = |
| arrayTanStructDecl->lookupDirect(DeclBaseName::createSubscript()); |
| SubscriptDecl *subscriptDecl = nullptr; |
| for (auto *candidate : subscriptLookup) { |
| auto candidateModule = candidate->getModuleContext(); |
| if (candidateModule->getName() == ctx.Id_Differentiation || |
| candidateModule->isStdlibModule()) { |
| assert(!subscriptDecl && "Multiple `Array.TangentVector.subscript`s"); |
| subscriptDecl = cast<SubscriptDecl>(candidate); |
| #ifdef NDEBUG |
| break; |
| #endif |
| } |
| } |
| assert(subscriptDecl && "No `Array.TangentVector.subscript`"); |
| auto *subscriptGetterDecl = |
| subscriptDecl->getOpaqueAccessor(AccessorKind::Get); |
| assert(subscriptGetterDecl && "No `Array.TangentVector.subscript` getter"); |
| SILOptFunctionBuilder fb(getContext().getTransform()); |
| auto *subscriptGetterFn = fb.getOrCreateFunction( |
| loc, SILDeclRef(subscriptGetterDecl), NotForDefinition); |
| // %subscript_fn = function_ref @Array.TangentVector<T>.subscript.getter |
| auto *subscriptFnRef = builder.createFunctionRef(loc, subscriptGetterFn); |
| auto subscriptFnGenSig = |
| subscriptGetterFn->getLoweredFunctionType()->getSubstGenericSignature(); |
| // Apply `Array.TangentVector.subscript.getter` to get array element adjoint |
| // buffer. |
| // %index_literal = integer_literal $Builtin.IntXX, <index> |
| auto builtinIntType = |
| SILType::getPrimitiveObjectType(ctx.getIntDecl() |
| ->getStoredProperties() |
| .front() |
| ->getInterfaceType() |
| ->getCanonicalType()); |
| auto *eltIndexLiteral = |
| builder.createIntegerLiteral(loc, builtinIntType, eltIndex); |
| auto intType = SILType::getPrimitiveObjectType( |
| ctx.getIntDecl()->getDeclaredInterfaceType()->getCanonicalType()); |
| // %index_int = struct $Int (%index_literal) |
| auto *eltIndexInt = builder.createStruct(loc, intType, {eltIndexLiteral}); |
| auto *swiftModule = getModule().getSwiftModule(); |
| auto *diffProto = ctx.getProtocol(KnownProtocolKind::Differentiable); |
| auto diffConf = swiftModule->lookupConformance(eltTanType, diffProto); |
| assert(!diffConf.isInvalid() && "Missing conformance to `Differentiable`"); |
| auto *addArithProto = ctx.getProtocol(KnownProtocolKind::AdditiveArithmetic); |
| auto addArithConf = swiftModule->lookupConformance(eltTanType, addArithProto); |
| assert(!addArithConf.isInvalid() && |
| "Missing conformance to `AdditiveArithmetic`"); |
| auto subMap = SubstitutionMap::get(subscriptFnGenSig, {eltTanType}, |
| {addArithConf, diffConf}); |
| // %elt_adj = alloc_stack $T.TangentVector |
| // Create and register a local allocation. |
| auto *eltAdjBuffer = createFunctionLocalAllocation(eltTanSILType, loc); |
| // Temporarily change global builder insertion point and emit zero into the |
| // local allocation. |
| auto insertionPoint = builder.getInsertionBB(); |
| builder.setInsertionPoint(localAllocBuilder.getInsertionBB(), |
| localAllocBuilder.getInsertionPoint()); |
| emitZeroIndirect(eltTanType, eltAdjBuffer, loc); |
| builder.setInsertionPoint(insertionPoint); |
| // Immediately destroy the emitted zero value. |
| // NOTE: It is not efficient to emit a zero value then immediately destroy |
| // it. However, it was the easiest way to to avoid "lifetime mismatch in |
| // predecessors" memory lifetime verification errors for control flow |
| // differentiation. |
| // Perhaps we can avoid emitting a zero value if local allocations are created |
| // per pullback bb instead of all in the pullback entry: TF-1075. |
| builder.emitDestroyOperation(loc, eltAdjBuffer); |
| // apply %subscript_fn<T.TangentVector>(%elt_adj, %index_int, %array_adj) |
| builder.createApply(loc, subscriptFnRef, subMap, |
| {eltAdjBuffer, eltIndexInt, arrayAdjoint}); |
| return eltAdjBuffer; |
| } |
| |
| } // end namespace autodiff |
| } // end namespace swift |