| """Utilities related to determining the reachability of code (in semantic analysis).""" |
| |
| from typing import Tuple, TypeVar, Union, Optional |
| from typing_extensions import Final |
| |
| from mypy.nodes import ( |
| Expression, IfStmt, Block, AssertStmt, NameExpr, UnaryExpr, MemberExpr, OpExpr, ComparisonExpr, |
| StrExpr, UnicodeExpr, CallExpr, IntExpr, TupleExpr, IndexExpr, SliceExpr, Import, ImportFrom, |
| ImportAll, LITERAL_YES |
| ) |
| from mypy.options import Options |
| from mypy.traverser import TraverserVisitor |
| from mypy.literals import literal |
| |
| # Inferred truth value of an expression. |
| ALWAYS_TRUE: Final = 1 |
| MYPY_TRUE: Final = 2 # True in mypy, False at runtime |
| ALWAYS_FALSE: Final = 3 |
| MYPY_FALSE: Final = 4 # False in mypy, True at runtime |
| TRUTH_VALUE_UNKNOWN: Final = 5 |
| |
| inverted_truth_mapping: Final = { |
| ALWAYS_TRUE: ALWAYS_FALSE, |
| ALWAYS_FALSE: ALWAYS_TRUE, |
| TRUTH_VALUE_UNKNOWN: TRUTH_VALUE_UNKNOWN, |
| MYPY_TRUE: MYPY_FALSE, |
| MYPY_FALSE: MYPY_TRUE, |
| } |
| |
| reverse_op: Final = { |
| "==": "==", |
| "!=": "!=", |
| "<": ">", |
| ">": "<", |
| "<=": ">=", |
| ">=": "<=", |
| } |
| |
| |
| def infer_reachability_of_if_statement(s: IfStmt, options: Options) -> None: |
| for i in range(len(s.expr)): |
| result = infer_condition_value(s.expr[i], options) |
| if result in (ALWAYS_FALSE, MYPY_FALSE): |
| # The condition is considered always false, so we skip the if/elif body. |
| mark_block_unreachable(s.body[i]) |
| elif result in (ALWAYS_TRUE, MYPY_TRUE): |
| # This condition is considered always true, so all of the remaining |
| # elif/else bodies should not be checked. |
| if result == MYPY_TRUE: |
| # This condition is false at runtime; this will affect |
| # import priorities. |
| mark_block_mypy_only(s.body[i]) |
| for body in s.body[i + 1:]: |
| mark_block_unreachable(body) |
| |
| # Make sure else body always exists and is marked as |
| # unreachable so the type checker always knows that |
| # all control flow paths will flow through the if |
| # statement body. |
| if not s.else_body: |
| s.else_body = Block([]) |
| mark_block_unreachable(s.else_body) |
| break |
| |
| |
| def assert_will_always_fail(s: AssertStmt, options: Options) -> bool: |
| return infer_condition_value(s.expr, options) in (ALWAYS_FALSE, MYPY_FALSE) |
| |
| |
| def infer_condition_value(expr: Expression, options: Options) -> int: |
| """Infer whether the given condition is always true/false. |
| |
| Return ALWAYS_TRUE if always true, ALWAYS_FALSE if always false, |
| MYPY_TRUE if true under mypy and false at runtime, MYPY_FALSE if |
| false under mypy and true at runtime, else TRUTH_VALUE_UNKNOWN. |
| """ |
| pyversion = options.python_version |
| name = '' |
| negated = False |
| alias = expr |
| if isinstance(alias, UnaryExpr): |
| if alias.op == 'not': |
| expr = alias.expr |
| negated = True |
| result = TRUTH_VALUE_UNKNOWN |
| if isinstance(expr, NameExpr): |
| name = expr.name |
| elif isinstance(expr, MemberExpr): |
| name = expr.name |
| elif isinstance(expr, OpExpr) and expr.op in ('and', 'or'): |
| left = infer_condition_value(expr.left, options) |
| if ((left in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == 'and') or |
| (left in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == 'or')): |
| # Either `True and <other>` or `False or <other>`: the result will |
| # always be the right-hand-side. |
| return infer_condition_value(expr.right, options) |
| else: |
| # The result will always be the left-hand-side (e.g. ALWAYS_* or |
| # TRUTH_VALUE_UNKNOWN). |
| return left |
| else: |
| result = consider_sys_version_info(expr, pyversion) |
| if result == TRUTH_VALUE_UNKNOWN: |
| result = consider_sys_platform(expr, options.platform) |
| if result == TRUTH_VALUE_UNKNOWN: |
| if name == 'PY2': |
| result = ALWAYS_TRUE if pyversion[0] == 2 else ALWAYS_FALSE |
| elif name == 'PY3': |
| result = ALWAYS_TRUE if pyversion[0] == 3 else ALWAYS_FALSE |
| elif name == 'MYPY' or name == 'TYPE_CHECKING': |
| result = MYPY_TRUE |
| elif name in options.always_true: |
| result = ALWAYS_TRUE |
| elif name in options.always_false: |
| result = ALWAYS_FALSE |
| if negated: |
| result = inverted_truth_mapping[result] |
| return result |
| |
| |
| def consider_sys_version_info(expr: Expression, pyversion: Tuple[int, ...]) -> int: |
| """Consider whether expr is a comparison involving sys.version_info. |
| |
| Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN. |
| """ |
| # Cases supported: |
| # - sys.version_info[<int>] <compare_op> <int> |
| # - sys.version_info[:<int>] <compare_op> <tuple_of_n_ints> |
| # - sys.version_info <compare_op> <tuple_of_1_or_2_ints> |
| # (in this case <compare_op> must be >, >=, <, <=, but cannot be ==, !=) |
| if not isinstance(expr, ComparisonExpr): |
| return TRUTH_VALUE_UNKNOWN |
| # Let's not yet support chained comparisons. |
| if len(expr.operators) > 1: |
| return TRUTH_VALUE_UNKNOWN |
| op = expr.operators[0] |
| if op not in ('==', '!=', '<=', '>=', '<', '>'): |
| return TRUTH_VALUE_UNKNOWN |
| |
| index = contains_sys_version_info(expr.operands[0]) |
| thing = contains_int_or_tuple_of_ints(expr.operands[1]) |
| if index is None or thing is None: |
| index = contains_sys_version_info(expr.operands[1]) |
| thing = contains_int_or_tuple_of_ints(expr.operands[0]) |
| op = reverse_op[op] |
| if isinstance(index, int) and isinstance(thing, int): |
| # sys.version_info[i] <compare_op> k |
| if 0 <= index <= 1: |
| return fixed_comparison(pyversion[index], op, thing) |
| else: |
| return TRUTH_VALUE_UNKNOWN |
| elif isinstance(index, tuple) and isinstance(thing, tuple): |
| lo, hi = index |
| if lo is None: |
| lo = 0 |
| if hi is None: |
| hi = 2 |
| if 0 <= lo < hi <= 2: |
| val = pyversion[lo:hi] |
| if len(val) == len(thing) or len(val) > len(thing) and op not in ('==', '!='): |
| return fixed_comparison(val, op, thing) |
| return TRUTH_VALUE_UNKNOWN |
| |
| |
| def consider_sys_platform(expr: Expression, platform: str) -> int: |
| """Consider whether expr is a comparison involving sys.platform. |
| |
| Return ALWAYS_TRUE, ALWAYS_FALSE, or TRUTH_VALUE_UNKNOWN. |
| """ |
| # Cases supported: |
| # - sys.platform == 'posix' |
| # - sys.platform != 'win32' |
| # - sys.platform.startswith('win') |
| if isinstance(expr, ComparisonExpr): |
| # Let's not yet support chained comparisons. |
| if len(expr.operators) > 1: |
| return TRUTH_VALUE_UNKNOWN |
| op = expr.operators[0] |
| if op not in ('==', '!='): |
| return TRUTH_VALUE_UNKNOWN |
| if not is_sys_attr(expr.operands[0], 'platform'): |
| return TRUTH_VALUE_UNKNOWN |
| right = expr.operands[1] |
| if not isinstance(right, (StrExpr, UnicodeExpr)): |
| return TRUTH_VALUE_UNKNOWN |
| return fixed_comparison(platform, op, right.value) |
| elif isinstance(expr, CallExpr): |
| if not isinstance(expr.callee, MemberExpr): |
| return TRUTH_VALUE_UNKNOWN |
| if len(expr.args) != 1 or not isinstance(expr.args[0], (StrExpr, UnicodeExpr)): |
| return TRUTH_VALUE_UNKNOWN |
| if not is_sys_attr(expr.callee.expr, 'platform'): |
| return TRUTH_VALUE_UNKNOWN |
| if expr.callee.name != 'startswith': |
| return TRUTH_VALUE_UNKNOWN |
| if platform.startswith(expr.args[0].value): |
| return ALWAYS_TRUE |
| else: |
| return ALWAYS_FALSE |
| else: |
| return TRUTH_VALUE_UNKNOWN |
| |
| |
| Targ = TypeVar('Targ', int, str, Tuple[int, ...]) |
| |
| |
| def fixed_comparison(left: Targ, op: str, right: Targ) -> int: |
| rmap = {False: ALWAYS_FALSE, True: ALWAYS_TRUE} |
| if op == '==': |
| return rmap[left == right] |
| if op == '!=': |
| return rmap[left != right] |
| if op == '<=': |
| return rmap[left <= right] |
| if op == '>=': |
| return rmap[left >= right] |
| if op == '<': |
| return rmap[left < right] |
| if op == '>': |
| return rmap[left > right] |
| return TRUTH_VALUE_UNKNOWN |
| |
| |
| def contains_int_or_tuple_of_ints(expr: Expression |
| ) -> Union[None, int, Tuple[int], Tuple[int, ...]]: |
| if isinstance(expr, IntExpr): |
| return expr.value |
| if isinstance(expr, TupleExpr): |
| if literal(expr) == LITERAL_YES: |
| thing = [] |
| for x in expr.items: |
| if not isinstance(x, IntExpr): |
| return None |
| thing.append(x.value) |
| return tuple(thing) |
| return None |
| |
| |
| def contains_sys_version_info(expr: Expression |
| ) -> Union[None, int, Tuple[Optional[int], Optional[int]]]: |
| if is_sys_attr(expr, 'version_info'): |
| return (None, None) # Same as sys.version_info[:] |
| if isinstance(expr, IndexExpr) and is_sys_attr(expr.base, 'version_info'): |
| index = expr.index |
| if isinstance(index, IntExpr): |
| return index.value |
| if isinstance(index, SliceExpr): |
| if index.stride is not None: |
| if not isinstance(index.stride, IntExpr) or index.stride.value != 1: |
| return None |
| begin = end = None |
| if index.begin_index is not None: |
| if not isinstance(index.begin_index, IntExpr): |
| return None |
| begin = index.begin_index.value |
| if index.end_index is not None: |
| if not isinstance(index.end_index, IntExpr): |
| return None |
| end = index.end_index.value |
| return (begin, end) |
| return None |
| |
| |
| def is_sys_attr(expr: Expression, name: str) -> bool: |
| # TODO: This currently doesn't work with code like this: |
| # - import sys as _sys |
| # - from sys import version_info |
| if isinstance(expr, MemberExpr) and expr.name == name: |
| if isinstance(expr.expr, NameExpr) and expr.expr.name == 'sys': |
| # TODO: Guard against a local named sys, etc. |
| # (Though later passes will still do most checking.) |
| return True |
| return False |
| |
| |
| def mark_block_unreachable(block: Block) -> None: |
| block.is_unreachable = True |
| block.accept(MarkImportsUnreachableVisitor()) |
| |
| |
| class MarkImportsUnreachableVisitor(TraverserVisitor): |
| """Visitor that flags all imports nested within a node as unreachable.""" |
| |
| def visit_import(self, node: Import) -> None: |
| node.is_unreachable = True |
| |
| def visit_import_from(self, node: ImportFrom) -> None: |
| node.is_unreachable = True |
| |
| def visit_import_all(self, node: ImportAll) -> None: |
| node.is_unreachable = True |
| |
| |
| def mark_block_mypy_only(block: Block) -> None: |
| block.accept(MarkImportsMypyOnlyVisitor()) |
| |
| |
| class MarkImportsMypyOnlyVisitor(TraverserVisitor): |
| """Visitor that sets is_mypy_only (which affects priority).""" |
| |
| def visit_import(self, node: Import) -> None: |
| node.is_mypy_only = True |
| |
| def visit_import_from(self, node: ImportFrom) -> None: |
| node.is_mypy_only = True |
| |
| def visit_import_all(self, node: ImportAll) -> None: |
| node.is_mypy_only = True |