| """Plugin that provides support for dataclasses.""" |
| |
| from __future__ import annotations |
| |
| from collections.abc import Iterator |
| from typing import TYPE_CHECKING, Final, Literal |
| |
| from mypy import errorcodes, message_registry |
| from mypy.expandtype import expand_type, expand_type_by_instance |
| from mypy.meet import meet_types |
| from mypy.messages import format_type_bare |
| from mypy.nodes import ( |
| ARG_NAMED, |
| ARG_NAMED_OPT, |
| ARG_OPT, |
| ARG_POS, |
| ARG_STAR, |
| ARG_STAR2, |
| MDEF, |
| Argument, |
| AssignmentStmt, |
| Block, |
| CallExpr, |
| ClassDef, |
| Context, |
| DataclassTransformSpec, |
| Decorator, |
| EllipsisExpr, |
| Expression, |
| FuncDef, |
| FuncItem, |
| IfStmt, |
| JsonDict, |
| NameExpr, |
| Node, |
| PlaceholderNode, |
| RefExpr, |
| Statement, |
| SymbolTableNode, |
| TempNode, |
| TypeAlias, |
| TypeInfo, |
| TypeVarExpr, |
| Var, |
| ) |
| from mypy.plugin import ClassDefContext, FunctionSigContext, SemanticAnalyzerPluginInterface |
| from mypy.plugins.common import ( |
| _get_callee_type, |
| _get_decorator_bool_argument, |
| add_attribute_to_class, |
| add_method_to_class, |
| deserialize_and_fixup_type, |
| ) |
| from mypy.semanal_shared import find_dataclass_transform_spec, require_bool_literal_argument |
| from mypy.server.trigger import make_wildcard_trigger |
| from mypy.state import state |
| from mypy.typeops import map_type_from_supertype, try_getting_literals_from_type |
| from mypy.types import ( |
| AnyType, |
| CallableType, |
| FunctionLike, |
| Instance, |
| LiteralType, |
| NoneType, |
| ProperType, |
| TupleType, |
| Type, |
| TypeOfAny, |
| TypeVarId, |
| TypeVarType, |
| UninhabitedType, |
| UnionType, |
| get_proper_type, |
| ) |
| from mypy.typevars import fill_typevars |
| |
| if TYPE_CHECKING: |
| from mypy.checker import TypeChecker |
| |
| # The set of decorators that generate dataclasses. |
| dataclass_makers: Final = {"dataclass", "dataclasses.dataclass"} |
| # Default field specifiers for dataclasses |
| DATACLASS_FIELD_SPECIFIERS: Final = ("dataclasses.Field", "dataclasses.field") |
| |
| |
| SELF_TVAR_NAME: Final = "_DT" |
| _TRANSFORM_SPEC_FOR_DATACLASSES: Final = DataclassTransformSpec( |
| eq_default=True, |
| order_default=False, |
| kw_only_default=False, |
| frozen_default=False, |
| field_specifiers=DATACLASS_FIELD_SPECIFIERS, |
| ) |
| _INTERNAL_REPLACE_SYM_NAME: Final = "__mypy-replace" |
| _INTERNAL_POST_INIT_SYM_NAME: Final = "__mypy-post_init" |
| |
| |
| class DataclassAttribute: |
| def __init__( |
| self, |
| name: str, |
| alias: str | None, |
| is_in_init: bool, |
| is_init_var: bool, |
| has_default: bool, |
| line: int, |
| column: int, |
| type: Type | None, |
| info: TypeInfo, |
| kw_only: bool, |
| is_neither_frozen_nor_nonfrozen: bool, |
| api: SemanticAnalyzerPluginInterface, |
| ) -> None: |
| self.name = name |
| self.alias = alias |
| self.is_in_init = is_in_init |
| self.is_init_var = is_init_var |
| self.has_default = has_default |
| self.line = line |
| self.column = column |
| self.type = type # Type as __init__ argument |
| self.info = info |
| self.kw_only = kw_only |
| self.is_neither_frozen_nor_nonfrozen = is_neither_frozen_nor_nonfrozen |
| self._api = api |
| |
| def to_argument( |
| self, current_info: TypeInfo, *, of: Literal["__init__", "replace", "__post_init__"] |
| ) -> Argument: |
| if of == "__init__": |
| arg_kind = ARG_POS |
| if self.kw_only and self.has_default: |
| arg_kind = ARG_NAMED_OPT |
| elif self.kw_only and not self.has_default: |
| arg_kind = ARG_NAMED |
| elif not self.kw_only and self.has_default: |
| arg_kind = ARG_OPT |
| elif of == "replace": |
| arg_kind = ARG_NAMED if self.is_init_var and not self.has_default else ARG_NAMED_OPT |
| elif of == "__post_init__": |
| # We always use `ARG_POS` without a default value, because it is practical. |
| # Consider this case: |
| # |
| # @dataclass |
| # class My: |
| # y: dataclasses.InitVar[str] = 'a' |
| # def __post_init__(self, y: str) -> None: ... |
| # |
| # We would be *required* to specify `y: str = ...` if default is added here. |
| # But, most people won't care about adding default values to `__post_init__`, |
| # because it is not designed to be called directly, and duplicating default values |
| # for the sake of type-checking is unpleasant. |
| arg_kind = ARG_POS |
| return Argument( |
| variable=self.to_var(current_info), |
| type_annotation=self.expand_type(current_info), |
| initializer=EllipsisExpr() if self.has_default else None, # Only used by stubgen |
| kind=arg_kind, |
| ) |
| |
| def expand_type(self, current_info: TypeInfo) -> Type | None: |
| if self.type is not None and self.info.self_type is not None: |
| # In general, it is not safe to call `expand_type()` during semantic analysis, |
| # however this plugin is called very late, so all types should be fully ready. |
| # Also, it is tricky to avoid eager expansion of Self types here (e.g. because |
| # we serialize attributes). |
| with state.strict_optional_set(self._api.options.strict_optional): |
| return expand_type( |
| self.type, {self.info.self_type.id: fill_typevars(current_info)} |
| ) |
| return self.type |
| |
| def to_var(self, current_info: TypeInfo) -> Var: |
| return Var(self.alias or self.name, self.expand_type(current_info)) |
| |
| def serialize(self) -> JsonDict: |
| assert self.type |
| return { |
| "name": self.name, |
| "alias": self.alias, |
| "is_in_init": self.is_in_init, |
| "is_init_var": self.is_init_var, |
| "has_default": self.has_default, |
| "line": self.line, |
| "column": self.column, |
| "type": self.type.serialize(), |
| "kw_only": self.kw_only, |
| "is_neither_frozen_nor_nonfrozen": self.is_neither_frozen_nor_nonfrozen, |
| } |
| |
| @classmethod |
| def deserialize( |
| cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface |
| ) -> DataclassAttribute: |
| data = data.copy() |
| typ = deserialize_and_fixup_type(data.pop("type"), api) |
| return cls(type=typ, info=info, **data, api=api) |
| |
| def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: |
| """Expands type vars in the context of a subtype when an attribute is inherited |
| from a generic super type.""" |
| if self.type is not None: |
| with state.strict_optional_set(self._api.options.strict_optional): |
| self.type = map_type_from_supertype(self.type, sub_type, self.info) |
| |
| |
| class DataclassTransformer: |
| """Implement the behavior of @dataclass. |
| |
| Note that this may be executed multiple times on the same class, so |
| everything here must be idempotent. |
| |
| This runs after the main semantic analysis pass, so you can assume that |
| there are no placeholders. |
| """ |
| |
| def __init__( |
| self, |
| cls: ClassDef, |
| # Statement must also be accepted since class definition itself may be passed as the reason |
| # for subclass/metaclass-based uses of `typing.dataclass_transform` |
| reason: Expression | Statement, |
| spec: DataclassTransformSpec, |
| api: SemanticAnalyzerPluginInterface, |
| ) -> None: |
| self._cls = cls |
| self._reason = reason |
| self._spec = spec |
| self._api = api |
| |
| def transform(self) -> bool: |
| """Apply all the necessary transformations to the underlying |
| dataclass so as to ensure it is fully type checked according |
| to the rules in PEP 557. |
| """ |
| info = self._cls.info |
| attributes = self.collect_attributes() |
| if attributes is None: |
| # Some definitions are not ready. We need another pass. |
| return False |
| for attr in attributes: |
| if attr.type is None: |
| return False |
| decorator_arguments = { |
| "init": self._get_bool_arg("init", True), |
| "eq": self._get_bool_arg("eq", self._spec.eq_default), |
| "order": self._get_bool_arg("order", self._spec.order_default), |
| "frozen": self._get_bool_arg("frozen", self._spec.frozen_default), |
| "slots": self._get_bool_arg("slots", False), |
| "match_args": self._get_bool_arg("match_args", True), |
| } |
| py_version = self._api.options.python_version |
| |
| # If there are no attributes, it may be that the semantic analyzer has not |
| # processed them yet. In order to work around this, we can simply skip generating |
| # __init__ if there are no attributes, because if the user truly did not define any, |
| # then the object default __init__ with an empty signature will be present anyway. |
| if ( |
| decorator_arguments["init"] |
| and ("__init__" not in info.names or info.names["__init__"].plugin_generated) |
| and attributes |
| ): |
| args = [ |
| attr.to_argument(info, of="__init__") |
| for attr in attributes |
| if attr.is_in_init and not self._is_kw_only_type(attr.type) |
| ] |
| |
| if info.fallback_to_any: |
| # Make positional args optional since we don't know their order. |
| # This will at least allow us to typecheck them if they are called |
| # as kwargs |
| for arg in args: |
| if arg.kind == ARG_POS: |
| arg.kind = ARG_OPT |
| |
| existing_args_names = {arg.variable.name for arg in args} |
| gen_args_name = "generated_args" |
| while gen_args_name in existing_args_names: |
| gen_args_name += "_" |
| gen_kwargs_name = "generated_kwargs" |
| while gen_kwargs_name in existing_args_names: |
| gen_kwargs_name += "_" |
| args = [ |
| Argument(Var(gen_args_name), AnyType(TypeOfAny.explicit), None, ARG_STAR), |
| *args, |
| Argument(Var(gen_kwargs_name), AnyType(TypeOfAny.explicit), None, ARG_STAR2), |
| ] |
| |
| add_method_to_class( |
| self._api, self._cls, "__init__", args=args, return_type=NoneType() |
| ) |
| |
| if ( |
| decorator_arguments["eq"] |
| and info.get("__eq__") is None |
| or decorator_arguments["order"] |
| ): |
| # Type variable for self types in generated methods. |
| obj_type = self._api.named_type("builtins.object") |
| self_tvar_expr = TypeVarExpr( |
| SELF_TVAR_NAME, |
| info.fullname + "." + SELF_TVAR_NAME, |
| [], |
| obj_type, |
| AnyType(TypeOfAny.from_omitted_generics), |
| ) |
| info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr) |
| |
| # Add <, >, <=, >=, but only if the class has an eq method. |
| if decorator_arguments["order"]: |
| if not decorator_arguments["eq"]: |
| self._api.fail('"eq" must be True if "order" is True', self._reason) |
| |
| for method_name in ["__lt__", "__gt__", "__le__", "__ge__"]: |
| # Like for __eq__ and __ne__, we want "other" to match |
| # the self type. |
| obj_type = self._api.named_type("builtins.object") |
| order_tvar_def = TypeVarType( |
| SELF_TVAR_NAME, |
| f"{info.fullname}.{SELF_TVAR_NAME}", |
| id=TypeVarId(-1, namespace=f"{info.fullname}.{method_name}"), |
| values=[], |
| upper_bound=obj_type, |
| default=AnyType(TypeOfAny.from_omitted_generics), |
| ) |
| order_return_type = self._api.named_type("builtins.bool") |
| order_args = [ |
| Argument(Var("other", order_tvar_def), order_tvar_def, None, ARG_POS) |
| ] |
| |
| existing_method = info.get(method_name) |
| if existing_method is not None and not existing_method.plugin_generated: |
| assert existing_method.node |
| self._api.fail( |
| f'You may not have a custom "{method_name}" method when "order" is True', |
| existing_method.node, |
| ) |
| |
| add_method_to_class( |
| self._api, |
| self._cls, |
| method_name, |
| args=order_args, |
| return_type=order_return_type, |
| self_type=order_tvar_def, |
| tvar_def=order_tvar_def, |
| ) |
| |
| parent_decorator_arguments = [] |
| for parent in info.mro[1:-1]: |
| parent_args = parent.metadata.get("dataclass") |
| |
| # Ignore parent classes that directly specify a dataclass transform-decorated metaclass |
| # when searching for usage of the frozen parameter. PEP 681 states that a class that |
| # directly specifies such a metaclass must be treated as neither frozen nor non-frozen. |
| if parent_args and not _has_direct_dataclass_transform_metaclass(parent): |
| parent_decorator_arguments.append(parent_args) |
| |
| if decorator_arguments["frozen"]: |
| if any(not parent["frozen"] for parent in parent_decorator_arguments): |
| self._api.fail("Frozen dataclass cannot inherit from a non-frozen dataclass", info) |
| self._propertize_callables(attributes, settable=False) |
| self._freeze(attributes) |
| else: |
| if any(parent["frozen"] for parent in parent_decorator_arguments): |
| self._api.fail("Non-frozen dataclass cannot inherit from a frozen dataclass", info) |
| self._propertize_callables(attributes) |
| |
| if decorator_arguments["slots"]: |
| self.add_slots(info, attributes, correct_version=py_version >= (3, 10)) |
| |
| self.reset_init_only_vars(info, attributes) |
| |
| if ( |
| decorator_arguments["match_args"] |
| and ( |
| "__match_args__" not in info.names or info.names["__match_args__"].plugin_generated |
| ) |
| and py_version >= (3, 10) |
| ): |
| str_type = self._api.named_type("builtins.str") |
| literals: list[Type] = [ |
| LiteralType(attr.name, str_type) |
| for attr in attributes |
| if attr.is_in_init and not attr.kw_only |
| ] |
| match_args_type = TupleType(literals, self._api.named_type("builtins.tuple")) |
| add_attribute_to_class(self._api, self._cls, "__match_args__", match_args_type) |
| |
| self._add_dataclass_fields_magic_attribute() |
| self._add_internal_replace_method(attributes) |
| if self._api.options.python_version >= (3, 13): |
| self._add_dunder_replace(attributes) |
| |
| if "__post_init__" in info.names: |
| self._add_internal_post_init_method(attributes) |
| |
| info.metadata["dataclass"] = { |
| "attributes": [attr.serialize() for attr in attributes], |
| "frozen": decorator_arguments["frozen"], |
| } |
| |
| return True |
| |
| def _add_dunder_replace(self, attributes: list[DataclassAttribute]) -> None: |
| """Add a `__replace__` method to the class, which is used to replace attributes in the `copy` module.""" |
| args = [ |
| attr.to_argument(self._cls.info, of="replace") |
| for attr in attributes |
| if attr.is_in_init |
| ] |
| type_vars = [tv for tv in self._cls.type_vars] |
| add_method_to_class( |
| self._api, |
| self._cls, |
| "__replace__", |
| args=args, |
| return_type=Instance(self._cls.info, type_vars), |
| ) |
| |
| def _add_internal_replace_method(self, attributes: list[DataclassAttribute]) -> None: |
| """ |
| Stashes the signature of 'dataclasses.replace(...)' for this specific dataclass |
| to be used later whenever 'dataclasses.replace' is called for this dataclass. |
| """ |
| add_method_to_class( |
| self._api, |
| self._cls, |
| _INTERNAL_REPLACE_SYM_NAME, |
| args=[attr.to_argument(self._cls.info, of="replace") for attr in attributes], |
| return_type=NoneType(), |
| is_staticmethod=True, |
| ) |
| |
| def _add_internal_post_init_method(self, attributes: list[DataclassAttribute]) -> None: |
| add_method_to_class( |
| self._api, |
| self._cls, |
| _INTERNAL_POST_INIT_SYM_NAME, |
| args=[ |
| attr.to_argument(self._cls.info, of="__post_init__") |
| for attr in attributes |
| if attr.is_init_var |
| ], |
| return_type=NoneType(), |
| ) |
| |
| def add_slots( |
| self, info: TypeInfo, attributes: list[DataclassAttribute], *, correct_version: bool |
| ) -> None: |
| if not correct_version: |
| # This means that version is lower than `3.10`, |
| # it is just a non-existent argument for `dataclass` function. |
| self._api.fail( |
| 'Keyword argument "slots" for "dataclass" is only valid in Python 3.10 and higher', |
| self._reason, |
| ) |
| return |
| |
| generated_slots = {attr.name for attr in attributes} |
| if (info.slots is not None and info.slots != generated_slots) or info.names.get( |
| "__slots__" |
| ): |
| # This means we have a slots conflict. |
| # Class explicitly specifies a different `__slots__` field. |
| # And `@dataclass(slots=True)` is used. |
| # In runtime this raises a type error. |
| self._api.fail( |
| '"{}" both defines "__slots__" and is used with "slots=True"'.format( |
| self._cls.name |
| ), |
| self._cls, |
| ) |
| return |
| |
| if any(p.slots is None for p in info.mro[1:-1]): |
| # At least one type in mro (excluding `self` and `object`) |
| # does not have concrete `__slots__` defined. Ignoring. |
| return |
| |
| info.slots = generated_slots |
| |
| # Now, insert `.__slots__` attribute to class namespace: |
| slots_type = TupleType( |
| [self._api.named_type("builtins.str") for _ in generated_slots], |
| self._api.named_type("builtins.tuple"), |
| ) |
| add_attribute_to_class(self._api, self._cls, "__slots__", slots_type) |
| |
| def reset_init_only_vars(self, info: TypeInfo, attributes: list[DataclassAttribute]) -> None: |
| """Remove init-only vars from the class and reset init var declarations.""" |
| for attr in attributes: |
| if attr.is_init_var: |
| if attr.name in info.names: |
| del info.names[attr.name] |
| else: |
| # Nodes of superclass InitVars not used in __init__ cannot be reached. |
| assert attr.is_init_var |
| for stmt in info.defn.defs.body: |
| if isinstance(stmt, AssignmentStmt) and stmt.unanalyzed_type: |
| lvalue = stmt.lvalues[0] |
| if isinstance(lvalue, NameExpr) and lvalue.name == attr.name: |
| # Reset node so that another semantic analysis pass will |
| # recreate a symbol node for this attribute. |
| lvalue.node = None |
| |
| def _get_assignment_statements_from_if_statement( |
| self, stmt: IfStmt |
| ) -> Iterator[AssignmentStmt]: |
| for body in stmt.body: |
| if not body.is_unreachable: |
| yield from self._get_assignment_statements_from_block(body) |
| if stmt.else_body is not None and not stmt.else_body.is_unreachable: |
| yield from self._get_assignment_statements_from_block(stmt.else_body) |
| |
| def _get_assignment_statements_from_block(self, block: Block) -> Iterator[AssignmentStmt]: |
| for stmt in block.body: |
| if isinstance(stmt, AssignmentStmt): |
| yield stmt |
| elif isinstance(stmt, IfStmt): |
| yield from self._get_assignment_statements_from_if_statement(stmt) |
| |
| def collect_attributes(self) -> list[DataclassAttribute] | None: |
| """Collect all attributes declared in the dataclass and its parents. |
| |
| All assignments of the form |
| |
| a: SomeType |
| b: SomeOtherType = ... |
| |
| are collected. |
| |
| Return None if some dataclass base class hasn't been processed |
| yet and thus we'll need to ask for another pass. |
| """ |
| cls = self._cls |
| |
| # First, collect attributes belonging to any class in the MRO, ignoring duplicates. |
| # |
| # We iterate through the MRO in reverse because attrs defined in the parent must appear |
| # earlier in the attributes list than attrs defined in the child. See: |
| # https://docs.python.org/3/library/dataclasses.html#inheritance |
| # |
| # However, we also want attributes defined in the subtype to override ones defined |
| # in the parent. We can implement this via a dict without disrupting the attr order |
| # because dicts preserve insertion order in Python 3.7+. |
| found_attrs: dict[str, DataclassAttribute] = {} |
| found_dataclass_supertype = False |
| for info in reversed(cls.info.mro[1:-1]): |
| if "dataclass_tag" in info.metadata and "dataclass" not in info.metadata: |
| # We haven't processed the base class yet. Need another pass. |
| return None |
| if "dataclass" not in info.metadata: |
| continue |
| |
| # Each class depends on the set of attributes in its dataclass ancestors. |
| self._api.add_plugin_dependency(make_wildcard_trigger(info.fullname)) |
| found_dataclass_supertype = True |
| |
| for data in info.metadata["dataclass"]["attributes"]: |
| name: str = data["name"] |
| |
| attr = DataclassAttribute.deserialize(info, data, self._api) |
| # TODO: We shouldn't be performing type operations during the main |
| # semantic analysis pass, since some TypeInfo attributes might |
| # still be in flux. This should be performed in a later phase. |
| attr.expand_typevar_from_subtype(cls.info) |
| found_attrs[name] = attr |
| |
| sym_node = cls.info.names.get(name) |
| if sym_node and sym_node.node and not isinstance(sym_node.node, Var): |
| self._api.fail( |
| "Dataclass attribute may only be overridden by another attribute", |
| sym_node.node, |
| ) |
| |
| # Second, collect attributes belonging to the current class. |
| current_attr_names: set[str] = set() |
| kw_only = self._get_bool_arg("kw_only", self._spec.kw_only_default) |
| for stmt in self._get_assignment_statements_from_block(cls.defs): |
| # Any assignment that doesn't use the new type declaration |
| # syntax can be ignored out of hand. |
| if not stmt.new_syntax: |
| continue |
| |
| # a: int, b: str = 1, 'foo' is not supported syntax so we |
| # don't have to worry about it. |
| lhs = stmt.lvalues[0] |
| if not isinstance(lhs, NameExpr): |
| continue |
| |
| sym = cls.info.names.get(lhs.name) |
| if sym is None: |
| # There was probably a semantic analysis error. |
| continue |
| |
| node = sym.node |
| assert not isinstance(node, PlaceholderNode) |
| |
| if isinstance(node, TypeAlias): |
| self._api.fail( |
| ("Type aliases inside dataclass definitions are not supported at runtime"), |
| node, |
| ) |
| # Skip processing this node. This doesn't match the runtime behaviour, |
| # but the only alternative would be to modify the SymbolTable, |
| # and it's a little hairy to do that in a plugin. |
| continue |
| if isinstance(node, Decorator): |
| # This might be a property / field name clash. |
| # We will issue an error later. |
| continue |
| |
| assert isinstance(node, Var) |
| |
| # x: ClassVar[int] is ignored by dataclasses. |
| if node.is_classvar: |
| continue |
| |
| # x: InitVar[int] is turned into x: int and is removed from the class. |
| is_init_var = False |
| node_type = get_proper_type(node.type) |
| if ( |
| isinstance(node_type, Instance) |
| and node_type.type.fullname == "dataclasses.InitVar" |
| ): |
| is_init_var = True |
| node.type = node_type.args[0] |
| |
| if self._is_kw_only_type(node_type): |
| kw_only = True |
| |
| has_field_call, field_args = self._collect_field_args(stmt.rvalue) |
| |
| is_in_init_param = field_args.get("init") |
| if is_in_init_param is None: |
| is_in_init = self._get_default_init_value_for_field_specifier(stmt.rvalue) |
| else: |
| is_in_init = bool(self._api.parse_bool(is_in_init_param)) |
| |
| has_default = False |
| # Ensure that something like x: int = field() is rejected |
| # after an attribute with a default. |
| if has_field_call: |
| has_default = ( |
| "default" in field_args |
| or "default_factory" in field_args |
| # alias for default_factory defined in PEP 681 |
| or "factory" in field_args |
| ) |
| |
| # All other assignments are already type checked. |
| elif not isinstance(stmt.rvalue, TempNode): |
| has_default = True |
| |
| if not has_default and self._spec is _TRANSFORM_SPEC_FOR_DATACLASSES: |
| # Make all non-default dataclass attributes implicit because they are de-facto |
| # set on self in the generated __init__(), not in the class body. On the other |
| # hand, we don't know how custom dataclass transforms initialize attributes, |
| # so we don't treat them as implicit. This is required to support descriptors |
| # (https://github.com/python/mypy/issues/14868). |
| sym.implicit = True |
| |
| is_kw_only = kw_only |
| # Use the kw_only field arg if it is provided. Otherwise use the |
| # kw_only value from the decorator parameter. |
| field_kw_only_param = field_args.get("kw_only") |
| if field_kw_only_param is not None: |
| value = self._api.parse_bool(field_kw_only_param) |
| if value is not None: |
| is_kw_only = value |
| else: |
| self._api.fail('"kw_only" argument must be a boolean literal', stmt.rvalue) |
| |
| if sym.type is None and node.is_final and node.is_inferred: |
| # This is a special case, assignment like x: Final = 42 is classified |
| # annotated above, but mypy strips the `Final` turning it into x = 42. |
| # We do not support inferred types in dataclasses, so we can try inferring |
| # type for simple literals, and otherwise require an explicit type |
| # argument for Final[...]. |
| typ = self._api.analyze_simple_literal_type(stmt.rvalue, is_final=True) |
| if typ: |
| node.type = typ |
| else: |
| self._api.fail( |
| "Need type argument for Final[...] with non-literal default in dataclass", |
| stmt, |
| ) |
| node.type = AnyType(TypeOfAny.from_error) |
| |
| alias = None |
| if "alias" in field_args: |
| alias = self._api.parse_str_literal(field_args["alias"]) |
| if alias is None: |
| self._api.fail( |
| message_registry.DATACLASS_FIELD_ALIAS_MUST_BE_LITERAL, |
| stmt.rvalue, |
| code=errorcodes.LITERAL_REQ, |
| ) |
| |
| current_attr_names.add(lhs.name) |
| with state.strict_optional_set(self._api.options.strict_optional): |
| init_type = self._infer_dataclass_attr_init_type(sym, lhs.name, stmt) |
| found_attrs[lhs.name] = DataclassAttribute( |
| name=lhs.name, |
| alias=alias, |
| is_in_init=is_in_init, |
| is_init_var=is_init_var, |
| has_default=has_default, |
| line=stmt.line, |
| column=stmt.column, |
| type=init_type, |
| info=cls.info, |
| kw_only=is_kw_only, |
| is_neither_frozen_nor_nonfrozen=_has_direct_dataclass_transform_metaclass( |
| cls.info |
| ), |
| api=self._api, |
| ) |
| |
| all_attrs = list(found_attrs.values()) |
| if found_dataclass_supertype: |
| all_attrs.sort(key=lambda a: a.kw_only) |
| |
| # Third, ensure that arguments without a default don't follow |
| # arguments that have a default and that the KW_ONLY sentinel |
| # is only provided once. |
| found_default = False |
| found_kw_sentinel = False |
| for attr in all_attrs: |
| # If we find any attribute that is_in_init, not kw_only, and that |
| # doesn't have a default after one that does have one, |
| # then that's an error. |
| if found_default and attr.is_in_init and not attr.has_default and not attr.kw_only: |
| # If the issue comes from merging different classes, report it |
| # at the class definition point. |
| context: Context = cls |
| if attr.name in current_attr_names: |
| context = Context(line=attr.line, column=attr.column) |
| self._api.fail( |
| "Attributes without a default cannot follow attributes with one", context |
| ) |
| |
| found_default = found_default or (attr.has_default and attr.is_in_init) |
| if found_kw_sentinel and self._is_kw_only_type(attr.type): |
| context = cls |
| if attr.name in current_attr_names: |
| context = Context(line=attr.line, column=attr.column) |
| self._api.fail( |
| "There may not be more than one field with the KW_ONLY type", context |
| ) |
| found_kw_sentinel = found_kw_sentinel or self._is_kw_only_type(attr.type) |
| return all_attrs |
| |
| def _freeze(self, attributes: list[DataclassAttribute]) -> None: |
| """Converts all attributes to @property methods in order to |
| emulate frozen classes. |
| """ |
| info = self._cls.info |
| for attr in attributes: |
| # Classes that directly specify a dataclass_transform metaclass must be neither frozen |
| # non non-frozen per PEP681. Though it is surprising, this means that attributes from |
| # such a class must be writable even if the rest of the class hierarchy is frozen. This |
| # matches the behavior of Pyright (the reference implementation). |
| if attr.is_neither_frozen_nor_nonfrozen: |
| continue |
| |
| sym_node = info.names.get(attr.name) |
| if sym_node is not None: |
| var = sym_node.node |
| if isinstance(var, Var): |
| if var.is_final: |
| continue # do not turn `Final` attrs to `@property` |
| var.is_property = True |
| else: |
| var = attr.to_var(info) |
| var.info = info |
| var.is_property = True |
| var._fullname = info.fullname + "." + var.name |
| info.names[var.name] = SymbolTableNode(MDEF, var) |
| |
| def _propertize_callables( |
| self, attributes: list[DataclassAttribute], settable: bool = True |
| ) -> None: |
| """Converts all attributes with callable types to @property methods. |
| |
| This avoids the typechecker getting confused and thinking that |
| `my_dataclass_instance.callable_attr(foo)` is going to receive a |
| `self` argument (it is not). |
| |
| """ |
| info = self._cls.info |
| for attr in attributes: |
| if isinstance(get_proper_type(attr.type), CallableType): |
| var = attr.to_var(info) |
| var.info = info |
| var.is_property = True |
| var.is_settable_property = settable |
| var._fullname = info.fullname + "." + var.name |
| info.names[var.name] = SymbolTableNode(MDEF, var) |
| |
| def _is_kw_only_type(self, node: Type | None) -> bool: |
| """Checks if the type of the node is the KW_ONLY sentinel value.""" |
| if node is None: |
| return False |
| node_type = get_proper_type(node) |
| if not isinstance(node_type, Instance): |
| return False |
| return node_type.type.fullname == "dataclasses.KW_ONLY" |
| |
| def _add_dataclass_fields_magic_attribute(self) -> None: |
| attr_name = "__dataclass_fields__" |
| any_type = AnyType(TypeOfAny.explicit) |
| # For `dataclasses`, use the type `dict[str, Field[Any]]` for accuracy. For dataclass |
| # transforms, it's inaccurate to use `Field` since a given transform may use a completely |
| # different type (or none); fall back to `Any` there. |
| # |
| # In either case, we're aiming to match the Typeshed stub for `is_dataclass`, which expects |
| # the instance to have a `__dataclass_fields__` attribute of type `dict[str, Field[Any]]`. |
| if self._spec is _TRANSFORM_SPEC_FOR_DATACLASSES: |
| field_type = self._api.named_type_or_none("dataclasses.Field", [any_type]) or any_type |
| else: |
| field_type = any_type |
| attr_type = self._api.named_type( |
| "builtins.dict", [self._api.named_type("builtins.str"), field_type] |
| ) |
| var = Var(name=attr_name, type=attr_type) |
| var.info = self._cls.info |
| var._fullname = self._cls.info.fullname + "." + attr_name |
| var.is_classvar = True |
| self._cls.info.names[attr_name] = SymbolTableNode( |
| kind=MDEF, node=var, plugin_generated=True |
| ) |
| |
| def _collect_field_args(self, expr: Expression) -> tuple[bool, dict[str, Expression]]: |
| """Returns a tuple where the first value represents whether or not |
| the expression is a call to dataclass.field and the second is a |
| dictionary of the keyword arguments that field() was called with. |
| """ |
| if ( |
| isinstance(expr, CallExpr) |
| and isinstance(expr.callee, RefExpr) |
| and expr.callee.fullname in self._spec.field_specifiers |
| ): |
| # field() only takes keyword arguments. |
| args = {} |
| for name, arg, kind in zip(expr.arg_names, expr.args, expr.arg_kinds): |
| if not kind.is_named(): |
| if kind.is_named(star=True): |
| # This means that `field` is used with `**` unpacking, |
| # the best we can do for now is not to fail. |
| # TODO: we can infer what's inside `**` and try to collect it. |
| message = 'Unpacking **kwargs in "field()" is not supported' |
| elif self._spec is not _TRANSFORM_SPEC_FOR_DATACLASSES: |
| # dataclasses.field can only be used with keyword args, but this |
| # restriction is only enforced for the *standardized* arguments to |
| # dataclass_transform field specifiers. If this is not a |
| # dataclasses.dataclass class, we can just skip positional args safely. |
| continue |
| else: |
| message = '"field()" does not accept positional arguments' |
| self._api.fail(message, expr) |
| return True, {} |
| assert name is not None |
| args[name] = arg |
| return True, args |
| return False, {} |
| |
| def _get_bool_arg(self, name: str, default: bool) -> bool: |
| # Expressions are always CallExprs (either directly or via a wrapper like Decorator), so |
| # we can use the helpers from common |
| if isinstance(self._reason, Expression): |
| return _get_decorator_bool_argument( |
| ClassDefContext(self._cls, self._reason, self._api), name, default |
| ) |
| |
| # Subclass/metaclass use of `typing.dataclass_transform` reads the parameters from the |
| # class's keyword arguments (ie `class Subclass(Parent, kwarg1=..., kwarg2=...)`) |
| expression = self._cls.keywords.get(name) |
| if expression is not None: |
| return require_bool_literal_argument(self._api, expression, name, default) |
| return default |
| |
| def _get_default_init_value_for_field_specifier(self, call: Expression) -> bool: |
| """ |
| Find a default value for the `init` parameter of the specifier being called. If the |
| specifier's type signature includes an `init` parameter with a type of `Literal[True]` or |
| `Literal[False]`, return the appropriate boolean value from the literal. Otherwise, |
| fall back to the standard default of `True`. |
| """ |
| if not isinstance(call, CallExpr): |
| return True |
| |
| specifier_type = _get_callee_type(call) |
| if specifier_type is None: |
| return True |
| |
| parameter = specifier_type.argument_by_name("init") |
| if parameter is None: |
| return True |
| |
| literals = try_getting_literals_from_type(parameter.typ, bool, "builtins.bool") |
| if literals is None or len(literals) != 1: |
| return True |
| |
| return literals[0] |
| |
| def _infer_dataclass_attr_init_type( |
| self, sym: SymbolTableNode, name: str, context: Context |
| ) -> Type | None: |
| """Infer __init__ argument type for an attribute. |
| |
| In particular, possibly use the signature of __set__. |
| """ |
| default = sym.type |
| if sym.implicit: |
| return default |
| t = get_proper_type(sym.type) |
| |
| # Perform a simple-minded inference from the signature of __set__, if present. |
| # We can't use mypy.checkmember here, since this plugin runs before type checking. |
| # We only support some basic scanerios here, which is hopefully sufficient for |
| # the vast majority of use cases. |
| if not isinstance(t, Instance): |
| return default |
| setter = t.type.get("__set__") |
| if setter: |
| if isinstance(setter.node, FuncDef): |
| super_info = t.type.get_containing_type_info("__set__") |
| assert super_info |
| if setter.type: |
| setter_type = get_proper_type( |
| map_type_from_supertype(setter.type, t.type, super_info) |
| ) |
| else: |
| return AnyType(TypeOfAny.unannotated) |
| if isinstance(setter_type, CallableType) and setter_type.arg_kinds == [ |
| ARG_POS, |
| ARG_POS, |
| ARG_POS, |
| ]: |
| return expand_type_by_instance(setter_type.arg_types[2], t) |
| else: |
| self._api.fail( |
| f'Unsupported signature for "__set__" in "{t.type.name}"', context |
| ) |
| else: |
| self._api.fail(f'Unsupported "__set__" in "{t.type.name}"', context) |
| |
| return default |
| |
| |
| def add_dataclass_tag(info: TypeInfo) -> None: |
| # The value is ignored, only the existence matters. |
| info.metadata["dataclass_tag"] = {} |
| |
| |
| def dataclass_tag_callback(ctx: ClassDefContext) -> None: |
| """Record that we have a dataclass in the main semantic analysis pass. |
| |
| The later pass implemented by DataclassTransformer will use this |
| to detect dataclasses in base classes. |
| """ |
| add_dataclass_tag(ctx.cls.info) |
| |
| |
| def dataclass_class_maker_callback(ctx: ClassDefContext) -> bool: |
| """Hooks into the class typechecking process to add support for dataclasses.""" |
| if any(i.is_named_tuple for i in ctx.cls.info.mro): |
| ctx.api.fail("A NamedTuple cannot be a dataclass", ctx=ctx.cls.info) |
| return True |
| transformer = DataclassTransformer( |
| ctx.cls, ctx.reason, _get_transform_spec(ctx.reason), ctx.api |
| ) |
| return transformer.transform() |
| |
| |
| def _get_transform_spec(reason: Expression) -> DataclassTransformSpec: |
| """Find the relevant transform parameters from the decorator/parent class/metaclass that |
| triggered the dataclasses plugin. |
| |
| Although the resulting DataclassTransformSpec is based on the typing.dataclass_transform |
| function, we also use it for traditional dataclasses.dataclass classes as well for simplicity. |
| In those cases, we return a default spec rather than one based on a call to |
| `typing.dataclass_transform`. |
| """ |
| if _is_dataclasses_decorator(reason): |
| return _TRANSFORM_SPEC_FOR_DATACLASSES |
| |
| spec = find_dataclass_transform_spec(reason) |
| assert spec is not None, ( |
| "trying to find dataclass transform spec, but reason is neither dataclasses.dataclass nor " |
| "decorated with typing.dataclass_transform" |
| ) |
| return spec |
| |
| |
| def _is_dataclasses_decorator(node: Node) -> bool: |
| if isinstance(node, CallExpr): |
| node = node.callee |
| if isinstance(node, RefExpr): |
| return node.fullname in dataclass_makers |
| return False |
| |
| |
| def _has_direct_dataclass_transform_metaclass(info: TypeInfo) -> bool: |
| return ( |
| info.declared_metaclass is not None |
| and info.declared_metaclass.type.dataclass_transform_spec is not None |
| ) |
| |
| |
| def _get_expanded_dataclasses_fields( |
| ctx: FunctionSigContext, typ: ProperType, display_typ: ProperType, parent_typ: ProperType |
| ) -> list[CallableType] | None: |
| """ |
| For a given type, determine what dataclasses it can be: for each class, return the field types. |
| For generic classes, the field types are expanded. |
| If the type contains Any or a non-dataclass, returns None; in the latter case, also reports an error. |
| """ |
| if isinstance(typ, UnionType): |
| ret: list[CallableType] | None = [] |
| for item in typ.relevant_items(): |
| item = get_proper_type(item) |
| item_types = _get_expanded_dataclasses_fields(ctx, item, item, parent_typ) |
| if ret is not None and item_types is not None: |
| ret += item_types |
| else: |
| ret = None # but keep iterating to emit all errors |
| return ret |
| elif isinstance(typ, TypeVarType): |
| return _get_expanded_dataclasses_fields( |
| ctx, get_proper_type(typ.upper_bound), display_typ, parent_typ |
| ) |
| elif isinstance(typ, Instance): |
| replace_sym = typ.type.get_method(_INTERNAL_REPLACE_SYM_NAME) |
| if replace_sym is None: |
| return None |
| replace_sig = replace_sym.type |
| assert isinstance(replace_sig, ProperType) |
| assert isinstance(replace_sig, CallableType) |
| return [expand_type_by_instance(replace_sig, typ)] |
| else: |
| return None |
| |
| |
| # TODO: we can potentially get the function signature hook to allow returning a union |
| # and leave this to the regular machinery of resolving a union of callables |
| # (https://github.com/python/mypy/issues/15457) |
| def _meet_replace_sigs(sigs: list[CallableType]) -> CallableType: |
| """ |
| Produces the lowest bound of the 'replace' signatures of multiple dataclasses. |
| """ |
| args = { |
| name: (typ, kind) |
| for name, typ, kind in zip(sigs[0].arg_names, sigs[0].arg_types, sigs[0].arg_kinds) |
| } |
| |
| for sig in sigs[1:]: |
| sig_args = { |
| name: (typ, kind) |
| for name, typ, kind in zip(sig.arg_names, sig.arg_types, sig.arg_kinds) |
| } |
| for name in (*args.keys(), *sig_args.keys()): |
| sig_typ, sig_kind = args.get(name, (UninhabitedType(), ARG_NAMED_OPT)) |
| sig2_typ, sig2_kind = sig_args.get(name, (UninhabitedType(), ARG_NAMED_OPT)) |
| args[name] = ( |
| meet_types(sig_typ, sig2_typ), |
| ARG_NAMED_OPT if sig_kind == sig2_kind == ARG_NAMED_OPT else ARG_NAMED, |
| ) |
| |
| return sigs[0].copy_modified( |
| arg_names=list(args.keys()), |
| arg_types=[typ for typ, _ in args.values()], |
| arg_kinds=[kind for _, kind in args.values()], |
| ) |
| |
| |
| def replace_function_sig_callback(ctx: FunctionSigContext) -> CallableType: |
| """ |
| Returns a signature for the 'dataclasses.replace' function that's dependent on the type |
| of the first positional argument. |
| """ |
| if len(ctx.args) != 2: |
| # Ideally the name and context should be callee's, but we don't have it in FunctionSigContext. |
| ctx.api.fail(f'"{ctx.default_signature.name}" has unexpected type annotation', ctx.context) |
| return ctx.default_signature |
| |
| if len(ctx.args[0]) != 1: |
| return ctx.default_signature # leave it to the type checker to complain |
| |
| obj_arg = ctx.args[0][0] |
| obj_type = get_proper_type(ctx.api.get_expression_type(obj_arg)) |
| inst_type_str = format_type_bare(obj_type, ctx.api.options) |
| |
| replace_sigs = _get_expanded_dataclasses_fields(ctx, obj_type, obj_type, obj_type) |
| if replace_sigs is None: |
| return ctx.default_signature |
| replace_sig = _meet_replace_sigs(replace_sigs) |
| |
| return replace_sig.copy_modified( |
| arg_names=[None, *replace_sig.arg_names], |
| arg_kinds=[ARG_POS, *replace_sig.arg_kinds], |
| arg_types=[obj_type, *replace_sig.arg_types], |
| ret_type=obj_type, |
| fallback=ctx.default_signature.fallback, |
| name=f"{ctx.default_signature.name} of {inst_type_str}", |
| ) |
| |
| |
| def is_processed_dataclass(info: TypeInfo) -> bool: |
| return bool(info) and "dataclass" in info.metadata |
| |
| |
| def check_post_init(api: TypeChecker, defn: FuncItem, info: TypeInfo) -> None: |
| if defn.type is None: |
| return |
| assert isinstance(defn.type, FunctionLike) |
| |
| ideal_sig_method = info.get_method(_INTERNAL_POST_INIT_SYM_NAME) |
| assert ideal_sig_method is not None and ideal_sig_method.type is not None |
| ideal_sig = ideal_sig_method.type |
| assert isinstance(ideal_sig, ProperType) # we set it ourselves |
| assert isinstance(ideal_sig, CallableType) |
| ideal_sig = ideal_sig.copy_modified(name="__post_init__") |
| |
| api.check_override( |
| override=defn.type, |
| original=ideal_sig, |
| name="__post_init__", |
| name_in_super="__post_init__", |
| supertype="dataclass", |
| original_class_or_static=False, |
| override_class_or_static=False, |
| node=defn, |
| ) |