Merge pull request #18317 from DougGregor/refactor-name-lookup

diff --git a/include/swift/AST/DeclContext.h b/include/swift/AST/DeclContext.h
index d71fc0c..69b250e 100644
--- a/include/swift/AST/DeclContext.h
+++ b/include/swift/AST/DeclContext.h
@@ -486,6 +486,32 @@
                        LazyResolver *typeResolver,
                        SmallVectorImpl<ValueDecl *> &decls) const;
 
+  /// Look for the set of declarations with the given name within the
+  /// given set of type declarations.
+  ///
+  /// \param types The type declarations to look into.
+  ///
+  /// \param member The member to search for.
+  ///
+  /// \param options Options that control name lookup, based on the
+  /// \c NL_* constants in \c NameLookupOptions.
+  ///
+  /// \param[out] decls Will be populated with the declarations found by name
+  /// lookup.
+  ///
+  /// \returns true if anything was found.
+  bool lookupQualified(ArrayRef<TypeDecl *> types, DeclName member,
+                       NLOptions options,
+                       SmallVectorImpl<ValueDecl *> &decls) const;
+
+  /// Perform qualified lookup for the given member in the given module.
+  bool lookupQualified(ModuleDecl *module, DeclName member, NLOptions options,
+                       SmallVectorImpl<ValueDecl *> &decls) const;
+
+  /// Perform \c AnyObject lookup for the given member.
+  bool lookupAnyObject(DeclName member, NLOptions options,
+                       SmallVectorImpl<ValueDecl *> &decls) const;
+
   /// Look up all Objective-C methods with the given selector visible
   /// in the enclosing module.
   void lookupAllObjCMethods(
diff --git a/lib/AST/NameLookup.cpp b/lib/AST/NameLookup.cpp
index 01f1e2a..c4b0979 100644
--- a/lib/AST/NameLookup.cpp
+++ b/lib/AST/NameLookup.cpp
@@ -1523,21 +1523,29 @@
   llvm_unreachable("unhandled storage decl kind");
 }
 
-bool DeclContext::lookupQualified(Type type,
-                                  DeclName member,
-                                  NLOptions options,
-                                  LazyResolver *typeResolver,
-                                  SmallVectorImpl<ValueDecl *> &decls) const {
-  using namespace namelookup;
-  assert(decls.empty() && "additive lookup not supported");
+/// Configure name lookup for the given declaration context and options.
+///
+/// This utility is used by qualified name lookup.
+static void configureLookup(const DeclContext *dc,
+                            NLOptions &options,
+                            ReferencedNameTracker *&tracker,
+                            bool &isLookupCascading) {
+  auto &ctx = dc->getASTContext();
+  if (!ctx.LangOpts.EnableAccessControl)
+    options |= NL_IgnoreAccessControl;
 
-  if (!typeResolver)
-    typeResolver = getASTContext().getLazyResolver();
-  
-  auto checkLookupCascading = [this, options]() -> Optional<bool> {
+  // Find the dependency tracker we'll need for this lookup.
+  tracker = nullptr;
+  if (auto containingSourceFile =
+          dyn_cast<SourceFile>(dc->getModuleScopeContext())) {
+    tracker = containingSourceFile->getReferencedNameTracker();
+  }
+
+  auto checkLookupCascading = [dc, options]() -> Optional<bool> {
     switch (static_cast<unsigned>(options & NL_KnownDependencyMask)) {
     case 0:
-      return isCascadingContextForLookup(/*functionsAreNonCascading=*/false);
+      return dc->isCascadingContextForLookup(
+               /*functionsAreNonCascading=*/false);
     case NL_KnownNonCascadingDependency:
       return false;
     case NL_KnownCascadingDependency:
@@ -1554,161 +1562,207 @@
     }
   };
 
-  // Look for module references.
-  if (auto moduleTy = type->getAs<ModuleType>()) {
-    ModuleDecl *module = moduleTy->getModule();
-    auto topLevelScope = getModuleScopeContext();
-    if (module == topLevelScope->getParentModule()) {
-      if (auto maybeLookupCascade = checkLookupCascading()) {
-        recordLookupOfTopLevelName(topLevelScope, member,
-                                   maybeLookupCascade.getValue());
-      }
-      lookupInModule(module, /*accessPath=*/{}, member, decls,
-                     NLKind::QualifiedLookup, ResolutionKind::Overloadable,
-                     typeResolver, topLevelScope);
-    } else {
-      // Note: This is a lookup into another module. Unless we're compiling
-      // multiple modules at once, or if the other module re-exports this one,
-      // it shouldn't be possible to have a dependency from that module on
-      // anything in this one.
-
-      // Perform the lookup in all imports of this module.
-      forAllVisibleModules(this,
-                           [&](const ModuleDecl::ImportedModule &import) -> bool {
-        if (import.second != module)
-          return true;
-        lookupInModule(import.second, import.first, member, decls,
-                       NLKind::QualifiedLookup, ResolutionKind::Overloadable,
-                       typeResolver, topLevelScope);
-        // If we're able to do an unscoped lookup, we see everything. No need
-        // to keep going.
-        return !import.first.empty();
-      });
-    }
-
-    llvm::SmallPtrSet<ValueDecl *, 4> knownDecls;
-    decls.erase(std::remove_if(decls.begin(), decls.end(),
-                               [&](ValueDecl *vd) -> bool {
-      // If we're performing a type lookup, don't even attempt to validate
-      // the decl if its not a type.
-      if ((options & NL_OnlyTypes) && !isa<TypeDecl>(vd))
-        return true;
-
-      return !knownDecls.insert(vd).second;
-    }), decls.end());
-
-    if (auto *debugClient = topLevelScope->getParentModule()->getDebugClient())
-      filterForDiscriminator(decls, debugClient);
-    
-    return !decls.empty();
-  }
-
-  auto &ctx = getASTContext();
-  if (!ctx.LangOpts.EnableAccessControl)
-    options |= NL_IgnoreAccessControl;
-
-  // The set of nominal type declarations we should (and have) visited.
-  SmallVector<NominalTypeDecl *, 4> stack;
-  llvm::SmallPtrSet<NominalTypeDecl *, 4> visited;
-
-  // Note that we don't have to visit the superclass of a protocol.
-  // If we started with an archetype or existential, we'll visit the
-  // superclass because we will have added it to the stack upfront.
-  //
-  // If we started with a concrete class conforming to a protocol
-  // with a superclass, we will visit the superclass from the
-  // concrete type.
-  bool checkProtocolSuperclass = false;
-
-  // Handle nominal types.
-  bool wantProtocolMembers = (options & NL_ProtocolMembers);
-  bool wantLookupInAllClasses = false;
-  if (auto nominal = type->getAnyNominal()) {
-    visited.insert(nominal);
-    stack.push_back(nominal);
-
-    if (isa<ProtocolDecl>(nominal))
-      checkProtocolSuperclass = true;
-  }
-  // Handle archetypes
-  else if (auto archetypeTy = type->getAs<ArchetypeType>()) {
-    // Look in the protocols to which the archetype conforms (always).
-    for (auto proto : archetypeTy->getConformsTo())
-      if (visited.insert(proto).second)
-        stack.push_back(proto);
-
-    // Look into the superclasses of this archetype.
-    if (auto superclass = archetypeTy->getSuperclass())
-      if (auto superclassDecl = superclass->getClassOrBoundGenericClass())
-        if (visited.insert(superclassDecl).second)
-          stack.push_back(superclassDecl);
-  }
-  // Handle protocol compositions.
-  else if (auto compositionTy = type->getAs<ProtocolCompositionType>()) {
-    auto layout = compositionTy->getExistentialLayout();
-
-    if (layout.isAnyObject() &&
-        (options & NL_DynamicLookup))
-      wantLookupInAllClasses = true;
-
-    for (auto proto : layout.getProtocols()) {
-      auto *protoDecl = proto->getDecl();
-      if (visited.insert(protoDecl).second)
-        stack.push_back(protoDecl);
-    }
-
-    if (auto superclass = layout.explicitSuperclass) {
-      auto *superclassDecl = superclass->getClassOrBoundGenericClass();
-      if (visited.insert(superclassDecl).second)
-        stack.push_back(superclassDecl);
-    }
-  } else {
-    llvm_unreachable("Bad type for qualified lookup");
-  }
-
-  // Allow filtering of the visible declarations based on various
-  // criteria.
-  bool onlyCompleteObjectInits = false;
-  auto isAcceptableDecl = [&](NominalTypeDecl *current, ValueDecl *decl) -> bool {
-    // Filter out designated initializers, if requested.
-    if (onlyCompleteObjectInits) {
-      if (auto ctor = dyn_cast<ConstructorDecl>(decl)) {
-        if (isa<ClassDecl>(ctor->getDeclContext()) &&
-            !ctor->isInheritable())
-          return false;
-      } else {
-        return false;
-      }
-    }
-
-    // Ignore stub implementations.
-    if (auto ctor = dyn_cast<ConstructorDecl>(decl)) {
-      if (ctor->hasStubImplementation())
-        return false;
-    }
-
-    // Check access.
-    if (!(options & NL_IgnoreAccessControl)) {
-      return decl->isAccessibleFrom(this);
-    }
-
-    return true;
-  };
-
-  ReferencedNameTracker *tracker = nullptr;
-  if (auto containingSourceFile = dyn_cast<SourceFile>(getModuleScopeContext()))
-    tracker = containingSourceFile->getReferencedNameTracker();
-
-  bool isLookupCascading;
+  // Determine whether a lookup here will cascade.
+  isLookupCascading = false;
   if (tracker) {
     if (auto maybeLookupCascade = checkLookupCascading())
       isLookupCascading = maybeLookupCascade.getValue();
     else
       tracker = nullptr;
   }
+}
+
+/// Determine whether the given declaration is an acceptable lookup
+/// result when searching from the given DeclContext.
+static bool isAcceptableLookupResult(const DeclContext *dc,
+                                     NLOptions options,
+                                     ValueDecl *decl,
+                                     bool onlyCompleteObjectInits) {
+  // Filter out designated initializers, if requested.
+  if (onlyCompleteObjectInits) {
+    if (auto ctor = dyn_cast<ConstructorDecl>(decl)) {
+      if (isa<ClassDecl>(ctor->getDeclContext()) && !ctor->isInheritable())
+        return false;
+    } else {
+      return false;
+    }
+  }
+
+  // Ignore stub implementations.
+  if (auto ctor = dyn_cast<ConstructorDecl>(decl)) {
+    if (ctor->hasStubImplementation())
+      return false;
+  }
+
+  // Check access.
+  if (!(options & NL_IgnoreAccessControl)) {
+    return decl->isAccessibleFrom(dc);
+  }
+
+  return true;
+}
+
+/// Only name lookup has gathered a set of results, perform any necessary
+/// steps to prune the result set before returning it to the caller.
+static bool finishLookup(const DeclContext *dc, NLOptions options,
+                         SmallVectorImpl<ValueDecl *> &decls) {
+  // If we're supposed to remove overridden declarations, do so now.
+  if (options & NL_RemoveOverridden)
+    removeOverriddenDecls(decls);
+
+  // If we're supposed to remove shadowed/hidden declarations, do so now.
+  ModuleDecl *M = dc->getParentModule();
+  if (options & NL_RemoveNonVisible)
+    removeShadowedDecls(decls, M);
+
+  if (auto *debugClient = M->getDebugClient())
+    filterForDiscriminator(decls, debugClient);
+
+  // We're done. Report success/failure.
+  return !decls.empty();
+}
+
+/// Inspect the given type to determine which nominal type declarations it
+/// directly references, to facilitate name lookup into those types.
+static void extractDirectlyReferencedNominalTypes(
+              Type type, SmallVectorImpl<NominalTypeDecl *> &decls) {
+  if (auto nominal = type->getAnyNominal()) {
+    decls.push_back(nominal);
+    return;
+  }
+
+  if (auto archetypeTy = type->getAs<ArchetypeType>()) {
+    // Look in the protocols to which the archetype conforms (always).
+    for (auto proto : archetypeTy->getConformsTo())
+      decls.push_back(proto);
+
+    // Look into the superclasses of this archetype.
+    if (auto superclass = archetypeTy->getSuperclass()) {
+      if (auto superclassDecl = superclass->getClassOrBoundGenericClass())
+        decls.push_back(superclassDecl);
+    }
+
+    return;
+  }
+
+  if (auto compositionTy = type->getAs<ProtocolCompositionType>()) {
+    auto layout = compositionTy->getExistentialLayout();
+
+    for (auto proto : layout.getProtocols()) {
+      auto *protoDecl = proto->getDecl();
+      decls.push_back(protoDecl);
+    }
+
+    if (auto superclass = layout.explicitSuperclass) {
+      auto *superclassDecl = superclass->getClassOrBoundGenericClass();
+      if (superclassDecl)
+        decls.push_back(superclassDecl);
+    }
+
+    return;
+  }
+
+  llvm_unreachable("Not a type containing nominal types?");
+}
+
+bool DeclContext::lookupQualified(Type type,
+                                  DeclName member,
+                                  NLOptions options,
+                                  LazyResolver *typeResolver,
+                                  SmallVectorImpl<ValueDecl *> &decls) const {
+  using namespace namelookup;
+  assert(decls.empty() && "additive lookup not supported");
+
+  // Handle AnyObject lookup.
+  if (type->isAnyObject())
+    return lookupAnyObject(member, options, decls);
+
+  // Handle lookup in a module.
+  if (auto moduleTy = type->getAs<ModuleType>())
+    return lookupQualified(moduleTy->getModule(), member, options, decls);
+
+  // Figure out which nominal types we will look into.
+  SmallVector<NominalTypeDecl *, 4> nominalTypesToLookInto;
+  extractDirectlyReferencedNominalTypes(type, nominalTypesToLookInto);
+
+  SmallVector<TypeDecl *, 4> declsToLookInto(nominalTypesToLookInto.begin(),
+                                             nominalTypesToLookInto.end());
+  return lookupQualified(declsToLookInto, member, options, decls);
+}
+
+bool DeclContext::lookupQualified(ArrayRef<TypeDecl *> typeDecls,
+                                  DeclName member,
+                                  NLOptions options,
+                                  SmallVectorImpl<ValueDecl *> &decls) const {
+  using namespace namelookup;
+  assert(decls.empty() && "additive lookup not supported");
+
+  // If we have a module, look in it.
+  if (typeDecls.size() == 1 && isa<ModuleDecl>(typeDecls[0])) {
+    return lookupQualified(cast<ModuleDecl>(typeDecls[0]), member, options,
+                           decls);
+  }
+
+  // Configure lookup and dig out the tracker.
+  ReferencedNameTracker *tracker = nullptr;
+  bool isLookupCascading;
+  configureLookup(this, options, tracker, isLookupCascading);
+
+  // Tracking for the nominal types we'll visit.
+  SmallVector<NominalTypeDecl *, 4> stack;
+  llvm::SmallPtrSet<NominalTypeDecl *, 4> visited;
+  bool sawClassDecl = false;
+
+  // Add the given nominal type to the stack.
+  auto addNominalType = [&](NominalTypeDecl *nominal) {
+    if (!visited.insert(nominal).second)
+      return false;
+
+    if (isa<ClassDecl>(nominal))
+      sawClassDecl = true;
+
+    stack.push_back(nominal);
+    return true;
+  };
+
+  // Look through the type declarations we were given, resolving
+  ASTContext &ctx = getASTContext();
+  for (auto typeDecl : typeDecls) {
+    // Add nominal types directly.
+    if (auto nominal = dyn_cast<NominalTypeDecl>(typeDecl)) {
+      addNominalType(nominal);
+      continue;
+    }
+
+    // For typealiases, extract nominal type declarations from the underlying
+    // type of the typealias.
+    if (auto typealias = dyn_cast<TypeAliasDecl>(typeDecl)) {
+      SmallVector<NominalTypeDecl *, 4> nominalTypeDecls;
+      if (!typealias->getUnderlyingTypeLoc().getType()) {
+        auto lazyResolver = ctx.getLazyResolver();
+        assert(lazyResolver && "Cannot resolve underlying type of typealias");
+        lazyResolver->resolveDeclSignature(typealias);
+      }
+
+      extractDirectlyReferencedNominalTypes(
+          typealias->getUnderlyingTypeLoc().getType(),
+          nominalTypeDecls);
+      for (auto nominal : nominalTypeDecls)
+        addNominalType(nominal);
+      continue;
+    }
+
+    llvm_unreachable("Cannot look into multiple modules "
+                     "or any type parameter");
+  }
+
+  // Whether we only want to return complete object initializers.
+  bool onlyCompleteObjectInits = false;
 
   // Visit all of the nominal types we know about, discovering any others
   // we need along the way.
+  auto typeResolver = ctx.getLazyResolver();
+  bool wantProtocolMembers = (options & NL_ProtocolMembers);
   while (!stack.empty()) {
     auto current = stack.back();
     stack.pop_back();
@@ -1732,7 +1786,8 @@
       if ((options & NL_OnlyTypes) && !isa<TypeDecl>(decl))
         continue;
 
-      if (isAcceptableDecl(current, decl))
+      if (isAcceptableLookupResult(this, options, decl,
+                                   onlyCompleteObjectInits))
         decls.push_back(decl);
     }
 
@@ -1762,7 +1817,7 @@
     if (!wantProtocolMembers && !currentIsProtocol)
       continue;
 
-    if (checkProtocolSuperclass) {
+    if (!sawClassDecl) {
       if (auto *protoDecl = dyn_cast<ProtocolDecl>(current)) {
         if (auto superclassDecl = protoDecl->getSuperclassDecl()) {
           visited.insert(superclassDecl);
@@ -1784,66 +1839,114 @@
       wantProtocolMembers = false;
   }
 
-  // If we want to perform lookup into all classes, do so now.
-  if (wantLookupInAllClasses) {
-    if (tracker)
-      tracker->addDynamicLookupName(member.getBaseName(), isLookupCascading);
+  return finishLookup(this, options, decls);
+}
 
-    // Collect all of the visible declarations.
-    SmallVector<ValueDecl *, 4> allDecls;
-    forAllVisibleModules(this, [&](ModuleDecl::ImportedModule import) {
-      import.second->lookupClassMember(import.first, member, allDecls);
-    });
+bool DeclContext::lookupQualified(ModuleDecl *module, DeclName member,
+                                  NLOptions options,
+                                  SmallVectorImpl<ValueDecl *> &decls) const {
+  using namespace namelookup;
+  assert(decls.empty() && "additive lookup not supported");
 
-    // For each declaration whose context is not something we've
-    // already visited above, add it to the list of declarations.
-    llvm::SmallPtrSet<ValueDecl *, 4> knownDecls;
-    for (auto decl : allDecls) {
-      // If we're performing a type lookup, don't even attempt to validate
-      // the decl if its not a type.
-      if ((options & NL_OnlyTypes) && !isa<TypeDecl>(decl))
-        continue;
+  // Configure lookup and dig out the tracker.
+  ReferencedNameTracker *tracker = nullptr;
+  bool isLookupCascading;
+  configureLookup(this, options, tracker, isLookupCascading);
 
-      // If the declaration is not @objc, it cannot be called dynamically.
-      if (!decl->isObjC())
-        continue;
-
-      // If the declaration has an override, name lookup will also have
-      // found the overridden method. Skip this declaration, because we
-      // prefer the overridden method.
-      if (decl->getOverriddenDecl())
-        continue;
-
-      auto dc = decl->getDeclContext();
-      auto nominal = dyn_cast<NominalTypeDecl>(dc);
-      if (!nominal) {
-        auto ext = cast<ExtensionDecl>(dc);
-        nominal = ext->getExtendedType()->getAnyNominal();
-        assert(nominal && "Couldn't find nominal type?");
-      }
-
-      // If we didn't visit this nominal type above, add this
-      // declaration to the list.
-      if (!visited.count(nominal) && knownDecls.insert(decl).second &&
-          isAcceptableDecl(nominal, decl))
-        decls.push_back(decl);
+  ASTContext &ctx = getASTContext();
+  auto topLevelScope = getModuleScopeContext();
+  if (module == topLevelScope->getParentModule()) {
+    if (tracker) {
+      recordLookupOfTopLevelName(topLevelScope, member, isLookupCascading);
     }
+    lookupInModule(module, /*accessPath=*/{}, member, decls,
+                   NLKind::QualifiedLookup, ResolutionKind::Overloadable,
+                   ctx.getLazyResolver(), topLevelScope);
+  } else {
+    // Note: This is a lookup into another module. Unless we're compiling
+    // multiple modules at once, or if the other module re-exports this one,
+    // it shouldn't be possible to have a dependency from that module on
+    // anything in this one.
+
+    // Perform the lookup in all imports of this module.
+    forAllVisibleModules(this,
+                         [&](const ModuleDecl::ImportedModule &import) -> bool {
+      if (import.second != module)
+        return true;
+      lookupInModule(import.second, import.first, member, decls,
+                     NLKind::QualifiedLookup, ResolutionKind::Overloadable,
+                     ctx.getLazyResolver(), topLevelScope);
+      // If we're able to do an unscoped lookup, we see everything. No need
+      // to keep going.
+      return !import.first.empty();
+    });
   }
 
-  // If we're supposed to remove overridden declarations, do so now.
-  if (options & NL_RemoveOverridden)
-    removeOverriddenDecls(decls);
+  llvm::SmallPtrSet<ValueDecl *, 4> knownDecls;
+  decls.erase(std::remove_if(decls.begin(), decls.end(),
+                             [&](ValueDecl *vd) -> bool {
+    // If we're performing a type lookup, skip non-types.
+    if ((options & NL_OnlyTypes) && !isa<TypeDecl>(vd))
+      return true;
 
-  // If we're supposed to remove shadowed/hidden declarations, do so now.
-  ModuleDecl *M = getParentModule();
-  if (options & NL_RemoveNonVisible)
-    removeShadowedDecls(decls, M);
+    return !knownDecls.insert(vd).second;
+  }), decls.end());
 
-  if (auto *debugClient = M->getDebugClient())
-    filterForDiscriminator(decls, debugClient);
+  return finishLookup(this, options, decls);
+}
 
-  // We're done. Report success/failure.
-  return !decls.empty();
+bool DeclContext::lookupAnyObject(DeclName member, NLOptions options,
+                                  SmallVectorImpl<ValueDecl *> &decls) const {
+  using namespace namelookup;
+  assert(decls.empty() && "additive lookup not supported");
+
+  // Configure lookup and dig out the tracker.
+  ReferencedNameTracker *tracker = nullptr;
+  bool isLookupCascading;
+  configureLookup(this, options, tracker, isLookupCascading);
+
+  // Record this lookup.
+  if (tracker)
+    tracker->addDynamicLookupName(member.getBaseName(), isLookupCascading);
+
+  // Type-only lookup won't find anything on AnyObject.
+  if (options & NL_OnlyTypes)
+    return false;
+
+  // Collect all of the visible declarations.
+  SmallVector<ValueDecl *, 4> allDecls;
+  forAllVisibleModules(this, [&](ModuleDecl::ImportedModule import) {
+    import.second->lookupClassMember(import.first, member, allDecls);
+  });
+
+  // For each declaration whose context is not something we've
+  // already visited above, add it to the list of declarations.
+  llvm::SmallPtrSet<ValueDecl *, 4> knownDecls;
+  for (auto decl : allDecls) {
+    // If the declaration is not @objc, it cannot be called dynamically.
+    if (!decl->isObjC())
+      continue;
+
+    // If the declaration has an override, name lookup will also have
+    // found the overridden method. Skip this declaration, because we
+    // prefer the overridden method.
+    if (decl->getOverriddenDecl())
+      continue;
+
+    auto dc = decl->getDeclContext();
+    auto nominal = dc->getAsNominalTypeOrNominalTypeExtensionContext();
+    assert(nominal && "Couldn't find nominal type?");
+
+    // If we didn't see this declaration before, and it's an acceptable
+    // result, add it to the list.
+    // declaration to the list.
+    if (knownDecls.insert(decl).second &&
+        isAcceptableLookupResult(this, options, decl,
+                                 /*onlyCompleteObjectInits=*/false))
+      decls.push_back(decl);
+  }
+
+  return finishLookup(this, options, decls);
 }
 
 void DeclContext::lookupAllObjCMethods(