[mlir][spirv] Add lowering for load/store zero-rank memref from std to SPIR-V.

Differential Revision: https://reviews.llvm.org/D74874
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index 7730661..3cf5046 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -69,6 +69,9 @@
     if (!elementSize) {
       return llvm::None;
     }
+    if (memRefType.getRank() == 0) {
+      return elementSize;
+    }
     auto dims = memRefType.getShape();
     if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
         offset == MemRefType::getDynamicStrideOrOffset() ||
@@ -325,8 +328,12 @@
   }
   SmallVector<Value, 2> linearizedIndices;
   // Add a '0' at the start to index into the struct.
-  linearizedIndices.push_back(builder.create<spirv::ConstantOp>(
-      loc, indexType, IntegerAttr::get(indexType, 0)));
+  auto zero = spirv::ConstantOp::getZero(indexType, loc, &builder);
+  linearizedIndices.push_back(zero);
+  // If it is a zero-rank memref type, extract the element directly.
+  if (!ptrLoc) {
+    ptrLoc = zero;
+  }
   linearizedIndices.push_back(ptrLoc);
   return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
 }
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
index 1ec8b4c..341df27 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
@@ -312,3 +312,41 @@
 func @memref_type(%arg0: memref<3xi1>) {
   return
 }
+
+// CHECK-LABEL: @load_store_zero_rank_float
+// CHECK: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>,
+// CHECK: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>)
+func @load_store_zero_rank_float(%arg0: memref<f32>, %arg1: memref<f32>) {
+  //      CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32
+  //      CHECK: spv.AccessChain [[ARG0]][
+  // CHECK-SAME: [[ZERO1]], [[ZERO1]]
+  // CHECK-SAME: ] :
+  //      CHECK: spv.Load "StorageBuffer" %{{.*}} : f32
+  %0 = load %arg0[] : memref<f32>
+  //      CHECK: [[ZERO2:%.*]] = spv.constant 0 : i32
+  //      CHECK: spv.AccessChain [[ARG1]][
+  // CHECK-SAME: [[ZERO2]], [[ZERO2]]
+  // CHECK-SAME: ] :
+  //      CHECK: spv.Store "StorageBuffer" %{{.*}} : f32
+  store %0, %arg1[] : memref<f32>
+  return
+}
+
+// CHECK-LABEL: @load_store_zero_rank_int
+// CHECK: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>,
+// CHECK: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>)
+func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
+  //      CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32
+  //      CHECK: spv.AccessChain [[ARG0]][
+  // CHECK-SAME: [[ZERO1]], [[ZERO1]]
+  // CHECK-SAME: ] :
+  //      CHECK: spv.Load "StorageBuffer" %{{.*}} : i32
+  %0 = load %arg0[] : memref<i32>
+  //      CHECK: [[ZERO2:%.*]] = spv.constant 0 : i32
+  //      CHECK: spv.AccessChain [[ARG1]][
+  // CHECK-SAME: [[ZERO2]], [[ZERO2]]
+  // CHECK-SAME: ] :
+  //      CHECK: spv.Store "StorageBuffer" %{{.*}} : i32
+  store %0, %arg1[] : memref<i32>
+  return
+}
diff --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
index 9c8be4e..d89f1ff 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
@@ -23,3 +23,37 @@
     spv.Return
   }
 }
+
+// -----
+
+spv.module "Logical" "GLSL450" {
+  spv.func @load_store_zero_rank_float(%arg0: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>, %arg1: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>) "None" {
+    // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>
+    // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : f32
+    %0 = spv.constant 0 : i32
+    %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>
+    %2 = spv.Load "StorageBuffer" %1 : f32
+
+    // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>
+    // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32
+    %3 = spv.constant 0 : i32
+    %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>
+    spv.Store "StorageBuffer" %4, %2 : f32
+    spv.Return
+  }
+
+  spv.func @load_store_zero_rank_int(%arg0: !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>, %arg1: !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>) "None" {
+    // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>
+    // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : i32
+    %0 = spv.constant 0 : i32
+    %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>
+    %2 = spv.Load "StorageBuffer" %1 : i32
+
+    // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>
+    // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32
+    %3 = spv.constant 0 : i32
+    %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr<!spv.struct<!spv.array<1 x i32 [4]> [0]>, StorageBuffer>
+    spv.Store "StorageBuffer" %4, %2 : i32
+    spv.Return
+  }
+}