blob: f15beceff8cec7a79fe3f425e3599160ce8067ee [file] [log] [blame] [edit]
"""Utilities for checking that internal ir is valid and consistent."""
from typing import List, Union
from mypyc.ir.pprint import format_func
from mypyc.ir.ops import (
OpVisitor, BasicBlock, Op, ControlOp, Goto, Branch, Return, Unreachable,
Assign, AssignMulti, LoadErrorValue, LoadLiteral, GetAttr, SetAttr, LoadStatic,
InitStatic, TupleGet, TupleSet, IncRef, DecRef, Call, MethodCall, Cast,
Box, Unbox, RaiseStandardError, CallC, Truncate, LoadGlobal, IntOp, ComparisonOp,
LoadMem, SetMem, GetElementPtr, LoadAddress, KeepAlive
)
from mypyc.ir.func_ir import FuncIR
class FnError(object):
def __init__(self, source: Union[Op, BasicBlock], desc: str) -> None:
self.source = source
self.desc = desc
def __eq__(self, other: object) -> bool:
return isinstance(other, FnError) and self.source == other.source and \
self.desc == other.desc
def __repr__(self) -> str:
return f"FnError(source={self.source}, desc={self.desc})"
def check_func_ir(fn: FuncIR) -> List[FnError]:
"""Applies validations to a given function ir and returns a list of errors found."""
errors = []
for block in fn.blocks:
if not block.terminated:
errors.append(FnError(
source=block.ops[-1] if block.ops else block,
desc="Block not terminated",
))
op_checker = OpChecker(fn)
for block in fn.blocks:
for op in block.ops:
op.accept(op_checker)
return errors + op_checker.errors
class IrCheckException(Exception):
pass
def assert_func_ir_valid(fn: FuncIR) -> None:
errors = check_func_ir(fn)
if errors:
raise IrCheckException("Internal error: Generated invalid IR: \n" + "\n".join(
format_func(fn, [(e.source, e.desc) for e in errors])),
)
class OpChecker(OpVisitor[None]):
def __init__(self, parent_fn: FuncIR) -> None:
self.parent_fn = parent_fn
self.errors: List[FnError] = []
def fail(self, source: Op, desc: str) -> None:
self.errors.append(FnError(source=source, desc=desc))
def check_control_op_targets(self, op: ControlOp) -> None:
for target in op.targets():
if target not in self.parent_fn.blocks:
self.fail(source=op, desc=f"Invalid control operation target: {target.label}")
def visit_goto(self, op: Goto) -> None:
self.check_control_op_targets(op)
def visit_branch(self, op: Branch) -> None:
self.check_control_op_targets(op)
def visit_return(self, op: Return) -> None:
pass
def visit_unreachable(self, op: Unreachable) -> None:
pass
def visit_assign(self, op: Assign) -> None:
pass
def visit_assign_multi(self, op: AssignMulti) -> None:
pass
def visit_load_error_value(self, op: LoadErrorValue) -> None:
pass
def visit_load_literal(self, op: LoadLiteral) -> None:
pass
def visit_get_attr(self, op: GetAttr) -> None:
pass
def visit_set_attr(self, op: SetAttr) -> None:
pass
def visit_load_static(self, op: LoadStatic) -> None:
pass
def visit_init_static(self, op: InitStatic) -> None:
pass
def visit_tuple_get(self, op: TupleGet) -> None:
pass
def visit_tuple_set(self, op: TupleSet) -> None:
pass
def visit_inc_ref(self, op: IncRef) -> None:
pass
def visit_dec_ref(self, op: DecRef) -> None:
pass
def visit_call(self, op: Call) -> None:
pass
def visit_method_call(self, op: MethodCall) -> None:
pass
def visit_cast(self, op: Cast) -> None:
pass
def visit_box(self, op: Box) -> None:
pass
def visit_unbox(self, op: Unbox) -> None:
pass
def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
pass
def visit_call_c(self, op: CallC) -> None:
pass
def visit_truncate(self, op: Truncate) -> None:
pass
def visit_load_global(self, op: LoadGlobal) -> None:
pass
def visit_int_op(self, op: IntOp) -> None:
pass
def visit_comparison_op(self, op: ComparisonOp) -> None:
pass
def visit_load_mem(self, op: LoadMem) -> None:
pass
def visit_set_mem(self, op: SetMem) -> None:
pass
def visit_get_element_ptr(self, op: GetElementPtr) -> None:
pass
def visit_load_address(self, op: LoadAddress) -> None:
pass
def visit_keep_alive(self, op: KeepAlive) -> None:
pass