| from contextlib import contextmanager |
| from typing import Generator, List, Optional, Tuple |
| |
| from mypy.nodes import MatchStmt, NameExpr, TypeInfo |
| from mypy.patterns import ( |
| AsPattern, |
| ClassPattern, |
| MappingPattern, |
| OrPattern, |
| Pattern, |
| SequencePattern, |
| SingletonPattern, |
| StarredPattern, |
| ValuePattern, |
| ) |
| from mypy.traverser import TraverserVisitor |
| from mypy.types import Instance, TupleType, get_proper_type |
| from mypyc.ir.ops import BasicBlock, Value |
| from mypyc.ir.rtypes import object_rprimitive |
| from mypyc.irbuild.builder import IRBuilder |
| from mypyc.primitives.dict_ops import ( |
| dict_copy, |
| dict_del_item, |
| mapping_has_key, |
| supports_mapping_protocol, |
| ) |
| from mypyc.primitives.generic_ops import generic_ssize_t_len_op |
| from mypyc.primitives.list_ops import ( |
| sequence_get_item, |
| sequence_get_slice, |
| supports_sequence_protocol, |
| ) |
| from mypyc.primitives.misc_ops import fast_isinstance_op, slow_isinstance_op |
| |
| # From: https://peps.python.org/pep-0634/#class-patterns |
| MATCHABLE_BUILTINS = { |
| "builtins.bool", |
| "builtins.bytearray", |
| "builtins.bytes", |
| "builtins.dict", |
| "builtins.float", |
| "builtins.frozenset", |
| "builtins.int", |
| "builtins.list", |
| "builtins.set", |
| "builtins.str", |
| "builtins.tuple", |
| } |
| |
| |
| class MatchVisitor(TraverserVisitor): |
| builder: IRBuilder |
| code_block: BasicBlock |
| next_block: BasicBlock |
| final_block: BasicBlock |
| subject: Value |
| match: MatchStmt |
| |
| as_pattern: Optional[AsPattern] = None |
| |
| def __init__(self, builder: IRBuilder, match_node: MatchStmt) -> None: |
| self.builder = builder |
| |
| self.code_block = BasicBlock() |
| self.next_block = BasicBlock() |
| self.final_block = BasicBlock() |
| |
| self.match = match_node |
| self.subject = builder.accept(match_node.subject) |
| |
| def build_match_body(self, index: int) -> None: |
| self.builder.activate_block(self.code_block) |
| |
| guard = self.match.guards[index] |
| |
| if guard: |
| self.code_block = BasicBlock() |
| |
| cond = self.builder.accept(guard) |
| self.builder.add_bool_branch(cond, self.code_block, self.next_block) |
| |
| self.builder.activate_block(self.code_block) |
| |
| self.builder.accept(self.match.bodies[index]) |
| self.builder.goto(self.final_block) |
| |
| def visit_match_stmt(self, m: MatchStmt) -> None: |
| for i, pattern in enumerate(m.patterns): |
| self.code_block = BasicBlock() |
| self.next_block = BasicBlock() |
| |
| pattern.accept(self) |
| |
| self.build_match_body(i) |
| self.builder.activate_block(self.next_block) |
| |
| self.builder.goto_and_activate(self.final_block) |
| |
| def visit_value_pattern(self, pattern: ValuePattern) -> None: |
| value = self.builder.accept(pattern.expr) |
| |
| cond = self.builder.binary_op(self.subject, value, "==", pattern.expr.line) |
| |
| self.bind_as_pattern(value) |
| |
| self.builder.add_bool_branch(cond, self.code_block, self.next_block) |
| |
| def visit_or_pattern(self, pattern: OrPattern) -> None: |
| backup_block = self.next_block |
| self.next_block = BasicBlock() |
| |
| for p in pattern.patterns: |
| # Hack to ensure the as pattern is bound to each pattern in the |
| # "or" pattern, but not every subpattern |
| backup = self.as_pattern |
| p.accept(self) |
| self.as_pattern = backup |
| |
| self.builder.activate_block(self.next_block) |
| self.next_block = BasicBlock() |
| |
| self.next_block = backup_block |
| self.builder.goto(self.next_block) |
| |
| def visit_class_pattern(self, pattern: ClassPattern) -> None: |
| # TODO: use faster instance check for native classes (while still |
| # making sure to account for inheritence) |
| isinstance_op = ( |
| fast_isinstance_op |
| if self.builder.is_builtin_ref_expr(pattern.class_ref) |
| else slow_isinstance_op |
| ) |
| |
| cond = self.builder.call_c( |
| isinstance_op, [self.subject, self.builder.accept(pattern.class_ref)], pattern.line |
| ) |
| |
| self.builder.add_bool_branch(cond, self.code_block, self.next_block) |
| |
| self.bind_as_pattern(self.subject, new_block=True) |
| |
| if pattern.positionals: |
| if pattern.class_ref.fullname in MATCHABLE_BUILTINS: |
| self.builder.activate_block(self.code_block) |
| self.code_block = BasicBlock() |
| |
| pattern.positionals[0].accept(self) |
| |
| return |
| |
| node = pattern.class_ref.node |
| assert isinstance(node, TypeInfo) |
| |
| ty = node.names.get("__match_args__") |
| assert ty |
| |
| match_args_type = get_proper_type(ty.type) |
| assert isinstance(match_args_type, TupleType) |
| |
| match_args: List[str] = [] |
| |
| for item in match_args_type.items: |
| proper_item = get_proper_type(item) |
| assert isinstance(proper_item, Instance) and proper_item.last_known_value |
| |
| match_arg = proper_item.last_known_value.value |
| assert isinstance(match_arg, str) |
| |
| match_args.append(match_arg) |
| |
| for i, expr in enumerate(pattern.positionals): |
| self.builder.activate_block(self.code_block) |
| self.code_block = BasicBlock() |
| |
| # TODO: use faster "get_attr" method instead when calling on native or |
| # builtin objects |
| positional = self.builder.py_get_attr(self.subject, match_args[i], expr.line) |
| |
| with self.enter_subpattern(positional): |
| expr.accept(self) |
| |
| for key, value in zip(pattern.keyword_keys, pattern.keyword_values): |
| self.builder.activate_block(self.code_block) |
| self.code_block = BasicBlock() |
| |
| # TODO: same as above "get_attr" comment |
| attr = self.builder.py_get_attr(self.subject, key, value.line) |
| |
| with self.enter_subpattern(attr): |
| value.accept(self) |
| |
| def visit_as_pattern(self, pattern: AsPattern) -> None: |
| if pattern.pattern: |
| old_pattern = self.as_pattern |
| self.as_pattern = pattern |
| pattern.pattern.accept(self) |
| self.as_pattern = old_pattern |
| |
| elif pattern.name: |
| target = self.builder.get_assignment_target(pattern.name) |
| |
| self.builder.assign(target, self.subject, pattern.line) |
| |
| self.builder.goto(self.code_block) |
| |
| def visit_singleton_pattern(self, pattern: SingletonPattern) -> None: |
| if pattern.value is None: |
| obj = self.builder.none_object() |
| elif pattern.value is True: |
| obj = self.builder.true() |
| else: |
| obj = self.builder.false() |
| |
| cond = self.builder.binary_op(self.subject, obj, "is", pattern.line) |
| |
| self.builder.add_bool_branch(cond, self.code_block, self.next_block) |
| |
| def visit_mapping_pattern(self, pattern: MappingPattern) -> None: |
| is_dict = self.builder.call_c(supports_mapping_protocol, [self.subject], pattern.line) |
| |
| self.builder.add_bool_branch(is_dict, self.code_block, self.next_block) |
| |
| keys: List[Value] = [] |
| |
| for key, value in zip(pattern.keys, pattern.values): |
| self.builder.activate_block(self.code_block) |
| self.code_block = BasicBlock() |
| |
| key_value = self.builder.accept(key) |
| keys.append(key_value) |
| |
| exists = self.builder.call_c(mapping_has_key, [self.subject, key_value], pattern.line) |
| |
| self.builder.add_bool_branch(exists, self.code_block, self.next_block) |
| self.builder.activate_block(self.code_block) |
| self.code_block = BasicBlock() |
| |
| item = self.builder.gen_method_call( |
| self.subject, "__getitem__", [key_value], object_rprimitive, pattern.line |
| ) |
| |
| with self.enter_subpattern(item): |
| value.accept(self) |
| |
| if pattern.rest: |
| self.builder.activate_block(self.code_block) |
| self.code_block = BasicBlock() |
| |
| rest = self.builder.call_c(dict_copy, [self.subject], pattern.rest.line) |
| |
| target = self.builder.get_assignment_target(pattern.rest) |
| |
| self.builder.assign(target, rest, pattern.rest.line) |
| |
| for i, key_name in enumerate(keys): |
| self.builder.call_c(dict_del_item, [rest, key_name], pattern.keys[i].line) |
| |
| self.builder.goto(self.code_block) |
| |
| def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: |
| star_index, capture, patterns = prep_sequence_pattern(seq_pattern) |
| |
| is_list = self.builder.call_c(supports_sequence_protocol, [self.subject], seq_pattern.line) |
| |
| self.builder.add_bool_branch(is_list, self.code_block, self.next_block) |
| |
| self.builder.activate_block(self.code_block) |
| self.code_block = BasicBlock() |
| |
| actual_len = self.builder.call_c(generic_ssize_t_len_op, [self.subject], seq_pattern.line) |
| min_len = len(patterns) |
| |
| is_long_enough = self.builder.binary_op( |
| actual_len, |
| self.builder.load_int(min_len), |
| "==" if star_index is None else ">=", |
| seq_pattern.line, |
| ) |
| |
| self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block) |
| |
| for i, pattern in enumerate(patterns): |
| self.builder.activate_block(self.code_block) |
| self.code_block = BasicBlock() |
| |
| if star_index is not None and i >= star_index: |
| current = self.builder.binary_op( |
| actual_len, self.builder.load_int(min_len - i), "-", pattern.line |
| ) |
| |
| else: |
| current = self.builder.load_int(i) |
| |
| item = self.builder.call_c(sequence_get_item, [self.subject, current], pattern.line) |
| |
| with self.enter_subpattern(item): |
| pattern.accept(self) |
| |
| if capture and star_index is not None: |
| self.builder.activate_block(self.code_block) |
| self.code_block = BasicBlock() |
| |
| capture_end = self.builder.binary_op( |
| actual_len, self.builder.load_int(min_len - star_index), "-", capture.line |
| ) |
| |
| rest = self.builder.call_c( |
| sequence_get_slice, |
| [self.subject, self.builder.load_int(star_index), capture_end], |
| capture.line, |
| ) |
| |
| target = self.builder.get_assignment_target(capture) |
| self.builder.assign(target, rest, capture.line) |
| |
| self.builder.goto(self.code_block) |
| |
| def bind_as_pattern(self, value: Value, new_block: bool = False) -> None: |
| if self.as_pattern and self.as_pattern.pattern and self.as_pattern.name: |
| if new_block: |
| self.builder.activate_block(self.code_block) |
| self.code_block = BasicBlock() |
| |
| target = self.builder.get_assignment_target(self.as_pattern.name) |
| self.builder.assign(target, value, self.as_pattern.pattern.line) |
| |
| self.as_pattern = None |
| |
| if new_block: |
| self.builder.goto(self.code_block) |
| |
| @contextmanager |
| def enter_subpattern(self, subject: Value) -> Generator[None, None, None]: |
| old_subject = self.subject |
| self.subject = subject |
| yield |
| self.subject = old_subject |
| |
| |
| def prep_sequence_pattern( |
| seq_pattern: SequencePattern, |
| ) -> Tuple[Optional[int], Optional[NameExpr], List[Pattern]]: |
| star_index: Optional[int] = None |
| capture: Optional[NameExpr] = None |
| patterns: List[Pattern] = [] |
| |
| for i, pattern in enumerate(seq_pattern.patterns): |
| if isinstance(pattern, StarredPattern): |
| star_index = i |
| capture = pattern.capture |
| |
| else: |
| patterns.append(pattern) |
| |
| return star_index, capture, patterns |