| """Data-flow analyses.""" |
| |
| from __future__ import annotations |
| |
| from abc import abstractmethod |
| from typing import Dict, Generic, Iterable, Iterator, Set, Tuple, TypeVar |
| |
| from mypyc.ir.func_ir import all_values |
| from mypyc.ir.ops import ( |
| Assign, |
| AssignMulti, |
| BasicBlock, |
| Box, |
| Branch, |
| Call, |
| CallC, |
| Cast, |
| ComparisonOp, |
| ControlOp, |
| Extend, |
| Float, |
| FloatComparisonOp, |
| FloatNeg, |
| FloatOp, |
| GetAttr, |
| GetElementPtr, |
| Goto, |
| InitStatic, |
| Integer, |
| IntOp, |
| KeepAlive, |
| LoadAddress, |
| LoadErrorValue, |
| LoadGlobal, |
| LoadLiteral, |
| LoadMem, |
| LoadStatic, |
| MethodCall, |
| Op, |
| OpVisitor, |
| RaiseStandardError, |
| RegisterOp, |
| Return, |
| SetAttr, |
| SetMem, |
| Truncate, |
| TupleGet, |
| TupleSet, |
| Unbox, |
| Unreachable, |
| Value, |
| ) |
| |
| |
| class CFG: |
| """Control-flow graph. |
| |
| Node 0 is always assumed to be the entry point. There must be a |
| non-empty set of exits. |
| """ |
| |
| def __init__( |
| self, |
| succ: dict[BasicBlock, list[BasicBlock]], |
| pred: dict[BasicBlock, list[BasicBlock]], |
| exits: set[BasicBlock], |
| ) -> None: |
| assert exits |
| self.succ = succ |
| self.pred = pred |
| self.exits = exits |
| |
| def __str__(self) -> str: |
| lines = [] |
| lines.append("exits: %s" % sorted(self.exits, key=lambda e: int(e.label))) |
| lines.append("succ: %s" % self.succ) |
| lines.append("pred: %s" % self.pred) |
| return "\n".join(lines) |
| |
| |
| def get_cfg(blocks: list[BasicBlock]) -> CFG: |
| """Calculate basic block control-flow graph. |
| |
| The result is a dictionary like this: |
| |
| basic block index -> (successors blocks, predecesssor blocks) |
| """ |
| succ_map = {} |
| pred_map: dict[BasicBlock, list[BasicBlock]] = {} |
| exits = set() |
| for block in blocks: |
| |
| assert not any( |
| isinstance(op, ControlOp) for op in block.ops[:-1] |
| ), "Control-flow ops must be at the end of blocks" |
| |
| succ = list(block.terminator.targets()) |
| if not succ: |
| exits.add(block) |
| |
| # Errors can occur anywhere inside a block, which means that |
| # we can't assume that the entire block has executed before |
| # jumping to the error handler. In our CFG construction, we |
| # model this as saying that a block can jump to its error |
| # handler or the error handlers of any of its normal |
| # successors (to represent an error before that next block |
| # completes). This works well for analyses like "must |
| # defined", where it implies that registers assigned in a |
| # block may be undefined in its error handler, but is in |
| # general not a precise representation of reality; any |
| # analyses that require more fidelity must wait until after |
| # exception insertion. |
| for error_point in [block] + succ: |
| if error_point.error_handler: |
| succ.append(error_point.error_handler) |
| |
| succ_map[block] = succ |
| pred_map[block] = [] |
| for prev, nxt in succ_map.items(): |
| for label in nxt: |
| pred_map[label].append(prev) |
| return CFG(succ_map, pred_map, exits) |
| |
| |
| def get_real_target(label: BasicBlock) -> BasicBlock: |
| if len(label.ops) == 1 and isinstance(label.ops[-1], Goto): |
| label = label.ops[-1].label |
| return label |
| |
| |
| def cleanup_cfg(blocks: list[BasicBlock]) -> None: |
| """Cleanup the control flow graph. |
| |
| This eliminates obviously dead basic blocks and eliminates blocks that contain |
| nothing but a single jump. |
| |
| There is a lot more that could be done. |
| """ |
| changed = True |
| while changed: |
| # First collapse any jumps to basic block that only contain a goto |
| for block in blocks: |
| for i, tgt in enumerate(block.terminator.targets()): |
| block.terminator.set_target(i, get_real_target(tgt)) |
| |
| # Then delete any blocks that have no predecessors |
| changed = False |
| cfg = get_cfg(blocks) |
| orig_blocks = blocks[:] |
| blocks.clear() |
| for i, block in enumerate(orig_blocks): |
| if i == 0 or cfg.pred[block]: |
| blocks.append(block) |
| else: |
| changed = True |
| |
| |
| T = TypeVar("T") |
| |
| AnalysisDict = Dict[Tuple[BasicBlock, int], Set[T]] |
| |
| |
| class AnalysisResult(Generic[T]): |
| def __init__(self, before: AnalysisDict[T], after: AnalysisDict[T]) -> None: |
| self.before = before |
| self.after = after |
| |
| def __str__(self) -> str: |
| return f"before: {self.before}\nafter: {self.after}\n" |
| |
| |
| GenAndKill = Tuple[Set[T], Set[T]] |
| |
| |
| class BaseAnalysisVisitor(OpVisitor[GenAndKill[T]]): |
| def visit_goto(self, op: Goto) -> GenAndKill[T]: |
| return set(), set() |
| |
| @abstractmethod |
| def visit_register_op(self, op: RegisterOp) -> GenAndKill[T]: |
| raise NotImplementedError |
| |
| @abstractmethod |
| def visit_assign(self, op: Assign) -> GenAndKill[T]: |
| raise NotImplementedError |
| |
| @abstractmethod |
| def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[T]: |
| raise NotImplementedError |
| |
| @abstractmethod |
| def visit_set_mem(self, op: SetMem) -> GenAndKill[T]: |
| raise NotImplementedError |
| |
| def visit_call(self, op: Call) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_method_call(self, op: MethodCall) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_load_error_value(self, op: LoadErrorValue) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_load_literal(self, op: LoadLiteral) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_get_attr(self, op: GetAttr) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_set_attr(self, op: SetAttr) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_load_static(self, op: LoadStatic) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_init_static(self, op: InitStatic) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_tuple_get(self, op: TupleGet) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_tuple_set(self, op: TupleSet) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_box(self, op: Box) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_unbox(self, op: Unbox) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_cast(self, op: Cast) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_call_c(self, op: CallC) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_truncate(self, op: Truncate) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_extend(self, op: Extend) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_load_global(self, op: LoadGlobal) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_int_op(self, op: IntOp) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_float_op(self, op: FloatOp) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_float_neg(self, op: FloatNeg) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_comparison_op(self, op: ComparisonOp) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_float_comparison_op(self, op: FloatComparisonOp) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_load_mem(self, op: LoadMem) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_get_element_ptr(self, op: GetElementPtr) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_load_address(self, op: LoadAddress) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| def visit_keep_alive(self, op: KeepAlive) -> GenAndKill[T]: |
| return self.visit_register_op(op) |
| |
| |
| class DefinedVisitor(BaseAnalysisVisitor[Value]): |
| """Visitor for finding defined registers. |
| |
| Note that this only deals with registers and not temporaries, on |
| the assumption that we never access temporaries when they might be |
| undefined. |
| |
| If strict_errors is True, then we regard any use of LoadErrorValue |
| as making a register undefined. Otherwise we only do if |
| `undefines` is set on the error value. |
| |
| This lets us only consider the things we care about during |
| uninitialized variable checking while capturing all possibly |
| undefined things for refcounting. |
| """ |
| |
| def __init__(self, strict_errors: bool = False) -> None: |
| self.strict_errors = strict_errors |
| |
| def visit_branch(self, op: Branch) -> GenAndKill[Value]: |
| return set(), set() |
| |
| def visit_return(self, op: Return) -> GenAndKill[Value]: |
| return set(), set() |
| |
| def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]: |
| return set(), set() |
| |
| def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]: |
| return set(), set() |
| |
| def visit_assign(self, op: Assign) -> GenAndKill[Value]: |
| # Loading an error value may undefine the register. |
| if isinstance(op.src, LoadErrorValue) and (op.src.undefines or self.strict_errors): |
| return set(), {op.dest} |
| else: |
| return {op.dest}, set() |
| |
| def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]: |
| # Array registers are special and we don't track the definedness of them. |
| return set(), set() |
| |
| def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]: |
| return set(), set() |
| |
| |
| def analyze_maybe_defined_regs( |
| blocks: list[BasicBlock], cfg: CFG, initial_defined: set[Value] |
| ) -> AnalysisResult[Value]: |
| """Calculate potentially defined registers at each CFG location. |
| |
| A register is defined if it has a value along some path from the initial location. |
| """ |
| return run_analysis( |
| blocks=blocks, |
| cfg=cfg, |
| gen_and_kill=DefinedVisitor(), |
| initial=initial_defined, |
| backward=False, |
| kind=MAYBE_ANALYSIS, |
| ) |
| |
| |
| def analyze_must_defined_regs( |
| blocks: list[BasicBlock], |
| cfg: CFG, |
| initial_defined: set[Value], |
| regs: Iterable[Value], |
| strict_errors: bool = False, |
| ) -> AnalysisResult[Value]: |
| """Calculate always defined registers at each CFG location. |
| |
| This analysis can work before exception insertion, since it is a |
| sound assumption that registers defined in a block might not be |
| initialized in its error handler. |
| |
| A register is defined if it has a value along all paths from the |
| initial location. |
| """ |
| return run_analysis( |
| blocks=blocks, |
| cfg=cfg, |
| gen_and_kill=DefinedVisitor(strict_errors=strict_errors), |
| initial=initial_defined, |
| backward=False, |
| kind=MUST_ANALYSIS, |
| universe=set(regs), |
| ) |
| |
| |
| class BorrowedArgumentsVisitor(BaseAnalysisVisitor[Value]): |
| def __init__(self, args: set[Value]) -> None: |
| self.args = args |
| |
| def visit_branch(self, op: Branch) -> GenAndKill[Value]: |
| return set(), set() |
| |
| def visit_return(self, op: Return) -> GenAndKill[Value]: |
| return set(), set() |
| |
| def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]: |
| return set(), set() |
| |
| def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]: |
| return set(), set() |
| |
| def visit_assign(self, op: Assign) -> GenAndKill[Value]: |
| if op.dest in self.args: |
| return set(), {op.dest} |
| return set(), set() |
| |
| def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]: |
| return set(), set() |
| |
| def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]: |
| return set(), set() |
| |
| |
| def analyze_borrowed_arguments( |
| blocks: list[BasicBlock], cfg: CFG, borrowed: set[Value] |
| ) -> AnalysisResult[Value]: |
| """Calculate arguments that can use references borrowed from the caller. |
| |
| When assigning to an argument, it no longer is borrowed. |
| """ |
| return run_analysis( |
| blocks=blocks, |
| cfg=cfg, |
| gen_and_kill=BorrowedArgumentsVisitor(borrowed), |
| initial=borrowed, |
| backward=False, |
| kind=MUST_ANALYSIS, |
| universe=borrowed, |
| ) |
| |
| |
| class UndefinedVisitor(BaseAnalysisVisitor[Value]): |
| def visit_branch(self, op: Branch) -> GenAndKill[Value]: |
| return set(), set() |
| |
| def visit_return(self, op: Return) -> GenAndKill[Value]: |
| return set(), set() |
| |
| def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]: |
| return set(), set() |
| |
| def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]: |
| return set(), {op} if not op.is_void else set() |
| |
| def visit_assign(self, op: Assign) -> GenAndKill[Value]: |
| return set(), {op.dest} |
| |
| def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]: |
| return set(), {op.dest} |
| |
| def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]: |
| return set(), set() |
| |
| |
| def analyze_undefined_regs( |
| blocks: list[BasicBlock], cfg: CFG, initial_defined: set[Value] |
| ) -> AnalysisResult[Value]: |
| """Calculate potentially undefined registers at each CFG location. |
| |
| A register is undefined if there is some path from initial block |
| where it has an undefined value. |
| |
| Function arguments are assumed to be always defined. |
| """ |
| initial_undefined = set(all_values([], blocks)) - initial_defined |
| return run_analysis( |
| blocks=blocks, |
| cfg=cfg, |
| gen_and_kill=UndefinedVisitor(), |
| initial=initial_undefined, |
| backward=False, |
| kind=MAYBE_ANALYSIS, |
| ) |
| |
| |
| def non_trivial_sources(op: Op) -> set[Value]: |
| result = set() |
| for source in op.sources(): |
| if not isinstance(source, (Integer, Float)): |
| result.add(source) |
| return result |
| |
| |
| class LivenessVisitor(BaseAnalysisVisitor[Value]): |
| def visit_branch(self, op: Branch) -> GenAndKill[Value]: |
| return non_trivial_sources(op), set() |
| |
| def visit_return(self, op: Return) -> GenAndKill[Value]: |
| if not isinstance(op.value, (Integer, Float)): |
| return {op.value}, set() |
| else: |
| return set(), set() |
| |
| def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]: |
| return set(), set() |
| |
| def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]: |
| gen = non_trivial_sources(op) |
| if not op.is_void: |
| return gen, {op} |
| else: |
| return gen, set() |
| |
| def visit_assign(self, op: Assign) -> GenAndKill[Value]: |
| return non_trivial_sources(op), {op.dest} |
| |
| def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]: |
| return non_trivial_sources(op), {op.dest} |
| |
| def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]: |
| return non_trivial_sources(op), set() |
| |
| |
| def analyze_live_regs(blocks: list[BasicBlock], cfg: CFG) -> AnalysisResult[Value]: |
| """Calculate live registers at each CFG location. |
| |
| A register is live at a location if it can be read along some CFG path starting |
| from the location. |
| """ |
| return run_analysis( |
| blocks=blocks, |
| cfg=cfg, |
| gen_and_kill=LivenessVisitor(), |
| initial=set(), |
| backward=True, |
| kind=MAYBE_ANALYSIS, |
| ) |
| |
| |
| # Analysis kinds |
| MUST_ANALYSIS = 0 |
| MAYBE_ANALYSIS = 1 |
| |
| |
| def run_analysis( |
| blocks: list[BasicBlock], |
| cfg: CFG, |
| gen_and_kill: OpVisitor[GenAndKill[T]], |
| initial: set[T], |
| kind: int, |
| backward: bool, |
| universe: set[T] | None = None, |
| ) -> AnalysisResult[T]: |
| """Run a general set-based data flow analysis. |
| |
| Args: |
| blocks: All basic blocks |
| cfg: Control-flow graph for the code |
| gen_and_kill: Implementation of gen and kill functions for each op |
| initial: Value of analysis for the entry points (for a forward analysis) or the |
| exit points (for a backward analysis) |
| kind: MUST_ANALYSIS or MAYBE_ANALYSIS |
| backward: If False, the analysis is a forward analysis; it's backward otherwise |
| universe: For a must analysis, the set of all possible values. This is the starting |
| value for the work list algorithm, which will narrow this down until reaching a |
| fixed point. For a maybe analysis the iteration always starts from an empty set |
| and this argument is ignored. |
| |
| Return analysis results: (before, after) |
| """ |
| block_gen = {} |
| block_kill = {} |
| |
| # Calculate kill and gen sets for entire basic blocks. |
| for block in blocks: |
| gen: set[T] = set() |
| kill: set[T] = set() |
| ops = block.ops |
| if backward: |
| ops = list(reversed(ops)) |
| for op in ops: |
| opgen, opkill = op.accept(gen_and_kill) |
| gen = (gen - opkill) | opgen |
| kill = (kill - opgen) | opkill |
| block_gen[block] = gen |
| block_kill[block] = kill |
| |
| # Set up initial state for worklist algorithm. |
| worklist = list(blocks) |
| if not backward: |
| worklist = worklist[::-1] # Reverse for a small performance improvement |
| workset = set(worklist) |
| before: dict[BasicBlock, set[T]] = {} |
| after: dict[BasicBlock, set[T]] = {} |
| for block in blocks: |
| if kind == MAYBE_ANALYSIS: |
| before[block] = set() |
| after[block] = set() |
| else: |
| assert universe is not None, "Universe must be defined for a must analysis" |
| before[block] = set(universe) |
| after[block] = set(universe) |
| |
| if backward: |
| pred_map = cfg.succ |
| succ_map = cfg.pred |
| else: |
| pred_map = cfg.pred |
| succ_map = cfg.succ |
| |
| # Run work list algorithm to generate in and out sets for each basic block. |
| while worklist: |
| label = worklist.pop() |
| workset.remove(label) |
| if pred_map[label]: |
| new_before: set[T] | None = None |
| for pred in pred_map[label]: |
| if new_before is None: |
| new_before = set(after[pred]) |
| elif kind == MAYBE_ANALYSIS: |
| new_before |= after[pred] |
| else: |
| new_before &= after[pred] |
| assert new_before is not None |
| else: |
| new_before = set(initial) |
| before[label] = new_before |
| new_after = (new_before - block_kill[label]) | block_gen[label] |
| if new_after != after[label]: |
| for succ in succ_map[label]: |
| if succ not in workset: |
| worklist.append(succ) |
| workset.add(succ) |
| after[label] = new_after |
| |
| # Run algorithm for each basic block to generate opcode-level sets. |
| op_before: dict[tuple[BasicBlock, int], set[T]] = {} |
| op_after: dict[tuple[BasicBlock, int], set[T]] = {} |
| for block in blocks: |
| label = block |
| cur = before[label] |
| ops_enum: Iterator[tuple[int, Op]] = enumerate(block.ops) |
| if backward: |
| ops_enum = reversed(list(ops_enum)) |
| for idx, op in ops_enum: |
| op_before[label, idx] = cur |
| opgen, opkill = op.accept(gen_and_kill) |
| cur = (cur - opkill) | opgen |
| op_after[label, idx] = cur |
| if backward: |
| op_after, op_before = op_before, op_after |
| |
| return AnalysisResult(op_before, op_after) |