blob: 6a6afc1b485d5fd0f6ddc0b05234d2b54628a10e [file] [log] [blame]
# Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
# For details: https://github.com/pylint-dev/pylint/blob/main/LICENSE
# Copyright (c) https://github.com/pylint-dev/pylint/blob/main/CONTRIBUTORS.txt
"""Tests for pylint.pyreverse.utils."""
from __future__ import annotations
from typing import Any
from unittest.mock import patch
import astroid
import pytest
from astroid import nodes
from pylint.pyreverse.utils import (
get_annotation,
get_annotation_label,
get_visibility,
infer_node,
)
@pytest.mark.parametrize(
"names, expected",
[
(["__reduce_ex__", "__setattr__"], "special"),
(["__g_", "____dsf", "__23_9"], "private"),
(["simple"], "public"),
(
["_", "__", "___", "____", "_____", "___e__", "_nextsimple", "_filter_it_"],
"protected",
),
],
)
def test_get_visibility(names: list[str], expected: str) -> None:
for name in names:
got = get_visibility(name)
assert got == expected, f"got {got} instead of {expected} for value {name}"
@pytest.mark.parametrize(
"assign, label",
[
("a: str = None", "Optional[str]"),
("a: str = 'mystr'", "str"),
("a: Optional[str] = 'str'", "Optional[str]"),
("a: Optional[str] = None", "Optional[str]"),
],
)
def test_get_annotation_annassign(assign: str, label: str) -> None:
"""AnnAssign."""
node: nodes.AnnAssign = astroid.extract_node(assign)
annotation = get_annotation(node.value)
assert annotation is not None
got = annotation.name
assert isinstance(node, nodes.AnnAssign)
assert got == label, f"got {got} instead of {label} for value {node}"
@pytest.mark.parametrize(
"init_method, label",
[
("def __init__(self, x: str): self.x = x", "str"),
("def __init__(self, x: str = 'str'): self.x = x", "str"),
("def __init__(self, x: str = None): self.x = x", "Optional[str]"),
("def __init__(self, x: Optional[str]): self.x = x", "Optional[str]"),
("def __init__(self, x: Optional[str] = None): self.x = x", "Optional[str]"),
("def __init__(self, x: Optional[str] = 'str'): self.x = x", "Optional[str]"),
],
)
def test_get_annotation_assignattr(init_method: str, label: str) -> None:
"""AssignAttr."""
assign = rf"""
class A:
{init_method}
"""
node = astroid.extract_node(assign)
instance_attrs = node.instance_attrs
for assign_attrs in instance_attrs.values():
for assign_attr in assign_attrs:
annotation = get_annotation(assign_attr)
assert annotation is not None
got = annotation.name
assert isinstance(assign_attr, nodes.AssignAttr)
assert got == label, f"got {got} instead of {label} for value {node}"
@pytest.mark.parametrize(
"node_text, expected_label",
[
("def f() -> None: pass", "None"),
("def f() -> int: pass", "int"),
("def f(a) -> Optional[int]: return 1 if a else None", "Optional[int]"),
("def f() -> 'MyType': pass", "'MyType'"),
],
)
def test_get_annotation_label_of_return_type(
node_text: str, expected_label: str
) -> None:
func = astroid.extract_node(node_text)
assert isinstance(func, nodes.FunctionDef)
assert get_annotation_label(func.returns) == expected_label
@patch("pylint.pyreverse.utils.get_annotation")
@patch("astroid.nodes.NodeNG.infer", side_effect=astroid.InferenceError)
def test_infer_node_1(mock_infer: Any, mock_get_annotation: Any) -> None:
"""Return set() when astroid.InferenceError is raised and an annotation has
not been returned.
"""
mock_get_annotation.return_value = None
node = astroid.extract_node("a: str = 'mystr'")
mock_infer.return_value = "x"
assert infer_node(node) == set()
assert mock_infer.called
@patch("pylint.pyreverse.utils.get_annotation")
@patch("astroid.nodes.NodeNG.infer")
def test_infer_node_2(mock_infer: Any, mock_get_annotation: Any) -> None:
"""Return set(node.infer()) when InferenceError is not raised and an
annotation has not been returned.
"""
mock_get_annotation.return_value = None
node = astroid.extract_node("a: str = 'mystr'")
mock_infer.return_value = "x"
assert infer_node(node) == set("x")
assert mock_infer.called
def test_infer_node_3() -> None:
"""Return a set containing a nodes.ClassDef object when the attribute
has a type annotation.
"""
node = astroid.extract_node(
"""
class Component:
pass
class Composite:
def __init__(self, component: Component):
self.component = component
"""
)
instance_attr = node.instance_attrs.get("component")[0]
assert isinstance(infer_node(instance_attr), set)
assert isinstance(infer_node(instance_attr).pop(), nodes.ClassDef)
def test_infer_node_4() -> None:
"""Verify the label for an argument with a typehint of the type
nodes.Subscript.
"""
node = astroid.extract_node(
"""
class MyClass:
def __init__(self, my_int: Optional[int] = None):
self.my_test_int = my_int
"""
)
instance_attr = node.instance_attrs.get("my_test_int")[0]
assert isinstance(instance_attr, nodes.AssignAttr)
inferred = infer_node(instance_attr).pop()
assert isinstance(inferred, nodes.Subscript)
assert inferred.name == "Optional[int]"