[Coroutines] Enhance symmetric transfer for constant CmpInst

This fixes bug52896.

Simply, some symmetric transfer optimization chances get invalided due
to we delete some inlined optimization passes in 822b92a. This would
cause stack-overflow in some situations which should be avoided by the
design of coroutine. This patch tries to fix this by transforming the
constant CmpInst instruction which was done in the deleted passes.

Reviewed By: rjmccall, junparser

Differential Revision: https://reviews.llvm.org/D116327

(cherry picked from commit 403772ff1ce5618c8d02316531386b415312274a)
diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index b6932db..fc83bef 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -29,6 +29,7 @@
 #include "llvm/Analysis/CFG.h"
 #include "llvm/Analysis/CallGraph.h"
 #include "llvm/Analysis/CallGraphSCCPass.h"
+#include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/LazyCallGraph.h"
 #include "llvm/IR/Argument.h"
 #include "llvm/IR/Attributes.h"
@@ -1174,6 +1175,15 @@
 static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
   DenseMap<Value *, Value *> ResolvedValues;
   BasicBlock *UnconditionalSucc = nullptr;
+  assert(InitialInst->getModule());
+  const DataLayout &DL = InitialInst->getModule()->getDataLayout();
+
+  auto TryResolveConstant = [&ResolvedValues](Value *V) {
+    auto It = ResolvedValues.find(V);
+    if (It != ResolvedValues.end())
+      V = It->second;
+    return dyn_cast<ConstantInt>(V);
+  };
 
   Instruction *I = InitialInst;
   while (I->isTerminator() ||
@@ -1190,47 +1200,65 @@
     }
     if (auto *BR = dyn_cast<BranchInst>(I)) {
       if (BR->isUnconditional()) {
-        BasicBlock *BB = BR->getSuccessor(0);
+        BasicBlock *Succ = BR->getSuccessor(0);
         if (I == InitialInst)
-          UnconditionalSucc = BB;
-        scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
-        I = BB->getFirstNonPHIOrDbgOrLifetime();
+          UnconditionalSucc = Succ;
+        scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues);
+        I = Succ->getFirstNonPHIOrDbgOrLifetime();
+        continue;
+      }
+
+      BasicBlock *BB = BR->getParent();
+      // Handle the case the condition of the conditional branch is constant.
+      // e.g.,
+      //
+      //     br i1 false, label %cleanup, label %CoroEnd
+      //
+      // It is possible during the transformation. We could continue the
+      // simplifying in this case.
+      if (ConstantFoldTerminator(BB, /*DeleteDeadConditions=*/true)) {
+        // Handle this branch in next iteration.
+        I = BB->getTerminator();
         continue;
       }
     } else if (auto *CondCmp = dyn_cast<CmpInst>(I)) {
+      // If the case number of suspended switch instruction is reduced to
+      // 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator.
       auto *BR = dyn_cast<BranchInst>(I->getNextNode());
-      if (BR && BR->isConditional() && CondCmp == BR->getCondition()) {
-        // If the case number of suspended switch instruction is reduced to
-        // 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator.
-        // And the comparsion looks like : %cond = icmp eq i8 %V, constant.
-        ConstantInt *CondConst = dyn_cast<ConstantInt>(CondCmp->getOperand(1));
-        if (CondConst && CondCmp->getPredicate() == CmpInst::ICMP_EQ) {
-          Value *V = CondCmp->getOperand(0);
-          auto it = ResolvedValues.find(V);
-          if (it != ResolvedValues.end())
-            V = it->second;
+      if (!BR || !BR->isConditional() || CondCmp != BR->getCondition())
+        return false;
 
-          if (ConstantInt *Cond0 = dyn_cast<ConstantInt>(V)) {
-            BasicBlock *BB = Cond0->equalsInt(CondConst->getZExtValue())
-                                 ? BR->getSuccessor(0)
-                                 : BR->getSuccessor(1);
-            scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
-            I = BB->getFirstNonPHIOrDbgOrLifetime();
-            continue;
-          }
-        }
-      }
+      // And the comparsion looks like : %cond = icmp eq i8 %V, constant.
+      // So we try to resolve constant for the first operand only since the
+      // second operand should be literal constant by design.
+      ConstantInt *Cond0 = TryResolveConstant(CondCmp->getOperand(0));
+      auto *Cond1 = dyn_cast<ConstantInt>(CondCmp->getOperand(1));
+      if (!Cond0 || !Cond1)
+        return false;
+
+      // Both operands of the CmpInst are Constant. So that we could evaluate
+      // it immediately to get the destination.
+      auto *ConstResult =
+          dyn_cast_or_null<ConstantInt>(ConstantFoldCompareInstOperands(
+              CondCmp->getPredicate(), Cond0, Cond1, DL));
+      if (!ConstResult)
+        return false;
+
+      CondCmp->replaceAllUsesWith(ConstResult);
+      CondCmp->eraseFromParent();
+
+      // Handle this branch in next iteration.
+      I = BR;
+      continue;
     } else if (auto *SI = dyn_cast<SwitchInst>(I)) {
-      Value *V = SI->getCondition();
-      auto it = ResolvedValues.find(V);
-      if (it != ResolvedValues.end())
-        V = it->second;
-      if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) {
-        BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
-        scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
-        I = BB->getFirstNonPHIOrDbgOrLifetime();
-        continue;
-      }
+      ConstantInt *Cond = TryResolveConstant(SI->getCondition());
+      if (!Cond)
+        return false;
+
+      BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
+      scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
+      I = BB->getFirstNonPHIOrDbgOrLifetime();
+      continue;
     }
     return false;
   }
diff --git a/llvm/test/Transforms/Coroutines/coro-split-musttail4.ll b/llvm/test/Transforms/Coroutines/coro-split-musttail4.ll
new file mode 100644
index 0000000..0d73d94
--- /dev/null
+++ b/llvm/test/Transforms/Coroutines/coro-split-musttail4.ll
@@ -0,0 +1,65 @@
+; Tests that coro-split will convert a call before coro.suspend to a musttail call
+; while the user of the coro.suspend is a icmpinst.
+; RUN: opt < %s -passes='cgscc(coro-split),simplifycfg,early-cse' -S | FileCheck %s
+
+define void @fakeresume1(i8*)  {
+entry:
+  ret void;
+}
+
+define void @f() #0 {
+entry:
+  %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
+  %alloc = call i8* @malloc(i64 16) #3
+  %vFrame = call noalias nonnull i8* @llvm.coro.begin(token %id, i8* %alloc)
+
+  %save = call token @llvm.coro.save(i8* null)
+
+  %init_suspend = call i8 @llvm.coro.suspend(token %save, i1 false)
+  switch i8 %init_suspend, label %coro.end [
+    i8 0, label %await.ready
+    i8 1, label %coro.end
+  ]
+await.ready:
+  %save2 = call token @llvm.coro.save(i8* null)
+
+  call fastcc void @fakeresume1(i8* align 8 null)
+  %suspend = call i8 @llvm.coro.suspend(token %save2, i1 true)
+  %switch = icmp ult i8 %suspend, 2
+  br i1 %switch, label %cleanup, label %coro.end
+
+cleanup:
+  %free.handle = call i8* @llvm.coro.free(token %id, i8* %vFrame)
+  %.not = icmp eq i8* %free.handle, null
+  br i1 %.not, label %coro.end, label %coro.free
+
+coro.free:
+  call void @delete(i8* nonnull %free.handle) #2
+  br label %coro.end
+
+coro.end:
+  call i1 @llvm.coro.end(i8* null, i1 false)
+  ret void
+}
+
+; CHECK-LABEL: @f.resume(
+; CHECK:          musttail call fastcc void @fakeresume1(
+; CHECK-NEXT:     ret void
+
+declare token @llvm.coro.id(i32, i8* readnone, i8* nocapture readonly, i8*) #1
+declare i1 @llvm.coro.alloc(token) #2
+declare i64 @llvm.coro.size.i64() #3
+declare i8* @llvm.coro.begin(token, i8* writeonly) #2
+declare token @llvm.coro.save(i8*) #2
+declare i8* @llvm.coro.frame() #3
+declare i8 @llvm.coro.suspend(token, i1) #2
+declare i8* @llvm.coro.free(token, i8* nocapture readonly) #1
+declare i1 @llvm.coro.end(i8*, i1) #2
+declare i8* @llvm.coro.subfn.addr(i8* nocapture readonly, i8) #1
+declare i8* @malloc(i64)
+declare void @delete(i8* nonnull) #2
+
+attributes #0 = { "coroutine.presplit"="1" }
+attributes #1 = { argmemonly nounwind readonly }
+attributes #2 = { nounwind }
+attributes #3 = { nounwind readnone }