New type inference: add support for upper bounds and values (#15813)
This is a third PR in series following
https://github.com/python/mypy/pull/15287 and
https://github.com/python/mypy/pull/15754. This one is quite simple: I
just add basic support for polymorphic inference involving type
variables with upper bounds and values. A complete support would be
quite complicated, and it will be a corner case to already rare
situation. Finally, it is written in a way that is easy to tune in the
future.
I also use this PR to add some unit tests for all three PRs so far,
other two PRs only added integration tests (and I clean up existing unit
tests as well).
diff --git a/mypy/solve.py b/mypy/solve.py
index 02df90a..72b3d6f 100644
--- a/mypy/solve.py
+++ b/mypy/solve.py
@@ -10,11 +10,13 @@
from mypy.expandtype import expand_type
from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort
from mypy.join import join_types
-from mypy.meet import meet_types
+from mypy.meet import meet_type_list, meet_types
from mypy.subtypes import is_subtype
from mypy.typeops import get_type_vars
from mypy.types import (
AnyType,
+ Instance,
+ NoneType,
ProperType,
Type,
TypeOfAny,
@@ -108,7 +110,7 @@
else:
candidate = AnyType(TypeOfAny.special_form)
res.append(candidate)
- return res, [originals[tv] for tv in free_vars]
+ return res, free_vars
def solve_with_dependent(
@@ -116,7 +118,7 @@
constraints: list[Constraint],
original_vars: list[TypeVarId],
originals: dict[TypeVarId, TypeVarLikeType],
-) -> tuple[Solutions, list[TypeVarId]]:
+) -> tuple[Solutions, list[TypeVarLikeType]]:
"""Solve set of constraints that may depend on each other, like T <: List[S].
The whole algorithm consists of five steps:
@@ -135,23 +137,24 @@
raw_batches = list(topsort(prepare_sccs(sccs, dmap)))
free_vars = []
+ free_solutions = {}
for scc in raw_batches[0]:
# If there are no bounds on this SCC, then the only meaningful solution we can
# express, is that each variable is equal to a new free variable. For example,
# if we have T <: S, S <: U, we deduce: T = S = U = <free>.
if all(not lowers[tv] and not uppers[tv] for tv in scc):
- # For convenience with current type application machinery, we use a stable
- # choice that prefers the original type variables (not polymorphic ones) in SCC.
- # TODO: be careful about upper bounds (or values) when introducing free vars.
- free_vars.append(sorted(scc, key=lambda x: (x not in original_vars, x.raw_id))[0])
+ best_free = choose_free([originals[tv] for tv in scc], original_vars)
+ if best_free:
+ free_vars.append(best_free.id)
+ free_solutions[best_free.id] = best_free
# Update lowers/uppers with free vars, so these can now be used
# as valid solutions.
- for l, u in graph.copy():
+ for l, u in graph:
if l in free_vars:
- lowers[u].add(originals[l])
+ lowers[u].add(free_solutions[l])
if u in free_vars:
- uppers[l].add(originals[u])
+ uppers[l].add(free_solutions[u])
# Flatten the SCCs that are independent, we can solve them together,
# since we don't need to update any targets in between.
@@ -166,7 +169,7 @@
for flat_batch in batches:
res = solve_iteratively(flat_batch, graph, lowers, uppers)
solutions.update(res)
- return solutions, free_vars
+ return solutions, [free_solutions[tv] for tv in free_vars]
def solve_iteratively(
@@ -276,6 +279,61 @@
return candidate
+def choose_free(
+ scc: list[TypeVarLikeType], original_vars: list[TypeVarId]
+) -> TypeVarLikeType | None:
+ """Choose the best solution for an SCC containing only type variables.
+
+ This is needed to preserve e.g. the upper bound in a situation like this:
+ def dec(f: Callable[[T], S]) -> Callable[[T], S]: ...
+
+ @dec
+ def test(x: U) -> U: ...
+
+ where U <: A.
+ """
+
+ if len(scc) == 1:
+ # Fast path, choice is trivial.
+ return scc[0]
+
+ common_upper_bound = meet_type_list([t.upper_bound for t in scc])
+ common_upper_bound_p = get_proper_type(common_upper_bound)
+ # We include None for when strict-optional is disabled.
+ if isinstance(common_upper_bound_p, (UninhabitedType, NoneType)):
+ # This will cause to infer <nothing>, which is better than a free TypeVar
+ # that has an upper bound <nothing>.
+ return None
+
+ values: list[Type] = []
+ for tv in scc:
+ if isinstance(tv, TypeVarType) and tv.values:
+ if values:
+ # It is too tricky to support multiple TypeVars with values
+ # within the same SCC.
+ return None
+ values = tv.values.copy()
+
+ if values and not is_trivial_bound(common_upper_bound_p):
+ # If there are both values and upper bound present, we give up,
+ # since type variables having both are not supported.
+ return None
+
+ # For convenience with current type application machinery, we use a stable
+ # choice that prefers the original type variables (not polymorphic ones) in SCC.
+ best = sorted(scc, key=lambda x: (x.id not in original_vars, x.id.raw_id))[0]
+ if isinstance(best, TypeVarType):
+ return best.copy_modified(values=values, upper_bound=common_upper_bound)
+ if is_trivial_bound(common_upper_bound_p):
+ # TODO: support more cases for ParamSpecs/TypeVarTuples
+ return best
+ return None
+
+
+def is_trivial_bound(tp: ProperType) -> bool:
+ return isinstance(tp, Instance) and tp.type.fullname == "builtins.object"
+
+
def normalize_constraints(
constraints: list[Constraint], vars: list[TypeVarId]
) -> list[Constraint]:
diff --git a/mypy/test/testsolve.py b/mypy/test/testsolve.py
index 5d67203..6566b03 100644
--- a/mypy/test/testsolve.py
+++ b/mypy/test/testsolve.py
@@ -3,10 +3,10 @@
from __future__ import annotations
from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint
-from mypy.solve import solve_constraints
+from mypy.solve import Bounds, Graph, solve_constraints, transitive_closure
from mypy.test.helpers import Suite, assert_equal
from mypy.test.typefixture import TypeFixture
-from mypy.types import Type, TypeVarLikeType, TypeVarType
+from mypy.types import Type, TypeVarId, TypeVarLikeType, TypeVarType
class SolveSuite(Suite):
@@ -17,11 +17,11 @@
self.assert_solve([], [], [])
def test_simple_supertype_constraints(self) -> None:
- self.assert_solve([self.fx.t], [self.supc(self.fx.t, self.fx.a)], [(self.fx.a, self.fx.o)])
+ self.assert_solve([self.fx.t], [self.supc(self.fx.t, self.fx.a)], [self.fx.a])
self.assert_solve(
[self.fx.t],
[self.supc(self.fx.t, self.fx.a), self.supc(self.fx.t, self.fx.b)],
- [(self.fx.a, self.fx.o)],
+ [self.fx.a],
)
def test_simple_subtype_constraints(self) -> None:
@@ -36,7 +36,7 @@
self.assert_solve(
[self.fx.t],
[self.supc(self.fx.t, self.fx.b), self.subc(self.fx.t, self.fx.a)],
- [(self.fx.b, self.fx.a)],
+ [self.fx.b],
)
def test_unsatisfiable_constraints(self) -> None:
@@ -49,7 +49,7 @@
self.assert_solve(
[self.fx.t],
[self.supc(self.fx.t, self.fx.b), self.subc(self.fx.t, self.fx.b)],
- [(self.fx.b, self.fx.b)],
+ [self.fx.b],
)
def test_multiple_variables(self) -> None:
@@ -60,7 +60,7 @@
self.supc(self.fx.s, self.fx.c),
self.subc(self.fx.t, self.fx.a),
],
- [(self.fx.b, self.fx.a), (self.fx.c, self.fx.o)],
+ [self.fx.b, self.fx.c],
)
def test_no_constraints_for_var(self) -> None:
@@ -69,36 +69,32 @@
self.assert_solve(
[self.fx.t, self.fx.s],
[self.supc(self.fx.s, self.fx.a)],
- [self.fx.uninhabited, (self.fx.a, self.fx.o)],
+ [self.fx.uninhabited, self.fx.a],
)
def test_simple_constraints_with_dynamic_type(self) -> None:
- self.assert_solve(
- [self.fx.t], [self.supc(self.fx.t, self.fx.anyt)], [(self.fx.anyt, self.fx.anyt)]
- )
+ self.assert_solve([self.fx.t], [self.supc(self.fx.t, self.fx.anyt)], [self.fx.anyt])
self.assert_solve(
[self.fx.t],
[self.supc(self.fx.t, self.fx.anyt), self.supc(self.fx.t, self.fx.anyt)],
- [(self.fx.anyt, self.fx.anyt)],
+ [self.fx.anyt],
)
self.assert_solve(
[self.fx.t],
[self.supc(self.fx.t, self.fx.anyt), self.supc(self.fx.t, self.fx.a)],
- [(self.fx.anyt, self.fx.anyt)],
+ [self.fx.anyt],
)
- self.assert_solve(
- [self.fx.t], [self.subc(self.fx.t, self.fx.anyt)], [(self.fx.anyt, self.fx.anyt)]
- )
+ self.assert_solve([self.fx.t], [self.subc(self.fx.t, self.fx.anyt)], [self.fx.anyt])
self.assert_solve(
[self.fx.t],
[self.subc(self.fx.t, self.fx.anyt), self.subc(self.fx.t, self.fx.anyt)],
- [(self.fx.anyt, self.fx.anyt)],
+ [self.fx.anyt],
)
# self.assert_solve([self.fx.t],
# [self.subc(self.fx.t, self.fx.anyt),
# self.subc(self.fx.t, self.fx.a)],
- # [(self.fx.anyt, self.fx.anyt)])
+ # [self.fx.anyt])
# TODO: figure out what this should be after changes to meet(any, X)
def test_both_normal_and_any_types_in_results(self) -> None:
@@ -107,29 +103,180 @@
self.assert_solve(
[self.fx.t],
[self.supc(self.fx.t, self.fx.a), self.subc(self.fx.t, self.fx.anyt)],
- [(self.fx.anyt, self.fx.anyt)],
+ [self.fx.anyt],
)
self.assert_solve(
[self.fx.t],
[self.supc(self.fx.t, self.fx.anyt), self.subc(self.fx.t, self.fx.a)],
- [(self.fx.anyt, self.fx.anyt)],
+ [self.fx.anyt],
+ )
+
+ def test_poly_no_constraints(self) -> None:
+ self.assert_solve(
+ [self.fx.t, self.fx.u],
+ [],
+ [self.fx.uninhabited, self.fx.uninhabited],
+ allow_polymorphic=True,
+ )
+
+ def test_poly_trivial_free(self) -> None:
+ self.assert_solve(
+ [self.fx.t, self.fx.u],
+ [self.subc(self.fx.t, self.fx.a)],
+ [self.fx.a, self.fx.u],
+ [self.fx.u],
+ allow_polymorphic=True,
+ )
+
+ def test_poly_free_pair(self) -> None:
+ self.assert_solve(
+ [self.fx.t, self.fx.u],
+ [self.subc(self.fx.t, self.fx.u)],
+ [self.fx.t, self.fx.t],
+ [self.fx.t],
+ allow_polymorphic=True,
+ )
+
+ def test_poly_free_pair_with_bounds(self) -> None:
+ t_prime = self.fx.t.copy_modified(upper_bound=self.fx.b)
+ self.assert_solve(
+ [self.fx.t, self.fx.ub],
+ [self.subc(self.fx.t, self.fx.ub)],
+ [t_prime, t_prime],
+ [t_prime],
+ allow_polymorphic=True,
+ )
+
+ def test_poly_free_pair_with_bounds_uninhabited(self) -> None:
+ self.assert_solve(
+ [self.fx.ub, self.fx.uc],
+ [self.subc(self.fx.ub, self.fx.uc)],
+ [self.fx.uninhabited, self.fx.uninhabited],
+ [],
+ allow_polymorphic=True,
+ )
+
+ def test_poly_bounded_chain(self) -> None:
+ # B <: T <: U <: S <: A
+ self.assert_solve(
+ [self.fx.t, self.fx.u, self.fx.s],
+ [
+ self.supc(self.fx.t, self.fx.b),
+ self.subc(self.fx.t, self.fx.u),
+ self.subc(self.fx.u, self.fx.s),
+ self.subc(self.fx.s, self.fx.a),
+ ],
+ [self.fx.b, self.fx.b, self.fx.b],
+ allow_polymorphic=True,
+ )
+
+ def test_poly_reverse_overlapping_chain(self) -> None:
+ # A :> T <: S :> B
+ self.assert_solve(
+ [self.fx.t, self.fx.s],
+ [
+ self.subc(self.fx.t, self.fx.s),
+ self.subc(self.fx.t, self.fx.a),
+ self.supc(self.fx.s, self.fx.b),
+ ],
+ [self.fx.a, self.fx.a],
+ allow_polymorphic=True,
+ )
+
+ def test_poly_reverse_split_chain(self) -> None:
+ # B :> T <: S :> A
+ self.assert_solve(
+ [self.fx.t, self.fx.s],
+ [
+ self.subc(self.fx.t, self.fx.s),
+ self.subc(self.fx.t, self.fx.b),
+ self.supc(self.fx.s, self.fx.a),
+ ],
+ [self.fx.b, self.fx.a],
+ allow_polymorphic=True,
+ )
+
+ def test_poly_unsolvable_chain(self) -> None:
+ # A <: T <: U <: S <: B
+ self.assert_solve(
+ [self.fx.t, self.fx.u, self.fx.s],
+ [
+ self.supc(self.fx.t, self.fx.a),
+ self.subc(self.fx.t, self.fx.u),
+ self.subc(self.fx.u, self.fx.s),
+ self.subc(self.fx.s, self.fx.b),
+ ],
+ [None, None, None],
+ allow_polymorphic=True,
+ )
+
+ def test_simple_chain_closure(self) -> None:
+ self.assert_transitive_closure(
+ [self.fx.t.id, self.fx.s.id],
+ [
+ self.supc(self.fx.t, self.fx.b),
+ self.subc(self.fx.t, self.fx.s),
+ self.subc(self.fx.s, self.fx.a),
+ ],
+ {(self.fx.t.id, self.fx.s.id)},
+ {self.fx.t.id: {self.fx.b}, self.fx.s.id: {self.fx.b}},
+ {self.fx.t.id: {self.fx.a}, self.fx.s.id: {self.fx.a}},
+ )
+
+ def test_reverse_chain_closure(self) -> None:
+ self.assert_transitive_closure(
+ [self.fx.t.id, self.fx.s.id],
+ [
+ self.subc(self.fx.t, self.fx.s),
+ self.subc(self.fx.t, self.fx.a),
+ self.supc(self.fx.s, self.fx.b),
+ ],
+ {(self.fx.t.id, self.fx.s.id)},
+ {self.fx.t.id: set(), self.fx.s.id: {self.fx.b}},
+ {self.fx.t.id: {self.fx.a}, self.fx.s.id: set()},
+ )
+
+ def test_secondary_constraint_closure(self) -> None:
+ self.assert_transitive_closure(
+ [self.fx.t.id, self.fx.s.id],
+ [self.supc(self.fx.s, self.fx.gt), self.subc(self.fx.s, self.fx.ga)],
+ set(),
+ {self.fx.t.id: set(), self.fx.s.id: {self.fx.gt}},
+ {self.fx.t.id: {self.fx.a}, self.fx.s.id: {self.fx.ga}},
)
def assert_solve(
self,
vars: list[TypeVarLikeType],
constraints: list[Constraint],
- results: list[None | Type | tuple[Type, Type]],
+ results: list[None | Type],
+ free_vars: list[TypeVarLikeType] | None = None,
+ allow_polymorphic: bool = False,
) -> None:
- res: list[Type | None] = []
- for r in results:
- if isinstance(r, tuple):
- res.append(r[0])
- else:
- res.append(r)
- actual, _ = solve_constraints(vars, constraints)
- assert_equal(str(actual), str(res))
+ if free_vars is None:
+ free_vars = []
+ actual, actual_free = solve_constraints(
+ vars, constraints, allow_polymorphic=allow_polymorphic
+ )
+ assert_equal(actual, results)
+ assert_equal(actual_free, free_vars)
+
+ def assert_transitive_closure(
+ self,
+ vars: list[TypeVarId],
+ constraints: list[Constraint],
+ graph: Graph,
+ lowers: Bounds,
+ uppers: Bounds,
+ ) -> None:
+ actual_graph, actual_lowers, actual_uppers = transitive_closure(vars, constraints)
+ # Add trivial elements.
+ for v in vars:
+ graph.add((v, v))
+ assert_equal(actual_graph, graph)
+ assert_equal(dict(actual_lowers), lowers)
+ assert_equal(dict(actual_uppers), uppers)
def supc(self, type_var: TypeVarType, bound: Type) -> Constraint:
return Constraint(type_var, SUPERTYPE_OF, bound)
diff --git a/mypy/test/typefixture.py b/mypy/test/typefixture.py
index bf1500a..81af765 100644
--- a/mypy/test/typefixture.py
+++ b/mypy/test/typefixture.py
@@ -219,6 +219,10 @@
self._add_bool_dunder(self.bool_type_info)
self._add_bool_dunder(self.ai)
+ # TypeVars with non-trivial bounds
+ self.ub = make_type_var("UB", 5, [], self.b, variance) # UB`5 (type variable)
+ self.uc = make_type_var("UC", 6, [], self.c, variance) # UC`6 (type variable)
+
def make_type_var_tuple(name: str, id: int, upper_bound: Type) -> TypeVarTupleType:
return TypeVarTupleType(
name,
diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test
index 5c510a1..d1842a7 100644
--- a/test-data/unit/check-generics.test
+++ b/test-data/unit/check-generics.test
@@ -3007,3 +3007,31 @@
c: C
reveal_type(c.test()) # N: Revealed type is "__main__.C"
+
+[case testInferenceAgainstGenericBoundsAndValues]
+# flags: --new-type-inference
+from typing import TypeVar, Callable, List
+
+class B: ...
+class C(B): ...
+
+S = TypeVar('S')
+T = TypeVar('T')
+UB = TypeVar('UB', bound=B)
+UC = TypeVar('UC', bound=C)
+V = TypeVar('V', int, str)
+
+def dec1(f: Callable[[S], T]) -> Callable[[S], List[T]]:
+ ...
+def dec2(f: Callable[[UC], T]) -> Callable[[UC], List[T]]:
+ ...
+def id1(x: UB) -> UB:
+ ...
+def id2(x: V) -> V:
+ ...
+
+reveal_type(dec1(id1)) # N: Revealed type is "def [S <: __main__.B] (S`1) -> builtins.list[S`1]"
+reveal_type(dec1(id2)) # N: Revealed type is "def [S in (builtins.int, builtins.str)] (S`3) -> builtins.list[S`3]"
+reveal_type(dec2(id1)) # N: Revealed type is "def [UC <: __main__.C] (UC`5) -> builtins.list[UC`5]"
+reveal_type(dec2(id2)) # N: Revealed type is "def (<nothing>) -> builtins.list[<nothing>]" \
+ # E: Argument 1 to "dec2" has incompatible type "Callable[[V], V]"; expected "Callable[[<nothing>], <nothing>]"