blob: 50c42102f66d33c585c3c8b21621c10319526fe8 [file] [log] [blame] [edit]
# RUN: %PYTHON %s 2>&1 | FileCheck %s
import gc, sys
from mlir.ir import *
from mlir.passmanager import *
from mlir.dialects.builtin import ModuleOp
from mlir.dialects import pdl
from mlir.rewrite import *
def log(*args):
print(*args, file=sys.stderr)
sys.stderr.flush()
def run(f):
log("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
def make_pdl_module():
with Location.unknown():
pdl_module = Module.create()
with InsertionPoint(pdl_module.body):
# Change all arith.addi with index types to arith.muli.
@pdl.pattern(benefit=1, sym_name="addi_to_mul")
def pat():
# Match arith.addi with index types.
i64_type = pdl.TypeOp(IntegerType.get_signless(64))
operand0 = pdl.OperandOp(i64_type)
operand1 = pdl.OperandOp(i64_type)
op0 = pdl.OperationOp(
name="arith.addi", args=[operand0, operand1], types=[i64_type]
)
# Replace the matched op with arith.muli.
@pdl.rewrite()
def rew():
newOp = pdl.OperationOp(
name="arith.muli", args=[operand0, operand1], types=[i64_type]
)
pdl.ReplaceOp(op0, with_op=newOp)
return pdl_module
# CHECK-LABEL: TEST: testCustomPass
@run
def testCustomPass():
with Context():
pdl_module = make_pdl_module()
frozen = PDLModule(pdl_module).freeze()
module = ModuleOp.parse(
r"""
module {
func.func @add(%a: i64, %b: i64) -> i64 {
%sum = arith.addi %a, %b : i64
return %sum : i64
}
}
"""
)
def custom_pass_1(op, pass_):
print("hello from pass 1!!!", file=sys.stderr)
class CustomPass2:
def __call__(self, op, pass_):
apply_patterns_and_fold_greedily(op, frozen)
custom_pass_2 = CustomPass2()
pm = PassManager("any")
pm.enable_ir_printing()
# CHECK: hello from pass 1!!!
# CHECK-LABEL: Dump After custom_pass_1
pm.add(custom_pass_1)
# CHECK-LABEL: Dump After CustomPass2
# CHECK: arith.muli
pm.add(custom_pass_2, "CustomPass2")
# CHECK-LABEL: Dump After ArithToLLVMConversionPass
# CHECK: llvm.mul
pm.add("convert-arith-to-llvm")
pm.run(module)
# test signal_pass_failure
def custom_pass_that_fails(op, pass_):
print("hello from pass that fails")
pass_.signal_pass_failure()
pm = PassManager("any")
pm.add(custom_pass_that_fails, "CustomPassThatFails")
# CHECK: hello from pass that fails
# CHECK: caught exception: Failure while executing pass pipeline
try:
pm.run(module)
except Exception as e:
print(f"caught exception: {e}")