[WIP] 1:N conversion pattern

do not build argument materializations anymore

fix more tests

Fix decompose call graph test
diff --git a/mlir/artifacts/jq-linux64 b/mlir/artifacts/jq-linux64
new file mode 100755
index 0000000..f48b0ca
--- /dev/null
+++ b/mlir/artifacts/jq-linux64
Binary files differ
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index f3bf5b6..6751c3e 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -143,6 +143,8 @@
 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
 public:
   using OpAdaptor = typename SourceOp::Adaptor;
+  using OneToNOpAdaptor =
+      typename SourceOp::template GenericAdaptor<ArrayRef<ArrayRef<Value>>>;
 
   explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
                                   PatternBenefit benefit = 1)
@@ -153,8 +155,13 @@
   /// Wrappers around the RewritePattern methods that pass the derived op type.
   void rewrite(Operation *op, ArrayRef<Value> operands,
                ConversionPatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
-            rewriter);
+    auto sourceOp = cast<SourceOp>(op);
+    rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
+  }
+  void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
+               ConversionPatternRewriter &rewriter) const final {
+    auto sourceOp = cast<SourceOp>(op);
+    rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
   }
   LogicalResult match(Operation *op) const final {
     return match(cast<SourceOp>(op));
@@ -162,8 +169,15 @@
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    return matchAndRewrite(cast<SourceOp>(op),
-                           OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
+    auto sourceOp = cast<SourceOp>(op);
+    return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
+  }
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    auto sourceOp = cast<SourceOp>(op);
+    return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
+                           rewriter);
   }
 
   /// Rewrite and Match methods that operate on the SourceOp type. These must be
@@ -175,6 +189,12 @@
                        ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("must override rewrite or matchAndRewrite");
   }
+  virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
+                       ConversionPatternRewriter &rewriter) const {
+    SmallVector<Value> oneToOneOperands =
+        getOneToOneAdaptorOperands(adaptor.getOperands());
+    rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
@@ -183,6 +203,13 @@
     rewrite(op, adaptor, rewriter);
     return success();
   }
+  virtual LogicalResult
+  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const {
+    SmallVector<Value> oneToOneOperands =
+        getOneToOneAdaptorOperands(adaptor.getOperands());
+    return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+  }
 
 private:
   using ConvertToLLVMPattern::match;
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index de47765..4c555e1 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -537,6 +537,10 @@
                        ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("unimplemented rewrite");
   }
+  virtual void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
+                       ConversionPatternRewriter &rewriter) const {
+    rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+  }
 
   /// Hook for derived classes to implement combined matching and rewriting.
   virtual LogicalResult
@@ -547,6 +551,11 @@
     rewrite(op, operands, rewriter);
     return success();
   }
+  virtual LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
+                  ConversionPatternRewriter &rewriter) const {
+    return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+  }
 
   /// Attempt to match and rewrite the IR root at the specified operation.
   LogicalResult matchAndRewrite(Operation *op,
@@ -574,6 +583,9 @@
       : RewritePattern(std::forward<Args>(args)...),
         typeConverter(&typeConverter) {}
 
+  SmallVector<Value>
+  getOneToOneAdaptorOperands(ArrayRef<ArrayRef<Value>> operands) const;
+
 protected:
   /// An optional type converter for use by this pattern.
   const TypeConverter *typeConverter = nullptr;
@@ -589,6 +601,8 @@
 class OpConversionPattern : public ConversionPattern {
 public:
   using OpAdaptor = typename SourceOp::Adaptor;
+  using OneToNOpAdaptor =
+      typename SourceOp::template GenericAdaptor<ArrayRef<ArrayRef<Value>>>;
 
   OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
       : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -607,12 +621,24 @@
     auto sourceOp = cast<SourceOp>(op);
     rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
   }
+  void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
+               ConversionPatternRewriter &rewriter) const final {
+    auto sourceOp = cast<SourceOp>(op);
+    rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
+  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
     auto sourceOp = cast<SourceOp>(op);
     return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
   }
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    auto sourceOp = cast<SourceOp>(op);
+    return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
+                           rewriter);
+  }
 
   /// Rewrite and Match methods that operate on the SourceOp type. These must be
   /// overridden by the derived pattern class.
@@ -623,6 +649,12 @@
                        ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("must override matchAndRewrite or a rewrite method");
   }
+  virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
+                       ConversionPatternRewriter &rewriter) const {
+    SmallVector<Value> oneToOneOperands =
+        getOneToOneAdaptorOperands(adaptor.getOperands());
+    rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
@@ -631,6 +663,13 @@
     rewrite(op, adaptor, rewriter);
     return success();
   }
+  virtual LogicalResult
+  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const {
+    SmallVector<Value> oneToOneOperands =
+        getOneToOneAdaptorOperands(adaptor.getOperands());
+    return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+  }
 
 private:
   using ConversionPattern::matchAndRewrite;
@@ -656,11 +695,20 @@
                ConversionPatternRewriter &rewriter) const final {
     rewrite(cast<SourceOp>(op), operands, rewriter);
   }
+  void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
+               ConversionPatternRewriter &rewriter) const final {
+    rewrite(cast<SourceOp>(op), operands, rewriter);
+  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
     return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
   }
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
+  }
 
   /// Rewrite and Match methods that operate on the SourceOp type. These must be
   /// overridden by the derived pattern class.
@@ -668,6 +716,10 @@
                        ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("must override matchAndRewrite or a rewrite method");
   }
+  virtual void rewrite(SourceOp op, ArrayRef<ArrayRef<Value>> operands,
+                       ConversionPatternRewriter &rewriter) const {
+    rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const {
@@ -676,6 +728,11 @@
     rewrite(op, operands, rewriter);
     return success();
   }
+  virtual LogicalResult
+  matchAndRewrite(SourceOp op, ArrayRef<ArrayRef<Value>> operands,
+                  ConversionPatternRewriter &rewriter) const {
+    return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+  }
 
 private:
   using ConversionPattern::matchAndRewrite;
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ce91424..20a2a10 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,6 +153,7 @@
                                        type.isVarArg());
   });
 
+/*
   // Argument materializations convert from the new block argument types
   // (multiple SSA values that make up a memref descriptor) back to the
   // original block argument type. The dialect conversion framework will then
@@ -198,16 +199,62 @@
     return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
         .getResult(0);
   });
+
+*/
   // Add generic source and target materializations to handle cases where
   // non-LLVM types persist after an LLVM conversion.
   addSourceMaterialization([&](OpBuilder &builder, Type resultType,
                                ValueRange inputs, Location loc) {
-    if (inputs.size() != 1)
-      return Value();
+    //if (inputs.size() != 1)
+    //  return Value();
 
     return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
         .getResult(0);
   });
+  addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
+                               ValueRange inputs, Location loc) {
+    if (inputs.size()== 1 && isa<LLVM::LLVMStructType>(inputs.front().getType())) return Value();
+
+    Value desc;
+    if (inputs.size() == 1 && isa<LLVM::LLVMPointerType>(inputs.front().getType())) {
+      // This is a bare pointer. We allow bare pointers only for function entry
+      // blocks.
+      BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
+      if (!barePtr)
+        return Value();
+      Block *block = barePtr.getOwner();
+      if (!block->isEntryBlock() ||
+          !isa<FunctionOpInterface>(block->getParentOp()))
+        return Value();
+      desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
+                                               inputs[0]);
+    } else {
+      //llvm::errs() << "pack elems: " << inputs.size() << "\n";
+      //llvm::errs() << inputs[0] << "\n";
+      desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+      //llvm::errs() << "done packing\n";
+    }
+    // An argument materialization must return a value of type `resultType`,
+    // so insert a cast from the memref descriptor type (!llvm.struct) to the
+    // original memref type.
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
+        .getResult(0);
+  });
+  addSourceMaterialization([&](OpBuilder &builder, UnrankedMemRefType resultType,
+                               ValueRange inputs, Location loc) {
+    if (inputs.size() == 1) {
+      // Bare pointers are not supported for unranked memrefs because a
+      // memref descriptor cannot be built just from a bare pointer.
+      return Value();
+    }
+    Value desc =
+        UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+    // An argument materialization must return a value of type
+    // `resultType`, so insert a cast from the memref descriptor type
+    // (!llvm.struct) to the original memref type.
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
+        .getResult(0);
+  });
   addTargetMaterialization([&](OpBuilder &builder, Type resultType,
                                ValueRange inputs, Location loc) {
     if (inputs.size() != 1)
@@ -216,6 +263,51 @@
     return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
         .getResult(0);
   });
+  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
+                               ValueRange inputs,
+                               Location loc, Type originalType) -> Value {
+    llvm::errs() << "TARGET MAT: -> " << resultType << "\n";
+    if (!originalType) {
+      llvm::errs() << " -- no orig\n";
+      return Value();
+    }
+    if (auto memrefType = dyn_cast<MemRefType>(originalType)) {
+      assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
+      if (inputs.size() == 1) {
+        Value input = inputs.front();
+        if (auto castOp =input.getDefiningOp<UnrealizedConversionCastOp>()) {
+          if (castOp.getInputs().size() == 1 && isa<LLVM::LLVMPointerType>(castOp.getInputs()[0].getType())) {
+            input = castOp.getInputs()[0];
+          }
+        }
+        if (!isa<LLVM::LLVMPointerType>(input.getType()))
+          return Value();
+        BlockArgument barePtr = dyn_cast<BlockArgument>(input);
+        if (!barePtr)
+          return Value();
+        Block *block = barePtr.getOwner();
+        if (!block->isEntryBlock() ||
+            !isa<FunctionOpInterface>(block->getParentOp()))
+          return Value();
+        // Bare ptr
+        return MemRefDescriptor::fromStaticShape(builder, loc, *this, memrefType,
+                                                 input);
+      }
+      return MemRefDescriptor::pack(builder, loc, *this, memrefType, inputs);
+    }
+    if (auto memrefType = dyn_cast<UnrankedMemRefType>(originalType)) {
+      assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
+      if (inputs.size() == 1) {
+          // Bare pointers are not supported for unranked memrefs because a
+          // memref descriptor cannot be built just from a bare pointer.
+          return Value();
+      }
+      return UnrankedMemRefDescriptor::pack(builder, loc, *this,
+                                                    memrefType, inputs);
+    }
+
+      return Value();
+  });
 
   // Integer memory spaces map to themselves.
   addTypeAttributeConversion(
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index a087643..03be003 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
@@ -14,40 +14,6 @@
 using namespace mlir::func;
 
 //===----------------------------------------------------------------------===//
-// Helper functions
-//===----------------------------------------------------------------------===//
-
-/// If the given value can be decomposed with the type converter, decompose it.
-/// Otherwise, return the given value.
-// TODO: Value decomposition should happen automatically through a 1:N adaptor.
-// This function will disappear when the 1:1 and 1:N drivers are merged.
-static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc,
-                                         Value value,
-                                         const TypeConverter *converter) {
-  // Try to convert the given value's type. If that fails, just return the
-  // given value.
-  SmallVector<Type> convertedTypes;
-  if (failed(converter->convertType(value.getType(), convertedTypes)))
-    return {value};
-  if (convertedTypes.empty())
-    return {};
-
-  // If the given value's type is already legal, just return the given value.
-  TypeRange convertedTypeRange(convertedTypes);
-  if (convertedTypeRange == TypeRange(value.getType()))
-    return {value};
-
-  // Try to materialize a target conversion. If the materialization did not
-  // produce values of the requested type, the materialization failed. Just
-  // return the given value in that case.
-  SmallVector<Value> result = converter->materializeTargetConversion(
-      builder, loc, convertedTypeRange, value);
-  if (result.empty())
-    return {value};
-  return result;
-}
-
-//===----------------------------------------------------------------------===//
 // DecomposeCallGraphTypesForFuncArgs
 //===----------------------------------------------------------------------===//
 
@@ -102,16 +68,11 @@
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
+  matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     SmallVector<Value, 2> newOperands;
-    for (Value operand : adaptor.getOperands()) {
-      // TODO: We can directly take the values from the adaptor once this is a
-      // 1:N conversion pattern.
-      llvm::append_range(newOperands,
-                         decomposeValue(rewriter, operand.getLoc(), operand,
-                                        getTypeConverter()));
-    }
+    for (ValueRange operand : adaptor.getOperands())
+      llvm::append_range(newOperands, operand);
     rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
     return success();
   }
@@ -128,18 +89,13 @@
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(CallOp op, OpAdaptor adaptor,
+  matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
 
     // Create the operands list of the new `CallOp`.
     SmallVector<Value, 2> newOperands;
-    for (Value operand : adaptor.getOperands()) {
-      // TODO: We can directly take the values from the adaptor once this is a
-      // 1:N conversion pattern.
-      llvm::append_range(newOperands,
-                         decomposeValue(rewriter, operand.getLoc(), operand,
-                                        getTypeConverter()));
-    }
+    for (ValueRange operand : adaptor.getOperands())
+      llvm::append_range(newOperands, operand);
 
     // Create the new result types for the new `CallOp` and track the number of
     // replacement types for each original op result.
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 93a7805..4d154b0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -16,20 +16,16 @@
 
 namespace {
 
-// Unpacks the single unrealized_conversion_cast using the list of inputs
-// e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d)
-static void unpackUnrealizedConversionCast(Value v,
-                                           SmallVectorImpl<Value> &unpacked) {
-  if (auto cast =
-          dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp())) {
-    if (cast.getInputs().size() != 1) {
-      // 1 : N type conversion.
-      unpacked.append(cast.getInputs().begin(), cast.getInputs().end());
-      return;
-    }
-  }
-  // 1 : 1 type conversion.
-  unpacked.push_back(v);
+static SmallVector<Value> flattenValues(ArrayRef<ArrayRef<Value>> values) {
+  SmallVector<Value> result;
+  for (ArrayRef<Value> v : values)
+    llvm::append_range(result, v);
+  return result;
+}
+
+static Value getSingleValue(ArrayRef<Value> values) {
+  assert(values.size() == 1 && "expected single value");
+  return values.front();
 }
 
 // CRTP
@@ -40,19 +36,21 @@
 public:
   using OpConversionPattern<SourceOp>::typeConverter;
   using OpConversionPattern<SourceOp>::OpConversionPattern;
-  using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor;
+  using OneToNOpAdaptor =
+      typename OpConversionPattern<SourceOp>::OneToNOpAdaptor;
 
   //
   // Derived classes should provide the following method which performs the
   // actual conversion. It should return std::nullopt upon conversion failure
   // and return the converted operation upon success.
   //
-  // std::optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor,
-  //                                    ConversionPatternRewriter &rewriter,
-  //                                    TypeRange dstTypes) const;
+  // std::optional<SourceOp> convertSourceOp(
+  //     SourceOp op, OneToNOpAdaptor adaptor,
+  //     ConversionPatternRewriter &rewriter,
+  //     TypeRange dstTypes) const;
 
   LogicalResult
-  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
+  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     SmallVector<Type> dstTypes;
     SmallVector<unsigned> offsets;
@@ -73,28 +71,15 @@
       return rewriter.notifyMatchFailure(op, "could not convert operation");
 
     // Packs the return value.
-    SmallVector<Value> packedRets;
+    SmallVector<ValueRange> packedRets;
     for (unsigned i = 1, e = offsets.size(); i < e; i++) {
       unsigned start = offsets[i - 1], end = offsets[i];
       unsigned len = end - start;
       ValueRange mappedValue = newOp->getResults().slice(start, len);
-      if (len != 1) {
-        // 1 : N type conversion.
-        Type origType = op.getResultTypes()[i - 1];
-        Value mat = typeConverter->materializeSourceConversion(
-            rewriter, op.getLoc(), origType, mappedValue);
-        if (!mat) {
-          return rewriter.notifyMatchFailure(
-              op, "Failed to materialize 1:N type conversion");
-        }
-        packedRets.push_back(mat);
-      } else {
-        // 1 : 1 type conversion.
-        packedRets.push_back(mappedValue.front());
-      }
+      packedRets.push_back(mappedValue);
     }
 
-    rewriter.replaceOp(op, packedRets);
+    rewriter.replaceOpWithMultiple(op, packedRets);
     return success();
   }
 };
@@ -105,7 +90,7 @@
   using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
 
   // The callback required by CRTP.
-  std::optional<ForOp> convertSourceOp(ForOp op, OpAdaptor adaptor,
+  std::optional<ForOp> convertSourceOp(ForOp op, OneToNOpAdaptor adaptor,
                                        ConversionPatternRewriter &rewriter,
                                        TypeRange dstTypes) const {
     // Create a empty new op and inline the regions from the old op.
@@ -129,16 +114,13 @@
     if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter)))
       return std::nullopt;
 
-    // Unpacked the iteration arguments.
-    SmallVector<Value> flatArgs;
-    for (Value arg : adaptor.getInitArgs())
-      unpackUnrealizedConversionCast(arg, flatArgs);
-
     // We can not do clone as the number of result types after conversion
     // might be different.
-    ForOp newOp = rewriter.create<ForOp>(op.getLoc(), adaptor.getLowerBound(),
-                                         adaptor.getUpperBound(),
-                                         adaptor.getStep(), flatArgs);
+    ForOp newOp = rewriter.create<ForOp>(
+        op.getLoc(), getSingleValue(adaptor.getLowerBound()),
+        getSingleValue(adaptor.getUpperBound()),
+        getSingleValue(adaptor.getStep()),
+        flattenValues(adaptor.getInitArgs()));
 
     // Reserve whatever attributes in the original op.
     newOp->setAttrs(op->getAttrs());
@@ -160,12 +142,12 @@
 public:
   using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
 
-  std::optional<IfOp> convertSourceOp(IfOp op, OpAdaptor adaptor,
+  std::optional<IfOp> convertSourceOp(IfOp op, OneToNOpAdaptor adaptor,
                                       ConversionPatternRewriter &rewriter,
                                       TypeRange dstTypes) const {
 
-    IfOp newOp = rewriter.create<IfOp>(op.getLoc(), dstTypes,
-                                       adaptor.getCondition(), true);
+    IfOp newOp = rewriter.create<IfOp>(
+        op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true);
     newOp->setAttrs(op->getAttrs());
 
     // We do not need the empty blocks created by rewriter.
@@ -189,15 +171,11 @@
 public:
   using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
 
-  std::optional<WhileOp> convertSourceOp(WhileOp op, OpAdaptor adaptor,
+  std::optional<WhileOp> convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor,
                                          ConversionPatternRewriter &rewriter,
                                          TypeRange dstTypes) const {
-    // Unpacked the iteration arguments.
-    SmallVector<Value> flatArgs;
-    for (Value arg : adaptor.getOperands())
-      unpackUnrealizedConversionCast(arg, flatArgs);
-
-    auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, flatArgs);
+    auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes,
+                                          flattenValues(adaptor.getOperands()));
 
     for (auto i : {0u, 1u}) {
       if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
@@ -218,13 +196,10 @@
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
+  matchAndRewrite(scf::YieldOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    SmallVector<Value> unpackedYield;
-    for (Value operand : adaptor.getOperands())
-      unpackUnrealizedConversionCast(operand, unpackedYield);
-
-    rewriter.replaceOpWithNewOp<scf::YieldOp>(op, unpackedYield);
+    rewriter.replaceOpWithNewOp<scf::YieldOp>(
+        op, flattenValues(adaptor.getOperands()));
     return success();
   }
 };
@@ -235,13 +210,10 @@
 public:
   using OpConversionPattern<ConditionOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
+  matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    SmallVector<Value> unpackedYield;
-    for (Value operand : adaptor.getOperands())
-      unpackUnrealizedConversionCast(operand, unpackedYield);
-
-    rewriter.modifyOpInPlace(op, [&]() { op->setOperands(unpackedYield); });
+    rewriter.modifyOpInPlace(
+        op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); });
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 9abb1d3..0fa9f26 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -39,25 +39,16 @@
 // Helper methods.
 //===----------------------------------------------------------------------===//
 
-/// Flattens a list of operands that may contain sparse tensors.
-static void flattenOperands(ValueRange operands,
-                            SmallVectorImpl<Value> &flattened) {
-  // In case of
-  // sparse_tensor, c, sparse_tensor
-  // ==>
-  // memref ..., c, memref ...
-  for (auto operand : operands) {
-    if (getSparseTensorEncoding(operand.getType())) {
-      auto tuple = getTuple(operand);
-      // An unrealized_conversion_cast will be inserted by type converter to
-      // inter-mix the gap between 1:N conversion between sparse tensors and
-      // fields. In this case, take the operands in the cast and replace the
-      // sparse tensor output with the flattened type array.
-      flattened.append(tuple.getOperands().begin(), tuple.getOperands().end());
-    } else {
-      flattened.push_back(operand);
-    }
-  }
+static SmallVector<Value> flattenValues(ArrayRef<ArrayRef<Value>> values) {
+  SmallVector<Value> result;
+  for (ArrayRef<Value> v : values)
+    llvm::append_range(result, v);
+  return result;
+}
+
+static Value getSingleValue(ArrayRef<Value> values) {
+  assert(values.size() == 1 && "expected single value");
+  return values.front();
 }
 
 /// Generates a load with proper `index` typing.
@@ -567,12 +558,11 @@
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
+  matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    SmallVector<Value> flattened;
-    flattenOperands(adaptor.getOperands(), flattened);
     // Create a return with the flattened value extracted from sparse tensors.
-    rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened);
+    rewriter.replaceOpWithNewOp<func::ReturnOp>(
+        op, flattenValues(adaptor.getOperands()));
     return success();
   }
 };
@@ -583,7 +573,7 @@
   // The default CallOp converter can not handle 1:N type conversion.
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
+  matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     // In case of:
@@ -596,10 +586,8 @@
       return failure();
 
     // (1) Generates new call with flattened return value.
-    SmallVector<Value> flattened;
-    flattenOperands(adaptor.getOperands(), flattened);
-    auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
-                                                 finalRetTy, flattened);
+    auto newCall = rewriter.create<func::CallOp>(
+        loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands()));
     // (2) Gather sparse tensor returns.
     SmallVector<SmallVector<Value>> packedResultVals;
     // Tracks the offset of current return value (of the original call)
@@ -643,13 +631,15 @@
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(LvlOp op, OpAdaptor adaptor,
+  matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     std::optional<int64_t> lvl = op.getConstantLvlIndex();
-    if (!lvl || !getSparseTensorEncoding(adaptor.getSource().getType()))
+    if (!lvl || !getSparseTensorEncoding(op.getSource().getType()))
       return failure();
 
-    auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
+    SparseTensorDescriptor desc(
+        SparseTensorType(cast<RankedTensorType>(op.getSource().getType())),
+        adaptor.getSource());
     auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
 
     rewriter.replaceOp(op, sz);
@@ -661,7 +651,7 @@
 struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor,
+  matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     MLIRContext *ctx = op.getContext();
@@ -675,8 +665,10 @@
     assert(dstStt.hasSameDimToLvl(srcStt));
 
     // We don't need a mutable descriptor here as we perform sorting in-place.
-    auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getInputCoo());
-    auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo());
+    SparseTensorDescriptor desc(
+        SparseTensorType(cast<RankedTensorType>(op.getInputCoo().getType())),
+        adaptor.getInputCoo());
+    auto nnz = desc.getValMemSize(rewriter, op.getLoc());
     auto crd = desc.getAOSMemRef();
     auto val = desc.getValMemRef();
 
@@ -691,7 +683,7 @@
 
     // Since we do in-place sorting, the destinate tensor will have the same set
     // of memrefs as the source tensor.
-    rewriter.replaceOp(op, adaptor.getInputCoo());
+    rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()});
     return success();
   }
 };
@@ -701,10 +693,13 @@
 public:
   using OpConversionPattern<Op>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+  matchAndRewrite(Op op,
+                  typename OpConversionPattern<Op>::OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Simply lowers to specifer.get <field> operation.
-    auto desc = getDescriptorFromTensorTuple(adaptor.getSlice());
+    SparseTensorDescriptor desc(
+        SparseTensorType(cast<RankedTensorType>(op.getSlice().getType())),
+        adaptor.getSlice());
     auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
                                     op.getDim().getZExtValue());
 
@@ -718,14 +713,14 @@
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
+  matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Only rewrite identically annotated source/dest.
     auto encDst = getSparseTensorEncoding(op.getType());
     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
     if (!encDst || encDst != encSrc)
       return failure();
-    rewriter.replaceOp(op, adaptor.getOperands());
+    rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
     return success();
   }
 };
@@ -734,10 +729,10 @@
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
+  matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Simply fold the operation.
-    rewriter.replaceOp(op, adaptor.getSource());
+    rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
     return success();
   }
 };
@@ -753,7 +748,7 @@
         enableBufferInitialization(enableInit) {}
 
   LogicalResult
-  matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
+  matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     const auto resType = getSparseTensorType(op);
     if (!resType.hasEncoding())
@@ -762,7 +757,9 @@
     Location loc = op.getLoc();
     // Deal with copy.
     if (op.getCopy()) {
-      auto desc = getDescriptorFromTensorTuple(adaptor.getCopy());
+      SparseTensorDescriptor desc(
+          SparseTensorType(cast<RankedTensorType>(op.getCopy().getType())),
+          adaptor.getCopy());
       SmallVector<Value> fields;
       fields.reserve(desc.getNumFields());
       // Memcpy on memref fields.
@@ -787,7 +784,8 @@
     }
     // Level size equals to dimension size since lvl2dim map is an identity map.
     SmallVector<Value> lvlSizesValues;
-    createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
+    createDimSizes(rewriter, loc, resType,
+                   flattenValues(adaptor.getDynamicSizes()),
                    /*dimSizesValues=*/lvlSizesValues);
 
     // Construct allocation for each field.
@@ -857,7 +855,7 @@
         createDeallocs(createDeallocs) {}
 
   LogicalResult
-  matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
+  matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto enc = getSparseTensorEncoding(op.getTensor().getType());
     if (!enc)
@@ -868,7 +866,9 @@
     if (createDeallocs) {
       // Replace the sparse tensor deallocation with field deallocations.
       Location loc = op.getLoc();
-      auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+      SparseTensorDescriptor desc(
+          SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())),
+          adaptor.getTensor());
       for (auto input : desc.getMemRefFields())
         // Deallocate every buffer used to store the sparse tensor handler.
         rewriter.create<memref::DeallocOp>(loc, input);
@@ -886,10 +886,12 @@
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(LoadOp op, OpAdaptor adaptor,
+  matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Prepare descriptor.
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    SparseTensorDescriptor desc(
+        SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())),
+        adaptor.getTensor());
     // Generate optional insertion finalization code.
     if (op.getHasInserts())
       genEndInsert(rewriter, op.getLoc(), desc);
@@ -904,12 +906,14 @@
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
+  matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (!getSparseTensorEncoding(op.getTensor().getType()))
       return failure();
     Location loc = op->getLoc();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    SparseTensorDescriptor desc(
+        SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())),
+        adaptor.getTensor());
     const auto srcType = getSparseTensorType(op.getTensor());
     Type eltType = srcType.getElementType();
     Type boolType = rewriter.getIntegerType(1);
@@ -955,15 +959,18 @@
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(CompressOp op, OpAdaptor adaptor,
+  matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
     SmallVector<Value> fields;
-    auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
-    Value values = adaptor.getValues();
-    Value filled = adaptor.getFilled();
-    Value added = adaptor.getAdded();
-    Value count = adaptor.getCount();
+    llvm::append_range(fields, adaptor.getTensor());
+    MutSparseTensorDescriptor desc(
+        SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())),
+        fields);
+    Value values = getSingleValue(adaptor.getValues());
+    Value filled = getSingleValue(adaptor.getFilled());
+    Value added = getSingleValue(adaptor.getAdded());
+    Value count = getSingleValue(adaptor.getCount());
     const SparseTensorType dstType(desc.getRankedTensorType());
     Type eltType = dstType.getElementType();
 
@@ -996,7 +1003,8 @@
     SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end());
     SmallVector<Type> flatSpTensorTps = llvm::to_vector(
         llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); }));
-    params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
+    SmallVector<Value> flatLvlCoords = flattenValues(adaptor.getLvlCoords());
+    params.append(flatLvlCoords.begin(), flatLvlCoords.end());
     params.push_back(crd);
     params.push_back(value);
     SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
@@ -1024,19 +1032,22 @@
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
+  matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto stt = getSparseTensorType(adaptor.getDest());
+    auto stt = getSparseTensorType(op.getDest());
     if (!stt.hasEncoding())
       return failure();
     assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
 
     Location loc = op.getLoc();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getDest());
+    SparseTensorDescriptor desc(
+        SparseTensorType(cast<RankedTensorType>(op.getDest().getType())),
+        adaptor.getDest());
     TypeRange flatSpTensorTps = desc.getFields().getTypes();
     SmallVector<Value> params = llvm::to_vector(desc.getFields());
-    params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
-    params.push_back(adaptor.getScalar());
+    SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices());
+    params.append(flatIndices.begin(), flatIndices.end());
+    params.push_back(getSingleValue(adaptor.getScalar()));
     SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
                                     params, /*genCall=*/true);
     SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
@@ -1052,14 +1063,16 @@
   using OpAdaptor = typename ToPositionsOp::Adaptor;
   using OpConversionPattern<ToPositionsOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
+  matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Replace the requested position access with corresponding field.
     // The view is restricted to the actual size to ensure clients
     // of this operation truly observe size, not capacity!
     Location loc = op.getLoc();
     Level lvl = op.getLevel();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    SparseTensorDescriptor desc(
+        SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())),
+        adaptor.getTensor());
     auto mem = desc.getPosMemRef(lvl);
     auto size = desc.getPosMemSize(rewriter, loc, lvl);
     rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
@@ -1074,14 +1087,16 @@
   using OpAdaptor = typename ToCoordinatesOp::Adaptor;
   using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
+  matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Replace the requested coordinates access with corresponding field.
     // The view is restricted to the actual size to ensure clients
     // of this operation truly observe size, not capacity!
     Location loc = op.getLoc();
     Level lvl = op.getLevel();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    SparseTensorDescriptor desc(
+        SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())),
+        adaptor.getTensor());
     auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
     if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
       auto size = desc.getCrdMemSize(rewriter, loc, lvl);
@@ -1099,14 +1114,16 @@
   using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor;
   using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
+  matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Replace the requested coordinates access with corresponding field.
     // The view is restricted to the actual size to ensure clients
     // of this operation truly observe size, not capacity!
     Location loc = op.getLoc();
     Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    SparseTensorDescriptor desc(
+        SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())),
+        adaptor.getTensor());
     auto mem = desc.getAOSMemRef();
     auto size = desc.getCrdMemSize(rewriter, loc, lvl);
     rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
@@ -1120,13 +1137,15 @@
   using OpAdaptor = typename ToValuesOp::Adaptor;
   using OpConversionPattern<ToValuesOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
+  matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Replace the requested values access with corresponding field.
     // The view is restricted to the actual size to ensure clients
     // of this operation truly observe size, not capacity!
     Location loc = op.getLoc();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    SparseTensorDescriptor desc(
+        SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())),
+        adaptor.getTensor());
     auto mem = desc.getValMemRef();
     auto size = desc.getValMemSize(rewriter, loc);
     rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
@@ -1139,7 +1158,7 @@
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
+  matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
     SparseTensorEncodingAttr encSrc =
@@ -1159,7 +1178,7 @@
     Type srcElemTp = op.getSource().getType().getElementType();
     // Fold the trivial cases.
     if (retElemTp == srcElemTp && encDst == encSrc) {
-      rewriter.replaceOp(op, adaptor.getSource());
+      rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
       return success();
     }
     //
@@ -1172,7 +1191,9 @@
     //   else:
     //     dst = memref.copy(src)
     Location loc = op.getLoc();
-    auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource());
+    SparseTensorDescriptor srcDesc(
+        SparseTensorType(cast<RankedTensorType>(op.getSource().getType())),
+        adaptor.getSource());
     SmallVector<Value> fields;
     foreachFieldAndTypeInSparseTensor(
         SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
@@ -1224,7 +1245,7 @@
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
+  matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     MLIRContext *ctx = op.getContext();
@@ -1236,7 +1257,10 @@
     assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
 
     SmallVector<Value> fields;
-    auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields);
+    llvm::append_range(fields, adaptor.getSource());
+    MutSparseTensorDescriptor desc(
+        SparseTensorType(cast<RankedTensorType>(op.getSource().getType())),
+        fields);
 
     auto newSpec = rewriter.create<StorageSpecifierInitOp>(
         loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier());
@@ -1280,13 +1304,15 @@
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
+  matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Query memSizes for the actually stored values.
     // FIXME: the nse value computed in this way might be wrong when there is
     // any "loose_compressed" level.
-    rewriter.replaceOp(
-        op, genValMemSize(rewriter, op.getLoc(), adaptor.getTensor()));
+    SparseTensorDescriptor desc(
+        SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())),
+        adaptor.getTensor());
+    rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
     return success();
   }
 };
@@ -1413,9 +1439,11 @@
       : OpConversionPattern(typeConverter, context) {}
 
   LogicalResult
-  matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
+  matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    SparseTensorDescriptor desc(
+        SparseTensorType(cast<RankedTensorType>(op.getTensor().getType())),
+        adaptor.getTensor());
     Location loc = op.getLoc();
     SmallVector<Value> retMem;
     SmallVector<Value> retLen;
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 33fa9e4..9b96821 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -54,8 +54,6 @@
   });
 }
 
-/// Helper function that computes an insertion point where the given value is
-/// defined and can be used without a dominance violation.
 static OpBuilder::InsertPoint computeInsertPoint(Value value) {
   Block *insertBlock = value.getParentBlock();
   Block::iterator insertPt = insertBlock->begin();
@@ -64,6 +62,27 @@
   return OpBuilder::InsertPoint(insertBlock, insertPt);
 }
 
+/// Helper function that computes an insertion point where the given value is
+/// defined and can be used without a dominance violation.
+static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
+  assert(!vals.empty() && "expected at least one value");
+  OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
+  for (Value v : vals.drop_front()) {
+    OpBuilder::InsertPoint pt2 = computeInsertPoint(v);
+    assert(pt.getBlock() == pt2.getBlock());
+    if (pt.getPoint() == pt.getBlock()->begin()) {
+      pt = pt2;
+      continue;
+    }
+    if (pt2.getPoint() == pt2.getBlock()->begin()) {
+      continue;
+    }
+    if (pt.getPoint()->isBeforeInBlock(&*pt2.getPoint()))
+      pt = pt2;
+  }
+  return pt;
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionValueMapping
 //===----------------------------------------------------------------------===//
@@ -73,89 +92,220 @@
 using ReplacementValues = SmallVector<Value, 1>;
 
 namespace {
+struct SmallVectorMapInfo {
+  static SmallVector<Value, 1> getEmptyKey() { return SmallVector<Value, 1>{}; }
+  static SmallVector<Value, 1> getTombstoneKey() {
+    return SmallVector<Value, 1>{};
+  }
+  static ::llvm::hash_code getHashValue(SmallVector<Value, 1> val) {
+    return ::llvm::hash_combine_range(val.begin(), val.end());
+  }
+  static bool isEqual(SmallVector<Value, 1> LHS, SmallVector<Value, 1> RHS) {
+    return LHS == RHS;
+  }
+};
+
 /// This class wraps a IRMapping to provide recursive lookup
 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
 struct ConversionValueMapping {
-  /// Lookup the most recently mapped value with the desired type in the
-  /// mapping.
-  ///
-  /// Special cases:
-  /// - If the desired type is "null", simply return the most recently mapped
-  ///   value.
-  /// - If there is no mapping to the desired type, also return the most
-  ///   recently mapped value.
-  /// - If there is no mapping for the given value at all, return the given
-  ///   value.
-  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
+  /// Find the most recently mapped values for the given value. If the value is
+  /// not mapped at all, return the given value.
+  SmallVector<Value, 1> lookupOrDefault(Value from) const;
 
-  /// Lookup a mapped value within the map, or return null if a mapping does not
-  /// exist. If a mapping exists, this follows the same behavior of
-  /// `lookupOrDefault`.
-  Value lookupOrNull(Value from, Type desiredType = nullptr) const;
+  /// TODO: Find most recently mapped or materialization with matching type. May
+  /// return the given value if the type matches.
+  SmallVector<Value, 1>
+  lookupOrDefault(Value from, SmallVector<Type, 1> desiredTypes) const;
 
-  /// Map a value to the one provided.
-  void map(Value oldVal, Value newVal) {
-    LLVM_DEBUG({
-      for (Value it = newVal; it; it = mapping.lookupOrNull(it))
-        assert(it != oldVal && "inserting cyclic mapping");
-    });
-    mapping.map(oldVal, newVal);
+  Value lookupDirectSingleReplacement(Value from) const {
+    auto it = mapping.find(from);
+    if (it == mapping.end())
+      return Value();
+    const SmallVector<Value, 1> &repl = it->second;
+    if (repl.size() != 1) return Value();
+    return repl.front();
+/*
+    if (!mapping.contains(from)) return Value();
+    auto it = llvm::find(mapping, from);
+    const SmallVector<Value, 1> &repl = it->second;
+    if (repl.size() != 1) return Value();
+    return repl.front();
+    */
   }
 
-  /// Try to map a value to the one provided. Returns false if a transitive
-  /// mapping from the new value to the old value already exists, true if the
-  /// map was updated.
-  bool tryMap(Value oldVal, Value newVal);
+  /// Find the most recently mapped values for the given value. If the value is
+  /// not mapped at all, return an empty vector.
+  SmallVector<Value, 1> lookupOrNull(Value from) const;
 
-  /// Drop the last mapping for the given value.
-  void erase(Value value) { mapping.erase(value); }
+  /// Find the most recently mapped values for the given value. If those values
+  /// have the desired types, return them. Otherwise, try to find a
+  /// materialization to the desired types.
+  ///
+  /// If the given value is not mapped at all or if there are no mapped values/
+  /// materialization results with the desired types, return an empty vector.
+  SmallVector<Value, 1> lookupOrNull(Value from,
+                                     SmallVector<Type, 1> desiredTypes) const;
+
+  Value lookupOrNull(Value from, Type desiredType) {
+    SmallVector<Value, 1> vals =
+        lookupOrNull(from, SmallVector<Type, 1>{desiredType});
+    if (vals.empty())
+      return Value();
+    assert(vals.size() == 1 && "expected single value");
+    return vals.front();
+  }
+
+  void erase(Value from) { mapping.erase(from); }
+
+  void map(Value from, ArrayRef<BlockArgument> to) {
+    SmallVector<Value> vals;
+    for (Value v : to)
+      vals.push_back(v);
+    map(from, vals);
+  }
+
+  void map(Value from, ArrayRef<Value> to) {
+#ifndef NDEBUG
+    assert(from && "expected non-null value");
+    assert(!to.empty() && "cannot map to zero values");
+    for (Value v : to)
+      assert(v && "expected non-null value");
+#endif
+    // assert(from != to && "cannot map value to itself");
+    //  TODO: Check for cyclic mapping.
+    assert(!mapping.contains(from) && "value is already mapped");
+    mapping[from].assign(to.begin(), to.end());
+  }
+
+  void mapMaterialization(SmallVector<Value, 1> from,
+                          SmallVector<Value, 1> to) {
+#ifndef NDEBUG
+    assert(!from.empty() && "from cannot be empty");
+    assert(!to.empty() && "to cannot be empty");
+    for (Value v : from) {
+      assert(v && "expected non-null value");
+      assert(!mapping.contains(v) &&
+             "cannot add materialization for mapped value");
+    }
+    for (Value v : to) {
+      assert(v && "expected non-null value");
+    }
+    assert(TypeRange(from) != TypeRange(to) &&
+           "cannot add materialization for identical type");
+    for (const SmallVector<Value, 1> &mat : materializations[from])
+      assert(TypeRange(mat) != TypeRange(to) &&
+             "cannot register duplicate materialization");
+#endif // NDEBUG
+    materializations[from].push_back(to);
+  }
+
+  void eraseMaterialization(SmallVector<Value, 1> from,
+                            SmallVector<Value, 1> to) {
+    auto it = llvm::find(materializations[from], to);
+    if (it == materializations[from].end())
+      return;
+    materializations[from].erase(it);
+  }
 
   /// Returns the inverse raw value mapping (without recursive query support).
   DenseMap<Value, SmallVector<Value>> getInverse() const {
     DenseMap<Value, SmallVector<Value>> inverse;
-    for (auto &it : mapping.getValueMap())
-      inverse[it.second].push_back(it.first);
+
+    for (auto &it : mapping)
+      for (Value v : it.second)
+        inverse[v].push_back(it.first);
+
+    for (auto &it : materializations)
+      for (const SmallVector<Value, 1> &mat : it.second)
+        for (Value v : mat)
+          for (Value v2 : it.first)
+            inverse[v].push_back(v2);
+
     return inverse;
   }
 
 private:
-  /// Current value mappings.
-  IRMapping mapping;
+  /// Replacement mapping: Value -> ValueRange
+  DenseMap<Value, SmallVector<Value, 1>> mapping;
+
+  /// Materializations: ValueRange -> ValueRange*
+  DenseMap<SmallVector<Value, 1>, SmallVector<SmallVector<Value, 1>>,
+           SmallVectorMapInfo>
+      materializations;
 };
 } // namespace
 
-Value ConversionValueMapping::lookupOrDefault(Value from,
-                                              Type desiredType) const {
-  // Try to find the deepest value that has the desired type. If there is no
-  // such value, simply return the deepest value.
-  Value desiredValue;
-  do {
-    if (!desiredType || from.getType() == desiredType)
-      desiredValue = from;
-
-    Value mappedValue = mapping.lookupOrNull(from);
-    if (!mappedValue)
-      break;
-    from = mappedValue;
-  } while (true);
-
-  // If the desired value was found use it, otherwise default to the leaf value.
-  return desiredValue ? desiredValue : from;
+SmallVector<Value, 1>
+ConversionValueMapping::lookupOrDefault(Value from) const {
+  SmallVector<Value, 1> to = lookupOrNull(from);
+  return to.empty() ? SmallVector<Value, 1>{from} : to;
 }
 
-Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
-  Value result = lookupOrDefault(from, desiredType);
-  if (result == from || (desiredType && result.getType() != desiredType))
-    return nullptr;
+SmallVector<Value, 1> ConversionValueMapping::lookupOrDefault(
+    Value from, SmallVector<Type, 1> desiredTypes) const {
+#ifndef NDEBUG
+  assert(desiredTypes.size() > 0 && "expected non-empty types");
+  for (Type t : desiredTypes)
+    assert(t && "expected non-null type");
+#endif // NDEBUG
+
+  SmallVector<Value, 1> vals = lookupOrNull(from);
+  if (vals.empty()) {
+    // Value is not mapped. Return if the type matches.
+    if (TypeRange(from) == desiredTypes)
+      return {from};
+    // Check materializations.
+    auto it = materializations.find({from});
+    if (it == materializations.end())
+      return {};
+    for (const SmallVector<Value, 1> &mat : it->second)
+      if (TypeRange(mat) == desiredTypes)
+        return mat;
+    return {};
+  }
+
+  return lookupOrNull(from, desiredTypes);
+}
+
+SmallVector<Value, 1> ConversionValueMapping::lookupOrNull(Value from) const {
+  auto it = mapping.find(from);
+  if (it == mapping.end())
+    return {};
+  SmallVector<Value, 1> result;
+  for (Value v : it->second) {
+    llvm::append_range(result, lookupOrDefault(v));
+  }
   return result;
 }
 
-bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) {
-  for (Value it = newVal; it; it = mapping.lookupOrNull(it))
-    if (it == oldVal)
-      return false;
-  map(oldVal, newVal);
-  return true;
+SmallVector<Value, 1>
+ConversionValueMapping::lookupOrNull(Value from,
+                                     SmallVector<Type, 1> desiredTypes) const {
+#ifndef NDEBUG
+  assert(desiredTypes.size() > 0 && "expected non-empty types");
+  for (Type t : desiredTypes)
+    assert(t && "expected non-null type");
+#endif // NDEBUG
+
+  SmallVector<Value, 1> vals = lookupOrNull(from);
+  if (vals.empty())
+    return {};
+
+  // There is a mapping and the types match.
+  if (TypeRange(vals) == desiredTypes)
+    return vals;
+
+  // There is a mapping, but the types do not match. Try to find a matching
+  // materialization.
+  auto it = materializations.find(vals);
+  if (it == materializations.end())
+    return {};
+  for (const SmallVector<Value, 1> &mat : it->second)
+    if (TypeRange(mat) == desiredTypes)
+      return mat;
+
+  // No materialization found. Return an empty vector.
+  return {};
 }
 
 //===----------------------------------------------------------------------===//
@@ -781,7 +931,7 @@
   LogicalResult remapValues(StringRef valueDiagTag,
                             std::optional<Location> inputLoc,
                             PatternRewriter &rewriter, ValueRange values,
-                            SmallVectorImpl<Value> &remapped);
+                            SmallVector<SmallVector<Value, 1>> &remapped);
 
   /// Return "true" if the given operation is ignored, and does not need to be
   /// converted.
@@ -817,27 +967,12 @@
 
   /// Build an unresolved materialization operation given an output type and set
   /// of input operands.
-  Value buildUnresolvedMaterialization(MaterializationKind kind,
-                                       OpBuilder::InsertPoint ip, Location loc,
-                                       ValueRange inputs, Type outputType,
-                                       Type originalType,
-                                       const TypeConverter *converter);
-
-  /// Build an N:1 materialization for the given original value that was
-  /// replaced with the given replacement values.
-  ///
-  /// This is a workaround around incomplete 1:N support in the dialect
-  /// conversion driver. The conversion mapping can store only 1:1 replacements
-  /// and the conversion patterns only support single Value replacements in the
-  /// adaptor, so N values must be converted back to a single value. This
-  /// function will be deleted when full 1:N support has been added.
-  ///
-  /// This function inserts an argument materialization back to the original
-  /// type, followed by a target materialization to the legalized type (if
-  /// applicable).
-  void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
-                                 ValueRange replacements, Value originalValue,
-                                 const TypeConverter *converter);
+  ValueRange buildUnresolvedMaterialization(MaterializationKind kind,
+                                            OpBuilder::InsertPoint ip,
+                                            Location loc, ValueRange inputs,
+                                            TypeRange outputTypes,
+                                            Type originalType,
+                                            const TypeConverter *converter);
 
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
@@ -1072,10 +1207,8 @@
 }
 
 void UnresolvedMaterializationRewrite::rollback() {
-  if (getMaterializationKind() == MaterializationKind::Target) {
-    for (Value input : op->getOperands())
-      rewriterImpl.mapping.erase(input);
-  }
+  rewriterImpl.mapping.eraseMaterialization(op->getOperands(),
+                                            op->getResults());
   rewriterImpl.unresolvedMaterializations.erase(getOperation());
   op->erase();
 }
@@ -1120,7 +1253,7 @@
 LogicalResult ConversionPatternRewriterImpl::remapValues(
     StringRef valueDiagTag, std::optional<Location> inputLoc,
     PatternRewriter &rewriter, ValueRange values,
-    SmallVectorImpl<Value> &remapped) {
+    SmallVector<SmallVector<Value, 1>> &remapped) {
   remapped.reserve(llvm::size(values));
 
   for (const auto &it : llvm::enumerate(values)) {
@@ -1132,7 +1265,8 @@
       // The current pattern does not have a type converter. I.e., it does not
       // distinguish between legal and illegal types. For each operand, simply
       // pass through the most recently mapped value.
-      remapped.push_back(mapping.lookupOrDefault(operand));
+      SmallVector<Value, 1> vals = mapping.lookupOrDefault(operand);
+      remapped.push_back(vals);
       continue;
     }
 
@@ -1146,36 +1280,29 @@
       return failure();
     }
 
-    if (legalTypes.size() != 1) {
-      // TODO: Parts of the dialect conversion infrastructure do not support
-      // 1->N type conversions yet. Therefore, if a type is converted to 0 or
-      // multiple types, the only thing that we can do for now is passing
-      // through the most recently mapped value. Fixing this requires
-      // improvements to the `ConversionValueMapping` (to be able to store 1:N
-      // mappings) and to the `ConversionPattern` adaptor handling (to be able
-      // to pass multiple remapped values for a single operand to the adaptor).
-      remapped.push_back(mapping.lookupOrDefault(operand));
+    // Try to find a mapped value with the desired type.
+    if (legalTypes.empty()) {
+      remapped.push_back({});
       continue;
     }
 
-    // Handle 1->1 type conversions.
-    Type desiredType = legalTypes.front();
-    // Try to find a mapped value with the desired type. (Or the operand itself
-    // if the value is not mapped at all.)
-    Value newOperand = mapping.lookupOrDefault(operand, desiredType);
-    if (newOperand.getType() != desiredType) {
-      // If the looked up value's type does not have the desired type, it means
-      // that the value was replaced with a value of different type and no
-      // source materialization was created yet.
-      Value castValue = buildUnresolvedMaterialization(
-          MaterializationKind::Target, computeInsertPoint(newOperand),
-          operandLoc,
-          /*inputs=*/newOperand, /*outputType=*/desiredType,
-          /*originalType=*/origType, currentTypeConverter);
-      mapping.map(newOperand, castValue);
-      newOperand = castValue;
+    SmallVector<Value, 1> mat = mapping.lookupOrDefault(operand, legalTypes);
+    if (!mat.empty()) {
+      // Mapped value has the correct type or there is an existing
+      // materialization. Or the value is not mapped at all and has the
+      // correct type.
+      remapped.push_back(mat);
+      continue;
     }
-    remapped.push_back(newOperand);
+
+    // Create a materialization for the most recently mapped value.
+    SmallVector<Value, 1> vals = mapping.lookupOrDefault(operand);
+    ValueRange castValues = buildUnresolvedMaterialization(
+        MaterializationKind::Target, computeInsertPoint(vals), operandLoc,
+        /*inputs=*/vals, /*outputTypes=*/legalTypes, /*originalType=*/origType, currentTypeConverter);
+
+    mapping.mapMaterialization(vals, castValues);
+    remapped.push_back(castValues);
   }
   return success();
 }
@@ -1287,7 +1414,7 @@
           MaterializationKind::Source,
           OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
           /*inputs=*/ValueRange(),
-          /*outputType=*/origArgType, /*originalType=*/Type(), converter);
+          /*outputTypes=*/origArgType, /*originalType=*/Type(), converter)[0];
       mapping.map(origArg, repl);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
       continue;
@@ -1303,15 +1430,10 @@
       continue;
     }
 
-    // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
-    // dialect conversion. Therefore, we need an argument materialization to
-    // turn the replacement block arguments into a single SSA value that can be
-    // used as a replacement.
+    // Map to replacement arguments.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    insertNTo1Materialization(
-        OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-        /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+    mapping.map(origArg, replArgs);
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
   }
 
@@ -1330,59 +1452,21 @@
 
 /// Build an unresolved materialization operation given an output type and set
 /// of input operands.
-Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
+ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-    ValueRange inputs, Type outputType, Type originalType,
-    const TypeConverter *converter) {
-  assert((!originalType || kind == MaterializationKind::Target) &&
-         "original type is valid only for target materializations");
-
+    ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter) {
   // Avoid materializing an unnecessary cast.
-  if (inputs.size() == 1 && inputs.front().getType() == outputType)
-    return inputs.front();
+  if (TypeRange(inputs) == outputTypes)
+    return inputs;
 
   // Create an unresolved materialization. We use a new OpBuilder to avoid
   // tracking the materialization like we do for other operations.
-  OpBuilder builder(outputType.getContext());
+  OpBuilder builder(outputTypes.front().getContext());
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
   auto convertOp =
-      builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
-  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
-                                                  originalType);
-  return convertOp.getResult(0);
-}
-
-void ConversionPatternRewriterImpl::insertNTo1Materialization(
-    OpBuilder::InsertPoint ip, Location loc, ValueRange replacements,
-    Value originalValue, const TypeConverter *converter) {
-  // Insert argument materialization back to the original type.
-  Type originalType = originalValue.getType();
-  Value argMat =
-      buildUnresolvedMaterialization(MaterializationKind::Argument, ip, loc,
-                                     /*inputs=*/replacements, originalType,
-                                     /*originalType=*/Type(), converter);
-  mapping.map(originalValue, argMat);
-
-  // Insert target materialization to the legalized type.
-  Type legalOutputType;
-  if (converter) {
-    legalOutputType = converter->convertType(originalType);
-  } else if (replacements.size() == 1) {
-    // When there is no type converter, assume that the replacement value
-    // types are legal. This is reasonable to assume because they were
-    // specified by the user.
-    // FIXME: This won't work for 1->N conversions because multiple output
-    // types are not supported in parts of the dialect conversion. In such a
-    // case, we currently use the original value type.
-    legalOutputType = replacements[0].getType();
-  }
-  if (legalOutputType && legalOutputType != originalType) {
-    Value targetMat = buildUnresolvedMaterialization(
-        MaterializationKind::Target, computeInsertPoint(argMat), loc,
-        /*inputs=*/argMat, /*outputType=*/legalOutputType,
-        /*originalType=*/originalType, converter);
-    mapping.map(argMat, targetMat);
-  }
+      builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
+  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind, originalType);
+  return convertOp.getResults();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1432,12 +1516,11 @@
       }
 
       // Materialize a replacement value "out of thin air".
-      Value sourceMat = buildUnresolvedMaterialization(
+      repl = buildUnresolvedMaterialization(
           MaterializationKind::Source, computeInsertPoint(result),
           result.getLoc(), /*inputs=*/ValueRange(),
           /*outputType=*/result.getType(), /*originalType=*/Type(),
           currentTypeConverter);
-      repl.push_back(sourceMat);
     } else {
       // Make sure that the user does not mess with unresolved materializations
       // that were inserted by the conversion driver. We keep track of these
@@ -1450,18 +1533,8 @@
     }
 
     // Remap result to replacement value.
-    if (repl.empty())
-      continue;
-
-    if (repl.size() == 1) {
-      // Single replacement value: replace directly.
-      mapping.map(result, repl.front());
-    } else {
-      // Multiple replacement values: insert N:1 materialization.
-      insertNTo1Materialization(computeInsertPoint(result), result.getLoc(),
-                                /*replacements=*/repl, /*outputValue=*/result,
-                                currentTypeConverter);
-    }
+    if (!repl.empty())
+      mapping.map(result, repl);
   }
 
   appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
@@ -1612,15 +1685,18 @@
                              << "'(" << from.getOwner()->getParentOp() << ")\n";
   });
   impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
-  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
+  SmallVector<Value, 1> mapped = impl->mapping.lookupOrDefault(from);
+  assert(mapped.size() == 1 && "replaceUsesOfBlockArgument is not supported for 1:N replacements");
+  impl->mapping.map(mapped.front(), to);
 }
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
-  SmallVector<Value> remappedValues;
+  SmallVector<SmallVector<Value, 1>> remappedValues;
   if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
                                remappedValues)))
     return nullptr;
-  return remappedValues.front();
+  assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
+  return remappedValues.front().front();
 }
 
 LogicalResult
@@ -1628,8 +1704,15 @@
                                              SmallVectorImpl<Value> &results) {
   if (keys.empty())
     return success();
-  return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
-                           results);
+  SmallVector<SmallVector<Value, 1>> remapped;
+  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
+                               remapped)))
+    return failure();
+  for (const auto &values : remapped) {
+    assert(values.size() == 1 && "1:N conversion not supported");
+    results.push_back(values.front());
+  }
+  return success();
 }
 
 void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
@@ -1723,6 +1806,19 @@
 // ConversionPattern
 //===----------------------------------------------------------------------===//
 
+SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands(
+    ArrayRef<ArrayRef<Value>> operands) const {
+  SmallVector<Value> oneToOneOperands;
+  oneToOneOperands.reserve(operands.size());
+  for (ArrayRef<Value> operand : operands) {
+    if (operand.size() != 1)
+      llvm::report_fatal_error("pattern '" + getDebugName() +
+                               "' does not support 1:N conversion");
+    oneToOneOperands.push_back(operand.front());
+  }
+  return oneToOneOperands;
+}
+
 LogicalResult
 ConversionPattern::matchAndRewrite(Operation *op,
                                    PatternRewriter &rewriter) const {
@@ -1734,12 +1830,18 @@
                                              getTypeConverter());
 
   // Remap the operands of the operation.
-  SmallVector<Value, 4> operands;
+  SmallVector<SmallVector<Value, 1>> remapped;
   if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
-                                      op->getOperands(), operands))) {
+                                      op->getOperands(), remapped))) {
     return failure();
   }
-  return matchAndRewrite(op, operands, dialectRewriter);
+
+  // Convert to ArrayRef.
+  // TODO: This should not be necessary.
+  SmallVector<ArrayRef<Value>> remappedArrayRef;
+  for (const auto &vals : remapped)
+    remappedArrayRef.push_back(vals);
+  return matchAndRewrite(op, remappedArrayRef, dialectRewriter);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2483,45 +2585,40 @@
   assert(!op.use_empty() &&
          "expected that dead materializations have already been DCE'd");
   Operation::operand_range inputOperands = op.getOperands();
-  Type outputType = op.getResultTypes()[0];
 
   // Try to materialize the conversion.
   if (const TypeConverter *converter = rewrite->getConverter()) {
     rewriter.setInsertionPoint(op);
-    Value newMaterialization;
+    SmallVector<Value> newMaterialization;
     switch (rewrite->getMaterializationKind()) {
     case MaterializationKind::Argument:
-      // Try to materialize an argument conversion.
-      newMaterialization = converter->materializeArgumentConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      if (newMaterialization)
-        break;
-      // If an argument materialization failed, fallback to trying a target
-      // materialization.
-      [[fallthrough]];
+      llvm_unreachable("argument materializations have been removed");
     case MaterializationKind::Target:
       newMaterialization = converter->materializeTargetConversion(
-          rewriter, op->getLoc(), outputType, inputOperands,
+          rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
           rewrite->getOriginalType());
       break;
     case MaterializationKind::Source:
-      newMaterialization = converter->materializeSourceConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
+      assert(op.getNumResults() == 1 && "*:N source materializations are not supported");
+      Value sourceMat = converter->materializeSourceConversion(
+          rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
+      if (sourceMat)
+        newMaterialization.push_back(sourceMat);
       break;
     }
-    if (newMaterialization) {
-      assert(newMaterialization.getType() == outputType &&
+    if (!newMaterialization.empty()) {
+      assert(TypeRange(newMaterialization) == op.getResultTypes() &&
              "materialization callback produced value of incorrect type");
       rewriter.replaceOp(op, newMaterialization);
       return success();
     }
   }
 
-  InFlightDiagnostic diag =
-      op->emitError() << "failed to legalize unresolved materialization "
-                         "from ("
-                      << inputOperands.getTypes() << ") to (" << outputType
-                      << ") that remained live after conversion";
+  InFlightDiagnostic diag = op->emitError()
+                            << "failed to legalize unresolved materialization "
+                               "from ("
+                            << inputOperands.getTypes() << ") to (" << op.getResultTypes()
+                            << ") that remained live after conversion";
   diag.attachNote(op->getUsers().begin()->getLoc())
       << "see existing live user here: " << *op->getUsers().begin();
   return failure();
@@ -2642,6 +2739,11 @@
     std::tie(replacedValues, converter) =
         getReplacedValues(rewriterImpl.rewrites[i].get());
     for (Value originalValue : replacedValues) {
+      // If this value is directly replaced with a value of the same type,
+      // there is nothing to do.
+      Value repl = rewriterImpl.mapping.lookupDirectSingleReplacement(originalValue);
+      if (repl && repl.getType() == originalValue.getType())
+        continue;
       // If the type of this value changed and the value is still live, we need
       // to materialize a conversion.
       if (rewriterImpl.mapping.lookupOrNull(originalValue,
@@ -2653,16 +2755,16 @@
         continue;
 
       // Legalize this value replacement.
-      Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
-      assert(newValue && "replacement value not found");
+      SmallVector<Value, 1> newValues =
+          rewriterImpl.mapping.lookupOrNull(originalValue);
+      assert(!newValues.empty() && "replacement value not found");
       Value castValue = rewriterImpl.buildUnresolvedMaterialization(
-          MaterializationKind::Source, computeInsertPoint(newValue),
+          MaterializationKind::Source, computeInsertPoint(newValues),
           originalValue.getLoc(),
-          /*inputs=*/newValue, /*outputType=*/originalValue.getType(),
-          /*originalType=*/Type(), converter);
-      rewriterImpl.mapping.map(originalValue, castValue);
-      inverseMapping[castValue].push_back(originalValue);
-      llvm::erase(inverseMapping[newValue], originalValue);
+          /*inputs=*/newValues, /*outputTypes=*/originalValue.getType(), /*originalType=*/Type(),
+          converter)[0];
+      rewriterImpl.mapping.mapMaterialization(newValues, {castValue});
+      llvm::append_range(inverseMapping[castValue], newValues);
     }
   }
 }
diff --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir
index b8fad63..4e64131 100644
--- a/mlir/test/Transforms/decompose-call-graph-types.mlir
+++ b/mlir/test/Transforms/decompose-call-graph-types.mlir
@@ -9,10 +9,7 @@
 // CHECK-LABEL:   func @identity(
 // CHECK-SAME:                   %[[ARG0:.*]]: i1,
 // CHECK-SAME:                   %[[ARG1:.*]]: i32) -> (i1, i32) {
-// CHECK:           %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32>
-// CHECK:           %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
-// CHECK:           %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
-// CHECK:           return %[[RET0]], %[[RET1]] : i1, i32
+// CHECK:           return %[[ARG0]], %[[ARG1]] : i1, i32
 // CHECK-12N-LABEL:   func @identity(
 // CHECK-12N-SAME:                   %[[ARG0:.*]]: i1,
 // CHECK-12N-SAME:                   %[[ARG1:.*]]: i32) -> (i1, i32) {
@@ -56,18 +53,7 @@
 // CHECK-LABEL:   func @mixed_recursive_decomposition(
 // CHECK-SAME:                 %[[ARG0:.*]]: i1,
 // CHECK-SAME:                 %[[ARG1:.*]]: i2) -> (i1, i2) {
-// CHECK:           %[[V0:.*]] = "test.make_tuple"() : () -> tuple<>
-// CHECK:           %[[V1:.*]] = "test.make_tuple"(%[[ARG0]]) : (i1) -> tuple<i1>
-// CHECK:           %[[V2:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2>
-// CHECK:           %[[V3:.*]] = "test.make_tuple"(%[[V2]]) : (tuple<i2>) -> tuple<tuple<i2>>
-// CHECK:           %[[V4:.*]] = "test.make_tuple"(%[[V0]], %[[V1]], %[[V3]]) : (tuple<>, tuple<i1>, tuple<tuple<i2>>) -> tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>
-// CHECK:           %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 0 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<>
-// CHECK:           %[[V6:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 1 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<i1>
-// CHECK:           %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<i1>) -> i1
-// CHECK:           %[[V8:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 2 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
-// CHECK:           %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
-// CHECK:           %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
-// CHECK:           return %[[V7]], %[[V10]] : i1, i2
+// CHECK:           return %[[ARG0]], %[[ARG1]] : i1, i2
 // CHECK-12N-LABEL:   func @mixed_recursive_decomposition(
 // CHECK-12N-SAME:                 %[[ARG0:.*]]: i1,
 // CHECK-12N-SAME:                 %[[ARG1:.*]]: i2) -> (i1, i2) {
@@ -87,14 +73,8 @@
 // CHECK-LABEL:   func @caller(
 // CHECK-SAME:                 %[[ARG0:.*]]: i1,
 // CHECK-SAME:                 %[[ARG1:.*]]: i32) -> (i1, i32) {
-// CHECK:           %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32>
-// CHECK:           %[[CALL_ARG0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
-// CHECK:           %[[CALL_ARG1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
-// CHECK:           %[[DECOMPOSED:.*]]:2 = call @callee(%[[CALL_ARG0]], %[[CALL_ARG1]]) : (i1, i32) -> (i1, i32)
-// CHECK:           %[[CALL_RESULT_RECOMPOSED:.*]] = "test.make_tuple"(%[[DECOMPOSED]]#0, %[[DECOMPOSED]]#1) : (i1, i32) -> tuple<i1, i32>
-// CHECK:           %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
-// CHECK:           %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
-// CHECK:           return %[[RET0]], %[[RET1]] : i1, i32
+// CHECK:           %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32)
+// CHECK:           return %[[V0]]#0, %[[V0]]#1 : i1, i32
 // CHECK-12N-LABEL:   func @caller(
 // CHECK-12N-SAME:                 %[[ARG0:.*]]: i1,
 // CHECK-12N-SAME:                 %[[ARG1:.*]]: i32) -> (i1, i32) {
@@ -190,14 +170,8 @@
 // CHECK-SAME:                 %[[I4:.*]]: i4,
 // CHECK-SAME:                 %[[I5:.*]]: i5,
 // CHECK-SAME:                 %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) {
-// CHECK:           %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[I4]], %[[I5]]) : (i4, i5) -> tuple<i4, i5>
-// CHECK:           %[[ARG_TUPLE_0:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4
-// CHECK:           %[[ARG_TUPLE_1:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5
-// CHECK:           %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[ARG_TUPLE_0]], %[[ARG_TUPLE_1]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
-// CHECK:           %[[RET_TUPLE:.*]] = "test.make_tuple"(%[[CALL]]#3, %[[CALL]]#4) : (i4, i5) -> tuple<i4, i5>
-// CHECK:           %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4
-// CHECK:           %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5
-// CHECK:           return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[RET_TUPLE_0]], %[[RET_TUPLE_1]], %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
+// CHECK:           %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
+// CHECK:           return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
 // CHECK-12N-LABEL:   func @caller(
 // CHECK-12N-SAME:                 %[[I1:.*]]: i1,
 // CHECK-12N-SAME:                 %[[I2:.*]]: i2,
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index de511c5..0b8d4c0 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -139,7 +139,7 @@
           tupleType.getFlattenedTypes(types);
           return success();
         });
-    typeConverter.addArgumentMaterialization(buildMakeTupleOp);
+    typeConverter.addSourceMaterialization(buildMakeTupleOp);
     typeConverter.addTargetMaterialization(buildDecomposeTuple);
 
     populateDecomposeCallGraphTypesPatterns(context, typeConverter, patterns);
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 3df6cff..9154964 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1076,6 +1076,7 @@
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
+    llvm::errs() << "TestUpdateConsumerType operand: " << operands.front() << "\n";
     // Verify that the incoming operand has been successfully remapped to F64.
     if (!operands[0].getType().isF64())
       return failure();