Merge pull request #18785 from xedin/select-best-disjunction-improvements
[CSSolver] Refactor `selectBestBindingDisjunction`
diff --git a/lib/Sema/CSSolver.cpp b/lib/Sema/CSSolver.cpp
index 5a26b26..4bc5add 100644
--- a/lib/Sema/CSSolver.cpp
+++ b/lib/Sema/CSSolver.cpp
@@ -1856,62 +1856,41 @@
if (disjunctions.empty())
return nullptr;
- // Collect any disjunctions that simply attempt bindings for a
- // type variable.
- SmallVector<Constraint *, 8> bindingDisjunctions;
+ auto getAsTypeVar = [&cs](Type type) {
+ return cs.simplifyType(type)->getRValueType()->getAs<TypeVariableType>();
+ };
+
+ Constraint *firstBindDisjunction = nullptr;
for (auto *disjunction : disjunctions) {
- TypeVariableType *commonTypeVariable = nullptr;
- if (llvm::all_of(
- disjunction->getNestedConstraints(),
- [&](Constraint *bindingConstraint) {
- if (bindingConstraint->getKind() != ConstraintKind::Bind)
- return false;
+ auto choices = disjunction->getNestedConstraints();
+ assert(!choices.empty());
- auto *tv = cs.simplifyType(bindingConstraint->getFirstType())
- ->getRValueType()
- ->getAs<TypeVariableType>();
- // Only do this for simple type variable bindings, not for
- // bindings like: ($T1) -> $T2 bind String -> Int
- if (!tv)
- return false;
+ auto *choice = choices.front();
+ if (choice->getKind() != ConstraintKind::Bind)
+ continue;
- // If we've seen a variable before, make sure that this is
- // the same one.
- if (commonTypeVariable == tv)
- return true;
- if (commonTypeVariable)
- return false;
+ // We can judge disjunction based on the single choice
+ // because all of choices (of bind overload set) should
+ // have the same left-hand side.
+ // Only do this for simple type variable bindings, not for
+ // bindings like: ($T1) -> $T2 bind String -> Int
+ auto *typeVar = getAsTypeVar(choice->getFirstType());
+ if (!typeVar)
+ continue;
- commonTypeVariable = tv;
- return true;
- })) {
- bindingDisjunctions.push_back(disjunction);
- }
- }
-
- for (auto *disjunction : bindingDisjunctions) {
- auto nested = disjunction->getNestedConstraints();
- assert(!nested.empty());
- auto *tv = cs.simplifyType(nested[0]->getFirstType())
- ->getRValueType()
- ->getAs<TypeVariableType>();
- assert(tv);
+ if (!firstBindDisjunction)
+ firstBindDisjunction = disjunction;
llvm::SetVector<Constraint *> constraints;
cs.getConstraintGraph().gatherConstraints(
- tv, constraints, ConstraintGraph::GatheringKind::EquivalenceClass,
+ typeVar, constraints, ConstraintGraph::GatheringKind::EquivalenceClass,
[](Constraint *constraint) {
return constraint->getKind() == ConstraintKind::Conversion;
});
for (auto *constraint : constraints) {
- auto toType =
- cs.simplifyType(constraint->getSecondType())->getRValueType();
- auto *toTV = toType->getAs<TypeVariableType>();
- if (tv != toTV)
- continue;
-
- return disjunction;
+ if (typeVar == getAsTypeVar(constraint->getSecondType()))
+ return disjunction;
}
}
@@ -1919,10 +1898,7 @@
// those. These ensure that we attempt to bind types earlier than
// trying the elements of other disjunctions, which can often mean
// we fail faster.
- if (!bindingDisjunctions.empty())
- return bindingDisjunctions[0];
-
- return nullptr;
+ return firstBindDisjunction;
}
Constraint *ConstraintSystem::selectDisjunction() {
diff --git a/lib/Sema/Constraint.cpp b/lib/Sema/Constraint.cpp
index aef0dd1..a724019 100644
--- a/lib/Sema/Constraint.cpp
+++ b/lib/Sema/Constraint.cpp
@@ -722,6 +722,25 @@
return constraints.front();
}
+#ifndef NDEBUG
+ assert(!constraints.empty());
+ // Verify that all disjunction choices have the same left-hand side.
+ Type commonType;
+ assert(llvm::all_of(constraints, [&](const Constraint *choice) -> bool {
+ // if this disjunction is formed from "fixed"
+ // constraints let's not try to validate.
+ if (choice->HasRestriction || choice->getFix())
+ return true;
+
+ auto currentType = choice->getFirstType();
+ if (!commonType) {
+ commonType = currentType;
+ return true;
+ }
+ return commonType->isEqual(currentType);
+ }));
+#endif
+
// Create the disjunction constraint.
uniqueTypeVariables(typeVars);
unsigned size = totalSizeToAlloc<TypeVariableType*>(typeVars.size());