blob: 2d4e03e1e271af35662969c5b0910792a1a805f4 [file] [log] [blame]
//===--- BuilderTransform.cpp - Function-builder transformation -----------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2018 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file implements routines associated with the function-builder
// transformation.
//
//===----------------------------------------------------------------------===//
#include "ConstraintSystem.h"
#include "TypeChecker.h"
#include "swift/AST/ASTVisitor.h"
#include "swift/AST/ASTWalker.h"
#include "swift/AST/NameLookup.h"
#include "swift/AST/NameLookupRequests.h"
#include "swift/AST/ParameterList.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include <iterator>
#include <map>
#include <memory>
#include <utility>
#include <tuple>
using namespace swift;
using namespace constraints;
namespace {
/// Visitor to classify the contents of the given closure.
class BuilderClosureVisitor
: public StmtVisitor<BuilderClosureVisitor, Expr *> {
ConstraintSystem *cs;
ASTContext &ctx;
bool wantExpr;
Type builderType;
NominalTypeDecl *builder = nullptr;
llvm::SmallDenseMap<Identifier, bool> supportedOps;
public:
SkipUnhandledConstructInFunctionBuilder::UnhandledNode unhandledNode;
private:
/// Produce a builder call to the given named function with the given arguments.
Expr *buildCallIfWanted(SourceLoc loc,
Identifier fnName, ArrayRef<Expr *> args,
ArrayRef<Identifier> argLabels,
bool allowOneWay) {
if (!wantExpr)
return nullptr;
// FIXME: Setting a TypeLoc on this expression is necessary in order
// to get diagnostics if something about this builder call fails,
// e.g. if there isn't a matching overload for `buildBlock`.
// But we can only do this if there isn't a type variable in the type.
TypeLoc typeLoc;
if (!builderType->hasTypeVariable()) {
typeLoc = TypeLoc(new (ctx) FixedTypeRepr(builderType, loc), builderType);
}
auto typeExpr = new (ctx) TypeExpr(typeLoc);
if (cs) {
cs->setType(typeExpr, MetatypeType::get(builderType));
cs->setType(&typeExpr->getTypeLoc(), builderType);
}
SmallVector<SourceLoc, 4> argLabelLocs;
for (auto i : indices(argLabels)) {
argLabelLocs.push_back(args[i]->getStartLoc());
}
typeExpr->setImplicit();
auto memberRef = new (ctx) UnresolvedDotExpr(
typeExpr, loc, fnName, DeclNameLoc(loc), /*implicit=*/true);
SourceLoc openLoc = args.empty() ? loc : args.front()->getStartLoc();
SourceLoc closeLoc = args.empty() ? loc : args.back()->getEndLoc();
Expr *result = CallExpr::create(ctx, memberRef, openLoc, args,
argLabels, argLabelLocs, closeLoc,
/*trailing closure*/ nullptr,
/*implicit*/true);
if (ctx.LangOpts.FunctionBuilderOneWayConstraints && allowOneWay) {
// Form a one-way constraint to prevent backward propagation.
result = new (ctx) OneWayExpr(result);
}
return result;
}
/// Check whether the builder supports the given operation.
bool builderSupports(Identifier fnName,
ArrayRef<Identifier> argLabels = {}) {
auto known = supportedOps.find(fnName);
if (known != supportedOps.end()) {
return known->second;
}
bool found = false;
for (auto decl : builder->lookupDirect(fnName)) {
if (auto func = dyn_cast<FuncDecl>(decl)) {
// Function must be static.
if (!func->isStatic())
continue;
// Function must have the right argument labels, if provided.
if (!argLabels.empty()) {
auto funcLabels = func->getFullName().getArgumentNames();
if (argLabels.size() > funcLabels.size() ||
funcLabels.slice(0, argLabels.size()) != argLabels)
continue;
}
// Okay, it's a good-enough match.
found = true;
break;
}
}
return supportedOps[fnName] = found;
}
public:
BuilderClosureVisitor(ASTContext &ctx, ConstraintSystem *cs,
bool wantExpr, Type builderType)
: cs(cs), ctx(ctx), wantExpr(wantExpr), builderType(builderType) {
assert((cs || !builderType->hasTypeVariable()) &&
"cannot handle builder type with type variables without "
"constraint system");
builder = builderType->getAnyNominal();
}
#define CONTROL_FLOW_STMT(StmtClass) \
Expr *visit##StmtClass##Stmt(StmtClass##Stmt *stmt) { \
if (!unhandledNode) \
unhandledNode = stmt; \
\
return nullptr; \
}
Expr *visitBraceStmt(BraceStmt *braceStmt) {
SmallVector<Expr *, 4> expressions;
for (const auto &node : braceStmt->getElements()) {
if (auto stmt = node.dyn_cast<Stmt *>()) {
auto expr = visit(stmt);
if (expr)
expressions.push_back(expr);
continue;
}
if (auto decl = node.dyn_cast<Decl *>()) {
// Just ignore #if; the chosen children should appear in the
// surrounding context. This isn't good for source tools but it
// at least works.
if (isa<IfConfigDecl>(decl))
continue;
if (!unhandledNode)
unhandledNode = decl;
continue;
}
auto expr = node.get<Expr *>();
if (wantExpr && ctx.LangOpts.FunctionBuilderOneWayConstraints)
expr = new (ctx) OneWayExpr(expr);
expressions.push_back(expr);
}
// Call Builder.buildBlock(... args ...)
return buildCallIfWanted(braceStmt->getStartLoc(),
ctx.Id_buildBlock, expressions,
/*argLabels=*/{ },
/*allowOneWay=*/true);
}
Expr *visitReturnStmt(ReturnStmt *stmt) {
// Allow implicit returns due to 'return' elision.
if (!stmt->isImplicit() || !stmt->hasResult()) {
if (!unhandledNode)
unhandledNode = stmt;
return nullptr;
}
return stmt->getResult();
}
Expr *visitDoStmt(DoStmt *doStmt) {
if (!builderSupports(ctx.Id_buildDo)) {
if (!unhandledNode)
unhandledNode = doStmt;
return nullptr;
}
auto arg = visit(doStmt->getBody());
if (!arg)
return nullptr;
return buildCallIfWanted(doStmt->getStartLoc(), ctx.Id_buildDo, arg,
/*argLabels=*/{ }, /*allowOneWay=*/true);
}
CONTROL_FLOW_STMT(Yield)
CONTROL_FLOW_STMT(Defer)
static Expr *getTrivialBooleanCondition(StmtCondition condition) {
if (condition.size() != 1)
return nullptr;
return condition.front().getBooleanOrNull();
}
static bool isBuildableIfChainRecursive(IfStmt *ifStmt,
unsigned &numPayloads,
bool &isOptional) {
// The conditional must be trivial.
if (!getTrivialBooleanCondition(ifStmt->getCond()))
return false;
// The 'then' clause contributes a payload.
numPayloads++;
// If there's an 'else' clause, it contributes payloads:
if (auto elseStmt = ifStmt->getElseStmt()) {
// If it's 'else if', it contributes payloads recursively.
if (auto elseIfStmt = dyn_cast<IfStmt>(elseStmt)) {
return isBuildableIfChainRecursive(elseIfStmt, numPayloads,
isOptional);
// Otherwise it's just the one.
} else {
numPayloads++;
}
// If not, the chain result is at least optional.
} else {
isOptional = true;
}
return true;
}
bool isBuildableIfChain(IfStmt *ifStmt, unsigned &numPayloads,
bool &isOptional) {
if (!isBuildableIfChainRecursive(ifStmt, numPayloads, isOptional))
return false;
// If there's a missing 'else', we need 'buildIf' to exist.
if (isOptional && !builderSupports(ctx.Id_buildIf))
return false;
// If there are multiple clauses, we need 'buildEither(first:)' and
// 'buildEither(second:)' to both exist.
if (numPayloads > 1) {
if (!builderSupports(ctx.Id_buildEither, {ctx.Id_first}) ||
!builderSupports(ctx.Id_buildEither, {ctx.Id_second}))
return false;
}
return true;
}
Expr *visitIfStmt(IfStmt *ifStmt) {
// Check whether the chain is buildable and whether it terminates
// without an `else`.
bool isOptional = false;
unsigned numPayloads = 0;
if (!isBuildableIfChain(ifStmt, numPayloads, isOptional)) {
if (!unhandledNode)
unhandledNode = ifStmt;
return nullptr;
}
// Attempt to build the chain, propagating short-circuits, which
// might arise either do to error or not wanting an expression.
auto chainExpr =
buildIfChainRecursive(ifStmt, 0, numPayloads, isOptional);
if (!chainExpr)
return nullptr;
assert(wantExpr);
// The operand should have optional type if we had optional results,
// so we just need to call `buildIf` now, since we're at the top level.
if (isOptional) {
chainExpr = buildCallIfWanted(ifStmt->getStartLoc(),
ctx.Id_buildIf, chainExpr,
/*argLabels=*/{ },
/*allowOneWay=*/true);
} else if (ctx.LangOpts.FunctionBuilderOneWayConstraints) {
// Form a one-way constraint to prevent backward propagation.
chainExpr = new (ctx) OneWayExpr(chainExpr);
}
return chainExpr;
}
/// Recursively build an if-chain: build an expression which will have
/// a value of the chain result type before any call to `buildIf`.
/// The expression will perform any necessary calls to `buildEither`,
/// and the result will have optional type if `isOptional` is true.
Expr *buildIfChainRecursive(IfStmt *ifStmt, unsigned payloadIndex,
unsigned numPayloads, bool isOptional) {
assert(payloadIndex < numPayloads);
// Make sure we recursively visit both sides even if we're not
// building expressions.
// Build the then clause. This will have the corresponding payload
// type (i.e. not wrapped in any way).
Expr *thenArg = visit(ifStmt->getThenStmt());
// Build the else clause, if present. If this is from an else-if,
// this will be fully wrapped; otherwise it will have the corresponding
// payload type (at index `payloadIndex + 1`).
assert(ifStmt->getElseStmt() || isOptional);
bool isElseIf = false;
Optional<Expr *> elseChain;
if (auto elseStmt = ifStmt->getElseStmt()) {
if (auto elseIfStmt = dyn_cast<IfStmt>(elseStmt)) {
isElseIf = true;
elseChain = buildIfChainRecursive(elseIfStmt, payloadIndex + 1,
numPayloads, isOptional);
} else {
elseChain = visit(elseStmt);
}
}
// Short-circuit if appropriate.
if (!wantExpr || !thenArg || (elseChain && !*elseChain))
return nullptr;
// Okay, build the conditional expression.
// Prepare the `then` operand by wrapping it to produce a chain result.
SourceLoc thenLoc = ifStmt->getThenStmt()->getStartLoc();
Expr *thenExpr = buildWrappedChainPayload(thenArg, payloadIndex,
numPayloads, isOptional);
// Prepare the `else operand:
Expr *elseExpr;
SourceLoc elseLoc;
// - If there's no `else` clause, use `Optional.none`.
if (!elseChain) {
assert(isOptional);
elseLoc = ifStmt->getEndLoc();
elseExpr = buildNoneExpr(elseLoc);
// - If there's an `else if`, the chain expression from that
// should already be producing a chain result.
} else if (isElseIf) {
elseExpr = *elseChain;
elseLoc = ifStmt->getElseLoc();
// - Otherwise, wrap it to produce a chain result.
} else {
elseLoc = ifStmt->getElseLoc();
elseExpr = buildWrappedChainPayload(*elseChain,
payloadIndex + 1, numPayloads,
isOptional);
}
Expr *condition = getTrivialBooleanCondition(ifStmt->getCond());
assert(condition && "checked by isBuildableIfChain");
auto ifExpr = new (ctx) IfExpr(condition, thenLoc, thenExpr,
elseLoc, elseExpr);
ifExpr->setImplicit();
return ifExpr;
}
/// Wrap a payload value in an expression which will produce a chain
/// result (without `buildIf`).
Expr *buildWrappedChainPayload(Expr *operand, unsigned payloadIndex,
unsigned numPayloads, bool isOptional) {
assert(payloadIndex < numPayloads);
// Inject into the appropriate chain position.
//
// We produce a (left-biased) balanced binary tree of Eithers in order
// to prevent requiring a linear number of injections in the worst case.
// That is, if we have 13 clauses, we want to produce:
//
// /------------------Either------------\
// /-------Either-------\ /--Either--\
// /--Either--\ /--Either--\ /--Either--\ \
// /-E-\ /-E-\ /-E-\ /-E-\ /-E-\ /-E-\ \
// 0000 0001 0010 0011 0100 0101 0110 0111 1000 1001 1010 1011 1100
//
// Note that a prefix of length D of the payload index acts as a path
// through the tree to the node at depth D. On the rightmost path
// through the tree (when this prefix is equal to the corresponding
// prefix of the maximum payload index), the bits of the index mark
// where Eithers are required.
//
// Since we naturally want to build from the innermost Either out, and
// therefore work with progressively shorter prefixes, we can do it all
// with right-shifts.
for (auto path = payloadIndex, maxPath = numPayloads - 1;
maxPath != 0; path >>= 1, maxPath >>= 1) {
// Skip making Eithers on the rightmost path where they aren't required.
// This isn't just an optimization: adding spurious Eithers could
// leave us with unresolvable type variables if `buildEither` has
// a signature like:
// static func buildEither<T,U>(first value: T) -> Either<T,U>
// which relies on unification to work.
if (path == maxPath && !(maxPath & 1)) continue;
bool isSecond = (path & 1);
operand = buildCallIfWanted(operand->getStartLoc(),
ctx.Id_buildEither, operand,
{isSecond ? ctx.Id_second : ctx.Id_first},
/*allowOneWay=*/false);
}
// Inject into Optional if required. We'll be adding the call to
// `buildIf` after all the recursive calls are complete.
if (isOptional) {
operand = buildSomeExpr(operand);
}
return operand;
}
Expr *buildSomeExpr(Expr *arg) {
auto optionalDecl = ctx.getOptionalDecl();
auto optionalType = optionalDecl->getDeclaredType();
auto loc = arg->getStartLoc();
auto optionalTypeExpr =
TypeExpr::createImplicitHack(loc, optionalType, ctx);
auto someRef = new (ctx) UnresolvedDotExpr(
optionalTypeExpr, loc, ctx.getIdentifier("some"),
DeclNameLoc(loc), /*implicit=*/true);
return CallExpr::createImplicit(ctx, someRef, arg, { });
}
Expr *buildNoneExpr(SourceLoc endLoc) {
auto optionalDecl = ctx.getOptionalDecl();
auto optionalType = optionalDecl->getDeclaredType();
auto optionalTypeExpr =
TypeExpr::createImplicitHack(endLoc, optionalType, ctx);
return new (ctx) UnresolvedDotExpr(
optionalTypeExpr, endLoc, ctx.getIdentifier("none"),
DeclNameLoc(endLoc), /*implicit=*/true);
}
CONTROL_FLOW_STMT(Guard)
CONTROL_FLOW_STMT(While)
CONTROL_FLOW_STMT(DoCatch)
CONTROL_FLOW_STMT(RepeatWhile)
CONTROL_FLOW_STMT(ForEach)
CONTROL_FLOW_STMT(Switch)
CONTROL_FLOW_STMT(Case)
CONTROL_FLOW_STMT(Catch)
CONTROL_FLOW_STMT(Break)
CONTROL_FLOW_STMT(Continue)
CONTROL_FLOW_STMT(Fallthrough)
CONTROL_FLOW_STMT(Fail)
CONTROL_FLOW_STMT(Throw)
CONTROL_FLOW_STMT(PoundAssert)
#undef CONTROL_FLOW_STMT
};
} // end anonymous namespace
BraceStmt *
TypeChecker::applyFunctionBuilderBodyTransform(FuncDecl *FD,
BraceStmt *body,
Type builderType) {
// Try to build a single result expression.
BuilderClosureVisitor visitor(Context, nullptr,
/*wantExpr=*/true, builderType);
Expr *returnExpr = visitor.visit(body);
if (!returnExpr)
return nullptr;
// Make sure we have a usable result type for the body.
Type returnType = AnyFunctionRef(FD).getBodyResultType();
if (!returnType || returnType->hasError())
return nullptr;
auto loc = returnExpr->getStartLoc();
auto returnStmt =
new (Context) ReturnStmt(loc, returnExpr, /*implicit*/ true);
return BraceStmt::create(Context, body->getLBraceLoc(), { returnStmt },
body->getRBraceLoc());
}
ConstraintSystem::TypeMatchResult ConstraintSystem::applyFunctionBuilder(
ClosureExpr *closure, Type builderType, ConstraintLocator *calleeLocator,
ConstraintLocatorBuilder locator) {
auto builder = builderType->getAnyNominal();
assert(builder && "Bad function builder type");
assert(builder->getAttrs().hasAttribute<FunctionBuilderAttr>());
// FIXME: Right now, single-expression closures suppress the function
// builder translation.
if (closure->hasSingleExpressionBody())
return getTypeMatchSuccess();
// Pre-check the closure body: pre-check any expressions in it and look
// for return statements.
switch (TC.preCheckFunctionBuilderClosureBody(closure)) {
case FunctionBuilderClosurePreCheck::Okay:
// If the pre-check was okay, apply the function-builder transform.
break;
case FunctionBuilderClosurePreCheck::Error:
// If the pre-check had an error, flag that.
return getTypeMatchFailure(locator);
case FunctionBuilderClosurePreCheck::HasReturnStmt:
// If the closure has a return statement, suppress the transform but
// continue solving the constraint system.
return getTypeMatchSuccess();
}
// Check the form of this closure to see if we can apply the
// function-builder translation at all.
{
// Check whether we can apply this specific function builder.
BuilderClosureVisitor visitor(getASTContext(), this,
/*wantExpr=*/false, builderType);
(void)visitor.visit(closure->getBody());
// If we saw a control-flow statement or declaration that the builder
// cannot handle, we don't have a well-formed function builder application.
if (visitor.unhandledNode) {
// If we aren't supposed to attempt fixes, fail.
if (!shouldAttemptFixes()) {
return getTypeMatchFailure(locator);
}
// Record the first unhandled construct as a fix.
if (recordFix(
SkipUnhandledConstructInFunctionBuilder::create(
*this, visitor.unhandledNode, builder,
getConstraintLocator(locator)))) {
return getTypeMatchFailure(locator);
}
}
}
// If the builder type has a type parameter, substitute in the type
// variables.
if (builderType->hasTypeParameter()) {
// Find the opened type for this callee and substitute in the type
// parametes.
for (const auto &opened : OpenedTypes) {
if (opened.first == calleeLocator) {
OpenedTypeMap replacements(opened.second.begin(),
opened.second.end());
builderType = openType(builderType, replacements);
break;
}
}
assert(!builderType->hasTypeParameter());
}
BuilderClosureVisitor visitor(getASTContext(), this,
/*wantExpr=*/true, builderType);
Expr *singleExpr = visitor.visit(closure->getBody());
// We've already pre-checked all the original expressions, but do the
// pre-check to the generated expression just to set up any preconditions
// that CSGen might have.
//
// TODO: just build the AST the way we want it in the first place.
if (TC.preCheckExpression(singleExpr, closure))
return getTypeMatchFailure(locator);
singleExpr = generateConstraints(singleExpr, closure);
if (!singleExpr)
return getTypeMatchFailure(locator);
Type transformedType = getType(singleExpr);
assert(transformedType && "Missing type");
// Record the transformation.
assert(std::find_if(
builderTransformedClosures.begin(),
builderTransformedClosures.end(),
[&](const std::pair<ClosureExpr *, AppliedBuilderTransform> &elt) {
return elt.first == closure;
}) == builderTransformedClosures.end() &&
"already transformed this closure along this path!?!");
builderTransformedClosures.push_back(
std::make_pair(closure,
AppliedBuilderTransform{builderType, singleExpr}));
// Bind the result type of the closure to the type of the transformed
// expression.
Type closureType = getType(closure);
auto fnType = closureType->castTo<FunctionType>();
addConstraint(ConstraintKind::Equal, fnType->getResult(), transformedType,
locator);
return getTypeMatchSuccess();
}
namespace {
/// Pre-check all the expressions in the closure body.
class PreCheckFunctionBuilderClosure : public ASTWalker {
TypeChecker &TC;
ClosureExpr *Closure;
bool HasReturnStmt = false;
bool HasError = false;
public:
PreCheckFunctionBuilderClosure(TypeChecker &tc, ClosureExpr *closure)
: TC(tc), Closure(closure) {}
FunctionBuilderClosurePreCheck run() {
Stmt *oldBody = Closure->getBody();
Stmt *newBody = oldBody->walk(*this);
// If the walk was aborted, it was because we had a problem of some kind.
assert((newBody == nullptr) == (HasError || HasReturnStmt) &&
"unexpected short-circuit while walking closure body");
if (!newBody) {
if (HasError)
return FunctionBuilderClosurePreCheck::Error;
return FunctionBuilderClosurePreCheck::HasReturnStmt;
}
assert(oldBody == newBody && "pre-check walk wasn't in-place?");
return FunctionBuilderClosurePreCheck::Okay;
}
std::pair<bool, Expr *> walkToExprPre(Expr *E) override {
// Pre-check the expression. If this fails, abort the walk immediately.
// Otherwise, replace the expression with the result of pre-checking.
// In either case, don't recurse into the expression.
if (TC.preCheckExpression(E, /*DC*/ Closure)) {
HasError = true;
return std::make_pair(false, nullptr);
}
return std::make_pair(false, E);
}
std::pair<bool, Stmt *> walkToStmtPre(Stmt *S) override {
// If we see a return statement, abort the walk immediately.
if (isa<ReturnStmt>(S)) {
HasReturnStmt = true;
return std::make_pair(false, nullptr);
}
// Otherwise, recurse into the statement normally.
return std::make_pair(true, S);
}
};
}
FunctionBuilderClosurePreCheck
TypeChecker::preCheckFunctionBuilderClosureBody(ClosureExpr *closure) {
// Single-expression closures should already have been pre-checked.
if (closure->hasSingleExpressionBody())
return FunctionBuilderClosurePreCheck::Okay;
// Check whether we've already done this analysis.
auto it = precheckedFunctionBuilderClosures.find(closure);
if (it != precheckedFunctionBuilderClosures.end())
return it->second;
auto result = PreCheckFunctionBuilderClosure(*this, closure).run();
// Cache the result.
precheckedFunctionBuilderClosures.insert(std::make_pair(closure, result));
return result;
}