| //===- TestMathToVCIXConversion.cpp - Test conversion to VCIX ops ---------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/LLVMIR/VCIXDialect.h" |
| #include "mlir/Dialect/Math/IR/Math.h" |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Pass/PassManager.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| namespace mlir { |
| namespace { |
| |
| /// Return number of extracts required to make input VectorType \vt legal and |
| /// also return thatlegal vector type. |
| /// For fixed vectors nothing special is needed. Scalable vectors are legalizes |
| /// according to LLVM's encoding: |
| /// https://lists.llvm.org/pipermail/llvm-dev/2020-October/145850.html |
| static std::pair<unsigned, VectorType> legalizeVectorType(const Type &type) { |
| VectorType vt = cast<VectorType>(type); |
| // To simplify test pass, avoid multi-dimensional vectors. |
| if (!vt || vt.getRank() != 1) |
| return {0, nullptr}; |
| |
| if (!vt.isScalable()) |
| return {1, vt}; |
| |
| Type eltTy = vt.getElementType(); |
| unsigned sew = 0; |
| if (eltTy.isF32()) |
| sew = 32; |
| else if (eltTy.isF64()) |
| sew = 64; |
| else if (auto intTy = dyn_cast<IntegerType>(eltTy)) |
| sew = intTy.getWidth(); |
| else |
| return {0, nullptr}; |
| |
| unsigned eltCount = vt.getShape()[0]; |
| const unsigned lmul = eltCount * sew / 64; |
| |
| unsigned n = lmul > 8 ? llvm::Log2_32(lmul) - 2 : 1; |
| return {n, VectorType::get({eltCount >> (n - 1)}, eltTy, {true})}; |
| } |
| |
| /// Replace math.cos(v) operation with vcix.v.iv(v). |
| struct MathCosToVCIX final : OpRewritePattern<math::CosOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(math::CosOp op, |
| PatternRewriter &rewriter) const override { |
| const Type opType = op.getOperand().getType(); |
| auto [n, legalType] = legalizeVectorType(opType); |
| if (!legalType) |
| return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV"); |
| Location loc = op.getLoc(); |
| Value vec = op.getOperand(); |
| Attribute immAttr = rewriter.getI32IntegerAttr(0); |
| Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); |
| Value rvl = nullptr; |
| if (legalType.isScalable()) |
| // Use arbitrary runtime vector length when vector type is scalable. |
| // Proper conversion pass should take it from the IR. |
| rvl = rewriter.create<arith::ConstantOp>(loc, |
| rewriter.getI64IntegerAttr(9)); |
| Value res; |
| if (n == 1) { |
| res = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr, vec, |
| immAttr, rvl); |
| } else { |
| const unsigned eltCount = legalType.getShape()[0]; |
| Type eltTy = legalType.getElementType(); |
| Value zero = rewriter.create<arith::ConstantOp>( |
| loc, eltTy, rewriter.getZeroAttr(eltTy)); |
| res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); |
| for (unsigned i = 0; i < n; ++i) { |
| Value extracted = rewriter.create<vector::ScalableExtractOp>( |
| loc, legalType, vec, i * eltCount); |
| Value v = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr, |
| extracted, immAttr, rvl); |
| res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, |
| i * eltCount); |
| } |
| } |
| rewriter.replaceOp(op, res); |
| return success(); |
| } |
| }; |
| |
| // Replace math.sin(v) operation with vcix.v.sv(v, v). |
| struct MathSinToVCIX final : OpRewritePattern<math::SinOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(math::SinOp op, |
| PatternRewriter &rewriter) const override { |
| const Type opType = op.getOperand().getType(); |
| auto [n, legalType] = legalizeVectorType(opType); |
| if (!legalType) |
| return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV"); |
| Location loc = op.getLoc(); |
| Value vec = op.getOperand(); |
| Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); |
| Value rvl = nullptr; |
| if (legalType.isScalable()) |
| // Use arbitrary runtime vector length when vector type is scalable. |
| // Proper conversion pass should take it from the IR. |
| rvl = rewriter.create<arith::ConstantOp>(loc, |
| rewriter.getI64IntegerAttr(9)); |
| Value res; |
| if (n == 1) { |
| res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, |
| vec, rvl); |
| } else { |
| const unsigned eltCount = legalType.getShape()[0]; |
| Type eltTy = legalType.getElementType(); |
| Value zero = rewriter.create<arith::ConstantOp>( |
| loc, eltTy, rewriter.getZeroAttr(eltTy)); |
| res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); |
| for (unsigned i = 0; i < n; ++i) { |
| Value extracted = rewriter.create<vector::ScalableExtractOp>( |
| loc, legalType, vec, i * eltCount); |
| Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, |
| extracted, extracted, rvl); |
| res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, |
| i * eltCount); |
| } |
| } |
| rewriter.replaceOp(op, res); |
| return success(); |
| } |
| }; |
| |
| // Replace math.tan(v) operation with vcix.v.sv(v, 0.0f). |
| struct MathTanToVCIX final : OpRewritePattern<math::TanOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(math::TanOp op, |
| PatternRewriter &rewriter) const override { |
| const Type opType = op.getOperand().getType(); |
| auto [n, legalType] = legalizeVectorType(opType); |
| Type eltTy = legalType.getElementType(); |
| if (!legalType) |
| return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV"); |
| Location loc = op.getLoc(); |
| Value vec = op.getOperand(); |
| Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); |
| Value zero = rewriter.create<arith::ConstantOp>( |
| loc, eltTy, rewriter.getZeroAttr(eltTy)); |
| Value rvl = nullptr; |
| if (legalType.isScalable()) |
| // Use arbitrary runtime vector length when vector type is scalable. |
| // Proper conversion pass should take it from the IR. |
| rvl = rewriter.create<arith::ConstantOp>(loc, |
| rewriter.getI64IntegerAttr(9)); |
| Value res; |
| if (n == 1) { |
| res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, |
| zero, rvl); |
| } else { |
| const unsigned eltCount = legalType.getShape()[0]; |
| res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); |
| for (unsigned i = 0; i < n; ++i) { |
| Value extracted = rewriter.create<vector::ScalableExtractOp>( |
| loc, legalType, vec, i * eltCount); |
| Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, |
| extracted, zero, rvl); |
| res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, |
| i * eltCount); |
| } |
| } |
| rewriter.replaceOp(op, res); |
| return success(); |
| } |
| }; |
| |
| // Replace math.log(v) operation with vcix.v.sv(v, 0). |
| struct MathLogToVCIX final : OpRewritePattern<math::LogOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(math::LogOp op, |
| PatternRewriter &rewriter) const override { |
| const Type opType = op.getOperand().getType(); |
| auto [n, legalType] = legalizeVectorType(opType); |
| if (!legalType) |
| return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV"); |
| Location loc = op.getLoc(); |
| Value vec = op.getOperand(); |
| Attribute opcodeAttr = rewriter.getI64IntegerAttr(0); |
| Value rvl = nullptr; |
| Value zeroInt = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); |
| if (legalType.isScalable()) |
| // Use arbitrary runtime vector length when vector type is scalable. |
| // Proper conversion pass should take it from the IR. |
| rvl = rewriter.create<arith::ConstantOp>(loc, |
| rewriter.getI64IntegerAttr(9)); |
| Value res; |
| if (n == 1) { |
| res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec, |
| zeroInt, rvl); |
| } else { |
| const unsigned eltCount = legalType.getShape()[0]; |
| Type eltTy = legalType.getElementType(); |
| Value zero = rewriter.create<arith::ConstantOp>( |
| loc, eltTy, rewriter.getZeroAttr(eltTy)); |
| res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/); |
| for (unsigned i = 0; i < n; ++i) { |
| Value extracted = rewriter.create<vector::ScalableExtractOp>( |
| loc, legalType, vec, i * eltCount); |
| Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, |
| extracted, zeroInt, rvl); |
| res = rewriter.create<vector::ScalableInsertOp>(loc, v, res, |
| i * eltCount); |
| } |
| } |
| rewriter.replaceOp(op, res); |
| return success(); |
| } |
| }; |
| |
| struct TestMathToVCIX |
| : PassWrapper<TestMathToVCIX, OperationPass<func::FuncOp>> { |
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMathToVCIX) |
| |
| StringRef getArgument() const final { return "test-math-to-vcix"; } |
| |
| StringRef getDescription() const final { |
| return "Test lowering patterns that converts some vector operations to " |
| "VCIX. Since DLA can implement VCIX instructions in completely " |
| "different way, conversions of that test pass only lives here."; |
| } |
| |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<arith::ArithDialect, func::FuncDialect, math::MathDialect, |
| vcix::VCIXDialect, vector::VectorDialect>(); |
| } |
| |
| void runOnOperation() override { |
| MLIRContext *ctx = &getContext(); |
| RewritePatternSet patterns(ctx); |
| patterns.add<MathCosToVCIX, MathSinToVCIX, MathTanToVCIX, MathLogToVCIX>( |
| ctx); |
| (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
| } |
| }; |
| |
| } // namespace |
| |
| namespace test { |
| void registerTestMathToVCIXPass() { PassRegistration<TestMathToVCIX>(); } |
| } // namespace test |
| } // namespace mlir |