[fidl][llcpp] Clean up ResponseContexts.

See go/fidl-llcpp-mt-client.

This CL adds the ability to clean up any outstanding ResponseContexts
when the ClientBase is destroyed or if an error occurs on a response
prior to the reply callback being invoked.

Bug: 7685
Test: runtests -t fidl-async-test-test -r 100
Change-Id: I2b2ba970245fb511eb894a155ecb98d0c8aaee00
Reviewed-on: https://fuchsia-review.googlesource.com/c/fuchsia/+/376879
Commit-Queue: Madhav Iyengar <madhaviyengar@google.com>
Reviewed-by: Andres Oportus <andresoportus@google.com>
Reviewed-by: Yifei Teng <yifeit@google.com>
Testability-Review: Andres Oportus <andresoportus@google.com>
Testability-Review: Yifei Teng <yifeit@google.com>
diff --git a/zircon/system/ulib/fidl-async/include/lib/fidl-async/cpp/async_bind_internal.h b/zircon/system/ulib/fidl-async/include/lib/fidl-async/cpp/async_bind_internal.h
index aa2a5d9..0bc7a66 100644
--- a/zircon/system/ulib/fidl-async/include/lib/fidl-async/cpp/async_bind_internal.h
+++ b/zircon/system/ulib/fidl-async/include/lib/fidl-async/cpp/async_bind_internal.h
@@ -54,6 +54,8 @@
     UnbindInternal(std::move(calling_ref), is_server_ ? &epitaph : nullptr);
   }
 
+  zx::unowned_channel channel() const { return zx::unowned_channel(channel_); }
+
  protected:
   AsyncBinding(async_dispatcher_t* dispatcher, zx::channel channel, void* impl, bool is_server,
                TypeErasedOnUnboundFn on_unbound_fn, DispatchFn dispatch_fn);
@@ -86,8 +88,6 @@
     delete unbound_task;
   }
 
-  zx::unowned_channel channel() const { return zx::unowned_channel(channel_); }
-
   void MessageHandler(zx_status_t status, const zx_packet_signal_t* signal) __TA_EXCLUDES(lock_);
 
   // Used by both Close() and Unbind().
diff --git a/zircon/system/ulib/fidl-async/include/lib/fidl-async/cpp/client_base.h b/zircon/system/ulib/fidl-async/include/lib/fidl-async/cpp/client_base.h
index 5e2ea9e..bafae70 100644
--- a/zircon/system/ulib/fidl-async/include/lib/fidl-async/cpp/client_base.h
+++ b/zircon/system/ulib/fidl-async/include/lib/fidl-async/cpp/client_base.h
@@ -18,12 +18,26 @@
 namespace fidl {
 namespace internal {
 
-// Struct used to track an outstanding asynchronous transaction.
-struct ResponseContext {
-  // Intrusive list node for tracking this ResponseContext.
-  // TODO(madhaviyengar): Replace this once an intrusive tree/map is added to the SDK.
-  list_node_t node = LIST_INITIAL_CLEARED_VALUE;
-  zx_txid_t txid = 0;  // txid of outstanding transaction.
+// ResponseContext contains the state for an outstanding asynchronous transaction. It inherits from
+// an intrusive container node so that ClientBase can track it.
+// TODO(madhaviyengar): Replace list_node_t once an intrusive tree/map is added to the SDK.
+class ResponseContext : private list_node_t {
+ public:
+  ResponseContext() : list_node_t(LIST_INITIAL_CLEARED_VALUE) {}
+  virtual ~ResponseContext() = default;
+
+  zx_txid_t Txid() const { return txid_; }
+
+  // Invoked if an error occurs handling the response message prior to invoking the user-specified
+  // callback or if the ClientBase is destroyed with the transaction outstanding. Note that
+  // OnError() may be invoked within ~ClientBase(), so the user must ensure that a ClientBase is not
+  // destroyed while holding any locks OnError() would take.
+  virtual void OnError() = 0;
+
+ private:
+  friend class ClientBase;
+
+  zx_txid_t txid_ = 0;  // Zircon txid of outstanding transaction.
 };
 
 // Base LLCPP client class supporting use with a multithreaded asynchronous dispatcher, safe error
@@ -56,11 +70,11 @@
 
   // Returns a strong reference to binding to prevent channel deletion during a zx_channel_call() or
   // zx_channel_write(). The caller is responsible for releasing the reference.
-  std::shared_ptr<AsyncBinding> GetBinding();
+  std::shared_ptr<AsyncBinding> GetBinding() { return binding_.lock(); }
 
   // Invoked by InternalDispatch() below. If `context` is non-null, the message was the response to
   // an asynchronous transaction. Otherwise, the message was an event.
-  virtual void Dispatch(fidl_msg_t* msg, ResponseContext* context) = 0;
+  virtual zx_status_t Dispatch(fidl_msg_t* msg, ResponseContext* context) = 0;
 
   // Dispatch function invoked by AsyncBinding on incoming message. The client only requires `msg`.
   void InternalDispatch(std::shared_ptr<AsyncBinding>&, fidl_msg_t* msg, bool*,
@@ -73,7 +87,8 @@
   std::mutex lock_;
   // The base node of an intrusive container of ResponseContexts corresponding to outstanding
   // asynchronous transactions.
-  ResponseContext __TA_GUARDED(lock_) contexts_ = {};
+  list_node_t contexts_ __TA_GUARDED(lock_) = LIST_INITIAL_VALUE(contexts_);
+  zx_txid_t txid_base_ __TA_GUARDED(lock_) = 0;  // Value used to compute the next txid.
 };
 
 }  // namespace internal
diff --git a/zircon/system/ulib/fidl-async/llcpp_client_base.cc b/zircon/system/ulib/fidl-async/llcpp_client_base.cc
index 26bead2..482e0b4a 100644
--- a/zircon/system/ulib/fidl-async/llcpp_client_base.cc
+++ b/zircon/system/ulib/fidl-async/llcpp_client_base.cc
@@ -16,6 +16,18 @@
 
 ClientBase::~ClientBase() {
   Unbind();
+  // Release any managed ResponseContexts.
+  list_node_t delete_list = LIST_INITIAL_CLEARED_VALUE;
+  {
+    std::scoped_lock lock(lock_);
+    list_move(&contexts_, &delete_list);
+  }
+  list_node_t* node = nullptr;
+  list_node_t* temp_node = nullptr;
+  list_for_every_safe(&delete_list, node, temp_node) {
+    list_delete(node);
+    static_cast<ResponseContext*>(node)->OnError();
+  }
 }
 
 void ClientBase::Unbind() {
@@ -25,7 +37,6 @@
 
 ClientBase::ClientBase(zx::channel channel, async_dispatcher_t* dispatcher,
                        TypeErasedOnUnboundFn on_unbound) {
-  list_initialize(&contexts_.node);
   binding_ = AsyncBinding::CreateClientBinding(
       dispatcher, std::move(channel), this, fit::bind_member(this, &ClientBase::InternalDispatch),
       std::move(on_unbound));
@@ -45,11 +56,11 @@
   do {
     found = false;
     do {
-      context->txid = ++contexts_.txid & kUserspaceTxidMask;  // txid must be within mask.
-    } while (!context->txid);  // txid must be non-zero.
-    ResponseContext* entry = nullptr;
-    list_for_every_entry(&contexts_.node, entry, ResponseContext, node) {
-      if (entry->txid == context->txid) {
+      context->txid_ = ++txid_base_ & kUserspaceTxidMask;  // txid must be within mask.
+    } while (!context->txid_);  // txid must be non-zero.
+    list_node_t* node = nullptr;
+    list_for_every(&contexts_, node) {
+      if (static_cast<ResponseContext*>(node)->txid_ == context->txid_) {
         found = true;
         break;
       }
@@ -57,17 +68,14 @@
   } while (found);
 
   // Insert the ResponseContext.
-  list_add_tail(&contexts_.node, &context->node);
+  list_add_tail(&contexts_, static_cast<list_node_t*>(context));
 }
 
 void ClientBase::ForgetAsyncTxn(ResponseContext* context) {
+  auto* node = static_cast<list_node_t*>(context);
   std::scoped_lock lock(lock_);
-  ZX_ASSERT(list_in_list(&context->node));
-  list_delete(&context->node);
-}
-
-std::shared_ptr<AsyncBinding> ClientBase::GetBinding() {
-  return binding_.lock();
+  ZX_ASSERT(list_in_list(node));
+  list_delete(node);
 }
 
 void ClientBase::InternalDispatch(std::shared_ptr<AsyncBinding>&, fidl_msg_t* msg, bool*,
@@ -84,11 +92,12 @@
   if (hdr->txid) {
     {
       std::scoped_lock lock(lock_);
-      ResponseContext* entry = nullptr;
-      list_for_every_entry(&contexts_.node, entry, ResponseContext, node) {
-        if (entry->txid == hdr->txid) {
+      list_node_t* node = nullptr;
+      list_for_every(&contexts_, node) {
+        auto* entry = static_cast<ResponseContext*>(node);
+        if (entry->txid_ == hdr->txid) {
           context = entry;
-          list_delete(&entry->node);  // This is safe since we break immediately after.
+          list_delete(node);  // This is safe since we break immediately after.
           break;
         }
       }
@@ -103,7 +112,7 @@
   }
 
   // Dispatch the message
-  Dispatch(msg, context);
+  *status = Dispatch(msg, context);
 }
 
 }  // namespace internal
diff --git a/zircon/system/ulib/fidl-async/test/llcpp_client_base_test.cc b/zircon/system/ulib/fidl-async/test/llcpp_client_base_test.cc
index 0d17950..0a70c2e 100644
--- a/zircon/system/ulib/fidl-async/test/llcpp_client_base_test.cc
+++ b/zircon/system/ulib/fidl-async/test/llcpp_client_base_test.cc
@@ -34,14 +34,14 @@
   void PrepareAsyncTxn(ResponseContext* context) {
     ClientBase::PrepareAsyncTxn(context);
     std::unique_lock lock(lock_);
-    EXPECT_FALSE(txids_.count(context->txid));
-    txids_.insert(context->txid);
+    EXPECT_FALSE(txids_.count(context->Txid()));
+    txids_.insert(context->Txid());
   }
 
   void ForgetAsyncTxn(ResponseContext* context) {
     {
       std::unique_lock lock(lock_);
-      txids_.erase(context->txid);
+      txids_.erase(context->Txid());
     }
     ClientBase::ForgetAsyncTxn(context);
   }
@@ -52,17 +52,23 @@
 
   // For responses, find and remove the entry for the matching txid. For events, increment the
   // event count.
-  void Dispatch(fidl_msg_t* msg, ResponseContext* context) override {
+  zx_status_t Dispatch(fidl_msg_t* msg, ResponseContext* context) override {
     auto* hdr = reinterpret_cast<fidl_message_header_t*>(msg->bytes);
-    ASSERT_EQ(!hdr->txid, !context);  // hdr->txid == 0 iff context == nullptr.
+    EXPECT_EQ(!hdr->txid, !context);  // hdr->txid == 0 iff context == nullptr.
+    if (!hdr->txid != !context) {
+      return ZX_OK;  // This is a failure, but let the test continue.
+    }
     std::unique_lock lock(lock_);
     if (hdr->txid) {
       auto txid_it = txids_.find(hdr->txid);
-      ASSERT_TRUE(txid_it != txids_.end());  // the transaction must be found.
-      txids_.erase(txid_it);
+      EXPECT_TRUE(txid_it != txids_.end());  // the transaction must be found.
+      if (txid_it != txids_.end()) {
+        txids_.erase(txid_it);
+      }
     } else {
       event_count_++;
     }
+    return ZX_OK;
   }
 
   uint32_t GetEventCount() {
@@ -76,7 +82,7 @@
   }
 
   size_t GetTxidCount() {
-    auto internal_count = list_length(&contexts_.node);
+    auto internal_count = list_length(&contexts_);
     std::unique_lock lock(lock_);
     EXPECT_EQ(txids_.size(), internal_count);
     return internal_count;
@@ -87,6 +93,12 @@
   uint32_t event_count_ = 0;
 };
 
+class TestResponseContext : public ResponseContext {
+ public:
+  TestResponseContext() = default;
+  void OnError() {}
+};
+
 TEST(ClientBaseTestCase, AsyncTxn) {
   async::Loop loop(&kAsyncLoopConfigNoAttachToCurrentThread);
   ASSERT_OK(loop.StartThread());
@@ -109,11 +121,11 @@
 
   // Generate a txid for a ResponseContext. Send a "response" message with the same txid from the
   // remote end of the channel.
-  ResponseContext context;
+  TestResponseContext context;
   client->PrepareAsyncTxn(&context);
-  EXPECT_TRUE(client->IsPending(context.txid));
+  EXPECT_TRUE(client->IsPending(context.Txid()));
   fidl_message_header_t hdr;
-  fidl_init_txn_header(&hdr, context.txid, 0);
+  fidl_init_txn_header(&hdr, context.Txid(), 0);
   ASSERT_OK(remote.write(0, &hdr, sizeof(fidl_message_header_t), nullptr, 0));
 
   // Trigger unbound handler.
@@ -143,14 +155,14 @@
 
   // In parallel, simulate 10 async transactions and send "response" messages from the remote end of
   // the channel.
-  ResponseContext contexts[10];
+  TestResponseContext contexts[10];
   std::thread threads[10];
   for (int i = 0; i < 10; ++i) {
     threads[i] = std::thread([context = &contexts[i], &remote, client]{
       client->PrepareAsyncTxn(context);
-      EXPECT_TRUE(client->IsPending(context->txid));
+      EXPECT_TRUE(client->IsPending(context->Txid()));
       fidl_message_header_t hdr;
-      fidl_init_txn_header(&hdr, context->txid, 0);
+      fidl_init_txn_header(&hdr, context->Txid(), 0);
       ASSERT_OK(remote.write(0, &hdr, sizeof(fidl_message_header_t), nullptr, 0));
     });
   }
@@ -172,9 +184,9 @@
   TestClient client(std::move(local), loop.dispatcher(), nullptr);
 
   // Generate a txid for a ResponseContext.
-  ResponseContext context;
+  TestResponseContext context;
   client.PrepareAsyncTxn(&context);
-  EXPECT_TRUE(client.IsPending(context.txid));
+  EXPECT_TRUE(client.IsPending(context.Txid()));
 
   // Forget the transaction.
   client.ForgetAsyncTxn(&context);
@@ -320,6 +332,34 @@
   EXPECT_OK(sync_completion_wait(&unbound, ZX_TIME_INFINITE));
 }
 
+TEST(ClientBaseTestCase, ReleaseOutstandingTxnsOnDestroy) {
+  class ReleaseTestResponseContext : public ResponseContext {
+   public:
+    ReleaseTestResponseContext(sync_completion_t* done) : done_(done) {}
+    void OnError() {
+      sync_completion_signal(done_);
+      delete this;
+    }
+    sync_completion_t* done_;
+  };
+
+  async::Loop loop(&kAsyncLoopConfigNoAttachToCurrentThread);
+  ASSERT_OK(loop.StartThread());
+
+  zx::channel local, remote;
+  ASSERT_OK(zx::channel::create(0, &local, &remote));
+
+  auto* client = new TestClient(std::move(local), loop.dispatcher(), nullptr);
+
+  // Create and register a response context which will signal when deleted.
+  sync_completion_t done;
+  client->PrepareAsyncTxn(new ReleaseTestResponseContext(&done));
+
+  // Delete the client and ensure that the response context is deleted.
+  delete client;
+  EXPECT_OK(sync_completion_wait(&done, ZX_TIME_INFINITE));
+}
+
 }  // namespace
 }  // namespace internal
 }  // namespace fidl