| # Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html |
| # For details: https://github.com/pylint-dev/pylint/blob/main/LICENSE |
| # Copyright (c) https://github.com/pylint-dev/pylint/blob/main/CONTRIBUTORS.txt |
| |
| """Generic classes/functions for pyreverse core/extensions.""" |
| |
| from __future__ import annotations |
| |
| import os |
| import re |
| import shutil |
| import subprocess |
| import sys |
| from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union |
| |
| import astroid |
| from astroid import nodes |
| from astroid.typing import InferenceResult |
| |
| if TYPE_CHECKING: |
| from pylint.pyreverse.diagrams import ClassDiagram, PackageDiagram |
| |
| _CallbackT = Callable[ |
| [nodes.NodeNG], |
| Union[Tuple[ClassDiagram], Tuple[PackageDiagram, ClassDiagram], None], |
| ] |
| _CallbackTupleT = Tuple[Optional[_CallbackT], Optional[_CallbackT]] |
| |
| |
| RCFILE = ".pyreverserc" |
| |
| |
| def get_default_options() -> list[str]: |
| """Read config file and return list of options.""" |
| options = [] |
| home = os.environ.get("HOME", "") |
| if home: |
| rcfile = os.path.join(home, RCFILE) |
| try: |
| with open(rcfile, encoding="utf-8") as file_handle: |
| options = file_handle.read().split() |
| except OSError: |
| pass # ignore if no config file found |
| return options |
| |
| |
| def insert_default_options() -> None: |
| """Insert default options to sys.argv.""" |
| options = get_default_options() |
| options.reverse() |
| for arg in options: |
| sys.argv.insert(1, arg) |
| |
| |
| # astroid utilities ########################################################### |
| SPECIAL = re.compile(r"^__([^\W_]_*)+__$") |
| PRIVATE = re.compile(r"^__(_*[^\W_])+_?$") |
| PROTECTED = re.compile(r"^_\w*$") |
| |
| |
| def get_visibility(name: str) -> str: |
| """Return the visibility from a name: public, protected, private or special.""" |
| if SPECIAL.match(name): |
| visibility = "special" |
| elif PRIVATE.match(name): |
| visibility = "private" |
| elif PROTECTED.match(name): |
| visibility = "protected" |
| |
| else: |
| visibility = "public" |
| return visibility |
| |
| |
| def is_exception(node: nodes.ClassDef) -> bool: |
| # bw compatibility |
| return node.type == "exception" # type: ignore[no-any-return] |
| |
| |
| # Helpers ##################################################################### |
| |
| _SPECIAL = 2 |
| _PROTECTED = 4 |
| _PRIVATE = 8 |
| MODES = { |
| "ALL": 0, |
| "PUB_ONLY": _SPECIAL + _PROTECTED + _PRIVATE, |
| "SPECIAL": _SPECIAL, |
| "OTHER": _PROTECTED + _PRIVATE, |
| } |
| VIS_MOD = { |
| "special": _SPECIAL, |
| "protected": _PROTECTED, |
| "private": _PRIVATE, |
| "public": 0, |
| } |
| |
| |
| class FilterMixIn: |
| """Filter nodes according to a mode and nodes' visibility.""" |
| |
| def __init__(self, mode: str) -> None: |
| """Init filter modes.""" |
| __mode = 0 |
| for nummod in mode.split("+"): |
| try: |
| __mode += MODES[nummod] |
| except KeyError as ex: |
| print(f"Unknown filter mode {ex}", file=sys.stderr) |
| self.__mode = __mode |
| |
| def show_attr(self, node: nodes.NodeNG | str) -> bool: |
| """Return true if the node should be treated.""" |
| visibility = get_visibility(getattr(node, "name", node)) |
| return not self.__mode & VIS_MOD[visibility] |
| |
| |
| class LocalsVisitor: |
| """Visit a project by traversing the locals dictionary. |
| |
| * visit_<class name> on entering a node, where class name is the class of |
| the node in lower case |
| |
| * leave_<class name> on leaving a node, where class name is the class of |
| the node in lower case |
| """ |
| |
| def __init__(self) -> None: |
| self._cache: dict[type[nodes.NodeNG], _CallbackTupleT] = {} |
| self._visited: set[nodes.NodeNG] = set() |
| |
| def get_callbacks(self, node: nodes.NodeNG) -> _CallbackTupleT: |
| """Get callbacks from handler for the visited node.""" |
| klass = node.__class__ |
| methods = self._cache.get(klass) |
| if methods is None: |
| kid = klass.__name__.lower() |
| e_method = getattr( |
| self, f"visit_{kid}", getattr(self, "visit_default", None) |
| ) |
| l_method = getattr( |
| self, f"leave_{kid}", getattr(self, "leave_default", None) |
| ) |
| self._cache[klass] = (e_method, l_method) |
| else: |
| e_method, l_method = methods |
| return e_method, l_method |
| |
| def visit(self, node: nodes.NodeNG) -> Any: |
| """Launch the visit starting from the given node.""" |
| if node in self._visited: |
| return None |
| |
| self._visited.add(node) |
| methods = self.get_callbacks(node) |
| if methods[0] is not None: |
| methods[0](node) |
| if hasattr(node, "locals"): # skip Instance and other proxy |
| for local_node in node.values(): |
| self.visit(local_node) |
| if methods[1] is not None: |
| return methods[1](node) |
| return None |
| |
| |
| def get_annotation_label(ann: nodes.Name | nodes.NodeNG) -> str: |
| if isinstance(ann, nodes.Name) and ann.name is not None: |
| return ann.name # type: ignore[no-any-return] |
| if isinstance(ann, nodes.NodeNG): |
| return ann.as_string() # type: ignore[no-any-return] |
| return "" |
| |
| |
| def get_annotation( |
| node: nodes.AssignAttr | nodes.AssignName, |
| ) -> nodes.Name | nodes.Subscript | None: |
| """Return the annotation for `node`.""" |
| ann = None |
| if isinstance(node.parent, nodes.AnnAssign): |
| ann = node.parent.annotation |
| elif isinstance(node, nodes.AssignAttr): |
| init_method = node.parent.parent |
| try: |
| annotations = dict(zip(init_method.locals, init_method.args.annotations)) |
| ann = annotations.get(node.parent.value.name) |
| except AttributeError: |
| pass |
| else: |
| return ann |
| |
| try: |
| default, *_ = node.infer() |
| except astroid.InferenceError: |
| default = "" |
| |
| label = get_annotation_label(ann) |
| |
| if ( |
| ann |
| and getattr(default, "value", "value") is None |
| and not label.startswith("Optional") |
| and ( |
| not isinstance(ann, nodes.BinOp) |
| or not any( |
| isinstance(child, nodes.Const) and child.value is None |
| for child in ann.get_children() |
| ) |
| ) |
| ): |
| label = rf"Optional[{label}]" |
| |
| if label and ann: |
| ann.name = label |
| return ann |
| |
| |
| def infer_node(node: nodes.AssignAttr | nodes.AssignName) -> set[InferenceResult]: |
| """Return a set containing the node annotation if it exists |
| otherwise return a set of the inferred types using the NodeNG.infer method. |
| """ |
| ann = get_annotation(node) |
| try: |
| if ann: |
| if isinstance(ann, nodes.Subscript) or ( |
| isinstance(ann, nodes.BinOp) and ann.op == "|" |
| ): |
| return {ann} |
| return set(ann.infer()) |
| return set(node.infer()) |
| except astroid.InferenceError: |
| return {ann} if ann else set() |
| |
| |
| def check_graphviz_availability() -> None: |
| """Check if the ``dot`` command is available on the machine. |
| |
| This is needed if image output is desired and ``dot`` is used to convert |
| from *.dot or *.gv into the final output format. |
| """ |
| if shutil.which("dot") is None: |
| print("'Graphviz' needs to be installed for your chosen output format.") |
| sys.exit(32) |
| |
| |
| def check_if_graphviz_supports_format(output_format: str) -> None: |
| """Check if the ``dot`` command supports the requested output format. |
| |
| This is needed if image output is desired and ``dot`` is used to convert |
| from *.gv into the final output format. |
| """ |
| dot_output = subprocess.run( |
| ["dot", "-T?"], capture_output=True, check=False, encoding="utf-8" |
| ) |
| match = re.match( |
| pattern=r".*Use one of: (?P<formats>(\S*\s?)+)", |
| string=dot_output.stderr.strip(), |
| ) |
| if not match: |
| print( |
| "Unable to determine Graphviz supported output formats. " |
| "Pyreverse will continue, but subsequent error messages " |
| "regarding the output format may come from Graphviz directly." |
| ) |
| return |
| supported_formats = match.group("formats") |
| if output_format not in supported_formats.split(): |
| print( |
| f"Format {output_format} is not supported by Graphviz. It supports: {supported_formats}" |
| ) |
| sys.exit(32) |