blob: bbdd670e6aa769e25a8c44355ac6bc1d8d29a855 [file] [log] [blame]
# RUN: %PYTHON %s | FileCheck %s
import gc
import io
import itertools
from mlir.ir import *
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
# CHECK-LABEL: TEST: test_insert_at_block_end
def test_insert_at_block_end():
ctx = Context()
ctx.allow_unregistered_dialects = True
loc = ctx.get_unknown_location()
module = ctx.parse_module(r"""
func @foo() -> () {
"custom.op1"() : () -> ()
}
""")
entry_block = module.body.operations[0].regions[0].blocks[0]
ip = InsertionPoint(entry_block)
ip.insert(ctx.create_operation("custom.op2", loc))
# CHECK: "custom.op1"
# CHECK: "custom.op2"
module.operation.print()
run(test_insert_at_block_end)
# CHECK-LABEL: TEST: test_insert_before_operation
def test_insert_before_operation():
ctx = Context()
ctx.allow_unregistered_dialects = True
loc = ctx.get_unknown_location()
module = ctx.parse_module(r"""
func @foo() -> () {
"custom.op1"() : () -> ()
"custom.op2"() : () -> ()
}
""")
entry_block = module.body.operations[0].regions[0].blocks[0]
ip = InsertionPoint(entry_block.operations[1])
ip.insert(ctx.create_operation("custom.op3", loc))
# CHECK: "custom.op1"
# CHECK: "custom.op3"
# CHECK: "custom.op2"
module.operation.print()
run(test_insert_before_operation)
# CHECK-LABEL: TEST: test_insert_at_block_begin
def test_insert_at_block_begin():
ctx = Context()
ctx.allow_unregistered_dialects = True
loc = ctx.get_unknown_location()
module = ctx.parse_module(r"""
func @foo() -> () {
"custom.op2"() : () -> ()
}
""")
entry_block = module.body.operations[0].regions[0].blocks[0]
ip = InsertionPoint.at_block_begin(entry_block)
ip.insert(ctx.create_operation("custom.op1", loc))
# CHECK: "custom.op1"
# CHECK: "custom.op2"
module.operation.print()
run(test_insert_at_block_begin)
# CHECK-LABEL: TEST: test_insert_at_block_begin_empty
def test_insert_at_block_begin_empty():
# TODO: Write this test case when we can create such a situation.
pass
run(test_insert_at_block_begin_empty)
# CHECK-LABEL: TEST: test_insert_at_terminator
def test_insert_at_terminator():
ctx = Context()
ctx.allow_unregistered_dialects = True
loc = ctx.get_unknown_location()
module = ctx.parse_module(r"""
func @foo() -> () {
"custom.op1"() : () -> ()
return
}
""")
entry_block = module.body.operations[0].regions[0].blocks[0]
ip = InsertionPoint.at_block_terminator(entry_block)
ip.insert(ctx.create_operation("custom.op2", loc))
# CHECK: "custom.op1"
# CHECK: "custom.op2"
module.operation.print()
run(test_insert_at_terminator)
# CHECK-LABEL: TEST: test_insert_at_block_terminator_missing
def test_insert_at_block_terminator_missing():
ctx = Context()
ctx.allow_unregistered_dialects = True
loc = ctx.get_unknown_location()
module = ctx.parse_module(r"""
func @foo() -> () {
"custom.op1"() : () -> ()
}
""")
entry_block = module.body.operations[0].regions[0].blocks[0]
try:
ip = InsertionPoint.at_block_terminator(entry_block)
except ValueError as e:
# CHECK: Block has no terminator
print(e)
else:
assert False, "Expected exception"
run(test_insert_at_block_terminator_missing)
# CHECK-LABEL: TEST: test_insertion_point_context
def test_insertion_point_context():
ctx = Context()
ctx.allow_unregistered_dialects = True
loc = ctx.get_unknown_location()
module = ctx.parse_module(r"""
func @foo() -> () {
"custom.op1"() : () -> ()
}
""")
entry_block = module.body.operations[0].regions[0].blocks[0]
with InsertionPoint(entry_block):
ctx.create_operation("custom.op2", loc)
with InsertionPoint.at_block_begin(entry_block):
ctx.create_operation("custom.opa", loc)
ctx.create_operation("custom.opb", loc)
ctx.create_operation("custom.op3", loc)
# CHECK: "custom.opa"
# CHECK: "custom.opb"
# CHECK: "custom.op1"
# CHECK: "custom.op2"
# CHECK: "custom.op3"
module.operation.print()
run(test_insertion_point_context)