| # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html |
| # For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE |
| # Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt |
| |
| import pytest |
| |
| import astroid |
| from astroid import bases, nodes |
| from astroid.exceptions import InferenceError |
| from astroid.util import Uninferable |
| |
| parametrize_module = pytest.mark.parametrize( |
| ("module",), (["dataclasses"], ["pydantic.dataclasses"], ["marshmallow_dataclass"]) |
| ) |
| |
| |
| @parametrize_module |
| def test_inference_attribute_no_default(module: str): |
| """Test inference of dataclass attribute with no default. |
| |
| Note that the argument to the constructor is ignored by the inference. |
| """ |
| klass, instance = astroid.extract_node(f""" |
| from {module} import dataclass |
| |
| @dataclass |
| class A: |
| name: str |
| |
| A.name #@ |
| A('hi').name #@ |
| """) |
| with pytest.raises(InferenceError): |
| klass.inferred() |
| |
| inferred = instance.inferred() |
| assert len(inferred) == 1 |
| assert isinstance(inferred[0], bases.Instance) |
| assert inferred[0].name == "str" |
| |
| |
| @parametrize_module |
| def test_inference_non_field_default(module: str): |
| """Test inference of dataclass attribute with a non-field default.""" |
| klass, instance = astroid.extract_node(f""" |
| from {module} import dataclass |
| |
| @dataclass |
| class A: |
| name: str = 'hi' |
| |
| A.name #@ |
| A().name #@ |
| """) |
| inferred = klass.inferred() |
| assert len(inferred) == 1 |
| assert isinstance(inferred[0], nodes.Const) |
| assert inferred[0].value == "hi" |
| |
| inferred = instance.inferred() |
| assert len(inferred) == 2 |
| assert isinstance(inferred[0], nodes.Const) |
| assert inferred[0].value == "hi" |
| |
| assert isinstance(inferred[1], bases.Instance) |
| assert inferred[1].name == "str" |
| |
| |
| @parametrize_module |
| def test_inference_field_default(module: str): |
| """Test inference of dataclass attribute with a field call default |
| (default keyword argument given). |
| """ |
| klass, instance = astroid.extract_node(f""" |
| from {module} import dataclass |
| from dataclasses import field |
| |
| @dataclass |
| class A: |
| name: str = field(default='hi') |
| |
| A.name #@ |
| A().name #@ |
| """) |
| inferred = klass.inferred() |
| assert len(inferred) == 1 |
| assert isinstance(inferred[0], nodes.Const) |
| assert inferred[0].value == "hi" |
| |
| inferred = instance.inferred() |
| assert len(inferred) == 2 |
| assert isinstance(inferred[0], nodes.Const) |
| assert inferred[0].value == "hi" |
| |
| assert isinstance(inferred[1], bases.Instance) |
| assert inferred[1].name == "str" |
| |
| |
| @parametrize_module |
| def test_inference_field_default_factory(module: str): |
| """Test inference of dataclass attribute with a field call default |
| (default_factory keyword argument given). |
| """ |
| klass, instance = astroid.extract_node(f""" |
| from {module} import dataclass |
| from dataclasses import field |
| |
| @dataclass |
| class A: |
| name: list = field(default_factory=list) |
| |
| A.name #@ |
| A().name #@ |
| """) |
| inferred = klass.inferred() |
| assert len(inferred) == 1 |
| assert isinstance(inferred[0], nodes.List) |
| assert inferred[0].elts == [] |
| |
| inferred = instance.inferred() |
| assert len(inferred) == 2 |
| assert isinstance(inferred[0], nodes.List) |
| assert inferred[0].elts == [] |
| |
| assert isinstance(inferred[1], bases.Instance) |
| assert inferred[1].name == "list" |
| |
| |
| @parametrize_module |
| def test_inference_method(module: str): |
| """Test inference of dataclass attribute within a method, |
| with a default_factory field. |
| |
| Based on https://github.com/pylint-dev/pylint/issues/2600 |
| """ |
| node = astroid.extract_node(f""" |
| from typing import Dict |
| from {module} import dataclass |
| from dataclasses import field |
| |
| @dataclass |
| class TestClass: |
| foo: str |
| bar: str |
| baz_dict: Dict[str, str] = field(default_factory=dict) |
| |
| def some_func(self) -> None: |
| f = self.baz_dict.items #@ |
| for key, value in f(): |
| print(key) |
| print(value) |
| """) |
| inferred = next(node.value.infer()) |
| assert isinstance(inferred, bases.BoundMethod) |
| |
| |
| @parametrize_module |
| def test_inference_no_annotation(module: str): |
| """Test that class variables without type annotations are not |
| turned into instance attributes. |
| """ |
| class_def, klass, instance = astroid.extract_node(f""" |
| from {module} import dataclass |
| |
| @dataclass |
| class A: |
| name = 'hi' |
| |
| A #@ |
| A.name #@ |
| A().name #@ |
| """) |
| inferred = next(class_def.infer()) |
| assert isinstance(inferred, nodes.ClassDef) |
| assert inferred.instance_attrs == {} |
| assert inferred.is_dataclass |
| |
| # Both the class and instance can still access the attribute |
| for node in (klass, instance): |
| assert isinstance(node, nodes.NodeNG) |
| inferred = node.inferred() |
| assert len(inferred) == 1 |
| assert isinstance(inferred[0], nodes.Const) |
| assert inferred[0].value == "hi" |
| |
| |
| @parametrize_module |
| def test_inference_class_var(module: str): |
| """Test that class variables with a ClassVar type annotations are not |
| turned into instance attributes. |
| """ |
| class_def, klass, instance = astroid.extract_node(f""" |
| from {module} import dataclass |
| from typing import ClassVar |
| |
| @dataclass |
| class A: |
| name: ClassVar[str] = 'hi' |
| |
| A #@ |
| A.name #@ |
| A().name #@ |
| """) |
| inferred = next(class_def.infer()) |
| assert isinstance(inferred, nodes.ClassDef) |
| assert inferred.instance_attrs == {} |
| assert inferred.is_dataclass |
| |
| # Both the class and instance can still access the attribute |
| for node in (klass, instance): |
| assert isinstance(node, nodes.NodeNG) |
| inferred = node.inferred() |
| assert len(inferred) == 1 |
| assert isinstance(inferred[0], nodes.Const) |
| assert inferred[0].value == "hi" |
| |
| |
| @parametrize_module |
| def test_inference_init_var(module: str): |
| """Test that class variables with InitVar type annotations are not |
| turned into instance attributes. |
| """ |
| class_def, klass, instance = astroid.extract_node(f""" |
| from {module} import dataclass |
| from dataclasses import InitVar |
| |
| @dataclass |
| class A: |
| name: InitVar[str] = 'hi' |
| |
| A #@ |
| A.name #@ |
| A().name #@ |
| """) |
| inferred = next(class_def.infer()) |
| assert isinstance(inferred, nodes.ClassDef) |
| assert inferred.instance_attrs == {} |
| assert inferred.is_dataclass |
| |
| # Both the class and instance can still access the attribute |
| for node in (klass, instance): |
| assert isinstance(node, nodes.NodeNG) |
| inferred = node.inferred() |
| assert len(inferred) == 1 |
| assert isinstance(inferred[0], nodes.Const) |
| assert inferred[0].value == "hi" |
| |
| |
| @parametrize_module |
| def test_inference_generic_collection_attribute(module: str): |
| """Test that an attribute with a generic collection type from the |
| typing module is inferred correctly. |
| """ |
| attr_nodes = astroid.extract_node(f""" |
| from {module} import dataclass |
| from dataclasses import field |
| import typing |
| |
| @dataclass |
| class A: |
| dict_prop: typing.Dict[str, str] |
| frozenset_prop: typing.FrozenSet[str] |
| list_prop: typing.List[str] |
| set_prop: typing.Set[str] |
| tuple_prop: typing.Tuple[int, str] |
| |
| a = A({{}}, frozenset(), [], set(), (1, 'hi')) |
| a.dict_prop #@ |
| a.frozenset_prop #@ |
| a.list_prop #@ |
| a.set_prop #@ |
| a.tuple_prop #@ |
| """) |
| names = ( |
| "Dict", |
| "FrozenSet", |
| "List", |
| "Set", |
| "Tuple", |
| ) |
| for node, name in zip(attr_nodes, names): |
| inferred = next(node.infer()) |
| assert isinstance(inferred, bases.Instance) |
| assert inferred.name == name |
| |
| |
| @pytest.mark.parametrize( |
| ("module", "typing_module"), |
| [ |
| ("dataclasses", "typing"), |
| ("pydantic.dataclasses", "typing"), |
| ("pydantic.dataclasses", "collections.abc"), |
| ("marshmallow_dataclass", "typing"), |
| ("marshmallow_dataclass", "collections.abc"), |
| ], |
| ) |
| def test_inference_callable_attribute(module: str, typing_module: str): |
| """Test that an attribute with a Callable annotation is inferred as Uninferable. |
| |
| See issue #1129 and pylint-dev/pylint#4895 |
| """ |
| instance = astroid.extract_node(f""" |
| from {module} import dataclass |
| from {typing_module} import Any, Callable |
| |
| @dataclass |
| class A: |
| enabled: Callable[[Any], bool] |
| |
| A(lambda x: x == 42).enabled #@ |
| """) |
| inferred = next(instance.infer()) |
| assert inferred is Uninferable |
| |
| |
| @parametrize_module |
| def test_inference_inherited(module: str): |
| """Test that an attribute is inherited from a superclass dataclass.""" |
| klass1, instance1, klass2, instance2 = astroid.extract_node(f""" |
| from {module} import dataclass |
| |
| @dataclass |
| class A: |
| value: int |
| name: str = "hi" |
| |
| @dataclass |
| class B(A): |
| new_attr: bool = True |
| |
| B.value #@ |
| B(1).value #@ |
| B.name #@ |
| B(1).name #@ |
| """) |
| with pytest.raises(InferenceError): # B.value is not defined |
| klass1.inferred() |
| |
| inferred = instance1.inferred() |
| assert isinstance(inferred[0], bases.Instance) |
| assert inferred[0].name == "int" |
| |
| inferred = klass2.inferred() |
| assert len(inferred) == 1 |
| assert isinstance(inferred[0], nodes.Const) |
| assert inferred[0].value == "hi" |
| |
| inferred = instance2.inferred() |
| assert len(inferred) == 2 |
| assert isinstance(inferred[0], nodes.Const) |
| assert inferred[0].value == "hi" |
| assert isinstance(inferred[1], bases.Instance) |
| assert inferred[1].name == "str" |
| |
| |
| def test_dataclass_order_of_inherited_attributes(): |
| """Test that an attribute in a child does not get put at the end of the init.""" |
| child, normal, keyword_only = astroid.extract_node(""" |
| from dataclass import dataclass |
| |
| |
| @dataclass |
| class Parent: |
| a: str |
| b: str |
| |
| |
| @dataclass |
| class Child(Parent): |
| c: str |
| a: str |
| |
| |
| @dataclass(kw_only=True) |
| class KeywordOnlyParent: |
| a: int |
| b: str |
| |
| |
| @dataclass |
| class NormalChild(KeywordOnlyParent): |
| c: str |
| a: str |
| |
| |
| @dataclass(kw_only=True) |
| class KeywordOnlyChild(KeywordOnlyParent): |
| c: str |
| a: str |
| |
| |
| Child.__init__ #@ |
| NormalChild.__init__ #@ |
| KeywordOnlyChild.__init__ #@ |
| """) |
| child_init: bases.UnboundMethod = next(child.infer()) |
| assert [a.name for a in child_init.args.args] == ["self", "a", "b", "c"] |
| |
| normal_init: bases.UnboundMethod = next(normal.infer()) |
| assert [a.name for a in normal_init.args.args] == ["self", "a", "c"] |
| assert [a.name for a in normal_init.args.kwonlyargs] == ["b"] |
| |
| keyword_only_init: bases.UnboundMethod = next(keyword_only.infer()) |
| assert [a.name for a in keyword_only_init.args.args] == ["self"] |
| assert [a.name for a in keyword_only_init.args.kwonlyargs] == ["a", "b", "c"] |
| |
| |
| def test_pydantic_field() -> None: |
| """Test that pydantic.Field attributes are currently Uninferable. |
| |
| (Eventually, we can extend the brain to support pydantic.Field) |
| """ |
| klass, instance = astroid.extract_node(""" |
| from pydantic import Field |
| from pydantic.dataclasses import dataclass |
| |
| @dataclass |
| class A: |
| name: str = Field("hi") |
| |
| A.name #@ |
| A().name #@ |
| """) |
| |
| inferred = klass.inferred() |
| assert len(inferred) == 1 |
| assert inferred[0] is Uninferable |
| |
| inferred = instance.inferred() |
| assert len(inferred) == 2 |
| assert inferred[0] is Uninferable |
| assert isinstance(inferred[1], bases.Instance) |
| assert inferred[1].name == "str" |
| |
| |
| @parametrize_module |
| def test_init_empty(module: str): |
| """Test init for a dataclass with no attributes.""" |
| node = astroid.extract_node(f""" |
| from {module} import dataclass |
| |
| @dataclass |
| class A: |
| pass |
| |
| A.__init__ #@ |
| """) |
| init = next(node.infer()) |
| assert [a.name for a in init.args.args] == ["self"] |
| |
| |
| @parametrize_module |
| def test_init_no_defaults(module: str): |
| """Test init for a dataclass with attributes and no defaults.""" |
| node = astroid.extract_node(f""" |
| from {module} import dataclass |
| from typing import List |
| |
| @dataclass |
| class A: |
| x: int |
| y: str |
| z: List[bool] |
| |
| A.__init__ #@ |
| """) |
| init = next(node.infer()) |
| assert [a.name for a in init.args.args] == ["self", "x", "y", "z"] |
| assert [a.as_string() if a else None for a in init.args.annotations] == [ |
| None, |
| "int", |
| "str", |
| "List[bool]", |
| ] |
| |
| |
| @parametrize_module |
| def test_init_defaults(module: str): |
| """Test init for a dataclass with attributes and some defaults.""" |
| node = astroid.extract_node(f""" |
| from {module} import dataclass |
| from dataclasses import field |
| from typing import List |
| |
| @dataclass |
| class A: |
| w: int |
| x: int = 10 |
| y: str = field(default="hi") |
| z: List[bool] = field(default_factory=list) |
| |
| A.__init__ #@ |
| """) |
| init = next(node.infer()) |
| assert [a.name for a in init.args.args] == ["self", "w", "x", "y", "z"] |
| assert [a.as_string() if a else None for a in init.args.annotations] == [ |
| None, |
| "int", |
| "int", |
| "str", |
| "List[bool]", |
| ] |
| assert [a.as_string() if a else None for a in init.args.defaults] == [ |
| "10", |
| "'hi'", |
| "_HAS_DEFAULT_FACTORY", |
| ] |
| |
| |
| @parametrize_module |
| def test_init_initvar(module: str): |
| """Test init for a dataclass with attributes and an InitVar.""" |
| node = astroid.extract_node(f""" |
| from {module} import dataclass |
| from dataclasses import InitVar |
| from typing import List |
| |
| @dataclass |
| class A: |
| x: int |
| y: str |
| init_var: InitVar[int] |
| z: List[bool] |
| |
| A.__init__ #@ |
| """) |
| init = next(node.infer()) |
| assert [a.name for a in init.args.args] == ["self", "x", "y", "init_var", "z"] |
| assert [a.as_string() if a else None for a in init.args.annotations] == [ |
| None, |
| "int", |
| "str", |
| "int", |
| "List[bool]", |
| ] |
| |
| |
| @parametrize_module |
| def test_init_decorator_init_false(module: str): |
| """Test that no init is generated when init=False is passed to |
| dataclass decorator. |
| """ |
| node = astroid.extract_node(f""" |
| from {module} import dataclass |
| from typing import List |
| |
| @dataclass(init=False) |
| class A: |
| x: int |
| y: str |
| z: List[bool] |
| |
| A.__init__ #@ |
| """) |
| init = next(node.infer()) |
| assert init._proxied.parent.name == "object" |
| |
| |
| @parametrize_module |
| def test_init_field_init_false(module: str): |
| """Test init for a dataclass with attributes with a field value where init=False |
| (these attributes should not be included in the initializer). |
| """ |
| node = astroid.extract_node(f""" |
| from {module} import dataclass |
| from dataclasses import field |
| from typing import List |
| |
| @dataclass |
| class A: |
| x: int |
| y: str |
| z: List[bool] = field(init=False) |
| |
| A.__init__ #@ |
| """) |
| init = next(node.infer()) |
| assert [a.name for a in init.args.args] == ["self", "x", "y"] |
| assert [a.as_string() if a else None for a in init.args.annotations] == [ |
| None, |
| "int", |
| "str", |
| ] |
| |
| |
| @parametrize_module |
| def test_init_override(module: str): |
| """Test init for a dataclass overrides a superclass initializer. |
| |
| Based on https://github.com/pylint-dev/pylint/issues/3201 |
| """ |
| node = astroid.extract_node(f""" |
| from {module} import dataclass |
| from typing import List |
| |
| class A: |
| arg0: str = None |
| |
| def __init__(self, arg0): |
| raise NotImplementedError |
| |
| @dataclass |
| class B(A): |
| arg1: int = None |
| arg2: str = None |
| |
| B.__init__ #@ |
| """) |
| init = next(node.infer()) |
| assert [a.name for a in init.args.args] == ["self", "arg1", "arg2"] |
| assert [a.as_string() if a else None for a in init.args.annotations] == [ |
| None, |
| "int", |
| "str", |
| ] |
| |
| |
| @parametrize_module |
| def test_init_attributes_from_superclasses(module: str): |
| """Test init for a dataclass that inherits and overrides attributes from |
| superclasses. |
| |
| Based on https://github.com/pylint-dev/pylint/issues/3201 |
| """ |
| node = astroid.extract_node(f""" |
| from {module} import dataclass |
| from typing import List |
| |
| @dataclass |
| class A: |
| arg0: float |
| arg2: str |
| |
| @dataclass |
| class B(A): |
| arg1: int |
| arg2: list # Overrides arg2 from A |
| |
| B.__init__ #@ |
| """) |
| init = next(node.infer()) |
| assert [a.name for a in init.args.args] == ["self", "arg0", "arg2", "arg1"] |
| assert [a.as_string() if a else None for a in init.args.annotations] == [ |
| None, |
| "float", |
| "list", # not str |
| "int", |
| ] |
| |
| |
| @parametrize_module |
| def test_invalid_init(module: str): |
| """Test that astroid doesn't generate an initializer when attribute order is |
| invalid. |
| """ |
| node = astroid.extract_node(f""" |
| from {module} import dataclass |
| |
| @dataclass |
| class A: |
| arg1: float = 0.0 |
| arg2: str |
| |
| A.__init__ #@ |
| """) |
| init = next(node.infer()) |
| assert init._proxied.parent.name == "object" |
| |
| |
| @parametrize_module |
| def test_annotated_enclosed_field_call(module: str): |
| """Test inference of dataclass attribute with a field call in another function |
| call. |
| """ |
| node = astroid.extract_node(f""" |
| from {module} import dataclass, field |
| from typing import cast |
| |
| @dataclass |
| class A: |
| attribute: int = cast(int, field(default_factory=dict)) |
| """) |
| inferred = node.inferred() |
| assert len(inferred) == 1 and isinstance(inferred[0], nodes.ClassDef) |
| assert "attribute" in inferred[0].instance_attrs |
| assert inferred[0].is_dataclass |
| |
| |
| @parametrize_module |
| def test_invalid_field_call(module: str) -> None: |
| """Test inference of invalid field call doesn't crash.""" |
| code = astroid.extract_node(f""" |
| from {module} import dataclass, field |
| |
| @dataclass |
| class A: |
| val: field() |
| """) |
| inferred = code.inferred() |
| assert len(inferred) == 1 |
| assert isinstance(inferred[0], nodes.ClassDef) |
| assert inferred[0].is_dataclass |
| |
| |
| def test_non_dataclass_is_not_dataclass() -> None: |
| """Test that something that isn't a dataclass has the correct attribute.""" |
| module = astroid.parse(""" |
| class A: |
| val: field() |
| |
| def dataclass(): |
| return |
| |
| @dataclass |
| class B: |
| val: field() |
| """) |
| class_a = module.body[0].inferred() |
| assert len(class_a) == 1 |
| assert isinstance(class_a[0], nodes.ClassDef) |
| assert not class_a[0].is_dataclass |
| |
| class_b = module.body[2].inferred() |
| assert len(class_b) == 1 |
| assert isinstance(class_b[0], nodes.ClassDef) |
| assert not class_b[0].is_dataclass |
| |
| |
| def test_kw_only_sentinel() -> None: |
| """Test that the KW_ONLY sentinel doesn't get added to the fields.""" |
| node_one, node_two = astroid.extract_node(""" |
| from dataclasses import dataclass, KW_ONLY |
| from dataclasses import KW_ONLY as keyword_only |
| |
| @dataclass |
| class A: |
| _: KW_ONLY |
| y: str |
| |
| A.__init__ #@ |
| |
| @dataclass |
| class B: |
| _: keyword_only |
| y: str |
| |
| B.__init__ #@ |
| """) |
| expected = ["self", "y"] |
| init = next(node_one.infer()) |
| assert [a.name for a in init.args.args] == expected |
| |
| init = next(node_two.infer()) |
| assert [a.name for a in init.args.args] == expected |
| |
| |
| def test_kw_only_decorator() -> None: |
| """Test that we update the signature correctly based on the keyword.""" |
| foodef, bardef, cee, dee = astroid.extract_node(""" |
| from dataclasses import dataclass |
| |
| @dataclass(kw_only=True) |
| class Foo: |
| a: int |
| e: str |
| |
| |
| @dataclass(kw_only=False) |
| class Bar(Foo): |
| c: int |
| |
| |
| @dataclass(kw_only=False) |
| class Cee(Bar): |
| d: int |
| |
| |
| @dataclass(kw_only=True) |
| class Dee(Cee): |
| ee: int |
| |
| |
| Foo.__init__ #@ |
| Bar.__init__ #@ |
| Cee.__init__ #@ |
| Dee.__init__ #@ |
| """) |
| |
| foo_init: bases.UnboundMethod = next(foodef.infer()) |
| assert [a.name for a in foo_init.args.args] == ["self"] |
| assert [a.name for a in foo_init.args.kwonlyargs] == ["a", "e"] |
| |
| bar_init: bases.UnboundMethod = next(bardef.infer()) |
| assert [a.name for a in bar_init.args.args] == ["self", "c"] |
| assert [a.name for a in bar_init.args.kwonlyargs] == ["a", "e"] |
| |
| cee_init: bases.UnboundMethod = next(cee.infer()) |
| assert [a.name for a in cee_init.args.args] == ["self", "c", "d"] |
| assert [a.name for a in cee_init.args.kwonlyargs] == ["a", "e"] |
| |
| dee_init: bases.UnboundMethod = next(dee.infer()) |
| assert [a.name for a in dee_init.args.args] == ["self", "c", "d"] |
| assert [a.name for a in dee_init.args.kwonlyargs] == ["a", "e", "ee"] |
| |
| |
| def test_kw_only_in_field_call() -> None: |
| """Test that keyword only fields get correctly put at the end of the __init__.""" |
| |
| first, second, third = astroid.extract_node(""" |
| from dataclasses import dataclass, field |
| |
| @dataclass |
| class Parent: |
| p1: int = field(kw_only=True, default=0) |
| |
| @dataclass |
| class Child(Parent): |
| c1: str |
| |
| @dataclass(kw_only=True) |
| class GrandChild(Child): |
| p2: int = field(kw_only=False, default=1) |
| p3: int = field(kw_only=True, default=2) |
| |
| Parent.__init__ #@ |
| Child.__init__ #@ |
| GrandChild.__init__ #@ |
| """) |
| |
| first_init: bases.UnboundMethod = next(first.infer()) |
| assert [a.name for a in first_init.args.args] == ["self"] |
| assert [a.name for a in first_init.args.kwonlyargs] == ["p1"] |
| assert [d.value for d in first_init.args.kw_defaults] == [0] |
| |
| second_init: bases.UnboundMethod = next(second.infer()) |
| assert [a.name for a in second_init.args.args] == ["self", "c1"] |
| assert [a.name for a in second_init.args.kwonlyargs] == ["p1"] |
| assert [d.value for d in second_init.args.kw_defaults] == [0] |
| |
| third_init: bases.UnboundMethod = next(third.infer()) |
| assert [a.name for a in third_init.args.args] == ["self", "c1", "p2"] |
| assert [a.name for a in third_init.args.kwonlyargs] == ["p1", "p3"] |
| assert [d.value for d in third_init.args.defaults] == [1] |
| assert [d.value for d in third_init.args.kw_defaults] == [0, 2] |
| |
| |
| def test_dataclass_with_unknown_base() -> None: |
| """Regression test for dataclasses with unknown base classes. |
| |
| Reported in https://github.com/pylint-dev/pylint/issues/7418 |
| """ |
| node = astroid.extract_node(""" |
| import dataclasses |
| |
| from unknown import Unknown |
| |
| |
| @dataclasses.dataclass |
| class MyDataclass(Unknown): |
| pass |
| |
| MyDataclass() |
| """) |
| |
| assert next(node.infer()) |
| |
| |
| def test_dataclass_with_unknown_typing() -> None: |
| """Regression test for dataclasses with unknown base classes. |
| |
| Reported in https://github.com/pylint-dev/pylint/issues/7422 |
| """ |
| node = astroid.extract_node(""" |
| from dataclasses import dataclass, InitVar |
| |
| |
| @dataclass |
| class TestClass: |
| '''Test Class''' |
| |
| config: InitVar = None |
| |
| TestClass.__init__ #@ |
| """) |
| |
| init_def: bases.UnboundMethod = next(node.infer()) |
| assert [a.name for a in init_def.args.args] == ["self", "config"] |
| |
| |
| def test_dataclass_with_default_factory() -> None: |
| """Regression test for dataclasses with default values. |
| |
| Reported in https://github.com/pylint-dev/pylint/issues/7425 |
| """ |
| bad_node, good_node = astroid.extract_node(""" |
| from dataclasses import dataclass |
| from typing import Union |
| |
| @dataclass |
| class BadExampleParentClass: |
| xyz: Union[str, int] |
| |
| @dataclass |
| class BadExampleClass(BadExampleParentClass): |
| xyz: str = "" |
| |
| BadExampleClass.__init__ #@ |
| |
| @dataclass |
| class GoodExampleParentClass: |
| xyz: str |
| |
| @dataclass |
| class GoodExampleClass(GoodExampleParentClass): |
| xyz: str = "" |
| |
| GoodExampleClass.__init__ #@ |
| """) |
| |
| bad_init: bases.UnboundMethod = next(bad_node.infer()) |
| assert bad_init.args.defaults |
| assert [a.name for a in bad_init.args.args] == ["self", "xyz"] |
| |
| good_init: bases.UnboundMethod = next(good_node.infer()) |
| assert bad_init.args.defaults |
| assert [a.name for a in good_init.args.args] == ["self", "xyz"] |
| |
| |
| def test_dataclass_with_multiple_inheritance() -> None: |
| """Regression test for dataclasses with multiple inheritance. |
| |
| Reported in https://github.com/pylint-dev/pylint/issues/7427 |
| Reported in https://github.com/pylint-dev/pylint/issues/7434 |
| """ |
| first, second, overwritten, overwriting, mixed = astroid.extract_node(""" |
| from dataclasses import dataclass |
| |
| @dataclass |
| class BaseParent: |
| _abc: int = 1 |
| |
| @dataclass |
| class AnotherParent: |
| ef: int = 2 |
| |
| @dataclass |
| class FirstChild(BaseParent, AnotherParent): |
| ghi: int = 3 |
| |
| @dataclass |
| class ConvolutedParent(AnotherParent): |
| '''Convoluted Parent''' |
| |
| @dataclass |
| class SecondChild(BaseParent, ConvolutedParent): |
| jkl: int = 4 |
| |
| @dataclass |
| class OverwritingParent: |
| ef: str = "2" |
| |
| @dataclass |
| class OverwrittenChild(OverwritingParent, AnotherParent): |
| '''Overwritten Child''' |
| |
| @dataclass |
| class OverwritingChild(BaseParent, AnotherParent): |
| _abc: float = 1.0 |
| ef: float = 2.0 |
| |
| class NotADataclassParent: |
| ef: int = 2 |
| |
| @dataclass |
| class ChildWithMixedParents(BaseParent, NotADataclassParent): |
| ghi: int = 3 |
| |
| FirstChild.__init__ #@ |
| SecondChild.__init__ #@ |
| OverwrittenChild.__init__ #@ |
| OverwritingChild.__init__ #@ |
| ChildWithMixedParents.__init__ #@ |
| """) |
| |
| first_init: bases.UnboundMethod = next(first.infer()) |
| assert [a.name for a in first_init.args.args] == ["self", "ef", "_abc", "ghi"] |
| assert [a.value for a in first_init.args.defaults] == [2, 1, 3] |
| |
| second_init: bases.UnboundMethod = next(second.infer()) |
| assert [a.name for a in second_init.args.args] == ["self", "ef", "_abc", "jkl"] |
| assert [a.value for a in second_init.args.defaults] == [2, 1, 4] |
| |
| overwritten_init: bases.UnboundMethod = next(overwritten.infer()) |
| assert [a.name for a in overwritten_init.args.args] == ["self", "ef"] |
| assert [a.value for a in overwritten_init.args.defaults] == ["2"] |
| |
| overwriting_init: bases.UnboundMethod = next(overwriting.infer()) |
| assert [a.name for a in overwriting_init.args.args] == ["self", "ef", "_abc"] |
| assert [a.value for a in overwriting_init.args.defaults] == [2.0, 1.0] |
| |
| mixed_init: bases.UnboundMethod = next(mixed.infer()) |
| assert [a.name for a in mixed_init.args.args] == ["self", "_abc", "ghi"] |
| assert [a.value for a in mixed_init.args.defaults] == [1, 3] |
| |
| first = astroid.extract_node(""" |
| from dataclasses import dataclass |
| |
| @dataclass |
| class BaseParent: |
| required: bool |
| |
| @dataclass |
| class FirstChild(BaseParent): |
| ... |
| |
| @dataclass |
| class SecondChild(BaseParent): |
| optional: bool = False |
| |
| @dataclass |
| class GrandChild(FirstChild, SecondChild): |
| ... |
| |
| GrandChild.__init__ #@ |
| """) |
| |
| first_init: bases.UnboundMethod = next(first.infer()) |
| assert [a.name for a in first_init.args.args] == ["self", "required", "optional"] |
| assert [a.value for a in first_init.args.defaults] == [False] |
| |
| |
| @pytest.mark.xfail(reason="Transforms returning Uninferable isn't supported.") |
| def test_dataclass_non_default_argument_after_default() -> None: |
| """Test that a non-default argument after a default argument is not allowed. |
| |
| This should succeed, but the dataclass brain is a transform |
| which currently can't return an Uninferable correctly. Therefore, we can't |
| set the dataclass ClassDef node to be Uninferable currently. |
| Eventually it can be merged into test_dataclass_with_multiple_inheritance. |
| """ |
| |
| impossible = astroid.extract_node(""" |
| from dataclasses import dataclass |
| |
| @dataclass |
| class BaseParent: |
| required: bool |
| |
| @dataclass |
| class FirstChild(BaseParent): |
| ... |
| |
| @dataclass |
| class SecondChild(BaseParent): |
| optional: bool = False |
| |
| @dataclass |
| class ThirdChild: |
| other: bool = False |
| |
| @dataclass |
| class ImpossibleGrandChild(FirstChild, SecondChild, ThirdChild): |
| ... |
| |
| ImpossibleGrandChild() #@ |
| """) |
| |
| assert next(impossible.infer()) is Uninferable |
| |
| |
| def test_dataclass_with_field_init_is_false() -> None: |
| """When init=False it shouldn't end up in the __init__.""" |
| first, second, second_child, third_child, third = astroid.extract_node(""" |
| from dataclasses import dataclass, field |
| |
| |
| @dataclass |
| class First: |
| a: int |
| |
| @dataclass |
| class Second(First): |
| a: int = field(init=False, default=1) |
| |
| @dataclass |
| class SecondChild(Second): |
| a: float |
| |
| @dataclass |
| class ThirdChild(SecondChild): |
| a: str |
| |
| @dataclass |
| class Third(First): |
| a: str |
| |
| First.__init__ #@ |
| Second.__init__ #@ |
| SecondChild.__init__ #@ |
| ThirdChild.__init__ #@ |
| Third.__init__ #@ |
| """) |
| |
| first_init: bases.UnboundMethod = next(first.infer()) |
| assert [a.name for a in first_init.args.args] == ["self", "a"] |
| assert [a.value for a in first_init.args.defaults] == [] |
| |
| second_init: bases.UnboundMethod = next(second.infer()) |
| assert [a.name for a in second_init.args.args] == ["self"] |
| assert [a.value for a in second_init.args.defaults] == [] |
| |
| second_child_init: bases.UnboundMethod = next(second_child.infer()) |
| assert [a.name for a in second_child_init.args.args] == ["self", "a"] |
| assert [a.value for a in second_child_init.args.defaults] == [1] |
| |
| third_child_init: bases.UnboundMethod = next(third_child.infer()) |
| assert [a.name for a in third_child_init.args.args] == ["self", "a"] |
| assert [a.value for a in third_child_init.args.defaults] == [1] |
| |
| third_init: bases.UnboundMethod = next(third.infer()) |
| assert [a.name for a in third_init.args.args] == ["self", "a"] |
| assert [a.value for a in third_init.args.defaults] == [] |
| |
| |
| def test_dataclass_inits_of_non_dataclasses() -> None: |
| """Regression test for __init__ mangling for non dataclasses. |
| |
| Regression test against changes tested in test_dataclass_with_multiple_inheritance |
| """ |
| first, second, third = astroid.extract_node(""" |
| from dataclasses import dataclass |
| |
| @dataclass |
| class DataclassParent: |
| _abc: int = 1 |
| |
| |
| class NotADataclassParent: |
| ef: int = 2 |
| |
| |
| class FirstChild(DataclassParent, NotADataclassParent): |
| ghi: int = 3 |
| |
| |
| class SecondChild(DataclassParent, NotADataclassParent): |
| ghi: int = 3 |
| |
| def __init__(self, ef: int = 3): |
| self.ef = ef |
| |
| |
| class ThirdChild(NotADataclassParent, DataclassParent): |
| ghi: int = 3 |
| |
| def __init__(self, ef: int = 3): |
| self.ef = ef |
| |
| FirstChild.__init__ #@ |
| SecondChild.__init__ #@ |
| ThirdChild.__init__ #@ |
| """) |
| |
| first_init: bases.UnboundMethod = next(first.infer()) |
| assert [a.name for a in first_init.args.args] == ["self", "_abc"] |
| assert [a.value for a in first_init.args.defaults] == [1] |
| |
| second_init: bases.UnboundMethod = next(second.infer()) |
| assert [a.name for a in second_init.args.args] == ["self", "ef"] |
| assert [a.value for a in second_init.args.defaults] == [3] |
| |
| third_init: bases.UnboundMethod = next(third.infer()) |
| assert [a.name for a in third_init.args.args] == ["self", "ef"] |
| assert [a.value for a in third_init.args.defaults] == [3] |
| |
| |
| def test_dataclass_with_properties() -> None: |
| """Tests for __init__ creation for dataclasses that use properties.""" |
| first, second, third = astroid.extract_node(""" |
| from dataclasses import dataclass |
| |
| @dataclass |
| class Dataclass: |
| attr: int |
| |
| @property |
| def attr(self) -> int: |
| return 1 |
| |
| @attr.setter |
| def attr(self, value: int) -> None: |
| pass |
| |
| class ParentOne(Dataclass): |
| '''Docstring''' |
| |
| @dataclass |
| class ParentTwo(Dataclass): |
| '''Docstring''' |
| |
| Dataclass.__init__ #@ |
| ParentOne.__init__ #@ |
| ParentTwo.__init__ #@ |
| """) |
| |
| first_init: bases.UnboundMethod = next(first.infer()) |
| assert [a.name for a in first_init.args.args] == ["self", "attr"] |
| assert [a.value for a in first_init.args.defaults] == [1] |
| |
| second_init: bases.UnboundMethod = next(second.infer()) |
| assert [a.name for a in second_init.args.args] == ["self", "attr"] |
| assert [a.value for a in second_init.args.defaults] == [1] |
| |
| third_init: bases.UnboundMethod = next(third.infer()) |
| assert [a.name for a in third_init.args.args] == ["self", "attr"] |
| assert [a.value for a in third_init.args.defaults] == [1] |
| |
| fourth = astroid.extract_node(""" |
| from dataclasses import dataclass |
| |
| @dataclass |
| class Dataclass: |
| other_attr: str |
| attr: str |
| |
| @property |
| def attr(self) -> str: |
| return self.other_attr[-1] |
| |
| @attr.setter |
| def attr(self, value: int) -> None: |
| pass |
| |
| Dataclass.__init__ #@ |
| """) |
| |
| fourth_init: bases.UnboundMethod = next(fourth.infer()) |
| assert [a.name for a in fourth_init.args.args] == ["self", "other_attr", "attr"] |
| assert [a.name for a in fourth_init.args.defaults] == ["Uninferable"] |
| |
| |
| def test_dataclass_with_duplicate_bases_no_crash(): |
| """Regression test for https://github.com/pylint-dev/astroid/issues/2628. |
| |
| A dataclass inheriting from a class with duplicate bases in MRO |
| (e.g., Protocol appearing both directly and indirectly) should not |
| crash with DuplicateBasesError during AST transformation. |
| """ |
| code = """ |
| import dataclasses |
| from typing import TypeVar, Protocol |
| |
| BaseT = TypeVar("BaseT") |
| T = TypeVar("T", bound=BaseT) |
| |
| class ConfigBase(Protocol[BaseT]): |
| ... |
| |
| class Config(ConfigBase[T], Protocol[T]): |
| ... |
| |
| @dataclasses.dataclass |
| class DatasetConfig(Config[T]): |
| name: str = "default" |
| |
| DatasetConfig.__init__ #@ |
| """ |
| node = astroid.extract_node(code) |
| # Should not raise DuplicateBasesError — graceful degradation instead |
| inferred = next(node.infer()) |
| assert inferred is not None |
| |
| |
| def test_dataclass_with_duplicate_bases_field_default(): |
| """Regression test for _get_previous_field_default with broken MRO. |
| |
| When a parent dataclass defines a field with a default and a child (with |
| duplicate bases in its MRO) re-annotates that field without a value, |
| _get_previous_field_default should not crash with DuplicateBasesError. |
| |
| See https://github.com/pylint-dev/astroid/issues/2628. |
| """ |
| code = """ |
| import dataclasses |
| from typing import TypeVar, Protocol |
| |
| BaseT = TypeVar("BaseT") |
| T = TypeVar("T", bound=BaseT) |
| |
| class ConfigBase(Protocol[BaseT]): |
| ... |
| |
| class Config(ConfigBase[T], Protocol[T]): |
| ... |
| |
| @dataclasses.dataclass |
| class BaseConfig(Config[T]): |
| name: str = "default" |
| |
| @dataclasses.dataclass |
| class ChildConfig(BaseConfig[T]): |
| name: str |
| |
| ChildConfig.__init__ #@ |
| """ |
| node = astroid.extract_node(code) |
| # Should not raise DuplicateBasesError in _get_previous_field_default |
| inferred = next(node.infer()) |
| assert inferred is not None |