blob: 60c4e07a118cb841dce941c903ee75fa44f7c708 [file] [log] [blame]
//===- Transforms.cpp ---------------------------------------------- C++ --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
#include "TransformsDetail.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <iterator>
#include <numeric>
namespace mlir::mesh {
namespace {
/// Lower `mesh.process_multi_index` into expression using
/// `mesh.process_linear_index` and `mesh.mesh_shape`.
struct ProcessMultiIndexOpLowering
: OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> {
using OpRewritePatternWithSymbolTableCollection::
OpRewritePatternWithSymbolTableCollection;
LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
PatternRewriter &rewriter) const override {
MeshOp mesh = getMesh(op, symbolTableCollection);
if (!mesh) {
return failure();
}
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
builder.setInsertionPointAfter(op.getOperation());
Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults();
SmallVector<Value> completeMultiIndex =
builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
.getMultiIndex();
SmallVector<Value> multiIndex;
ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
SmallVector<MeshAxis> opAxesIota;
if (opMeshAxes.empty()) {
opAxesIota.resize(mesh.getRank());
std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
opMeshAxes = opAxesIota;
}
llvm::transform(opMeshAxes, std::back_inserter(multiIndex),
[&completeMultiIndex](MeshAxis meshAxis) {
return completeMultiIndex[meshAxis];
});
rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
return success();
}
};
struct AllSliceOpLowering
: OpRewritePatternWithSymbolTableCollection<AllSliceOp> {
using OpRewritePatternWithSymbolTableCollection::
OpRewritePatternWithSymbolTableCollection;
LogicalResult matchAndRewrite(AllSliceOp op,
PatternRewriter &rewriter) const override {
// 1. Compute the process linear index inside the process group from its
// multi-index.
//
// 2. Extract a slice from the input tensor.
// All axes except the slicing axis are not interesting and take the full
// axis.
// The slice axis is split into equisized parts with count
// the number of processes in the collective process group induced by
// the mesh axes.
// The part for each process is determined by the corresponding
// linear-index in the process group.
//
// There are no collectives that require communication.
// Each process operates on its local tensor.
MeshOp mesh = getMesh(op, symbolTableCollection);
if (!mesh) {
return failure();
}
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
builder.setInsertionPointAfter(op.getOperation());
Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
Operation::result_range processInGroupMultiIndex =
builder.create<ProcessMultiIndexOp>(mesh.getSymName(), op.getMeshAxes())
.getResults();
Operation::result_range processGroupShape =
builder.create<MeshShapeOp>(mesh.getSymName(), op.getMeshAxes())
.getResult();
Value processGroupSize =
createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder);
int64_t sliceAxis = op.getSliceAxis().getSExtValue();
Value operandSliceAxisSize =
builder.create<tensor::DimOp>(op.getOperand(), sliceAxis);
Value operandSliceAxisSizeModProcessGroupSize =
builder.create<arith::RemUIOp>(operandSliceAxisSize, processGroupSize);
Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
arith::CmpIPredicate::eq, operandSliceAxisSizeModProcessGroupSize,
zero);
builder.create<cf::AssertOp>(isTargetShapeExactlyDivisible,
"Slicing a tensor with axis size that is "
"not exactly divisible by the "
"mesh process group size is not supported.");
Value resultSliceAxisSize =
builder.create<arith::DivUIOp>(operandSliceAxisSize, processGroupSize);
OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
// insert tensor.extract_slice
RankedTensorType operandType =
cast<RankedTensorType>(op.getOperand().getType());
SmallVector<OpFoldResult> sizes;
for (int64_t i = 0; i < operandType.getRank(); ++i) {
if (i == sliceAxis) {
sizes.emplace_back(resultSliceAxisSize);
} else {
Value dimSize = builder.create<tensor::DimOp>(op.getOperand(), i);
sizes.emplace_back(dimSize);
}
}
SmallVector<OpFoldResult> offsets(
operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 0));
offsets[sliceAxis] =
ArithBuilder(builder, builder.getLoc())
.mul(getValueOrCreateConstantIndexOp(builder, builder.getLoc(),
processInGroupLinearIndex),
resultSliceAxisSize);
SmallVector<OpFoldResult> strides(
operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1));
Value slice = builder.create<tensor::ExtractSliceOp>(
op.getOperand(), offsets, sizes, strides);
Value newResult =
builder.create<tensor::CastOp>(op.getResult().getType(), slice);
rewriter.replaceAllUsesWith(op.getResult(), newResult);
return success();
}
};
} // namespace
void populateProcessMultiIndexOpLoweringPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
patterns.add<ProcessMultiIndexOpLowering>(symbolTableCollection,
patterns.getContext());
}
void registerProcessMultiIndexOpLoweringDialects(DialectRegistry &registry) {
registry.insert<affine::AffineDialect, mesh::MeshDialect>();
}
void populateAllSliceOpLoweringPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
patterns.add<AllSliceOpLowering>(symbolTableCollection,
patterns.getContext());
}
void registerAllSliceOpLoweringDialects(DialectRegistry &registry) {
registry.insert<affine::AffineDialect, arith::ArithDialect,
cf::ControlFlowDialect, mesh::MeshDialect,
tensor::TensorDialect>();
}
void populateAllOpLoweringPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
populateProcessMultiIndexOpLoweringPatterns(patterns, symbolTableCollection);
populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
}
void registerAllOpLoweringDialects(DialectRegistry &registry) {
registerProcessMultiIndexOpLoweringDialects(registry);
registerAllSliceOpLoweringDialects(registry);
}
TypedValue<IndexType>
createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
ImplicitLocOpBuilder &builder) {
Operation::result_range meshShape =
builder.create<mesh::MeshShapeOp>(mesh, axes).getResults();
return cast<TypedValue<IndexType>>(arith::createProduct(
builder, builder.getLoc(), llvm::to_vector_of<Value>(meshShape),
builder.getIndexType()));
}
TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
ArrayRef<MeshAxis> meshAxes,
ImplicitLocOpBuilder &builder) {
ResultRange processInGroupMultiIndex =
builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults();
Operation::result_range processGroupShape =
builder.create<MeshShapeOp>(mesh, meshAxes).getResult();
OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
return cast<TypedValue<IndexType>>(processInGroupLinearIndex.get<Value>());
}
} // namespace mlir::mesh