Merge pull request #15450 from DougGregor/infer-ext-generic-typealias
Retain type sugar for extension declarations that name generic typealiases
diff --git a/lib/AST/DeclContext.cpp b/lib/AST/DeclContext.cpp
index 9d3c017..21c40b0 100644
--- a/lib/AST/DeclContext.cpp
+++ b/lib/AST/DeclContext.cpp
@@ -47,15 +47,33 @@
GenericTypeDecl *
DeclContext::getAsTypeOrTypeExtensionContext() const {
- if (auto decl = const_cast<Decl*>(getAsDeclOrDeclExtensionContext())) {
- if (auto ED = dyn_cast<ExtensionDecl>(decl)) {
- if (auto type = ED->getExtendedType())
- return type->getAnyNominal();
- return nullptr;
+ auto decl = const_cast<Decl*>(getAsDeclOrDeclExtensionContext());
+ if (!decl) return nullptr;
+
+ auto ext = dyn_cast<ExtensionDecl>(decl);
+ if (!ext) return dyn_cast<GenericTypeDecl>(decl);
+
+ auto type = ext->getExtendedType();
+ if (!type) return nullptr;
+
+ do {
+ // expected case: we reference a nominal type (potentially through sugar)
+ if (auto nominal = type->getAnyNominal())
+ return nominal;
+
+ // early type checking case: we have a typealias reference that is still
+ // unsugared, so explicitly look through the underlying type if there is
+ // one.
+ if (auto typealias =
+ dyn_cast_or_null<TypeAliasDecl>(type->getAnyGeneric())) {
+ type = typealias->getUnderlyingTypeLoc().getType();
+ if (!type) return nullptr;
+
+ continue;
}
- return dyn_cast<GenericTypeDecl>(decl);
- }
- return nullptr;
+
+ return nullptr;
+ } while (true);
}
/// If this DeclContext is a NominalType declaration or an
diff --git a/lib/Sema/TypeCheckDecl.cpp b/lib/Sema/TypeCheckDecl.cpp
index df3c75b..4d451f2 100644
--- a/lib/Sema/TypeCheckDecl.cpp
+++ b/lib/Sema/TypeCheckDecl.cpp
@@ -8197,46 +8197,139 @@
assert(D->hasAccess());
}
+bool swift::isPassThroughTypealias(TypeAliasDecl *typealias) {
+ // Pass-through only makes sense when the typealias refers to a nominal
+ // type.
+ Type underlyingType = typealias->getUnderlyingTypeLoc().getType();
+ auto nominal = underlyingType->getAnyNominal();
+ if (!nominal) return false;
+
+ // Check that the nominal type and the typealias are either both generic
+ // at this level or neither are.
+ if (nominal->isGeneric() != typealias->isGeneric())
+ return false;
+
+ // Make sure either both have generic signatures or neither do.
+ auto nominalSig = nominal->getGenericSignature();
+ auto typealiasSig = typealias->getGenericSignature();
+ if (static_cast<bool>(nominalSig) != static_cast<bool>(typealiasSig))
+ return false;
+
+ // If neither is generic, we're done: it's a pass-through alias.
+ if (!nominalSig) return true;
+
+ // Check that the type parameters are the same the whole way through.
+ auto nominalGenericParams = nominalSig->getGenericParams();
+ auto typealiasGenericParams = typealiasSig->getGenericParams();
+ if (nominalGenericParams.size() != typealiasGenericParams.size())
+ return false;
+ if (!std::equal(nominalGenericParams.begin(), nominalGenericParams.end(),
+ typealiasGenericParams.begin(),
+ [](GenericTypeParamType *gp1, GenericTypeParamType *gp2) {
+ return gp1->isEqual(gp2);
+ }))
+ return false;
+
+ // If neither is generic at this level, we have a pass-through typealias.
+ if (!typealias->isGeneric()) return true;
+
+ auto boundGenericType = underlyingType->getAs<BoundGenericType>();
+ if (!boundGenericType) return false;
+
+ // If our arguments line up with our innermost generic parameters, it's
+ // a passthrough typealias.
+ auto innermostGenericParams = typealiasSig->getInnermostGenericParams();
+ auto boundArgs = boundGenericType->getGenericArgs();
+ if (boundArgs.size() != innermostGenericParams.size())
+ return false;
+
+ return std::equal(boundArgs.begin(), boundArgs.end(),
+ innermostGenericParams.begin(),
+ [](Type arg, GenericTypeParamType *gp) {
+ return arg->isEqual(gp);
+ });
+}
+
/// Form the interface type of an extension from the raw type and the
/// extension's list of generic parameters.
-static Type formExtensionInterfaceType(Type type,
- GenericParamList *genericParams) {
+static Type formExtensionInterfaceType(TypeChecker &tc, ExtensionDecl *ext,
+ Type type,
+ GenericParamList *genericParams,
+ bool &mustInferRequirements) {
// Find the nominal type declaration and its parent type.
Type parentType;
- NominalTypeDecl *nominal;
+ GenericTypeDecl *genericDecl;
if (auto unbound = type->getAs<UnboundGenericType>()) {
parentType = unbound->getParent();
- nominal = cast<NominalTypeDecl>(unbound->getDecl());
+ genericDecl = unbound->getDecl();
} else {
if (type->is<ProtocolCompositionType>())
type = type->getCanonicalType();
auto nominalType = type->castTo<NominalType>();
parentType = nominalType->getParent();
- nominal = nominalType->getDecl();
+ genericDecl = nominalType->getDecl();
}
// Reconstruct the parent, if there is one.
if (parentType) {
// Build the nested extension type.
- auto parentGenericParams = nominal->getGenericParams()
+ auto parentGenericParams = genericDecl->getGenericParams()
? genericParams->getOuterParameters()
: genericParams;
- parentType = formExtensionInterfaceType(parentType, parentGenericParams);
+ parentType =
+ formExtensionInterfaceType(tc, ext, parentType, parentGenericParams,
+ mustInferRequirements);
}
- // If we don't have generic parameters at this level, just build the result.
- if (!nominal->getGenericParams() || isa<ProtocolDecl>(nominal)) {
- return NominalType::get(nominal, parentType,
- nominal->getASTContext());
+ // Find the nominal type.
+ auto nominal = dyn_cast<NominalTypeDecl>(genericDecl);
+ auto typealias = dyn_cast<TypeAliasDecl>(genericDecl);
+ if (!nominal) {
+ Type underlyingType = typealias->getUnderlyingTypeLoc().getType();
+ nominal = underlyingType->getNominalOrBoundGenericNominal();
}
- // Form the bound generic type with the type parameters provided.
+ // Form the result.
+ Type resultType;
SmallVector<Type, 2> genericArgs;
- for (auto gp : *genericParams) {
- genericArgs.push_back(gp->getDeclaredInterfaceType());
+ if (!nominal->isGeneric() || isa<ProtocolDecl>(nominal)) {
+ resultType = NominalType::get(nominal, parentType,
+ nominal->getASTContext());
+ } else {
+ // Form the bound generic type with the type parameters provided.
+ for (auto gp : *genericParams) {
+ genericArgs.push_back(gp->getDeclaredInterfaceType());
+ }
+
+ resultType = BoundGenericType::get(nominal, parentType, genericArgs);
}
- return BoundGenericType::get(nominal, parentType, genericArgs);
+ // If we have a typealias, try to form type sugar.
+ if (typealias && isPassThroughTypealias(typealias)) {
+ auto typealiasSig = typealias->getGenericSignature();
+ if (typealiasSig) {
+ auto subMap =
+ typealiasSig->getSubstitutionMap(
+ [](SubstitutableType *type) -> Type {
+ return Type(type);
+ },
+ [](CanType dependentType,
+ Type replacementType,
+ ProtocolType *protoType) {
+ auto proto = protoType->getDecl();
+ return ProtocolConformanceRef(proto);
+ });
+
+ resultType = BoundNameAliasType::get(typealias, parentType,
+ subMap, resultType);
+
+ mustInferRequirements = true;
+ } else {
+ resultType = typealias->getDeclaredInterfaceType();
+ }
+ }
+
+ return resultType;
}
/// Visit the given generic parameter lists from the outermost to the innermost,
@@ -8258,7 +8351,10 @@
assert(!ext->getGenericEnvironment());
// Form the interface type of the extension.
- Type extInterfaceType = formExtensionInterfaceType(type, genericParams);
+ bool mustInferRequirements = false;
+ Type extInterfaceType =
+ formExtensionInterfaceType(tc, ext, type, genericParams,
+ mustInferRequirements);
// Prepare all of the generic parameter lists for generic signature
// validation.
@@ -8280,7 +8376,8 @@
auto *env = tc.checkGenericEnvironment(genericParams,
ext->getDeclContext(), nullptr,
/*allowConcreteGenericParams=*/true,
- ext, inferExtendedTypeReqs);
+ ext, inferExtendedTypeReqs,
+ mustInferRequirements);
// Validate the generic parameters for the last time, to splat down
// actual archetypes.
@@ -8319,7 +8416,14 @@
return;
// Validate the nominal type declaration being extended.
- auto nominal = extendedType->getAnyNominal();
+ NominalTypeDecl *nominal = extendedType->getAnyNominal();
+ if (!nominal) {
+ auto unbound = cast<UnboundGenericType>(extendedType.getPointer());
+ auto typealias = cast<TypeAliasDecl>(unbound->getDecl());
+ validateDecl(typealias);
+
+ nominal = typealias->getUnderlyingTypeLoc().getType()->getAnyNominal();
+ }
validateDecl(nominal);
if (nominal->getGenericParamsOfContext()) {
diff --git a/lib/Sema/TypeCheckGeneric.cpp b/lib/Sema/TypeCheckGeneric.cpp
index c435b9e..d620bc7 100644
--- a/lib/Sema/TypeCheckGeneric.cpp
+++ b/lib/Sema/TypeCheckGeneric.cpp
@@ -1147,13 +1147,14 @@
bool allowConcreteGenericParams,
ExtensionDecl *ext,
llvm::function_ref<void(GenericSignatureBuilder &)>
- inferRequirements) {
+ inferRequirements,
+ bool mustInferRequirements) {
assert(genericParams && "Missing generic parameters?");
bool recursivelyVisitGenericParams =
genericParams->getOuterParameters() && !parentSig;
GenericSignature *sig;
- if (!ext || ext->getTrailingWhereClause() ||
+ if (!ext || mustInferRequirements || ext->getTrailingWhereClause() ||
getExtendedTypeGenericDepth(ext) != genericParams->getDepth()) {
// Collect the generic parameters.
SmallVector<GenericTypeParamType *, 4> allGenericParams;
diff --git a/lib/Sema/TypeChecker.cpp b/lib/Sema/TypeChecker.cpp
index 215d3c0..70da39f 100644
--- a/lib/Sema/TypeChecker.cpp
+++ b/lib/Sema/TypeChecker.cpp
@@ -323,7 +323,8 @@
auto extendedNominal = aliasDecl->getDeclaredInterfaceType()->getAnyNominal();
if (extendedNominal) {
extendedType = extendedNominal->getDeclaredType();
- ED->getExtendedTypeLoc().setType(extendedType);
+ if (!isPassThroughTypealias(aliasDecl))
+ ED->getExtendedTypeLoc().setType(extendedType);
}
}
}
diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h
index 76c1f35..0e5d12f 100644
--- a/lib/Sema/TypeChecker.h
+++ b/lib/Sema/TypeChecker.h
@@ -1406,7 +1406,8 @@
bool allowConcreteGenericParams,
ExtensionDecl *ext,
llvm::function_ref<void(GenericSignatureBuilder &)>
- inferRequirements);
+ inferRequirements,
+ bool mustInferRequirements);
/// Construct a new generic environment for the given declaration context.
///
@@ -1426,7 +1427,8 @@
ExtensionDecl *ext) {
return checkGenericEnvironment(genericParams, dc, outerSignature,
allowConcreteGenericParams, ext,
- [&](GenericSignatureBuilder &) { });
+ [&](GenericSignatureBuilder &) { },
+ /*mustInferRequirements=*/false);
}
/// Validate the signature of a generic type.
@@ -2531,7 +2533,27 @@
bool isAcceptableDynamicMemberLookupSubscript(SubscriptDecl *decl,
DeclContext *DC,
TypeChecker &TC);
-
+
+/// Determine whether this is a "pass-through" typealias, which has the
+/// same type parameters as the nominal type it references and specializes
+/// the underlying nominal type with exactly those type parameters.
+/// For example, the following typealias \c GX is a pass-through typealias:
+///
+/// \code
+/// struct X<T, U> { }
+/// typealias GX<A, B> = X<A, B>
+/// \endcode
+///
+/// whereas \c GX2 and \c GX3 are not pass-through because \c GX2 has
+/// different type parameters and \c GX3 doesn't pass its type parameters
+/// directly through.
+///
+/// \code
+/// typealias GX2<A> = X<A, A>
+/// typealias GX3<A, B> = X<B, A>
+/// \endcode
+bool isPassThroughTypealias(TypeAliasDecl *typealias);
+
} // end namespace swift
#endif
diff --git a/test/Compatibility/stdlib_generic_typealiases.swift b/test/Compatibility/stdlib_generic_typealiases.swift
new file mode 100644
index 0000000..4a5d6ec
--- /dev/null
+++ b/test/Compatibility/stdlib_generic_typealiases.swift
@@ -0,0 +1,18 @@
+// RUN: %target-typecheck-verify-swift
+
+struct RequiresComparable<T: Comparable> { }
+
+extension CountableRange { // expected-warning{{'CountableRange' is deprecated: renamed to 'Range'}}
+ // expected-note@-1{{use 'Range' instead}}{{11-25=Range}}
+ func testComparable() {
+ _ = RequiresComparable<Bound>()
+ }
+}
+
+struct RequiresHashable<T: Hashable> { }
+
+extension DictionaryIndex {
+ func testHashable() {
+ _ = RequiresHashable<Key>()
+ }
+}
diff --git a/test/Generics/requirement_inference.swift b/test/Generics/requirement_inference.swift
index 28f1d9f..02cd9c0 100644
--- a/test/Generics/requirement_inference.swift
+++ b/test/Generics/requirement_inference.swift
@@ -481,10 +481,28 @@
}
// Extend using the inferred requirement.
-// FIXME: Currently broken.
extension X1WithP2 {
func f() {
- _ = X5<T>() // FIXME: expected-error{{type 'T' does not conform to protocol 'P2'}}
+ _ = X5<T>() // okay: inferred T: P2 from generic typealias
+ }
+}
+
+extension X1: P1 {
+ func p1() { }
+}
+
+typealias X1WithP2Changed<T: P2> = X1<X1<T>>
+typealias X1WithP2MoreArgs<T: P2, U> = X1<T>
+
+extension X1WithP2Changed {
+ func bad1() {
+ _ = X5<T>() // expected-error{{type 'T' does not conform to protocol 'P2'}}
+ }
+}
+
+extension X1WithP2MoreArgs {
+ func bad2() {
+ _ = X5<T>() // expected-error{{type 'T' does not conform to protocol 'P2'}}
}
}
diff --git a/validation-test/Serialization/rdar29694978.swift b/validation-test/Serialization/rdar29694978.swift
index 9530375..8e92b38 100644
--- a/validation-test/Serialization/rdar29694978.swift
+++ b/validation-test/Serialization/rdar29694978.swift
@@ -22,7 +22,7 @@
// CHECK-DAG: typealias MyGenericType<T> = GenericType<T>
typealias MyGenericType<T: NSObject> = GenericType<T>
-// CHECK-DAG: extension GenericType where Element : AnyObject
+// CHECK-DAG: extension GenericType where Element : NSObject
extension MyGenericType {}
// CHECK-DAG: extension GenericType where Element == NSObject
extension MyGenericType where Element == NSObject {}