blob: 605e5bc46ae4b57a5724a829f30a095ef606caa4 [file] [log] [blame]
"""Bool register elimination optimization.
Example input:
L1:
r0 = f()
b = r0
goto L3
L2:
r1 = g()
b = r1
goto L3
L3:
if b goto L4 else goto L5
The register b is redundant and we replace the assignments with two copies of
the branch in L3:
L1:
r0 = f()
if r0 goto L4 else goto L5
L2:
r1 = g()
if r1 goto L4 else goto L5
This helps generate simpler IR for tagged integers comparisons, for example.
"""
from __future__ import annotations
from mypyc.ir.func_ir import FuncIR
from mypyc.ir.ops import Assign, BasicBlock, Branch, Goto, Register, Unreachable
from mypyc.irbuild.ll_builder import LowLevelIRBuilder
from mypyc.options import CompilerOptions
from mypyc.transform.ir_transform import IRTransform
def do_flag_elimination(fn: FuncIR, options: CompilerOptions) -> None:
# Find registers that are used exactly once as source, and in a branch.
counts: dict[Register, int] = {}
branches: dict[Register, Branch] = {}
labels: dict[Register, BasicBlock] = {}
for block in fn.blocks:
for i, op in enumerate(block.ops):
for src in op.sources():
if isinstance(src, Register):
counts[src] = counts.get(src, 0) + 1
if i == 0 and isinstance(op, Branch) and isinstance(op.value, Register):
branches[op.value] = op
labels[op.value] = block
# Based on these we can find the candidate registers.
candidates: set[Register] = {
r for r in branches if counts.get(r, 0) == 1 and r not in fn.arg_regs
}
# Remove candidates with invalid assignments.
for block in fn.blocks:
for i, op in enumerate(block.ops):
if isinstance(op, Assign) and op.dest in candidates:
next_op = block.ops[i + 1]
if not (isinstance(next_op, Goto) and next_op.label is labels[op.dest]):
# Not right
candidates.remove(op.dest)
builder = LowLevelIRBuilder(None, options)
transform = FlagEliminationTransform(
builder, {x: y for x, y in branches.items() if x in candidates}
)
transform.transform_blocks(fn.blocks)
fn.blocks = builder.blocks
class FlagEliminationTransform(IRTransform):
def __init__(self, builder: LowLevelIRBuilder, branch_map: dict[Register, Branch]) -> None:
super().__init__(builder)
self.branch_map = branch_map
self.branches = set(branch_map.values())
def visit_assign(self, op: Assign) -> None:
old_branch = self.branch_map.get(op.dest)
if old_branch:
# Replace assignment with a copy of the old branch, which is in a
# separate basic block. The old branch will be deletecd in visit_branch.
new_branch = Branch(
op.src,
old_branch.true,
old_branch.false,
old_branch.op,
old_branch.line,
rare=old_branch.rare,
)
new_branch.negated = old_branch.negated
new_branch.traceback_entry = old_branch.traceback_entry
self.add(new_branch)
else:
self.add(op)
def visit_goto(self, op: Goto) -> None:
# This is a no-op if basic block already terminated
self.builder.goto(op.label)
def visit_branch(self, op: Branch) -> None:
if op in self.branches:
# This branch is optimized away
self.add(Unreachable())
else:
self.add(op)