[mlir] [VectorOps] Initial framework for progressively lowering vector.contract

Summary:
This sets the basic framework for lowering vector.contract progressively
into simpler vector.contract operations until a direct vector.reduction
operation is reached. More details will be filled out progressively as well.

Reviewers: nicolasvasilache

Reviewed By: nicolasvasilache

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74520
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h
index 990b477..1aee56f 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h
@@ -54,6 +54,12 @@
 void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
                                           MLIRContext *context);
 
+/// Collect a set of vector contraction transformation patterns
+/// that express all vector.contract ops in terms of more elementary
+/// extraction and reduction ops.
+void populateVectorContractLoweringPatterns(OwningRewritePatternList &patterns,
+                                            MLIRContext *context);
+
 /// Returns the integer type required for subscripts in the vector dialect.
 IntegerType getVectorSubscriptType(Builder &builder);
 
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index 074a6d0..5ead876 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -216,6 +216,29 @@
   }];
 }
 
+// TODO(ajcbik): quick version with "fused" accumulator; next step
+//               will merge Reduction/ReductionV2 into one with
+//               an optional accumulator instead
+def Vector_ReductionV2Op :
+  Vector_Op<"reductionv2", [NoSideEffect]>,
+    Arguments<(ins StrAttr:$kind, VectorOf<[F32, F64]>:$vector, AnyType:$acc)>,
+    Results<(outs AnyType:$dest)> {
+  let summary = "reduction operation";
+  let description = [{
+     As vector.reduction, but with a fused accumulator (add/mul for fp only).
+  }];
+  let verifier = ?;
+  let assemblyFormat = [{
+    $kind `,` $vector `,` $acc attr-dict `:`
+      type($vector) `,` type($acc) `into` type($dest)
+  }];
+  let extraClassDeclaration = [{
+    VectorType getVectorType() {
+      return vector().getType().cast<VectorType>();
+    }
+  }];
+}
+
 def Vector_BroadcastOp :
   Vector_Op<"broadcast", [NoSideEffect,
      PredOpTrait<"source operand and result have same element type",
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9fcad2f..c43fc0e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -340,6 +340,33 @@
   }
 };
 
+// TODO(ajcbik): merge Reduction and ReductionV2
+class VectorReductionV2OpConversion : public LLVMOpLowering {
+public:
+  explicit VectorReductionV2OpConversion(MLIRContext *context,
+                                         LLVMTypeConverter &typeConverter)
+      : LLVMOpLowering(vector::ReductionV2Op::getOperationName(), context,
+                       typeConverter) {}
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto reductionOp = cast<vector::ReductionV2Op>(op);
+    auto kind = reductionOp.kind();
+    Type eltType = reductionOp.dest().getType();
+    Type llvmType = lowering.convertType(eltType);
+    if (kind == "add") {
+      rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
+          op, llvmType, operands[1], operands[0]);
+      return matchSuccess();
+    } else if (kind == "mul") {
+      rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>(
+          op, llvmType, operands[1], operands[0]);
+      return matchSuccess();
+    }
+    return matchFailure();
+  }
+};
+
 class VectorShuffleOpConversion : public LLVMOpLowering {
 public:
   explicit VectorShuffleOpConversion(MLIRContext *context,
@@ -1125,11 +1152,12 @@
                   VectorInsertStridedSliceOpSameRankRewritePattern,
                   VectorStridedSliceOpConversion>(ctx);
   patterns.insert<VectorBroadcastOpConversion, VectorReductionOpConversion,
-                  VectorShuffleOpConversion, VectorExtractElementOpConversion,
-                  VectorExtractOpConversion, VectorFMAOp1DConversion,
-                  VectorInsertElementOpConversion, VectorInsertOpConversion,
-                  VectorOuterProductOpConversion, VectorTypeCastOpConversion,
-                  VectorPrintOpConversion>(ctx, converter);
+                  VectorReductionV2OpConversion, VectorShuffleOpConversion,
+                  VectorExtractElementOpConversion, VectorExtractOpConversion,
+                  VectorFMAOp1DConversion, VectorInsertElementOpConversion,
+                  VectorInsertOpConversion, VectorOuterProductOpConversion,
+                  VectorTypeCastOpConversion, VectorPrintOpConversion>(
+      ctx, converter);
 }
 
 namespace {
@@ -1139,11 +1167,12 @@
 } // namespace
 
 void LowerVectorToLLVMPass::runOnModule() {
-  // Perform progressive lowering of operations on "slices".
-  // Folding and DCE get rid of all non-leaking tuple ops.
+  // Perform progressive lowering of operations on "slices" and
+  // all contraction operations. Also applies folding and DCE.
   {
     OwningRewritePatternList patterns;
     populateVectorSlicesLoweringPatterns(patterns, &getContext());
+    populateVectorContractLoweringPatterns(patterns, &getContext());
     applyPatternsGreedily(getModule(), patterns);
   }
 
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index 8bdeb92..fe62666 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -538,6 +538,7 @@
 }
 
 namespace {
+
 // Splits vector TransferReadOp into smaller TransferReadOps based on slicing
 // scheme of its unique ExtractSlicesOp user.
 struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
@@ -862,6 +863,72 @@
   }
 };
 
+/// Progressive lowering of ConstractionOp.
+class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
+public:
+  using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(vector::ContractionOp op,
+                                     PatternRewriter &rewriter) const override {
+    // TODO(ajcbik): implement masks
+    if (llvm::size(op.masks()) != 0)
+      return matchFailure();
+
+    auto loc = op.getLoc();
+    VectorType lhsType = op.getLhsType();
+    VectorType rhsType = op.getRhsType();
+    Type resType = op.getResultType();
+
+    // Find first batch dimension in lhs/rhs, and lower when found.
+    std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
+    if (!batchDimMap.empty()) {
+      // TODO(ajcbik): implement batch
+      return matchFailure();
+    }
+
+    // Collect contracting dimensions.
+    std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
+        op.getContractingDimMap();
+    DenseSet<int64_t> lhsContractingDimSet;
+    DenseSet<int64_t> rhsContractingDimSet;
+    for (auto &dimPair : contractingDimMap) {
+      lhsContractingDimSet.insert(dimPair.first);
+      rhsContractingDimSet.insert(dimPair.second);
+    }
+
+    // Find free dimension in lhs/rhs, and lower first when found.
+    for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
+      if (lhsContractingDimSet.count(i) == 0) {
+        // TODO(ajcbik): implement free
+        return matchFailure();
+      }
+    }
+    for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
+      if (rhsContractingDimSet.count(i) == 0) {
+        // TODO(ajcbik): implement free
+        return matchFailure();
+      }
+    }
+
+    // Only contraction dimensions remain.
+    if (!resType.isa<VectorType>() && lhsType.getRank() == 1 &&
+        rhsType.getRank() == 1) {
+      // Handle reduction into scalar.
+      Value zero = rewriter.create<ConstantOp>(loc, resType,
+                                               rewriter.getZeroAttr(resType));
+      Value splat = rewriter.create<SplatOp>(loc, lhsType, zero);
+      Value fma =
+          rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), splat);
+      StringAttr kind = rewriter.getStringAttr("add");
+      rewriter.replaceOpWithNewOp<vector::ReductionV2Op>(op, resType, kind, fma,
+                                                         op.acc());
+      return matchSuccess();
+    }
+    // TODO(ajcbik): implement more contraction
+    return matchFailure();
+  }
+};
+
 } // namespace
 
 // TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp).
@@ -876,3 +943,8 @@
     OwningRewritePatternList &patterns, MLIRContext *context) {
   patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);
 }
+
+void mlir::vector::populateVectorContractLoweringPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context) {
+  patterns.insert<ContractionOpLowering>(context);
+}
diff --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
new file mode 100644
index 0000000..6c4cb5f
--- /dev/null
+++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s
+
+#dotp_accesses = [
+  affine_map<(i) -> (i)>,
+  affine_map<(i) -> (i)>,
+  affine_map<(i) -> ()>
+]
+#dotp_trait = {
+  indexing_maps = #dotp_accesses,
+  iterator_types = ["reduction"]
+}
+
+// CHECK-LABEL: func @extract_contract1
+// CHECK-SAME: %[[A:.*0]]: vector<4xf32>,
+// CHECK-SAME: %[[B:.*1]]: vector<4xf32>,
+// CHECK-SAME: %[[C:.*2]]: f32
+// CHECK:      %[[Z:.*]] = constant dense<0.000000e+00>
+// CHECK:      %[[F:.*]] = vector.fma %[[A]], %[[B]], %[[Z]] : vector<4xf32>
+// CHECK:      %[[R:.*]] = vector.reductionv2 "add", %[[F]], %[[C]]
+// CHECK:      return %[[R]] : f32
+
+func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 {
+  %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2
+    : vector<4xf32>, vector<4xf32> into f32
+  return %0 : f32
+}
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index d0ac718..3f35f81 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -42,16 +42,29 @@
   }
 };
 
+struct TestVectorContractionConversion
+    : public FunctionPass<TestVectorContractionConversion> {
+  void runOnFunction() override {
+    OwningRewritePatternList patterns;
+    populateVectorContractLoweringPatterns(patterns, &getContext());
+    applyPatternsGreedily(getFunction(), patterns);
+  }
+};
+
 } // end anonymous namespace
 
 namespace mlir {
 void registerTestVectorConversions() {
-  PassRegistration<TestVectorToVectorConversion> pass(
+  PassRegistration<TestVectorToVectorConversion> vectorToVectorPass(
       "test-vector-to-vector-conversion",
       "Test conversion patterns between ops in the vector dialect");
 
-  PassRegistration<TestVectorSlicesConversion> slices_pass(
+  PassRegistration<TestVectorSlicesConversion> slicesPass(
       "test-vector-slices-conversion",
       "Test conversion patterns that lower slices ops in the vector dialect");
+
+  PassRegistration<TestVectorContractionConversion> contractionPass(
+      "test-vector-contraction-conversion",
+      "Test conversion patterns that lower contract ops in the vector dialect");
 }
 } // namespace mlir