blob: 3635cd319a07856a2d4f3f2f986c9c70179e92a3 [file] [log] [blame]
//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements lowering patterns from vector.contract to
// arm_neon.intr.smmla
//
//===---
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmNeon/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "lower-contract-to-arm-neon"
using namespace mlir;
using namespace mlir::arm_neon;
namespace {
/// Return the shaped type with new element type.
static Type matchContainerType(Type element, Type container) {
if (auto shapedTy = dyn_cast<ShapedType>(container)) {
return shapedTy.clone(element);
}
return element;
}
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
/// any vector.contract into multiple smmla instructions with unrolling so long
/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
/// necessary, a single smmla instruction is emitted.
class LowerContractionToSMMLAPattern
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
// Note: RHS is not transposed.
mlir::VectorType lhsType = op.getLhsType();
mlir::VectorType rhsType = op.getRhsType();
// Avoid 0-D vectors and 1-D rhs:
if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
return failure();
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
auto dimK = rhsType.getDimSize(1);
bool isVecmat = dimM == 1 ? true : false;
if (lhsType.getDimSize(lhsType.getRank() - 1) !=
rhsType.getDimSize(rhsType.getRank() - 1)) {
return failure(); // dimK mismatch
}
// Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
// tiling.
if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
return failure();
}
// Check iterator types for contract. All iterators except inner-most
// dimension must be parallel.
auto iteratorTypes = op.getIteratorTypesArray();
if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
vector::IteratorType::reduction) {
return failure();
}
if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
[](vector::IteratorType iteratorType) {
return iteratorType != vector::IteratorType::parallel;
})) {
return failure();
}
// Check two extsi inputs Rhs Lhs for contract.
arith::ExtSIOp origLhsExtOp =
dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
arith::ExtSIOp origRhsExtOp =
dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
if (!origLhsExtOp || !origRhsExtOp) {
return failure();
}
// Match any iX to i32 for X<8 then turn into an i8 output. Feed into
// following neon instruction. Check inputs for extsi are <=i8
Value extsiLhs;
Value extsiRhs;
if (auto lhsExtInType =
dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
if (lhsExtInType.getElementTypeBitWidth() <= 8) {
Type targetLhsExtTy =
matchContainerType(rewriter.getI8Type(), lhsExtInType);
extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
origLhsExtOp.getIn());
}
}
if (auto rhsExtInType =
dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
if (rhsExtInType.getElementTypeBitWidth() <= 8) {
Type targetRhsExtTy =
matchContainerType(rewriter.getI8Type(), rhsExtInType);
extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
origRhsExtOp.getIn());
}
}
if (!extsiLhs || !extsiRhs) {
return failure();
}
// Initial accumulator for the final result. This is the un-tiled result if
// tiling is done.
Value result = rewriter.create<arith::ConstantOp>(
loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
SmallVector<int64_t> smmlaShape{2, 8};
SmallVector<int64_t> loopOrder{0, 1};
if (unrolledSize.size() == 3) {
smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
loopOrder.push_back(2);
}
// Keep track of the previous accumulator when tiling over K.
Value kAcc;
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
// Helper to compute the new shape of each operand and extract the slice.
auto extractOperand = [&](Value operand, AffineMap permutationMap,
ArrayRef<int64_t> operandOffsets) {
SmallVector<int64_t> operandShape =
applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, operand, operandOffsets, operandShape, operandStrides);
};
// Extract tiled lhs, rhs, and acc
AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
SmallVector<int64_t> lhsOffsets =
applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
SmallVector<int64_t> rhsOffsets =
applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
SmallVector<int64_t> accOffsets =
applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
Value tiledAcc =
extractOperand(op.getAcc(), accPermutationMap, accOffsets);
auto inputElementType =
cast<ShapedType>(tiledLhs.getType()).getElementType();
auto accElementType =
cast<ShapedType>(tiledAcc.getType()).getElementType();
auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
auto outputExpandedType = VectorType::get({2, 2}, accElementType);
// With vecmat, tiled LHS and ACC will contain only one of 2 necessary
// rows along dimM. Expand their shapes to match the smmla op.
if (isVecmat) {
auto expandForSMMLA = [&](Value tiledOperand,
VectorType expandedTypeType) {
auto emptyOperand = rewriter.create<arith::ConstantOp>(
loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
SmallVector<int64_t> offsets(
cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
SmallVector<int64_t> strides(
cast<ShapedType>(tiledOperand.getType()).getRank(), 1);
return rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, tiledOperand, emptyOperand, offsets, strides);
};
tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
}
// Collapse tiled operands to 1D vectors required by smmla intrinsic
auto collapsedInputType =
VectorType::get(inputExpandedType.getNumElements(), inputElementType);
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledLhs.getLoc(), collapsedInputType, tiledLhs);
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledRhs.getLoc(), collapsedInputType, tiledRhs);
auto collapsedOutputType =
VectorType::get(outputExpandedType.getNumElements(), accElementType);
bool initialKAcc = offsets.back() == 0;
Value collapsedRes;
if (!initialKAcc) {
collapsedRes = kAcc;
} else {
collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
}
// Insert contract op
kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>(
op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
collapsedRhs);
// Reshape output back to 2D
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
kAcc.getLoc(), tiledAcc.getType(), kAcc);
// With vecmat, only one row of tiled ACC can be inserted into file result
if (isVecmat) {
tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
}
// Insert the tiled result back into the non tiled result of the
// contract op.
SmallVector<int64_t> strides(
cast<ShapedType>(tiledRes.getType()).getRank(), 1);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, tiledRes, result, accOffsets, strides);
}
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
}