| from __future__ import annotations |
| |
| from enum import Enum |
| |
| from mypy import checker, errorcodes |
| from mypy.messages import MessageBuilder |
| from mypy.nodes import ( |
| AssertStmt, |
| AssignmentExpr, |
| AssignmentStmt, |
| BreakStmt, |
| ClassDef, |
| Context, |
| ContinueStmt, |
| DictionaryComprehension, |
| Expression, |
| ExpressionStmt, |
| ForStmt, |
| FuncDef, |
| FuncItem, |
| GeneratorExpr, |
| GlobalDecl, |
| IfStmt, |
| Import, |
| ImportFrom, |
| LambdaExpr, |
| ListExpr, |
| Lvalue, |
| MatchStmt, |
| MypyFile, |
| NameExpr, |
| NonlocalDecl, |
| RaiseStmt, |
| ReturnStmt, |
| StarExpr, |
| SymbolTable, |
| TryStmt, |
| TupleExpr, |
| WhileStmt, |
| WithStmt, |
| implicit_module_attrs, |
| ) |
| from mypy.options import Options |
| from mypy.patterns import AsPattern, StarredPattern |
| from mypy.reachability import ALWAYS_TRUE, infer_pattern_value |
| from mypy.traverser import ExtendedTraverserVisitor |
| from mypy.types import Type, UninhabitedType |
| |
| |
| class BranchState: |
| """BranchState contains information about variable definition at the end of a branching statement. |
| `if` and `match` are examples of branching statements. |
| |
| `may_be_defined` contains variables that were defined in only some branches. |
| `must_be_defined` contains variables that were defined in all branches. |
| """ |
| |
| def __init__( |
| self, |
| must_be_defined: set[str] | None = None, |
| may_be_defined: set[str] | None = None, |
| skipped: bool = False, |
| ) -> None: |
| if may_be_defined is None: |
| may_be_defined = set() |
| if must_be_defined is None: |
| must_be_defined = set() |
| |
| self.may_be_defined = set(may_be_defined) |
| self.must_be_defined = set(must_be_defined) |
| self.skipped = skipped |
| |
| def copy(self) -> BranchState: |
| return BranchState( |
| must_be_defined=set(self.must_be_defined), |
| may_be_defined=set(self.may_be_defined), |
| skipped=self.skipped, |
| ) |
| |
| |
| class BranchStatement: |
| def __init__(self, initial_state: BranchState | None = None) -> None: |
| if initial_state is None: |
| initial_state = BranchState() |
| self.initial_state = initial_state |
| self.branches: list[BranchState] = [ |
| BranchState( |
| must_be_defined=self.initial_state.must_be_defined, |
| may_be_defined=self.initial_state.may_be_defined, |
| ) |
| ] |
| |
| def copy(self) -> BranchStatement: |
| result = BranchStatement(self.initial_state) |
| result.branches = [b.copy() for b in self.branches] |
| return result |
| |
| def next_branch(self) -> None: |
| self.branches.append( |
| BranchState( |
| must_be_defined=self.initial_state.must_be_defined, |
| may_be_defined=self.initial_state.may_be_defined, |
| ) |
| ) |
| |
| def record_definition(self, name: str) -> None: |
| assert len(self.branches) > 0 |
| self.branches[-1].must_be_defined.add(name) |
| self.branches[-1].may_be_defined.discard(name) |
| |
| def delete_var(self, name: str) -> None: |
| assert len(self.branches) > 0 |
| self.branches[-1].must_be_defined.discard(name) |
| self.branches[-1].may_be_defined.discard(name) |
| |
| def record_nested_branch(self, state: BranchState) -> None: |
| assert len(self.branches) > 0 |
| current_branch = self.branches[-1] |
| if state.skipped: |
| current_branch.skipped = True |
| return |
| current_branch.must_be_defined.update(state.must_be_defined) |
| current_branch.may_be_defined.update(state.may_be_defined) |
| current_branch.may_be_defined.difference_update(current_branch.must_be_defined) |
| |
| def skip_branch(self) -> None: |
| assert len(self.branches) > 0 |
| self.branches[-1].skipped = True |
| |
| def is_possibly_undefined(self, name: str) -> bool: |
| assert len(self.branches) > 0 |
| return name in self.branches[-1].may_be_defined |
| |
| def is_undefined(self, name: str) -> bool: |
| assert len(self.branches) > 0 |
| branch = self.branches[-1] |
| return name not in branch.may_be_defined and name not in branch.must_be_defined |
| |
| def is_defined_in_a_branch(self, name: str) -> bool: |
| assert len(self.branches) > 0 |
| for b in self.branches: |
| if name in b.must_be_defined or name in b.may_be_defined: |
| return True |
| return False |
| |
| def done(self) -> BranchState: |
| # First, compute all vars, including skipped branches. We include skipped branches |
| # because our goal is to capture all variables that semantic analyzer would |
| # consider defined. |
| all_vars = set() |
| for b in self.branches: |
| all_vars.update(b.may_be_defined) |
| all_vars.update(b.must_be_defined) |
| # For the rest of the things, we only care about branches that weren't skipped. |
| non_skipped_branches = [b for b in self.branches if not b.skipped] |
| if non_skipped_branches: |
| must_be_defined = non_skipped_branches[0].must_be_defined |
| for b in non_skipped_branches[1:]: |
| must_be_defined.intersection_update(b.must_be_defined) |
| else: |
| must_be_defined = set() |
| # Everything that wasn't defined in all branches but was defined |
| # in at least one branch should be in `may_be_defined`! |
| may_be_defined = all_vars.difference(must_be_defined) |
| return BranchState( |
| must_be_defined=must_be_defined, |
| may_be_defined=may_be_defined, |
| skipped=len(non_skipped_branches) == 0, |
| ) |
| |
| |
| class ScopeType(Enum): |
| Global = 1 |
| Class = 2 |
| Func = 3 |
| Generator = 4 |
| |
| |
| class Scope: |
| def __init__(self, stmts: list[BranchStatement], scope_type: ScopeType) -> None: |
| self.branch_stmts: list[BranchStatement] = stmts |
| self.scope_type = scope_type |
| self.undefined_refs: dict[str, set[NameExpr]] = {} |
| |
| def copy(self) -> Scope: |
| result = Scope([s.copy() for s in self.branch_stmts], self.scope_type) |
| result.undefined_refs = self.undefined_refs.copy() |
| return result |
| |
| def record_undefined_ref(self, o: NameExpr) -> None: |
| if o.name not in self.undefined_refs: |
| self.undefined_refs[o.name] = set() |
| self.undefined_refs[o.name].add(o) |
| |
| def pop_undefined_ref(self, name: str) -> set[NameExpr]: |
| return self.undefined_refs.pop(name, set()) |
| |
| |
| class DefinedVariableTracker: |
| """DefinedVariableTracker manages the state and scope for the UndefinedVariablesVisitor.""" |
| |
| def __init__(self) -> None: |
| # There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement. |
| self.scopes: list[Scope] = [Scope([BranchStatement()], ScopeType.Global)] |
| # disable_branch_skip is used to disable skipping a branch due to a return/raise/etc. This is useful |
| # in things like try/except/finally statements. |
| self.disable_branch_skip = False |
| |
| def copy(self) -> DefinedVariableTracker: |
| result = DefinedVariableTracker() |
| result.scopes = [s.copy() for s in self.scopes] |
| result.disable_branch_skip = self.disable_branch_skip |
| return result |
| |
| def _scope(self) -> Scope: |
| assert len(self.scopes) > 0 |
| return self.scopes[-1] |
| |
| def enter_scope(self, scope_type: ScopeType) -> None: |
| assert len(self._scope().branch_stmts) > 0 |
| initial_state = None |
| if scope_type == ScopeType.Generator: |
| # Generators are special because they inherit the outer scope. |
| initial_state = self._scope().branch_stmts[-1].branches[-1] |
| self.scopes.append(Scope([BranchStatement(initial_state)], scope_type)) |
| |
| def exit_scope(self) -> None: |
| self.scopes.pop() |
| |
| def in_scope(self, scope_type: ScopeType) -> bool: |
| return self._scope().scope_type == scope_type |
| |
| def start_branch_statement(self) -> None: |
| assert len(self._scope().branch_stmts) > 0 |
| self._scope().branch_stmts.append( |
| BranchStatement(self._scope().branch_stmts[-1].branches[-1]) |
| ) |
| |
| def next_branch(self) -> None: |
| assert len(self._scope().branch_stmts) > 1 |
| self._scope().branch_stmts[-1].next_branch() |
| |
| def end_branch_statement(self) -> None: |
| assert len(self._scope().branch_stmts) > 1 |
| result = self._scope().branch_stmts.pop().done() |
| self._scope().branch_stmts[-1].record_nested_branch(result) |
| |
| def skip_branch(self) -> None: |
| # Only skip branch if we're outside of "root" branch statement. |
| if len(self._scope().branch_stmts) > 1 and not self.disable_branch_skip: |
| self._scope().branch_stmts[-1].skip_branch() |
| |
| def record_definition(self, name: str) -> None: |
| assert len(self.scopes) > 0 |
| assert len(self.scopes[-1].branch_stmts) > 0 |
| self._scope().branch_stmts[-1].record_definition(name) |
| |
| def delete_var(self, name: str) -> None: |
| assert len(self.scopes) > 0 |
| assert len(self.scopes[-1].branch_stmts) > 0 |
| self._scope().branch_stmts[-1].delete_var(name) |
| |
| def record_undefined_ref(self, o: NameExpr) -> None: |
| """Records an undefined reference. These can later be retrieved via `pop_undefined_ref`.""" |
| assert len(self.scopes) > 0 |
| self._scope().record_undefined_ref(o) |
| |
| def pop_undefined_ref(self, name: str) -> set[NameExpr]: |
| """If name has previously been reported as undefined, the NameExpr that was called will be returned.""" |
| assert len(self.scopes) > 0 |
| return self._scope().pop_undefined_ref(name) |
| |
| def is_possibly_undefined(self, name: str) -> bool: |
| assert len(self._scope().branch_stmts) > 0 |
| # A variable is undefined if it's in a set of `may_be_defined` but not in `must_be_defined`. |
| return self._scope().branch_stmts[-1].is_possibly_undefined(name) |
| |
| def is_defined_in_different_branch(self, name: str) -> bool: |
| """This will return true if a variable is defined in a branch that's not the current branch.""" |
| assert len(self._scope().branch_stmts) > 0 |
| stmt = self._scope().branch_stmts[-1] |
| if not stmt.is_undefined(name): |
| return False |
| for stmt in self._scope().branch_stmts: |
| if stmt.is_defined_in_a_branch(name): |
| return True |
| return False |
| |
| def is_undefined(self, name: str) -> bool: |
| assert len(self._scope().branch_stmts) > 0 |
| return self._scope().branch_stmts[-1].is_undefined(name) |
| |
| |
| class Loop: |
| def __init__(self) -> None: |
| self.has_break = False |
| |
| |
| class PossiblyUndefinedVariableVisitor(ExtendedTraverserVisitor): |
| """Detects the following cases: |
| - A variable that's defined only part of the time. |
| - If a variable is used before definition |
| |
| An example of a partial definition: |
| if foo(): |
| x = 1 |
| print(x) # Error: "x" may be undefined. |
| |
| Example of a used before definition: |
| x = y |
| y: int = 2 |
| |
| Note that this code does not detect variables not defined in any of the branches -- that is |
| handled by the semantic analyzer. |
| """ |
| |
| def __init__( |
| self, |
| msg: MessageBuilder, |
| type_map: dict[Expression, Type], |
| options: Options, |
| names: SymbolTable, |
| ) -> None: |
| self.msg = msg |
| self.type_map = type_map |
| self.options = options |
| self.builtins = SymbolTable() |
| builtins_mod = names.get("__builtins__", None) |
| if builtins_mod: |
| assert isinstance(builtins_mod.node, MypyFile) |
| self.builtins = builtins_mod.node.names |
| self.loops: list[Loop] = [] |
| self.try_depth = 0 |
| self.tracker = DefinedVariableTracker() |
| for name in implicit_module_attrs: |
| self.tracker.record_definition(name) |
| |
| def var_used_before_def(self, name: str, context: Context) -> None: |
| if self.msg.errors.is_error_code_enabled(errorcodes.USED_BEFORE_DEF): |
| self.msg.var_used_before_def(name, context) |
| |
| def variable_may_be_undefined(self, name: str, context: Context) -> None: |
| if self.msg.errors.is_error_code_enabled(errorcodes.POSSIBLY_UNDEFINED): |
| self.msg.variable_may_be_undefined(name, context) |
| |
| def process_definition(self, name: str) -> None: |
| # Was this name previously used? If yes, it's a used-before-definition error. |
| if not self.tracker.in_scope(ScopeType.Class): |
| refs = self.tracker.pop_undefined_ref(name) |
| for ref in refs: |
| if self.loops: |
| self.variable_may_be_undefined(name, ref) |
| else: |
| self.var_used_before_def(name, ref) |
| else: |
| # Errors in class scopes are caught by the semantic analyzer. |
| pass |
| self.tracker.record_definition(name) |
| |
| def visit_global_decl(self, o: GlobalDecl) -> None: |
| for name in o.names: |
| self.process_definition(name) |
| super().visit_global_decl(o) |
| |
| def visit_nonlocal_decl(self, o: NonlocalDecl) -> None: |
| for name in o.names: |
| self.process_definition(name) |
| super().visit_nonlocal_decl(o) |
| |
| def process_lvalue(self, lvalue: Lvalue | None) -> None: |
| if isinstance(lvalue, NameExpr): |
| self.process_definition(lvalue.name) |
| elif isinstance(lvalue, StarExpr): |
| self.process_lvalue(lvalue.expr) |
| elif isinstance(lvalue, (ListExpr, TupleExpr)): |
| for item in lvalue.items: |
| self.process_lvalue(item) |
| |
| def visit_assignment_stmt(self, o: AssignmentStmt) -> None: |
| for lvalue in o.lvalues: |
| self.process_lvalue(lvalue) |
| super().visit_assignment_stmt(o) |
| |
| def visit_assignment_expr(self, o: AssignmentExpr) -> None: |
| o.value.accept(self) |
| self.process_lvalue(o.target) |
| |
| def visit_if_stmt(self, o: IfStmt) -> None: |
| for e in o.expr: |
| e.accept(self) |
| self.tracker.start_branch_statement() |
| for b in o.body: |
| if b.is_unreachable: |
| continue |
| b.accept(self) |
| self.tracker.next_branch() |
| if o.else_body: |
| if not o.else_body.is_unreachable: |
| o.else_body.accept(self) |
| else: |
| self.tracker.skip_branch() |
| self.tracker.end_branch_statement() |
| |
| def visit_match_stmt(self, o: MatchStmt) -> None: |
| o.subject.accept(self) |
| self.tracker.start_branch_statement() |
| for i in range(len(o.patterns)): |
| pattern = o.patterns[i] |
| pattern.accept(self) |
| guard = o.guards[i] |
| if guard is not None: |
| guard.accept(self) |
| if not o.bodies[i].is_unreachable: |
| o.bodies[i].accept(self) |
| else: |
| self.tracker.skip_branch() |
| is_catchall = infer_pattern_value(pattern) == ALWAYS_TRUE |
| if not is_catchall: |
| self.tracker.next_branch() |
| self.tracker.end_branch_statement() |
| |
| def visit_func_def(self, o: FuncDef) -> None: |
| self.process_definition(o.name) |
| super().visit_func_def(o) |
| |
| def visit_func(self, o: FuncItem) -> None: |
| if o.is_dynamic() and not self.options.check_untyped_defs: |
| return |
| |
| args = o.arguments or [] |
| # Process initializers (defaults) outside the function scope. |
| for arg in args: |
| if arg.initializer is not None: |
| arg.initializer.accept(self) |
| |
| self.tracker.enter_scope(ScopeType.Func) |
| for arg in args: |
| self.process_definition(arg.variable.name) |
| super().visit_var(arg.variable) |
| o.body.accept(self) |
| self.tracker.exit_scope() |
| |
| def visit_generator_expr(self, o: GeneratorExpr) -> None: |
| self.tracker.enter_scope(ScopeType.Generator) |
| for idx in o.indices: |
| self.process_lvalue(idx) |
| super().visit_generator_expr(o) |
| self.tracker.exit_scope() |
| |
| def visit_dictionary_comprehension(self, o: DictionaryComprehension) -> None: |
| self.tracker.enter_scope(ScopeType.Generator) |
| for idx in o.indices: |
| self.process_lvalue(idx) |
| super().visit_dictionary_comprehension(o) |
| self.tracker.exit_scope() |
| |
| def visit_for_stmt(self, o: ForStmt) -> None: |
| o.expr.accept(self) |
| self.process_lvalue(o.index) |
| o.index.accept(self) |
| self.tracker.start_branch_statement() |
| loop = Loop() |
| self.loops.append(loop) |
| o.body.accept(self) |
| self.tracker.next_branch() |
| self.tracker.end_branch_statement() |
| if o.else_body is not None: |
| # If the loop has a `break` inside, `else` is executed conditionally. |
| # If the loop doesn't have a `break` either the function will return or |
| # execute the `else`. |
| has_break = loop.has_break |
| if has_break: |
| self.tracker.start_branch_statement() |
| self.tracker.next_branch() |
| o.else_body.accept(self) |
| if has_break: |
| self.tracker.end_branch_statement() |
| self.loops.pop() |
| |
| def visit_return_stmt(self, o: ReturnStmt) -> None: |
| super().visit_return_stmt(o) |
| self.tracker.skip_branch() |
| |
| def visit_lambda_expr(self, o: LambdaExpr) -> None: |
| self.tracker.enter_scope(ScopeType.Func) |
| super().visit_lambda_expr(o) |
| self.tracker.exit_scope() |
| |
| def visit_assert_stmt(self, o: AssertStmt) -> None: |
| super().visit_assert_stmt(o) |
| if checker.is_false_literal(o.expr): |
| self.tracker.skip_branch() |
| |
| def visit_raise_stmt(self, o: RaiseStmt) -> None: |
| super().visit_raise_stmt(o) |
| self.tracker.skip_branch() |
| |
| def visit_continue_stmt(self, o: ContinueStmt) -> None: |
| super().visit_continue_stmt(o) |
| self.tracker.skip_branch() |
| |
| def visit_break_stmt(self, o: BreakStmt) -> None: |
| super().visit_break_stmt(o) |
| if self.loops: |
| self.loops[-1].has_break = True |
| self.tracker.skip_branch() |
| |
| def visit_expression_stmt(self, o: ExpressionStmt) -> None: |
| if isinstance(self.type_map.get(o.expr, None), UninhabitedType): |
| self.tracker.skip_branch() |
| super().visit_expression_stmt(o) |
| |
| def visit_try_stmt(self, o: TryStmt) -> None: |
| """ |
| Note that finding undefined vars in `finally` requires different handling from |
| the rest of the code. In particular, we want to disallow skipping branches due to jump |
| statements in except/else clauses for finally but not for other cases. Imagine a case like: |
| def f() -> int: |
| try: |
| x = 1 |
| except: |
| # This jump statement needs to be handled differently depending on whether or |
| # not we're trying to process `finally` or not. |
| return 0 |
| finally: |
| # `x` may be undefined here. |
| pass |
| # `x` is always defined here. |
| return x |
| """ |
| self.try_depth += 1 |
| if o.finally_body is not None: |
| # In order to find undefined vars in `finally`, we need to |
| # process try/except with branch skipping disabled. However, for the rest of the code |
| # after finally, we need to process try/except with branch skipping enabled. |
| # Therefore, we need to process try/finally twice. |
| # Because processing is not idempotent, we should make a copy of the tracker. |
| old_tracker = self.tracker.copy() |
| self.tracker.disable_branch_skip = True |
| self.process_try_stmt(o) |
| self.tracker = old_tracker |
| self.process_try_stmt(o) |
| self.try_depth -= 1 |
| |
| def process_try_stmt(self, o: TryStmt) -> None: |
| """ |
| Processes try statement decomposing it into the following: |
| if ...: |
| body |
| else_body |
| elif ...: |
| except 1 |
| elif ...: |
| except 2 |
| else: |
| except n |
| finally |
| """ |
| self.tracker.start_branch_statement() |
| o.body.accept(self) |
| if o.else_body is not None: |
| o.else_body.accept(self) |
| if len(o.handlers) > 0: |
| assert len(o.handlers) == len(o.vars) == len(o.types) |
| for i in range(len(o.handlers)): |
| self.tracker.next_branch() |
| exc_type = o.types[i] |
| if exc_type is not None: |
| exc_type.accept(self) |
| var = o.vars[i] |
| if var is not None: |
| self.process_definition(var.name) |
| var.accept(self) |
| o.handlers[i].accept(self) |
| if var is not None: |
| self.tracker.delete_var(var.name) |
| self.tracker.end_branch_statement() |
| |
| if o.finally_body is not None: |
| o.finally_body.accept(self) |
| |
| def visit_while_stmt(self, o: WhileStmt) -> None: |
| o.expr.accept(self) |
| self.tracker.start_branch_statement() |
| loop = Loop() |
| self.loops.append(loop) |
| o.body.accept(self) |
| has_break = loop.has_break |
| if not checker.is_true_literal(o.expr): |
| # If this is a loop like `while True`, we can consider the body to be |
| # a single branch statement (we're guaranteed that the body is executed at least once). |
| # If not, call next_branch() to make all variables defined there conditional. |
| self.tracker.next_branch() |
| self.tracker.end_branch_statement() |
| if o.else_body is not None: |
| # If the loop has a `break` inside, `else` is executed conditionally. |
| # If the loop doesn't have a `break` either the function will return or |
| # execute the `else`. |
| if has_break: |
| self.tracker.start_branch_statement() |
| self.tracker.next_branch() |
| if o.else_body: |
| o.else_body.accept(self) |
| if has_break: |
| self.tracker.end_branch_statement() |
| self.loops.pop() |
| |
| def visit_as_pattern(self, o: AsPattern) -> None: |
| if o.name is not None: |
| self.process_lvalue(o.name) |
| super().visit_as_pattern(o) |
| |
| def visit_starred_pattern(self, o: StarredPattern) -> None: |
| if o.capture is not None: |
| self.process_lvalue(o.capture) |
| super().visit_starred_pattern(o) |
| |
| def visit_name_expr(self, o: NameExpr) -> None: |
| if o.name in self.builtins and self.tracker.in_scope(ScopeType.Global): |
| return |
| if self.tracker.is_possibly_undefined(o.name): |
| # A variable is only defined in some branches. |
| self.variable_may_be_undefined(o.name, o) |
| # We don't want to report the error on the same variable multiple times. |
| self.tracker.record_definition(o.name) |
| elif self.tracker.is_defined_in_different_branch(o.name): |
| # A variable is defined in one branch but used in a different branch. |
| if self.loops or self.try_depth > 0: |
| # If we're in a loop or in a try, we can't be sure that this variable |
| # is undefined. Report it as "may be undefined". |
| self.variable_may_be_undefined(o.name, o) |
| else: |
| self.var_used_before_def(o.name, o) |
| elif self.tracker.is_undefined(o.name): |
| # A variable is undefined. It could be due to two things: |
| # 1. A variable is just totally undefined |
| # 2. The variable is defined later in the code. |
| # Case (1) will be caught by semantic analyzer. Case (2) is a forward ref that should |
| # be caught by this visitor. Save the ref for later, so that if we see a definition, |
| # we know it's a used-before-definition scenario. |
| self.tracker.record_undefined_ref(o) |
| super().visit_name_expr(o) |
| |
| def visit_with_stmt(self, o: WithStmt) -> None: |
| for expr, idx in zip(o.expr, o.target): |
| expr.accept(self) |
| self.process_lvalue(idx) |
| o.body.accept(self) |
| |
| def visit_class_def(self, o: ClassDef) -> None: |
| self.process_definition(o.name) |
| self.tracker.enter_scope(ScopeType.Class) |
| super().visit_class_def(o) |
| self.tracker.exit_scope() |
| |
| def visit_import(self, o: Import) -> None: |
| for mod, alias in o.ids: |
| if alias is not None: |
| self.tracker.record_definition(alias) |
| else: |
| # When you do `import x.y`, only `x` becomes defined. |
| names = mod.split(".") |
| if names: |
| # `names` should always be nonempty, but we don't want mypy |
| # to crash on invalid code. |
| self.tracker.record_definition(names[0]) |
| super().visit_import(o) |
| |
| def visit_import_from(self, o: ImportFrom) -> None: |
| for mod, alias in o.names: |
| name = alias |
| if name is None: |
| name = mod |
| self.tracker.record_definition(name) |
| super().visit_import_from(o) |