Merge pull request #15450 from DougGregor/infer-ext-generic-typealias

Retain type sugar for extension declarations that name generic typealiases
diff --git a/lib/AST/DeclContext.cpp b/lib/AST/DeclContext.cpp
index 9d3c017..21c40b0 100644
--- a/lib/AST/DeclContext.cpp
+++ b/lib/AST/DeclContext.cpp
@@ -47,15 +47,33 @@
 
 GenericTypeDecl *
 DeclContext::getAsTypeOrTypeExtensionContext() const {
-  if (auto decl = const_cast<Decl*>(getAsDeclOrDeclExtensionContext())) {
-    if (auto ED = dyn_cast<ExtensionDecl>(decl)) {
-      if (auto type = ED->getExtendedType())
-        return type->getAnyNominal();
-      return nullptr;
+  auto decl = const_cast<Decl*>(getAsDeclOrDeclExtensionContext());
+  if (!decl) return nullptr;
+
+  auto ext = dyn_cast<ExtensionDecl>(decl);
+  if (!ext) return dyn_cast<GenericTypeDecl>(decl);
+
+  auto type = ext->getExtendedType();
+  if (!type) return nullptr;
+
+  do {
+    // expected case: we reference a nominal type (potentially through sugar)
+    if (auto nominal = type->getAnyNominal())
+      return nominal;
+
+    // early type checking case: we have a typealias reference that is still
+    // unsugared, so explicitly look through the underlying type if there is
+    // one.
+    if (auto typealias =
+          dyn_cast_or_null<TypeAliasDecl>(type->getAnyGeneric())) {
+      type = typealias->getUnderlyingTypeLoc().getType();
+      if (!type) return nullptr;
+
+      continue;
     }
-    return dyn_cast<GenericTypeDecl>(decl);
-  }
-  return nullptr;
+
+    return nullptr;
+  } while (true);
 }
 
 /// If this DeclContext is a NominalType declaration or an
diff --git a/lib/Sema/TypeCheckDecl.cpp b/lib/Sema/TypeCheckDecl.cpp
index df3c75b..4d451f2 100644
--- a/lib/Sema/TypeCheckDecl.cpp
+++ b/lib/Sema/TypeCheckDecl.cpp
@@ -8197,46 +8197,139 @@
   assert(D->hasAccess());
 }
 
+bool swift::isPassThroughTypealias(TypeAliasDecl *typealias) {
+  // Pass-through only makes sense when the typealias refers to a nominal
+  // type.
+  Type underlyingType = typealias->getUnderlyingTypeLoc().getType();
+  auto nominal = underlyingType->getAnyNominal();
+  if (!nominal) return false;
+
+  // Check that the nominal type and the typealias are either both generic
+  // at this level or neither are.
+  if (nominal->isGeneric() != typealias->isGeneric())
+    return false;
+
+  // Make sure either both have generic signatures or neither do.
+  auto nominalSig = nominal->getGenericSignature();
+  auto typealiasSig = typealias->getGenericSignature();
+  if (static_cast<bool>(nominalSig) != static_cast<bool>(typealiasSig))
+    return false;
+
+  // If neither is generic, we're done: it's a pass-through alias.
+  if (!nominalSig) return true;
+
+  // Check that the type parameters are the same the whole way through.
+  auto nominalGenericParams = nominalSig->getGenericParams();
+  auto typealiasGenericParams = typealiasSig->getGenericParams();
+  if (nominalGenericParams.size() != typealiasGenericParams.size())
+    return false;
+  if (!std::equal(nominalGenericParams.begin(), nominalGenericParams.end(),
+                  typealiasGenericParams.begin(),
+                  [](GenericTypeParamType *gp1, GenericTypeParamType *gp2) {
+                    return gp1->isEqual(gp2);
+                  }))
+    return false;
+
+  // If neither is generic at this level, we have a pass-through typealias.
+  if (!typealias->isGeneric()) return true;
+
+  auto boundGenericType = underlyingType->getAs<BoundGenericType>();
+  if (!boundGenericType) return false;
+
+  // If our arguments line up with our innermost generic parameters, it's
+  // a passthrough typealias.
+  auto innermostGenericParams = typealiasSig->getInnermostGenericParams();
+  auto boundArgs = boundGenericType->getGenericArgs();
+  if (boundArgs.size() != innermostGenericParams.size())
+    return false;
+
+  return std::equal(boundArgs.begin(), boundArgs.end(),
+                    innermostGenericParams.begin(),
+                    [](Type arg, GenericTypeParamType *gp) {
+                      return arg->isEqual(gp);
+                    });
+}
+
 /// Form the interface type of an extension from the raw type and the
 /// extension's list of generic parameters.
-static Type formExtensionInterfaceType(Type type,
-                                       GenericParamList *genericParams) {
+static Type formExtensionInterfaceType(TypeChecker &tc, ExtensionDecl *ext,
+                                       Type type,
+                                       GenericParamList *genericParams,
+                                       bool &mustInferRequirements) {
   // Find the nominal type declaration and its parent type.
   Type parentType;
-  NominalTypeDecl *nominal;
+  GenericTypeDecl *genericDecl;
   if (auto unbound = type->getAs<UnboundGenericType>()) {
     parentType = unbound->getParent();
-    nominal = cast<NominalTypeDecl>(unbound->getDecl());
+    genericDecl = unbound->getDecl();
   } else {
     if (type->is<ProtocolCompositionType>())
       type = type->getCanonicalType();
     auto nominalType = type->castTo<NominalType>();
     parentType = nominalType->getParent();
-    nominal = nominalType->getDecl();
+    genericDecl = nominalType->getDecl();
   }
 
   // Reconstruct the parent, if there is one.
   if (parentType) {
     // Build the nested extension type.
-    auto parentGenericParams = nominal->getGenericParams()
+    auto parentGenericParams = genericDecl->getGenericParams()
                                  ? genericParams->getOuterParameters()
                                  : genericParams;
-    parentType = formExtensionInterfaceType(parentType, parentGenericParams);
+    parentType =
+      formExtensionInterfaceType(tc, ext, parentType, parentGenericParams,
+                                 mustInferRequirements);
   }
 
-  // If we don't have generic parameters at this level, just build the result.
-  if (!nominal->getGenericParams() || isa<ProtocolDecl>(nominal)) {
-    return NominalType::get(nominal, parentType,
-                            nominal->getASTContext());
+  // Find the nominal type.
+  auto nominal = dyn_cast<NominalTypeDecl>(genericDecl);
+  auto typealias = dyn_cast<TypeAliasDecl>(genericDecl);
+  if (!nominal) {
+    Type underlyingType = typealias->getUnderlyingTypeLoc().getType();
+    nominal = underlyingType->getNominalOrBoundGenericNominal();
   }
 
-  // Form the bound generic type with the type parameters provided.
+  // Form the result.
+  Type resultType;
   SmallVector<Type, 2> genericArgs;
-  for (auto gp : *genericParams) {
-    genericArgs.push_back(gp->getDeclaredInterfaceType());
+  if (!nominal->isGeneric() || isa<ProtocolDecl>(nominal)) {
+    resultType = NominalType::get(nominal, parentType,
+                                  nominal->getASTContext());
+  } else {
+    // Form the bound generic type with the type parameters provided.
+    for (auto gp : *genericParams) {
+      genericArgs.push_back(gp->getDeclaredInterfaceType());
+    }
+
+    resultType = BoundGenericType::get(nominal, parentType, genericArgs);
   }
 
-  return BoundGenericType::get(nominal, parentType, genericArgs);
+  // If we have a typealias, try to form type sugar.
+  if (typealias && isPassThroughTypealias(typealias)) {
+    auto typealiasSig = typealias->getGenericSignature();
+    if (typealiasSig) {
+      auto subMap =
+        typealiasSig->getSubstitutionMap(
+                              [](SubstitutableType *type) -> Type {
+                                return Type(type);
+                              },
+                              [](CanType dependentType,
+                                 Type replacementType,
+                                 ProtocolType *protoType) {
+                                auto proto = protoType->getDecl();
+                                return ProtocolConformanceRef(proto);
+                              });
+
+      resultType = BoundNameAliasType::get(typealias, parentType,
+                                           subMap, resultType);
+
+      mustInferRequirements = true;
+    } else {
+      resultType = typealias->getDeclaredInterfaceType();
+    }
+  }
+
+  return resultType;
 }
 
 /// Visit the given generic parameter lists from the outermost to the innermost,
@@ -8258,7 +8351,10 @@
   assert(!ext->getGenericEnvironment());
 
   // Form the interface type of the extension.
-  Type extInterfaceType = formExtensionInterfaceType(type, genericParams);
+  bool mustInferRequirements = false;
+  Type extInterfaceType =
+    formExtensionInterfaceType(tc, ext, type, genericParams,
+                               mustInferRequirements);
 
   // Prepare all of the generic parameter lists for generic signature
   // validation.
@@ -8280,7 +8376,8 @@
   auto *env = tc.checkGenericEnvironment(genericParams,
                                          ext->getDeclContext(), nullptr,
                                          /*allowConcreteGenericParams=*/true,
-                                         ext, inferExtendedTypeReqs);
+                                         ext, inferExtendedTypeReqs,
+                                         mustInferRequirements);
 
   // Validate the generic parameters for the last time, to splat down
   // actual archetypes.
@@ -8319,7 +8416,14 @@
     return;
 
   // Validate the nominal type declaration being extended.
-  auto nominal = extendedType->getAnyNominal();
+  NominalTypeDecl *nominal = extendedType->getAnyNominal();
+  if (!nominal) {
+    auto unbound = cast<UnboundGenericType>(extendedType.getPointer());
+    auto typealias = cast<TypeAliasDecl>(unbound->getDecl());
+    validateDecl(typealias);
+
+    nominal = typealias->getUnderlyingTypeLoc().getType()->getAnyNominal();
+  }
   validateDecl(nominal);
 
   if (nominal->getGenericParamsOfContext()) {
diff --git a/lib/Sema/TypeCheckGeneric.cpp b/lib/Sema/TypeCheckGeneric.cpp
index c435b9e..d620bc7 100644
--- a/lib/Sema/TypeCheckGeneric.cpp
+++ b/lib/Sema/TypeCheckGeneric.cpp
@@ -1147,13 +1147,14 @@
                       bool allowConcreteGenericParams,
                       ExtensionDecl *ext,
                       llvm::function_ref<void(GenericSignatureBuilder &)>
-                        inferRequirements) {
+                        inferRequirements,
+                      bool mustInferRequirements) {
   assert(genericParams && "Missing generic parameters?");
   bool recursivelyVisitGenericParams =
     genericParams->getOuterParameters() && !parentSig;
 
   GenericSignature *sig;
-  if (!ext || ext->getTrailingWhereClause() ||
+  if (!ext || mustInferRequirements || ext->getTrailingWhereClause() ||
       getExtendedTypeGenericDepth(ext) != genericParams->getDepth()) {
     // Collect the generic parameters.
     SmallVector<GenericTypeParamType *, 4> allGenericParams;
diff --git a/lib/Sema/TypeChecker.cpp b/lib/Sema/TypeChecker.cpp
index 215d3c0..70da39f 100644
--- a/lib/Sema/TypeChecker.cpp
+++ b/lib/Sema/TypeChecker.cpp
@@ -323,7 +323,8 @@
       auto extendedNominal = aliasDecl->getDeclaredInterfaceType()->getAnyNominal();
       if (extendedNominal) {
         extendedType = extendedNominal->getDeclaredType();
-        ED->getExtendedTypeLoc().setType(extendedType);
+        if (!isPassThroughTypealias(aliasDecl))
+          ED->getExtendedTypeLoc().setType(extendedType);
       }
     }
   }
diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h
index 76c1f35..0e5d12f 100644
--- a/lib/Sema/TypeChecker.h
+++ b/lib/Sema/TypeChecker.h
@@ -1406,7 +1406,8 @@
                         bool allowConcreteGenericParams,
                         ExtensionDecl *ext,
                         llvm::function_ref<void(GenericSignatureBuilder &)>
-                          inferRequirements);
+                          inferRequirements,
+                        bool mustInferRequirements);
 
   /// Construct a new generic environment for the given declaration context.
   ///
@@ -1426,7 +1427,8 @@
                         ExtensionDecl *ext) {
     return checkGenericEnvironment(genericParams, dc, outerSignature,
                                    allowConcreteGenericParams, ext,
-                                   [&](GenericSignatureBuilder &) { });
+                                   [&](GenericSignatureBuilder &) { },
+                                   /*mustInferRequirements=*/false);
   }
 
   /// Validate the signature of a generic type.
@@ -2531,7 +2533,27 @@
 bool isAcceptableDynamicMemberLookupSubscript(SubscriptDecl *decl,
                                               DeclContext *DC,
                                               TypeChecker &TC);
-  
+
+/// Determine whether this is a "pass-through" typealias, which has the
+/// same type parameters as the nominal type it references and specializes
+/// the underlying nominal type with exactly those type parameters.
+/// For example, the following typealias \c GX is a pass-through typealias:
+///
+/// \code
+/// struct X<T, U> { }
+/// typealias GX<A, B> = X<A, B>
+/// \endcode
+///
+/// whereas \c GX2 and \c GX3 are not pass-through because \c GX2 has
+/// different type parameters and \c GX3 doesn't pass its type parameters
+/// directly through.
+///
+/// \code
+/// typealias GX2<A> = X<A, A>
+/// typealias GX3<A, B> = X<B, A>
+/// \endcode
+bool isPassThroughTypealias(TypeAliasDecl *typealias);
+
 } // end namespace swift
 
 #endif
diff --git a/test/Compatibility/stdlib_generic_typealiases.swift b/test/Compatibility/stdlib_generic_typealiases.swift
new file mode 100644
index 0000000..4a5d6ec
--- /dev/null
+++ b/test/Compatibility/stdlib_generic_typealiases.swift
@@ -0,0 +1,18 @@
+// RUN: %target-typecheck-verify-swift
+
+struct RequiresComparable<T: Comparable> { }
+
+extension CountableRange { // expected-warning{{'CountableRange' is deprecated: renamed to 'Range'}}
+  // expected-note@-1{{use 'Range' instead}}{{11-25=Range}}
+  func testComparable() {
+    _ = RequiresComparable<Bound>()
+  }
+}
+
+struct RequiresHashable<T: Hashable> { }
+
+extension DictionaryIndex {
+  func testHashable() {
+    _ = RequiresHashable<Key>()
+  }
+}
diff --git a/test/Generics/requirement_inference.swift b/test/Generics/requirement_inference.swift
index 28f1d9f..02cd9c0 100644
--- a/test/Generics/requirement_inference.swift
+++ b/test/Generics/requirement_inference.swift
@@ -481,10 +481,28 @@
 }
 
 // Extend using the inferred requirement.
-// FIXME: Currently broken.
 extension X1WithP2 {
   func f() {
-    _ = X5<T>() // FIXME: expected-error{{type 'T' does not conform to protocol 'P2'}}
+    _ = X5<T>() // okay: inferred T: P2 from generic typealias
+  }
+}
+
+extension X1: P1 {
+  func p1() { }
+}
+
+typealias X1WithP2Changed<T: P2> = X1<X1<T>>
+typealias X1WithP2MoreArgs<T: P2, U> = X1<T>
+
+extension X1WithP2Changed {
+  func bad1() {
+    _ = X5<T>() // expected-error{{type 'T' does not conform to protocol 'P2'}}
+  }
+}
+
+extension X1WithP2MoreArgs {
+  func bad2() {
+    _ = X5<T>() // expected-error{{type 'T' does not conform to protocol 'P2'}}
   }
 }
 
diff --git a/validation-test/Serialization/rdar29694978.swift b/validation-test/Serialization/rdar29694978.swift
index 9530375..8e92b38 100644
--- a/validation-test/Serialization/rdar29694978.swift
+++ b/validation-test/Serialization/rdar29694978.swift
@@ -22,7 +22,7 @@
 
 // CHECK-DAG: typealias MyGenericType<T> = GenericType<T>
 typealias MyGenericType<T: NSObject> = GenericType<T>
-// CHECK-DAG: extension GenericType where Element : AnyObject
+// CHECK-DAG: extension GenericType where Element : NSObject
 extension MyGenericType {}
 // CHECK-DAG: extension GenericType where Element == NSObject
 extension MyGenericType where Element == NSObject {}