[netstack] unify TCP/UDP read/write loops
Test: CQ
Change-Id: I5d19015c7eb9535d4987ccf22740878eddd6b2eb
diff --git a/go/src/netstack/socket_server.go b/go/src/netstack/socket_server.go
index 1e617ed..a3fe91f 100644
--- a/go/src/netstack/socket_server.go
+++ b/go/src/netstack/socket_server.go
@@ -17,7 +17,6 @@
"time"
"github.com/google/netstack/tcpip"
- "github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
"github.com/google/netstack/tcpip/network/ipv4"
"github.com/google/netstack/tcpip/network/ipv6"
@@ -73,7 +72,7 @@
func signalConnectSuccess(s zx.Socket) error {
// CONNECTED should be sent to the peer before it is sent locally.
// That ensures the peer detects the connection before any data is written by
- // loopStreamRead.
+ // loopRead.
err := sendSignal(s, ZXSIO_SIGNAL_OUTGOING|ZXSIO_SIGNAL_CONNECTED, true)
if err != nil {
return err
@@ -95,412 +94,252 @@
mu sync.Mutex
lastError *tcpip.Error // if not-nil, next error returned via getsockopt
- loopWriteDone chan struct{} // report that loop[Stream|Dgram]Write finished
+ loopWriteDone chan struct{} // report that loopWrite finished
loopListenDone chan struct{} // report that loopListen finished
closing chan struct{}
}
-// loopStreamWrite connects libc write to the network stack for TCP sockets.
-//
-// As written, we have two netstack threads per socket.
-// That's not so bad for small client work, but even a client OS is
-// eventually going to feel the overhead of this.
-func (ios *iostate) loopStreamWrite() {
- // Warm up.
- _, err := zxwait.Wait(zx.Handle(ios.dataHandle),
- zx.SignalSocketReadable|zx.SignalSocketReadDisabled|
- zx.SignalSocketPeerClosed|LOCAL_SIGNAL_CLOSING,
- zx.TimensecInfinite)
- switch mxerror.Status(err) {
- case zx.ErrOk:
- // NOP
- case zx.ErrBadHandle, zx.ErrCanceled, zx.ErrPeerClosed:
- return
- default:
- log.Printf("loopStreamWrite: warmup failed: %v", err)
- }
-
- // The client might have written some data into the socket.
- // Always continue to the 'for' loop below and try to read them
- // even if the signals show the client has closed the dataHandle.
+// loopWrite connects libc write to the network stack.
+func (ios *iostate) loopWrite() error {
+ const sigs = zx.SignalSocketReadable | zx.SignalSocketReadDisabled |
+ zx.SignalSocketPeerClosed | LOCAL_SIGNAL_CLOSING
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ defer ios.wq.EventUnregister(&waitEntry)
for {
// TODO: obviously allocating for each read is silly.
// A quick hack we can do is store these in a ring buffer,
// as the lifecycle of this buffer.View starts here, and
// ends in nearby code we control in link.go.
- v := buffer.NewView(2048)
- n, err := ios.dataHandle.Read([]byte(v), 0)
- switch mxerror.Status(err) {
+ v := make([]byte, 0, 2048)
+ switch n, err := ios.dataHandle.Read(v[:cap(v)], 0); mxerror.Status(err) {
case zx.ErrOk:
// Success. Pass the data to the endpoint and loop.
+ v = v[:n]
case zx.ErrBadState:
// This side of the socket is closed.
- err := ios.ep.Shutdown(tcpip.ShutdownWrite)
- if err != nil {
- log.Printf("loopStreamWrite: ShutdownWrite failed: %v", err)
+ if err := ios.ep.Shutdown(tcpip.ShutdownWrite); err != nil && err != tcpip.ErrNotConnected {
+ return fmt.Errorf("Endpoint.Shutdown(ShutdownWrite): %s", err)
}
- return
+ return nil
case zx.ErrShouldWait:
- obs, err := zxwait.Wait(zx.Handle(ios.dataHandle),
- zx.SignalSocketReadable|zx.SignalSocketReadDisabled|
- zx.SignalSocketPeerClosed|LOCAL_SIGNAL_CLOSING,
- zx.TimensecInfinite)
- switch mxerror.Status(err) {
+ switch obs, err := zxwait.Wait(zx.Handle(ios.dataHandle), sigs, zx.TimensecInfinite); mxerror.Status(err) {
case zx.ErrOk:
- // Handle signal below.
- case zx.ErrBadHandle, zx.ErrCanceled, zx.ErrPeerClosed:
- return
- default:
- log.Printf("loopStreamWrite: wait failed: %v", err)
- return
- }
- switch {
- case obs&zx.SignalSocketReadDisabled != 0:
+ switch {
+ case obs&zx.SignalSocketReadDisabled != 0:
// The next Read will return zx.BadState.
- continue
- case obs&zx.SignalSocketReadable != 0:
- continue
- case obs&LOCAL_SIGNAL_CLOSING != 0:
- return
- case obs&zx.SignalSocketPeerClosed != 0:
- return
+ case obs&zx.SignalSocketReadable != 0:
+ // The client might have written some data into the socket.
+ // Always continue to the 'for' loop below and try to read them
+ // even if the signals show the client has closed the dataHandle.
+ continue
+ case obs&zx.SignalSocketPeerClosed != 0:
+ return nil
+ case obs&LOCAL_SIGNAL_CLOSING != 0:
+ return nil
+ }
+ case zx.ErrBadHandle, zx.ErrCanceled, zx.ErrPeerClosed:
+ return nil
+ default:
+ return err
}
case zx.ErrBadHandle, zx.ErrCanceled, zx.ErrPeerClosed:
- return
+ return nil
default:
- log.Printf("socket read failed: %v", err) // TODO: communicate this
- continue
+ return err
}
- v = v[:n]
if debug {
- log.Printf("loopStreamWrite: sending packet n=%d, v=%q", n, v)
+ log.Printf("%p: loopWrite: sending packet n=%d, v=%q", ios, len(v), v)
}
- if err := func() *tcpip.Error {
- ios.wq.EventRegister(&waitEntry, waiter.EventOut)
- defer ios.wq.EventUnregister(&waitEntry)
-
- for {
- n, resCh, err := ios.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
- if resCh != nil {
- if err != tcpip.ErrNoLinkAddress {
- panic(fmt.Sprintf("err=%v inconsistent with presence of resCh", err))
- }
- panic(fmt.Sprintf("TCP link address resolutions happen on connect; saw %d/%d", n, len(v)))
- }
- if err == tcpip.ErrWouldBlock {
- // Note that Close should not interrupt this wait.
- <-notifyCh
- continue
- }
+ var opts tcpip.WriteOptions
+ if ios.transProto != tcp.ProtocolNumber {
+ var fdioSocketMsg C.struct_fdio_socket_msg
+ if err := fdioSocketMsg.Unmarshal(v[:C.FDIO_SOCKET_MSG_HEADER_SIZE]); err != nil {
+ return err
+ }
+ if fdioSocketMsg.addrlen != 0 {
+ addr, err := fdioSocketMsg.addr.Decode()
if err != nil {
return err
}
- v = v[n:]
- if len(v) == 0 {
- return nil
+ opts.To = &addr
+ }
+ v = v[C.FDIO_SOCKET_MSG_HEADER_SIZE:]
+ }
+
+ ios.wq.EventRegister(&waitEntry, waiter.EventOut)
+ for {
+ n, resCh, err := ios.ep.Write(tcpip.SlicePayload(v), opts)
+ if resCh != nil {
+ if err != tcpip.ErrNoLinkAddress {
+ panic(fmt.Sprintf("err=%v inconsistent with presence of resCh", err))
+ }
+ if ios.transProto == tcp.ProtocolNumber {
+ panic(fmt.Sprintf("TCP link address resolutions happen on connect; saw %d/%d", n, len(v)))
+ }
+ <-resCh
+ continue
+ }
+ if err == tcpip.ErrWouldBlock {
+ switch ios.transProto {
+ case tcp.ProtocolNumber:
+ default:
+ panic(fmt.Sprintf("UDP writes are nonblocking; saw %d/%d", n, len(v)))
+ }
+ // Note that Close should not interrupt this wait.
+ <-notifyCh
+ continue
+ }
+ if err != nil {
+ return fmt.Errorf("Endpoint.Write(...): %s", err)
+ }
+ if ios.transProto != tcp.ProtocolNumber {
+ if int(n) < len(v) {
+ panic(fmt.Sprintf("UDP disallows short writes; saw: %d/%d", n, len(v)))
}
}
- }(); err != nil {
- log.Printf("loopStreamWrite: got endpoint error: %v (TODO)", err)
- return
+ v = v[n:]
+ if len(v) == 0 {
+ break
+ }
}
+ ios.wq.EventUnregister(&waitEntry)
}
}
-// loopStreamRead connects libc read to the network stack for TCP sockets.
-func (ios *iostate) loopStreamRead() {
- // Warm up.
- writable := false
- connected := false
- for !(writable && connected) {
- sigs := zx.Signals(zx.SignalSocketWriteDisabled | zx.SignalSocketPeerClosed | LOCAL_SIGNAL_CLOSING)
- if !writable {
- sigs |= zx.SignalSocketWritable
- }
- if !connected {
- sigs |= ZXSIO_SIGNAL_CONNECTED
- }
- obs, err := zxwait.Wait(zx.Handle(ios.dataHandle), sigs, zx.TimensecInfinite)
- switch mxerror.Status(err) {
+// loopRead connects libc read to the network stack.
+func (ios *iostate) loopRead() error {
+ if ios.transProto == tcp.ProtocolNumber {
+ switch obs, err := zxwait.Wait(
+ zx.Handle(ios.dataHandle),
+ zx.SignalSocketWriteDisabled|zx.SignalSocketPeerClosed|LOCAL_SIGNAL_CLOSING|ZXSIO_SIGNAL_CONNECTED,
+ zx.TimensecInfinite,
+ ); mxerror.Status(err) {
case zx.ErrOk:
- // NOP
- case zx.ErrBadHandle, zx.ErrCanceled, zx.ErrPeerClosed:
- return
- default:
- log.Printf("loopStreamRead: warmup failed: %v", err)
- }
- if obs&zx.SignalSocketWritable != 0 {
- writable = true
- }
- if obs&ZXSIO_SIGNAL_CONNECTED != 0 {
- connected = true
- }
- if obs&zx.SignalSocketPeerClosed != 0 {
- return
- }
- if obs&LOCAL_SIGNAL_CLOSING != 0 {
- return
- }
- if obs&zx.SignalSocketWriteDisabled != 0 {
- err := ios.ep.Shutdown(tcpip.ShutdownRead)
- if err != nil {
- log.Printf("loopStreamRead: ShutdownRead failed: %v", err)
+ if obs&ZXSIO_SIGNAL_CONNECTED != 0 {
+ // We're connected.
}
- return
+ if obs&zx.SignalSocketPeerClosed != 0 {
+ return nil
+ }
+ if obs&LOCAL_SIGNAL_CLOSING != 0 {
+ return nil
+ }
+ if obs&zx.SignalSocketWriteDisabled != 0 {
+ if err := ios.ep.Shutdown(tcpip.ShutdownRead); err != nil {
+ return fmt.Errorf("Endpoint.Shutdown(ShutdownRead): %s", err)
+ }
+ return nil
+ }
+ case zx.ErrBadHandle, zx.ErrCanceled, zx.ErrPeerClosed:
+ return nil
+ default:
+ return err
}
}
+ const sigs = zx.SignalSocketWritable | zx.SignalSocketWriteDisabled |
+ zx.SignalSocketPeerClosed | LOCAL_SIGNAL_CLOSING
+
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ defer ios.wq.EventUnregister(&waitEntry)
+
+ var sender tcpip.FullAddress
for {
ios.wq.EventRegister(&waitEntry, waiter.EventIn)
- var v buffer.View
- var err *tcpip.Error
+ var v []byte
for {
- v, _, err = ios.ep.Read(nil)
- if err == nil {
- break
- } else if err == tcpip.ErrWouldBlock || err == tcpip.ErrInvalidEndpointState || err == tcpip.ErrNotConnected {
- if debug {
- log.Printf("loopStreamRead read err=%v", err)
- }
+ var err *tcpip.Error
+ v, _, err = ios.ep.Read(&sender)
+ if err == tcpip.ErrWouldBlock {
select {
case <-notifyCh:
continue
case <-ios.closing:
// TODO: write a unit test that exercises this.
- return
+ return nil
}
- } else if err == tcpip.ErrClosedForReceive || err == tcpip.ErrConnectionRefused {
- if err == tcpip.ErrConnectionRefused {
- ios.lastError = err
- }
- err := ios.dataHandle.Shutdown(zx.SocketShutdownWrite)
- switch mxerror.Status(err) {
- case zx.ErrOk:
- case zx.ErrBadHandle, zx.ErrCanceled, zx.ErrPeerClosed:
- default:
- log.Printf("socket read: shutdown failed: %v", err)
- }
- return
}
- log.Printf("loopStreamRead got endpoint error: %v (TODO)", err)
- return
+ if err == tcpip.ErrClosedForReceive {
+ return ios.dataHandle.Shutdown(zx.SocketShutdownWrite)
+ }
+ if err != nil {
+ return fmt.Errorf("Endpoint.Read(): %s", err)
+ }
+ break
}
ios.wq.EventUnregister(&waitEntry)
if debug {
- log.Printf("loopStreamRead: got a buffer, len(v)=%d", len(v))
+ log.Printf("%p: loopRead: received packet n=%d, v=%q", ios, len(v), v)
}
- writeLoop:
- for len(v) > 0 {
- n, err := ios.dataHandle.Write([]byte(v), 0)
- v = v[n:]
- switch mxerror.Status(err) {
- case zx.ErrOk:
- // Success. Loop and keep writing.
- case zx.ErrBadState:
- // This side of the socket is closed.
- err := ios.ep.Shutdown(tcpip.ShutdownRead)
- if err != nil {
- log.Printf("loopStreamRead: ShutdownRead failed: %v", err)
- }
- return
- case zx.ErrShouldWait:
- if debug {
- log.Printf("loopStreamRead: got zx.ErrShouldWait")
- }
- obs, err := zxwait.Wait(zx.Handle(ios.dataHandle),
- zx.SignalSocketWritable|zx.SignalSocketWriteDisabled|
- zx.SignalSocketPeerClosed|LOCAL_SIGNAL_CLOSING,
- zx.TimensecInfinite)
- switch mxerror.Status(err) {
- case zx.ErrOk:
- // Handle signal below.
- case zx.ErrBadHandle, zx.ErrCanceled, zx.ErrPeerClosed:
- return
- default:
- log.Printf("loopStreamRead: wait failed: %v", err)
- return
- }
- switch {
- case obs&zx.SignalSocketPeerClosed != 0:
- return
- case obs&LOCAL_SIGNAL_CLOSING != 0:
- return
- case obs&zx.SignalSocketWriteDisabled != 0:
- // The next Write will return zx.ErrBadState.
- continue
- case obs&zx.SignalSocketWritable != 0:
- continue
- }
- case zx.ErrBadHandle, zx.ErrCanceled, zx.ErrPeerClosed:
- return
- default:
- log.Printf("socket write failed: %v", err) // TODO: communicate this
- break writeLoop
- }
- }
- }
-}
-
-// loopDgramRead connects libc read to the network stack for UDP messages.
-func (ios *iostate) loopDgramRead() {
- waitEntry, notifyCh := waiter.NewChannelEntry(nil)
- for {
- ios.wq.EventRegister(&waitEntry, waiter.EventIn)
- var sender tcpip.FullAddress
- var v buffer.View
- var err *tcpip.Error
- for {
- v, _, err = ios.ep.Read(&sender)
- if err == nil {
- break
- } else if err == tcpip.ErrWouldBlock {
- select {
- case <-notifyCh:
- continue
- case <-ios.closing:
- return
- }
- } else if err == tcpip.ErrClosedForReceive {
- if debug {
- log.Printf("TODO loopDgramRead closed")
- }
- // TODO _, err := ios.dataHandle.Write(nil, ZX_SOCKET_HALF_CLOSE)
- return
- }
- // TODO communicate to user
- log.Printf("loopDgramRead got endpoint error: %v (TODO)", err)
- return
- }
- ios.wq.EventUnregister(&waitEntry)
-
- out := make([]byte, C.FDIO_SOCKET_MSG_HEADER_SIZE+len(v))
- if err := func() error {
- var fdioSocketMsg C.struct_fdio_socket_msg
- n, err := fdioSocketMsg.addr.Encode(sender)
- if err != nil {
- return err
- }
- fdioSocketMsg.addrlen = C.socklen_t(n)
- if _, err := fdioSocketMsg.MarshalTo(out[:C.FDIO_SOCKET_MSG_HEADER_SIZE]); err != nil {
- return err
- }
- return nil
- }(); err != nil {
- // TODO communicate to user
- log.Printf("writeSocketMsgHdr failed: %v (TODO)", err)
- }
- if n := copy(out[C.FDIO_SOCKET_MSG_HEADER_SIZE:], v); n < len(v) {
- panic(fmt.Sprintf("copied %d/%d bytes", n, len(v)))
- }
-
- writeLoop:
- for {
- _, err := ios.dataHandle.Write(out, 0)
- switch mxerror.Status(err) {
- case zx.ErrOk:
- break writeLoop
- case zx.ErrBadState:
- return // This side of the socket is closed.
- case zx.ErrBadHandle, zx.ErrCanceled, zx.ErrPeerClosed:
- return
- default:
- log.Printf("socket write failed: %v", err) // TODO: communicate this
- break writeLoop
- }
- }
- }
-}
-
-// loopDgramWrite connects libc write to the network stack for UDP messages.
-func (ios *iostate) loopDgramWrite() {
- for {
- v := buffer.NewView(2048)
- n, err := ios.dataHandle.Read([]byte(v), 0)
- switch mxerror.Status(err) {
- case zx.ErrOk:
- // Success. Pass the data to the endpoint and loop.
- case zx.ErrBadState:
- return // This side of the socket is closed.
- case zx.ErrBadHandle, zx.ErrCanceled, zx.ErrPeerClosed:
- return
- case zx.ErrShouldWait:
- obs, err := zxwait.Wait(zx.Handle(ios.dataHandle),
- zx.SignalSocketReadable|zx.SignalSocketPeerClosed|LOCAL_SIGNAL_CLOSING,
- zx.TimensecInfinite)
- switch mxerror.Status(err) {
- case zx.ErrBadHandle, zx.ErrCanceled, zx.ErrPeerClosed:
- return
- case zx.ErrOk:
- switch {
- case obs&zx.SignalChannelReadable != 0:
- continue
- case obs&LOCAL_SIGNAL_CLOSING != 0:
- return
- case obs&zx.SignalSocketPeerClosed != 0:
- return
- }
- default:
- log.Printf("loopDgramWrite wait failed: %v", err)
- return
- }
- default:
- log.Printf("loopDgramWrite failed: %v", err) // TODO: communicate this
- continue
- }
- v = v[:n]
-
- receiver, err := func() (*tcpip.FullAddress, error) {
- var fdioSocketMsg C.struct_fdio_socket_msg
- if err := fdioSocketMsg.Unmarshal(v[:C.FDIO_SOCKET_MSG_HEADER_SIZE]); err != nil {
- return nil, err
- }
- if fdioSocketMsg.addrlen == 0 {
- return nil, nil
- }
- addr, err := fdioSocketMsg.addr.Decode()
- if err != nil {
- return nil, err
- }
- return &addr, nil
- }()
- if err != nil {
- // TODO communicate
- log.Printf("loopDgramWrite: bad socket msg header: %v", err)
- continue
- }
- v = v[C.FDIO_SOCKET_MSG_HEADER_SIZE:]
-
- if err := func() *tcpip.Error {
- for {
- n, resCh, err := ios.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{To: receiver})
- if resCh != nil {
- if err != tcpip.ErrNoLinkAddress {
- panic(fmt.Sprintf("err=%v inconsistent with presence of resCh", err))
- }
- <-resCh
- continue
- }
- if err == tcpip.ErrWouldBlock {
- panic(fmt.Sprintf("UDP writes are nonblocking; saw %d/%d", n, len(v)))
- }
+ if ios.transProto != tcp.ProtocolNumber {
+ out := make([]byte, C.FDIO_SOCKET_MSG_HEADER_SIZE+len(v))
+ if err := func() error {
+ var fdioSocketMsg C.struct_fdio_socket_msg
+ n, err := fdioSocketMsg.addr.Encode(sender)
if err != nil {
return err
}
- if int(n) < len(v) {
- panic(fmt.Sprintf("UDP disallowes short writes; saw: %d/%d", n, len(v)))
+ fdioSocketMsg.addrlen = C.socklen_t(n)
+ if _, err := fdioSocketMsg.MarshalTo(out[:C.FDIO_SOCKET_MSG_HEADER_SIZE]); err != nil {
+ return err
}
return nil
+ }(); err != nil {
+ return err
}
- }(); err != nil {
- log.Printf("loopDgramWrite: got endpoint error: %v (TODO)", err)
- return
+ if n := copy(out[C.FDIO_SOCKET_MSG_HEADER_SIZE:], v); n < len(v) {
+ panic(fmt.Sprintf("copied %d/%d bytes", n, len(v)))
+ }
+ v = out
+ }
+
+ writeLoop:
+ for {
+ switch n, err := ios.dataHandle.Write(v, 0); mxerror.Status(err) {
+ case zx.ErrOk:
+ if ios.transProto != tcp.ProtocolNumber {
+ if n < len(v) {
+ panic(fmt.Sprintf("UDP disallows short writes; saw: %d/%d", n, len(v)))
+ }
+ }
+ v = v[n:]
+ if len(v) == 0 {
+ break writeLoop
+ }
+ case zx.ErrBadState:
+ // This side of the socket is closed.
+ if err := ios.ep.Shutdown(tcpip.ShutdownRead); err != nil {
+ return fmt.Errorf("Endpoint.Shutdown(ShutdownRead): %s", err)
+ }
+ return nil
+ case zx.ErrShouldWait:
+ switch obs, err := zxwait.Wait(zx.Handle(ios.dataHandle), sigs, zx.TimensecInfinite); mxerror.Status(err) {
+ case zx.ErrOk:
+ switch {
+ case obs&zx.SignalSocketWriteDisabled != 0:
+ // The next Write will return zx.BadState.
+ case obs&zx.SignalSocketWritable != 0:
+ continue
+ case obs&zx.SignalSocketPeerClosed != 0:
+ return nil
+ case obs&LOCAL_SIGNAL_CLOSING != 0:
+ return nil
+ }
+ case zx.ErrBadHandle, zx.ErrCanceled, zx.ErrPeerClosed:
+ return nil
+ default:
+ return err
+ }
+ case zx.ErrBadHandle, zx.ErrCanceled, zx.ErrPeerClosed:
+ return nil
+ default:
+ return err
+ }
}
}
}
@@ -593,15 +432,15 @@
go ios.loopControl()
go func() {
+ if err := ios.loopRead(); err != nil {
+ log.Printf("%p: loopRead: %s", ios, err)
+ }
+ }()
+ go func() {
defer close(ios.loopWriteDone)
- switch transProto {
- case tcp.ProtocolNumber:
- go ios.loopStreamRead()
- ios.loopStreamWrite()
- case udp.ProtocolNumber, ping.ProtocolNumber4:
- go ios.loopDgramRead()
- ios.loopDgramWrite()
+ if err := ios.loopWrite(); err != nil {
+ log.Printf("%p: loopWrite: %s", ios, err)
}
}()