blob: 326a5baca1e74195afac0ac47b2231ffa1cee4f7 [file] [log] [blame]
"""Helpers for implementing generic IR to IR transforms."""
from __future__ import annotations
from typing import Final, Optional
from mypyc.ir.ops import (
Assign,
AssignMulti,
BasicBlock,
Box,
Branch,
Call,
CallC,
Cast,
ComparisonOp,
DecRef,
Extend,
FloatComparisonOp,
FloatNeg,
FloatOp,
GetAttr,
GetElementPtr,
Goto,
IncRef,
InitStatic,
IntOp,
KeepAlive,
LoadAddress,
LoadErrorValue,
LoadGlobal,
LoadLiteral,
LoadMem,
LoadStatic,
MethodCall,
Op,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
Return,
SetAttr,
SetMem,
Truncate,
TupleGet,
TupleSet,
Unborrow,
Unbox,
Unreachable,
Value,
)
from mypyc.irbuild.ll_builder import LowLevelIRBuilder
class IRTransform(OpVisitor[Optional[Value]]):
"""Identity transform.
Subclass and override to perform changes to IR.
Subclass IRTransform and override any OpVisitor visit_* methods
that perform any IR changes. The default implementations implement
an identity transform.
A visit method can return None to remove ops. In this case the
transform must ensure that no op uses the original removed op
as a source after the transform.
You can retain old BasicBlock and op references in ops. The transform
will automatically patch these for you as needed.
"""
def __init__(self, builder: LowLevelIRBuilder) -> None:
self.builder = builder
# Subclasses add additional op mappings here. A None value indicates
# that the op/register is deleted.
self.op_map: dict[Value, Value | None] = {}
def transform_blocks(self, blocks: list[BasicBlock]) -> None:
"""Transform basic blocks that represent a single function.
The result of the transform will be collected at self.builder.blocks.
"""
block_map: dict[BasicBlock, BasicBlock] = {}
op_map = self.op_map
empties = set()
for block in blocks:
new_block = BasicBlock()
block_map[block] = new_block
self.builder.activate_block(new_block)
new_block.error_handler = block.error_handler
for op in block.ops:
new_op = op.accept(self)
if new_op is not op:
op_map[op] = new_op
# A transform can produce empty blocks which can be removed.
if is_empty_block(new_block) and not is_empty_block(block):
empties.add(new_block)
self.builder.blocks = [block for block in self.builder.blocks if block not in empties]
# Update all op/block references to point to the transformed ones.
patcher = PatchVisitor(op_map, block_map)
for block in self.builder.blocks:
for op in block.ops:
op.accept(patcher)
if block.error_handler is not None:
block.error_handler = block_map.get(block.error_handler, block.error_handler)
def add(self, op: Op) -> Value:
return self.builder.add(op)
def visit_goto(self, op: Goto) -> None:
self.add(op)
def visit_branch(self, op: Branch) -> None:
self.add(op)
def visit_return(self, op: Return) -> None:
self.add(op)
def visit_unreachable(self, op: Unreachable) -> None:
self.add(op)
def visit_assign(self, op: Assign) -> Value | None:
if op.src in self.op_map and self.op_map[op.src] is None:
# Special case: allow removing register initialization assignments
return None
return self.add(op)
def visit_assign_multi(self, op: AssignMulti) -> Value | None:
return self.add(op)
def visit_load_error_value(self, op: LoadErrorValue) -> Value | None:
return self.add(op)
def visit_load_literal(self, op: LoadLiteral) -> Value | None:
return self.add(op)
def visit_get_attr(self, op: GetAttr) -> Value | None:
return self.add(op)
def visit_set_attr(self, op: SetAttr) -> Value | None:
return self.add(op)
def visit_load_static(self, op: LoadStatic) -> Value | None:
return self.add(op)
def visit_init_static(self, op: InitStatic) -> Value | None:
return self.add(op)
def visit_tuple_get(self, op: TupleGet) -> Value | None:
return self.add(op)
def visit_tuple_set(self, op: TupleSet) -> Value | None:
return self.add(op)
def visit_inc_ref(self, op: IncRef) -> Value | None:
return self.add(op)
def visit_dec_ref(self, op: DecRef) -> Value | None:
return self.add(op)
def visit_call(self, op: Call) -> Value | None:
return self.add(op)
def visit_method_call(self, op: MethodCall) -> Value | None:
return self.add(op)
def visit_cast(self, op: Cast) -> Value | None:
return self.add(op)
def visit_box(self, op: Box) -> Value | None:
return self.add(op)
def visit_unbox(self, op: Unbox) -> Value | None:
return self.add(op)
def visit_raise_standard_error(self, op: RaiseStandardError) -> Value | None:
return self.add(op)
def visit_call_c(self, op: CallC) -> Value | None:
return self.add(op)
def visit_primitive_op(self, op: PrimitiveOp) -> Value | None:
return self.add(op)
def visit_truncate(self, op: Truncate) -> Value | None:
return self.add(op)
def visit_extend(self, op: Extend) -> Value | None:
return self.add(op)
def visit_load_global(self, op: LoadGlobal) -> Value | None:
return self.add(op)
def visit_int_op(self, op: IntOp) -> Value | None:
return self.add(op)
def visit_comparison_op(self, op: ComparisonOp) -> Value | None:
return self.add(op)
def visit_float_op(self, op: FloatOp) -> Value | None:
return self.add(op)
def visit_float_neg(self, op: FloatNeg) -> Value | None:
return self.add(op)
def visit_float_comparison_op(self, op: FloatComparisonOp) -> Value | None:
return self.add(op)
def visit_load_mem(self, op: LoadMem) -> Value | None:
return self.add(op)
def visit_set_mem(self, op: SetMem) -> Value | None:
return self.add(op)
def visit_get_element_ptr(self, op: GetElementPtr) -> Value | None:
return self.add(op)
def visit_load_address(self, op: LoadAddress) -> Value | None:
return self.add(op)
def visit_keep_alive(self, op: KeepAlive) -> Value | None:
return self.add(op)
def visit_unborrow(self, op: Unborrow) -> Value | None:
return self.add(op)
class PatchVisitor(OpVisitor[None]):
def __init__(
self, op_map: dict[Value, Value | None], block_map: dict[BasicBlock, BasicBlock]
) -> None:
self.op_map: Final = op_map
self.block_map: Final = block_map
def fix_op(self, op: Value) -> Value:
new = self.op_map.get(op, op)
assert new is not None, "use of removed op"
return new
def fix_block(self, block: BasicBlock) -> BasicBlock:
return self.block_map.get(block, block)
def visit_goto(self, op: Goto) -> None:
op.label = self.fix_block(op.label)
def visit_branch(self, op: Branch) -> None:
op.value = self.fix_op(op.value)
op.true = self.fix_block(op.true)
op.false = self.fix_block(op.false)
def visit_return(self, op: Return) -> None:
op.value = self.fix_op(op.value)
def visit_unreachable(self, op: Unreachable) -> None:
pass
def visit_assign(self, op: Assign) -> None:
op.src = self.fix_op(op.src)
def visit_assign_multi(self, op: AssignMulti) -> None:
op.src = [self.fix_op(s) for s in op.src]
def visit_load_error_value(self, op: LoadErrorValue) -> None:
pass
def visit_load_literal(self, op: LoadLiteral) -> None:
pass
def visit_get_attr(self, op: GetAttr) -> None:
op.obj = self.fix_op(op.obj)
def visit_set_attr(self, op: SetAttr) -> None:
op.obj = self.fix_op(op.obj)
op.src = self.fix_op(op.src)
def visit_load_static(self, op: LoadStatic) -> None:
pass
def visit_init_static(self, op: InitStatic) -> None:
op.value = self.fix_op(op.value)
def visit_tuple_get(self, op: TupleGet) -> None:
op.src = self.fix_op(op.src)
def visit_tuple_set(self, op: TupleSet) -> None:
op.items = [self.fix_op(item) for item in op.items]
def visit_inc_ref(self, op: IncRef) -> None:
op.src = self.fix_op(op.src)
def visit_dec_ref(self, op: DecRef) -> None:
op.src = self.fix_op(op.src)
def visit_call(self, op: Call) -> None:
op.args = [self.fix_op(arg) for arg in op.args]
def visit_method_call(self, op: MethodCall) -> None:
op.obj = self.fix_op(op.obj)
op.args = [self.fix_op(arg) for arg in op.args]
def visit_cast(self, op: Cast) -> None:
op.src = self.fix_op(op.src)
def visit_box(self, op: Box) -> None:
op.src = self.fix_op(op.src)
def visit_unbox(self, op: Unbox) -> None:
op.src = self.fix_op(op.src)
def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
if isinstance(op.value, Value):
op.value = self.fix_op(op.value)
def visit_call_c(self, op: CallC) -> None:
op.args = [self.fix_op(arg) for arg in op.args]
def visit_primitive_op(self, op: PrimitiveOp) -> None:
op.args = [self.fix_op(arg) for arg in op.args]
def visit_truncate(self, op: Truncate) -> None:
op.src = self.fix_op(op.src)
def visit_extend(self, op: Extend) -> None:
op.src = self.fix_op(op.src)
def visit_load_global(self, op: LoadGlobal) -> None:
pass
def visit_int_op(self, op: IntOp) -> None:
op.lhs = self.fix_op(op.lhs)
op.rhs = self.fix_op(op.rhs)
def visit_comparison_op(self, op: ComparisonOp) -> None:
op.lhs = self.fix_op(op.lhs)
op.rhs = self.fix_op(op.rhs)
def visit_float_op(self, op: FloatOp) -> None:
op.lhs = self.fix_op(op.lhs)
op.rhs = self.fix_op(op.rhs)
def visit_float_neg(self, op: FloatNeg) -> None:
op.src = self.fix_op(op.src)
def visit_float_comparison_op(self, op: FloatComparisonOp) -> None:
op.lhs = self.fix_op(op.lhs)
op.rhs = self.fix_op(op.rhs)
def visit_load_mem(self, op: LoadMem) -> None:
op.src = self.fix_op(op.src)
def visit_set_mem(self, op: SetMem) -> None:
op.dest = self.fix_op(op.dest)
op.src = self.fix_op(op.src)
def visit_get_element_ptr(self, op: GetElementPtr) -> None:
op.src = self.fix_op(op.src)
def visit_load_address(self, op: LoadAddress) -> None:
if isinstance(op.src, LoadStatic):
new = self.fix_op(op.src)
assert isinstance(new, LoadStatic)
op.src = new
def visit_keep_alive(self, op: KeepAlive) -> None:
op.src = [self.fix_op(s) for s in op.src]
def visit_unborrow(self, op: Unborrow) -> None:
op.src = self.fix_op(op.src)
def is_empty_block(block: BasicBlock) -> bool:
return len(block.ops) == 1 and isinstance(block.ops[0], Unreachable)