| //===--- Common.cpp - Automatic differentiation common utils --*- C++ -*---===// |
| // |
| // This source file is part of the Swift.org open source project |
| // |
| // Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors |
| // Licensed under Apache License v2.0 with Runtime Library Exception |
| // |
| // See https://swift.org/LICENSE.txt for license information |
| // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Automatic differentiation common utilities. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #define DEBUG_TYPE "differentiation" |
| |
| #include "swift/SILOptimizer/Differentiation/Common.h" |
| #include "swift/AST/TypeCheckRequests.h" |
| #include "swift/SILOptimizer/Differentiation/ADContext.h" |
| |
| namespace swift { |
| namespace autodiff { |
| |
| raw_ostream &getADDebugStream() { return llvm::dbgs() << "[AD] "; } |
| |
| //===----------------------------------------------------------------------===// |
| // Helpers |
| //===----------------------------------------------------------------------===// |
| |
| ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v) { |
| // Find the `pointer_to_address` result, peering through `index_addr`. |
| auto *ptai = dyn_cast<PointerToAddressInst>(v); |
| if (auto *iai = dyn_cast<IndexAddrInst>(v)) |
| ptai = dyn_cast<PointerToAddressInst>(iai->getOperand(0)); |
| if (!ptai) |
| return nullptr; |
| // Return the `array.uninitialized_intrinsic` application, if it exists. |
| if (auto *dti = dyn_cast<DestructureTupleInst>( |
| ptai->getOperand()->getDefiningInstruction())) |
| return ArraySemanticsCall(dti->getOperand(), |
| semantics::ARRAY_UNINITIALIZED_INTRINSIC); |
| return nullptr; |
| } |
| |
| DestructureTupleInst *getSingleDestructureTupleUser(SILValue value) { |
| bool foundDestructureTupleUser = false; |
| if (!value->getType().is<TupleType>()) |
| return nullptr; |
| DestructureTupleInst *result = nullptr; |
| for (auto *use : value->getUses()) { |
| if (auto *dti = dyn_cast<DestructureTupleInst>(use->getUser())) { |
| assert(!foundDestructureTupleUser && |
| "There should only be one `destructure_tuple` user of a tuple"); |
| foundDestructureTupleUser = true; |
| result = dti; |
| } |
| } |
| return result; |
| } |
| |
| bool isSemanticMemberAccessor(SILFunction *original) { |
| auto *dc = original->getDeclContext(); |
| if (!dc) |
| return false; |
| auto *decl = dc->getAsDecl(); |
| if (!decl) |
| return false; |
| auto *accessor = dyn_cast<AccessorDecl>(decl); |
| if (!accessor) |
| return false; |
| // Currently, only getters and setters are supported. |
| // TODO(SR-12640): Support `modify` accessors. |
| if (accessor->getAccessorKind() != AccessorKind::Get && |
| accessor->getAccessorKind() != AccessorKind::Set) |
| return false; |
| // Accessor must come from a `var` declaration. |
| auto *varDecl = dyn_cast<VarDecl>(accessor->getStorage()); |
| if (!varDecl) |
| return false; |
| // Return true for stored property accessors. |
| if (varDecl->hasStorage() && varDecl->isInstanceMember()) |
| return true; |
| // Return true for properties that have attached property wrappers. |
| if (varDecl->hasAttachedPropertyWrapper()) |
| return true; |
| // Otherwise, return false. |
| // User-defined accessors can never be supported because they may use custom |
| // logic that does not semantically perform a member access. |
| return false; |
| } |
| |
| bool hasSemanticMemberAccessorCallee(ApplySite applySite) { |
| if (auto *FRI = dyn_cast<FunctionRefBaseInst>(applySite.getCallee())) |
| if (auto *F = FRI->getReferencedFunctionOrNull()) |
| return isSemanticMemberAccessor(F); |
| return false; |
| } |
| |
| void forEachApplyDirectResult( |
| FullApplySite applySite, |
| llvm::function_ref<void(SILValue)> resultCallback) { |
| switch (applySite.getKind()) { |
| case FullApplySiteKind::ApplyInst: { |
| auto *ai = cast<ApplyInst>(applySite.getInstruction()); |
| if (!ai->getType().is<TupleType>()) { |
| resultCallback(ai); |
| return; |
| } |
| if (auto *dti = getSingleDestructureTupleUser(ai)) |
| for (auto directResult : dti->getResults()) |
| resultCallback(directResult); |
| break; |
| } |
| case FullApplySiteKind::BeginApplyInst: { |
| auto *bai = cast<BeginApplyInst>(applySite.getInstruction()); |
| for (auto directResult : bai->getResults()) |
| resultCallback(directResult); |
| break; |
| } |
| case FullApplySiteKind::TryApplyInst: { |
| auto *tai = cast<TryApplyInst>(applySite.getInstruction()); |
| for (auto *succBB : tai->getSuccessorBlocks()) |
| for (auto *arg : succBB->getArguments()) |
| resultCallback(arg); |
| break; |
| } |
| } |
| } |
| |
| void collectAllFormalResultsInTypeOrder(SILFunction &function, |
| SmallVectorImpl<SILValue> &results) { |
| SILFunctionConventions convs(function.getLoweredFunctionType(), |
| function.getModule()); |
| auto indResults = function.getIndirectResults(); |
| auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator()); |
| auto retVal = retInst->getOperand(); |
| SmallVector<SILValue, 8> dirResults; |
| if (auto *tupleInst = |
| dyn_cast_or_null<TupleInst>(retVal->getDefiningInstruction())) |
| dirResults.append(tupleInst->getElements().begin(), |
| tupleInst->getElements().end()); |
| else |
| dirResults.push_back(retVal); |
| unsigned indResIdx = 0, dirResIdx = 0; |
| for (auto &resInfo : convs.getResults()) |
| results.push_back(resInfo.isFormalDirect() ? dirResults[dirResIdx++] |
| : indResults[indResIdx++]); |
| // Treat `inout` parameters as semantic results. |
| // Append `inout` parameters after formal results. |
| for (auto i : range(convs.getNumParameters())) { |
| auto paramInfo = convs.getParameters()[i]; |
| if (!paramInfo.isIndirectMutating()) |
| continue; |
| auto *argument = function.getArgumentsWithoutIndirectResults()[i]; |
| results.push_back(argument); |
| } |
| } |
| |
| void collectAllDirectResultsInTypeOrder(SILFunction &function, |
| SmallVectorImpl<SILValue> &results) { |
| SILFunctionConventions convs(function.getLoweredFunctionType(), |
| function.getModule()); |
| auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator()); |
| auto retVal = retInst->getOperand(); |
| if (auto *tupleInst = dyn_cast<TupleInst>(retVal)) |
| results.append(tupleInst->getElements().begin(), |
| tupleInst->getElements().end()); |
| else |
| results.push_back(retVal); |
| } |
| |
| void collectAllActualResultsInTypeOrder( |
| ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults, |
| SmallVectorImpl<SILValue> &results) { |
| auto calleeConvs = ai->getSubstCalleeConv(); |
| unsigned indResIdx = 0, dirResIdx = 0; |
| for (auto &resInfo : calleeConvs.getResults()) { |
| results.push_back(resInfo.isFormalDirect() |
| ? extractedDirectResults[dirResIdx++] |
| : ai->getIndirectSILResults()[indResIdx++]); |
| } |
| } |
| |
| void collectMinimalIndicesForFunctionCall( |
| ApplyInst *ai, AutoDiffConfig parentConfig, |
| const DifferentiableActivityInfo &activityInfo, |
| SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> ¶mIndices, |
| SmallVectorImpl<unsigned> &resultIndices) { |
| auto calleeFnTy = ai->getSubstCalleeType(); |
| auto calleeConvs = ai->getSubstCalleeConv(); |
| // Parameter indices are indices (in the callee type signature) of parameter |
| // arguments that are varied or are arguments. |
| // Record all parameter indices in type order. |
| unsigned currentParamIdx = 0; |
| for (auto applyArg : ai->getArgumentsWithoutIndirectResults()) { |
| if (activityInfo.isActive(applyArg, parentConfig)) |
| paramIndices.push_back(currentParamIdx); |
| ++currentParamIdx; |
| } |
| // Result indices are indices (in the callee type signature) of results that |
| // are useful. |
| SmallVector<SILValue, 8> directResults; |
| forEachApplyDirectResult(ai, [&](SILValue directResult) { |
| directResults.push_back(directResult); |
| }); |
| auto indirectResults = ai->getIndirectSILResults(); |
| // Record all results and result indices in type order. |
| results.reserve(calleeFnTy->getNumResults()); |
| unsigned dirResIdx = 0; |
| unsigned indResIdx = calleeConvs.getSILArgIndexOfFirstIndirectResult(); |
| for (auto &resAndIdx : enumerate(calleeConvs.getResults())) { |
| auto &res = resAndIdx.value(); |
| unsigned idx = resAndIdx.index(); |
| if (res.isFormalDirect()) { |
| results.push_back(directResults[dirResIdx]); |
| if (auto dirRes = directResults[dirResIdx]) |
| if (dirRes && activityInfo.isActive(dirRes, parentConfig)) |
| resultIndices.push_back(idx); |
| ++dirResIdx; |
| } else { |
| results.push_back(indirectResults[indResIdx]); |
| if (activityInfo.isActive(indirectResults[indResIdx], parentConfig)) |
| resultIndices.push_back(idx); |
| ++indResIdx; |
| } |
| } |
| // Record all `inout` parameters as results. |
| auto inoutParamResultIndex = calleeFnTy->getNumResults(); |
| for (auto ¶mAndIdx : enumerate(calleeConvs.getParameters())) { |
| auto ¶m = paramAndIdx.value(); |
| if (!param.isIndirectMutating()) |
| continue; |
| unsigned idx = paramAndIdx.index(); |
| auto inoutArg = ai->getArgument(idx); |
| results.push_back(inoutArg); |
| resultIndices.push_back(inoutParamResultIndex++); |
| } |
| // Make sure the function call has active results. |
| #ifndef NDEBUG |
| auto numResults = calleeFnTy->getNumResults() + |
| calleeFnTy->getNumIndirectMutatingParameters(); |
| assert(results.size() == numResults); |
| assert(llvm::any_of(results, [&](SILValue result) { |
| return activityInfo.isActive(result, parentConfig); |
| })); |
| #endif |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Diagnostic utilities |
| //===----------------------------------------------------------------------===// |
| |
| SILLocation getValidLocation(SILValue v) { |
| auto loc = v.getLoc(); |
| if (loc.isNull() || loc.getSourceLoc().isInvalid()) |
| loc = v->getFunction()->getLocation(); |
| return loc; |
| } |
| |
| SILLocation getValidLocation(SILInstruction *inst) { |
| auto loc = inst->getLoc(); |
| if (loc.isNull() || loc.getSourceLoc().isInvalid()) |
| loc = inst->getFunction()->getLocation(); |
| return loc; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Tangent property lookup utilities |
| //===----------------------------------------------------------------------===// |
| |
| VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField, |
| CanType baseType, SILLocation loc, |
| DifferentiationInvoker invoker) { |
| auto &astCtx = context.getASTContext(); |
| auto tanFieldInfo = evaluateOrDefault( |
| astCtx.evaluator, TangentStoredPropertyRequest{originalField, baseType}, |
| TangentPropertyInfo(nullptr)); |
| // If no error, return the tangent property. |
| if (tanFieldInfo) |
| return tanFieldInfo.tangentProperty; |
| // Otherwise, diagnose error and return nullptr. |
| assert(tanFieldInfo.error); |
| auto *parentDC = originalField->getDeclContext(); |
| assert(parentDC->isTypeContext()); |
| auto parentDeclName = parentDC->getSelfNominalTypeDecl()->getNameStr(); |
| auto fieldName = originalField->getNameStr(); |
| auto sourceLoc = loc.getSourceLoc(); |
| switch (tanFieldInfo.error->kind) { |
| case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty: |
| llvm_unreachable( |
| "`@noDerivative` stored property accesses should not be " |
| "differentiated; activity analysis should not mark as varied"); |
| case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable: |
| context.emitNondifferentiabilityError( |
| sourceLoc, invoker, |
| diag::autodiff_stored_property_parent_not_differentiable, |
| parentDeclName, fieldName); |
| break; |
| case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable: |
| context.emitNondifferentiabilityError( |
| sourceLoc, invoker, diag::autodiff_stored_property_not_differentiable, |
| parentDeclName, fieldName, originalField->getInterfaceType()); |
| break; |
| case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct: |
| context.emitNondifferentiabilityError( |
| sourceLoc, invoker, diag::autodiff_stored_property_tangent_not_struct, |
| parentDeclName, fieldName); |
| break; |
| case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound: |
| context.emitNondifferentiabilityError( |
| sourceLoc, invoker, |
| diag::autodiff_stored_property_no_corresponding_tangent, parentDeclName, |
| fieldName); |
| break; |
| case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType: |
| context.emitNondifferentiabilityError( |
| sourceLoc, invoker, diag::autodiff_tangent_property_wrong_type, |
| parentDeclName, fieldName, tanFieldInfo.error->getType()); |
| break; |
| case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored: |
| context.emitNondifferentiabilityError( |
| sourceLoc, invoker, diag::autodiff_tangent_property_not_stored, |
| parentDeclName, fieldName); |
| break; |
| } |
| return nullptr; |
| } |
| |
| VarDecl *getTangentStoredProperty(ADContext &context, |
| SingleValueInstruction *projectionInst, |
| CanType baseType, |
| DifferentiationInvoker invoker) { |
| assert(isa<StructExtractInst>(projectionInst) || |
| isa<StructElementAddrInst>(projectionInst) || |
| isa<RefElementAddrInst>(projectionInst)); |
| Projection proj(projectionInst); |
| auto loc = getValidLocation(projectionInst); |
| auto *field = proj.getVarDecl(projectionInst->getOperand(0)->getType()); |
| return getTangentStoredProperty(context, field, baseType, |
| loc, invoker); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Code emission utilities |
| //===----------------------------------------------------------------------===// |
| |
| SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder, |
| SILLocation loc) { |
| if (elements.size() == 1) |
| return elements.front(); |
| return builder.createTuple(loc, elements); |
| } |
| |
| void extractAllElements(SILValue value, SILBuilder &builder, |
| SmallVectorImpl<SILValue> &results) { |
| auto tupleType = value->getType().getAs<TupleType>(); |
| if (!tupleType) { |
| results.push_back(value); |
| return; |
| } |
| if (builder.hasOwnership()) { |
| auto *dti = builder.createDestructureTuple(value.getLoc(), value); |
| results.append(dti->getResults().begin(), dti->getResults().end()); |
| return; |
| } |
| for (auto i : range(tupleType->getNumElements())) |
| results.push_back(builder.createTupleExtract(value.getLoc(), value, i)); |
| } |
| |
| void emitZeroIntoBuffer(SILBuilder &builder, CanType type, |
| SILValue bufferAccess, SILLocation loc) { |
| auto &astCtx = builder.getASTContext(); |
| auto *swiftMod = builder.getModule().getSwiftModule(); |
| auto &typeConverter = builder.getModule().Types; |
| // Look up conformance to `AdditiveArithmetic`. |
| auto *additiveArithmeticProto = |
| astCtx.getProtocol(KnownProtocolKind::AdditiveArithmetic); |
| auto confRef = swiftMod->lookupConformance(type, additiveArithmeticProto); |
| assert(!confRef.isInvalid() && "Missing conformance to `AdditiveArithmetic`"); |
| // Look up `AdditiveArithmetic.zero.getter`. |
| auto zeroDeclLookup = additiveArithmeticProto->lookupDirect(astCtx.Id_zero); |
| auto *zeroDecl = cast<VarDecl>(zeroDeclLookup.front()); |
| assert(zeroDecl->isProtocolRequirement()); |
| auto *accessorDecl = zeroDecl->getOpaqueAccessor(AccessorKind::Get); |
| SILDeclRef accessorDeclRef(accessorDecl, SILDeclRef::Kind::Func); |
| auto silFnType = typeConverter.getConstantType( |
| TypeExpansionContext::minimal(), accessorDeclRef); |
| // %wm = witness_method ... |
| auto *getter = builder.createWitnessMethod(loc, type, confRef, |
| accessorDeclRef, silFnType); |
| // %metatype = metatype $T |
| auto metatypeType = CanMetatypeType::get(type, MetatypeRepresentation::Thick); |
| auto metatype = builder.createMetatype( |
| loc, SILType::getPrimitiveObjectType(metatypeType)); |
| auto subMap = SubstitutionMap::getProtocolSubstitutions( |
| additiveArithmeticProto, type, confRef); |
| builder.createApply(loc, getter, subMap, {bufferAccess, metatype}, |
| /*isNonThrowing*/ false); |
| builder.emitDestroyValueOperation(loc, getter); |
| } |
| |
| SILValue emitMemoryLayoutSize( |
| SILBuilder &builder, SILLocation loc, CanType type) { |
| auto &ctx = builder.getASTContext(); |
| auto id = ctx.getIdentifier(getBuiltinName(BuiltinValueKind::Sizeof)); |
| auto *builtin = cast<FuncDecl>(getBuiltinValueDecl(ctx, id)); |
| auto metatypeTy = SILType::getPrimitiveObjectType( |
| CanMetatypeType::get(type, MetatypeRepresentation::Thin)); |
| auto metatypeVal = builder.createMetatype(loc, metatypeTy); |
| return builder.createBuiltin( |
| loc, id, SILType::getBuiltinWordType(ctx), |
| SubstitutionMap::get( |
| builtin->getGenericSignature(), ArrayRef<Type>{type}, {}), |
| {metatypeVal}); |
| } |
| |
| SILValue emitProjectTopLevelSubcontext( |
| SILBuilder &builder, SILLocation loc, SILValue context, |
| SILType subcontextType) { |
| assert(context.getOwnershipKind() == OwnershipKind::Guaranteed); |
| auto &ctx = builder.getASTContext(); |
| auto id = ctx.getIdentifier( |
| getBuiltinName(BuiltinValueKind::AutoDiffProjectTopLevelSubcontext)); |
| assert(context->getType() == SILType::getNativeObjectType(ctx)); |
| auto *subcontextAddr = builder.createBuiltin( |
| loc, id, SILType::getRawPointerType(ctx), SubstitutionMap(), {context}); |
| return builder.createPointerToAddress( |
| loc, subcontextAddr, subcontextType.getAddressType(), /*isStrict*/ true); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Utilities for looking up derivatives of functions |
| //===----------------------------------------------------------------------===// |
| |
| /// Returns the AbstractFunctionDecl corresponding to `F`. If there isn't one, |
| /// returns `nullptr`. |
| static AbstractFunctionDecl *findAbstractFunctionDecl(SILFunction *F) { |
| auto *DC = F->getDeclContext(); |
| if (!DC) |
| return nullptr; |
| auto *D = DC->getAsDecl(); |
| if (!D) |
| return nullptr; |
| return dyn_cast<AbstractFunctionDecl>(D); |
| } |
| |
| SILDifferentiabilityWitness * |
| getExactDifferentiabilityWitness(SILModule &module, SILFunction *original, |
| IndexSubset *parameterIndices, |
| IndexSubset *resultIndices) { |
| for (auto *w : module.lookUpDifferentiabilityWitnessesForFunction( |
| original->getName())) { |
| if (w->getParameterIndices() == parameterIndices && |
| w->getResultIndices() == resultIndices) |
| return w; |
| } |
| return nullptr; |
| } |
| |
| Optional<AutoDiffConfig> |
| findMinimalDerivativeConfiguration(AbstractFunctionDecl *original, |
| IndexSubset *parameterIndices, |
| IndexSubset *&minimalASTParameterIndices) { |
| Optional<AutoDiffConfig> minimalConfig = None; |
| auto configs = original->getDerivativeFunctionConfigurations(); |
| for (auto config : configs) { |
| auto *silParameterIndices = autodiff::getLoweredParameterIndices( |
| config.parameterIndices, |
| original->getInterfaceType()->castTo<AnyFunctionType>()); |
| // If all indices in `parameterIndices` are in `daParameterIndices`, and |
| // it has fewer indices than our current candidate and a primitive VJP, |
| // then `attr` is our new candidate. |
| // |
| // NOTE(TF-642): `attr` may come from a un-partial-applied function and |
| // have larger capacity than the desired indices. We expect this logic to |
| // go away when `partial_apply` supports `@differentiable` callees. |
| if (silParameterIndices->isSupersetOf(parameterIndices->extendingCapacity( |
| original->getASTContext(), silParameterIndices->getCapacity())) && |
| // fewer parameters than before |
| (!minimalConfig || |
| silParameterIndices->getNumIndices() < |
| minimalConfig->parameterIndices->getNumIndices())) { |
| minimalASTParameterIndices = config.parameterIndices; |
| minimalConfig = |
| AutoDiffConfig(silParameterIndices, config.resultIndices, |
| autodiff::getDifferentiabilityWitnessGenericSignature( |
| original->getGenericSignature(), |
| config.derivativeGenericSignature)); |
| } |
| } |
| return minimalConfig; |
| } |
| |
| SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness( |
| SILModule &module, SILFunction *original, IndexSubset *parameterIndices, |
| IndexSubset *resultIndices) { |
| // AST differentiability witnesses always have a single result. |
| if (resultIndices->getCapacity() != 1 || !resultIndices->contains(0)) |
| return nullptr; |
| |
| // Explicit differentiability witnesses only exist on SIL functions that come |
| // from AST functions. |
| auto *originalAFD = findAbstractFunctionDecl(original); |
| if (!originalAFD) |
| return nullptr; |
| |
| IndexSubset *minimalASTParameterIndices = nullptr; |
| auto minimalConfig = findMinimalDerivativeConfiguration( |
| originalAFD, parameterIndices, minimalASTParameterIndices); |
| if (!minimalConfig) |
| return nullptr; |
| |
| std::string originalName = original->getName().str(); |
| // If original function requires a foreign entry point, use the foreign SIL |
| // function to get or create the minimal differentiability witness. |
| if (requiresForeignEntryPoint(originalAFD)) { |
| originalName = SILDeclRef(originalAFD).asForeign().mangle(); |
| original = module.lookUpFunction(SILDeclRef(originalAFD).asForeign()); |
| } |
| |
| auto *existingWitness = |
| module.lookUpDifferentiabilityWitness({originalName, *minimalConfig}); |
| if (existingWitness) |
| return existingWitness; |
| |
| assert(original->isExternalDeclaration() && |
| "SILGen should create differentiability witnesses for all function " |
| "definitions with explicit differentiable attributes"); |
| |
| return SILDifferentiabilityWitness::createDeclaration( |
| module, SILLinkage::PublicExternal, original, |
| minimalConfig->parameterIndices, minimalConfig->resultIndices, |
| minimalConfig->derivativeGenericSignature); |
| } |
| |
| } // end namespace autodiff |
| } // end namespace swift |