[netstack] remove socketServer
This structure is not particularly useful - all the useful state is
already in `iostate`.
Test: CQ
Change-Id: I21b77540d8e35160955a22c04605f3482a4ce7d9
diff --git a/go/src/netstack/fuchsia_net_socket.go b/go/src/netstack/fuchsia_net_socket.go
index a7a13e8..03a1202 100644
--- a/go/src/netstack/fuchsia_net_socket.go
+++ b/go/src/netstack/fuchsia_net_socket.go
@@ -19,6 +19,7 @@
"github.com/google/netstack/tcpip/transport/ping"
"github.com/google/netstack/tcpip/transport/tcp"
"github.com/google/netstack/tcpip/transport/udp"
+ "github.com/google/netstack/waiter"
)
// #cgo CFLAGS: -I${SRCDIR}/../../../../zircon/system/ulib/zxs/include
@@ -67,11 +68,27 @@
return zx.Socket(zx.HandleInvalid), int32(errStatus(err)), nil
}
- s, err := sp.ns.socketServer.opSocket(netProto, transProto)
- if err != nil {
- return zx.Socket(zx.HandleInvalid), int32(errStatus(err)), nil
+ {
+ wq := new(waiter.Queue)
+ ep, err := sp.ns.mu.stack.NewEndpoint(transProto, netProto, wq)
+ if err != nil {
+ if debug {
+ log.Printf("socket: new endpoint: %v", err)
+ }
+ return zx.Socket(zx.HandleInvalid), int32(zx.ErrInternal), nil
+ }
+ {
+ _, peerS, err := newIostate(sp.ns, netProto, transProto, wq, ep, false)
+ if err != nil {
+ if debug {
+ log.Printf("socket: new iostate: %v", err)
+ }
+ return zx.Socket(zx.HandleInvalid), int32(errStatus(err)), nil
+ }
+
+ return peerS, 0, nil
+ }
}
- return s, 0, nil
}
func (sp *socketProviderImpl) GetAddrInfo(node *string, service *string, hints *net.AddrInfoHints) (net.AddrInfoStatus, uint32, [4]net.AddrInfo, error) {
diff --git a/go/src/netstack/main.go b/go/src/netstack/main.go
index 73d104d..7c47fe5 100644
--- a/go/src/netstack/main.go
+++ b/go/src/netstack/main.go
@@ -54,12 +54,6 @@
log.Fatalf("method SetTransportProtocolOption(%v, tcp.SACKEnabled(true)) failed: %v", tcp.ProtocolNumber, err)
}
- s, err := newSocketServer(stk, ctx)
- if err != nil {
- log.Fatal(err)
- }
- log.Print("socket server started")
-
arena, err := eth.NewArena()
if err != nil {
log.Fatalf("ethernet: %s", err)
@@ -74,7 +68,6 @@
ns := &Netstack{
arena: arena,
- socketServer: s,
dnsClient: dns.NewClient(stk),
deviceSettings: ds,
ifStates: make(map[tcpip.NICID]*ifState),
@@ -85,8 +78,6 @@
log.Fatalf("loopback: %s", err)
}
- s.setNetstack(ns)
-
// TODO(NET-1263): register resolver admin service once clients don't crash netstack
// var dnsService netstack.ResolverAdminService
var netstackService netstack.NetstackService
diff --git a/go/src/netstack/netstack.go b/go/src/netstack/netstack.go
index 2111b4f..bd52664 100644
--- a/go/src/netstack/netstack.go
+++ b/go/src/netstack/netstack.go
@@ -47,8 +47,7 @@
// A Netstack tracks all of the running state of the network stack.
type Netstack struct {
- arena *eth.Arena
- socketServer *socketServer
+ arena *eth.Arena
deviceSettings *devicesettings.DeviceSettingsManagerInterface
dnsClient *dns.Client
diff --git a/go/src/netstack/socket_server.go b/go/src/netstack/socket_server.go
index c8848ac..7542a28 100644
--- a/go/src/netstack/socket_server.go
+++ b/go/src/netstack/socket_server.go
@@ -16,14 +16,11 @@
"syscall/zx/zxwait"
"time"
- "app/context"
-
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
"github.com/google/netstack/tcpip/network/ipv4"
"github.com/google/netstack/tcpip/network/ipv6"
- "github.com/google/netstack/tcpip/stack"
"github.com/google/netstack/tcpip/transport/ping"
"github.com/google/netstack/tcpip/transport/tcp"
"github.com/google/netstack/tcpip/transport/udp"
@@ -82,32 +79,18 @@
return sendSignal(s, ZXSIO_SIGNAL_CONNECTED, false)
}
-func newSocketServer(stk *stack.Stack, ctx *context.Context) (*socketServer, error) {
- a := socketServer{
- stack: stk,
- io: make(map[cookie]*iostate),
- next: 1,
- }
- return &a, nil
-}
-
-func (s *socketServer) setNetstack(ns *Netstack) {
- s.ns = ns
-}
-
-type cookie int64
-
type iostate struct {
wq *waiter.Queue
ep tcpip.Endpoint
+ ns *Netstack
+
netProto tcpip.NetworkProtocolNumber // IPv4 or IPv6
transProto tcpip.TransportProtocolNumber // TCP or UDP
dataHandle zx.Socket // used to communicate with libc
mu sync.Mutex
- refs int
lastError *tcpip.Error // if not-nil, next error returned via getsockopt
loopWriteDone chan struct{} // report that loop[Stream|Dgram]Write finished
@@ -121,7 +104,7 @@
// As written, we have two netstack threads per socket.
// That's not so bad for small client work, but even a client OS is
// eventually going to feel the overhead of this.
-func (ios *iostate) loopStreamWrite(stk *stack.Stack) {
+func (ios *iostate) loopStreamWrite() {
// Warm up.
_, err := zxwait.Wait(zx.Handle(ios.dataHandle),
zx.SignalSocketReadable|zx.SignalSocketReadDisabled|
@@ -228,7 +211,7 @@
}
// loopStreamRead connects libc read to the network stack for TCP sockets.
-func (ios *iostate) loopStreamRead(stk *stack.Stack) {
+func (ios *iostate) loopStreamRead() {
// Warm up.
writable := false
connected := false
@@ -364,7 +347,7 @@
}
// loopDgramRead connects libc read to the network stack for UDP messages.
-func (ios *iostate) loopDgramRead(stk *stack.Stack) {
+func (ios *iostate) loopDgramRead() {
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
for {
ios.wq.EventRegister(&waitEntry, waiter.EventIn)
@@ -434,7 +417,7 @@
}
// loopDgramWrite connects libc write to the network stack for UDP messages.
-func (ios *iostate) loopDgramWrite(stk *stack.Stack) {
+func (ios *iostate) loopDgramWrite() {
for {
v := buffer.NewView(2048)
n, err := ios.dataHandle.Read([]byte(v), 0)
@@ -520,11 +503,11 @@
}
}
-func (ios *iostate) loopControl(s *socketServer, cookie int64) {
+func (ios *iostate) loopControl() {
synthesizeClose := true
defer func() {
if synthesizeClose {
- switch err := zxsocket.Handler(0, zxsocket.ServerHandler(s.zxsocketHandler), cookie); mxerror.Status(err) {
+ switch err := zxsocket.Handler(0, zxsocket.ServerHandler(ios.zxsocketHandler), 0); mxerror.Status(err) {
case zx.ErrOk:
default:
log.Printf("synethsize close failed: %v", err)
@@ -537,7 +520,7 @@
}()
for {
- switch err := zxsocket.Handler(ios.dataHandle, zxsocket.ServerHandler(s.zxsocketHandler), cookie); mxerror.Status(err) {
+ switch err := zxsocket.Handler(ios.dataHandle, zxsocket.ServerHandler(ios.zxsocketHandler), 0); mxerror.Status(err) {
case zx.ErrOk:
// Success. Pass the data to the endpoint and loop.
case zx.ErrBadState:
@@ -576,7 +559,7 @@
}
}
-func (s *socketServer) newIostate(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, wq *waiter.Queue, ep tcpip.Endpoint, isAccept bool) (zx.Socket, zx.Socket, error) {
+func newIostate(ns *Netstack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, wq *waiter.Queue, ep tcpip.Endpoint, isAccept bool) (zx.Socket, zx.Socket, error) {
var t uint32
switch transProto {
case tcp.ProtocolNumber:
@@ -600,69 +583,29 @@
transProto: transProto,
wq: wq,
ep: ep,
- refs: 1,
+ ns: ns,
dataHandle: localS,
loopWriteDone: make(chan struct{}),
closing: make(chan struct{}),
}
- s.mu.Lock()
- newCookie := s.next
- s.next++
- s.io[newCookie] = ios
- s.mu.Unlock()
-
- go ios.loopControl(s, int64(newCookie))
+ go ios.loopControl()
go func() {
defer close(ios.loopWriteDone)
switch transProto {
case tcp.ProtocolNumber:
- go ios.loopStreamRead(s.stack)
- ios.loopStreamWrite(s.stack)
+ go ios.loopStreamRead()
+ ios.loopStreamWrite()
case udp.ProtocolNumber, ping.ProtocolNumber4:
- go ios.loopDgramRead(s.stack)
- ios.loopDgramWrite(s.stack)
+ go ios.loopDgramRead()
+ ios.loopDgramWrite()
}
}()
return localS, peerS, nil
}
-type socketServer struct {
- stack *stack.Stack
- ns *Netstack
-
- mu sync.Mutex
- next cookie
- io map[cookie]*iostate
-}
-
-func (s *socketServer) opSocket(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (zx.Socket, error) {
- wq := new(waiter.Queue)
- ep, e := s.stack.NewEndpoint(transProto, netProto, wq)
- if e != nil {
- if debug {
- log.Printf("socket: new endpoint: %v", e)
- }
- return zx.Socket(zx.HandleInvalid), mxerror.Errorf(zx.ErrInternal, "socket: new endpoint: %v", e)
- }
- if netProto == ipv6.ProtocolNumber {
- if err := ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil {
- log.Printf("socket: setsockopt v6only option failed: %v", err)
- }
- }
- _, peerS, err := s.newIostate(netProto, transProto, wq, ep, false)
- if err != nil {
- if debug {
- log.Printf("socket: new iostate: %v", err)
- }
- return zx.Socket(zx.HandleInvalid), err
- }
-
- return peerS, nil
-}
-
func errStatus(err error) zx.Status {
if err == nil {
return zx.ErrOk
@@ -740,7 +683,7 @@
}
}
-func (s *socketServer) opGetSockOpt(ios *iostate, msg *zxsocket.Msg) zx.Status {
+func (ios *iostate) opGetSockOpt(msg *zxsocket.Msg) zx.Status {
var val C.struct_zxrio_sockopt_req_reply
if err := val.Unmarshal(msg.Data[:msg.Datalen]); err != nil {
if debug {
@@ -884,7 +827,7 @@
return zx.ErrOk
}
-func (s *socketServer) opSetSockOpt(ios *iostate, msg *zxsocket.Msg) zx.Status {
+func (ios *iostate) opSetSockOpt(msg *zxsocket.Msg) zx.Status {
var val C.struct_zxrio_sockopt_req_reply
if err := val.Unmarshal(msg.Data[:msg.Datalen]); err != nil {
if debug {
@@ -908,7 +851,7 @@
return zx.ErrOk
}
-func (s *socketServer) opBind(ios *iostate, msg *zxsocket.Msg) (status zx.Status) {
+func (ios *iostate) opBind(msg *zxsocket.Msg) (status zx.Status) {
// TODO(tamird): are we really sending raw sockaddr_storage here? why aren't we using
// zxrio_sockaddr_reply? come to think of it, why does zxrio_sockaddr_reply exist?
addr, err := func() (tcpip.FullAddress, error) {
@@ -945,13 +888,13 @@
return zx.ErrOk
}
-func (s *socketServer) buildIfInfos() *C.netc_get_if_info_t {
+func (ios *iostate) buildIfInfos() *C.netc_get_if_info_t {
rep := &C.netc_get_if_info_t{}
- s.ns.mu.Lock()
- defer s.ns.mu.Unlock()
+ ios.ns.mu.Lock()
+ defer ios.ns.mu.Unlock()
var index C.uint
- for nicid, ifs := range s.ns.ifStates {
+ for nicid, ifs := range ios.ns.ifStates {
if ifs.nic.Addr == ipv4Loopback {
continue
}
@@ -997,11 +940,11 @@
// a race condition if the interface list changes between calls to ioctlNetcGetIfInfoAt.
var lastIfInfo *C.netc_get_if_info_t
-func (s *socketServer) opIoctl(ios *iostate, msg *zxsocket.Msg) zx.Status {
+func (ios *iostate) opIoctl(msg *zxsocket.Msg) zx.Status {
switch msg.IoctlOp() {
// TODO(ZX-766): remove when dart/runtime/bin/socket_base_fuchsia.cc uses getifaddrs().
case ioctlNetcGetNumIfs:
- lastIfInfo = s.buildIfInfos()
+ lastIfInfo = ios.buildIfInfos()
binary.LittleEndian.PutUint32(msg.Data[:msg.Arg], uint32(lastIfInfo.n_info))
msg.Datalen = 4
return zx.ErrOk
@@ -1037,7 +980,7 @@
msg.Datalen = uint32(n)
return zx.ErrOk
case ioctlNetcGetNodename:
- nodename := s.ns.getNodeName()
+ nodename := ios.ns.getNodeName()
msg.Datalen = uint32(copy(msg.Data[:msg.Arg], nodename))
msg.Data[msg.Datalen] = 0
return zx.ErrOk
@@ -1070,7 +1013,7 @@
return zx.ErrOk
}
-func (s *socketServer) opGetSockName(ios *iostate, msg *zxsocket.Msg) zx.Status {
+func (ios *iostate) opGetSockName(msg *zxsocket.Msg) zx.Status {
a, err := ios.ep.GetLocalAddress()
if err != nil {
return zxNetError(err)
@@ -1090,7 +1033,7 @@
return fdioSockAddrReply(a, msg)
}
-func (s *socketServer) opGetPeerName(ios *iostate, msg *zxsocket.Msg) (status zx.Status) {
+func (ios *iostate) opGetPeerName(msg *zxsocket.Msg) (status zx.Status) {
a, err := ios.ep.GetRemoteAddress()
if err != nil {
return zxNetError(err)
@@ -1098,7 +1041,7 @@
return fdioSockAddrReply(a, msg)
}
-func (s *socketServer) loopListen(ios *iostate, inCh chan struct{}) {
+func (ios *iostate) loopListen(inCh chan struct{}) {
// When an incoming connection is available, wait for the listening socket to
// enter a shareable state, then share it with the client.
for {
@@ -1152,7 +1095,7 @@
}
}
- localS, peerS, err := s.newIostate(ios.netProto, ios.transProto, newwq, newep, true)
+ localS, peerS, err := newIostate(ios.ns, ios.netProto, ios.transProto, newwq, newep, true)
if err != nil {
if debug {
log.Printf("listen: newIostate failed: %v", err)
@@ -1172,7 +1115,7 @@
}
}
-func (s *socketServer) opListen(ios *iostate, msg *zxsocket.Msg) (status zx.Status) {
+func (ios *iostate) opListen(msg *zxsocket.Msg) (status zx.Status) {
d := msg.Data[:msg.Datalen]
if len(d) != 4 {
if debug {
@@ -1201,7 +1144,7 @@
ios.loopListenDone = make(chan struct{})
go func() {
defer close(ios.loopListenDone)
- s.loopListen(ios, inCh)
+ ios.loopListen(inCh)
ios.wq.EventUnregister(&inEntry)
}()
@@ -1210,7 +1153,7 @@
return zx.ErrOk
}
-func (s *socketServer) opConnect(ios *iostate, msg *zxsocket.Msg) (status zx.Status) {
+func (ios *iostate) opConnect(msg *zxsocket.Msg) (status zx.Status) {
if msg.Datalen == 0 {
if ios.transProto == udp.ProtocolNumber {
// connect() can be called with no address to
@@ -1307,11 +1250,7 @@
return zx.ErrOk
}
-func (s *socketServer) opClose(ios *iostate, cookie cookie) zx.Status {
- s.mu.Lock()
- delete(s.io, cookie)
- s.mu.Unlock()
-
+func (ios *iostate) opClose(cookie int64) zx.Status {
// Signal that we're about to close. This tells the various message loops to finish
// processing, and let us know when they're done.
err := mxerror.Status(ios.dataHandle.Handle().Signal(0, LOCAL_SIGNAL_CLOSING))
@@ -1330,44 +1269,31 @@
return err
}
-func (s *socketServer) zxsocketHandler(msg *zxsocket.Msg, rh zx.Socket, cookieVal int64) zx.Status {
- cookie := cookie(cookieVal)
+func (ios *iostate) zxsocketHandler(msg *zxsocket.Msg, rh zx.Socket, cookie int64) zx.Status {
op := msg.Op()
if debug {
log.Printf("zxsocketHandler: op=%v, len=%d, arg=%v, hcount=%d", op, msg.Datalen, msg.Arg, msg.Hcount)
}
- s.mu.Lock()
- ios := s.io[cookie]
- s.mu.Unlock()
- if ios == nil {
- if op == zxsocket.OpClose && rh == 0 {
- // The close op was synthesized by Dispatcher (because the peer channel was closed).
- return zx.ErrOk
- }
- log.Printf("zxsioHandler: request (op:%v) dropped because of the state mismatch", op)
- return zx.ErrBadState
- }
-
switch op {
case zxsocket.OpConnect:
- return s.opConnect(ios, msg)
+ return ios.opConnect(msg)
case zxsocket.OpClose:
- return s.opClose(ios, cookie)
+ return ios.opClose(cookie)
case zxsocket.OpBind:
- return s.opBind(ios, msg)
+ return ios.opBind(msg)
case zxsocket.OpListen:
- return s.opListen(ios, msg)
+ return ios.opListen(msg)
case zxsocket.OpIoctl:
- return s.opIoctl(ios, msg)
+ return ios.opIoctl(msg)
case zxsocket.OpGetSockname:
- return s.opGetSockName(ios, msg)
+ return ios.opGetSockName(msg)
case zxsocket.OpGetPeerName:
- return s.opGetPeerName(ios, msg)
+ return ios.opGetPeerName(msg)
case zxsocket.OpGetSockOpt:
- return s.opGetSockOpt(ios, msg)
+ return ios.opGetSockOpt(msg)
case zxsocket.OpSetSockOpt:
- return s.opSetSockOpt(ios, msg)
+ return ios.opSetSockOpt(msg)
default:
log.Printf("zxsocketHandler: unknown socket op: %v", op)
return zx.ErrNotSupported