Fix DuplicateBasesError crash in dataclass transform (#2970)
diff --git a/ChangeLog b/ChangeLog index 7d15bf0..b6621fa 100644 --- a/ChangeLog +++ b/ChangeLog
@@ -13,6 +13,13 @@ ============================ Release date: TBA +* Fix ``DuplicateBasesError`` crash in dataclass transform when a class has + duplicate bases in its MRO (e.g., ``Protocol`` appearing both directly and + indirectly). Catch ``MroError`` at ``.mro()`` call sites in + ``brain_dataclasses.py``, consistent with the existing pattern elsewhere. + + Closes #2628 + * Catch ``MemoryError`` when inferring f-strings with extremely large format widths (e.g. ``f'{0:11111111111}'``) so that inference yields ``Uninferable`` instead of crashing.
diff --git a/astroid/brain/brain_dataclasses.py b/astroid/brain/brain_dataclasses.py index 244665e..b6b1956 100644 --- a/astroid/brain/brain_dataclasses.py +++ b/astroid/brain/brain_dataclasses.py
@@ -21,7 +21,12 @@ from astroid.brain.helpers import is_class_var from astroid.builder import parse from astroid.const import PY313_PLUS -from astroid.exceptions import AstroidSyntaxError, InferenceError, UseInferenceDefault +from astroid.exceptions import ( + AstroidSyntaxError, + InferenceError, + MroError, + UseInferenceDefault, +) from astroid.inference_tip import inference_tip from astroid.manager import AstroidManager from astroid.typing import InferenceResult @@ -174,7 +179,12 @@ # See TODO down below # all_have_defaults = True - for base in reversed(node.mro()): + try: + mro = node.mro() + except MroError: + return pos_only_store, kw_only_store + + for base in reversed(mro): if not base.is_dataclass: continue try: @@ -224,7 +234,12 @@ def _get_previous_field_default(node: nodes.ClassDef, name: str) -> nodes.NodeNG | None: """Get the default value of a previously defined field.""" - for base in reversed(node.mro()): + try: + mro = node.mro() + except MroError: + return None + + for base in reversed(mro): if not base.is_dataclass: continue if name in base.locals:
diff --git a/tests/brain/test_dataclasses.py b/tests/brain/test_dataclasses.py index d6ab13e..4795327 100644 --- a/tests/brain/test_dataclasses.py +++ b/tests/brain/test_dataclasses.py
@@ -1237,3 +1237,73 @@ fourth_init: bases.UnboundMethod = next(fourth.infer()) assert [a.name for a in fourth_init.args.args] == ["self", "other_attr", "attr"] assert [a.name for a in fourth_init.args.defaults] == ["Uninferable"] + + +def test_dataclass_with_duplicate_bases_no_crash(): + """Regression test for https://github.com/pylint-dev/astroid/issues/2628. + + A dataclass inheriting from a class with duplicate bases in MRO + (e.g., Protocol appearing both directly and indirectly) should not + crash with DuplicateBasesError during AST transformation. + """ + code = """ + import dataclasses + from typing import TypeVar, Protocol + + BaseT = TypeVar("BaseT") + T = TypeVar("T", bound=BaseT) + + class ConfigBase(Protocol[BaseT]): + ... + + class Config(ConfigBase[T], Protocol[T]): + ... + + @dataclasses.dataclass + class DatasetConfig(Config[T]): + name: str = "default" + + DatasetConfig.__init__ #@ + """ + node = astroid.extract_node(code) + # Should not raise DuplicateBasesError — graceful degradation instead + inferred = next(node.infer()) + assert inferred is not None + + +def test_dataclass_with_duplicate_bases_field_default(): + """Regression test for _get_previous_field_default with broken MRO. + + When a parent dataclass defines a field with a default and a child (with + duplicate bases in its MRO) re-annotates that field without a value, + _get_previous_field_default should not crash with DuplicateBasesError. + + See https://github.com/pylint-dev/astroid/issues/2628. + """ + code = """ + import dataclasses + from typing import TypeVar, Protocol + + BaseT = TypeVar("BaseT") + T = TypeVar("T", bound=BaseT) + + class ConfigBase(Protocol[BaseT]): + ... + + class Config(ConfigBase[T], Protocol[T]): + ... + + @dataclasses.dataclass + class BaseConfig(Config[T]): + name: str = "default" + + @dataclasses.dataclass + class ChildConfig(BaseConfig[T]): + name: str + + ChildConfig.__init__ #@ + """ + node = astroid.extract_node(code) + # Should not raise DuplicateBasesError in _get_previous_field_default + inferred = next(node.infer()) + assert inferred is not None