blob: 6fd09ef369b89d2618d570c35f6a6e0a1e36f96f [file] [log] [blame]
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Go's distribution tools attempt to compile everything; this file
// depends on zxwait, which doesn't compile in not-Fuchsia.
//go:build fuchsia
package zxsocket
import (
stdio "io"
"strings"
"syscall"
"syscall/zx"
"syscall/zx/fdio"
"syscall/zx/internal/context"
"syscall/zx/io"
"syscall/zx/net"
"syscall/zx/posix/socket"
"syscall/zx/unknown"
"syscall/zx/zxwait"
)
const (
signalDatagramIncoming = zx.Signals(socket.SignalDatagramIncoming)
signalDatagramShutdownRead = zx.Signals(socket.SignalDatagramShutdownRead)
)
var _ Socket = (*DatagramSocket)(nil)
var _ Socket = (*StreamSocket)(nil)
// Socket is the common subset of datagram and stream sockets.
type Socket interface {
fdio.FDIO
Bind(net.SocketAddress) error
Connect(net.SocketAddress) error
GetPeerName() (net.SocketAddress, error)
GetSockName() (net.SocketAddress, error)
SetKeepAlive(bool) error
}
type stub struct{}
func (*stub) Sync() error {
return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Sync"}
}
func (*stub) GetAttr() (io.NodeAttributes, error) {
return io.NodeAttributes{}, &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.GetAttr"}
}
func (*stub) SetAttr(uint32, io.NodeAttributes) error {
return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.SetAttr"}
}
func (*stub) GetFlags() (io.OpenFlags, error) {
return 0, &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.GetFlags"}
}
func (*stub) ReadAt([]byte, int64) (int, error) {
return 0, &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.ReadAt"}
}
func (*stub) WriteAt([]byte, int64) (int, error) {
return 0, &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.WriteAt"}
}
func (*stub) Seek(int64, int) (int64, error) {
return 0, &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Seek"}
}
func (*stub) Resize(uint64) error {
return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Resize"}
}
func (*stub) Open(string, io.OpenFlags) (fdio.FDIO, error) {
return nil, &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Open"}
}
func (*stub) Link(string, string) error {
return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Link"}
}
func (*stub) Rename(string, string) error {
return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Rename"}
}
func (*stub) Unlink(string) error {
return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Unlink"}
}
func (*stub) ReadDirents(uint64) ([]byte, error) {
return nil, &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.ReadDirents"}
}
func (*stub) Rewind() error {
return &zx.Error{Status: zx.ErrNotSupported, Text: "zxsocket.Socket.Rewind"}
}
func bind(client socket.BaseNetworkSocketWithCtx, addr net.SocketAddress) error {
result, err := client.Bind(context.Background(), addr)
if err != nil {
return err
}
switch result.Which() {
case socket.BaseNetworkSocketBindResultErr:
return syscall.Errno(result.Err)
case socket.BaseNetworkSocketBindResultResponse:
return nil
default:
panic("unreachable")
}
}
func connect(client socket.BaseNetworkSocketWithCtx, addr net.SocketAddress) error {
result, err := client.Connect(context.Background(), addr)
if err != nil {
return err
}
switch result.Which() {
case socket.BaseNetworkSocketConnectResultErr:
return syscall.Errno(result.Err)
case socket.BaseNetworkSocketConnectResultResponse:
return nil
default:
panic("unreachable")
}
}
func getPeerName(client socket.BaseNetworkSocketWithCtx) (net.SocketAddress, error) {
result, err := client.GetPeerName(context.Background())
if err != nil {
return net.SocketAddress{}, err
}
switch result.Which() {
case socket.BaseNetworkSocketGetPeerNameResultErr:
return net.SocketAddress{}, syscall.Errno(result.Err)
case socket.BaseNetworkSocketGetPeerNameResultResponse:
return result.Response.Addr, nil
default:
panic("unreachable")
}
}
func getSockName(client socket.BaseNetworkSocketWithCtx) (net.SocketAddress, error) {
result, err := client.GetSockName(context.Background())
if err != nil {
return net.SocketAddress{}, err
}
switch result.Which() {
case socket.BaseNetworkSocketGetSockNameResultErr:
return net.SocketAddress{}, syscall.Errno(result.Err)
case socket.BaseNetworkSocketGetSockNameResultResponse:
return result.Response.Addr, nil
default:
panic("unreachable")
}
}
func setKeepAlive(client socket.BaseSocketWithCtx, keepalive bool) error {
result, err := client.SetKeepAlive(context.Background(), keepalive)
if err != nil {
return err
}
switch result.Which() {
case socket.BaseSocketSetKeepAliveResultErr:
return syscall.Errno(result.Err)
case socket.BaseSocketSetKeepAliveResultResponse:
return nil
default:
panic("unreachable")
}
}
type closeError struct {
fidl error
channel struct {
wait error
close error
}
handleClose error
}
func (err *closeError) Error() string {
var b strings.Builder
if err.fidl != nil {
if b.Len() != 0 {
b.WriteByte(' ')
}
b.WriteString("fidl=")
b.WriteString(err.fidl.Error())
}
if err.channel.wait != nil {
if b.Len() != 0 {
b.WriteByte(' ')
}
b.WriteString("channel.wait=")
b.WriteString(err.channel.wait.Error())
}
if err.channel.close != nil {
if b.Len() != 0 {
b.WriteByte(' ')
}
b.WriteString("channel.close=")
b.WriteString(err.channel.close.Error())
}
if err.handleClose != nil {
if b.Len() != 0 {
b.WriteByte(' ')
}
b.WriteString("handleClose=")
b.WriteString(err.handleClose.Error())
}
return b.String()
}
func closeSocket(client socket.BaseSocketWithCtx, channel *zx.Channel, handle *zx.Handle) error {
var err closeError
err.fidl = func() error {
result, err := client.Close(context.Background())
if err != nil {
return err
}
switch result.Which() {
case unknown.CloseableCloseResultErr:
return &zx.Error{Status: zx.Status(result.Err), Text: "zxsocket.Socket.Close"}
case unknown.CloseableCloseResultResponse:
return nil
default:
panic("unknown variant")
}
}()
_, err.channel.wait = zxwait.WaitContext(context.Background(), *channel.Handle(), zx.SignalChannelPeerClosed)
err.channel.close = channel.Close()
err.handleClose = handle.Close()
if err.fidl == nil && err.channel.wait == nil && err.channel.close == nil && err.handleClose == nil {
return nil
}
return &err
}
// DatagramSocket is a datagram socket.
type DatagramSocket struct {
stub
client *socket.SynchronousDatagramSocketWithCtxInterface
event zx.Handle
}
func NewDatagramSocket(client *socket.SynchronousDatagramSocketWithCtxInterface) (*DatagramSocket, error) {
info, err := client.Describe(context.Background())
if err != nil {
return nil, err
}
if !info.HasEvent() {
return nil, &zx.Error{
Status: zx.ErrInvalidArgs,
Text: "SynchronousDatagramSocket without event",
}
}
return &DatagramSocket{
client: client,
event: info.GetEvent(),
}, nil
}
func (s *DatagramSocket) clone() (*DatagramSocket, error) {
req, obj, err := socket.NewSynchronousDatagramSocketWithCtxInterfaceRequest()
if err != nil {
return nil, err
}
if err := s.client.Clone2(context.Background(), (unknown.CloneableWithCtxInterfaceRequest)(req)); err != nil {
return nil, err
}
return NewDatagramSocket(obj)
}
// Clone implements Socket.
func (s *DatagramSocket) Clone() (fdio.FDIO, error) {
return s.clone()
}
// Bind implements Socket.
func (s *DatagramSocket) Bind(addr net.SocketAddress) error {
return bind(s.client, addr)
}
// Connect implements Socket.
func (s *DatagramSocket) Connect(addr net.SocketAddress) error {
return connect(s.client, addr)
}
// GetPeerName implements Socket.
func (s *DatagramSocket) GetPeerName() (net.SocketAddress, error) {
return getPeerName(s.client)
}
// GetSockName implements Socket.
func (s *DatagramSocket) GetSockName() (net.SocketAddress, error) {
return getSockName(s.client)
}
// SetKeepAlive implements Socket.
func (s *DatagramSocket) SetKeepAlive(keepalive bool) error {
return setKeepAlive(s.client, keepalive)
}
func (s *DatagramSocket) recvMsg(dataLen uint32) (net.SocketAddress, []byte, error) {
for {
result, err := s.client.RecvMsg(context.Background() /* wantAddr */, true, dataLen /* wantControl */, false, 0)
if err != nil {
return net.SocketAddress{}, nil, err
}
switch result.Which() {
case socket.SynchronousDatagramSocketRecvMsgResultErr:
if err := syscall.Errno(result.Err); err != syscall.EAGAIN {
return net.SocketAddress{}, nil, err
}
obs, err := zxwait.WaitContext(context.Background(), s.event, signalDatagramIncoming|signalDatagramShutdownRead|zx.SignalEpairPeerClosed)
if err != nil {
return net.SocketAddress{}, nil, err
}
if obs&signalDatagramIncoming != 0 {
continue
}
if obs&(signalDatagramShutdownRead|zx.SignalEpairPeerClosed) != 0 {
return net.SocketAddress{}, nil, &zx.Error{Status: zx.ErrPeerClosed, Text: "zxsocket.SynchronousDatagramSocket.RecvMsg"}
}
panic("unreachable")
case socket.SynchronousDatagramSocketRecvMsgResultResponse:
return *result.Response.Addr, result.Response.Data, nil
default:
panic("unreachable")
}
}
}
// RecvMsg implements roughly the recvmsg "system call". Its signature resembles
// that of syscall.Revcvmsg on other platforms.
func (s *DatagramSocket) RecvMsg(maxLen int) ([]byte, net.SocketAddress, error) {
addr, data, err := s.recvMsg(uint32(maxLen))
if err != nil {
return nil, net.SocketAddress{}, err
}
return data, addr, nil
}
func (s *DatagramSocket) sendMsg(addr *net.SocketAddress, data []byte) (int, error) {
result, err := s.client.SendMsg(context.Background(), addr, data, socket.DatagramSocketSendControlData{}, 0)
if err != nil {
return 0, err
}
switch result.Which() {
case socket.SynchronousDatagramSocketSendMsgResultErr:
return 0, syscall.Errno(result.Err)
case socket.SynchronousDatagramSocketSendMsgResultResponse:
return int(result.Response.Len), nil
default:
panic("unreachable")
}
}
// SendMsg implements roughly the sendmsg "system call". Its signature resembles
// that of syscall.Sendmsg on other platforms.
func (s *DatagramSocket) SendMsg(b []byte, addr net.SocketAddress) (int, error) {
switch addr.Which() {
case net.SocketAddressIpv4, net.SocketAddressIpv6:
return s.sendMsg(&addr, b)
default:
return s.sendMsg(nil, b)
}
}
// Close implements fdio.FDIO.
func (s *DatagramSocket) Close() error {
return closeSocket(s.client, &s.client.Channel, &s.event)
}
// Handles implements fdio.FDIO.
func (s *DatagramSocket) Handles() []zx.Handle {
return []zx.Handle{*s.client.Handle(), s.event}
}
// Read implements fdio.FDIO.
func (s *DatagramSocket) Read(data []byte) (int, error) {
_, b, err := s.recvMsg(uint32(len(data)))
return copy(data, b), err
}
// Write implements fdio.FDIO.
func (s *DatagramSocket) Write(data []byte) (int, error) {
return s.sendMsg(nil, data)
}
// StreamSocket is a stream socket.
type StreamSocket struct {
stub
client *socket.StreamSocketWithCtxInterface
socket zx.Socket
}
func NewStreamSocket(client *socket.StreamSocketWithCtxInterface) (*StreamSocket, error) {
info, err := client.Describe(context.Background())
if err != nil {
return nil, err
}
if !info.HasSocket() {
return nil, &zx.Error{
Status: zx.ErrInvalidArgs,
Text: "StreamSocket without socket",
}
}
return &StreamSocket{
client: client,
socket: info.GetSocket(),
}, nil
}
func (s *StreamSocket) clone() (*StreamSocket, error) {
req, obj, err := socket.NewStreamSocketWithCtxInterfaceRequest()
if err != nil {
return nil, err
}
if err := s.client.Clone2(context.Background(), (unknown.CloneableWithCtxInterfaceRequest)(req)); err != nil {
return nil, err
}
return NewStreamSocket(obj)
}
// Clone implements Socket.
func (s *StreamSocket) Clone() (fdio.FDIO, error) {
return s.clone()
}
// Bind implements Socket.
func (s *StreamSocket) Bind(addr net.SocketAddress) error {
return bind(s.client, addr)
}
// Connect implements Socket.
func (s *StreamSocket) Connect(addr net.SocketAddress) error {
return connect(s.client, addr)
}
// GetPeerName implements Socket.
func (s *StreamSocket) GetPeerName() (net.SocketAddress, error) {
return getPeerName(s.client)
}
// GetSockName implements Socket.
func (s *StreamSocket) GetSockName() (net.SocketAddress, error) {
return getSockName(s.client)
}
// SetKeepAlive implements Socket.
func (s *StreamSocket) SetKeepAlive(keepalive bool) error {
return setKeepAlive(s.client, keepalive)
}
// SetNoDelay sets the SOL_TCP, TCP_NODELAY socket option.
func (s *StreamSocket) SetNoDelay(noDelay bool) error {
result, err := s.client.SetTcpNoDelay(context.Background(), noDelay)
if err != nil {
return err
}
switch result.Which() {
case socket.StreamSocketSetTcpNoDelayResultErr:
return syscall.Errno(result.Err)
case socket.StreamSocketSetTcpNoDelayResultResponse:
return nil
default:
panic("unreachable")
}
}
// SetKeepAlivePeriod sets the SOL_TCP, {TCP_KEEPINTVL,TCP_KEEPIDLE} socket options.
func (s *StreamSocket) SetKeepAlivePeriod(seconds uint32) error {
{
result, err := s.client.SetTcpKeepAliveInterval(context.Background(), seconds)
if err != nil {
return err
}
switch result.Which() {
case socket.StreamSocketSetTcpKeepAliveIntervalResultErr:
return syscall.Errno(result.Err)
case socket.StreamSocketSetTcpKeepAliveIntervalResultResponse:
default:
panic("unreachable")
}
}
{
result, err := s.client.SetTcpKeepAliveIdle(context.Background(), seconds)
if err != nil {
return err
}
switch result.Which() {
case socket.StreamSocketSetTcpKeepAliveIdleResultErr:
return syscall.Errno(result.Err)
case socket.StreamSocketSetTcpKeepAliveIdleResultResponse:
default:
panic("unreachable")
}
}
return nil
}
// Accept accepts an incoming connection.
func (s *StreamSocket) Accept(wantAddr bool) (*StreamSocket, *net.SocketAddress, error) {
result, err := s.client.Accept(context.Background(), wantAddr)
if err != nil {
return nil, nil, err
}
switch result.Which() {
case socket.StreamSocketAcceptResultErr:
return nil, nil, syscall.Errno(result.Err)
case socket.StreamSocketAcceptResultResponse:
newS, err := NewStreamSocket(&result.Response.S)
if err != nil {
return nil, nil, err
}
return newS, result.Response.Addr, nil
default:
panic("unreachable")
}
}
// Listen begins listening for incoming connections.
func (s *StreamSocket) Listen(backlog int16) error {
result, err := s.client.Listen(context.Background(), backlog)
if err != nil {
return err
}
switch result.Which() {
case socket.StreamSocketListenResultErr:
return syscall.Errno(result.Err)
case socket.StreamSocketListenResultResponse:
return nil
default:
panic("unreachable")
}
}
// Wait waits on the receiver's socket handle until any of the provided signals
// are asserted, or until the timeout elapses.
func (s *StreamSocket) Wait(ctx context.Context, signals zx.Signals) (zx.Signals, error) {
return zxwait.WaitContext(ctx, *s.socket.Handle(), signals)
}
// Close implements fdio.FDIO.
func (s *StreamSocket) Close() error {
return closeSocket(s.client, &s.client.Channel, s.socket.Handle())
}
// Handles implements fdio.FDIO.
func (s *StreamSocket) Handles() []zx.Handle {
return []zx.Handle{*s.client.Handle(), *s.socket.Handle()}
}
// Read implements fdio.FDIO.
func (s *StreamSocket) Read(data []byte) (int, error) {
for {
n, err := s.socket.Read(data, 0)
if err != nil {
if err, ok := err.(*zx.Error); ok {
switch err.Status {
case zx.ErrPeerClosed:
return 0, stdio.EOF
case zx.ErrShouldWait:
obs, err := s.Wait(context.Background(), zx.SignalSocketReadable|zx.SignalSocketPeerClosed)
if err != nil {
if err, ok := err.(*zx.Error); ok {
switch err.Status {
case zx.ErrBadHandle, zx.ErrCanceled:
return 0, stdio.EOF
}
}
return 0, err
}
switch {
case obs&zx.SignalSocketReadable != 0:
continue
case obs&zx.SignalSocketPeerClosed != 0:
return 0, stdio.EOF
}
}
}
return 0, err
}
return n, nil
}
}
// Write implements fdio.FDIO.
func (s *StreamSocket) Write(data []byte) (int, error) {
var total int
for {
n, err := s.socket.Write(data, 0)
total += n
if err != nil {
if err, ok := err.(*zx.Error); ok {
switch err.Status {
case zx.ErrShouldWait:
obs, err := s.Wait(context.Background(), zx.SignalSocketWritable|zx.SignalSocketPeerClosed|zx.SignalSocketWriteDisabled)
if err != nil {
return total, err
}
if obs&zx.SignalSocketPeerClosed != 0 || obs&zx.SignalSocketWriteDisabled != 0 {
return total, &zx.Error{Status: zx.ErrPeerClosed, Text: "zxsocket.Socket.Write"}
}
if obs&zx.SignalSocketWritable != 0 {
data = data[n:]
continue
}
// This case should be impossible:
return total, &zx.Error{Status: zx.ErrInternal, Text: "zxsocket.Socket.Write(impossible state)"}
}
}
return total, err
}
return total, nil
}
}