| import numpy as np |
| from mlir import ir |
| from mlir.dialects import arith |
| from mlir.dialects import func |
| from mlir.dialects import gpu |
| from mlir.dialects import memref |
| from mlir.dialects import nvgpu |
| from mlir.dialects import nvvm |
| from mlir.dialects import llvm |
| from mlir.dialects import builtin |
| from mlir.dialects import scf |
| from mlir.dialects import vector |
| from mlir.extras import types as T |
| |
| TMA_LAST_DIM_F16 = 64 # 128B flaot16 |
| WARP_SIZE = 32 |
| WARP_GROUP_SIZE = WARP_SIZE * 4 |
| |
| PRODUCER_REGISTER_SIZE = 40 |
| CONSUMER_REGISTER_SIZE = 232 |
| |
| PRODUCER_PRIMARY_THREAD = 128 |
| CONSUMER_PRIMARY_THREAD = 0 |
| |
| # C++ uses this value to understand whether it's dynamic or not. |
| MLIR_DYNAMIC = -9223372036854775808 |
| |
| DEBUG = False |
| |
| |
| class TmaDescriptorBuilder: |
| """A class that builds a TMA descriptor.""" |
| |
| def __init__(self, swizzle, l2promo, oob, interleave, tma_box_shape, memref_ty): |
| self.swizzle = swizzle # mlir.nvgpu.TensorMapSwizzleKind |
| self.l2promo = l2promo # mlir.nvgpu.TensorMapL2PromoKind |
| self.oob = oob # mlir.nvgpu.TensorMapOOBKind |
| self.interleave = interleave # mlir.nvgpu.TensorMapInterleaveKind |
| self.tma_box_shape = tma_box_shape |
| self.memref_ty = memref_ty # MemRefType |
| |
| @property |
| def tensormap_descriptor_ty(self): |
| """Returns a tensormap descriptor type.""" |
| tensorMemrefType = ir.MemRefType.get( |
| self.tma_box_shape, |
| self.memref_ty.element_type, |
| memory_space=ir.Attribute.parse("3"), |
| ) |
| return nvgpu.TensorMapDescriptorType.get( |
| tensorMemrefType, |
| self.swizzle, |
| self.l2promo, |
| self.oob, |
| self.interleave, |
| ) |
| |
| def tma_descriptor_op(self, device_ptr): |
| """Returns a tensormap descriptor op.""" |
| tma_descriptor_ty = self.tensormap_descriptor_ty |
| device_unranked_memref = memref.CastOp( |
| ir.UnrankedMemRefType.get( |
| self.memref_ty.element_type, self.memref_ty.memory_space |
| ), |
| device_ptr, |
| ) |
| tma_descriptor_op = nvgpu.TmaCreateDescriptorOp( |
| tma_descriptor_ty, device_unranked_memref, map(c, self.tma_box_shape) |
| ) |
| return tma_descriptor_op.result |
| |
| |
| def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False): |
| if not DEBUG and not forcePrint: |
| return |
| type_formats = [] |
| for arg in args: |
| ty_format = None |
| if ir.IndexType.isinstance(arg.type): |
| ty_format = "%llu" |
| if ir.IntegerType.isinstance(arg.type): |
| width = ir.IntegerType(arg.type).width |
| if width == 64: |
| ty_format = "%llu" |
| elif width == 32: |
| ty_format = "%d" |
| elif width == 1: |
| ty_format = "%i" |
| if ir.F32Type.isinstance(arg.type): |
| ty_format = "%f" |
| if ty_format is None: |
| raise NotImplementedError(arg.type) |
| type_formats.append(ty_format) |
| if threadNumber != -1: |
| tidx = gpu.thread_id(gpu.Dimension.x) |
| predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber)) |
| scf.yield_([]) |
| if_op = scf.IfOp(predicate) |
| with ir.InsertionPoint(if_op.then_block): |
| gpu.printf(fmt.format(*type_formats) + "\n", args) |
| scf.yield_([]) |
| |
| |
| def get_type_size(ty): |
| if ir.FloatType.isinstance(ty): |
| return ir.FloatType(ty).width // 8 |
| if ir.IntegerType.isinstance(ty): |
| return ir.IntegerType(ty).width // 8 |
| raise NotImplementedError(ty) |
| |
| |
| def get_mlir_ty(dtype): |
| if dtype == np.float16: |
| return T.f16() |
| if dtype == np.float32: |
| return T.f32() |
| if dtype == np.float64: |
| return T.f64() |
| if dtype == np.int32: |
| return T.i32() |
| if dtype == np.int64: |
| return T.i64() |
| raise NotImplementedError(dtype) |
| |
| |
| def c(value, ty=None): |
| ty = T.index() if ty is None else ty |
| return arith.constant(ty, value) |
| |
| |
| def make_kernel_name( |
| input_type=np.float16, |
| output_type=np.float32, |
| M=4096, |
| N=4096, |
| K=4096, |
| BLOCK_M=128, |
| BLOCK_N=128, |
| BLOCK_K=128, |
| num_stages=3, |
| use_warp_specialization=False, |
| ): |
| kernelName = "warpspecialized" if use_warp_specialization else "multistage" |
| return ( |
| kernelName |
| + "_" |
| + str(M) |
| + "x" |
| + str(N) |
| + "x" |
| + str(K) |
| + "_" |
| + str(BLOCK_M) |
| + "x" |
| + str(BLOCK_N) |
| + "x" |
| + str(BLOCK_K) |
| + "_" |
| + str(num_stages) |
| ) |
| |
| |
| def generate_matmul_ws( |
| input_type=np.float16, |
| output_type=np.float32, |
| M=4096, |
| N=4096, |
| K=4096, |
| BLOCK_M=128, |
| BLOCK_N=128, |
| BLOCK_K=128, |
| num_stages=3, |
| ): |
| # Limitaitons for now |
| assert input_type == np.float16 |
| assert output_type == np.float32 |
| assert BLOCK_M == 128 |
| assert BLOCK_N == 128 |
| assert BLOCK_K == 64 |
| assert M % BLOCK_M == 0 |
| assert N % BLOCK_N == 0 |
| assert K % BLOCK_K == 0 |
| |
| module = ir.Module.create() |
| token_ty = ir.Type.parse("!gpu.async.token") |
| a_elem_ty = get_mlir_ty(input_type) |
| b_elem_ty = get_mlir_ty(input_type) |
| c_elem_ty = get_mlir_ty(output_type) |
| a_ty = ir.MemRefType.get([M, K], a_elem_ty) |
| b_ty = ir.MemRefType.get((K, N), b_elem_ty) |
| c_ty = ir.MemRefType.get((M, N), c_elem_ty) |
| a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16) |
| b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16) |
| b_tile_shape = (BLOCK_K, BLOCK_N) |
| txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + ( |
| a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty) |
| ) |
| smem_space_str = "#gpu.address_space<workgroup>" |
| smem_space = ir.Attribute.parse(smem_space_str) |
| mbar_ty = ir.Type.parse( |
| "!nvgpu.mbarrier.group<memorySpace = " |
| + str(smem_space) |
| + ", num_barriers = " |
| + str(num_stages) |
| + ">" |
| ) |
| acc_ty = ir.Type.parse( |
| "!nvgpu.warpgroup.accumulator<fragmented=vector<" |
| + str(BLOCK_M) |
| + "x" |
| + str(BLOCK_N) |
| + "x" |
| + str(c_elem_ty) |
| + ">>" |
| ) |
| a_wgmma_ty = ir.Type.parse( |
| "!nvgpu.warpgroup.descriptor<tensor=memref<" |
| + str(BLOCK_M) |
| + "x" |
| + str(BLOCK_K) |
| + "x" |
| + str(a_elem_ty) |
| + ", " |
| + smem_space_str |
| + ">>" |
| ) |
| b_wgmma_ty = ir.Type.parse( |
| "!nvgpu.warpgroup.descriptor<tensor=memref<" |
| + str(BLOCK_K) |
| + "x" |
| + str(BLOCK_N) |
| + "x" |
| + str(a_elem_ty) |
| + ", " |
| + smem_space_str |
| + ">>" |
| ) |
| kernelName = make_kernel_name( |
| input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_stages, True |
| ) |
| with ir.InsertionPoint(module.body): |
| fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], [])) |
| with ir.InsertionPoint(fop.add_entry_block()): |
| a_host = fop.arguments[0] |
| b_host = fop.arguments[1] |
| c_host = fop.arguments[2] |
| lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty) |
| rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty) |
| smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages |
| smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty) |
| smem_size = max(smem_size_input, smem_size_output) |
| |
| # Step 1. Allocate device memory and memcpy |
| t1 = gpu.wait(token_ty, []) |
| a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], []) |
| b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], []) |
| c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], []) |
| t5 = gpu.memcpy(token_ty, [t4], a_device, a_host) |
| t6 = gpu.memcpy(token_ty, [t5], b_device, b_host) |
| t7 = gpu.wait(token_ty, [t6]) |
| |
| # Step 2. Create TMA Descriptors |
| a_tma_desc = TmaDescriptorBuilder( |
| nvgpu.TensorMapSwizzleKind.SWIZZLE_128B, |
| nvgpu.TensorMapL2PromoKind.L2PROMO_NONE, |
| nvgpu.TensorMapOOBKind.OOB_ZERO, |
| nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE, |
| a_tma_shape, |
| a_ty, |
| ) |
| |
| b_tma_desc = TmaDescriptorBuilder( |
| nvgpu.TensorMapSwizzleKind.SWIZZLE_128B, |
| nvgpu.TensorMapL2PromoKind.L2PROMO_NONE, |
| nvgpu.TensorMapOOBKind.OOB_ZERO, |
| nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE, |
| b_tma_shape, |
| b_ty, |
| ) |
| |
| a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device) |
| b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device) |
| |
| # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer |
| cta_m = M // BLOCK_M |
| cta_n = N // BLOCK_N |
| assert M % BLOCK_M == 0 and N % BLOCK_N == 0 |
| grid = (cta_m, cta_n, 1) |
| block = (WARP_GROUP_SIZE * 2, 1, 1) |
| launch_op = gpu.LaunchOp( |
| token_ty, |
| [t7], |
| *map(c, grid), |
| *map(c, block), |
| dynamicSharedMemorySize=c(smem_size, ty=T.i32()), |
| ) |
| launch_op.body.blocks.append(*([T.index()] * 12)) |
| with ir.InsertionPoint(launch_op.body.blocks[0]): |
| # GPU Step 0. This is need for vectorized ld/st |
| memref.assume_alignment(c_device, 16) |
| dynamic_smem = gpu.dynamic_shared_memory( |
| ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space) |
| ) |
| ticks = c(10000000) |
| |
| # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc. |
| tidx = gpu.thread_id(gpu.Dimension.x) |
| wgPrimaryThread = arith.cmpi( |
| arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0) |
| ) |
| warp_id = arith.divui(tidx, c(32)) |
| warpgroup_id = arith.divui(warp_id, c(4)) |
| is_producer = arith.cmpi( |
| arith.CmpIPredicate.eq, |
| warpgroup_id, |
| c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0), |
| ) |
| is_consumer = arith.cmpi( |
| arith.CmpIPredicate.eq, |
| warpgroup_id, |
| c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1), |
| ) |
| producerPrimaryThread = arith.cmpi( |
| arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD) |
| ) |
| consumerPrimaryThread = arith.cmpi( |
| arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD) |
| ) |
| bidx = gpu.block_id(gpu.Dimension.x) |
| bidy = gpu.block_id(gpu.Dimension.y) |
| dimX = arith.muli(bidx, c(BLOCK_M)) |
| dimY = arith.muli(bidy, c(BLOCK_N)) |
| |
| # GPU Step 2. Initialize mbarrier groups |
| mbarTMA = nvgpu.mbarrier_create(mbar_ty) |
| mbarDONE = nvgpu.mbarrier_create(mbar_ty) |
| for i in range(num_stages): |
| nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread) |
| nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread) |
| gpu.barrier() |
| |
| # GPU Step 3. Prefetch TMA descriptors |
| nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=wgPrimaryThread) |
| nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=wgPrimaryThread) |
| |
| ns = num_stages if num_stages == 1 else num_stages - 1 |
| # GPU Step 5. Producer Warpgroup (TMA Warpgroup) |
| with ir.InsertionPoint(scf.IfOp(is_producer).then_block): |
| # Step 5.1. Reduce register size |
| nvvm.setmaxregister( |
| PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease |
| ) |
| |
| # Step 5.2. TMA Main Loop |
| for_op = scf.ForOp( |
| c(0), c(K // BLOCK_K), c(1), [arith.constant(T.bool(), 1)] |
| ) |
| with ir.InsertionPoint(for_op.body): |
| phaseParity = for_op.inner_iter_args[0] |
| iv = for_op.induction_variable |
| stage = arith.remui(iv, c(num_stages)) |
| |
| # Step 5.2.1. Wait mbarDONE |
| debug_print( |
| "[prod] iv={} | mbarDONE[{}] try_wait phase={}", |
| iv, |
| stage, |
| phaseParity, |
| predicate=producerPrimaryThread, |
| ) |
| nvgpu.MBarrierTryWaitParityOp( |
| mbarDONE, phaseParity, ticks, mbarId=stage |
| ) |
| debug_print( |
| "[prod] iv={} | mbarDONE[{}] try_wait phase={} [done]", |
| iv, |
| stage, |
| phaseParity, |
| predicate=producerPrimaryThread, |
| ) |
| p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1)) |
| phaseParity = arith.select( |
| p, |
| arith.xori(phaseParity, arith.constant(T.bool(), 1)), |
| phaseParity, |
| ) |
| |
| # Step 5.2.2. Load TMA |
| a_offset = arith.muli(stage, c(lhs_tile_bytes)) |
| a_tma_slice = memref.view( |
| ir.MemRefType.get( |
| a_tma_shape, a_elem_ty, memory_space=smem_space |
| ), |
| dynamic_smem, |
| a_offset, |
| [], |
| ) |
| b_offset = arith.addi( |
| arith.muli(stage, c(rhs_tile_bytes)), |
| c(lhs_tile_bytes * num_stages), |
| ) |
| b_tma_slice_1 = memref.view( |
| ir.MemRefType.get( |
| b_tma_shape, b_elem_ty, memory_space=smem_space |
| ), |
| dynamic_smem, |
| b_offset, |
| [], |
| ) |
| b_offset2 = arith.addi( |
| b_offset, |
| c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)), |
| ) |
| b_tma_slice_2 = memref.view( |
| ir.MemRefType.get( |
| b_tma_shape, b_elem_ty, memory_space=smem_space |
| ), |
| dynamic_smem, |
| b_offset2, |
| [], |
| ) |
| debug_print( |
| "[prod] a_offset={} b_offset={} b_offset2={}", |
| a_offset, |
| b_offset, |
| b_offset2, |
| predicate=producerPrimaryThread, |
| ) |
| coord = arith.muli(c(64), iv) |
| nvgpu.TmaAsyncLoadOp( |
| a_tma_slice, |
| mbarTMA, |
| a_tma_desc_op, |
| coordinates=[coord, dimX], |
| mbarId=stage, |
| predicate=producerPrimaryThread, |
| ) |
| nvgpu.TmaAsyncLoadOp( |
| b_tma_slice_1, |
| mbarTMA, |
| b_tma_desc_op, |
| coordinates=[dimY, coord], |
| mbarId=stage, |
| predicate=producerPrimaryThread, |
| ) |
| dimY2 = arith.addi(dimY, c(64)) |
| nvgpu.TmaAsyncLoadOp( |
| b_tma_slice_2, |
| mbarTMA, |
| b_tma_desc_op, |
| coordinates=[dimY2, coord], |
| mbarId=stage, |
| predicate=producerPrimaryThread, |
| ) |
| |
| # Step 5.2.3. Arrive mbarTMA |
| debug_print( |
| "[prod] iv={} | mbarTMA[{}] arrive", |
| iv, |
| stage, |
| predicate=producerPrimaryThread, |
| ) |
| nvgpu.mbarrier_arrive_expect_tx( |
| mbarTMA, c(txcount), stage, predicate=producerPrimaryThread |
| ) |
| debug_print( |
| "[prod] iv={} | mbarTMA[{}] arrive [done]", |
| iv, |
| stage, |
| predicate=producerPrimaryThread, |
| ) |
| scf.yield_([phaseParity]) |
| scf.yield_([]) |
| |
| # GPU Step 6. Consumer Warpgroup (MMA Warpgroup) |
| if_op = scf.IfOp(is_consumer) |
| with ir.InsertionPoint(if_op.then_block): |
| # Step 6.1. Increase register size |
| nvvm.setmaxregister( |
| CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase |
| ) |
| |
| # GPU Step 6.2. Initialize MMA registers |
| acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty) |
| |
| # Step 6.3. MMA Main Loop |
| for_op = scf.ForOp( |
| c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)] |
| ) |
| with ir.InsertionPoint(for_op.body): |
| # Step 6.3.1. Wait mbar1 |
| phaseParity = for_op.inner_iter_args[1] |
| iv = for_op.induction_variable |
| stage = arith.remui(iv, c(num_stages)) |
| debug_print( |
| "[cons] iv={} | mbarTMA[{}] try_wait phase={}", |
| iv, |
| stage, |
| phaseParity, |
| predicate=consumerPrimaryThread, |
| ) |
| nvgpu.MBarrierTryWaitParityOp( |
| mbarTMA, phaseParity, ticks, mbarId=stage |
| ) |
| debug_print( |
| "[cons] iv={} | mbarTMA[{}] try_wait phase={} [done]", |
| iv, |
| stage, |
| phaseParity, |
| predicate=consumerPrimaryThread, |
| ) |
| |
| # Step 6.3.2. Create WGMMA Descriptors |
| a_offset = arith.muli(stage, c(lhs_tile_bytes)) |
| a_tile_slice = memref.view( |
| ir.MemRefType.get( |
| a_tile_shape, a_elem_ty, memory_space=smem_space |
| ), |
| dynamic_smem, |
| a_offset, |
| [], |
| ) |
| b_offset = arith.addi( |
| arith.muli(stage, c(rhs_tile_bytes)), |
| c(lhs_tile_bytes * num_stages), |
| ) |
| b_tile_slice = memref.view( |
| ir.MemRefType.get( |
| b_tile_shape, b_elem_ty, memory_space=smem_space |
| ), |
| dynamic_smem, |
| b_offset, |
| [], |
| ) |
| debug_print( |
| "[cons] a_offset={} b_offset={}", |
| a_offset, |
| b_offset, |
| predicate=consumerPrimaryThread, |
| ) |
| da = nvgpu.WarpgroupGenerateDescriptorOp( |
| a_wgmma_ty, a_tile_slice, a_tma_desc_op |
| ) |
| db = nvgpu.WarpgroupGenerateDescriptorOp( |
| b_wgmma_ty, b_tile_slice, b_tma_desc_op |
| ) |
| |
| # Step 6.3.3. MMA |
| carry_acc = for_op.inner_iter_args[0] |
| new_acc = nvgpu.WarpgroupMmaOp( |
| acc.type, da, db, carry_acc, transposeB=True |
| ) |
| |
| # Step 6.3.4. Arrive mbarDONE |
| if num_stages == 1: |
| p_arrive = consumerPrimaryThread |
| else: |
| p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0)) |
| p_arrive = arith.andi(consumerPrimaryThread, p1) |
| with ir.InsertionPoint(scf.IfOp(p_arrive).then_block): |
| p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0)) |
| barId = arith.select( |
| p, c(num_stages - 1), arith.subi(stage, c(1)) |
| ) |
| debug_print( |
| "[cons] iv={} | mbarDONE[{}] arrive ", |
| iv, |
| barId, |
| predicate=consumerPrimaryThread, |
| ) |
| nvgpu.mbarrier_arrive( |
| ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId |
| ) |
| debug_print( |
| "[cons] iv={} | mbarDONE[{}] arrive [done]", |
| iv, |
| barId, |
| predicate=consumerPrimaryThread, |
| ) |
| scf.yield_([]) |
| |
| p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1)) |
| phaseParity = arith.select( |
| p, |
| arith.xori(phaseParity, arith.constant(T.bool(), 1)), |
| phaseParity, |
| ) |
| |
| # Step 6.3.5. Yield |
| scf.yield_([new_acc, phaseParity]) |
| |
| # Step 6.3. Wait All WGMMA |
| nvvm.WgmmaWaitGroupSyncOp(0) |
| |
| with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block): |
| barId = c((K // BLOCK_K) % num_stages) |
| nvgpu.mbarrier_arrive( |
| ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId |
| ) |
| scf.yield_([]) |
| |
| # Step 6.4. Epilogue (registers --> shared memory) |
| acc_smem_ty = ir.MemRefType.get( |
| (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space |
| ) |
| acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), []) |
| debug_print("[cons] | Storing", predicate=consumerPrimaryThread) |
| nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem) |
| scf.yield_([]) |
| gpu.barrier() |
| |
| # GPU Step 9. Epilogue (shared memory --> global memory) |
| fd = ir.MemRefType.get( |
| [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space |
| ) |
| collapsed_smem = memref.view(fd, dynamic_smem, c(0), []) |
| rty = ir.MemRefType.get( |
| (BLOCK_M, BLOCK_N), |
| c_elem_ty, |
| ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"), |
| ) |
| c_device_per_block = memref.SubViewOp( |
| rty, |
| c_device, |
| [dimX, dimY], |
| [], |
| [], |
| [MLIR_DYNAMIC, MLIR_DYNAMIC], |
| [BLOCK_M, BLOCK_N], |
| [1, 1], |
| ) |
| vlen = 1 |
| for_op = scf.ForOp( |
| tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2) |
| ) |
| with ir.InsertionPoint(for_op.body): |
| x = arith.divui(for_op.induction_variable, c(BLOCK_M)) |
| y = arith.remui(for_op.induction_variable, c(BLOCK_N)) |
| vdata = vector.load( |
| ir.VectorType.get((vlen,), c_elem_ty), |
| collapsed_smem, |
| [for_op.induction_variable], |
| ) |
| vector.store(vdata, c_device_per_block, [x, y]) |
| scf.yield_([]) |
| |
| gpu.terminator() |
| |
| # Step 4. Copy back to host |
| t8 = gpu.wait(token_ty, [launch_op]) |
| t9 = gpu.memcpy(token_ty, [t8], c_host, c_device) |
| gpu.dealloc(token_ty, [t8], a_device) |
| gpu.dealloc(token_ty, [t8], b_device) |
| gpu.wait(token_ty, [t9]) |
| gpu.dealloc(token_ty, [t8], c_device) |
| func.ReturnOp([]) |
| |
| fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() |
| module.operation.verify() |
| return module |
| |
| |
| def generate_matmul_multistage( |
| input_type=np.float16, |
| output_type=np.float32, |
| M=4096, |
| N=4096, |
| K=4096, |
| BLOCK_M=128, |
| BLOCK_N=128, |
| BLOCK_K=64, |
| num_stages=3, |
| ): |
| # Limitaitons for now |
| assert input_type == np.float16 |
| assert output_type == np.float32 |
| assert BLOCK_M == 128 |
| assert BLOCK_N == 128 |
| assert BLOCK_K == 64 |
| assert M % BLOCK_M == 0 |
| assert N % BLOCK_N == 0 |
| assert K % BLOCK_K == 0 |
| |
| module = ir.Module.create() |
| token_ty = ir.Type.parse("!gpu.async.token") |
| a_elem_ty = get_mlir_ty(input_type) |
| b_elem_ty = get_mlir_ty(input_type) |
| c_elem_ty = get_mlir_ty(output_type) |
| a_ty = ir.MemRefType.get([M, K], a_elem_ty) |
| b_ty = ir.MemRefType.get((K, N), b_elem_ty) |
| c_ty = ir.MemRefType.get((M, N), c_elem_ty) |
| a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16) |
| b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16) |
| b_tile_shape = (BLOCK_K, BLOCK_N) |
| txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + ( |
| a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty) |
| ) |
| smem_space_str = "#gpu.address_space<workgroup>" |
| smem_space = ir.Attribute.parse(smem_space_str) |
| mbar_ty = ir.Type.parse( |
| "!nvgpu.mbarrier.group<memorySpace = " |
| + str(smem_space) |
| + ", num_barriers = " |
| + str(num_stages) |
| + ">" |
| ) |
| acc_ty = ir.Type.parse( |
| "!nvgpu.warpgroup.accumulator<fragmented=vector<" |
| + str(BLOCK_M) |
| + "x" |
| + str(BLOCK_N) |
| + "x" |
| + str(c_elem_ty) |
| + ">>" |
| ) |
| a_wgmma_ty = ir.Type.parse( |
| "!nvgpu.warpgroup.descriptor<tensor=memref<" |
| + str(BLOCK_M) |
| + "x" |
| + str(BLOCK_K) |
| + "x" |
| + str(a_elem_ty) |
| + ", " |
| + smem_space_str |
| + ">>" |
| ) |
| b_wgmma_ty = ir.Type.parse( |
| "!nvgpu.warpgroup.descriptor<tensor=memref<" |
| + str(BLOCK_K) |
| + "x" |
| + str(BLOCK_N) |
| + "x" |
| + str(a_elem_ty) |
| + ", " |
| + smem_space_str |
| + ">>" |
| ) |
| |
| with ir.InsertionPoint(module.body): |
| kernelName = make_kernel_name( |
| input_type, |
| output_type, |
| M, |
| N, |
| K, |
| BLOCK_M, |
| BLOCK_N, |
| BLOCK_K, |
| num_stages, |
| False, |
| ) |
| fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], [])) |
| with ir.InsertionPoint(fop.add_entry_block()): |
| a_host = fop.arguments[0] |
| b_host = fop.arguments[1] |
| c_host = fop.arguments[2] |
| lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty) |
| rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty) |
| smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages |
| smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty) |
| smem_size = max(smem_size_input, smem_size_output) |
| |
| # Step 1. Allocate device memory and memcpy |
| t1 = gpu.wait(token_ty, []) |
| a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], []) |
| b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], []) |
| c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], []) |
| t5 = gpu.memcpy(token_ty, [t4], a_device, a_host) |
| t6 = gpu.memcpy(token_ty, [t5], b_device, b_host) |
| t7 = gpu.wait(token_ty, [t6]) |
| |
| # Step 2. Create TMA Descriptors |
| a_tma_desc = TmaDescriptorBuilder( |
| nvgpu.TensorMapSwizzleKind.SWIZZLE_128B, |
| nvgpu.TensorMapL2PromoKind.L2PROMO_NONE, |
| nvgpu.TensorMapOOBKind.OOB_ZERO, |
| nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE, |
| a_tma_shape, |
| a_ty, |
| ) |
| |
| b_tma_desc = TmaDescriptorBuilder( |
| nvgpu.TensorMapSwizzleKind.SWIZZLE_128B, |
| nvgpu.TensorMapL2PromoKind.L2PROMO_NONE, |
| nvgpu.TensorMapOOBKind.OOB_ZERO, |
| nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE, |
| b_tma_shape, |
| b_ty, |
| ) |
| |
| a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device) |
| b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device) |
| |
| # Step 3. Launch Kernel with 1 Warpgroup |
| cta_m = M // BLOCK_M |
| cta_n = N // BLOCK_N |
| assert M % BLOCK_M == 0 and N % BLOCK_N == 0 |
| grid = (cta_m, cta_n, 1) |
| block = (WARP_GROUP_SIZE, 1, 1) |
| launch_op = gpu.LaunchOp( |
| token_ty, |
| [t7], |
| *map(c, grid), |
| *map(c, block), |
| dynamicSharedMemorySize=c(smem_size, ty=T.i32()), |
| ) |
| launch_op.body.blocks.append(*([T.index()] * 12)) |
| with ir.InsertionPoint(launch_op.body.blocks[0]): |
| # GPU Step 0. Bootstrapping |
| memref.assume_alignment(c_device, 16) |
| dynamic_smem = gpu.dynamic_shared_memory( |
| ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space) |
| ) |
| ticks = c(10000000) |
| tidx = gpu.thread_id(gpu.Dimension.x) |
| primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0)) |
| warpId = arith.divui(tidx, c(32)) |
| bidx = gpu.block_id(gpu.Dimension.x) |
| bidy = gpu.block_id(gpu.Dimension.y) |
| dimX = arith.muli(bidx, c(BLOCK_M)) |
| dimY = arith.muli(bidy, c(BLOCK_N)) |
| |
| # GPU Step 1. Initialize mbarrier groups |
| mbarTMA = nvgpu.mbarrier_create(mbar_ty) |
| for i in range(num_stages): |
| nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread) |
| gpu.barrier() |
| |
| # GPU Step 2. Prefetch TMA descriptors |
| nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=primaryThread) |
| nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=primaryThread) |
| |
| # GPU Step 3. Prologue (global memory --> shared memory) |
| ns = num_stages if num_stages == 1 else num_stages - 1 |
| for_op = scf.ForOp(c(0), c(ns), c(1)) |
| with ir.InsertionPoint(for_op.body): |
| iv = for_op.induction_variable |
| |
| # Step 3.1. Calculate offsets |
| a_offset = arith.muli(iv, c(lhs_tile_bytes)) |
| a_tma_slice = memref.view( |
| ir.MemRefType.get( |
| a_tma_shape, a_elem_ty, memory_space=smem_space |
| ), |
| dynamic_smem, |
| a_offset, |
| [], |
| ) |
| b_offset = arith.addi( |
| arith.muli(iv, c(rhs_tile_bytes)), |
| c(lhs_tile_bytes * num_stages), |
| ) |
| b_tma_slice_1 = memref.view( |
| ir.MemRefType.get( |
| b_tma_shape, b_elem_ty, memory_space=smem_space |
| ), |
| dynamic_smem, |
| b_offset, |
| [], |
| ) |
| b_offset2 = arith.addi( |
| b_offset, |
| c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)), |
| ) |
| b_tma_slice_2 = memref.view( |
| ir.MemRefType.get( |
| b_tma_shape, b_elem_ty, memory_space=smem_space |
| ), |
| dynamic_smem, |
| b_offset2, |
| [], |
| ) |
| |
| # Step 3.2. TMA Load |
| coord = arith.muli(c(64), iv) |
| dimY2 = arith.addi(dimY, c(64)) |
| debug_print( |
| "[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})", |
| a_offset, |
| b_offset, |
| b_offset2, |
| coord, |
| dimX, |
| dimY, |
| coord, |
| predicate=primaryThread, |
| ) |
| nvgpu.TmaAsyncLoadOp( |
| a_tma_slice, |
| mbarTMA, |
| a_tma_desc_op, |
| coordinates=[coord, dimX], |
| mbarId=iv, |
| predicate=primaryThread, |
| ) |
| nvgpu.TmaAsyncLoadOp( |
| b_tma_slice_1, |
| mbarTMA, |
| b_tma_desc_op, |
| coordinates=[dimY, coord], |
| mbarId=iv, |
| predicate=primaryThread, |
| ) |
| nvgpu.TmaAsyncLoadOp( |
| b_tma_slice_2, |
| mbarTMA, |
| b_tma_desc_op, |
| coordinates=[dimY2, coord], |
| mbarId=iv, |
| predicate=primaryThread, |
| ) |
| |
| # Step 3.2. mbarTMA arrive |
| debug_print( |
| "[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread |
| ) |
| nvgpu.mbarrier_arrive_expect_tx( |
| mbarTMA, c(txcount), iv, predicate=primaryThread |
| ) |
| debug_print( |
| "[Prologue] mbarTMA[{}] arrive [done]", |
| iv, |
| predicate=primaryThread, |
| ) |
| scf.yield_([]) |
| |
| # GPU Step 4. Main Loop |
| acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty) |
| for_op = scf.ForOp( |
| c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)] |
| ) |
| with ir.InsertionPoint(for_op.body): |
| # Step 4.1. Wait mbarTMA |
| phaseParity = for_op.inner_iter_args[1] |
| iv = for_op.induction_variable |
| stage = arith.remui(iv, c(num_stages)) |
| debug_print( |
| "[MainLoop] mbarTMA[{}] try_wait phase={}", |
| stage, |
| phaseParity, |
| predicate=primaryThread, |
| ) |
| nvgpu.MBarrierTryWaitParityOp( |
| mbarTMA, phaseParity, ticks, mbarId=stage |
| ) |
| debug_print( |
| "[MainLoop] mbarTMA[{}] try_wait phase={} [done]", |
| stage, |
| phaseParity, |
| predicate=primaryThread, |
| ) |
| |
| # Step 4.2. Create WGMMA Descriptors |
| a_offset = arith.muli(stage, c(lhs_tile_bytes)) |
| a_tile_slice = memref.view( |
| ir.MemRefType.get( |
| a_tile_shape, a_elem_ty, memory_space=smem_space |
| ), |
| dynamic_smem, |
| a_offset, |
| [], |
| ) |
| b_offset = arith.addi( |
| arith.muli(stage, c(rhs_tile_bytes)), |
| c(lhs_tile_bytes * num_stages), |
| ) |
| b_tile_slice = memref.view( |
| ir.MemRefType.get( |
| b_tile_shape, b_elem_ty, memory_space=smem_space |
| ), |
| dynamic_smem, |
| b_offset, |
| [], |
| ) |
| debug_print( |
| "[MainLoop] iv={} MMA a_offset={} b_offset={}", |
| iv, |
| a_offset, |
| b_offset, |
| predicate=primaryThread, |
| ) |
| da = nvgpu.WarpgroupGenerateDescriptorOp( |
| a_wgmma_ty, a_tile_slice, a_tma_desc_op |
| ) |
| db = nvgpu.WarpgroupGenerateDescriptorOp( |
| b_wgmma_ty, b_tile_slice, b_tma_desc_op |
| ) |
| |
| # Step 4.3. MMA |
| carry_acc = for_op.inner_iter_args[0] |
| new_acc = nvgpu.WarpgroupMmaOp( |
| acc.type, da, db, carry_acc, transposeB=True |
| ) |
| if num_stages == 1: |
| nvvm.WgmmaWaitGroupSyncOp(0) |
| |
| # Step 4.4. Load TMA for next stage |
| p1 = arith.cmpi( |
| arith.CmpIPredicate.ult, |
| arith.addi(iv, c(ns)), |
| c(K // BLOCK_K), |
| ) |
| p = arith.andi(primaryThread, p1) |
| nextStage = arith.addi(iv, c(ns)) |
| nextSlot = arith.remui(nextStage, c(num_stages)) |
| a_offset = arith.muli(nextSlot, c(lhs_tile_bytes)) |
| |
| debug_print( |
| "[MainLoop] mbarTMA[{}] arrive", |
| nextSlot, |
| predicate=p, |
| ) |
| nvgpu.mbarrier_arrive_expect_tx( |
| mbarTMA, c(txcount), nextSlot, predicate=p |
| ) |
| debug_print( |
| "[MainLoop] mbarTMA[{}] arrive [done]", |
| nextSlot, |
| predicate=p, |
| ) |
| |
| a_tma_slice = memref.view( |
| ir.MemRefType.get( |
| a_tma_shape, a_elem_ty, memory_space=smem_space |
| ), |
| dynamic_smem, |
| a_offset, |
| [], |
| ) |
| b_offset = arith.addi( |
| arith.muli(nextSlot, c(rhs_tile_bytes)), |
| c(lhs_tile_bytes * num_stages), |
| ) |
| b_tma_slice_1 = memref.view( |
| ir.MemRefType.get( |
| b_tma_shape, b_elem_ty, memory_space=smem_space |
| ), |
| dynamic_smem, |
| b_offset, |
| [], |
| ) |
| b_offset2 = arith.addi( |
| b_offset, |
| c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)), |
| ) |
| b_tma_slice_2 = memref.view( |
| ir.MemRefType.get( |
| b_tma_shape, b_elem_ty, memory_space=smem_space |
| ), |
| dynamic_smem, |
| b_offset2, |
| [], |
| ) |
| |
| coord = arith.muli(c(64), nextStage) |
| debug_print( |
| "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})", |
| iv, |
| a_offset, |
| b_offset, |
| b_offset2, |
| coord, |
| dimX, |
| dimY, |
| coord, |
| predicate=p, |
| ) |
| nvgpu.TmaAsyncLoadOp( |
| a_tma_slice, |
| mbarTMA, |
| a_tma_desc_op, |
| coordinates=[coord, dimX], |
| mbarId=nextSlot, |
| predicate=p, |
| ) |
| nvgpu.TmaAsyncLoadOp( |
| b_tma_slice_1, |
| mbarTMA, |
| b_tma_desc_op, |
| coordinates=[dimY, coord], |
| mbarId=nextSlot, |
| predicate=p, |
| ) |
| dimY2 = arith.addi(dimY, c(64)) |
| nvgpu.TmaAsyncLoadOp( |
| b_tma_slice_2, |
| mbarTMA, |
| b_tma_desc_op, |
| coordinates=[dimY2, coord], |
| mbarId=nextSlot, |
| predicate=p, |
| ) |
| # Step 4.5. Change the phaseParity |
| p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1)) |
| phaseParity = arith.select( |
| p, |
| arith.xori(phaseParity, arith.constant(T.bool(), 1)), |
| phaseParity, |
| ) |
| |
| # Step 4.5. Yield |
| scf.yield_([new_acc, phaseParity]) |
| |
| # Step 5. Wait All WGMMA groups |
| nvvm.WgmmaWaitGroupSyncOp(0) |
| |
| # Step 6. Epilogue (registers --> shared memory) |
| acc_smem_ty = ir.MemRefType.get( |
| (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space |
| ) |
| acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), []) |
| debug_print("Storing", predicate=primaryThread) |
| nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem) |
| gpu.barrier() |
| |
| # GPU Step 7. Epilogue (shared memory --> global memory) |
| fd = ir.MemRefType.get( |
| [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space |
| ) |
| collapsed_smem = memref.view(fd, dynamic_smem, c(0), []) |
| rty = ir.MemRefType.get( |
| (BLOCK_M, BLOCK_N), |
| c_elem_ty, |
| ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"), |
| ) |
| c_device_per_block = memref.SubViewOp( |
| rty, |
| c_device, |
| [dimX, dimY], |
| [], |
| [], |
| [MLIR_DYNAMIC, MLIR_DYNAMIC], |
| [BLOCK_M, BLOCK_N], |
| [1, 1], |
| ) |
| vlen = 1 |
| for_op = scf.ForOp( |
| tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE) |
| ) |
| with ir.InsertionPoint(for_op.body): |
| x = arith.divui(for_op.induction_variable, c(BLOCK_M)) |
| y = arith.remui(for_op.induction_variable, c(BLOCK_N)) |
| vdata = vector.load( |
| ir.VectorType.get((vlen,), c_elem_ty), |
| collapsed_smem, |
| [for_op.induction_variable], |
| ) |
| vector.store(vdata, c_device_per_block, [x, y]) |
| scf.yield_([]) |
| |
| gpu.terminator() |
| |
| # Step 4. Copy back to host |
| t8 = gpu.wait(token_ty, [launch_op]) |
| t9 = gpu.memcpy(token_ty, [t8], c_host, c_device) |
| gpu.dealloc(token_ty, [t8], a_device) |
| gpu.dealloc(token_ty, [t8], b_device) |
| gpu.wait(token_ty, [t9]) |
| gpu.dealloc(token_ty, [t8], c_device) |
| func.ReturnOp([]) |
| |
| fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() |
| module.operation.verify() |
| return module |