[mypyc] Specialize `s[i] == 'x'` to a codepoint int compare (#21579)
7th PR of #21418
Lowers `s[i] == 'x'` (and the symmetric `==` / `!=` forms) down to a
bounds-checked codepoint read + int compare, instead of `CPyStr_GetItem`
+ `CPyStr_EqualLiteral` which (may) allocate a 1-character `PyUnicode`
per iteration. No annotations are required for this optimization.
On microbenchmarks (1-compare-per-iter hot loop, ~2.5M-codepoint
SQL-like string) the comparison is ~3.6x times faster.
<br />
Some follow up optimizations that might be worth it I can work on:
- In operator e.g `s[i] in ('a', 'b', 'c')` --> Fuse to one check with N
int comparisons
- Comparison operators e.g `s[i] < 'x'` --> Need to expand the op set
- `s[i] == s[j]` --> Need drop the literal-only guarddiff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py
index e8d22a0..f953dd3 100644
--- a/mypyc/irbuild/expression.py
+++ b/mypyc/irbuild/expression.py
@@ -93,9 +93,12 @@
is_list_rprimitive,
is_none_rprimitive,
is_object_rprimitive,
+ is_str_rprimitive,
+ is_tagged,
is_tuple_rprimitive,
object_rprimitive,
set_rprimitive,
+ short_int_rprimitive,
vec_api_by_item_type,
)
from mypyc.irbuild.ast_helpers import is_borrow_friendly_expr, process_conditional
@@ -119,6 +122,7 @@
apply_dunder_specialization,
apply_function_specialization,
apply_method_specialization,
+ translate_getitem_with_bounds_check,
translate_object_new,
translate_object_setattr,
)
@@ -137,7 +141,12 @@
from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op
from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op
from mypyc.primitives.set_ops import set_add_op, set_in_op, set_update_op
-from mypyc.primitives.str_ops import str_slice_op
+from mypyc.primitives.str_ops import (
+ str_adjust_index_op,
+ str_get_item_unsafe_as_int_op,
+ str_range_check_op,
+ str_slice_op,
+)
from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op
# Name and attribute references
@@ -918,6 +927,16 @@
return result
if len(e.operators) == 1:
+ # s[i] == 'x' / s[i] != 'x' (and the symmetric RHS) -> int compare of
+ # codepoints. Skips the per-iteration 1-char str allocation/lookup and
+ # generic str equality call.
+ if first_op in ("==", "!="):
+ result = try_specialize_str_index_compare(
+ builder, first_op, e.operands[0], e.operands[1], e.line
+ )
+ if result is not None:
+ return result
+
# Special some common simple cases
if first_op in ("is", "is not"):
right_expr = e.operands[1]
@@ -960,6 +979,50 @@
return go(0, builder.accept(e.operands[0]))
+def try_specialize_str_index_compare(
+ builder: IRBuilder, op: str, lhs: Expression, rhs: Expression, line: int
+) -> Value | None:
+ """Specialize `s[i] == 'x'` / `s[i] != 'x'` (and the symmetric form with
+ operands swapped) into an int compare of codepoints.
+
+ Returns None if the pattern doesn't match: the indexed base must be str,
+ the index must be an integer, and the literal must be a 1-character str.
+ Multi-character or empty literals fall through to the generic str compare
+ (which still returns False for them, matching today's behavior).
+ """
+ # Normalize so the IndexExpr is on the left.
+ if isinstance(rhs, IndexExpr) and not isinstance(lhs, IndexExpr):
+ tmp = lhs
+ lhs, rhs = rhs, tmp
+ # Shape: s[i] {==, !=} "x" where "x" is exactly one codepoint.
+ if (
+ not isinstance(lhs, IndexExpr)
+ or not isinstance(rhs, StrExpr)
+ or len(rhs.value) != 1
+ or not is_str_rprimitive(builder.node_type(lhs.base))
+ ):
+ return None
+ index_type = builder.node_type(lhs.index)
+ if not (is_tagged(index_type) or is_fixed_width_rtype(index_type)):
+ return None
+
+ # ord(s[i]) with bounds check; raises IndexError for out-of-range indices,
+ # matching the behavior of the generic s[i] path.
+ codepoint = translate_getitem_with_bounds_check(
+ builder,
+ lhs.base,
+ [lhs.index],
+ lhs,
+ str_adjust_index_op,
+ str_range_check_op,
+ str_get_item_unsafe_as_int_op,
+ )
+ if codepoint is None:
+ return None
+ literal_cp = Integer(ord(rhs.value), short_int_rprimitive, line)
+ return builder.binary_op(codepoint, literal_cp, op, line)
+
+
def try_specialize_in_expr(
builder: IRBuilder, op: str, lhs: Expression, rhs: Expression, line: int
) -> Value | None:
diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test
index 81cd5bd..b16057f 100644
--- a/mypyc/test-data/irbuild-str.test
+++ b/mypyc/test-data/irbuild-str.test
@@ -1025,3 +1025,153 @@
L0:
r0 = CPyStr_IsDigit(x)
return r0
+
+[case testStrIndexEqLiteral_64bit]
+def is_comma(s: str, i: int) -> bool:
+ return s[i] == ","
+def is_comma_swapped(s: str, i: int) -> bool:
+ return "," == s[i]
+def is_comma_ne(s: str, i: int) -> bool:
+ return s[i] != ","
+[out]
+def is_comma(s, i):
+ s :: str
+ i :: int
+ r0 :: native_int
+ r1 :: bit
+ r2, r3 :: i64
+ r4 :: ptr
+ r5 :: c_ptr
+ r6, r7 :: i64
+ r8, r9 :: bool
+ r10 :: short_int
+ r11 :: bit
+L0:
+ r0 = i & 1
+ r1 = r0 == 0
+ if r1 goto L1 else goto L2 :: bool
+L1:
+ r2 = i >> 1
+ r3 = r2
+ goto L3
+L2:
+ r4 = i ^ 1
+ r5 = r4
+ r6 = CPyLong_AsInt64(r5)
+ r3 = r6
+ keep_alive i
+L3:
+ r7 = CPyStr_AdjustIndex(s, r3)
+ r8 = CPyStr_RangeCheck(s, r7)
+ if r8 goto L5 else goto L4 :: bool
+L4:
+ r9 = raise IndexError('index out of range')
+ unreachable
+L5:
+ r10 = CPyStr_GetItemUnsafeAsInt(s, r7)
+ r11 = int_eq r10, 88
+ return r11
+def is_comma_swapped(s, i):
+ s :: str
+ i :: int
+ r0 :: native_int
+ r1 :: bit
+ r2, r3 :: i64
+ r4 :: ptr
+ r5 :: c_ptr
+ r6, r7 :: i64
+ r8, r9 :: bool
+ r10 :: short_int
+ r11 :: bit
+L0:
+ r0 = i & 1
+ r1 = r0 == 0
+ if r1 goto L1 else goto L2 :: bool
+L1:
+ r2 = i >> 1
+ r3 = r2
+ goto L3
+L2:
+ r4 = i ^ 1
+ r5 = r4
+ r6 = CPyLong_AsInt64(r5)
+ r3 = r6
+ keep_alive i
+L3:
+ r7 = CPyStr_AdjustIndex(s, r3)
+ r8 = CPyStr_RangeCheck(s, r7)
+ if r8 goto L5 else goto L4 :: bool
+L4:
+ r9 = raise IndexError('index out of range')
+ unreachable
+L5:
+ r10 = CPyStr_GetItemUnsafeAsInt(s, r7)
+ r11 = int_eq r10, 88
+ return r11
+def is_comma_ne(s, i):
+ s :: str
+ i :: int
+ r0 :: native_int
+ r1 :: bit
+ r2, r3 :: i64
+ r4 :: ptr
+ r5 :: c_ptr
+ r6, r7 :: i64
+ r8, r9 :: bool
+ r10 :: short_int
+ r11 :: bit
+L0:
+ r0 = i & 1
+ r1 = r0 == 0
+ if r1 goto L1 else goto L2 :: bool
+L1:
+ r2 = i >> 1
+ r3 = r2
+ goto L3
+L2:
+ r4 = i ^ 1
+ r5 = r4
+ r6 = CPyLong_AsInt64(r5)
+ r3 = r6
+ keep_alive i
+L3:
+ r7 = CPyStr_AdjustIndex(s, r3)
+ r8 = CPyStr_RangeCheck(s, r7)
+ if r8 goto L5 else goto L4 :: bool
+L4:
+ r9 = raise IndexError('index out of range')
+ unreachable
+L5:
+ r10 = CPyStr_GetItemUnsafeAsInt(s, r7)
+ r11 = int_ne r10, 88
+ return r11
+
+[case testStrIndexEqLiteralNoSpecialize]
+def two_char_literal(s: str, i: int) -> bool:
+ # Multi-char literals don't match the specialization; falls through to
+ # the generic str equality path.
+ return s[i] == "ab"
+def empty_literal(s: str, i: int) -> bool:
+ # Empty string literals also fall through; the generic path returns False.
+ return s[i] == ""
+[out]
+def two_char_literal(s, i):
+ s :: str
+ i :: int
+ r0, r1 :: str
+ r2 :: bool
+L0:
+ r0 = CPyStr_GetItem(s, i)
+ r1 = 'ab'
+ r2 = CPyStr_EqualLiteral(r0, r1, 2)
+ return r2
+def empty_literal(s, i):
+ s :: str
+ i :: int
+ r0, r1 :: str
+ r2 :: bool
+L0:
+ r0 = CPyStr_GetItem(s, i)
+ r1 = ''
+ r2 = CPyStr_EqualLiteral(r0, r1, 0)
+ return r2
diff --git a/mypyc/test-data/run-strings.test b/mypyc/test-data/run-strings.test
index ec662da..81b8558 100644
--- a/mypyc/test-data/run-strings.test
+++ b/mypyc/test-data/run-strings.test
@@ -1412,3 +1412,64 @@
assert not "\u00e9\u00e8".isdigit()
assert not "123\u00e9".isdigit()
assert not "\U0001d7ce!".isdigit()
+
+[case testStrIndexEqLiteralSpecialize]
+from testutil import assertRaises
+
+# The specializer fires on the AST shape `IndexExpr == StrLiteral` (or the
+# symmetric swap, and `!=`). The literal has to be a real source-level
+# string literal (can't be passed in as a parameter), so each test
+# function pins one distinct shape.
+
+def eq_comma(s: str, i: int) -> bool:
+ # Specialized: s[i] == "x".
+ return s[i] == ","
+
+def ne_comma(s: str, i: int) -> bool:
+ # Specialized: s[i] != "x".
+ return s[i] != ","
+
+def comma_eq(s: str, i: int) -> bool:
+ # Specialized: "x" == s[i]. Operand-swap is normalized.
+ return "," == s[i]
+
+def eq_two_chars(s: str, i: int) -> bool:
+ # Not specialized: literal isn't 1 char. Falls through to the generic
+ # str compare, which returns False since s[i] is always 1 codepoint.
+ return s[i] == "ab"
+
+def eq_empty(s: str, i: int) -> bool:
+ # Not specialized: empty literal. Same fall-through.
+ return s[i] == ""
+
+def test_specialized_path() -> None:
+ s = "a,b" # comma at index 1
+ assert eq_comma(s, 1)
+ assert not eq_comma(s, 0)
+ assert not eq_comma(s, 2)
+ # != inverts.
+ assert ne_comma(s, 0)
+ assert not ne_comma(s, 1)
+ # Literal on the LHS is normalized to the same shape.
+ assert comma_eq(s, 1)
+ assert not comma_eq(s, 0)
+
+def test_negative_index_is_adjusted() -> None:
+ s = "a,b"
+ assert eq_comma(s, -2) # -2 -> 1 (',')
+ assert not eq_comma(s, -1) # -1 -> 2 ('b')
+
+def test_non_1char_literal_falls_through() -> None:
+ s = "a,b"
+ # Generic str compare answers False because s[i] has length 1.
+ assert not eq_two_chars(s, 0)
+ assert not eq_two_chars(s, 1)
+ assert not eq_empty(s, 0)
+
+def test_out_of_range_raises_indexerror() -> None:
+ # Bounds-check semantics match the unspecialized s[i] path.
+ s = "a,b"
+ with assertRaises(IndexError):
+ eq_comma(s, 3)
+ with assertRaises(IndexError):
+ eq_comma(s, -4)