1:N memref to LLVM

update some more code

update

update

update

update

some progress

update

update

more improements
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
index d5055f0..119106e 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
@@ -30,13 +30,13 @@
 /// Helper class to produce LLVM dialect operations extracting or inserting
 /// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor.
 /// The Value may be null, in which case none of the operations are valid.
-class MemRefDescriptor : public StructBuilder {
+class MemRefDescriptor {
 public:
   /// Construct a helper for the given descriptor value.
-  explicit MemRefDescriptor(Value descriptor);
+  explicit MemRefDescriptor(ValueRange elements);
   /// Builds IR creating a `poison` value of the descriptor type.
   static MemRefDescriptor poison(OpBuilder &builder, Location loc,
-                                 Type descriptorType);
+                                 TypeRange descriptorTypes);
   /// Builds IR creating a MemRef descriptor that represents `type` and
   /// populates it with static shape and stride information extracted from the
   /// type.
@@ -49,6 +49,11 @@
                   const LLVMTypeConverter &typeConverter, MemRefType type,
                   Value memory, Value alignedMemory);
 
+  /// Builds IR extracting individual elements of a MemRef descriptor structure
+  /// and returning them as `results` list.
+  static MemRefDescriptor fromPackedStruct(OpBuilder &builder, Location loc,
+                                           Value packed);
+
   /// Builds IR extracting the allocated pointer from the descriptor.
   Value allocatedPtr(OpBuilder &builder, Location loc);
   /// Builds IR inserting the allocated pointer into the descriptor.
@@ -98,6 +103,8 @@
   Value bufferPtr(OpBuilder &builder, Location loc,
                   const LLVMTypeConverter &converter, MemRefType type);
 
+  int64_t getRank();
+
   /// Builds IR populating a MemRef descriptor structure from a list of
   /// individual values composing that descriptor, in the following order:
   /// - allocated pointer;
@@ -106,20 +113,21 @@
   /// - <rank> sizes;
   /// - <rank> strides;
   /// where <rank> is the MemRef rank as provided in `type`.
-  static Value pack(OpBuilder &builder, Location loc,
-                    const LLVMTypeConverter &converter, MemRefType type,
-                    ValueRange values);
-
-  /// Builds IR extracting individual elements of a MemRef descriptor structure
-  /// and returning them as `results` list.
-  static void unpack(OpBuilder &builder, Location loc, Value packed,
-                     MemRefType type, SmallVectorImpl<Value> &results);
+  Value packStruct(OpBuilder &builder, Location loc);
 
   /// Returns the number of non-aggregate values that would be produced by
   /// `unpack`.
   static unsigned getNumUnpackedValues(MemRefType type);
 
+  ValueRange getElements() { return elements; }
+
+  /*implicit*/ operator ValueRange() { return elements; }
+
 private:
+  SmallVector<Value> elements;
+  // Value allocatedPtrVal, alignedPtrVal, offsetVal;
+  // SmallVector<Value> sizeVals, strideVals;
+
   // Cached index type.
   Type indexType;
 };
@@ -155,13 +163,18 @@
   ValueRange elements;
 };
 
-class UnrankedMemRefDescriptor : public StructBuilder {
+class UnrankedMemRefDescriptor {
 public:
   /// Construct a helper for the given descriptor value.
-  explicit UnrankedMemRefDescriptor(Value descriptor);
+  explicit UnrankedMemRefDescriptor(ValueRange elements);
   /// Builds IR creating an `undef` value of the descriptor type.
   static UnrankedMemRefDescriptor poison(OpBuilder &builder, Location loc,
-                                         Type descriptorType);
+                                         TypeRange descriptorType);
+
+  /// Builds IR extracting individual elements of a MemRef descriptor structure
+  /// and returning them as `results` list.
+  static UnrankedMemRefDescriptor fromPackedStruct(OpBuilder &builder,
+                                                   Location loc, Value packed);
 
   /// Builds IR extracting the rank from the descriptor
   Value rank(OpBuilder &builder, Location loc) const;
@@ -176,14 +189,7 @@
   /// of individual constituent values in the following order:
   /// - rank of the memref;
   /// - pointer to the memref descriptor.
-  static Value pack(OpBuilder &builder, Location loc,
-                    const LLVMTypeConverter &converter, UnrankedMemRefType type,
-                    ValueRange values);
-
-  /// Builds IR extracting individual elements that compose an unranked memref
-  /// descriptor and returns them as `results` list.
-  static void unpack(OpBuilder &builder, Location loc, Value packed,
-                     SmallVectorImpl<Value> &results);
+  Value packStruct(OpBuilder &builder, Location loc);
 
   /// Returns the number of non-aggregate values that would be produced by
   /// `unpack`.
@@ -269,6 +275,13 @@
   static void setStride(OpBuilder &builder, Location loc,
                         const LLVMTypeConverter &typeConverter,
                         Value strideBasePtr, Value index, Value stride);
+
+  ValueRange getElements() { return elements; }
+
+  /*implicit*/ operator ValueRange() { return elements; }
+
+private:
+  SmallVector<Value> elements;
 };
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index e78f174..2d743a9 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -76,7 +76,7 @@
 
   // This is a strided getElementPtr variant that linearizes subscripts as:
   //   `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
-  Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
+  Value getStridedElementPtr(Location loc, MemRefType type, ValueRange memRefDesc,
                              ValueRange indices,
                              ConversionPatternRewriter &rewriter) const;
 
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 38b5e49..a65f136 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -91,6 +91,17 @@
   Type convertCallingConventionType(Type type,
                                     bool useBarePointerCallConv = false) const;
 
+  /// Convert a memref type into an LLVM type that captures the relevant data.
+  LogicalResult convertMemRefType(MemRefType type,
+                                  SmallVectorImpl<Type> &result,
+                                  bool packed = false) const;
+
+  /// Convert an unranked memref type to an LLVM type that captures the
+  /// runtime rank and a pointer to the static ranked memref desc
+  LogicalResult convertUnrankedMemRefType(UnrankedMemRefType type,
+                                          SmallVectorImpl<Type> &result,
+                                          bool packed = false) const;
+
   /// Promote the bare pointers in 'values' that resulted from memrefs to
   /// descriptors. 'stdTypes' holds the types of 'values' before the conversion
   /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
@@ -111,7 +122,7 @@
   /// of the platform-specific C/C++ ABI lowering related to struct argument
   /// passing.
   SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
-                                        ValueRange operands, OpBuilder &builder,
+                                        ArrayRef<ValueRange> operands, OpBuilder &builder,
                                         bool useBarePtrCallConv = false) const;
 
   /// Promote the LLVM struct representation of one MemRef descriptor to stack
@@ -245,13 +256,6 @@
   /// `!llvm<"{ double, double }">`. `complex<bf16>` is not supported.
   Type convertComplexType(ComplexType type) const;
 
-  /// Convert a memref type into an LLVM type that captures the relevant data.
-  Type convertMemRefType(MemRefType type) const;
-
-  /// Convert an unranked memref type to an LLVM type that captures the
-  /// runtime rank and a pointer to the static ranked memref desc
-  Type convertUnrankedMemRefType(UnrankedMemRefType type) const;
-
   /// Convert a memref type to a bare pointer to the memref element type.
   Type convertMemRefToBarePtr(BaseMemRefType type) const;
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3acd470..e5f70c4 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -178,6 +178,7 @@
   LogicalResult
   matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+                    /*
     Location loc = op.getLoc();
     Value memRef = adaptor.getSource();
     Value unconvertedMemref = op.getSource();
@@ -222,7 +223,7 @@
 
     Value fatPtr = makeBufferRsrc(
         rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
-        chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=*/7);
+        chipset, adaptor.getCacheSwizzleStride(), /*addressSpace=7);
 
     Value result = MemRefDescriptor::poison(
         rewriter, loc,
@@ -241,6 +242,8 @@
     }
     rewriter.replaceOp(op, result);
     return success();
+    */
+   return failure();
   }
 };
 
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index debfd00..bc6613d 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -125,24 +125,35 @@
   return rewriter.applySignatureConversion(block, *conversion, converter);
 }
 
+static SmallVector<Value> flattenValueRanges(ArrayRef<ValueRange> ranges) {
+  SmallVector<Value> result;
+  for (ValueRange range : ranges)
+    llvm::append_range(result, range);
+  return result;
+}
+
 /// Convert the destination block signature (if necessary) and lower the branch
 /// op to llvm.br.
 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
   using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
+  using Adaptor =
+      typename ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor;
 
   LogicalResult
-  matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
+  matchAndRewrite(cf::BranchOp op, Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> flattenedOperands =
+        flattenValueRanges(adaptor.getOperands());
     FailureOr<Block *> convertedBlock =
         getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
-                          TypeRange(adaptor.getOperands()));
+                          TypeRange(ValueRange(flattenedOperands)));
     if (failed(convertedBlock))
       return failure();
     Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
-        op, adaptor.getOperands(), *convertedBlock);
+        op, flattenedOperands, *convertedBlock);
     // TODO: We should not just forward all attributes like that. But there are
     // existing Flang tests that depend on this behavior.
-    newOp->setAttrs(op->getAttrDictionary());
+    newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
     return success();
   }
 };
@@ -151,28 +162,33 @@
 /// branch op to llvm.cond_br.
 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
   using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
+  using Adaptor =
+      typename ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor;
 
   LogicalResult
-  matchAndRewrite(cf::CondBranchOp op,
-                  typename cf::CondBranchOp::Adaptor adaptor,
+  matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> flattenedTrueDestOperands =
+        flattenValueRanges(adaptor.getTrueDestOperands());
     FailureOr<Block *> convertedTrueBlock =
         getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
-                          TypeRange(adaptor.getTrueDestOperands()));
+                          TypeRange(ValueRange(flattenedTrueDestOperands)));
     if (failed(convertedTrueBlock))
       return failure();
+    SmallVector<Value> flattenedFalseDestOperands =
+        flattenValueRanges(adaptor.getFalseDestOperands());
     FailureOr<Block *> convertedFalseBlock =
         getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
-                          TypeRange(adaptor.getFalseDestOperands()));
+                          TypeRange(ValueRange(flattenedFalseDestOperands)));
     if (failed(convertedFalseBlock))
       return failure();
     Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
-        op, adaptor.getCondition(), *convertedTrueBlock,
-        adaptor.getTrueDestOperands(), *convertedFalseBlock,
-        adaptor.getFalseDestOperands());
+        op, llvm::getSingleElement(adaptor.getCondition()), *convertedTrueBlock,
+        flattenedTrueDestOperands, *convertedFalseBlock,
+        flattenedFalseDestOperands);
     // TODO: We should not just forward all attributes like that. But there are
     // existing Flang tests that depend on this behavior.
-    newOp->setAttrs(op->getAttrDictionary());
+    newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
     return success();
   }
 };
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 55f0a9a..c5c0817 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -140,15 +140,23 @@
   for (auto [index, argType] : llvm::enumerate(type.getInputs())) {
     Value arg = wrapperFuncOp.getArgument(index + argOffset);
     if (auto memrefType = dyn_cast<MemRefType>(argType)) {
+      SmallVector<Type> convertedType;
+      LogicalResult status = typeConverter.convertMemRefType(memrefType, convertedType, /*packed=*/true);
+      (void)status;
+      assert(succeeded(status) && "failed to convert memref type");
       Value loaded = rewriter.create<LLVM::LoadOp>(
-          loc, typeConverter.convertType(memrefType), arg);
-      MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
+          loc, llvm::getSingleElement(convertedType), arg);
+      llvm::append_range(args, MemRefDescriptor::fromPackedStruct(rewriter, loc, loaded).getElements());
       continue;
     }
-    if (isa<UnrankedMemRefType>(argType)) {
+    if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(argType)) {
+      SmallVector<Type> convertedType;
+      LogicalResult status = typeConverter.convertUnrankedMemRefType(unrankedMemrefType, convertedType, /*packed=*/true);
+      (void)status;
+      assert(succeeded(status) && "failed to convert memref type");
       Value loaded = rewriter.create<LLVM::LoadOp>(
-          loc, typeConverter.convertType(argType), arg);
-      UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
+          loc, llvm::getSingleElement(convertedType), arg);
+      llvm::append_range(args, UnrankedMemRefDescriptor::fromPackedStruct(rewriter, loc, loaded).getElements());
       continue;
     }
 
@@ -231,14 +239,12 @@
       numToDrop = memRefType
                       ? MemRefDescriptor::getNumUnpackedValues(memRefType)
                       : UnrankedMemRefDescriptor::getNumUnpackedValues();
-      Value packed =
-          memRefType
-              ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
-                                       wrapperArgsRange.take_front(numToDrop))
-              : UnrankedMemRefDescriptor::pack(
-                    builder, loc, typeConverter, unrankedMemRefType,
-                    wrapperArgsRange.take_front(numToDrop));
-
+      Value packed;
+      if (memRefType) {
+        packed = MemRefDescriptor(wrapperArgsRange.take_front(numToDrop)).packStruct(builder, loc);
+      } else {
+        packed = UnrankedMemRefDescriptor(wrapperArgsRange.take_front(numToDrop)).packStruct(builder, loc);
+      }
       auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
       Value one = builder.create<LLVM::ConstantOp>(
           loc, typeConverter.convertType(builder.getIndexType()),
@@ -515,9 +521,10 @@
   using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
   using Super = CallOpInterfaceLowering<CallOpType>;
   using Base = ConvertOpToLLVMPattern<CallOpType>;
+  using Adaptor = typename ConvertOpToLLVMPattern<CallOpType>::OneToNOpAdaptor;
 
   LogicalResult matchAndRewriteImpl(CallOpType callOp,
-                                    typename CallOpType::Adaptor adaptor,
+                                    Adaptor adaptor,
                                     ConversionPatternRewriter &rewriter,
                                     bool useBarePtrCallConv = false) const {
     // Pack the result types into a struct.
@@ -579,7 +586,18 @@
       return failure();
     }
 
-    rewriter.replaceOp(callOp, results);
+    SmallVector<SmallVector<Value>> unpackedResults;
+    for (auto it : llvm::zip_equal(resultTypes, results)) {
+      SmallVector<Value> &result = unpackedResults.emplace_back();
+      if (isa<MemRefType>(std::get<0>(it))) {
+        llvm::append_range(result, MemRefDescriptor::fromPackedStruct(rewriter, callOp.getLoc(), std::get<1>(it)).getElements());
+      } else if (isa<UnrankedMemRefType>(std::get<0>(it))) {
+        llvm::append_range(result, UnrankedMemRefDescriptor::fromPackedStruct(rewriter, callOp.getLoc(), std::get<1>(it)).getElements());
+      } else {
+        result.push_back(std::get<1>(it));
+      }
+    }
+    rewriter.replaceOpWithMultiple(callOp, unpackedResults);
     return success();
   }
 };
@@ -593,7 +611,7 @@
         symbolTable(symbolTable) {}
 
   LogicalResult
-  matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
+  matchAndRewrite(func::CallOp callOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     bool useBarePtrCallConv = false;
     if (getTypeConverter()->getOptions().useBarePtrCallConv) {
@@ -623,7 +641,7 @@
   using Super::Super;
 
   LogicalResult
-  matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
+  matchAndRewrite(func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
   }
@@ -666,7 +684,7 @@
   using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
+  matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     unsigned numArguments = op.getNumOperands();
@@ -680,20 +698,36 @@
       // be returned from the memref descriptor.
       for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
         Type oldTy = std::get<0>(it).getType();
-        Value newOperand = std::get<1>(it);
+        ValueRange adaptorVal = std::get<1>(it);
         if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
                                           cast<BaseMemRefType>(oldTy))) {
-          MemRefDescriptor memrefDesc(newOperand);
-          newOperand = memrefDesc.allocatedPtr(rewriter, loc);
+          MemRefDescriptor memrefDesc(adaptorVal);
+          updatedOperands.push_back( memrefDesc.allocatedPtr(rewriter, loc));
         } else if (isa<UnrankedMemRefType>(oldTy)) {
           // Unranked memref is not supported in the bare pointer calling
           // convention.
           return failure();
+        } else {
+          assert(adaptorVal.size() == 1 && "1:N conversion not supported for non-memref types");
+          updatedOperands.push_back(adaptorVal.front());
         }
-        updatedOperands.push_back(newOperand);
       }
     } else {
-      updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
+      // Pack operands.
+      for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
+        Value operand = std::get<0>(it);
+        ValueRange adaptorVal = std::get<1>(it);
+        if (isa<MemRefType>(operand.getType())) {
+          MemRefDescriptor memrefDesc(adaptorVal);
+          updatedOperands.push_back(memrefDesc.packStruct(rewriter, loc));
+        } else if (isa<UnrankedMemRefType>(operand.getType())) {
+          UnrankedMemRefDescriptor unrankedMemrefDesc(adaptorVal);
+          updatedOperands.push_back(unrankedMemrefDesc.packStruct(rewriter, loc));
+        } else {
+          assert(adaptorVal.size() == 1 && "1:N conversion not supported for non-memref types");
+          updatedOperands.push_back(adaptorVal.front());
+        }
+      }
       (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
                                     updatedOperands,
                                     /*toDynamic=*/true);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index f22ad1f..79bd1582 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -76,310 +76,8 @@
 LogicalResult
 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
                                    ConversionPatternRewriter &rewriter) const {
-  Location loc = gpuFuncOp.getLoc();
 
-  SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
-  if (encodeWorkgroupAttributionsAsArguments) {
-    // Append an `llvm.ptr` argument to the function signature to encode
-    // workgroup attributions.
-
-    ArrayRef<BlockArgument> workgroupAttributions =
-        gpuFuncOp.getWorkgroupAttributions();
-    size_t numAttributions = workgroupAttributions.size();
-
-    // Insert all arguments at the end.
-    unsigned index = gpuFuncOp.getNumArguments();
-    SmallVector<unsigned> argIndices(numAttributions, index);
-
-    // New arguments will simply be `llvm.ptr` with the correct address space
-    Type workgroupPtrType =
-        rewriter.getType<LLVM::LLVMPointerType>(workgroupAddrSpace);
-    SmallVector<Type> argTypes(numAttributions, workgroupPtrType);
-
-    // Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>)
-    std::array attrs{
-        rewriter.getNamedAttr(LLVM::LLVMDialect::getNoAliasAttrName(),
-                              rewriter.getUnitAttr()),
-        rewriter.getNamedAttr(
-            getDialect().getWorkgroupAttributionAttrHelper().getName(),
-            rewriter.getUnitAttr()),
-    };
-    SmallVector<DictionaryAttr> argAttrs;
-    for (BlockArgument attribution : workgroupAttributions) {
-      auto attributionType = cast<MemRefType>(attribution.getType());
-      IntegerAttr numElements =
-          rewriter.getI64IntegerAttr(attributionType.getNumElements());
-      Type llvmElementType =
-          getTypeConverter()->convertType(attributionType.getElementType());
-      if (!llvmElementType)
-        return failure();
-      TypeAttr type = TypeAttr::get(llvmElementType);
-      attrs.back().setValue(
-          rewriter.getAttr<LLVM::WorkgroupAttributionAttr>(numElements, type));
-      argAttrs.push_back(rewriter.getDictionaryAttr(attrs));
-    }
-
-    // Location match function location
-    SmallVector<Location> argLocs(numAttributions, gpuFuncOp.getLoc());
-
-    // Perform signature modification
-    rewriter.modifyOpInPlace(
-        gpuFuncOp, [gpuFuncOp, &argIndices, &argTypes, &argAttrs, &argLocs]() {
-          static_cast<FunctionOpInterface>(gpuFuncOp).insertArguments(
-              argIndices, argTypes, argAttrs, argLocs);
-        });
-  } else {
-    workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
-    for (auto [idx, attribution] :
-         llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
-      auto type = dyn_cast<MemRefType>(attribution.getType());
-      assert(type && type.hasStaticShape() && "unexpected type in attribution");
-
-      uint64_t numElements = type.getNumElements();
-
-      auto elementType =
-          cast<Type>(typeConverter->convertType(type.getElementType()));
-      auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
-      std::string name =
-          std::string(llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), idx));
-      uint64_t alignment = 0;
-      if (auto alignAttr = dyn_cast_or_null<IntegerAttr>(
-              gpuFuncOp.getWorkgroupAttributionAttr(
-                  idx, LLVM::LLVMDialect::getAlignAttrName())))
-        alignment = alignAttr.getInt();
-      auto globalOp = rewriter.create<LLVM::GlobalOp>(
-          gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
-          LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment,
-          workgroupAddrSpace);
-      workgroupBuffers.push_back(globalOp);
-    }
-  }
-
-  // Remap proper input types.
-  TypeConverter::SignatureConversion signatureConversion(
-      gpuFuncOp.front().getNumArguments());
-
-  Type funcType = getTypeConverter()->convertFunctionSignature(
-      gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
-      getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
-  if (!funcType) {
-    return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
-      diag << "failed to convert function signature type for: "
-           << gpuFuncOp.getFunctionType();
-    });
-  }
-
-  // Create the new function operation. Only copy those attributes that are
-  // not specific to function modeling.
-  SmallVector<NamedAttribute, 4> attributes;
-  ArrayAttr argAttrs;
-  for (const auto &attr : gpuFuncOp->getAttrs()) {
-    if (attr.getName() == SymbolTable::getSymbolAttrName() ||
-        attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
-        attr.getName() ==
-            gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() ||
-        attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() ||
-        attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName() ||
-        attr.getName() == gpuFuncOp.getKnownBlockSizeAttrName() ||
-        attr.getName() == gpuFuncOp.getKnownGridSizeAttrName())
-      continue;
-    if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) {
-      argAttrs = gpuFuncOp.getArgAttrsAttr();
-      continue;
-    }
-    attributes.push_back(attr);
-  }
-
-  DenseI32ArrayAttr knownBlockSize = gpuFuncOp.getKnownBlockSizeAttr();
-  DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr();
-  // Ensure we don't lose information if the function is lowered before its
-  // surrounding context.
-  auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect());
-  if (knownBlockSize)
-    attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(),
-                            knownBlockSize);
-  if (knownGridSize)
-    attributes.emplace_back(gpuDialect->getKnownGridSizeAttrHelper().getName(),
-                            knownGridSize);
-
-  // Add a dialect specific kernel attribute in addition to GPU kernel
-  // attribute. The former is necessary for further translation while the
-  // latter is expected by gpu.launch_func.
-  if (gpuFuncOp.isKernel()) {
-    if (kernelAttributeName)
-      attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
-    // Set the dialect-specific block size attribute if there is one.
-    if (kernelBlockSizeAttributeName && knownBlockSize) {
-      attributes.emplace_back(kernelBlockSizeAttributeName, knownBlockSize);
-    }
-  }
-  LLVM::CConv callingConvention = gpuFuncOp.isKernel()
-                                      ? kernelCallingConvention
-                                      : nonKernelCallingConvention;
-  auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
-      gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
-      LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention,
-      /*comdat=*/nullptr, attributes);
-
-  {
-    // Insert operations that correspond to converted workgroup and private
-    // memory attributions to the body of the function. This must operate on
-    // the original function, before the body region is inlined in the new
-    // function to maintain the relation between block arguments and the
-    // parent operation that assigns their semantics.
-    OpBuilder::InsertionGuard guard(rewriter);
-
-    // Rewrite workgroup memory attributions to addresses of global buffers.
-    rewriter.setInsertionPointToStart(&gpuFuncOp.front());
-    unsigned numProperArguments = gpuFuncOp.getNumArguments();
-
-    if (encodeWorkgroupAttributionsAsArguments) {
-      // Build a MemRefDescriptor with each of the arguments added above.
-
-      unsigned numAttributions = gpuFuncOp.getNumWorkgroupAttributions();
-      assert(numProperArguments >= numAttributions &&
-             "Expecting attributions to be encoded as arguments already");
-
-      // Arguments encoding workgroup attributions will be in positions
-      // [numProperArguments, numProperArguments+numAttributions)
-      ArrayRef<BlockArgument> attributionArguments =
-          gpuFuncOp.getArguments().slice(numProperArguments - numAttributions,
-                                         numAttributions);
-      for (auto [idx, vals] : llvm::enumerate(llvm::zip_equal(
-               gpuFuncOp.getWorkgroupAttributions(), attributionArguments))) {
-        auto [attribution, arg] = vals;
-        auto type = cast<MemRefType>(attribution.getType());
-
-        // Arguments are of llvm.ptr type and attributions are of memref type:
-        // we need to wrap them in memref descriptors.
-        Value descr = MemRefDescriptor::fromStaticShape(
-            rewriter, loc, *getTypeConverter(), type, arg);
-
-        // And remap the arguments
-        signatureConversion.remapInput(numProperArguments + idx, descr);
-      }
-    } else {
-      for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
-        auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
-                                                  global.getAddrSpace());
-        Value address = rewriter.create<LLVM::AddressOfOp>(
-            loc, ptrType, global.getSymNameAttr());
-        Value memory =
-            rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(),
-                                         address, ArrayRef<LLVM::GEPArg>{0, 0});
-
-        // Build a memref descriptor pointing to the buffer to plug with the
-        // existing memref infrastructure. This may use more registers than
-        // otherwise necessary given that memref sizes are fixed, but we can try
-        // and canonicalize that away later.
-        Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
-        auto type = cast<MemRefType>(attribution.getType());
-        Value descr = MemRefDescriptor::fromStaticShape(
-            rewriter, loc, *getTypeConverter(), type, memory);
-        signatureConversion.remapInput(numProperArguments + idx, descr);
-      }
-    }
-
-    // Rewrite private memory attributions to alloca'ed buffers.
-    unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
-    auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
-    for (const auto [idx, attribution] :
-         llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
-      auto type = cast<MemRefType>(attribution.getType());
-      assert(type && type.hasStaticShape() && "unexpected type in attribution");
-
-      // Explicitly drop memory space when lowering private memory
-      // attributions since NVVM models it as `alloca`s in the default
-      // memory space and does not support `alloca`s with addrspace(5).
-      Type elementType = typeConverter->convertType(type.getElementType());
-      auto ptrType =
-          LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace);
-      Value numElements = rewriter.create<LLVM::ConstantOp>(
-          gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
-      uint64_t alignment = 0;
-      if (auto alignAttr =
-              dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
-                  idx, LLVM::LLVMDialect::getAlignAttrName())))
-        alignment = alignAttr.getInt();
-      Value allocated = rewriter.create<LLVM::AllocaOp>(
-          gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
-      Value descr = MemRefDescriptor::fromStaticShape(
-          rewriter, loc, *getTypeConverter(), type, allocated);
-      signatureConversion.remapInput(
-          numProperArguments + numWorkgroupAttributions + idx, descr);
-    }
-  }
-
-  // Move the region to the new function, update the entry block signature.
-  rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
-                              llvmFuncOp.end());
-  if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
-                                         &signatureConversion)))
-    return failure();
-
-  // Get memref type from function arguments and set the noalias to
-  // pointer arguments.
-  for (const auto [idx, argTy] :
-       llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
-    auto remapping = signatureConversion.getInputMapping(idx);
-    NamedAttrList argAttr =
-        argAttrs ? cast<DictionaryAttr>(argAttrs[idx]) : NamedAttrList();
-    auto copyAttribute = [&](StringRef attrName) {
-      Attribute attr = argAttr.erase(attrName);
-      if (!attr)
-        return;
-      for (size_t i = 0, e = remapping->size; i < e; ++i)
-        llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
-    };
-    auto copyPointerAttribute = [&](StringRef attrName) {
-      Attribute attr = argAttr.erase(attrName);
-
-      if (!attr)
-        return;
-      if (remapping->size > 1 &&
-          attrName == LLVM::LLVMDialect::getNoAliasAttrName()) {
-        emitWarning(llvmFuncOp.getLoc(),
-                    "Cannot copy noalias with non-bare pointers.\n");
-        return;
-      }
-      for (size_t i = 0, e = remapping->size; i < e; ++i) {
-        if (isa<LLVM::LLVMPointerType>(
-                llvmFuncOp.getArgument(remapping->inputNo + i).getType())) {
-          llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
-        }
-      }
-    };
-
-    if (argAttr.empty())
-      continue;
-
-    copyAttribute(LLVM::LLVMDialect::getReturnedAttrName());
-    copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName());
-    copyAttribute(LLVM::LLVMDialect::getInRegAttrName());
-    bool lowersToPointer = false;
-    for (size_t i = 0, e = remapping->size; i < e; ++i) {
-      lowersToPointer |= isa<LLVM::LLVMPointerType>(
-          llvmFuncOp.getArgument(remapping->inputNo + i).getType());
-    }
-
-    if (lowersToPointer) {
-      copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName());
-      copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName());
-      copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName());
-      copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName());
-      copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName());
-      copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName());
-      copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName());
-      copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName());
-      copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName());
-      copyPointerAttribute(
-          LLVM::LLVMDialect::getDereferenceableOrNullAttrName());
-      copyPointerAttribute(
-          LLVM::LLVMDialect::WorkgroupAttributionAttrHelper::getNameStr());
-    }
-  }
-  rewriter.eraseOp(gpuFuncOp);
-  return success();
+  return failure();
 }
 
 LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 512820b..f0b1602 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -723,8 +723,10 @@
   auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
 
-  auto arguments = getTypeConverter()->promoteOperands(
-      loc, op->getOperands(), adaptor.getOperands(), rewriter);
+  llvm_unreachable("TODO");
+  SmallVector<Value> arguments;
+  //auto arguments = getTypeConverter()->promoteOperands(
+  //    loc, op->getOperands(), adaptor.getOperands(), rewriter);
   arguments.push_back(elementSize);
   hostRegisterCallBuilder.create(loc, rewriter, arguments);
 
@@ -745,8 +747,10 @@
   auto elementType = cast<UnrankedMemRefType>(memRefType).getElementType();
   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
 
-  auto arguments = getTypeConverter()->promoteOperands(
-      loc, op->getOperands(), adaptor.getOperands(), rewriter);
+  llvm_unreachable("TODO");
+  SmallVector<Value> arguments;
+  //auto arguments = getTypeConverter()->promoteOperands(
+  //    loc, op->getOperands(), adaptor.getOperands(), rewriter);
   arguments.push_back(elementSize);
   hostUnregisterCallBuilder.create(loc, rewriter, arguments);
 
@@ -805,9 +809,9 @@
 
   if (allocOp.getAsyncToken()) {
     // Async alloc: make dependent ops use the same stream.
-    rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
+    //rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
   } else {
-    rewriter.replaceOp(allocOp, {memRefDescriptor});
+    //rewriter.replaceOp(allocOp, {memRefDescriptor});
   }
 
   return success();
@@ -977,9 +981,11 @@
   // Note: If `useBarePtrCallConv` is set in the type converter's options,
   // the value of `kernelBarePtrCallConv` will be ignored.
   OperandRange origArguments = launchOp.getKernelOperands();
-  SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
-      loc, origArguments, adaptor.getKernelOperands(), rewriter,
-      /*useBarePtrCallConv=*/kernelBarePtrCallConv);
+  llvm_unreachable("TODO");
+  SmallVector<Value,8> llvmArguments;
+  //SmallVector<Value, 8> llvmArguments = getTypeConverter()->promoteOperands(
+  //    loc, origArguments, adaptor.getKernelOperands(), rewriter,
+  //    /*useBarePtrCallConv=*/kernelBarePtrCallConv);
   SmallVector<Value, 8> llvmArgumentsWithSizes;
 
   // Intersperse size information if requested.
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index 86d6643..9f8030a 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -21,19 +21,23 @@
 //===----------------------------------------------------------------------===//
 
 /// Construct a helper for the given descriptor value.
-MemRefDescriptor::MemRefDescriptor(Value descriptor)
-    : StructBuilder(descriptor) {
-  assert(value != nullptr && "value cannot be null");
-  indexType = cast<LLVM::LLVMStructType>(value.getType())
-                  .getBody()[kOffsetPosInMemRefDescriptor];
+MemRefDescriptor::MemRefDescriptor(ValueRange elements) : elements(elements) {
+  indexType = elements[kOffsetPosInMemRefDescriptor].getType();
 }
 
 /// Builds IR creating an `undef` value of the descriptor type.
 MemRefDescriptor MemRefDescriptor::poison(OpBuilder &builder, Location loc,
-                                          Type descriptorType) {
-
-  Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType);
-  return MemRefDescriptor(descriptor);
+                                          TypeRange descriptorTypes) {
+  DenseMap<Type, Value> poisonValues;
+  SmallVector<Value> elements;
+  for (Type t : descriptorTypes) {
+    auto it = poisonValues.find(t);
+    if (it == poisonValues.end()) {
+      poisonValues[t] = builder.create<LLVM::PoisonOp>(loc, t);
+    }
+    elements.push_back(poisonValues[t]);
+  }
+  return MemRefDescriptor(elements);
 }
 
 /// Builds IR creating a MemRef descriptor that represents `type` and
@@ -57,10 +61,11 @@
   assert(!llvm::any_of(strides, ShapedType::isDynamic) &&
          "expected static strides");
 
-  auto convertedType = typeConverter.convertType(type);
-  assert(convertedType && "unexpected failure in memref type conversion");
+  SmallVector<Type> convertedTypes;
+  LogicalResult status = typeConverter.convertType(type, convertedTypes);
+  assert(succeeded(status) && "unexpected failure in memref type conversion");
 
-  auto descr = MemRefDescriptor::poison(builder, loc, convertedType);
+  auto descr = MemRefDescriptor::poison(builder, loc, convertedTypes);
   descr.setAllocatedPtr(builder, loc, memory);
   descr.setAlignedPtr(builder, loc, alignedMemory);
   descr.setConstantOffset(builder, loc, offset);
@@ -73,26 +78,81 @@
   return descr;
 }
 
+static Value extractStructElement(OpBuilder &builder, Location loc,
+                                  Value packed, ArrayRef<int64_t> idx) {
+  return builder.create<LLVM::ExtractValueOp>(loc, packed, idx);
+}
+
+static Value insertStructElement(OpBuilder &builder, Location loc, Value packed,
+                                 Value val, ArrayRef<int64_t> idx) {
+  return builder.create<LLVM::InsertValueOp>(loc, packed, val, idx);
+}
+MemRefDescriptor MemRefDescriptor::fromPackedStruct(OpBuilder &builder,
+                                                    Location loc,
+                                                    Value packed) {
+  auto llvmStruct = cast<LLVM::LLVMStructType>(packed.getType());
+  SmallVector<Value> elements;
+  elements.push_back(extractStructElement(builder, loc, packed, 0));
+  elements.push_back(extractStructElement(builder, loc, packed, 1));
+  elements.push_back(extractStructElement(builder, loc, packed, 2));
+  if (llvmStruct.getBody().size() > 3) {
+    auto llvmArray = cast<LLVM::LLVMArrayType>(llvmStruct.getBody()[3]);
+    int64_t rank = llvmArray.getNumElements();
+    for (int i = 0; i < rank; ++i)
+      elements.push_back(extractStructElement(builder, loc, packed, {3, i}));
+    for (int i = 0; i < rank; ++i)
+      elements.push_back(extractStructElement(builder, loc, packed, {4, i}));
+  }
+  return MemRefDescriptor(elements);
+}
+
+Value MemRefDescriptor::packStruct(OpBuilder &builder, Location loc) {
+  Type offsetStrideTy = elements[2].getType();
+  SmallVector<Type> fields;
+  fields.push_back(elements[0].getType());
+  fields.push_back(elements[1].getType());
+  fields.push_back(offsetStrideTy);
+  if (getRank() > 0) {
+    auto llvmArray = LLVM::LLVMArrayType::get(builder.getContext(),
+                                            offsetStrideTy, getRank());
+    fields.push_back(llvmArray);
+    fields.push_back(llvmArray);
+  }
+  Value desc = builder.create<LLVM::UndefOp>(
+      loc, LLVM::LLVMStructType::getLiteral(builder.getContext(), fields));
+  desc = insertStructElement(builder, loc, desc, elements[0], 0);
+  desc = insertStructElement(builder, loc, desc, elements[1], 1);
+  desc = insertStructElement(builder, loc, desc, elements[2], 2);
+  if(getRank() > 0) {
+    for (int i = 0; i < getRank(); ++i)
+      desc = insertStructElement(builder, loc, desc, elements[3 + i], {3, i});
+    for (int i = 0; i < getRank(); ++i)
+      desc = insertStructElement(builder, loc, desc, elements[3 + getRank() + i],
+                                {4, i});
+  }
+  return desc;
+}
+
 /// Builds IR extracting the allocated pointer from the descriptor.
 Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
-  return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor);
+  return elements[kAllocatedPtrPosInMemRefDescriptor];
 }
 
 /// Builds IR inserting the allocated pointer into the descriptor.
 void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
                                        Value ptr) {
-  setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr);
+  elements[kAllocatedPtrPosInMemRefDescriptor] = ptr;
 }
 
 /// Builds IR extracting the aligned pointer from the descriptor.
 Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
-  return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor);
+  return elements[kAlignedPtrPosInMemRefDescriptor];
 }
 
 /// Builds IR inserting the aligned pointer into the descriptor.
 void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
                                      Value ptr) {
-  setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr);
+  elements[kAlignedPtrPosInMemRefDescriptor] = ptr;
 }
 
 // Creates a constant Op producing a value of `resultType` from an index-typed
@@ -105,28 +165,25 @@
 
 /// Builds IR extracting the offset from the descriptor.
 Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
-  return builder.create<LLVM::ExtractValueOp>(loc, value,
-                                              kOffsetPosInMemRefDescriptor);
+  return elements[kOffsetPosInMemRefDescriptor];
 }
 
 /// Builds IR inserting the offset into the descriptor.
 void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
                                  Value offset) {
-  value = builder.create<LLVM::InsertValueOp>(loc, value, offset,
-                                              kOffsetPosInMemRefDescriptor);
+  elements[kOffsetPosInMemRefDescriptor] = offset;
 }
 
 /// Builds IR inserting the offset into the descriptor.
 void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc,
                                          uint64_t offset) {
-  setOffset(builder, loc,
-            createIndexAttrConstant(builder, loc, indexType, offset));
+  elements[kOffsetPosInMemRefDescriptor] =
+      createIndexAttrConstant(builder, loc, indexType, offset);
 }
 
 /// Builds IR extracting the pos-th size from the descriptor.
 Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
-  return builder.create<LLVM::ExtractValueOp>(
-      loc, value, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos}));
+  return elements[kSizePosInMemRefDescriptor + pos];
 }
 
 Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
@@ -137,8 +194,14 @@
 
   // Copy size values to stack-allocated memory.
   auto one = createIndexAttrConstant(builder, loc, indexType, 1);
-  auto sizes = builder.create<LLVM::ExtractValueOp>(
-      loc, value, llvm::ArrayRef<int64_t>({kSizePosInMemRefDescriptor}));
+  SmallVector<Type> structElems(rank, indexType);
+  Value sizes = builder.create<LLVM::UndefOp>(
+      loc, LLVM::LLVMStructType::getLiteral(builder.getContext(), structElems));
+  ValueRange sizeVals =
+      ValueRange(elements).slice(kSizePosInMemRefDescriptor, rank);
+  for (auto it : llvm::enumerate(sizeVals))
+    sizes =
+        builder.create<LLVM::InsertValueOp>(loc, sizes, it.value(), it.index());
   auto sizesPtr = builder.create<LLVM::AllocaOp>(loc, ptrTy, arrayTy, one,
                                                  /*alignment=*/0);
   builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr);
@@ -152,40 +215,35 @@
 /// Builds IR inserting the pos-th size into the descriptor
 void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
                                Value size) {
-  value = builder.create<LLVM::InsertValueOp>(
-      loc, value, size, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos}));
+  elements[kSizePosInMemRefDescriptor + pos] = size;
 }
 
 void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
                                        unsigned pos, uint64_t size) {
-  setSize(builder, loc, pos,
-          createIndexAttrConstant(builder, loc, indexType, size));
+  elements[kSizePosInMemRefDescriptor + pos] =
+      createIndexAttrConstant(builder, loc, indexType, size);
 }
 
 /// Builds IR extracting the pos-th stride from the descriptor.
 Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) {
-  return builder.create<LLVM::ExtractValueOp>(
-      loc, value, ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos}));
+  return elements[kSizePosInMemRefDescriptor + getRank() + pos];
 }
 
 /// Builds IR inserting the pos-th stride into the descriptor
 void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
                                  Value stride) {
-  value = builder.create<LLVM::InsertValueOp>(
-      loc, value, stride,
-      ArrayRef<int64_t>({kStridePosInMemRefDescriptor, pos}));
+  elements[kSizePosInMemRefDescriptor + getRank() + pos] = stride;
 }
 
 void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
                                          unsigned pos, uint64_t stride) {
-  setStride(builder, loc, pos,
-            createIndexAttrConstant(builder, loc, indexType, stride));
+  elements[kSizePosInMemRefDescriptor + getRank() + pos] =
+      createIndexAttrConstant(builder, loc, indexType, stride);
 }
 
 LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() {
   return cast<LLVM::LLVMPointerType>(
-      cast<LLVM::LLVMStructType>(value.getType())
-          .getBody()[kAlignedPtrPosInMemRefDescriptor]);
+      elements[kAlignedPtrPosInMemRefDescriptor].getType());
 }
 
 Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc,
@@ -212,51 +270,6 @@
   return ptr;
 }
 
-/// Creates a MemRef descriptor structure from a list of individual values
-/// composing that descriptor, in the following order:
-/// - allocated pointer;
-/// - aligned pointer;
-/// - offset;
-/// - <rank> sizes;
-/// - <rank> strides;
-/// where <rank> is the MemRef rank as provided in `type`.
-Value MemRefDescriptor::pack(OpBuilder &builder, Location loc,
-                             const LLVMTypeConverter &converter,
-                             MemRefType type, ValueRange values) {
-  Type llvmType = converter.convertType(type);
-  auto d = MemRefDescriptor::poison(builder, loc, llvmType);
-
-  d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]);
-  d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]);
-  d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]);
-
-  int64_t rank = type.getRank();
-  for (unsigned i = 0; i < rank; ++i) {
-    d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]);
-    d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]);
-  }
-
-  return d;
-}
-
-/// Builds IR extracting individual elements of a MemRef descriptor structure
-/// and returning them as `results` list.
-void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed,
-                              MemRefType type,
-                              SmallVectorImpl<Value> &results) {
-  int64_t rank = type.getRank();
-  results.reserve(results.size() + getNumUnpackedValues(type));
-
-  MemRefDescriptor d(packed);
-  results.push_back(d.allocatedPtr(builder, loc));
-  results.push_back(d.alignedPtr(builder, loc));
-  results.push_back(d.offset(builder, loc));
-  for (int64_t i = 0; i < rank; ++i)
-    results.push_back(d.size(builder, loc, i));
-  for (int64_t i = 0; i < rank; ++i)
-    results.push_back(d.stride(builder, loc, i));
-}
-
 /// Returns the number of non-aggregate values that would be produced by
 /// `unpack`.
 unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) {
@@ -264,6 +277,8 @@
   return 3 + 2 * type.getRank();
 }
 
+int64_t MemRefDescriptor::getRank() { return (elements.size() - 3) / 2; }
+
 //===----------------------------------------------------------------------===//
 // MemRefDescriptorView implementation.
 //===----------------------------------------------------------------------===//
@@ -296,57 +311,61 @@
 //===----------------------------------------------------------------------===//
 
 /// Construct a helper for the given descriptor value.
-UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
-    : StructBuilder(descriptor) {}
+UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(ValueRange elements)
+    : elements(elements) {}
 
 /// Builds IR creating an `undef` value of the descriptor type.
-UnrankedMemRefDescriptor UnrankedMemRefDescriptor::poison(OpBuilder &builder,
-                                                          Location loc,
-                                                          Type descriptorType) {
-  Value descriptor = builder.create<LLVM::PoisonOp>(loc, descriptorType);
-  return UnrankedMemRefDescriptor(descriptor);
+UnrankedMemRefDescriptor
+UnrankedMemRefDescriptor::poison(OpBuilder &builder, Location loc,
+                                 TypeRange descriptorTypes) {
+  DenseMap<Type, Value> poisonValues;
+  SmallVector<Value> elements;
+  for (Type t : descriptorTypes) {
+    auto it = poisonValues.find(t);
+    if (it == poisonValues.end()) {
+      poisonValues[t] = builder.create<LLVM::PoisonOp>(loc, t);
+    }
+    elements.push_back(poisonValues[t]);
+  }
+  return UnrankedMemRefDescriptor(elements);
 }
+
+/// Builds IR extracting individual elements of a MemRef descriptor structure
+/// and returning them as `results` list.
+UnrankedMemRefDescriptor
+UnrankedMemRefDescriptor::fromPackedStruct(OpBuilder &builder, Location loc,
+                                           Value packed) {
+  SmallVector<Value> elements;
+  elements.push_back(extractStructElement(builder, loc, packed, 0));
+  elements.push_back(extractStructElement(builder, loc, packed, 1));
+  return UnrankedMemRefDescriptor(elements);
+}
+
+Value UnrankedMemRefDescriptor::packStruct(OpBuilder &builder, Location loc) {
+  SmallVector<Type> fields;
+  fields.push_back(elements[0].getType());
+  fields.push_back(elements[1].getType());
+  Value desc = builder.create<LLVM::UndefOp>(
+      loc, LLVM::LLVMStructType::getLiteral(builder.getContext(), fields));
+  desc = insertStructElement(builder, loc, desc, elements[0], 0);
+  desc = insertStructElement(builder, loc, desc, elements[1], 1);
+  return desc;
+}
+
 Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) const {
-  return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor);
+  return elements[kRankInUnrankedMemRefDescriptor];
 }
 void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc,
                                        Value v) {
-  setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v);
+  elements[kRankInUnrankedMemRefDescriptor] = v;
 }
 Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder,
                                               Location loc) const {
-  return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor);
+  return elements[kPtrInUnrankedMemRefDescriptor];
 }
 void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder,
                                                 Location loc, Value v) {
-  setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v);
-}
-
-/// Builds IR populating an unranked MemRef descriptor structure from a list
-/// of individual constituent values in the following order:
-/// - rank of the memref;
-/// - pointer to the memref descriptor.
-Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc,
-                                     const LLVMTypeConverter &converter,
-                                     UnrankedMemRefType type,
-                                     ValueRange values) {
-  Type llvmType = converter.convertType(type);
-  auto d = UnrankedMemRefDescriptor::poison(builder, loc, llvmType);
-
-  d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]);
-  d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]);
-  return d;
-}
-
-/// Builds IR extracting individual elements that compose an unranked memref
-/// descriptor and returns them as `results` list.
-void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
-                                      Value packed,
-                                      SmallVectorImpl<Value> &results) {
-  UnrankedMemRefDescriptor d(packed);
-  results.reserve(results.size() + 2);
-  results.push_back(d.rank(builder, loc));
-  results.push_back(d.memRefDescPtr(builder, loc));
+  elements[kPtrInUnrankedMemRefDescriptor] = v;
 }
 
 void UnrankedMemRefDescriptor::computeSizes(
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 71b6861..c5af470 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -59,7 +59,7 @@
 }
 
 Value ConvertToLLVMPattern::getStridedElementPtr(
-    Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
+    Location loc, MemRefType type, ValueRange memRefDesc, ValueRange indices,
     ConversionPatternRewriter &rewriter) const {
 
   auto [strides, offset] = type.getStridesAndOffset();
@@ -217,34 +217,20 @@
     Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr,
     ArrayRef<Value> sizes, ArrayRef<Value> strides,
     ConversionPatternRewriter &rewriter) const {
-  auto structType = typeConverter->convertType(memRefType);
-  auto memRefDescriptor = MemRefDescriptor::poison(rewriter, loc, structType);
-
-  // Field 1: Allocated pointer, used for malloc/free.
-  memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr);
-
-  // Field 2: Actual aligned pointer to payload.
-  memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
-
-  // Field 3: Offset in aligned pointer.
+  SmallVector<Value> elements;
+  elements.push_back(allocatedPtr);
+  elements.push_back(alignedPtr);
   Type indexType = getIndexType();
-  memRefDescriptor.setOffset(
-      rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, 0));
-
-  // Fields 4: Sizes.
-  for (const auto &en : llvm::enumerate(sizes))
-    memRefDescriptor.setSize(rewriter, loc, en.index(), en.value());
-
-  // Field 5: Strides.
-  for (const auto &en : llvm::enumerate(strides))
-    memRefDescriptor.setStride(rewriter, loc, en.index(), en.value());
-
-  return memRefDescriptor;
+  elements.push_back(createIndexAttrConstant(rewriter, loc, indexType, 0));
+  llvm::append_range(elements, sizes);
+  llvm::append_range(elements, strides);
+  return MemRefDescriptor(elements);
 }
 
 LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
     OpBuilder &builder, Location loc, TypeRange origTypes,
     SmallVectorImpl<Value> &operands, bool toDynamic) const {
+  // TODO: Pass unpacked structs to this function.
   assert(origTypes.size() == operands.size() &&
          "expected as may original types as operands");
 
@@ -253,7 +239,7 @@
   SmallVector<unsigned> unrankedAddressSpaces;
   for (unsigned i = 0, e = operands.size(); i < e; ++i) {
     if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
-      unrankedMemrefs.emplace_back(operands[i]);
+      unrankedMemrefs.push_back(UnrankedMemRefDescriptor::fromPackedStruct(builder, loc, operands[i]));
       FailureOr<unsigned> addressSpace =
           getTypeConverter()->getMemRefAddressSpace(memRefType);
       if (failed(addressSpace))
@@ -294,7 +280,7 @@
     if (!isa<UnrankedMemRefType>(type))
       continue;
     Value allocationSize = sizes[unrankedMemrefPos++];
-    UnrankedMemRefDescriptor desc(operands[i]);
+    UnrankedMemRefDescriptor desc = UnrankedMemRefDescriptor::fromPackedStruct(builder, loc, operands[i]);
 
     // Allocate memory, copy, and free the source if necessary.
     Value memory =
@@ -315,16 +301,15 @@
     // times, attempting to modify its pointer can lead to memory leaks
     // (allocated twice and overwritten) or double frees (the caller does not
     // know if the descriptor points to the same memory).
-    Type descriptorType = getTypeConverter()->convertType(type);
-    if (!descriptorType)
+    SmallVector<Type> descriptorTypes;
+    if (failed(getTypeConverter()->convertType(type, descriptorTypes)))
       return failure();
     auto updatedDesc =
-        UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
+        UnrankedMemRefDescriptor::poison(builder, loc, descriptorTypes);
     Value rank = desc.rank(builder, loc);
     updatedDesc.setRank(builder, loc, rank);
     updatedDesc.setMemRefDescPtr(builder, loc, memory);
-
-    operands[i] = updatedDesc;
+    operands[i] = updatedDesc.packStruct(builder, loc);
   }
 
   return success();
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4..2113bd3 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -50,68 +50,6 @@
          isa<LLVM::LLVMPointerType>(values.front().getType());
 }
 
-/// Pack SSA values into an unranked memref descriptor struct.
-static Value packUnrankedMemRefDesc(OpBuilder &builder,
-                                    UnrankedMemRefType resultType,
-                                    ValueRange inputs, Location loc,
-                                    const LLVMTypeConverter &converter) {
-  // Note: Bare pointers are not supported for unranked memrefs because a
-  // memref descriptor cannot be built just from a bare pointer.
-  if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
-    return Value();
-  return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
-                                        inputs);
-}
-
-/// Pack SSA values into a ranked memref descriptor struct.
-static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType,
-                                  ValueRange inputs, Location loc,
-                                  const LLVMTypeConverter &converter) {
-  assert(resultType && "expected non-null result type");
-  if (isBarePointer(inputs))
-    return MemRefDescriptor::fromStaticShape(builder, loc, converter,
-                                             resultType, inputs[0]);
-  if (TypeRange(inputs) ==
-      converter.getMemRefDescriptorFields(resultType,
-                                          /*unpackAggregates=*/true))
-    return MemRefDescriptor::pack(builder, loc, converter, resultType, inputs);
-  // The inputs are neither a bare pointer nor an unpacked memref descriptor.
-  // This materialization function cannot be used.
-  return Value();
-}
-
-/// MemRef descriptor elements -> UnrankedMemRefType
-static Value unrankedMemRefMaterialization(OpBuilder &builder,
-                                           UnrankedMemRefType resultType,
-                                           ValueRange inputs, Location loc,
-                                           const LLVMTypeConverter &converter) {
-  // A source materialization must return a value of type
-  // `resultType`, so insert a cast from the memref descriptor type
-  // (!llvm.struct) to the original memref type.
-  Value packed =
-      packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter);
-  if (!packed)
-    return Value();
-  return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
-      .getResult(0);
-}
-
-/// MemRef descriptor elements -> MemRefType
-static Value rankedMemRefMaterialization(OpBuilder &builder,
-                                         MemRefType resultType,
-                                         ValueRange inputs, Location loc,
-                                         const LLVMTypeConverter &converter) {
-  // A source materialization must return a value of type `resultType`,
-  // so insert a cast from the memref descriptor type (!llvm.struct) to the
-  // original memref type.
-  Value packed =
-      packRankedMemRefDesc(builder, resultType, inputs, loc, converter);
-  if (!packed)
-    return Value();
-  return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
-      .getResult(0);
-}
-
 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                      const LowerToLLVMOptions &options,
@@ -126,9 +64,22 @@
   addConversion([&](FunctionType type) { return convertFunctionType(type); });
   addConversion([&](IndexType type) { return convertIndexType(type); });
   addConversion([&](IntegerType type) { return convertIntegerType(type); });
-  addConversion([&](MemRefType type) { return convertMemRefType(type); });
   addConversion(
-      [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
+      [&](MemRefType type,
+          SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+        LogicalResult status = convertMemRefType(type, result);
+        if (failed(status))
+          return std::nullopt;
+        return success();
+      });
+  addConversion(
+      [&](UnrankedMemRefType type,
+          SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+        LogicalResult status = convertUnrankedMemRefType(type, result);
+        if (failed(status))
+          return std::nullopt;
+        return success();
+      });
   addConversion([&](VectorType type) -> std::optional<Type> {
     FailureOr<Type> llvmType = convertVectorType(type);
     if (failed(llvmType))
@@ -228,42 +179,26 @@
     return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
         .getResult(0);
   });
-  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs, Location loc) {
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
-        .getResult(0);
+  addTargetMaterialization([&](OpBuilder &builder, TypeRange resultTypes,
+                               ValueRange inputs,
+                               Location loc) -> SmallVector<Value> {
+    auto castOp =
+        builder.create<UnrealizedConversionCastOp>(loc, resultTypes, inputs);
+    return llvm::map_to_vector(castOp.getResults(),
+                               [](OpResult r) -> Value { return r; });
   });
 
-  // Source materializations convert from the new block argument types
-  // (multiple SSA values that make up a memref descriptor) back to the
-  // original block argument type.
-  addSourceMaterialization([&](OpBuilder &builder,
-                               UnrankedMemRefType resultType, ValueRange inputs,
-                               Location loc) {
-    return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
-                                         *this);
-  });
   addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
                                ValueRange inputs, Location loc) {
-    return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
-  });
-
-  // Bare pointer -> Packed MemRef descriptor
-  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs, Location loc,
-                               Type originalType) -> Value {
-    // The original MemRef type is required to build a MemRef descriptor
-    // because the sizes/strides of the MemRef cannot be inferred from just the
-    // bare pointer.
-    if (!originalType)
-      return Value();
-    if (resultType != convertType(originalType))
-      return Value();
-    if (auto memrefType = dyn_cast<MemRefType>(originalType))
-      return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this);
-    if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
-      return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc,
-                                    *this);
+    if (isBarePointer(inputs)) {
+      MemRefDescriptor desc = MemRefDescriptor::fromStaticShape(
+          builder, loc, *this, resultType, inputs[0]);
+      return builder
+          .create<UnrealizedConversionCastOp>(loc, resultType,
+                                              desc.getElements())
+          .getResult(0);
+    }
+    // Default materialization creates unrealized_conversion_cast.
     return Value();
   });
 
@@ -430,8 +365,10 @@
   Type resultType = type.getNumResults() == 0
                         ? LLVM::LLVMVoidType::get(&getContext())
                         : packFunctionResults(type.getResults());
-  if (!resultType)
+  if (!resultType) {
+    llvm_unreachable("no result type!");
     return {};
+  }
 
   auto ptrType = LLVM::LLVMPointerType::get(type.getContext());
   auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
@@ -443,9 +380,11 @@
   }
 
   for (Type t : type.getInputs()) {
-    auto converted = convertType(t);
-    if (!converted || !LLVM::isCompatibleType(converted))
+    auto converted = convertCallingConventionType(t);
+    if (!converted || !LLVM::isCompatibleType(converted)) {
+      llvm_unreachable("could not convert input!");
       return {};
+    }
     if (isa<MemRefType, UnrankedMemRefType>(t))
       converted = ptrType;
     inputs.push_back(converted);
@@ -533,14 +472,18 @@
 
 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
-Type LLVMTypeConverter::convertMemRefType(MemRefType type) const {
-  // When converting a MemRefType to a struct with descriptor fields, do not
-  // unpack the `sizes` and `strides` arrays.
-  SmallVector<Type, 5> types =
-      getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
-  if (types.empty())
-    return {};
-  return LLVM::LLVMStructType::getLiteral(&getContext(), types);
+LogicalResult LLVMTypeConverter::convertMemRefType(
+    MemRefType type, SmallVectorImpl<Type> &result, bool packed) const {
+  SmallVector<Type, 5> fields =
+      getMemRefDescriptorFields(type, /*unpackAggregates=*/!packed);
+  if (fields.empty())
+    return failure();
+  if (packed) {
+    result.push_back(LLVM::LLVMStructType::getLiteral(&getContext(), fields));
+  } else {
+    llvm::append_range(result, fields);
+  }
+  return success();
 }
 
 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
@@ -563,12 +506,17 @@
          llvm::divideCeil(getPointerBitwidth(space), 8);
 }
 
-Type LLVMTypeConverter::convertUnrankedMemRefType(
-    UnrankedMemRefType type) const {
+LogicalResult LLVMTypeConverter::convertUnrankedMemRefType(
+    UnrankedMemRefType type, SmallVectorImpl<Type> &result, bool packed) const {
   if (!convertType(type.getElementType()))
-    return {};
-  return LLVM::LLVMStructType::getLiteral(&getContext(),
-                                          getUnrankedMemRefDescriptorFields());
+    return failure();
+  if (packed) {
+    result.push_back(LLVM::LLVMStructType::getLiteral(
+        &getContext(), getUnrankedMemRefDescriptorFields()));
+  } else {
+    llvm::append_range(result, getUnrankedMemRefDescriptorFields());
+  }
+  return success();
 }
 
 FailureOr<unsigned>
@@ -665,6 +613,20 @@
     if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
       return convertMemRefToBarePtr(memrefTy);
 
+  if (auto memrefTy = dyn_cast<MemRefType>(type)) {
+    SmallVector<Type> convertedType;
+    LogicalResult status = convertMemRefType(memrefTy, convertedType, true);
+    if (failed(status)) return Type();
+    return llvm::getSingleElement(convertedType);
+  }
+
+  if (auto unrankedMemrefTy = dyn_cast<UnrankedMemRefType>(type)) {
+    SmallVector<Type> convertedType;
+    LogicalResult status = convertUnrankedMemRefType(unrankedMemrefTy, convertedType, true);
+    if (failed(status)) return Type();
+    return llvm::getSingleElement(convertedType);
+  }
+
   return convertType(type);
 }
 
@@ -674,12 +636,15 @@
 void LLVMTypeConverter::promoteBarePtrsToDescriptors(
     ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
     SmallVectorImpl<Value> &values) const {
-  assert(stdTypes.size() == values.size() &&
-         "The number of types and values doesn't match");
-  for (unsigned i = 0, end = values.size(); i < end; ++i)
-    if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
-      values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
-                                                    memrefTy, values[i]);
+  /*
+    assert(stdTypes.size() == values.size() &&
+           "The number of types and values doesn't match");
+    for (unsigned i = 0, end = values.size(); i < end; ++i)
+      if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
+        values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
+                                                      memrefTy, values[i]);
+  */
+  llvm_unreachable("not implemented");
 }
 
 /// Convert a non-empty list of types of values produced by an operation into an
@@ -743,38 +708,27 @@
 
 SmallVector<Value, 4>
 LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
-                                   ValueRange operands, OpBuilder &builder,
+                                   ArrayRef<ValueRange> operands, OpBuilder &builder,
                                    bool useBarePtrCallConv) const {
   SmallVector<Value, 4> promotedOperands;
   promotedOperands.reserve(operands.size());
   useBarePtrCallConv |= options.useBarePtrCallConv;
   for (auto it : llvm::zip(opOperands, operands)) {
     auto operand = std::get<0>(it);
-    auto llvmOperand = std::get<1>(it);
+    auto llvmOperands = std::get<1>(it);
 
     if (useBarePtrCallConv) {
       // For the bare-ptr calling convention, we only have to extract the
       // aligned pointer of a memref.
       if (dyn_cast<MemRefType>(operand.getType())) {
-        MemRefDescriptor desc(llvmOperand);
-        llvmOperand = desc.alignedPtr(builder, loc);
+        MemRefDescriptor desc(llvmOperands);
+        promotedOperands.push_back(desc.alignedPtr(builder, loc));
+        continue;
       } else if (isa<UnrankedMemRefType>(operand.getType())) {
         llvm_unreachable("Unranked memrefs are not supported");
       }
-    } else {
-      if (isa<UnrankedMemRefType>(operand.getType())) {
-        UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
-                                         promotedOperands);
-        continue;
-      }
-      if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
-        MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
-                                 promotedOperands);
-        continue;
-      }
     }
-
-    promotedOperands.push_back(llvmOperand);
+    llvm::append_range(promotedOperands, llvmOperands);
   }
   return promotedOperands;
 }
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index c5b2e83..c072723 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -195,6 +195,6 @@
       loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
 
   // Return the final value of the descriptor.
-  rewriter.replaceOp(op, {memRefDescriptor});
+  rewriter.replaceOpWithMultiple(op, {memRefDescriptor});
   return success();
 }
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index cb4317e..a12507b 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -25,6 +25,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/Pass/Pass.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/Support/MathExtras.h"
 #include <optional>
@@ -185,15 +186,14 @@
       : ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {}
 
   LogicalResult
-  matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
+  matchAndRewrite(memref::AssumeAlignmentOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Value memref = adaptor.getMemref();
     unsigned alignment = op.getAlignment();
     auto loc = op.getLoc();
 
     auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
-    Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
-                                     rewriter);
+    Value ptr = getStridedElementPtr(loc, srcMemRefType, adaptor.getMemref(),
+                                     /*indices=*/{}, rewriter);
 
     // Emit llvm.assume(true) ["align"(memref, alignment)].
     // This is more direct than ptrtoint-based checks, is explicitly supported,
@@ -220,7 +220,7 @@
       : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
 
   LogicalResult
-  matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
+  matchAndRewrite(memref::DeallocOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Insert the `free` declaration if it is not already present.
     FailureOr<LLVM::LLVMFuncOp> freeFunc =
@@ -253,21 +253,20 @@
   using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
+  matchAndRewrite(memref::DimOp dimOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type operandType = dimOp.getSource().getType();
     if (isa<UnrankedMemRefType>(operandType)) {
-      FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
-          operandType, dimOp, adaptor.getOperands(), rewriter);
+      FailureOr<Value> extractedSize =
+          extractSizeOfUnrankedMemRef(operandType, dimOp, adaptor, rewriter);
       if (failed(extractedSize))
         return failure();
       rewriter.replaceOp(dimOp, {*extractedSize});
       return success();
     }
     if (isa<MemRefType>(operandType)) {
-      rewriter.replaceOp(
-          dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
-                                            adaptor.getOperands(), rewriter)});
+      rewriter.replaceOp(dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
+                                                           adaptor, rewriter)});
       return success();
     }
     llvm_unreachable("expected MemRefType or UnrankedMemRefType");
@@ -276,7 +275,7 @@
 private:
   FailureOr<Value>
   extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
-                              OpAdaptor adaptor,
+                              OneToNOpAdaptor &adaptor,
                               ConversionPatternRewriter &rewriter) const {
     Location loc = dimOp.getLoc();
 
@@ -298,20 +297,24 @@
     UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
 
-    Type elementType = typeConverter->convertType(scalarMemRefType);
+    SmallVector<Type> convertedMemRefType;
+    if (failed(static_cast<const LLVMTypeConverter *>(typeConverter)
+                   ->convertMemRefType(scalarMemRefType, convertedMemRefType,
+                                       /*packed=*/true)))
+      return failure();
 
     // Get pointer to offset field of memref<element_type> descriptor.
     auto indexPtrTy =
         LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
     Value offsetPtr = rewriter.create<LLVM::GEPOp>(
-        loc, indexPtrTy, elementType, underlyingRankedDesc,
-        ArrayRef<LLVM::GEPArg>{0, 2});
+        loc, indexPtrTy, llvm::getSingleElement(convertedMemRefType),
+        underlyingRankedDesc, ArrayRef<LLVM::GEPArg>{0, 2});
 
     // The size value that we have to extract can be obtained using GEPop with
     // `dimOp.index() + 1` index argument.
     Value idxPlusOne = rewriter.create<LLVM::AddOp>(
         loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
-        adaptor.getIndex());
+        llvm::getSingleElement(adaptor.getIndex()));
     Value sizePtr = rewriter.create<LLVM::GEPOp>(
         loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
         idxPlusOne);
@@ -331,7 +334,7 @@
   }
 
   Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
-                                  OpAdaptor adaptor,
+                                  OneToNOpAdaptor &adaptor,
                                   ConversionPatternRewriter &rewriter) const {
     Location loc = dimOp.getLoc();
 
@@ -351,7 +354,7 @@
         return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
       }
     }
-    Value index = adaptor.getIndex();
+    Value index = llvm::getSingleElement(adaptor.getIndex());
     int64_t rank = memRefType.getRank();
     MemRefDescriptor memrefDescriptor(adaptor.getSource());
     return memrefDescriptor.size(rewriter, loc, index, rank);
@@ -400,7 +403,7 @@
   using Base::Base;
 
   LogicalResult
-  matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
+  matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = atomicOp.getLoc();
     Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
@@ -416,8 +419,12 @@
     // Compute the loaded value and branch to the loop block.
     rewriter.setInsertionPointToEnd(initBlock);
     auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
+    SmallVector<Value> indices =
+        llvm::map_to_vector(adaptor.getIndices(), [](ValueRange r) {
+          return llvm::getSingleElement(r);
+        });
     auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
-                                        adaptor.getIndices(), rewriter);
+                                        indices, rewriter);
     Value init = rewriter.create<LLVM::LoadOp>(
         loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
     rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
@@ -579,13 +586,15 @@
   using Base::Base;
 
   LogicalResult
-  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
+  matchAndRewrite(memref::LoadOp loadOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto type = loadOp.getMemRefType();
-
-    Value dataPtr =
-        getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
-                             adaptor.getIndices(), rewriter);
+    SmallVector<Value> indices =
+        llvm::map_to_vector(adaptor.getIndices(), [](ValueRange r) {
+          return llvm::getSingleElement(r);
+        });
+    Value dataPtr = getStridedElementPtr(
+        loadOp.getLoc(), type, adaptor.getMemref(), indices, rewriter);
     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
         loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
         false, loadOp.getNontemporal());
@@ -599,14 +608,18 @@
   using Base::Base;
 
   LogicalResult
-  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
+  matchAndRewrite(memref::StoreOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto type = op.getMemRefType();
-
+    SmallVector<Value> indices =
+        llvm::map_to_vector(adaptor.getIndices(), [](ValueRange r) {
+          return llvm::getSingleElement(r);
+        });
     Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
-                                         adaptor.getIndices(), rewriter);
-    rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
-                                               0, false, op.getNontemporal());
+                                         indices, rewriter);
+    rewriter.replaceOpWithNewOp<LLVM::StoreOp>(
+        op, llvm::getSingleElement(adaptor.getValue()), dataPtr, 0, false,
+        op.getNontemporal());
     return success();
   }
 };
@@ -617,13 +630,16 @@
   using Base::Base;
 
   LogicalResult
-  matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
+  matchAndRewrite(memref::PrefetchOp prefetchOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto type = prefetchOp.getMemRefType();
     auto loc = prefetchOp.getLoc();
-
-    Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
-                                         adaptor.getIndices(), rewriter);
+    SmallVector<Value> indices =
+        llvm::map_to_vector(adaptor.getIndices(), [](ValueRange r) {
+          return llvm::getSingleElement(r);
+        });
+    Value dataPtr =
+        getStridedElementPtr(loc, type, adaptor.getMemref(), indices, rewriter);
 
     // Replace with llvm.prefetch.
     IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
@@ -640,7 +656,7 @@
   using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
+  matchAndRewrite(memref::RankOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     Type operandType = op.getMemref().getType();
@@ -664,8 +680,9 @@
   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
+  matchAndRewrite(memref::CastOp memRefCastOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    auto loc = memRefCastOp.getLoc();
     Type srcType = memRefCastOp.getOperand().getType();
     Type dstType = memRefCastOp.getType();
 
@@ -674,21 +691,21 @@
     // and require source and result type to have the same rank. Therefore,
     // perform a sanity check that the underlying structs are the same. Once op
     // semantics are relaxed we can revisit.
+    SmallVector<Type> convertedSrc, convertedDst;
+    if (failed(typeConverter->convertType(srcType, convertedSrc)) ||
+        failed(typeConverter->convertType(dstType, convertedDst)))
+      return failure();
     if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
-      if (typeConverter->convertType(srcType) !=
-          typeConverter->convertType(dstType))
+      if (!llvm::equal(convertedSrc, convertedDst))
         return failure();
 
     // Unranked to unranked cast is disallowed
     if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType))
       return failure();
 
-    auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
-    auto loc = memRefCastOp.getLoc();
-
     // For ranked/ranked case, just keep the original descriptor.
     if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) {
-      rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
+      rewriter.replaceOpWithMultiple(memRefCastOp, {adaptor.getSource()});
       return success();
     }
 
@@ -701,19 +718,20 @@
       int64_t rank = srcMemRefType.getRank();
       // ptr = AllocaOp sizeof(MemRefDescriptor)
       auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
-          loc, adaptor.getSource(), rewriter);
+          loc, MemRefDescriptor(adaptor.getSource()).packStruct(rewriter, loc),
+          rewriter);
 
       // rank = ConstantOp srcRank
       auto rankVal = rewriter.create<LLVM::ConstantOp>(
           loc, getIndexType(), rewriter.getIndexAttr(rank));
       // poison = PoisonOp
       UnrankedMemRefDescriptor memRefDesc =
-          UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType);
+          UnrankedMemRefDescriptor::poison(rewriter, loc, convertedDst);
       // d1 = InsertValueOp poison, rank, 0
       memRefDesc.setRank(rewriter, loc, rankVal);
       // d2 = InsertValueOp d1, ptr, 1
       memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
-      rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
+      rewriter.replaceOpWithMultiple(memRefCastOp, {memRefDesc});
 
     } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
       // Casting from unranked type to ranked.
@@ -722,10 +740,16 @@
       UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
       // ptr = ExtractValueOp src, 1
       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
-
       // struct = LoadOp ptr
-      auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr);
-      rewriter.replaceOp(memRefCastOp, loadOp.getResult());
+      SmallVector<Type> targetStructType;
+      if (failed(getTypeConverter()->convertMemRefType(
+              cast<MemRefType>(dstType), targetStructType, /*packed=*/true)))
+        return failure();
+      auto loadOp = rewriter.create<LLVM::LoadOp>(
+          loc, llvm::getSingleElement(targetStructType), ptr);
+      rewriter.replaceOpWithMultiple(memRefCastOp,
+                                     {MemRefDescriptor::fromPackedStruct(
+                                         rewriter, loc, loadOp.getResult())});
     } else {
       llvm_unreachable("Unsupported unranked memref to unranked memref cast");
     }
@@ -743,7 +767,7 @@
   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
+  lowerToMemCopyIntrinsic(memref::CopyOp op, OneToNOpAdaptor adaptor,
                           ConversionPatternRewriter &rewriter) const {
     auto loc = op.getLoc();
     auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
@@ -782,74 +806,75 @@
     return success();
   }
 
+  /*
+    LogicalResult
+    lowerToMemCopyFunctionCall(memref::CopyOp op, OneToNOpAdaptor adaptor,
+                               ConversionPatternRewriter &rewriter) const {
+      auto loc = op.getLoc();
+      auto srcType = cast<BaseMemRefType>(op.getSource().getType());
+      auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
+
+      // First make sure we have an unranked memref descriptor representation.
+      auto makeUnranked = [&, this](Value ranked, MemRefType type) {
+        auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
+                                                      type.getRank());
+        auto *typeConverter = getTypeConverter();
+        auto ptr =
+            typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
+
+        auto unrankedType =
+            UnrankedMemRefType::get(type.getElementType(),
+    type.getMemorySpace()); return UnrankedMemRefDescriptor::pack( rewriter,
+    loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
+      };
+
+      // Save stack position before promoting descriptors
+      auto stackSaveOp =
+          rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
+
+      auto srcMemRefType = dyn_cast<MemRefType>(srcType);
+      Value unrankedSource =
+          srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
+                        : adaptor.getSource();
+      auto targetMemRefType = dyn_cast<MemRefType>(targetType);
+      Value unrankedTarget =
+          targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
+                           : adaptor.getTarget();
+
+      // Now promote the unranked descriptors to the stack.
+      auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
+                                                   rewriter.getIndexAttr(1));
+      auto promote = [&](Value desc) {
+        auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+        auto allocated =
+            rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
+        rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
+        return allocated;
+      };
+
+      auto sourcePtr = promote(unrankedSource);
+      auto targetPtr = promote(unrankedTarget);
+
+      // Derive size from llvm.getelementptr which will account for any
+      // potential alignment
+      auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
+      auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
+          op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
+      if (failed(copyFn))
+        return failure();
+      rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
+                                    ValueRange{elemSize, sourcePtr, targetPtr});
+
+      // Restore stack used for descriptors
+      rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
+
+      rewriter.eraseOp(op);
+
+      return success();
+    }
+  */
   LogicalResult
-  lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
-                             ConversionPatternRewriter &rewriter) const {
-    auto loc = op.getLoc();
-    auto srcType = cast<BaseMemRefType>(op.getSource().getType());
-    auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
-
-    // First make sure we have an unranked memref descriptor representation.
-    auto makeUnranked = [&, this](Value ranked, MemRefType type) {
-      auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
-                                                    type.getRank());
-      auto *typeConverter = getTypeConverter();
-      auto ptr =
-          typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
-
-      auto unrankedType =
-          UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
-      return UnrankedMemRefDescriptor::pack(
-          rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
-    };
-
-    // Save stack position before promoting descriptors
-    auto stackSaveOp =
-        rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
-
-    auto srcMemRefType = dyn_cast<MemRefType>(srcType);
-    Value unrankedSource =
-        srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
-                      : adaptor.getSource();
-    auto targetMemRefType = dyn_cast<MemRefType>(targetType);
-    Value unrankedTarget =
-        targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
-                         : adaptor.getTarget();
-
-    // Now promote the unranked descriptors to the stack.
-    auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
-                                                 rewriter.getIndexAttr(1));
-    auto promote = [&](Value desc) {
-      auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
-      auto allocated =
-          rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
-      rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
-      return allocated;
-    };
-
-    auto sourcePtr = promote(unrankedSource);
-    auto targetPtr = promote(unrankedTarget);
-
-    // Derive size from llvm.getelementptr which will account for any
-    // potential alignment
-    auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
-    auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
-        op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
-    if (failed(copyFn))
-      return failure();
-    rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
-                                  ValueRange{elemSize, sourcePtr, targetPtr});
-
-    // Restore stack used for descriptors
-    rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
-
-    rewriter.eraseOp(op);
-
-    return success();
-  }
-
-  LogicalResult
-  matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
+  matchAndRewrite(memref::CopyOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto srcType = cast<BaseMemRefType>(op.getSource().getType());
     auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
@@ -868,7 +893,8 @@
     if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
 
-    return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
+    return failure();
+    // return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
   }
 };
 
@@ -878,26 +904,23 @@
       memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
+  matchAndRewrite(memref::MemorySpaceCastOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-
     Type resultType = op.getDest().getType();
-    if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
-      auto resultDescType =
-          cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
-      Type newPtrType = resultDescType.getBody()[0];
+    SmallVector<Type> convertedResultTypes;
+    if (failed(typeConverter->convertType(resultType, convertedResultTypes)))
+      return failure();
 
-      SmallVector<Value> descVals;
-      MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
-                               descVals);
+    if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
+      Type newPtrType = convertedResultTypes[0];
+
+      SmallVector<Value> descVals = llvm::to_vector(adaptor.getSource());
       descVals[0] =
           rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
       descVals[1] =
           rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
-      Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
-                                            resultTypeR, descVals);
-      rewriter.replaceOp(op, result);
+      rewriter.replaceOpWithMultiple(op, {descVals});
       return success();
     }
     if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
@@ -922,8 +945,8 @@
       Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
 
       // Create and allocate storage for new memref descriptor.
-      auto result = UnrankedMemRefDescriptor::poison(
-          rewriter, loc, typeConverter->convertType(resultTypeU));
+      auto result =
+          UnrankedMemRefDescriptor::poison(rewriter, loc, convertedResultTypes);
       result.setRank(rewriter, loc, rank);
       SmallVector<Value, 1> sizes;
       UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
@@ -972,7 +995,7 @@
       rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
                                       copySize, /*isVolatile=*/false);
 
-      rewriter.replaceOp(op, ValueRange{result});
+      rewriter.replaceOpWithMultiple(op, ValueRange{result});
       return success();
     }
     return rewriter.notifyMatchFailure(loc, "unexpected memref type");
@@ -986,7 +1009,7 @@
                                      ConversionPatternRewriter &rewriter,
                                      const LLVMTypeConverter &typeConverter,
                                      Value originalOperand,
-                                     Value convertedOperand,
+                                     ValueRange convertedOperand,
                                      Value *allocatedPtr, Value *alignedPtr,
                                      Value *offset = nullptr) {
   Type operandType = originalOperand.getType();
@@ -1026,33 +1049,32 @@
       memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
+  matchAndRewrite(memref::ReinterpretCastOp castOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type srcType = castOp.getSource().getType();
 
-    Value descriptor;
+    SmallVector<Value> descriptor;
     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
                                                adaptor, &descriptor)))
       return failure();
-    rewriter.replaceOp(castOp, {descriptor});
+    rewriter.replaceOpWithMultiple(castOp, {descriptor});
     return success();
   }
 
 private:
   LogicalResult convertSourceMemRefToDescriptor(
       ConversionPatternRewriter &rewriter, Type srcType,
-      memref::ReinterpretCastOp castOp,
-      memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
+      memref::ReinterpretCastOp castOp, OneToNOpAdaptor adaptor,
+      SmallVector<Value> *descriptor) const {
     MemRefType targetMemRefType =
         cast<MemRefType>(castOp.getResult().getType());
-    auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
-        typeConverter->convertType(targetMemRefType));
-    if (!llvmTargetDescriptorTy)
+    SmallVector<Type> convertedTypes;
+    if (failed(typeConverter->convertType(targetMemRefType, convertedTypes)))
       return failure();
 
     // Create descriptor.
     Location loc = castOp.getLoc();
-    auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
+    auto desc = MemRefDescriptor::poison(rewriter, loc, convertedTypes);
 
     // Set allocated and aligned pointers.
     Value allocatedPtr, alignedPtr;
@@ -1064,7 +1086,8 @@
 
     // Set offset.
     if (castOp.isDynamicOffset(0))
-      desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
+      desc.setOffset(rewriter, loc,
+                     llvm::getSingleElement(adaptor.getOffsets()[0]));
     else
       desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
 
@@ -1073,16 +1096,19 @@
     unsigned dynStrideId = 0;
     for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
       if (castOp.isDynamicSize(i))
-        desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
+        desc.setSize(rewriter, loc, i,
+                     llvm::getSingleElement(adaptor.getSizes()[dynSizeId++]));
       else
         desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
 
       if (castOp.isDynamicStride(i))
-        desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
+        desc.setStride(
+            rewriter, loc, i,
+            llvm::getSingleElement(adaptor.getStrides()[dynStrideId++]));
       else
         desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
     }
-    *descriptor = desc;
+    llvm::append_range(*descriptor, desc.getElements());
     return success();
   }
 };
@@ -1092,15 +1118,15 @@
   using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
+  matchAndRewrite(memref::ReshapeOp reshapeOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type srcType = reshapeOp.getSource().getType();
 
-    Value descriptor;
+    SmallVector<Value> descriptor;
     if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
                                                adaptor, &descriptor)))
       return failure();
-    rewriter.replaceOp(reshapeOp, {descriptor});
+    rewriter.replaceOpWithMultiple(reshapeOp, {descriptor});
     return success();
   }
 
@@ -1108,21 +1134,19 @@
   LogicalResult
   convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
                                   Type srcType, memref::ReshapeOp reshapeOp,
-                                  memref::ReshapeOp::Adaptor adaptor,
-                                  Value *descriptor) const {
+                                  OneToNOpAdaptor adaptor,
+                                  SmallVector<Value> *descriptor) const {
     auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
     if (shapeMemRefType.hasStaticShape()) {
       MemRefType targetMemRefType =
           cast<MemRefType>(reshapeOp.getResult().getType());
-      auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
-          typeConverter->convertType(targetMemRefType));
-      if (!llvmTargetDescriptorTy)
+      SmallVector<Type> convertedTypes;
+      if (failed(typeConverter->convertType(targetMemRefType, convertedTypes)))
         return failure();
 
       // Create descriptor.
       Location loc = reshapeOp.getLoc();
-      auto desc =
-          MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
+      auto desc = MemRefDescriptor::poison(rewriter, loc, convertedTypes);
 
       // Set allocated and aligned pointers.
       Value allocatedPtr, alignedPtr;
@@ -1188,7 +1212,7 @@
         stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
       }
 
-      *descriptor = desc;
+      llvm::append_range(*descriptor, desc.getElements());
       return success();
     }
 
@@ -1204,8 +1228,11 @@
 
     // Create the unranked memref descriptor that holds the ranked one. The
     // inner descriptor is allocated on stack.
+    SmallVector<Type> convertedTypes;
+    if (failed(typeConverter->convertType(targetType, convertedTypes)))
+      return failure();
     auto targetDesc = UnrankedMemRefDescriptor::poison(
-        rewriter, loc, typeConverter->convertType(targetType));
+        rewriter, loc, convertedTypes);
     targetDesc.setRank(rewriter, loc, resultRank);
     SmallVector<Value, 4> sizes;
     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
@@ -1303,7 +1330,7 @@
     // Reset position to beginning of new remainder block.
     rewriter.setInsertionPointToStart(remainder);
 
-    *descriptor = targetDesc;
+    llvm::append_range(*descriptor, targetDesc.getElements());
     return success();
   }
 };
@@ -1315,10 +1342,11 @@
     : public ConvertOpToLLVMPattern<ReshapeOp> {
 public:
   using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
-  using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
+  using ReshapeOpAdaptor =
+      typename ConvertOpToLLVMPattern<ReshapeOp>::OneToNOpAdaptor;
 
   LogicalResult
-  matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
+  matchAndRewrite(ReshapeOp reshapeOp, ReshapeOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     return rewriter.notifyMatchFailure(
         reshapeOp,
@@ -1332,7 +1360,7 @@
   using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
+  matchAndRewrite(memref::SubViewOp subViewOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     return rewriter.notifyMatchFailure(
         subViewOp, "subview operations should have been expanded beforehand");
@@ -1351,7 +1379,7 @@
   using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
+  matchAndRewrite(memref::TransposeOp transposeOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = transposeOp.getLoc();
     MemRefDescriptor viewMemRef(adaptor.getIn());
@@ -1360,9 +1388,11 @@
     if (transposeOp.getPermutation().isIdentity())
       return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
 
-    auto targetMemRef = MemRefDescriptor::poison(
-        rewriter, loc,
-        typeConverter->convertType(transposeOp.getIn().getType()));
+    SmallVector<Type> convertedTypes;
+    if (failed(typeConverter->convertType(transposeOp.getIn().getType(),
+                                          convertedTypes)))
+      return failure();
+    auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, convertedTypes);
 
     // Copy the base and aligned pointers from the old descriptor to the new
     // one.
@@ -1388,7 +1418,7 @@
                              viewMemRef.stride(rewriter, loc, sourcePos));
     }
 
-    rewriter.replaceOp(transposeOp, {targetMemRef});
+    rewriter.replaceOpWithMultiple(transposeOp, {targetMemRef});
     return success();
   }
 };
@@ -1434,17 +1464,19 @@
   }
 
   LogicalResult
-  matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
+  matchAndRewrite(memref::ViewOp viewOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = viewOp.getLoc();
 
     auto viewMemRefType = viewOp.getType();
     auto targetElementTy =
         typeConverter->convertType(viewMemRefType.getElementType());
-    auto targetDescTy = typeConverter->convertType(viewMemRefType);
-    if (!targetDescTy || !targetElementTy ||
-        !LLVM::isCompatibleType(targetElementTy) ||
-        !LLVM::isCompatibleType(targetDescTy))
+    SmallVector<Type> targetDescTy;
+    if (failed(typeConverter->convertType(viewMemRefType, targetDescTy)))
+      return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
+             failure();
+    // TODO: Check targetDescTy is LLVM compatible.
+    if (!targetElementTy || !LLVM::isCompatibleType(targetElementTy))
       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
              failure();
 
@@ -1475,7 +1507,7 @@
     alignedPtr = rewriter.create<LLVM::GEPOp>(
         loc, alignedPtr.getType(),
         typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
-        adaptor.getByteShift());
+        llvm::getSingleElement(adaptor.getByteShift()));
 
     targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
 
@@ -1493,10 +1525,14 @@
 
     // Fields 4 and 5: Update sizes and strides.
     Value stride = nullptr, nextSize = nullptr;
+    SmallVector<Value> sizes =
+        llvm::map_to_vector(adaptor.getSizes(), [](ValueRange r) {
+          return llvm::getSingleElement(r);
+        });
     for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
       // Update size.
-      Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
-                           adaptor.getSizes(), i, indexType);
+      Value size = getSize(rewriter, loc, viewMemRefType.getShape(), sizes, i,
+                           indexType);
       targetMemRef.setSize(rewriter, loc, i, size);
       // Update stride.
       stride =
@@ -1505,7 +1541,7 @@
       nextSize = size;
     }
 
-    rewriter.replaceOp(viewOp, {targetMemRef});
+    rewriter.replaceOpWithMultiple(viewOp, {targetMemRef.getElements()});
     return success();
   }
 };
@@ -1551,7 +1587,7 @@
   using Base::Base;
 
   LogicalResult
-  matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
+  matchAndRewrite(memref::AtomicRMWOp atomicOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto maybeKind = matchSimpleAtomicOp(atomicOp);
     if (!maybeKind)
@@ -1561,11 +1597,15 @@
     int64_t offset;
     if (failed(memRefType.getStridesAndOffset(strides, offset)))
       return failure();
-    auto dataPtr =
-        getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
-                             adaptor.getIndices(), rewriter);
+    SmallVector<Value> indices =
+        llvm::map_to_vector(adaptor.getIndices(), [](ValueRange r) {
+          return llvm::getSingleElement(r);
+        });
+    auto dataPtr = getStridedElementPtr(atomicOp.getLoc(), memRefType,
+                                        adaptor.getMemref(), indices, rewriter);
     rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
-        atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
+        atomicOp, *maybeKind, dataPtr,
+        llvm::getSingleElement(adaptor.getValue()),
         LLVM::AtomicOrdering::acq_rel);
     return success();
   }
@@ -1580,7 +1620,7 @@
 
   LogicalResult
   matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
-                  OpAdaptor adaptor,
+                  OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     BaseMemRefType sourceTy = extractOp.getSource().getType();
 
@@ -1616,12 +1656,8 @@
 
   LogicalResult
   matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
-                  OpAdaptor adaptor,
+                  OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-
-    if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
-      return failure();
-
     // Create the descriptor.
     MemRefDescriptor sourceMemRef(adaptor.getSource());
     Location loc = extractStridedMetadataOp.getLoc();
@@ -1629,7 +1665,7 @@
 
     auto sourceMemRefType = cast<MemRefType>(source.getType());
     int64_t rank = sourceMemRefType.getRank();
-    SmallVector<Value> results;
+    SmallVector<ValueRange> results;
     results.reserve(2 + rank * 2);
 
     // Base buffer.
@@ -1639,19 +1675,11 @@
         rewriter, loc, *getTypeConverter(),
         cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
         baseBuffer, alignedBuffer);
-    results.push_back((Value)dstMemRef);
-
-    // Offset.
-    results.push_back(sourceMemRef.offset(rewriter, loc));
-
-    // Sizes.
-    for (unsigned i = 0; i < rank; ++i)
-      results.push_back(sourceMemRef.size(rewriter, loc, i));
-    // Strides.
-    for (unsigned i = 0; i < rank; ++i)
-      results.push_back(sourceMemRef.stride(rewriter, loc, i));
-
-    rewriter.replaceOp(extractStridedMetadataOp, results);
+    results.push_back(dstMemRef.getElements());
+    // Offset, sizes, strides of the source memref.
+    for (size_t i = 2, e = sourceMemRef.getElements().size(); i < e; ++i)
+      results.push_back(sourceMemRef.getElements().slice(i, 1));
+    rewriter.replaceOpWithMultiple(extractStridedMetadataOp, results);
     return success();
   }
 };
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 51507c6..4613b90 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1179,8 +1179,10 @@
 
     Value tensorElementType =
         elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
-    auto promotedOperands = getTypeConverter()->promoteOperands(
-        b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
+    llvm_unreachable("TODO");
+    SmallVector<Value> promotedOperands;
+    //auto promotedOperands = getTypeConverter()->promoteOperands(
+    //    b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
 
     Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
                                                  makeI64Const(b, 5));
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 213f737..23525cc 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -95,7 +95,7 @@
 // Add an index vector component to a base pointer.
 static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
                             const LLVMTypeConverter &typeConverter,
-                            MemRefType memRefType, Value llvmMemref, Value base,
+                            MemRefType memRefType, ValueRange llvmMemref, Value base,
                             Value index, VectorType vectorType) {
   assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
          "unsupported memref type");
@@ -185,8 +185,9 @@
 /// Overloaded utility that replaces a vector.load, vector.store,
 /// vector.maskedload and vector.maskedstore with their respective LLVM
 /// couterparts.
+template<typename Adaptor>
 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
-                                 vector::LoadOpAdaptor adaptor,
+                                 Adaptor adaptor,
                                  VectorType vectorTy, Value ptr, unsigned align,
                                  ConversionPatternRewriter &rewriter) {
   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align,
@@ -194,29 +195,32 @@
                                             loadOp.getNontemporal());
 }
 
+template<typename Adaptor>
 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
-                                 vector::MaskedLoadOpAdaptor adaptor,
+                                 Adaptor adaptor,
                                  VectorType vectorTy, Value ptr, unsigned align,
                                  ConversionPatternRewriter &rewriter) {
   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
-      loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
+      loadOp, vectorTy, ptr, llvm::getSingleElement(adaptor.getMask()), llvm::getSingleElement(adaptor.getPassThru()), align);
 }
 
+template<typename Adaptor>
 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
-                                 vector::StoreOpAdaptor adaptor,
+                                 Adaptor adaptor,
                                  VectorType vectorTy, Value ptr, unsigned align,
                                  ConversionPatternRewriter &rewriter) {
-  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
+  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, llvm::getSingleElement(adaptor.getValueToStore()),
                                              ptr, align, /*volatile_=*/false,
                                              storeOp.getNontemporal());
 }
 
+template<typename Adaptor>
 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
-                                 vector::MaskedStoreOpAdaptor adaptor,
+                                 Adaptor adaptor,
                                  VectorType vectorTy, Value ptr, unsigned align,
                                  ConversionPatternRewriter &rewriter) {
   rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
-      storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
+      storeOp, llvm::getSingleElement(adaptor.getValueToStore()), ptr, llvm::getSingleElement(adaptor.getMask()), align);
 }
 
 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
@@ -225,10 +229,11 @@
 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
 public:
   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
+  using Adaptor = typename ConvertOpToLLVMPattern<LoadOrStoreOp>::OneToNOpAdaptor;
 
   LogicalResult
   matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
-                  typename LoadOrStoreOp::Adaptor adaptor,
+                  Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Only 1-D vectors can be lowered to LLVM.
     VectorType vectorTy = loadOrStoreOp.getVectorType();
@@ -244,10 +249,11 @@
       return failure();
 
     // Resolve address.
+    SmallVector<Value> indices = llvm::map_to_vector(adaptor.getIndices(), [](ValueRange range) { return llvm::getSingleElement(range); });
     auto vtype = cast<VectorType>(
         this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
     Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
-                                               adaptor.getIndices(), rewriter);
+                                               indices, rewriter);
     replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
                          rewriter);
     return success();
@@ -261,7 +267,7 @@
   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
+  matchAndRewrite(vector::GatherOp gather, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = gather->getLoc();
     MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
@@ -284,17 +290,18 @@
     }
 
     // Resolve address.
+    SmallVector<Value> indices = llvm::map_to_vector(adaptor.getIndices(), [](ValueRange range) { return llvm::getSingleElement(range); });
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
-                                     adaptor.getIndices(), rewriter);
-    Value base = adaptor.getBase();
+                                     indices, rewriter);
+    ValueRange base = adaptor.getBase();
     Value ptrs =
         getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
-                       base, ptr, adaptor.getIndexVec(), vType);
+                       base, ptr, llvm::getSingleElement(adaptor.getIndexVec()), vType);
 
     // Replace with the gather intrinsic.
     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
-        gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
-        adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
+        gather, typeConverter->convertType(vType), ptrs, llvm::getSingleElement(adaptor.getMask()),
+        llvm::getSingleElement(adaptor.getPassThru()), rewriter.getI32IntegerAttr(align));
     return success();
   }
 };
@@ -306,7 +313,7 @@
   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
+  matchAndRewrite(vector::ScatterOp scatter, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = scatter->getLoc();
     MemRefType memRefType = scatter.getMemRefType();
@@ -328,15 +335,16 @@
     }
 
     // Resolve address.
+    SmallVector<Value> indices = llvm::map_to_vector(adaptor.getIndices(), [](ValueRange range) { return llvm::getSingleElement(range); });
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
-                                     adaptor.getIndices(), rewriter);
+                                     indices, rewriter);
     Value ptrs =
         getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
-                       adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
+                       adaptor.getBase(), ptr, llvm::getSingleElement(adaptor.getIndexVec()), vType);
 
     // Replace with the scatter intrinsic.
     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
-        scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
+        scatter, llvm::getSingleElement(adaptor.getValueToStore()), ptrs, llvm::getSingleElement(adaptor.getMask()),
         rewriter.getI32IntegerAttr(align));
     return success();
   }
@@ -349,18 +357,19 @@
   using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
+  matchAndRewrite(vector::ExpandLoadOp expand, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = expand->getLoc();
     MemRefType memRefType = expand.getMemRefType();
 
     // Resolve address.
     auto vtype = typeConverter->convertType(expand.getVectorType());
+    SmallVector<Value> indices = llvm::map_to_vector(adaptor.getIndices(), [](ValueRange range) { return llvm::getSingleElement(range); });
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
-                                     adaptor.getIndices(), rewriter);
+                                     indices, rewriter);
 
     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
-        expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
+        expand, vtype, ptr, llvm::getSingleElement(adaptor.getMask()), llvm::getSingleElement(adaptor.getPassThru()));
     return success();
   }
 };
@@ -372,17 +381,18 @@
   using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
+  matchAndRewrite(vector::CompressStoreOp compress, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = compress->getLoc();
     MemRefType memRefType = compress.getMemRefType();
 
     // Resolve address.
+    SmallVector<Value> indices = llvm::map_to_vector(adaptor.getIndices(), [](ValueRange range) { return llvm::getSingleElement(range); });
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
-                                     adaptor.getIndices(), rewriter);
+                                     indices, rewriter);
 
     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
-        compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
+        compress, llvm::getSingleElement(adaptor.getValueToStore()), ptr, llvm::getSingleElement(adaptor.getMask()));
     return success();
   }
 };
@@ -1416,7 +1426,7 @@
   using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
+  matchAndRewrite(vector::TypeCastOp castOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = castOp->getLoc();
     MemRefType sourceMemRefType =
@@ -1428,15 +1438,10 @@
         !targetMemRefType.hasStaticShape())
       return failure();
 
-    auto llvmSourceDescriptorTy =
-        dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
-    if (!llvmSourceDescriptorTy)
-      return failure();
     MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
 
-    auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
-        typeConverter->convertType(targetMemRefType));
-    if (!llvmTargetDescriptorTy)
+    SmallVector<Type> llvmTargetDescriptorTypes;
+    if (failed(typeConverter->convertType(targetMemRefType, llvmTargetDescriptorTypes)))
       return failure();
 
     // Only contiguous source buffers supported atm.
@@ -1453,7 +1458,7 @@
     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
 
     // Create descriptor.
-    auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
+    auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTypes);
     // Set allocated ptr.
     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
     desc.setAllocatedPtr(rewriter, loc, allocated);
@@ -1480,7 +1485,7 @@
       desc.setStride(rewriter, loc, index, stride);
     }
 
-    rewriter.replaceOp(castOp, {desc});
+    rewriter.replaceOpWithMultiple(castOp, {desc.getElements()});
     return success();
   }
 };