| """Constant folding of IR values. |
| |
| For example, 3 + 5 can be constant folded into 8. |
| """ |
| |
| from typing import Optional, Union |
| from typing_extensions import Final |
| |
| from mypy.nodes import Expression, IntExpr, StrExpr, OpExpr, UnaryExpr, NameExpr, MemberExpr, Var |
| from mypyc.irbuild.builder import IRBuilder |
| |
| |
| # All possible result types of constant folding |
| ConstantValue = Union[int, str] |
| CONST_TYPES: Final = (int, str) |
| |
| |
| def constant_fold_expr(builder: IRBuilder, expr: Expression) -> Optional[ConstantValue]: |
| """Return the constant value of an expression for supported operations. |
| |
| Return None otherwise. |
| """ |
| if isinstance(expr, IntExpr): |
| return expr.value |
| if isinstance(expr, StrExpr): |
| return expr.value |
| elif isinstance(expr, NameExpr): |
| node = expr.node |
| if isinstance(node, Var) and node.is_final: |
| value = node.final_value |
| if isinstance(value, (CONST_TYPES)): |
| return value |
| elif isinstance(expr, MemberExpr): |
| final = builder.get_final_ref(expr) |
| if final is not None: |
| fn, final_var, native = final |
| if final_var.is_final: |
| value = final_var.final_value |
| if isinstance(value, (CONST_TYPES)): |
| return value |
| elif isinstance(expr, OpExpr): |
| left = constant_fold_expr(builder, expr.left) |
| right = constant_fold_expr(builder, expr.right) |
| if isinstance(left, int) and isinstance(right, int): |
| return constant_fold_binary_int_op(expr.op, left, right) |
| elif isinstance(left, str) and isinstance(right, str): |
| return constant_fold_binary_str_op(expr.op, left, right) |
| elif isinstance(expr, UnaryExpr): |
| value = constant_fold_expr(builder, expr.expr) |
| if isinstance(value, int): |
| return constant_fold_unary_int_op(expr.op, value) |
| return None |
| |
| |
| def constant_fold_binary_int_op(op: str, left: int, right: int) -> Optional[int]: |
| if op == '+': |
| return left + right |
| if op == '-': |
| return left - right |
| elif op == '*': |
| return left * right |
| elif op == '//': |
| if right != 0: |
| return left // right |
| elif op == '%': |
| if right != 0: |
| return left % right |
| elif op == '&': |
| return left & right |
| elif op == '|': |
| return left | right |
| elif op == '^': |
| return left ^ right |
| elif op == '<<': |
| if right >= 0: |
| return left << right |
| elif op == '>>': |
| if right >= 0: |
| return left >> right |
| elif op == '**': |
| if right >= 0: |
| return left ** right |
| return None |
| |
| |
| def constant_fold_unary_int_op(op: str, value: int) -> Optional[int]: |
| if op == '-': |
| return -value |
| elif op == '~': |
| return ~value |
| elif op == '+': |
| return value |
| return None |
| |
| |
| def constant_fold_binary_str_op(op: str, left: str, right: str) -> Optional[str]: |
| if op == '+': |
| return left + right |
| return None |