Merge pull request #14454 from DougGregor/gsb-term-rewriting
[GSB] Term rewriting for same-type constraints
diff --git a/include/swift/AST/GenericSignatureBuilder.h b/include/swift/AST/GenericSignatureBuilder.h
index b9c0e8d..dff956a 100644
--- a/include/swift/AST/GenericSignatureBuilder.h
+++ b/include/swift/AST/GenericSignatureBuilder.h
@@ -258,11 +258,13 @@
Type getTypeInContext(GenericSignatureBuilder &builder,
GenericEnvironment *genericEnv);
- /// Dump a debugging representation of this equivalence class.
- void dump(llvm::raw_ostream &out) const;
+ /// Dump a debugging representation of this equivalence class,
+ void dump(llvm::raw_ostream &out,
+ GenericSignatureBuilder *builder = nullptr) const;
- LLVM_ATTRIBUTE_DEPRECATED(void dump() const,
- "only for use in the debugger");
+ LLVM_ATTRIBUTE_DEPRECATED(
+ void dump(GenericSignatureBuilder *builder = nullptr) const,
+ "only for use in the debugger");
/// Caches.
@@ -271,9 +273,8 @@
/// The cached anchor itself.
Type anchor;
- /// The number of members of the equivalence class when the archetype
- /// anchor was cached.
- unsigned numMembers;
+ /// The generation at which the anchor was last computed.
+ unsigned lastGeneration;
} archetypeAnchorCache;
/// Describes a cached nested type.
@@ -444,6 +445,11 @@
/// Note that we have added the nested type nestedPA
void addedNestedType(PotentialArchetype *nestedPA);
+ /// Add a rewrite rule for a same-type constraint between the given
+ /// types.
+ void addSameTypeRewriteRule(PotentialArchetype *type1,
+ PotentialArchetype *type2);
+
/// \brief Add a new conformance requirement specifying that the given
/// potential archetypes are equivalent.
ConstraintResult addSameTypeRequirementBetweenArchetypes(
@@ -802,6 +808,12 @@
/// Determine whether the two given types are in the same equivalence class.
bool areInSameEquivalenceClass(Type type1, Type type2);
+ /// Simplify the given dependent type down to its canonical representation.
+ ///
+ /// \returns null if the type involved dependent member types that
+ /// don't have associated types.
+ Type getCanonicalTypeParameter(Type type);
+
/// Verify the correctness of the given generic signature.
///
/// This routine will test that the given generic signature is both minimal
diff --git a/lib/AST/GenericSignatureBuilder.cpp b/lib/AST/GenericSignatureBuilder.cpp
index f473456..d657f7c 100644
--- a/lib/AST/GenericSignatureBuilder.cpp
+++ b/lib/AST/GenericSignatureBuilder.cpp
@@ -107,6 +107,190 @@
"Delayed requirements left unresolved");
STATISTIC(NumConditionalRequirementsAdded,
"# of conditional requirements added");
+STATISTIC(NumComponentsCollapsedViaRewriting,
+ "# of same-type components collapsed via term rewriting");
+
+namespace {
+
+/// A purely-relative rewrite path consisting of a (possibly empty)
+/// sequence of associated type references.
+using RelativeRewritePath = ArrayRef<AssociatedTypeDecl *>;
+
+/// Describes a rewrite path, which contains an optional base (generic
+/// parameter) followed by a sequence of associated type references.
+class RewritePath {
+ Optional<GenericParamKey> base;
+ TinyPtrVector<AssociatedTypeDecl *> path;
+
+public:
+ RewritePath() { }
+
+ enum PathOrder {
+ Forward,
+ Reverse,
+ };
+
+ /// Form a rewrite path given an optional base and a relative rewrite path.
+ RewritePath(Optional<GenericParamKey> base, RelativeRewritePath path,
+ PathOrder order);
+
+ /// Retrieve the base of the given rewrite path.
+ ///
+ /// When present, it indicates that the entire path will be rebased on
+ /// the given base generic parameter. This is required for describing
+ /// rewrites on type parameters themselves, e.g., T == U.
+ Optional<GenericParamKey> getBase() const { return base; }
+
+ /// Retrieve the sequence of associated type references that describes
+ /// the path.
+ ArrayRef<AssociatedTypeDecl *> getPath() const { return path; }
+
+ /// Decompose a type into a path.
+ ///
+ /// \returns the path, or None if it contained unresolved dependent member
+ /// types.
+ Optional<RewritePath> static createPath(Type type);
+
+ /// Decompose a potential archetype into a patch.
+ ///
+ /// \returns the path, or None if it contained potential archetypes
+ /// with concrete declarations.
+ Optional<RewritePath> static createPath(PotentialArchetype *pa);
+
+ /// Compute the common path between this path and \c other, if one exists.
+ Optional<RewritePath> commonPath(const RewritePath &other) const;
+
+ /// Form a canonical, dependent type.
+ ///
+ /// This requires that the rewrite path have a base.
+ CanType formDependentType(ASTContext &ctx) const;
+
+ /// Compare the given rewrite paths.
+ int compare(const RewritePath &other) const;
+
+ /// Print this path.
+ void print(llvm::raw_ostream &out) const;
+
+ LLVM_ATTRIBUTE_DEPRECATED(void dump() const LLVM_ATTRIBUTE_USED,
+ "only for use within the debugger") {
+ print(llvm::errs());
+ }
+
+ friend bool operator==(const RewritePath &lhs, const RewritePath &rhs) {
+ return lhs.getBase() == rhs.getBase() && lhs.getPath() == rhs.getPath();
+ }
+};
+
+/// A node within the prefix tree that is used to match associated type
+/// references.
+class RewriteTreeNode {
+ /// The associated type that leads to this node.
+ ///
+ /// The bit indicates whether there is a rewrite rule for this particular
+ /// node. If the bit is not set, \c rewrite is invalid.
+ llvm::PointerIntPair<AssociatedTypeDecl *, 1, bool> assocTypeAndHasRewrite;
+
+ /// The sequence of associated types to which a reference to this associated
+ /// type (from the equivalence class root) can be rewritten. This field is
+ /// only valid when the bit of \c assocTypeAndHasRewrite is set.
+ ///
+ /// Consider a requirement "Self.A.B.C == C". This will be encoded as
+ /// a prefix tree starting at the equivalence class for Self with
+ /// the following nodes:
+ ///
+ /// (assocType: A,
+ /// children: [
+ /// (assocType: B,
+ /// children: [
+ /// (assocType: C, rewrite: [C], children: [])
+ /// ])
+ /// ])
+ RewritePath rewrite;
+
+ /// The child nodes, which extend the sequence to be matched.
+ ///
+ /// The child nodes are sorted by the associated type declaration
+ /// pointers, so we can perform binary searches quickly.
+ llvm::TinyPtrVector<RewriteTreeNode *> children;
+
+public:
+ ~RewriteTreeNode();
+
+ RewriteTreeNode(AssociatedTypeDecl *assocType)
+ : assocTypeAndHasRewrite(assocType, false) { }
+
+ /// Retrieve the associated type declaration one must match to use this
+ /// node, which may the
+ AssociatedTypeDecl *getMatch() const {
+ return assocTypeAndHasRewrite.getPointer();
+ }
+
+ /// Determine whether this particular node has a rewrite rule.
+ bool hasRewriteRule() const {
+ return assocTypeAndHasRewrite.getInt();
+ }
+
+ /// Set a new rewrite rule for this particular node. This can only be
+ /// performed once.
+ void setRewriteRule(RewritePath replacementPath) {
+ assert(!hasRewriteRule());
+ assocTypeAndHasRewrite.setInt(true);
+ rewrite = replacementPath;
+ }
+
+ /// Retrieve the path to which this node will be rewritten.
+ const RewritePath &getRewriteRule() const {
+ assert(hasRewriteRule());
+ return rewrite;
+ }
+
+ /// Add a new rewrite rule to this tree node.
+ ///
+ /// \param matchPath The path of associated type declarations that must
+ /// be matched to produce a rewrite.
+ ///
+ /// \param replacementPath The sequence of associated type declarations
+ /// with which a match will be replaced.
+ void addRewriteRule(RelativeRewritePath matchPath,
+ RewritePath replacementPath);
+
+ /// Enumerate all of the paths to which the given matched path can be
+ /// rewritten.
+ ///
+ /// \param matchPath The path to match.
+ ///
+ /// \param callback A callback that will be invoked with (prefix, rewrite)
+ /// pairs, where \c prefix is the length of the matching prefix of
+ /// \c matchPath that matched and \c rewrite is the path to which it can
+ /// be rewritten.
+ void enumerateRewritePaths(
+ RelativeRewritePath matchPath,
+ llvm::function_ref<void(unsigned, RewritePath)> callback,
+ unsigned depth = 0) const;
+
+ /// Find the best rewrite rule to match the given path.
+ ///
+ /// \param path The path to match.
+ /// \param prefixLength The length of the prefix leading up to \c path.
+ Optional<std::pair<unsigned, RewritePath>>
+ bestMatch(GenericParamKey base, RelativeRewritePath path,
+ unsigned prefixLength);
+
+ /// Merge the given rewrite tree into \c other.
+ void mergeInto(RewriteTreeNode *other);
+
+ LLVM_ATTRIBUTE_DEPRECATED(void dump() const LLVM_ATTRIBUTE_USED,
+ "only for use within the debugger");
+
+ /// Dump the tree.
+ void dump(llvm::raw_ostream &out, bool lastChild = true) const;
+
+private:
+ /// Merge the given rewrite tree into \c other.
+ void mergeIntoRec(RewriteTreeNode *other,
+ llvm::SmallVectorImpl<AssociatedTypeDecl *> &matchPath);
+};
+}
struct GenericSignatureBuilder::Implementation {
/// Allocator.
@@ -131,6 +315,9 @@
/// Equivalence classes that are not currently being used.
std::vector<void *> FreeEquivalenceClasses;
+ /// The roots of the rewrite tree.
+ DenseMap<const EquivalenceClass *, RewriteTreeNode *> RewriteTreeRoots;
+
/// The generation number, which is incremented whenever we successfully
/// introduce a new constraint.
unsigned Generation = 0;
@@ -141,16 +328,6 @@
/// Whether we are currently processing delayed requirements.
bool ProcessingDelayedRequirements = false;
- /// Tear down an implementation.
- ~Implementation();
-
- /// Allocate a new equivalence class with the given representative.
- EquivalenceClass *allocateEquivalenceClass(
- PotentialArchetype *representative);
-
- /// Deallocate the given equivalence class, returning it to the free list.
- void deallocateEquivalenceClass(EquivalenceClass *equivClass);
-
/// Whether there were any errors.
bool HadAnyError = false;
@@ -161,12 +338,35 @@
/// Whether we've already finalized the builder.
bool finalized = false;
#endif
+
+ /// Tear down an implementation.
+ ~Implementation();
+
+ /// Allocate a new equivalence class with the given representative.
+ EquivalenceClass *allocateEquivalenceClass(
+ PotentialArchetype *representative);
+
+ /// Deallocate the given equivalence class, returning it to the free list.
+ void deallocateEquivalenceClass(EquivalenceClass *equivClass);
+
+ /// Retrieve the rewrite tree root for the given equivalence class,
+ /// if present.
+ RewriteTreeNode *getRewriteTreeRootIfPresent(
+ const EquivalenceClass *equivClass);
+
+ /// Retrieve the rewrite tree root for the given equivalence class,
+ /// creating it if needed.
+ RewriteTreeNode *getOrCreateRewriteTreeRoot(
+ const EquivalenceClass *equivClass);
};
#pragma mark Memory management
GenericSignatureBuilder::Implementation::~Implementation() {
for (auto pa : PotentialArchetypes)
pa->~PotentialArchetype();
+
+ for (const auto &root : RewriteTreeRoots)
+ delete root.second;
}
EquivalenceClass *
@@ -1915,31 +2115,13 @@
return populateResult((nestedTypeNameCache[name] = std::move(entry)));
}
-/// Determine whether any part of this potential archetype's path to the
-/// root contains the given equivalence class.
-static bool pathContainsEquivalenceClass(GenericSignatureBuilder &builder,
- PotentialArchetype *pa,
- EquivalenceClass *equivClass) {
- // Chase the potential archetype up to the root.
- for (; pa; pa = pa->getParent()) {
- // Check whether this potential archetype is in the given equivalence
- // class.
- if (pa->getOrCreateEquivalenceClass(builder) == equivClass)
- return true;
- }
-
- return false;
-}
-
Type EquivalenceClass::getAnchor(
GenericSignatureBuilder &builder,
TypeArrayView<GenericTypeParamType> genericParams) {
- // Check whether the cache is valid.
- if (archetypeAnchorCache.anchor &&
- archetypeAnchorCache.numMembers == members.size()) {
- ++NumArchetypeAnchorCacheHits;
+ // Substitute into the anchor with the given generic parameters.
+ auto substAnchor = [&] {
+ if (genericParams.empty()) return archetypeAnchorCache.anchor;
- // Reparent the anchor using genericParams.
return archetypeAnchorCache.anchor.subst(
[&](SubstitutableType *dependentType) {
if (auto gp = dyn_cast<GenericTypeParamType>(dependentType)) {
@@ -1951,73 +2133,43 @@
return Type(dependentType);
},
MakeAbstractConformanceForGenericType());
+
+ };
+
+ // Check whether the cache is valid.
+ if (archetypeAnchorCache.anchor &&
+ archetypeAnchorCache.lastGeneration == builder.Impl->Generation) {
+ ++NumArchetypeAnchorCacheHits;
+ return substAnchor();
}
- // Map the members of this equivalence class to the best associated type
- // within that equivalence class.
- llvm::SmallDenseMap<EquivalenceClass *, AssociatedTypeDecl *> nestedTypes;
+ // Check whether we already have an anchor, in which case we
+ // can simplify it further.
+ if (archetypeAnchorCache.anchor) {
+ // Record the cache miss.
+ ++NumArchetypeAnchorCacheMisses;
- Type bestGenericParam;
+ // Update the anchor by simplifying it further.
+ archetypeAnchorCache.anchor =
+ builder.getCanonicalTypeParameter(archetypeAnchorCache.anchor);
+ archetypeAnchorCache.lastGeneration = builder.Impl->Generation;
+ return substAnchor();
+ }
+
+ // Form the anchor.
for (auto member : members) {
- // If the member is a generic parameter, keep the best generic parameter.
- if (member->isGenericParam()) {
- Type genericParamType = member->getDependentType(genericParams);
- if (!bestGenericParam ||
- compareDependentTypes(genericParamType, bestGenericParam) < 0)
- bestGenericParam = genericParamType;
- continue;
- }
+ auto anchorType =
+ builder.getCanonicalTypeParameter(member->getDependentType(genericParams));
+ if (!anchorType) continue;
- // If we saw a generic parameter, ignore any nested types.
- if (bestGenericParam) continue;
-
- // If the nested type doesn't have an associated type, skip it.
- auto assocType = member->getResolvedAssociatedType();
- if (!assocType) continue;
-
- // Dig out the equivalence class of the parent.
- auto parentEquivClass =
- member->getParent()->getOrCreateEquivalenceClass(builder);
-
- // If the path from this member to the root contains this equivalence
- // class, it cannot be part of the anchor.
- if (pathContainsEquivalenceClass(builder, member->getParent(), this))
- continue;
-
- // Take the best associated type for this equivalence class.
- assocType = assocType->getAssociatedTypeAnchor();
- auto &bestAssocType = nestedTypes[parentEquivClass];
- if (!bestAssocType ||
- compareAssociatedTypes(assocType, bestAssocType) < 0)
- bestAssocType = assocType;
+ // Record the cache miss and update the cache.
+ ++NumArchetypeAnchorCacheMisses;
+ archetypeAnchorCache.anchor = anchorType;
+ archetypeAnchorCache.lastGeneration = builder.Impl->Generation;
+ return substAnchor();
}
- // If we found a generic parameter, return that.
- if (bestGenericParam)
- return bestGenericParam;
-
- // Determine the best anchor among the parent equivalence classes.
- Type bestParentAnchor;
- AssociatedTypeDecl *bestAssocType = nullptr;
- std::pair<EquivalenceClass *, Identifier> bestNestedType;
- for (const auto &nestedType : nestedTypes) {
- auto parentAnchor = nestedType.first->getAnchor(builder, genericParams);
- if (!bestParentAnchor ||
- compareDependentTypes(parentAnchor, bestParentAnchor) < 0) {
- bestParentAnchor = parentAnchor;
- bestAssocType = nestedType.second;
- }
- }
-
- // Form the anchor type.
- Type anchorType = DependentMemberType::get(bestParentAnchor, bestAssocType);
-
- // Record the cache miss and update the cache.
- ++NumArchetypeAnchorCacheMisses;
- archetypeAnchorCache.anchor = anchorType;
- archetypeAnchorCache.numMembers = members.size();
-
- return anchorType;
+ llvm_unreachable("Unable to compute anchor");
}
Type EquivalenceClass::getTypeInContext(GenericSignatureBuilder &builder,
@@ -2138,7 +2290,8 @@
return archetype;
}
-void EquivalenceClass::dump(llvm::raw_ostream &out) const {
+void EquivalenceClass::dump(llvm::raw_ostream &out,
+ GenericSignatureBuilder *builder) const {
out << "Equivalence class represented by "
<< members.front()->getRepresentative()->getDebugName() << ":\n";
out << "Members: ";
@@ -2183,6 +2336,13 @@
out << "\n";
+ if (builder) {
+ if (auto rewriteRoot = builder->Impl->getRewriteTreeRootIfPresent(this)) {
+ out << "---Rewrite tree---\n";
+ rewriteRoot->dump(out);
+ }
+ }
+
{
out << "---GraphViz output for same-type constraints---\n";
@@ -2215,8 +2375,8 @@
}
}
-void EquivalenceClass::dump() const {
- dump(llvm::errs());
+void EquivalenceClass::dump(GenericSignatureBuilder *builder) const {
+ dump(llvm::errs(), builder);
}
void DelayedRequirement::dump(llvm::raw_ostream &out) const {
@@ -2523,6 +2683,24 @@
return 0;
}
+/// Compare two dependent paths to determine which is better.
+static int compareDependentPaths(ArrayRef<AssociatedTypeDecl *> path1,
+ ArrayRef<AssociatedTypeDecl *> path2) {
+ // Shorter paths win.
+ if (path1.size() != path2.size())
+ return path1.size() < path2.size() ? -1 : 1;
+
+ // The paths are the same length, so order by comparing the associted
+ // types.
+ for (unsigned index : indices(path1)) {
+ if (int result = compareAssociatedTypes(path1[index], path2[index]))
+ return result;
+ }
+
+ // Identical paths.
+ return 0;
+}
+
namespace {
/// Function object used to suppress conflict diagnoses when we know we'll
/// see them again later.
@@ -2883,6 +3061,507 @@
}
}
+#pragma mark Rewrite tree
+RewritePath::RewritePath(Optional<GenericParamKey> base,
+ RelativeRewritePath path,
+ PathOrder order)
+ : base(base)
+{
+ switch (order) {
+ case Forward:
+ this->path.insert(this->path.begin(), path.begin(), path.end());
+ break;
+
+ case Reverse:
+ this->path.insert(this->path.begin(), path.rbegin(), path.rend());
+ break;
+ }
+}
+
+Optional<RewritePath> RewritePath::createPath(PotentialArchetype *pa) {
+ SmallVector<AssociatedTypeDecl *, 4> path;
+ while (auto parent = pa->getParent()) {
+ auto assocType = pa->getResolvedAssociatedType();
+ if (!assocType) return None;
+
+ path.push_back(assocType);
+ pa = parent;
+ }
+
+ return RewritePath(pa->getGenericParamKey(), path, Reverse);
+}
+
+Optional<RewritePath> RewritePath::createPath(Type type) {
+ SmallVector<AssociatedTypeDecl *, 4> path;
+ while (auto depMemTy = type->getAs<DependentMemberType>()) {
+ auto assocType = depMemTy->getAssocType();
+ if (!assocType) return None;
+
+ path.push_back(assocType);
+ type = depMemTy->getBase();
+ }
+
+ auto genericParam = type->getAs<GenericTypeParamType>();
+ if (!genericParam) return None;
+
+ return RewritePath(GenericParamKey(genericParam), path, Reverse);
+}
+
+Optional<RewritePath> RewritePath::commonPath(const RewritePath &other) const {
+ assert(getBase().hasValue() && other.getBase().hasValue());
+
+ if (*getBase() != *other.getBase()) return None;
+
+ // Find the longest common prefix.
+ RelativeRewritePath path1 = getPath();
+ RelativeRewritePath path2 = other.getPath();
+ if (path1.size() > path2.size())
+ std::swap(path1, path2);
+ unsigned prefixLength =
+ std::mismatch(path1.begin(), path1.end(), path2.begin()).first
+ - path1.begin();
+
+ // Form the common path.
+ return RewritePath(getBase(), path1.slice(0, prefixLength), Forward);
+}
+
+/// Form a dependent type with the given generic parameter, then following the
+/// path of associated types.
+static Type formDependentType(GenericTypeParamType *base,
+ RelativeRewritePath path) {
+ return std::accumulate(path.begin(), path.end(), Type(base),
+ [](Type type, AssociatedTypeDecl *assocType) -> Type {
+ return DependentMemberType::get(type, assocType);
+ });
+}
+
+/// Form a dependent type with the (canonical) generic parameter for the given
+/// parameter key, then following the path of associated types.
+static Type formDependentType(ASTContext &ctx, GenericParamKey genericParam,
+ RelativeRewritePath path) {
+ return formDependentType(GenericTypeParamType::get(genericParam.Depth,
+ genericParam.Index,
+ ctx),
+ path);
+}
+
+CanType RewritePath::formDependentType(ASTContext &ctx) const {
+ assert(getBase());
+ return CanType(::formDependentType(ctx, *getBase(), getPath()));
+}
+
+int RewritePath::compare(const RewritePath &other) const {
+ // Prefer relative to absolute paths.
+ if (getBase().hasValue() != other.getBase().hasValue()) {
+ return other.getBase().hasValue() ? -1 : 1;
+ }
+
+ // Order based on the bases.
+ if (getBase() && *getBase() != *other.getBase())
+ return (*getBase() < *other.getBase()) ? -1 : 1;
+
+ // Order based on the path contents.
+ return compareDependentPaths(getPath(), other.getPath());
+}
+
+void RewritePath::print(llvm::raw_ostream &out) const {
+ out << "[";
+
+ if (getBase()) {
+ out << "(" << getBase()->Depth << ", " << getBase()->Index << ")";
+ if (!getPath().empty()) out << " -> ";
+ }
+
+ interleave(getPath().begin(), getPath().end(),
+ [&](AssociatedTypeDecl *assocType) {
+ out.changeColor(raw_ostream::BLUE);
+ out << assocType->getProtocol()->getName() << "."
+ << assocType->getName();
+ out.resetColor();
+ }, [&] {
+ out << " -> ";
+ });
+ out << "]";
+}
+
+RewriteTreeNode::~RewriteTreeNode() {
+ for (auto child : children)
+ delete child;
+}
+
+namespace {
+/// Function object used to order rewrite tree nodes based on the address
+/// of the associated type.
+class OrderTreeRewriteNode {
+ bool compare(AssociatedTypeDecl *lhs, AssociatedTypeDecl *rhs) const {
+ // Make sure null pointers precede everything else.
+ if (static_cast<bool>(lhs) != static_cast<bool>(rhs))
+ return static_cast<bool>(rhs);
+
+ // Use std::less to provide a defined ordering.
+ return std::less<AssociatedTypeDecl *>()(lhs, rhs);
+ }
+
+public:
+ bool operator()(RewriteTreeNode *lhs, AssociatedTypeDecl *rhs) const {
+ return compare(lhs->getMatch(), rhs);
+ }
+
+ bool operator()(AssociatedTypeDecl *lhs, RewriteTreeNode *rhs) const {
+ return compare(lhs, rhs->getMatch());
+ }
+
+ bool operator()(RewriteTreeNode *lhs, RewriteTreeNode *rhs) const {
+ return compare(lhs->getMatch(), rhs->getMatch());
+ }
+};
+}
+
+void RewriteTreeNode::addRewriteRule(RelativeRewritePath matchPath,
+ RewritePath replacementPath) {
+ // If the match path is empty, we're adding the rewrite rule to this node.
+ if (matchPath.empty()) {
+ // If we don't already have a rewrite rule, add it.
+ if (!hasRewriteRule()) {
+ setRewriteRule(replacementPath);
+ return;
+ }
+
+ // If we already have this rewrite rule, we're done.
+ if (getRewriteRule() == replacementPath) return;
+
+ // Check whether any of the continuation children matches.
+ auto insertPos = children.begin();
+ while (insertPos != children.end() && !(*insertPos)->getMatch()) {
+ if ((*insertPos)->hasRewriteRule() &&
+ (*insertPos)->getRewriteRule() == replacementPath)
+ return;
+
+ ++insertPos;
+ }
+
+ // We already have a rewrite rule, so add a new child with a
+ // null associated type match to hold the rewrite rule.
+ auto newChild = new RewriteTreeNode(nullptr);
+ newChild->setRewriteRule(replacementPath);
+ children.insert(insertPos, newChild);
+ return;
+ }
+
+ // Find (or create) a child node describing the next step in the match.
+ auto matchFront = matchPath.front();
+ auto childPos =
+ std::lower_bound(children.begin(), children.end(), matchFront,
+ OrderTreeRewriteNode());
+ if (childPos == children.end() || (*childPos)->getMatch() != matchFront) {
+ childPos = children.insert(childPos, new RewriteTreeNode(matchFront));
+ }
+
+ // Add the rewrite rule to the child.
+ (*childPos)->addRewriteRule(matchPath.slice(1), replacementPath);
+}
+
+void RewriteTreeNode::enumerateRewritePaths(
+ RelativeRewritePath matchPath,
+ llvm::function_ref<void(unsigned, RewritePath)> callback,
+ unsigned depth) const {
+ // Determine whether we know anything about the next step in the path.
+ auto childPos =
+ depth < matchPath.size()
+ ? std::lower_bound(children.begin(), children.end(),
+ matchPath[depth], OrderTreeRewriteNode())
+ : children.end();
+ if (childPos != children.end() &&
+ (*childPos)->getMatch() == matchPath[depth]) {
+ // Try to match the rest of the path.
+ (*childPos)->enumerateRewritePaths(matchPath, callback, depth + 1);
+ }
+
+ // If we have a rewrite rule at this position, invoke it.
+ if (hasRewriteRule()) {
+ // Invoke the callback with the first result.
+ callback(depth, rewrite);
+ }
+
+ // Walk any children with NULL associated types; they might have more matches.
+ for (auto otherRewrite : children) {
+ if (otherRewrite->getMatch()) break;
+ otherRewrite->enumerateRewritePaths(matchPath, callback, depth);
+ }
+}
+
+Optional<std::pair<unsigned, RewritePath>>
+RewriteTreeNode::bestMatch(GenericParamKey base, RelativeRewritePath path,
+ unsigned prefixLength) {
+ Optional<std::pair<unsigned, RewritePath>> best;
+ unsigned bestAdjustedLength = 0;
+ enumerateRewritePaths(path,
+ [&](unsigned length, RewritePath path) {
+ // Determine how much of the original path will be replaced by the rewrite.
+ unsigned adjustedLength = length;
+ if (auto newBase = path.getBase()) {
+ adjustedLength += prefixLength;
+
+ // If the base is unchanged, make sure we're reducing the length.
+ if (*newBase == base && adjustedLength <= path.getPath().size())
+ return;
+ }
+
+ if (adjustedLength == 0) return;
+
+ if (adjustedLength > bestAdjustedLength ||
+ (adjustedLength == bestAdjustedLength &&
+ path.compare(best->second) < 0)) {
+ best = { length, path };
+ bestAdjustedLength = adjustedLength;
+ }
+ });
+
+ return best;
+}
+
+void RewriteTreeNode::mergeInto(RewriteTreeNode *other) {
+ SmallVector<AssociatedTypeDecl *, 4> matchPath;
+ mergeIntoRec(other, matchPath);
+}
+
+void RewriteTreeNode::mergeIntoRec(
+ RewriteTreeNode *other,
+ llvm::SmallVectorImpl<AssociatedTypeDecl *> &matchPath) {
+ // FIXME: A destructive version of this operation would be more efficient,
+ // since we generally don't care about \c other after doing this.
+ if (auto assocType = getMatch())
+ matchPath.push_back(assocType);
+
+ // Add this rewrite rule, if there is one.
+ if (hasRewriteRule())
+ other->addRewriteRule(matchPath, rewrite);
+
+ // Recurse into the child nodes.
+ for (auto child : children)
+ child->mergeIntoRec(other, matchPath);
+
+ if (auto assocType = getMatch())
+ matchPath.pop_back();
+}
+
+void RewriteTreeNode::dump() const {
+ dump(llvm::errs());
+}
+
+void RewriteTreeNode::dump(llvm::raw_ostream &out, bool lastChild) const {
+ std::string prefixStr;
+
+ std::function<void(const RewriteTreeNode *, bool lastChild)> print;
+ print = [&](const RewriteTreeNode *node, bool lastChild) {
+ out << prefixStr << " `--";
+
+ // Print the node name.
+ out.changeColor(raw_ostream::GREEN);
+ if (auto assoc = node->getMatch())
+ out << assoc->getProtocol()->getName() << "." << assoc->getName();
+ else
+ out << "(cont'd)";
+ out.resetColor();
+
+ // Print the rewrite, if there is one.
+ if (node->hasRewriteRule()) {
+ out << " --> ";
+ node->rewrite.print(out);
+ }
+
+ out << "\n";
+
+ // Print children.
+ prefixStr += ' ';
+ prefixStr += (lastChild ? ' ' : '|');
+ prefixStr += " ";
+
+ for (auto child : node->children) {
+ print(child, child == node->children.back());
+ }
+
+ prefixStr.erase(prefixStr.end() - 4, prefixStr.end());
+ };
+
+ print(this, lastChild);
+}
+
+RewriteTreeNode *
+GenericSignatureBuilder::Implementation::getRewriteTreeRootIfPresent(
+ const EquivalenceClass *equivClass) {
+ auto known = RewriteTreeRoots.find(equivClass);
+ if (known != RewriteTreeRoots.end()) return known->second;
+
+ return nullptr;
+}
+
+RewriteTreeNode *
+GenericSignatureBuilder::Implementation::getOrCreateRewriteTreeRoot(
+ const EquivalenceClass *equivClass) {
+ auto known = RewriteTreeRoots.find(equivClass);
+ if (known != RewriteTreeRoots.end()) return known->second;
+
+ auto root = new RewriteTreeNode(nullptr);
+ RewriteTreeRoots[equivClass] = root;
+ return root;
+}
+
+void GenericSignatureBuilder::addSameTypeRewriteRule(PotentialArchetype *pa1,
+ PotentialArchetype *pa2){
+ auto pathOpt1 = RewritePath::createPath(pa1);
+ if (!pathOpt1) return;
+
+ auto pathOpt2 = RewritePath::createPath(pa2);
+ if (!pathOpt2) return;
+
+ auto path1 = std::move(pathOpt1).getValue();
+ auto path2 = std::move(pathOpt2).getValue();
+
+ // Look for a common path.
+ auto prefix = path1.commonPath(path2);
+
+ // If we didn't find a common path, try harder.
+ Type simplifiedType1;
+ Type simplifiedType2;
+ if (!prefix) {
+ // Simplify both sides in the hope of uncovering a common path.
+ simplifiedType1 = getCanonicalTypeParameter(pa1->getDependentType({ }));
+ simplifiedType2 = getCanonicalTypeParameter(pa2->getDependentType({ }));
+ if (simplifiedType1->isEqual(simplifiedType2)) return;
+
+ // Create new paths from the simplified types.
+ path1 = *RewritePath::createPath(simplifiedType1);
+ path2 = *RewritePath::createPath(simplifiedType2);
+
+ // Find a common path.
+ prefix = path1.commonPath(path2);
+ }
+
+ // When we have a common prefix, form a rewrite rule using relative paths.
+ if (prefix) {
+ // Find the better relative rewrite path.
+ RelativeRewritePath relPath1
+ = path1.getPath().slice(prefix->getPath().size());
+ RelativeRewritePath relPath2
+ = path2.getPath().slice(prefix->getPath().size());
+ // Order the paths so that we go to the more-canonical path.
+ if (compareDependentPaths(relPath1, relPath2) < 0)
+ std::swap(relPath1, relPath2);
+
+ // Find the equivalence class for the prefix.
+ CanType commonType = prefix->formDependentType(getASTContext());
+ auto equivClass =
+ resolveEquivalenceClass(commonType, ArchetypeResolutionKind::WellFormed);
+ assert(equivClass && "Prefix cannot be resolved?");
+
+ // Add the rewrite rule.
+ auto root = Impl->getOrCreateRewriteTreeRoot(equivClass);
+ root->addRewriteRule(relPath1,
+ RewritePath(None, relPath2, RewritePath::Forward));
+
+ return;
+ }
+
+ // Otherwise, form a rewrite rule with absolute paths.
+
+ // Find the better path and make sure it's in path2.
+ if (compareDependentTypes(simplifiedType1, simplifiedType2) < 0) {
+ std::swap(path1, path2);
+ std::swap(simplifiedType1, simplifiedType2);
+ }
+
+ // Add the rewrite rule.
+ Type firstBase =
+ GenericTypeParamType::get(path1.getBase()->Depth, path1.getBase()->Index,
+ getASTContext());
+ auto equivClass =
+ resolveEquivalenceClass(firstBase, ArchetypeResolutionKind::WellFormed);
+ assert(equivClass && "Base cannot be resolved?");
+
+ auto root = Impl->getOrCreateRewriteTreeRoot(equivClass);
+ root->addRewriteRule(path1.getPath(), path2);
+}
+
+Type GenericSignatureBuilder::getCanonicalTypeParameter(Type type) {
+ auto initialPath = RewritePath::createPath(type);
+ if (!initialPath) return nullptr;
+
+ auto genericParamType =
+ GenericTypeParamType::get(initialPath->getBase()->Depth,
+ initialPath->getBase()->Index,
+ getASTContext());
+ auto initialEquivClass =
+ resolveEquivalenceClass(genericParamType,
+ ArchetypeResolutionKind::WellFormed);
+ if (!initialEquivClass) return nullptr;
+
+ unsigned startIndex = 0;
+ auto equivClass = initialEquivClass;
+ Type currentType = genericParamType;
+ SmallVector<AssociatedTypeDecl *, 4> path(initialPath->getPath().begin(),
+ initialPath->getPath().end());
+ bool simplified = false;
+ do {
+ if (auto rootNode = Impl->getRewriteTreeRootIfPresent(equivClass)) {
+ // Find the best rewrite rule for the path starting at startIndex.
+ auto match =
+ rootNode->bestMatch(genericParamType,
+ llvm::makeArrayRef(path).slice(startIndex),
+ startIndex);
+
+ // If we have a match, replace the matched path with the replacement
+ // path.
+ if (match) {
+ // Determine the range in the path which we'll be replacing.
+ unsigned replaceStartIndex = match->second.getBase() ? 0 : startIndex;
+ unsigned replaceEndIndex = startIndex + match->first;
+
+ // Overwrite the beginning of the match.
+ auto replacementPath = match->second.getPath();
+ assert((replaceEndIndex - replaceStartIndex) >= replacementPath.size());
+ auto replacementStartPos = path.begin() + replaceStartIndex;
+ std::copy(replacementPath.begin(), replacementPath.end(),
+ replacementStartPos);
+
+ // Erase the rest.
+ path.erase(replacementStartPos + replacementPath.size(),
+ path.begin() + replaceEndIndex);
+
+ // If this is an absolute path, use the new base.
+ if (auto newBase = match->second.getBase()) {
+ genericParamType =
+ GenericTypeParamType::get(newBase->Depth, newBase->Index,
+ getASTContext());
+ initialEquivClass =
+ resolveEquivalenceClass(genericParamType,
+ ArchetypeResolutionKind::WellFormed);
+ assert(initialEquivClass && "Must have an equivalence class");
+ }
+
+ // Move back to the beginning; we may have opened up other rewrites.
+ simplified = true;
+ startIndex = 0;
+ currentType = genericParamType;
+ equivClass = initialEquivClass;
+ continue;
+ }
+ }
+
+ // If we've hit the end of the path, we're done.
+ if (startIndex >= path.size()) break;
+
+ // FIXME: It would be nice if there were a better way to get the equivalence
+ // class of a named nested type.
+ currentType = DependentMemberType::get(currentType, path[startIndex++]);
+ equivClass =
+ resolveEquivalenceClass(currentType, ArchetypeResolutionKind::WellFormed);
+ if (!equivClass) break;
+ } while (true);
+
+ return formDependentType(genericParamType, path);
+}
+
#pragma mark Equivalence classes
EquivalenceClass::EquivalenceClass(PotentialArchetype *representative)
: recursiveConcreteType(false), invalidConcreteType(false),
@@ -3865,6 +4544,9 @@
PotentialArchetype *OrigT2,
const RequirementSource *Source)
{
+ // Add a rewrite rule based on the given same-type constraint.
+ addSameTypeRewriteRule(OrigT1, OrigT2);
+
// Record the same-type constraint, and bail out if it was already known.
if (!OrigT1->getOrCreateEquivalenceClass(*this)
->recordSameTypeConstraint(OrigT1, OrigT2, Source))
@@ -3910,6 +4592,20 @@
equivClass->sameTypeConstraints.end(),
equivClass2->sameTypeConstraints.begin(),
equivClass2->sameTypeConstraints.end());
+
+ // Combine the rewrite rules.
+ if (auto rewriteRoot2 = Impl->getOrCreateRewriteTreeRoot(equivClass2)) {
+ if (auto rewriteRoot1 = Impl->getOrCreateRewriteTreeRoot(equivClass)) {
+ // Merge the second rewrite tree into the first.
+ rewriteRoot2->mergeInto(rewriteRoot1);
+ Impl->RewriteTreeRoots.erase(equivClass2);
+ delete rewriteRoot2;
+ } else {
+ // Take the second rewrite tree and make it the first.
+ Impl->RewriteTreeRoots.erase(equivClass2);
+ (void)Impl->RewriteTreeRoots.insert({equivClass, rewriteRoot2});
+ }
+ }
}
// Same-type-to-concrete requirements.
@@ -5446,7 +6142,7 @@
return lhs.constraint < rhs.constraint;
}
- LLVM_ATTRIBUTE_DEPRECATED(void dump() const,
+ LLVM_ATTRIBUTE_DEPRECATED(void dump() const LLVM_ATTRIBUTE_USED,
"only for use in the debugger");
};
}