blob: 72b3d6f26618e4e7853eacae5283fd2cb6c3a3fa [file] [log] [blame]
"""Type inference constraint solving"""
from __future__ import annotations
from collections import defaultdict
from typing import Iterable, Sequence
from typing_extensions import TypeAlias as _TypeAlias
from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints, neg_op
from mypy.expandtype import expand_type
from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort
from mypy.join import join_types
from mypy.meet import meet_type_list, meet_types
from mypy.subtypes import is_subtype
from mypy.typeops import get_type_vars
from mypy.types import (
AnyType,
Instance,
NoneType,
ProperType,
Type,
TypeOfAny,
TypeVarId,
TypeVarLikeType,
TypeVarType,
UninhabitedType,
UnionType,
get_proper_type,
remove_dups,
)
from mypy.typestate import type_state
Bounds: _TypeAlias = "dict[TypeVarId, set[Type]]"
Graph: _TypeAlias = "set[tuple[TypeVarId, TypeVarId]]"
Solutions: _TypeAlias = "dict[TypeVarId, Type | None]"
def solve_constraints(
original_vars: Sequence[TypeVarLikeType],
constraints: list[Constraint],
strict: bool = True,
allow_polymorphic: bool = False,
) -> tuple[list[Type | None], list[TypeVarLikeType]]:
"""Solve type constraints.
Return the best type(s) for type variables; each type can be None if the value of
the variable could not be solved.
If a variable has no constraints, if strict=True then arbitrarily
pick UninhabitedType as the value of the type variable. If strict=False, pick AnyType.
If allow_polymorphic=True, then use the full algorithm that can potentially return
free type variables in solutions (these require special care when applying). Otherwise,
use a simplified algorithm that just solves each type variable individually if possible.
"""
vars = [tv.id for tv in original_vars]
if not vars:
return [], []
originals = {tv.id: tv for tv in original_vars}
extra_vars: list[TypeVarId] = []
# Get additional type variables from generic actuals.
for c in constraints:
extra_vars.extend([v.id for v in c.extra_tvars if v.id not in vars + extra_vars])
originals.update({v.id: v for v in c.extra_tvars if v.id not in originals})
if allow_polymorphic:
# Constraints like T :> S and S <: T are semantically the same, but they are
# represented differently. Normalize the constraint list w.r.t this equivalence.
constraints = normalize_constraints(constraints, vars + extra_vars)
# Collect a list of constraints for each type variable.
cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars + extra_vars}
for con in constraints:
if con.type_var in vars + extra_vars:
cmap[con.type_var].append(con)
if allow_polymorphic:
if constraints:
solutions, free_vars = solve_with_dependent(
vars + extra_vars, constraints, vars, originals
)
else:
solutions = {}
free_vars = []
else:
solutions = {}
free_vars = []
for tv, cs in cmap.items():
if not cs:
continue
lowers = [c.target for c in cs if c.op == SUPERTYPE_OF]
uppers = [c.target for c in cs if c.op == SUBTYPE_OF]
solution = solve_one(lowers, uppers)
# Do not leak type variables in non-polymorphic solutions.
if solution is None or not get_vars(
solution, [tv for tv in extra_vars if tv not in vars]
):
solutions[tv] = solution
res: list[Type | None] = []
for v in vars:
if v in solutions:
res.append(solutions[v])
else:
# No constraints for type variable -- 'UninhabitedType' is the most specific type.
candidate: Type
if strict:
candidate = UninhabitedType()
candidate.ambiguous = True
else:
candidate = AnyType(TypeOfAny.special_form)
res.append(candidate)
return res, free_vars
def solve_with_dependent(
vars: list[TypeVarId],
constraints: list[Constraint],
original_vars: list[TypeVarId],
originals: dict[TypeVarId, TypeVarLikeType],
) -> tuple[Solutions, list[TypeVarLikeType]]:
"""Solve set of constraints that may depend on each other, like T <: List[S].
The whole algorithm consists of five steps:
* Propagate via linear constraints and use secondary constraints to get transitive closure
* Find dependencies between type variables, group them in SCCs, and sort topologically
* Check that all SCC are intrinsically linear, we can't solve (express) T <: List[T]
* Variables in leaf SCCs that don't have constant bounds are free (choose one per SCC)
* Solve constraints iteratively starting from leafs, updating bounds after each step.
"""
graph, lowers, uppers = transitive_closure(vars, constraints)
dmap = compute_dependencies(vars, graph, lowers, uppers)
sccs = list(strongly_connected_components(set(vars), dmap))
if not all(check_linear(scc, lowers, uppers) for scc in sccs):
return {}, []
raw_batches = list(topsort(prepare_sccs(sccs, dmap)))
free_vars = []
free_solutions = {}
for scc in raw_batches[0]:
# If there are no bounds on this SCC, then the only meaningful solution we can
# express, is that each variable is equal to a new free variable. For example,
# if we have T <: S, S <: U, we deduce: T = S = U = <free>.
if all(not lowers[tv] and not uppers[tv] for tv in scc):
best_free = choose_free([originals[tv] for tv in scc], original_vars)
if best_free:
free_vars.append(best_free.id)
free_solutions[best_free.id] = best_free
# Update lowers/uppers with free vars, so these can now be used
# as valid solutions.
for l, u in graph:
if l in free_vars:
lowers[u].add(free_solutions[l])
if u in free_vars:
uppers[l].add(free_solutions[u])
# Flatten the SCCs that are independent, we can solve them together,
# since we don't need to update any targets in between.
batches = []
for batch in raw_batches:
next_bc = []
for scc in batch:
next_bc.extend(list(scc))
batches.append(next_bc)
solutions: dict[TypeVarId, Type | None] = {}
for flat_batch in batches:
res = solve_iteratively(flat_batch, graph, lowers, uppers)
solutions.update(res)
return solutions, [free_solutions[tv] for tv in free_vars]
def solve_iteratively(
batch: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds
) -> Solutions:
"""Solve transitive closure sequentially, updating upper/lower bounds after each step.
Transitive closure is represented as a linear graph plus lower/upper bounds for each
type variable, see transitive_closure() docstring for details.
We solve for type variables that appear in `batch`. If a bound is not constant (i.e. it
looks like T :> F[S, ...]), we substitute solutions found so far in the target F[S, ...]
after solving the batch.
Importantly, after solving each variable in a batch, we move it from linear graph to
upper/lower bounds, this way we can guarantee consistency of solutions (see comment below
for an example when this is important).
"""
solutions = {}
s_batch = set(batch)
while s_batch:
for tv in sorted(s_batch, key=lambda x: x.raw_id):
if lowers[tv] or uppers[tv]:
solvable_tv = tv
break
else:
break
# Solve each solvable type variable separately.
s_batch.remove(solvable_tv)
result = solve_one(lowers[solvable_tv], uppers[solvable_tv])
solutions[solvable_tv] = result
if result is None:
# TODO: support backtracking lower/upper bound choices and order within SCCs.
# (will require switching this function from iterative to recursive).
continue
# Update the (transitive) bounds from graph if there is a solution.
# This is needed to guarantee solutions will never contradict the initial
# constraints. For example, consider {T <: S, T <: A, S :> B} with A :> B.
# If we would not update the uppers/lowers from graph, we would infer T = A, S = B
# which is not correct.
for l, u in graph.copy():
if l == u:
continue
if l == solvable_tv:
lowers[u].add(result)
graph.remove((l, u))
if u == solvable_tv:
uppers[l].add(result)
graph.remove((l, u))
# We can update uppers/lowers only once after solving the whole SCC,
# since uppers/lowers can't depend on type variables in the SCC
# (and we would reject such SCC as non-linear and therefore not solvable).
subs = {tv: s for (tv, s) in solutions.items() if s is not None}
for tv in lowers:
lowers[tv] = {expand_type(lt, subs) for lt in lowers[tv]}
for tv in uppers:
uppers[tv] = {expand_type(ut, subs) for ut in uppers[tv]}
return solutions
def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None:
"""Solve constraints by finding by using meets of upper bounds, and joins of lower bounds."""
bottom: Type | None = None
top: Type | None = None
candidate: Type | None = None
# Process each bound separately, and calculate the lower and upper
# bounds based on constraints. Note that we assume that the constraint
# targets do not have constraint references.
for target in lowers:
if bottom is None:
bottom = target
else:
if type_state.infer_unions:
# This deviates from the general mypy semantics because
# recursive types are union-heavy in 95% of cases.
bottom = UnionType.make_union([bottom, target])
else:
bottom = join_types(bottom, target)
for target in uppers:
if top is None:
top = target
else:
top = meet_types(top, target)
p_top = get_proper_type(top)
p_bottom = get_proper_type(bottom)
if isinstance(p_top, AnyType) or isinstance(p_bottom, AnyType):
source_any = top if isinstance(p_top, AnyType) else bottom
assert isinstance(source_any, ProperType) and isinstance(source_any, AnyType)
return AnyType(TypeOfAny.from_another_any, source_any=source_any)
elif bottom is None:
if top:
candidate = top
else:
# No constraints for type variable
return None
elif top is None:
candidate = bottom
elif is_subtype(bottom, top):
candidate = bottom
else:
candidate = None
return candidate
def choose_free(
scc: list[TypeVarLikeType], original_vars: list[TypeVarId]
) -> TypeVarLikeType | None:
"""Choose the best solution for an SCC containing only type variables.
This is needed to preserve e.g. the upper bound in a situation like this:
def dec(f: Callable[[T], S]) -> Callable[[T], S]: ...
@dec
def test(x: U) -> U: ...
where U <: A.
"""
if len(scc) == 1:
# Fast path, choice is trivial.
return scc[0]
common_upper_bound = meet_type_list([t.upper_bound for t in scc])
common_upper_bound_p = get_proper_type(common_upper_bound)
# We include None for when strict-optional is disabled.
if isinstance(common_upper_bound_p, (UninhabitedType, NoneType)):
# This will cause to infer <nothing>, which is better than a free TypeVar
# that has an upper bound <nothing>.
return None
values: list[Type] = []
for tv in scc:
if isinstance(tv, TypeVarType) and tv.values:
if values:
# It is too tricky to support multiple TypeVars with values
# within the same SCC.
return None
values = tv.values.copy()
if values and not is_trivial_bound(common_upper_bound_p):
# If there are both values and upper bound present, we give up,
# since type variables having both are not supported.
return None
# For convenience with current type application machinery, we use a stable
# choice that prefers the original type variables (not polymorphic ones) in SCC.
best = sorted(scc, key=lambda x: (x.id not in original_vars, x.id.raw_id))[0]
if isinstance(best, TypeVarType):
return best.copy_modified(values=values, upper_bound=common_upper_bound)
if is_trivial_bound(common_upper_bound_p):
# TODO: support more cases for ParamSpecs/TypeVarTuples
return best
return None
def is_trivial_bound(tp: ProperType) -> bool:
return isinstance(tp, Instance) and tp.type.fullname == "builtins.object"
def normalize_constraints(
constraints: list[Constraint], vars: list[TypeVarId]
) -> list[Constraint]:
"""Normalize list of constraints (to simplify life for the non-linear solver).
This includes two things currently:
* Complement T :> S by S <: T
* Remove strict duplicates
* Remove constrains for unrelated variables
"""
res = constraints.copy()
for c in constraints:
if isinstance(c.target, TypeVarType):
res.append(Constraint(c.target, neg_op(c.op), c.origin_type_var))
return [c for c in remove_dups(constraints) if c.type_var in vars]
def transitive_closure(
tvars: list[TypeVarId], constraints: list[Constraint]
) -> tuple[Graph, Bounds, Bounds]:
"""Find transitive closure for given constraints on type variables.
Transitive closure gives maximal set of lower/upper bounds for each type variable,
such that we cannot deduce any further bounds by chaining other existing bounds.
The transitive closure is represented by:
* A set of lower and upper bounds for each type variable, where only constant and
non-linear terms are included in the bounds.
* A graph of linear constraints between type variables (represented as a set of pairs)
Such separation simplifies reasoning, and allows an efficient and simple incremental
transitive closure algorithm that we use here.
For example if we have initial constraints [T <: S, S <: U, U <: int], the transitive
closure is given by:
* {} <: T <: {int}
* {} <: S <: {int}
* {} <: U <: {int}
* {T <: S, S <: U, T <: U}
"""
uppers: Bounds = defaultdict(set)
lowers: Bounds = defaultdict(set)
graph: Graph = {(tv, tv) for tv in tvars}
remaining = set(constraints)
while remaining:
c = remaining.pop()
if isinstance(c.target, TypeVarType) and c.target.id in tvars:
if c.op == SUBTYPE_OF:
lower, upper = c.type_var, c.target.id
else:
lower, upper = c.target.id, c.type_var
if (lower, upper) in graph:
continue
graph |= {
(l, u) for l in tvars for u in tvars if (l, lower) in graph and (upper, u) in graph
}
for u in tvars:
if (upper, u) in graph:
lowers[u] |= lowers[lower]
for l in tvars:
if (l, lower) in graph:
uppers[l] |= uppers[upper]
for lt in lowers[lower]:
for ut in uppers[upper]:
# TODO: what if secondary constraints result in inference
# against polymorphic actual (also in below branches)?
remaining |= set(infer_constraints(lt, ut, SUBTYPE_OF))
remaining |= set(infer_constraints(ut, lt, SUPERTYPE_OF))
elif c.op == SUBTYPE_OF:
if c.target in uppers[c.type_var]:
continue
for l in tvars:
if (l, c.type_var) in graph:
uppers[l].add(c.target)
for lt in lowers[c.type_var]:
remaining |= set(infer_constraints(lt, c.target, SUBTYPE_OF))
remaining |= set(infer_constraints(c.target, lt, SUPERTYPE_OF))
else:
assert c.op == SUPERTYPE_OF
if c.target in lowers[c.type_var]:
continue
for u in tvars:
if (c.type_var, u) in graph:
lowers[u].add(c.target)
for ut in uppers[c.type_var]:
remaining |= set(infer_constraints(ut, c.target, SUPERTYPE_OF))
remaining |= set(infer_constraints(c.target, ut, SUBTYPE_OF))
return graph, lowers, uppers
def compute_dependencies(
tvars: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds
) -> dict[TypeVarId, list[TypeVarId]]:
"""Compute dependencies between type variables induced by constraints.
If we have a constraint like T <: List[S], we say that T depends on S, since
we will need to solve for S first before we can solve for T.
"""
res = {}
for tv in tvars:
deps = set()
for lt in lowers[tv]:
deps |= get_vars(lt, tvars)
for ut in uppers[tv]:
deps |= get_vars(ut, tvars)
for other in tvars:
if other == tv:
continue
if (tv, other) in graph or (other, tv) in graph:
deps.add(other)
res[tv] = list(deps)
return res
def check_linear(scc: set[TypeVarId], lowers: Bounds, uppers: Bounds) -> bool:
"""Check there are only linear constraints between type variables in SCC.
Linear are constraints like T <: S (while T <: F[S] are non-linear).
"""
for tv in scc:
if any(get_vars(lt, list(scc)) for lt in lowers[tv]):
return False
if any(get_vars(ut, list(scc)) for ut in uppers[tv]):
return False
return True
def get_vars(target: Type, vars: list[TypeVarId]) -> set[TypeVarId]:
"""Find type variables for which we are solving in a target type."""
return {tv.id for tv in get_type_vars(target)} & set(vars)