blob: 75981263313616e959de57d1d12644eca498d135 [file] [log] [blame]
//===- LinalgToSPIRV.cpp - Linalg to SPIR-V dialect conversion ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//
/// Returns a `Value` containing the `dim`-th dimension's size of SPIR-V
/// location invocation ID. This function will create necessary operations with
/// `builder` at the proper region containing `op`.
static Value getLocalInvocationDimSize(Operation *op, int dim, Location loc,
OpBuilder *builder) {
assert(dim >= 0 && dim < 3 && "local invocation only has three dimensions");
Value invocation = spirv::getBuiltinVariableValue(
op, spirv::BuiltIn::LocalInvocationId, *builder);
Type xType = invocation.getType().cast<ShapedType>().getElementType();
return builder->create<spirv::CompositeExtractOp>(
loc, xType, invocation, builder->getI32ArrayAttr({dim}));
}
//===----------------------------------------------------------------------===//
// Reduction (single workgroup)
//===----------------------------------------------------------------------===//
namespace {
/// A pattern to convert a linalg.generic op to SPIR-V ops under the condition
/// that the linalg.generic op is performing reduction with a workload size that
/// can fit in one workgroup.
class SingleWorkgroupReduction final
: public SPIRVOpLowering<linalg::GenericOp> {
public:
using SPIRVOpLowering<linalg::GenericOp>::SPIRVOpLowering;
/// Matches the given linalg.generic op as performing reduction and returns
/// the binary op kind if successful.
static Optional<linalg::RegionMatcher::BinaryOpKind>
matchAsPerformingReduction(linalg::GenericOp genericOp);
PatternMatchResult
matchAndRewrite(linalg::GenericOp genericOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
Optional<linalg::RegionMatcher::BinaryOpKind>
SingleWorkgroupReduction::matchAsPerformingReduction(
linalg::GenericOp genericOp) {
Operation *op = genericOp.getOperation();
// Make sure the linalg.generic is working on memrefs.
if (!genericOp.hasBufferSemantics())
return llvm::None;
// Make sure this is reudction with one input and one output.
if (genericOp.args_in().getZExtValue() != 1 ||
genericOp.args_out().getZExtValue() != 1)
return llvm::None;
auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();
// Make sure the original input has one dimension.
if (!originalInputType.hasStaticShape() || originalInputType.getRank() != 1)
return llvm::None;
// Make sure the original output has one element.
if (!originalOutputType.hasStaticShape() ||
originalOutputType.getNumElements() != 1)
return llvm::None;
if (!genericOp.hasSingleReductionLoop())
return llvm::None;
if (genericOp.indexing_maps().getValue().size() != 2)
return llvm::None;
// TODO(nicolasvasilache): create utility functions for these checks in Linalg
// and use them.
auto inputMap = genericOp.indexing_maps().getValue()[0].cast<AffineMapAttr>();
auto outputMap =
genericOp.indexing_maps().getValue()[1].cast<AffineMapAttr>();
// The indexing map for the input should be `(i) -> (i)`.
if (inputMap.getValue() !=
AffineMap::get(1, 0, {getAffineDimExpr(0, op->getContext())}))
return llvm::None;
// The indexing map for the input should be `(i) -> (0)`.
if (outputMap.getValue() !=
AffineMap::get(1, 0, {getAffineConstantExpr(0, op->getContext())}))
return llvm::None;
return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp);
}
PatternMatchResult SingleWorkgroupReduction::matchAndRewrite(
linalg::GenericOp genericOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Operation *op = genericOp.getOperation();
auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
auto originalOutputType = op->getOperand(1).getType().cast<MemRefType>();
auto binaryOpKind = matchAsPerformingReduction(genericOp);
if (!binaryOpKind)
return matchFailure();
// Query the shader interface for local workgroup size to make sure the
// invocation configuration fits with the input memref's shape.
DenseIntElementsAttr localSize = spirv::lookupLocalWorkGroupSize(genericOp);
if (!localSize)
return matchFailure();
if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0))
return matchFailure();
if (llvm::any_of(llvm::drop_begin(localSize.getIntValues(), 1),
[](const APInt &size) { return !size.isOneValue(); }))
return matchFailure();
// TODO(antiagainst): Query the target environment to make sure the current
// workload fits in a local workgroup.
Value convertedInput = operands[0], convertedOutput = operands[1];
Location loc = genericOp.getLoc();
// Get the invocation ID.
Value x = getLocalInvocationDimSize(genericOp, /*dim=*/0, loc, &rewriter);
// TODO(antiagainst): Load to Workgroup storage class first.
// Get the input element accessed by this invocation.
Value inputElementPtr = spirv::getElementPtr(
typeConverter, originalInputType, convertedInput, {x}, loc, rewriter);
Value inputElement = rewriter.create<spirv::LoadOp>(loc, inputElementPtr);
// Perform the group reduction operation.
Value groupOperation;
#define CREATE_GROUP_NON_UNIFORM_BIN_OP(opKind, spvOp) \
case linalg::RegionMatcher::BinaryOpKind::opKind: { \
groupOperation = rewriter.create<spirv::spvOp>( \
loc, originalInputType.getElementType(), spirv::Scope::Subgroup, \
spirv::GroupOperation::Reduce, inputElement, \
/*cluster_size=*/ArrayRef<Value>()); \
} break
switch (*binaryOpKind) {
CREATE_GROUP_NON_UNIFORM_BIN_OP(IAdd, GroupNonUniformIAddOp);
}
#undef CREATE_GROUP_NON_UNIFORM_BIN_OP
// Get the output element accessed by this reduction.
Value zero = spirv::ConstantOp::getZero(
typeConverter.getIndexType(rewriter.getContext()), loc, &rewriter);
SmallVector<Value, 1> zeroIndices(originalOutputType.getRank(), zero);
Value outputElementPtr =
spirv::getElementPtr(typeConverter, originalOutputType, convertedOutput,
zeroIndices, loc, rewriter);
// Write out the final reduction result. This should be only conducted by one
// invocation. We use spv.GroupNonUniformElect to find the invocation with the
// lowest ID.
//
// ```
// if (spv.GroupNonUniformElect) { output = ... }
// ```
Value condition = rewriter.create<spirv::GroupNonUniformElectOp>(
loc, spirv::Scope::Subgroup);
auto createAtomicOp = [&](OpBuilder *builder) {
#define CREATE_ATOMIC_BIN_OP(opKind, spvOp) \
case linalg::RegionMatcher::BinaryOpKind::opKind: { \
builder->create<spirv::spvOp>(loc, outputElementPtr, spirv::Scope::Device, \
spirv::MemorySemantics::AcquireRelease, \
groupOperation); \
} break
switch (*binaryOpKind) { CREATE_ATOMIC_BIN_OP(IAdd, AtomicIAddOp); }
#undef CREATE_ATOMIC_BIN_OP
};
spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, &rewriter);
rewriter.eraseOp(genericOp);
return matchSuccess();
}
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
void mlir::populateLinalgToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
patterns.insert<SingleWorkgroupReduction>(context, typeConverter);
}