| # RUN: %PYTHON %s | FileCheck %s |
| |
| import mlir.dialects.arith as arith |
| import mlir.dialects.memref as memref |
| import mlir.extras.types as T |
| from mlir.dialects.memref import _infer_memref_subview_result_type |
| from mlir.ir import * |
| |
| |
| def run(f): |
| print("\nTEST:", f.__name__) |
| f() |
| return f |
| |
| |
| # CHECK-LABEL: TEST: testSubViewAccessors |
| @run |
| def testSubViewAccessors(): |
| ctx = Context() |
| module = Module.parse( |
| r""" |
| func.func @f1(%arg0: memref<?x?xf32>) { |
| %0 = arith.constant 0 : index |
| %1 = arith.constant 1 : index |
| %2 = arith.constant 2 : index |
| %3 = arith.constant 3 : index |
| %4 = arith.constant 4 : index |
| %5 = arith.constant 5 : index |
| memref.subview %arg0[%0, %1][%2, %3][%4, %5] : memref<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>> |
| return |
| } |
| """, |
| ctx, |
| ) |
| func_body = module.body.operations[0].regions[0].blocks[0] |
| subview = func_body.operations[6] |
| |
| assert subview.source == subview.operands[0] |
| assert len(subview.offsets) == 2 |
| assert len(subview.sizes) == 2 |
| assert len(subview.strides) == 2 |
| assert subview.result == subview.results[0] |
| |
| # CHECK: SubViewOp |
| print(type(subview).__name__) |
| |
| # CHECK: constant 0 |
| print(subview.offsets[0]) |
| # CHECK: constant 1 |
| print(subview.offsets[1]) |
| # CHECK: constant 2 |
| print(subview.sizes[0]) |
| # CHECK: constant 3 |
| print(subview.sizes[1]) |
| # CHECK: constant 4 |
| print(subview.strides[0]) |
| # CHECK: constant 5 |
| print(subview.strides[1]) |
| |
| |
| # CHECK-LABEL: TEST: testCustomBuidlers |
| @run |
| def testCustomBuidlers(): |
| with Context() as ctx, Location.unknown(ctx): |
| module = Module.parse( |
| r""" |
| func.func @f1(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) { |
| return |
| } |
| """ |
| ) |
| f = module.body.operations[0] |
| func_body = f.regions[0].blocks[0] |
| with InsertionPoint.at_block_terminator(func_body): |
| memref.LoadOp(f.arguments[0], f.arguments[1:]) |
| |
| # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) |
| # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] |
| print(module) |
| assert module.operation.verify() |
| |
| |
| # CHECK-LABEL: TEST: testMemRefAttr |
| @run |
| def testMemRefAttr(): |
| with Context() as ctx, Location.unknown(ctx): |
| module = Module.create() |
| with InsertionPoint(module.body): |
| memref.global_("objFifo_in0", T.memref(16, T.i32())) |
| # CHECK: memref.global @objFifo_in0 : memref<16xi32> |
| print(module) |
| |
| |
| # CHECK-LABEL: TEST: testSubViewOpInferReturnTypeSemantics |
| @run |
| def testSubViewOpInferReturnTypeSemantics(): |
| with Context() as ctx, Location.unknown(ctx): |
| module = Module.create() |
| with InsertionPoint(module.body): |
| x = memref.alloc(T.memref(10, 10, T.i32()), [], []) |
| # CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<10x10xi32> |
| print(x.owner) |
| |
| y = memref.subview(x, [1, 1], [3, 3], [1, 1]) |
| assert y.owner.verify() |
| # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>> |
| print(y.owner) |
| |
| z = memref.subview( |
| x, |
| [arith.constant(T.index(), 1), 1], |
| [3, 3], |
| [1, 1], |
| ) |
| # CHECK: %{{.*}} = memref.subview %[[ALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 11>> |
| print(z.owner) |
| |
| z = memref.subview( |
| x, |
| [arith.constant(T.index(), 3), arith.constant(T.index(), 4)], |
| [3, 3], |
| [1, 1], |
| ) |
| # CHECK: %{{.*}} = memref.subview %[[ALLOC]][3, 4] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: 34>> |
| print(z.owner) |
| |
| s = arith.addi(arith.constant(T.index(), 3), arith.constant(T.index(), 4)) |
| z = memref.subview( |
| x, |
| [s, 0], |
| [3, 3], |
| [1, 1], |
| ) |
| # CHECK: {{.*}} = memref.subview %[[ALLOC]][%0, 0] [3, 3] [1, 1] : memref<10x10xi32> to memref<3x3xi32, strided<[10, 1], offset: ?>> |
| print(z) |
| |
| try: |
| _infer_memref_subview_result_type( |
| x.type, |
| [arith.constant(T.index(), 3), arith.constant(T.index(), 4)], |
| [ShapedType.get_dynamic_size(), 3], |
| [1, 1], |
| ) |
| except ValueError as e: |
| # CHECK: Only inferring from python or mlir integer constant is supported |
| print(e) |
| |
| try: |
| memref.subview( |
| x, |
| [arith.constant(T.index(), 3), arith.constant(T.index(), 4)], |
| [ShapedType.get_dynamic_size(), 3], |
| [1, 1], |
| ) |
| except ValueError as e: |
| # CHECK: mixed static/dynamic offset/sizes/strides requires explicit result type |
| print(e) |
| |
| layout = StridedLayoutAttr.get(ShapedType.get_dynamic_size(), [10, 1]) |
| x = memref.alloc( |
| T.memref( |
| 10, |
| 10, |
| T.i32(), |
| layout=layout, |
| ), |
| [], |
| [arith.constant(T.index(), 42)], |
| ) |
| # CHECK: %[[DYNAMICALLOC:.*]] = memref.alloc()[%c42] : memref<10x10xi32, strided<[10, 1], offset: ?>> |
| print(x.owner) |
| y = memref.subview( |
| x, |
| [1, 1], |
| [3, 3], |
| [1, 1], |
| result_type=T.memref(3, 3, T.i32(), layout=layout), |
| ) |
| # CHECK: %{{.*}} = memref.subview %[[DYNAMICALLOC]][1, 1] [3, 3] [1, 1] : memref<10x10xi32, strided<[10, 1], offset: ?>> to memref<3x3xi32, strided<[10, 1], offset: ?>> |
| print(y.owner) |
| |
| |
| # CHECK-LABEL: TEST: testSubViewOpInferReturnTypeExtensiveSlicing |
| @run |
| def testSubViewOpInferReturnTypeExtensiveSlicing(): |
| def check_strides_offset(memref, np_view): |
| layout = memref.type.layout |
| dtype_size_in_bytes = np_view.dtype.itemsize |
| golden_strides = (np.array(np_view.strides) // dtype_size_in_bytes).tolist() |
| golden_offset = ( |
| np_view.ctypes.data - np_view.base.ctypes.data |
| ) // dtype_size_in_bytes |
| |
| assert (layout.strides, layout.offset) == (golden_strides, golden_offset) |
| |
| with Context() as ctx, Location.unknown(ctx): |
| module = Module.create() |
| with InsertionPoint(module.body): |
| shape = (10, 22, 3, 44) |
| golden_mem = np.zeros(shape, dtype=np.int32) |
| mem1 = memref.alloc(T.memref(*shape, T.i32()), [], []) |
| |
| # fmt: off |
| check_strides_offset(memref.subview(mem1, (1, 0, 0, 0), (1, 22, 3, 44), (1, 1, 1, 1)), golden_mem[1:2, ...]) |
| check_strides_offset(memref.subview(mem1, (0, 1, 0, 0), (10, 1, 3, 44), (1, 1, 1, 1)), golden_mem[:, 1:2]) |
| check_strides_offset(memref.subview(mem1, (0, 0, 1, 0), (10, 22, 1, 44), (1, 1, 1, 1)), golden_mem[:, :, 1:2]) |
| check_strides_offset(memref.subview(mem1, (0, 0, 0, 1), (10, 22, 3, 1), (1, 1, 1, 1)), golden_mem[:, :, :, 1:2]) |
| check_strides_offset(memref.subview(mem1, (0, 1, 0, 1), (10, 1, 3, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, :, 1:2]) |
| check_strides_offset(memref.subview(mem1, (1, 0, 0, 1), (1, 22, 3, 1), (1, 1, 1, 1)), golden_mem[1:2, :, :, 1:2]) |
| check_strides_offset(memref.subview(mem1, (1, 1, 0, 0), (1, 1, 3, 44), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, :]) |
| check_strides_offset(memref.subview(mem1, (0, 0, 1, 1), (10, 22, 1, 1), (1, 1, 1, 1)), golden_mem[:, :, 1:2, 1:2]) |
| check_strides_offset(memref.subview(mem1, (0, 1, 1, 0), (10, 1, 1, 44), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, :]) |
| check_strides_offset(memref.subview(mem1, (1, 0, 1, 0), (1, 22, 1, 44), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, :]) |
| check_strides_offset(memref.subview(mem1, (1, 1, 0, 1), (1, 1, 3, 1), (1, 1, 1, 1)), golden_mem[1:2, 1:2, :, 1:2]) |
| check_strides_offset(memref.subview(mem1, (1, 0, 1, 1), (1, 22, 1, 1), (1, 1, 1, 1)), golden_mem[1:2, :, 1:2, 1:2]) |
| check_strides_offset(memref.subview(mem1, (0, 1, 1, 1), (10, 1, 1, 1), (1, 1, 1, 1)), golden_mem[:, 1:2, 1:2, 1:2]) |
| check_strides_offset(memref.subview(mem1, (1, 1, 1, 0), (1, 1, 1, 44), (1, 1, 1, 1)), golden_mem[1:2, 1:2, 1:2, :]) |
| # fmt: on |
| |
| # default strides and offset means no stridedlayout attribute means affinemap layout |
| assert memref.subview( |
| mem1, (0, 0, 0, 0), (10, 22, 3, 44), (1, 1, 1, 1) |
| ).type.layout == AffineMapAttr.get( |
| AffineMap.get( |
| 4, |
| 0, |
| [ |
| AffineDimExpr.get(0), |
| AffineDimExpr.get(1), |
| AffineDimExpr.get(2), |
| AffineDimExpr.get(3), |
| ], |
| ) |
| ) |
| |
| shape = (7, 22, 30, 44) |
| golden_mem = np.zeros(shape, dtype=np.int32) |
| mem2 = memref.alloc(T.memref(*shape, T.i32()), [], []) |
| # fmt: off |
| check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 3, 44), (1, 2, 1, 1)), golden_mem[:, 0:22:2]) |
| check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 44), (1, 2, 30, 1)), golden_mem[:, 0:22:2, 0:330:30]) |
| check_strides_offset(memref.subview(mem2, (0, 0, 0, 0), (7, 11, 11, 11), (1, 2, 30, 400)), golden_mem[:, 0:22:2, 0:330:30, 0:4400:400]) |
| # fmt: on |
| |
| shape = (8, 8) |
| golden_mem = np.zeros(shape, dtype=np.int32) |
| # fmt: off |
| mem3 = memref.alloc(T.memref(*shape, T.i32()), [], []) |
| check_strides_offset(memref.subview(mem3, (0, 0), (4, 4), (1, 1)), golden_mem[0:4, 0:4]) |
| check_strides_offset(memref.subview(mem3, (4, 4), (4, 4), (1, 1)), golden_mem[4:8, 4:8]) |
| # fmt: on |