[lib/inferior_control] Rewrite RefreshAllThreads to fetch koid list first

... and make EnsureThreadMapFresh return void: Any failures should be
internal bugs, which are checked internally.

FLK-94 #comment patch

Change-Id: I2ae83a73ae3c257b85dde81a50570f72668690cc
diff --git a/garnet/lib/inferior_control/process.cc b/garnet/lib/inferior_control/process.cc
index 3258ab3..d4c2551d 100644
--- a/garnet/lib/inferior_control/process.cc
+++ b/garnet/lib/inferior_control/process.cc
@@ -16,6 +16,7 @@
 
 #include "garnet/lib/debugger_utils/breakpoints.h"
 #include "garnet/lib/debugger_utils/jobs.h"
+#include "garnet/lib/debugger_utils/processes.h"
 #include "garnet/lib/debugger_utils/util.h"
 
 #include "process.h"
@@ -553,11 +554,10 @@
   }
 }
 
-bool Process::EnsureThreadMapFresh() {
+void Process::EnsureThreadMapFresh() {
   if (thread_map_stale_) {
-    return RefreshAllThreads();
+    RefreshAllThreads();
   }
-  return true;
 }
 
 Thread* Process::FindThreadById(zx_koid_t thread_id) {
@@ -574,8 +574,7 @@
   }
 
   FXL_DCHECK(handle_);
-  bool fresh = EnsureThreadMapFresh();
-  FXL_DCHECK(fresh);
+  EnsureThreadMapFresh();
 
   const auto iter = threads_.find(thread_id);
   if (iter != threads_.end()) {
@@ -605,8 +604,7 @@
 }
 
 Thread* Process::PickOneThread() {
-  bool fresh = EnsureThreadMapFresh();
-  FXL_DCHECK(fresh);
+  EnsureThreadMapFresh();
 
   if (threads_.empty())
     return nullptr;
@@ -614,68 +612,53 @@
   return threads_.begin()->second.get();
 }
 
-bool Process::RefreshAllThreads() {
+void Process::RefreshAllThreads() {
   FXL_DCHECK(handle_);
 
-  // First get the thread count so that we can allocate an appropriately sized
-  // buffer. This is racy but unless the caller stops all threads that's just
-  // the way things are.
-  size_t num_threads;
-  zx_status_t status = zx_object_get_info(handle_, ZX_INFO_PROCESS_THREADS,
-                                          nullptr, 0, nullptr, &num_threads);
-  if (status != ZX_OK) {
-    FXL_LOG(ERROR) << "Failed to get process thread info (#threads): "
-                   << debugger_utils::ZxErrorString(status);
-    return false;
-  }
+  std::vector<zx_koid_t> threads;
+  size_t num_available_threads;
 
-  auto buffer_size = num_threads * sizeof(zx_koid_t);
-  auto koids = std::make_unique<zx_koid_t[]>(num_threads);
-  size_t records_read;
-  status = zx_object_get_info(handle_, ZX_INFO_PROCESS_THREADS, koids.get(),
-                              buffer_size, &records_read, nullptr);
-  if (status != ZX_OK) {
-    FXL_LOG(ERROR) << "Failed to get process thread info: "
-                   << debugger_utils::ZxErrorString(status);
-    return false;
-  }
+  __UNUSED zx_status_t status =
+      debugger_utils::GetProcessThreadKoids(handle_, kRefreshThreadsTryCount,
+          kNumExtraRefreshThreads, &threads, &num_available_threads);
+  // The only way this can fail is if we have a bug (or the kernel runs out
+  // of memory, but we don't try to cope with that case).
+  // TODO(dje): Verify the handle we are given has sufficient rights.
+  FXL_DCHECK(status == ZX_OK);
 
-  FXL_DCHECK(records_read == num_threads);
+  // The heuristic we use to collect all threads is sufficient that this
+  // will never fail in practice. If it does we need to adjust it.
+  FXL_DCHECK(threads.size() == num_available_threads);
 
-  ThreadMap new_threads;
-  for (size_t i = 0; i < num_threads; ++i) {
-    zx_koid_t thread_id = koids[i];
-    if (threads_.find(thread_id) != threads_.end()) {
+  for (auto tid : threads) {
+    if (threads_.find(tid) != threads_.end()) {
       // We already have this thread.
       continue;
     }
     zx_handle_t thread_handle = ZX_HANDLE_INVALID;
-    status = zx_object_get_child(handle_, thread_id, ZX_RIGHT_SAME_RIGHTS,
+    status = zx_object_get_child(handle_, tid, ZX_RIGHT_SAME_RIGHTS,
                                  &thread_handle);
-    if (status != ZX_OK) {
-      FXL_LOG(ERROR) << "Could not obtain a debug handle to thread: "
-                     << debugger_utils::ZxErrorString(status);
+    // The only way this can otherwise fail is if we have a bug.
+    FXL_DCHECK(status == ZX_OK || status == ZX_ERR_NOT_FOUND);
+    if (status == ZX_ERR_NOT_FOUND) {
+      // Thread died in the interim.
       continue;
     }
-    __UNUSED Thread* thread = AddThread(thread_handle, thread_id);
+    __UNUSED Thread* thread = AddThread(thread_handle, tid);
   }
 
   thread_map_stale_ = false;
-
-  return true;
 }
 
 void Process::ForEachThread(const ThreadCallback& callback) {
-  bool fresh = EnsureThreadMapFresh();
-  FXL_DCHECK(fresh);
+  EnsureThreadMapFresh();
 
   for (const auto& iter : threads_)
     callback(iter.second.get());
 }
 
 void Process::ForEachLiveThread(const ThreadCallback& callback) {
-  bool fresh = EnsureThreadMapFresh();
-  FXL_DCHECK(fresh);
+  EnsureThreadMapFresh();
 
   for (const auto& iter : threads_) {
     Thread* thread = iter.second.get();
@@ -883,7 +866,7 @@
 }
 
 void Process::Dump() {
-  FXL_CHECK(EnsureThreadMapFresh());
+  EnsureThreadMapFresh();
   FXL_LOG(INFO) << "Dump of threads for process " << id_;
 
   ForEachLiveThread([](Thread* thread) {
diff --git a/garnet/lib/inferior_control/process.h b/garnet/lib/inferior_control/process.h
index 18b8af7..6af4b7b 100644
--- a/garnet/lib/inferior_control/process.h
+++ b/garnet/lib/inferior_control/process.h
@@ -196,13 +196,8 @@
   Thread* PickOneThread();
 
   // If the thread map might be stale, refresh it.
-  // Returns true on success.
-  bool EnsureThreadMapFresh();
-
-  // Refreshes the complete Thread list for this process. Returns false if an
-  // error is returned from a syscall.
-  // Pointers to existing threads are maintained.
-  bool RefreshAllThreads();
+  // This may not be called while detached.
+  void EnsureThreadMapFresh();
 
   // Iterates through all cached threads and invokes |callback| for each of
   // them. |callback| is guaranteed to get called only before ForEachThread()
@@ -269,6 +264,22 @@
   zx_vaddr_t ldso_debug_map_addr() const { return ldso_debug_map_addr_; }
 
  private:
+  // When refreshing the thread list, new threads could be created.
+  // Add this to the number of existing threads to account for new ones.
+  // The number is large but the cost is only 8 bytes per extra thread for
+  // the thread's koid.
+  static constexpr size_t kNumExtraRefreshThreads = 20;
+
+  // When refreshing the thread list, if threads are being created faster than
+  // we can keep up, keep looking, but don't keep trying forever.
+  static constexpr size_t kRefreshThreadsTryCount = 4;
+
+  // Refreshes the complete Thread list for this process. Returns false if an
+  // error is returned from a syscall. Any threads that were accumulated up to
+  // that point are retained.
+  // Pointers to existing threads are maintained.
+  void RefreshAllThreads();
+
   // Wrapper on |zx_object_get_property()| to fetch the value of
   // ZX_PROP_PROCESS_DEBUG_ADDR.
   // Returns a boolean indicating success.
@@ -395,13 +406,13 @@
   // The collection of breakpoints that belong to this process.
   ProcessBreakpointSet breakpoints_;
 
-  // The threads owned by this process. This is map is populated lazily when
+  // The threads owned by this process. This map is populated lazily when
   // threads are requested through FindThreadById(). It can also be repopulated
   // from scratch, e.g., when attaching to an already running program.
   using ThreadMap = std::unordered_map<zx_koid_t, std::unique_ptr<Thread>>;
   ThreadMap threads_;
 
-  // If true then |threads_| needs to be recalculated from scratch.
+  // If true then |threads_| needs to be recalculated.
   bool thread_map_stale_ = false;
 
   // List of dsos loaded.
diff --git a/garnet/lib/inferior_control/process_unittest.cc b/garnet/lib/inferior_control/process_unittest.cc
index 143a901..8f08650 100644
--- a/garnet/lib/inferior_control/process_unittest.cc
+++ b/garnet/lib/inferior_control/process_unittest.cc
@@ -3,6 +3,7 @@
 // found in the LICENSE file.
 
 #include <lib/async/cpp/task.h>
+#include <lib/fxl/strings/string_printf.h>
 #include <lib/zx/channel.h>
 #include <string.h>
 
@@ -68,11 +69,16 @@
                                     &pending),
                 ZX_OK);
     }
+
     EXPECT_TRUE(inferior->Detach());
+    EXPECT_FALSE(inferior->IsAttached());
+    EXPECT_EQ(inferior->handle(), ZX_HANDLE_INVALID);
+
     // Sleep a little to hopefully give the inferior a chance to run.
     // We want it to trip over the ld.so breakpoint if we forgot to remove it.
     zx::nanosleep(zx::deadline_after(zx::msec(10)));
     EXPECT_TRUE(inferior->Attach(pid));
+    EXPECT_TRUE(inferior->IsAttached());
     // If attaching failed we'll hang since we won't see the inferior exiting.
     if (!inferior->IsAttached()) {
       QuitMessageLoop(true);
@@ -158,7 +164,8 @@
 
   void OnArchitecturalException(
       Process* process, Thread* thread, zx_handle_t eport,
-      const zx_excp_type_t type, const zx_exception_context_t& context) {
+      const zx_excp_type_t type, const zx_exception_context_t& context)
+      override {
     FXL_LOG(INFO) << "Got exception 0x" << std::hex << type;
     if (type == ZX_EXCP_SW_BREAKPOINT) {
       // The shared libraries should have been loaded by now.
@@ -190,6 +197,8 @@
       // Terminate the inferior, we don't want the exception propagating to
       // the system exception handler.
       zx_task_kill(process->handle());
+    } else {
+      EXPECT_TRUE(thread->TryNext(eport));
     }
   }
 
@@ -258,5 +267,91 @@
   EXPECT_TRUE(kill_requested());
 }
 
+// Test |RefreshThreads()| when a new thread is created while we're collecting
+// the list of threads. This is done by detaching and re-attaching with
+// successive number of new threads, and each time telling |RefreshThreads()|
+// there's only one thread.
+
+class RefreshTest : public TestServer {
+ public:
+  static constexpr size_t kNumIterations = 4;
+
+  RefreshTest() = default;
+
+  void OnThreadStarting(Process* process, Thread* thread, zx_handle_t eport,
+                        const zx_exception_context_t& context) override {
+    ++num_threads_;
+    FXL_LOG(INFO) << "Thread " << thread->id() << " starting, #threads: " << num_threads_;
+    // If this is the main thread then we don't want to do the test yet,
+    // we need to first proceed past the ld.so breakpoint.
+    // We can't currently catch the ld.so breakpoint, so just count started
+    // threads.
+    if (num_threads_ >= 2) {
+      PostQuitMessageLoop(true);
+    }
+
+    // Pass on to baseclass method to resume the thread.
+    TestServer::OnThreadStarting(process, thread, eport, context);
+  }
+
+ private:
+  int num_threads_ = 0;
+};
+
+TEST_F(RefreshTest, RefreshWithNewThreads) {
+  std::vector<std::string> argv{
+      kTestHelperPath,
+      "start-n-threads",
+      fxl::StringPrintf("%zu", kNumIterations),
+  };
+  ASSERT_TRUE(SetupInferior(argv));
+
+  zx::channel our_channel, their_channel;
+  auto status = zx::channel::create(0, &our_channel, &their_channel);
+  ASSERT_EQ(status, ZX_OK);
+
+  EXPECT_TRUE(RunHelperProgram(std::move(their_channel)));
+
+  Process* inferior = current_process();
+  zx_koid_t pid = inferior->id();
+
+  // This can't test new threads appearing while we're building the list,
+  // that is tested by the unittest for |GetProcessThreadKoids()|. But we
+  // can exercise |RefreshThreads()|.
+
+  for (size_t i = 0; i < kNumIterations; ++i) {
+    FXL_VLOG(1) << "Iteration " << i + 1;
+
+    // This won't return until the new thread is running.
+    EXPECT_TRUE(Run());
+
+    // Detaching and re-attaching will cause us to discard the previously
+    // collected set of threads.
+    EXPECT_TRUE(inferior->Detach());
+    EXPECT_TRUE(inferior->Attach(pid));
+
+    inferior->EnsureThreadMapFresh();
+    // There should be the main thread plus one new thread each iteration.
+    size_t count = 0;
+    inferior->ForEachThread([&count](Thread*) { ++count; });
+    EXPECT_EQ(count, i + 2);
+
+    // Reset the quit indicator for the next iteration. Do this before
+    // we allow the inferior to advance and create a new thread.
+    EXPECT_EQ(message_loop().ResetQuit(), ZX_OK);
+
+    // Send the inferior a packet so that it will continue with the next
+    // iteration.
+    FXL_VLOG(1) << "Advancing to next iteration";
+    uint64_t packet = kUint64MagicPacketValue;
+    EXPECT_EQ(our_channel.write(0, &packet, sizeof(packet), nullptr, 0), ZX_OK);
+  }
+
+  // Run the loop one more time to catch the inferior exiting.
+  EXPECT_TRUE(Run());
+
+  EXPECT_TRUE(TestSuccessfulExit());
+}
+
 }  // namespace
 }  // namespace inferior_control
diff --git a/garnet/lib/inferior_control/test_helper.cc b/garnet/lib/inferior_control/test_helper.cc
index a02e9f7..4f83e7fb 100644
--- a/garnet/lib/inferior_control/test_helper.cc
+++ b/garnet/lib/inferior_control/test_helper.cc
@@ -3,12 +3,14 @@
 // found in the LICENSE file.
 
 #include <cstdio>
+#include <cstdlib>
 #include <cstring>
 #include <thread>
 
 #include <lib/fxl/command_line.h>
 #include <lib/fxl/log_settings.h>
 #include <lib/fxl/log_settings_command_line.h>
+#include <lib/fxl/strings/string_number_conversions.h>
 #include <lib/zx/channel.h>
 #include <lib/zx/event.h>
 #include <lib/zx/port.h>
@@ -20,6 +22,8 @@
 #include "garnet/lib/debugger_utils/breakpoints.h"
 #include "garnet/lib/debugger_utils/util.h"
 
+#include "test_helper.h"
+
 using debugger_utils::ZxErrorString;
 
 static void ExceptionHandlerThreadFunc(
@@ -117,6 +121,58 @@
   return 0;
 }
 
+void WaitChannelReadable(const zx::channel& channel) {
+  zx_signals_t pending;
+  FXL_CHECK(channel.wait_one(ZX_CHANNEL_READABLE, zx::time::infinite(),
+                             &pending) == ZX_OK);
+}
+
+static void ReadUint64Packet(const zx::channel& channel,
+                             uint64_t expected_value) {
+  uint64_t packet;
+  uint32_t packet_size;
+  FXL_CHECK(channel.read(0, &packet, sizeof(packet), &packet_size,
+                         nullptr, 0, nullptr) == ZX_OK);
+  FXL_CHECK(packet_size == sizeof(packet));
+  FXL_CHECK(packet == expected_value);
+}
+
+static void StartNThreadsThreadFunc(zx_handle_t done_event) {
+  zx_signals_t pending;
+  zx_status_t status = zx_object_wait_one(done_event, ZX_EVENT_SIGNALED,
+                                          ZX_TIME_INFINITE, &pending);
+  FXL_CHECK(status == ZX_ERR_CANCELED);
+}
+
+static int StartNThreads(const zx::channel& channel, int num_iterations) {
+  std::vector<std::thread> threads;
+
+  // When this is closed the threads will exit.
+  zx::event done_event;
+  FXL_CHECK(zx::event::create(0, &done_event) == ZX_OK);
+
+  // What we want to do here is start a new thread and then wait for the
+  // test to do its thing, and repeat.
+
+  for (int i = 0; i < num_iterations; ++i) {
+    FXL_VLOG(1) << "StartNThreads iteration " << i + 1;
+
+    threads.emplace_back(
+        std::thread{StartNThreadsThreadFunc, done_event.get()});
+
+    WaitChannelReadable(channel);
+    ReadUint64Packet(channel, inferior_control::kUint64MagicPacketValue);
+  }
+
+  // Terminate the threads;
+  done_event.reset();
+  for (auto& thread : threads) {
+    thread.join();
+  }
+
+  return 0;
+}
+
 int main(int argc, char* argv[]) {
   auto cl = fxl::CommandLineFromArgcArgv(argc, argv);
   if (!fxl::SetLogSettingsFromCommandLine(cl)) {
@@ -142,6 +198,19 @@
     if (cmd == "trigger-sw-bkpt-with-handler") {
       return TriggerSoftwareBreakpoint(channel, true);
     }
+    if (cmd == "start-n-threads") {
+      if (args.size() < 2) {
+        FXL_LOG(ERROR) << "Missing iteration count";
+        return 1;
+      }
+      int num_threads = 0;
+      if (!fxl::StringToNumberWithError(args[1], &num_threads) ||
+          num_threads < 1) {
+        FXL_LOG(ERROR) << "Error parsing number of threads";
+        return 1;
+      }
+      return StartNThreads(channel, num_threads);
+    }
     FXL_LOG(ERROR) << "Unrecognized command: " << cmd;
     return 1;
   }
diff --git a/garnet/lib/inferior_control/test_helper.h b/garnet/lib/inferior_control/test_helper.h
index 1a7b5e3..a03c5c8 100644
--- a/garnet/lib/inferior_control/test_helper.h
+++ b/garnet/lib/inferior_control/test_helper.h
@@ -15,6 +15,9 @@
 // A string that appears in the Dso name of the test helper executable.
 constexpr const char kTestHelperDsoName[] = "test_helper";
 
+// A special value to pass between processes as a sanity check.
+constexpr uint64_t kUint64MagicPacketValue = 0x0123456789abcdeful;
+
 }  // namespace inferior_control
 
 #endif // GARNET_LIB_INFERIOR_CONTROL_TEST_HELPER_H_