Merge pull request #19971 from rudkx/sort-designated-types
[ConstraintSystem] Sort the designated types based on actual argument…
diff --git a/lib/Sema/CSSolver.cpp b/lib/Sema/CSSolver.cpp
index c5bb33e..3b1a18f 100644
--- a/lib/Sema/CSSolver.cpp
+++ b/lib/Sema/CSSolver.cpp
@@ -1683,14 +1683,91 @@
return operatorDecl->getDesignatedNominalTypes();
}
+void ConstraintSystem::sortDesignatedTypes(
+ SmallVectorImpl<NominalTypeDecl *> &nominalTypes,
+ Constraint *bindOverload) {
+ auto *tyvar = bindOverload->getFirstType()->castTo<TypeVariableType>();
+ llvm::SetVector<Constraint *> applicableFns;
+ getConstraintGraph().gatherConstraints(
+ tyvar, applicableFns, ConstraintGraph::GatheringKind::EquivalenceClass,
+ [](Constraint *match) {
+ return match->getKind() == ConstraintKind::ApplicableFunction;
+ });
+
+ // FIXME: This is not true when we run the constraint optimizer.
+ // assert(applicableFns.size() <= 1);
+
+ // We have a disjunction for an operator but no application of it,
+ // so it's being passed as an argument.
+ if (applicableFns.size() == 0)
+ return;
+
+ // FIXME: We have more than one applicable per disjunction as a
+ // result of merging disjunction type variables. We may want
+ // to rip that out at some point.
+ Constraint *foundApplicable = nullptr;
+ SmallVector<Optional<Type>, 2> argumentTypes;
+ for (auto *applicableFn : applicableFns) {
+ argumentTypes.clear();
+ auto *fnTy = applicableFn->getFirstType()->castTo<FunctionType>();
+ ArgumentInfoCollector argInfo(*this, fnTy);
+ // Stop if we hit anything with concrete types or conformances to
+ // literals.
+ if (!argInfo.getTypes().empty() || !argInfo.getLiteralProtocols().empty()) {
+ foundApplicable = applicableFn;
+ break;
+ }
+ }
+
+ if (!foundApplicable)
+ return;
+
+ // FIXME: It would be good to avoid this redundancy.
+ auto *fnTy = foundApplicable->getFirstType()->castTo<FunctionType>();
+ ArgumentInfoCollector argInfo(*this, fnTy);
+
+ size_t nextType = 0;
+ for (auto argType : argInfo.getTypes()) {
+ auto *nominal = argType->getAnyNominal();
+ for (size_t i = nextType + 1; i < nominalTypes.size(); ++i) {
+ if (nominal == nominalTypes[i]) {
+ std::swap(nominalTypes[nextType], nominalTypes[i]);
+ ++nextType;
+ break;
+ }
+ }
+ }
+
+ if (nextType + 1 >= nominalTypes.size())
+ return;
+
+ for (auto *protocol : argInfo.getLiteralProtocols()) {
+ auto defaultType = TC.getDefaultType(protocol, DC);
+ auto *nominal = defaultType->getAnyNominal();
+ for (size_t i = nextType + 1; i < nominalTypes.size(); ++i) {
+ if (nominal == nominalTypes[i]) {
+ std::swap(nominalTypes[nextType], nominalTypes[i]);
+ ++nextType;
+ break;
+ }
+ }
+ }
+}
+
void ConstraintSystem::partitionForDesignatedTypes(
ArrayRef<Constraint *> Choices, ConstraintMatchLoop forEachChoice,
PartitionAppendCallback appendPartition) {
- auto designatedNominalTypes = getOperatorDesignatedNominalTypes(Choices[0]);
- if (designatedNominalTypes.empty())
+ auto types = getOperatorDesignatedNominalTypes(Choices[0]);
+ if (types.empty())
return;
+ SmallVector<NominalTypeDecl *, 4> designatedNominalTypes(types.begin(),
+ types.end());
+
+ if (designatedNominalTypes.size() > 1)
+ sortDesignatedTypes(designatedNominalTypes, Choices[0]);
+
SmallVector<SmallVector<unsigned, 4>, 4> definedInDesignatedType;
SmallVector<SmallVector<unsigned, 4>, 4> definedInExtensionOfDesignatedType;
diff --git a/lib/Sema/ConstraintSystem.h b/lib/Sema/ConstraintSystem.h
index a76baa4..aa11af7 100644
--- a/lib/Sema/ConstraintSystem.h
+++ b/lib/Sema/ConstraintSystem.h
@@ -3227,6 +3227,12 @@
typedef std::function<void(SmallVectorImpl<unsigned> &options)>
PartitionAppendCallback;
+ // Attempt to sort nominalTypes based on what we can discover about
+ // calls into the overloads in the disjunction that bindOverload is
+ // a part of.
+ void sortDesignatedTypes(SmallVectorImpl<NominalTypeDecl *> &nominalTypes,
+ Constraint *bindOverload);
+
// Partition the choices in a disjunction based on those that match
// the designated types for the operator that the disjunction was
// formed for.
diff --git a/validation-test/Sema/type_checker_perf/fast/rdar17024694.swift b/validation-test/Sema/type_checker_perf/fast/rdar17024694.swift
index 0181500..4a799ea 100644
--- a/validation-test/Sema/type_checker_perf/fast/rdar17024694.swift
+++ b/validation-test/Sema/type_checker_perf/fast/rdar17024694.swift
@@ -1,4 +1,4 @@
-// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1
+// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1 -swift-version 5 -solver-disable-shrink -disable-constraint-solver-performance-hacks -solver-enable-operator-designated-types
// REQUIRES: tools-release,no_asserts
_ = (2...100).reversed().filter({ $0 % 11 == 0 }).map {