blob: a1b6ad757fd90457702f92f41e4c3cc2fb406fbc [file] [log] [blame]
//===--- DerivedConformanceVectorProtocol.cpp -----------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file implements explicit derivation of the VectorProtocol protocol for
// struct types.
//
//===----------------------------------------------------------------------===//
#include "CodeSynthesis.h"
#include "TypeChecker.h"
#include "swift/AST/Decl.h"
#include "swift/AST/Expr.h"
#include "swift/AST/GenericSignature.h"
#include "swift/AST/Module.h"
#include "swift/AST/ParameterList.h"
#include "swift/AST/Pattern.h"
#include "swift/AST/ProtocolConformance.h"
#include "swift/AST/Stmt.h"
#include "swift/AST/Types.h"
#include "DerivedConformances.h"
using namespace swift;
// Return the protocol requirement with the specified name.
// TODO: Move function to shared place for use with other derived conformances.
static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
auto lookup = proto->lookupDirect(name);
// Erase declarations that are not protocol requirements.
// This is important for removing default implementations of the same name.
llvm::erase_if(lookup, [](ValueDecl *v) {
return !isa<ProtocolDecl>(v->getDeclContext()) ||
!v->isProtocolRequirement();
});
assert(lookup.size() == 1 && "Ambiguous protocol requirement");
return lookup.front();
}
// Return true if given nominal type has a `let` stored with an initial value.
// TODO: Move function to shared place for use with other derived conformances.
static bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal) {
return llvm::any_of(nominal->getStoredProperties(), [&](VarDecl *v) {
return v->isLet() && v->hasInitialValue();
});
}
// Return the `VectorSpaceScalar` associated type for the given `ValueDecl` if
// it conforms to `VectorProtocol` in the given context. Otherwise, return
// `nullptr`.
static Type getVectorProtocolVectorSpaceScalarAssocType(
VarDecl *varDecl, DeclContext *DC) {
auto &C = varDecl->getASTContext();
auto *vectorProto = C.getProtocol(KnownProtocolKind::VectorProtocol);
if (!varDecl->hasInterfaceType())
C.getLazyResolver()->resolveDeclSignature(varDecl);
if (!varDecl->hasInterfaceType())
return nullptr;
auto varType = DC->mapTypeIntoContext(varDecl->getValueInterfaceType());
auto conf = TypeChecker::conformsToProtocol(varType, vectorProto, DC, None);
if (!conf)
return nullptr;
return conf->getTypeWitnessByName(varType, C.Id_VectorSpaceScalar);
}
// Return the `VectorSpaceScalar` associated type for the given nominal type in
// the given context, or `nullptr` if `VectorSpaceScalar` cannot be derived.
static Type deriveVectorProtocol_VectorSpaceScalar(NominalTypeDecl *nominal,
DeclContext *DC) {
auto &C = DC->getASTContext();
// Nominal type must be a struct. (Zero stored properties is okay.)
if (!isa<StructDecl>(nominal))
return nullptr;
// If all stored properties conform to `VectorProtocol` and have the same
// `VectorSpaceScalar` associated type, return that `VectorSpaceScalar`
// associated type. Otherwise, the `VectorSpaceScalar` type cannot be derived.
Type sameScalarType;
for (auto member : nominal->getStoredProperties()) {
if (!member->hasInterfaceType())
C.getLazyResolver()->resolveDeclSignature(member);
if (!member->hasInterfaceType())
return nullptr;
auto scalarType = getVectorProtocolVectorSpaceScalarAssocType(member, DC);
// If stored property does not conform to `VectorProtocol`, return nullptr.
if (!scalarType)
return nullptr;
// If same `VectorSpaceScalar` type has not been set, set it for the first
// time.
if (!sameScalarType) {
sameScalarType = scalarType;
continue;
}
// If stored property `VectorSpaceScalar` types do not match, return
// nullptr.
if (!scalarType->isEqual(sameScalarType))
return nullptr;
}
return sameScalarType;
}
bool DerivedConformance::canDeriveVectorProtocol(NominalTypeDecl *nominal,
DeclContext *DC) {
// Must not have any `let` stored properties with an initial value.
// - This restriction may be lifted later with support for "true" memberwise
// initializers that initialize all stored properties, including initial
// value information.
if (hasLetStoredPropertyWithInitialValue(nominal))
return false;
// Must be able to derive `VectorSpaceScalar` associated type.
return bool(deriveVectorProtocol_VectorSpaceScalar(nominal, DC));
}
// Synthesize body for a `VectorProtocol` method requirement.
static std::pair<BraceStmt *, bool>
deriveBodyVectorProtocol_method(AbstractFunctionDecl *funcDecl,
Identifier methodName,
Identifier methodParamLabel) {
auto *parentDC = funcDecl->getParent();
auto *nominal = parentDC->getSelfNominalTypeDecl();
auto &C = nominal->getASTContext();
// Create memberwise initializer: `Nominal.init(...)`.
auto *memberwiseInitDecl = nominal->getEffectiveMemberwiseInitializer();
assert(memberwiseInitDecl && "Memberwise initializer must exist");
auto *initDRE =
new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true);
initDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
auto *nominalTypeExpr = TypeExpr::createForDecl(SourceLoc(), nominal,
funcDecl, /*Implicit*/ true);
auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, nominalTypeExpr);
// Get method protocol requirement.
auto *vectorProto = C.getProtocol(KnownProtocolKind::VectorProtocol);
auto *methodReq = getProtocolRequirement(vectorProto, methodName);
// Get references to `self` and parameter declarations.
auto *selfDecl = funcDecl->getImplicitSelfDecl();
auto *selfDRE =
new (C) DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true);
auto *paramDecl = funcDecl->getParameters()->get(0);
auto *paramDRE =
new (C) DeclRefExpr(paramDecl, DeclNameLoc(), /*Implicit*/ true);
// Create call expression applying a member method to the parameter.
// Format: `<member>.method(<parameter>)`.
// Example: `x.scaled(by: scalar)`.
auto createMemberMethodCallExpr = [&](VarDecl *member) -> Expr * {
auto *module = nominal->getModuleContext();
auto memberType =
parentDC->mapTypeIntoContext(member->getValueInterfaceType());
auto confRef = module->lookupConformance(memberType, vectorProto);
assert(confRef && "Member does not conform to `VectorNumeric`");
// Get member type's method, e.g. `Member.scaled(by:)`.
// Use protocol requirement declaration for the method by default: this
// will be dynamically dispatched.
ValueDecl *memberMethodDecl = methodReq;
// If conformance reference is concrete, then use concrete witness
// declaration for the operator.
if (confRef->isConcrete()) {
if (auto *concreteMemberMethodDecl =
confRef->getConcrete()->getWitnessDecl(methodReq))
memberMethodDecl = concreteMemberMethodDecl;
assert(memberMethodDecl);
}
assert(memberMethodDecl && "Member method declaration must exist");
auto memberMethodDRE =
new (C) DeclRefExpr(memberMethodDecl, DeclNameLoc(), /*Implicit*/ true);
memberMethodDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
// Create reference to member method: `x.scaled(by:)`.
auto memberExpr =
new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
/*Implicit*/ true);
auto memberMethodExpr =
new (C) DotSyntaxCallExpr(memberMethodDRE, SourceLoc(), memberExpr);
// Create expression: `x.scaled(by: scalar)`.
return CallExpr::createImplicit(C, memberMethodExpr, {paramDRE},
{methodParamLabel});
};
// Create array of member method call expressions.
llvm::SmallVector<Expr *, 2> memberMethodCallExprs;
llvm::SmallVector<Identifier, 2> memberNames;
for (auto *member : nominal->getStoredProperties()) {
memberMethodCallExprs.push_back(createMemberMethodCallExpr(member));
memberNames.push_back(member->getName());
}
// Call memberwise initializer with member method call expressions.
auto *callExpr =
CallExpr::createImplicit(C, initExpr, memberMethodCallExprs, memberNames);
ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), callExpr, true);
auto *braceStmt =
BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true);
return std::pair<BraceStmt *, bool>(braceStmt, false);
}
// Synthesize function declaration for a `VectorProtocol` method requirement.
static ValueDecl *deriveVectorProtocol_method(
DerivedConformance &derived, Identifier methodBaseName,
Identifier argumentLabel, Identifier parameterName, Type parameterType,
Type returnType, AbstractFunctionDecl::BodySynthesizer bodySynthesizer) {
auto nominal = derived.Nominal;
auto &TC = derived.TC;
auto &C = derived.TC.Context;
auto parentDC = derived.getConformanceContext();
auto *param =
new (C) ParamDecl(ParamDecl::Specifier::Default, SourceLoc(), SourceLoc(),
argumentLabel, SourceLoc(), parameterName, parentDC);
param->setInterfaceType(parameterType);
ParameterList *params = ParameterList::create(C, {param});
DeclName declName(C, methodBaseName, params);
auto funcDecl = FuncDecl::create(C, SourceLoc(), StaticSpellingKind::None,
SourceLoc(), declName, SourceLoc(),
/*Throws*/ false, SourceLoc(),
/*GenericParams*/ nullptr, params,
TypeLoc::withoutLoc(returnType), parentDC);
funcDecl->setImplicit();
funcDecl->setBodySynthesizer(bodySynthesizer.Fn, bodySynthesizer.Context);
funcDecl->setGenericSignature(parentDC->getGenericSignatureOfContext());
funcDecl->computeType();
funcDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
derived.addMembersToConformanceContext({funcDecl});
C.addSynthesizedDecl(funcDecl);
// Returned nominal type must define a memberwise initializer.
// Add memberwise initializer if necessary.
if (!nominal->getEffectiveMemberwiseInitializer()) {
// The implicit memberwise constructor must be explicitly created so that
// it can called in `VectorProtocol` methods. Normally, the memberwise
// constructor is synthesized during SILGen, which is too late.
auto *initDecl = createMemberwiseImplicitConstructor(TC, nominal);
nominal->addMember(initDecl);
C.addSynthesizedDecl(initDecl);
}
return funcDecl;
}
/// Synthesize a method declaration that has the following signture:
/// func {methodBaseName}(
/// {argumentLabel} {parameterName}: VectorSpaceScalar
/// ) -> Self
static ValueDecl *deriveVectorProtocol_unaryMethodOnScalar(
DerivedConformance &derived, Identifier methodBaseName,
Identifier argumentLabel, Identifier parameterName) {
auto &C = derived.TC.Context;
auto *nominal = derived.Nominal;
auto *parentDC = derived.getConformanceContext();
auto selfInterfaceType = parentDC->getDeclaredInterfaceType();
auto scalarType = deriveVectorProtocol_VectorSpaceScalar(nominal, parentDC)
->mapTypeOutOfContext();
auto bodySynthesizer = [](AbstractFunctionDecl *funcDecl,
void *ctx) -> std::pair<BraceStmt *, bool> {
auto methodNameAndLabel = reinterpret_cast<Identifier *>(ctx);
return deriveBodyVectorProtocol_method(
funcDecl, methodNameAndLabel[0], methodNameAndLabel[1]);
};
Identifier baseNameAndLabel[2] = {methodBaseName, argumentLabel};
return deriveVectorProtocol_method(
derived, methodBaseName, argumentLabel, parameterName, scalarType,
selfInterfaceType,
{bodySynthesizer, C.AllocateCopy(baseNameAndLabel).data()});
}
ValueDecl *DerivedConformance::deriveVectorProtocol(ValueDecl *requirement) {
// Diagnose conformances in disallowed contexts.
if (checkAndDiagnoseDisallowedContext(requirement))
return nullptr;
auto &C = requirement->getASTContext();
if (requirement->getBaseName() == TC.Context.Id_scaled)
return deriveVectorProtocol_unaryMethodOnScalar(
*this, C.Id_scaled, C.Id_by, C.Id_scale);
if (requirement->getBaseName() == TC.Context.Id_adding)
return deriveVectorProtocol_unaryMethodOnScalar(
*this, C.Id_adding, Identifier(), C.Id_x);
if (requirement->getBaseName() == TC.Context.Id_subtracting)
return deriveVectorProtocol_unaryMethodOnScalar(
*this, C.Id_subtracting, Identifier(), C.Id_x);
TC.diagnose(requirement->getLoc(), diag::broken_vector_protocol_requirement);
return nullptr;
}
Type DerivedConformance::deriveVectorProtocol(AssociatedTypeDecl *requirement) {
// Diagnose conformances in disallowed contexts.
if (checkAndDiagnoseDisallowedContext(requirement))
return nullptr;
if (requirement->getBaseName() == TC.Context.Id_VectorSpaceScalar)
return deriveVectorProtocol_VectorSpaceScalar(
Nominal, getConformanceContext());
TC.diagnose(requirement->getLoc(), diag::broken_vector_protocol_requirement);
return nullptr;
}