[mlir][Transforms] Greedy pattern rewriter: fix infinite folding loop
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7cfd6d3..04daed8c 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -393,7 +393,7 @@
OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
// addi(x, 0) -> x
- if (matchPattern(adaptor.getRhs(), m_Zero()))
+ if (matchPattern(adaptor.getRhs(), m_Zero()) && getLhs() != *this)
return getLhs();
// addi(subi(a, b), b) -> a
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 74e4a82..93468dd 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -555,7 +555,8 @@
replacements.push_back(constOp->getResult(0));
}
- if (materializationSucceeded) {
+ if (materializationSucceeded &&
+ !llvm::equal(replacements, op->getResults())) {
rewriter.replaceOp(op, replacements);
changed = true;
LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 8e02c06..ed987e5 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1248,3 +1248,12 @@
%u = index.castu %const : index to i64
return %u: i64
}
+
+// -----
+
+// Make sure that the canonicalizer does not fold infinitely.
+
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+%c0 = arith.constant 0 : index
+// CHECK: %[[add:.*]] = arith.addi %[[c0]], %[[add]] : index
+%0 = arith.addi %c0, %0 : index