| //===--- DerivedConformanceElementaryFunctions.cpp ------------------------===// |
| // |
| // This source file is part of the Swift.org open source project |
| // |
| // Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors |
| // Licensed under Apache License v2.0 with Runtime Library Exception |
| // |
| // See https://swift.org/LICENSE.txt for license information |
| // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements explicit derivation of the ElementaryFunctions protocol |
| // for struct types. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "CodeSynthesis.h" |
| #include "TypeChecker.h" |
| #include "swift/AST/Decl.h" |
| #include "swift/AST/Expr.h" |
| #include "swift/AST/GenericSignature.h" |
| #include "swift/AST/Module.h" |
| #include "swift/AST/ParameterList.h" |
| #include "swift/AST/Pattern.h" |
| #include "swift/AST/ProtocolConformance.h" |
| #include "swift/AST/Stmt.h" |
| #include "swift/AST/Types.h" |
| #include "DerivedConformances.h" |
| |
| using namespace swift; |
| |
| // Represents synthesizable `ElementaryFunction` protocol requirements. |
| enum ElementaryFunction { |
| #define ELEMENTARY_FUNCTION(ID, NAME) ID, |
| #include "DerivedConformanceElementaryFunctions.def" |
| #undef ELEMENTARY_FUNCTION |
| }; |
| |
| static StringRef getElementaryFunctionName(ElementaryFunction op) { |
| switch (op) { |
| #define ELEMENTARY_FUNCTION(ID, NAME) case ElementaryFunction::ID: return NAME; |
| #include "DerivedConformanceElementaryFunctions.def" |
| #undef ELEMENTARY_FUNCTION |
| } |
| } |
| |
| // Return the protocol requirement with the specified name. |
| // TODO: Move function to shared place for use with other derived conformances. |
| static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) { |
| auto lookup = proto->lookupDirect(name); |
| llvm::erase_if(lookup, [](ValueDecl *v) { |
| return !isa<ProtocolDecl>(v->getDeclContext()) || |
| !v->isProtocolRequirement(); |
| }); |
| assert(lookup.size() == 1 && "Ambiguous protocol requirement"); |
| return lookup.front(); |
| } |
| |
| // Return true if given nominal type has a `let` stored with an initial value. |
| // TODO: Move function to shared place for use with other derived conformances. |
| static bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal) { |
| return llvm::any_of(nominal->getStoredProperties(), [&](VarDecl *v) { |
| return v->isLet() && v->hasInitialValue(); |
| }); |
| } |
| |
| // Return the `ElementaryFunction` protocol requirement corresponding to the |
| // given elementary function. |
| static ValueDecl *getElementaryFunctionRequirement( |
| ASTContext &C, ElementaryFunction op) { |
| auto *mathProto = C.getProtocol(KnownProtocolKind::ElementaryFunctions); |
| auto operatorId = C.getIdentifier(getElementaryFunctionName(op)); |
| switch (op) { |
| #define ELEMENTARY_FUNCTION_UNARY(ID, NAME) \ |
| case ID: \ |
| return getProtocolRequirement(mathProto, operatorId); |
| #include "DerivedConformanceElementaryFunctions.def" |
| #undef ELEMENTARY_FUNCTION_UNARY |
| case Root: |
| return getProtocolRequirement(mathProto, operatorId); |
| case Pow: |
| case PowInt: |
| auto lookup = mathProto->lookupDirect(operatorId); |
| lookup.erase(std::remove_if(lookup.begin(), lookup.end(), |
| [](ValueDecl *v) { |
| return !isa<ProtocolDecl>( |
| v->getDeclContext()) || |
| !v->isProtocolRequirement(); |
| }), |
| lookup.end()); |
| assert(lookup.size() == 2 && "Expected two 'pow' functions"); |
| auto *powFuncDecl = cast<FuncDecl>(lookup.front()); |
| auto secondParamType = |
| powFuncDecl->getParameters()->get(1)->getInterfaceType(); |
| if (secondParamType->getAnyNominal() == C.getIntDecl()) |
| return op == PowInt ? lookup.front() : lookup[1]; |
| else |
| return op == PowInt ? lookup[1] : lookup.front(); |
| } |
| } |
| |
| // Get the effective memberwise initializer of the given nominal type, or create |
| // it if it does not exist. |
| static ConstructorDecl *getOrCreateEffectiveMemberwiseInitializer( |
| TypeChecker &TC, NominalTypeDecl *nominal) { |
| auto &C = nominal->getASTContext(); |
| if (auto *initDecl = nominal->getEffectiveMemberwiseInitializer()) |
| return initDecl; |
| auto *initDecl = createMemberwiseImplicitConstructor(TC, nominal); |
| nominal->addMember(initDecl); |
| C.addSynthesizedDecl(initDecl); |
| return initDecl; |
| } |
| |
| bool DerivedConformance::canDeriveElementaryFunctions(NominalTypeDecl *nominal, |
| DeclContext *DC) { |
| // Nominal type must be a struct. (Zero stored properties is okay.) |
| auto *structDecl = dyn_cast<StructDecl>(nominal); |
| if (!structDecl) |
| return false; |
| // Must not have any `let` stored properties with an initial value. |
| // - This restriction may be lifted later with support for "true" memberwise |
| // initializers that initialize all stored properties, including initial |
| // value information. |
| if (hasLetStoredPropertyWithInitialValue(nominal)) |
| return false; |
| // All stored properties must conform to `ElementaryFunctions`. |
| auto &C = nominal->getASTContext(); |
| auto *mathProto = C.getProtocol(KnownProtocolKind::ElementaryFunctions); |
| return llvm::all_of(structDecl->getStoredProperties(), [&](VarDecl *v) { |
| if (!v->hasInterfaceType()) |
| C.getLazyResolver()->resolveDeclSignature(v); |
| if (!v->hasInterfaceType()) |
| return false; |
| auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType()); |
| return (bool)TypeChecker::conformsToProtocol(varType, mathProto, DC, None); |
| }); |
| } |
| |
| // Synthesize body for the given `ElementaryFunction` protocol requirement. |
| static std::pair<BraceStmt *, bool> |
| deriveBodyElementaryFunction(AbstractFunctionDecl *funcDecl, |
| ElementaryFunction op) { |
| auto *parentDC = funcDecl->getParent(); |
| auto *nominal = parentDC->getSelfNominalTypeDecl(); |
| auto &C = nominal->getASTContext(); |
| |
| // Create memberwise initializer: `Nominal.init(...)`. |
| auto *memberwiseInitDecl = nominal->getEffectiveMemberwiseInitializer(); |
| assert(memberwiseInitDecl && "Memberwise initializer must exist"); |
| auto *initDRE = |
| new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true); |
| initDRE->setFunctionRefKind(FunctionRefKind::SingleApply); |
| auto *nominalTypeExpr = TypeExpr::createForDecl(SourceLoc(), nominal, |
| funcDecl, /*Implicit*/ true); |
| auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, nominalTypeExpr); |
| |
| // Get operator protocol requirement. |
| auto *mathProto = C.getProtocol(KnownProtocolKind::ElementaryFunctions); |
| auto *operatorReq = getElementaryFunctionRequirement(C, op); |
| |
| // Create reference(s) to operator parameters: one for unary functions and two |
| // for binary functions. |
| auto params = funcDecl->getParameters(); |
| auto *firstParamDRE = |
| new (C) DeclRefExpr(params->get(0), DeclNameLoc(), /*Implicit*/ true); |
| Expr *secondParamDRE = nullptr; |
| if (params->size() == 2) |
| secondParamDRE = |
| new (C) DeclRefExpr(params->get(1), DeclNameLoc(), /*Implicit*/ true); |
| |
| // Create call expression combining lhs and rhs members using member operator. |
| auto createMemberOpCallExpr = [&](VarDecl *member) -> Expr * { |
| auto module = nominal->getModuleContext(); |
| auto memberType = |
| parentDC->mapTypeIntoContext(member->getValueInterfaceType()); |
| auto confRef = module->lookupConformance(memberType, mathProto); |
| assert(confRef && "Member does not conform to math protocol"); |
| |
| // Get member type's elementary function, e.g. `Member.cos`. |
| // Use protocol requirement declaration for the operator by default: this |
| // will be dynamically dispatched. |
| ValueDecl *memberOpDecl = operatorReq; |
| // If conformance reference is concrete, then use concrete witness |
| // declaration for the operator. |
| if (confRef->isConcrete()) |
| memberOpDecl = confRef->getConcrete()->getWitnessDecl( |
| operatorReq); |
| assert(memberOpDecl && "Member operator declaration must exist"); |
| auto memberOpDRE = |
| new (C) DeclRefExpr(memberOpDecl, DeclNameLoc(), /*Implicit*/ true); |
| auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C); |
| auto memberOpExpr = |
| new (C) DotSyntaxCallExpr(memberOpDRE, SourceLoc(), memberTypeExpr); |
| |
| // - For unary ops, create expression: |
| // `<op>(x.member)`. |
| // - For `pow(_ x: Self, _ y: Self)`, create expression: |
| // `<op>(x.member, y.member)`. |
| // - For `pow(_ x: Self, _ n: Int)` and `root(_ x: Self, n: Int)`, create: |
| // `<op>(x.member, n)`. |
| Expr *firstArg = new (C) MemberRefExpr(firstParamDRE, SourceLoc(), member, |
| DeclNameLoc(), /*Implicit*/ true); |
| Expr *secondArg = nullptr; |
| if (secondParamDRE) { |
| if (op == PowInt || op == Root) |
| secondArg = secondParamDRE; |
| else |
| secondArg = new (C) MemberRefExpr(secondParamDRE, SourceLoc(), member, |
| DeclNameLoc(), /*Implicit*/ true); |
| } |
| SmallVector<Expr *, 2> memberOpArgs{firstArg}; |
| if (secondArg) |
| memberOpArgs.push_back(secondArg); |
| SmallVector<Identifier, 2> memberOpArgLabels(memberOpArgs.size()); |
| auto *memberOpCallExpr = CallExpr::createImplicit( |
| C, memberOpExpr, memberOpArgs, memberOpArgLabels); |
| return memberOpCallExpr; |
| }; |
| |
| // Create array of member operator call expressions. |
| llvm::SmallVector<Expr *, 2> memberOpCallExprs; |
| llvm::SmallVector<Identifier, 2> memberNames; |
| for (auto member : nominal->getStoredProperties()) { |
| memberOpCallExprs.push_back(createMemberOpCallExpr(member)); |
| memberNames.push_back(member->getName()); |
| } |
| // Call memberwise initializer with member operator call expressions. |
| auto *callExpr = |
| CallExpr::createImplicit(C, initExpr, memberOpCallExprs, memberNames); |
| ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), callExpr, true); |
| auto* braceStmt = BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true); |
| return std::pair<BraceStmt *, bool>(braceStmt, false); |
| } |
| |
| #define ELEMENTARY_FUNCTION(ID, NAME) \ |
| static std::pair<BraceStmt *, bool> deriveBodyElementaryFunctions_##ID( \ |
| AbstractFunctionDecl *funcDecl, void *) { \ |
| return deriveBodyElementaryFunction(funcDecl, ID); \ |
| } |
| #include "DerivedConformanceElementaryFunctions.def" |
| #undef ELEMENTARY_FUNCTION |
| |
| // Synthesize function declaration for the given math operator. |
| static ValueDecl *deriveElementaryFunction(DerivedConformance &derived, |
| ElementaryFunction op) { |
| auto nominal = derived.Nominal; |
| auto parentDC = derived.getConformanceContext(); |
| auto &C = derived.TC.Context; |
| auto selfInterfaceType = parentDC->getDeclaredInterfaceType(); |
| |
| // Create parameter declaration with the given name and type. |
| auto createParamDecl = [&](StringRef name, Type type) -> ParamDecl * { |
| auto *param = new (C) |
| ParamDecl(ParamDecl::Specifier::Default, SourceLoc(), SourceLoc(), |
| Identifier(), SourceLoc(), C.getIdentifier(name), parentDC); |
| param->setInterfaceType(type); |
| return param; |
| }; |
| |
| ParameterList *params = nullptr; |
| |
| switch (op) { |
| #define ELEMENTARY_FUNCTION_UNARY(ID, NAME) \ |
| case ID: \ |
| params = \ |
| ParameterList::create(C, {createParamDecl("x", selfInterfaceType)}); \ |
| break; |
| #include "DerivedConformanceElementaryFunctions.def" |
| #undef ELEMENTARY_FUNCTION_UNARY |
| case Pow: |
| params = |
| ParameterList::create(C, {createParamDecl("x", selfInterfaceType), |
| createParamDecl("y", selfInterfaceType)}); |
| break; |
| case PowInt: |
| case Root: |
| params = ParameterList::create( |
| C, {createParamDecl("x", selfInterfaceType), |
| createParamDecl("n", C.getIntDecl()->getDeclaredInterfaceType())}); |
| break; |
| } |
| |
| auto operatorId = C.getIdentifier(getElementaryFunctionName(op)); |
| DeclName operatorDeclName(C, operatorId, params); |
| auto operatorDecl = |
| FuncDecl::create(C, SourceLoc(), StaticSpellingKind::KeywordStatic, |
| SourceLoc(), operatorDeclName, SourceLoc(), |
| /*Throws*/ false, SourceLoc(), |
| /*GenericParams*/ nullptr, params, |
| TypeLoc::withoutLoc(selfInterfaceType), parentDC); |
| operatorDecl->setImplicit(); |
| switch (op) { |
| #define ELEMENTARY_FUNCTION(ID, NAME) \ |
| case ID: \ |
| operatorDecl->setBodySynthesizer(deriveBodyElementaryFunctions_##ID, \ |
| nullptr); \ |
| break; |
| #include "DerivedConformanceElementaryFunctions.def" |
| #undef ELEMENTARY_FUNCTION |
| } |
| operatorDecl->setGenericSignature(parentDC->getGenericSignatureOfContext()); |
| operatorDecl->computeType(); |
| operatorDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); |
| |
| derived.addMembersToConformanceContext({operatorDecl}); |
| C.addSynthesizedDecl(operatorDecl); |
| |
| return operatorDecl; |
| } |
| |
| ValueDecl * |
| DerivedConformance::deriveElementaryFunctions(ValueDecl *requirement) { |
| // Diagnose conformances in disallowed contexts. |
| if (checkAndDiagnoseDisallowedContext(requirement)) |
| return nullptr; |
| // Create memberwise initializer for nominal type if it doesn't already exist. |
| getOrCreateEffectiveMemberwiseInitializer(TC, Nominal); |
| #define ELEMENTARY_FUNCTION_UNARY(ID, NAME) \ |
| if (requirement->getBaseName() == TC.Context.getIdentifier(NAME)) \ |
| return deriveElementaryFunction(*this, ID); |
| #include "DerivedConformanceElementaryFunctions.def" |
| #undef ELEMENTARY_FUNCTION_UNARY |
| if (requirement->getBaseName() == TC.Context.getIdentifier("root")) |
| return deriveElementaryFunction(*this, Root); |
| if (requirement->getBaseName() == TC.Context.getIdentifier("pow")) { |
| auto *powFuncDecl = cast<FuncDecl>(requirement); |
| return powFuncDecl->getParameters()->get(1)->getName().str() == "n" |
| ? deriveElementaryFunction(*this, PowInt) |
| : deriveElementaryFunction(*this, Pow); |
| } |
| TC.diagnose(requirement->getLoc(), |
| diag::broken_elementary_functions_requirement); |
| return nullptr; |
| } |