| """Utilities for checking that internal ir is valid and consistent.""" |
| from __future__ import annotations |
| |
| from mypyc.ir.func_ir import FUNC_STATICMETHOD, FuncIR |
| from mypyc.ir.ops import ( |
| Assign, |
| AssignMulti, |
| BaseAssign, |
| BasicBlock, |
| Box, |
| Branch, |
| Call, |
| CallC, |
| Cast, |
| ComparisonOp, |
| ControlOp, |
| DecRef, |
| Extend, |
| FloatComparisonOp, |
| FloatNeg, |
| FloatOp, |
| GetAttr, |
| GetElementPtr, |
| Goto, |
| IncRef, |
| InitStatic, |
| Integer, |
| IntOp, |
| KeepAlive, |
| LoadAddress, |
| LoadErrorValue, |
| LoadGlobal, |
| LoadLiteral, |
| LoadMem, |
| LoadStatic, |
| MethodCall, |
| Op, |
| OpVisitor, |
| RaiseStandardError, |
| Register, |
| Return, |
| SetAttr, |
| SetMem, |
| Truncate, |
| TupleGet, |
| TupleSet, |
| Unbox, |
| Unreachable, |
| Value, |
| ) |
| from mypyc.ir.pprint import format_func |
| from mypyc.ir.rtypes import ( |
| RArray, |
| RInstance, |
| RPrimitive, |
| RType, |
| RUnion, |
| bytes_rprimitive, |
| dict_rprimitive, |
| int_rprimitive, |
| is_float_rprimitive, |
| is_object_rprimitive, |
| list_rprimitive, |
| range_rprimitive, |
| set_rprimitive, |
| str_rprimitive, |
| tuple_rprimitive, |
| ) |
| |
| |
| class FnError: |
| def __init__(self, source: Op | BasicBlock, desc: str) -> None: |
| self.source = source |
| self.desc = desc |
| |
| def __eq__(self, other: object) -> bool: |
| return ( |
| isinstance(other, FnError) and self.source == other.source and self.desc == other.desc |
| ) |
| |
| def __repr__(self) -> str: |
| return f"FnError(source={self.source}, desc={self.desc})" |
| |
| |
| def check_func_ir(fn: FuncIR) -> list[FnError]: |
| """Applies validations to a given function ir and returns a list of errors found.""" |
| errors = [] |
| |
| op_set = set() |
| |
| for block in fn.blocks: |
| if not block.terminated: |
| errors.append( |
| FnError(source=block.ops[-1] if block.ops else block, desc="Block not terminated") |
| ) |
| for op in block.ops[:-1]: |
| if isinstance(op, ControlOp): |
| errors.append(FnError(source=op, desc="Block has operations after control op")) |
| |
| if op in op_set: |
| errors.append(FnError(source=op, desc="Func has a duplicate op")) |
| op_set.add(op) |
| |
| errors.extend(check_op_sources_valid(fn)) |
| if errors: |
| return errors |
| |
| op_checker = OpChecker(fn) |
| for block in fn.blocks: |
| for op in block.ops: |
| op.accept(op_checker) |
| |
| return op_checker.errors |
| |
| |
| class IrCheckException(Exception): |
| pass |
| |
| |
| def assert_func_ir_valid(fn: FuncIR) -> None: |
| errors = check_func_ir(fn) |
| if errors: |
| raise IrCheckException( |
| "Internal error: Generated invalid IR: \n" |
| + "\n".join(format_func(fn, [(e.source, e.desc) for e in errors])) |
| ) |
| |
| |
| def check_op_sources_valid(fn: FuncIR) -> list[FnError]: |
| errors = [] |
| valid_ops: set[Op] = set() |
| valid_registers: set[Register] = set() |
| |
| for block in fn.blocks: |
| valid_ops.update(block.ops) |
| |
| for op in block.ops: |
| if isinstance(op, BaseAssign): |
| valid_registers.add(op.dest) |
| elif isinstance(op, LoadAddress) and isinstance(op.src, Register): |
| valid_registers.add(op.src) |
| |
| valid_registers.update(fn.arg_regs) |
| |
| for block in fn.blocks: |
| for op in block.ops: |
| for source in op.sources(): |
| if isinstance(source, Integer): |
| pass |
| elif isinstance(source, Op): |
| if source not in valid_ops: |
| errors.append( |
| FnError( |
| source=op, |
| desc=f"Invalid op reference to op of type {type(source).__name__}", |
| ) |
| ) |
| elif isinstance(source, Register): |
| if source not in valid_registers: |
| errors.append( |
| FnError( |
| source=op, desc=f"Invalid op reference to register {source.name!r}" |
| ) |
| ) |
| |
| return errors |
| |
| |
| disjoint_types = { |
| int_rprimitive.name, |
| bytes_rprimitive.name, |
| str_rprimitive.name, |
| dict_rprimitive.name, |
| list_rprimitive.name, |
| set_rprimitive.name, |
| tuple_rprimitive.name, |
| range_rprimitive.name, |
| } |
| |
| |
| def can_coerce_to(src: RType, dest: RType) -> bool: |
| """Check if src can be assigned to dest_rtype. |
| |
| Currently okay to have false positives. |
| """ |
| if isinstance(dest, RUnion): |
| return any(can_coerce_to(src, d) for d in dest.items) |
| |
| if isinstance(dest, RPrimitive): |
| if isinstance(src, RPrimitive): |
| # If either src or dest is a disjoint type, then they must both be. |
| if src.name in disjoint_types and dest.name in disjoint_types: |
| return src.name == dest.name |
| return src.size == dest.size |
| if isinstance(src, RInstance): |
| return is_object_rprimitive(dest) |
| if isinstance(src, RUnion): |
| # IR doesn't have the ability to narrow unions based on |
| # control flow, so cannot be a strict all() here. |
| return any(can_coerce_to(s, dest) for s in src.items) |
| return False |
| |
| return True |
| |
| |
| class OpChecker(OpVisitor[None]): |
| def __init__(self, parent_fn: FuncIR) -> None: |
| self.parent_fn = parent_fn |
| self.errors: list[FnError] = [] |
| |
| def fail(self, source: Op, desc: str) -> None: |
| self.errors.append(FnError(source=source, desc=desc)) |
| |
| def check_control_op_targets(self, op: ControlOp) -> None: |
| for target in op.targets(): |
| if target not in self.parent_fn.blocks: |
| self.fail(source=op, desc=f"Invalid control operation target: {target.label}") |
| |
| def check_type_coercion(self, op: Op, src: RType, dest: RType) -> None: |
| if not can_coerce_to(src, dest): |
| self.fail( |
| source=op, desc=f"Cannot coerce source type {src.name} to dest type {dest.name}" |
| ) |
| |
| def check_compatibility(self, op: Op, t: RType, s: RType) -> None: |
| if not can_coerce_to(t, s) or not can_coerce_to(s, t): |
| self.fail(source=op, desc=f"{t.name} and {s.name} are not compatible") |
| |
| def expect_float(self, op: Op, v: Value) -> None: |
| if not is_float_rprimitive(v.type): |
| self.fail(op, f"Float expected (actual type is {v.type})") |
| |
| def expect_non_float(self, op: Op, v: Value) -> None: |
| if is_float_rprimitive(v.type): |
| self.fail(op, "Float not expected") |
| |
| def visit_goto(self, op: Goto) -> None: |
| self.check_control_op_targets(op) |
| |
| def visit_branch(self, op: Branch) -> None: |
| self.check_control_op_targets(op) |
| |
| def visit_return(self, op: Return) -> None: |
| self.check_type_coercion(op, op.value.type, self.parent_fn.decl.sig.ret_type) |
| |
| def visit_unreachable(self, op: Unreachable) -> None: |
| # Unreachables are checked at a higher level since validation |
| # requires access to the entire basic block. |
| pass |
| |
| def visit_assign(self, op: Assign) -> None: |
| self.check_type_coercion(op, op.src.type, op.dest.type) |
| |
| def visit_assign_multi(self, op: AssignMulti) -> None: |
| for src in op.src: |
| assert isinstance(op.dest.type, RArray) |
| self.check_type_coercion(op, src.type, op.dest.type.item_type) |
| |
| def visit_load_error_value(self, op: LoadErrorValue) -> None: |
| # Currently it is assumed that all types have an error value. |
| # Once this is fixed we can validate that the rtype here actually |
| # has an error value. |
| pass |
| |
| def check_tuple_items_valid_literals(self, op: LoadLiteral, t: tuple[object, ...]) -> None: |
| for x in t: |
| if x is not None and not isinstance(x, (str, bytes, bool, int, float, complex, tuple)): |
| self.fail(op, f"Invalid type for item of tuple literal: {type(x)})") |
| if isinstance(x, tuple): |
| self.check_tuple_items_valid_literals(op, x) |
| |
| def check_frozenset_items_valid_literals(self, op: LoadLiteral, s: frozenset[object]) -> None: |
| for x in s: |
| if x is None or isinstance(x, (str, bytes, bool, int, float, complex)): |
| pass |
| elif isinstance(x, tuple): |
| self.check_tuple_items_valid_literals(op, x) |
| else: |
| self.fail(op, f"Invalid type for item of frozenset literal: {type(x)})") |
| |
| def visit_load_literal(self, op: LoadLiteral) -> None: |
| expected_type = None |
| if op.value is None: |
| expected_type = "builtins.object" |
| elif isinstance(op.value, int): |
| expected_type = "builtins.int" |
| elif isinstance(op.value, str): |
| expected_type = "builtins.str" |
| elif isinstance(op.value, bytes): |
| expected_type = "builtins.bytes" |
| elif isinstance(op.value, bool): |
| expected_type = "builtins.object" |
| elif isinstance(op.value, float): |
| expected_type = "builtins.float" |
| elif isinstance(op.value, complex): |
| expected_type = "builtins.object" |
| elif isinstance(op.value, tuple): |
| expected_type = "builtins.tuple" |
| self.check_tuple_items_valid_literals(op, op.value) |
| elif isinstance(op.value, frozenset): |
| # There's no frozenset_rprimitive type since it'd be pretty useless so we just pretend |
| # it's a set (when it's really a frozenset). |
| expected_type = "builtins.set" |
| self.check_frozenset_items_valid_literals(op, op.value) |
| |
| assert expected_type is not None, "Missed a case for LoadLiteral check" |
| |
| if op.type.name not in [expected_type, "builtins.object"]: |
| self.fail( |
| op, |
| f"Invalid literal value for type: value has " |
| f"type {expected_type}, but op has type {op.type.name}", |
| ) |
| |
| def visit_get_attr(self, op: GetAttr) -> None: |
| # Nothing to do. |
| pass |
| |
| def visit_set_attr(self, op: SetAttr) -> None: |
| # Nothing to do. |
| pass |
| |
| # Static operations cannot be checked at the function level. |
| def visit_load_static(self, op: LoadStatic) -> None: |
| pass |
| |
| def visit_init_static(self, op: InitStatic) -> None: |
| pass |
| |
| def visit_tuple_get(self, op: TupleGet) -> None: |
| # Nothing to do. |
| pass |
| |
| def visit_tuple_set(self, op: TupleSet) -> None: |
| # Nothing to do. |
| pass |
| |
| def visit_inc_ref(self, op: IncRef) -> None: |
| # Nothing to do. |
| pass |
| |
| def visit_dec_ref(self, op: DecRef) -> None: |
| # Nothing to do. |
| pass |
| |
| def visit_call(self, op: Call) -> None: |
| # Length is checked in constructor, and return type is set |
| # in a way that can't be incorrect |
| for arg_value, arg_runtime in zip(op.args, op.fn.sig.args): |
| self.check_type_coercion(op, arg_value.type, arg_runtime.type) |
| |
| def visit_method_call(self, op: MethodCall) -> None: |
| # Similar to above, but we must look up method first. |
| method_decl = op.receiver_type.class_ir.method_decl(op.method) |
| if method_decl.kind == FUNC_STATICMETHOD: |
| decl_index = 0 |
| else: |
| decl_index = 1 |
| |
| if len(op.args) + decl_index != len(method_decl.sig.args): |
| self.fail(op, "Incorrect number of args for method call.") |
| |
| # Skip the receiver argument (self) |
| for arg_value, arg_runtime in zip(op.args, method_decl.sig.args[decl_index:]): |
| self.check_type_coercion(op, arg_value.type, arg_runtime.type) |
| |
| def visit_cast(self, op: Cast) -> None: |
| pass |
| |
| def visit_box(self, op: Box) -> None: |
| pass |
| |
| def visit_unbox(self, op: Unbox) -> None: |
| pass |
| |
| def visit_raise_standard_error(self, op: RaiseStandardError) -> None: |
| pass |
| |
| def visit_call_c(self, op: CallC) -> None: |
| pass |
| |
| def visit_truncate(self, op: Truncate) -> None: |
| pass |
| |
| def visit_extend(self, op: Extend) -> None: |
| pass |
| |
| def visit_load_global(self, op: LoadGlobal) -> None: |
| pass |
| |
| def visit_int_op(self, op: IntOp) -> None: |
| self.expect_non_float(op, op.lhs) |
| self.expect_non_float(op, op.rhs) |
| |
| def visit_comparison_op(self, op: ComparisonOp) -> None: |
| self.check_compatibility(op, op.lhs.type, op.rhs.type) |
| self.expect_non_float(op, op.lhs) |
| self.expect_non_float(op, op.rhs) |
| |
| def visit_float_op(self, op: FloatOp) -> None: |
| self.expect_float(op, op.lhs) |
| self.expect_float(op, op.rhs) |
| |
| def visit_float_neg(self, op: FloatNeg) -> None: |
| self.expect_float(op, op.src) |
| |
| def visit_float_comparison_op(self, op: FloatComparisonOp) -> None: |
| self.expect_float(op, op.lhs) |
| self.expect_float(op, op.rhs) |
| |
| def visit_load_mem(self, op: LoadMem) -> None: |
| pass |
| |
| def visit_set_mem(self, op: SetMem) -> None: |
| pass |
| |
| def visit_get_element_ptr(self, op: GetElementPtr) -> None: |
| pass |
| |
| def visit_load_address(self, op: LoadAddress) -> None: |
| pass |
| |
| def visit_keep_alive(self, op: KeepAlive) -> None: |
| pass |