blob: 08bb80e3c9ad8683de2a421252d1bc98c7a3da65 [file] [log] [blame]
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
"""Classes representing different types of constraints on inference values."""
from __future__ import annotations
import sys
from abc import ABC, abstractmethod
from collections.abc import Iterator
from typing import TYPE_CHECKING, Union
from astroid import nodes, util
from astroid.typing import InferenceResult
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
if TYPE_CHECKING:
from astroid import bases
_NameNodes = Union[nodes.AssignAttr, nodes.Attribute, nodes.AssignName, nodes.Name]
class Constraint(ABC):
"""Represents a single constraint on a variable."""
def __init__(self, node: nodes.NodeNG, negate: bool) -> None:
self.node = node
"""The node that this constraint applies to."""
self.negate = negate
"""True if this constraint is negated. E.g., "is not" instead of "is"."""
@classmethod
@abstractmethod
def match(
cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
) -> Self | None:
"""Return a new constraint for node matched from expr, if expr matches
the constraint pattern.
If negate is True, negate the constraint.
"""
@abstractmethod
def satisfied_by(self, inferred: InferenceResult) -> bool:
"""Return True if this constraint is satisfied by the given inferred value."""
class NoneConstraint(Constraint):
"""Represents an "is None" or "is not None" constraint."""
CONST_NONE: nodes.Const = nodes.Const(None)
@classmethod
def match(
cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
) -> Self | None:
"""Return a new constraint for node matched from expr, if expr matches
the constraint pattern.
Negate the constraint based on the value of negate.
"""
if isinstance(expr, nodes.Compare) and len(expr.ops) == 1:
left = expr.left
op, right = expr.ops[0]
if op in {"is", "is not"} and (
_matches(left, node) and _matches(right, cls.CONST_NONE)
):
negate = (op == "is" and negate) or (op == "is not" and not negate)
return cls(node=node, negate=negate)
return None
def satisfied_by(self, inferred: InferenceResult) -> bool:
"""Return True if this constraint is satisfied by the given inferred value."""
# Assume true if uninferable
if isinstance(inferred, util.UninferableBase):
return True
# Return the XOR of self.negate and matches(inferred, self.CONST_NONE)
return self.negate ^ _matches(inferred, self.CONST_NONE)
def get_constraints(
expr: _NameNodes, frame: nodes.LocalsDictNodeNG
) -> dict[nodes.If, set[Constraint]]:
"""Returns the constraints for the given expression.
The returned dictionary maps the node where the constraint was generated to the
corresponding constraint(s).
Constraints are computed statically by analysing the code surrounding expr.
Currently this only supports constraints generated from if conditions.
"""
current_node: nodes.NodeNG | None = expr
constraints_mapping: dict[nodes.If, set[Constraint]] = {}
while current_node is not None and current_node is not frame:
parent = current_node.parent
if isinstance(parent, nodes.If):
branch, _ = parent.locate_child(current_node)
constraints: set[Constraint] | None = None
if branch == "body":
constraints = set(_match_constraint(expr, parent.test))
elif branch == "orelse":
constraints = set(_match_constraint(expr, parent.test, invert=True))
if constraints:
constraints_mapping[parent] = constraints
current_node = parent
return constraints_mapping
ALL_CONSTRAINT_CLASSES = frozenset((NoneConstraint,))
"""All supported constraint types."""
def _matches(node1: nodes.NodeNG | bases.Proxy, node2: nodes.NodeNG) -> bool:
"""Returns True if the two nodes match."""
if isinstance(node1, nodes.Name) and isinstance(node2, nodes.Name):
return node1.name == node2.name
if isinstance(node1, nodes.Attribute) and isinstance(node2, nodes.Attribute):
return node1.attrname == node2.attrname and _matches(node1.expr, node2.expr)
if isinstance(node1, nodes.Const) and isinstance(node2, nodes.Const):
return node1.value == node2.value
return False
def _match_constraint(
node: _NameNodes, expr: nodes.NodeNG, invert: bool = False
) -> Iterator[Constraint]:
"""Yields all constraint patterns for node that match."""
for constraint_cls in ALL_CONSTRAINT_CLASSES:
constraint = constraint_cls.match(node, expr, invert)
if constraint:
yield constraint