| // Copyright 2017 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // Go's distribution tools attempt to compile everything; this file |
| // depends on zxwait, which doesn't compile in not-Fuchsia. |
| //go:build fuchsia |
| // +build fuchsia |
| |
| package zxsocket |
| |
| import ( |
| "io" |
| "strconv" |
| "strings" |
| "syscall" |
| "syscall/zx" |
| "syscall/zx/fdio" |
| "syscall/zx/internal/context" |
| fidlIo "syscall/zx/io" |
| "syscall/zx/net" |
| "syscall/zx/posix/socket" |
| "syscall/zx/zxwait" |
| ) |
| |
| // These constants mirror those defined in |
| // https://cs.opensource.google/fuchsia/fuchsia/+/main:sdk/lib/fdio/socket.cc |
| const ( |
| SignalStreamIncoming = zx.SignalUser0 |
| SignalStreamConnected = zx.SignalUser3 |
| |
| SignalDatagramIncoming = zx.SignalUser0 |
| SignalDatagramOutgoing = zx.SignalUser1 |
| SignalDatagramError = zx.SignalUser2 |
| SignalDatagramShutdownRead = zx.SignalUser4 |
| SignalDatagramShutdownWrite = zx.SignalUser5 |
| ) |
| |
| var _ Socket = (*DatagramSocket)(nil) |
| var _ Socket = (*StreamSocket)(nil) |
| |
| // Socket is the common subset of datagram and stream sockets. |
| type Socket interface { |
| fdio.FDIO |
| Bind(net.SocketAddress) error |
| Connect(net.SocketAddress) error |
| GetPeerName() (net.SocketAddress, error) |
| GetSockName() (net.SocketAddress, error) |
| SetKeepAlive(bool) error |
| } |
| |
| // NewSocket creates a new Socket. |
| func NewSocket(base *socket.BaseSocketWithCtxInterface) (Socket, error) { |
| info, err := base.Describe(context.Background()) |
| if err != nil { |
| return nil, err |
| } |
| switch w := info.Which(); w { |
| case fidlIo.NodeInfoService, fidlIo.NodeInfoFile, fidlIo.NodeInfoDirectory, fidlIo.NodeInfoPipe, fidlIo.NodeInfoVmofile, fidlIo.NodeInfoDevice, fidlIo.NodeInfoTty: |
| return nil, &zx.Error{Status: zx.ErrInternal, Text: "zxsocket.NewSocket"} |
| case fidlIo.NodeInfoDatagramSocket: |
| return &DatagramSocket{ |
| client: socket.DatagramSocketWithCtxInterface{Channel: base.Channel}, |
| event: info.DatagramSocket.Event, |
| }, nil |
| case fidlIo.NodeInfoStreamSocket: |
| return &StreamSocket{ |
| client: socket.StreamSocketWithCtxInterface{Channel: base.Channel}, |
| socket: info.StreamSocket.Socket, |
| }, nil |
| default: |
| panic("unknown node info tag " + strconv.FormatInt(int64(w), 10)) |
| } |
| } |
| |
| type stub struct{} |
| |
| func (*stub) Sync() error { |
| return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Sync"} |
| } |
| |
| func (*stub) GetAttr() (fidlIo.NodeAttributes, error) { |
| return fidlIo.NodeAttributes{}, &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.GetAttr"} |
| } |
| |
| func (*stub) SetAttr(uint32, fidlIo.NodeAttributes) error { |
| return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.SetAttr"} |
| } |
| |
| func (*stub) ReadAt([]byte, int64) (int, error) { |
| return 0, &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.ReadAt"} |
| } |
| |
| func (*stub) WriteAt([]byte, int64) (int, error) { |
| return 0, &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.WriteAt"} |
| } |
| |
| func (*stub) Seek(int64, int) (int64, error) { |
| return 0, &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Seek"} |
| } |
| |
| func (*stub) Truncate(uint64) error { |
| return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Truncate"} |
| } |
| |
| func (*stub) Open(string, uint32, uint32) (fdio.FDIO, error) { |
| return nil, &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Open"} |
| } |
| |
| func (*stub) Link(string, string) error { |
| return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Link"} |
| } |
| |
| func (*stub) Rename(string, string) error { |
| return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Rename"} |
| } |
| |
| func (*stub) Unlink(string) error { |
| return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Unlink"} |
| } |
| |
| func (*stub) ReadDirents(uint64) ([]byte, error) { |
| return nil, &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.ReadDirents"} |
| } |
| |
| func (*stub) Rewind() error { |
| return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Rewind"} |
| } |
| |
| func clone(client socket.BaseSocketWithCtx) (Socket, error) { |
| req, obj, err := fidlIo.NewNodeWithCtxInterfaceRequest() |
| if err != nil { |
| return nil, err |
| } |
| if err := client.Clone(context.Background(), 0, req); err != nil { |
| return nil, err |
| } |
| return NewSocket((*socket.BaseSocketWithCtxInterface)(obj)) |
| } |
| |
| func bind(client socket.BaseSocketWithCtx, addr net.SocketAddress) error { |
| result, err := client.Bind(context.Background(), addr) |
| if err != nil { |
| return err |
| } |
| switch result.Which() { |
| case socket.BaseSocketBindResultErr: |
| return syscall.Errno(result.Err) |
| case socket.BaseSocketBindResultResponse: |
| return nil |
| default: |
| panic("unreachable") |
| } |
| } |
| |
| func connect(client socket.BaseSocketWithCtx, addr net.SocketAddress) error { |
| result, err := client.Connect(context.Background(), addr) |
| if err != nil { |
| return err |
| } |
| switch result.Which() { |
| case socket.BaseSocketConnectResultErr: |
| return syscall.Errno(result.Err) |
| case socket.BaseSocketConnectResultResponse: |
| return nil |
| default: |
| panic("unreachable") |
| } |
| } |
| |
| func getPeerName(client socket.BaseSocketWithCtx) (net.SocketAddress, error) { |
| result, err := client.GetPeerName(context.Background()) |
| if err != nil { |
| return net.SocketAddress{}, err |
| } |
| switch result.Which() { |
| case socket.BaseSocketGetPeerNameResultErr: |
| return net.SocketAddress{}, syscall.Errno(result.Err) |
| case socket.BaseSocketGetPeerNameResultResponse: |
| return result.Response.Addr, nil |
| default: |
| panic("unreachable") |
| } |
| } |
| |
| func getSockName(client socket.BaseSocketWithCtx) (net.SocketAddress, error) { |
| result, err := client.GetSockName(context.Background()) |
| if err != nil { |
| return net.SocketAddress{}, err |
| } |
| switch result.Which() { |
| case socket.BaseSocketGetSockNameResultErr: |
| return net.SocketAddress{}, syscall.Errno(result.Err) |
| case socket.BaseSocketGetSockNameResultResponse: |
| return result.Response.Addr, nil |
| default: |
| panic("unreachable") |
| } |
| } |
| |
| func setKeepAlive(client socket.BaseSocketWithCtx, keepalive bool) error { |
| result, err := client.SetKeepAlive(context.Background(), keepalive) |
| if err != nil { |
| return err |
| } |
| switch result.Which() { |
| case socket.BaseSocketSetKeepAliveResultErr: |
| return syscall.Errno(result.Err) |
| case socket.BaseSocketSetKeepAliveResultResponse: |
| return nil |
| default: |
| panic("unreachable") |
| } |
| } |
| |
| type closeError struct { |
| fidl error |
| channel struct { |
| wait error |
| close error |
| } |
| handleClose error |
| } |
| |
| func (err *closeError) Error() string { |
| var b strings.Builder |
| if err.fidl != nil { |
| if b.Len() != 0 { |
| b.WriteByte(' ') |
| } |
| b.WriteString("fidl=") |
| b.WriteString(err.fidl.Error()) |
| } |
| if err.channel.wait != nil { |
| if b.Len() != 0 { |
| b.WriteByte(' ') |
| } |
| b.WriteString("channel.wait=") |
| b.WriteString(err.channel.wait.Error()) |
| } |
| if err.channel.close != nil { |
| if b.Len() != 0 { |
| b.WriteByte(' ') |
| } |
| b.WriteString("channel.close=") |
| b.WriteString(err.channel.close.Error()) |
| } |
| if err.handleClose != nil { |
| if b.Len() != 0 { |
| b.WriteByte(' ') |
| } |
| b.WriteString("handleClose=") |
| b.WriteString(err.handleClose.Error()) |
| } |
| return b.String() |
| } |
| |
| func closeSocket(client socket.BaseSocketWithCtx, channel *zx.Channel, handle *zx.Handle) error { |
| var err closeError |
| err.fidl = func() error { |
| status, err := client.Close(context.Background()) |
| if err != nil { |
| return err |
| } |
| if status := zx.Status(status); status != zx.ErrOk { |
| return &zx.Error{Status: status, Text: "zxsocket.Socket.Close"} |
| } |
| return nil |
| }() |
| _, err.channel.wait = zxwait.WaitContext(context.Background(), *channel.Handle(), zx.SignalChannelPeerClosed) |
| err.channel.close = channel.Close() |
| err.handleClose = handle.Close() |
| if err.fidl == nil && err.channel.wait == nil && err.channel.close == nil && err.handleClose == nil { |
| return nil |
| } |
| return &err |
| } |
| |
| // DatagramSocket is a datagram socket. |
| type DatagramSocket struct { |
| stub |
| |
| client socket.DatagramSocketWithCtxInterface |
| event zx.Handle |
| } |
| |
| // Clone implements Socket. |
| func (s *DatagramSocket) Clone() (fdio.FDIO, error) { |
| return clone(&s.client) |
| } |
| |
| // Bind implements Socket. |
| func (s *DatagramSocket) Bind(addr net.SocketAddress) error { |
| return bind(&s.client, addr) |
| } |
| |
| // Connect implements Socket. |
| func (s *DatagramSocket) Connect(addr net.SocketAddress) error { |
| return connect(&s.client, addr) |
| } |
| |
| // GetPeerName implements Socket. |
| func (s *DatagramSocket) GetPeerName() (net.SocketAddress, error) { |
| return getPeerName(&s.client) |
| } |
| |
| // GetSockName implements Socket. |
| func (s *DatagramSocket) GetSockName() (net.SocketAddress, error) { |
| return getSockName(&s.client) |
| } |
| |
| // SetKeepAlive implements Socket. |
| func (s *DatagramSocket) SetKeepAlive(keepalive bool) error { |
| return setKeepAlive(&s.client, keepalive) |
| } |
| |
| func (s *DatagramSocket) recvMsg(dataLen uint32) (net.SocketAddress, []byte, error) { |
| for { |
| result, err := s.client.RecvMsg(context.Background() /* wantAddr */, true, dataLen /* wantControl */, false, 0) |
| if err != nil { |
| return net.SocketAddress{}, nil, err |
| } |
| switch result.Which() { |
| case socket.DatagramSocketRecvMsgResultErr: |
| if err := syscall.Errno(result.Err); err != syscall.EAGAIN { |
| return net.SocketAddress{}, nil, err |
| } |
| |
| obs, err := zxwait.WaitContext(context.Background(), s.event, SignalDatagramIncoming|SignalDatagramShutdownRead|zx.SignalEpairPeerClosed) |
| if err != nil { |
| return net.SocketAddress{}, nil, err |
| } |
| if obs&SignalDatagramIncoming != 0 { |
| continue |
| } |
| if obs&(SignalDatagramShutdownRead|zx.SignalEpairPeerClosed) != 0 { |
| return net.SocketAddress{}, nil, &zx.Error{Status: zx.ErrPeerClosed, Text: "zxsocket.DatagramSocket.RecvMsg"} |
| } |
| panic("unreachable") |
| |
| case socket.DatagramSocketRecvMsgResultResponse: |
| return *result.Response.Addr, result.Response.Data, nil |
| default: |
| panic("unreachable") |
| } |
| } |
| } |
| |
| // RecvMsg implements roughly the recvmsg "system call". Its signature resembles |
| // that of syscall.Revcvmsg on other platforms. |
| func (s *DatagramSocket) RecvMsg(maxLen int) ([]byte, net.SocketAddress, error) { |
| addr, data, err := s.recvMsg(uint32(maxLen)) |
| if err != nil { |
| return nil, net.SocketAddress{}, err |
| } |
| return data, addr, nil |
| } |
| |
| func (s *DatagramSocket) sendMsg(addr *net.SocketAddress, data []byte) (int, error) { |
| result, err := s.client.SendMsg(context.Background(), addr, data, socket.SendControlData{}, 0) |
| if err != nil { |
| return 0, err |
| } |
| switch result.Which() { |
| case socket.DatagramSocketSendMsgResultErr: |
| return 0, syscall.Errno(result.Err) |
| case socket.DatagramSocketSendMsgResultResponse: |
| return int(result.Response.Len), nil |
| default: |
| panic("unreachable") |
| } |
| } |
| |
| // SendMsg implements roughly the sendmsg "system call". Its signature resembles |
| // that of syscall.Sendmsg on other platforms. |
| func (s *DatagramSocket) SendMsg(b []byte, addr net.SocketAddress) (int, error) { |
| switch addr.Which() { |
| case net.SocketAddressIpv4, net.SocketAddressIpv6: |
| return s.sendMsg(&addr, b) |
| default: |
| return s.sendMsg(nil, b) |
| } |
| } |
| |
| // Close implements fdio.FDIO. |
| func (s *DatagramSocket) Close() error { |
| return closeSocket(&s.client, &s.client.Channel, &s.event) |
| } |
| |
| // Handles implements fdio.FDIO. |
| func (s *DatagramSocket) Handles() []zx.Handle { |
| return []zx.Handle{*s.client.Handle(), s.event} |
| } |
| |
| // Read implements fdio.FDIO. |
| func (s *DatagramSocket) Read(data []byte) (int, error) { |
| _, b, err := s.recvMsg(uint32(len(data))) |
| return copy(data, b), err |
| } |
| |
| // Write implements fdio.FDIO. |
| func (s *DatagramSocket) Write(data []byte) (int, error) { |
| return s.sendMsg(nil, data) |
| } |
| |
| // StreamSocket is a stream socket. |
| type StreamSocket struct { |
| stub |
| |
| client socket.StreamSocketWithCtxInterface |
| socket zx.Socket |
| } |
| |
| // Clone implements Socket. |
| func (s *StreamSocket) Clone() (fdio.FDIO, error) { |
| return clone(&s.client) |
| } |
| |
| // Bind implements Socket. |
| func (s *StreamSocket) Bind(addr net.SocketAddress) error { |
| return bind(&s.client, addr) |
| } |
| |
| // Connect implements Socket. |
| func (s *StreamSocket) Connect(addr net.SocketAddress) error { |
| return connect(&s.client, addr) |
| } |
| |
| // GetPeerName implements Socket. |
| func (s *StreamSocket) GetPeerName() (net.SocketAddress, error) { |
| return getPeerName(&s.client) |
| } |
| |
| // GetSockName implements Socket. |
| func (s *StreamSocket) GetSockName() (net.SocketAddress, error) { |
| return getSockName(&s.client) |
| } |
| |
| // SetKeepAlive implements Socket. |
| func (s *StreamSocket) SetKeepAlive(keepalive bool) error { |
| return setKeepAlive(&s.client, keepalive) |
| } |
| |
| // SetNoDelay sets the SOL_TCP, TCP_NODELAY socket option. |
| func (s *StreamSocket) SetNoDelay(noDelay bool) error { |
| result, err := s.client.SetTcpNoDelay(context.Background(), noDelay) |
| if err != nil { |
| return err |
| } |
| switch result.Which() { |
| case socket.StreamSocketSetTcpNoDelayResultErr: |
| return syscall.Errno(result.Err) |
| case socket.StreamSocketSetTcpNoDelayResultResponse: |
| return nil |
| default: |
| panic("unreachable") |
| } |
| } |
| |
| // SetKeepAlivePeriod sets the SOL_TCP, {TCP_KEEPINTVL,TCP_KEEPIDLE} socket options. |
| func (s *StreamSocket) SetKeepAlivePeriod(seconds uint32) error { |
| { |
| result, err := s.client.SetTcpKeepAliveInterval(context.Background(), seconds) |
| if err != nil { |
| return err |
| } |
| switch result.Which() { |
| case socket.StreamSocketSetTcpKeepAliveIntervalResultErr: |
| return syscall.Errno(result.Err) |
| case socket.StreamSocketSetTcpKeepAliveIntervalResultResponse: |
| default: |
| panic("unreachable") |
| } |
| } |
| { |
| result, err := s.client.SetTcpKeepAliveIdle(context.Background(), seconds) |
| if err != nil { |
| return err |
| } |
| switch result.Which() { |
| case socket.StreamSocketSetTcpKeepAliveIdleResultErr: |
| return syscall.Errno(result.Err) |
| case socket.StreamSocketSetTcpKeepAliveIdleResultResponse: |
| default: |
| panic("unreachable") |
| } |
| } |
| return nil |
| } |
| |
| // Accept accepts an incoming connection. |
| func (s *StreamSocket) Accept(wantAddr bool) (*StreamSocket, *net.SocketAddress, error) { |
| result, err := s.client.Accept(context.Background(), wantAddr) |
| if err != nil { |
| return nil, nil, err |
| } |
| switch result.Which() { |
| case socket.StreamSocketAcceptResultErr: |
| return nil, nil, syscall.Errno(result.Err) |
| case socket.StreamSocketAcceptResultResponse: |
| newS, err := NewSocket(&socket.BaseSocketWithCtxInterface{Channel: result.Response.S.Channel}) |
| if err != nil { |
| return nil, nil, err |
| } |
| return newS.(*StreamSocket), result.Response.Addr, nil |
| default: |
| panic("unreachable") |
| } |
| } |
| |
| // Listen begins listening for incoming connections. |
| func (s *StreamSocket) Listen(backlog int16) error { |
| result, err := s.client.Listen(context.Background(), backlog) |
| if err != nil { |
| return err |
| } |
| switch result.Which() { |
| case socket.StreamSocketListenResultErr: |
| return syscall.Errno(result.Err) |
| case socket.StreamSocketListenResultResponse: |
| return nil |
| default: |
| panic("unreachable") |
| } |
| } |
| |
| // Wait waits on the receiver's socket handle until any of the provided signals |
| // are asserted, or until the timeout elapses. |
| func (s *StreamSocket) Wait(ctx context.Context, signals zx.Signals) (zx.Signals, error) { |
| return zxwait.WaitContext(ctx, *s.socket.Handle(), signals) |
| } |
| |
| // Close implements fdio.FDIO. |
| func (s *StreamSocket) Close() error { |
| return closeSocket(&s.client, &s.client.Channel, s.socket.Handle()) |
| } |
| |
| // Handles implements fdio.FDIO. |
| func (s *StreamSocket) Handles() []zx.Handle { |
| return []zx.Handle{*s.client.Handle(), *s.socket.Handle()} |
| } |
| |
| // Read implements fdio.FDIO. |
| func (s *StreamSocket) Read(data []byte) (int, error) { |
| for { |
| n, err := s.socket.Read(data, 0) |
| if err != nil { |
| if err, ok := err.(*zx.Error); ok { |
| switch err.Status { |
| case zx.ErrPeerClosed: |
| return 0, io.EOF |
| case zx.ErrShouldWait: |
| obs, err := s.Wait(context.Background(), zx.SignalSocketReadable|zx.SignalSocketPeerClosed) |
| if err != nil { |
| if err, ok := err.(*zx.Error); ok { |
| switch err.Status { |
| case zx.ErrBadHandle, zx.ErrCanceled: |
| return 0, io.EOF |
| } |
| } |
| return 0, err |
| } |
| switch { |
| case obs&zx.SignalSocketReadable != 0: |
| continue |
| case obs&zx.SignalSocketPeerClosed != 0: |
| return 0, io.EOF |
| } |
| } |
| } |
| return 0, err |
| } |
| return n, nil |
| } |
| } |
| |
| // Write implements fdio.FDIO. |
| func (s *StreamSocket) Write(data []byte) (int, error) { |
| var total int |
| for { |
| n, err := s.socket.Write(data, 0) |
| total += n |
| if err != nil { |
| if err, ok := err.(*zx.Error); ok { |
| switch err.Status { |
| case zx.ErrShouldWait: |
| obs, err := s.Wait(context.Background(), zx.SignalSocketWritable|zx.SignalSocketPeerClosed|zx.SignalSocketWriteDisabled) |
| if err != nil { |
| return total, err |
| } |
| if obs&zx.SignalSocketPeerClosed != 0 || obs&zx.SignalSocketWriteDisabled != 0 { |
| return total, &zx.Error{Status: zx.ErrPeerClosed, Text: "zxsocket.Socket.Write"} |
| } |
| if obs&zx.SignalSocketWritable != 0 { |
| data = data[n:] |
| continue |
| } |
| // This case should be impossible: |
| return total, &zx.Error{Status: zx.ErrInternal, Text: "zxsocket.Socket.Write(impossible state)"} |
| } |
| } |
| return total, err |
| } |
| return total, nil |
| } |
| } |