| # RUN: %PYTHON %s | FileCheck %s |
| |
| import gc |
| from mlir.ir import * |
| |
| def run(f): |
| print("\nTEST:", f.__name__) |
| f() |
| gc.collect() |
| assert Context._get_live_count() == 0 |
| |
| |
| # CHECK-LABEL: TEST: testParsePrint |
| def testParsePrint(): |
| ctx = Context() |
| t = Type.parse("i32", ctx) |
| assert t.context is ctx |
| ctx = None |
| gc.collect() |
| # CHECK: i32 |
| print(str(t)) |
| # CHECK: Type(i32) |
| print(repr(t)) |
| |
| run(testParsePrint) |
| |
| |
| # CHECK-LABEL: TEST: testParseError |
| # TODO: Hook the diagnostic manager to capture a more meaningful error |
| # message. |
| def testParseError(): |
| ctx = Context() |
| try: |
| t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx) |
| except ValueError as e: |
| # CHECK: Unable to parse type: 'BAD_TYPE_DOES_NOT_EXIST' |
| print("testParseError:", e) |
| else: |
| print("Exception not produced") |
| |
| run(testParseError) |
| |
| |
| # CHECK-LABEL: TEST: testTypeEq |
| def testTypeEq(): |
| ctx = Context() |
| t1 = Type.parse("i32", ctx) |
| t2 = Type.parse("f32", ctx) |
| t3 = Type.parse("i32", ctx) |
| # CHECK: t1 == t1: True |
| print("t1 == t1:", t1 == t1) |
| # CHECK: t1 == t2: False |
| print("t1 == t2:", t1 == t2) |
| # CHECK: t1 == t3: True |
| print("t1 == t3:", t1 == t3) |
| # CHECK: t1 == None: False |
| print("t1 == None:", t1 == None) |
| |
| run(testTypeEq) |
| |
| |
| # CHECK-LABEL: TEST: testTypeEqDoesNotRaise |
| def testTypeEqDoesNotRaise(): |
| ctx = Context() |
| t1 = Type.parse("i32", ctx) |
| not_a_type = "foo" |
| # CHECK: False |
| print(t1 == not_a_type) |
| # CHECK: False |
| print(t1 == None) |
| # CHECK: True |
| print(t1 != None) |
| |
| run(testTypeEqDoesNotRaise) |
| |
| |
| # CHECK-LABEL: TEST: testTypeCapsule |
| def testTypeCapsule(): |
| with Context() as ctx: |
| t1 = Type.parse("i32", ctx) |
| # CHECK: mlir.ir.Type._CAPIPtr |
| type_capsule = t1._CAPIPtr |
| print(type_capsule) |
| t2 = Type._CAPICreate(type_capsule) |
| assert t2 == t1 |
| assert t2.context is ctx |
| |
| run(testTypeCapsule) |
| |
| |
| # CHECK-LABEL: TEST: testStandardTypeCasts |
| def testStandardTypeCasts(): |
| ctx = Context() |
| t1 = Type.parse("i32", ctx) |
| tint = IntegerType(t1) |
| tself = IntegerType(tint) |
| # CHECK: Type(i32) |
| print(repr(tint)) |
| try: |
| tillegal = IntegerType(Type.parse("f32", ctx)) |
| except ValueError as e: |
| # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32)) |
| print("ValueError:", e) |
| else: |
| print("Exception not produced") |
| |
| run(testStandardTypeCasts) |
| |
| |
| # CHECK-LABEL: TEST: testIntegerType |
| def testIntegerType(): |
| with Context() as ctx: |
| i32 = IntegerType(Type.parse("i32")) |
| # CHECK: i32 width: 32 |
| print("i32 width:", i32.width) |
| # CHECK: i32 signless: True |
| print("i32 signless:", i32.is_signless) |
| # CHECK: i32 signed: False |
| print("i32 signed:", i32.is_signed) |
| # CHECK: i32 unsigned: False |
| print("i32 unsigned:", i32.is_unsigned) |
| |
| s32 = IntegerType(Type.parse("si32")) |
| # CHECK: s32 signless: False |
| print("s32 signless:", s32.is_signless) |
| # CHECK: s32 signed: True |
| print("s32 signed:", s32.is_signed) |
| # CHECK: s32 unsigned: False |
| print("s32 unsigned:", s32.is_unsigned) |
| |
| u32 = IntegerType(Type.parse("ui32")) |
| # CHECK: u32 signless: False |
| print("u32 signless:", u32.is_signless) |
| # CHECK: u32 signed: False |
| print("u32 signed:", u32.is_signed) |
| # CHECK: u32 unsigned: True |
| print("u32 unsigned:", u32.is_unsigned) |
| |
| # CHECK: signless: i16 |
| print("signless:", IntegerType.get_signless(16)) |
| # CHECK: signed: si8 |
| print("signed:", IntegerType.get_signed(8)) |
| # CHECK: unsigned: ui64 |
| print("unsigned:", IntegerType.get_unsigned(64)) |
| |
| run(testIntegerType) |
| |
| # CHECK-LABEL: TEST: testIndexType |
| def testIndexType(): |
| with Context() as ctx: |
| # CHECK: index type: index |
| print("index type:", IndexType.get()) |
| |
| run(testIndexType) |
| |
| # CHECK-LABEL: TEST: testFloatType |
| def testFloatType(): |
| with Context(): |
| # CHECK: float: bf16 |
| print("float:", BF16Type.get()) |
| # CHECK: float: f16 |
| print("float:", F16Type.get()) |
| # CHECK: float: f32 |
| print("float:", F32Type.get()) |
| # CHECK: float: f64 |
| print("float:", F64Type.get()) |
| |
| run(testFloatType) |
| |
| # CHECK-LABEL: TEST: testNoneType |
| def testNoneType(): |
| with Context(): |
| # CHECK: none type: none |
| print("none type:", NoneType.get()) |
| |
| run(testNoneType) |
| |
| # CHECK-LABEL: TEST: testComplexType |
| def testComplexType(): |
| with Context() as ctx: |
| complex_i32 = ComplexType(Type.parse("complex<i32>")) |
| # CHECK: complex type element: i32 |
| print("complex type element:", complex_i32.element_type) |
| |
| f32 = F32Type.get() |
| # CHECK: complex type: complex<f32> |
| print("complex type:", ComplexType.get(f32)) |
| |
| index = IndexType.get() |
| try: |
| complex_invalid = ComplexType.get(index) |
| except ValueError as e: |
| # CHECK: invalid 'Type(index)' and expected floating point or integer type. |
| print(e) |
| else: |
| print("Exception not produced") |
| |
| run(testComplexType) |
| |
| |
| # CHECK-LABEL: TEST: testConcreteShapedType |
| # Shaped type is not a kind of builtin types, it is the base class for vectors, |
| # memrefs and tensors, so this test case uses an instance of vector to test the |
| # shaped type. The class hierarchy is preserved on the python side. |
| def testConcreteShapedType(): |
| with Context() as ctx: |
| vector = VectorType(Type.parse("vector<2x3xf32>")) |
| # CHECK: element type: f32 |
| print("element type:", vector.element_type) |
| # CHECK: whether the given shaped type is ranked: True |
| print("whether the given shaped type is ranked:", vector.has_rank) |
| # CHECK: rank: 2 |
| print("rank:", vector.rank) |
| # CHECK: whether the shaped type has a static shape: True |
| print("whether the shaped type has a static shape:", vector.has_static_shape) |
| # CHECK: whether the dim-th dimension is dynamic: False |
| print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0)) |
| # CHECK: dim size: 3 |
| print("dim size:", vector.get_dim_size(1)) |
| # CHECK: is_dynamic_size: False |
| print("is_dynamic_size:", vector.is_dynamic_size(3)) |
| # CHECK: is_dynamic_stride_or_offset: False |
| print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1)) |
| # CHECK: isinstance(ShapedType): True |
| print("isinstance(ShapedType):", isinstance(vector, ShapedType)) |
| |
| run(testConcreteShapedType) |
| |
| # CHECK-LABEL: TEST: testAbstractShapedType |
| # Tests that ShapedType operates as an abstract base class of a concrete |
| # shaped type (using vector as an example). |
| def testAbstractShapedType(): |
| ctx = Context() |
| vector = ShapedType(Type.parse("vector<2x3xf32>", ctx)) |
| # CHECK: element type: f32 |
| print("element type:", vector.element_type) |
| |
| run(testAbstractShapedType) |
| |
| # CHECK-LABEL: TEST: testVectorType |
| def testVectorType(): |
| with Context(), Location.unknown(): |
| f32 = F32Type.get() |
| shape = [2, 3] |
| # CHECK: vector type: vector<2x3xf32> |
| print("vector type:", VectorType.get(shape, f32)) |
| |
| none = NoneType.get() |
| try: |
| vector_invalid = VectorType.get(shape, none) |
| except ValueError as e: |
| # CHECK: invalid 'Type(none)' and expected floating point or integer type. |
| print(e) |
| else: |
| print("Exception not produced") |
| |
| run(testVectorType) |
| |
| # CHECK-LABEL: TEST: testRankedTensorType |
| def testRankedTensorType(): |
| with Context(), Location.unknown(): |
| f32 = F32Type.get() |
| shape = [2, 3] |
| loc = Location.unknown() |
| # CHECK: ranked tensor type: tensor<2x3xf32> |
| print("ranked tensor type:", |
| RankedTensorType.get(shape, f32)) |
| |
| none = NoneType.get() |
| try: |
| tensor_invalid = RankedTensorType.get(shape, none) |
| except ValueError as e: |
| # CHECK: invalid 'Type(none)' and expected floating point, integer, vector |
| # CHECK: or complex type. |
| print(e) |
| else: |
| print("Exception not produced") |
| |
| run(testRankedTensorType) |
| |
| # CHECK-LABEL: TEST: testUnrankedTensorType |
| def testUnrankedTensorType(): |
| with Context(), Location.unknown(): |
| f32 = F32Type.get() |
| loc = Location.unknown() |
| unranked_tensor = UnrankedTensorType.get(f32) |
| # CHECK: unranked tensor type: tensor<*xf32> |
| print("unranked tensor type:", unranked_tensor) |
| try: |
| invalid_rank = unranked_tensor.rank |
| except ValueError as e: |
| # CHECK: calling this method requires that the type has a rank. |
| print(e) |
| else: |
| print("Exception not produced") |
| try: |
| invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0) |
| except ValueError as e: |
| # CHECK: calling this method requires that the type has a rank. |
| print(e) |
| else: |
| print("Exception not produced") |
| try: |
| invalid_get_dim_size = unranked_tensor.get_dim_size(1) |
| except ValueError as e: |
| # CHECK: calling this method requires that the type has a rank. |
| print(e) |
| else: |
| print("Exception not produced") |
| |
| none = NoneType.get() |
| try: |
| tensor_invalid = UnrankedTensorType.get(none) |
| except ValueError as e: |
| # CHECK: invalid 'Type(none)' and expected floating point, integer, vector |
| # CHECK: or complex type. |
| print(e) |
| else: |
| print("Exception not produced") |
| |
| run(testUnrankedTensorType) |
| |
| # CHECK-LABEL: TEST: testMemRefType |
| def testMemRefType(): |
| with Context(), Location.unknown(): |
| f32 = F32Type.get() |
| shape = [2, 3] |
| loc = Location.unknown() |
| memref = MemRefType.get(shape, f32, memory_space=2) |
| # CHECK: memref type: memref<2x3xf32, 2> |
| print("memref type:", memref) |
| # CHECK: number of affine layout maps: 0 |
| print("number of affine layout maps:", len(memref.layout)) |
| # CHECK: memory space: 2 |
| print("memory space:", memref.memory_space) |
| |
| layout = AffineMap.get_permutation([1, 0]) |
| memref_layout = MemRefType.get(shape, f32, [layout]) |
| # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>> |
| print("memref type:", memref_layout) |
| assert len(memref_layout.layout) == 1 |
| # CHECK: memref layout: (d0, d1) -> (d1, d0) |
| print("memref layout:", memref_layout.layout[0]) |
| # CHECK: memory space: 0 |
| print("memory space:", memref_layout.memory_space) |
| |
| none = NoneType.get() |
| try: |
| memref_invalid = MemRefType.get(shape, none) |
| except ValueError as e: |
| # CHECK: invalid 'Type(none)' and expected floating point, integer, vector |
| # CHECK: or complex type. |
| print(e) |
| else: |
| print("Exception not produced") |
| |
| run(testMemRefType) |
| |
| # CHECK-LABEL: TEST: testUnrankedMemRefType |
| def testUnrankedMemRefType(): |
| with Context(), Location.unknown(): |
| f32 = F32Type.get() |
| loc = Location.unknown() |
| unranked_memref = UnrankedMemRefType.get(f32, 2) |
| # CHECK: unranked memref type: memref<*xf32, 2> |
| print("unranked memref type:", unranked_memref) |
| try: |
| invalid_rank = unranked_memref.rank |
| except ValueError as e: |
| # CHECK: calling this method requires that the type has a rank. |
| print(e) |
| else: |
| print("Exception not produced") |
| try: |
| invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0) |
| except ValueError as e: |
| # CHECK: calling this method requires that the type has a rank. |
| print(e) |
| else: |
| print("Exception not produced") |
| try: |
| invalid_get_dim_size = unranked_memref.get_dim_size(1) |
| except ValueError as e: |
| # CHECK: calling this method requires that the type has a rank. |
| print(e) |
| else: |
| print("Exception not produced") |
| |
| none = NoneType.get() |
| try: |
| memref_invalid = UnrankedMemRefType.get(none, 2) |
| except ValueError as e: |
| # CHECK: invalid 'Type(none)' and expected floating point, integer, vector |
| # CHECK: or complex type. |
| print(e) |
| else: |
| print("Exception not produced") |
| |
| run(testUnrankedMemRefType) |
| |
| # CHECK-LABEL: TEST: testTupleType |
| def testTupleType(): |
| with Context() as ctx: |
| i32 = IntegerType(Type.parse("i32")) |
| f32 = F32Type.get() |
| vector = VectorType(Type.parse("vector<2x3xf32>")) |
| l = [i32, f32, vector] |
| tuple_type = TupleType.get_tuple(l) |
| # CHECK: tuple type: tuple<i32, f32, vector<2x3xf32>> |
| print("tuple type:", tuple_type) |
| # CHECK: number of types: 3 |
| print("number of types:", tuple_type.num_types) |
| # CHECK: pos-th type in the tuple type: f32 |
| print("pos-th type in the tuple type:", tuple_type.get_type(1)) |
| |
| run(testTupleType) |
| |
| |
| # CHECK-LABEL: TEST: testFunctionType |
| def testFunctionType(): |
| with Context() as ctx: |
| input_types = [IntegerType.get_signless(32), |
| IntegerType.get_signless(16)] |
| result_types = [IndexType.get()] |
| func = FunctionType.get(input_types, result_types) |
| # CHECK: INPUTS: [Type(i32), Type(i16)] |
| print("INPUTS:", func.inputs) |
| # CHECK: RESULTS: [Type(index)] |
| print("RESULTS:", func.results) |
| |
| |
| run(testFunctionType) |