blob: 4e256aea0be37aa38a15af7604152b44be3b41b1 [file] [log] [blame]
//===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h"
using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::nvgpu;
using namespace mlir::NVVM;
using namespace mlir::transform;
#define DEBUG_TYPE "nvgpu-transforms"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define DBGSNL() (llvm::dbgs() << "\n")
#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
//===----------------------------------------------------------------------===//
// Apply...ConversionPatternsOp
//===----------------------------------------------------------------------===//
void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
/// device-side async tokens cannot be materialized in nvvm. We just
/// convert them to a dummy i32 type in order to easily drop them during
/// conversion.
populateGpuMemorySpaceAttributeConversions(
llvmTypeConverter, [](gpu::AddressSpace space) -> unsigned {
switch (space) {
case gpu::AddressSpace::Global:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kGlobalMemorySpace);
case gpu::AddressSpace::Workgroup:
return static_cast<unsigned>(
NVVM::NVVMMemorySpace::kSharedMemorySpace);
case gpu::AddressSpace::Private:
return 0;
}
llvm_unreachable("unknown address space enum value");
return 0;
});
llvmTypeConverter.addConversion(
[&](nvgpu::DeviceAsyncTokenType type) -> Type {
return llvmTypeConverter.convertType(
IntegerType::get(type.getContext(), 32));
});
llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
return llvmTypeConverter.convertType(
IntegerType::get(type.getContext(), 64));
});
llvmTypeConverter.addConversion(
[&](nvgpu::WarpgroupAccumulatorType type) -> Type {
Type elemType = type.getFragmented().getElementType();
int64_t sizeM = type.getFragmented().getDimSize(0);
int64_t sizeN = type.getFragmented().getDimSize(1);
unsigned numMembers;
if (elemType.isF32() || elemType.isInteger(32))
numMembers = sizeN / 2;
else if (elemType.isF16())
numMembers = sizeN / 4;
else
llvm_unreachable("unsupported type for warpgroup accumulator");
SmallVector<Type> innerStructBody;
for (unsigned i = 0; i < numMembers; i++)
innerStructBody.push_back(elemType);
auto innerStructType = LLVM::LLVMStructType::getLiteral(
type.getContext(), innerStructBody);
SmallVector<Type> structBody;
for (int i = 0; i < sizeM; i += kWgmmaSizeM)
structBody.push_back(innerStructType);
auto convertedType =
LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
return llvmTypeConverter.convertType(convertedType);
});
llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
return llvmTypeConverter.convertType(
getMBarrierMemrefType(type.getContext(), type));
});
llvmTypeConverter.addConversion(
[&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
return llvmTypeConverter.convertType(
IntegerType::get(type.getContext(), 64));
});
llvmTypeConverter.addConversion(
[&](nvgpu::TensorMapDescriptorType type) -> Type {
return LLVM::LLVMPointerType::get(type.getContext());
});
populateNVGPUToNVVMConversionPatterns(llvmTypeConverter, patterns);
}
LogicalResult
transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
transform::TypeConverterBuilderOpInterface builder) {
if (builder.getTypeConverterType() != "LLVMTypeConverter")
return emitOpError("expected LLVMTypeConverter");
return success();
}
//===---------------------------------------------------------------------===//
// CreateAsyncGroupsOp
//===---------------------------------------------------------------------===//
void transform::CreateAsyncGroupsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::consumesHandle(getTarget(), effects);
transform::producesHandle(getResult(), effects);
transform::modifiesPayload(effects);
}
DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne(
TransformRewriter &rewriter, Operation *target,
ApplyToEachResultList &results, TransformState &state) {
nvgpu::createAsyncGroups(rewriter, target, getBypassL1());
results.push_back(target);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// PipelineSharedMemoryCopiesOp
//===----------------------------------------------------------------------===//
/// Returns true if the given type has the default memory space.
static bool hasDefaultMemorySpace(BaseMemRefType type) {
return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0;
}
/// Returns true if the given type has the shared (workgroup) memory space.
static bool hasSharedMemorySpace(BaseMemRefType type) {
auto space =
dyn_cast_if_present<gpu::AddressSpaceAttr>(type.getMemorySpace());
return space &&
space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
}
/// Returns the value produced by a load from the default memory space. Returns
/// null if the operation is not such a load.
static Value getValueLoadedFromGlobal(Operation *op) {
// TODO: consider an interface or leveraging the memory effects interface.
auto load = dyn_cast<vector::TransferReadOp>(op);
if (!load)
return nullptr;
auto loadType = dyn_cast<MemRefType>(load.getSource().getType());
if (!loadType || !hasDefaultMemorySpace(loadType))
return nullptr;
return load;
}
/// Returns true if the operation is storing the given value into shared memory.
static bool isStoreToShared(Operation *op, Value v) {
// TOD: consider an interface or leveraging the memory effects interface.
auto store = dyn_cast<vector::TransferWriteOp>(op);
if (!store || store.getVector() != v)
return false;
auto storeType = dyn_cast<MemRefType>(store.getSource().getType());
return storeType || hasSharedMemorySpace(storeType);
}
/// Returns true if the operation is a load from the default memory space the
/// result of which is only stored into the shared memory space.
static bool isLoadFromGlobalStoredToShared(Operation *op) {
Value loaded = getValueLoadedFromGlobal(op);
if (!loaded || !loaded.hasOneUse())
return false;
return isStoreToShared(*loaded.getUsers().begin(), loaded);
}
/// Populate `ops` with the set of operations that belong to the stage 0 of the
/// pipelined version of the given loop when pipelining copies to shared memory.
/// Specifically, this collects:
///
/// 1. all loads from global memory, both sync and async;
/// 2. the barriers for async loads.
///
/// In particular, barriers are omitted if they do not dominate at least one
/// async load for which there is not yet a barrier.
static LogicalResult
collectStage0PipeliningOps(scf::ForOp forOp,
llvm::SmallPtrSet<Operation *, 16> &ops) {
llvm::SmallPtrSet<Operation *, 4> barriers;
for (Operation &op : *forOp.getBody()) {
// Bail on nested ops for now.
if (op.getNumRegions() > 0)
return failure();
if (isa<gpu::BarrierOp>(op)) {
barriers.insert(&op);
continue;
}
if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
ops.insert(&op);
ops.insert(std::make_move_iterator(barriers.begin()),
std::make_move_iterator(barriers.end()));
assert(barriers.empty() &&
"expected to have moved the barriers into another set");
continue;
}
if (isLoadFromGlobalStoredToShared(&op)) {
ops.insert(&op);
continue;
}
}
return success();
}
/// Hook for the loop pipeliner that sets the "num groups in flight" attribute
/// of async wait operations corresponding to pipelined shared memory copies.
// TODO: this currently assumes that there are no groups that could be in flight
// in the existing code.
static void
setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op,
scf::PipeliningOption::PipelinerPart part,
unsigned iteration, unsigned depth) {
// Based on the order of copies within the loop we need to set the number
// of copies in flight, unless it is already set.
auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op);
if (!waitOp || waitOp.getNumGroups())
return;
int numGroupInFlight = 0;
if (part == scf::PipeliningOption::PipelinerPart::Kernel ||
part == scf::PipeliningOption::PipelinerPart::Prologue) {
numGroupInFlight = depth - 1;
} else {
// By construction there should be no wait op in the prologue as all the
// wait should be in the last stage.
assert(part == scf::PipeliningOption::PipelinerPart::Epilogue);
// Based on the schedule we pick we know how many groups are in flight for
// each iteration of the epilogue.
numGroupInFlight = depth - 1 - iteration;
}
waitOp.setNumGroups(numGroupInFlight);
}
/// Hook for the loop pipeliner that populates `ops` with the stage information
/// as follows:
///
/// - operations in `stage0Ops` (typically loads from global memory and
/// related barriers) are at stage 0;
/// - operations in the backward slice of any stage0Ops are all at stage 0;
/// - other operations are at stage `depth`;
/// - the internal order of the pipelined loop has ops at stage `depth` first,
/// then those at stage 0, with relative order within each group preserved.
///
static void getPipelineStages(
scf::ForOp forOp,
std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) {
SetVector<Operation *> dependencies;
BackwardSliceOptions options([&](Operation *visited) {
return visited->getBlock() == forOp.getBody();
});
options.inclusive = true;
for (Operation &op : forOp.getBody()->getOperations()) {
if (stage0Ops.contains(&op))
getBackwardSlice(&op, &dependencies, options);
}
for (Operation &op : forOp.getBody()->getOperations()) {
if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
opsWithPipelineStages.emplace_back(&op, depth);
}
for (Operation &op : forOp.getBody()->getOperations()) {
if (dependencies.contains(&op))
opsWithPipelineStages.emplace_back(&op, 0);
}
}
/// Hook for the loop pipeliner. Replaces op with a predicated version and
/// returns the resulting operation. Returns the original op if the predication
/// isn't necessary for the given op. Returns null if predication is needed but
/// not supported.
static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter,
Operation *op, Value predicate) {
// Some operations may be fine to execute "speculatively" more times than the
// original number of iterations, in particular side-effect free operations
// and barriers, even if they cannot be predicated.
if (isMemoryEffectFree(op) ||
isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
nvgpu::DeviceAsyncWaitOp>(op)) {
return op;
}
// Otherwise, only async copies can currently be predicated.
auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
if (!asyncCopyOp)
return nullptr;
// Create srcElement Value based on `predicate`. The next lines generate
// the following code:
//
// srcElement = (pred) ? prevSrcElements : 0;
//
Location loc = asyncCopyOp->getLoc();
Value dstElements =
rewriter.create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr());
Value originalSrcElement =
asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
Value c0Index = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto srcElements = rewriter.create<arith::SelectOp>(
loc, predicate, originalSrcElement, c0Index);
auto asyncCopyZeroFillOp = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
UnitAttr());
rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
return asyncCopyZeroFillOp;
}
/// Applies loop pipelining with the given depth to the given loop so that
/// copies into the shared memory are pipelined. Doesn't affect other loops.
/// Returns a pair containing the error state and the pipelined op, the latter
/// being null in case of any failure. The error state contains a definite error
/// if the IR has been modified and a silenceable error otherwise.
static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth,
bool epiloguePeeling) {
llvm::SmallPtrSet<Operation *, 16> stage0Ops;
if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) {
return std::make_tuple(
emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining"),
scf::ForOp());
}
if (stage0Ops.empty()) {
return std::make_tuple(
emitSilenceableFailure(forOp, "no shared memory copy"), scf::ForOp());
}
scf::PipeliningOption options;
unsigned maxDepth = depth;
auto setAnnotation = [&](Operation *op,
scf::PipeliningOption::PipelinerPart part,
unsigned iteration) {
return setAsyncWaitGroupsInFlight(rewriter, op, part, iteration, maxDepth);
};
options.getScheduleFn =
[&](scf::ForOp schedulingFor,
std::vector<std::pair<Operation *, unsigned>> &ops) {
if (schedulingFor != forOp)
return;
return getPipelineStages(forOp, ops, maxDepth, stage0Ops);
};
options.annotateFn = setAnnotation;
if (!epiloguePeeling) {
options.peelEpilogue = false;
options.predicateFn = replaceOpWithPredicatedOp;
}
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(forOp);
bool modifiedIR;
FailureOr<scf::ForOp> maybePipelined =
pipelineForLoop(rewriter, forOp, options, &modifiedIR);
if (succeeded(maybePipelined)) {
return std::make_tuple(DiagnosedSilenceableFailure::success(),
*maybePipelined);
}
return std::make_tuple(
modifiedIR
? DiagnosedSilenceableFailure::definiteFailure()
: emitSilenceableFailure(forOp, "pipelining preconditions failed"),
scf::ForOp());
}
DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne(
TransformRewriter &rewriter, scf::ForOp forOp,
ApplyToEachResultList &results, TransformState &state) {
auto [diag, pipelined] = pipelineForSharedCopies(
rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue());
if (diag.succeeded()) {
results.push_back(pipelined);
return DiagnosedSilenceableFailure::success();
}
if (diag.isDefiniteFailure()) {
auto diag = emitDefiniteFailure("irreversible pipelining failure");
if (!getPeelEpilogue()) {
diag.attachNote(forOp->getLoc()) << "couldn't predicate?";
diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName();
}
return diag;
}
return std::move(diag);
}
//===----------------------------------------------------------------------===//
// RewriteMatmulAsMmaSyncOp
//===----------------------------------------------------------------------===//
/// Helper struct to encode a pair of row/column indexings in the form of
/// affine expressions.
struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {
RowColIndexing(AffineExpr row, AffineExpr col)
: std::pair<AffineExpr, AffineExpr>(row, col) {}
AffineExpr row() const { return first; };
AffineExpr col() const { return second; };
void print(llvm::raw_ostream &os) const {
os << "- indexing: " << first << ", " << second;
}
};
/// Helper struct to provide a simple mapping from matmul operations to the
/// corresponding mma.sync operation. This is constrained to the case where the
/// matmul matches the mma.sync operation 1-1.
struct MmaSyncBuilder {
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
: b(b), loc(loc), laneId(laneId) {}
using IndexCalculator =
std::function<SmallVector<RowColIndexing>(MLIRContext *)>;
/// Create the mma.sync operation corresponding to `linalgOp` along with all
/// the supporting load/store and vector operations.
FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
private:
struct MmaSyncInfo {
std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>>
vectorShapes;
SmallVector<int64_t> mmaShape;
bool tf32Enabled;
};
/// Return the specific index calculator for the given `linalgOp` or failure
/// if the op is not supported. This is the toplevel switch that should just
/// be Tablegen'd in the future.
FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape,
TypeRange elementalTypes);
//===--------------------------------------------------------------------===//
// Instruction-specific row, column indexing expression builders.
// These should all be declaratively specified via Tablegen in the future.
// The Tablegen specification should be as straightforward as possible to
// only model the existing size and type combinations.
//===--------------------------------------------------------------------===//
//
// TODO: Tablegen all this.
//===--------------------------------------------------------------------===//
// m16n8k4 tf32 case.
//===--------------------------------------------------------------------===//
/// From the NVIDIA doc:
/// groupID = %laneid >> 2
/// threadIDInGroup = %laneid % 4
/// row = groupID for a0
/// groupID + 8 for a1
/// col = threadIDInGroup
static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) {
auto dim = getAffineDimExpr(0, ctx);
AffineExpr groupID = dim.floorDiv(4);
AffineExpr threadIDInGroup = dim % 4;
return {RowColIndexing{groupID, threadIDInGroup},
RowColIndexing{groupID + 8, threadIDInGroup}};
}
/// From the NVIDIA doc:
/// groupID = %laneid >> 2
/// threadIDInGroup = %laneid % 4
/// row = threadIDInGroup
/// col = groupID
static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) {
auto dim = getAffineDimExpr(0, ctx);
AffineExpr groupID = dim.floorDiv(4);
AffineExpr threadIDInGroup = dim % 4;
return {RowColIndexing{threadIDInGroup, groupID}};
}
/// From the NVIDIA doc:
/// groupID = %laneid >> 2
/// threadIDInGroup = %laneid % 4
/// row = groupID for c0 and c1
/// groupID + 8 for c2 and c3
/// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) {
auto dim = getAffineDimExpr(0, ctx);
AffineExpr groupID = dim.floorDiv(4);
AffineExpr threadIDInGroup = dim % 4;
return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
RowColIndexing{groupID, threadIDInGroup * 2 + 1},
RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
}
//===--------------------------------------------------------------------===//
// m16n8k16 f16 case.
//===--------------------------------------------------------------------===//
/// From the NVIDIA doc:
/// groupID = %laneid >> 2
/// threadIDInGroup = %laneid % 4
///
/// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
/// groupID + 8 Otherwise
///
/// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4
/// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4
static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) {
auto dim = getAffineDimExpr(0, ctx);
AffineExpr groupID = dim.floorDiv(4);
AffineExpr threadIDInGroup = dim % 4;
// clang-format off
return {
RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3
RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4
RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5
RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6
RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7
};
// clang-format on
}
/// From the NVIDIA doc:
/// groupID = %laneid >> 2
/// threadIDInGroup = %laneid % 4
///
/// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2
/// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2
///
/// col = groupID
static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) {
auto dim = getAffineDimExpr(0, ctx);
AffineExpr groupID = dim.floorDiv(4);
AffineExpr threadIDInGroup = dim % 4;
// clang-format off
return {
RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0
RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1
RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2
RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3
};
// clang-format on
}
/// From the NVIDIA doc:
/// groupID = %laneid >> 2
/// threadIDInGroup = %laneid % 4
///
/// row = groupID for ci where i < 2
/// groupID + 8 for ci where i >= 2
///
/// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) {
auto dim = getAffineDimExpr(0, ctx);
AffineExpr groupID = dim.floorDiv(4);
AffineExpr threadIDInGroup = dim % 4;
// clang-format off
return {
RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3
};
// clang-format on
}
//===--------------------------------------------------------------------===//
/// Helper functions to create customizable load and stores operations. The
/// specific shapes of each MMA instruction are passed via the
/// IndexCalculator callback.
//===--------------------------------------------------------------------===//
/// Build a list of memref.load operations indexed at `(row, col)` indices
/// that make sense for a particular MMA instruction and specified via the
/// IndexCalculator callback.
SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc,
OpFoldResult laneId, Value memref,
const IndexCalculator &indexFn);
/// Perform a distributed load of a vector operand of `vectorShape` for a
/// particular MMA instruction whose `(row, col)` indices are specified via
/// the IndexCalculator callback. Each `laneId` loads the subportion of the
/// data that makes sense for the particular MMA operation.
/// The `vectorShape` matches existing NVGPU dialect op specification but
/// could also be flattened in the future if needed for simplification.
Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
OpFoldResult laneId, Value memref,
IndexCalculator indexFn,
ArrayRef<int64_t> vectorShape);
/// Build a list of memref.store operations indexed at `(row, col)` indices
/// that make sense for a particular MMA instruction and specified via the
/// IndexCalculator callback.
SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
ValueRange toStore,
OpFoldResult laneId, Value memref,
const IndexCalculator &indexFn);
/// Perform a distributed store of a vector operand of `vectorShape` for a
/// particular MMA instruction whose `(row, col)` indices are specified via
/// the IndexCalculator callback. Each `laneId` loads the subportion of the
/// data that makes sense for the particular MMA operation.
/// The `vectorShape` matches existing NVGPU dialect op specification but
/// could also be flattened in the future if needed for simplification.
SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
OpBuilder &b;
Location loc;
OpFoldResult laneId;
};
//===--------------------------------------------------------------------===//
/// Helper functions to create customizable load and stores operations. The
/// specific shapes of each MMA instruction are passed via the
/// IndexCalculator callback.
//===--------------------------------------------------------------------===//
template <typename ApplyFn, typename ReduceFn>
static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
ReduceFn reduceFn) {
VectorType vectorType = cast<VectorType>(vector.getType());
auto vectorShape = vectorType.getShape();
auto strides = computeStrides(vectorShape);
for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
auto indices = delinearize(idx, strides);
reduceFn(applyFn(vector, idx, indices), idx, indices);
}
}
SmallVector<Value>
MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
OpFoldResult laneId, Value memref,
const IndexCalculator &indexFn) {
auto aff = [&](AffineExpr e) {
return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
};
SmallVector<Value> res;
SmallVector<RowColIndexing> indexings = indexFn(b.getContext());
for (auto indexing : indexings) {
Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
auto load = b.create<memref::LoadOp>(loc, memref, ValueRange{row, col});
res.push_back(load);
}
return res;
}
Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
Type elementType = getElementTypeOrSelf(memref.getType());
auto vt = VectorType::get(vectorShape, elementType);
Value res = b.create<vector::SplatOp>(loc, vt, loads[0]);
foreachIndividualVectorElement(
res,
/*applyFn=*/
[&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
return loads[linearIdx];
},
/*reduceFn=*/
[&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
res = b.create<vector::InsertOp>(loc, v, res, indices);
});
return res;
}
SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores(
OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId,
Value memref, const IndexCalculator &indexFn) {
auto aff = [&](AffineExpr e) {
return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
};
SmallVector<Operation *> res;
for (auto [indexing, val] :
llvm::zip_equal(indexFn(b.getContext()), toStore)) {
Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
Operation *store =
b.create<memref::StoreOp>(loc, val, memref, ValueRange{row, col});
res.push_back(store);
}
return res;
}
SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
SmallVector<Value> toStore;
toStore.reserve(32);
foreachIndividualVectorElement(
vectorToStore,
/*applyFn=*/
[&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
return b.create<vector::ExtractOp>(loc, vectorToStore, indices);
},
/*reduceFn=*/
[&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
toStore.push_back(v);
});
return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
}
static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
SmallVector<int64_t>>
makeVectorShapes(ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs,
ArrayRef<int64_t> res) {
SmallVector<int64_t> vlhs{lhs.begin(), lhs.end()};
SmallVector<int64_t> vrhs{rhs.begin(), rhs.end()};
SmallVector<int64_t> vres{res.begin(), res.end()};
return std::make_tuple(vlhs, vrhs, vres);
}
FailureOr<MmaSyncBuilder::MmaSyncInfo>
MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
TypeRange elementalTypes) {
// TODO: Tablegen all this.
Type f16 = b.getF16Type();
Type f32 = b.getF32Type();
if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
elementalTypes == TypeRange{f32, f32, f32}) {
return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
&MmaSyncBuilder::m16n8k4tf32Rhs,
&MmaSyncBuilder::m16n8k4tf32Res),
makeVectorShapes({2, 1}, {1, 1}, {2, 2}),
SmallVector<int64_t>{opShape.begin(), opShape.end()},
/*tf32Enabled=*/true};
}
// This is the version with f16 accumulation.
// TODO: version with f32 accumulation.
if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
elementalTypes == TypeRange{f16, f16, f16}) {
return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
&MmaSyncBuilder::m16n8k16f16Rhs,
&MmaSyncBuilder::m16n8k16f16Res),
makeVectorShapes({4, 2}, {2, 2}, {2, 2}),
SmallVector<int64_t>{opShape.begin(), opShape.end()},
/*tf32Enabled=*/false};
}
return failure();
}
FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 &&
"expected lhs to be a 2D memref");
assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 &&
"expected rhs to be a 2D memref");
assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 &&
"expected res to be a 2D memref");
int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];
int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1];
int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1];
Type lhsType = getElementTypeOrSelf(lhsMemRef.getType());
Type rhsType = getElementTypeOrSelf(rhsMemRef.getType());
Type resType = getElementTypeOrSelf(resMemRef.getType());
FailureOr<MmaSyncInfo> maybeInfo =
getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
if (failed(maybeInfo))
return failure();
MmaSyncInfo info = *maybeInfo;
auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
lhsIndexFn, lhsShape);
Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
rhsIndexFn, rhsShape);
Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
resIndexFn, resShape);
res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
info.tf32Enabled);
buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
resShape);
return res.getDefiningOp();
}
DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
transform::TransformRewriter &rewriter, LinalgOp linalgOp,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
bool fail = true;
// TODO: more robust detection of matmulOp, with transposes etc.
if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
Location loc = linalgOp.getLoc();
// TODO: more robust computation of laneId, for now assume a single warp.
Value laneId = rewriter.create<gpu::ThreadIdOp>(
loc, rewriter.getIndexType(), gpu::Dimension::x);
if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
fail = false;
}
if (fail) {
DiagnosedSilenceableFailure diag = emitSilenceableError()
<< "unsupported target op: " << linalgOp;
diag.attachNote(linalgOp->getLoc()) << "target op";
return diag;
}
rewriter.eraseOp(linalgOp);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// Hopper builders.
//===----------------------------------------------------------------------===//
/// Helper to create the base Hopper-specific operations that are reused in
/// various other places.
struct HopperBuilder {
HopperBuilder(RewriterBase &rewriter, Location loc)
: rewriter(rewriter), loc(loc) {}
TypedValue<nvgpu::MBarrierGroupType>
buildAndInitBarrierInSharedMemory(OpFoldResult numThreads);
/// Create tma descriptor op to initiate transfer from global to shared
/// memory. This must be done before the launch op, on the host.
TypedValue<nvgpu::TensorMapDescriptorType>
buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
gpu::LaunchOp launchOp);
/// Build a tma load from global memory to shared memory using `barrier` to
/// synchronize. Return the number of bytes that will be transferred.
OpFoldResult
buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
TypedValue<MemRefType> sharedMemref,
TypedValue<nvgpu::MBarrierGroupType> barrier,
SmallVectorImpl<Operation *> &loadOps);
void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier,
ArrayRef<OpFoldResult> sizes);
/// If threadIdx.x == 0 does TMA request + wait, else just wait.
/// Return the operation that performs the transfer on thread0.
// TODO: In the future, don't hardcode to thread 0 but elect a leader.
SmallVector<Operation *> buildPredicateLoadsOnThread0(
ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
TypedValue<nvgpu::MBarrierGroupType> barrier);
void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier);
RewriterBase &rewriter;
Location loc;
};
SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0(
ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
TypedValue<nvgpu::MBarrierGroupType> barrier) {
SmallVector<Operation *> loadOps;
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value tidx = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
Value cond =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero);
// clang-format off
rewriter.create<scf::IfOp>(
/*location=*/loc,
/*conditional=*/cond,
/*thenBuilder=*/
[&](OpBuilder &lb, Location loc) {
SmallVector<OpFoldResult> sizes;
sizes.reserve(globalDescriptors.size());
for (auto [desc, shmem] : llvm::zip_equal(
globalDescriptors, sharedMemBuffers)) {
OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
sizes.push_back(sz);
}
// TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load.
// This may or may not have perf implications.
buildBarrierArriveTx(barrier, sizes);
rewriter.create<scf::YieldOp>(loc);
},
/*elseBuilder=*/
[&](OpBuilder &lb, Location loc) {
// TODO: is this for no-thread divergence?
// Should we just yield the size and hoist?
buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0));
rewriter.create<scf::YieldOp>(loc);
});
// clang-format on
return loadOps;
}
static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) {
return gpu::AddressSpaceAttr::get(
b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
// return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace));
}
TypedValue<nvgpu::MBarrierGroupType>
HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) {
auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
Value barrier = rewriter.create<nvgpu::MBarrierCreateOp>(
loc,
nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
rewriter.create<nvgpu::MBarrierInitOp>(
loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads),
zero, Value());
rewriter.create<gpu::BarrierOp>(loc);
return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
}
TypedValue<nvgpu::TensorMapDescriptorType>
HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
gpu::LaunchOp launchOp) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(launchOp);
Value unrankedMemRef = rewriter.create<memref::CastOp>(
loc,
UnrankedMemRefType::get(memref.getType().getElementType(),
memref.getType().getMemorySpace()),
memref);
SmallVector<OpFoldResult> mixedSizes =
memref::getMixedSizes(rewriter, loc, memref);
SmallVector<Value> sizes =
getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes);
auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>(
loc,
nvgpu::TensorMapDescriptorType::get(
rewriter.getContext(),
MemRefType::Builder(memref.getType())
.setMemorySpace(sharedMemorySpace),
TensorMapSwizzleKind::SWIZZLE_NONE,
TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
TensorMapInterleaveKind::INTERLEAVE_NONE),
unrankedMemRef, sizes);
return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
}
OpFoldResult HopperBuilder::buildTmaAsyncLoad(
TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
TypedValue<MemRefType> sharedMemref,
TypedValue<nvgpu::MBarrierGroupType> barrier,
SmallVectorImpl<Operation *> &loadOps) {
MLIRContext *ctx = rewriter.getContext();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>(
loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero,
Value(), Value());
loadOps.push_back(loadOp);
auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
SmallVector<AffineExpr> symbols(mixedSizes.size());
bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
AffineExpr prodExprInBytes =
computeProduct(ctx, symbols) *
(sharedMemref.getType().getElementTypeBitWidth() / 8);
auto res = affine::makeComposedFoldedAffineApply(rewriter, loc,
prodExprInBytes, mixedSizes);
return res;
}
void HopperBuilder::buildBarrierArriveTx(
TypedValue<nvgpu::MBarrierGroupType> barrier,
ArrayRef<OpFoldResult> mixedSizes) {
assert(!mixedSizes.empty() && "expecte non-empty sizes");
MLIRContext *ctx = rewriter.getContext();
SmallVector<AffineExpr> symbols(mixedSizes.size());
bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
AffineExpr sumExpr = computeSum(ctx, symbols);
OpFoldResult size =
affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes);
Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size);
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero,
Value());
}
void HopperBuilder::buildTryWaitParity(
TypedValue<nvgpu::MBarrierGroupType> barrier) {
Type i1 = rewriter.getI1Type();
Value parity = rewriter.create<LLVM::ConstantOp>(loc, i1, 0);
// 10M is an arbitrary, not too small or too big number to specify the number
// of ticks before retry.
// TODO: hoist this in a default dialect constant.
Value ticksBeforeRetry =
rewriter.create<arith::ConstantIndexOp>(loc, 10000000);
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
ticksBeforeRetry, zero);
}
//===----------------------------------------------------------------------===//
// RewriteCopyAsTmaOp
//===----------------------------------------------------------------------===//
/// Helper to create the tma operations corresponding to `linalg::CopyOp`.
struct CopyBuilder : public HopperBuilder {
CopyBuilder(RewriterBase &rewriter, Location loc)
: HopperBuilder(rewriter, loc) {}
SmallVector<Operation *> rewrite(ArrayRef<Operation *> copyOps);
};
SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
MLIRContext *ctx = rewriter.getContext();
if (copyOps.empty())
return SmallVector<Operation *>();
auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
assert(launchOp && "expected launch op");
// 1. Init a barrier object in shared memory.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(copyOps.front());
AffineExpr bx, by, bz;
bindSymbols(ctx, bx, by, bz);
AffineExpr prod = computeProduct(ctx, ArrayRef<AffineExpr>{bx, by, bz});
OpFoldResult numThreads = affine::makeComposedFoldedAffineApply(
rewriter, loc, prod,
ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
launchOp.getBlockSizeZ()});
TypedValue<nvgpu::MBarrierGroupType> barrier =
buildAndInitBarrierInSharedMemory(numThreads);
SmallVector<TypedValue<MemRefType>> shmems;
SmallVector<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescs;
for (Operation *op : copyOps) {
auto copyOp = cast<linalg::CopyOp>(op);
auto inMemRef =
cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
assert(inMemRef.getType().getRank() == 2 &&
"expected in to be a 2D memref");
// 2. Build global memory descriptor.
TypedValue<nvgpu::TensorMapDescriptorType> globalDesc =
buildGlobalMemRefDescriptor(inMemRef, launchOp);
globalDescs.push_back(globalDesc);
// 3. Shared memory and descriptor for the tmp array.
auto shmem =
cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
shmems.push_back(shmem);
}
// 4. Load in from global memory to shared memory using tma.
OpBuilder::InsertionGuard g2(rewriter);
rewriter.setInsertionPoint(copyOps.front());
SmallVector<Operation *> results =
buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
// 5. Spin-loop until data is ready.
buildTryWaitParity(barrier);
// 6. Erase the ops that have now been rewritten.
for (Operation *op : copyOps)
rewriter.eraseOp(op);
return results;
}
DiagnosedSilenceableFailure
transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto payloadOps = state.getPayloadOps(getTarget());
gpu::LaunchOp commonLaunchOp;
Operation *firstOp, *failingOp;
if (llvm::any_of(payloadOps, [&](Operation *op) {
if (!commonLaunchOp) {
commonLaunchOp = op->getParentOfType<gpu::LaunchOp>();
firstOp = op;
}
auto fail = !op->getParentOfType<gpu::LaunchOp>() ||
commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() ||
!isa<linalg::CopyOp>(op);
if (fail)
failingOp = op;
return fail;
})) {
DiagnosedSilenceableFailure diag =
emitSilenceableError()
<< "target ops must be linalg::CopyOp nested under a common "
"gpu.LaunchOp to be rewritten because the tma descriptors need to "
"be created on the host.\nBut got: "
<< *firstOp << "\nand " << *failingOp;
return diag;
}
// TODO: more robust detection of copy, with transposes etc.
CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
namespace {
class NVGPUTransformDialectExtension
: public transform::TransformDialectExtension<
NVGPUTransformDialectExtension> {
public:
NVGPUTransformDialectExtension() {
declareGeneratedDialect<arith::ArithDialect>();
declareGeneratedDialect<affine::AffineDialect>();
declareGeneratedDialect<nvgpu::NVGPUDialect>();
declareGeneratedDialect<NVVM::NVVMDialect>();
declareGeneratedDialect<vector::VectorDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
>();
}
};
} // namespace
#define GET_OP_CLASSES
#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
void mlir::nvgpu::registerTransformDialectExtension(DialectRegistry &registry) {
registry.addExtensions<NVGPUTransformDialectExtension>();
}