blob: 666ac56c6cd5cd87be606a96c619222de7ea23c3 [file] [log] [blame]
//===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
namespace tensor {
namespace {
static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
return llvm::all_of(
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
}
/// Returns the number of shape sizes that is either dynamic or greater than 1.
static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
return llvm::count_if(
shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
}
/// Returns success() if there is only 1 dimension size in non-packed domain
/// being greater than 1 and packing only happens on the dimension.
/// Note: this method should only be used by pack/unpack to reshape conversion.
/// It assumes that non-unit inner tile size must be used by the non-unit
/// dimension.
static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> srcShape,
ArrayRef<int64_t> innerPackTileSize) {
if (getNumGtOneDims(srcShape) > 1) {
return rewriter.notifyMatchFailure(
op, "expects non-packed domain to have at most one non-unit dims");
}
// Non-unit inner tile size must be used by the non-unit dimension. If not, it
// will faill on getting reassociation maps.
if (getNumGtOneDims(innerPackTileSize) > 1) {
return rewriter.notifyMatchFailure(
op, "expects at most one non-unit inner tiles");
}
return success();
}
/// Packing one-dimensional tensor can be expressed as an expand shape op.
struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;
Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
Type newOperandType, ArrayAttr reassociation) const {
if (operand.getType() == newOperandType)
return operand;
return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
reassociation);
}
/// Returns success() if it is only packing on the innermost dimension.
LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
PackOp packOp) const {
auto outerDimsPerm = packOp.getOuterDimsPerm();
if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
return rewriter.notifyMatchFailure(
packOp,
"expects outer_dims_perm is empty or an identity permutation");
}
int64_t srcRank = packOp.getSourceRank();
ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
return rewriter.notifyMatchFailure(
packOp, "expects packing at the innermost dimension");
}
return success();
}
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
if (packOp.getPaddingValue())
return rewriter.notifyMatchFailure(packOp, "expects no padding value");
RankedTensorType sourceType = packOp.getSourceType();
if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
packOp.getStaticTiles()))) {
return failure();
}
RankedTensorType destType = packOp.getDestType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
return failure();
Value expanded = insertExpand(
rewriter, packOp.getLoc(), packOp.getSource(), destType,
getReassociationIndicesAttribute(rewriter, *reassociation));
rewriter.replaceOp(packOp, expanded);
return success();
}
};
struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
using OpRewritePattern<UnPackOp>::OpRewritePattern;
Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
Type newOperandType, ArrayAttr reassociation) const {
if (operand.getType() == newOperandType)
return operand;
return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
operand, reassociation);
}
/// Returns success() if it is unpacking on the innermost dimension.
LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
UnPackOp unpackOp) const {
auto outerDimsPerm = unpackOp.getOuterDimsPerm();
if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
return rewriter.notifyMatchFailure(
unpackOp,
"expects outer_dims_perm is empty or an identity permutation");
}
RankedTensorType sourceType = unpackOp.getSourceType();
RankedTensorType destType = unpackOp.getDestType();
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
return rewriter.notifyMatchFailure(
unpackOp, "expects unpacking on the innermost dimension");
}
return success();
}
LogicalResult matchAndRewrite(UnPackOp unpackOp,
PatternRewriter &rewriter) const override {
RankedTensorType destType = unpackOp.getDestType();
if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
unpackOp.getStaticTiles()))) {
return failure();
}
RankedTensorType sourceType = unpackOp.getSourceType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
return failure();
Value collapsed = insertCollapse(
rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
getReassociationIndicesAttribute(rewriter, *reassociation));
rewriter.replaceOp(unpackOp, collapsed);
return success();
}
};
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
/// the pad op has zero low paddings, or if `pack` has no padding values.
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
auto padOp = packOp.getSource().getDefiningOp<PadOp>();
if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
return failure();
Value constantPaddingValue = padOp.getConstantPaddingValue();
if (!constantPaddingValue)
return failure();
if (auto paddingValue = packOp.getPaddingValue())
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
return failure();
rewriter.replaceOpWithNewOp<PackOp>(
packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
packOp.getMixedTiles(), constantPaddingValue,
packOp.getOuterDimsPerm());
return success();
}
};
/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
/// has extract_slice semantics.
struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
if (!unpackOp)
return failure();
if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
return rewriter.notifyMatchFailure(
sliceOp, "rank-reduced folding is not supported");
}
// Check all offsets are zeros, and all strides are ones.
if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
!areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
return rewriter.notifyMatchFailure(
sliceOp, "expects offsets to be 0s and strides to be 1s");
}
// Create a new empty output tensor.
Type elementType = unpackOp.getDestType().getElementType();
Value output = rewriter.create<EmptyOp>(
sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
rewriter.replaceOpWithNewOp<UnPackOp>(
sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
return success();
}
};
// Applies 'permutation' on 'inVec' and stores the result in resVec.
// 'inVec' may be empty, in that case it's one-to-one mapping with permutation.
// `rank` sets the boundary for permutation i.e., the permutation dim can't be
// greater than the rank specified. If it's so then return false.
// For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in
// permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is
// not allowed since `3` exceeds the value of the rank in the given range.
static bool checkAndPermute(ArrayRef<int64_t> permutation,
ArrayRef<int64_t> inVec,
SmallVectorImpl<int64_t> &resVec, int64_t rank) {
for (unsigned int i = 0; i < rank; ++i) {
int64_t remappedPosition = permutation[i];
if (!inVec.empty()) {
if (remappedPosition >= rank) {
return false;
}
remappedPosition = inVec[remappedPosition];
}
resVec.push_back(remappedPosition);
}
return true;
}
/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
/// semantics.
struct FoldProducerPackWithConsumerLinalgTransposeOp
: public OpRewritePattern<linalg::TransposeOp> {
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
if (!packOp)
return failure();
auto innerDimsPos = packOp.getInnerDimsPos();
auto mixedInnerTiles = packOp.getMixedTiles();
auto outerDimsPerm = packOp.getOuterDimsPerm();
auto transposePerm = transposeOp.getPermutation();
SmallVector<int64_t> newOuterDimsPermVec;
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<OpFoldResult> newMixedInnerTilesVec;
int64_t srcRank = packOp.getSourceRank();
if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
srcRank))
return rewriter.notifyMatchFailure(
transposeOp,
"Cannot fold in tensor.pack if a tile dimension was transposed "
"with a non-tile dimension in linalg.transpose.");
// Process transpose operation for tiled inner dimensions
for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
int64_t remappedPosition = transposePerm[i] - srcRank;
newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
}
Value output = packOp.createDestinationTensor(
rewriter, transposeOp.getLoc(), packOp.getSource(),
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
rewriter.replaceOpWithNewOp<PackOp>(
transposeOp, packOp.getSource(), output, newInnerDimsPosVec,
newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
return success();
}
};
/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
/// semantics.
struct FoldConsumerPackWithProducerLinalgTransposeOp
: public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
if (!transposeOp)
return failure();
auto transposePermutation = transposeOp.getPermutation();
auto outerDimsPerm = packOp.getOuterDimsPerm();
auto innerDimsPos = packOp.getInnerDimsPos();
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<int64_t> newOuterDimsPermVec =
llvm::to_vector(transposePermutation);
if (!outerDimsPerm.empty())
applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
// Can't use applyPermutationToVector for newInnerDimsPosVec since input and
// permutation rank won't necessarily be equal in all cases.
for (auto dim : innerDimsPos)
newInnerDimsPosVec.push_back(transposePermutation[dim]);
Value output = packOp.createDestinationTensor(
rewriter, packOp.getLoc(), transposeOp.getOperand(0),
packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
rewriter.replaceOpWithNewOp<PackOp>(
packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
return success();
}
};
/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
/// transpose semantics.
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
: public OpRewritePattern<linalg::TransposeOp> {
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
if (!unPackOp)
return failure();
auto transposePermutation = transposeOp.getPermutation();
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
auto innerDimsPos = unPackOp.getInnerDimsPos();
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<int64_t> newOuterDimsPermVec =
llvm::to_vector(transposePermutation);
if (!outerDimsPerm.empty())
applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
// Can't use applyPermutationToVector for newInnerDimsPosVec since input and
// permutation rank won't necessarily be equal in all cases.
for (auto dim : innerDimsPos)
newInnerDimsPosVec.push_back(transposePermutation[dim]);
Value output = unPackOp.createDestinationTensor(
rewriter, transposeOp.getLoc(), unPackOp.getSource(),
unPackOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
rewriter.replaceOpWithNewOp<UnPackOp>(
transposeOp, unPackOp.getSource(), output, newInnerDimsPosVec,
unPackOp.getMixedTiles(), newOuterDimsPermVec);
return success();
}
};
/// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
/// transpose semantics.
struct FoldConsumerUnPackWithProducerLinalgTransposeOp
: public OpRewritePattern<UnPackOp> {
using OpRewritePattern<UnPackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
auto transposeOp =
unPackOp.getSource().getDefiningOp<linalg::TransposeOp>();
if (!transposeOp)
return failure();
auto transposePermutation = transposeOp.getPermutation();
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
auto innerDimsPos = unPackOp.getInnerDimsPos();
int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
auto mixedInnerTilesVec = unPackOp.getMixedTiles();
SmallVector<int64_t> newOuterDimsPermVec;
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<OpFoldResult> newMixedInnerTilesVec;
if (!checkAndPermute(transposePermutation, outerDimsPerm,
newOuterDimsPermVec, destRank))
return rewriter.notifyMatchFailure(
unPackOp,
"Cannot fold in tensor.unpack if a tile dimension was transposed "
"with a non-tile dimension in linalg.transpose.");
// Process transpose operation for tiled inner dimensions
for (unsigned int i = destRank; i < transposePermutation.size(); ++i) {
int64_t remappedPosition = transposePermutation[i] - destRank;
newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
}
Value output = unPackOp.createDestinationTensor(
rewriter, unPackOp.getLoc(), transposeOp.getOperand(0),
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
rewriter.replaceOpWithNewOp<UnPackOp>(
unPackOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
newMixedInnerTilesVec, newOuterDimsPermVec);
return success();
}
};
} // namespace
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
FoldProducerPackWithConsumerLinalgTransposeOp,
FoldConsumerPackWithProducerLinalgTransposeOp,
FoldConsumerUnPackWithProducerLinalgTransposeOp,
FoldProducerUnPackWithConsumerLinalgTransposeOp>(
patterns.getContext());
}
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
patterns.getContext());
}
} // namespace tensor
} // namespace mlir