| from __future__ import annotations |
| |
| from functools import partial |
| from typing import Callable, Final |
| |
| import mypy.errorcodes as codes |
| from mypy import message_registry |
| from mypy.nodes import DictExpr, IntExpr, StrExpr, UnaryExpr |
| from mypy.plugin import ( |
| AttributeContext, |
| ClassDefContext, |
| FunctionContext, |
| FunctionSigContext, |
| MethodContext, |
| MethodSigContext, |
| Plugin, |
| ) |
| from mypy.plugins.common import try_getting_str_literals |
| from mypy.subtypes import is_subtype |
| from mypy.typeops import is_literal_type_like, make_simplified_union |
| from mypy.types import ( |
| TPDICT_FB_NAMES, |
| AnyType, |
| CallableType, |
| FunctionLike, |
| Instance, |
| LiteralType, |
| NoneType, |
| TupleType, |
| Type, |
| TypedDictType, |
| TypeOfAny, |
| TypeVarType, |
| UnionType, |
| get_proper_type, |
| get_proper_types, |
| ) |
| |
| |
| class DefaultPlugin(Plugin): |
| """Type checker plugin that is enabled by default.""" |
| |
| def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None: |
| from mypy.plugins import ctypes, enums, singledispatch |
| |
| if fullname == "_ctypes.Array": |
| return ctypes.array_constructor_callback |
| elif fullname == "functools.singledispatch": |
| return singledispatch.create_singledispatch_function_callback |
| elif fullname == "functools.partial": |
| import mypy.plugins.functools |
| |
| return mypy.plugins.functools.partial_new_callback |
| elif fullname == "enum.member": |
| return enums.enum_member_callback |
| |
| return None |
| |
| def get_function_signature_hook( |
| self, fullname: str |
| ) -> Callable[[FunctionSigContext], FunctionLike] | None: |
| from mypy.plugins import attrs, dataclasses |
| |
| if fullname in ("attr.evolve", "attrs.evolve", "attr.assoc", "attrs.assoc"): |
| return attrs.evolve_function_sig_callback |
| elif fullname in ("attr.fields", "attrs.fields"): |
| return attrs.fields_function_sig_callback |
| elif fullname == "dataclasses.replace": |
| return dataclasses.replace_function_sig_callback |
| return None |
| |
| def get_method_signature_hook( |
| self, fullname: str |
| ) -> Callable[[MethodSigContext], FunctionLike] | None: |
| from mypy.plugins import ctypes, singledispatch |
| |
| if fullname == "typing.Mapping.get": |
| return typed_dict_get_signature_callback |
| elif fullname in {n + ".setdefault" for n in TPDICT_FB_NAMES}: |
| return typed_dict_setdefault_signature_callback |
| elif fullname in {n + ".pop" for n in TPDICT_FB_NAMES}: |
| return typed_dict_pop_signature_callback |
| elif fullname == "_ctypes.Array.__setitem__": |
| return ctypes.array_setitem_callback |
| elif fullname == singledispatch.SINGLEDISPATCH_CALLABLE_CALL_METHOD: |
| return singledispatch.call_singledispatch_function_callback |
| |
| typed_dict_updates = set() |
| for n in TPDICT_FB_NAMES: |
| typed_dict_updates.add(n + ".update") |
| typed_dict_updates.add(n + ".__or__") |
| typed_dict_updates.add(n + ".__ror__") |
| typed_dict_updates.add(n + ".__ior__") |
| |
| if fullname in typed_dict_updates: |
| return typed_dict_update_signature_callback |
| |
| return None |
| |
| def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None: |
| from mypy.plugins import ctypes, singledispatch |
| |
| if fullname == "typing.Mapping.get": |
| return typed_dict_get_callback |
| elif fullname == "builtins.int.__pow__": |
| return int_pow_callback |
| elif fullname == "builtins.int.__neg__": |
| return int_neg_callback |
| elif fullname == "builtins.int.__pos__": |
| return int_pos_callback |
| elif fullname in ("builtins.tuple.__mul__", "builtins.tuple.__rmul__"): |
| return tuple_mul_callback |
| elif fullname in {n + ".setdefault" for n in TPDICT_FB_NAMES}: |
| return typed_dict_setdefault_callback |
| elif fullname in {n + ".pop" for n in TPDICT_FB_NAMES}: |
| return typed_dict_pop_callback |
| elif fullname in {n + ".__delitem__" for n in TPDICT_FB_NAMES}: |
| return typed_dict_delitem_callback |
| elif fullname == "_ctypes.Array.__getitem__": |
| return ctypes.array_getitem_callback |
| elif fullname == "_ctypes.Array.__iter__": |
| return ctypes.array_iter_callback |
| elif fullname == singledispatch.SINGLEDISPATCH_REGISTER_METHOD: |
| return singledispatch.singledispatch_register_callback |
| elif fullname == singledispatch.REGISTER_CALLABLE_CALL_METHOD: |
| return singledispatch.call_singledispatch_function_after_register_argument |
| elif fullname == "functools.partial.__call__": |
| import mypy.plugins.functools |
| |
| return mypy.plugins.functools.partial_call_callback |
| return None |
| |
| def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None: |
| from mypy.plugins import ctypes, enums |
| |
| if fullname == "_ctypes.Array.value": |
| return ctypes.array_value_callback |
| elif fullname == "_ctypes.Array.raw": |
| return ctypes.array_raw_callback |
| elif fullname in enums.ENUM_NAME_ACCESS: |
| return enums.enum_name_callback |
| elif fullname in enums.ENUM_VALUE_ACCESS: |
| return enums.enum_value_callback |
| return None |
| |
| def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None: |
| from mypy.plugins import attrs, dataclasses |
| |
| # These dataclass and attrs hooks run in the main semantic analysis pass |
| # and only tag known dataclasses/attrs classes, so that the second |
| # hooks (in get_class_decorator_hook_2) can detect dataclasses/attrs classes |
| # in the MRO. |
| if fullname in dataclasses.dataclass_makers: |
| return dataclasses.dataclass_tag_callback |
| if ( |
| fullname in attrs.attr_class_makers |
| or fullname in attrs.attr_dataclass_makers |
| or fullname in attrs.attr_frozen_makers |
| or fullname in attrs.attr_define_makers |
| ): |
| return attrs.attr_tag_callback |
| |
| return None |
| |
| def get_class_decorator_hook_2( |
| self, fullname: str |
| ) -> Callable[[ClassDefContext], bool] | None: |
| import mypy.plugins.functools |
| from mypy.plugins import attrs, dataclasses |
| |
| if fullname in dataclasses.dataclass_makers: |
| return dataclasses.dataclass_class_maker_callback |
| elif fullname in mypy.plugins.functools.functools_total_ordering_makers: |
| return mypy.plugins.functools.functools_total_ordering_maker_callback |
| elif fullname in attrs.attr_class_makers: |
| return attrs.attr_class_maker_callback |
| elif fullname in attrs.attr_dataclass_makers: |
| return partial(attrs.attr_class_maker_callback, auto_attribs_default=True) |
| elif fullname in attrs.attr_frozen_makers: |
| return partial( |
| attrs.attr_class_maker_callback, auto_attribs_default=None, frozen_default=True |
| ) |
| elif fullname in attrs.attr_define_makers: |
| return partial( |
| attrs.attr_class_maker_callback, auto_attribs_default=None, slots_default=True |
| ) |
| |
| return None |
| |
| |
| def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType: |
| """Try to infer a better signature type for TypedDict.get. |
| |
| This is used to get better type context for the second argument that |
| depends on a TypedDict value type. |
| """ |
| signature = ctx.default_signature |
| if ( |
| isinstance(ctx.type, TypedDictType) |
| and len(ctx.args) == 2 |
| and len(ctx.args[0]) == 1 |
| and isinstance(ctx.args[0][0], StrExpr) |
| and len(signature.arg_types) == 2 |
| and len(signature.variables) == 1 |
| and len(ctx.args[1]) == 1 |
| ): |
| key = ctx.args[0][0].value |
| value_type = get_proper_type(ctx.type.items.get(key)) |
| ret_type = signature.ret_type |
| if value_type: |
| default_arg = ctx.args[1][0] |
| if ( |
| isinstance(value_type, TypedDictType) |
| and isinstance(default_arg, DictExpr) |
| and len(default_arg.items) == 0 |
| ): |
| # Caller has empty dict {} as default for typed dict. |
| value_type = value_type.copy_modified(required_keys=set()) |
| # Tweak the signature to include the value type as context. It's |
| # only needed for type inference since there's a union with a type |
| # variable that accepts everything. |
| tv = signature.variables[0] |
| assert isinstance(tv, TypeVarType) |
| return signature.copy_modified( |
| arg_types=[signature.arg_types[0], make_simplified_union([value_type, tv])], |
| ret_type=ret_type, |
| ) |
| return signature |
| |
| |
| def typed_dict_get_callback(ctx: MethodContext) -> Type: |
| """Infer a precise return type for TypedDict.get with literal first argument.""" |
| if ( |
| isinstance(ctx.type, TypedDictType) |
| and len(ctx.arg_types) >= 1 |
| and len(ctx.arg_types[0]) == 1 |
| ): |
| keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0]) |
| if keys is None: |
| return ctx.default_return_type |
| |
| output_types: list[Type] = [] |
| for key in keys: |
| value_type = get_proper_type(ctx.type.items.get(key)) |
| if value_type is None: |
| return ctx.default_return_type |
| |
| if len(ctx.arg_types) == 1: |
| output_types.append(value_type) |
| elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1: |
| default_arg = ctx.args[1][0] |
| if ( |
| isinstance(default_arg, DictExpr) |
| and len(default_arg.items) == 0 |
| and isinstance(value_type, TypedDictType) |
| ): |
| # Special case '{}' as the default for a typed dict type. |
| output_types.append(value_type.copy_modified(required_keys=set())) |
| else: |
| output_types.append(value_type) |
| output_types.append(ctx.arg_types[1][0]) |
| |
| if len(ctx.arg_types) == 1: |
| output_types.append(NoneType()) |
| |
| return make_simplified_union(output_types) |
| return ctx.default_return_type |
| |
| |
| def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType: |
| """Try to infer a better signature type for TypedDict.pop. |
| |
| This is used to get better type context for the second argument that |
| depends on a TypedDict value type. |
| """ |
| signature = ctx.default_signature |
| str_type = ctx.api.named_generic_type("builtins.str", []) |
| if ( |
| isinstance(ctx.type, TypedDictType) |
| and len(ctx.args) == 2 |
| and len(ctx.args[0]) == 1 |
| and isinstance(ctx.args[0][0], StrExpr) |
| and len(signature.arg_types) == 2 |
| and len(signature.variables) == 1 |
| and len(ctx.args[1]) == 1 |
| ): |
| key = ctx.args[0][0].value |
| value_type = ctx.type.items.get(key) |
| if value_type: |
| # Tweak the signature to include the value type as context. It's |
| # only needed for type inference since there's a union with a type |
| # variable that accepts everything. |
| tv = signature.variables[0] |
| assert isinstance(tv, TypeVarType) |
| typ = make_simplified_union([value_type, tv]) |
| return signature.copy_modified(arg_types=[str_type, typ], ret_type=typ) |
| return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]]) |
| |
| |
| def typed_dict_pop_callback(ctx: MethodContext) -> Type: |
| """Type check and infer a precise return type for TypedDict.pop.""" |
| if ( |
| isinstance(ctx.type, TypedDictType) |
| and len(ctx.arg_types) >= 1 |
| and len(ctx.arg_types[0]) == 1 |
| ): |
| key_expr = ctx.args[0][0] |
| keys = try_getting_str_literals(key_expr, ctx.arg_types[0][0]) |
| if keys is None: |
| ctx.api.fail( |
| message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, |
| key_expr, |
| code=codes.LITERAL_REQ, |
| ) |
| return AnyType(TypeOfAny.from_error) |
| |
| value_types = [] |
| for key in keys: |
| if key in ctx.type.required_keys: |
| ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, key_expr) |
| |
| value_type = ctx.type.items.get(key) |
| if value_type: |
| value_types.append(value_type) |
| else: |
| ctx.api.msg.typeddict_key_not_found(ctx.type, key, key_expr) |
| return AnyType(TypeOfAny.from_error) |
| |
| if len(ctx.args[1]) == 0: |
| return make_simplified_union(value_types) |
| elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1: |
| return make_simplified_union([*value_types, ctx.arg_types[1][0]]) |
| return ctx.default_return_type |
| |
| |
| def typed_dict_setdefault_signature_callback(ctx: MethodSigContext) -> CallableType: |
| """Try to infer a better signature type for TypedDict.setdefault. |
| |
| This is used to get better type context for the second argument that |
| depends on a TypedDict value type. |
| """ |
| signature = ctx.default_signature |
| str_type = ctx.api.named_generic_type("builtins.str", []) |
| if ( |
| isinstance(ctx.type, TypedDictType) |
| and len(ctx.args) == 2 |
| and len(ctx.args[0]) == 1 |
| and isinstance(ctx.args[0][0], StrExpr) |
| and len(signature.arg_types) == 2 |
| and len(ctx.args[1]) == 1 |
| ): |
| key = ctx.args[0][0].value |
| value_type = ctx.type.items.get(key) |
| if value_type: |
| return signature.copy_modified(arg_types=[str_type, value_type]) |
| return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]]) |
| |
| |
| def typed_dict_setdefault_callback(ctx: MethodContext) -> Type: |
| """Type check TypedDict.setdefault and infer a precise return type.""" |
| if ( |
| isinstance(ctx.type, TypedDictType) |
| and len(ctx.arg_types) == 2 |
| and len(ctx.arg_types[0]) == 1 |
| and len(ctx.arg_types[1]) == 1 |
| ): |
| key_expr = ctx.args[0][0] |
| keys = try_getting_str_literals(key_expr, ctx.arg_types[0][0]) |
| if keys is None: |
| ctx.api.fail( |
| message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, |
| key_expr, |
| code=codes.LITERAL_REQ, |
| ) |
| return AnyType(TypeOfAny.from_error) |
| |
| assigned_readonly_keys = ctx.type.readonly_keys & set(keys) |
| if assigned_readonly_keys: |
| ctx.api.msg.readonly_keys_mutated(assigned_readonly_keys, context=key_expr) |
| |
| default_type = ctx.arg_types[1][0] |
| default_expr = ctx.args[1][0] |
| |
| value_types = [] |
| for key in keys: |
| value_type = ctx.type.items.get(key) |
| |
| if value_type is None: |
| ctx.api.msg.typeddict_key_not_found(ctx.type, key, key_expr) |
| return AnyType(TypeOfAny.from_error) |
| |
| # The signature_callback above can't always infer the right signature |
| # (e.g. when the expression is a variable that happens to be a Literal str) |
| # so we need to handle the check ourselves here and make sure the provided |
| # default can be assigned to all key-value pairs we're updating. |
| if not is_subtype(default_type, value_type): |
| ctx.api.msg.typeddict_setdefault_arguments_inconsistent( |
| default_type, value_type, default_expr |
| ) |
| return AnyType(TypeOfAny.from_error) |
| |
| value_types.append(value_type) |
| |
| return make_simplified_union(value_types) |
| return ctx.default_return_type |
| |
| |
| def typed_dict_delitem_callback(ctx: MethodContext) -> Type: |
| """Type check TypedDict.__delitem__.""" |
| if ( |
| isinstance(ctx.type, TypedDictType) |
| and len(ctx.arg_types) == 1 |
| and len(ctx.arg_types[0]) == 1 |
| ): |
| key_expr = ctx.args[0][0] |
| keys = try_getting_str_literals(key_expr, ctx.arg_types[0][0]) |
| if keys is None: |
| ctx.api.fail( |
| message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, |
| key_expr, |
| code=codes.LITERAL_REQ, |
| ) |
| return AnyType(TypeOfAny.from_error) |
| |
| for key in keys: |
| if key in ctx.type.required_keys or key in ctx.type.readonly_keys: |
| ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, key_expr) |
| elif key not in ctx.type.items: |
| ctx.api.msg.typeddict_key_not_found(ctx.type, key, key_expr) |
| return ctx.default_return_type |
| |
| |
| _TP_DICT_MUTATING_METHODS: Final = frozenset({"update of TypedDict", "__ior__ of TypedDict"}) |
| |
| |
| def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType: |
| """Try to infer a better signature type for methods that update `TypedDict`. |
| |
| This includes: `TypedDict.update`, `TypedDict.__or__`, `TypedDict.__ror__`, |
| and `TypedDict.__ior__`. |
| """ |
| signature = ctx.default_signature |
| if isinstance(ctx.type, TypedDictType) and len(signature.arg_types) == 1: |
| arg_type = get_proper_type(signature.arg_types[0]) |
| if not isinstance(arg_type, TypedDictType): |
| return signature |
| arg_type = arg_type.as_anonymous() |
| arg_type = arg_type.copy_modified(required_keys=set()) |
| if ctx.args and ctx.args[0]: |
| if signature.name in _TP_DICT_MUTATING_METHODS: |
| # If we want to mutate this object in place, we need to set this flag, |
| # it will trigger an extra check in TypedDict's checker. |
| arg_type.to_be_mutated = True |
| with ctx.api.msg.filter_errors( |
| filter_errors=lambda name, info: info.code != codes.TYPEDDICT_READONLY_MUTATED, |
| save_filtered_errors=True, |
| ): |
| inferred = get_proper_type( |
| ctx.api.get_expression_type(ctx.args[0][0], type_context=arg_type) |
| ) |
| if arg_type.to_be_mutated: |
| arg_type.to_be_mutated = False # Done! |
| possible_tds = [] |
| if isinstance(inferred, TypedDictType): |
| possible_tds = [inferred] |
| elif isinstance(inferred, UnionType): |
| possible_tds = [ |
| t |
| for t in get_proper_types(inferred.relevant_items()) |
| if isinstance(t, TypedDictType) |
| ] |
| items = [] |
| for td in possible_tds: |
| item = arg_type.copy_modified( |
| required_keys=(arg_type.required_keys | td.required_keys) |
| & arg_type.items.keys() |
| ) |
| if not ctx.api.options.extra_checks: |
| item = item.copy_modified(item_names=list(td.items)) |
| items.append(item) |
| if items: |
| arg_type = make_simplified_union(items) |
| return signature.copy_modified(arg_types=[arg_type]) |
| return signature |
| |
| |
| def int_pow_callback(ctx: MethodContext) -> Type: |
| """Infer a more precise return type for int.__pow__.""" |
| # int.__pow__ has an optional modulo argument, |
| # so we expect 2 argument positions |
| if len(ctx.arg_types) == 2 and len(ctx.arg_types[0]) == 1 and len(ctx.arg_types[1]) == 0: |
| arg = ctx.args[0][0] |
| if isinstance(arg, IntExpr): |
| exponent = arg.value |
| elif isinstance(arg, UnaryExpr) and arg.op == "-" and isinstance(arg.expr, IntExpr): |
| exponent = -arg.expr.value |
| else: |
| # Right operand not an int literal or a negated literal -- give up. |
| return ctx.default_return_type |
| if exponent >= 0: |
| return ctx.api.named_generic_type("builtins.int", []) |
| else: |
| return ctx.api.named_generic_type("builtins.float", []) |
| return ctx.default_return_type |
| |
| |
| def int_neg_callback(ctx: MethodContext, multiplier: int = -1) -> Type: |
| """Infer a more precise return type for int.__neg__ and int.__pos__. |
| |
| This is mainly used to infer the return type as LiteralType |
| if the original underlying object is a LiteralType object. |
| """ |
| if isinstance(ctx.type, Instance) and ctx.type.last_known_value is not None: |
| value = ctx.type.last_known_value.value |
| fallback = ctx.type.last_known_value.fallback |
| if isinstance(value, int): |
| if is_literal_type_like(ctx.api.type_context[-1]): |
| return LiteralType(value=multiplier * value, fallback=fallback) |
| else: |
| return ctx.type.copy_modified( |
| last_known_value=LiteralType( |
| value=multiplier * value, |
| fallback=fallback, |
| line=ctx.type.line, |
| column=ctx.type.column, |
| ) |
| ) |
| elif isinstance(ctx.type, LiteralType): |
| value = ctx.type.value |
| fallback = ctx.type.fallback |
| if isinstance(value, int): |
| return LiteralType(value=multiplier * value, fallback=fallback) |
| return ctx.default_return_type |
| |
| |
| def int_pos_callback(ctx: MethodContext) -> Type: |
| """Infer a more precise return type for int.__pos__. |
| |
| This is identical to __neg__, except the value is not inverted. |
| """ |
| return int_neg_callback(ctx, +1) |
| |
| |
| def tuple_mul_callback(ctx: MethodContext) -> Type: |
| """Infer a more precise return type for tuple.__mul__ and tuple.__rmul__. |
| |
| This is used to return a specific sized tuple if multiplied by Literal int |
| """ |
| if not isinstance(ctx.type, TupleType): |
| return ctx.default_return_type |
| |
| arg_type = get_proper_type(ctx.arg_types[0][0]) |
| if isinstance(arg_type, Instance) and arg_type.last_known_value is not None: |
| value = arg_type.last_known_value.value |
| if isinstance(value, int): |
| return ctx.type.copy_modified(items=ctx.type.items * value) |
| elif isinstance(arg_type, LiteralType): |
| value = arg_type.value |
| if isinstance(value, int): |
| return ctx.type.copy_modified(items=ctx.type.items * value) |
| |
| return ctx.default_return_type |