Merge pull request #19971 from rudkx/sort-designated-types

[ConstraintSystem] Sort the designated types based on actual argument…
diff --git a/lib/Sema/CSSolver.cpp b/lib/Sema/CSSolver.cpp
index c5bb33e..3b1a18f 100644
--- a/lib/Sema/CSSolver.cpp
+++ b/lib/Sema/CSSolver.cpp
@@ -1683,14 +1683,91 @@
   return operatorDecl->getDesignatedNominalTypes();
 }
 
+void ConstraintSystem::sortDesignatedTypes(
+    SmallVectorImpl<NominalTypeDecl *> &nominalTypes,
+    Constraint *bindOverload) {
+  auto *tyvar = bindOverload->getFirstType()->castTo<TypeVariableType>();
+  llvm::SetVector<Constraint *> applicableFns;
+  getConstraintGraph().gatherConstraints(
+      tyvar, applicableFns, ConstraintGraph::GatheringKind::EquivalenceClass,
+      [](Constraint *match) {
+        return match->getKind() == ConstraintKind::ApplicableFunction;
+      });
+
+  // FIXME: This is not true when we run the constraint optimizer.
+  // assert(applicableFns.size() <= 1);
+
+  // We have a disjunction for an operator but no application of it,
+  // so it's being passed as an argument.
+  if (applicableFns.size() == 0)
+    return;
+
+  // FIXME: We have more than one applicable per disjunction as a
+  //        result of merging disjunction type variables. We may want
+  //        to rip that out at some point.
+  Constraint *foundApplicable = nullptr;
+  SmallVector<Optional<Type>, 2> argumentTypes;
+  for (auto *applicableFn : applicableFns) {
+    argumentTypes.clear();
+    auto *fnTy = applicableFn->getFirstType()->castTo<FunctionType>();
+    ArgumentInfoCollector argInfo(*this, fnTy);
+    // Stop if we hit anything with concrete types or conformances to
+    // literals.
+    if (!argInfo.getTypes().empty() || !argInfo.getLiteralProtocols().empty()) {
+      foundApplicable = applicableFn;
+      break;
+    }
+  }
+
+  if (!foundApplicable)
+    return;
+
+  // FIXME: It would be good to avoid this redundancy.
+  auto *fnTy = foundApplicable->getFirstType()->castTo<FunctionType>();
+  ArgumentInfoCollector argInfo(*this, fnTy);
+
+  size_t nextType = 0;
+  for (auto argType : argInfo.getTypes()) {
+    auto *nominal = argType->getAnyNominal();
+    for (size_t i = nextType + 1; i < nominalTypes.size(); ++i) {
+      if (nominal == nominalTypes[i]) {
+        std::swap(nominalTypes[nextType], nominalTypes[i]);
+        ++nextType;
+        break;
+      }
+    }
+  }
+
+  if (nextType + 1 >= nominalTypes.size())
+    return;
+
+  for (auto *protocol : argInfo.getLiteralProtocols()) {
+    auto defaultType = TC.getDefaultType(protocol, DC);
+    auto *nominal = defaultType->getAnyNominal();
+    for (size_t i = nextType + 1; i < nominalTypes.size(); ++i) {
+      if (nominal == nominalTypes[i]) {
+        std::swap(nominalTypes[nextType], nominalTypes[i]);
+        ++nextType;
+        break;
+      }
+    }
+  }
+}
+
 void ConstraintSystem::partitionForDesignatedTypes(
     ArrayRef<Constraint *> Choices, ConstraintMatchLoop forEachChoice,
     PartitionAppendCallback appendPartition) {
 
-  auto designatedNominalTypes = getOperatorDesignatedNominalTypes(Choices[0]);
-  if (designatedNominalTypes.empty())
+  auto types = getOperatorDesignatedNominalTypes(Choices[0]);
+  if (types.empty())
     return;
 
+  SmallVector<NominalTypeDecl *, 4> designatedNominalTypes(types.begin(),
+                                                           types.end());
+
+  if (designatedNominalTypes.size() > 1)
+    sortDesignatedTypes(designatedNominalTypes, Choices[0]);
+
   SmallVector<SmallVector<unsigned, 4>, 4> definedInDesignatedType;
   SmallVector<SmallVector<unsigned, 4>, 4> definedInExtensionOfDesignatedType;
 
diff --git a/lib/Sema/ConstraintSystem.h b/lib/Sema/ConstraintSystem.h
index a76baa4..aa11af7 100644
--- a/lib/Sema/ConstraintSystem.h
+++ b/lib/Sema/ConstraintSystem.h
@@ -3227,6 +3227,12 @@
   typedef std::function<void(SmallVectorImpl<unsigned> &options)>
       PartitionAppendCallback;
 
+  // Attempt to sort nominalTypes based on what we can discover about
+  // calls into the overloads in the disjunction that bindOverload is
+  // a part of.
+  void sortDesignatedTypes(SmallVectorImpl<NominalTypeDecl *> &nominalTypes,
+                           Constraint *bindOverload);
+
   // Partition the choices in a disjunction based on those that match
   // the designated types for the operator that the disjunction was
   // formed for.
diff --git a/validation-test/Sema/type_checker_perf/fast/rdar17024694.swift b/validation-test/Sema/type_checker_perf/fast/rdar17024694.swift
index 0181500..4a799ea 100644
--- a/validation-test/Sema/type_checker_perf/fast/rdar17024694.swift
+++ b/validation-test/Sema/type_checker_perf/fast/rdar17024694.swift
@@ -1,4 +1,4 @@
-// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1
+// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1 -swift-version 5 -solver-disable-shrink -disable-constraint-solver-performance-hacks -solver-enable-operator-designated-types
 // REQUIRES: tools-release,no_asserts
 
 _ = (2...100).reversed().filter({ $0 % 11 == 0 }).map {