cse
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index c48043b..8008958 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3180,6 +3180,49 @@
   return failure();
 }
 
+static SmallVector<UnrealizedConversionCastOp>
+cseUnrealizedCasts(SmallVectorImpl<UnrealizedConversionCastOp> &castOps) {
+  SmallVector<UnrealizedConversionCastOp> result;
+  DominanceInfo domInfo;
+  DenseMap<unsigned, SmallVector<UnrealizedConversionCastOp>> hashedOps;
+  for (UnrealizedConversionCastOp castOp : castOps) {
+    unsigned hash = 0;
+    for (Type type : castOp.getResultTypes())
+      hash ^= hash_value(type);
+    for (Value value : castOp.getInputs())
+      hash ^= hash_value(value);
+    hashedOps[hash].push_back(castOp);
+  }
+  // TODO: This should run to a fixed point.
+  DenseSet<UnrealizedConversionCastOp> erasedOps;
+  for (auto &it : hashedOps) {
+    SmallVector<UnrealizedConversionCastOp> &ops = it.second;
+    if (ops.size() == 1)
+      continue;
+    UnrealizedConversionCastOp top = ops.front();
+    for (UnrealizedConversionCastOp castOp : llvm::drop_begin(ops)) {
+      if (castOp.getInputs() != top.getInputs())
+        continue;
+      if (castOp.getResultTypes() != top.getResultTypes())
+        continue;
+      if (domInfo.dominates(castOp, top)) {
+        std::swap(top, castOp);
+      }
+      if (domInfo.properlyDominates(top, castOp)) {
+        castOp.replaceAllUsesWith(top);
+        castOp.erase();
+        erasedOps.insert(castOp);
+        continue;
+      }
+    }
+  }
+
+  for (UnrealizedConversionCastOp castOp : castOps)
+    if (!erasedOps.contains(castOp))
+      result.push_back(castOp);
+  return result;
+}
+
 LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   assert(!ops.empty() && "expected at least one operation");
   const ConversionTarget &target = opLegalizer.getTarget();
@@ -3233,6 +3276,7 @@
   // patterns.)
   SmallVector<UnrealizedConversionCastOp> remainingCastOps;
   reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
+  remainingCastOps = cseUnrealizedCasts(remainingCastOps);
 
   // Drop markers.
   for (UnrealizedConversionCastOp castOp : remainingCastOps)