Merge pull request #10185 from nkcsgexi/api-digester-simplify-context

diff --git a/tools/swift-api-digester/swift-api-digester.cpp b/tools/swift-api-digester/swift-api-digester.cpp
index 8d0d927..b1f3667 100644
--- a/tools/swift-api-digester/swift-api-digester.cpp
+++ b/tools/swift-api-digester/swift-api-digester.cpp
@@ -173,19 +173,6 @@
 typedef std::map<NodePtr, NodePtr> ParentMap;
 typedef std::vector<NodePtr> NodeVector;
 
-class SDKContext {
-  llvm::StringSet<> TextData;
-  llvm::BumpPtrAllocator Allocator;
-
-public:
-  llvm::BumpPtrAllocator &allocator() {
-    return Allocator;
-  }
-  StringRef buffer(StringRef Text) {
-    return TextData.insert(Text).first->getKey();
-  }
-};
-
 // The interface used to visit the SDK tree.
 class SDKNodeVisitor {
   friend SDKNode;
@@ -216,6 +203,39 @@
   virtual ~MatchedNodeListener() = default;
 };
 
+using NodePairVector = llvm::MapVector<NodePtr, NodePtr>;
+
+// This map keeps track of updated nodes; thus we can conveniently find out what
+// is the counterpart of a node before or after being updated.
+class UpdatedNodesMap : public MatchedNodeListener {
+  NodePairVector MapImpl;
+  UpdatedNodesMap(const UpdatedNodesMap& that) = delete;
+public:
+  UpdatedNodesMap() = default;
+  NodePtr findUpdateCounterpart(const SDKNode *Node) const;
+  void foundMatch(NodePtr Left, NodePtr Right) override {
+    assert(Left && Right && "Not update operation.");
+    MapImpl.insert({Left, Right});
+  }
+};
+
+class SDKContext {
+  llvm::StringSet<> TextData;
+  llvm::BumpPtrAllocator Allocator;
+  UpdatedNodesMap UpdateMap;
+
+public:
+  llvm::BumpPtrAllocator &allocator() {
+    return Allocator;
+  }
+  StringRef buffer(StringRef Text) {
+    return TextData.insert(Text).first->getKey();
+  }
+  UpdatedNodesMap &getNodeUpdateMap() {
+    return UpdateMap;
+  }
+};
+
 // A node matcher will traverse two trees of SDKNode and find matched nodes
 struct NodeMatcher {
   virtual void match() = 0;
@@ -368,6 +388,16 @@
   bool isStatic() const { return IsStatic; };
 };
 
+NodePtr UpdatedNodesMap::findUpdateCounterpart(const SDKNode *Node) const {
+  assert(Node->isAnnotatedAs(NodeAnnotation::Updated) && "Not update operation.");
+  auto FoundPair = std::find_if(MapImpl.begin(), MapImpl.end(),
+                      [&](std::pair<NodePtr, NodePtr> Pair) {
+    return Pair.second == Node || Pair.first == Node;
+  });
+  assert(FoundPair != MapImpl.end() && "Cannot find update counterpart.");
+  return Node == FoundPair->first ? FoundPair->second : FoundPair->first;
+}
+
 class SDKNodeType : public SDKNode {
   std::vector<TypeAttrKind> TypeAttributes;
 
@@ -1934,30 +1964,6 @@
   virtual ~SDKTreeDiffPass() {}
 };
 
-using NodePairVector = llvm::MapVector<NodePtr, NodePtr>;
-
-// This map keeps track of updated nodes; thus we can conveniently find out what
-// is the counterpart of a node before or after being updated.
-class UpdatedNodesMap : public MatchedNodeListener {
-  NodePairVector MapImpl;
-
-public:
-  void foundMatch(NodePtr Left, NodePtr Right) override {
-    assert(Left && Right && "Not update operation.");
-    MapImpl.insert({Left, Right});
-  }
-
-  NodePtr findUpdateCounterpart(const SDKNode *Node) const {
-    assert(Node->isAnnotatedAs(NodeAnnotation::Updated) && "Not update operation.");
-    auto FoundPair = std::find_if(MapImpl.begin(), MapImpl.end(),
-                        [&](std::pair<NodePtr, NodePtr> Pair) {
-      return Pair.second == Node || Pair.first == Node;
-    });
-    assert(FoundPair != MapImpl.end() && "Cannot find update counterpart.");
-    return Node == FoundPair->first ? FoundPair->second : FoundPair->first;
-  }
-};
-
 static void detectFuncDeclChange(NodePtr L, NodePtr R) {
   assert(L->getKind() == R->getKind());
   if (auto LF = dyn_cast<SDKNodeAbstractFunc>(L)) {
@@ -2015,10 +2021,10 @@
       Right->removeChild(R);
   }
 
-  std::unique_ptr<UpdatedNodesMap> UpdateMap;
+  UpdatedNodesMap &UpdateMap;
 
 public:
-  PrunePass() : UpdateMap(new UpdatedNodesMap()) {}
+  PrunePass(UpdatedNodesMap &UpdateMap) : UpdateMap(UpdateMap) {}
 
   void foundRemoveAddMatch(NodePtr Left, NodePtr Right) override {
     if (!Left)
@@ -2038,7 +2044,7 @@
     Left->annotate(NodeAnnotation::Updated);
     Right->annotate(NodeAnnotation::Updated);
     // Push the updated node to the map for future reference.
-    UpdateMap->foundMatch(Left, Right);
+    UpdateMap.foundMatch(Left, Right);
 
     if (Left->getKind() != Right->getKind()) {
       assert(isa<SDKNodeType>(Left) && isa<SDKNodeType>(Right) &&
@@ -2093,10 +2099,6 @@
   void pass(NodePtr Left, NodePtr Right) override {
     foundMatch(Left, Right);
   }
-
-  std::unique_ptr<UpdatedNodesMap> getNodeUpdateMap() {
-    return std::move(UpdateMap);
-  }
 };
 
 // For a given SDK node tree, this will build up a mapping from USR to node
@@ -2222,7 +2224,7 @@
 
 class ChangeRefinementPass : public SDKTreeDiffPass, public SDKNodeVisitor {
   bool IsVisitingLeft;
-  std::unique_ptr<UpdatedNodesMap> UpdateMap;
+  UpdatedNodesMap &UpdateMap;
 
 #define ANNOTATE(Node, Counter, X, Y)                                          \
   auto ToAnnotate = IsVisitingLeft ? Node : Counter;                           \
@@ -2294,14 +2296,13 @@
   }
 
   bool isUnhandledCase(SDKNodeType *Node) {
-    auto Counter = UpdateMap->findUpdateCounterpart(Node)->getAs<SDKNodeType>();
+    auto Counter = UpdateMap.findUpdateCounterpart(Node)->getAs<SDKNodeType>();
     return Node->getTypeKind() == KnownTypeKind::Void ||
            Counter->getTypeKind() == KnownTypeKind::Void;
   }
 
 public:
-  ChangeRefinementPass(std::unique_ptr<UpdatedNodesMap> UpdateMap) :
-    UpdateMap(std::move(UpdateMap)) {}
+  ChangeRefinementPass(UpdatedNodesMap &UpdateMap) : UpdateMap(UpdateMap) {}
 
   void pass(NodePtr Left, NodePtr Right) override {
 
@@ -2317,7 +2318,7 @@
     if (!Node || !Node->isAnnotatedAs(NodeAnnotation::Updated) ||
         isUnhandledCase(Node))
       return;
-    auto Counter = const_cast<SDKNodeType*>(UpdateMap->
+    auto Counter = const_cast<SDKNodeType*>(UpdateMap.
       findUpdateCounterpart(Node)->getAs<SDKNodeType>());
 
     bool Result = detectWrapOptional(Node, Counter)||
@@ -2328,10 +2329,6 @@
     (void) Result;
     return;
   }
-
-  std::unique_ptr<UpdatedNodesMap> getNodeUpdateMap() {
-    return std::move(UpdateMap);
-  }
 };
 
 typedef std::vector<CommonDiffItem> DiffVector;
@@ -2643,7 +2640,7 @@
   DiagBag<MovedDeclDiag> MovedDecls;
   DiagBag<RemovedDeclDiag> RemovedDecls;
 
-  UpdatedNodesMap UpdateMap;
+  UpdatedNodesMap &UpdateMap;
   DiagnosisEmitter(UpdatedNodesMap &UpdateMap) : UpdateMap(UpdateMap) {}
 public:
   static void diagnosis(NodePtr LeftRoot, NodePtr RightRoot,
@@ -3019,12 +3016,11 @@
   RightCollector.deSerialize(RightPath);
   auto LeftModule = LeftCollector.getSDKRoot();
   auto RightModule = RightCollector.getSDKRoot();
-  PrunePass Prune;
+  PrunePass Prune(Ctx.getNodeUpdateMap());
   Prune.pass(LeftModule, RightModule);
-  ChangeRefinementPass RefinementPass(Prune.getNodeUpdateMap());
+  ChangeRefinementPass RefinementPass(Ctx.getNodeUpdateMap());
   RefinementPass.pass(LeftModule, RightModule);
-  DiagnosisEmitter::diagnosis(LeftModule, RightModule,
-                              *RefinementPass.getNodeUpdateMap());
+  DiagnosisEmitter::diagnosis(LeftModule, RightModule, Ctx.getNodeUpdateMap());
   return 0;
 }
 
@@ -3055,10 +3051,10 @@
   TypeMemberDiffVector typeMemberDiffs;
   findTypeMemberDiffs(LeftModule, RightModule, typeMemberDiffs);
 
-  PrunePass Prune;
+  PrunePass Prune(Ctx.getNodeUpdateMap());
   Prune.pass(LeftModule, RightModule);
   llvm::errs() << "Finished pruning" << "\n";
-  ChangeRefinementPass RefinementPass(Prune.getNodeUpdateMap());
+  ChangeRefinementPass RefinementPass(Ctx.getNodeUpdateMap());
   RefinementPass.pass(LeftModule, RightModule);
   DiffVector AllItems;
   DiffItemEmitter::collectDiffItems(LeftModule, AllItems);