blob: 7402c644a1c186f977e5c18e47c6892ff80088fb [file] [log] [blame]
# RUN: %PYTHON %s | FileCheck %s
import gc
from mlir.ir import *
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
# CHECK-LABEL: TEST: testParsePrint
def testParsePrint():
ctx = Context()
t = Type.parse("i32", ctx)
assert t.context is ctx
ctx = None
gc.collect()
# CHECK: i32
print(str(t))
# CHECK: Type(i32)
print(repr(t))
run(testParsePrint)
# CHECK-LABEL: TEST: testParseError
# TODO: Hook the diagnostic manager to capture a more meaningful error
# message.
def testParseError():
ctx = Context()
try:
t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx)
except ValueError as e:
# CHECK: Unable to parse type: 'BAD_TYPE_DOES_NOT_EXIST'
print("testParseError:", e)
else:
print("Exception not produced")
run(testParseError)
# CHECK-LABEL: TEST: testTypeEq
def testTypeEq():
ctx = Context()
t1 = Type.parse("i32", ctx)
t2 = Type.parse("f32", ctx)
t3 = Type.parse("i32", ctx)
# CHECK: t1 == t1: True
print("t1 == t1:", t1 == t1)
# CHECK: t1 == t2: False
print("t1 == t2:", t1 == t2)
# CHECK: t1 == t3: True
print("t1 == t3:", t1 == t3)
# CHECK: t1 == None: False
print("t1 == None:", t1 == None)
run(testTypeEq)
# CHECK-LABEL: TEST: testTypeEqDoesNotRaise
def testTypeEqDoesNotRaise():
ctx = Context()
t1 = Type.parse("i32", ctx)
not_a_type = "foo"
# CHECK: False
print(t1 == not_a_type)
# CHECK: False
print(t1 == None)
# CHECK: True
print(t1 != None)
run(testTypeEqDoesNotRaise)
# CHECK-LABEL: TEST: testTypeCapsule
def testTypeCapsule():
with Context() as ctx:
t1 = Type.parse("i32", ctx)
# CHECK: mlir.ir.Type._CAPIPtr
type_capsule = t1._CAPIPtr
print(type_capsule)
t2 = Type._CAPICreate(type_capsule)
assert t2 == t1
assert t2.context is ctx
run(testTypeCapsule)
# CHECK-LABEL: TEST: testStandardTypeCasts
def testStandardTypeCasts():
ctx = Context()
t1 = Type.parse("i32", ctx)
tint = IntegerType(t1)
tself = IntegerType(tint)
# CHECK: Type(i32)
print(repr(tint))
try:
tillegal = IntegerType(Type.parse("f32", ctx))
except ValueError as e:
# CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32))
print("ValueError:", e)
else:
print("Exception not produced")
run(testStandardTypeCasts)
# CHECK-LABEL: TEST: testIntegerType
def testIntegerType():
with Context() as ctx:
i32 = IntegerType(Type.parse("i32"))
# CHECK: i32 width: 32
print("i32 width:", i32.width)
# CHECK: i32 signless: True
print("i32 signless:", i32.is_signless)
# CHECK: i32 signed: False
print("i32 signed:", i32.is_signed)
# CHECK: i32 unsigned: False
print("i32 unsigned:", i32.is_unsigned)
s32 = IntegerType(Type.parse("si32"))
# CHECK: s32 signless: False
print("s32 signless:", s32.is_signless)
# CHECK: s32 signed: True
print("s32 signed:", s32.is_signed)
# CHECK: s32 unsigned: False
print("s32 unsigned:", s32.is_unsigned)
u32 = IntegerType(Type.parse("ui32"))
# CHECK: u32 signless: False
print("u32 signless:", u32.is_signless)
# CHECK: u32 signed: False
print("u32 signed:", u32.is_signed)
# CHECK: u32 unsigned: True
print("u32 unsigned:", u32.is_unsigned)
# CHECK: signless: i16
print("signless:", IntegerType.get_signless(16))
# CHECK: signed: si8
print("signed:", IntegerType.get_signed(8))
# CHECK: unsigned: ui64
print("unsigned:", IntegerType.get_unsigned(64))
run(testIntegerType)
# CHECK-LABEL: TEST: testIndexType
def testIndexType():
with Context() as ctx:
# CHECK: index type: index
print("index type:", IndexType.get())
run(testIndexType)
# CHECK-LABEL: TEST: testFloatType
def testFloatType():
with Context():
# CHECK: float: bf16
print("float:", BF16Type.get())
# CHECK: float: f16
print("float:", F16Type.get())
# CHECK: float: f32
print("float:", F32Type.get())
# CHECK: float: f64
print("float:", F64Type.get())
run(testFloatType)
# CHECK-LABEL: TEST: testNoneType
def testNoneType():
with Context():
# CHECK: none type: none
print("none type:", NoneType.get())
run(testNoneType)
# CHECK-LABEL: TEST: testComplexType
def testComplexType():
with Context() as ctx:
complex_i32 = ComplexType(Type.parse("complex<i32>"))
# CHECK: complex type element: i32
print("complex type element:", complex_i32.element_type)
f32 = F32Type.get()
# CHECK: complex type: complex<f32>
print("complex type:", ComplexType.get(f32))
index = IndexType.get()
try:
complex_invalid = ComplexType.get(index)
except ValueError as e:
# CHECK: invalid 'Type(index)' and expected floating point or integer type.
print(e)
else:
print("Exception not produced")
run(testComplexType)
# CHECK-LABEL: TEST: testConcreteShapedType
# Shaped type is not a kind of builtin types, it is the base class for vectors,
# memrefs and tensors, so this test case uses an instance of vector to test the
# shaped type. The class hierarchy is preserved on the python side.
def testConcreteShapedType():
with Context() as ctx:
vector = VectorType(Type.parse("vector<2x3xf32>"))
# CHECK: element type: f32
print("element type:", vector.element_type)
# CHECK: whether the given shaped type is ranked: True
print("whether the given shaped type is ranked:", vector.has_rank)
# CHECK: rank: 2
print("rank:", vector.rank)
# CHECK: whether the shaped type has a static shape: True
print("whether the shaped type has a static shape:", vector.has_static_shape)
# CHECK: whether the dim-th dimension is dynamic: False
print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0))
# CHECK: dim size: 3
print("dim size:", vector.get_dim_size(1))
# CHECK: is_dynamic_size: False
print("is_dynamic_size:", vector.is_dynamic_size(3))
# CHECK: is_dynamic_stride_or_offset: False
print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1))
# CHECK: isinstance(ShapedType): True
print("isinstance(ShapedType):", isinstance(vector, ShapedType))
run(testConcreteShapedType)
# CHECK-LABEL: TEST: testAbstractShapedType
# Tests that ShapedType operates as an abstract base class of a concrete
# shaped type (using vector as an example).
def testAbstractShapedType():
ctx = Context()
vector = ShapedType(Type.parse("vector<2x3xf32>", ctx))
# CHECK: element type: f32
print("element type:", vector.element_type)
run(testAbstractShapedType)
# CHECK-LABEL: TEST: testVectorType
def testVectorType():
with Context(), Location.unknown():
f32 = F32Type.get()
shape = [2, 3]
# CHECK: vector type: vector<2x3xf32>
print("vector type:", VectorType.get(shape, f32))
none = NoneType.get()
try:
vector_invalid = VectorType.get(shape, none)
except ValueError as e:
# CHECK: invalid 'Type(none)' and expected floating point or integer type.
print(e)
else:
print("Exception not produced")
run(testVectorType)
# CHECK-LABEL: TEST: testRankedTensorType
def testRankedTensorType():
with Context(), Location.unknown():
f32 = F32Type.get()
shape = [2, 3]
loc = Location.unknown()
# CHECK: ranked tensor type: tensor<2x3xf32>
print("ranked tensor type:",
RankedTensorType.get(shape, f32))
none = NoneType.get()
try:
tensor_invalid = RankedTensorType.get(shape, none)
except ValueError as e:
# CHECK: invalid 'Type(none)' and expected floating point, integer, vector
# CHECK: or complex type.
print(e)
else:
print("Exception not produced")
run(testRankedTensorType)
# CHECK-LABEL: TEST: testUnrankedTensorType
def testUnrankedTensorType():
with Context(), Location.unknown():
f32 = F32Type.get()
loc = Location.unknown()
unranked_tensor = UnrankedTensorType.get(f32)
# CHECK: unranked tensor type: tensor<*xf32>
print("unranked tensor type:", unranked_tensor)
try:
invalid_rank = unranked_tensor.rank
except ValueError as e:
# CHECK: calling this method requires that the type has a rank.
print(e)
else:
print("Exception not produced")
try:
invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0)
except ValueError as e:
# CHECK: calling this method requires that the type has a rank.
print(e)
else:
print("Exception not produced")
try:
invalid_get_dim_size = unranked_tensor.get_dim_size(1)
except ValueError as e:
# CHECK: calling this method requires that the type has a rank.
print(e)
else:
print("Exception not produced")
none = NoneType.get()
try:
tensor_invalid = UnrankedTensorType.get(none)
except ValueError as e:
# CHECK: invalid 'Type(none)' and expected floating point, integer, vector
# CHECK: or complex type.
print(e)
else:
print("Exception not produced")
run(testUnrankedTensorType)
# CHECK-LABEL: TEST: testMemRefType
def testMemRefType():
with Context(), Location.unknown():
f32 = F32Type.get()
shape = [2, 3]
loc = Location.unknown()
memref = MemRefType.get(shape, f32, memory_space=2)
# CHECK: memref type: memref<2x3xf32, 2>
print("memref type:", memref)
# CHECK: number of affine layout maps: 0
print("number of affine layout maps:", len(memref.layout))
# CHECK: memory space: 2
print("memory space:", memref.memory_space)
layout = AffineMap.get_permutation([1, 0])
memref_layout = MemRefType.get(shape, f32, [layout])
# CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
print("memref type:", memref_layout)
assert len(memref_layout.layout) == 1
# CHECK: memref layout: (d0, d1) -> (d1, d0)
print("memref layout:", memref_layout.layout[0])
# CHECK: memory space: 0
print("memory space:", memref_layout.memory_space)
none = NoneType.get()
try:
memref_invalid = MemRefType.get(shape, none)
except ValueError as e:
# CHECK: invalid 'Type(none)' and expected floating point, integer, vector
# CHECK: or complex type.
print(e)
else:
print("Exception not produced")
run(testMemRefType)
# CHECK-LABEL: TEST: testUnrankedMemRefType
def testUnrankedMemRefType():
with Context(), Location.unknown():
f32 = F32Type.get()
loc = Location.unknown()
unranked_memref = UnrankedMemRefType.get(f32, 2)
# CHECK: unranked memref type: memref<*xf32, 2>
print("unranked memref type:", unranked_memref)
try:
invalid_rank = unranked_memref.rank
except ValueError as e:
# CHECK: calling this method requires that the type has a rank.
print(e)
else:
print("Exception not produced")
try:
invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0)
except ValueError as e:
# CHECK: calling this method requires that the type has a rank.
print(e)
else:
print("Exception not produced")
try:
invalid_get_dim_size = unranked_memref.get_dim_size(1)
except ValueError as e:
# CHECK: calling this method requires that the type has a rank.
print(e)
else:
print("Exception not produced")
none = NoneType.get()
try:
memref_invalid = UnrankedMemRefType.get(none, 2)
except ValueError as e:
# CHECK: invalid 'Type(none)' and expected floating point, integer, vector
# CHECK: or complex type.
print(e)
else:
print("Exception not produced")
run(testUnrankedMemRefType)
# CHECK-LABEL: TEST: testTupleType
def testTupleType():
with Context() as ctx:
i32 = IntegerType(Type.parse("i32"))
f32 = F32Type.get()
vector = VectorType(Type.parse("vector<2x3xf32>"))
l = [i32, f32, vector]
tuple_type = TupleType.get_tuple(l)
# CHECK: tuple type: tuple<i32, f32, vector<2x3xf32>>
print("tuple type:", tuple_type)
# CHECK: number of types: 3
print("number of types:", tuple_type.num_types)
# CHECK: pos-th type in the tuple type: f32
print("pos-th type in the tuple type:", tuple_type.get_type(1))
run(testTupleType)
# CHECK-LABEL: TEST: testFunctionType
def testFunctionType():
with Context() as ctx:
input_types = [IntegerType.get_signless(32),
IntegerType.get_signless(16)]
result_types = [IndexType.get()]
func = FunctionType.get(input_types, result_types)
# CHECK: INPUTS: [Type(i32), Type(i16)]
print("INPUTS:", func.inputs)
# CHECK: RESULTS: [Type(index)]
print("RESULTS:", func.results)
run(testFunctionType)