| """Helpers for implementing generic IR to IR transforms.""" |
| |
| from __future__ import annotations |
| |
| from typing import Final, Optional |
| |
| from mypyc.ir.ops import ( |
| Assign, |
| AssignMulti, |
| BasicBlock, |
| Box, |
| Branch, |
| Call, |
| CallC, |
| Cast, |
| ComparisonOp, |
| DecRef, |
| Extend, |
| FloatComparisonOp, |
| FloatNeg, |
| FloatOp, |
| GetAttr, |
| GetElementPtr, |
| Goto, |
| IncRef, |
| InitStatic, |
| IntOp, |
| KeepAlive, |
| LoadAddress, |
| LoadErrorValue, |
| LoadGlobal, |
| LoadLiteral, |
| LoadMem, |
| LoadStatic, |
| MethodCall, |
| Op, |
| OpVisitor, |
| PrimitiveOp, |
| RaiseStandardError, |
| Return, |
| SetAttr, |
| SetMem, |
| Truncate, |
| TupleGet, |
| TupleSet, |
| Unborrow, |
| Unbox, |
| Unreachable, |
| Value, |
| ) |
| from mypyc.irbuild.ll_builder import LowLevelIRBuilder |
| |
| |
| class IRTransform(OpVisitor[Optional[Value]]): |
| """Identity transform. |
| |
| Subclass and override to perform changes to IR. |
| |
| Subclass IRTransform and override any OpVisitor visit_* methods |
| that perform any IR changes. The default implementations implement |
| an identity transform. |
| |
| A visit method can return None to remove ops. In this case the |
| transform must ensure that no op uses the original removed op |
| as a source after the transform. |
| |
| You can retain old BasicBlock and op references in ops. The transform |
| will automatically patch these for you as needed. |
| """ |
| |
| def __init__(self, builder: LowLevelIRBuilder) -> None: |
| self.builder = builder |
| # Subclasses add additional op mappings here. A None value indicates |
| # that the op/register is deleted. |
| self.op_map: dict[Value, Value | None] = {} |
| |
| def transform_blocks(self, blocks: list[BasicBlock]) -> None: |
| """Transform basic blocks that represent a single function. |
| |
| The result of the transform will be collected at self.builder.blocks. |
| """ |
| block_map: dict[BasicBlock, BasicBlock] = {} |
| op_map = self.op_map |
| empties = set() |
| for block in blocks: |
| new_block = BasicBlock() |
| block_map[block] = new_block |
| self.builder.activate_block(new_block) |
| new_block.error_handler = block.error_handler |
| for op in block.ops: |
| new_op = op.accept(self) |
| if new_op is not op: |
| op_map[op] = new_op |
| # A transform can produce empty blocks which can be removed. |
| if is_empty_block(new_block) and not is_empty_block(block): |
| empties.add(new_block) |
| self.builder.blocks = [block for block in self.builder.blocks if block not in empties] |
| # Update all op/block references to point to the transformed ones. |
| patcher = PatchVisitor(op_map, block_map) |
| for block in self.builder.blocks: |
| for op in block.ops: |
| op.accept(patcher) |
| if block.error_handler is not None: |
| block.error_handler = block_map.get(block.error_handler, block.error_handler) |
| |
| def add(self, op: Op) -> Value: |
| return self.builder.add(op) |
| |
| def visit_goto(self, op: Goto) -> None: |
| self.add(op) |
| |
| def visit_branch(self, op: Branch) -> None: |
| self.add(op) |
| |
| def visit_return(self, op: Return) -> None: |
| self.add(op) |
| |
| def visit_unreachable(self, op: Unreachable) -> None: |
| self.add(op) |
| |
| def visit_assign(self, op: Assign) -> Value | None: |
| if op.src in self.op_map and self.op_map[op.src] is None: |
| # Special case: allow removing register initialization assignments |
| return None |
| return self.add(op) |
| |
| def visit_assign_multi(self, op: AssignMulti) -> Value | None: |
| return self.add(op) |
| |
| def visit_load_error_value(self, op: LoadErrorValue) -> Value | None: |
| return self.add(op) |
| |
| def visit_load_literal(self, op: LoadLiteral) -> Value | None: |
| return self.add(op) |
| |
| def visit_get_attr(self, op: GetAttr) -> Value | None: |
| return self.add(op) |
| |
| def visit_set_attr(self, op: SetAttr) -> Value | None: |
| return self.add(op) |
| |
| def visit_load_static(self, op: LoadStatic) -> Value | None: |
| return self.add(op) |
| |
| def visit_init_static(self, op: InitStatic) -> Value | None: |
| return self.add(op) |
| |
| def visit_tuple_get(self, op: TupleGet) -> Value | None: |
| return self.add(op) |
| |
| def visit_tuple_set(self, op: TupleSet) -> Value | None: |
| return self.add(op) |
| |
| def visit_inc_ref(self, op: IncRef) -> Value | None: |
| return self.add(op) |
| |
| def visit_dec_ref(self, op: DecRef) -> Value | None: |
| return self.add(op) |
| |
| def visit_call(self, op: Call) -> Value | None: |
| return self.add(op) |
| |
| def visit_method_call(self, op: MethodCall) -> Value | None: |
| return self.add(op) |
| |
| def visit_cast(self, op: Cast) -> Value | None: |
| return self.add(op) |
| |
| def visit_box(self, op: Box) -> Value | None: |
| return self.add(op) |
| |
| def visit_unbox(self, op: Unbox) -> Value | None: |
| return self.add(op) |
| |
| def visit_raise_standard_error(self, op: RaiseStandardError) -> Value | None: |
| return self.add(op) |
| |
| def visit_call_c(self, op: CallC) -> Value | None: |
| return self.add(op) |
| |
| def visit_primitive_op(self, op: PrimitiveOp) -> Value | None: |
| return self.add(op) |
| |
| def visit_truncate(self, op: Truncate) -> Value | None: |
| return self.add(op) |
| |
| def visit_extend(self, op: Extend) -> Value | None: |
| return self.add(op) |
| |
| def visit_load_global(self, op: LoadGlobal) -> Value | None: |
| return self.add(op) |
| |
| def visit_int_op(self, op: IntOp) -> Value | None: |
| return self.add(op) |
| |
| def visit_comparison_op(self, op: ComparisonOp) -> Value | None: |
| return self.add(op) |
| |
| def visit_float_op(self, op: FloatOp) -> Value | None: |
| return self.add(op) |
| |
| def visit_float_neg(self, op: FloatNeg) -> Value | None: |
| return self.add(op) |
| |
| def visit_float_comparison_op(self, op: FloatComparisonOp) -> Value | None: |
| return self.add(op) |
| |
| def visit_load_mem(self, op: LoadMem) -> Value | None: |
| return self.add(op) |
| |
| def visit_set_mem(self, op: SetMem) -> Value | None: |
| return self.add(op) |
| |
| def visit_get_element_ptr(self, op: GetElementPtr) -> Value | None: |
| return self.add(op) |
| |
| def visit_load_address(self, op: LoadAddress) -> Value | None: |
| return self.add(op) |
| |
| def visit_keep_alive(self, op: KeepAlive) -> Value | None: |
| return self.add(op) |
| |
| def visit_unborrow(self, op: Unborrow) -> Value | None: |
| return self.add(op) |
| |
| |
| class PatchVisitor(OpVisitor[None]): |
| def __init__( |
| self, op_map: dict[Value, Value | None], block_map: dict[BasicBlock, BasicBlock] |
| ) -> None: |
| self.op_map: Final = op_map |
| self.block_map: Final = block_map |
| |
| def fix_op(self, op: Value) -> Value: |
| new = self.op_map.get(op, op) |
| assert new is not None, "use of removed op" |
| return new |
| |
| def fix_block(self, block: BasicBlock) -> BasicBlock: |
| return self.block_map.get(block, block) |
| |
| def visit_goto(self, op: Goto) -> None: |
| op.label = self.fix_block(op.label) |
| |
| def visit_branch(self, op: Branch) -> None: |
| op.value = self.fix_op(op.value) |
| op.true = self.fix_block(op.true) |
| op.false = self.fix_block(op.false) |
| |
| def visit_return(self, op: Return) -> None: |
| op.value = self.fix_op(op.value) |
| |
| def visit_unreachable(self, op: Unreachable) -> None: |
| pass |
| |
| def visit_assign(self, op: Assign) -> None: |
| op.src = self.fix_op(op.src) |
| |
| def visit_assign_multi(self, op: AssignMulti) -> None: |
| op.src = [self.fix_op(s) for s in op.src] |
| |
| def visit_load_error_value(self, op: LoadErrorValue) -> None: |
| pass |
| |
| def visit_load_literal(self, op: LoadLiteral) -> None: |
| pass |
| |
| def visit_get_attr(self, op: GetAttr) -> None: |
| op.obj = self.fix_op(op.obj) |
| |
| def visit_set_attr(self, op: SetAttr) -> None: |
| op.obj = self.fix_op(op.obj) |
| op.src = self.fix_op(op.src) |
| |
| def visit_load_static(self, op: LoadStatic) -> None: |
| pass |
| |
| def visit_init_static(self, op: InitStatic) -> None: |
| op.value = self.fix_op(op.value) |
| |
| def visit_tuple_get(self, op: TupleGet) -> None: |
| op.src = self.fix_op(op.src) |
| |
| def visit_tuple_set(self, op: TupleSet) -> None: |
| op.items = [self.fix_op(item) for item in op.items] |
| |
| def visit_inc_ref(self, op: IncRef) -> None: |
| op.src = self.fix_op(op.src) |
| |
| def visit_dec_ref(self, op: DecRef) -> None: |
| op.src = self.fix_op(op.src) |
| |
| def visit_call(self, op: Call) -> None: |
| op.args = [self.fix_op(arg) for arg in op.args] |
| |
| def visit_method_call(self, op: MethodCall) -> None: |
| op.obj = self.fix_op(op.obj) |
| op.args = [self.fix_op(arg) for arg in op.args] |
| |
| def visit_cast(self, op: Cast) -> None: |
| op.src = self.fix_op(op.src) |
| |
| def visit_box(self, op: Box) -> None: |
| op.src = self.fix_op(op.src) |
| |
| def visit_unbox(self, op: Unbox) -> None: |
| op.src = self.fix_op(op.src) |
| |
| def visit_raise_standard_error(self, op: RaiseStandardError) -> None: |
| if isinstance(op.value, Value): |
| op.value = self.fix_op(op.value) |
| |
| def visit_call_c(self, op: CallC) -> None: |
| op.args = [self.fix_op(arg) for arg in op.args] |
| |
| def visit_primitive_op(self, op: PrimitiveOp) -> None: |
| op.args = [self.fix_op(arg) for arg in op.args] |
| |
| def visit_truncate(self, op: Truncate) -> None: |
| op.src = self.fix_op(op.src) |
| |
| def visit_extend(self, op: Extend) -> None: |
| op.src = self.fix_op(op.src) |
| |
| def visit_load_global(self, op: LoadGlobal) -> None: |
| pass |
| |
| def visit_int_op(self, op: IntOp) -> None: |
| op.lhs = self.fix_op(op.lhs) |
| op.rhs = self.fix_op(op.rhs) |
| |
| def visit_comparison_op(self, op: ComparisonOp) -> None: |
| op.lhs = self.fix_op(op.lhs) |
| op.rhs = self.fix_op(op.rhs) |
| |
| def visit_float_op(self, op: FloatOp) -> None: |
| op.lhs = self.fix_op(op.lhs) |
| op.rhs = self.fix_op(op.rhs) |
| |
| def visit_float_neg(self, op: FloatNeg) -> None: |
| op.src = self.fix_op(op.src) |
| |
| def visit_float_comparison_op(self, op: FloatComparisonOp) -> None: |
| op.lhs = self.fix_op(op.lhs) |
| op.rhs = self.fix_op(op.rhs) |
| |
| def visit_load_mem(self, op: LoadMem) -> None: |
| op.src = self.fix_op(op.src) |
| |
| def visit_set_mem(self, op: SetMem) -> None: |
| op.dest = self.fix_op(op.dest) |
| op.src = self.fix_op(op.src) |
| |
| def visit_get_element_ptr(self, op: GetElementPtr) -> None: |
| op.src = self.fix_op(op.src) |
| |
| def visit_load_address(self, op: LoadAddress) -> None: |
| if isinstance(op.src, LoadStatic): |
| new = self.fix_op(op.src) |
| assert isinstance(new, LoadStatic) |
| op.src = new |
| |
| def visit_keep_alive(self, op: KeepAlive) -> None: |
| op.src = [self.fix_op(s) for s in op.src] |
| |
| def visit_unborrow(self, op: Unborrow) -> None: |
| op.src = self.fix_op(op.src) |
| |
| |
| def is_empty_block(block: BasicBlock) -> bool: |
| return len(block.ops) == 1 and isinstance(block.ops[0], Unreachable) |