BreakCriticalEdges: Update PostDominatorTree

llvm-svn: 354673
diff --git a/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h b/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h
index e997ed0..a7bbe28 100644
--- a/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h
@@ -35,6 +35,7 @@
 class MDNode;
 class MemoryDependenceResults;
 class MemorySSAUpdater;
+class PostDominatorTree;
 class ReturnInst;
 class TargetLibraryInfo;
 class Value;
@@ -103,6 +104,7 @@
 /// during critical edge splitting.
 struct CriticalEdgeSplittingOptions {
   DominatorTree *DT;
+  PostDominatorTree *PDT;
   LoopInfo *LI;
   MemorySSAUpdater *MSSAU;
   bool MergeIdenticalEdges = false;
@@ -111,8 +113,9 @@
 
   CriticalEdgeSplittingOptions(DominatorTree *DT = nullptr,
                                LoopInfo *LI = nullptr,
-                               MemorySSAUpdater *MSSAU = nullptr)
-      : DT(DT), LI(LI), MSSAU(MSSAU) {}
+                               MemorySSAUpdater *MSSAU = nullptr,
+                               PostDominatorTree *PDT = nullptr)
+      : DT(DT), PDT(PDT), LI(LI), MSSAU(MSSAU) {}
 
   CriticalEdgeSplittingOptions &setMergeIdenticalEdges() {
     MergeIdenticalEdges = true;
diff --git a/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp b/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp
index 2944c37..3b4b0b5 100644
--- a/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp
+++ b/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp
@@ -23,6 +23,7 @@
 #include "llvm/Analysis/CFG.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/MemorySSAUpdater.h"
+#include "llvm/Analysis/PostDominators.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Instructions.h"
@@ -48,10 +49,14 @@
     bool runOnFunction(Function &F) override {
       auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
       auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
+
+      auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>();
+      auto *PDT = PDTWP ? &PDTWP->getPostDomTree() : nullptr;
+
       auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
       auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
       unsigned N =
-          SplitAllCriticalEdges(F, CriticalEdgeSplittingOptions(DT, LI));
+          SplitAllCriticalEdges(F, CriticalEdgeSplittingOptions(DT, LI, nullptr, PDT));
       NumBroken += N;
       return N > 0;
     }
@@ -201,16 +206,17 @@
 
   // If we have nothing to update, just return.
   auto *DT = Options.DT;
+  auto *PDT = Options.PDT;
   auto *LI = Options.LI;
   auto *MSSAU = Options.MSSAU;
   if (MSSAU)
     MSSAU->wireOldPredecessorsToNewImmediatePredecessor(
         DestBB, NewBB, {TIBB}, Options.MergeIdenticalEdges);
 
-  if (!DT && !LI)
+  if (!DT && !PDT && !LI)
     return NewBB;
 
-  if (DT) {
+  if (DT || PDT) {
     // Update the DominatorTree.
     //       ---> NewBB -----\
     //      /                 V
@@ -226,7 +232,10 @@
     if (llvm::find(successors(TIBB), DestBB) == succ_end(TIBB))
       Updates.push_back({DominatorTree::Delete, TIBB, DestBB});
 
-    DT->applyUpdates(Updates);
+    if (DT)
+      DT->applyUpdates(Updates);
+    if (PDT)
+      PDT->applyUpdates(Updates);
   }
 
   // Update LoopInfo if it is around.
diff --git a/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp b/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp
index db30837..2d3731c 100644
--- a/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Utils/BasicBlockUtilsTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Analysis/PostDominators.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Dominators.h"
@@ -49,3 +50,31 @@
   SplitBlockPredecessors(&F->getEntryBlock(), {}, "split.entry", &DT);
   EXPECT_TRUE(DT.verify());
 }
+
+TEST(BasicBlockUtils, SplitCriticalEdge) {
+  LLVMContext C;
+
+  std::unique_ptr<Module> M = parseIR(
+    C,
+    "define void @crit_edge(i1 %cond0, i1 %cond1) {\n"
+    "entry:\n"
+    "  br i1 %cond0, label %bb0, label %bb1\n"
+    "bb0:\n"
+    "  br label %bb1\n"
+    "bb1:\n"
+    "  br label %bb2\n"
+    "bb2:\n"
+    "  ret void\n"
+    "}\n"
+    "\n"
+    );
+
+  auto *F = M->getFunction("crit_edge");
+  DominatorTree DT(*F);
+  PostDominatorTree PDT(*F);
+
+  CriticalEdgeSplittingOptions CESO(&DT, nullptr, nullptr, &PDT);
+  EXPECT_EQ(1u, SplitAllCriticalEdges(*F, CESO));
+  EXPECT_TRUE(DT.verify());
+  EXPECT_TRUE(PDT.verify());
+}