| """Bool register elimination optimization. |
| |
| Example input: |
| |
| L1: |
| r0 = f() |
| b = r0 |
| goto L3 |
| L2: |
| r1 = g() |
| b = r1 |
| goto L3 |
| L3: |
| if b goto L4 else goto L5 |
| |
| The register b is redundant and we replace the assignments with two copies of |
| the branch in L3: |
| |
| L1: |
| r0 = f() |
| if r0 goto L4 else goto L5 |
| L2: |
| r1 = g() |
| if r1 goto L4 else goto L5 |
| |
| This helps generate simpler IR for tagged integers comparisons, for example. |
| """ |
| |
| from __future__ import annotations |
| |
| from mypyc.ir.func_ir import FuncIR |
| from mypyc.ir.ops import Assign, BasicBlock, Branch, Goto, Register, Unreachable |
| from mypyc.irbuild.ll_builder import LowLevelIRBuilder |
| from mypyc.options import CompilerOptions |
| from mypyc.transform.ir_transform import IRTransform |
| |
| |
| def do_flag_elimination(fn: FuncIR, options: CompilerOptions) -> None: |
| # Find registers that are used exactly once as source, and in a branch. |
| counts: dict[Register, int] = {} |
| branches: dict[Register, Branch] = {} |
| labels: dict[Register, BasicBlock] = {} |
| for block in fn.blocks: |
| for i, op in enumerate(block.ops): |
| for src in op.sources(): |
| if isinstance(src, Register): |
| counts[src] = counts.get(src, 0) + 1 |
| if i == 0 and isinstance(op, Branch) and isinstance(op.value, Register): |
| branches[op.value] = op |
| labels[op.value] = block |
| |
| # Based on these we can find the candidate registers. |
| candidates: set[Register] = { |
| r for r in branches if counts.get(r, 0) == 1 and r not in fn.arg_regs |
| } |
| |
| # Remove candidates with invalid assignments. |
| for block in fn.blocks: |
| for i, op in enumerate(block.ops): |
| if isinstance(op, Assign) and op.dest in candidates: |
| next_op = block.ops[i + 1] |
| if not (isinstance(next_op, Goto) and next_op.label is labels[op.dest]): |
| # Not right |
| candidates.remove(op.dest) |
| |
| builder = LowLevelIRBuilder(None, options) |
| transform = FlagEliminationTransform( |
| builder, {x: y for x, y in branches.items() if x in candidates} |
| ) |
| transform.transform_blocks(fn.blocks) |
| fn.blocks = builder.blocks |
| |
| |
| class FlagEliminationTransform(IRTransform): |
| def __init__(self, builder: LowLevelIRBuilder, branch_map: dict[Register, Branch]) -> None: |
| super().__init__(builder) |
| self.branch_map = branch_map |
| self.branches = set(branch_map.values()) |
| |
| def visit_assign(self, op: Assign) -> None: |
| old_branch = self.branch_map.get(op.dest) |
| if old_branch: |
| # Replace assignment with a copy of the old branch, which is in a |
| # separate basic block. The old branch will be deletecd in visit_branch. |
| new_branch = Branch( |
| op.src, |
| old_branch.true, |
| old_branch.false, |
| old_branch.op, |
| old_branch.line, |
| rare=old_branch.rare, |
| ) |
| new_branch.negated = old_branch.negated |
| new_branch.traceback_entry = old_branch.traceback_entry |
| self.add(new_branch) |
| else: |
| self.add(op) |
| |
| def visit_goto(self, op: Goto) -> None: |
| # This is a no-op if basic block already terminated |
| self.builder.goto(op.label) |
| |
| def visit_branch(self, op: Branch) -> None: |
| if op in self.branches: |
| # This branch is optimized away |
| self.add(Unreachable()) |
| else: |
| self.add(op) |