| //===--- DerivedConformanceDifferentiable.cpp - Derived Differentiable ----===// |
| // |
| // This source file is part of the Swift.org open source project |
| // |
| // Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors |
| // Licensed under Apache License v2.0 with Runtime Library Exception |
| // |
| // See https://swift.org/LICENSE.txt for license information |
| // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // SWIFT_ENABLE_TENSORFLOW |
| // |
| // This file implements explicit derivation of the Differentiable protocol for |
| // struct and class types. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "CodeSynthesis.h" |
| #include "TypeChecker.h" |
| #include "swift/AST/AutoDiff.h" |
| #include "swift/AST/Decl.h" |
| #include "swift/AST/Expr.h" |
| #include "swift/AST/Module.h" |
| #include "swift/AST/ParameterList.h" |
| #include "swift/AST/Pattern.h" |
| #include "swift/AST/ProtocolConformance.h" |
| #include "swift/AST/Stmt.h" |
| #include "swift/AST/Types.h" |
| #include "DerivedConformances.h" |
| |
| using namespace swift; |
| |
| /// Return the protocol requirement with the specified name. |
| /// TODO: Move function to shared place for use with other derived conformances. |
| static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) { |
| auto lookup = proto->lookupDirect(name); |
| // Erase declarations that are not protocol requirements. |
| // This is important for removing default implementations of the same name. |
| llvm::erase_if(lookup, [](ValueDecl *v) { |
| return !isa<ProtocolDecl>(v->getDeclContext()) || |
| !v->isProtocolRequirement(); |
| }); |
| assert(lookup.size() <= 1 && "Ambiguous protocol requirement"); |
| return lookup.front(); |
| } |
| |
| /// Get the stored properties of a nominal type that are relevant for |
| /// differentiation, except the ones tagged `@noDerivative`. |
| static void |
| getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, |
| DeclContext *DC, |
| SmallVectorImpl<VarDecl *> &result) { |
| auto &C = nominal->getASTContext(); |
| auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); |
| for (auto *vd : nominal->getStoredProperties()) { |
| if (vd->getAttrs().hasAttribute<NoDerivativeAttr>()) |
| continue; |
| if (vd->isLet()) |
| continue; |
| if (!vd->hasInterfaceType()) |
| C.getLazyResolver()->resolveDeclSignature(vd); |
| if (!vd->hasInterfaceType()) |
| continue; |
| auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType()); |
| if (!TypeChecker::conformsToProtocol(varType, diffableProto, nominal, |
| None)) |
| continue; |
| result.push_back(vd); |
| } |
| } |
| |
| /// Convert the given `ValueDecl` to a `StructDecl` if it is a `StructDecl` or a |
| /// `TypeDecl` with an underlying struct type. Otherwise, return `nullptr`. |
| static StructDecl *convertToStructDecl(ValueDecl *v) { |
| if (auto *structDecl = dyn_cast<StructDecl>(v)) |
| return structDecl; |
| auto *typeDecl = dyn_cast<TypeDecl>(v); |
| if (!typeDecl) |
| return nullptr; |
| return dyn_cast_or_null<StructDecl>( |
| typeDecl->getDeclaredInterfaceType()->getAnyNominal()); |
| } |
| |
| /// Get the `Differentiable` protocol `TangentVector` associated type for the |
| /// given `VarDecl`. |
| /// TODO: Generalize and move function to shared place for use with other derived |
| /// conformances. |
| static Type getTangentVectorType(VarDecl *decl, DeclContext *DC) { |
| auto &C = decl->getASTContext(); |
| auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); |
| if (!decl->hasInterfaceType()) |
| C.getLazyResolver()->resolveDeclSignature(decl); |
| auto varType = DC->mapTypeIntoContext(decl->getValueInterfaceType()); |
| auto conf = TypeChecker::conformsToProtocol(varType, diffableProto, DC, |
| None); |
| if (!conf) |
| return nullptr; |
| Type tangentType = conf->getTypeWitnessByName(varType, C.Id_TangentVector); |
| return tangentType; |
| } |
| |
| // Get the `Differentiable` protocol associated `TangentVector` struct for the |
| // given nominal `DeclContext`. Asserts that the `TangentVector` struct type |
| // exists. |
| static StructDecl *getTangentVectorStructDecl(DeclContext *DC) { |
| assert(DC->getSelfNominalTypeDecl() && "Must be a nominal `DeclContext`"); |
| auto &C = DC->getASTContext(); |
| auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); |
| assert(diffableProto && "`Differentiable` protocol not found"); |
| auto conf = TypeChecker::conformsToProtocol(DC->getSelfTypeInContext(), |
| diffableProto, DC, None); |
| assert(conf && "Nominal must conform to `Differentiable`"); |
| auto assocType = conf->getTypeWitnessByName( |
| DC->getSelfTypeInContext(), C.Id_TangentVector); |
| assert(assocType && "`Differentiable.TangentVector` type not found"); |
| auto *structDecl = dyn_cast<StructDecl>(assocType->getAnyNominal()); |
| assert(structDecl && "Associated type must be a struct type"); |
| return structDecl; |
| } |
| |
| bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal, |
| DeclContext *DC) { |
| // Nominal type must be a struct or class. (No stored properties is okay.) |
| if (!isa<StructDecl>(nominal) && !isa<ClassDecl>(nominal)) |
| return false; |
| auto &C = nominal->getASTContext(); |
| auto *lazyResolver = C.getLazyResolver(); |
| auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); |
| auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); |
| |
| // Nominal type must not customize `TangentVector` to anything other than |
| // `Self`. Otherwise, synthesis is semantically unsupported. |
| auto tangentDecls = nominal->lookupDirect(C.Id_TangentVector); |
| auto nominalTypeInContext = |
| DC->mapTypeIntoContext(nominal->getDeclaredInterfaceType()); |
| |
| auto isValidAssocTypeCandidate = [&](ValueDecl *v) -> StructDecl * { |
| // Valid candidate must be a struct or a typealias to a struct. |
| auto *structDecl = convertToStructDecl(v); |
| if (!structDecl) |
| return nullptr; |
| // Valid candidate must either: |
| // 1. Be implicit (previously synthesized). |
| if (structDecl->isImplicit()) |
| return structDecl; |
| // 2. Equal nominal's implicit parent. |
| // This can occur during mutually recursive constraints. Example: |
| // `X == X.TangentVector`. |
| if (nominal->isImplicit() && structDecl == nominal->getDeclContext() && |
| TypeChecker::conformsToProtocol(structDecl->getDeclaredInterfaceType(), |
| diffableProto, DC, None)) |
| return structDecl; |
| // 3. Equal nominal and conform to `AdditiveArithmetic`. |
| if (structDecl == nominal) { |
| // Check conformance to `AdditiveArithmetic`. |
| if (TypeChecker::conformsToProtocol(nominalTypeInContext, addArithProto, |
| DC, None)) |
| return structDecl; |
| } |
| // Otherwise, candidate is invalid. |
| return nullptr; |
| }; |
| |
| auto invalidTangentDecls = llvm::partition(tangentDecls, [&](ValueDecl *v) { |
| return isValidAssocTypeCandidate(v); |
| }); |
| |
| auto validTangentDeclCount = |
| std::distance(tangentDecls.begin(), invalidTangentDecls); |
| auto invalidTangentDeclCount = |
| std::distance(invalidTangentDecls, tangentDecls.end()); |
| |
| // There cannot be any invalid `TangentVector` types. |
| // There can be at most one valid `TangentVector` type. |
| if (invalidTangentDeclCount != 0 || validTangentDeclCount > 1) |
| return false; |
| |
| // All stored properties not marked with `@noDerivative`: |
| // - Must conform to `Differentiable`. |
| // - Must not have any `let` stored properties with an initial value. |
| // - This restriction may be lifted later with support for "true" memberwise |
| // initializers that initialize all stored properties, including initial |
| // value information. |
| SmallVector<VarDecl *, 16> diffProperties; |
| getStoredPropertiesForDifferentiation(nominal, DC, diffProperties); |
| return llvm::all_of(diffProperties, [&](VarDecl *v) { |
| if (!v->hasInterfaceType()) |
| lazyResolver->resolveDeclSignature(v); |
| if (!v->hasInterfaceType()) |
| return false; |
| auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType()); |
| return (bool)TypeChecker::conformsToProtocol(varType, diffableProto, DC, |
| None); |
| }); |
| } |
| |
| /// Determine if a EuclideanDifferentiable requirement can be derived for a type. |
| /// |
| /// \returns True if the requirement can be derived. |
| bool DerivedConformance::canDeriveEuclideanDifferentiable( |
| NominalTypeDecl *nominal, DeclContext *DC) { |
| if (!canDeriveDifferentiable(nominal, DC)) |
| return false; |
| auto &C = nominal->getASTContext(); |
| auto *lazyResolver = C.getLazyResolver(); |
| auto *eucDiffProto = |
| C.getProtocol(KnownProtocolKind::EuclideanDifferentiable); |
| // Return true if all differentiation stored properties conform to |
| // `AdditiveArithmetic` and their `TangentVector` equals themselves. |
| SmallVector<VarDecl *, 16> diffProperties; |
| getStoredPropertiesForDifferentiation(nominal, DC, diffProperties); |
| return llvm::all_of(diffProperties, [&](VarDecl *member) { |
| if (!member->hasInterfaceType()) |
| lazyResolver->resolveDeclSignature(member); |
| if (!member->hasInterfaceType()) |
| return false; |
| auto varType = DC->mapTypeIntoContext(member->getValueInterfaceType()); |
| return (bool)TypeChecker::conformsToProtocol( |
| varType, eucDiffProto, DC, None); |
| }); |
| } |
| |
| /// Synthesize body for a `Differentiable` method requirement. |
| static std::pair<BraceStmt *, bool> |
| deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl, |
| Identifier methodName, |
| Identifier methodParamLabel) { |
| auto *parentDC = funcDecl->getParent(); |
| auto *nominal = parentDC->getSelfNominalTypeDecl(); |
| auto &C = nominal->getASTContext(); |
| |
| // Get method protocol requirement. |
| auto *diffProto = C.getProtocol(KnownProtocolKind::Differentiable); |
| auto *methodReq = getProtocolRequirement(diffProto, methodName); |
| |
| // Get references to `self` and parameter declarations. |
| auto *selfDecl = funcDecl->getImplicitSelfDecl(); |
| auto *selfDRE = |
| new (C) DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true); |
| auto *paramDecl = funcDecl->getParameters()->get(0); |
| auto *paramDRE = |
| new (C) DeclRefExpr(paramDecl, DeclNameLoc(), /*Implicit*/ true); |
| |
| SmallVector<VarDecl *, 8> diffProperties; |
| getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties); |
| |
| // Create call expression applying a member method to a parameter member. |
| // Format: `<member>.method(<parameter>.<member>)`. |
| // Example: `x.move(along: direction.x)`. |
| auto createMemberMethodCallExpr = [&](VarDecl *member) -> Expr * { |
| auto *module = nominal->getModuleContext(); |
| auto memberType = |
| parentDC->mapTypeIntoContext(member->getValueInterfaceType()); |
| auto confRef = module->lookupConformance(memberType, diffProto); |
| assert(confRef && "Member does not conform to `Differentiable`"); |
| |
| // Get member type's method, e.g. `Member.move(along:)`. |
| // Use protocol requirement declaration for the method by default: this |
| // will be dynamically dispatched. |
| ValueDecl *memberMethodDecl = methodReq; |
| // If conformance reference is concrete, then use concrete witness |
| // declaration for the operator. |
| if (confRef->isConcrete()) |
| memberMethodDecl = confRef->getConcrete()->getWitnessDecl( |
| methodReq); |
| assert(memberMethodDecl && "Member method declaration must exist"); |
| auto memberMethodDRE = |
| new (C) DeclRefExpr(memberMethodDecl, DeclNameLoc(), /*Implicit*/ true); |
| memberMethodDRE->setFunctionRefKind(FunctionRefKind::SingleApply); |
| |
| // Create reference to member method: `x.move(along:)`. |
| auto memberExpr = |
| new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(), |
| /*Implicit*/ true); |
| auto memberMethodExpr = |
| new (C) DotSyntaxCallExpr(memberMethodDRE, SourceLoc(), memberExpr); |
| |
| // Create reference to parameter member: `direction.x`. |
| VarDecl *paramMember = nullptr; |
| auto *paramNominal = paramDecl->getType()->getAnyNominal(); |
| assert(paramNominal && "Parameter should have a nominal type"); |
| // Find parameter member corresponding to returned nominal member. |
| for (auto *candidate : paramNominal->getStoredProperties()) { |
| if (candidate->getName() == member->getName()) { |
| paramMember = candidate; |
| break; |
| } |
| } |
| assert(paramMember && "Could not find corresponding parameter member"); |
| auto *paramMemberExpr = |
| new (C) MemberRefExpr(paramDRE, SourceLoc(), paramMember, DeclNameLoc(), |
| /*Implicit*/ true); |
| // Create expression: `x.move(along: direction.x)`. |
| return CallExpr::createImplicit(C, memberMethodExpr, {paramMemberExpr}, |
| {methodParamLabel}); |
| }; |
| |
| // Create array of member method call expressions. |
| llvm::SmallVector<ASTNode, 2> memberMethodCallExprs; |
| llvm::SmallVector<Identifier, 2> memberNames; |
| for (auto *member : diffProperties) { |
| memberMethodCallExprs.push_back(createMemberMethodCallExpr(member)); |
| memberNames.push_back(member->getName()); |
| } |
| auto *braceStmt = BraceStmt::create(C, SourceLoc(), memberMethodCallExprs, |
| SourceLoc(), true); |
| return std::pair<BraceStmt *, bool>(braceStmt, false); |
| } |
| |
| /// Synthesize body for `move(along:)`. |
| static std::pair<BraceStmt *, bool> |
| deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) { |
| auto &C = funcDecl->getASTContext(); |
| return deriveBodyDifferentiable_method(funcDecl, C.Id_move, |
| C.getIdentifier("along")); |
| } |
| |
| /// Synthesize function declaration for a `Differentiable` method requirement. |
| static ValueDecl *deriveDifferentiable_method( |
| DerivedConformance &derived, Identifier methodName, Identifier argumentName, |
| Identifier parameterName, Type parameterType, Type returnType, |
| AbstractFunctionDecl::BodySynthesizer bodySynthesizer) { |
| auto *nominal = derived.Nominal; |
| auto &C = derived.TC.Context; |
| auto *parentDC = derived.getConformanceContext(); |
| |
| auto *param = |
| new (C) ParamDecl(ParamDecl::Specifier::Default, SourceLoc(), SourceLoc(), |
| argumentName, SourceLoc(), parameterName, parentDC); |
| param->setInterfaceType(parameterType); |
| ParameterList *params = ParameterList::create(C, {param}); |
| |
| DeclName declName(C, methodName, params); |
| auto *funcDecl = FuncDecl::create(C, SourceLoc(), StaticSpellingKind::None, |
| SourceLoc(), declName, SourceLoc(), |
| /*Throws*/ false, SourceLoc(), |
| /*GenericParams=*/nullptr, params, |
| TypeLoc::withoutLoc(returnType), parentDC); |
| if (!nominal->getSelfClassDecl()) |
| funcDecl->setSelfAccessKind(SelfAccessKind::Mutating); |
| funcDecl->setImplicit(); |
| funcDecl->setBodySynthesizer(bodySynthesizer.Fn, bodySynthesizer.Context); |
| |
| funcDecl->setGenericSignature(parentDC->getGenericSignatureOfContext()); |
| funcDecl->computeType(); |
| funcDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); |
| |
| derived.addMembersToConformanceContext({funcDecl}); |
| C.addSynthesizedDecl(funcDecl); |
| |
| return funcDecl; |
| } |
| |
| /// Synthesize the `move(along:)` function declaration. |
| static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) { |
| auto &C = derived.TC.Context; |
| auto *parentDC = derived.getConformanceContext(); |
| |
| auto *tangentDecl = getTangentVectorStructDecl(parentDC); |
| auto tangentType = tangentDecl->getDeclaredInterfaceType(); |
| |
| return deriveDifferentiable_method( |
| derived, C.Id_move, C.getIdentifier("along"), |
| C.getIdentifier("direction"), tangentType, C.TheEmptyTupleType, |
| {deriveBodyDifferentiable_move, nullptr}); |
| } |
| |
| /// Synthesize the `differentiableVectorView` property declaration. |
| static ValueDecl *deriveEuclideanDifferentiable_differentiableVectorView( |
| DerivedConformance &derived) { |
| auto &C = derived.TC.Context; |
| auto *parentDC = derived.getConformanceContext(); |
| |
| auto *tangentDecl = getTangentVectorStructDecl(parentDC); |
| auto tangentType = tangentDecl->getDeclaredInterfaceType(); |
| auto tangentContextualType = parentDC->mapTypeIntoContext(tangentType); |
| |
| VarDecl *vectorViewDecl; |
| PatternBindingDecl *pbDecl; |
| std::tie(vectorViewDecl, pbDecl) = derived.declareDerivedProperty( |
| C.Id_differentiableVectorView, tangentType, tangentContextualType, |
| /*isStatic*/ false, /*isFinal*/ true); |
| |
| struct GetterSynthesizerContext { |
| StructDecl *tangentDecl; |
| Type tangentContextualType; |
| }; |
| |
| auto getterSynthesizer = [](AbstractFunctionDecl *getterDecl, void *ctx) |
| -> std::pair<BraceStmt *, bool> { |
| auto *context = reinterpret_cast<GetterSynthesizerContext *>(ctx); |
| assert(context && "Invalid context"); |
| auto *parentDC = getterDecl->getParent(); |
| auto *nominal = parentDC->getSelfNominalTypeDecl(); |
| auto *module = nominal->getModuleContext(); |
| auto &C = nominal->getASTContext(); |
| auto *eucDiffProto = |
| C.getProtocol(KnownProtocolKind::EuclideanDifferentiable); |
| auto *vectorViewReq = |
| eucDiffProto->lookupDirect(C.Id_differentiableVectorView).front(); |
| |
| SmallVector<VarDecl *, 8> diffProperties; |
| getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties); |
| |
| // Create a reference to the memberwise initializer: `TangentVector.init`. |
| auto *memberwiseInitDecl = |
| context->tangentDecl->getEffectiveMemberwiseInitializer(); |
| assert(memberwiseInitDecl && "Memberwise initializer must exist"); |
| assert(diffProperties.size() == |
| memberwiseInitDecl->getParameters()->size()); |
| // `TangentVector` |
| auto *tangentTypeExpr = |
| TypeExpr::createImplicit(context->tangentContextualType, C); |
| // `TangentVector.init` |
| auto *initDRE = new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), |
| /*Implicit*/ true); |
| initDRE->setFunctionRefKind(FunctionRefKind::SingleApply); |
| auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, tangentTypeExpr); |
| initExpr->setThrows(false); |
| initExpr->setImplicit(); |
| |
| // Create a call: |
| // TangentVector.init( |
| // <property_name_1...>: |
| // self.<property_name_1>.differentiableVectorView, |
| // <property_name_2...>: |
| // self.<property_name_2>.differentiableVectorView, |
| // ... |
| // ) |
| SmallVector<Identifier, 8> argLabels; |
| SmallVector<Expr *, 8> memberRefs; |
| for (auto *member : diffProperties) { |
| auto *selfDRE = new (C) DeclRefExpr(getterDecl->getImplicitSelfDecl(), |
| DeclNameLoc(), |
| /*Implicit*/ true); |
| auto *memberExpr = new (C) MemberRefExpr( |
| selfDRE, SourceLoc(), member, DeclNameLoc(), /*Implicit*/ true); |
| auto memberType = |
| parentDC->mapTypeIntoContext(member->getValueInterfaceType()); |
| auto confRef = module->lookupConformance(memberType, eucDiffProto); |
| assert(confRef && |
| "Member missing conformance to `EuclideanDifferentiable`"); |
| ConcreteDeclRef memberDeclRef = vectorViewReq; |
| if (confRef->isConcrete()) |
| memberDeclRef = confRef->getConcrete()->getWitnessDecl(vectorViewReq); |
| argLabels.push_back(member->getName()); |
| memberRefs.push_back(new (C) MemberRefExpr( |
| memberExpr, SourceLoc(), memberDeclRef, DeclNameLoc(), |
| /*Implicit*/ true)); |
| } |
| assert(memberRefs.size() == argLabels.size()); |
| CallExpr *callExpr = |
| CallExpr::createImplicit(C, initExpr, memberRefs, argLabels); |
| |
| // Create a return statement: `return TangentVector.init(...)`. |
| ASTNode retStmt = |
| new (C) ReturnStmt(SourceLoc(), callExpr, /*implicit*/ true); |
| auto *braceStmt = BraceStmt::create(C, SourceLoc(), retStmt, SourceLoc(), |
| /*implicit*/ true); |
| return std::make_pair(braceStmt, false); |
| }; |
| auto *getterDecl = derived.addGetterToReadOnlyDerivedProperty( |
| vectorViewDecl, tangentContextualType); |
| getterDecl->setBodySynthesizer( |
| getterSynthesizer, /*context*/ C.AllocateObjectCopy( |
| GetterSynthesizerContext{tangentDecl, tangentContextualType})); |
| derived.addMembersToConformanceContext({vectorViewDecl, pbDecl}); |
| return vectorViewDecl; |
| } |
| |
| /// Return associated `TangentVector` struct for a nominal type, if it exists. |
| /// If not, synthesize the struct. |
| static StructDecl * |
| getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) { |
| auto &TC = derived.TC; |
| auto *parentDC = derived.getConformanceContext(); |
| auto *nominal = derived.Nominal; |
| auto &C = nominal->getASTContext(); |
| |
| // If the associated struct already exists, return it. |
| auto lookup = nominal->lookupDirect(C.Id_TangentVector); |
| assert(lookup.size() < 2 && |
| "Expected at most one associated type named `TangentVector`"); |
| if (lookup.size() == 1) { |
| auto *structDecl = convertToStructDecl(lookup.front()); |
| assert(structDecl && "Expected lookup result to be a struct"); |
| return structDecl; |
| } |
| |
| // Otherwise, synthesize a new struct. |
| auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); |
| auto diffableType = TypeLoc::withoutLoc(diffableProto->getDeclaredType()); |
| auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); |
| auto addArithType = TypeLoc::withoutLoc(addArithProto->getDeclaredType()); |
| auto *pointMulProto = |
| C.getProtocol(KnownProtocolKind::PointwiseMultiplicative); |
| auto pointMulType = TypeLoc::withoutLoc(pointMulProto->getDeclaredType()); |
| auto *mathProto = C.getProtocol(KnownProtocolKind::ElementaryFunctions); |
| auto mathType = TypeLoc::withoutLoc(mathProto->getDeclaredType()); |
| auto *vectorProto = C.getProtocol(KnownProtocolKind::VectorProtocol); |
| auto vectorType = TypeLoc::withoutLoc(vectorProto->getDeclaredType()); |
| auto *kpIterableProto = C.getProtocol(KnownProtocolKind::KeyPathIterable); |
| auto kpIterableType = TypeLoc::withoutLoc(kpIterableProto->getDeclaredType()); |
| |
| // By definition, `TangentVector` must conform to `Differentiable` and |
| // `AdditiveArithmetic`. |
| SmallVector<TypeLoc, 4> inherited{diffableType, addArithType}; |
| |
| // Cache original members and their associated types for later use. |
| SmallVector<VarDecl *, 8> diffProperties; |
| getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties); |
| |
| // Add ad-hoc implicit conformances for `TangentVector`. |
| // TODO(TF-632): Remove this implicit conformance logic when synthesized |
| // member types can be extended. |
| |
| // `TangentVector` struct can derive `PointwiseMultiplicative` if the |
| // `TangentVector` types of all stored properties conform to |
| // `PointwiseMultiplicative`. |
| bool canDerivePointwiseMultiplicative = |
| llvm::all_of(diffProperties, [&](VarDecl *vd) { |
| return TC.conformsToProtocol(getTangentVectorType(vd, parentDC), |
| pointMulProto, parentDC, None); |
| }); |
| |
| // `TangentVector` struct can derive `ElementaryFunctions` if the |
| // `TangentVector` types of all stored properties conform to |
| // `ElementaryFunctions`. |
| bool canDeriveElementaryFunctions = |
| llvm::all_of(diffProperties, [&](VarDecl *vd) { |
| return TC.conformsToProtocol(getTangentVectorType(vd, parentDC), |
| mathProto, parentDC, None); |
| }); |
| |
| // `TangentVector` struct can derive `VectorProtocol` if the `TangentVector` |
| // types of all members conform to `VectorProtocol` and share the same |
| // `VectorSpaceScalar` type. |
| Type sameScalarType; |
| bool canDeriveVectorProtocol = !diffProperties.empty() && |
| llvm::all_of(diffProperties, [&](VarDecl *vd) { |
| auto tanType = getTangentVectorType(vd, parentDC); |
| auto conf = TC.conformsToProtocol(tanType, vectorProto, nominal, None); |
| if (!conf) |
| return false; |
| auto scalarType = |
| conf->getTypeWitnessByName(tanType, C.Id_VectorSpaceScalar); |
| if (!sameScalarType) { |
| sameScalarType = scalarType; |
| return true; |
| } |
| return scalarType->isEqual(sameScalarType); |
| }); |
| |
| // `TangentVector` struct should derive `KeyPathIterable` if the parent struct |
| // conforms to `KeyPathIterable`. |
| bool shouldDeriveKeyPathIterable = |
| TC.conformsToProtocol(nominal->getDeclaredInterfaceType(), |
| kpIterableProto, parentDC, None).hasValue(); |
| |
| // If all members conform to `PointwiseMultiplicative`, make the |
| // `TangentVector` struct conform to `PointwiseMultiplicative`. |
| if (canDerivePointwiseMultiplicative) |
| inherited.push_back(pointMulType); |
| // If all members conform to `ElementaryFunctions`, make the `TangentVector` |
| // struct conform to `ElementaryFunctions`. |
| if (canDeriveElementaryFunctions) |
| inherited.push_back(mathType); |
| // If all members also conform to `VectorProtocol` with the same `Scalar` |
| // type, make the `TangentVector` struct conform to `VectorProtocol`. |
| if (canDeriveVectorProtocol) |
| inherited.push_back(vectorType); |
| // If parent type conforms to `KeyPathIterable`, make the `TangentVector` |
| // struct conform to `KeyPathIterable`. |
| if (shouldDeriveKeyPathIterable) |
| inherited.push_back(kpIterableType); |
| |
| auto *structDecl = |
| new (C) StructDecl(SourceLoc(), C.Id_TangentVector, SourceLoc(), |
| /*Inherited*/ C.AllocateCopy(inherited), |
| /*GenericParams*/ {}, parentDC); |
| structDecl->setImplicit(); |
| structDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); |
| |
| // Add members to `TangentVector` struct. |
| for (auto *member : diffProperties) { |
| // Add this member's corresponding `TangentVector` type to the parent's |
| // `TangentVector` struct. |
| auto *newMember = new (C) VarDecl( |
| member->isStatic(), member->getIntroducer(), member->isCaptureList(), |
| /*NameLoc*/ SourceLoc(), member->getName(), structDecl); |
| // NOTE: `newMember` is not marked as implicit here, because that affects |
| // memberwise initializer synthesis. |
| |
| auto memberAssocType = getTangentVectorType(member, parentDC); |
| auto memberAssocInterfaceType = memberAssocType->hasArchetype() |
| ? memberAssocType->mapTypeOutOfContext() |
| : memberAssocType; |
| auto memberAssocContextualType = |
| parentDC->mapTypeIntoContext(memberAssocInterfaceType); |
| newMember->setInterfaceType(memberAssocInterfaceType); |
| newMember->setType(memberAssocContextualType); |
| Pattern *memberPattern = |
| new (C) NamedPattern(newMember, /*implicit*/ true); |
| memberPattern->setType(memberAssocContextualType); |
| memberPattern = TypedPattern::createImplicit( |
| C, memberPattern, memberAssocContextualType); |
| memberPattern->setType(memberAssocContextualType); |
| auto *memberBinding = PatternBindingDecl::createImplicit( |
| C, StaticSpellingKind::None, memberPattern, /*initExpr*/ nullptr, |
| structDecl); |
| structDecl->addMember(newMember); |
| structDecl->addMember(memberBinding); |
| newMember->copyFormalAccessFrom(member, /*sourceIsParentContext*/ true); |
| newMember->setSetterAccess(member->getFormalAccess()); |
| C.addSynthesizedDecl(newMember); |
| C.addSynthesizedDecl(memberBinding); |
| |
| // Now that this member is in the `TangentVector` type, it should be marked |
| // `@differentiable` so that the differentiation transform will synthesize |
| // derivative functions for it. We only add this to public stored |
| // properties, because their access outside the module will go through a |
| // call to the getter. |
| if (member->getEffectiveAccess() > AccessLevel::Internal && |
| !member->getAttrs().hasAttribute<DifferentiableAttr>()) { |
| if (!member->getSynthesizedAccessor(AccessorKind::Get) |
| ->hasInterfaceType()) |
| TC.resolveDeclSignature(member->getAccessor(AccessorKind::Get)); |
| // If member or its getter already has a `@differentiable` attribute, |
| // continue. |
| if (member->getAttrs().hasAttribute<DifferentiableAttr>() || |
| member->getAccessor(AccessorKind::Get) |
| ->getAttrs() |
| .hasAttribute<DifferentiableAttr>()) |
| continue; |
| GenericSignature derivativeGenSig = GenericSignature(); |
| // If the parent declaration context is an extension, the nominal type may |
| // conditionally conform to `Differentiable`. Use the extension generic |
| // requirements in getter `@differentiable` attributes. |
| if (auto *extDecl = dyn_cast<ExtensionDecl>(parentDC->getAsDecl())) |
| derivativeGenSig = extDecl->getGenericSignature(); |
| auto *diffableAttr = DifferentiableAttr::create( |
| C, /*implicit*/ true, SourceLoc(), SourceLoc(), |
| /*linear*/ false, {}, None, None, derivativeGenSig); |
| member->getAttrs().add(diffableAttr); |
| // Set getter `@differentiable` attribute parameter indices. |
| diffableAttr->setParameterIndices(IndexSubset::get(C, 1, {0})); |
| } |
| } |
| |
| // If nominal type has `@_fixed_layout` attribute, mark `TangentVector` struct |
| // as `@_fixed_layout` as well. |
| if (nominal->getAttrs().hasAttribute<FixedLayoutAttr>()) |
| structDecl->addFixedLayoutAttr(); |
| |
| // The implicit memberwise constructor must be explicitly created so that it |
| // can called in `AdditiveArithmetic` and `Differentiable` methods. Normally, |
| // the memberwise constructor is synthesized during SILGen, which is too late. |
| auto *initDecl = createMemberwiseImplicitConstructor(TC, structDecl); |
| structDecl->addMember(initDecl); |
| C.addSynthesizedDecl(initDecl); |
| |
| // After memberwise initializer is synthesized, mark members as implicit. |
| for (auto *member : structDecl->getStoredProperties()) |
| member->setImplicit(); |
| |
| derived.addMembersToConformanceContext({structDecl}); |
| C.addSynthesizedDecl(structDecl); |
| |
| return structDecl; |
| } |
| |
| /// Add a typealias declaration with the given name and underlying target |
| /// struct type to the given source nominal declaration context. |
| static void addAssociatedTypeAliasDecl(Identifier name, |
| DeclContext *sourceDC, |
| StructDecl *target, |
| TypeChecker &TC) { |
| auto &C = TC.Context; |
| auto *nominal = sourceDC->getSelfNominalTypeDecl(); |
| assert(nominal && "Expected `DeclContext` to be a nominal type"); |
| auto lookup = nominal->lookupDirect(name); |
| assert(lookup.size() < 2 && |
| "Expected at most one associated type named member"); |
| // If implicit type declaration with the given name already exists in source |
| // struct, return it. |
| if (lookup.size() == 1) { |
| auto existingTypeDecl = dyn_cast<TypeDecl>(lookup.front()); |
| assert(existingTypeDecl && existingTypeDecl->isImplicit() && |
| "Expected lookup result to be an implicit type declaration"); |
| return; |
| } |
| // Otherwise, create a new typealias. |
| auto *aliasDecl = new (C) |
| TypeAliasDecl(SourceLoc(), SourceLoc(), name, SourceLoc(), {}, sourceDC); |
| aliasDecl->setUnderlyingType(target->getDeclaredInterfaceType()); |
| aliasDecl->setImplicit(); |
| aliasDecl->setGenericSignature(sourceDC->getGenericSignatureOfContext()); |
| cast<IterableDeclContext>(sourceDC->getAsDecl())->addMember(aliasDecl); |
| aliasDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); |
| aliasDecl->computeType(); |
| TC.validateDecl(aliasDecl); |
| C.addSynthesizedDecl(aliasDecl); |
| }; |
| |
| /// Diagnose stored properties in the nominal that do not have an explicit |
| /// `@noDerivative` attribute, but either: |
| /// - Do not conform to `Differentiable`. |
| /// - Are a `let` stored property. |
| /// Emit a warning and a fixit so that users will make the attribute explicit. |
| static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC, |
| NominalTypeDecl *nominal, |
| DeclContext* DC) { |
| auto *diffableProto = |
| TC.Context.getProtocol(KnownProtocolKind::Differentiable); |
| bool nominalCanDeriveAdditiveArithmetic = |
| DerivedConformance::canDeriveAdditiveArithmetic(nominal, DC); |
| for (auto *vd : nominal->getStoredProperties()) { |
| if (!vd->hasInterfaceType()) |
| TC.resolveDeclSignature(vd); |
| if (!vd->hasInterfaceType()) |
| continue; |
| auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType()); |
| if (vd->getAttrs().hasAttribute<NoDerivativeAttr>()) |
| continue; |
| // Check whether to diagnose stored property. |
| bool conformsToDifferentiable = |
| TC.conformsToProtocol(varType, diffableProto, nominal, None).hasValue(); |
| // If stored property should not be diagnosed, continue. |
| if (conformsToDifferentiable && !vd->isLet()) |
| continue; |
| // Otherwise, add an implicit `@noDerivative` attribute. |
| vd->getAttrs().add( |
| new (TC.Context) NoDerivativeAttr(/*Implicit*/ true)); |
| auto loc = vd->getAttributeInsertionLoc(/*forModifier*/ false); |
| assert(loc.isValid() && "Expected valid source location"); |
| // If nominal type can conform to `AdditiveArithmetic`, suggest conforming |
| // adding a conformance to `AdditiveArithmetic`. |
| // `Differentiable` protocol requirements all have default implementations |
| // when `Self` conforms to `AdditiveArithmetic`, so `Differentiable` |
| // derived conformances will no longer be necessary. |
| if (!conformsToDifferentiable) { |
| TC.diagnose(loc, |
| diag::differentiable_nondiff_type_implicit_noderivative_fixit, |
| vd->getName(), nominal->getName(), |
| nominalCanDeriveAdditiveArithmetic) |
| .fixItInsert(loc, "@noDerivative "); |
| continue; |
| } |
| TC.diagnose(loc, |
| diag::differentiable_let_property_implicit_noderivative_fixit, |
| vd->getName(), nominal->getName(), |
| nominalCanDeriveAdditiveArithmetic) |
| .fixItInsert(loc, "@noDerivative "); |
| } |
| } |
| |
| /// Get or synthesize `TangentVector` struct type. |
| static Type |
| getOrSynthesizeTangentVectorStructType(DerivedConformance &derived) { |
| auto &TC = derived.TC; |
| auto *parentDC = derived.getConformanceContext(); |
| auto *nominal = derived.Nominal; |
| auto &C = nominal->getASTContext(); |
| |
| // Get or synthesize `TangentVector` struct. |
| auto *tangentStruct = |
| getOrSynthesizeTangentVectorStruct(derived, C.Id_TangentVector); |
| if (!tangentStruct) |
| return nullptr; |
| // Check and emit warnings for implicit `@noDerivative` members. |
| checkAndDiagnoseImplicitNoDerivative(TC, nominal, parentDC); |
| // Add `TangentVector` typealias for `TangentVector` struct. |
| addAssociatedTypeAliasDecl(C.Id_TangentVector, |
| tangentStruct, tangentStruct, TC); |
| TC.validateDecl(tangentStruct); |
| |
| // Sanity checks for synthesized struct. |
| assert(DerivedConformance::canDeriveAdditiveArithmetic(tangentStruct, |
| parentDC) && |
| "Should be able to derive `AdditiveArithmetic`"); |
| assert(DerivedConformance::canDeriveDifferentiable(tangentStruct, parentDC) && |
| "Should be able to derive `Differentiable`"); |
| |
| // Return the `TangentVector` struct type. |
| return parentDC->mapTypeIntoContext( |
| tangentStruct->getDeclaredInterfaceType()); |
| } |
| |
| /// Synthesize the `TangentVector` struct type. |
| static Type |
| deriveDifferentiable_TangentVectorStruct(DerivedConformance &derived) { |
| auto &TC = derived.TC; |
| auto *parentDC = derived.getConformanceContext(); |
| auto *nominal = derived.Nominal; |
| auto &C = nominal->getASTContext(); |
| |
| // Get all stored properties for differentation. |
| SmallVector<VarDecl *, 16> diffProperties; |
| getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties); |
| |
| // If any member has an invalid `TangentVector` type, return nullptr. |
| for (auto *member : diffProperties) |
| if (!getTangentVectorType(member, parentDC)) |
| return nullptr; |
| |
| // Prevent re-synthesis during repeated calls. |
| // FIXME: Investigate why this is necessary to prevent duplicate synthesis. |
| auto lookup = nominal->lookupDirect(C.Id_TangentVector); |
| if (lookup.size() == 1) |
| if (auto *structDecl = convertToStructDecl(lookup.front())) |
| if (structDecl->isImplicit()) |
| return structDecl->getDeclaredInterfaceType(); |
| |
| // Check whether at least one `@noDerivative` stored property exists. |
| unsigned numStoredProperties = |
| std::distance(nominal->getStoredProperties().begin(), |
| nominal->getStoredProperties().end()); |
| bool hasNoDerivativeStoredProp = diffProperties.size() != numStoredProperties; |
| |
| // Check conditions for returning `Self`. |
| // - `Self` is not a class type. |
| // - No `@noDerivative` stored properties exist. |
| // - All stored properties must have `TangentVector` type equal to `Self`. |
| // - Parent type must also conform to `AdditiveArithmetic`. |
| bool allMembersAssocTypeEqualsSelf = |
| llvm::all_of(diffProperties, [&](VarDecl *member) { |
| auto memberAssocType = getTangentVectorType(member, parentDC); |
| return member->getType()->isEqual(memberAssocType); |
| }); |
| |
| auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); |
| auto nominalConformsToAddArith = |
| TC.conformsToProtocol(parentDC->getSelfTypeInContext(), addArithProto, |
| parentDC, None); |
| |
| // Return `Self` if conditions are met. |
| if (!hasNoDerivativeStoredProp && !nominal->getSelfClassDecl() && |
| allMembersAssocTypeEqualsSelf && nominalConformsToAddArith) { |
| auto selfType = parentDC->getSelfTypeInContext(); |
| auto *aliasDecl = |
| new (C) TypeAliasDecl(SourceLoc(), SourceLoc(), C.Id_TangentVector, |
| SourceLoc(), {}, parentDC); |
| aliasDecl->setUnderlyingType(selfType); |
| aliasDecl->setImplicit(); |
| aliasDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); |
| aliasDecl->computeType(); |
| TC.validateDecl(aliasDecl); |
| derived.addMembersToConformanceContext({aliasDecl}); |
| C.addSynthesizedDecl(aliasDecl); |
| return selfType; |
| } |
| |
| // Otherwise, get or synthesize `TangentVector` struct type. |
| return getOrSynthesizeTangentVectorStructType(derived); |
| } |
| |
| ValueDecl *DerivedConformance::deriveDifferentiable(ValueDecl *requirement) { |
| // Diagnose conformances in disallowed contexts. |
| if (checkAndDiagnoseDisallowedContext(requirement)) |
| return nullptr; |
| if (requirement->getBaseName() == TC.Context.Id_move) |
| return deriveDifferentiable_move(*this); |
| TC.diagnose(requirement->getLoc(), diag::broken_differentiable_requirement); |
| return nullptr; |
| } |
| |
| Type DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) { |
| // Diagnose conformances in disallowed contexts. |
| if (checkAndDiagnoseDisallowedContext(requirement)) |
| return nullptr; |
| if (requirement->getBaseName() == TC.Context.Id_TangentVector) |
| return deriveDifferentiable_TangentVectorStruct(*this); |
| TC.diagnose(requirement->getLoc(), diag::broken_differentiable_requirement); |
| return nullptr; |
| } |
| |
| /// Derive a EuclideanDifferentiable requirement for a nominal type. |
| /// |
| /// \returns the derived member, which will also be added to the type. |
| ValueDecl *DerivedConformance::deriveEuclideanDifferentiable( |
| ValueDecl *requirement) { |
| // Diagnose conformances in disallowed contexts. |
| if (checkAndDiagnoseDisallowedContext(requirement)) |
| return nullptr; |
| if (requirement->getFullName() == TC.Context.Id_differentiableVectorView) |
| return deriveEuclideanDifferentiable_differentiableVectorView(*this); |
| TC.diagnose(requirement->getLoc(), |
| diag::broken_euclidean_differentiable_requirement); |
| return nullptr; |
| } |