| """Code generation for native function bodies.""" |
| |
| from typing import Union |
| from typing_extensions import Final |
| |
| from mypyc.common import ( |
| REG_PREFIX, NATIVE_PREFIX, STATIC_PREFIX, TYPE_PREFIX, MODULE_PREFIX, |
| ) |
| from mypyc.codegen.emit import Emitter |
| from mypyc.ir.ops import ( |
| OpVisitor, Goto, Branch, Return, Assign, Integer, LoadErrorValue, GetAttr, SetAttr, |
| LoadStatic, InitStatic, TupleGet, TupleSet, Call, IncRef, DecRef, Box, Cast, Unbox, |
| BasicBlock, Value, MethodCall, Unreachable, NAMESPACE_STATIC, NAMESPACE_TYPE, NAMESPACE_MODULE, |
| RaiseStandardError, CallC, LoadGlobal, Truncate, IntOp, LoadMem, GetElementPtr, |
| LoadAddress, ComparisonOp, SetMem, Register |
| ) |
| from mypyc.ir.rtypes import ( |
| RType, RTuple, is_tagged, is_int32_rprimitive, is_int64_rprimitive, RStruct, |
| is_pointer_rprimitive |
| ) |
| from mypyc.ir.func_ir import FuncIR, FuncDecl, FUNC_STATICMETHOD, FUNC_CLASSMETHOD, all_values |
| from mypyc.ir.class_ir import ClassIR |
| from mypyc.ir.pprint import generate_names_for_ir |
| |
| # Whether to insert debug asserts for all error handling, to quickly |
| # catch errors propagating without exceptions set. |
| DEBUG_ERRORS = False |
| |
| |
| def native_function_type(fn: FuncIR, emitter: Emitter) -> str: |
| args = ', '.join(emitter.ctype(arg.type) for arg in fn.args) or 'void' |
| ret = emitter.ctype(fn.ret_type) |
| return '{} (*)({})'.format(ret, args) |
| |
| |
| def native_function_header(fn: FuncDecl, emitter: Emitter) -> str: |
| args = [] |
| for arg in fn.sig.args: |
| args.append('{}{}{}'.format(emitter.ctype_spaced(arg.type), REG_PREFIX, arg.name)) |
| |
| return '{ret_type}{name}({args})'.format( |
| ret_type=emitter.ctype_spaced(fn.sig.ret_type), |
| name=emitter.native_function_name(fn), |
| args=', '.join(args) or 'void') |
| |
| |
| def generate_native_function(fn: FuncIR, |
| emitter: Emitter, |
| source_path: str, |
| module_name: str) -> None: |
| declarations = Emitter(emitter.context) |
| names = generate_names_for_ir(fn.arg_regs, fn.blocks) |
| body = Emitter(emitter.context, names) |
| visitor = FunctionEmitterVisitor(body, declarations, source_path, module_name) |
| |
| declarations.emit_line('{} {{'.format(native_function_header(fn.decl, emitter))) |
| body.indent() |
| |
| for r in all_values(fn.arg_regs, fn.blocks): |
| if isinstance(r.type, RTuple): |
| emitter.declare_tuple_struct(r.type) |
| |
| if r in fn.arg_regs: |
| continue # Skip the arguments |
| |
| ctype = emitter.ctype_spaced(r.type) |
| init = '' |
| declarations.emit_line('{ctype}{prefix}{name}{init};'.format(ctype=ctype, |
| prefix=REG_PREFIX, |
| name=names[r], |
| init=init)) |
| |
| # Before we emit the blocks, give them all labels |
| for i, block in enumerate(fn.blocks): |
| block.label = i |
| |
| for block in fn.blocks: |
| body.emit_label(block) |
| for op in block.ops: |
| op.accept(visitor) |
| |
| body.emit_line('}') |
| |
| emitter.emit_from_emitter(declarations) |
| emitter.emit_from_emitter(body) |
| |
| |
| class FunctionEmitterVisitor(OpVisitor[None]): |
| def __init__(self, |
| emitter: Emitter, |
| declarations: Emitter, |
| source_path: str, |
| module_name: str) -> None: |
| self.emitter = emitter |
| self.names = emitter.names |
| self.declarations = declarations |
| self.source_path = source_path |
| self.module_name = module_name |
| |
| def temp_name(self) -> str: |
| return self.emitter.temp_name() |
| |
| def visit_goto(self, op: Goto) -> None: |
| self.emit_line('goto %s;' % self.label(op.label)) |
| |
| def visit_branch(self, op: Branch) -> None: |
| neg = '!' if op.negated else '' |
| |
| cond = '' |
| if op.op == Branch.BOOL: |
| expr_result = self.reg(op.value) |
| cond = '{}{}'.format(neg, expr_result) |
| elif op.op == Branch.IS_ERROR: |
| typ = op.value.type |
| compare = '!=' if op.negated else '==' |
| if isinstance(typ, RTuple): |
| # TODO: What about empty tuple? |
| cond = self.emitter.tuple_undefined_check_cond(typ, |
| self.reg(op.value), |
| self.c_error_value, |
| compare) |
| else: |
| cond = '{} {} {}'.format(self.reg(op.value), |
| compare, |
| self.c_error_value(typ)) |
| else: |
| assert False, "Invalid branch" |
| |
| # For error checks, tell the compiler the branch is unlikely |
| if op.traceback_entry is not None or op.rare: |
| cond = 'unlikely({})'.format(cond) |
| |
| self.emit_line('if ({}) {{'.format(cond)) |
| |
| self.emit_traceback(op) |
| |
| self.emit_lines( |
| 'goto %s;' % self.label(op.true), |
| '} else', |
| ' goto %s;' % self.label(op.false) |
| ) |
| |
| def visit_return(self, op: Return) -> None: |
| value_str = self.reg(op.value) |
| self.emit_line('return %s;' % value_str) |
| |
| def visit_tuple_set(self, op: TupleSet) -> None: |
| dest = self.reg(op) |
| tuple_type = op.tuple_type |
| self.emitter.declare_tuple_struct(tuple_type) |
| if len(op.items) == 0: # empty tuple |
| self.emit_line('{}.empty_struct_error_flag = 0;'.format(dest)) |
| else: |
| for i, item in enumerate(op.items): |
| self.emit_line('{}.f{} = {};'.format(dest, i, self.reg(item))) |
| self.emit_inc_ref(dest, tuple_type) |
| |
| def visit_assign(self, op: Assign) -> None: |
| dest = self.reg(op.dest) |
| src = self.reg(op.src) |
| # clang whines about self assignment (which we might generate |
| # for some casts), so don't emit it. |
| if dest != src: |
| self.emit_line('%s = %s;' % (dest, src)) |
| |
| def visit_load_error_value(self, op: LoadErrorValue) -> None: |
| if isinstance(op.type, RTuple): |
| values = [self.c_undefined_value(item) for item in op.type.types] |
| tmp = self.temp_name() |
| self.emit_line('%s %s = { %s };' % (self.ctype(op.type), tmp, ', '.join(values))) |
| self.emit_line('%s = %s;' % (self.reg(op), tmp)) |
| else: |
| self.emit_line('%s = %s;' % (self.reg(op), |
| self.c_error_value(op.type))) |
| |
| def get_attr_expr(self, obj: str, op: Union[GetAttr, SetAttr], decl_cl: ClassIR) -> str: |
| """Generate attribute accessor for normal (non-property) access. |
| |
| This either has a form like obj->attr_name for attributes defined in non-trait |
| classes, and *(obj + attr_offset) for attributes defined by traits. We also |
| insert all necessary C casts here. |
| """ |
| cast = '({} *)'.format(op.class_type.struct_name(self.emitter.names)) |
| if decl_cl.is_trait and op.class_type.class_ir.is_trait: |
| # For pure trait access find the offset first, offsets |
| # are ordered by attribute position in the cl.attributes dict. |
| # TODO: pre-calculate the mapping to make this faster. |
| trait_attr_index = list(decl_cl.attributes).index(op.attr) |
| # TODO: reuse these names somehow? |
| offset = self.emitter.temp_name() |
| self.declarations.emit_line('size_t {};'.format(offset)) |
| self.emitter.emit_line('{} = {};'.format( |
| offset, |
| 'CPy_FindAttrOffset({}, {}, {})'.format( |
| self.emitter.type_struct_name(decl_cl), |
| '({}{})->vtable'.format(cast, obj), |
| trait_attr_index, |
| ) |
| )) |
| attr_cast = '({} *)'.format(self.ctype(op.class_type.attr_type(op.attr))) |
| return '*{}((char *){} + {})'.format(attr_cast, obj, offset) |
| else: |
| # Cast to something non-trait. Note: for this to work, all struct |
| # members for non-trait classes must obey monotonic linear growth. |
| if op.class_type.class_ir.is_trait: |
| assert not decl_cl.is_trait |
| cast = '({} *)'.format(decl_cl.struct_name(self.emitter.names)) |
| return '({}{})->{}'.format( |
| cast, obj, self.emitter.attr(op.attr) |
| ) |
| |
| def visit_get_attr(self, op: GetAttr) -> None: |
| dest = self.reg(op) |
| obj = self.reg(op.obj) |
| rtype = op.class_type |
| cl = rtype.class_ir |
| attr_rtype, decl_cl = cl.attr_details(op.attr) |
| if cl.get_method(op.attr): |
| # Properties are essentially methods, so use vtable access for them. |
| version = '_TRAIT' if cl.is_trait else '' |
| self.emit_line('%s = CPY_GET_ATTR%s(%s, %s, %d, %s, %s); /* %s */' % ( |
| dest, |
| version, |
| obj, |
| self.emitter.type_struct_name(rtype.class_ir), |
| rtype.getter_index(op.attr), |
| rtype.struct_name(self.names), |
| self.ctype(rtype.attr_type(op.attr)), |
| op.attr)) |
| else: |
| # Otherwise, use direct or offset struct access. |
| attr_expr = self.get_attr_expr(obj, op, decl_cl) |
| self.emitter.emit_line('{} = {};'.format(dest, attr_expr)) |
| if attr_rtype.is_refcounted: |
| self.emitter.emit_undefined_attr_check( |
| attr_rtype, attr_expr, '==', unlikely=True |
| ) |
| exc_class = 'PyExc_AttributeError' |
| self.emitter.emit_lines( |
| 'PyErr_SetString({}, "attribute {} of {} undefined");'.format( |
| exc_class, repr(op.attr), repr(cl.name)), |
| '} else {') |
| self.emitter.emit_inc_ref(attr_expr, attr_rtype) |
| self.emitter.emit_line('}') |
| |
| def visit_set_attr(self, op: SetAttr) -> None: |
| dest = self.reg(op) |
| obj = self.reg(op.obj) |
| src = self.reg(op.src) |
| rtype = op.class_type |
| cl = rtype.class_ir |
| attr_rtype, decl_cl = cl.attr_details(op.attr) |
| if cl.get_method(op.attr): |
| # Again, use vtable access for properties... |
| version = '_TRAIT' if cl.is_trait else '' |
| self.emit_line('%s = CPY_SET_ATTR%s(%s, %s, %d, %s, %s, %s); /* %s */' % ( |
| dest, |
| version, |
| obj, |
| self.emitter.type_struct_name(rtype.class_ir), |
| rtype.setter_index(op.attr), |
| src, |
| rtype.struct_name(self.names), |
| self.ctype(rtype.attr_type(op.attr)), |
| op.attr)) |
| else: |
| # ...and struct access for normal attributes. |
| attr_expr = self.get_attr_expr(obj, op, decl_cl) |
| if attr_rtype.is_refcounted: |
| self.emitter.emit_undefined_attr_check(attr_rtype, attr_expr, '!=') |
| self.emitter.emit_dec_ref(attr_expr, attr_rtype) |
| self.emitter.emit_line('}') |
| # This steal the reference to src, so we don't need to increment the arg |
| self.emitter.emit_lines( |
| '{} = {};'.format(attr_expr, src), |
| '{} = 1;'.format(dest), |
| ) |
| |
| PREFIX_MAP = { |
| NAMESPACE_STATIC: STATIC_PREFIX, |
| NAMESPACE_TYPE: TYPE_PREFIX, |
| NAMESPACE_MODULE: MODULE_PREFIX, |
| } # type: Final |
| |
| def visit_load_static(self, op: LoadStatic) -> None: |
| dest = self.reg(op) |
| prefix = self.PREFIX_MAP[op.namespace] |
| name = self.emitter.static_name(op.identifier, op.module_name, prefix) |
| if op.namespace == NAMESPACE_TYPE: |
| name = '(PyObject *)%s' % name |
| ann = '' |
| if op.ann: |
| s = repr(op.ann) |
| if not any(x in s for x in ('/*', '*/', '\0')): |
| ann = ' /* %s */' % s |
| self.emit_line('%s = %s;%s' % (dest, name, ann)) |
| |
| def visit_init_static(self, op: InitStatic) -> None: |
| value = self.reg(op.value) |
| prefix = self.PREFIX_MAP[op.namespace] |
| name = self.emitter.static_name(op.identifier, op.module_name, prefix) |
| if op.namespace == NAMESPACE_TYPE: |
| value = '(PyTypeObject *)%s' % value |
| self.emit_line('%s = %s;' % (name, value)) |
| self.emit_inc_ref(name, op.value.type) |
| |
| def visit_tuple_get(self, op: TupleGet) -> None: |
| dest = self.reg(op) |
| src = self.reg(op.src) |
| self.emit_line('{} = {}.f{};'.format(dest, src, op.index)) |
| self.emit_inc_ref(dest, op.type) |
| |
| def get_dest_assign(self, dest: Value) -> str: |
| if not dest.is_void: |
| return self.reg(dest) + ' = ' |
| else: |
| return '' |
| |
| def visit_call(self, op: Call) -> None: |
| """Call native function.""" |
| dest = self.get_dest_assign(op) |
| args = ', '.join(self.reg(arg) for arg in op.args) |
| lib = self.emitter.get_group_prefix(op.fn) |
| cname = op.fn.cname(self.names) |
| self.emit_line('%s%s%s%s(%s);' % (dest, lib, NATIVE_PREFIX, cname, args)) |
| |
| def visit_method_call(self, op: MethodCall) -> None: |
| """Call native method.""" |
| dest = self.get_dest_assign(op) |
| obj = self.reg(op.obj) |
| |
| rtype = op.receiver_type |
| class_ir = rtype.class_ir |
| name = op.method |
| method_idx = rtype.method_index(name) |
| method = rtype.class_ir.get_method(name) |
| assert method is not None |
| |
| # Can we call the method directly, bypassing vtable? |
| is_direct = class_ir.is_method_final(name) |
| |
| # The first argument gets omitted for static methods and |
| # turned into the class for class methods |
| obj_args = ( |
| [] if method.decl.kind == FUNC_STATICMETHOD else |
| ['(PyObject *)Py_TYPE({})'.format(obj)] if method.decl.kind == FUNC_CLASSMETHOD else |
| [obj]) |
| args = ', '.join(obj_args + [self.reg(arg) for arg in op.args]) |
| mtype = native_function_type(method, self.emitter) |
| version = '_TRAIT' if rtype.class_ir.is_trait else '' |
| if is_direct: |
| # Directly call method, without going through the vtable. |
| lib = self.emitter.get_group_prefix(method.decl) |
| self.emit_line('{}{}{}{}({});'.format( |
| dest, lib, NATIVE_PREFIX, method.cname(self.names), args)) |
| else: |
| # Call using vtable. |
| self.emit_line('{}CPY_GET_METHOD{}({}, {}, {}, {}, {})({}); /* {} */'.format( |
| dest, version, obj, self.emitter.type_struct_name(rtype.class_ir), |
| method_idx, rtype.struct_name(self.names), mtype, args, op.method)) |
| |
| def visit_inc_ref(self, op: IncRef) -> None: |
| src = self.reg(op.src) |
| self.emit_inc_ref(src, op.src.type) |
| |
| def visit_dec_ref(self, op: DecRef) -> None: |
| src = self.reg(op.src) |
| self.emit_dec_ref(src, op.src.type, op.is_xdec) |
| |
| def visit_box(self, op: Box) -> None: |
| self.emitter.emit_box(self.reg(op.src), self.reg(op), op.src.type, can_borrow=True) |
| |
| def visit_cast(self, op: Cast) -> None: |
| self.emitter.emit_cast(self.reg(op.src), self.reg(op), op.type, |
| src_type=op.src.type) |
| |
| def visit_unbox(self, op: Unbox) -> None: |
| self.emitter.emit_unbox(self.reg(op.src), self.reg(op), op.type) |
| |
| def visit_unreachable(self, op: Unreachable) -> None: |
| self.emitter.emit_line('CPy_Unreachable();') |
| |
| def visit_raise_standard_error(self, op: RaiseStandardError) -> None: |
| # TODO: Better escaping of backspaces and such |
| if op.value is not None: |
| if isinstance(op.value, str): |
| message = op.value.replace('"', '\\"') |
| self.emitter.emit_line( |
| 'PyErr_SetString(PyExc_{}, "{}");'.format(op.class_name, message)) |
| elif isinstance(op.value, Value): |
| self.emitter.emit_line( |
| 'PyErr_SetObject(PyExc_{}, {});'.format(op.class_name, |
| self.emitter.reg(op.value))) |
| else: |
| assert False, 'op value type must be either str or Value' |
| else: |
| self.emitter.emit_line('PyErr_SetNone(PyExc_{});'.format(op.class_name)) |
| self.emitter.emit_line('{} = 0;'.format(self.reg(op))) |
| |
| def visit_call_c(self, op: CallC) -> None: |
| if op.is_void: |
| dest = '' |
| else: |
| dest = self.get_dest_assign(op) |
| args = ', '.join(self.reg(arg) for arg in op.args) |
| self.emitter.emit_line("{}{}({});".format(dest, op.function_name, args)) |
| |
| def visit_truncate(self, op: Truncate) -> None: |
| dest = self.reg(op) |
| value = self.reg(op.src) |
| # for C backend the generated code are straight assignments |
| self.emit_line("{} = {};".format(dest, value)) |
| |
| def visit_load_global(self, op: LoadGlobal) -> None: |
| dest = self.reg(op) |
| ann = '' |
| if op.ann: |
| s = repr(op.ann) |
| if not any(x in s for x in ('/*', '*/', '\0')): |
| ann = ' /* %s */' % s |
| self.emit_line('%s = %s;%s' % (dest, op.identifier, ann)) |
| |
| def visit_int_op(self, op: IntOp) -> None: |
| dest = self.reg(op) |
| lhs = self.reg(op.lhs) |
| rhs = self.reg(op.rhs) |
| self.emit_line('%s = %s %s %s;' % (dest, lhs, op.op_str[op.op], rhs)) |
| |
| def visit_comparison_op(self, op: ComparisonOp) -> None: |
| dest = self.reg(op) |
| lhs = self.reg(op.lhs) |
| rhs = self.reg(op.rhs) |
| lhs_cast = "" |
| rhs_cast = "" |
| signed_op = {ComparisonOp.SLT, ComparisonOp.SGT, ComparisonOp.SLE, ComparisonOp.SGE} |
| unsigned_op = {ComparisonOp.ULT, ComparisonOp.UGT, ComparisonOp.ULE, ComparisonOp.UGE} |
| if op.op in signed_op: |
| lhs_cast = self.emit_signed_int_cast(op.lhs.type) |
| rhs_cast = self.emit_signed_int_cast(op.rhs.type) |
| elif op.op in unsigned_op: |
| lhs_cast = self.emit_unsigned_int_cast(op.lhs.type) |
| rhs_cast = self.emit_unsigned_int_cast(op.rhs.type) |
| self.emit_line('%s = %s%s %s %s%s;' % (dest, lhs_cast, lhs, |
| op.op_str[op.op], rhs_cast, rhs)) |
| |
| def visit_load_mem(self, op: LoadMem) -> None: |
| dest = self.reg(op) |
| src = self.reg(op.src) |
| # TODO: we shouldn't dereference to type that are pointer type so far |
| type = self.ctype(op.type) |
| self.emit_line('%s = *(%s *)%s;' % (dest, type, src)) |
| |
| def visit_set_mem(self, op: SetMem) -> None: |
| dest = self.reg(op.dest) |
| src = self.reg(op.src) |
| dest_type = self.ctype(op.dest_type) |
| # clang whines about self assignment (which we might generate |
| # for some casts), so don't emit it. |
| if dest != src: |
| self.emit_line('*(%s *)%s = %s;' % (dest_type, dest, src)) |
| |
| def visit_get_element_ptr(self, op: GetElementPtr) -> None: |
| dest = self.reg(op) |
| src = self.reg(op.src) |
| # TODO: support tuple type |
| assert isinstance(op.src_type, RStruct) |
| assert op.field in op.src_type.names, "Invalid field name." |
| self.emit_line('%s = (%s)&((%s *)%s)->%s;' % (dest, op.type._ctype, op.src_type.name, |
| src, op.field)) |
| |
| def visit_load_address(self, op: LoadAddress) -> None: |
| typ = op.type |
| dest = self.reg(op) |
| src = self.reg(op.src) if isinstance(op.src, Register) else op.src |
| self.emit_line('%s = (%s)&%s;' % (dest, typ._ctype, src)) |
| |
| # Helpers |
| |
| def label(self, label: BasicBlock) -> str: |
| return self.emitter.label(label) |
| |
| def reg(self, reg: Value) -> str: |
| if isinstance(reg, Integer): |
| val = reg.value |
| if val == 0 and is_pointer_rprimitive(reg.type): |
| return "NULL" |
| return str(val) |
| else: |
| return self.emitter.reg(reg) |
| |
| def ctype(self, rtype: RType) -> str: |
| return self.emitter.ctype(rtype) |
| |
| def c_error_value(self, rtype: RType) -> str: |
| return self.emitter.c_error_value(rtype) |
| |
| def c_undefined_value(self, rtype: RType) -> str: |
| return self.emitter.c_undefined_value(rtype) |
| |
| def emit_line(self, line: str) -> None: |
| self.emitter.emit_line(line) |
| |
| def emit_lines(self, *lines: str) -> None: |
| self.emitter.emit_lines(*lines) |
| |
| def emit_inc_ref(self, dest: str, rtype: RType) -> None: |
| self.emitter.emit_inc_ref(dest, rtype) |
| |
| def emit_dec_ref(self, dest: str, rtype: RType, is_xdec: bool) -> None: |
| self.emitter.emit_dec_ref(dest, rtype, is_xdec) |
| |
| def emit_declaration(self, line: str) -> None: |
| self.declarations.emit_line(line) |
| |
| def emit_traceback(self, op: Branch) -> None: |
| if op.traceback_entry is not None: |
| globals_static = self.emitter.static_name('globals', self.module_name) |
| self.emit_line('CPy_AddTraceback("%s", "%s", %d, %s);' % ( |
| self.source_path.replace("\\", "\\\\"), |
| op.traceback_entry[0], |
| op.traceback_entry[1], |
| globals_static)) |
| if DEBUG_ERRORS: |
| self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");') |
| |
| def emit_signed_int_cast(self, type: RType) -> str: |
| if is_tagged(type): |
| return '(Py_ssize_t)' |
| else: |
| return '' |
| |
| def emit_unsigned_int_cast(self, type: RType) -> str: |
| if is_int32_rprimitive(type): |
| return '(uint32_t)' |
| elif is_int64_rprimitive(type): |
| return '(uint64_t)' |
| else: |
| return '' |