cse
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index c48043b..8008958 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3180,6 +3180,49 @@
return failure();
}
+static SmallVector<UnrealizedConversionCastOp>
+cseUnrealizedCasts(SmallVectorImpl<UnrealizedConversionCastOp> &castOps) {
+ SmallVector<UnrealizedConversionCastOp> result;
+ DominanceInfo domInfo;
+ DenseMap<unsigned, SmallVector<UnrealizedConversionCastOp>> hashedOps;
+ for (UnrealizedConversionCastOp castOp : castOps) {
+ unsigned hash = 0;
+ for (Type type : castOp.getResultTypes())
+ hash ^= hash_value(type);
+ for (Value value : castOp.getInputs())
+ hash ^= hash_value(value);
+ hashedOps[hash].push_back(castOp);
+ }
+ // TODO: This should run to a fixed point.
+ DenseSet<UnrealizedConversionCastOp> erasedOps;
+ for (auto &it : hashedOps) {
+ SmallVector<UnrealizedConversionCastOp> &ops = it.second;
+ if (ops.size() == 1)
+ continue;
+ UnrealizedConversionCastOp top = ops.front();
+ for (UnrealizedConversionCastOp castOp : llvm::drop_begin(ops)) {
+ if (castOp.getInputs() != top.getInputs())
+ continue;
+ if (castOp.getResultTypes() != top.getResultTypes())
+ continue;
+ if (domInfo.dominates(castOp, top)) {
+ std::swap(top, castOp);
+ }
+ if (domInfo.properlyDominates(top, castOp)) {
+ castOp.replaceAllUsesWith(top);
+ castOp.erase();
+ erasedOps.insert(castOp);
+ continue;
+ }
+ }
+ }
+
+ for (UnrealizedConversionCastOp castOp : castOps)
+ if (!erasedOps.contains(castOp))
+ result.push_back(castOp);
+ return result;
+}
+
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
assert(!ops.empty() && "expected at least one operation");
const ConversionTarget &target = opLegalizer.getTarget();
@@ -3233,6 +3276,7 @@
// patterns.)
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
+ remainingCastOps = cseUnrealizedCasts(remainingCastOps);
// Drop markers.
for (UnrealizedConversionCastOp castOp : remainingCastOps)