[netstack2] Fix endpoint double close bug and handle receiving TCP RST immediately

This change addresses a bug which leads to an endpoint (that can have references
to it by multiple socketImpls) may attempt to cleanup its resources multiple
times, resulting in a panic when we attempt to close a closed go channel.

The bug:
1) Accept is called on a listening socket.
2) Create an endpoint and a socketImpl referencing the endpoint, return the
   socketImpl.
3) Receive a TCP RST which closes the endpoint (it only had a ref count of 1
   which drops to 0 so it gets closed).
4) Clone the socketImpl which results in a new socketImpl but with a reference
   to the original endpoint (the ref count of the endpoint increases to 1).
5) Close the new socketImpl (which closes the endpoint again since the ref count
   drops back to 0).

Step 5 is where the bug occurs (before this change, the cleanup would would be
done twice).

This change also addresses another bug where a TCP connection may be closed
immediately after being accepted (receiving a TCP RST on an endpoint that was
just Accepted but not yet returned to the caller of Accept).

Test: Test that closing an endpoint twice will not result in a panic.

Bug: 37475

Change-Id: Ie19f45cb2d864ee13f64ae1685eaa205a3802a46
diff --git a/src/connectivity/network/netstack/netstack_test.go b/src/connectivity/network/netstack/netstack_test.go
index 73d9ad8..9993ed2 100644
--- a/src/connectivity/network/netstack/netstack_test.go
+++ b/src/connectivity/network/netstack/netstack_test.go
@@ -30,6 +30,8 @@
 	"github.com/google/netstack/tcpip/network/ipv4"
 	"github.com/google/netstack/tcpip/network/ipv6"
 	tcpipstack "github.com/google/netstack/tcpip/stack"
+	"github.com/google/netstack/tcpip/transport/tcp"
+	"github.com/google/netstack/waiter"
 )
 
 const (
@@ -39,6 +41,71 @@
 	testV6Address  tcpip.Address = tcpip.Address("\xc0\xa8\x2a\x10\xc0\xa8\x2a\x10\xc0\xa8\x2a\x10\xc0\xa8\x2a\x10")
 )
 
+// TestEndpointDoubleClose tests that closing an endpoint that has already been
+// closed once will not panic. This is in response to a bug where a socketImpl
+// (whose endpoint had already been closed) was cloned, resulting in a second
+// socketImpl with a reference to the already closed endpoint. When this second
+// socketImpl closes, it will attempt to close the endpoint (which is already
+// closed) resulting in a panic.
+func TestEndpointDoubleClose(t *testing.T) {
+	ns := newNetstack(t)
+	wq := &waiter.Queue{}
+	ep, err := ns.mu.stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, wq)
+	if err != nil {
+		t.Fatalf("NewEndpoint = %s", err)
+	}
+
+	ios := &endpoint{
+		netProto:      ipv6.ProtocolNumber,
+		transProto:    tcp.ProtocolNumber,
+		wq:            wq,
+		ep:            ep,
+		loopReadDone:  make(chan struct{}),
+		loopWriteDone: make(chan struct{}),
+		closing:       make(chan struct{}),
+	}
+
+	{
+		localS, peerS, err := zx.NewSocket(uint32(zx.SocketStream))
+		if err != nil {
+			t.Fatalf("zx.NewSocket = %s", err)
+		}
+
+		ios.local = localS
+		ios.peer = peerS
+	}
+
+	// We set clones to 2 initially to make sure that resources only get
+	// cleaned up when the ref count drops from 1 to 0.
+	ios.clones = 2
+	if refcount := ios.close(); refcount != 1 {
+		t.Fatalf("got refcount = %d, want = 1", refcount)
+	}
+	select {
+	case <-ios.closing:
+		t.Fatal("ios.closing is closed")
+	default:
+	}
+
+	// Ref count is now 1 so when we call close again, it will cleanup
+	// associated resources.
+	if refcount := ios.close(); refcount != 0 {
+		t.Fatalf("got refcount = %d, want = 0", refcount)
+	}
+	select {
+	case <-ios.closing:
+	default:
+		t.Fatal("ios.closing is not closed")
+	}
+
+	// Set ref count to 1 so it drops to 0 and make sure we do not
+	// do the work of closing again, and therefore should not panic.
+	ios.clones = 1
+	if refcount := ios.close(); refcount != 0 {
+		t.Fatalf("got refcount = %d, want = 0", refcount)
+	}
+}
+
 func TestNicName(t *testing.T) {
 	ns := newNetstack(t)
 
@@ -307,6 +374,9 @@
 			ipv4.NewProtocol(),
 			ipv6.NewProtocol(),
 		},
+		TransportProtocols: []tcpipstack.TransportProtocol{
+			tcp.NewProtocol(),
+		},
 	})
 
 	// We need to initialize the DNS client, since adding/removing interfaces
diff --git a/src/connectivity/network/netstack/socket_server.go b/src/connectivity/network/netstack/socket_server.go
index 9f379df..e4d34a1 100644
--- a/src/connectivity/network/netstack/socket_server.go
+++ b/src/connectivity/network/netstack/socket_server.go
@@ -73,6 +73,10 @@
 	//  - loop{Read,Write}Done are signaled iff loop{Read,Write} have
 	//    exited, respectively.
 	closing, loopReadDone, loopWriteDone chan struct{}
+
+	// This is used to make sure that endpoint.close only cleans up its
+	// resources once - the first time it was closed.
+	closeOnce sync.Once
 }
 
 // loopWrite connects libc write to the network stack.
@@ -494,39 +498,44 @@
 // When called, close signals loopRead and loopWrite (via endpoint.closing and
 // ios.local) to exit, and then blocks until its arguments are signaled. close
 // is typically called with ios.loop{Read,Write}Done.
+//
+// Note, calling close on an endpoint that has already been closed is safe as
+// the cleanup work will only be done once.
 func (ios *endpoint) close(loopDone ...<-chan struct{}) int64 {
 	clones := atomic.AddInt64(&ios.clones, -1)
 
 	if clones == 0 {
-		// Interrupt waits on notification channels. Notification reads
-		// are always combined with ios.closing in a select statement.
-		close(ios.closing)
+		ios.closeOnce.Do(func() {
+			// Interrupt waits on notification channels. Notification reads
+			// are always combined with ios.closing in a select statement.
+			close(ios.closing)
 
-		// Interrupt waits on endpoint.local. Handle waits always
-		// include localSignalClosing.
-		if err := ios.local.Handle().Signal(0, localSignalClosing); err != nil {
-			panic(err)
-		}
+			// Interrupt waits on endpoint.local. Handle waits always
+			// include localSignalClosing.
+			if err := ios.local.Handle().Signal(0, localSignalClosing); err != nil {
+				panic(err)
+			}
 
-		// The interruptions above cause our loops to exit. Wait until
-		// they do before releasing resources they may be using.
-		for _, ch := range loopDone {
-			<-ch
-		}
+			// The interruptions above cause our loops to exit. Wait until
+			// they do before releasing resources they may be using.
+			for _, ch := range loopDone {
+				<-ch
+			}
 
-		ios.ep.Close()
+			ios.ep.Close()
 
-		// HACK(crbug.com/1005300): chromium mojo code expects this; it doesn't
-		// care if the socket is closed.
-		ios.local.Shutdown(zx.SocketShutdownRead | zx.SocketShutdownWrite)
+			// HACK(crbug.com/1005300): chromium mojo code expects this; it doesn't
+			// care if the socket is closed.
+			ios.local.Shutdown(zx.SocketShutdownRead | zx.SocketShutdownWrite)
 
-		if err := ios.local.Close(); err != nil {
-			panic(err)
-		}
+			if err := ios.local.Close(); err != nil {
+				panic(err)
+			}
 
-		if err := ios.peer.Close(); err != nil {
-			panic(err)
-		}
+			if err := ios.peer.Close(); err != nil {
+				panic(err)
+			}
+		})
 	}
 
 	return clones
@@ -705,14 +714,30 @@
 	}
 
 	localAddr, err := ep.GetLocalAddress()
-	if err != nil {
+	if err == tcpip.ErrNotConnected {
+		// This should never happen as of writing as GetLocalAddress
+		// does not actually return any errors. However, we handle
+		// the tcpip.ErrNotConnected case now for the same reasons
+		// as mentioned below for the ep.GetRemoteAddress case.
+		syslog.VLogTf(syslog.DebugVerbosity, "accept", "%p: disconnected", s.endpoint)
+	} else if err != nil {
 		panic(err)
+	} else {
+		// GetRemoteAddress returns a tcpip.ErrNotConnected error if ep is no
+		// longer connected. This can happen if the endpoint was closed after
+		// the call to Accept returned, but before this point. A scenario this
+		// was actually witnessed was when a TCP RST was received after the call
+		// to Accept returned, but before this point. If GetRemoteAddress
+		// returns other (unexpected) errors, panic.
+		remoteAddr, err := ep.GetRemoteAddress()
+		if err == tcpip.ErrNotConnected {
+			syslog.VLogTf(syslog.DebugVerbosity, "accept", "%p: local=%+v, disconnected", s.endpoint, localAddr)
+		} else if err != nil {
+			panic(err)
+		} else {
+			syslog.VLogTf(syslog.DebugVerbosity, "accept", "%p: local=%+v, remote=%+v", s.endpoint, localAddr, remoteAddr)
+		}
 	}
-	remoteAddr, err := ep.GetRemoteAddress()
-	if err != nil {
-		panic(err)
-	}
-	syslog.VLogTf(syslog.DebugVerbosity, "accept", "%p: local=%+v, remote=%+v", s.endpoint, localAddr, remoteAddr)
 
 	{
 		controlInterface, err := newSocket(s.endpoint.netProto, s.endpoint.transProto, wq, ep, s.controlService)
diff --git a/src/connectivity/network/tests/BUILD.gn b/src/connectivity/network/tests/BUILD.gn
index d453e37..3fbc5e7 100644
--- a/src/connectivity/network/tests/BUILD.gn
+++ b/src/connectivity/network/tests/BUILD.gn
@@ -134,6 +134,7 @@
 
   deps = [
     "//src/lib/fxl/test:gtest_main",
+    "//zircon/public/lib/fbl",
     "//zircon/public/lib/fdio",
     "//zircon/public/lib/sync",
     "//zircon/system/fidl/fuchsia-posix-socket",
diff --git a/src/connectivity/network/tests/bsdsocket_test.cc b/src/connectivity/network/tests/bsdsocket_test.cc
index 25c3352d..4a946ad 100644
--- a/src/connectivity/network/tests/bsdsocket_test.cc
+++ b/src/connectivity/network/tests/bsdsocket_test.cc
@@ -21,70 +21,6 @@
 
 namespace {
 
-static void fill_stream_send_buf(int fd, int peer_fd) {
-  // We're about to fill the send buffer; shrink it and the other side's receive buffer to the
-  // minimum allowed.
-  {
-    const int bufsize = 1;
-    socklen_t optlen = sizeof(bufsize);
-
-    EXPECT_EQ(setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &bufsize, optlen), 0) << strerror(errno);
-    EXPECT_EQ(setsockopt(peer_fd, SOL_SOCKET, SO_RCVBUF, &bufsize, optlen), 0) << strerror(errno);
-  }
-
-  int sndbuf_opt;
-  socklen_t sndbuf_optlen = sizeof(sndbuf_opt);
-  EXPECT_EQ(getsockopt(fd, SOL_SOCKET, SO_SNDBUF, &sndbuf_opt, &sndbuf_optlen), 0)
-      << strerror(errno);
-  EXPECT_EQ(sndbuf_optlen, sizeof(sndbuf_opt));
-
-  int rcvbuf_opt;
-  socklen_t rcvbuf_optlen = sizeof(rcvbuf_opt);
-  EXPECT_EQ(getsockopt(peer_fd, SOL_SOCKET, SO_RCVBUF, &rcvbuf_opt, &rcvbuf_optlen), 0)
-      << strerror(errno);
-  EXPECT_EQ(rcvbuf_optlen, sizeof(rcvbuf_opt));
-
-  // Now that the buffers involved are minimal, we can temporarily make the socket non-blocking on
-  // Linux without introducing flakiness. We can't do that on Fuchsia because of the asynchronous
-  // copy from the zircon socket to the "real" send buffer, which takes a bit of time, so we use
-  // a small timeout which was empirically tested to ensure no flakiness is introduced.
-#if defined(__linux__)
-  int flags;
-  EXPECT_GE(flags = fcntl(fd, F_GETFL), 0) << strerror(errno);
-  EXPECT_EQ(fcntl(fd, F_SETFL, flags | O_NONBLOCK), 0) << strerror(errno);
-#else
-  struct timeval original_tv;
-  socklen_t tv_len = sizeof(original_tv);
-  EXPECT_EQ(getsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &original_tv, &tv_len), 0) << strerror(errno);
-  EXPECT_EQ(tv_len, sizeof(original_tv));
-  const struct timeval tv = {
-      .tv_sec = 0,
-      .tv_usec = 1 << 16,  // ~65ms
-  };
-  EXPECT_EQ(setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), 0) << strerror(errno);
-#endif
-
-  // buf size should be neither too small in which case too many writes operation is required
-  // to fill out the sending buffer nor too big in which case a big stack is needed for the buf
-  // array.
-  int cnt = 0;
-  {
-    char buf[sndbuf_opt + rcvbuf_opt];
-    int size;
-    while ((size = write(fd, buf, sizeof(buf))) > 0) {
-      cnt += size;
-    }
-  }
-  EXPECT_GT(cnt, 0);
-  ASSERT_TRUE(errno == EAGAIN || errno == EWOULDBLOCK) << strerror(errno);
-
-#if defined(__linux__)
-  EXPECT_EQ(fcntl(fd, F_SETFL, flags), 0) << strerror(errno);
-#else
-  EXPECT_EQ(setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &original_tv, tv_len), 0) << strerror(errno);
-#endif
-}
-
 // Raw sockets are typically used for implementing custom protocols. We intend to support custom
 // protocols through structured FIDL APIs in the future, so this test ensures that raw sockets are
 // disabled to prevent them from accidentally becoming load-bearing.
diff --git a/src/connectivity/network/tests/fdio_test.cc b/src/connectivity/network/tests/fdio_test.cc
index 32c613e..949c334 100644
--- a/src/connectivity/network/tests/fdio_test.cc
+++ b/src/connectivity/network/tests/fdio_test.cc
@@ -5,14 +5,17 @@
 // These tests ensure the zircon libc can talk to netstack.
 // No network connection is required, only a running netstack binary.
 
-#include <thread>
-
 #include <fuchsia/posix/socket/cpp/fidl.h>
 #include <lib/fdio/fd.h>
 #include <lib/sync/completion.h>
+#include <poll.h>
 #include <zircon/status.h>
 #include <zircon/syscalls.h>
 
+#include <thread>
+
+#include <fbl/unique_fd.h>
+
 #include "gtest/gtest.h"
 #include "util.h"
 
@@ -155,3 +158,74 @@
             ZX_OK)
       << zx_status_get_string(status);
 }
+
+TEST(SocketTest, CloseClonedSocketAfterTcpRst) {
+  // Create the listening endpoint (server).
+  fbl::unique_fd serverfd;
+  ASSERT_TRUE(serverfd = fbl::unique_fd(socket(AF_INET, SOCK_STREAM, 0))) << strerror(errno);
+
+  struct sockaddr_in addr = {};
+  addr.sin_family = AF_INET;
+  addr.sin_addr.s_addr = htonl(INADDR_ANY);
+  ASSERT_EQ(bind(serverfd.get(), reinterpret_cast<const struct sockaddr*>(&addr), sizeof(addr)), 0)
+      << strerror(errno);
+  ASSERT_EQ(listen(serverfd.get(), 1), 0) << strerror(errno);
+
+  // Get the address the server is listening on.
+  socklen_t addrlen = sizeof(addr);
+  ASSERT_EQ(getsockname(serverfd.get(), reinterpret_cast<struct sockaddr*>(&addr), &addrlen), 0)
+      << strerror(errno);
+  ASSERT_EQ(addrlen, sizeof(addr));
+
+  // Connect to the listening endpoint (client).
+  fbl::unique_fd clientfd;
+  ASSERT_TRUE(clientfd = fbl::unique_fd(socket(AF_INET, SOCK_STREAM, 0))) << strerror(errno);
+  ASSERT_EQ(connect(clientfd.get(), reinterpret_cast<const struct sockaddr*>(&addr), sizeof(addr)),
+            0)
+      << strerror(errno);
+
+  // Accept the new connection (client) on the listening endpoint (server).
+  fbl::unique_fd connfd;
+  ASSERT_TRUE(connfd = fbl::unique_fd(accept(serverfd.get(), nullptr, nullptr))) << strerror(errno);
+  ASSERT_EQ(close(serverfd.release()), 0) << strerror(errno);
+
+  // Fill up the rcvbuf (client-side).
+  fill_stream_send_buf(connfd.get(), clientfd.get());
+
+  // Closing the client-side connection while it has data that has not been
+  // read by the client should trigger a TCP RST.
+  ASSERT_EQ(close(clientfd.release()), 0) << strerror(errno);
+
+  struct pollfd pfd = {};
+  pfd.fd = connfd.get();
+  pfd.events = POLLOUT;
+  int n = poll(&pfd, 1, kTimeout);
+  ASSERT_GE(n, 0) << strerror(errno);
+  ASSERT_EQ(n, 1);
+  // TODO(crbug.com/1005300): we should check that revents is exactly
+  // OUT|ERR|HUP. Currently, this is a bit racey, and we might see OUT and HUP
+  // but not ERR due to the hack in socket_server.go which references this same
+  // bug.
+  ASSERT_TRUE(pfd.revents & (POLLOUT | POLLHUP)) << pfd.revents;
+
+  // Now that the socket's endpoint has been closed, clone the socket (twice
+  // to increase the endpoint's reference count to at least 1), then close all
+  // copies of the socket.
+  zx_status_t status;
+  zx::channel channel1, channel2;
+  ASSERT_EQ(status = fdio_fd_clone(connfd.get(), channel1.reset_and_get_address()), ZX_OK)
+      << zx_status_get_string(status);
+  ASSERT_EQ(status = fdio_fd_clone(connfd.get(), channel2.reset_and_get_address()), ZX_OK)
+      << zx_status_get_string(status);
+
+  zx_status_t io_status;
+  fuchsia::posix::socket::Control_SyncProxy control1(std::move(channel1));
+  ASSERT_EQ(io_status = control1.Close(&status), ZX_OK) << zx_status_get_string(io_status);
+  ASSERT_EQ(status, ZX_OK) << zx_status_get_string(status);
+
+  fuchsia::posix::socket::Control_SyncProxy control2(std::move(channel2));
+  ASSERT_EQ(io_status = control2.Close(&status), ZX_OK) << zx_status_get_string(io_status);
+  ASSERT_EQ(status, ZX_OK) << zx_status_get_string(status);
+
+  ASSERT_EQ(close(connfd.release()), 0) << strerror(errno);
+}
diff --git a/src/connectivity/network/tests/util.cc b/src/connectivity/network/tests/util.cc
index a0b46ef..9dcc71d 100644
--- a/src/connectivity/network/tests/util.cc
+++ b/src/connectivity/network/tests/util.cc
@@ -5,6 +5,7 @@
 #include "util.h"
 
 #include <arpa/inet.h>
+#include <fcntl.h>
 #include <poll.h>
 #include <sys/socket.h>
 
@@ -208,3 +209,67 @@
 
   NotifySuccess(ntfyfd);
 }
+
+void fill_stream_send_buf(int fd, int peer_fd) {
+  // We're about to fill the send buffer; shrink it and the other side's receive buffer to the
+  // minimum allowed.
+  {
+    const int bufsize = 1;
+    socklen_t optlen = sizeof(bufsize);
+
+    EXPECT_EQ(setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &bufsize, optlen), 0) << strerror(errno);
+    EXPECT_EQ(setsockopt(peer_fd, SOL_SOCKET, SO_RCVBUF, &bufsize, optlen), 0) << strerror(errno);
+  }
+
+  int sndbuf_opt;
+  socklen_t sndbuf_optlen = sizeof(sndbuf_opt);
+  EXPECT_EQ(getsockopt(fd, SOL_SOCKET, SO_SNDBUF, &sndbuf_opt, &sndbuf_optlen), 0)
+      << strerror(errno);
+  EXPECT_EQ(sndbuf_optlen, sizeof(sndbuf_opt));
+
+  int rcvbuf_opt;
+  socklen_t rcvbuf_optlen = sizeof(rcvbuf_opt);
+  EXPECT_EQ(getsockopt(peer_fd, SOL_SOCKET, SO_RCVBUF, &rcvbuf_opt, &rcvbuf_optlen), 0)
+      << strerror(errno);
+  EXPECT_EQ(rcvbuf_optlen, sizeof(rcvbuf_opt));
+
+  // Now that the buffers involved are minimal, we can temporarily make the socket non-blocking on
+  // Linux without introducing flakiness. We can't do that on Fuchsia because of the asynchronous
+  // copy from the zircon socket to the "real" send buffer, which takes a bit of time, so we use
+  // a small timeout which was empirically tested to ensure no flakiness is introduced.
+#if defined(__linux__)
+  int flags;
+  EXPECT_GE(flags = fcntl(fd, F_GETFL), 0) << strerror(errno);
+  EXPECT_EQ(fcntl(fd, F_SETFL, flags | O_NONBLOCK), 0) << strerror(errno);
+#else
+  struct timeval original_tv;
+  socklen_t tv_len = sizeof(original_tv);
+  EXPECT_EQ(getsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &original_tv, &tv_len), 0) << strerror(errno);
+  EXPECT_EQ(tv_len, sizeof(original_tv));
+  const struct timeval tv = {
+      .tv_sec = 0,
+      .tv_usec = 1 << 16,  // ~65ms
+  };
+  EXPECT_EQ(setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), 0) << strerror(errno);
+#endif
+
+  // buf size should be neither too small in which case too many writes operation is required
+  // to fill out the sending buffer nor too big in which case a big stack is needed for the buf
+  // array.
+  int cnt = 0;
+  {
+    char buf[sndbuf_opt + rcvbuf_opt];
+    int size;
+    while ((size = write(fd, buf, sizeof(buf))) > 0) {
+      cnt += size;
+    }
+  }
+  EXPECT_GT(cnt, 0);
+  ASSERT_TRUE(errno == EAGAIN || errno == EWOULDBLOCK) << strerror(errno);
+
+#if defined(__linux__)
+  EXPECT_EQ(fcntl(fd, F_SETFL, flags), 0) << strerror(errno);
+#else
+  EXPECT_EQ(setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &original_tv, tv_len), 0) << strerror(errno);
+#endif
+}
diff --git a/src/connectivity/network/tests/util.h b/src/connectivity/network/tests/util.h
index d5ce829..86bb937 100644
--- a/src/connectivity/network/tests/util.h
+++ b/src/connectivity/network/tests/util.h
@@ -23,4 +23,7 @@
                   int ntfyfd, int timeout);
 void DatagramReadWrite(int recvfd, int ntfyfd);
 void DatagramReadWriteV6(int recvfd, int ntfyfd);
+
+void fill_stream_send_buf(int fd, int peer_fd);
+
 #endif  // SRC_CONNECTIVITY_NETWORK_TESTS_UTIL_H_