Flush keep alives before taking a branch
diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index 310f28d..d515470 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py
@@ -970,12 +970,14 @@ is_short_int_rprimitive(rhs.type)))): # We can skip the tag check check = self.comparison_op(lhs, rhs, int_comparison_op_mapping[op][0], line) + self.flush_keep_alives() self.add(Branch(check, true, false, Branch.BOOL)) return op_type, c_func_desc, negate_result, swap_op = int_comparison_op_mapping[op] int_block, short_int_block = BasicBlock(), BasicBlock() check_lhs = self.check_tagged_short_int(lhs, line, negated=True) if is_eq or is_short_int_rprimitive(rhs.type): + self.flush_keep_alives() self.add(Branch(check_lhs, int_block, short_int_block, Branch.BOOL)) else: # For non-equality logical ops (less/greater than, etc.), need to check both sides @@ -983,6 +985,7 @@ self.add(Branch(check_lhs, int_block, rhs_block, Branch.BOOL)) self.activate_block(rhs_block) check_rhs = self.check_tagged_short_int(rhs, line, negated=True) + self.flush_keep_alives() self.add(Branch(check_rhs, int_block, short_int_block, Branch.BOOL)) # Arbitrary integers (slow path) self.activate_block(int_block) @@ -994,6 +997,7 @@ if negate_result: self.add(Branch(call, false, true, Branch.BOOL)) else: + self.flush_keep_alives() self.add(Branch(call, true, false, Branch.BOOL)) # Short integers (fast path) self.activate_block(short_int_block)
diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index 9fe5aba..5a574ac 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test
@@ -1289,3 +1289,47 @@ r1 = r0.e r2 = r1.x return r2 + +[case testBorrowResultOfCustomGetItemInIfStatement] +from typing import List + +class C: + def __getitem__(self, x: int) -> List[int]: + return [] + +def f(x: C) -> None: + # In this case the keep_alive must come before the branch, as otherwise + # reference count transform will get confused. + if x[1][0] == 2: + y = 1 + else: + y = 2 +[out] +def C.__getitem__(self, x): + self :: __main__.C + x :: int + r0 :: list +L0: + r0 = PyList_New(0) + return r0 +def f(x): + x :: __main__.C + r0 :: list + r1 :: object + r2 :: int + r3 :: bit + y :: int +L0: + r0 = x.__getitem__(2) + r1 = CPyList_GetItemShortBorrow(r0, 0) + r2 = unbox(int, r1) + r3 = r2 == 4 + keep_alive r0 + if r3 goto L1 else goto L2 :: bool +L1: + y = 2 + goto L3 +L2: + y = 4 +L3: + return 1