| # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html |
| # For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE |
| # Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt |
| |
| """Classes representing different types of constraints on inference values.""" |
| from __future__ import annotations |
| |
| import sys |
| from abc import ABC, abstractmethod |
| from collections.abc import Iterator |
| from typing import TYPE_CHECKING, Union |
| |
| from astroid import nodes, util |
| from astroid.typing import InferenceResult |
| |
| if sys.version_info >= (3, 11): |
| from typing import Self |
| else: |
| from typing_extensions import Self |
| |
| if TYPE_CHECKING: |
| from astroid import bases |
| |
| _NameNodes = Union[nodes.AssignAttr, nodes.Attribute, nodes.AssignName, nodes.Name] |
| |
| |
| class Constraint(ABC): |
| """Represents a single constraint on a variable.""" |
| |
| def __init__(self, node: nodes.NodeNG, negate: bool) -> None: |
| self.node = node |
| """The node that this constraint applies to.""" |
| self.negate = negate |
| """True if this constraint is negated. E.g., "is not" instead of "is".""" |
| |
| @classmethod |
| @abstractmethod |
| def match( |
| cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False |
| ) -> Self | None: |
| """Return a new constraint for node matched from expr, if expr matches |
| the constraint pattern. |
| |
| If negate is True, negate the constraint. |
| """ |
| |
| @abstractmethod |
| def satisfied_by(self, inferred: InferenceResult) -> bool: |
| """Return True if this constraint is satisfied by the given inferred value.""" |
| |
| |
| class NoneConstraint(Constraint): |
| """Represents an "is None" or "is not None" constraint.""" |
| |
| CONST_NONE: nodes.Const = nodes.Const(None) |
| |
| @classmethod |
| def match( |
| cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False |
| ) -> Self | None: |
| """Return a new constraint for node matched from expr, if expr matches |
| the constraint pattern. |
| |
| Negate the constraint based on the value of negate. |
| """ |
| if isinstance(expr, nodes.Compare) and len(expr.ops) == 1: |
| left = expr.left |
| op, right = expr.ops[0] |
| if op in {"is", "is not"} and ( |
| _matches(left, node) and _matches(right, cls.CONST_NONE) |
| ): |
| negate = (op == "is" and negate) or (op == "is not" and not negate) |
| return cls(node=node, negate=negate) |
| |
| return None |
| |
| def satisfied_by(self, inferred: InferenceResult) -> bool: |
| """Return True if this constraint is satisfied by the given inferred value.""" |
| # Assume true if uninferable |
| if isinstance(inferred, util.UninferableBase): |
| return True |
| |
| # Return the XOR of self.negate and matches(inferred, self.CONST_NONE) |
| return self.negate ^ _matches(inferred, self.CONST_NONE) |
| |
| |
| def get_constraints( |
| expr: _NameNodes, frame: nodes.LocalsDictNodeNG |
| ) -> dict[nodes.If, set[Constraint]]: |
| """Returns the constraints for the given expression. |
| |
| The returned dictionary maps the node where the constraint was generated to the |
| corresponding constraint(s). |
| |
| Constraints are computed statically by analysing the code surrounding expr. |
| Currently this only supports constraints generated from if conditions. |
| """ |
| current_node: nodes.NodeNG | None = expr |
| constraints_mapping: dict[nodes.If, set[Constraint]] = {} |
| while current_node is not None and current_node is not frame: |
| parent = current_node.parent |
| if isinstance(parent, nodes.If): |
| branch, _ = parent.locate_child(current_node) |
| constraints: set[Constraint] | None = None |
| if branch == "body": |
| constraints = set(_match_constraint(expr, parent.test)) |
| elif branch == "orelse": |
| constraints = set(_match_constraint(expr, parent.test, invert=True)) |
| |
| if constraints: |
| constraints_mapping[parent] = constraints |
| current_node = parent |
| |
| return constraints_mapping |
| |
| |
| ALL_CONSTRAINT_CLASSES = frozenset((NoneConstraint,)) |
| """All supported constraint types.""" |
| |
| |
| def _matches(node1: nodes.NodeNG | bases.Proxy, node2: nodes.NodeNG) -> bool: |
| """Returns True if the two nodes match.""" |
| if isinstance(node1, nodes.Name) and isinstance(node2, nodes.Name): |
| return node1.name == node2.name |
| if isinstance(node1, nodes.Attribute) and isinstance(node2, nodes.Attribute): |
| return node1.attrname == node2.attrname and _matches(node1.expr, node2.expr) |
| if isinstance(node1, nodes.Const) and isinstance(node2, nodes.Const): |
| return node1.value == node2.value |
| |
| return False |
| |
| |
| def _match_constraint( |
| node: _NameNodes, expr: nodes.NodeNG, invert: bool = False |
| ) -> Iterator[Constraint]: |
| """Yields all constraint patterns for node that match.""" |
| for constraint_cls in ALL_CONSTRAINT_CLASSES: |
| constraint = constraint_cls.match(node, expr, invert) |
| if constraint: |
| yield constraint |