blob: d1f3e6b4a61ad6d56577144640b2791f62492e6f [file] [log] [blame]
# RUN: %PYTHON %s | FileCheck %s
import gc
import mlir
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert mlir.ir.Context._get_live_count() == 0
# CHECK-LABEL: TEST: testParsePrint
def testParsePrint():
ctx = mlir.ir.Context()
t = ctx.parse_attr('"hello"')
assert t.context is ctx
ctx = None
gc.collect()
# CHECK: "hello"
print(str(t))
# CHECK: Attribute("hello")
print(repr(t))
run(testParsePrint)
# CHECK-LABEL: TEST: testParseError
# TODO: Hook the diagnostic manager to capture a more meaningful error
# message.
def testParseError():
ctx = mlir.ir.Context()
try:
t = ctx.parse_attr("BAD_ATTR_DOES_NOT_EXIST")
except ValueError as e:
# CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST'
print("testParseError:", e)
else:
print("Exception not produced")
run(testParseError)
# CHECK-LABEL: TEST: testAttrEq
def testAttrEq():
ctx = mlir.ir.Context()
a1 = ctx.parse_attr('"attr1"')
a2 = ctx.parse_attr('"attr2"')
a3 = ctx.parse_attr('"attr1"')
# CHECK: a1 == a1: True
print("a1 == a1:", a1 == a1)
# CHECK: a1 == a2: False
print("a1 == a2:", a1 == a2)
# CHECK: a1 == a3: True
print("a1 == a3:", a1 == a3)
# CHECK: a1 == None: False
print("a1 == None:", a1 == None)
run(testAttrEq)
# CHECK-LABEL: TEST: testAttrEqDoesNotRaise
def testAttrEqDoesNotRaise():
ctx = mlir.ir.Context()
a1 = ctx.parse_attr('"attr1"')
not_an_attr = "foo"
# CHECK: False
print(a1 == not_an_attr)
# CHECK: False
print(a1 == None)
# CHECK: True
print(a1 != None)
run(testAttrEqDoesNotRaise)
# CHECK-LABEL: TEST: testStandardAttrCasts
def testStandardAttrCasts():
ctx = mlir.ir.Context()
a1 = ctx.parse_attr('"attr1"')
astr = mlir.ir.StringAttr(a1)
aself = mlir.ir.StringAttr(astr)
# CHECK: Attribute("attr1")
print(repr(astr))
try:
tillegal = mlir.ir.StringAttr(ctx.parse_attr("1.0"))
except ValueError as e:
# CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64))
print("ValueError:", e)
else:
print("Exception not produced")
run(testStandardAttrCasts)
# CHECK-LABEL: TEST: testFloatAttr
def testFloatAttr():
ctx = mlir.ir.Context()
fattr = mlir.ir.FloatAttr(ctx.parse_attr("42.0 : f32"))
# CHECK: fattr value: 42.0
print("fattr value:", fattr.value)
# Test factory methods.
loc = ctx.get_unknown_location()
# CHECK: default_get: 4.200000e+01 : f32
print("default_get:", mlir.ir.FloatAttr.get(
mlir.ir.F32Type.get(ctx), 42.0, loc))
# CHECK: f32_get: 4.200000e+01 : f32
print("f32_get:", mlir.ir.FloatAttr.get_f32(ctx, 42.0))
# CHECK: f64_get: 4.200000e+01 : f64
print("f64_get:", mlir.ir.FloatAttr.get_f64(ctx, 42.0))
try:
fattr_invalid = mlir.ir.FloatAttr.get(
mlir.ir.IntegerType.get_signless(ctx, 32), 42, loc)
except ValueError as e:
# CHECK: invalid 'Type(i32)' and expected floating point type.
print(e)
else:
print("Exception not produced")
run(testFloatAttr)
# CHECK-LABEL: TEST: testIntegerAttr
def testIntegerAttr():
ctx = mlir.ir.Context()
iattr = mlir.ir.IntegerAttr(ctx.parse_attr("42"))
# CHECK: iattr value: 42
print("iattr value:", iattr.value)
# CHECK: iattr type: i64
print("iattr type:", iattr.type)
# Test factory methods.
# CHECK: default_get: 42 : i32
print("default_get:", mlir.ir.IntegerAttr.get(
mlir.ir.IntegerType.get_signless(ctx, 32), 42))
run(testIntegerAttr)
# CHECK-LABEL: TEST: testBoolAttr
def testBoolAttr():
ctx = mlir.ir.Context()
battr = mlir.ir.BoolAttr(ctx.parse_attr("true"))
# CHECK: iattr value: 1
print("iattr value:", battr.value)
# Test factory methods.
# CHECK: default_get: true
print("default_get:", mlir.ir.BoolAttr.get(ctx, True))
run(testBoolAttr)
# CHECK-LABEL: TEST: testStringAttr
def testStringAttr():
ctx = mlir.ir.Context()
sattr = mlir.ir.StringAttr(ctx.parse_attr('"stringattr"'))
# CHECK: sattr value: stringattr
print("sattr value:", sattr.value)
# Test factory methods.
# CHECK: default_get: "foobar"
print("default_get:", mlir.ir.StringAttr.get(ctx, "foobar"))
# CHECK: typed_get: "12345" : i32
print("typed_get:", mlir.ir.StringAttr.get_typed(
mlir.ir.IntegerType.get_signless(ctx, 32), "12345"))
run(testStringAttr)
# CHECK-LABEL: TEST: testNamedAttr
def testNamedAttr():
ctx = mlir.ir.Context()
a = ctx.parse_attr('"stringattr"')
named = a.get_named("foobar") # Note: under the small object threshold
# CHECK: attr: "stringattr"
print("attr:", named.attr)
# CHECK: name: foobar
print("name:", named.name)
# CHECK: named: NamedAttribute(foobar="stringattr")
print("named:", named)
run(testNamedAttr)