Merge pull request #18785 from xedin/select-best-disjunction-improvements

[CSSolver] Refactor `selectBestBindingDisjunction`
diff --git a/lib/Sema/CSSolver.cpp b/lib/Sema/CSSolver.cpp
index 5a26b26..4bc5add 100644
--- a/lib/Sema/CSSolver.cpp
+++ b/lib/Sema/CSSolver.cpp
@@ -1856,62 +1856,41 @@
   if (disjunctions.empty())
     return nullptr;
 
-  // Collect any disjunctions that simply attempt bindings for a
-  // type variable.
-  SmallVector<Constraint *, 8> bindingDisjunctions;
+  auto getAsTypeVar = [&cs](Type type) {
+    return cs.simplifyType(type)->getRValueType()->getAs<TypeVariableType>();
+  };
+
+  Constraint *firstBindDisjunction = nullptr;
   for (auto *disjunction : disjunctions) {
-    TypeVariableType *commonTypeVariable = nullptr;
-    if (llvm::all_of(
-            disjunction->getNestedConstraints(),
-            [&](Constraint *bindingConstraint) {
-              if (bindingConstraint->getKind() != ConstraintKind::Bind)
-                return false;
+    auto choices = disjunction->getNestedConstraints();
+    assert(!choices.empty());
 
-              auto *tv = cs.simplifyType(bindingConstraint->getFirstType())
-                             ->getRValueType()
-                             ->getAs<TypeVariableType>();
-              // Only do this for simple type variable bindings, not for
-              // bindings like: ($T1) -> $T2 bind String -> Int
-              if (!tv)
-                return false;
+    auto *choice = choices.front();
+    if (choice->getKind() != ConstraintKind::Bind)
+      continue;
 
-              // If we've seen a variable before, make sure that this is
-              // the same one.
-              if (commonTypeVariable == tv)
-                return true;
-              if (commonTypeVariable)
-                return false;
+    // We can judge disjunction based on the single choice
+    // because all of choices (of bind overload set) should
+    // have the same left-hand side.
+    // Only do this for simple type variable bindings, not for
+    // bindings like: ($T1) -> $T2 bind String -> Int
+    auto *typeVar = getAsTypeVar(choice->getFirstType());
+    if (!typeVar)
+      continue;
 
-              commonTypeVariable = tv;
-              return true;
-            })) {
-      bindingDisjunctions.push_back(disjunction);
-    }
-  }
-
-  for (auto *disjunction : bindingDisjunctions) {
-    auto nested = disjunction->getNestedConstraints();
-    assert(!nested.empty());
-    auto *tv = cs.simplifyType(nested[0]->getFirstType())
-                   ->getRValueType()
-                   ->getAs<TypeVariableType>();
-    assert(tv);
+    if (!firstBindDisjunction)
+      firstBindDisjunction = disjunction;
 
     llvm::SetVector<Constraint *> constraints;
     cs.getConstraintGraph().gatherConstraints(
-        tv, constraints, ConstraintGraph::GatheringKind::EquivalenceClass,
+        typeVar, constraints, ConstraintGraph::GatheringKind::EquivalenceClass,
         [](Constraint *constraint) {
           return constraint->getKind() == ConstraintKind::Conversion;
         });
 
     for (auto *constraint : constraints) {
-      auto toType =
-          cs.simplifyType(constraint->getSecondType())->getRValueType();
-      auto *toTV = toType->getAs<TypeVariableType>();
-      if (tv != toTV)
-        continue;
-
-      return disjunction;
+      if (typeVar == getAsTypeVar(constraint->getSecondType()))
+        return disjunction;
     }
   }
 
@@ -1919,10 +1898,7 @@
   // those. These ensure that we attempt to bind types earlier than
   // trying the elements of other disjunctions, which can often mean
   // we fail faster.
-  if (!bindingDisjunctions.empty())
-    return bindingDisjunctions[0];
-
-  return nullptr;
+  return firstBindDisjunction;
 }
 
 Constraint *ConstraintSystem::selectDisjunction() {
diff --git a/lib/Sema/Constraint.cpp b/lib/Sema/Constraint.cpp
index aef0dd1..a724019 100644
--- a/lib/Sema/Constraint.cpp
+++ b/lib/Sema/Constraint.cpp
@@ -722,6 +722,25 @@
     return constraints.front();
   }
 
+#ifndef NDEBUG
+  assert(!constraints.empty());
+  // Verify that all disjunction choices have the same left-hand side.
+  Type commonType;
+  assert(llvm::all_of(constraints, [&](const Constraint *choice) -> bool {
+    // if this disjunction is formed from "fixed"
+    // constraints let's not try to validate.
+    if (choice->HasRestriction || choice->getFix())
+      return true;
+
+    auto currentType = choice->getFirstType();
+    if (!commonType) {
+      commonType = currentType;
+      return true;
+    }
+    return commonType->isEqual(currentType);
+  }));
+#endif
+
   // Create the disjunction constraint.
   uniqueTypeVariables(typeVars);
   unsigned size = totalSizeToAlloc<TypeVariableType*>(typeVars.size());