blob: 5c9de0642ae72035dbc5523e3007b0e9ca75d367 [file] [log] [blame]
//===--- DerivedConformanceVectorProtocol.cpp -----------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file implements explicit derivation of the VectorProtocol protocol for
// struct types.
//
//===----------------------------------------------------------------------===//
#include "CodeSynthesis.h"
#include "TypeChecker.h"
#include "swift/AST/Decl.h"
#include "swift/AST/Expr.h"
#include "swift/AST/GenericSignature.h"
#include "swift/AST/Module.h"
#include "swift/AST/ParameterList.h"
#include "swift/AST/Pattern.h"
#include "swift/AST/ProtocolConformance.h"
#include "swift/AST/Stmt.h"
#include "swift/AST/Types.h"
#include "DerivedConformances.h"
using namespace swift;
// Return the `VectorSpaceScalar` associated type for the given `ValueDecl` if
// it conforms to `VectorProtocol` in the given context. Otherwise, return
// `nullptr`.
static Type getVectorProtocolVectorSpaceScalarAssocType(
VarDecl *varDecl, DeclContext *DC) {
auto &C = varDecl->getASTContext();
auto *vectorProto = C.getProtocol(KnownProtocolKind::VectorProtocol);
if (varDecl->getInterfaceType()->hasError())
return nullptr;
auto varType = DC->mapTypeIntoContext(varDecl->getValueInterfaceType());
auto conf = TypeChecker::conformsToProtocol(varType, vectorProto, DC);
if (!conf)
return nullptr;
return conf.getTypeWitnessByName(varType, C.Id_VectorSpaceScalar);
}
// Return the `VectorSpaceScalar` associated type for the given nominal type in
// the given context, or `nullptr` if `VectorSpaceScalar` cannot be derived.
static Type deriveVectorProtocol_VectorSpaceScalar(NominalTypeDecl *nominal,
DeclContext *DC) {
// Nominal type must be a struct. (Zero stored properties is okay.)
if (!isa<StructDecl>(nominal))
return nullptr;
// If all stored properties conform to `VectorProtocol` and have the same
// `VectorSpaceScalar` associated type, return that `VectorSpaceScalar`
// associated type. Otherwise, the `VectorSpaceScalar` type cannot be derived.
Type sameScalarType;
for (auto member : nominal->getStoredProperties()) {
if (member->getInterfaceType()->hasError())
return nullptr;
auto scalarType = getVectorProtocolVectorSpaceScalarAssocType(member, DC);
// If stored property does not conform to `VectorProtocol`, return nullptr.
if (!scalarType)
return nullptr;
// If same `VectorSpaceScalar` type has not been set, set it for the first
// time.
if (!sameScalarType) {
sameScalarType = scalarType;
continue;
}
// If stored property `VectorSpaceScalar` types do not match, return
// nullptr.
if (!scalarType->isEqual(sameScalarType))
return nullptr;
}
return sameScalarType;
}
bool DerivedConformance::canDeriveVectorProtocol(NominalTypeDecl *nominal,
DeclContext *DC) {
// Must not have any `let` stored properties with an initial value.
// - This restriction may be lifted later with support for "true" memberwise
// initializers that initialize all stored properties, including initial
// value information.
if (hasLetStoredPropertyWithInitialValue(nominal))
return false;
// Must be able to derive `VectorSpaceScalar` associated type.
return bool(deriveVectorProtocol_VectorSpaceScalar(nominal, DC));
}
// Synthesize body for a `VectorProtocol` method requirement.
static std::pair<BraceStmt *, bool>
deriveBodyVectorProtocol_method(AbstractFunctionDecl *funcDecl,
Identifier methodName,
Identifier methodParamLabel) {
auto *parentDC = funcDecl->getParent();
auto *nominal = parentDC->getSelfNominalTypeDecl();
auto &C = nominal->getASTContext();
// Create memberwise initializer: `Nominal.init(...)`.
auto *memberwiseInitDecl = nominal->getEffectiveMemberwiseInitializer();
assert(memberwiseInitDecl && "Memberwise initializer must exist");
auto *initDRE =
new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true);
initDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
auto *nominalTypeExpr = TypeExpr::createImplicitForDecl(
DeclNameLoc(), nominal, funcDecl,
funcDecl->mapTypeIntoContext(nominal->getInterfaceType()));
auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, nominalTypeExpr);
// Get method protocol requirement.
auto *vectorProto = C.getProtocol(KnownProtocolKind::VectorProtocol);
auto *methodReq = getProtocolRequirement(vectorProto, methodName);
// Get references to `self` and parameter declarations.
auto *selfDecl = funcDecl->getImplicitSelfDecl();
auto *paramDecl = funcDecl->getParameters()->get(0);
// Create call expression applying a member method to the parameter.
// Format: `<member>.method(<parameter>)`.
// Example: `x.scaled(by: scalar)`.
auto createMemberMethodCallExpr = [&](VarDecl *member) -> Expr * {
auto *module = nominal->getModuleContext();
auto memberType =
parentDC->mapTypeIntoContext(member->getValueInterfaceType());
auto confRef = module->lookupConformance(memberType, vectorProto);
assert(confRef && "Member does not conform to `VectorNumeric`");
// Get member type's method, e.g. `Member.scaled(by:)`.
// Use protocol requirement declaration for the method by default: this
// will be dynamically dispatched.
ValueDecl *memberMethodDecl = methodReq;
// If conformance reference is concrete, then use concrete witness
// declaration for the operator.
if (confRef.isConcrete()) {
if (auto *concreteMemberMethodDecl =
confRef.getConcrete()->getWitnessDecl(methodReq))
memberMethodDecl = concreteMemberMethodDecl;
assert(memberMethodDecl);
}
assert(memberMethodDecl && "Member method declaration must exist");
// Create reference to member method: `x.scaled(by:)`.
// NOTE(TF-1054): create new `DeclRefExpr`s per loop iteration to avoid
// `ConstraintSystem::resolveOverload` error.
auto *selfDRE =
new (C) DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true);
auto memberExpr =
new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
/*Implicit*/ true);
auto memberMethodExpr =
new (C) MemberRefExpr(memberExpr, SourceLoc(), memberMethodDecl,
DeclNameLoc(), /*Implicit*/ true);
auto *paramDRE =
new (C) DeclRefExpr(paramDecl, DeclNameLoc(), /*Implicit*/ true);
// Create expression: `x.scaled(by: scalar)`.
return CallExpr::createImplicit(C, memberMethodExpr, {paramDRE},
{methodParamLabel});
};
// Create array of member method call expressions.
llvm::SmallVector<Expr *, 2> memberMethodCallExprs;
llvm::SmallVector<Identifier, 2> memberNames;
for (auto *member : nominal->getStoredProperties()) {
memberMethodCallExprs.push_back(createMemberMethodCallExpr(member));
memberNames.push_back(member->getName());
}
// Call memberwise initializer with member method call expressions.
auto *callExpr =
CallExpr::createImplicit(C, initExpr, memberMethodCallExprs, memberNames);
ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), callExpr, true);
auto *braceStmt =
BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true);
return std::pair<BraceStmt *, bool>(braceStmt, false);
}
// Synthesize function declaration for a `VectorProtocol` method requirement.
static ValueDecl *deriveVectorProtocol_method(
DerivedConformance &derived, Identifier methodBaseName,
Identifier argumentLabel, Identifier parameterName, Type parameterType,
Type returnType, AbstractFunctionDecl::BodySynthesizer bodySynthesizer) {
auto nominal = derived.Nominal;
auto &C = derived.Context;
auto parentDC = derived.getConformanceContext();
auto *param =
new (C) ParamDecl(SourceLoc(), SourceLoc(), argumentLabel, SourceLoc(),
parameterName, parentDC);
param->setSpecifier(ParamDecl::Specifier::Default);
param->setInterfaceType(parameterType);
ParameterList *params = ParameterList::create(C, {param});
DeclName declName(C, methodBaseName, params);
auto funcDecl = FuncDecl::createImplicit(
C, StaticSpellingKind::None, declName, SourceLoc(), /*Async*/ false,
/*Throws*/ false,
/*GenericParams*/ nullptr, params, returnType, parentDC);
funcDecl->setBodySynthesizer(bodySynthesizer.Fn, bodySynthesizer.Context);
funcDecl->setGenericSignature(parentDC->getGenericSignatureOfContext());
funcDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
derived.addMembersToConformanceContext({funcDecl});
return funcDecl;
}
/// Synthesize a method declaration that has the following signture:
/// func {methodBaseName}(
/// {argumentLabel} {parameterName}: VectorSpaceScalar
/// ) -> Self
static ValueDecl *deriveVectorProtocol_unaryMethodOnScalar(
DerivedConformance &derived, Identifier methodBaseName,
Identifier argumentLabel, Identifier parameterName) {
auto &C = derived.Context;
auto *nominal = derived.Nominal;
auto *parentDC = derived.getConformanceContext();
auto selfInterfaceType = parentDC->getDeclaredInterfaceType();
auto scalarType = deriveVectorProtocol_VectorSpaceScalar(nominal, parentDC)
->mapTypeOutOfContext();
auto bodySynthesizer = [](AbstractFunctionDecl *funcDecl,
void *ctx) -> std::pair<BraceStmt *, bool> {
auto methodNameAndLabel = reinterpret_cast<Identifier *>(ctx);
return deriveBodyVectorProtocol_method(
funcDecl, methodNameAndLabel[0], methodNameAndLabel[1]);
};
Identifier baseNameAndLabel[2] = {methodBaseName, argumentLabel};
return deriveVectorProtocol_method(
derived, methodBaseName, argumentLabel, parameterName, scalarType,
selfInterfaceType,
{bodySynthesizer, C.AllocateCopy(baseNameAndLabel).data()});
}
ValueDecl *DerivedConformance::deriveVectorProtocol(ValueDecl *requirement) {
// Diagnose conformances in disallowed contexts.
if (checkAndDiagnoseDisallowedContext(requirement))
return nullptr;
auto &C = requirement->getASTContext();
if (requirement->getBaseName() == Context.Id_scaled)
return deriveVectorProtocol_unaryMethodOnScalar(
*this, C.Id_scaled, C.Id_by, C.Id_scale);
if (requirement->getBaseName() == Context.Id_adding)
return deriveVectorProtocol_unaryMethodOnScalar(
*this, C.Id_adding, Identifier(), C.Id_x);
if (requirement->getBaseName() == Context.Id_subtracting)
return deriveVectorProtocol_unaryMethodOnScalar(
*this, C.Id_subtracting, Identifier(), C.Id_x);
Context.Diags.diagnose(requirement->getLoc(),
diag::broken_vector_protocol_requirement);
return nullptr;
}
Type DerivedConformance::deriveVectorProtocol(AssociatedTypeDecl *requirement) {
// Diagnose conformances in disallowed contexts.
if (checkAndDiagnoseDisallowedContext(requirement))
return nullptr;
if (requirement->getBaseName() == Context.Id_VectorSpaceScalar)
return deriveVectorProtocol_VectorSpaceScalar(
Nominal, getConformanceContext());
Context.Diags.diagnose(requirement->getLoc(),
diag::broken_vector_protocol_requirement);
return nullptr;
}