blob: 704b0cdb0324b9b2dcf552b9b9ff0bb1c6431881 [file] [log] [blame]
//===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "../PassDetail.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace mlir::shape;
using namespace mlir::scf;
/// Conversion patterns.
namespace {
class AnyOpConversion : public OpConversionPattern<AnyOp> {
public:
using OpConversionPattern<AnyOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
LogicalResult
AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
AnyOp::Adaptor transformed(operands);
// Replace `any` with its first operand.
// Any operand would be a valid substitution.
rewriter.replaceOp(op, {transformed.inputs().front()});
return success();
}
namespace {
template <typename SrcOpTy, typename DstOpTy>
class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
public:
using OpConversionPattern<SrcOpTy>::OpConversionPattern;
LogicalResult
matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
typename SrcOpTy::Adaptor transformed(operands);
// For now, only error-free types are supported by this lowering.
if (op.getType().template isa<SizeType>())
return failure();
rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(),
transformed.rhs());
return success();
}
};
} // namespace
namespace {
struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
using OpConversionPattern<BroadcastOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
LogicalResult BroadcastOpConverter::matchAndRewrite(
BroadcastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// For now, this lowering is only defined on `tensor<?xindex>` operands, not
// on shapes.
if (op.getType().isa<ShapeType>())
return failure();
assert(!op.lhs().getType().isa<ShapeType>() &&
!op.rhs().getType().isa<ShapeType>());
auto loc = op.getLoc();
BroadcastOp::Adaptor transformed(operands);
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
// Find smaller and greater rank and extent tensor.
Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
Value lhsRankULE =
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
Type indexTy = rewriter.getIndexType();
Value lesserRank =
rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
Value greaterRank =
rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
auto erasedRankType =
RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
Value rankErasedLhs =
rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
Value rankErasedRhs =
rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
Value lesserRankOperand =
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
Value greaterRankOperand =
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
Value rankDiff =
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>(
op, getExtentTensorType(op.getContext()), ValueRange{greaterRank},
[&](OpBuilder &b, Location loc, ValueRange args) {
Value outputDimension = args[0];
Value isUnchallengedDimension = b.create<CmpIOp>(
loc, CmpIPredicate::ult, outputDimension, rankDiff);
Value greaterRankOperandExtent = b.create<ExtractElementOp>(
loc, greaterRankOperand, outputDimension);
// The initial dimensions of the greater-rank operand are unchallenged,
// so we can take them as-is. Otherwise, we need to do a comparison.
// We need an actual branch here (instead of a select) because the
// lesser-rank operand might be rank 0, so any extract_element would be
// invalid.
auto ifOp = b.create<IfOp>(
loc, TypeRange{indexTy}, isUnchallengedDimension,
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(loc, greaterRankOperandExtent);
},
[&](OpBuilder &b, Location loc) {
// The broadcasting logic is:
// - if one extent (here we arbitrarily choose the extent from
// the greater-rank operand) is equal to 1, then take the extent
// from the other operand
// - otherwise, take the extent as-is.
// Note that this logic remains correct in the presence of
// dimensions of zero extent.
Value lesserRankOperandDimension =
b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
loc, lesserRankOperand,
ValueRange{lesserRankOperandDimension});
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
Value broadcastedExtent = b.create<SelectOp>(
loc, greaterRankOperandExtentIsOne, lesserRankOperandExtent,
greaterRankOperandExtent);
b.create<scf::YieldOp>(loc, broadcastedExtent);
});
b.create<mlir::YieldOp>(loc, ifOp.getResult(0));
});
return success();
}
namespace {
class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
public:
using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
LogicalResult ConstShapeOpConverter::matchAndRewrite(
ConstShapeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// For now, this lowering supports only extent tensors, not `shape.shape`
// types.
if (op.getType().isa<ShapeType>())
return failure();
auto loc = op.getLoc();
SmallVector<Value, 4> extentOperands;
for (auto extent : op.shape()) {
extentOperands.push_back(
rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
}
Type indexTy = rewriter.getIndexType();
Value tensor =
rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands);
Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
return success();
}
namespace {
class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
public:
using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
LogicalResult ConstSizeOpConversion::matchAndRewrite(
ConstSizeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
return success();
}
namespace {
struct IsBroadcastableOpConverter
: public OpConversionPattern<IsBroadcastableOp> {
using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(IsBroadcastableOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
IsBroadcastableOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// For now, this lowering is only defined on `tensor<?xindex>` operands, not
// on shapes.
IsBroadcastableOp::Adaptor transformed(operands);
if (transformed.lhs().getType().isa<ShapeType>() ||
transformed.rhs().getType().isa<ShapeType>())
return failure();
auto loc = op.getLoc();
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
// Find smaller and greater rank and extent tensor.
Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
Value lhsRankULE =
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
Type indexTy = rewriter.getIndexType();
Value lesserRank =
rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
Value greaterRank =
rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
auto erasedRankType =
RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
Value rankErasedLhs =
rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
Value rankErasedRhs =
rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
Value lesserRankOperand =
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
Value greaterRankOperand =
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
Value rankDiff =
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
Type i1Ty = rewriter.getI1Type();
Value init =
rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
// Determine if all overlapping extents are broadcastable.
auto reduceResult = rewriter.create<ForOp>(
loc, rankDiff, greaterRank, one, ValueRange{init},
[&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
Value greaterRankOperandExtent =
b.create<ExtractElementOp>(loc, greaterRankOperand, ValueRange{iv});
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
loc, lesserRankOperand, ValueRange{ivShifted});
Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
Value extentsAreEqual =
b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
lesserRankOperandExtent);
Value broadcastableExtents = b.create<AndOp>(
loc, iterArgs[0],
b.create<OrOp>(loc,
b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
lesserRankOperandExtentIsOne),
extentsAreEqual));
b.create<scf::YieldOp>(loc, broadcastableExtents);
});
rewriter.replaceOp(op, reduceResult.results().front());
return success();
}
namespace {
class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
using OpConversionPattern<GetExtentOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
LogicalResult GetExtentOpConverter::matchAndRewrite(
GetExtentOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
GetExtentOp::Adaptor transformed(operands);
// For now, only error-free types are supported by this lowering.
if (op.getType().isa<SizeType>())
return failure();
// Derive shape extent directly from shape origin if possible. This
// circumvents the necessity to materialize the shape in memory.
if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
if (shapeOfOp.arg().getType().isa<ShapedType>()) {
rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
transformed.dim());
return success();
}
}
rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(),
transformed.shape(),
ValueRange{transformed.dim()});
return success();
}
namespace {
class RankOpConverter : public OpConversionPattern<shape::RankOp> {
public:
using OpConversionPattern<shape::RankOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
LogicalResult
RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// For now, this lowering supports only error-free types.
if (op.getType().isa<SizeType>())
return failure();
shape::RankOp::Adaptor transformed(operands);
rewriter.replaceOpWithNewOp<DimOp>(op, transformed.shape(), 0);
return success();
}
namespace {
/// Converts `shape.reduce` to `scf.for`.
struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final;
};
} // namespace
LogicalResult
ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// For now, this lowering is only defined on `tensor<?xindex>` operands.
if (op.shape().getType().isa<ShapeType>())
return failure();
auto loc = op.getLoc();
shape::ReduceOp::Adaptor transformed(operands);
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
Type indexTy = rewriter.getIndexType();
Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero);
auto loop = rewriter.create<scf::ForOp>(
loc, zero, rank, one, op.initVals(),
[&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv);
SmallVector<Value, 2> mappedValues{iv, extent};
mappedValues.append(args.begin(), args.end());
BlockAndValueMapping mapping;
Block *reduceBody = op.getBody();
mapping.map(reduceBody->getArguments(), mappedValues);
for (auto &nested : reduceBody->without_terminator())
b.clone(nested, mapping);
SmallVector<Value, 2> mappedResults;
for (auto result : reduceBody->getTerminator()->getOperands())
mappedResults.push_back(mapping.lookup(result));
b.create<scf::YieldOp>(loc, mappedResults);
});
rewriter.replaceOp(op, loop.getResults());
return success();
}
namespace {
/// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
/// only defined on `tensor<?xindex>` operands. The test for equality first
/// compares their size and, if equal, checks every extent for equality.
///
/// Example:
///
/// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
///
/// becomes
///
/// %c0 = constant 0 : index
/// %0 = dim %arg0, %c0 : tensor<?xindex>
/// %1 = dim %arg1, %c0 : tensor<?xindex>
/// %2 = cmpi "eq", %0, %1 : index
/// %result = scf.if %2 -> (i1) {
/// %c1 = constant 1 : index
/// %true = constant true
/// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
/// %5 = extract_element %arg0[%arg2] : tensor<?xindex>
/// %6 = extract_element %arg1[%arg2] : tensor<?xindex>
/// %7 = cmpi "eq", %5, %6 : index
/// %8 = and %arg3, %7 : i1
/// scf.yield %8 : i1
/// }
/// scf.yield %4 : i1
/// } else {
/// %false = constant false
/// scf.yield %false : i1
/// }
///
struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
LogicalResult
ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// For now, this lowering is only defined on `tensor<?xindex>` operands, not
// on shapes.
if (op.lhs().getType().isa<ShapeType>() ||
op.rhs().getType().isa<ShapeType>()) {
return failure();
}
ShapeEqOp::Adaptor transformed(operands);
auto loc = op.getLoc();
Type indexTy = rewriter.getIndexType();
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero);
Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero);
Value eqRank =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank);
Type i1Ty = rewriter.getI1Type();
rewriter.replaceOpWithNewOp<IfOp>(
op, i1Ty, eqRank,
[&](OpBuilder &b, Location loc) {
Value one = b.create<ConstantIndexOp>(loc, 1);
Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
auto loop = b.create<scf::ForOp>(
loc, zero, lhsRank, one, ValueRange{init},
[&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
Value conj = args[0];
Value lhsExtent =
b.create<ExtractElementOp>(loc, transformed.lhs(), iv);
Value rhsExtent =
b.create<ExtractElementOp>(loc, transformed.rhs(), iv);
Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
lhsExtent, rhsExtent);
Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
});
b.create<scf::YieldOp>(loc, loop.getResults());
},
[&](OpBuilder &b, Location loc) {
Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
b.create<scf::YieldOp>(loc, result);
});
return success();
}
namespace {
class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
public:
using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
LogicalResult ShapeOfOpConversion::matchAndRewrite(
ShapeOfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// For now, only error-free types are supported by this lowering.
if (op.getType().isa<ShapeType>())
return failure();
// For ranked tensor arguments, lower to `tensor_from_elements`.
auto loc = op.getLoc();
ShapeOfOp::Adaptor transformed(operands);
Value tensor = transformed.arg();
Type tensorTy = tensor.getType();
if (tensorTy.isa<RankedTensorType>()) {
// Build values for individual extents.
SmallVector<Value, 8> extentValues;
RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
int64_t rank = rankedTensorTy.getRank();
for (int64_t i = 0; i < rank; i++) {
if (rankedTensorTy.isDynamicDim(i)) {
Value extent = rewriter.create<DimOp>(loc, tensor, i);
extentValues.push_back(extent);
} else {
Value extent =
rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i));
extentValues.push_back(extent);
}
}
// Materialize extent tensor.
Value staticExtentTensor = rewriter.create<TensorFromElementsOp>(
loc, rewriter.getIndexType(), extentValues);
rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
op.getType());
return success();
}
// Lower to `dynamic_tensor_from_elements` otherwise.
auto *ctx = rewriter.getContext();
Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>(
op, getExtentTensorType(ctx), ValueRange{rank},
[&](OpBuilder &b, Location loc, ValueRange args) {
Value dim = args.front();
Value extent = b.create<DimOp>(loc, tensor, dim);
b.create<mlir::YieldOp>(loc, extent);
});
return success();
}
namespace {
class ToExtentTensorOpConversion
: public OpConversionPattern<ToExtentTensorOp> {
public:
using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
ToExtentTensorOpAdaptor adaptor(operands);
if (!adaptor.input().getType().isa<RankedTensorType>())
return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(),
op.getType());
return success();
}
};
} // namespace
namespace {
/// Conversion pass.
class ConvertShapeToStandardPass
: public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
void runOnOperation() override;
};
} // namespace
void ConvertShapeToStandardPass::runOnOperation() {
// Setup target legality.
MLIRContext &ctx = getContext();
ConversionTarget target(ctx);
target.addLegalDialect<StandardOpsDialect, SCFDialect>();
target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
// Setup conversion patterns.
OwningRewritePatternList patterns;
populateShapeToStandardConversionPatterns(patterns, &ctx);
// Apply conversion.
auto module = getOperation();
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}
void mlir::populateShapeToStandardConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
// clang-format off
patterns.insert<
AnyOpConversion,
BinaryOpConversion<AddOp, AddIOp>,
BinaryOpConversion<MulOp, MulIOp>,
BroadcastOpConverter,
ConstShapeOpConverter,
ConstSizeOpConversion,
IsBroadcastableOpConverter,
GetExtentOpConverter,
RankOpConverter,
ReduceOpConverter,
ShapeEqOpConverter,
ShapeOfOpConversion,
ToExtentTensorOpConversion>(ctx);
// clang-format on
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertShapeToStandardPass() {
return std::make_unique<ConvertShapeToStandardPass>();
}