Merge pull request #14454 from DougGregor/gsb-term-rewriting

[GSB] Term rewriting for same-type constraints
diff --git a/include/swift/AST/GenericSignatureBuilder.h b/include/swift/AST/GenericSignatureBuilder.h
index b9c0e8d..dff956a 100644
--- a/include/swift/AST/GenericSignatureBuilder.h
+++ b/include/swift/AST/GenericSignatureBuilder.h
@@ -258,11 +258,13 @@
     Type getTypeInContext(GenericSignatureBuilder &builder,
                           GenericEnvironment *genericEnv);
 
-    /// Dump a debugging representation of this equivalence class.
-    void dump(llvm::raw_ostream &out) const;
+    /// Dump a debugging representation of this equivalence class,
+    void dump(llvm::raw_ostream &out,
+              GenericSignatureBuilder *builder = nullptr) const;
 
-    LLVM_ATTRIBUTE_DEPRECATED(void dump() const,
-                              "only for use in the debugger");
+    LLVM_ATTRIBUTE_DEPRECATED(
+                  void dump(GenericSignatureBuilder *builder = nullptr) const,
+                  "only for use in the debugger");
 
     /// Caches.
 
@@ -271,9 +273,8 @@
       /// The cached anchor itself.
       Type anchor;
 
-      /// The number of members of the equivalence class when the archetype
-      /// anchor was cached.
-      unsigned numMembers;
+      /// The generation at which the anchor was last computed.
+      unsigned lastGeneration;
     } archetypeAnchorCache;
 
     /// Describes a cached nested type.
@@ -444,6 +445,11 @@
   /// Note that we have added the nested type nestedPA
   void addedNestedType(PotentialArchetype *nestedPA);
 
+  /// Add a rewrite rule for a same-type constraint between the given
+  /// types.
+  void addSameTypeRewriteRule(PotentialArchetype *type1,
+                              PotentialArchetype *type2);
+
   /// \brief Add a new conformance requirement specifying that the given
   /// potential archetypes are equivalent.
   ConstraintResult addSameTypeRequirementBetweenArchetypes(
@@ -802,6 +808,12 @@
   /// Determine whether the two given types are in the same equivalence class.
   bool areInSameEquivalenceClass(Type type1, Type type2);
 
+  /// Simplify the given dependent type down to its canonical representation.
+  ///
+  /// \returns null if the type involved dependent member types that
+  /// don't have associated types.
+  Type getCanonicalTypeParameter(Type type);
+
   /// Verify the correctness of the given generic signature.
   ///
   /// This routine will test that the given generic signature is both minimal
diff --git a/lib/AST/GenericSignatureBuilder.cpp b/lib/AST/GenericSignatureBuilder.cpp
index f473456..d657f7c 100644
--- a/lib/AST/GenericSignatureBuilder.cpp
+++ b/lib/AST/GenericSignatureBuilder.cpp
@@ -107,6 +107,190 @@
           "Delayed requirements left unresolved");
 STATISTIC(NumConditionalRequirementsAdded,
           "# of conditional requirements added");
+STATISTIC(NumComponentsCollapsedViaRewriting,
+          "# of same-type components collapsed via term rewriting");
+
+namespace  {
+
+/// A purely-relative rewrite path consisting of a (possibly empty)
+/// sequence of associated type references.
+using RelativeRewritePath = ArrayRef<AssociatedTypeDecl *>;
+
+/// Describes a rewrite path, which contains an optional base (generic
+/// parameter) followed by a sequence of associated type references.
+class RewritePath {
+  Optional<GenericParamKey> base;
+  TinyPtrVector<AssociatedTypeDecl *> path;
+
+public:
+  RewritePath() { }
+
+  enum PathOrder {
+    Forward,
+    Reverse,
+  };
+
+  /// Form a rewrite path given an optional base and a relative rewrite path.
+  RewritePath(Optional<GenericParamKey> base, RelativeRewritePath path,
+              PathOrder order);
+
+  /// Retrieve the base of the given rewrite path.
+  ///
+  /// When present, it indicates that the entire path will be rebased on
+  /// the given base generic parameter. This is required for describing
+  /// rewrites on type parameters themselves, e.g., T == U.
+  Optional<GenericParamKey> getBase() const { return base; }
+
+  /// Retrieve the sequence of associated type references that describes
+  /// the path.
+  ArrayRef<AssociatedTypeDecl *> getPath() const { return path; }
+
+  /// Decompose a type into a path.
+  ///
+  /// \returns the path, or None if it contained unresolved dependent member
+  /// types.
+  Optional<RewritePath> static createPath(Type type);
+
+  /// Decompose a potential archetype into a patch.
+  ///
+  /// \returns the path, or None if it contained potential archetypes
+  /// with concrete declarations.
+  Optional<RewritePath> static createPath(PotentialArchetype *pa);
+
+  /// Compute the common path between this path and \c other, if one exists.
+  Optional<RewritePath> commonPath(const RewritePath &other) const;
+
+  /// Form a canonical, dependent type.
+  ///
+  /// This requires that the rewrite path have a base.
+  CanType formDependentType(ASTContext &ctx) const;
+
+  /// Compare the given rewrite paths.
+  int compare(const RewritePath &other) const;
+
+  /// Print this path.
+  void print(llvm::raw_ostream &out) const;
+
+  LLVM_ATTRIBUTE_DEPRECATED(void dump() const LLVM_ATTRIBUTE_USED,
+                            "only for use within the debugger") {
+    print(llvm::errs());
+  }
+
+  friend bool operator==(const RewritePath &lhs, const RewritePath &rhs) {
+    return lhs.getBase() == rhs.getBase() && lhs.getPath() == rhs.getPath();
+  }
+};
+
+/// A node within the prefix tree that is used to match associated type
+/// references.
+class RewriteTreeNode {
+  /// The associated type that leads to this node.
+  ///
+  /// The bit indicates whether there is a rewrite rule for this particular
+  /// node. If the bit is not set, \c rewrite is invalid.
+  llvm::PointerIntPair<AssociatedTypeDecl *, 1, bool> assocTypeAndHasRewrite;
+
+  /// The sequence of associated types to which a reference to this associated
+  /// type (from the equivalence class root) can be rewritten. This field is
+  /// only valid when the bit of \c assocTypeAndHasRewrite is set.
+  ///
+  /// Consider a requirement "Self.A.B.C == C". This will be encoded as
+  /// a prefix tree starting at the equivalence class for Self with
+  /// the following nodes:
+  ///
+  /// (assocType: A,
+  ///   children: [
+  ///     (assocType: B,
+  ///       children: [
+  ///         (assocType: C, rewrite: [C], children: [])
+  ///       ])
+  ///   ])
+  RewritePath rewrite;
+
+  /// The child nodes, which extend the sequence to be matched.
+  ///
+  /// The child nodes are sorted by the associated type declaration
+  /// pointers, so we can perform binary searches quickly.
+  llvm::TinyPtrVector<RewriteTreeNode *> children;
+
+public:
+  ~RewriteTreeNode();
+
+  RewriteTreeNode(AssociatedTypeDecl *assocType)
+    : assocTypeAndHasRewrite(assocType, false) { }
+
+  /// Retrieve the associated type declaration one must match to use this
+  /// node, which may the
+  AssociatedTypeDecl *getMatch() const {
+    return assocTypeAndHasRewrite.getPointer();
+  }
+
+  /// Determine whether this particular node has a rewrite rule.
+  bool hasRewriteRule() const {
+    return assocTypeAndHasRewrite.getInt();
+  }
+
+  /// Set a new rewrite rule for this particular node. This can only be
+  /// performed once.
+  void setRewriteRule(RewritePath replacementPath) {
+    assert(!hasRewriteRule());
+    assocTypeAndHasRewrite.setInt(true);
+    rewrite = replacementPath;
+  }
+
+  /// Retrieve the path to which this node will be rewritten.
+  const RewritePath &getRewriteRule() const {
+    assert(hasRewriteRule());
+    return rewrite;
+  }
+
+  /// Add a new rewrite rule to this tree node.
+  ///
+  /// \param matchPath The path of associated type declarations that must
+  /// be matched to produce a rewrite.
+  ///
+  /// \param replacementPath The sequence of associated type declarations
+  /// with which a match will be replaced.
+  void addRewriteRule(RelativeRewritePath matchPath,
+                      RewritePath replacementPath);
+
+  /// Enumerate all of the paths to which the given matched path can be
+  /// rewritten.
+  ///
+  /// \param matchPath The path to match.
+  ///
+  /// \param callback A callback that will be invoked with (prefix, rewrite)
+  /// pairs, where \c prefix is the length of the matching prefix of
+  /// \c matchPath that matched and \c rewrite is the path to which it can
+  /// be rewritten.
+  void enumerateRewritePaths(
+                       RelativeRewritePath matchPath,
+                       llvm::function_ref<void(unsigned, RewritePath)> callback,
+                       unsigned depth = 0) const;
+
+  /// Find the best rewrite rule to match the given path.
+  ///
+  /// \param path The path to match.
+  /// \param prefixLength The length of the prefix leading up to \c path.
+  Optional<std::pair<unsigned, RewritePath>>
+  bestMatch(GenericParamKey base, RelativeRewritePath path,
+            unsigned prefixLength);
+
+  /// Merge the given rewrite tree into \c other.
+  void mergeInto(RewriteTreeNode *other);
+
+  LLVM_ATTRIBUTE_DEPRECATED(void dump() const LLVM_ATTRIBUTE_USED,
+                            "only for use within the debugger");
+
+  /// Dump the tree.
+  void dump(llvm::raw_ostream &out, bool lastChild = true) const;
+
+private:
+  /// Merge the given rewrite tree into \c other.
+  void mergeIntoRec(RewriteTreeNode *other,
+                    llvm::SmallVectorImpl<AssociatedTypeDecl *> &matchPath);
+};
+}
 
 struct GenericSignatureBuilder::Implementation {
   /// Allocator.
@@ -131,6 +315,9 @@
   /// Equivalence classes that are not currently being used.
   std::vector<void *> FreeEquivalenceClasses;
 
+  /// The roots of the rewrite tree.
+  DenseMap<const EquivalenceClass *, RewriteTreeNode *> RewriteTreeRoots;
+
   /// The generation number, which is incremented whenever we successfully
   /// introduce a new constraint.
   unsigned Generation = 0;
@@ -141,16 +328,6 @@
   /// Whether we are currently processing delayed requirements.
   bool ProcessingDelayedRequirements = false;
 
-  /// Tear down an implementation.
-  ~Implementation();
-
-  /// Allocate a new equivalence class with the given representative.
-  EquivalenceClass *allocateEquivalenceClass(
-                                       PotentialArchetype *representative);
-
-  /// Deallocate the given equivalence class, returning it to the free list.
-  void deallocateEquivalenceClass(EquivalenceClass *equivClass);
-
   /// Whether there were any errors.
   bool HadAnyError = false;
 
@@ -161,12 +338,35 @@
   /// Whether we've already finalized the builder.
   bool finalized = false;
 #endif
+
+  /// Tear down an implementation.
+  ~Implementation();
+
+  /// Allocate a new equivalence class with the given representative.
+  EquivalenceClass *allocateEquivalenceClass(
+                                       PotentialArchetype *representative);
+
+  /// Deallocate the given equivalence class, returning it to the free list.
+  void deallocateEquivalenceClass(EquivalenceClass *equivClass);
+
+  /// Retrieve the rewrite tree root for the given equivalence class,
+  /// if present.
+  RewriteTreeNode *getRewriteTreeRootIfPresent(
+                                      const EquivalenceClass *equivClass);
+
+  /// Retrieve the rewrite tree root for the given equivalence class,
+  /// creating it if needed.
+  RewriteTreeNode *getOrCreateRewriteTreeRoot(
+                                        const EquivalenceClass *equivClass);
 };
 
 #pragma mark Memory management
 GenericSignatureBuilder::Implementation::~Implementation() {
   for (auto pa : PotentialArchetypes)
     pa->~PotentialArchetype();
+
+  for (const auto &root : RewriteTreeRoots)
+    delete root.second;
 }
 
 EquivalenceClass *
@@ -1915,31 +2115,13 @@
   return populateResult((nestedTypeNameCache[name] = std::move(entry)));
 }
 
-/// Determine whether any part of this potential archetype's path to the
-/// root contains the given equivalence class.
-static bool pathContainsEquivalenceClass(GenericSignatureBuilder &builder,
-                                         PotentialArchetype *pa,
-                                         EquivalenceClass *equivClass) {
-  // Chase the potential archetype up to the root.
-  for (; pa; pa = pa->getParent()) {
-    // Check whether this potential archetype is in the given equivalence
-    // class.
-    if (pa->getOrCreateEquivalenceClass(builder) == equivClass)
-      return true;
-  }
-
-  return false;
-}
-
 Type EquivalenceClass::getAnchor(
                             GenericSignatureBuilder &builder,
                             TypeArrayView<GenericTypeParamType> genericParams) {
-  // Check whether the cache is valid.
-  if (archetypeAnchorCache.anchor &&
-      archetypeAnchorCache.numMembers == members.size()) {
-    ++NumArchetypeAnchorCacheHits;
+  // Substitute into the anchor with the given generic parameters.
+  auto substAnchor = [&] {
+    if (genericParams.empty()) return archetypeAnchorCache.anchor;
 
-    // Reparent the anchor using genericParams.
     return archetypeAnchorCache.anchor.subst(
              [&](SubstitutableType *dependentType) {
                if (auto gp = dyn_cast<GenericTypeParamType>(dependentType)) {
@@ -1951,73 +2133,43 @@
                return Type(dependentType);
              },
              MakeAbstractConformanceForGenericType());
+
+  };
+
+  // Check whether the cache is valid.
+  if (archetypeAnchorCache.anchor &&
+      archetypeAnchorCache.lastGeneration == builder.Impl->Generation) {
+    ++NumArchetypeAnchorCacheHits;
+    return substAnchor();
   }
 
-  // Map the members of this equivalence class to the best associated type
-  // within that equivalence class.
-  llvm::SmallDenseMap<EquivalenceClass *, AssociatedTypeDecl *> nestedTypes;
+  // Check whether we already have an anchor, in which case we
+  // can simplify it further.
+  if (archetypeAnchorCache.anchor) {
+    // Record the cache miss.
+    ++NumArchetypeAnchorCacheMisses;
 
-  Type bestGenericParam;
+    // Update the anchor by simplifying it further.
+    archetypeAnchorCache.anchor =
+      builder.getCanonicalTypeParameter(archetypeAnchorCache.anchor);
+    archetypeAnchorCache.lastGeneration = builder.Impl->Generation;
+    return substAnchor();
+  }
+
+  // Form the anchor.
   for (auto member : members) {
-    // If the member is a generic parameter, keep the best generic parameter.
-    if (member->isGenericParam()) {
-      Type genericParamType = member->getDependentType(genericParams);
-      if (!bestGenericParam ||
-          compareDependentTypes(genericParamType, bestGenericParam) < 0)
-        bestGenericParam = genericParamType;
-      continue;
-    }
+    auto anchorType =
+      builder.getCanonicalTypeParameter(member->getDependentType(genericParams));
+    if (!anchorType) continue;
 
-    // If we saw a generic parameter, ignore any nested types.
-    if (bestGenericParam) continue;
-
-    // If the nested type doesn't have an associated type, skip it.
-    auto assocType = member->getResolvedAssociatedType();
-    if (!assocType) continue;
-
-    // Dig out the equivalence class of the parent.
-    auto parentEquivClass =
-      member->getParent()->getOrCreateEquivalenceClass(builder);
-
-    // If the path from this member to the root contains this equivalence
-    // class, it cannot be part of the anchor.
-    if (pathContainsEquivalenceClass(builder, member->getParent(), this))
-      continue;
-
-    // Take the best associated type for this equivalence class.
-    assocType = assocType->getAssociatedTypeAnchor();
-    auto &bestAssocType = nestedTypes[parentEquivClass];
-    if (!bestAssocType ||
-        compareAssociatedTypes(assocType, bestAssocType) < 0)
-      bestAssocType = assocType;
+    // Record the cache miss and update the cache.
+    ++NumArchetypeAnchorCacheMisses;
+    archetypeAnchorCache.anchor = anchorType;
+    archetypeAnchorCache.lastGeneration = builder.Impl->Generation;
+    return substAnchor();
   }
 
-  // If we found a generic parameter, return that.
-  if (bestGenericParam)
-    return bestGenericParam;
-
-  // Determine the best anchor among the parent equivalence classes.
-  Type bestParentAnchor;
-  AssociatedTypeDecl *bestAssocType = nullptr;
-  std::pair<EquivalenceClass *, Identifier> bestNestedType;
-  for (const auto &nestedType : nestedTypes) {
-    auto parentAnchor = nestedType.first->getAnchor(builder, genericParams);
-    if (!bestParentAnchor ||
-        compareDependentTypes(parentAnchor, bestParentAnchor) < 0) {
-      bestParentAnchor = parentAnchor;
-      bestAssocType = nestedType.second;
-    }
-  }
-
-  // Form the anchor type.
-  Type anchorType = DependentMemberType::get(bestParentAnchor, bestAssocType);
-
-  // Record the cache miss and update the cache.
-  ++NumArchetypeAnchorCacheMisses;
-  archetypeAnchorCache.anchor = anchorType;
-  archetypeAnchorCache.numMembers = members.size();
-
-  return anchorType;
+  llvm_unreachable("Unable to compute anchor");
 }
 
 Type EquivalenceClass::getTypeInContext(GenericSignatureBuilder &builder,
@@ -2138,7 +2290,8 @@
   return archetype;
 }
 
-void EquivalenceClass::dump(llvm::raw_ostream &out) const {
+void EquivalenceClass::dump(llvm::raw_ostream &out,
+                            GenericSignatureBuilder *builder) const {
   out << "Equivalence class represented by "
     << members.front()->getRepresentative()->getDebugName() << ":\n";
   out << "Members: ";
@@ -2183,6 +2336,13 @@
 
   out << "\n";
 
+  if (builder) {
+    if (auto rewriteRoot = builder->Impl->getRewriteTreeRootIfPresent(this)) {
+      out << "---Rewrite tree---\n";
+      rewriteRoot->dump(out);
+    }
+  }
+
   {
     out << "---GraphViz output for same-type constraints---\n";
 
@@ -2215,8 +2375,8 @@
   }
 }
 
-void EquivalenceClass::dump() const {
-  dump(llvm::errs());
+void EquivalenceClass::dump(GenericSignatureBuilder *builder) const {
+  dump(llvm::errs(), builder);
 }
 
 void DelayedRequirement::dump(llvm::raw_ostream &out) const {
@@ -2523,6 +2683,24 @@
   return 0;
 }
 
+/// Compare two dependent paths to determine which is better.
+static int compareDependentPaths(ArrayRef<AssociatedTypeDecl *> path1,
+                                 ArrayRef<AssociatedTypeDecl *> path2) {
+  // Shorter paths win.
+  if (path1.size() != path2.size())
+    return path1.size() < path2.size() ? -1 : 1;
+
+  // The paths are the same length, so order by comparing the associted
+  // types.
+  for (unsigned index : indices(path1)) {
+    if (int result = compareAssociatedTypes(path1[index], path2[index]))
+      return result;
+  }
+
+  // Identical paths.
+  return 0;
+}
+
 namespace {
   /// Function object used to suppress conflict diagnoses when we know we'll
   /// see them again later.
@@ -2883,6 +3061,507 @@
   }
 }
 
+#pragma mark Rewrite tree
+RewritePath::RewritePath(Optional<GenericParamKey> base,
+                         RelativeRewritePath path,
+                         PathOrder order)
+  : base(base)
+{
+  switch (order) {
+  case Forward:
+    this->path.insert(this->path.begin(), path.begin(), path.end());
+    break;
+
+  case Reverse:
+    this->path.insert(this->path.begin(), path.rbegin(), path.rend());
+    break;
+  }
+}
+
+Optional<RewritePath> RewritePath::createPath(PotentialArchetype *pa) {
+  SmallVector<AssociatedTypeDecl *, 4> path;
+  while (auto parent = pa->getParent()) {
+    auto assocType = pa->getResolvedAssociatedType();
+    if (!assocType) return None;
+
+    path.push_back(assocType);
+    pa = parent;
+  }
+
+  return RewritePath(pa->getGenericParamKey(), path, Reverse);
+}
+
+Optional<RewritePath> RewritePath::createPath(Type type) {
+  SmallVector<AssociatedTypeDecl *, 4> path;
+  while (auto depMemTy = type->getAs<DependentMemberType>()) {
+    auto assocType = depMemTy->getAssocType();
+    if (!assocType) return None;
+
+    path.push_back(assocType);
+    type = depMemTy->getBase();
+  }
+
+  auto genericParam = type->getAs<GenericTypeParamType>();
+  if (!genericParam) return None;
+
+  return RewritePath(GenericParamKey(genericParam), path, Reverse);
+}
+
+Optional<RewritePath> RewritePath::commonPath(const RewritePath &other) const {
+  assert(getBase().hasValue() && other.getBase().hasValue());
+
+  if (*getBase() != *other.getBase()) return None;
+
+  // Find the longest common prefix.
+  RelativeRewritePath path1 = getPath();
+  RelativeRewritePath path2 = other.getPath();
+  if (path1.size() > path2.size())
+    std::swap(path1, path2);
+  unsigned prefixLength =
+    std::mismatch(path1.begin(), path1.end(), path2.begin()).first
+      - path1.begin();
+
+  // Form the common path.
+  return RewritePath(getBase(), path1.slice(0, prefixLength), Forward);
+}
+
+/// Form a dependent type with the given generic parameter, then following the
+/// path of associated types.
+static Type formDependentType(GenericTypeParamType *base,
+                              RelativeRewritePath path) {
+  return std::accumulate(path.begin(), path.end(), Type(base),
+                         [](Type type, AssociatedTypeDecl *assocType) -> Type {
+                           return DependentMemberType::get(type, assocType);
+                         });
+}
+
+/// Form a dependent type with the (canonical) generic parameter for the given
+/// parameter key, then following the path of associated types.
+static Type formDependentType(ASTContext &ctx, GenericParamKey genericParam,
+                              RelativeRewritePath path) {
+  return formDependentType(GenericTypeParamType::get(genericParam.Depth,
+                                                     genericParam.Index,
+                                                     ctx),
+                           path);
+}
+
+CanType RewritePath::formDependentType(ASTContext &ctx) const {
+  assert(getBase());
+  return CanType(::formDependentType(ctx, *getBase(), getPath()));
+}
+
+int RewritePath::compare(const RewritePath &other) const {
+  // Prefer relative to absolute paths.
+  if (getBase().hasValue() != other.getBase().hasValue()) {
+    return other.getBase().hasValue() ? -1 : 1;
+  }
+
+  // Order based on the bases.
+  if (getBase() && *getBase() != *other.getBase())
+    return (*getBase() < *other.getBase()) ? -1 : 1;
+
+  // Order based on the path contents.
+  return compareDependentPaths(getPath(), other.getPath());
+}
+
+void RewritePath::print(llvm::raw_ostream &out) const {
+  out << "[";
+
+  if (getBase()) {
+    out << "(" << getBase()->Depth << ", " << getBase()->Index << ")";
+    if (!getPath().empty()) out << " -> ";
+  }
+
+  interleave(getPath().begin(), getPath().end(),
+             [&](AssociatedTypeDecl *assocType) {
+               out.changeColor(raw_ostream::BLUE);
+               out << assocType->getProtocol()->getName() << "."
+               << assocType->getName();
+               out.resetColor();
+             }, [&] {
+               out << " -> ";
+             });
+  out << "]";
+}
+
+RewriteTreeNode::~RewriteTreeNode() {
+  for (auto child : children)
+    delete child;
+}
+
+namespace {
+/// Function object used to order rewrite tree nodes based on the address
+/// of the associated type.
+class OrderTreeRewriteNode {
+  bool compare(AssociatedTypeDecl *lhs, AssociatedTypeDecl *rhs) const {
+    // Make sure null pointers precede everything else.
+    if (static_cast<bool>(lhs) != static_cast<bool>(rhs))
+      return static_cast<bool>(rhs);
+
+    // Use std::less to provide a defined ordering.
+    return std::less<AssociatedTypeDecl *>()(lhs, rhs);
+  }
+
+public:
+  bool operator()(RewriteTreeNode *lhs, AssociatedTypeDecl *rhs) const {
+    return compare(lhs->getMatch(), rhs);
+  }
+
+  bool operator()(AssociatedTypeDecl *lhs, RewriteTreeNode *rhs) const {
+    return compare(lhs, rhs->getMatch());
+  }
+
+  bool operator()(RewriteTreeNode *lhs, RewriteTreeNode *rhs) const {
+    return compare(lhs->getMatch(), rhs->getMatch());
+  }
+};
+}
+
+void RewriteTreeNode::addRewriteRule(RelativeRewritePath matchPath,
+                                     RewritePath replacementPath) {
+  // If the match path is empty, we're adding the rewrite rule to this node.
+  if (matchPath.empty()) {
+    // If we don't already have a rewrite rule, add it.
+    if (!hasRewriteRule()) {
+      setRewriteRule(replacementPath);
+      return;
+    }
+
+    // If we already have this rewrite rule, we're done.
+    if (getRewriteRule() == replacementPath) return;
+
+    // Check whether any of the continuation children matches.
+    auto insertPos = children.begin();
+    while (insertPos != children.end() && !(*insertPos)->getMatch()) {
+      if ((*insertPos)->hasRewriteRule() &&
+          (*insertPos)->getRewriteRule() == replacementPath)
+        return;
+
+      ++insertPos;
+    }
+
+    // We already have a rewrite rule, so add a new child with a
+    // null associated type match to hold the rewrite rule.
+    auto newChild = new RewriteTreeNode(nullptr);
+    newChild->setRewriteRule(replacementPath);
+    children.insert(insertPos, newChild);
+    return;
+  }
+
+  // Find (or create) a child node describing the next step in the match.
+  auto matchFront = matchPath.front();
+  auto childPos =
+    std::lower_bound(children.begin(), children.end(), matchFront,
+                     OrderTreeRewriteNode());
+  if (childPos == children.end() || (*childPos)->getMatch() != matchFront) {
+    childPos = children.insert(childPos, new RewriteTreeNode(matchFront));
+  }
+
+  // Add the rewrite rule to the child.
+  (*childPos)->addRewriteRule(matchPath.slice(1), replacementPath);
+}
+
+void RewriteTreeNode::enumerateRewritePaths(
+                       RelativeRewritePath matchPath,
+                       llvm::function_ref<void(unsigned, RewritePath)> callback,
+                       unsigned depth) const {
+  // Determine whether we know anything about the next step in the path.
+  auto childPos =
+    depth < matchPath.size()
+      ? std::lower_bound(children.begin(), children.end(),
+                         matchPath[depth], OrderTreeRewriteNode())
+      : children.end();
+  if (childPos != children.end() &&
+      (*childPos)->getMatch() == matchPath[depth]) {
+    // Try to match the rest of the path.
+    (*childPos)->enumerateRewritePaths(matchPath, callback, depth + 1);
+  }
+
+  // If we have a rewrite rule at this position, invoke it.
+  if (hasRewriteRule()) {
+    // Invoke the callback with the first result.
+    callback(depth, rewrite);
+  }
+
+  // Walk any children with NULL associated types; they might have more matches.
+  for (auto otherRewrite : children) {
+    if (otherRewrite->getMatch()) break;
+    otherRewrite->enumerateRewritePaths(matchPath, callback, depth);
+  }
+}
+
+Optional<std::pair<unsigned, RewritePath>>
+RewriteTreeNode::bestMatch(GenericParamKey base, RelativeRewritePath path,
+                           unsigned prefixLength) {
+  Optional<std::pair<unsigned, RewritePath>> best;
+  unsigned bestAdjustedLength = 0;
+  enumerateRewritePaths(path,
+                        [&](unsigned length, RewritePath path) {
+    // Determine how much of the original path will be replaced by the rewrite.
+    unsigned adjustedLength = length;
+    if (auto newBase = path.getBase()) {
+      adjustedLength += prefixLength;
+
+      // If the base is unchanged, make sure we're reducing the length.
+      if (*newBase == base && adjustedLength <= path.getPath().size())
+        return;
+    }
+
+    if (adjustedLength == 0) return;
+
+    if (adjustedLength > bestAdjustedLength ||
+        (adjustedLength == bestAdjustedLength &&
+         path.compare(best->second) < 0)) {
+      best = { length, path };
+      bestAdjustedLength = adjustedLength;
+    }
+  });
+
+  return best;
+}
+
+void RewriteTreeNode::mergeInto(RewriteTreeNode *other) {
+  SmallVector<AssociatedTypeDecl *, 4> matchPath;
+  mergeIntoRec(other, matchPath);
+}
+
+void RewriteTreeNode::mergeIntoRec(
+                     RewriteTreeNode *other,
+                     llvm::SmallVectorImpl<AssociatedTypeDecl *> &matchPath) {
+  // FIXME: A destructive version of this operation would be more efficient,
+  // since we generally don't care about \c other after doing this.
+  if (auto assocType = getMatch())
+    matchPath.push_back(assocType);
+
+  // Add this rewrite rule, if there is one.
+  if (hasRewriteRule())
+    other->addRewriteRule(matchPath, rewrite);
+
+  // Recurse into the child nodes.
+  for (auto child : children)
+    child->mergeIntoRec(other, matchPath);
+
+  if (auto assocType = getMatch())
+    matchPath.pop_back();
+}
+
+void RewriteTreeNode::dump() const {
+  dump(llvm::errs());
+}
+
+void RewriteTreeNode::dump(llvm::raw_ostream &out, bool lastChild) const {
+  std::string prefixStr;
+
+  std::function<void(const RewriteTreeNode *, bool lastChild)> print;
+  print = [&](const RewriteTreeNode *node, bool lastChild) {
+    out << prefixStr << " `--";
+
+    // Print the node name.
+    out.changeColor(raw_ostream::GREEN);
+    if (auto assoc = node->getMatch())
+      out << assoc->getProtocol()->getName() << "." << assoc->getName();
+    else
+      out << "(cont'd)";
+    out.resetColor();
+
+    // Print the rewrite, if there is one.
+    if (node->hasRewriteRule()) {
+      out << " --> ";
+      node->rewrite.print(out);
+    }
+
+    out << "\n";
+
+    // Print children.
+    prefixStr += ' ';
+    prefixStr += (lastChild ? ' ' : '|');
+    prefixStr += "  ";
+
+    for (auto child : node->children) {
+      print(child, child == node->children.back());
+    }
+
+    prefixStr.erase(prefixStr.end() - 4, prefixStr.end());
+  };
+
+  print(this, lastChild);
+}
+
+RewriteTreeNode *
+GenericSignatureBuilder::Implementation::getRewriteTreeRootIfPresent(
+                                          const EquivalenceClass *equivClass) {
+  auto known = RewriteTreeRoots.find(equivClass);
+  if (known != RewriteTreeRoots.end()) return known->second;
+
+  return nullptr;
+}
+
+RewriteTreeNode *
+GenericSignatureBuilder::Implementation::getOrCreateRewriteTreeRoot(
+                                          const EquivalenceClass *equivClass) {
+  auto known = RewriteTreeRoots.find(equivClass);
+  if (known != RewriteTreeRoots.end()) return known->second;
+
+  auto root = new RewriteTreeNode(nullptr);
+  RewriteTreeRoots[equivClass] = root;
+  return root;
+}
+
+void GenericSignatureBuilder::addSameTypeRewriteRule(PotentialArchetype *pa1,
+                                                     PotentialArchetype *pa2){
+  auto pathOpt1 = RewritePath::createPath(pa1);
+  if (!pathOpt1) return;
+
+  auto pathOpt2 = RewritePath::createPath(pa2);
+  if (!pathOpt2) return;
+
+  auto path1 = std::move(pathOpt1).getValue();
+  auto path2 = std::move(pathOpt2).getValue();
+
+  // Look for a common path.
+  auto prefix = path1.commonPath(path2);
+
+  // If we didn't find a common path, try harder.
+  Type simplifiedType1;
+  Type simplifiedType2;
+  if (!prefix) {
+    // Simplify both sides in the hope of uncovering a common path.
+    simplifiedType1 = getCanonicalTypeParameter(pa1->getDependentType({ }));
+    simplifiedType2 = getCanonicalTypeParameter(pa2->getDependentType({ }));
+    if (simplifiedType1->isEqual(simplifiedType2)) return;
+
+    // Create new paths from the simplified types.
+    path1 = *RewritePath::createPath(simplifiedType1);
+    path2 = *RewritePath::createPath(simplifiedType2);
+
+    // Find a common path.
+    prefix = path1.commonPath(path2);
+  }
+
+  // When we have a common prefix, form a rewrite rule using relative paths.
+  if (prefix) {
+    // Find the better relative rewrite path.
+    RelativeRewritePath relPath1
+      = path1.getPath().slice(prefix->getPath().size());
+    RelativeRewritePath relPath2
+      = path2.getPath().slice(prefix->getPath().size());
+    // Order the paths so that we go to the more-canonical path.
+    if (compareDependentPaths(relPath1, relPath2) < 0)
+      std::swap(relPath1, relPath2);
+
+    // Find the equivalence class for the prefix.
+    CanType commonType = prefix->formDependentType(getASTContext());
+    auto equivClass =
+      resolveEquivalenceClass(commonType, ArchetypeResolutionKind::WellFormed);
+    assert(equivClass && "Prefix cannot be resolved?");
+
+    // Add the rewrite rule.
+    auto root = Impl->getOrCreateRewriteTreeRoot(equivClass);
+    root->addRewriteRule(relPath1,
+                         RewritePath(None, relPath2, RewritePath::Forward));
+
+    return;
+  }
+
+  // Otherwise, form a rewrite rule with absolute paths.
+
+  // Find the better path and make sure it's in path2.
+  if (compareDependentTypes(simplifiedType1, simplifiedType2) < 0) {
+    std::swap(path1, path2);
+    std::swap(simplifiedType1, simplifiedType2);
+  }
+
+  // Add the rewrite rule.
+  Type firstBase =
+    GenericTypeParamType::get(path1.getBase()->Depth, path1.getBase()->Index,
+                              getASTContext());
+  auto equivClass =
+    resolveEquivalenceClass(firstBase, ArchetypeResolutionKind::WellFormed);
+  assert(equivClass && "Base cannot be resolved?");
+
+  auto root = Impl->getOrCreateRewriteTreeRoot(equivClass);
+  root->addRewriteRule(path1.getPath(), path2);
+}
+
+Type GenericSignatureBuilder::getCanonicalTypeParameter(Type type) {
+  auto initialPath = RewritePath::createPath(type);
+  if (!initialPath) return nullptr;
+
+  auto genericParamType =
+    GenericTypeParamType::get(initialPath->getBase()->Depth,
+                              initialPath->getBase()->Index,
+                              getASTContext());
+  auto initialEquivClass =
+    resolveEquivalenceClass(genericParamType,
+                            ArchetypeResolutionKind::WellFormed);
+  if (!initialEquivClass) return nullptr;
+
+  unsigned startIndex = 0;
+  auto equivClass = initialEquivClass;
+  Type currentType = genericParamType;
+  SmallVector<AssociatedTypeDecl *, 4> path(initialPath->getPath().begin(),
+                                            initialPath->getPath().end());
+  bool simplified = false;
+  do {
+    if (auto rootNode = Impl->getRewriteTreeRootIfPresent(equivClass)) {
+      // Find the best rewrite rule for the path starting at startIndex.
+      auto match =
+        rootNode->bestMatch(genericParamType,
+                            llvm::makeArrayRef(path).slice(startIndex),
+                            startIndex);
+
+      // If we have a match, replace the matched path with the replacement
+      // path.
+      if (match) {
+        // Determine the range in the path which we'll be replacing.
+        unsigned replaceStartIndex = match->second.getBase() ? 0 : startIndex;
+        unsigned replaceEndIndex = startIndex + match->first;
+
+        // Overwrite the beginning of the match.
+        auto replacementPath = match->second.getPath();
+        assert((replaceEndIndex - replaceStartIndex) >= replacementPath.size());
+        auto replacementStartPos = path.begin() + replaceStartIndex;
+        std::copy(replacementPath.begin(), replacementPath.end(),
+                  replacementStartPos);
+
+        // Erase the rest.
+        path.erase(replacementStartPos + replacementPath.size(),
+                   path.begin() + replaceEndIndex);
+
+        // If this is an absolute path, use the new base.
+        if (auto newBase = match->second.getBase()) {
+          genericParamType =
+            GenericTypeParamType::get(newBase->Depth, newBase->Index,
+                                      getASTContext());
+          initialEquivClass =
+            resolveEquivalenceClass(genericParamType,
+                                    ArchetypeResolutionKind::WellFormed);
+          assert(initialEquivClass && "Must have an equivalence class");
+        }
+
+        // Move back to the beginning; we may have opened up other rewrites.
+        simplified = true;
+        startIndex = 0;
+        currentType = genericParamType;
+        equivClass = initialEquivClass;
+        continue;
+      }
+    }
+
+    // If we've hit the end of the path, we're done.
+    if (startIndex >= path.size()) break;
+
+    // FIXME: It would be nice if there were a better way to get the equivalence
+    // class of a named nested type.
+    currentType = DependentMemberType::get(currentType, path[startIndex++]);
+    equivClass =
+      resolveEquivalenceClass(currentType, ArchetypeResolutionKind::WellFormed);
+    if (!equivClass) break;
+  } while (true);
+
+  return formDependentType(genericParamType, path);
+}
+
 #pragma mark Equivalence classes
 EquivalenceClass::EquivalenceClass(PotentialArchetype *representative)
   : recursiveConcreteType(false), invalidConcreteType(false),
@@ -3865,6 +4544,9 @@
        PotentialArchetype *OrigT2,
        const RequirementSource *Source) 
 {
+  // Add a rewrite rule based on the given same-type constraint.
+  addSameTypeRewriteRule(OrigT1, OrigT2);
+
   // Record the same-type constraint, and bail out if it was already known.
   if (!OrigT1->getOrCreateEquivalenceClass(*this)
         ->recordSameTypeConstraint(OrigT1, OrigT2, Source))
@@ -3910,6 +4592,20 @@
                                    equivClass->sameTypeConstraints.end(),
                                    equivClass2->sameTypeConstraints.begin(),
                                    equivClass2->sameTypeConstraints.end());
+
+    // Combine the rewrite rules.
+    if (auto rewriteRoot2 = Impl->getOrCreateRewriteTreeRoot(equivClass2)) {
+      if (auto rewriteRoot1 = Impl->getOrCreateRewriteTreeRoot(equivClass)) {
+        // Merge the second rewrite tree into the first.
+        rewriteRoot2->mergeInto(rewriteRoot1);
+        Impl->RewriteTreeRoots.erase(equivClass2);
+        delete rewriteRoot2;
+      } else {
+        // Take the second rewrite tree and make it the first.
+        Impl->RewriteTreeRoots.erase(equivClass2);
+        (void)Impl->RewriteTreeRoots.insert({equivClass, rewriteRoot2});
+      }
+    }
   }
 
   // Same-type-to-concrete requirements.
@@ -5446,7 +6142,7 @@
       return lhs.constraint < rhs.constraint;
     }
 
-    LLVM_ATTRIBUTE_DEPRECATED(void dump() const,
+    LLVM_ATTRIBUTE_DEPRECATED(void dump() const LLVM_ATTRIBUTE_USED,
                               "only for use in the debugger");
   };
 }