Fix crash on invalid recursive variadic alias (#21572)
Fixes https://github.com/python/mypy/issues/21125
This one was tricky. This is because the above issue actually exposed
_two_ different crash scenarios:
* A crash on invalid constructs like `*tuple[Ts]` (must be
`*tuple[*Ts]`).
* An infinite recursion when trying to detect pathological and divergent
aliases.
And while working on this I discovered two more cases:
* A crash where an invalid type (like a union) appears in unpack in a
recursive alias definition.
* A crash on non-normalizeable recursive tuple.
I fix the first by tightening logic in `typeanal.py` w.r.t. where
exactly a `TypeVarTuple` is allowed. I fix the second and third by
avoiding `get_proper_type()` calls in `expand_type()` for recursive
tuples. The fourth is the most problematic, and is kind of a fundamental
thing. This PR only avoids an immediate crash for such aliases. We will
still need to update various call sites where we special-case tuples to
expect non-normal ones.
Couple more related things:
* I fix couple issues with `is_recursive` cache invalidation.
* I added a fast path to `detect_diverging_alias()` to avoid creating
sets unless really needed.
diff --git a/mypy/expandtype.py b/mypy/expandtype.py
index b576d9f..186429a 100644
--- a/mypy/expandtype.py
+++ b/mypy/expandtype.py
@@ -228,11 +228,11 @@
if t.type.fullname == "builtins.tuple":
# Normalize Tuple[*Tuple[X, ...], ...] -> Tuple[X, ...]
arg = args[0]
- if isinstance(arg, UnpackType):
+ if isinstance(arg, UnpackType) and not (
+ isinstance(arg.type, TypeAliasType) and arg.type.is_recursive
+ ):
unpacked = get_proper_type(arg.type)
if isinstance(unpacked, Instance):
- # TODO: this and similar asserts below may be unsafe because get_proper_type()
- # may be called during semantic analysis before all invalid types are removed.
assert unpacked.type.fullname == "builtins.tuple"
args = list(unpacked.args)
return t.copy_modified(args=args)
@@ -536,7 +536,9 @@
if len(items) == 1:
# Normalize Tuple[*Tuple[X, ...]] -> Tuple[X, ...]
item = items[0]
- if isinstance(item, UnpackType):
+ if isinstance(item, UnpackType) and not (
+ isinstance(item.type, TypeAliasType) and item.type.is_recursive
+ ):
unpacked = get_proper_type(item.type)
if isinstance(unpacked, Instance):
# expand_type() may be called during semantic analysis, before invalid unpacks are fixed.
diff --git a/mypy/semanal.py b/mypy/semanal.py
index 58c152f..e010273 100644
--- a/mypy/semanal.py
+++ b/mypy/semanal.py
@@ -4263,6 +4263,8 @@
# An alias gets updated.
updated = False
if isinstance(existing.node, TypeAlias):
+ # Invalidate recursive status cache in case it was previously set.
+ existing.node._is_recursive = None
if existing.node.target != res:
# Copy expansion to the existing alias, this matches how we update base classes
# for a TypeInfo _in place_ if there are nested placeholders.
@@ -4271,8 +4273,6 @@
existing.node.alias_tvars = alias_tvars
existing.node.no_args = no_args
updated = True
- # Invalidate recursive status cache in case it was previously set.
- existing.node._is_recursive = None
else:
# Otherwise just replace existing placeholder with type alias *in place*.
existing._node = alias_node
@@ -5830,6 +5830,8 @@
):
updated = False
if isinstance(existing.node, TypeAlias):
+ # Invalidate recursive status cache in case it was previously set.
+ existing.node._is_recursive = None
if (
existing.node.target != res
or existing.node.alias_tvars != alias_node.alias_tvars
@@ -5840,8 +5842,6 @@
existing.node.default_depends = default_depends
existing.node.alias_tvars = alias_tvars
updated = True
- # Invalidate recursive status cache in case it was previously set.
- existing.node._is_recursive = None
else:
# Otherwise just replace existing placeholder with type alias *in place*.
existing._node = alias_node
diff --git a/mypy/semanal_typeargs.py b/mypy/semanal_typeargs.py
index 0f62a4a..5b6d9b1 100644
--- a/mypy/semanal_typeargs.py
+++ b/mypy/semanal_typeargs.py
@@ -108,7 +108,10 @@
self.seen_aliases.discard(t)
def visit_tuple_type(self, t: TupleType) -> None:
- t.items = flatten_nested_tuples(t.items)
+ # Unfortunately, universal normalization of tuples is not possible in presence of
+ # recursive aliases, see testNoCrashOnNonNormalRecursiveTuple for an example.
+ # TODO: update the places where we handle tuples to always expect non-normal ones.
+ t.items = flatten_nested_tuples(t.items, handle_recursive=False)
for i, it in enumerate(t.items):
if self.check_non_paramspec(it, "tuple", t):
t.items[i] = AnyType(TypeOfAny.from_error)
diff --git a/mypy/server/astmerge.py b/mypy/server/astmerge.py
index 075bf7c..5b72371 100644
--- a/mypy/server/astmerge.py
+++ b/mypy/server/astmerge.py
@@ -340,6 +340,8 @@
super().visit_var(node)
def visit_type_alias(self, node: TypeAlias) -> None:
+ # Updating alias target can invalidate its recursive status.
+ node._is_recursive = None
self.fixup_type(node.target)
for v in node.alias_tvars:
self.fixup_type(v)
diff --git a/mypy/typeanal.py b/mypy/typeanal.py
index a303621..aa5d14b 100644
--- a/mypy/typeanal.py
+++ b/mypy/typeanal.py
@@ -75,6 +75,7 @@
BoolTypeQuery,
CallableArgument,
CallableType,
+ CollectAliasesVisitor,
DeletedType,
EllipsisType,
ErasedType,
@@ -275,7 +276,9 @@
self.prohibit_special_class_field_types = prohibit_special_class_field_types
# Allow variables typed as Type[Any] and type (useful for base classes).
self.allow_type_any = allow_type_any
- self.allow_type_var_tuple = False
+ # Level of nesting at which a TypeVarTuple is allowed. Note we specify exact level
+ # to prohibit things like Unpack[list[Ts]], which are not supported.
+ self.allow_type_var_tuple = -1
self.allow_unpack = allow_unpack
# Set when we are analyzing a default of a type variable.
self.analyzing_tvar_def = analyzing_tvar_def
@@ -453,7 +456,7 @@
self.fail(msg, t, code=codes.VALID_TYPE)
return AnyType(TypeOfAny.from_error)
assert isinstance(tvar_def, TypeVarTupleType)
- if not self.allow_type_var_tuple:
+ if self.allow_type_var_tuple != self.nesting_level:
self.fail(
f'TypeVarTuple "{t.name}" is only valid with an unpack',
t,
@@ -808,9 +811,9 @@
if not self.allow_unpack:
self.fail(message_registry.INVALID_UNPACK_POSITION, t, code=codes.VALID_TYPE)
return AnyType(TypeOfAny.from_error)
- self.allow_type_var_tuple = True
+ self.allow_type_var_tuple = self.nesting_level + 1
result = UnpackType(self.anal_type(t.args[0]), line=t.line, column=t.column)
- self.allow_type_var_tuple = False
+ self.allow_type_var_tuple = -1
return result
elif fullname in SELF_TYPE_NAMES:
if t.args:
@@ -1161,9 +1164,9 @@
if not self.allow_unpack:
self.fail(message_registry.INVALID_UNPACK_POSITION, t.type, code=codes.VALID_TYPE)
return AnyType(TypeOfAny.from_error)
- self.allow_type_var_tuple = True
+ self.allow_type_var_tuple = self.nesting_level + 1
result = UnpackType(self.anal_type(t.type), from_star_syntax=t.from_star_syntax)
- self.allow_type_var_tuple = False
+ self.allow_type_var_tuple = -1
return result
def visit_parameters(self, t: Parameters) -> Type:
@@ -2523,6 +2526,15 @@
They may be handy in rare cases, e.g. to express a union of non-mixed nested lists:
Nested = Union[T, Nested[List[T]]] ~> Union[T, List[T], List[List[T]], ...]
"""
+ is_recursive = node._is_recursive
+ if is_recursive is None:
+ is_recursive = node in node.target.accept(CollectAliasesVisitor())
+ if not is_recursive:
+ # Fast path: this is not a recursive alias at all.
+ return False
+ # Note we only cache positive case, caching negative case is risky, as this type alias
+ # (or more importantly any other alias it uses) may be not ready yet.
+ node._is_recursive = True
visitor = DivergingAliasDetector({node})
_ = target.accept(visitor)
return visitor.diverging
diff --git a/mypy/types.py b/mypy/types.py
index 129d988..324135d 100644
--- a/mypy/types.py
+++ b/mypy/types.py
@@ -4301,12 +4301,12 @@
# Funky code here avoids mypyc narrowing the type of unpack_index.
old_index = unpack_index
assert old_index is None
- # Don't return so that we can also sanity check there is only one.
+ # Don't return so that we can also sanity-check there is only one.
unpack_index = i
return unpack_index
-def flatten_nested_tuples(types: Iterable[Type]) -> list[Type]:
+def flatten_nested_tuples(types: Iterable[Type], handle_recursive: bool = True) -> list[Type]:
"""Recursively flatten TupleTypes nested with Unpack.
For example this will transform
@@ -4320,7 +4320,12 @@
res.append(typ)
continue
p_type = get_proper_type(typ.type)
- if not isinstance(p_type, TupleType):
+ if (
+ not isinstance(p_type, TupleType)
+ or not handle_recursive
+ and isinstance(typ.type, TypeAliasType)
+ and typ.type.is_recursive
+ ):
res.append(typ)
continue
if isinstance(typ.type, TypeAliasType):
diff --git a/test-data/unit/check-typevar-tuple.test b/test-data/unit/check-typevar-tuple.test
index 3f0765b..55f125d 100644
--- a/test-data/unit/check-typevar-tuple.test
+++ b/test-data/unit/check-typevar-tuple.test
@@ -882,7 +882,39 @@
reveal_type(x) # N: Revealed type is "Any"
reveal_type(y) # N: Revealed type is "Any"
reveal_type(z) # N: Revealed type is "tuple[builtins.int, Unpack[builtins.tuple[Any, ...]]]"
+[builtins fixtures/tuple.pyi]
+[case testBanPathologicalRecursiveTuplesGeneric]
+from typing import TypeVarTuple, Unpack
+
+Ts = TypeVarTuple("Ts")
+A = tuple[Unpack[B[Unpack[Ts]]]] # E: Invalid recursive alias: a tuple item of itself \
+ # E: Name "B" is used before definition
+B = tuple[Unpack[A[Unpack[Ts]]]]
+[builtins fixtures/tuple.pyi]
+
+[case testNoCrashOnInvalidRecursiveUnpackOfUnion]
+from typing import Unpack
+
+A = tuple[int, str] | list[tuple[Unpack[A]]] # E: "tuple[int, str] | list[tuple[Unpack[A]]]" cannot be unpacked (must be tuple or TypeVarTuple)
+[builtins fixtures/tuple.pyi]
+
+[case testNoCrashOnNonNormalRecursiveTuple]
+from typing import Unpack
+
+A = tuple[int, list[tuple[str, Unpack[A]]]]
+a: A
+x, y = a
+y[0] = 1 # E: Incompatible types in assignment (expression has type "int", target has type "tuple[str, Unpack[A]]")
+[builtins fixtures/list.pyi]
+
+[case testBanTypeVarTupleNotImmediatelyInsideUnpack]
+from typing import TypeVarTuple, Unpack
+
+Ts = TypeVarTuple("Ts")
+A = tuple[Unpack[tuple[Ts]]] # E: TypeVarTuple "Ts" is only valid with an unpack
+x: A[int, str]
+reveal_type(x) # N: Revealed type is "tuple[Any]"
[builtins fixtures/tuple.pyi]
[case testInferenceAgainstGenericVariadicWithBadType]