blob: 1c53ea0f169e854299f856863d4ff4f4ffb5cfd4 [file] [log] [blame]
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Go's distribution tools attempt to compile everything; this file
// depends on zxwait, which doesn't compile in not-Fuchsia.
//go:build fuchsia
// +build fuchsia
package zxsocket
import (
"io"
"strconv"
"strings"
"syscall"
"syscall/zx"
"syscall/zx/fdio"
"syscall/zx/internal/context"
fidlIo "syscall/zx/io"
"syscall/zx/net"
"syscall/zx/posix/socket"
"syscall/zx/zxwait"
)
// These constants mirror those defined in
// https://cs.opensource.google/fuchsia/fuchsia/+/main:sdk/lib/fdio/socket.cc
const (
SignalStreamIncoming = zx.SignalUser0
SignalStreamConnected = zx.SignalUser3
SignalDatagramIncoming = zx.SignalUser0
SignalDatagramOutgoing = zx.SignalUser1
SignalDatagramError = zx.SignalUser2
SignalDatagramShutdownRead = zx.SignalUser4
SignalDatagramShutdownWrite = zx.SignalUser5
)
var _ Socket = (*DatagramSocket)(nil)
var _ Socket = (*StreamSocket)(nil)
// Socket is the common subset of datagram and stream sockets.
type Socket interface {
fdio.FDIO
Bind(net.SocketAddress) error
Connect(net.SocketAddress) error
GetPeerName() (net.SocketAddress, error)
GetSockName() (net.SocketAddress, error)
SetKeepAlive(bool) error
}
// NewSocket creates a new Socket.
func NewSocket(base *socket.BaseSocketWithCtxInterface) (Socket, error) {
info, err := base.Describe(context.Background())
if err != nil {
return nil, err
}
switch w := info.Which(); w {
case fidlIo.NodeInfoService, fidlIo.NodeInfoFile, fidlIo.NodeInfoDirectory, fidlIo.NodeInfoPipe, fidlIo.NodeInfoVmofile, fidlIo.NodeInfoDevice, fidlIo.NodeInfoTty:
return nil, &zx.Error{Status: zx.ErrInternal, Text: "zxsocket.NewSocket"}
case fidlIo.NodeInfoDatagramSocket:
return &DatagramSocket{
client: socket.DatagramSocketWithCtxInterface{Channel: base.Channel},
event: info.DatagramSocket.Event,
}, nil
case fidlIo.NodeInfoStreamSocket:
return &StreamSocket{
client: socket.StreamSocketWithCtxInterface{Channel: base.Channel},
socket: info.StreamSocket.Socket,
}, nil
default:
panic("unknown node info tag " + strconv.FormatInt(int64(w), 10))
}
}
type stub struct{}
func (*stub) Sync() error {
return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Sync"}
}
func (*stub) GetAttr() (fidlIo.NodeAttributes, error) {
return fidlIo.NodeAttributes{}, &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.GetAttr"}
}
func (*stub) SetAttr(uint32, fidlIo.NodeAttributes) error {
return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.SetAttr"}
}
func (*stub) ReadAt([]byte, int64) (int, error) {
return 0, &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.ReadAt"}
}
func (*stub) WriteAt([]byte, int64) (int, error) {
return 0, &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.WriteAt"}
}
func (*stub) Seek(int64, int) (int64, error) {
return 0, &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Seek"}
}
func (*stub) Truncate(uint64) error {
return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Truncate"}
}
func (*stub) Open(string, uint32, uint32) (fdio.FDIO, error) {
return nil, &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Open"}
}
func (*stub) Link(string, string) error {
return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Link"}
}
func (*stub) Rename(string, string) error {
return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Rename"}
}
func (*stub) Unlink(string) error {
return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Unlink"}
}
func (*stub) ReadDirents(uint64) ([]byte, error) {
return nil, &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.ReadDirents"}
}
func (*stub) Rewind() error {
return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Rewind"}
}
func clone(client socket.BaseSocketWithCtx) (Socket, error) {
req, obj, err := fidlIo.NewNodeWithCtxInterfaceRequest()
if err != nil {
return nil, err
}
if err := client.Clone(context.Background(), 0, req); err != nil {
return nil, err
}
return NewSocket((*socket.BaseSocketWithCtxInterface)(obj))
}
func bind(client socket.BaseSocketWithCtx, addr net.SocketAddress) error {
result, err := client.Bind(context.Background(), addr)
if err != nil {
return err
}
switch result.Which() {
case socket.BaseSocketBindResultErr:
return syscall.Errno(result.Err)
case socket.BaseSocketBindResultResponse:
return nil
default:
panic("unreachable")
}
}
func connect(client socket.BaseSocketWithCtx, addr net.SocketAddress) error {
result, err := client.Connect(context.Background(), addr)
if err != nil {
return err
}
switch result.Which() {
case socket.BaseSocketConnectResultErr:
return syscall.Errno(result.Err)
case socket.BaseSocketConnectResultResponse:
return nil
default:
panic("unreachable")
}
}
func getPeerName(client socket.BaseSocketWithCtx) (net.SocketAddress, error) {
result, err := client.GetPeerName(context.Background())
if err != nil {
return net.SocketAddress{}, err
}
switch result.Which() {
case socket.BaseSocketGetPeerNameResultErr:
return net.SocketAddress{}, syscall.Errno(result.Err)
case socket.BaseSocketGetPeerNameResultResponse:
return result.Response.Addr, nil
default:
panic("unreachable")
}
}
func getSockName(client socket.BaseSocketWithCtx) (net.SocketAddress, error) {
result, err := client.GetSockName(context.Background())
if err != nil {
return net.SocketAddress{}, err
}
switch result.Which() {
case socket.BaseSocketGetSockNameResultErr:
return net.SocketAddress{}, syscall.Errno(result.Err)
case socket.BaseSocketGetSockNameResultResponse:
return result.Response.Addr, nil
default:
panic("unreachable")
}
}
func setKeepAlive(client socket.BaseSocketWithCtx, keepalive bool) error {
result, err := client.SetKeepAlive(context.Background(), keepalive)
if err != nil {
return err
}
switch result.Which() {
case socket.BaseSocketSetKeepAliveResultErr:
return syscall.Errno(result.Err)
case socket.BaseSocketSetKeepAliveResultResponse:
return nil
default:
panic("unreachable")
}
}
type closeError struct {
fidl error
channel struct {
wait error
close error
}
handleClose error
}
func (err *closeError) Error() string {
var b strings.Builder
if err.fidl != nil {
if b.Len() != 0 {
b.WriteByte(' ')
}
b.WriteString("fidl=")
b.WriteString(err.fidl.Error())
}
if err.channel.wait != nil {
if b.Len() != 0 {
b.WriteByte(' ')
}
b.WriteString("channel.wait=")
b.WriteString(err.channel.wait.Error())
}
if err.channel.close != nil {
if b.Len() != 0 {
b.WriteByte(' ')
}
b.WriteString("channel.close=")
b.WriteString(err.channel.close.Error())
}
if err.handleClose != nil {
if b.Len() != 0 {
b.WriteByte(' ')
}
b.WriteString("handleClose=")
b.WriteString(err.handleClose.Error())
}
return b.String()
}
func closeSocket(client socket.BaseSocketWithCtx, channel *zx.Channel, handle *zx.Handle) error {
var err closeError
err.fidl = func() error {
status, err := client.Close(context.Background())
if err != nil {
return err
}
if status := zx.Status(status); status != zx.ErrOk {
return &zx.Error{Status: status, Text: "zxsocket.Socket.Close"}
}
return nil
}()
_, err.channel.wait = zxwait.WaitContext(context.Background(), *channel.Handle(), zx.SignalChannelPeerClosed)
err.channel.close = channel.Close()
err.handleClose = handle.Close()
if err.fidl == nil && err.channel.wait == nil && err.channel.close == nil && err.handleClose == nil {
return nil
}
return &err
}
// DatagramSocket is a datagram socket.
type DatagramSocket struct {
stub
client socket.DatagramSocketWithCtxInterface
event zx.Handle
}
// Clone implements Socket.
func (s *DatagramSocket) Clone() (fdio.FDIO, error) {
return clone(&s.client)
}
// Bind implements Socket.
func (s *DatagramSocket) Bind(addr net.SocketAddress) error {
return bind(&s.client, addr)
}
// Connect implements Socket.
func (s *DatagramSocket) Connect(addr net.SocketAddress) error {
return connect(&s.client, addr)
}
// GetPeerName implements Socket.
func (s *DatagramSocket) GetPeerName() (net.SocketAddress, error) {
return getPeerName(&s.client)
}
// GetSockName implements Socket.
func (s *DatagramSocket) GetSockName() (net.SocketAddress, error) {
return getSockName(&s.client)
}
// SetKeepAlive implements Socket.
func (s *DatagramSocket) SetKeepAlive(keepalive bool) error {
return setKeepAlive(&s.client, keepalive)
}
func (s *DatagramSocket) recvMsg(dataLen uint32) (net.SocketAddress, []byte, error) {
for {
result, err := s.client.RecvMsg(context.Background() /* wantAddr */, true, dataLen /* wantControl */, false, 0)
if err != nil {
return net.SocketAddress{}, nil, err
}
switch result.Which() {
case socket.DatagramSocketRecvMsgResultErr:
if err := syscall.Errno(result.Err); err != syscall.EAGAIN {
return net.SocketAddress{}, nil, err
}
obs, err := zxwait.WaitContext(context.Background(), s.event, SignalDatagramIncoming|SignalDatagramShutdownRead|zx.SignalEpairPeerClosed)
if err != nil {
return net.SocketAddress{}, nil, err
}
if obs&SignalDatagramIncoming != 0 {
continue
}
if obs&(SignalDatagramShutdownRead|zx.SignalEpairPeerClosed) != 0 {
return net.SocketAddress{}, nil, &zx.Error{Status: zx.ErrPeerClosed, Text: "zxsocket.DatagramSocket.RecvMsg"}
}
panic("unreachable")
case socket.DatagramSocketRecvMsgResultResponse:
return *result.Response.Addr, result.Response.Data, nil
default:
panic("unreachable")
}
}
}
// RecvMsg implements roughly the recvmsg "system call". Its signature resembles
// that of syscall.Revcvmsg on other platforms.
func (s *DatagramSocket) RecvMsg(maxLen int) ([]byte, net.SocketAddress, error) {
addr, data, err := s.recvMsg(uint32(maxLen))
if err != nil {
return nil, net.SocketAddress{}, err
}
return data, addr, nil
}
func (s *DatagramSocket) sendMsg(addr *net.SocketAddress, data []byte) (int, error) {
result, err := s.client.SendMsg(context.Background(), addr, data, socket.SendControlData{}, 0)
if err != nil {
return 0, err
}
switch result.Which() {
case socket.DatagramSocketSendMsgResultErr:
return 0, syscall.Errno(result.Err)
case socket.DatagramSocketSendMsgResultResponse:
return int(result.Response.Len), nil
default:
panic("unreachable")
}
}
// SendMsg implements roughly the sendmsg "system call". Its signature resembles
// that of syscall.Sendmsg on other platforms.
func (s *DatagramSocket) SendMsg(b []byte, addr net.SocketAddress) (int, error) {
switch addr.Which() {
case net.SocketAddressIpv4, net.SocketAddressIpv6:
return s.sendMsg(&addr, b)
default:
return s.sendMsg(nil, b)
}
}
// Close implements fdio.FDIO.
func (s *DatagramSocket) Close() error {
return closeSocket(&s.client, &s.client.Channel, &s.event)
}
// Handles implements fdio.FDIO.
func (s *DatagramSocket) Handles() []zx.Handle {
return []zx.Handle{*s.client.Handle(), s.event}
}
// Read implements fdio.FDIO.
func (s *DatagramSocket) Read(data []byte) (int, error) {
_, b, err := s.recvMsg(uint32(len(data)))
return copy(data, b), err
}
// Write implements fdio.FDIO.
func (s *DatagramSocket) Write(data []byte) (int, error) {
return s.sendMsg(nil, data)
}
// StreamSocket is a stream socket.
type StreamSocket struct {
stub
client socket.StreamSocketWithCtxInterface
socket zx.Socket
}
// Clone implements Socket.
func (s *StreamSocket) Clone() (fdio.FDIO, error) {
return clone(&s.client)
}
// Bind implements Socket.
func (s *StreamSocket) Bind(addr net.SocketAddress) error {
return bind(&s.client, addr)
}
// Connect implements Socket.
func (s *StreamSocket) Connect(addr net.SocketAddress) error {
return connect(&s.client, addr)
}
// GetPeerName implements Socket.
func (s *StreamSocket) GetPeerName() (net.SocketAddress, error) {
return getPeerName(&s.client)
}
// GetSockName implements Socket.
func (s *StreamSocket) GetSockName() (net.SocketAddress, error) {
return getSockName(&s.client)
}
// SetKeepAlive implements Socket.
func (s *StreamSocket) SetKeepAlive(keepalive bool) error {
return setKeepAlive(&s.client, keepalive)
}
// SetNoDelay sets the SOL_TCP, TCP_NODELAY socket option.
func (s *StreamSocket) SetNoDelay(noDelay bool) error {
result, err := s.client.SetTcpNoDelay(context.Background(), noDelay)
if err != nil {
return err
}
switch result.Which() {
case socket.StreamSocketSetTcpNoDelayResultErr:
return syscall.Errno(result.Err)
case socket.StreamSocketSetTcpNoDelayResultResponse:
return nil
default:
panic("unreachable")
}
}
// SetKeepAlivePeriod sets the SOL_TCP, {TCP_KEEPINTVL,TCP_KEEPIDLE} socket options.
func (s *StreamSocket) SetKeepAlivePeriod(seconds uint32) error {
{
result, err := s.client.SetTcpKeepAliveInterval(context.Background(), seconds)
if err != nil {
return err
}
switch result.Which() {
case socket.StreamSocketSetTcpKeepAliveIntervalResultErr:
return syscall.Errno(result.Err)
case socket.StreamSocketSetTcpKeepAliveIntervalResultResponse:
default:
panic("unreachable")
}
}
{
result, err := s.client.SetTcpKeepAliveIdle(context.Background(), seconds)
if err != nil {
return err
}
switch result.Which() {
case socket.StreamSocketSetTcpKeepAliveIdleResultErr:
return syscall.Errno(result.Err)
case socket.StreamSocketSetTcpKeepAliveIdleResultResponse:
default:
panic("unreachable")
}
}
return nil
}
// Accept accepts an incoming connection.
func (s *StreamSocket) Accept(wantAddr bool) (*StreamSocket, *net.SocketAddress, error) {
result, err := s.client.Accept(context.Background(), wantAddr)
if err != nil {
return nil, nil, err
}
switch result.Which() {
case socket.StreamSocketAcceptResultErr:
return nil, nil, syscall.Errno(result.Err)
case socket.StreamSocketAcceptResultResponse:
newS, err := NewSocket(&socket.BaseSocketWithCtxInterface{Channel: result.Response.S.Channel})
if err != nil {
return nil, nil, err
}
return newS.(*StreamSocket), result.Response.Addr, nil
default:
panic("unreachable")
}
}
// Listen begins listening for incoming connections.
func (s *StreamSocket) Listen(backlog int16) error {
result, err := s.client.Listen(context.Background(), backlog)
if err != nil {
return err
}
switch result.Which() {
case socket.StreamSocketListenResultErr:
return syscall.Errno(result.Err)
case socket.StreamSocketListenResultResponse:
return nil
default:
panic("unreachable")
}
}
// Wait waits on the receiver's socket handle until any of the provided signals
// are asserted, or until the timeout elapses.
func (s *StreamSocket) Wait(ctx context.Context, signals zx.Signals) (zx.Signals, error) {
return zxwait.WaitContext(ctx, *s.socket.Handle(), signals)
}
// Close implements fdio.FDIO.
func (s *StreamSocket) Close() error {
return closeSocket(&s.client, &s.client.Channel, s.socket.Handle())
}
// Handles implements fdio.FDIO.
func (s *StreamSocket) Handles() []zx.Handle {
return []zx.Handle{*s.client.Handle(), *s.socket.Handle()}
}
// Read implements fdio.FDIO.
func (s *StreamSocket) Read(data []byte) (int, error) {
for {
n, err := s.socket.Read(data, 0)
if err != nil {
if err, ok := err.(*zx.Error); ok {
switch err.Status {
case zx.ErrPeerClosed:
return 0, io.EOF
case zx.ErrShouldWait:
obs, err := s.Wait(context.Background(), zx.SignalSocketReadable|zx.SignalSocketPeerClosed)
if err != nil {
if err, ok := err.(*zx.Error); ok {
switch err.Status {
case zx.ErrBadHandle, zx.ErrCanceled:
return 0, io.EOF
}
}
return 0, err
}
switch {
case obs&zx.SignalSocketReadable != 0:
continue
case obs&zx.SignalSocketPeerClosed != 0:
return 0, io.EOF
}
}
}
return 0, err
}
return n, nil
}
}
// Write implements fdio.FDIO.
func (s *StreamSocket) Write(data []byte) (int, error) {
var total int
for {
n, err := s.socket.Write(data, 0)
total += n
if err != nil {
if err, ok := err.(*zx.Error); ok {
switch err.Status {
case zx.ErrShouldWait:
obs, err := s.Wait(context.Background(), zx.SignalSocketWritable|zx.SignalSocketPeerClosed|zx.SignalSocketWriteDisabled)
if err != nil {
return total, err
}
if obs&zx.SignalSocketPeerClosed != 0 || obs&zx.SignalSocketWriteDisabled != 0 {
return total, &zx.Error{Status: zx.ErrPeerClosed, Text: "zxsocket.Socket.Write"}
}
if obs&zx.SignalSocketWritable != 0 {
data = data[n:]
continue
}
// This case should be impossible:
return total, &zx.Error{Status: zx.ErrInternal, Text: "zxsocket.Socket.Write(impossible state)"}
}
}
return total, err
}
return total, nil
}
}