| """Constant folding of IR values. |
| |
| For example, 3 + 5 can be constant folded into 8. |
| |
| This is mostly like mypy.constant_fold, but we can bind some additional |
| NameExpr and MemberExpr references here, since we have more knowledge |
| about which definitions can be trusted -- we constant fold only references |
| to other compiled modules in the same compilation unit. |
| """ |
| |
| from __future__ import annotations |
| |
| from typing import Final, Union |
| |
| from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_op |
| from mypy.nodes import ( |
| BytesExpr, |
| ComplexExpr, |
| Expression, |
| FloatExpr, |
| IntExpr, |
| MemberExpr, |
| NameExpr, |
| OpExpr, |
| StrExpr, |
| UnaryExpr, |
| Var, |
| ) |
| from mypyc.irbuild.builder import IRBuilder |
| from mypyc.irbuild.util import bytes_from_str |
| |
| # All possible result types of constant folding |
| ConstantValue = Union[int, float, complex, str, bytes] |
| CONST_TYPES: Final = (int, float, complex, str, bytes) |
| |
| |
| def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | None: |
| """Return the constant value of an expression for supported operations. |
| |
| Return None otherwise. |
| """ |
| if isinstance(expr, IntExpr): |
| return expr.value |
| if isinstance(expr, FloatExpr): |
| return expr.value |
| if isinstance(expr, StrExpr): |
| return expr.value |
| if isinstance(expr, BytesExpr): |
| return bytes_from_str(expr.value) |
| if isinstance(expr, ComplexExpr): |
| return expr.value |
| elif isinstance(expr, NameExpr): |
| node = expr.node |
| if isinstance(node, Var) and node.is_final: |
| final_value = node.final_value |
| if isinstance(final_value, (CONST_TYPES)): |
| return final_value |
| elif isinstance(expr, MemberExpr): |
| final = builder.get_final_ref(expr) |
| if final is not None: |
| fn, final_var, native = final |
| if final_var.is_final: |
| final_value = final_var.final_value |
| if isinstance(final_value, (CONST_TYPES)): |
| return final_value |
| elif isinstance(expr, OpExpr): |
| left = constant_fold_expr(builder, expr.left) |
| right = constant_fold_expr(builder, expr.right) |
| if left is not None and right is not None: |
| return constant_fold_binary_op_extended(expr.op, left, right) |
| elif isinstance(expr, UnaryExpr): |
| value = constant_fold_expr(builder, expr.expr) |
| if value is not None and not isinstance(value, bytes): |
| return constant_fold_unary_op(expr.op, value) |
| return None |
| |
| |
| def constant_fold_binary_op_extended( |
| op: str, left: ConstantValue, right: ConstantValue |
| ) -> ConstantValue | None: |
| """Like mypy's constant_fold_binary_op(), but includes bytes support. |
| |
| mypy cannot use constant folded bytes easily so it's simpler to only support them in mypyc. |
| """ |
| if not isinstance(left, bytes) and not isinstance(right, bytes): |
| return constant_fold_binary_op(op, left, right) |
| |
| if op == "+" and isinstance(left, bytes) and isinstance(right, bytes): |
| return left + right |
| elif op == "*" and isinstance(left, bytes) and isinstance(right, int): |
| return left * right |
| elif op == "*" and isinstance(left, int) and isinstance(right, bytes): |
| return left * right |
| |
| return None |