blob: e2ea63a2fbc0cd432d4730866c9c2121c6860a67 [file] [log] [blame]
//===--- DerivedConformanceRingMathProtocols.cpp --------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file implements explicit derivation of mathematical ring protocols for
// struct types: AdditiveArithmetic and PointwiseMultiplicative.
//
//===----------------------------------------------------------------------===//
#include "CodeSynthesis.h"
#include "TypeChecker.h"
#include "swift/AST/Decl.h"
#include "swift/AST/Expr.h"
#include "swift/AST/GenericSignature.h"
#include "swift/AST/Module.h"
#include "swift/AST/ParameterList.h"
#include "swift/AST/Pattern.h"
#include "swift/AST/ProtocolConformance.h"
#include "swift/AST/Stmt.h"
#include "swift/AST/Types.h"
#include "DerivedConformances.h"
using namespace swift;
// Represents synthesizable math operators.
enum MathOperator {
// `+(Self, Self)`: AdditiveArithmetic
Add,
// `-(Self, Self)`: AdditiveArithmetic
Subtract,
// `.*(Self, Self)`: PointwiseMultiplicative
Multiply
};
static StringRef getMathOperatorName(MathOperator op) {
switch (op) {
case Add:
return "+";
case Subtract:
return "-";
case Multiply:
return ".*";
}
}
static KnownProtocolKind getKnownProtocolKind(MathOperator op) {
switch (op) {
case Add:
case Subtract:
return KnownProtocolKind::AdditiveArithmetic;
case Multiply:
return KnownProtocolKind::PointwiseMultiplicative;
}
}
// Return the protocol requirement with the specified name.
// TODO: Move function to shared place for use with other derived conformances.
static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
auto lookup = proto->lookupDirect(name);
// Erase declarations that are not protocol requirements.
// This is important for removing default implementations of the same name.
llvm::erase_if(lookup, [](ValueDecl *v) {
return !isa<ProtocolDecl>(v->getDeclContext()) ||
!v->isProtocolRequirement();
});
assert(lookup.size() == 1 && "Ambiguous protocol requirement");
return lookup.front();
}
// Get the effective memberwise initializer of the given nominal type, or create
// it if it does not exist.
static ConstructorDecl *getOrCreateEffectiveMemberwiseInitializer(
TypeChecker &TC, NominalTypeDecl *nominal) {
auto &C = nominal->getASTContext();
if (auto *initDecl = nominal->getEffectiveMemberwiseInitializer())
return initDecl;
auto *initDecl = createMemberwiseImplicitConstructor(
TC, nominal);
nominal->addMember(initDecl);
C.addSynthesizedDecl(initDecl);
return initDecl;
}
// Return true if given nominal type has a `let` stored with an initial value.
// TODO: Move function to shared place for use with other derived conformances.
static bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal) {
return llvm::any_of(nominal->getStoredProperties(), [&](VarDecl *v) {
return v->isLet() && v->hasInitialValue();
});
}
static bool canDeriveRingProtocol(KnownProtocolKind knownProtoKind,
NominalTypeDecl *nominal, DeclContext *DC) {
// Nominal type must be a struct. (No stored properties is okay.)
auto *structDecl = dyn_cast<StructDecl>(nominal);
if (!structDecl)
return false;
// Must not have any `let` stored properties with an initial value.
// - This restriction may be lifted later with support for "true" memberwise
// initializers that initialize all stored properties, including initial
// value information.
if (hasLetStoredPropertyWithInitialValue(nominal))
return false;
// All stored properties must conform to `AdditiveArithmetic`.
auto &C = nominal->getASTContext();
auto *proto = C.getProtocol(knownProtoKind);
return llvm::all_of(structDecl->getStoredProperties(), [&](VarDecl *v) {
if (!v->hasInterfaceType())
C.getLazyResolver()->resolveDeclSignature(v);
if (!v->hasInterfaceType())
return false;
auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType());
return (bool)TypeChecker::conformsToProtocol(varType, proto, DC, None);
});
}
bool DerivedConformance::canDeriveAdditiveArithmetic(NominalTypeDecl *nominal,
DeclContext *DC) {
return canDeriveRingProtocol(KnownProtocolKind::AdditiveArithmetic,
nominal, DC);
}
bool DerivedConformance::canDerivePointwiseMultiplicative(NominalTypeDecl *nominal,
DeclContext *DC) {
return canDeriveRingProtocol(KnownProtocolKind::PointwiseMultiplicative,
nominal, DC);
}
// Synthesize body for ring math operator.
static std::pair<BraceStmt *, bool>
deriveBodyMathOperator(AbstractFunctionDecl *funcDecl, MathOperator op) {
auto *parentDC = funcDecl->getParent();
auto *nominal = parentDC->getSelfNominalTypeDecl();
auto &C = nominal->getASTContext();
// Create memberwise initializer: `Nominal.init(...)`.
auto *memberwiseInitDecl = nominal->getEffectiveMemberwiseInitializer();
assert(memberwiseInitDecl && "Memberwise initializer must exist");
auto *initDRE =
new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true);
initDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
auto *nominalTypeExpr = TypeExpr::createForDecl(SourceLoc(), nominal,
funcDecl, /*Implicit*/ true);
auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, nominalTypeExpr);
// Get operator protocol requirement.
auto *proto = C.getProtocol(getKnownProtocolKind(op));
auto operatorId = C.getIdentifier(getMathOperatorName(op));
auto *operatorReq = getProtocolRequirement(proto, operatorId);
// Create reference to operator parameters: lhs and rhs.
auto params = funcDecl->getParameters();
auto *lhsDRE =
new (C) DeclRefExpr(params->get(0), DeclNameLoc(), /*Implicit*/ true);
auto *rhsDRE =
new (C) DeclRefExpr(params->get(1), DeclNameLoc(), /*Implicit*/ true);
// Create expression combining lhs and rhs members using member operator.
auto createMemberOpExpr = [&](VarDecl *member) -> Expr * {
auto module = nominal->getModuleContext();
auto memberType =
parentDC->mapTypeIntoContext(member->getValueInterfaceType());
auto confRef = module->lookupConformance(memberType, proto);
assert(confRef && "Member does not conform to math protocol");
// Get member type's math operator, e.g. `Member.+`.
// Use protocol requirement declaration for the operator by default: this
// will be dynamically dispatched.
ValueDecl *memberOpDecl = operatorReq;
// If conformance reference is concrete, then use concrete witness
// declaration for the operator.
if (confRef->isConcrete())
if (auto *concreteMemberMethodDecl =
confRef->getConcrete()->getWitnessDecl(operatorReq))
memberOpDecl = concreteMemberMethodDecl;
assert(memberOpDecl && "Member operator declaration must exist");
auto memberOpDRE =
new (C) DeclRefExpr(memberOpDecl, DeclNameLoc(), /*Implicit*/ true);
auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C);
auto memberOpExpr =
new (C) DotSyntaxCallExpr(memberOpDRE, SourceLoc(), memberTypeExpr);
// Create expression `lhs.member <op> rhs.member`.
Expr *lhsArg = new (C) MemberRefExpr(lhsDRE, SourceLoc(), member,
DeclNameLoc(), /*Implicit*/ true);
auto *rhsArg = new (C) MemberRefExpr(rhsDRE, SourceLoc(), member,
DeclNameLoc(), /*Implicit*/ true);
auto *memberOpArgs =
TupleExpr::create(C, SourceLoc(), {lhsArg, rhsArg}, {}, {}, SourceLoc(),
/*HasTrailingClosure*/ false,
/*Implicit*/ true);
auto *memberOpCallExpr =
new (C) BinaryExpr(memberOpExpr, memberOpArgs, /*Implicit*/ true);
return memberOpCallExpr;
};
// Create array of member operator call expressions.
llvm::SmallVector<Expr *, 2> memberOpExprs;
llvm::SmallVector<Identifier, 2> memberNames;
for (auto member : nominal->getStoredProperties()) {
memberOpExprs.push_back(createMemberOpExpr(member));
memberNames.push_back(member->getName());
}
// Call memberwise initializer with member operator call expressions.
auto *callExpr =
CallExpr::createImplicit(C, initExpr, memberOpExprs, memberNames);
ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), callExpr, true);
return std::pair<BraceStmt *, bool>(
BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true), false);
}
// Synthesize function declaration for the given math operator.
static ValueDecl *deriveMathOperator(DerivedConformance &derived,
MathOperator op) {
auto nominal = derived.Nominal;
auto parentDC = derived.getConformanceContext();
auto &C = derived.TC.Context;
auto selfInterfaceType = parentDC->getDeclaredInterfaceType();
// Create parameter declaration with the given name and type.
auto createParamDecl = [&](StringRef name, Type type) -> ParamDecl * {
auto *param = new (C)
ParamDecl(ParamDecl::Specifier::Default, SourceLoc(), SourceLoc(),
Identifier(), SourceLoc(), C.getIdentifier(name), parentDC);
param->setInterfaceType(type);
return param;
};
ParameterList *params =
ParameterList::create(C, {createParamDecl("lhs", selfInterfaceType),
createParamDecl("rhs", selfInterfaceType)});
auto operatorId = C.getIdentifier(getMathOperatorName(op));
DeclName operatorDeclName(C, operatorId, params);
auto operatorDecl =
FuncDecl::create(C, SourceLoc(), StaticSpellingKind::KeywordStatic,
SourceLoc(), operatorDeclName, SourceLoc(),
/*Throws*/ false, SourceLoc(),
/*GenericParams=*/nullptr, params,
TypeLoc::withoutLoc(selfInterfaceType), parentDC);
operatorDecl->setImplicit();
auto bodySynthesizer = [](AbstractFunctionDecl *funcDecl,
void *ctx) -> std::pair<BraceStmt *, bool> {
auto op = (MathOperator) reinterpret_cast<intptr_t>(ctx);
return deriveBodyMathOperator(funcDecl, op);
};
operatorDecl->setBodySynthesizer(bodySynthesizer, (void *) op);
operatorDecl->setGenericSignature(parentDC->getGenericSignatureOfContext());
operatorDecl->computeType();
operatorDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
derived.addMembersToConformanceContext({operatorDecl});
C.addSynthesizedDecl(operatorDecl);
return operatorDecl;
}
// Synthesize body for a ring property computed property getter.
static std::pair<BraceStmt *, bool>
deriveBodyRingPropertyGetter(AbstractFunctionDecl *funcDecl,
ProtocolDecl *proto, ValueDecl *reqDecl) {
auto *parentDC = funcDecl->getParent();
auto *nominal = parentDC->getSelfNominalTypeDecl();
auto &C = nominal->getASTContext();
auto *memberwiseInitDecl = nominal->getEffectiveMemberwiseInitializer();
assert(memberwiseInitDecl && "Memberwise initializer must exist");
auto *initDRE =
new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true);
initDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
auto *nominalTypeExpr = TypeExpr::createForDecl(SourceLoc(), nominal,
funcDecl, /*Implicit*/ true);
auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, nominalTypeExpr);
auto createMemberRingPropertyExpr = [&](VarDecl *member) -> Expr * {
auto memberType =
parentDC->mapTypeIntoContext(member->getValueInterfaceType());
Expr *memberExpr = nullptr;
// If the property is static, create a type expression: `Member`.
if (reqDecl->isStatic()) {
memberExpr = TypeExpr::createImplicit(memberType, C);
}
// If the property is not static, create a member ref expression:
// `self.member`.
else {
auto *selfDecl = funcDecl->getImplicitSelfDecl();
auto *selfDRE =
new (C) DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true);
memberExpr =
new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
/*Implicit*/ true);
}
auto module = nominal->getModuleContext();
auto confRef = module->lookupConformance(memberType, proto);
assert(confRef && "Member does not conform to ring protocol");
// If conformance reference is not concrete, then concrete witness
// declaration for ring property cannot be resolved. Return reference to
// protocol requirement: this will be dynamically dispatched.
if (!confRef->isConcrete()) {
return new (C) MemberRefExpr(memberExpr, SourceLoc(), reqDecl,
DeclNameLoc(), /*Implicit*/ true);
}
// Otherwise, return reference to concrete witness declaration.
auto conf = confRef->getConcrete();
auto witnessDecl = conf->getWitnessDecl(reqDecl);
return new (C) MemberRefExpr(memberExpr, SourceLoc(), witnessDecl,
DeclNameLoc(), /*Implicit*/ true);
};
// Create array of `member.<ring property>` expressions.
llvm::SmallVector<Expr *, 2> memberPropExprs;
llvm::SmallVector<Identifier, 2> memberNames;
for (auto member : nominal->getStoredProperties()) {
memberPropExprs.push_back(createMemberRingPropertyExpr(member));
memberNames.push_back(member->getName());
}
// Call memberwise initializer with member ring property expressions.
auto *callExpr =
CallExpr::createImplicit(C, initExpr, memberPropExprs, memberNames);
ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), callExpr, true);
auto *braceStmt =
BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true);
return std::pair<BraceStmt *, bool>(braceStmt, false);
}
// Synthesize body for the `AdditiveArithmetic.zero` computed property getter.
static std::pair<BraceStmt *, bool>
deriveBodyAdditiveArithmetic_zero(AbstractFunctionDecl *funcDecl, void *) {
auto &C = funcDecl->getASTContext();
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
auto *zeroReq = getProtocolRequirement(addArithProto, C.Id_zero);
return deriveBodyRingPropertyGetter(funcDecl, addArithProto, zeroReq);
}
// Synthesize body for the `PointwiseMultiplicative.one` computed property
// getter.
static std::pair<BraceStmt *, bool>
deriveBodyPointwiseMultiplicative_one(AbstractFunctionDecl *funcDecl, void *) {
auto &C = funcDecl->getASTContext();
auto *pointMulProto =
C.getProtocol(KnownProtocolKind::PointwiseMultiplicative);
auto *oneReq = getProtocolRequirement(pointMulProto, C.Id_one);
return deriveBodyRingPropertyGetter(funcDecl, pointMulProto, oneReq);
}
// Synthesize body for the `PointwiseMultiplicative.reciprocal` computed
// property getter.
static std::pair<BraceStmt *, bool>
deriveBodyPointwiseMultiplicative_reciprocal(AbstractFunctionDecl *funcDecl,
void *) {
auto &C = funcDecl->getASTContext();
auto *pointMulProto =
C.getProtocol(KnownProtocolKind::PointwiseMultiplicative);
auto *reciprocalReq = getProtocolRequirement(pointMulProto, C.Id_reciprocal);
return deriveBodyRingPropertyGetter(funcDecl, pointMulProto, reciprocalReq);
}
// Synthesize a ring protocol property declaration.
static ValueDecl *
deriveRingProperty(DerivedConformance &derived, Identifier propertyName,
bool isStatic,
AbstractFunctionDecl::BodySynthesizer bodySynthesizer) {
auto *nominal = derived.Nominal;
auto *parentDC = derived.getConformanceContext();
auto returnInterfaceTy = nominal->getDeclaredInterfaceType();
auto returnTy = parentDC->mapTypeIntoContext(returnInterfaceTy);
// Create ring property declaration.
VarDecl *propDecl;
PatternBindingDecl *pbDecl;
std::tie(propDecl, pbDecl) = derived.declareDerivedProperty(
propertyName, returnInterfaceTy, returnTy, /*isStatic*/ isStatic,
/*isFinal*/ true);
// Create ring property getter.
auto *getterDecl =
derived.addGetterToReadOnlyDerivedProperty(propDecl, returnTy);
getterDecl->setBodySynthesizer(bodySynthesizer.Fn, bodySynthesizer.Context);
derived.addMembersToConformanceContext({propDecl, pbDecl});
return propDecl;
}
// Synthesize the static property declaration for `AdditiveArithmetic.zero`.
static ValueDecl *deriveAdditiveArithmetic_zero(DerivedConformance &derived) {
auto &C = derived.TC.Context;
return deriveRingProperty(derived, C.Id_zero, /*isStatic*/ true,
{deriveBodyAdditiveArithmetic_zero, nullptr});
}
// Synthesize the static property declaration for
// `PointwiseMultiplicative.one`.
static ValueDecl *
derivePointwiseMultiplicative_one(DerivedConformance &derived) {
auto &C = derived.TC.Context;
return deriveRingProperty(derived, C.Id_one, /*isStatic*/ true,
{deriveBodyPointwiseMultiplicative_one, nullptr});
}
// Synthesize the instance property declaration for
// `PointwiseMultiplicative.reciprocal`.
static ValueDecl *
derivePointwiseMultiplicative_reciprocal(DerivedConformance &derived) {
auto &C = derived.TC.Context;
return deriveRingProperty(
derived, C.Id_reciprocal, /*isStatic*/ false,
{deriveBodyPointwiseMultiplicative_reciprocal, nullptr});
}
ValueDecl *
DerivedConformance::deriveAdditiveArithmetic(ValueDecl *requirement) {
// Diagnose conformances in disallowed contexts.
if (checkAndDiagnoseDisallowedContext(requirement))
return nullptr;
// Create memberwise initializer for nominal type if it doesn't already exist.
getOrCreateEffectiveMemberwiseInitializer(TC, Nominal);
if (requirement->getBaseName() == TC.Context.getIdentifier("+"))
return deriveMathOperator(*this, Add);
if (requirement->getBaseName() == TC.Context.getIdentifier("-"))
return deriveMathOperator(*this, Subtract);
if (requirement->getBaseName() == TC.Context.Id_zero)
return deriveAdditiveArithmetic_zero(*this);
TC.diagnose(requirement->getLoc(),
diag::broken_additive_arithmetic_requirement);
return nullptr;
}
ValueDecl *
DerivedConformance::derivePointwiseMultiplicative(ValueDecl *requirement) {
// Diagnose conformances in disallowed contexts.
if (checkAndDiagnoseDisallowedContext(requirement))
return nullptr;
// Create memberwise initializer for nominal type if it doesn't already exist.
getOrCreateEffectiveMemberwiseInitializer(TC, Nominal);
if (requirement->getBaseName() == TC.Context.getIdentifier(".*"))
return deriveMathOperator(*this, Multiply);
if (requirement->getBaseName() == TC.Context.Id_one)
return derivePointwiseMultiplicative_one(*this);
if (requirement->getBaseName() == TC.Context.Id_reciprocal)
return derivePointwiseMultiplicative_reciprocal(*this);
TC.diagnose(requirement->getLoc(),
diag::broken_pointwise_multiplicative_requirement);
return nullptr;
}