Make NFA threads shared and effectively immutable.

Instead of copying threads all the time, the NFA execution engine now
increments reference counters in AddToThreadq() and decrements them in
Step(). It now copies threads only when recording captures.

Twiddling reference counters is cheaper than copying pointers because
there are at least two pointers for the zeroth submatch and, of course,
there may be arbitrarily many submatches.

This probably will not help much with memory footprint except when fanout
is high, but seems like it will be friendlier in terms of cache effects.

Change-Id: I90e9f6c0164cb4d06554ec16a89bc8ce76f500a3
Reviewed-on: https://code-review.googlesource.com/4670
Reviewed-by: Paul Wankadia <junyer@google.com>
diff --git a/re2/nfa.cc b/re2/nfa.cc
index 16bbe31..0a5cfd1 100644
--- a/re2/nfa.cc
+++ b/re2/nfa.cc
@@ -55,22 +55,24 @@
 
  private:
   struct Thread {
-    Thread* next;  // when on free list
+    union {
+      int ref;
+      Thread* next;  // when on free list
+    };
     const char** capture;
   };
 
   // State for explicit stack in AddToThreadq.
   struct AddState {
-    int id;           // Inst to process
-    int j;
-    const char* cap_j;  // if j>=0, set capture[j] = cap_j before processing ip
+    int id;     // Inst to process
+    Thread* t;  // if not null, set t0 = t before processing id
 
     AddState()
-      : id(0), j(-1), cap_j(NULL) {}
+        : id(0), t(NULL) {}
     explicit AddState(int id)
-      : id(id), j(-1), cap_j(NULL) {}
-    AddState(int id, const char* cap_j, int j)
-      : id(id), j(j), cap_j(cap_j) {}
+        : id(id), t(NULL) {}
+    AddState(int id, Thread* t)
+        : id(id), t(t) {}
   };
 
   // Threadq is a list of threads.  The list is sorted by the order
@@ -79,12 +81,13 @@
   typedef SparseArray<Thread*> Threadq;
 
   inline Thread* AllocThread();
-  inline void FreeThread(Thread*);
+  inline Thread* Incref(Thread* t);
+  inline void Decref(Thread* t);
 
-  // Add id (or its children, following unlabeled arrows)
+  // Add id0 (or its children, following unlabeled arrows)
   // to the workqueue q with associated capture info.
-  void AddToThreadq(Threadq* q, int id, int flag,
-                    const char* p, const char** capture);
+  void AddToThreadq(Threadq* q, int id0, int flag,
+                    const char* p, Thread* t0);
 
   // Run runq on byte c, appending new states to nextq.
   // Updates matched_ and match_ as new, better matches are found.
@@ -154,24 +157,36 @@
   }
 }
 
-void NFA::FreeThread(Thread *t) {
-  if (t == NULL)
-    return;
-  t->next = free_threads_;
-  free_threads_ = t;
-}
-
 NFA::Thread* NFA::AllocThread() {
   Thread* t = free_threads_;
   if (t == NULL) {
     t = new Thread;
+    t->ref = 1;
     t->capture = new const char*[ncapture_];
     return t;
   }
   free_threads_ = t->next;
+  t->ref = 1;
   return t;
 }
 
+NFA::Thread* NFA::Incref(Thread* t) {
+  DCHECK(t != NULL);
+  t->ref++;
+  return t;
+}
+
+void NFA::Decref(Thread* t) {
+  if (t == NULL)
+    return;
+  t->ref--;
+  if (t->ref > 0)
+    return;
+  DCHECK_EQ(t->ref, 0);
+  t->next = free_threads_;
+  free_threads_ = t;
+}
+
 void NFA::CopyCapture(const char** dst, const char** src) {
   for (int i = 0; i < ncapture_; i+=2) {
     dst[i] = src[i];
@@ -181,10 +196,9 @@
 
 // Follows all empty arrows from id0 and enqueues all the states reached.
 // The bits in flag (Bol, Eol, etc.) specify whether ^, $ and \b match.
-// The pointer p is the current input position, and m is the
-// current set of match boundaries.
+// p is the current input position, and t0 is the current thread.
 void NFA::AddToThreadq(Threadq* q, int id0, int flag,
-                       const char* p, const char** capture) {
+                       const char* p, Thread* t0) {
   if (id0 == 0)
     return;
 
@@ -204,15 +218,19 @@
     AddState a = stk[--nstk];
 
   Loop:
-    if (a.j >= 0)
-      capture[a.j] = a.cap_j;
+    if (a.t != NULL) {
+      // t0 was a thread that we allocated and copied in order to
+      // record the capture, so we must now decref it.
+      Decref(t0);
+      t0 = a.t;
+    }
 
     int id = a.id;
     if (id == 0)
       continue;
     if (q->has_index(id)) {
       if (Debug)
-        fprintf(stderr, "  [%d%s]\n", id, FormatCapture(capture).c_str());
+        fprintf(stderr, "  [%d%s]\n", id, FormatCapture(t0->capture).c_str());
       continue;
     }
 
@@ -235,8 +253,7 @@
 
     case kInstAltMatch:
       // Save state; will pick up at next byte.
-      t = AllocThread();
-      CopyCapture(t->capture, capture);
+      t = Incref(t0);
       *tp = t;
 
       DCHECK(!ip->last());
@@ -256,12 +273,15 @@
         stk[nstk++] = AddState(id+1);
 
       if ((j=ip->cap()) < ncapture_) {
-        // Push a dummy whose only job is to restore capture[j]
+        // Push a dummy whose only job is to restore t0
         // once we finish exploring this possibility.
-        stk[nstk++] = AddState(0, capture[j], j);
+        stk[nstk++] = AddState(0, t0);
 
         // Record capture.
-        capture[j] = p;
+        t = AllocThread();
+        CopyCapture(t->capture, t0->capture);
+        t->capture[j] = p;
+        t0 = t;
       }
       a = AddState(ip->out());
       goto Loop;
@@ -269,11 +289,10 @@
     case kInstMatch:
     case kInstByteRange:
       // Save state; will pick up at next byte.
-      t = AllocThread();
-      CopyCapture(t->capture, capture);
+      t = Incref(t0);
       *tp = t;
       if (Debug)
-        fprintf(stderr, " + %d%s [%p]\n", id, FormatCapture(t->capture).c_str(), t);
+        fprintf(stderr, " + %d%s\n", id, FormatCapture(t0->capture).c_str());
 
       if (ip->last())
         break;
@@ -312,7 +331,7 @@
     if (longest_) {
       // Can skip any threads started after our current best match.
       if (matched_ && match_[0] < t->capture[0]) {
-        FreeThread(t);
+        Decref(t);
         continue;
       }
     }
@@ -328,7 +347,7 @@
 
       case kInstByteRange:
         if (ip->Matches(c))
-          AddToThreadq(nextq, ip->out(), flag, p+1, t->capture);
+          AddToThreadq(nextq, ip->out(), flag, p+1, t);
         break;
 
       case kInstAltMatch:
@@ -337,11 +356,12 @@
         // The match is ours if we want it.
         if (ip->greedy(prog_) || longest_) {
           CopyCapture(match_, t->capture);
-          FreeThread(t);
-          for (++i; i != runq->end(); ++i)
-            FreeThread(i->second);
-          runq->clear();
           matched_ = true;
+
+          Decref(t);
+          for (++i; i != runq->end(); ++i)
+            Decref(i->second);
+          runq->clear();
           if (ip->greedy(prog_))
             return ip->out1();
           return ip->out();
@@ -352,33 +372,35 @@
         if (endmatch_ && p != etext_)
           break;
 
-        t->capture[1] = p;
         if (longest_) {
           // Leftmost-longest mode: save this match only if
           // it is either farther to the left or at the same
           // point but longer than an existing match.
           if (!matched_ || t->capture[0] < match_[0] ||
-              (t->capture[0] == match_[0] && t->capture[1] > match_[1]))
+              (t->capture[0] == match_[0] && p > match_[1])) {
             CopyCapture(match_, t->capture);
+            match_[1] = p;
+            matched_ = true;
+          }
         } else {
           // Leftmost-biased mode: this match is by definition
           // better than what we've already found (see next line).
           CopyCapture(match_, t->capture);
+          match_[1] = p;
+          matched_ = true;
 
           // Cut off the threads that can only find matches
           // worse than the one we just found: don't run the
           // rest of the current Threadq.
-          FreeThread(t);
+          Decref(t);
           for (++i; i != runq->end(); ++i)
-            FreeThread(i->second);
+            Decref(i->second);
           runq->clear();
-          matched_ = true;
           return 0;
         }
-        matched_ = true;
         break;
     }
-    FreeThread(t);
+    Decref(t);
   }
   runq->clear();
   return 0;
@@ -454,7 +476,6 @@
 
   match_ = new const char*[ncapture_];
   matched_ = false;
-  memset(match_, 0, ncapture_*sizeof match_[0]);
 
   // For debugging prints.
   btext_ = context.begin();
@@ -580,8 +601,11 @@
         flag = Prog::EmptyFlags(context, p);
       }
 
-      match_[0] = p;
-      AddToThreadq(runq, start_, flag, p, match_);
+      Thread* t = AllocThread();
+      CopyCapture(t->capture, match_);
+      t->capture[0] = p;
+      AddToThreadq(runq, start_, flag, p, t);
+      Decref(t);
     }
 
     // If all the threads have died, stop early.
@@ -601,7 +625,7 @@
   }
 
   for (Threadq::iterator i = runq->begin(); i != runq->end(); ++i)
-    FreeThread(i->second);
+    Decref(i->second);
 
   if (matched_) {
     for (int i = 0; i < nsubmatch; i++)