| //===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h" |
| |
| #include "mlir/Analysis/SliceAnalysis.h" |
| #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
| #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Arith/Utils/Utils.h" |
| #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
| #include "mlir/Dialect/NVGPU/Transforms/Transforms.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
| #include "mlir/Dialect/Utils/IndexingUtils.h" |
| #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "mlir/IR/AffineExpr.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Value.h" |
| #include "llvm/ADT/ArrayRef.h" |
| |
| using namespace mlir; |
| using namespace mlir::linalg; |
| using namespace mlir::nvgpu; |
| using namespace mlir::NVVM; |
| using namespace mlir::transform; |
| |
| #define DEBUG_TYPE "nvgpu-transforms" |
| #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| #define DBGSNL() (llvm::dbgs() << "\n") |
| #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
| |
| //===----------------------------------------------------------------------===// |
| // Apply...ConversionPatternsOp |
| //===----------------------------------------------------------------------===// |
| |
| void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns( |
| TypeConverter &typeConverter, RewritePatternSet &patterns) { |
| auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); |
| /// device-side async tokens cannot be materialized in nvvm. We just |
| /// convert them to a dummy i32 type in order to easily drop them during |
| /// conversion. |
| llvmTypeConverter.addConversion( |
| [&](nvgpu::DeviceAsyncTokenType type) -> Type { |
| return llvmTypeConverter.convertType( |
| IntegerType::get(type.getContext(), 32)); |
| }); |
| llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type { |
| return llvmTypeConverter.convertType( |
| IntegerType::get(type.getContext(), 64)); |
| }); |
| llvmTypeConverter.addConversion( |
| [&](nvgpu::WarpgroupAccumulatorType type) -> Type { |
| Type elemType = type.getFragmented().getElementType(); |
| int64_t sizeM = type.getFragmented().getDimSize(0); |
| int64_t sizeN = type.getFragmented().getDimSize(1); |
| |
| unsigned numMembers; |
| if (elemType.isF32() || elemType.isInteger(32)) |
| numMembers = sizeN / 2; |
| else if (elemType.isF16()) |
| numMembers = sizeN / 4; |
| else |
| llvm_unreachable("unsupported type for warpgroup accumulator"); |
| |
| SmallVector<Type> innerStructBody; |
| for (unsigned i = 0; i < numMembers; i++) |
| innerStructBody.push_back(elemType); |
| auto innerStructType = LLVM::LLVMStructType::getLiteral( |
| type.getContext(), innerStructBody); |
| |
| SmallVector<Type> structBody; |
| for (int i = 0; i < sizeM; i += kWgmmaSizeM) |
| structBody.push_back(innerStructType); |
| |
| auto convertedType = |
| LLVM::LLVMStructType::getLiteral(type.getContext(), structBody); |
| return llvmTypeConverter.convertType(convertedType); |
| }); |
| llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type { |
| return llvmTypeConverter.convertType( |
| getMBarrierMemrefType(type.getContext(), type)); |
| }); |
| llvmTypeConverter.addConversion( |
| [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type { |
| return llvmTypeConverter.convertType( |
| IntegerType::get(type.getContext(), 64)); |
| }); |
| llvmTypeConverter.addConversion( |
| [&](nvgpu::TensorMapDescriptorType type) -> Type { |
| return LLVM::LLVMPointerType::get(type.getContext()); |
| }); |
| populateNVGPUToNVVMConversionPatterns(llvmTypeConverter, patterns); |
| } |
| |
| LogicalResult |
| transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter( |
| transform::TypeConverterBuilderOpInterface builder) { |
| if (builder.getTypeConverterType() != "LLVMTypeConverter") |
| return emitOpError("expected LLVMTypeConverter"); |
| return success(); |
| } |
| |
| //===---------------------------------------------------------------------===// |
| // CreateAsyncGroupsOp |
| //===---------------------------------------------------------------------===// |
| |
| void transform::CreateAsyncGroupsOp::getEffects( |
| SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| transform::consumesHandle(getTarget(), effects); |
| transform::producesHandle(getResult(), effects); |
| transform::modifiesPayload(effects); |
| } |
| |
| DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne( |
| TransformRewriter &rewriter, Operation *target, |
| ApplyToEachResultList &results, TransformState &state) { |
| nvgpu::createAsyncGroups(rewriter, target, getBypassL1()); |
| results.push_back(target); |
| return DiagnosedSilenceableFailure::success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // PipelineSharedMemoryCopiesOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Returns true if the given type has the default memory space. |
| static bool hasDefaultMemorySpace(BaseMemRefType type) { |
| return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0; |
| } |
| |
| /// Returns true if the given type has the shared (workgroup) memory space. |
| static bool hasSharedMemorySpace(BaseMemRefType type) { |
| auto space = |
| dyn_cast_if_present<gpu::AddressSpaceAttr>(type.getMemorySpace()); |
| return space && |
| space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace(); |
| } |
| |
| /// Returns the value produced by a load from the default memory space. Returns |
| /// null if the operation is not such a load. |
| static Value getValueLoadedFromGlobal(Operation *op) { |
| // TODO: consider an interface or leveraging the memory effects interface. |
| auto load = dyn_cast<vector::TransferReadOp>(op); |
| if (!load) |
| return nullptr; |
| |
| auto loadType = dyn_cast<MemRefType>(load.getSource().getType()); |
| if (!loadType || !hasDefaultMemorySpace(loadType)) |
| return nullptr; |
| return load; |
| } |
| |
| /// Returns true if the operation is storing the given value into shared memory. |
| static bool isStoreToShared(Operation *op, Value v) { |
| // TOD: consider an interface or leveraging the memory effects interface. |
| auto store = dyn_cast<vector::TransferWriteOp>(op); |
| if (!store || store.getVector() != v) |
| return false; |
| |
| auto storeType = dyn_cast<MemRefType>(store.getSource().getType()); |
| return storeType || hasSharedMemorySpace(storeType); |
| } |
| |
| /// Returns true if the operation is a load from the default memory space the |
| /// result of which is only stored into the shared memory space. |
| static bool isLoadFromGlobalStoredToShared(Operation *op) { |
| Value loaded = getValueLoadedFromGlobal(op); |
| if (!loaded || !loaded.hasOneUse()) |
| return false; |
| |
| return isStoreToShared(*loaded.getUsers().begin(), loaded); |
| } |
| |
| /// Populate `ops` with the set of operations that belong to the stage 0 of the |
| /// pipelined version of the given loop when pipelining copies to shared memory. |
| /// Specifically, this collects: |
| /// |
| /// 1. all loads from global memory, both sync and async; |
| /// 2. the barriers for async loads. |
| /// |
| /// In particular, barriers are omitted if they do not dominate at least one |
| /// async load for which there is not yet a barrier. |
| static LogicalResult |
| collectStage0PipeliningOps(scf::ForOp forOp, |
| llvm::SmallPtrSet<Operation *, 16> &ops) { |
| |
| llvm::SmallPtrSet<Operation *, 4> barriers; |
| for (Operation &op : *forOp.getBody()) { |
| // Bail on nested ops for now. |
| if (op.getNumRegions() > 0) |
| return failure(); |
| |
| if (isa<gpu::BarrierOp>(op)) { |
| barriers.insert(&op); |
| continue; |
| } |
| |
| if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) { |
| ops.insert(&op); |
| ops.insert(std::make_move_iterator(barriers.begin()), |
| std::make_move_iterator(barriers.end())); |
| assert(barriers.empty() && |
| "expected to have moved the barriers into another set"); |
| continue; |
| } |
| |
| if (isLoadFromGlobalStoredToShared(&op)) { |
| ops.insert(&op); |
| continue; |
| } |
| } |
| |
| return success(); |
| } |
| |
| /// Hook for the loop pipeliner that sets the "num groups in flight" attribute |
| /// of async wait operations corresponding to pipelined shared memory copies. |
| // TODO: this currently assumes that there are no groups that could be in flight |
| // in the existing code. |
| static void |
| setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, |
| scf::PipeliningOption::PipelinerPart part, |
| unsigned iteration, unsigned depth) { |
| // Based on the order of copies within the loop we need to set the number |
| // of copies in flight, unless it is already set. |
| auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op); |
| if (!waitOp || waitOp.getNumGroups()) |
| return; |
| |
| int numGroupInFlight = 0; |
| if (part == scf::PipeliningOption::PipelinerPart::Kernel || |
| part == scf::PipeliningOption::PipelinerPart::Prologue) { |
| numGroupInFlight = depth - 1; |
| } else { |
| // By construction there should be no wait op in the prologue as all the |
| // wait should be in the last stage. |
| assert(part == scf::PipeliningOption::PipelinerPart::Epilogue); |
| // Based on the schedule we pick we know how many groups are in flight for |
| // each iteration of the epilogue. |
| numGroupInFlight = depth - 1 - iteration; |
| } |
| waitOp.setNumGroups(numGroupInFlight); |
| } |
| |
| /// Hook for the loop pipeliner that populates `ops` with the stage information |
| /// as follows: |
| /// |
| /// - operations in `stage0Ops` (typically loads from global memory and |
| /// related barriers) are at stage 0; |
| /// - operations in the backward slice of any stage0Ops are all at stage 0; |
| /// - other operations are at stage `depth`; |
| /// - the internal order of the pipelined loop has ops at stage `depth` first, |
| /// then those at stage 0, with relative order within each group preserved. |
| /// |
| static void getPipelineStages( |
| scf::ForOp forOp, |
| std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages, |
| unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) { |
| SetVector<Operation *> dependencies; |
| BackwardSliceOptions options([&](Operation *visited) { |
| return visited->getBlock() == forOp.getBody(); |
| }); |
| options.inclusive = true; |
| for (Operation &op : forOp.getBody()->getOperations()) { |
| if (stage0Ops.contains(&op)) |
| getBackwardSlice(&op, &dependencies, options); |
| } |
| |
| for (Operation &op : forOp.getBody()->getOperations()) { |
| if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op)) |
| opsWithPipelineStages.emplace_back(&op, depth); |
| } |
| for (Operation &op : forOp.getBody()->getOperations()) { |
| if (dependencies.contains(&op)) |
| opsWithPipelineStages.emplace_back(&op, 0); |
| } |
| } |
| |
| /// Hook for the loop pipeliner. Replaces op with a predicated version and |
| /// returns the resulting operation. Returns the original op if the predication |
| /// isn't necessary for the given op. Returns null if predication is needed but |
| /// not supported. |
| static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter, |
| Operation *op, Value predicate) { |
| // Some operations may be fine to execute "speculatively" more times than the |
| // original number of iterations, in particular side-effect free operations |
| // and barriers, even if they cannot be predicated. |
| if (isMemoryEffectFree(op) || |
| isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp, |
| nvgpu::DeviceAsyncWaitOp>(op)) { |
| return op; |
| } |
| |
| // Otherwise, only async copies can currently be predicated. |
| auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op); |
| if (!asyncCopyOp) |
| return nullptr; |
| |
| // Create srcElement Value based on `predicate`. The next lines generate |
| // the following code: |
| // |
| // srcElement = (pred) ? prevSrcElements : 0; |
| // |
| Location loc = asyncCopyOp->getLoc(); |
| Value dstElements = |
| rewriter.create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr()); |
| Value originalSrcElement = |
| asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements; |
| Value c0Index = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| auto srcElements = rewriter.create<arith::SelectOp>( |
| loc, predicate, originalSrcElement, c0Index); |
| auto asyncCopyZeroFillOp = rewriter.create<nvgpu::DeviceAsyncCopyOp>( |
| loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()), |
| asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(), |
| asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements, |
| UnitAttr()); |
| rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp); |
| return asyncCopyZeroFillOp; |
| } |
| |
| /// Applies loop pipelining with the given depth to the given loop so that |
| /// copies into the shared memory are pipelined. Doesn't affect other loops. |
| /// Returns a pair containing the error state and the pipelined op, the latter |
| /// being null in case of any failure. The error state contains a definite error |
| /// if the IR has been modified and a silenceable error otherwise. |
| static std::tuple<DiagnosedSilenceableFailure, scf::ForOp> |
| pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth, |
| bool epiloguePeeling) { |
| llvm::SmallPtrSet<Operation *, 16> stage0Ops; |
| if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) { |
| return std::make_tuple( |
| emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining"), |
| scf::ForOp()); |
| } |
| if (stage0Ops.empty()) { |
| return std::make_tuple( |
| emitSilenceableFailure(forOp, "no shared memory copy"), scf::ForOp()); |
| } |
| |
| scf::PipeliningOption options; |
| unsigned maxDepth = depth; |
| auto setAnnotation = [&](Operation *op, |
| scf::PipeliningOption::PipelinerPart part, |
| unsigned iteration) { |
| return setAsyncWaitGroupsInFlight(rewriter, op, part, iteration, maxDepth); |
| }; |
| options.getScheduleFn = |
| [&](scf::ForOp schedulingFor, |
| std::vector<std::pair<Operation *, unsigned>> &ops) { |
| if (schedulingFor != forOp) |
| return; |
| return getPipelineStages(forOp, ops, maxDepth, stage0Ops); |
| }; |
| options.annotateFn = setAnnotation; |
| if (!epiloguePeeling) { |
| options.peelEpilogue = false; |
| options.predicateFn = replaceOpWithPredicatedOp; |
| } |
| |
| OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.setInsertionPoint(forOp); |
| bool modifiedIR; |
| FailureOr<scf::ForOp> maybePipelined = |
| pipelineForLoop(rewriter, forOp, options, &modifiedIR); |
| if (succeeded(maybePipelined)) { |
| return std::make_tuple(DiagnosedSilenceableFailure::success(), |
| *maybePipelined); |
| } |
| return std::make_tuple( |
| modifiedIR |
| ? DiagnosedSilenceableFailure::definiteFailure() |
| : emitSilenceableFailure(forOp, "pipelining preconditions failed"), |
| scf::ForOp()); |
| } |
| |
| DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne( |
| TransformRewriter &rewriter, scf::ForOp forOp, |
| ApplyToEachResultList &results, TransformState &state) { |
| auto [diag, pipelined] = pipelineForSharedCopies( |
| rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue()); |
| if (diag.succeeded()) { |
| results.push_back(pipelined); |
| return DiagnosedSilenceableFailure::success(); |
| } |
| if (diag.isDefiniteFailure()) { |
| auto diag = emitDefiniteFailure("irreversible pipelining failure"); |
| if (!getPeelEpilogue()) { |
| diag.attachNote(forOp->getLoc()) << "couldn't predicate?"; |
| diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName(); |
| } |
| return diag; |
| } |
| |
| return std::move(diag); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // RewriteMatmulAsMmaSyncOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Helper struct to encode a pair of row/column indexings in the form of |
| /// affine expressions. |
| struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> { |
| RowColIndexing(AffineExpr row, AffineExpr col) |
| : std::pair<AffineExpr, AffineExpr>(row, col) {} |
| |
| AffineExpr row() const { return first; }; |
| AffineExpr col() const { return second; }; |
| |
| void print(llvm::raw_ostream &os) const { |
| os << "- indexing: " << first << ", " << second; |
| } |
| }; |
| |
| /// Helper struct to provide a simple mapping from matmul operations to the |
| /// corresponding mma.sync operation. This is constrained to the case where the |
| /// matmul matches the mma.sync operation 1-1. |
| struct MmaSyncBuilder { |
| MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId) |
| : b(b), loc(loc), laneId(laneId) {} |
| |
| using IndexCalculator = |
| std::function<SmallVector<RowColIndexing>(MLIRContext *)>; |
| |
| /// Create the mma.sync operation corresponding to `linalgOp` along with all |
| /// the supporting load/store and vector operations. |
| FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp); |
| |
| private: |
| struct MmaSyncInfo { |
| std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns; |
| std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>> |
| vectorShapes; |
| SmallVector<int64_t> mmaShape; |
| bool tf32Enabled; |
| }; |
| |
| /// Return the specific index calculator for the given `linalgOp` or failure |
| /// if the op is not supported. This is the toplevel switch that should just |
| /// be Tablegen'd in the future. |
| FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape, |
| TypeRange elementalTypes); |
| |
| //===--------------------------------------------------------------------===// |
| // Instruction-specific row, column indexing expression builders. |
| // These should all be declaratively specified via Tablegen in the future. |
| // The Tablegen specification should be as straightforward as possible to |
| // only model the existing size and type combinations. |
| //===--------------------------------------------------------------------===// |
| // |
| // TODO: Tablegen all this. |
| //===--------------------------------------------------------------------===// |
| // m16n8k4 tf32 case. |
| //===--------------------------------------------------------------------===// |
| /// From the NVIDIA doc: |
| /// groupID = %laneid >> 2 |
| /// threadIDInGroup = %laneid % 4 |
| /// row = groupID for a0 |
| /// groupID + 8 for a1 |
| /// col = threadIDInGroup |
| static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) { |
| auto dim = getAffineDimExpr(0, ctx); |
| AffineExpr groupID = dim.floorDiv(4); |
| AffineExpr threadIDInGroup = dim % 4; |
| return {RowColIndexing{groupID, threadIDInGroup}, |
| RowColIndexing{groupID + 8, threadIDInGroup}}; |
| } |
| |
| /// From the NVIDIA doc: |
| /// groupID = %laneid >> 2 |
| /// threadIDInGroup = %laneid % 4 |
| /// row = threadIDInGroup |
| /// col = groupID |
| static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) { |
| auto dim = getAffineDimExpr(0, ctx); |
| AffineExpr groupID = dim.floorDiv(4); |
| AffineExpr threadIDInGroup = dim % 4; |
| return {RowColIndexing{threadIDInGroup, groupID}}; |
| } |
| |
| /// From the NVIDIA doc: |
| /// groupID = %laneid >> 2 |
| /// threadIDInGroup = %laneid % 4 |
| /// row = groupID for c0 and c1 |
| /// groupID + 8 for c2 and c3 |
| /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3} |
| static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) { |
| auto dim = getAffineDimExpr(0, ctx); |
| AffineExpr groupID = dim.floorDiv(4); |
| AffineExpr threadIDInGroup = dim % 4; |
| return {RowColIndexing{groupID, threadIDInGroup * 2 + 0}, |
| RowColIndexing{groupID, threadIDInGroup * 2 + 1}, |
| RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, |
| RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}}; |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // m16n8k16 f16 case. |
| //===--------------------------------------------------------------------===// |
| /// From the NVIDIA doc: |
| /// groupID = %laneid >> 2 |
| /// threadIDInGroup = %laneid % 4 |
| /// |
| /// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6 |
| /// groupID + 8 Otherwise |
| /// |
| /// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4 |
| /// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4 |
| static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) { |
| auto dim = getAffineDimExpr(0, ctx); |
| AffineExpr groupID = dim.floorDiv(4); |
| AffineExpr threadIDInGroup = dim % 4; |
| // clang-format off |
| return { |
| RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0 |
| RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1 |
| RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2 |
| RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3 |
| RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4 |
| RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5 |
| RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6 |
| RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7 |
| }; |
| // clang-format on |
| } |
| |
| /// From the NVIDIA doc: |
| /// groupID = %laneid >> 2 |
| /// threadIDInGroup = %laneid % 4 |
| /// |
| /// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2 |
| /// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2 |
| /// |
| /// col = groupID |
| static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) { |
| auto dim = getAffineDimExpr(0, ctx); |
| AffineExpr groupID = dim.floorDiv(4); |
| AffineExpr threadIDInGroup = dim % 4; |
| // clang-format off |
| return { |
| RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0 |
| RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1 |
| RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2 |
| RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3 |
| }; |
| // clang-format on |
| } |
| |
| /// From the NVIDIA doc: |
| /// groupID = %laneid >> 2 |
| /// threadIDInGroup = %laneid % 4 |
| /// |
| /// row = groupID for ci where i < 2 |
| /// groupID + 8 for ci where i >= 2 |
| /// |
| /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3} |
| static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) { |
| auto dim = getAffineDimExpr(0, ctx); |
| AffineExpr groupID = dim.floorDiv(4); |
| AffineExpr threadIDInGroup = dim % 4; |
| // clang-format off |
| return { |
| RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0 |
| RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1 |
| RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2 |
| RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3 |
| }; |
| // clang-format on |
| } |
| |
| //===--------------------------------------------------------------------===// |
| /// Helper functions to create customizable load and stores operations. The |
| /// specific shapes of each MMA instruction are passed via the |
| /// IndexCalculator callback. |
| //===--------------------------------------------------------------------===// |
| /// Build a list of memref.load operations indexed at `(row, col)` indices |
| /// that make sense for a particular MMA instruction and specified via the |
| /// IndexCalculator callback. |
| SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc, |
| OpFoldResult laneId, Value memref, |
| IndexCalculator indexFn); |
| |
| /// Perform a distributed load of a vector operand of `vectorShape` for a |
| /// particular MMA instruction whose `(row, col)` indices are specified via |
| /// the IndexCalculator callback. Each `laneId` loads the subportion of the |
| /// data that makes sense for the particular MMA operation. |
| /// The `vectorShape` matches existing NVGPU dialect op specification but |
| /// could also be flattened in the future if needed for simplification. |
| Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc, |
| OpFoldResult laneId, Value memref, |
| IndexCalculator indexFn, |
| ArrayRef<int64_t> vectorShape); |
| |
| /// Build a list of memref.store operations indexed at `(row, col)` indices |
| /// that make sense for a particular MMA instruction and specified via the |
| /// IndexCalculator callback. |
| SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc, |
| ValueRange toStore, |
| OpFoldResult laneId, Value memref, |
| IndexCalculator indexFn); |
| |
| /// Perform a distributed store of a vector operand of `vectorShape` for a |
| /// particular MMA instruction whose `(row, col)` indices are specified via |
| /// the IndexCalculator callback. Each `laneId` loads the subportion of the |
| /// data that makes sense for the particular MMA operation. |
| /// The `vectorShape` matches existing NVGPU dialect op specification but |
| /// could also be flattened in the future if needed for simplification. |
| SmallVector<Operation *> buildMmaSyncMemRefStoreOperand( |
| OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId, |
| Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape); |
| |
| OpBuilder &b; |
| Location loc; |
| OpFoldResult laneId; |
| }; |
| |
| //===--------------------------------------------------------------------===// |
| /// Helper functions to create customizable load and stores operations. The |
| /// specific shapes of each MMA instruction are passed via the |
| /// IndexCalculator callback. |
| //===--------------------------------------------------------------------===// |
| |
| template <typename ApplyFn, typename ReduceFn> |
| static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, |
| ReduceFn reduceFn) { |
| VectorType vectorType = vector.getType().cast<VectorType>(); |
| auto vectorShape = vectorType.getShape(); |
| auto strides = computeStrides(vectorShape); |
| for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) { |
| auto indices = delinearize(idx, strides); |
| reduceFn(applyFn(vector, idx, indices), idx, indices); |
| } |
| } |
| |
| SmallVector<Value> MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc, |
| OpFoldResult laneId, |
| Value memref, |
| IndexCalculator indexFn) { |
| auto aff = [&](AffineExpr e) { |
| return affine::makeComposedFoldedAffineApply(b, loc, e, laneId); |
| }; |
| SmallVector<Value> res; |
| SmallVector<RowColIndexing> indexings = indexFn(b.getContext()); |
| for (auto indexing : indexings) { |
| Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row())); |
| Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col())); |
| auto load = b.create<memref::LoadOp>(loc, memref, ValueRange{row, col}); |
| res.push_back(load); |
| } |
| return res; |
| } |
| |
| Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand( |
| OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, |
| IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) { |
| auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn); |
| |
| Type elementType = getElementTypeOrSelf(memref.getType()); |
| auto vt = VectorType::get(vectorShape, elementType); |
| Value res = b.create<vector::SplatOp>(loc, vt, loads[0]); |
| foreachIndividualVectorElement( |
| res, |
| /*applyFn=*/ |
| [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { |
| return loads[linearIdx]; |
| }, |
| /*reduceFn=*/ |
| [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { |
| res = b.create<vector::InsertOp>(loc, v, res, indices); |
| }); |
| |
| return res; |
| } |
| |
| SmallVector<Operation *> |
| MmaSyncBuilder::buildMemRefStores(OpBuilder &b, Location loc, |
| ValueRange toStore, OpFoldResult laneId, |
| Value memref, IndexCalculator indexFn) { |
| auto aff = [&](AffineExpr e) { |
| return affine::makeComposedFoldedAffineApply(b, loc, e, laneId); |
| }; |
| SmallVector<Operation *> res; |
| for (auto [indexing, val] : |
| llvm::zip_equal(indexFn(b.getContext()), toStore)) { |
| Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row())); |
| Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col())); |
| Operation *store = |
| b.create<memref::StoreOp>(loc, val, memref, ValueRange{row, col}); |
| res.push_back(store); |
| } |
| return res; |
| } |
| |
| SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand( |
| OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId, |
| Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) { |
| SmallVector<Value> toStore; |
| toStore.reserve(32); |
| foreachIndividualVectorElement( |
| vectorToStore, |
| /*applyFn=*/ |
| [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { |
| return b.create<vector::ExtractOp>(loc, vectorToStore, indices); |
| }, |
| /*reduceFn=*/ |
| [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { |
| toStore.push_back(v); |
| }); |
| return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn); |
| } |
| |
| static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, |
| SmallVector<int64_t>> |
| makeVectorShapes(ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, |
| ArrayRef<int64_t> res) { |
| SmallVector<int64_t> vlhs{lhs.begin(), lhs.end()}; |
| SmallVector<int64_t> vrhs{rhs.begin(), rhs.end()}; |
| SmallVector<int64_t> vres{res.begin(), res.end()}; |
| return std::make_tuple(vlhs, vrhs, vres); |
| } |
| |
| FailureOr<MmaSyncBuilder::MmaSyncInfo> |
| MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape, |
| TypeRange elementalTypes) { |
| // TODO: Tablegen all this. |
| Type f16 = b.getF16Type(); |
| Type f32 = b.getF32Type(); |
| if (opShape == ArrayRef<int64_t>{16, 8, 4} && |
| elementalTypes == TypeRange{f32, f32, f32}) { |
| return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs, |
| &MmaSyncBuilder::m16n8k4tf32Rhs, |
| &MmaSyncBuilder::m16n8k4tf32Res), |
| makeVectorShapes({2, 1}, {1, 1}, {2, 2}), |
| SmallVector<int64_t>{opShape.begin(), opShape.end()}, |
| /*tf32Enabled=*/true}; |
| } |
| // This is the version with f16 accumulation. |
| // TODO: version with f32 accumulation. |
| if (opShape == ArrayRef<int64_t>{16, 8, 16} && |
| elementalTypes == TypeRange{f16, f16, f16}) { |
| return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs, |
| &MmaSyncBuilder::m16n8k16f16Rhs, |
| &MmaSyncBuilder::m16n8k16f16Res), |
| makeVectorShapes({4, 2}, {2, 2}, {2, 2}), |
| SmallVector<int64_t>{opShape.begin(), opShape.end()}, |
| /*tf32Enabled=*/false}; |
| } |
| return failure(); |
| } |
| |
| FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) { |
| Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get(); |
| Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get(); |
| Value resMemRef = linalgOp.getDpsInitOperand(0)->get(); |
| assert(lhsMemRef.getType().cast<MemRefType>().getRank() == 2 && |
| "expected lhs to be a 2D memref"); |
| assert(rhsMemRef.getType().cast<MemRefType>().getRank() == 2 && |
| "expected rhs to be a 2D memref"); |
| assert(resMemRef.getType().cast<MemRefType>().getRank() == 2 && |
| "expected res to be a 2D memref"); |
| |
| int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0]; |
| int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1]; |
| int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1]; |
| Type lhsType = getElementTypeOrSelf(lhsMemRef.getType()); |
| Type rhsType = getElementTypeOrSelf(rhsMemRef.getType()); |
| Type resType = getElementTypeOrSelf(resMemRef.getType()); |
| |
| FailureOr<MmaSyncInfo> maybeInfo = |
| getIndexCalculators({m, n, k}, {lhsType, rhsType, resType}); |
| if (failed(maybeInfo)) |
| return failure(); |
| |
| MmaSyncInfo info = *maybeInfo; |
| auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns; |
| auto [lhsShape, rhsShape, resShape] = info.vectorShapes; |
| Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef, |
| lhsIndexFn, lhsShape); |
| Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef, |
| rhsIndexFn, rhsShape); |
| Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef, |
| resIndexFn, resShape); |
| res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape, |
| info.tf32Enabled); |
| buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn, |
| resShape); |
| return res.getDefiningOp(); |
| } |
| |
| DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne( |
| transform::TransformRewriter &rewriter, LinalgOp linalgOp, |
| transform::ApplyToEachResultList &results, |
| transform::TransformState &state) { |
| bool fail = true; |
| // TODO: more robust detection of matmulOp, with transposes etc. |
| if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) { |
| Location loc = linalgOp.getLoc(); |
| // TODO: more robust computation of laneId, for now assume a single warp. |
| Value laneId = rewriter.create<gpu::ThreadIdOp>( |
| loc, rewriter.getIndexType(), gpu::Dimension::x); |
| if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp))) |
| fail = false; |
| } |
| |
| if (fail) { |
| DiagnosedSilenceableFailure diag = emitSilenceableError() |
| << "unsupported target op: " << linalgOp; |
| diag.attachNote(linalgOp->getLoc()) << "target op"; |
| return diag; |
| } |
| |
| rewriter.eraseOp(linalgOp); |
| return DiagnosedSilenceableFailure::success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Hopper builders. |
| //===----------------------------------------------------------------------===// |
| |
| /// Helper to create the base Hopper-specific operations that are reused in |
| /// various other places. |
| struct HopperBuilder { |
| HopperBuilder(RewriterBase &rewriter, Location loc) |
| : rewriter(rewriter), loc(loc) {} |
| |
| TypedValue<nvgpu::MBarrierGroupType> |
| buildAndInitBarrierInSharedMemory(OpFoldResult numThreads); |
| |
| /// Create tma descriptor op to initiate transfer from global to shared |
| /// memory. This must be done before the launch op, on the host. |
| TypedValue<nvgpu::TensorMapDescriptorType> |
| buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref, |
| gpu::LaunchOp launchOp); |
| |
| /// Build a tma load from global memory to shared memory using `barrier` to |
| /// synchronize. Return the number of bytes that will be transferred. |
| OpFoldResult |
| buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc, |
| TypedValue<MemRefType> sharedMemref, |
| TypedValue<nvgpu::MBarrierGroupType> barrier, |
| SmallVectorImpl<Operation *> &loadOps); |
| void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier, |
| ArrayRef<OpFoldResult> sizes); |
| |
| /// If threadIdx.x == 0 does TMA request + wait, else just wait. |
| /// Return the operation that performs the transfer on thread0. |
| // TODO: In the future, don't hardcode to thread 0 but elect a leader. |
| SmallVector<Operation *> buildPredicateLoadsOnThread0( |
| ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors, |
| ArrayRef<TypedValue<MemRefType>> sharedMemBuffers, |
| TypedValue<nvgpu::MBarrierGroupType> barrier); |
| |
| void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier); |
| |
| RewriterBase &rewriter; |
| Location loc; |
| }; |
| |
| SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0( |
| ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors, |
| ArrayRef<TypedValue<MemRefType>> sharedMemBuffers, |
| TypedValue<nvgpu::MBarrierGroupType> barrier) { |
| SmallVector<Operation *> loadOps; |
| Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| Value tidx = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x); |
| Value cond = |
| rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero); |
| // clang-format off |
| rewriter.create<scf::IfOp>( |
| /*location=*/loc, |
| /*conditional=*/cond, |
| /*thenBuilder=*/ |
| [&](OpBuilder &lb, Location loc) { |
| SmallVector<OpFoldResult> sizes; |
| sizes.reserve(globalDescriptors.size()); |
| for (auto [desc, shmem] : llvm::zip_equal( |
| globalDescriptors, sharedMemBuffers)) { |
| OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps); |
| sizes.push_back(sz); |
| } |
| // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load. |
| // This may or may not have perf implications. |
| buildBarrierArriveTx(barrier, sizes); |
| rewriter.create<scf::YieldOp>(loc); |
| }, |
| /*elseBuilder=*/ |
| [&](OpBuilder &lb, Location loc) { |
| // TODO: is this for no-thread divergence? |
| // Should we just yield the size and hoist? |
| buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0)); |
| rewriter.create<scf::YieldOp>(loc); |
| }); |
| // clang-format on |
| return loadOps; |
| } |
| |
| static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) { |
| return gpu::AddressSpaceAttr::get( |
| b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); |
| // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace)); |
| } |
| |
| TypedValue<nvgpu::MBarrierGroupType> |
| HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) { |
| auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter); |
| Value barrier = rewriter.create<nvgpu::MBarrierCreateOp>( |
| loc, |
| nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace)); |
| Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| rewriter.create<nvgpu::MBarrierInitOp>( |
| loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads), |
| zero, Value()); |
| rewriter.create<gpu::BarrierOp>(loc); |
| return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier); |
| } |
| |
| TypedValue<nvgpu::TensorMapDescriptorType> |
| HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref, |
| gpu::LaunchOp launchOp) { |
| OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.setInsertionPoint(launchOp); |
| Value unrankedMemRef = rewriter.create<memref::CastOp>( |
| loc, |
| UnrankedMemRefType::get(memref.getType().getElementType(), |
| memref.getType().getMemorySpace()), |
| memref); |
| SmallVector<OpFoldResult> mixedSizes = |
| memref::getMixedSizes(rewriter, loc, memref); |
| SmallVector<Value> sizes = |
| getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes); |
| |
| auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter); |
| Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>( |
| loc, |
| nvgpu::TensorMapDescriptorType::get( |
| rewriter.getContext(), |
| MemRefType::Builder(memref.getType()) |
| .setMemorySpace(sharedMemorySpace), |
| TensorMapSwizzleKind::SWIZZLE_NONE, |
| TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO, |
| TensorMapInterleaveKind::INTERLEAVE_NONE), |
| unrankedMemRef, sizes); |
| return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc); |
| } |
| |
| OpFoldResult HopperBuilder::buildTmaAsyncLoad( |
| TypedValue<nvgpu::TensorMapDescriptorType> globalDesc, |
| TypedValue<MemRefType> sharedMemref, |
| TypedValue<nvgpu::MBarrierGroupType> barrier, |
| SmallVectorImpl<Operation *> &loadOps) { |
| MLIRContext *ctx = rewriter.getContext(); |
| Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>( |
| loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero, |
| Value()); |
| loadOps.push_back(loadOp); |
| auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref); |
| SmallVector<AffineExpr> symbols(mixedSizes.size()); |
| bindSymbolsList(ctx, llvm::MutableArrayRef{symbols}); |
| AffineExpr prodExprInBytes = |
| computeProduct(ctx, symbols) * |
| (sharedMemref.getType().getElementTypeBitWidth() / 8); |
| auto res = affine::makeComposedFoldedAffineApply(rewriter, loc, |
| prodExprInBytes, mixedSizes); |
| return res; |
| } |
| |
| void HopperBuilder::buildBarrierArriveTx( |
| TypedValue<nvgpu::MBarrierGroupType> barrier, |
| ArrayRef<OpFoldResult> mixedSizes) { |
| assert(!mixedSizes.empty() && "expecte non-empty sizes"); |
| MLIRContext *ctx = rewriter.getContext(); |
| SmallVector<AffineExpr> symbols(mixedSizes.size()); |
| bindSymbolsList(ctx, llvm::MutableArrayRef{symbols}); |
| AffineExpr sumExpr = computeSum(ctx, symbols); |
| OpFoldResult size = |
| affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes); |
| Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size); |
| Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero, |
| Value()); |
| } |
| |
| void HopperBuilder::buildTryWaitParity( |
| TypedValue<nvgpu::MBarrierGroupType> barrier) { |
| Value parity = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| // 10M is an arbitrary, not too small or too big number to specify the number |
| // of ticks before retry. |
| // TODO: hoist this in a default dialect constant. |
| Value ticksBeforeRetry = |
| rewriter.create<arith::ConstantIndexOp>(loc, 10000000); |
| Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity, |
| ticksBeforeRetry, zero); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // RewriteCopyAsTmaOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Helper to create the tma operations corresponding to `linalg::CopyOp`. |
| struct CopyBuilder : public HopperBuilder { |
| CopyBuilder(RewriterBase &rewriter, Location loc) |
| : HopperBuilder(rewriter, loc) {} |
| |
| SmallVector<Operation *> rewrite(ArrayRef<Operation *> copyOps); |
| }; |
| |
| SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) { |
| MLIRContext *ctx = rewriter.getContext(); |
| if (copyOps.empty()) |
| return SmallVector<Operation *>(); |
| |
| auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>(); |
| assert(launchOp && "expected launch op"); |
| |
| // 1. Init a barrier object in shared memory. |
| OpBuilder::InsertionGuard g(rewriter); |
| rewriter.setInsertionPoint(copyOps.front()); |
| AffineExpr bx, by, bz; |
| bindSymbols(ctx, bx, by, bz); |
| AffineExpr prod = computeProduct(ctx, ArrayRef<AffineExpr>{bx, by, bz}); |
| OpFoldResult numThreads = affine::makeComposedFoldedAffineApply( |
| rewriter, loc, prod, |
| ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(), |
| launchOp.getBlockSizeZ()}); |
| |
| TypedValue<nvgpu::MBarrierGroupType> barrier = |
| buildAndInitBarrierInSharedMemory(numThreads); |
| |
| SmallVector<TypedValue<MemRefType>> shmems; |
| SmallVector<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescs; |
| for (Operation *op : copyOps) { |
| auto copyOp = cast<linalg::CopyOp>(op); |
| auto inMemRef = |
| cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get()); |
| assert(inMemRef.getType().getRank() == 2 && |
| "expected in to be a 2D memref"); |
| |
| // 2. Build global memory descriptor. |
| TypedValue<nvgpu::TensorMapDescriptorType> globalDesc = |
| buildGlobalMemRefDescriptor(inMemRef, launchOp); |
| globalDescs.push_back(globalDesc); |
| |
| // 3. Shared memory and descriptor for the tmp array. |
| auto shmem = |
| cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get()); |
| shmems.push_back(shmem); |
| } |
| |
| // 4. Load in from global memory to shared memory using tma. |
| OpBuilder::InsertionGuard g2(rewriter); |
| rewriter.setInsertionPoint(copyOps.front()); |
| SmallVector<Operation *> results = |
| buildPredicateLoadsOnThread0(globalDescs, shmems, barrier); |
| |
| // 5. Spin-loop until data is ready. |
| buildTryWaitParity(barrier); |
| |
| // 6. Erase the ops that have now been rewritten. |
| for (Operation *op : copyOps) |
| rewriter.eraseOp(op); |
| |
| return results; |
| } |
| |
| DiagnosedSilenceableFailure |
| transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter, |
| transform::TransformResults &results, |
| transform::TransformState &state) { |
| auto payloadOps = state.getPayloadOps(getTarget()); |
| gpu::LaunchOp commonLaunchOp; |
| Operation *firstOp, *failingOp; |
| if (llvm::any_of(payloadOps, [&](Operation *op) { |
| if (!commonLaunchOp) { |
| commonLaunchOp = op->getParentOfType<gpu::LaunchOp>(); |
| firstOp = op; |
| } |
| auto fail = !op->getParentOfType<gpu::LaunchOp>() || |
| commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() || |
| !isa<linalg::CopyOp>(op); |
| if (fail) |
| failingOp = op; |
| return fail; |
| })) { |
| DiagnosedSilenceableFailure diag = |
| emitSilenceableError() |
| << "target ops must be linalg::CopyOp nested under a common " |
| "gpu.LaunchOp to be rewritten because the tma descriptors need to " |
| "be created on the host.\nBut got: " |
| << *firstOp << "\nand " << *failingOp; |
| return diag; |
| } |
| |
| // TODO: more robust detection of copy, with transposes etc. |
| CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps)); |
| |
| return DiagnosedSilenceableFailure::success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Transform op registration |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| class NVGPUTransformDialectExtension |
| : public transform::TransformDialectExtension< |
| NVGPUTransformDialectExtension> { |
| public: |
| NVGPUTransformDialectExtension() { |
| declareGeneratedDialect<arith::ArithDialect>(); |
| declareGeneratedDialect<affine::AffineDialect>(); |
| declareGeneratedDialect<nvgpu::NVGPUDialect>(); |
| declareGeneratedDialect<NVVM::NVVMDialect>(); |
| declareGeneratedDialect<vector::VectorDialect>(); |
| registerTransformOps< |
| #define GET_OP_LIST |
| #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc" |
| >(); |
| } |
| }; |
| } // namespace |
| |
| #define GET_OP_CLASSES |
| #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc" |
| |
| void mlir::nvgpu::registerTransformDialectExtension(DialectRegistry ®istry) { |
| registry.addExtensions<NVGPUTransformDialectExtension>(); |
| } |