[netstack3] Remove #[async_trait] workaround

Use the newly-approved async_fn_in_trait unstable feature in Netstack3
instead of the #[async_trait] macro workaround. This removes an instance
of heap allocation and dynamic dispatch on the socket control (and data,
for UDP/ICMP) path.

Bug: 122546
Change-Id: I8835873be858f1a60ca7c1f4de55d9d4fd16d305
Reviewed-on: https://fuchsia-review.googlesource.com/c/fuchsia/+/810109
Fuchsia-Auto-Submit: Alex Konradi <akonradi@google.com>
Reviewed-by: Jeff Martin <martinjeffrey@google.com>
Commit-Queue: Auto-Submit <auto-submit@fuchsia-infra.iam.gserviceaccount.com>
Reviewed-by: Ghanan Gowripalan <ghanan@google.com>
diff --git a/src/connectivity/network/netstack3/BUILD.gn b/src/connectivity/network/netstack3/BUILD.gn
index c1a882a..82ebfeb 100644
--- a/src/connectivity/network/netstack3/BUILD.gn
+++ b/src/connectivity/network/netstack3/BUILD.gn
@@ -45,10 +45,6 @@
   "//third_party/rust_crates:thiserror",
   "//third_party/rust_crates:tracing",
   "//zircon/system/ulib/backtrace-request/rust",
-
-  # TODO(https://fxbug.dev/122464): Replace this with
-  # #![feature(async_fn_in_trait)] when available.
-  "//third_party/rust_crates:async-trait",
 ]
 
 common_test_deps = [
@@ -98,6 +94,7 @@
   test_deps = common_test_deps
 
   sources = common_sources
+  configs += [ "//build/config/rust:async_fn_in_trait" ]
   configs -= [ "//build/config/rust/lints:allow_unused_results" ]
 }
 
@@ -112,6 +109,7 @@
   test_deps = common_test_deps
 
   sources = common_sources
+  configs += [ "//build/config/rust:async_fn_in_trait" ]
   configs -= [ "//build/config/rust/lints:allow_unused_results" ]
 }
 
@@ -170,6 +168,7 @@
     "core:netstack3-core-benchmarks",
     "//src/developer/fuchsia-criterion",
   ]
+  configs += [ "//build/config/rust:async_fn_in_trait" ]
   configs -= [ "//build/config/rust/lints:allow_unused_results" ]
 }
 
diff --git a/src/connectivity/network/netstack3/src/bindings/socket/datagram.rs b/src/connectivity/network/netstack3/src/bindings/socket/datagram.rs
index 2400fc7..c2bddca 100644
--- a/src/connectivity/network/netstack3/src/bindings/socket/datagram.rs
+++ b/src/connectivity/network/netstack3/src/bindings/socket/datagram.rs
@@ -19,12 +19,10 @@
 use fidl_fuchsia_posix_socket as fposix_socket;
 
 use assert_matches::assert_matches;
-// TODO(https://fxbug.dev/122464): Use #![feature(async_fn_in_trait)] when
-// available.
-use async_trait::async_trait;
 use explicit::ResultExt as _;
 use fidl::endpoints::RequestStream as _;
 use fidl_fuchsia_unknown::CloseableCloseResult;
+use fuchsia_async as fasync;
 use fuchsia_zircon::{self as zx, prelude::HandleBased as _, Peered as _};
 use log::{error, trace, warn};
 use net_types::{
@@ -1434,18 +1432,27 @@
 ) -> Result<(), fposix::Errno> {
     match (domain, proto) {
         (fposix_socket::Domain::Ipv4, fposix_socket::DatagramSocketProtocol::Udp) => {
-            SocketWorker::<BindingData<Ipv4, Udp>>::spawn(ctx, properties, events)
+            fasync::Task::spawn(SocketWorker::<BindingData<Ipv4, Udp>>::serve_stream(
+                ctx, properties, events,
+            ))
         }
         (fposix_socket::Domain::Ipv6, fposix_socket::DatagramSocketProtocol::Udp) => {
-            SocketWorker::<BindingData<Ipv6, Udp>>::spawn(ctx, properties, events)
+            fasync::Task::spawn(SocketWorker::<BindingData<Ipv6, Udp>>::serve_stream(
+                ctx, properties, events,
+            ))
         }
         (fposix_socket::Domain::Ipv4, fposix_socket::DatagramSocketProtocol::IcmpEcho) => {
-            SocketWorker::<BindingData<Ipv4, IcmpEcho>>::spawn(ctx, properties, events)
+            fasync::Task::spawn(SocketWorker::<BindingData<Ipv4, IcmpEcho>>::serve_stream(
+                ctx, properties, events,
+            ))
         }
         (fposix_socket::Domain::Ipv6, fposix_socket::DatagramSocketProtocol::IcmpEcho) => {
-            SocketWorker::<BindingData<Ipv6, IcmpEcho>>::spawn(ctx, properties, events)
+            fasync::Task::spawn(SocketWorker::<BindingData<Ipv6, IcmpEcho>>::serve_stream(
+                ctx, properties, events,
+            ))
         }
     }
+    .detach();
     Ok(())
 }
 
@@ -1455,7 +1462,6 @@
     }
 }
 
-#[async_trait]
 impl<I, T> worker::SocketWorkerHandler for BindingData<I, T>
 where
     I: SocketCollectionIpExt<T> + IpExt + IpSockAddrExt,
diff --git a/src/connectivity/network/netstack3/src/bindings/socket/stream.rs b/src/connectivity/network/netstack3/src/bindings/socket/stream.rs
index 697e334..b8d8d73 100644
--- a/src/connectivity/network/netstack3/src/bindings/socket/stream.rs
+++ b/src/connectivity/network/netstack3/src/bindings/socket/stream.rs
@@ -13,7 +13,6 @@
 };
 
 use assert_matches::assert_matches;
-use async_trait::async_trait;
 use explicit::ResultExt as _;
 use fidl::{
     endpoints::{ClientEnd, RequestStream as _},
@@ -486,9 +485,6 @@
     }
 }
 
-// TODO(https://fxbug.dev/122464): Use #![feature(async_fn_in_trait)] when
-// available.
-#[async_trait]
 impl<I: IpExt + IpSockAddrExt> worker::SocketWorkerHandler for BindingData<I>
 where
     DeviceId<BindingsNonSyncCtxImpl>:
@@ -553,12 +549,21 @@
 {
     match (domain, proto) {
         (fposix_socket::Domain::Ipv4, fposix_socket::StreamSocketProtocol::Tcp) => {
-            SocketWorker::<BindingData<Ipv4>>::spawn(ctx, SocketWorkerProperties {}, request_stream)
+            fasync::Task::spawn(SocketWorker::<BindingData<Ipv4>>::serve_stream(
+                ctx,
+                SocketWorkerProperties {},
+                request_stream,
+            ))
         }
         (fposix_socket::Domain::Ipv6, fposix_socket::StreamSocketProtocol::Tcp) => {
-            SocketWorker::<BindingData<Ipv6>>::spawn(ctx, SocketWorkerProperties {}, request_stream)
+            fasync::Task::spawn(SocketWorker::<BindingData<Ipv6>>::serve_stream(
+                ctx,
+                SocketWorkerProperties {},
+                request_stream,
+            ))
         }
     }
+    .detach()
 }
 
 impl IntoErrno for AcceptError {
@@ -842,17 +847,7 @@
                     fidl::endpoints::create_request_stream::<fposix_socket::StreamSocketMarker>()
                         .expect("failed to create new fidl endpoints");
                 spawn_send_task::<I>(ctx.clone(), socket, watcher, accepted);
-                SocketWorker::<BindingData<I>>::spawn_with(
-                    ctx.clone(),
-                    move |_: &mut SyncCtx<_>,
-                          _: &mut BindingsNonSyncCtxImpl,
-                          SocketWorkerProperties {}| BindingData {
-                        id: SocketId::Connection(accepted, true),
-                        peer,
-                    },
-                    SocketWorkerProperties {},
-                    request_stream,
-                );
+                spawn_connected_socket_task(ctx.clone(), accepted, peer, request_stream);
                 Ok((want_addr.then(|| Box::new(addr.into_sock_addr())), client))
             }
             SocketId::Unbound(_, _) | SocketId::Connection(_, _) | SocketId::Bound(_, _) => {
@@ -1502,6 +1497,28 @@
     }
 }
 
+fn spawn_connected_socket_task<I: IpExt + IpSockAddrExt>(
+    ctx: NetstackContext,
+    accepted: ConnectionId<I>,
+    peer: zx::Socket,
+    request_stream: fposix_socket::StreamSocketRequestStream,
+) where
+    DeviceId<BindingsNonSyncCtxImpl>:
+        TryFromFidlWithContext<<I::SocketAddress as SockAddr>::Zone, Error = DeviceNotFoundError>,
+    WeakDeviceId<BindingsNonSyncCtxImpl>:
+        TryIntoFidlWithContext<<I::SocketAddress as SockAddr>::Zone, Error = DeviceNotFoundError>,
+{
+    fasync::Task::spawn(SocketWorker::<BindingData<I>>::serve_stream_with(
+        ctx,
+        move |_: &mut SyncCtx<_>, _: &mut BindingsNonSyncCtxImpl, SocketWorkerProperties {}| {
+            BindingData { id: SocketId::Connection(accepted, true), peer }
+        },
+        SocketWorkerProperties {},
+        request_stream,
+    ))
+    .detach();
+}
+
 impl<A: IpAddress, D> TryIntoFidlWithContext<<A::Version as IpSockAddrExt>::SocketAddress>
     for SocketAddr<A, D>
 where
diff --git a/src/connectivity/network/netstack3/src/bindings/socket/worker.rs b/src/connectivity/network/netstack3/src/bindings/socket/worker.rs
index 25f55a4..b40ecac 100644
--- a/src/connectivity/network/netstack3/src/bindings/socket/worker.rs
+++ b/src/connectivity/network/netstack3/src/bindings/socket/worker.rs
@@ -4,12 +4,10 @@
 
 use std::ops::{ControlFlow, DerefMut};
 
-use async_trait::async_trait;
 use async_utils::stream::OneOrMany;
 use fidl::endpoints::{ControlHandle, RequestStream};
 use fidl_fuchsia_unknown::CloseableCloseResult;
-use fuchsia_async as fasync;
-use futures::{StreamExt as _, TryFutureExt as _};
+use futures::StreamExt as _;
 use log::error;
 use netstack3_core::{Ctx, SyncCtx};
 
@@ -33,7 +31,6 @@
 /// handler instance.
 // TODO(https://fxbug.dev/122464): Use #![feature(async_fn_in_trait)] when
 // available.
-#[async_trait]
 pub(crate) trait SocketWorkerHandler: Send + 'static {
     /// The type of request that this worker can handle.
     type Request: Send;
@@ -95,21 +92,22 @@
 
 impl<H: SocketWorkerHandler> SocketWorker<H> {
     /// Starts servicing events from the provided event stream.
-    pub(crate) fn spawn(
+    pub(crate) async fn serve_stream(
         ctx: NetstackContext,
         properties: SocketWorkerProperties,
         events: H::RequestStream,
     ) {
-        Self::spawn_with(
+        Self::serve_stream_with(
             ctx,
             |sync_ctx, non_sync_ctx, properties| H::new(sync_ctx, non_sync_ctx, properties),
             properties,
             events,
-        );
+        )
+        .await
     }
 
     /// Starts servicing events from the provided state and event stream.
-    pub(crate) fn spawn_with<
+    pub(crate) async fn serve_stream_with<
         F: FnOnce(
                 &mut SyncCtx<BindingsNonSyncCtxImpl>,
                 &mut BindingsNonSyncCtxImpl,
@@ -123,27 +121,22 @@
         properties: SocketWorkerProperties,
         events: H::RequestStream,
     ) {
-        fasync::Task::spawn(
-            async move {
-                let data = {
-                    let mut guard = ctx.lock().await;
-                    let Ctx { sync_ctx, non_sync_ctx } = guard.deref_mut();
+        let data = {
+            let mut guard = ctx.lock().await;
+            let Ctx { sync_ctx, non_sync_ctx } = guard.deref_mut();
 
-                    make_data(sync_ctx, non_sync_ctx, properties)
-                };
-                let worker = Self { ctx, data };
+            make_data(sync_ctx, non_sync_ctx, properties)
+        };
+        let worker = Self { ctx, data };
 
-                worker.handle_stream(events).await
-            }
-            // When the closure above finishes, that means `self` goes out of
-            // scope and is dropped, meaning that the event stream's underlying
-            // channel is closed. If any errors occurred as a result of the
-            // closure, we just log them.
-            .unwrap_or_else(|e: fidl::Error| error!("socket control request error: {:?}", e)),
-        )
-        // TODO(https://fxbug.dev/122464): Move the detach higher up so callers
-        // have to deal with it.
-        .detach();
+        // When the worker finishes, that means `self` goes out of scope and is
+        // dropped, meaning that the event stream's underlying channel is
+        // closed. If any errors occurred as a result of the closure, we just
+        // log them.
+        worker
+            .handle_stream(events)
+            .await
+            .unwrap_or_else(|e: fidl::Error| error!("socket control request error: {:?}", e))
     }
 
     /// Handles a stream of POSIX socket requests.
diff --git a/src/connectivity/network/netstack3/src/main.rs b/src/connectivity/network/netstack3/src/main.rs
index 2d5c2b0..ca4f5ee 100644
--- a/src/connectivity/network/netstack3/src/main.rs
+++ b/src/connectivity/network/netstack3/src/main.rs
@@ -5,6 +5,8 @@
 //! A networking stack.
 #![deny(missing_docs, unreachable_patterns, unused)]
 #![recursion_limit = "256"]
+#![allow(incomplete_features)]
+#![feature(async_fn_in_trait)]
 
 #[cfg(feature = "instrumented")]
 extern crate netstack3_core_instrumented as netstack3_core;