Merge pull request #18032 from rudkx/restore-order
Return to the old disjunction ordering until some test regressions are addressed.
diff --git a/lib/Sema/CSSolver.cpp b/lib/Sema/CSSolver.cpp
index 91ce7cd..5f7d20f 100644
--- a/lib/Sema/CSSolver.cpp
+++ b/lib/Sema/CSSolver.cpp
@@ -1858,24 +1858,98 @@
return false;
}
+// Attempt to find a disjunction of bind constraints where all options
+// in the disjunction are binding the same type variable.
+//
+// Prefer disjunctions where the bound type variable is also the
+// right-hand side of a conversion constraint, since having a concrete
+// type that we're converting to can make it possible to split the
+// constraint system into multiple ones.
+static Constraint *selectBestBindingDisjunction(
+ ConstraintSystem &cs, SmallVectorImpl<Constraint *> &disjunctions) {
+
+ if (disjunctions.empty())
+ return nullptr;
+
+ // Collect any disjunctions that simply attempt bindings for a
+ // type variable.
+ SmallVector<Constraint *, 8> bindingDisjunctions;
+ for (auto *disjunction : disjunctions) {
+ llvm::Optional<TypeVariableType *> commonTypeVariable;
+ if (llvm::all_of(
+ disjunction->getNestedConstraints(),
+ [&](Constraint *bindingConstraint) {
+ if (bindingConstraint->getKind() != ConstraintKind::Bind)
+ return false;
+
+ auto *tv =
+ bindingConstraint->getFirstType()->getAs<TypeVariableType>();
+ // Only do this for simple type variable bindings, not for
+ // bindings like: ($T1) -> $T2 bind String -> Int
+ if (!tv)
+ return false;
+
+ if (!commonTypeVariable.hasValue())
+ commonTypeVariable = tv;
+
+ if (commonTypeVariable.getValue() != tv)
+ return false;
+
+ return true;
+ })) {
+ bindingDisjunctions.push_back(disjunction);
+ }
+ }
+
+ for (auto *disjunction : bindingDisjunctions) {
+ auto nested = disjunction->getNestedConstraints();
+ assert(!nested.empty());
+ auto *tv = cs.simplifyType(nested[0]->getFirstType())
+ ->getRValueType()
+ ->getAs<TypeVariableType>();
+ assert(tv);
+
+ SmallVector<Constraint *, 8> constraints;
+ cs.getConstraintGraph().gatherConstraints(
+ tv, constraints, ConstraintGraph::GatheringKind::EquivalenceClass);
+
+ for (auto *constraint : constraints) {
+ if (constraint->getKind() != ConstraintKind::Conversion)
+ continue;
+
+ auto toType =
+ cs.simplifyType(constraint->getSecondType())->getRValueType();
+ auto *toTV = toType->getAs<TypeVariableType>();
+ if (tv != toTV)
+ continue;
+
+ return disjunction;
+ }
+ }
+
+ // If we had any binding disjunctions, return the first of
+ // those. These ensure that we attempt to bind types earlier than
+ // trying the elements of other disjunctions, which can often mean
+ // we fail faster.
+ if (!bindingDisjunctions.empty())
+ return bindingDisjunctions[0];
+
+ return nullptr;
+}
+
Constraint *ConstraintSystem::selectDisjunction() {
SmallVector<Constraint *, 4> disjunctions;
collectDisjunctions(disjunctions);
+ if (auto *disjunction = selectBestBindingDisjunction(*this, disjunctions))
+ return disjunction;
- // Pick the disjunction with the lowest disjunction number in order
- // to solve them in the order they were created (which should be
- // stable within an expression).
+ // Pick the disjunction with the smallest number of active choices.
auto minDisjunction =
std::min_element(disjunctions.begin(), disjunctions.end(),
[&](Constraint *first, Constraint *second) -> bool {
- auto firstFound = DisjunctionNumber.find(first);
- auto secondFound = DisjunctionNumber.find(second);
-
- assert(firstFound != DisjunctionNumber.end() &&
- secondFound != DisjunctionNumber.end());
-
- return firstFound->second < secondFound->second;
+ return first->countActiveNestedConstraints() <
+ second->countActiveNestedConstraints();
});
if (minDisjunction != disjunctions.end())
diff --git a/validation-test/stdlib/AnyHashable.swift.gyb b/validation-test/stdlib/AnyHashable.swift.gyb
index c4fef1e..4c318e4 100644
--- a/validation-test/stdlib/AnyHashable.swift.gyb
+++ b/validation-test/stdlib/AnyHashable.swift.gyb
@@ -757,9 +757,8 @@
xs,
equalityOracle: { $0 / 2 == $1 / 2 },
hashEqualityOracle: { $0 / 4 == $1 / 4 })
- let mapped = xs.map(AnyHashable.init)
checkHashable(
- mapped,
+ xs.map(AnyHashable.init),
equalityOracle: { $0 / 2 == $1 / 2 },
hashEqualityOracle: { $0 / 4 == $1 / 4 })
}