| from __future__ import annotations |
| |
| from collections.abc import Iterable, Sequence |
| from typing import Callable |
| |
| import mypy.subtypes |
| from mypy.erasetype import erase_typevars |
| from mypy.expandtype import expand_type |
| from mypy.nodes import Context, TypeInfo |
| from mypy.type_visitor import TypeTranslator |
| from mypy.typeops import get_all_type_vars |
| from mypy.types import ( |
| AnyType, |
| CallableType, |
| Instance, |
| Parameters, |
| ParamSpecFlavor, |
| ParamSpecType, |
| PartialType, |
| ProperType, |
| Type, |
| TypeAliasType, |
| TypeVarId, |
| TypeVarLikeType, |
| TypeVarTupleType, |
| TypeVarType, |
| UninhabitedType, |
| UnpackType, |
| get_proper_type, |
| remove_dups, |
| ) |
| |
| |
| def get_target_type( |
| tvar: TypeVarLikeType, |
| type: Type, |
| callable: CallableType, |
| report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None], |
| context: Context, |
| skip_unsatisfied: bool, |
| ) -> Type | None: |
| p_type = get_proper_type(type) |
| if isinstance(p_type, UninhabitedType) and tvar.has_default(): |
| return tvar.default |
| if isinstance(tvar, ParamSpecType): |
| return type |
| if isinstance(tvar, TypeVarTupleType): |
| return type |
| assert isinstance(tvar, TypeVarType) |
| values = tvar.values |
| if values: |
| if isinstance(p_type, AnyType): |
| return type |
| if isinstance(p_type, TypeVarType) and p_type.values: |
| # Allow substituting T1 for T if every allowed value of T1 |
| # is also a legal value of T. |
| if all(any(mypy.subtypes.is_same_type(v, v1) for v in values) for v1 in p_type.values): |
| return type |
| matching = [] |
| for value in values: |
| if mypy.subtypes.is_subtype(type, value): |
| matching.append(value) |
| if matching: |
| best = matching[0] |
| # If there are more than one matching value, we select the narrowest |
| for match in matching[1:]: |
| if mypy.subtypes.is_subtype(match, best): |
| best = match |
| return best |
| if skip_unsatisfied: |
| return None |
| report_incompatible_typevar_value(callable, type, tvar.name, context) |
| else: |
| upper_bound = tvar.upper_bound |
| if tvar.name == "Self": |
| # Internally constructed Self-types contain class type variables in upper bound, |
| # so we need to erase them to avoid false positives. This is safe because we do |
| # not support type variables in upper bounds of user defined types. |
| upper_bound = erase_typevars(upper_bound) |
| if not mypy.subtypes.is_subtype(type, upper_bound): |
| if skip_unsatisfied: |
| return None |
| report_incompatible_typevar_value(callable, type, tvar.name, context) |
| return type |
| |
| |
| def apply_generic_arguments( |
| callable: CallableType, |
| orig_types: Sequence[Type | None], |
| report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None], |
| context: Context, |
| skip_unsatisfied: bool = False, |
| ) -> CallableType: |
| """Apply generic type arguments to a callable type. |
| |
| For example, applying [int] to 'def [T] (T) -> T' results in |
| 'def (int) -> int'. |
| |
| Note that each type can be None; in this case, it will not be applied. |
| |
| If `skip_unsatisfied` is True, then just skip the types that don't satisfy type variable |
| bound or constraints, instead of giving an error. |
| """ |
| tvars = callable.variables |
| assert len(orig_types) <= len(tvars) |
| # Check that inferred type variable values are compatible with allowed |
| # values and bounds. Also, promote subtype values to allowed values. |
| # Create a map from type variable id to target type. |
| id_to_type: dict[TypeVarId, Type] = {} |
| |
| for tvar, type in zip(tvars, orig_types): |
| assert not isinstance(type, PartialType), "Internal error: must never apply partial type" |
| if type is None: |
| continue |
| |
| target_type = get_target_type( |
| tvar, type, callable, report_incompatible_typevar_value, context, skip_unsatisfied |
| ) |
| if target_type is not None: |
| id_to_type[tvar.id] = target_type |
| |
| # TODO: validate arg_kinds/arg_names for ParamSpec and TypeVarTuple replacements, |
| # not just type variable bounds above. |
| param_spec = callable.param_spec() |
| if param_spec is not None: |
| nt = id_to_type.get(param_spec.id) |
| if nt is not None: |
| # ParamSpec expansion is special-cased, so we need to always expand callable |
| # as a whole, not expanding arguments individually. |
| callable = expand_type(callable, id_to_type) |
| assert isinstance(callable, CallableType) |
| return callable.copy_modified( |
| variables=[tv for tv in tvars if tv.id not in id_to_type] |
| ) |
| |
| # Apply arguments to argument types. |
| var_arg = callable.var_arg() |
| if var_arg is not None and isinstance(var_arg.typ, UnpackType): |
| # Same as for ParamSpec, callable with variadic types needs to be expanded as a whole. |
| callable = expand_type(callable, id_to_type) |
| assert isinstance(callable, CallableType) |
| return callable.copy_modified(variables=[tv for tv in tvars if tv.id not in id_to_type]) |
| else: |
| callable = callable.copy_modified( |
| arg_types=[expand_type(at, id_to_type) for at in callable.arg_types] |
| ) |
| |
| # Apply arguments to TypeGuard and TypeIs if any. |
| if callable.type_guard is not None: |
| type_guard = expand_type(callable.type_guard, id_to_type) |
| else: |
| type_guard = None |
| if callable.type_is is not None: |
| type_is = expand_type(callable.type_is, id_to_type) |
| else: |
| type_is = None |
| |
| # The callable may retain some type vars if only some were applied. |
| # TODO: move apply_poly() logic here when new inference |
| # becomes universally used (i.e. in all passes + in unification). |
| # With this new logic we can actually *add* some new free variables. |
| remaining_tvars: list[TypeVarLikeType] = [] |
| for tv in tvars: |
| if tv.id in id_to_type: |
| continue |
| if not tv.has_default(): |
| remaining_tvars.append(tv) |
| continue |
| # TypeVarLike isn't in id_to_type mapping. |
| # Only expand the TypeVar default here. |
| typ = expand_type(tv, id_to_type) |
| assert isinstance(typ, TypeVarLikeType) |
| remaining_tvars.append(typ) |
| |
| return callable.copy_modified( |
| ret_type=expand_type(callable.ret_type, id_to_type), |
| variables=remaining_tvars, |
| type_guard=type_guard, |
| type_is=type_is, |
| ) |
| |
| |
| def apply_poly(tp: CallableType, poly_tvars: Sequence[TypeVarLikeType]) -> CallableType | None: |
| """Make free type variables generic in the type if possible. |
| |
| This will translate the type `tp` while trying to create valid bindings for |
| type variables `poly_tvars` while traversing the type. This follows the same rules |
| as we do during semantic analysis phase, examples: |
| * Callable[Callable[[T], T], T] -> def [T] (def (T) -> T) -> T |
| * Callable[[], Callable[[T], T]] -> def () -> def [T] (T -> T) |
| * List[T] -> None (not possible) |
| """ |
| try: |
| return tp.copy_modified( |
| arg_types=[t.accept(PolyTranslator(poly_tvars)) for t in tp.arg_types], |
| ret_type=tp.ret_type.accept(PolyTranslator(poly_tvars)), |
| variables=[], |
| ) |
| except PolyTranslationError: |
| return None |
| |
| |
| class PolyTranslationError(Exception): |
| pass |
| |
| |
| class PolyTranslator(TypeTranslator): |
| """Make free type variables generic in the type if possible. |
| |
| See docstring for apply_poly() for details. |
| """ |
| |
| def __init__( |
| self, |
| poly_tvars: Iterable[TypeVarLikeType], |
| bound_tvars: frozenset[TypeVarLikeType] = frozenset(), |
| seen_aliases: frozenset[TypeInfo] = frozenset(), |
| ) -> None: |
| super().__init__() |
| self.poly_tvars = set(poly_tvars) |
| # This is a simplified version of TypeVarScope used during semantic analysis. |
| self.bound_tvars = bound_tvars |
| self.seen_aliases = seen_aliases |
| |
| def collect_vars(self, t: CallableType | Parameters) -> list[TypeVarLikeType]: |
| found_vars = [] |
| for arg in t.arg_types: |
| for tv in get_all_type_vars(arg): |
| if isinstance(tv, ParamSpecType): |
| normalized: TypeVarLikeType = tv.copy_modified( |
| flavor=ParamSpecFlavor.BARE, prefix=Parameters([], [], []) |
| ) |
| else: |
| normalized = tv |
| if normalized in self.poly_tvars and normalized not in self.bound_tvars: |
| found_vars.append(normalized) |
| return remove_dups(found_vars) |
| |
| def visit_callable_type(self, t: CallableType) -> Type: |
| found_vars = self.collect_vars(t) |
| self.bound_tvars |= set(found_vars) |
| result = super().visit_callable_type(t) |
| self.bound_tvars -= set(found_vars) |
| |
| assert isinstance(result, ProperType) and isinstance(result, CallableType) |
| result.variables = list(result.variables) + found_vars |
| return result |
| |
| def visit_type_var(self, t: TypeVarType) -> Type: |
| if t in self.poly_tvars and t not in self.bound_tvars: |
| raise PolyTranslationError() |
| return super().visit_type_var(t) |
| |
| def visit_param_spec(self, t: ParamSpecType) -> Type: |
| if t in self.poly_tvars and t not in self.bound_tvars: |
| raise PolyTranslationError() |
| return super().visit_param_spec(t) |
| |
| def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type: |
| if t in self.poly_tvars and t not in self.bound_tvars: |
| raise PolyTranslationError() |
| return super().visit_type_var_tuple(t) |
| |
| def visit_type_alias_type(self, t: TypeAliasType) -> Type: |
| if not t.args: |
| return t.copy_modified() |
| if not t.is_recursive: |
| return get_proper_type(t).accept(self) |
| # We can't handle polymorphic application for recursive generic aliases |
| # without risking an infinite recursion, just give up for now. |
| raise PolyTranslationError() |
| |
| def visit_instance(self, t: Instance) -> Type: |
| if t.type.has_param_spec_type: |
| # We need this special-casing to preserve the possibility to store a |
| # generic function in an instance type. Things like |
| # forall T . Foo[[x: T], T] |
| # are not really expressible in current type system, but this looks like |
| # a useful feature, so let's keep it. |
| param_spec_index = next( |
| i for (i, tv) in enumerate(t.type.defn.type_vars) if isinstance(tv, ParamSpecType) |
| ) |
| p = get_proper_type(t.args[param_spec_index]) |
| if isinstance(p, Parameters): |
| found_vars = self.collect_vars(p) |
| self.bound_tvars |= set(found_vars) |
| new_args = [a.accept(self) for a in t.args] |
| self.bound_tvars -= set(found_vars) |
| |
| repl = new_args[param_spec_index] |
| assert isinstance(repl, ProperType) and isinstance(repl, Parameters) |
| repl.variables = list(repl.variables) + list(found_vars) |
| return t.copy_modified(args=new_args) |
| # There is the same problem with callback protocols as with aliases |
| # (callback protocols are essentially more flexible aliases to callables). |
| if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]: |
| if t.type in self.seen_aliases: |
| raise PolyTranslationError() |
| call = mypy.subtypes.find_member("__call__", t, t, is_operator=True) |
| assert call is not None |
| return call.accept( |
| PolyTranslator(self.poly_tvars, self.bound_tvars, self.seen_aliases | {t.type}) |
| ) |
| return super().visit_instance(t) |