| # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html |
| # For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE |
| # Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt |
| |
| from __future__ import annotations |
| |
| import inspect |
| import random |
| |
| from astroid import nodes |
| from astroid.context import InferenceContext |
| from astroid.exceptions import UseInferenceDefault |
| from astroid.inference_tip import inference_tip |
| from astroid.manager import AstroidManager |
| from astroid.util import UninferableBase, safe_infer |
| |
| ACCEPTED_ITERABLES_FOR_SAMPLE = (nodes.List, nodes.Set, nodes.Tuple) |
| |
| |
| def _clone_node_with_lineno(node, parent, lineno): |
| if isinstance(node, nodes.EvaluatedObject): |
| node = node.original |
| cls = node.__class__ |
| other_fields = node._other_fields |
| _astroid_fields = node._astroid_fields |
| init_params = { |
| "lineno": lineno, |
| "col_offset": node.col_offset, |
| "parent": parent, |
| "end_lineno": node.end_lineno, |
| "end_col_offset": node.end_col_offset, |
| } |
| postinit_params = {param: getattr(node, param) for param in _astroid_fields} |
| |
| valid_init_params = set(inspect.signature(cls.__init__).parameters) |
| for param in other_fields: |
| if param in valid_init_params: |
| init_params[param] = getattr(node, param) |
| |
| new_node = cls(**init_params) |
| if hasattr(node, "postinit") and _astroid_fields: |
| new_node.postinit(**postinit_params) |
| |
| for param in other_fields: |
| if param not in valid_init_params: |
| setattr(new_node, param, getattr(node, param)) |
| return new_node |
| |
| |
| def infer_random_sample(node, context: InferenceContext | None = None): |
| if len(node.args) != 2: |
| raise UseInferenceDefault |
| |
| inferred_length = safe_infer(node.args[1], context=context) |
| if not isinstance(inferred_length, nodes.Const): |
| raise UseInferenceDefault |
| if not isinstance(inferred_length.value, int): |
| raise UseInferenceDefault |
| |
| inferred_sequence = safe_infer(node.args[0], context=context) |
| if not inferred_sequence: |
| raise UseInferenceDefault |
| |
| if not isinstance(inferred_sequence, ACCEPTED_ITERABLES_FOR_SAMPLE): |
| raise UseInferenceDefault |
| |
| if inferred_length.value > len(inferred_sequence.elts): |
| # In this case, this will raise a ValueError |
| raise UseInferenceDefault |
| |
| if any(isinstance(elt, UninferableBase) for elt in inferred_sequence.elts): |
| raise UseInferenceDefault |
| |
| try: |
| elts = random.sample(inferred_sequence.elts, inferred_length.value) |
| except ValueError as exc: |
| raise UseInferenceDefault from exc |
| |
| new_node = nodes.List( |
| lineno=node.lineno, |
| col_offset=node.col_offset, |
| parent=node.scope(), |
| end_lineno=node.end_lineno, |
| end_col_offset=node.end_col_offset, |
| ) |
| new_elts = [ |
| _clone_node_with_lineno(elt, parent=new_node, lineno=new_node.lineno) |
| for elt in elts |
| ] |
| new_node.postinit(new_elts) |
| return iter((new_node,)) |
| |
| |
| def _looks_like_random_sample(node) -> bool: |
| func = node.func |
| if isinstance(func, nodes.Attribute): |
| return func.attrname == "sample" |
| if isinstance(func, nodes.Name): |
| return func.name == "sample" |
| return False |
| |
| |
| def register(manager: AstroidManager) -> None: |
| manager.register_transform( |
| nodes.Call, inference_tip(infer_random_sample), _looks_like_random_sample |
| ) |