| //===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements lowering patterns from vector.contract to |
| // arm_neon.intr.smmla |
| // |
| //===--- |
| |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" |
| #include "mlir/Dialect/ArmNeon/Transforms.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| #include "mlir/Dialect/Utils/IndexingUtils.h" |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| #define DEBUG_TYPE "lower-contract-to-arm-neon" |
| |
| using namespace mlir; |
| using namespace mlir::arm_neon; |
| |
| namespace { |
| |
| /// Return the shaped type with new element type. |
| static Type matchContainerType(Type element, Type container) { |
| if (auto shapedTy = dyn_cast<ShapedType>(container)) { |
| return shapedTy.clone(element); |
| } |
| return element; |
| } |
| |
| /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile |
| /// any vector.contract into multiple smmla instructions with unrolling so long |
| /// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM |
| /// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is |
| /// necessary, a single smmla instruction is emitted. |
| class LowerContractionToSMMLAPattern |
| : public OpRewritePattern<vector::ContractionOp> { |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(vector::ContractionOp op, |
| PatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim. |
| // Note: RHS is not transposed. |
| mlir::VectorType lhsType = op.getLhsType(); |
| mlir::VectorType rhsType = op.getRhsType(); |
| // Avoid 0-D vectors and 1-D rhs: |
| if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2) |
| return failure(); |
| auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0); |
| auto dimN = rhsType.getDimSize(0); |
| auto dimK = rhsType.getDimSize(1); |
| bool isVecmat = dimM == 1 ? true : false; |
| if (lhsType.getDimSize(lhsType.getRank() - 1) != |
| rhsType.getDimSize(rhsType.getRank() - 1)) { |
| return failure(); // dimK mismatch |
| } |
| // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for |
| // tiling. |
| if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) { |
| return failure(); |
| } |
| |
| // Check iterator types for contract. All iterators except inner-most |
| // dimension must be parallel. |
| auto iteratorTypes = op.getIteratorTypesArray(); |
| if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] != |
| vector::IteratorType::reduction) { |
| return failure(); |
| } |
| if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1), |
| [](vector::IteratorType iteratorType) { |
| return iteratorType != vector::IteratorType::parallel; |
| })) { |
| return failure(); |
| } |
| |
| // Check two extsi inputs Rhs Lhs for contract. |
| arith::ExtSIOp origLhsExtOp = |
| dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp()); |
| arith::ExtSIOp origRhsExtOp = |
| dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp()); |
| if (!origLhsExtOp || !origRhsExtOp) { |
| return failure(); |
| } |
| |
| // Match any iX to i32 for X<8 then turn into an i8 output. Feed into |
| // following neon instruction. Check inputs for extsi are <=i8 |
| Value extsiLhs; |
| Value extsiRhs; |
| if (auto lhsExtInType = |
| dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) { |
| if (lhsExtInType.getElementTypeBitWidth() <= 8) { |
| Type targetLhsExtTy = |
| matchContainerType(rewriter.getI8Type(), lhsExtInType); |
| extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy, |
| origLhsExtOp.getIn()); |
| } |
| } |
| if (auto rhsExtInType = |
| dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) { |
| if (rhsExtInType.getElementTypeBitWidth() <= 8) { |
| Type targetRhsExtTy = |
| matchContainerType(rewriter.getI8Type(), rhsExtInType); |
| extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy, |
| origRhsExtOp.getIn()); |
| } |
| } |
| |
| if (!extsiLhs || !extsiRhs) { |
| return failure(); |
| } |
| |
| // Initial accumulator for the final result. This is the un-tiled result if |
| // tiling is done. |
| Value result = rewriter.create<arith::ConstantOp>( |
| loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType())); |
| |
| SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll(); |
| SmallVector<int64_t> smmlaShape{2, 8}; |
| SmallVector<int64_t> loopOrder{0, 1}; |
| if (unrolledSize.size() == 3) { |
| smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2); |
| loopOrder.push_back(2); |
| } |
| |
| // Keep track of the previous accumulator when tiling over K. |
| Value kAcc; |
| for (SmallVector<int64_t> offsets : |
| StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) { |
| // Helper to compute the new shape of each operand and extract the slice. |
| auto extractOperand = [&](Value operand, AffineMap permutationMap, |
| ArrayRef<int64_t> operandOffsets) { |
| SmallVector<int64_t> operandShape = |
| applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape)); |
| SmallVector<int64_t> operandStrides(operandOffsets.size(), 1); |
| return rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| loc, operand, operandOffsets, operandShape, operandStrides); |
| }; |
| |
| // Extract tiled lhs, rhs, and acc |
| AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0]; |
| SmallVector<int64_t> lhsOffsets = |
| applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets)); |
| Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets); |
| AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1]; |
| SmallVector<int64_t> rhsOffsets = |
| applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets)); |
| Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets); |
| AffineMap accPermutationMap = op.getIndexingMapsArray()[2]; |
| SmallVector<int64_t> accOffsets = |
| applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets)); |
| Value tiledAcc = |
| extractOperand(op.getAcc(), accPermutationMap, accOffsets); |
| |
| auto inputElementType = |
| cast<ShapedType>(tiledLhs.getType()).getElementType(); |
| auto accElementType = |
| cast<ShapedType>(tiledAcc.getType()).getElementType(); |
| auto inputExpandedType = VectorType::get({2, 8}, inputElementType); |
| auto outputExpandedType = VectorType::get({2, 2}, accElementType); |
| |
| // With vecmat, tiled LHS and ACC will contain only one of 2 necessary |
| // rows along dimM. Expand their shapes to match the smmla op. |
| if (isVecmat) { |
| auto expandForSMMLA = [&](Value tiledOperand, |
| VectorType expandedTypeType) { |
| auto emptyOperand = rewriter.create<arith::ConstantOp>( |
| loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType)); |
| SmallVector<int64_t> offsets( |
| cast<ShapedType>(emptyOperand.getType()).getRank(), 0); |
| SmallVector<int64_t> strides( |
| cast<ShapedType>(tiledOperand.getType()).getRank(), 1); |
| return rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| loc, tiledOperand, emptyOperand, offsets, strides); |
| }; |
| tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType); |
| tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType); |
| } |
| |
| // Collapse tiled operands to 1D vectors required by smmla intrinsic |
| auto collapsedInputType = |
| VectorType::get(inputExpandedType.getNumElements(), inputElementType); |
| auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>( |
| tiledLhs.getLoc(), collapsedInputType, tiledLhs); |
| auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>( |
| tiledRhs.getLoc(), collapsedInputType, tiledRhs); |
| auto collapsedOutputType = |
| VectorType::get(outputExpandedType.getNumElements(), accElementType); |
| |
| bool initialKAcc = offsets.back() == 0; |
| Value collapsedRes; |
| if (!initialKAcc) { |
| collapsedRes = kAcc; |
| } else { |
| collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>( |
| tiledAcc.getLoc(), collapsedOutputType, tiledAcc); |
| } |
| |
| // Insert contract op |
| kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>( |
| op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs, |
| collapsedRhs); |
| |
| // Reshape output back to 2D |
| Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>( |
| kAcc.getLoc(), tiledAcc.getType(), kAcc); |
| |
| // With vecmat, only one row of tiled ACC can be inserted into file result |
| if (isVecmat) { |
| tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0); |
| } |
| |
| // Insert the tiled result back into the non tiled result of the |
| // contract op. |
| SmallVector<int64_t> strides( |
| cast<ShapedType>(tiledRes.getType()).getRank(), 1); |
| result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| loc, tiledRes, result, accOffsets, strides); |
| } |
| |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns( |
| RewritePatternSet &patterns) { |
| MLIRContext *context = patterns.getContext(); |
| patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1); |
| } |