blob: 60e81efb6f7d35d790313f192bc0bdccf2535022 [file] [log] [blame]
// Copyright 2016 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package gonet provides a Go net package compatible wrapper for a tcpip stack.
package gonet
import (
"errors"
"io"
"net"
"sync"
"time"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/transport/tcp"
"github.com/google/netstack/tcpip/transport/udp"
"github.com/google/netstack/waiter"
)
var errCanceled = errors.New("operation canceled")
// timeoutError is how the net package reports timeouts.
type timeoutError struct{}
func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
// A Listener is a wrapper around a tcpip endpoint that implements
// net.Listener.
type Listener struct {
stack tcpip.Stack
ep tcpip.Endpoint
wq *waiter.Queue
cancel chan struct{}
}
// NewListener creates a new Listener.
func NewListener(s tcpip.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Listener, error) {
// Create TCP endpoint, bind it, then start listening.
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
if err != nil {
return nil, errors.New(err.String())
}
if err := ep.Bind(addr, nil); err != nil {
ep.Close()
return nil, &net.OpError{
Op: "bind",
Net: "tcp",
Addr: fullToTCPAddr(addr),
Err: errors.New(err.String()),
}
}
if err := ep.Listen(10); err != nil {
ep.Close()
return nil, &net.OpError{
Op: "listen",
Net: "tcp",
Addr: fullToTCPAddr(addr),
Err: errors.New(err.String()),
}
}
return &Listener{
stack: s,
ep: ep,
wq: &wq,
cancel: make(chan struct{}),
}, nil
}
// Close implements net.Listener.Close.
func (l *Listener) Close() error {
l.ep.Close()
return nil
}
// Shutdown stops the HTTP server.
func (l *Listener) Shutdown() {
l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
close(l.cancel) // broadcast cancellation
}
// Addr implements net.Listener.Addr.
func (l *Listener) Addr() net.Addr {
a, err := l.ep.GetLocalAddress()
if err != nil {
return nil
}
return fullToTCPAddr(a)
}
type deadlineTimer struct {
// mu protects the fields below.
mu sync.Mutex
readTimer *time.Timer
readCancelCh chan struct{}
writeTimer *time.Timer
writeCancelCh chan struct{}
}
func (d *deadlineTimer) init() {
d.readCancelCh = make(chan struct{})
d.writeCancelCh = make(chan struct{})
}
func (d *deadlineTimer) readCancel() <-chan struct{} {
d.mu.Lock()
c := d.readCancelCh
d.mu.Unlock()
return c
}
func (d *deadlineTimer) writeCancel() <-chan struct{} {
d.mu.Lock()
c := d.writeCancelCh
d.mu.Unlock()
return c
}
// setDeadline contains the shared logic for setting a deadline.
//
// cancelCh and timer must be pointers to deadlineTimer.readCancelCh and
// deadlineTimer.readTimer or deadlineTimer.writeCancelCh and
// deadlineTimer.writeTimer.
//
// setDeadline must only be called while holding d.mu.
func (d *deadlineTimer) setDeadline(cancelCh *chan struct{}, timer **time.Timer, t time.Time) {
if *timer != nil && !(*timer).Stop() {
*cancelCh = make(chan struct{})
}
// "A zero value for t means I/O operations will not time out."
// - net.Conn.SetDeadline
if !t.IsZero() {
timeout := t.Sub(time.Now())
if timeout <= 0 {
close(*cancelCh)
return
}
// Timer.Stop returns whether or not the AfterFunc has started, but
// does not indicate whether or not it has completed. Make a copy of
// the cancel channel to prevent this code from racing with the next
// call of setDeadline replacing *cancelCh.
ch := *cancelCh
*timer = time.AfterFunc(timeout, func() {
close(ch)
})
}
}
// SetReadDeadline implements net.Conn.SetReadDeadline and
// net.PacketConn.SetReadDeadline.
func (d *deadlineTimer) SetReadDeadline(t time.Time) error {
d.mu.Lock()
d.setDeadline(&d.readCancelCh, &d.readTimer, t)
d.mu.Unlock()
return nil
}
// SetWriteDeadline implements net.Conn.SetWriteDeadline and
// net.PacketConn.SetWriteDeadline.
func (d *deadlineTimer) SetWriteDeadline(t time.Time) error {
d.mu.Lock()
d.setDeadline(&d.writeCancelCh, &d.writeTimer, t)
d.mu.Unlock()
return nil
}
// SetDeadline implements net.Conn.SetDeadline and net.PacketConn.SetDeadline.
func (d *deadlineTimer) SetDeadline(t time.Time) error {
d.mu.Lock()
d.setDeadline(&d.readCancelCh, &d.readTimer, t)
d.setDeadline(&d.writeCancelCh, &d.writeTimer, t)
d.mu.Unlock()
return nil
}
// A Conn is a wrapper around a tcpip.Endpoint that implements the net.Conn
// interface.
type Conn struct {
deadlineTimer
wq *waiter.Queue
ep tcpip.Endpoint
// readMu serializes reads and implicitly protects read.
//
// Lock ordering:
// If both readMu and deadlineTimer.mu are to be used in a single
// request, readMu must be aquired before deadlineTimer.mu.
readMu sync.Mutex
// read contains bytes that have been read from the endpoint,
// but haven't yet been returned.
read buffer.View
}
// NewConn creates a new Conn.
func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn {
c := &Conn{
wq: wq,
ep: ep,
}
c.deadlineTimer.init()
return c
}
// Accept implements net.Conn.Accept.
func (l *Listener) Accept() (net.Conn, error) {
n, wq, err := l.ep.Accept()
if err == tcpip.ErrWouldBlock {
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
l.wq.EventRegister(&waitEntry, waiter.EventIn)
defer l.wq.EventUnregister(&waitEntry)
for {
n, wq, err = l.ep.Accept()
if err != tcpip.ErrWouldBlock {
break
}
select {
case <-l.cancel:
return nil, errCanceled
case <-notifyCh:
}
}
}
if err != nil {
return nil, &net.OpError{
Op: "accept",
Net: "tcp",
Addr: l.Addr(),
Err: errors.New(err.String()),
}
}
return NewConn(wq, n), nil
}
type opErrorer interface {
newOpError(op string, err error) *net.OpError
}
// commonRead implements the common logic between net.Conn.Read and
// net.PacketConn.ReadFrom.
func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) ([]byte, error) {
read, err := ep.Read(addr)
if err == tcpip.ErrWouldBlock {
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
wq.EventRegister(&waitEntry, waiter.EventIn)
defer wq.EventUnregister(&waitEntry)
for {
read, err = ep.Read(addr)
if err != tcpip.ErrWouldBlock {
break
}
select {
case <-deadline:
return nil, errorer.newOpError("read", &timeoutError{})
case <-notifyCh:
}
}
}
if err == tcpip.ErrClosedForReceive {
return nil, io.EOF
}
if err != nil {
return nil, errorer.newOpError("read", errors.New(err.String()))
}
return read, nil
}
// Read implements net.Conn.Read.
func (c *Conn) Read(b []byte) (int, error) {
c.readMu.Lock()
defer c.readMu.Unlock()
deadline := c.readCancel()
// Check if deadline has already expired.
select {
case <-deadline:
return 0, c.newOpError("read", &timeoutError{})
default:
}
if len(c.read) == 0 {
var err error
c.read, err = commonRead(c.ep, c.wq, deadline, nil, c)
if err != nil {
return 0, err
}
}
n := copy(b, c.read)
c.read.TrimFront(n)
if len(c.read) == 0 {
c.read = nil
}
return n, nil
}
// Write implements net.Conn.Write.
func (c *Conn) Write(b []byte) (int, error) {
deadline := c.writeCancel()
// Check if deadlineTimer has already expired.
select {
case <-deadline:
return 0, c.newOpError("write", &timeoutError{})
default:
}
v := buffer.NewView(len(b))
copy(v, b)
// We must handle two soft failure conditions simultaneously:
// 1. Write may write nothing and return tcpip.ErrWouldBlock.
// If this happens, we need to register for notifications if we have
// not already and wait to try again.
// 2. Write may write fewer than the full number of bytes and return
// without error. In this case we need to try writing the remaining
// bytes again. I do not need to register for notifications.
//
// What is more, these two soft failure conditions can be interspersed.
// There is no guarantee that all of the condition #1s will occur before
// all of the condition #2s or visa-versa.
var (
err *tcpip.Error
nbytes int
reg bool
notifyCh chan struct{}
)
for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) {
if err == tcpip.ErrWouldBlock {
if !reg {
// Only register once.
reg = true
// Create wait queue entry that notifies a channel.
var waitEntry waiter.Entry
waitEntry, notifyCh = waiter.NewChannelEntry(nil)
c.wq.EventRegister(&waitEntry, waiter.EventOut)
defer c.wq.EventUnregister(&waitEntry)
} else {
// Don't wait immediately after registration in case more data
// became available between when we last checked and when we setup
// the notification.
select {
case <-deadline:
return 0, c.newOpError("write", &timeoutError{})
case <-notifyCh:
}
}
}
var n uintptr
n, err = c.ep.Write(v, nil)
nbytes += int(n)
v.TrimFront(int(n))
}
if err == nil {
return nbytes, nil
}
return 0, c.newOpError("write", errors.New(err.String()))
}
// Close implements net.Conn.Close.
func (c *Conn) Close() error {
c.ep.Close()
return nil
}
// LocalAddr implements net.Conn.LocalAddr.
func (c *Conn) LocalAddr() net.Addr {
a, err := c.ep.GetLocalAddress()
if err != nil {
return nil
}
return fullToTCPAddr(a)
}
// RemoteAddr implements net.Conn.RemoteAddr.
func (c *Conn) RemoteAddr() net.Addr {
a, err := c.ep.GetRemoteAddress()
if err != nil {
return nil
}
return fullToTCPAddr(a)
}
func (c *Conn) newOpError(op string, err error) *net.OpError {
return &net.OpError{
Op: op,
Net: "tcp",
Source: c.LocalAddr(),
Addr: c.RemoteAddr(),
Err: err,
}
}
func fullToTCPAddr(addr tcpip.FullAddress) *net.TCPAddr {
return &net.TCPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)}
}
func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr {
return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)}
}
// DialTCP creates a new TCP Conn connected to the specified address.
func DialTCP(s tcpip.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) {
// Create TCP endpoint, then connect.
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
if err != nil {
return nil, errors.New(err.String())
}
// Create wait queue entry that notifies a channel.
//
// We do this unconditionally as Connect will always return an error.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
wq.EventRegister(&waitEntry, waiter.EventOut)
defer wq.EventUnregister(&waitEntry)
err = ep.Connect(addr)
for err != nil {
if err != tcpip.ErrConnectStarted {
ep.Close()
return nil, &net.OpError{
Op: "connect",
Net: "tcp",
Addr: fullToTCPAddr(addr),
Err: errors.New(err.String()),
}
}
<-notifyCh
err = ep.GetSockOpt(tcpip.ErrorOption{})
}
return NewConn(&wq, ep), nil
}
// A PacketConn is a wrapper around a tcpip endpoint that implements
// net.PacketConn.
type PacketConn struct {
deadlineTimer
stack tcpip.Stack
ep tcpip.Endpoint
wq *waiter.Queue
}
// NewPacketConn creates a new PacketConn.
func NewPacketConn(s tcpip.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) {
// Create UDP endpoint and bind it.
var wq waiter.Queue
ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq)
if err != nil {
return nil, errors.New(err.String())
}
if err := ep.Bind(addr, nil); err != nil {
ep.Close()
return nil, &net.OpError{
Op: "bind",
Net: "udp",
Addr: fullToUDPAddr(addr),
Err: errors.New(err.String()),
}
}
c := &PacketConn{
stack: s,
ep: ep,
wq: &wq,
}
c.deadlineTimer.init()
return c, nil
}
func (c *PacketConn) newOpError(op string, err error) *net.OpError {
return c.newRemoteOpError(op, nil, err)
}
func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError {
return &net.OpError{
Op: op,
Net: "udp",
Source: c.LocalAddr(),
Addr: remote,
Err: err,
}
}
// ReadFrom implements net.PacketConn.ReadFrom.
func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
deadline := c.readCancel()
// Check if deadline has already expired.
select {
case <-deadline:
return 0, nil, c.newOpError("read", &timeoutError{})
default:
}
var addr tcpip.FullAddress
read, err := commonRead(c.ep, c.wq, deadline, &addr, c)
if err != nil {
return 0, nil, err
}
return copy(b, read), fullToUDPAddr(addr), nil
}
// WriteTo implements net.PacketConn.WriteTo.
func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
deadline := c.writeCancel()
// Check if deadline has already expired.
select {
case <-deadline:
return 0, c.newRemoteOpError("write", addr, &timeoutError{})
default:
}
ua := addr.(*net.UDPAddr)
fullAddr := tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)}
v := buffer.NewView(len(b))
copy(v, b)
n, err := c.ep.Write(v, &fullAddr)
if err == tcpip.ErrWouldBlock {
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
c.wq.EventRegister(&waitEntry, waiter.EventOut)
defer c.wq.EventUnregister(&waitEntry)
for {
n, err = c.ep.Write(v, &fullAddr)
if err != tcpip.ErrWouldBlock {
break
}
select {
case <-deadline:
return 0, c.newRemoteOpError("write", addr, &timeoutError{})
case <-notifyCh:
}
}
}
if err == nil {
return int(n), nil
}
return 0, c.newRemoteOpError("write", addr, errors.New(err.String()))
}
// Close implements net.PacketConn.Close.
func (c *PacketConn) Close() error {
c.ep.Close()
return nil
}
// LocalAddr implements net.PacketConn.LocalAddr.
func (c *PacketConn) LocalAddr() net.Addr {
a, err := c.ep.GetLocalAddress()
if err != nil {
return nil
}
return fullToUDPAddr(a)
}