Merge pull request #6253 from rudkx/cs-typemap
diff --git a/include/swift/AST/Expr.h b/include/swift/AST/Expr.h
index d0274ad..e7ea7c8 100644
--- a/include/swift/AST/Expr.h
+++ b/include/swift/AST/Expr.h
@@ -576,6 +576,10 @@
/// \param allowOverwrite - true if it's okay if an expression already
/// has an access kind
void propagateLValueAccessKind(AccessKind accessKind,
+ std::function<Type(Expr *)> getType
+ = [](Expr *E) -> Type {
+ return E->getType();
+ },
bool allowOverwrite = false);
/// Retrieves the declaration that is being referenced by this
diff --git a/lib/AST/Expr.cpp b/lib/AST/Expr.cpp
index 2729415..d5299ca 100644
--- a/lib/AST/Expr.cpp
+++ b/lib/AST/Expr.cpp
@@ -213,17 +213,20 @@
/// Propagate l-value use information to children.
void Expr::propagateLValueAccessKind(AccessKind accessKind,
+ std::function<Type(Expr *)> getType,
bool allowOverwrite) {
/// A visitor class which walks an entire l-value expression.
class PropagateAccessKind
: public ExprVisitor<PropagateAccessKind, void, AccessKind> {
#ifndef NDEBUG
+ std::function<Type(Expr *)> GetType;
bool AllowOverwrite;
#endif
public:
- PropagateAccessKind(bool allowOverwrite)
+ PropagateAccessKind(std::function<Type(Expr *)> getType,
+ bool allowOverwrite)
#ifndef NDEBUG
- : AllowOverwrite(allowOverwrite)
+ : GetType(getType), AllowOverwrite(allowOverwrite)
#endif
{}
@@ -231,7 +234,7 @@
assert((AllowOverwrite || !E->hasLValueAccessKind()) &&
"l-value access kind has already been set");
- assert(E->getType()->isAssignableType() &&
+ assert(GetType(E)->isAssignableType() &&
"setting access kind on non-l-value");
E->setLValueAccessKind(kind);
@@ -255,11 +258,11 @@
}
void visitMemberRefExpr(MemberRefExpr *E, AccessKind accessKind) {
- if (!E->getBase()->getType()->isLValueType()) return;
+ if (!GetType(E->getBase())->isLValueType()) return;
visit(E->getBase(), getBaseAccessKind(E->getMember(), accessKind));
}
void visitSubscriptExpr(SubscriptExpr *E, AccessKind accessKind) {
- if (!E->getBase()->getType()->isLValueType()) return;
+ if (!GetType(E->getBase())->isLValueType()) return;
visit(E->getBase(), getBaseAccessKind(E->getDecl(), accessKind));
}
@@ -354,7 +357,7 @@
#undef NON_LVALUE_EXPR
};
- PropagateAccessKind(allowOverwrite).visit(this, accessKind);
+ PropagateAccessKind(getType, allowOverwrite).visit(this, accessKind);
}
ConcreteDeclRef Expr::getReferencedDecl() const {
diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp
index d4f3d1b..ccde1f6 100644
--- a/lib/Sema/CSApply.cpp
+++ b/lib/Sema/CSApply.cpp
@@ -266,7 +266,18 @@
return member->getAccessSemanticsFromContext(DC);
}
+void ConstraintSystem::propagateLValueAccessKind(Expr *E,
+ AccessKind accessKind,
+ bool allowOverwrite) {
+ E->propagateLValueAccessKind(accessKind,
+ [&](Expr *E) -> Type {
+ return getType(E);
+ },
+ allowOverwrite);
+}
+
namespace {
+
/// \brief Rewrites an expression by applying the solution of a constraint
/// system to that expression.
class ExprRewriter : public ExprVisitor<ExprRewriter, Expr *> {
@@ -688,8 +699,8 @@
// down to the original existential value. Otherwise, propagateLVAK
// will handle this.
if (record.OpaqueValue->hasLValueAccessKind())
- record.ExistentialValue->propagateLValueAccessKind(
- record.OpaqueValue->getLValueAccessKind());
+ cs.propagateLValueAccessKind(record.ExistentialValue,
+ record.OpaqueValue->getLValueAccessKind());
// Form the open-existential expression.
result = new (tc.Context) OpenExistentialExpr(
@@ -2803,7 +2814,7 @@
// case (when we turn the inout into an UnsafePointer) than to try to
// discover that we're in that case right now.
if (!cs.getType(expr->getSubExpr())->is<UnresolvedType>())
- expr->getSubExpr()->propagateLValueAccessKind(AccessKind::ReadWrite);
+ cs.propagateLValueAccessKind(expr->getSubExpr(), AccessKind::ReadWrite);
auto objectTy = cs.getType(expr->getSubExpr())->getRValueType();
// The type is simply inout of whatever the lvalue's object type was.
@@ -3322,7 +3333,7 @@
auto destTy = cs.computeAssignDestType(expr->getDest(), expr->getLoc());
if (!destTy)
return nullptr;
- expr->getDest()->propagateLValueAccessKind(AccessKind::Write);
+ cs.propagateLValueAccessKind(expr->getDest(), AccessKind::Write);
// Convert the source to the simplified destination type.
auto locator =
@@ -5284,7 +5295,7 @@
break;
// Load from the lvalue.
- expr->propagateLValueAccessKind(AccessKind::Read);
+ cs.propagateLValueAccessKind(expr, AccessKind::Read);
expr = cs.cacheType(
new (tc.Context) LoadExpr(expr, fromType->getRValueType()));
@@ -5427,8 +5438,9 @@
if (pointerKind == PTK_UnsafePointer) {
// Overwrite the l-value access kind to be read-only if we're
// converting to a non-mutable pointer type.
- cast<InOutExpr>(expr->getValueProvidingExpr())->getSubExpr()
- ->propagateLValueAccessKind(AccessKind::Read, /*overwrite*/ true);
+ auto *E = cast<InOutExpr>(expr->getValueProvidingExpr())->getSubExpr();
+ cs.propagateLValueAccessKind(E,
+ AccessKind::Read, /*overwrite*/ true);
}
tc.requirePointerArgumentIntrinsics(expr->getLoc());
@@ -5560,7 +5572,7 @@
// In an 'inout' operator like "++i", the operand is converted from
// an implicit lvalue to an inout argument.
assert(toIO->getObjectType()->isEqual(fromLValue->getObjectType()));
- expr->propagateLValueAccessKind(AccessKind::ReadWrite);
+ cs.propagateLValueAccessKind(expr, AccessKind::ReadWrite);
return cs.cacheType(new (tc.Context)
InOutExpr(expr->getStartLoc(), expr, toType,
/*isImplicit*/ true));
@@ -5578,7 +5590,7 @@
if (performLoad) {
// Load from the lvalue.
- expr->propagateLValueAccessKind(AccessKind::Read);
+ cs.propagateLValueAccessKind(expr, AccessKind::Read);
expr = cs.cacheType(new (tc.Context)
LoadExpr(expr, fromLValue->getObjectType()));
@@ -5811,7 +5823,7 @@
// Use InOutExpr to convert it to an explicit inout argument for the
// receiver.
- expr->propagateLValueAccessKind(AccessKind::ReadWrite);
+ cs.propagateLValueAccessKind(expr, AccessKind::ReadWrite);
return cs.cacheType(new (ctx) InOutExpr(expr->getStartLoc(), expr, toType,
/*isImplicit*/ true));
}
@@ -6600,7 +6612,7 @@
// If we already have an rvalue, we're done, otherwise emit a load.
if (auto lvalueTy = getType(expr)->getAs<LValueType>()) {
- expr->propagateLValueAccessKind(AccessKind::Read);
+ propagateLValueAccessKind(expr, AccessKind::Read);
return cacheType(new (getASTContext())
LoadExpr(expr, lvalueTy->getObjectType()));
}
diff --git a/lib/Sema/ConstraintSystem.h b/lib/Sema/ConstraintSystem.h
index ca7fed4..319811d 100644
--- a/lib/Sema/ConstraintSystem.h
+++ b/lib/Sema/ConstraintSystem.h
@@ -1764,6 +1764,13 @@
/// \brief Determine if the type in question is AnyHashable.
bool isAnyHashableType(Type t);
+ /// Call Expr::propagateLValueAccessKind on the given expression,
+ /// using a custom accessor for the type on the expression which
+ /// reads the type from the ConstraintSystem expression type map.
+ void propagateLValueAccessKind(Expr *E,
+ AccessKind accessKind,
+ bool allowOverwrite = false);
+
private:
/// Introduce the constraints associated with the given type variable
/// into the worklist.