| //===- IntNarrowing.cpp - Integer bitwidth reduction optimizations --------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Arith/Transforms/Passes.h" |
| |
| #include "mlir/Analysis/Presburger/IntegerRelation.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Arith/Transforms/Transforms.h" |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/BuiltinTypeInterfaces.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/Operation.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include <cassert> |
| #include <cstdint> |
| |
| namespace mlir::arith { |
| #define GEN_PASS_DEF_ARITHINTNARROWING |
| #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" |
| } // namespace mlir::arith |
| |
| namespace mlir::arith { |
| namespace { |
| //===----------------------------------------------------------------------===// |
| // Common Helpers |
| //===----------------------------------------------------------------------===// |
| |
| /// The base for integer bitwidth narrowing patterns. |
| template <typename SourceOp> |
| struct NarrowingPattern : OpRewritePattern<SourceOp> { |
| NarrowingPattern(MLIRContext *ctx, const ArithIntNarrowingOptions &options, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<SourceOp>(ctx, benefit), |
| supportedBitwidths(options.bitwidthsSupported.begin(), |
| options.bitwidthsSupported.end()) { |
| assert(!supportedBitwidths.empty() && "Invalid options"); |
| assert(!llvm::is_contained(supportedBitwidths, 0) && "Invalid bitwidth"); |
| llvm::sort(supportedBitwidths); |
| } |
| |
| FailureOr<unsigned> |
| getNarrowestCompatibleBitwidth(unsigned bitsRequired) const { |
| for (unsigned candidate : supportedBitwidths) |
| if (candidate >= bitsRequired) |
| return candidate; |
| |
| return failure(); |
| } |
| |
| /// Returns the narrowest supported type that fits `bitsRequired`. |
| FailureOr<Type> getNarrowType(unsigned bitsRequired, Type origTy) const { |
| assert(origTy); |
| FailureOr<unsigned> bestBitwidth = |
| getNarrowestCompatibleBitwidth(bitsRequired); |
| if (failed(bestBitwidth)) |
| return failure(); |
| |
| Type elemTy = getElementTypeOrSelf(origTy); |
| if (!isa<IntegerType>(elemTy)) |
| return failure(); |
| |
| auto newElemTy = IntegerType::get(origTy.getContext(), *bestBitwidth); |
| if (newElemTy == elemTy) |
| return failure(); |
| |
| if (origTy == elemTy) |
| return newElemTy; |
| |
| if (auto shapedTy = dyn_cast<ShapedType>(origTy)) |
| if (dyn_cast<IntegerType>(shapedTy.getElementType())) |
| return shapedTy.clone(shapedTy.getShape(), newElemTy); |
| |
| return failure(); |
| } |
| |
| private: |
| // Supported integer bitwidths in the ascending order. |
| llvm::SmallVector<unsigned, 6> supportedBitwidths; |
| }; |
| |
| /// Returns the integer bitwidth required to represent `type`. |
| FailureOr<unsigned> calculateBitsRequired(Type type) { |
| assert(type); |
| if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(type))) |
| return intTy.getWidth(); |
| |
| return failure(); |
| } |
| |
| enum class ExtensionKind { Sign, Zero }; |
| |
| /// Wrapper around `arith::ExtSIOp` and `arith::ExtUIOp` ops that abstracts away |
| /// the exact op type. Exposes helper functions to query the types, operands, |
| /// and the result. This is so that we can handle both extension kinds without |
| /// needing to use templates or branching. |
| class ExtensionOp { |
| public: |
| /// Attemps to create a new extension op from `op`. Returns an extension op |
| /// wrapper when `op` is either `arith.extsi` or `arith.extui`, and failure |
| /// otherwise. |
| static FailureOr<ExtensionOp> from(Operation *op) { |
| if (dyn_cast_or_null<arith::ExtSIOp>(op)) |
| return ExtensionOp{op, ExtensionKind::Sign}; |
| if (dyn_cast_or_null<arith::ExtUIOp>(op)) |
| return ExtensionOp{op, ExtensionKind::Zero}; |
| |
| return failure(); |
| } |
| |
| ExtensionOp(const ExtensionOp &) = default; |
| ExtensionOp &operator=(const ExtensionOp &) = default; |
| |
| /// Creates a new extension op of the same kind. |
| Operation *recreate(PatternRewriter &rewriter, Location loc, Type newType, |
| Value in) { |
| if (kind == ExtensionKind::Sign) |
| return rewriter.create<arith::ExtSIOp>(loc, newType, in); |
| |
| return rewriter.create<arith::ExtUIOp>(loc, newType, in); |
| } |
| |
| /// Replaces `toReplace` with a new extension op of the same kind. |
| void recreateAndReplace(PatternRewriter &rewriter, Operation *toReplace, |
| Value in) { |
| assert(toReplace->getNumResults() == 1); |
| Type newType = toReplace->getResult(0).getType(); |
| Operation *newOp = recreate(rewriter, toReplace->getLoc(), newType, in); |
| rewriter.replaceOp(toReplace, newOp->getResult(0)); |
| } |
| |
| ExtensionKind getKind() { return kind; } |
| |
| Value getResult() { return op->getResult(0); } |
| Value getIn() { return op->getOperand(0); } |
| |
| Type getType() { return getResult().getType(); } |
| Type getElementType() { return getElementTypeOrSelf(getType()); } |
| Type getInType() { return getIn().getType(); } |
| Type getInElementType() { return getElementTypeOrSelf(getInType()); } |
| |
| private: |
| ExtensionOp(Operation *op, ExtensionKind kind) : op(op), kind(kind) { |
| assert(op); |
| assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && "Not an extension op"); |
| } |
| Operation *op = nullptr; |
| ExtensionKind kind = {}; |
| }; |
| |
| /// Returns the integer bitwidth required to represent `value`. |
| unsigned calculateBitsRequired(const APInt &value, |
| ExtensionKind lookThroughExtension) { |
| // For unsigned values, we only need the active bits. As a special case, zero |
| // requires one bit. |
| if (lookThroughExtension == ExtensionKind::Zero) |
| return std::max(value.getActiveBits(), 1u); |
| |
| // If a signed value is nonnegative, we need one extra bit for the sign. |
| if (value.isNonNegative()) |
| return value.getActiveBits() + 1; |
| |
| // For the signed min, we need all the bits. |
| if (value.isMinSignedValue()) |
| return value.getBitWidth(); |
| |
| // For negative values, we need all the non-sign bits and one extra bit for |
| // the sign. |
| return value.getBitWidth() - value.getNumSignBits() + 1; |
| } |
| |
| /// Returns the integer bitwidth required to represent `value`. |
| /// Looks through either sign- or zero-extension as specified by |
| /// `lookThroughExtension`. |
| FailureOr<unsigned> calculateBitsRequired(Value value, |
| ExtensionKind lookThroughExtension) { |
| // Handle constants. |
| if (TypedAttr attr; matchPattern(value, m_Constant(&attr))) { |
| if (auto intAttr = dyn_cast<IntegerAttr>(attr)) |
| return calculateBitsRequired(intAttr.getValue(), lookThroughExtension); |
| |
| if (auto elemsAttr = dyn_cast<DenseElementsAttr>(attr)) { |
| if (elemsAttr.getElementType().isIntOrIndex()) { |
| if (elemsAttr.isSplat()) |
| return calculateBitsRequired(elemsAttr.getSplatValue<APInt>(), |
| lookThroughExtension); |
| |
| unsigned maxBits = 1; |
| for (const APInt &elemValue : elemsAttr.getValues<APInt>()) |
| maxBits = std::max( |
| maxBits, calculateBitsRequired(elemValue, lookThroughExtension)); |
| return maxBits; |
| } |
| } |
| } |
| |
| if (lookThroughExtension == ExtensionKind::Sign) { |
| if (auto sext = value.getDefiningOp<arith::ExtSIOp>()) |
| return calculateBitsRequired(sext.getIn().getType()); |
| } else if (lookThroughExtension == ExtensionKind::Zero) { |
| if (auto zext = value.getDefiningOp<arith::ExtUIOp>()) |
| return calculateBitsRequired(zext.getIn().getType()); |
| } |
| |
| // If nothing else worked, return the type requirements for this element type. |
| return calculateBitsRequired(value.getType()); |
| } |
| |
| /// Base pattern for arith binary ops. |
| /// Example: |
| /// ``` |
| /// %lhs = arith.extsi %a : i8 to i32 |
| /// %rhs = arith.extsi %b : i8 to i32 |
| /// %r = arith.addi %lhs, %rhs : i32 |
| /// ==> |
| /// %lhs = arith.extsi %a : i8 to i16 |
| /// %rhs = arith.extsi %b : i8 to i16 |
| /// %add = arith.addi %lhs, %rhs : i16 |
| /// %r = arith.extsi %add : i16 to i32 |
| /// ``` |
| template <typename BinaryOp> |
| struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> { |
| using NarrowingPattern<BinaryOp>::NarrowingPattern; |
| |
| /// Returns the number of bits required to represent the full result, assuming |
| /// that both operands are `operandBits`-wide. Derived classes must implement |
| /// this, taking into account `BinaryOp` semantics. |
| virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0; |
| |
| /// Customization point for patterns that should only apply with |
| /// zero/sign-extension ops as arguments. |
| virtual bool isSupported(ExtensionOp) const { return true; } |
| |
| LogicalResult matchAndRewrite(BinaryOp op, |
| PatternRewriter &rewriter) const final { |
| Type origTy = op.getType(); |
| FailureOr<unsigned> resultBits = calculateBitsRequired(origTy); |
| if (failed(resultBits)) |
| return failure(); |
| |
| // For the optimization to apply, we expect the lhs to be an extension op, |
| // and for the rhs to either be the same extension op or a constant. |
| FailureOr<ExtensionOp> ext = ExtensionOp::from(op.getLhs().getDefiningOp()); |
| if (failed(ext) || !isSupported(*ext)) |
| return failure(); |
| |
| FailureOr<unsigned> lhsBitsRequired = |
| calculateBitsRequired(ext->getIn(), ext->getKind()); |
| if (failed(lhsBitsRequired) || *lhsBitsRequired >= *resultBits) |
| return failure(); |
| |
| FailureOr<unsigned> rhsBitsRequired = |
| calculateBitsRequired(op.getRhs(), ext->getKind()); |
| if (failed(rhsBitsRequired) || *rhsBitsRequired >= *resultBits) |
| return failure(); |
| |
| // Negotiate a common bit requirements for both lhs and rhs, accounting for |
| // the result requiring more bits than the operands. |
| unsigned commonBitsRequired = |
| getResultBitsProduced(std::max(*lhsBitsRequired, *rhsBitsRequired)); |
| FailureOr<Type> narrowTy = this->getNarrowType(commonBitsRequired, origTy); |
| if (failed(narrowTy) || calculateBitsRequired(*narrowTy) >= *resultBits) |
| return failure(); |
| |
| Location loc = op.getLoc(); |
| Value newLhs = |
| rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getLhs()); |
| Value newRhs = |
| rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getRhs()); |
| Value newAdd = rewriter.create<BinaryOp>(loc, newLhs, newRhs); |
| ext->recreateAndReplace(rewriter, op, newAdd); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // AddIOp Pattern |
| //===----------------------------------------------------------------------===// |
| |
| struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> { |
| using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; |
| |
| // Addition may require one extra bit for the result. |
| // Example: `UINT8_MAX + 1 == 255 + 1 == 256`. |
| unsigned getResultBitsProduced(unsigned operandBits) const override { |
| return operandBits + 1; |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // SubIOp Pattern |
| //===----------------------------------------------------------------------===// |
| |
| struct SubIPattern final : BinaryOpNarrowingPattern<arith::SubIOp> { |
| using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; |
| |
| // This optimization only applies to signed arguments. |
| bool isSupported(ExtensionOp ext) const override { |
| return ext.getKind() == ExtensionKind::Sign; |
| } |
| |
| // Subtraction may require one extra bit for the result. |
| // Example: `INT8_MAX - (-1) == 127 - (-1) == 128`. |
| unsigned getResultBitsProduced(unsigned operandBits) const override { |
| return operandBits + 1; |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // MulIOp Pattern |
| //===----------------------------------------------------------------------===// |
| |
| struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> { |
| using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; |
| |
| // Multiplication may require up double the operand bits. |
| // Example: `UNT8_MAX * UINT8_MAX == 255 * 255 == 65025`. |
| unsigned getResultBitsProduced(unsigned operandBits) const override { |
| return 2 * operandBits; |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // DivSIOp Pattern |
| //===----------------------------------------------------------------------===// |
| |
| struct DivSIPattern final : BinaryOpNarrowingPattern<arith::DivSIOp> { |
| using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; |
| |
| // This optimization only applies to signed arguments. |
| bool isSupported(ExtensionOp ext) const override { |
| return ext.getKind() == ExtensionKind::Sign; |
| } |
| |
| // Unlike multiplication, signed division requires only one more result bit. |
| // Example: `INT8_MIN / (-1) == -128 / (-1) == 128`. |
| unsigned getResultBitsProduced(unsigned operandBits) const override { |
| return operandBits + 1; |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // DivUIOp Pattern |
| //===----------------------------------------------------------------------===// |
| |
| struct DivUIPattern final : BinaryOpNarrowingPattern<arith::DivUIOp> { |
| using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; |
| |
| // This optimization only applies to unsigned arguments. |
| bool isSupported(ExtensionOp ext) const override { |
| return ext.getKind() == ExtensionKind::Zero; |
| } |
| |
| // Unsigned division does not require any extra result bits. |
| unsigned getResultBitsProduced(unsigned operandBits) const override { |
| return operandBits; |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // Min/Max Patterns |
| //===----------------------------------------------------------------------===// |
| |
| template <typename MinMaxOp, ExtensionKind Kind> |
| struct MinMaxPattern final : BinaryOpNarrowingPattern<MinMaxOp> { |
| using BinaryOpNarrowingPattern<MinMaxOp>::BinaryOpNarrowingPattern; |
| |
| bool isSupported(ExtensionOp ext) const override { |
| return ext.getKind() == Kind; |
| } |
| |
| // Min/max returns one of the arguments and does not require any extra result |
| // bits. |
| unsigned getResultBitsProduced(unsigned operandBits) const override { |
| return operandBits; |
| } |
| }; |
| using MaxSIPattern = MinMaxPattern<arith::MaxSIOp, ExtensionKind::Sign>; |
| using MaxUIPattern = MinMaxPattern<arith::MaxUIOp, ExtensionKind::Zero>; |
| using MinSIPattern = MinMaxPattern<arith::MinSIOp, ExtensionKind::Sign>; |
| using MinUIPattern = MinMaxPattern<arith::MinUIOp, ExtensionKind::Zero>; |
| |
| //===----------------------------------------------------------------------===// |
| // *IToFPOp Patterns |
| //===----------------------------------------------------------------------===// |
| |
| template <typename IToFPOp, ExtensionKind Extension> |
| struct IToFPPattern final : NarrowingPattern<IToFPOp> { |
| using NarrowingPattern<IToFPOp>::NarrowingPattern; |
| |
| LogicalResult matchAndRewrite(IToFPOp op, |
| PatternRewriter &rewriter) const override { |
| FailureOr<unsigned> narrowestWidth = |
| calculateBitsRequired(op.getIn(), Extension); |
| if (failed(narrowestWidth)) |
| return failure(); |
| |
| FailureOr<Type> narrowTy = |
| this->getNarrowType(*narrowestWidth, op.getIn().getType()); |
| if (failed(narrowTy)) |
| return failure(); |
| |
| Value newIn = rewriter.createOrFold<arith::TruncIOp>(op.getLoc(), *narrowTy, |
| op.getIn()); |
| rewriter.replaceOpWithNewOp<IToFPOp>(op, op.getType(), newIn); |
| return success(); |
| } |
| }; |
| using SIToFPPattern = IToFPPattern<arith::SIToFPOp, ExtensionKind::Sign>; |
| using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>; |
| |
| //===----------------------------------------------------------------------===// |
| // Index Cast Patterns |
| //===----------------------------------------------------------------------===// |
| |
| // These rely on the `ValueBounds` interface for index values. For example, we |
| // can often statically tell index value bounds of loop induction variables. |
| |
| template <typename CastOp, ExtensionKind Kind> |
| struct IndexCastPattern final : NarrowingPattern<CastOp> { |
| using NarrowingPattern<CastOp>::NarrowingPattern; |
| |
| LogicalResult matchAndRewrite(CastOp op, |
| PatternRewriter &rewriter) const override { |
| Value in = op.getIn(); |
| // We only support scalar index -> integer casts. |
| if (!isa<IndexType>(in.getType())) |
| return failure(); |
| |
| // Check the lower bound in both the signed and unsigned cast case. We |
| // conservatively assume that even unsigned casts may be performed on |
| // negative indices. |
| FailureOr<int64_t> lb = ValueBoundsConstraintSet::computeConstantBound( |
| presburger::BoundType::LB, in); |
| if (failed(lb)) |
| return failure(); |
| |
| FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound( |
| presburger::BoundType::UB, in, |
| /*stopCondition=*/nullptr, /*closedUB=*/true); |
| if (failed(ub)) |
| return failure(); |
| |
| assert(*lb <= *ub && "Invalid bounds"); |
| unsigned lbBitsRequired = calculateBitsRequired(APInt(64, *lb), Kind); |
| unsigned ubBitsRequired = calculateBitsRequired(APInt(64, *ub), Kind); |
| unsigned bitsRequired = std::max(lbBitsRequired, ubBitsRequired); |
| |
| IntegerType resultTy = cast<IntegerType>(op.getType()); |
| if (resultTy.getWidth() <= bitsRequired) |
| return failure(); |
| |
| FailureOr<Type> narrowTy = this->getNarrowType(bitsRequired, resultTy); |
| if (failed(narrowTy)) |
| return failure(); |
| |
| Value newCast = rewriter.create<CastOp>(op.getLoc(), *narrowTy, op.getIn()); |
| |
| if (Kind == ExtensionKind::Sign) |
| rewriter.replaceOpWithNewOp<arith::ExtSIOp>(op, resultTy, newCast); |
| else |
| rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, resultTy, newCast); |
| return success(); |
| } |
| }; |
| using IndexCastSIPattern = |
| IndexCastPattern<arith::IndexCastOp, ExtensionKind::Sign>; |
| using IndexCastUIPattern = |
| IndexCastPattern<arith::IndexCastUIOp, ExtensionKind::Zero>; |
| |
| //===----------------------------------------------------------------------===// |
| // Patterns to Commute Extension Ops |
| //===----------------------------------------------------------------------===// |
| |
| struct ExtensionOverBroadcast final : NarrowingPattern<vector::BroadcastOp> { |
| using NarrowingPattern::NarrowingPattern; |
| |
| LogicalResult matchAndRewrite(vector::BroadcastOp op, |
| PatternRewriter &rewriter) const override { |
| FailureOr<ExtensionOp> ext = |
| ExtensionOp::from(op.getSource().getDefiningOp()); |
| if (failed(ext)) |
| return failure(); |
| |
| VectorType origTy = op.getResultVectorType(); |
| VectorType newTy = |
| origTy.cloneWith(origTy.getShape(), ext->getInElementType()); |
| Value newBroadcast = |
| rewriter.create<vector::BroadcastOp>(op.getLoc(), newTy, ext->getIn()); |
| ext->recreateAndReplace(rewriter, op, newBroadcast); |
| return success(); |
| } |
| }; |
| |
| struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> { |
| using NarrowingPattern::NarrowingPattern; |
| |
| LogicalResult matchAndRewrite(vector::ExtractOp op, |
| PatternRewriter &rewriter) const override { |
| FailureOr<ExtensionOp> ext = |
| ExtensionOp::from(op.getVector().getDefiningOp()); |
| if (failed(ext)) |
| return failure(); |
| |
| Value newExtract = rewriter.create<vector::ExtractOp>( |
| op.getLoc(), ext->getIn(), op.getMixedPosition()); |
| ext->recreateAndReplace(rewriter, op, newExtract); |
| return success(); |
| } |
| }; |
| |
| struct ExtensionOverExtractElement final |
| : NarrowingPattern<vector::ExtractElementOp> { |
| using NarrowingPattern::NarrowingPattern; |
| |
| LogicalResult matchAndRewrite(vector::ExtractElementOp op, |
| PatternRewriter &rewriter) const override { |
| FailureOr<ExtensionOp> ext = |
| ExtensionOp::from(op.getVector().getDefiningOp()); |
| if (failed(ext)) |
| return failure(); |
| |
| Value newExtract = rewriter.create<vector::ExtractElementOp>( |
| op.getLoc(), ext->getIn(), op.getPosition()); |
| ext->recreateAndReplace(rewriter, op, newExtract); |
| return success(); |
| } |
| }; |
| |
| struct ExtensionOverExtractStridedSlice final |
| : NarrowingPattern<vector::ExtractStridedSliceOp> { |
| using NarrowingPattern::NarrowingPattern; |
| |
| LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op, |
| PatternRewriter &rewriter) const override { |
| FailureOr<ExtensionOp> ext = |
| ExtensionOp::from(op.getVector().getDefiningOp()); |
| if (failed(ext)) |
| return failure(); |
| |
| VectorType origTy = op.getType(); |
| VectorType extractTy = |
| origTy.cloneWith(origTy.getShape(), ext->getInElementType()); |
| Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>( |
| op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(), |
| op.getStrides()); |
| ext->recreateAndReplace(rewriter, op, newExtract); |
| return success(); |
| } |
| }; |
| |
| /// Base pattern for `vector.insert` narrowing patterns. |
| template <typename InsertionOp> |
| struct ExtensionOverInsertionPattern : NarrowingPattern<InsertionOp> { |
| using NarrowingPattern<InsertionOp>::NarrowingPattern; |
| |
| /// Derived classes must provide a function to create the matching insertion |
| /// op based on the original op and new arguments. |
| virtual InsertionOp createInsertionOp(PatternRewriter &rewriter, |
| InsertionOp origInsert, |
| Value narrowValue, |
| Value narrowDest) const = 0; |
| |
| LogicalResult matchAndRewrite(InsertionOp op, |
| PatternRewriter &rewriter) const final { |
| FailureOr<ExtensionOp> ext = |
| ExtensionOp::from(op.getSource().getDefiningOp()); |
| if (failed(ext)) |
| return failure(); |
| |
| FailureOr<InsertionOp> newInsert = createNarrowInsert(op, rewriter, *ext); |
| if (failed(newInsert)) |
| return failure(); |
| ext->recreateAndReplace(rewriter, op, *newInsert); |
| return success(); |
| } |
| |
| FailureOr<InsertionOp> createNarrowInsert(InsertionOp op, |
| PatternRewriter &rewriter, |
| ExtensionOp insValue) const { |
| // Calculate the operand and result bitwidths. We can only apply narrowing |
| // when the inserted source value and destination vector require fewer bits |
| // than the result. Because the source and destination may have different |
| // bitwidths requirements, we have to find the common narrow bitwidth that |
| // is greater equal to the operand bitwidth requirements and still narrower |
| // than the result. |
| FailureOr<unsigned> origBitsRequired = calculateBitsRequired(op.getType()); |
| if (failed(origBitsRequired)) |
| return failure(); |
| |
| // TODO: We could relax this check by disregarding bitwidth requirements of |
| // elements that we know will be replaced by the insertion. |
| FailureOr<unsigned> destBitsRequired = |
| calculateBitsRequired(op.getDest(), insValue.getKind()); |
| if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired) |
| return failure(); |
| |
| FailureOr<unsigned> insertedBitsRequired = |
| calculateBitsRequired(insValue.getIn(), insValue.getKind()); |
| if (failed(insertedBitsRequired) || |
| *insertedBitsRequired >= *origBitsRequired) |
| return failure(); |
| |
| // Find a narrower element type that satisfies the bitwidth requirements of |
| // both the source and the destination values. |
| unsigned newInsertionBits = |
| std::max(*destBitsRequired, *insertedBitsRequired); |
| FailureOr<Type> newVecTy = |
| this->getNarrowType(newInsertionBits, op.getType()); |
| if (failed(newVecTy) || *newVecTy == op.getType()) |
| return failure(); |
| |
| FailureOr<Type> newInsertedValueTy = |
| this->getNarrowType(newInsertionBits, insValue.getType()); |
| if (failed(newInsertedValueTy)) |
| return failure(); |
| |
| Location loc = op.getLoc(); |
| Value narrowValue = rewriter.createOrFold<arith::TruncIOp>( |
| loc, *newInsertedValueTy, insValue.getResult()); |
| Value narrowDest = |
| rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest()); |
| return createInsertionOp(rewriter, op, narrowValue, narrowDest); |
| } |
| }; |
| |
| struct ExtensionOverInsert final |
| : ExtensionOverInsertionPattern<vector::InsertOp> { |
| using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern; |
| |
| vector::InsertOp createInsertionOp(PatternRewriter &rewriter, |
| vector::InsertOp origInsert, |
| Value narrowValue, |
| Value narrowDest) const override { |
| return rewriter.create<vector::InsertOp>(origInsert.getLoc(), narrowValue, |
| narrowDest, |
| origInsert.getMixedPosition()); |
| } |
| }; |
| |
| struct ExtensionOverInsertElement final |
| : ExtensionOverInsertionPattern<vector::InsertElementOp> { |
| using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern; |
| |
| vector::InsertElementOp createInsertionOp(PatternRewriter &rewriter, |
| vector::InsertElementOp origInsert, |
| Value narrowValue, |
| Value narrowDest) const override { |
| return rewriter.create<vector::InsertElementOp>( |
| origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition()); |
| } |
| }; |
| |
| struct ExtensionOverInsertStridedSlice final |
| : ExtensionOverInsertionPattern<vector::InsertStridedSliceOp> { |
| using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern; |
| |
| vector::InsertStridedSliceOp |
| createInsertionOp(PatternRewriter &rewriter, |
| vector::InsertStridedSliceOp origInsert, Value narrowValue, |
| Value narrowDest) const override { |
| return rewriter.create<vector::InsertStridedSliceOp>( |
| origInsert.getLoc(), narrowValue, narrowDest, origInsert.getOffsets(), |
| origInsert.getStrides()); |
| } |
| }; |
| |
| struct ExtensionOverShapeCast final : NarrowingPattern<vector::ShapeCastOp> { |
| using NarrowingPattern::NarrowingPattern; |
| |
| LogicalResult matchAndRewrite(vector::ShapeCastOp op, |
| PatternRewriter &rewriter) const override { |
| FailureOr<ExtensionOp> ext = |
| ExtensionOp::from(op.getSource().getDefiningOp()); |
| if (failed(ext)) |
| return failure(); |
| |
| VectorType origTy = op.getResultVectorType(); |
| VectorType newTy = |
| origTy.cloneWith(origTy.getShape(), ext->getInElementType()); |
| Value newCast = |
| rewriter.create<vector::ShapeCastOp>(op.getLoc(), newTy, ext->getIn()); |
| ext->recreateAndReplace(rewriter, op, newCast); |
| return success(); |
| } |
| }; |
| |
| struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> { |
| using NarrowingPattern::NarrowingPattern; |
| |
| LogicalResult matchAndRewrite(vector::TransposeOp op, |
| PatternRewriter &rewriter) const override { |
| FailureOr<ExtensionOp> ext = |
| ExtensionOp::from(op.getVector().getDefiningOp()); |
| if (failed(ext)) |
| return failure(); |
| |
| VectorType origTy = op.getResultVectorType(); |
| VectorType newTy = |
| origTy.cloneWith(origTy.getShape(), ext->getInElementType()); |
| Value newTranspose = rewriter.create<vector::TransposeOp>( |
| op.getLoc(), newTy, ext->getIn(), op.getPermutation()); |
| ext->recreateAndReplace(rewriter, op, newTranspose); |
| return success(); |
| } |
| }; |
| |
| struct ExtensionOverFlatTranspose final |
| : NarrowingPattern<vector::FlatTransposeOp> { |
| using NarrowingPattern::NarrowingPattern; |
| |
| LogicalResult matchAndRewrite(vector::FlatTransposeOp op, |
| PatternRewriter &rewriter) const override { |
| FailureOr<ExtensionOp> ext = |
| ExtensionOp::from(op.getMatrix().getDefiningOp()); |
| if (failed(ext)) |
| return failure(); |
| |
| VectorType origTy = op.getType(); |
| VectorType newTy = |
| origTy.cloneWith(origTy.getShape(), ext->getInElementType()); |
| Value newTranspose = rewriter.create<vector::FlatTransposeOp>( |
| op.getLoc(), newTy, ext->getIn(), op.getRowsAttr(), |
| op.getColumnsAttr()); |
| ext->recreateAndReplace(rewriter, op, newTranspose); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // Pass Definitions |
| //===----------------------------------------------------------------------===// |
| |
| struct ArithIntNarrowingPass final |
| : impl::ArithIntNarrowingBase<ArithIntNarrowingPass> { |
| using ArithIntNarrowingBase::ArithIntNarrowingBase; |
| |
| void runOnOperation() override { |
| if (bitwidthsSupported.empty() || |
| llvm::is_contained(bitwidthsSupported, 0)) { |
| // Invalid pass options. |
| return signalPassFailure(); |
| } |
| |
| Operation *op = getOperation(); |
| MLIRContext *ctx = op->getContext(); |
| RewritePatternSet patterns(ctx); |
| populateArithIntNarrowingPatterns( |
| patterns, ArithIntNarrowingOptions{bitwidthsSupported}); |
| if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) |
| signalPassFailure(); |
| } |
| }; |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // Public API |
| //===----------------------------------------------------------------------===// |
| |
| void populateArithIntNarrowingPatterns( |
| RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) { |
| // Add commute patterns with a higher benefit. This is to expose more |
| // optimization opportunities to narrowing patterns. |
| patterns.add<ExtensionOverBroadcast, ExtensionOverExtract, |
| ExtensionOverExtractElement, ExtensionOverExtractStridedSlice, |
| ExtensionOverInsert, ExtensionOverInsertElement, |
| ExtensionOverInsertStridedSlice, ExtensionOverShapeCast, |
| ExtensionOverTranspose, ExtensionOverFlatTranspose>( |
| patterns.getContext(), options, PatternBenefit(2)); |
| |
| patterns.add<AddIPattern, SubIPattern, MulIPattern, DivSIPattern, |
| DivUIPattern, MaxSIPattern, MaxUIPattern, MinSIPattern, |
| MinUIPattern, SIToFPPattern, UIToFPPattern, IndexCastSIPattern, |
| IndexCastUIPattern>(patterns.getContext(), options); |
| } |
| |
| } // namespace mlir::arith |