[mlir] [VectorOps] Initial framework for progressively lowering vector.contract
Summary:
This sets the basic framework for lowering vector.contract progressively
into simpler vector.contract operations until a direct vector.reduction
operation is reached. More details will be filled out progressively as well.
Reviewers: nicolasvasilache
Reviewed By: nicolasvasilache
Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D74520
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h
index 990b477..1aee56f 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h
@@ -54,6 +54,12 @@
void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
MLIRContext *context);
+/// Collect a set of vector contraction transformation patterns
+/// that express all vector.contract ops in terms of more elementary
+/// extraction and reduction ops.
+void populateVectorContractLoweringPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *context);
+
/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index 074a6d0..5ead876 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -216,6 +216,29 @@
}];
}
+// TODO(ajcbik): quick version with "fused" accumulator; next step
+// will merge Reduction/ReductionV2 into one with
+// an optional accumulator instead
+def Vector_ReductionV2Op :
+ Vector_Op<"reductionv2", [NoSideEffect]>,
+ Arguments<(ins StrAttr:$kind, VectorOf<[F32, F64]>:$vector, AnyType:$acc)>,
+ Results<(outs AnyType:$dest)> {
+ let summary = "reduction operation";
+ let description = [{
+ As vector.reduction, but with a fused accumulator (add/mul for fp only).
+ }];
+ let verifier = ?;
+ let assemblyFormat = [{
+ $kind `,` $vector `,` $acc attr-dict `:`
+ type($vector) `,` type($acc) `into` type($dest)
+ }];
+ let extraClassDeclaration = [{
+ VectorType getVectorType() {
+ return vector().getType().cast<VectorType>();
+ }
+ }];
+}
+
def Vector_BroadcastOp :
Vector_Op<"broadcast", [NoSideEffect,
PredOpTrait<"source operand and result have same element type",
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9fcad2f..c43fc0e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -340,6 +340,33 @@
}
};
+// TODO(ajcbik): merge Reduction and ReductionV2
+class VectorReductionV2OpConversion : public LLVMOpLowering {
+public:
+ explicit VectorReductionV2OpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::ReductionV2Op::getOperationName(), context,
+ typeConverter) {}
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto reductionOp = cast<vector::ReductionV2Op>(op);
+ auto kind = reductionOp.kind();
+ Type eltType = reductionOp.dest().getType();
+ Type llvmType = lowering.convertType(eltType);
+ if (kind == "add") {
+ rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
+ op, llvmType, operands[1], operands[0]);
+ return matchSuccess();
+ } else if (kind == "mul") {
+ rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>(
+ op, llvmType, operands[1], operands[0]);
+ return matchSuccess();
+ }
+ return matchFailure();
+ }
+};
+
class VectorShuffleOpConversion : public LLVMOpLowering {
public:
explicit VectorShuffleOpConversion(MLIRContext *context,
@@ -1125,11 +1152,12 @@
VectorInsertStridedSliceOpSameRankRewritePattern,
VectorStridedSliceOpConversion>(ctx);
patterns.insert<VectorBroadcastOpConversion, VectorReductionOpConversion,
- VectorShuffleOpConversion, VectorExtractElementOpConversion,
- VectorExtractOpConversion, VectorFMAOp1DConversion,
- VectorInsertElementOpConversion, VectorInsertOpConversion,
- VectorOuterProductOpConversion, VectorTypeCastOpConversion,
- VectorPrintOpConversion>(ctx, converter);
+ VectorReductionV2OpConversion, VectorShuffleOpConversion,
+ VectorExtractElementOpConversion, VectorExtractOpConversion,
+ VectorFMAOp1DConversion, VectorInsertElementOpConversion,
+ VectorInsertOpConversion, VectorOuterProductOpConversion,
+ VectorTypeCastOpConversion, VectorPrintOpConversion>(
+ ctx, converter);
}
namespace {
@@ -1139,11 +1167,12 @@
} // namespace
void LowerVectorToLLVMPass::runOnModule() {
- // Perform progressive lowering of operations on "slices".
- // Folding and DCE get rid of all non-leaking tuple ops.
+ // Perform progressive lowering of operations on "slices" and
+ // all contraction operations. Also applies folding and DCE.
{
OwningRewritePatternList patterns;
populateVectorSlicesLoweringPatterns(patterns, &getContext());
+ populateVectorContractLoweringPatterns(patterns, &getContext());
applyPatternsGreedily(getModule(), patterns);
}
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index 8bdeb92..fe62666 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -538,6 +538,7 @@
}
namespace {
+
// Splits vector TransferReadOp into smaller TransferReadOps based on slicing
// scheme of its unique ExtractSlicesOp user.
struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
@@ -862,6 +863,72 @@
}
};
+/// Progressive lowering of ConstractionOp.
+class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+ // TODO(ajcbik): implement masks
+ if (llvm::size(op.masks()) != 0)
+ return matchFailure();
+
+ auto loc = op.getLoc();
+ VectorType lhsType = op.getLhsType();
+ VectorType rhsType = op.getRhsType();
+ Type resType = op.getResultType();
+
+ // Find first batch dimension in lhs/rhs, and lower when found.
+ std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
+ if (!batchDimMap.empty()) {
+ // TODO(ajcbik): implement batch
+ return matchFailure();
+ }
+
+ // Collect contracting dimensions.
+ std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
+ op.getContractingDimMap();
+ DenseSet<int64_t> lhsContractingDimSet;
+ DenseSet<int64_t> rhsContractingDimSet;
+ for (auto &dimPair : contractingDimMap) {
+ lhsContractingDimSet.insert(dimPair.first);
+ rhsContractingDimSet.insert(dimPair.second);
+ }
+
+ // Find free dimension in lhs/rhs, and lower first when found.
+ for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
+ if (lhsContractingDimSet.count(i) == 0) {
+ // TODO(ajcbik): implement free
+ return matchFailure();
+ }
+ }
+ for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
+ if (rhsContractingDimSet.count(i) == 0) {
+ // TODO(ajcbik): implement free
+ return matchFailure();
+ }
+ }
+
+ // Only contraction dimensions remain.
+ if (!resType.isa<VectorType>() && lhsType.getRank() == 1 &&
+ rhsType.getRank() == 1) {
+ // Handle reduction into scalar.
+ Value zero = rewriter.create<ConstantOp>(loc, resType,
+ rewriter.getZeroAttr(resType));
+ Value splat = rewriter.create<SplatOp>(loc, lhsType, zero);
+ Value fma =
+ rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), splat);
+ StringAttr kind = rewriter.getStringAttr("add");
+ rewriter.replaceOpWithNewOp<vector::ReductionV2Op>(op, resType, kind, fma,
+ op.acc());
+ return matchSuccess();
+ }
+ // TODO(ajcbik): implement more contraction
+ return matchFailure();
+ }
+};
+
} // namespace
// TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp).
@@ -876,3 +943,8 @@
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);
}
+
+void mlir::vector::populateVectorContractLoweringPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context) {
+ patterns.insert<ContractionOpLowering>(context);
+}
diff --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
new file mode 100644
index 0000000..6c4cb5f
--- /dev/null
+++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s
+
+#dotp_accesses = [
+ affine_map<(i) -> (i)>,
+ affine_map<(i) -> (i)>,
+ affine_map<(i) -> ()>
+]
+#dotp_trait = {
+ indexing_maps = #dotp_accesses,
+ iterator_types = ["reduction"]
+}
+
+// CHECK-LABEL: func @extract_contract1
+// CHECK-SAME: %[[A:.*0]]: vector<4xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<4xf32>,
+// CHECK-SAME: %[[C:.*2]]: f32
+// CHECK: %[[Z:.*]] = constant dense<0.000000e+00>
+// CHECK: %[[F:.*]] = vector.fma %[[A]], %[[B]], %[[Z]] : vector<4xf32>
+// CHECK: %[[R:.*]] = vector.reductionv2 "add", %[[F]], %[[C]]
+// CHECK: return %[[R]] : f32
+
+func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 {
+ %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2
+ : vector<4xf32>, vector<4xf32> into f32
+ return %0 : f32
+}
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index d0ac718..3f35f81 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -42,16 +42,29 @@
}
};
+struct TestVectorContractionConversion
+ : public FunctionPass<TestVectorContractionConversion> {
+ void runOnFunction() override {
+ OwningRewritePatternList patterns;
+ populateVectorContractLoweringPatterns(patterns, &getContext());
+ applyPatternsGreedily(getFunction(), patterns);
+ }
+};
+
} // end anonymous namespace
namespace mlir {
void registerTestVectorConversions() {
- PassRegistration<TestVectorToVectorConversion> pass(
+ PassRegistration<TestVectorToVectorConversion> vectorToVectorPass(
"test-vector-to-vector-conversion",
"Test conversion patterns between ops in the vector dialect");
- PassRegistration<TestVectorSlicesConversion> slices_pass(
+ PassRegistration<TestVectorSlicesConversion> slicesPass(
"test-vector-slices-conversion",
"Test conversion patterns that lower slices ops in the vector dialect");
+
+ PassRegistration<TestVectorContractionConversion> contractionPass(
+ "test-vector-contraction-conversion",
+ "Test conversion patterns that lower contract ops in the vector dialect");
}
} // namespace mlir