| from __future__ import annotations |
| |
| from typing import NamedTuple |
| |
| from mypy.argmap import map_actuals_to_formals |
| from mypy.fixup import TypeFixer |
| from mypy.nodes import ( |
| ARG_POS, |
| MDEF, |
| SYMBOL_FUNCBASE_TYPES, |
| Argument, |
| Block, |
| CallExpr, |
| ClassDef, |
| Decorator, |
| Expression, |
| FuncDef, |
| JsonDict, |
| NameExpr, |
| Node, |
| OverloadedFuncDef, |
| PassStmt, |
| RefExpr, |
| SymbolTableNode, |
| TypeInfo, |
| Var, |
| ) |
| from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface |
| from mypy.semanal_shared import ( |
| ALLOW_INCOMPATIBLE_OVERRIDE, |
| parse_bool, |
| require_bool_literal_argument, |
| set_callable_name, |
| ) |
| from mypy.typeops import try_getting_str_literals as try_getting_str_literals |
| from mypy.types import ( |
| AnyType, |
| CallableType, |
| Instance, |
| LiteralType, |
| NoneType, |
| Overloaded, |
| Type, |
| TypeOfAny, |
| TypeType, |
| TypeVarType, |
| deserialize_type, |
| get_proper_type, |
| ) |
| from mypy.types_utils import is_overlapping_none |
| from mypy.typevars import fill_typevars |
| from mypy.util import get_unique_redefinition_name |
| |
| |
| def _get_decorator_bool_argument(ctx: ClassDefContext, name: str, default: bool) -> bool: |
| """Return the bool argument for the decorator. |
| |
| This handles both @decorator(...) and @decorator. |
| """ |
| if isinstance(ctx.reason, CallExpr): |
| return _get_bool_argument(ctx, ctx.reason, name, default) |
| else: |
| return default |
| |
| |
| def _get_bool_argument(ctx: ClassDefContext, expr: CallExpr, name: str, default: bool) -> bool: |
| """Return the boolean value for an argument to a call or the |
| default if it's not found. |
| """ |
| attr_value = _get_argument(expr, name) |
| if attr_value: |
| return require_bool_literal_argument(ctx.api, attr_value, name, default) |
| return default |
| |
| |
| def _get_argument(call: CallExpr, name: str) -> Expression | None: |
| """Return the expression for the specific argument.""" |
| # To do this we use the CallableType of the callee to find the FormalArgument, |
| # then walk the actual CallExpr looking for the appropriate argument. |
| # |
| # Note: I'm not hard-coding the index so that in the future we can support other |
| # attrib and class makers. |
| callee_type = _get_callee_type(call) |
| if not callee_type: |
| return None |
| |
| argument = callee_type.argument_by_name(name) |
| if not argument: |
| return None |
| assert argument.name |
| |
| for i, (attr_name, attr_value) in enumerate(zip(call.arg_names, call.args)): |
| if argument.pos is not None and not attr_name and i == argument.pos: |
| return attr_value |
| if attr_name == argument.name: |
| return attr_value |
| |
| return None |
| |
| |
| def find_shallow_matching_overload_item(overload: Overloaded, call: CallExpr) -> CallableType: |
| """Perform limited lookup of a matching overload item. |
| |
| Full overload resolution is only supported during type checking, but plugins |
| sometimes need to resolve overloads. This can be used in some such use cases. |
| |
| Resolve overloads based on these things only: |
| |
| * Match using argument kinds and names |
| * If formal argument has type None, only accept the "None" expression in the callee |
| * If formal argument has type Literal[True] or Literal[False], only accept the |
| relevant bool literal |
| |
| Return the first matching overload item, or the last one if nothing matches. |
| """ |
| for item in overload.items[:-1]: |
| ok = True |
| mapped = map_actuals_to_formals( |
| call.arg_kinds, |
| call.arg_names, |
| item.arg_kinds, |
| item.arg_names, |
| lambda i: AnyType(TypeOfAny.special_form), |
| ) |
| |
| # Look for extra actuals |
| matched_actuals = set() |
| for actuals in mapped: |
| matched_actuals.update(actuals) |
| if any(i not in matched_actuals for i in range(len(call.args))): |
| ok = False |
| |
| for arg_type, kind, actuals in zip(item.arg_types, item.arg_kinds, mapped): |
| if kind.is_required() and not actuals: |
| # Missing required argument |
| ok = False |
| break |
| elif actuals: |
| args = [call.args[i] for i in actuals] |
| arg_type = get_proper_type(arg_type) |
| arg_none = any(isinstance(arg, NameExpr) and arg.name == "None" for arg in args) |
| if isinstance(arg_type, NoneType): |
| if not arg_none: |
| ok = False |
| break |
| elif ( |
| arg_none |
| and not is_overlapping_none(arg_type) |
| and not ( |
| isinstance(arg_type, Instance) |
| and arg_type.type.fullname == "builtins.object" |
| ) |
| and not isinstance(arg_type, AnyType) |
| ): |
| ok = False |
| break |
| elif isinstance(arg_type, LiteralType) and isinstance(arg_type.value, bool): |
| if not any(parse_bool(arg) == arg_type.value for arg in args): |
| ok = False |
| break |
| if ok: |
| return item |
| return overload.items[-1] |
| |
| |
| def _get_callee_type(call: CallExpr) -> CallableType | None: |
| """Return the type of the callee, regardless of its syntactic form.""" |
| |
| callee_node: Node | None = call.callee |
| |
| if isinstance(callee_node, RefExpr): |
| callee_node = callee_node.node |
| |
| # Some decorators may be using typing.dataclass_transform, which is itself a decorator, so we |
| # need to unwrap them to get at the true callee |
| if isinstance(callee_node, Decorator): |
| callee_node = callee_node.func |
| |
| if isinstance(callee_node, (Var, SYMBOL_FUNCBASE_TYPES)) and callee_node.type: |
| callee_node_type = get_proper_type(callee_node.type) |
| if isinstance(callee_node_type, Overloaded): |
| return find_shallow_matching_overload_item(callee_node_type, call) |
| elif isinstance(callee_node_type, CallableType): |
| return callee_node_type |
| |
| return None |
| |
| |
| def add_method( |
| ctx: ClassDefContext, |
| name: str, |
| args: list[Argument], |
| return_type: Type, |
| self_type: Type | None = None, |
| tvar_def: TypeVarType | None = None, |
| is_classmethod: bool = False, |
| is_staticmethod: bool = False, |
| ) -> None: |
| """ |
| Adds a new method to a class. |
| Deprecated, use add_method_to_class() instead. |
| """ |
| add_method_to_class( |
| ctx.api, |
| ctx.cls, |
| name=name, |
| args=args, |
| return_type=return_type, |
| self_type=self_type, |
| tvar_def=tvar_def, |
| is_classmethod=is_classmethod, |
| is_staticmethod=is_staticmethod, |
| ) |
| |
| |
| class MethodSpec(NamedTuple): |
| """Represents a method signature to be added, except for `name`.""" |
| |
| args: list[Argument] |
| return_type: Type |
| self_type: Type | None = None |
| tvar_defs: list[TypeVarType] | None = None |
| |
| |
| def add_method_to_class( |
| api: SemanticAnalyzerPluginInterface | CheckerPluginInterface, |
| cls: ClassDef, |
| name: str, |
| # MethodSpec items kept for backward compatibility: |
| args: list[Argument], |
| return_type: Type, |
| self_type: Type | None = None, |
| tvar_def: list[TypeVarType] | TypeVarType | None = None, |
| is_classmethod: bool = False, |
| is_staticmethod: bool = False, |
| ) -> FuncDef | Decorator: |
| """Adds a new method to a class definition.""" |
| _prepare_class_namespace(cls, name) |
| |
| if tvar_def is not None and not isinstance(tvar_def, list): |
| tvar_def = [tvar_def] |
| |
| func, sym = _add_method_by_spec( |
| api, |
| cls.info, |
| name, |
| MethodSpec(args=args, return_type=return_type, self_type=self_type, tvar_defs=tvar_def), |
| is_classmethod=is_classmethod, |
| is_staticmethod=is_staticmethod, |
| ) |
| cls.info.names[name] = sym |
| cls.info.defn.defs.body.append(func) |
| return func |
| |
| |
| def add_overloaded_method_to_class( |
| api: SemanticAnalyzerPluginInterface | CheckerPluginInterface, |
| cls: ClassDef, |
| name: str, |
| items: list[MethodSpec], |
| is_classmethod: bool = False, |
| is_staticmethod: bool = False, |
| ) -> OverloadedFuncDef: |
| """Adds a new overloaded method to a class definition.""" |
| assert len(items) >= 2, "Overloads must contain at least two cases" |
| |
| # Save old definition, if it exists. |
| _prepare_class_namespace(cls, name) |
| |
| # Create function bodies for each passed method spec. |
| funcs: list[Decorator | FuncDef] = [] |
| for item in items: |
| func, _sym = _add_method_by_spec( |
| api, |
| cls.info, |
| name=name, |
| spec=item, |
| is_classmethod=is_classmethod, |
| is_staticmethod=is_staticmethod, |
| ) |
| if isinstance(func, FuncDef): |
| var = Var(func.name, func.type) |
| var.set_line(func.line) |
| func.is_decorated = True |
| |
| deco = Decorator(func, [], var) |
| else: |
| deco = func |
| deco.is_overload = True |
| funcs.append(deco) |
| |
| # Create the final OverloadedFuncDef node: |
| overload_def = OverloadedFuncDef(funcs) |
| overload_def.info = cls.info |
| overload_def.is_class = is_classmethod |
| overload_def.is_static = is_staticmethod |
| sym = SymbolTableNode(MDEF, overload_def) |
| sym.plugin_generated = True |
| |
| cls.info.names[name] = sym |
| cls.info.defn.defs.body.append(overload_def) |
| return overload_def |
| |
| |
| def _prepare_class_namespace(cls: ClassDef, name: str) -> None: |
| info = cls.info |
| assert info |
| |
| # First remove any previously generated methods with the same name |
| # to avoid clashes and problems in the semantic analyzer. |
| if name in info.names: |
| sym = info.names[name] |
| if sym.plugin_generated and isinstance(sym.node, FuncDef): |
| cls.defs.body.remove(sym.node) |
| |
| # NOTE: we would like the plugin generated node to dominate, but we still |
| # need to keep any existing definitions so they get semantically analyzed. |
| if name in info.names: |
| # Get a nice unique name instead. |
| r_name = get_unique_redefinition_name(name, info.names) |
| info.names[r_name] = info.names[name] |
| |
| |
| def _add_method_by_spec( |
| api: SemanticAnalyzerPluginInterface | CheckerPluginInterface, |
| info: TypeInfo, |
| name: str, |
| spec: MethodSpec, |
| *, |
| is_classmethod: bool, |
| is_staticmethod: bool, |
| ) -> tuple[FuncDef | Decorator, SymbolTableNode]: |
| args, return_type, self_type, tvar_defs = spec |
| |
| assert not ( |
| is_classmethod is True and is_staticmethod is True |
| ), "Can't add a new method that's both staticmethod and classmethod." |
| |
| if isinstance(api, SemanticAnalyzerPluginInterface): |
| function_type = api.named_type("builtins.function") |
| else: |
| function_type = api.named_generic_type("builtins.function", []) |
| |
| if is_classmethod: |
| self_type = self_type or TypeType(fill_typevars(info)) |
| first = [Argument(Var("_cls"), self_type, None, ARG_POS, True)] |
| elif is_staticmethod: |
| first = [] |
| else: |
| self_type = self_type or fill_typevars(info) |
| first = [Argument(Var("self"), self_type, None, ARG_POS)] |
| args = first + args |
| |
| arg_types, arg_names, arg_kinds = [], [], [] |
| for arg in args: |
| assert arg.type_annotation, "All arguments must be fully typed." |
| arg_types.append(arg.type_annotation) |
| arg_names.append(arg.variable.name) |
| arg_kinds.append(arg.kind) |
| |
| signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type) |
| if tvar_defs: |
| signature.variables = tvar_defs |
| |
| func = FuncDef(name, args, Block([PassStmt()])) |
| func.info = info |
| func.type = set_callable_name(signature, func) |
| func.is_class = is_classmethod |
| func.is_static = is_staticmethod |
| func._fullname = info.fullname + "." + name |
| func.line = info.line |
| |
| # Add decorator for is_staticmethod. It's unnecessary for is_classmethod. |
| if is_staticmethod: |
| func.is_decorated = True |
| v = Var(name, func.type) |
| v.info = info |
| v._fullname = func._fullname |
| v.is_staticmethod = True |
| dec = Decorator(func, [], v) |
| dec.line = info.line |
| sym = SymbolTableNode(MDEF, dec) |
| sym.plugin_generated = True |
| return dec, sym |
| |
| sym = SymbolTableNode(MDEF, func) |
| sym.plugin_generated = True |
| return func, sym |
| |
| |
| def add_attribute_to_class( |
| api: SemanticAnalyzerPluginInterface, |
| cls: ClassDef, |
| name: str, |
| typ: Type, |
| final: bool = False, |
| no_serialize: bool = False, |
| override_allow_incompatible: bool = False, |
| fullname: str | None = None, |
| is_classvar: bool = False, |
| overwrite_existing: bool = False, |
| ) -> Var: |
| """ |
| Adds a new attribute to a class definition. |
| This currently only generates the symbol table entry and no corresponding AssignmentStatement |
| """ |
| info = cls.info |
| |
| # NOTE: we would like the plugin generated node to dominate, but we still |
| # need to keep any existing definitions so they get semantically analyzed. |
| if name in info.names and not overwrite_existing: |
| # Get a nice unique name instead. |
| r_name = get_unique_redefinition_name(name, info.names) |
| info.names[r_name] = info.names[name] |
| |
| node = Var(name, typ) |
| node.info = info |
| node.is_final = final |
| node.is_classvar = is_classvar |
| if name in ALLOW_INCOMPATIBLE_OVERRIDE: |
| node.allow_incompatible_override = True |
| else: |
| node.allow_incompatible_override = override_allow_incompatible |
| |
| if fullname: |
| node._fullname = fullname |
| else: |
| node._fullname = info.fullname + "." + name |
| |
| info.names[name] = SymbolTableNode( |
| MDEF, node, plugin_generated=True, no_serialize=no_serialize |
| ) |
| return node |
| |
| |
| def deserialize_and_fixup_type(data: str | JsonDict, api: SemanticAnalyzerPluginInterface) -> Type: |
| typ = deserialize_type(data) |
| typ.accept(TypeFixer(api.modules, allow_missing=False)) |
| return typ |