| # RUN: %PYTHON %s | FileCheck %s |
| |
| from mlir import ir |
| from mlir.dialects.transform import interpreter as interp |
| |
| |
| def test_in_context(f): |
| with ir.Context(), ir.Location.unknown(): |
| f() |
| return f |
| |
| |
| print_root_module = """ |
| module attributes {transform.with_named_sequence} { |
| transform.named_sequence @__transform_main(%root: !transform.any_op) { |
| transform.print %root { name = \"from interpreter\" }: !transform.any_op |
| transform.yield |
| } |
| }""" |
| |
| |
| @test_in_context |
| def print_self(): |
| m = ir.Module.parse(print_root_module.replace("from interpreter", "print_self")) |
| interp.apply_named_sequence(m, m.body.operations[0], m) |
| |
| |
| # CHECK-LABEL: print_self |
| # CHECK: transform.named_sequence @__transform_main |
| # CHECK: transform.print |
| # CHECK: transform.yield |
| |
| |
| @test_in_context |
| def print_other(): |
| transform = ir.Module.parse( |
| print_root_module.replace("from interpreter", "print_other") |
| ) |
| payload = ir.Module.parse("module attributes { this.is.payload } {}") |
| interp.apply_named_sequence(payload, transform.body.operations[0], transform) |
| |
| |
| # CHECK-LABEL: print_other |
| # CHECK-NOT: transform |
| # CHECK: this.is.payload |
| |
| |
| @test_in_context |
| def transform_options(): |
| options = interp.TransformOptions() |
| options.expensive_checks = False |
| options.enforce_single_top_level_transform_op = True |
| m = ir.Module.parse( |
| print_root_module.replace("from interpreter", "transform_options") |
| ) |
| payload = ir.Module.parse("module attributes { this.is.payload } {}") |
| interp.apply_named_sequence(payload, m.body.operations[0], m, options) |
| |
| |
| # CHECK-LABEL: transform_options |
| |
| |
| @test_in_context |
| def failed(): |
| payload = ir.Module.parse("module attributes { this.is.payload } {}") |
| try: |
| interp.apply_named_sequence(payload, payload, payload) |
| except ValueError as e: |
| assert ( |
| "must implement TransformOpInterface to be used as transform root" in str(e) |
| ) |
| |
| |
| print_root_via_include_module = """ |
| module @print_root_via_include_module attributes {transform.with_named_sequence} { |
| transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly}) |
| transform.named_sequence private @callee2(%root: !transform.any_op {transform.readonly}) |
| transform.named_sequence @__transform_main(%root: !transform.any_op) { |
| transform.include @callee2 failures(propagate) |
| (%root) : (!transform.any_op) -> () |
| transform.yield |
| } |
| }""" |
| |
| callee2_definition = """ |
| module attributes {transform.with_named_sequence} { |
| transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly}) |
| transform.named_sequence @callee2(%root: !transform.any_op {transform.readonly}) { |
| transform.include @callee1 failures(propagate) |
| (%root) : (!transform.any_op) -> () |
| transform.yield |
| } |
| } |
| """ |
| |
| callee1_definition = """ |
| module attributes {transform.with_named_sequence} { |
| transform.named_sequence @callee1(%root: !transform.any_op {transform.readonly}) { |
| transform.print %root { name = \"from interpreter\" }: !transform.any_op |
| transform.yield |
| } |
| } |
| """ |
| |
| |
| @test_in_context |
| def include(): |
| main = ir.Module.parse(print_root_via_include_module) |
| callee1 = ir.Module.parse(callee1_definition) |
| callee2 = ir.Module.parse(callee2_definition) |
| interp.copy_symbols_and_merge_into(main, callee1) |
| interp.copy_symbols_and_merge_into(main, callee2) |
| |
| # CHECK: @print_root_via_include_module |
| # CHECK: transform.named_sequence @__transform_main |
| # CHECK: transform.include @callee2 |
| # |
| # CHECK: transform.named_sequence @callee1 |
| # CHECK: transform.print |
| # |
| # CHECK: transform.named_sequence @callee2 |
| # CHECK: transform.include @callee1 |
| interp.apply_named_sequence(main, main.body.operations[0], main) |
| |
| |
| @test_in_context |
| def partial_include(): |
| main = ir.Module.parse(print_root_via_include_module) |
| callee2 = ir.Module.parse(callee2_definition) |
| interp.copy_symbols_and_merge_into(main, callee2) |
| |
| try: |
| interp.apply_named_sequence(main, main.body.operations[0], main) |
| except ValueError as e: |
| assert "Failed to apply" in str(e) |
| |
| |
| @test_in_context |
| def repeated_include(): |
| main = ir.Module.parse(print_root_via_include_module) |
| callee2 = ir.Module.parse(callee2_definition) |
| interp.copy_symbols_and_merge_into(main, callee2) |
| |
| try: |
| interp.copy_symbols_and_merge_into(main, callee2) |
| except ValueError as e: |
| assert "doubly defined symbol @callee2" in str(e) |