[Backport maintenance/4.0.x] Wrong inference with default argument values (#2924)
Fix overzealous filtering of `IfExp` inference (#2914)
(cherry picked from commit 178a796d01b43240638921400bc71212c1b2b05e)
Co-authored-by: jkmnt <git@firewood.fastmail.com>
diff --git a/ChangeLog b/ChangeLog
index a41a6aa..3d78737 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -13,12 +13,16 @@
============================
Release date: TBA
+* Fix inference of ``IfExp`` (ternary expression) nodes to avoid prematurely narrowing
+ results in the face of inference ambiguity.
+
+ Closes #2899
+
* Fix base class inference for dataclasses using the PEP 695 typing syntax.
Refs pylint-dev/pylint#10788
-
What's New in astroid 4.0.2?
============================
Release date: 2025-11-09
diff --git a/astroid/nodes/node_classes.py b/astroid/nodes/node_classes.py
index cfd78e6..b9077de 100644
--- a/astroid/nodes/node_classes.py
+++ b/astroid/nodes/node_classes.py
@@ -3108,31 +3108,37 @@
to inferring both branches. Otherwise, we infer either branch
depending on the condition.
"""
- both_branches = False
+
# We use two separate contexts for evaluating lhs and rhs because
# evaluating lhs may leave some undesired entries in context.path
# which may not let us infer right value of rhs.
-
context = context or InferenceContext()
lhs_context = copy_context(context)
rhs_context = copy_context(context)
+
+ # Infer bool condition. Stop inferring if in doubt and fallback to
+ # evaluating both branches.
+ condition: bool | None = None
try:
- test = next(self.test.infer(context=context.clone()))
- except (InferenceError, StopIteration):
- both_branches = True
- else:
- test_bool_value = test.bool_value()
- if not isinstance(test, util.UninferableBase) and not isinstance(
- test_bool_value, util.UninferableBase
- ):
- if test_bool_value:
- yield from self.body.infer(context=lhs_context)
- else:
- yield from self.orelse.infer(context=rhs_context)
- else:
- both_branches = True
- if both_branches:
+ for test in self.test.infer(context=context.clone()):
+ if isinstance(test, util.UninferableBase):
+ condition = None
+ break
+ test_bool_value = test.bool_value()
+ if isinstance(test_bool_value, util.UninferableBase):
+ condition = None
+ break
+ if condition is None:
+ condition = test_bool_value
+ elif test_bool_value != condition:
+ condition = None
+ break
+ except InferenceError:
+ condition = None
+
+ if condition is True or condition is None:
yield from self.body.infer(context=lhs_context)
+ if condition is False or condition is None:
yield from self.orelse.infer(context=rhs_context)
diff --git a/tests/test_inference.py b/tests/test_inference.py
index bf41e8c..6a19220 100644
--- a/tests/test_inference.py
+++ b/tests/test_inference.py
@@ -6442,6 +6442,98 @@
assert [third[0].value, third[1].value] == [1, 2]
+def test_ifexp_with_default_arguments() -> None:
+ code = """
+ def with_default(foo: str | None = None):
+ a = 1 if foo else "bar" #@
+
+ def without_default(foo: str):
+ a = 1 if foo else "bar" #@
+
+ def some_ifexps(foo: str | None = None):
+ a = 1 if foo else 2
+ b = 3 if a else 4 #@
+ c = 4 if b else 5 #@
+ d = 5 if not foo else foo #@
+ e = d if not foo else foo #@
+ """
+
+ ast_nodes = extract_node(code)
+
+ first = ast_nodes[0].value.inferred()
+ second = ast_nodes[1].value.inferred()
+ third = ast_nodes[2].value.inferred()
+ fourth = ast_nodes[3].value.inferred()
+ fifth = ast_nodes[4].value.inferred()
+ sixth = ast_nodes[5].value.inferred()
+
+ assert len(first) == 2
+ assert [first[0].value, first[1].value] == [1, "bar"]
+
+ assert len(second) == 2
+ assert [second[0].value, second[1].value] == [1, "bar"]
+
+ assert len(third) == 1
+ assert third[0].value == 3
+
+ assert len(fourth) == 1
+ assert fourth[0].value == 4
+
+ assert len(fifth) == 2
+ assert [fifth[0].value, fifth[1].value] == [5, Uninferable]
+
+ assert len(sixth) == 3
+ assert [sixth[0].value, sixth[1].value, sixth[2].value] == [
+ 5,
+ Uninferable,
+ Uninferable,
+ ]
+
+
+def test_ifexp_with_uninferables() -> None:
+ code = """
+ def truthy_and_falsy():
+ return False if unknown() else True
+
+ def truthy_and_uninferable():
+ return False if unknown() else unknown()
+
+ def calls_truthy_and_falsy():
+ return 1 if truthy_and_falsy() else 2
+
+ def calls_truthy_and_uninferable():
+ return 1 if range(10) else truthy_and_uninferable()
+
+ truthy_and_falsy() #@
+ truthy_and_uninferable() #@
+ calls_truthy_and_falsy() #@
+ calls_truthy_and_uninferable() #@
+ """
+
+ ast_nodes = extract_node(code)
+
+ first = ast_nodes[0].inferred()
+ second = ast_nodes[1].inferred()
+ third = ast_nodes[2].inferred()
+ fourth = ast_nodes[3].inferred()
+
+ assert len(first) == 2
+ assert [first[0].value, first[1].value] == [False, True]
+
+ assert len(second) == 2
+ assert [second[0].value, second[1].value] == [False, Uninferable]
+
+ assert len(third) == 2
+ assert [third[0].value, third[1].value] == [1, 2]
+
+ assert len(fourth) == 3
+ assert [fourth[0].value, fourth[1].value, fourth[2].value] == [
+ 1,
+ False,
+ Uninferable,
+ ]
+
+
def test_assert_last_function_returns_none_on_inference() -> None:
code = """
def check_equal(a, b):