blob: 55a51d4adbb6a9fe5f5d2c97a6d3c73be5b57a6e [file] [log] [blame]
from __future__ import annotations
from typing import Callable, Sequence
import mypy.subtypes
from mypy.expandtype import expand_type, expand_unpack_with_variables
from mypy.nodes import ARG_STAR, Context
from mypy.types import (
AnyType,
CallableType,
Instance,
Parameters,
ParamSpecType,
PartialType,
TupleType,
Type,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
UnpackType,
get_proper_type,
)
from mypy.typevartuples import find_unpack_in_list, replace_starargs
def get_target_type(
tvar: TypeVarLikeType,
type: Type,
callable: CallableType,
report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None],
context: Context,
skip_unsatisfied: bool,
) -> Type | None:
if isinstance(tvar, ParamSpecType):
return type
if isinstance(tvar, TypeVarTupleType):
return type
assert isinstance(tvar, TypeVarType)
values = tvar.values
p_type = get_proper_type(type)
if values:
if isinstance(p_type, AnyType):
return type
if isinstance(p_type, TypeVarType) and p_type.values:
# Allow substituting T1 for T if every allowed value of T1
# is also a legal value of T.
if all(any(mypy.subtypes.is_same_type(v, v1) for v in values) for v1 in p_type.values):
return type
matching = []
for value in values:
if mypy.subtypes.is_subtype(type, value):
matching.append(value)
if matching:
best = matching[0]
# If there are more than one matching value, we select the narrowest
for match in matching[1:]:
if mypy.subtypes.is_subtype(match, best):
best = match
return best
if skip_unsatisfied:
return None
report_incompatible_typevar_value(callable, type, tvar.name, context)
else:
upper_bound = tvar.upper_bound
if not mypy.subtypes.is_subtype(type, upper_bound):
if skip_unsatisfied:
return None
report_incompatible_typevar_value(callable, type, tvar.name, context)
return type
def apply_generic_arguments(
callable: CallableType,
orig_types: Sequence[Type | None],
report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None],
context: Context,
skip_unsatisfied: bool = False,
) -> CallableType:
"""Apply generic type arguments to a callable type.
For example, applying [int] to 'def [T] (T) -> T' results in
'def (int) -> int'.
Note that each type can be None; in this case, it will not be applied.
If `skip_unsatisfied` is True, then just skip the types that don't satisfy type variable
bound or constraints, instead of giving an error.
"""
tvars = callable.variables
assert len(tvars) == len(orig_types)
# Check that inferred type variable values are compatible with allowed
# values and bounds. Also, promote subtype values to allowed values.
# Create a map from type variable id to target type.
id_to_type: dict[TypeVarId, Type] = {}
for tvar, type in zip(tvars, orig_types):
assert not isinstance(type, PartialType), "Internal error: must never apply partial type"
if type is None:
continue
target_type = get_target_type(
tvar, type, callable, report_incompatible_typevar_value, context, skip_unsatisfied
)
if target_type is not None:
id_to_type[tvar.id] = target_type
param_spec = callable.param_spec()
if param_spec is not None:
nt = id_to_type.get(param_spec.id)
if nt is not None:
nt = get_proper_type(nt)
if isinstance(nt, (CallableType, Parameters)):
callable = callable.expand_param_spec(nt)
# Apply arguments to argument types.
var_arg = callable.var_arg()
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
star_index = callable.arg_kinds.index(ARG_STAR)
callable = callable.copy_modified(
arg_types=(
[expand_type(at, id_to_type) for at in callable.arg_types[:star_index]]
+ [callable.arg_types[star_index]]
+ [expand_type(at, id_to_type) for at in callable.arg_types[star_index + 1 :]]
)
)
unpacked_type = get_proper_type(var_arg.typ.type)
if isinstance(unpacked_type, TupleType):
# Assuming for now that because we convert prefixes to positional arguments,
# the first argument is always an unpack.
expanded_tuple = expand_type(unpacked_type, id_to_type)
if isinstance(expanded_tuple, TupleType):
# TODO: handle the case where the tuple has an unpack. This will
# hit an assert below.
expanded_unpack = find_unpack_in_list(expanded_tuple.items)
if expanded_unpack is not None:
callable = callable.copy_modified(
arg_types=(
callable.arg_types[:star_index]
+ [expanded_tuple]
+ callable.arg_types[star_index + 1 :]
)
)
else:
callable = replace_starargs(callable, expanded_tuple.items)
else:
# TODO: handle the case for if we get a variable length tuple.
assert False, f"mypy bug: unimplemented case, {expanded_tuple}"
elif isinstance(unpacked_type, TypeVarTupleType):
expanded_tvt = expand_unpack_with_variables(var_arg.typ, id_to_type)
if isinstance(expanded_tvt, list):
for t in expanded_tvt:
assert not isinstance(t, UnpackType)
callable = replace_starargs(callable, expanded_tvt)
else:
assert isinstance(expanded_tvt, Instance)
assert expanded_tvt.type.fullname == "builtins.tuple"
callable = callable.copy_modified(
arg_types=(
callable.arg_types[:star_index]
+ [expanded_tvt.args[0]]
+ callable.arg_types[star_index + 1 :]
)
)
else:
assert False, "mypy bug: unhandled case applying unpack"
else:
callable = callable.copy_modified(
arg_types=[expand_type(at, id_to_type) for at in callable.arg_types]
)
# Apply arguments to TypeGuard if any.
if callable.type_guard is not None:
type_guard = expand_type(callable.type_guard, id_to_type)
else:
type_guard = None
# The callable may retain some type vars if only some were applied.
remaining_tvars = [tv for tv in tvars if tv.id not in id_to_type]
return callable.copy_modified(
ret_type=expand_type(callable.ret_type, id_to_type),
variables=remaining_tvars,
type_guard=type_guard,
)