Merge https://github.com/google/netstack
Change-Id: I8f075dc77fe7de97b5c1726171f63a97f47671ca
diff --git a/dhcp/client.go b/dhcp/client.go
index 9a73e76..c40c3e0 100644
--- a/dhcp/client.go
+++ b/dhcp/client.go
@@ -196,17 +196,18 @@
var opts options
for {
var addr tcpip.FullAddress
- v, err := epin.Read(&addr)
- if err == tcpip.ErrWouldBlock {
+ v, e := epin.Read(&addr)
+ if e == tcpip.ErrWouldBlock {
select {
case <-ch:
continue
case <-ctx.Done():
- return Config{}, tcpip.ErrAborted
+ return Config{}, fmt.Errorf("reading dhcp offer: %v", tcpip.ErrAborted)
}
}
h = header(v)
var valid bool
+ var err error
opts, valid, err = loadDHCPReply(h, dhcpOFFER, xid[:])
if !valid {
if err != nil {
@@ -226,7 +227,7 @@
addr := tcpip.Address(h.yiaddr())
if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, addr); err != nil {
if err != tcpip.ErrDuplicateAddress {
- return Config{}, err
+ return Config{}, fmt.Errorf("adding address: %v", err)
}
}
defer func() {
@@ -283,17 +284,18 @@
// DHCPACK
for {
var addr tcpip.FullAddress
- v, err := epin.Read(&addr)
- if err == tcpip.ErrWouldBlock {
+ v, e := epin.Read(&addr)
+ if e == tcpip.ErrWouldBlock {
select {
case <-ch:
continue
case <-ctx.Done():
- return Config{}, tcpip.ErrAborted
+ return Config{}, fmt.Errorf("reading dhcp ack: %v", tcpip.ErrAborted)
}
}
h = header(v)
var valid bool
+ var err error
opts, valid, err = loadDHCPReply(h, dhcpACK, xid[:])
if !valid {
if err != nil {
@@ -317,13 +319,13 @@
if !h.isValid() || h.op() != opReply || !bytes.Equal(h.xidbytes(), xid[:]) {
return nil, false, nil
}
- opts, err = h.options()
- if err != nil {
- return nil, false, err
+ opts, e := h.options()
+ if e != nil {
+ return nil, false, fmt.Errorf("dhcp ack: %v", e)
}
- msgtype, err := opts.dhcpMsgType()
- if err != nil {
- return nil, false, err
+ msgtype, e := opts.dhcpMsgType()
+ if e != nil {
+ return nil, false, fmt.Errorf("dhcp ack: %v", e)
}
if msgtype != typ {
return nil, false, nil
diff --git a/dhcp/server.go b/dhcp/server.go
index 8162acc..56a1d33 100644
--- a/dhcp/server.go
+++ b/dhcp/server.go
@@ -77,13 +77,13 @@
return nil, tcpip.FullAddress{}, io.EOF
}
}
- return v, addr, err
+ return v, addr, fmt.Errorf("dhcp: %v", err)
}
}
func (c *epConn) Write(b []byte, addr *tcpip.FullAddress) error {
_, err := c.ep.Write(b, addr)
- return err
+ return fmt.Errorf("dhcp: %v", err)
}
// NewServer creates a new DHCP server and begins serving.
diff --git a/dns/client.go b/dns/client.go
index 120f2bb..53b5a96 100644
--- a/dns/client.go
+++ b/dns/client.go
@@ -97,7 +97,7 @@
for len(b) > 0 {
n, err := ep.Write(b, nil)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("dns: write: %v", err)
}
b = b[n:]
@@ -120,11 +120,11 @@
case <-notifyCh:
continue
case <-ctx.Done():
- return nil, tcpip.ErrTimeout
+ return nil, fmt.Errorf("dns: read: %v", tcpip.ErrTimeout)
}
}
- return nil, err
+ return nil, fmt.Errorf("dns: read: %v", err)
}
b = append(b, []byte(v)...)
@@ -160,7 +160,7 @@
var wq waiter.Queue
ep, err := c.stack.NewEndpoint(transport, ipv4.ProtocolNumber, &wq)
if err != nil {
- return nil, nil, err
+ return nil, nil, fmt.Errorf("dns: %v", err)
}
// Issue connect request and wait for it to complete.
@@ -173,12 +173,12 @@
case <-notifyCh:
err = ep.GetSockOpt(tcpip.ErrorOption{})
case <-ctx.Done():
- return nil, nil, tcpip.ErrTimeout
+ err = tcpip.ErrTimeout
}
}
if err != nil {
- return nil, nil, err
+ return nil, nil, fmt.Errorf("dns: %v", err)
}
return ep, &wq, nil
diff --git a/tcpip/adapters/gonet/gonet.go b/tcpip/adapters/gonet/gonet.go
index 5d9031c..8abaa7d 100644
--- a/tcpip/adapters/gonet/gonet.go
+++ b/tcpip/adapters/gonet/gonet.go
@@ -42,7 +42,7 @@
var wq waiter.Queue
tcpEP, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
if err != nil {
- return nil, err
+ return nil, errors.New(err.String())
}
if err := tcpEP.Bind(addr, nil); err != nil {
@@ -51,7 +51,7 @@
Op: "bind",
Net: "tcp",
Addr: fullToTCPAddr(addr),
- Err: err,
+ Err: errors.New(err.String()),
}
}
@@ -61,7 +61,7 @@
Op: "listen",
Net: "tcp",
Addr: fullToTCPAddr(addr),
- Err: err,
+ Err: errors.New(err.String()),
}
}
@@ -160,7 +160,7 @@
Op: "accept",
Net: "tcp",
Addr: l.Addr(),
- Err: err,
+ Err: errors.New(err.String()),
}
}
@@ -184,7 +184,7 @@
}
if len(c.read) == 0 {
- var err error
+ var err *tcpip.Error
c.read, err = c.ep.Read(nil)
if err == tcpip.ErrWouldBlock {
@@ -210,7 +210,7 @@
}
if err != nil {
- return 0, c.newOpError("read", err)
+ return 0, c.newOpError("read", errors.New(err.String()))
}
}
@@ -250,7 +250,7 @@
// There is no guarantee that all of the condition #1s will occur before
// all of the condition #2s or visa-versa.
var (
- err error
+ err *tcpip.Error
nbytes int
reg bool
notifyCh chan struct{}
@@ -288,7 +288,7 @@
return nbytes, nil
}
- return 0, c.newOpError("write", err)
+ return 0, c.newOpError("write", errors.New(err.String()))
}
// Close implements net.Conn.Close.
diff --git a/tcpip/adapters/gonet/gonet_test.go b/tcpip/adapters/gonet/gonet_test.go
index ec0f791..b9fd1d1 100644
--- a/tcpip/adapters/gonet/gonet_test.go
+++ b/tcpip/adapters/gonet/gonet_test.go
@@ -42,7 +42,7 @@
}
}
-func newLoopbackStack() (*stack.Stack, error) {
+func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
// Create the stack and add a NIC.
s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName, udp.ProtocolName}).(*stack.Stack)
@@ -79,7 +79,7 @@
ep tcpip.Endpoint
}
-func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, error) {
+func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) {
wq := &waiter.Queue{}
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
@@ -117,9 +117,9 @@
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
- l, err := NewListener(s, addr, ipv4.ProtocolNumber)
- if err != nil {
- t.Fatalf("NewListener() = %v", err)
+ l, e := NewListener(s, addr, ipv4.ProtocolNumber)
+ if e != nil {
+ t.Fatalf("NewListener() = %v", e)
}
done := make(chan struct{})
go func() {
@@ -141,7 +141,7 @@
n, err := c.Read(buf)
got, ok := err.(*net.OpError)
want := tcpip.ErrConnectionAborted
- if n != 0 || !ok || got.Err != want {
+ if n != 0 || !ok || got.Err.Error() != want.String() {
t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, err, want)
}
t.Logf("c.Read() = %d, %v", n, err)
@@ -194,13 +194,13 @@
buf := make([]byte, 256)
t.Log("c.Read()")
- n, err := c.Read(buf)
- got, ok := err.(*net.OpError)
+ n, e := c.Read(buf)
+ got, ok := e.(*net.OpError)
want := tcpip.ErrConnectionAborted
- if n != 0 || !ok || got.Err != want {
- t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, err, want)
+ if n != 0 || !ok || got.Err.Error() != want.String() {
+ t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, e, want)
}
- t.Logf("c.Read() = %d, %v", n, err)
+ t.Logf("c.Read() = %d, %v", n, e)
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
@@ -228,9 +228,9 @@
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
- l, err := NewListener(s, addr, ipv4.ProtocolNumber)
- if err != nil {
- t.Fatalf("NewListener() = %v", err)
+ l, e := NewListener(s, addr, ipv4.ProtocolNumber)
+ if e != nil {
+ t.Fatalf("NewListener() = %v", e)
}
done := make(chan struct{})
go func() {
diff --git a/tcpip/link/channel/channel.go b/tcpip/link/channel/channel.go
index 8ae1848..aaf23f9 100644
--- a/tcpip/link/channel/channel.go
+++ b/tcpip/link/channel/channel.go
@@ -91,7 +91,7 @@
}
// WritePacket stores outbound packets into the channel.
-func (e *Endpoint) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) error {
+func (e *Endpoint) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
p := PacketInfo{
Header: hdr.View(),
Proto: protocol,
diff --git a/tcpip/link/fdbased/endpoint.go b/tcpip/link/fdbased/endpoint.go
index b313759..91003e8 100644
--- a/tcpip/link/fdbased/endpoint.go
+++ b/tcpip/link/fdbased/endpoint.go
@@ -33,7 +33,7 @@
// closed is a function to be called when the FD's peer (if any) closes
// its end of the communication pipe.
- closed func(error)
+ closed func(*tcpip.Error)
vv *buffer.VectorisedView
iovecs []syscall.Iovec
@@ -41,7 +41,7 @@
}
// New creates a new fd-based endpoint.
-func New(fd int, mtu uint32, closed func(error)) tcpip.LinkEndpointID {
+func New(fd int, mtu uint32, closed func(*tcpip.Error)) tcpip.LinkEndpointID {
syscall.SetNonblock(fd, true)
e := &endpoint{
@@ -81,7 +81,7 @@
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) error {
+func (e *endpoint) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
if payload == nil {
return rawfile.NonBlockingWrite(e.fd, hdr.UsedBytes())
@@ -117,7 +117,7 @@
}
// dispatch reads one packet from the file descriptor and dispatches it.
-func (e *endpoint) dispatch(d stack.NetworkDispatcher, largeV buffer.View) (bool, error) {
+func (e *endpoint) dispatch(d stack.NetworkDispatcher, largeV buffer.View) (bool, *tcpip.Error) {
e.allocateViews(BufConfig)
n, err := rawfile.BlockingReadv(e.fd, e.iovecs)
@@ -157,7 +157,7 @@
// dispatchLoop reads packets from the file descriptor in a loop and dispatches
// them to the network stack.
-func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) error {
+func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) *tcpip.Error {
v := buffer.NewView(header.MaxIPPacketSize)
for {
cont, err := e.dispatch(d, v)
diff --git a/tcpip/link/loopback/loopback.go b/tcpip/link/loopback/loopback.go
index 098557f..0506fea 100644
--- a/tcpip/link/loopback/loopback.go
+++ b/tcpip/link/loopback/loopback.go
@@ -51,7 +51,7 @@
// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
// packets to the network-layer dispatcher.
-func (e *endpoint) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) error {
+func (e *endpoint) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
var views [1]buffer.View
if len(payload) == 0 {
// We don't have a payload, so just use the buffer from the
diff --git a/tcpip/link/rawfile/errors.go b/tcpip/link/rawfile/errors.go
new file mode 100644
index 0000000..d470676
--- /dev/null
+++ b/tcpip/link/rawfile/errors.go
@@ -0,0 +1,40 @@
+package rawfile
+
+import (
+ "syscall"
+
+ "github.com/google/netstack/tcpip"
+)
+
+var translations = map[syscall.Errno]*tcpip.Error{
+ syscall.EEXIST: tcpip.ErrDuplicateAddress,
+ syscall.ENETUNREACH: tcpip.ErrNoRoute,
+ syscall.EINVAL: tcpip.ErrInvalidEndpointState,
+ syscall.EALREADY: tcpip.ErrAlreadyConnecting,
+ syscall.EISCONN: tcpip.ErrAlreadyConnected,
+ syscall.EADDRINUSE: tcpip.ErrPortInUse,
+ syscall.EADDRNOTAVAIL: tcpip.ErrBadLocalAddress,
+ syscall.EPIPE: tcpip.ErrClosedForSend,
+ syscall.EWOULDBLOCK: tcpip.ErrWouldBlock,
+ syscall.ECONNREFUSED: tcpip.ErrConnectionRefused,
+ syscall.ETIMEDOUT: tcpip.ErrTimeout,
+ syscall.EINPROGRESS: tcpip.ErrConnectStarted,
+ syscall.EDESTADDRREQ: tcpip.ErrDestinationRequired,
+ syscall.ENOTSUP: tcpip.ErrNotSupported,
+ syscall.ENOTTY: tcpip.ErrQueueSizeNotSupported,
+ syscall.ENOTCONN: tcpip.ErrNotConnected,
+ syscall.ECONNRESET: tcpip.ErrConnectionReset,
+ syscall.ECONNABORTED: tcpip.ErrConnectionAborted,
+}
+
+// TranslateErrno translate an errno from the syscall package into a
+// *tcpip.Error.
+//
+// Not all errnos are supported and this function will panic on unreconized
+// errnos.
+func TranslateErrno(e syscall.Errno) *tcpip.Error {
+ if err, ok := translations[e]; ok {
+ return err
+ }
+ return tcpip.ErrInvalidEndpointState
+}
diff --git a/tcpip/link/rawfile/rawfile_unsafe.go b/tcpip/link/rawfile/rawfile_unsafe.go
index c602cb1..aea4569 100644
--- a/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/tcpip/link/rawfile/rawfile_unsafe.go
@@ -2,12 +2,15 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// Package rawfile contains utilities for using the netstack with raw host
+// files on Linux hosts.
package rawfile
import (
- "fmt"
"syscall"
"unsafe"
+
+ "github.com/google/netstack/tcpip"
)
// TODO: Placed here to avoid breakage caused by coverage
@@ -42,19 +45,15 @@
// NonBlockingWrite writes the given buffer to a file descriptor. It fails if
// partial data is written.
-func NonBlockingWrite(fd int, buf []byte) error {
+func NonBlockingWrite(fd int, buf []byte) *tcpip.Error {
var ptr unsafe.Pointer
if len(buf) > 0 {
ptr = unsafe.Pointer(&buf[0])
}
- n, _, e := syscall.RawSyscall(syscall.SYS_WRITE, uintptr(fd), uintptr(ptr), uintptr(len(buf)))
+ _, _, e := syscall.RawSyscall(syscall.SYS_WRITE, uintptr(fd), uintptr(ptr), uintptr(len(buf)))
if e != 0 {
- return e
- }
-
- if n != uintptr(len(buf)) {
- return fmt.Errorf("wrong number of bytes written: expected %d, got %d", len(buf), n)
+ return TranslateErrno(e)
}
return nil
@@ -62,7 +61,7 @@
// NonBlockingWrite2 writes up to two byte slices to a file descriptor in a
// single syscall. It fails if partial data is written.
-func NonBlockingWrite2(fd int, b1, b2 []byte) error {
+func NonBlockingWrite2(fd int, b1, b2 []byte) *tcpip.Error {
// If the is no second buffer, issue a regular write.
if len(b2) == 0 {
return NonBlockingWrite(fd, b1)
@@ -81,13 +80,9 @@
},
}
- n, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), 2)
+ _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), 2)
if e != 0 {
- return e
- }
-
- if n != uintptr(len(b1)+len(b2)) {
- return fmt.Errorf("wrong number of bytes written: expected %d, got %d", len(b1)+len(b2), n)
+ return TranslateErrno(e)
}
return nil
@@ -96,7 +91,7 @@
// BlockingRead reads from a file descriptor that is set up as non-blocking. If
// no data is available, it will block in a poll() syscall until the file
// descirptor becomes readable.
-func BlockingRead(fd int, b []byte) (int, error) {
+func BlockingRead(fd int, b []byte) (int, *tcpip.Error) {
for {
n, _, e := syscall.RawSyscall(syscall.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)))
if e == 0 {
@@ -114,7 +109,7 @@
_, e = blockingPoll(unsafe.Pointer(&event), 1, -1)
if e != 0 && e != syscall.EINTR {
- return 0, e
+ return 0, TranslateErrno(e)
}
}
}
@@ -122,7 +117,7 @@
// BlockingReadv reads from a file descriptor that is set up as non-blocking and
// stores the data in a list of iovecs buffers. If no data is available, it will
// block in a poll() syscall until the file descirptor becomes readable.
-func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, error) {
+func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) {
for {
n, _, e := syscall.RawSyscall(syscall.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs)))
if e == 0 {
@@ -140,7 +135,7 @@
_, e = blockingPoll(unsafe.Pointer(&event), 1, -1)
if e != 0 && e != syscall.EINTR {
- return 0, e
+ return 0, TranslateErrno(e)
}
}
}
diff --git a/tcpip/link/sharedmem/pipe/pipe.go b/tcpip/link/sharedmem/pipe/pipe.go
new file mode 100644
index 0000000..1173a60
--- /dev/null
+++ b/tcpip/link/sharedmem/pipe/pipe.go
@@ -0,0 +1,68 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package pipe implements a shared memory ring buffer on which a single reader
+// and a single writer can operate (read/write) concurrently. The ring buffer
+// allows for data of different sizes to be written, and preserves the boundary
+// of the written data.
+//
+// Example usage is as follows:
+//
+// wb := t.Push(20)
+// // Write data to wb.
+// t.Flush()
+//
+// rb := r.Pull()
+// // Do something with data in rb.
+// t.Flush()
+package pipe
+
+import (
+ "math"
+)
+
+const (
+ jump uint64 = math.MaxUint32 + 1
+ offsetMask uint64 = math.MaxUint32
+ revolutionMask uint64 = ^offsetMask
+
+ sizeOfSlotHeader = 8 // sizeof(uint64)
+ slotFree uint64 = 1 << 63
+ slotSizeMask uint64 = math.MaxUint32
+)
+
+// payloadToSlotSize calculates the total size of a slot based on its payload
+// size. The total size is the header size, plus the payload size, plus padding
+// if necessary to make the total size a multiple of sizeOfSlotHeader.
+func payloadToSlotSize(payloadSize uint64) uint64 {
+ s := sizeOfSlotHeader + payloadSize
+ return (s + sizeOfSlotHeader - 1) &^ (sizeOfSlotHeader - 1)
+}
+
+// slotToPayloadSize calculates the payload size of a slot based on the total
+// size of the slot. This is only meant to be used when creating slots that
+// don't carry information (e.g., free slots or wrap slots).
+func slotToPayloadSize(offset uint64) uint64 {
+ return offset - sizeOfSlotHeader
+}
+
+// pipe is a basic data structure used by both (transmit & receive) ends of a
+// pipe. Indices into this pipe are split into two fields: offset, which counts
+// the number of bytes from the beginning of the buffer, and revolution, which
+// counts the number of times the index has wrapped around.
+type pipe struct {
+ buffer []byte
+}
+
+// init initializes the pipe buffer such that its size is a multiple of the size
+// of the slot header.
+func (p *pipe) init(b []byte) {
+ p.buffer = b[:len(b)&^(sizeOfSlotHeader-1)]
+}
+
+// data returns a section of the buffer starting at the given index (which may
+// include revolution information) and with the given size.
+func (p *pipe) data(idx uint64, size uint64) []byte {
+ return p.buffer[(idx&offsetMask)+sizeOfSlotHeader:][:size]
+}
diff --git a/tcpip/link/sharedmem/pipe/pipe_test.go b/tcpip/link/sharedmem/pipe/pipe_test.go
new file mode 100644
index 0000000..d35e7c9
--- /dev/null
+++ b/tcpip/link/sharedmem/pipe/pipe_test.go
@@ -0,0 +1,509 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package pipe
+
+import (
+ "math/rand"
+ "reflect"
+ "runtime"
+ "sync"
+ "testing"
+)
+
+func TestSimpleReadWrite(t *testing.T) {
+ // Check that a simple write can be properly read from the rx side.
+ tr := rand.New(rand.NewSource(99))
+ rr := rand.New(rand.NewSource(99))
+
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ wb := tx.Push(10)
+ if wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ for i := range wb {
+ wb[i] = byte(tr.Intn(256))
+ }
+ tx.Flush()
+
+ var rx Rx
+ rx.Init(b)
+ rb := rx.Pull()
+ if len(rb) != 10 {
+ t.Fatalf("Bad buffer size returned: got %v, want %v", len(rb), 10)
+ }
+
+ for i := range rb {
+ if v := byte(rr.Intn(256)); v != rb[i] {
+ t.Fatalf("Bad read buffer at index %v: got %v, want %v", i, rb[i], v)
+ }
+ }
+ rx.Flush()
+}
+
+func TestEmptyRead(t *testing.T) {
+ // Check that pulling from an empty pipe fails.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on empty pipe")
+ }
+}
+
+func TestTooLargeWrite(t *testing.T) {
+ // Check that writes that are too large are properly rejected.
+ b := make([]byte, 96)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(96); wb != nil {
+ t.Fatalf("Write of 96 bytes succeeded on 96-byte pipe")
+ }
+
+ if wb := tx.Push(88); wb != nil {
+ t.Fatalf("Write of 88 bytes succeeded on 96-byte pipe")
+ }
+
+ if wb := tx.Push(80); wb == nil {
+ t.Fatalf("Write of 80 bytes failed on 96-byte pipe")
+ }
+}
+
+func TestFullWrite(t *testing.T) {
+ // Check that writes fail when the pipe is full.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(80); wb == nil {
+ t.Fatalf("Write of 80 bytes failed on 96-byte pipe")
+ }
+
+ if wb := tx.Push(1); wb != nil {
+ t.Fatalf("Write succeeded on full pipe")
+ }
+}
+
+func TestFullAndFlushedWrite(t *testing.T) {
+ // Check that writes fail when the pipe is full and has already been
+ // flushed.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(80); wb == nil {
+ t.Fatalf("Write of 80 bytes failed on 96-byte pipe")
+ }
+
+ tx.Flush()
+
+ if wb := tx.Push(1); wb != nil {
+ t.Fatalf("Write succeeded on full pipe")
+ }
+}
+
+func TestTxFlushTwice(t *testing.T) {
+ // Checks that a second consecutive tx flush is a no-op.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ // Make copy of original tx queue, flush it, then check that it didn't
+ // change.
+ orig := tx
+ tx.Flush()
+
+ if !reflect.DeepEqual(orig, tx) {
+ t.Fatalf("Flush mutated tx pipe: got %v, want %v", tx, orig)
+ }
+}
+
+func TestRxFlushTwice(t *testing.T) {
+ // Checks that a second consecutive rx flush is a no-op.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+ rx.Flush()
+
+ // Make copy of original rx queue, flush it, then check that it didn't
+ // change.
+ orig := rx
+ rx.Flush()
+
+ if !reflect.DeepEqual(orig, rx) {
+ t.Fatalf("Flush mutated rx pipe: got %v, want %v", rx, orig)
+ }
+}
+
+func TestWrapInMiddleOfTransaction(t *testing.T) {
+ // Check that writes are not flushed when we need to wrap the buffer
+ // around.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+ rx.Flush()
+
+ // At this point the ring buffer is empty, but the write is at offset
+ // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment).
+ if wb := tx.Push(10); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on non-full pipe")
+ }
+
+ // We haven't flushed yet, so pull must return nil.
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on non-flushed pipe")
+ }
+
+ tx.Flush()
+
+ // The two buffers must be available now.
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+}
+
+func TestWriteAbort(t *testing.T) {
+ // Check that a read fails on a pipe that has had data pushed to it but
+ // has aborted the push.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(10); wb == nil {
+ t.Fatalf("Write failed on empty pipe")
+ }
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on empty pipe")
+ }
+
+ tx.Abort()
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on empty pipe")
+ }
+}
+
+func TestWrappedWriteAbort(t *testing.T) {
+ // Check that writes are properly aborted even if the writes wrap
+ // around.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+ rx.Flush()
+
+ // At this point the ring buffer is empty, but the write is at offset
+ // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment).
+ if wb := tx.Push(10); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on non-full pipe")
+ }
+
+ // We haven't flushed yet, so pull must return nil.
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on non-flushed pipe")
+ }
+
+ tx.Abort()
+
+ // The pushes were aborted, so no data should be readable.
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on non-flushed pipe")
+ }
+
+ // Try the same transactions again, but flush this time.
+ if wb := tx.Push(10); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on non-full pipe")
+ }
+
+ tx.Flush()
+
+ // The two buffers must be available now.
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+}
+
+func TestEmptyReadOnNonFlushedWrite(t *testing.T) {
+ // Check that a read fails on a pipe that has had data pushed to it
+ // but not yet flushed.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(10); wb == nil {
+ t.Fatalf("Write failed on empty pipe")
+ }
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on empty pipe")
+ }
+
+ tx.Flush()
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull on failed on non-empty pipe")
+ }
+}
+
+func TestPullAfterPullingEntirePipe(t *testing.T) {
+ // Check that Pull fails when the pipe is full, but all of it has
+ // already been pulled but not yet flushed.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+ rx.Flush()
+
+ // At this point the ring buffer is empty, but the write is at offset
+ // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 3
+ // buffers that will fill the pipe.
+ if wb := tx.Push(10); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+
+ if wb := tx.Push(20); wb == nil {
+ t.Fatalf("Push failed on non-full pipe")
+ }
+
+ if wb := tx.Push(24); wb == nil {
+ t.Fatalf("Push failed on non-full pipe")
+ }
+
+ tx.Flush()
+
+ // The three buffers must be available now.
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+
+ // Fourth pull must fail.
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on empty pipe")
+ }
+}
+
+func TestNoRoomToWrapOnPush(t *testing.T) {
+ // Check that Push fails when it tries to allocate room to add a wrap
+ // message.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+ rx.Flush()
+
+ // At this point the ring buffer is empty, but the write is at offset
+ // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 20,
+ // which won't fit (64+20+8+padding = 96, which wouldn't leave room for
+ // the padding), so it wraps around.
+ if wb := tx.Push(20); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+
+ tx.Flush()
+
+ // Buffer offset is at 28. Try to write 70, which would require a wrap
+ // slot which cannot be created now.
+ if wb := tx.Push(70); wb != nil {
+ t.Fatalf("Push succeeded on pipe with no room for wrap message")
+ }
+}
+
+func TestRxImplicitFlushOfWrapMessage(t *testing.T) {
+ // Check if the first read is that of a wrapping message, that it gets
+ // immediately flushed.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ // This will cause a wrapping message to written.
+ if wb := tx.Push(60); wb != nil {
+ t.Fatalf("Push succeeded when there is no room in pipe")
+ }
+
+ var rx Rx
+ rx.Init(b)
+
+ // Read the first message.
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+ rx.Flush()
+
+ // This should fail because of the wrapping message is taking up space.
+ if wb := tx.Push(60); wb != nil {
+ t.Fatalf("Push succeeded when there is no room in pipe")
+ }
+
+ // Try to read the next one. This should consume the wrapping message.
+ rx.Pull()
+
+ // This must now succeed.
+ if wb := tx.Push(60); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+}
+
+func TestConcurrentReaderWriter(t *testing.T) {
+ // Push a million buffers of random sizes and random contents. Check
+ // that buffers read match what was written.
+ tr := rand.New(rand.NewSource(99))
+ rr := rand.New(rand.NewSource(99))
+
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ var rx Rx
+ rx.Init(b)
+
+ const count = 1000000
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ runtime.Gosched()
+ total := 0
+ for i := 0; i < count; i++ {
+ n := 1 + tr.Intn(80)
+ total += n
+ wb := tx.Push(uint64(n))
+ for wb == nil {
+ wb = tx.Push(uint64(n))
+ }
+
+ for j := range wb {
+ wb[j] = byte(tr.Intn(256))
+ }
+
+ tx.Flush()
+ }
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ runtime.Gosched()
+ for i := 0; i < count; i++ {
+ n := 1 + rr.Intn(80)
+ rb := rx.Pull()
+ for rb == nil {
+ rb = rx.Pull()
+ }
+
+ if n != len(rb) {
+ t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n)
+ }
+
+ for j := range rb {
+ if v := byte(rr.Intn(256)); v != rb[j] {
+ t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v)
+ }
+ }
+
+ rx.Flush()
+ }
+ }()
+
+ wg.Wait()
+}
diff --git a/tcpip/link/sharedmem/pipe/pipe_unsafe.go b/tcpip/link/sharedmem/pipe/pipe_unsafe.go
new file mode 100644
index 0000000..d536abe
--- /dev/null
+++ b/tcpip/link/sharedmem/pipe/pipe_unsafe.go
@@ -0,0 +1,25 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package pipe
+
+import (
+ "sync/atomic"
+ "unsafe"
+)
+
+func (p *pipe) write(idx uint64, v uint64) {
+ ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0]))
+ *ptr = v
+}
+
+func (p *pipe) writeAtomic(idx uint64, v uint64) {
+ ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0]))
+ atomic.StoreUint64(ptr, v)
+}
+
+func (p *pipe) readAtomic(idx uint64) uint64 {
+ ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0]))
+ return atomic.LoadUint64(ptr)
+}
diff --git a/tcpip/link/sharedmem/pipe/rx.go b/tcpip/link/sharedmem/pipe/rx.go
new file mode 100644
index 0000000..5bf2b91
--- /dev/null
+++ b/tcpip/link/sharedmem/pipe/rx.go
@@ -0,0 +1,78 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package pipe
+
+// Rx is the receive side of the shared memory ring buffer.
+type Rx struct {
+ p pipe
+
+ tail uint64
+ head uint64
+}
+
+// Init initializes the receive end of the pipe. In the initial state, the next
+// slot to be inspected is the very first one.
+func (r *Rx) Init(b []byte) {
+ r.p.init(b)
+ r.tail = 0xfffffffe * jump
+ r.head = r.tail
+}
+
+// Pull reads the next buffer from the pipe, returning nil if there isn't one
+// currently available.
+//
+// The returned slice is available until Flush() is next called. After that, it
+// must not be touched.
+func (r *Rx) Pull() []byte {
+ if r.head == r.tail+jump {
+ // We've already pulled the whole pipe.
+ return nil
+ }
+
+ header := r.p.readAtomic(r.head)
+ if header&slotFree != 0 {
+ // The next slot is free, we can't pull it yet.
+ return nil
+ }
+
+ payloadSize := header & slotSizeMask
+ newHead := r.head + payloadToSlotSize(payloadSize)
+ headWrap := (r.head & revolutionMask) | uint64(len(r.p.buffer))
+
+ // Check if this is a wrapping slot. If that's the case, it carries no
+ // data, so we just skip it and try again from the first slot.
+ if int64(newHead-headWrap) >= 0 {
+ if int64(newHead-headWrap) > int64(jump) || newHead&offsetMask != 0 {
+ return nil
+ }
+
+ if r.tail == r.head {
+ // If this is the first pull since the last Flush()
+ // call, we flush the state so that the sender can use
+ // this space if it needs to.
+ r.p.writeAtomic(r.head, slotFree|slotToPayloadSize(newHead-r.head))
+ r.tail = newHead
+ }
+
+ r.head = newHead
+ return r.Pull()
+ }
+
+ // Grab the buffer before updating r.head.
+ b := r.p.data(r.head, payloadSize)
+ r.head = newHead
+ return b
+}
+
+// Flush tells the transmitter that all buffers pulled since the last Flush()
+// have been used, so the transmitter is free to used their slots for further
+// transmission.
+func (r *Rx) Flush() {
+ if r.head == r.tail {
+ return
+ }
+ r.p.writeAtomic(r.tail, slotFree|slotToPayloadSize(r.head-r.tail))
+ r.tail = r.head
+}
diff --git a/tcpip/link/sharedmem/pipe/tx.go b/tcpip/link/sharedmem/pipe/tx.go
new file mode 100644
index 0000000..10938e1
--- /dev/null
+++ b/tcpip/link/sharedmem/pipe/tx.go
@@ -0,0 +1,138 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package pipe
+
+// Tx is the transmit side of the shared memory ring buffer.
+type Tx struct {
+ p pipe
+ maxPayloadSize uint64
+
+ head uint64
+ tail uint64
+ next uint64
+
+ tailHeader uint64
+}
+
+// Init initializes the transmit end of the pipe. In the initial state, the next
+// slot to be written is the very first one, and the transmitter has the whole
+// ring buffer available to it.
+func (t *Tx) Init(b []byte) {
+ t.p.init(b)
+ // maxPayloadSize excludes the header of the payload, and the header
+ // of the wrapping message.
+ t.maxPayloadSize = uint64(len(t.p.buffer)) - 2*sizeOfSlotHeader
+ t.tail = 0xfffffffe * jump
+ t.next = t.tail
+ t.head = t.tail + jump
+ t.p.write(t.tail, slotFree)
+}
+
+// Push reserves "payloadSize" bytes for transmission in the pipe. The caller
+// populates the returned slice with the data to be transferred and enventually
+// calls Flush() to make the data visible to the reader, or Abort() to make the
+// pipe forget all Push() calls since the last Flush().
+//
+// The returned slice is available until Flush() or Abort() is next called.
+// After that, it must not be touched.
+func (t *Tx) Push(payloadSize uint64) []byte {
+ // Fail request if we know we will never have enough room.
+ if payloadSize > t.maxPayloadSize {
+ return nil
+ }
+
+ totalLen := payloadToSlotSize(payloadSize)
+ newNext := t.next + totalLen
+ nextWrap := (t.next & revolutionMask) | uint64(len(t.p.buffer))
+ if int64(newNext-nextWrap) >= 0 {
+ // The new buffer would overflow the pipe, so we push a wrapping
+ // slot, then try to add the actual slot to the front of the
+ // pipe.
+ newNext = (newNext & revolutionMask) + jump
+ wrappingPayloadSize := slotToPayloadSize(newNext - t.next)
+ if !t.reclaim(newNext) {
+ return nil
+ }
+
+ oldNext := t.next
+ t.next = newNext
+ if oldNext != t.tail {
+ t.p.write(oldNext, wrappingPayloadSize)
+ } else {
+ t.tailHeader = wrappingPayloadSize
+ t.Flush()
+ }
+
+ newNext += totalLen
+ }
+
+ // Check that we have enough room for the buffer.
+ if !t.reclaim(newNext) {
+ return nil
+ }
+
+ if t.next != t.tail {
+ t.p.write(t.next, payloadSize)
+ } else {
+ t.tailHeader = payloadSize
+ }
+
+ // Grab the buffer before updating t.next.
+ b := t.p.data(t.next, payloadSize)
+ t.next = newNext
+
+ return b
+}
+
+// reclaim attempts to advance the head until at least newNext. If the head is
+// already at or beyond newNext, nothing happens and true is returned; otherwise
+// it tries to reclaim slots that have already been consumed by the receive end
+// of the pipe (they will be marked as free) and returns a boolean indicating
+// whether it was successful in reclaiming enough slots.
+func (t *Tx) reclaim(newNext uint64) bool {
+ for int64(newNext-t.head) > 0 {
+ // Can't reclaim if slot is not free.
+ header := t.p.readAtomic(t.head)
+ if header&slotFree == 0 {
+ return false
+ }
+
+ payloadSize := header & slotSizeMask
+ newHead := t.head + payloadToSlotSize(payloadSize)
+
+ // Check newHead is within bounds and valid.
+ if int64(newHead-t.tail) > int64(jump) || newHead&offsetMask >= uint64(len(t.p.buffer)) {
+ return false
+ }
+
+ t.head = newHead
+ }
+
+ return true
+}
+
+// Abort causes all Push() calls since the last Flush() to be forgotten and
+// therefore they will not be made visible to the receiver.
+func (t *Tx) Abort() {
+ t.next = t.tail
+}
+
+// Flush causes all buffers pushed since the last Flush() [or Abort(), whichever
+// is the most recent] to be made visible to the receiver.
+func (t *Tx) Flush() {
+ if t.next == t.tail {
+ // Nothing to do if there are no pushed buffers.
+ return
+ }
+
+ if t.next != t.head {
+ // The receiver will spin in t.next, so we must make sure that
+ // the slotFree bit is set.
+ t.p.write(t.next, slotFree)
+ }
+
+ t.p.writeAtomic(t.tail, t.tailHeader)
+ t.tail = t.next
+}
diff --git a/tcpip/link/sharedmem/queue/queue_test.go b/tcpip/link/sharedmem/queue/queue_test.go
new file mode 100644
index 0000000..7a2c842
--- /dev/null
+++ b/tcpip/link/sharedmem/queue/queue_test.go
@@ -0,0 +1,507 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package queue
+
+import (
+ "encoding/binary"
+ "reflect"
+ "testing"
+
+ "github.com/google/netstack/tcpip/link/sharedmem/pipe"
+)
+
+func TestBasicTxQueue(t *testing.T) {
+ // Tests that a basic transmit on a queue works, and that completion
+ // gets properly reported as well.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Tx
+ q.Init(pb1, pb2)
+
+ // Enqueue two buffers.
+ b := []TxBuffer{
+ {nil, 100, 60},
+ {nil, 200, 40},
+ }
+
+ b[0].Next = &b[1]
+
+ const usedID = 1002
+ const usedTotalSize = 100
+ if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) {
+ t.Fatalf("Enqueue failed on empty queue")
+ }
+
+ // Check the contents of the pipe.
+ d := rxp.Pull()
+ if d == nil {
+ t.Fatalf("Tx pipe is empty after Enqueue")
+ }
+
+ want := []byte{
+ 234, 3, 0, 0, 0, 0, 0, 0, // id
+ 100, 0, 0, 0, // total size
+ 0, 0, 0, 0, // reserved
+ 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
+ 60, 0, 0, 0, // size 1
+ 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
+ 40, 0, 0, 0, // size 2
+ }
+
+ if !reflect.DeepEqual(want, d) {
+ t.Fatalf("Bad posted packet: got %v, want %v", d, want)
+ }
+
+ rxp.Flush()
+
+ // Check that there are no completions yet.
+ if _, ok := q.CompletedPacket(); ok {
+ t.Fatalf("Packet reported as completed too soon")
+ }
+
+ // Post a completion.
+ d = txp.Push(8)
+ if d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+ binary.LittleEndian.PutUint64(d, usedID)
+ txp.Flush()
+
+ // Check that completion is properly reported.
+ id, ok := q.CompletedPacket()
+ if !ok {
+ t.Fatalf("Completion not reported")
+ }
+
+ if id != usedID {
+ t.Fatalf("Bad completion id: got %v, want %v", id, usedID)
+ }
+}
+
+func TestBasicRxQueue(t *testing.T) {
+ // Tests that a basic receive on a queue works.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Rx
+ q.Init(pb1, pb2, nil)
+
+ // Post two buffers.
+ b := []RxBuffer{
+ {100, 60, 1077},
+ {200, 40, 2123},
+ }
+
+ if !q.PostBuffers(b) {
+ t.Fatalf("PostBuffers failed on empty queue")
+ }
+
+ // Check the contents of the pipe.
+ want := [][]byte{
+ {
+ 100, 0, 0, 0, 0, 0, 0, 0, // Offset1
+ 60, 0, 0, 0, // Size1
+ 0, 0, 0, 0, // Remaining in group 1
+ 0, 0, 0, 0, 0, 0, 0, 0, // User data 1
+ 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
+ },
+ {
+ 200, 0, 0, 0, 0, 0, 0, 0, // Offset2
+ 40, 0, 0, 0, // Size2
+ 0, 0, 0, 0, // Remaining in group 2
+ 0, 0, 0, 0, 0, 0, 0, 0, // User data 2
+ 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
+ },
+ }
+
+ for i := range b {
+ d := rxp.Pull()
+ if d == nil {
+ t.Fatalf("Tx pipe is empty after PostBuffers")
+ }
+
+ if !reflect.DeepEqual(want[i], d) {
+ t.Fatalf("Bad posted packet: got %v, want %v", d, want[i])
+ }
+
+ rxp.Flush()
+ }
+
+ // Check that there are no completions.
+ if _, n := q.Dequeue(nil); n != 0 {
+ t.Fatalf("Packet reported as received too soon")
+ }
+
+ // Post a completion.
+ d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
+ if d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+
+ copy(d, []byte{
+ 100, 0, 0, 0, // packet size
+ 0, 0, 0, 0, // reserved
+
+ 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
+ 60, 0, 0, 0, // size 1
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
+ 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
+
+ 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
+ 40, 0, 0, 0, // size 2
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
+ 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
+ })
+
+ txp.Flush()
+
+ // Check that completion is properly reported.
+ bufs, n := q.Dequeue(nil)
+ if n != 100 {
+ t.Fatalf("Bad packet size: got %v, want %v", n, 100)
+ }
+
+ if !reflect.DeepEqual(bufs, b) {
+ t.Fatalf("Bad returned buffers: got %v, want %v", bufs, b)
+ }
+}
+
+func TestBadTxCompletion(t *testing.T) {
+ // Check that tx completions with bad sizes are properly ignored.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Tx
+ q.Init(pb1, pb2)
+
+ // Post a completion that is too short, and check that it is ignored.
+ if d := txp.Push(7); d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+ txp.Flush()
+
+ if _, ok := q.CompletedPacket(); ok {
+ t.Fatalf("Bad completion not ignored")
+ }
+
+ // Post a completion that is too long, and check that it is ignored.
+ if d := txp.Push(10); d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+ txp.Flush()
+
+ if _, ok := q.CompletedPacket(); ok {
+ t.Fatalf("Bad completion not ignored")
+ }
+}
+
+func TestBadRxCompletion(t *testing.T) {
+ // Check that bad rx completions are properly ignored.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Rx
+ q.Init(pb1, pb2, nil)
+
+ // Post a completion that is too short, and check that it is ignored.
+ if d := txp.Push(7); d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+ txp.Flush()
+
+ if b, _ := q.Dequeue(nil); b != nil {
+ t.Fatalf("Bad completion not ignored")
+ }
+
+ // Post a completion whose buffer sizes add up to less than the total
+ // size.
+ d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
+ if d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+
+ copy(d, []byte{
+ 100, 0, 0, 0, // packet size
+ 0, 0, 0, 0, // reserved
+
+ 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
+ 10, 0, 0, 0, // size 1
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
+ 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
+
+ 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
+ 10, 0, 0, 0, // size 2
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
+ 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
+ })
+
+ txp.Flush()
+ if b, _ := q.Dequeue(nil); b != nil {
+ t.Fatalf("Bad completion not ignored")
+ }
+
+ // Post a completion whose buffer sizes will cause a 32-bit overflow,
+ // but adds up to the right number.
+ d = txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
+ if d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+
+ copy(d, []byte{
+ 100, 0, 0, 0, // packet size
+ 0, 0, 0, 0, // reserved
+
+ 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
+ 255, 255, 255, 255, // size 1
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
+ 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
+
+ 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
+ 101, 0, 0, 0, // size 2
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
+ 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
+ })
+
+ txp.Flush()
+ if b, _ := q.Dequeue(nil); b != nil {
+ t.Fatalf("Bad completion not ignored")
+ }
+}
+
+func TestFillTxPipe(t *testing.T) {
+ // Check that transmitting a new buffer when the buffer pipe is full
+ // fails gracefully.
+ pb1 := make([]byte, 104)
+ pb2 := make([]byte, 104)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Tx
+ q.Init(pb1, pb2)
+
+ // Transmit twice, which should fill the tx pipe.
+ b := []TxBuffer{
+ {nil, 100, 60},
+ {nil, 200, 40},
+ }
+
+ b[0].Next = &b[1]
+
+ const usedID = 1002
+ const usedTotalSize = 100
+ for i := uint64(0); i < 2; i++ {
+ if !q.Enqueue(usedID+i, usedTotalSize, 2, &b[0]) {
+ t.Fatalf("Failed to transmit buffer")
+ }
+ }
+
+ // Transmit another packet now that the tx pipe is full.
+ if q.Enqueue(usedID+2, usedTotalSize, 2, &b[0]) {
+ t.Fatalf("Enqueue succeeded when tx pipe is full")
+ }
+}
+
+func TestFillRxPipe(t *testing.T) {
+ // Check that posting a new buffer when the buffer pipe is full fails
+ // gracefully.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Rx
+ q.Init(pb1, pb2, nil)
+
+ // Post a buffer twice, it should fill the tx pipe.
+ b := []RxBuffer{
+ {100, 60, 1077},
+ }
+
+ for i := 0; i < 2; i++ {
+ if !q.PostBuffers(b) {
+ t.Fatalf("PostBuffers failed on non-full queue")
+ }
+ }
+
+ // Post another buffer now that the tx pipe is full.
+ if q.PostBuffers(b) {
+ t.Fatalf("PostBuffers succeeded on full queue")
+ }
+}
+
+func TestLotsOfTransmissions(t *testing.T) {
+ // Make sure pipes are being properly flushed when transmitting packets.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Tx
+ q.Init(pb1, pb2)
+
+ // Prepare packet with two buffers.
+ b := []TxBuffer{
+ {nil, 100, 60},
+ {nil, 200, 40},
+ }
+
+ b[0].Next = &b[1]
+
+ const usedID = 1002
+ const usedTotalSize = 100
+
+ // Post 100000 packets and completions.
+ for i := 100000; i > 0; i-- {
+ if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) {
+ t.Fatalf("Enqueue failed on non-full queue")
+ }
+
+ if d := rxp.Pull(); d == nil {
+ t.Fatalf("Tx pipe is empty after Enqueue")
+ }
+ rxp.Flush()
+
+ d := txp.Push(8)
+ if d == nil {
+ t.Fatalf("Unable to write to rx pipe")
+ }
+ binary.LittleEndian.PutUint64(d, usedID)
+ txp.Flush()
+ if _, ok := q.CompletedPacket(); !ok {
+ t.Fatalf("Completion not returned")
+ }
+ }
+}
+
+func TestLotsOfReceptions(t *testing.T) {
+ // Make sure pipes are being properly flushed when receiving packets.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Rx
+ q.Init(pb1, pb2, nil)
+
+ // Prepare for posting two buffers.
+ b := []RxBuffer{
+ {100, 60, 1077},
+ {200, 40, 2123},
+ }
+
+ // Post 100000 buffers and completions.
+ for i := 100000; i > 0; i-- {
+ if !q.PostBuffers(b) {
+ t.Fatalf("PostBuffers failed on non-full queue")
+ }
+
+ if d := rxp.Pull(); d == nil {
+ t.Fatalf("Tx pipe is empty after PostBuffers")
+ }
+ rxp.Flush()
+
+ if d := rxp.Pull(); d == nil {
+ t.Fatalf("Tx pipe is empty after PostBuffers")
+ }
+ rxp.Flush()
+
+ d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
+ if d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+
+ copy(d, []byte{
+ 100, 0, 0, 0, // packet size
+ 0, 0, 0, 0, // reserved
+
+ 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
+ 60, 0, 0, 0, // size 1
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
+ 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
+
+ 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
+ 40, 0, 0, 0, // size 2
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
+ 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
+ })
+
+ txp.Flush()
+
+ if _, n := q.Dequeue(nil); n == 0 {
+ t.Fatalf("Dequeue failed when there is a completion")
+ }
+ }
+}
+
+func TestRxEnableNotification(t *testing.T) {
+ // Check that enabling nofifications results in properly updated state.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var state uint32
+ var q Rx
+ q.Init(pb1, pb2, &state)
+
+ q.EnableNotification()
+ if state != eventFDEnabled {
+ t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDEnabled)
+ }
+}
+
+func TestRxDisableNotification(t *testing.T) {
+ // Check that disabling nofifications results in properly updated state.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var state uint32
+ var q Rx
+ q.Init(pb1, pb2, &state)
+
+ q.DisableNotification()
+ if state != eventFDDisabled {
+ t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDDisabled)
+ }
+}
diff --git a/tcpip/link/sharedmem/queue/rx.go b/tcpip/link/sharedmem/queue/rx.go
new file mode 100644
index 0000000..4d50993
--- /dev/null
+++ b/tcpip/link/sharedmem/queue/rx.go
@@ -0,0 +1,168 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package queue provides the implementation of transmit and receive queues
+// based on shared memory ring buffers.
+package queue
+
+import (
+ "encoding/binary"
+ "sync/atomic"
+
+ "github.com/google/netstack/tcpip/link/sharedmem/pipe"
+ "log"
+)
+
+const (
+ // Offsets within a posted buffer.
+ postedOffset = 0
+ postedSize = 8
+ postedRemainingInGroup = 12
+ postedUserData = 16
+ postedID = 24
+
+ sizeOfPostedBuffer = 32
+
+ // Offsets within a received packet header.
+ consumedPacketSize = 0
+ consumedPacketReserved = 4
+
+ sizeOfConsumedPacketHeader = 8
+
+ // Offsets within a consumed buffer.
+ consumedOffset = 0
+ consumedSize = 8
+ consumedUserData = 12
+ consumedID = 20
+
+ sizeOfConsumedBuffer = 28
+
+ // The following are the allowed states of the shared data area.
+ eventFDUninitialized = 0
+ eventFDDisabled = 1
+ eventFDEnabled = 2
+)
+
+// RxBuffer is the descriptor of a receive buffer.
+type RxBuffer struct {
+ Offset uint64
+ Size uint32
+ ID uint64
+}
+
+// Rx is a receive queue. It is implemented with one tx and one rx pipe: the tx
+// pipe is used to "post" buffers, while the rx pipe is used to receive packets
+// whose contents have been written to previously posted buffers.
+//
+// This struct is thread-compatible.
+type Rx struct {
+ tx pipe.Tx
+ rx pipe.Rx
+ sharedEventFDState *uint32
+}
+
+// Init initializes the receive queue with the given pipes, and shared state
+// pointer -- the latter is used to enable/disable eventfd notifications.
+func (r *Rx) Init(tx, rx []byte, sharedEventFDState *uint32) {
+ r.sharedEventFDState = sharedEventFDState
+ r.tx.Init(tx)
+ r.rx.Init(rx)
+}
+
+// EnableNotification updates the shared state such that the peer will notify
+// the eventfd when there are packets to be dequeued.
+func (r *Rx) EnableNotification() {
+ atomic.StoreUint32(r.sharedEventFDState, eventFDEnabled)
+}
+
+// DisableNotification updates the shared state such that the peer will not
+// notify the eventfd.
+func (r *Rx) DisableNotification() {
+ atomic.StoreUint32(r.sharedEventFDState, eventFDDisabled)
+}
+
+// PostBuffers makes the given buffers available for receiving data from the
+// peer. Once they are posted, the peer is free to write to them and will
+// eventually post them back for consumption.
+func (r *Rx) PostBuffers(buffers []RxBuffer) bool {
+ for i := range buffers {
+ b := r.tx.Push(sizeOfPostedBuffer)
+ if b == nil {
+ r.tx.Abort()
+ return false
+ }
+
+ pb := &buffers[i]
+ binary.LittleEndian.PutUint64(b[postedOffset:], pb.Offset)
+ binary.LittleEndian.PutUint32(b[postedSize:], pb.Size)
+ binary.LittleEndian.PutUint32(b[postedRemainingInGroup:], 0)
+ binary.LittleEndian.PutUint64(b[postedUserData:], 0)
+ binary.LittleEndian.PutUint64(b[postedID:], pb.ID)
+ }
+
+ r.tx.Flush()
+
+ return true
+}
+
+// Dequeue receives buffers that have been previously posted by PostBuffers()
+// and that have been filled by the peer and posted back.
+//
+// This is similar to append() in that new buffers are appended to "bufs", with
+// reallocation only if "bufs" doesn't have enough capacity.
+func (r *Rx) Dequeue(bufs []RxBuffer) ([]RxBuffer, uint32) {
+ for {
+ outBufs := bufs
+
+ // Pull the next descriptor from the rx pipe.
+ b := r.rx.Pull()
+ if b == nil {
+ return bufs, 0
+ }
+
+ if len(b) < sizeOfConsumedPacketHeader {
+ log.Printf("Ignoring packet header: size (%v) is less than header size (%v)", len(b), sizeOfConsumedPacketHeader)
+ r.rx.Flush()
+ continue
+ }
+
+ totalDataSize := binary.LittleEndian.Uint32(b[consumedPacketSize:])
+
+ // Calculate the number of buffer descriptors and copy them
+ // over to the output.
+ count := (len(b) - sizeOfConsumedPacketHeader) / sizeOfConsumedBuffer
+ offset := sizeOfConsumedPacketHeader
+ buffersSize := uint32(0)
+ for i := count; i > 0; i-- {
+ s := binary.LittleEndian.Uint32(b[offset+consumedSize:])
+ buffersSize += s
+ if buffersSize < s {
+ // The buffer size overflows an unsigned 32-bit
+ // integer, so break out and force it to be
+ // ignored.
+ totalDataSize = 1
+ buffersSize = 0
+ break
+ }
+
+ outBufs = append(outBufs, RxBuffer{
+ Offset: binary.LittleEndian.Uint64(b[offset+consumedOffset:]),
+ Size: s,
+ ID: binary.LittleEndian.Uint64(b[offset+consumedID:]),
+ })
+
+ offset += sizeOfConsumedBuffer
+ }
+
+ r.rx.Flush()
+
+ if buffersSize < totalDataSize {
+ // The descriptor is corrupted, ignore it.
+ log.Printf("Ignoring packet: actual data size (%v) less than expected size (%v)", buffersSize, totalDataSize)
+ continue
+ }
+
+ return outBufs, totalDataSize
+ }
+}
diff --git a/tcpip/link/sharedmem/queue/tx.go b/tcpip/link/sharedmem/queue/tx.go
new file mode 100644
index 0000000..d16e726
--- /dev/null
+++ b/tcpip/link/sharedmem/queue/tx.go
@@ -0,0 +1,103 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package queue
+
+import (
+ "encoding/binary"
+
+ "github.com/google/netstack/tcpip/link/sharedmem/pipe"
+ "log"
+)
+
+const (
+ // Offsets within a packet header.
+ packetID = 0
+ packetSize = 8
+ packetReserved = 12
+
+ sizeOfPacketHeader = 16
+
+ // Offsets with a buffer descriptor
+ bufferOffset = 0
+ bufferSize = 8
+
+ sizeOfBufferDescriptor = 12
+)
+
+// TxBuffer is the descriptor of a transmit buffer.
+type TxBuffer struct {
+ Next *TxBuffer
+ Offset uint64
+ Size uint32
+}
+
+// Tx is a transmit queue. It is implemented with one tx and one rx pipe: the
+// tx pipe is used to request the transmission of packets, while the rx pipe
+// is used to receive which transmissions have completed.
+//
+// This struct is thread-compatible.
+type Tx struct {
+ tx pipe.Tx
+ rx pipe.Rx
+}
+
+// Init initializes the transmit queue with the given pipes.
+func (t *Tx) Init(tx, rx []byte) {
+ t.tx.Init(tx)
+ t.rx.Init(rx)
+}
+
+// Enqueue queues the given linked list of buffers for transmission as one
+// packet. While it is queued, the caller must not modify them.
+func (t *Tx) Enqueue(id uint64, totalDataLen, bufferCount uint32, buffer *TxBuffer) bool {
+ // Reserve room in the tx pipe.
+ totalLen := sizeOfPacketHeader + uint64(bufferCount)*sizeOfBufferDescriptor
+
+ b := t.tx.Push(totalLen)
+ if b == nil {
+ return false
+ }
+
+ // Initialize the packet and buffer descriptors.
+ binary.LittleEndian.PutUint64(b[packetID:], id)
+ binary.LittleEndian.PutUint32(b[packetSize:], totalDataLen)
+ binary.LittleEndian.PutUint32(b[packetReserved:], 0)
+
+ offset := sizeOfPacketHeader
+ for i := bufferCount; i != 0; i-- {
+ binary.LittleEndian.PutUint64(b[offset+bufferOffset:], buffer.Offset)
+ binary.LittleEndian.PutUint32(b[offset+bufferSize:], buffer.Size)
+ offset += sizeOfBufferDescriptor
+ buffer = buffer.Next
+ }
+
+ t.tx.Flush()
+
+ return true
+}
+
+// CompletedPacket returns the id of the last completed transmission. The
+// returned id, if any, refers to a value passed on a previous call to
+// Enqueue().
+func (t *Tx) CompletedPacket() (id uint64, ok bool) {
+ for {
+ b := t.rx.Pull()
+ if b == nil {
+ return 0, false
+ }
+
+ if len(b) != 8 {
+ t.rx.Flush()
+ log.Printf("Ignoring completed packet: size (%v) is less than expected (%v)", len(b), 8)
+ continue
+ }
+
+ v := binary.LittleEndian.Uint64(b)
+
+ t.rx.Flush()
+
+ return v, true
+ }
+}
diff --git a/tcpip/link/sniffer/sniffer.go b/tcpip/link/sniffer/sniffer.go
index ef335a5..1b7d975 100644
--- a/tcpip/link/sniffer/sniffer.go
+++ b/tcpip/link/sniffer/sniffer.go
@@ -12,6 +12,7 @@
import (
"fmt"
+ "sync/atomic"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
@@ -20,6 +21,10 @@
"log"
)
+// LogPackets is a flag used to enable or disable packet valid values
+// are 0 or 1.
+var LogPackets uint32 = 1
+
type endpoint struct {
dispatcher stack.NetworkDispatcher
lower stack.LinkEndpoint
@@ -37,7 +42,9 @@
// called by the link-layer endpoint being wrapped when a packet arrives, and
// logs the packet before forwarding to the actual dispatcher.
func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) {
- LogPacket("recv", protocol, vv.First(), nil)
+ if atomic.LoadUint32(&LogPackets) == 1 {
+ LogPacket("recv", protocol, vv.First(), nil)
+ }
e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, protocol, vv)
}
@@ -68,8 +75,10 @@
// WritePacket implements the stack.LinkEndpoint interface. It is called by
// higher-level protocols to write packets; it just logs the packet and forwards
// the request to the lower endpoint.
-func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) error {
- LogPacket("send", protocol, hdr.UsedBytes(), payload)
+func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ if atomic.LoadUint32(&LogPackets) == 1 {
+ LogPacket("send", protocol, hdr.UsedBytes(), payload)
+ }
return e.lower.WritePacket(r, hdr, payload, protocol)
}
diff --git a/tcpip/network/arp/arp.go b/tcpip/network/arp/arp.go
index c0ac551..5707f97 100644
--- a/tcpip/network/arp/arp.go
+++ b/tcpip/network/arp/arp.go
@@ -16,8 +16,6 @@
package arp
import (
- "fmt"
-
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
@@ -65,7 +63,7 @@
func (e *endpoint) Close() {}
-func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) error {
+func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -109,9 +107,9 @@
return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
}
-func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, error) {
+func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
if addr != ProtocolAddress {
- return nil, fmt.Errorf("arp: invalid endpoint address %q", addr)
+ return nil, tcpip.ErrBadLocalAddress
}
return &endpoint{
nicid: nicid,
@@ -125,7 +123,7 @@
return header.IPv4ProtocolNumber
}
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) error {
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
RemoteLinkAddress: broadcastMAC,
}
diff --git a/tcpip/network/fragmentation/frag_heap_test.go b/tcpip/network/fragmentation/frag_heap_test.go
index 6d92d2d..ba074ac 100644
--- a/tcpip/network/fragmentation/frag_heap_test.go
+++ b/tcpip/network/fragmentation/frag_heap_test.go
@@ -20,56 +20,56 @@
{
comment: "Non-overlapping in-order",
in: []fragment{
- fragment{offset: 0, vv: vv(1, "0")},
- fragment{offset: 1, vv: vv(1, "1")},
+ {offset: 0, vv: vv(1, "0")},
+ {offset: 1, vv: vv(1, "1")},
},
want: vv(2, "0", "1"),
},
{
comment: "Non-overlapping out-of-order",
in: []fragment{
- fragment{offset: 1, vv: vv(1, "1")},
- fragment{offset: 0, vv: vv(1, "0")},
+ {offset: 1, vv: vv(1, "1")},
+ {offset: 0, vv: vv(1, "0")},
},
want: vv(2, "0", "1"),
},
{
comment: "Duplicated packets",
in: []fragment{
- fragment{offset: 0, vv: vv(1, "0")},
- fragment{offset: 0, vv: vv(1, "0")},
+ {offset: 0, vv: vv(1, "0")},
+ {offset: 0, vv: vv(1, "0")},
},
want: vv(1, "0"),
},
{
comment: "Overlapping in-order",
in: []fragment{
- fragment{offset: 0, vv: vv(2, "01")},
- fragment{offset: 1, vv: vv(2, "12")},
+ {offset: 0, vv: vv(2, "01")},
+ {offset: 1, vv: vv(2, "12")},
},
want: vv(3, "01", "2"),
},
{
comment: "Overlapping out-of-order",
in: []fragment{
- fragment{offset: 1, vv: vv(2, "12")},
- fragment{offset: 0, vv: vv(2, "01")},
+ {offset: 1, vv: vv(2, "12")},
+ {offset: 0, vv: vv(2, "01")},
},
want: vv(3, "01", "2"),
},
{
comment: "Overlapping subset in-order",
in: []fragment{
- fragment{offset: 0, vv: vv(3, "012")},
- fragment{offset: 1, vv: vv(1, "1")},
+ {offset: 0, vv: vv(3, "012")},
+ {offset: 1, vv: vv(1, "1")},
},
want: vv(3, "012"),
},
{
comment: "Overlapping subset out-of-order",
in: []fragment{
- fragment{offset: 1, vv: vv(1, "1")},
- fragment{offset: 0, vv: vv(3, "012")},
+ {offset: 1, vv: vv(1, "1")},
+ {offset: 0, vv: vv(3, "012")},
},
want: vv(3, "012"),
},
diff --git a/tcpip/network/fragmentation/fragmentation_test.go b/tcpip/network/fragmentation/fragmentation_test.go
index 0967fce..8ef89fd 100644
--- a/tcpip/network/fragmentation/fragmentation_test.go
+++ b/tcpip/network/fragmentation/fragmentation_test.go
@@ -49,27 +49,27 @@
{
comment: "One ID",
in: []processInput{
- processInput{id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
- processInput{id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
+ {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
},
out: []processOutput{
- processOutput{vv: emptyVv(), done: false},
- processOutput{vv: vv(4, "01", "23"), done: true},
+ {vv: emptyVv(), done: false},
+ {vv: vv(4, "01", "23"), done: true},
},
},
{
comment: "Two IDs",
in: []processInput{
- processInput{id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
- processInput{id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")},
- processInput{id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")},
- processInput{id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
+ {id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")},
+ {id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")},
+ {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
},
out: []processOutput{
- processOutput{vv: emptyVv(), done: false},
- processOutput{vv: emptyVv(), done: false},
- processOutput{vv: vv(4, "ab", "cd"), done: true},
- processOutput{vv: vv(4, "01", "23"), done: true},
+ {vv: emptyVv(), done: false},
+ {vv: emptyVv(), done: false},
+ {vv: vv(4, "ab", "cd"), done: true},
+ {vv: vv(4, "01", "23"), done: true},
},
},
}
@@ -132,3 +132,16 @@
t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
}
}
+
+func TestFragmentationViewsDoNotEscape(t *testing.T) {
+ f := NewFragmentation(1024, DefaultReassembleTimeout)
+ in := vv(2, "0", "1")
+ f.Process(0, 0, 1, true, in)
+ // Modify input view.
+ in.RemoveFirst()
+ got, _ := f.Process(0, 2, 2, false, vv(1, "2"))
+ want := vv(3, "0", "1", "2")
+ if !reflect.DeepEqual(got, *want) {
+ t.Errorf("Process() returned a wrong vv. Got %v. Want %v", got, *want)
+ }
+}
diff --git a/tcpip/network/fragmentation/reassembler.go b/tcpip/network/fragmentation/reassembler.go
index dabc713..46dd26b 100644
--- a/tcpip/network/fragmentation/reassembler.go
+++ b/tcpip/network/fragmentation/reassembler.go
@@ -80,7 +80,8 @@
}
if r.updateHoles(first, last, more) {
// We store the incoming packet only if it filled some holes.
- heap.Push(&r.heap, fragment{offset: first, vv: vv})
+ uu := vv.Clone(nil)
+ heap.Push(&r.heap, fragment{offset: first, vv: &uu})
consumed = vv.Size()
r.size += consumed
}
diff --git a/tcpip/network/fragmentation/reassembler_test.go b/tcpip/network/fragmentation/reassembler_test.go
index 131664b..b646043 100644
--- a/tcpip/network/fragmentation/reassembler_test.go
+++ b/tcpip/network/fragmentation/reassembler_test.go
@@ -24,38 +24,38 @@
{
comment: "No fragments. Expected holes: {[0 -> inf]}.",
in: []updateHolesInput{},
- want: []hole{hole{first: 0, last: math.MaxUint16, deleted: false}},
+ want: []hole{{first: 0, last: math.MaxUint16, deleted: false}},
},
{
comment: "One fragment at beginning. Expected holes: {[2, inf]}.",
in: []updateHolesInput{{first: 0, last: 1, more: true}},
want: []hole{
- hole{first: 0, last: math.MaxUint16, deleted: true},
- hole{first: 2, last: math.MaxUint16, deleted: false},
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 2, last: math.MaxUint16, deleted: false},
},
},
{
comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.",
in: []updateHolesInput{{first: 1, last: 2, more: true}},
want: []hole{
- hole{first: 0, last: math.MaxUint16, deleted: true},
- hole{first: 0, last: 0, deleted: false},
- hole{first: 3, last: math.MaxUint16, deleted: false},
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 0, last: 0, deleted: false},
+ {first: 3, last: math.MaxUint16, deleted: false},
},
},
{
comment: "One fragment at the end. Expected holes: {[0, 0]}.",
in: []updateHolesInput{{first: 1, last: 2, more: false}},
want: []hole{
- hole{first: 0, last: math.MaxUint16, deleted: true},
- hole{first: 0, last: 0, deleted: false},
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 0, last: 0, deleted: false},
},
},
{
comment: "One fragment completing a packet. Expected holes: {}.",
in: []updateHolesInput{{first: 0, last: 1, more: false}},
want: []hole{
- hole{first: 0, last: math.MaxUint16, deleted: true},
+ {first: 0, last: math.MaxUint16, deleted: true},
},
},
{
@@ -65,8 +65,8 @@
{first: 2, last: 3, more: false},
},
want: []hole{
- hole{first: 0, last: math.MaxUint16, deleted: true},
- hole{first: 2, last: math.MaxUint16, deleted: true},
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 2, last: math.MaxUint16, deleted: true},
},
},
{
@@ -76,8 +76,8 @@
{first: 2, last: 3, more: false},
},
want: []hole{
- hole{first: 0, last: math.MaxUint16, deleted: true},
- hole{first: 3, last: math.MaxUint16, deleted: true},
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 3, last: math.MaxUint16, deleted: true},
},
},
}
diff --git a/tcpip/network/ip_test.go b/tcpip/network/ip_test.go
index 315f754..5ac1efe 100644
--- a/tcpip/network/ip_test.go
+++ b/tcpip/network/ip_test.go
@@ -89,7 +89,7 @@
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
-func (t *testObject) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) error {
+func (t *testObject) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
var prot tcpip.TransportProtocolNumber
var srcAddr tcpip.Address
var dstAddr tcpip.Address
diff --git a/tcpip/network/ipv4/icmp.go b/tcpip/network/ipv4/icmp.go
index d8f3a3e..72bcfb8 100644
--- a/tcpip/network/ipv4/icmp.go
+++ b/tcpip/network/ipv4/icmp.go
@@ -62,7 +62,7 @@
}
}
-func sendICMPv4(r *stack.Route, typ header.ICMPv4Type, code byte, data buffer.View) error {
+func sendICMPv4(r *stack.Route, typ header.ICMPv4Type, code byte, data buffer.View) *tcpip.Error {
hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength()))
icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
@@ -85,7 +85,7 @@
// Ping sends echo requests to an ICMPv4 endpoint.
// Responses are streamed to the channel ch.
-func (p *Pinger) Ping(ctx context.Context, ch chan<- PingReply) error {
+func (p *Pinger) Ping(ctx context.Context, ch chan<- PingReply) *tcpip.Error {
count := p.Count
if count == 0 {
count = 1<<16 - 1
@@ -110,13 +110,13 @@
RemoteAddress: p.Addr,
}
- _, err = p.Stack.PickEphemeralPort(func(port uint16) (bool, error) {
+ _, err = p.Stack.PickEphemeralPort(func(port uint16) (bool, *tcpip.Error) {
id.LocalPort = port
err := p.Stack.RegisterTransportEndpoint(p.NICID, netProtos, pingProtocolNumber, id, ep)
switch err {
case nil:
return true, nil
- case tcpip.ErrDuplicateAddress:
+ case tcpip.ErrPortInUse:
return false, nil
default:
return false, err
@@ -174,14 +174,14 @@
// PingReply summarizes an ICMP echo reply.
type PingReply struct {
- Error error // reports any errors sending a ping request
+ Error *tcpip.Error // reports any errors sending a ping request
Duration time.Duration
SeqNumber uint16
}
type pingProtocol struct{}
-func (*pingProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, error) {
+func (*pingProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return nil, tcpip.ErrNotSupported // endpoints are created directly
}
@@ -189,7 +189,7 @@
func (*pingProtocol) MinimumPacketSize() int { return header.ICMPv4EchoMinimumSize }
-func (*pingProtocol) ParsePorts(v buffer.View) (src, dst uint16, err error) {
+func (*pingProtocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
ident := binary.BigEndian.Uint16(v[4:])
return 0, ident, nil
}
diff --git a/tcpip/network/ipv4/ipv4.go b/tcpip/network/ipv4/ipv4.go
index 4da9b91..4b64763 100644
--- a/tcpip/network/ipv4/ipv4.go
+++ b/tcpip/network/ipv4/ipv4.go
@@ -99,7 +99,7 @@
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) error {
+func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
length := uint16(hdr.UsedLength() + len(payload))
id := uint32(0)
@@ -185,7 +185,7 @@
}
// NewEndpoint creates a new ipv4 endpoint.
-func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, error) {
+func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
return newEndpoint(nicid, addr, dispatcher, linkEP), nil
}
diff --git a/tcpip/network/ipv6/icmp.go b/tcpip/network/ipv6/icmp.go
index 485d86a..0c7e5e6 100644
--- a/tcpip/network/ipv6/icmp.go
+++ b/tcpip/network/ipv6/icmp.go
@@ -81,7 +81,7 @@
return header.IPv6ProtocolNumber
}
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) error {
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
// Solicited-Node multicast address, used for NDP. Described in RFC 4291.
snaddr := "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff" + addr[len(addr)-3:]
r := &stack.Route{
diff --git a/tcpip/network/ipv6/ipv6.go b/tcpip/network/ipv6/ipv6.go
index a98d7fb..f06efdc 100644
--- a/tcpip/network/ipv6/ipv6.go
+++ b/tcpip/network/ipv6/ipv6.go
@@ -74,7 +74,7 @@
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) error {
+func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
length := uint16(hdr.UsedLength())
if payload != nil {
length += uint16(len(payload))
@@ -139,7 +139,7 @@
}
// NewEndpoint creates a new ipv6 endpoint.
-func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, error) {
+func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
e := &endpoint{
nicid: nicid,
linkEP: linkEP,
diff --git a/tcpip/ports/ports.go b/tcpip/ports/ports.go
index 82f104a..7fc78a2 100644
--- a/tcpip/ports/ports.go
+++ b/tcpip/ports/ports.go
@@ -62,7 +62,7 @@
// possible ephemeral ports, allowing the caller to decide whether a given port
// is suitable for its needs, and stopping when a port is found or an error
// occurs.
-func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, error)) (port uint16, err error) {
+func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
count := uint16(math.MaxUint16 - firstEphemeral + 1)
offset := uint16(rand.Int31n(int32(count)))
@@ -85,7 +85,7 @@
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
// "port" return value.
-func (s *PortManager) ReservePort(network []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) (reservedPort uint16, err error) {
+func (s *PortManager) ReservePort(network []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) (reservedPort uint16, err *tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -99,7 +99,7 @@
}
// A port wasn't specified, so try to find one.
- return s.PickEphemeralPort(func(p uint16) (bool, error) {
+ return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
return s.reserveSpecificPort(network, transport, addr, p), nil
})
}
diff --git a/tcpip/ports/ports_test.go b/tcpip/ports/ports_test.go
index 375905b..8ce7a3b 100644
--- a/tcpip/ports/ports_test.go
+++ b/tcpip/ports/ports_test.go
@@ -4,7 +4,6 @@
package ports
import (
- "errors"
"testing"
"github.com/google/netstack/tcpip"
@@ -25,7 +24,7 @@
for _, test := range []struct {
port uint16
ip tcpip.Address
- want error
+ want *tcpip.Error
}{
{
port: 80,
@@ -84,30 +83,30 @@
func TestPickEphemeralPort(t *testing.T) {
pm := NewPortManager()
- customErr := errors.New("fake error")
+ customErr := &tcpip.Error{}
for _, test := range []struct {
name string
- f func(port uint16) (bool, error)
- wantErr error
+ f func(port uint16) (bool, *tcpip.Error)
+ wantErr *tcpip.Error
wantPort uint16
}{
{
name: "no-port-available",
- f: func(port uint16) (bool, error) {
+ f: func(port uint16) (bool, *tcpip.Error) {
return false, nil
},
wantErr: tcpip.ErrNoPortAvailable,
},
{
name: "port-tester-error",
- f: func(port uint16) (bool, error) {
+ f: func(port uint16) (bool, *tcpip.Error) {
return false, customErr
},
wantErr: customErr,
},
{
name: "only-port-16042-available",
- f: func(port uint16) (bool, error) {
+ f: func(port uint16) (bool, *tcpip.Error) {
if port == firstEphemeral+42 {
return true, nil
}
@@ -117,7 +116,7 @@
},
{
name: "only-port-under-16000-available",
- f: func(port uint16) (bool, error) {
+ f: func(port uint16) (bool, *tcpip.Error) {
if port < firstEphemeral {
return true, nil
}
diff --git a/tcpip/sample/tun_tcp_connect/main.go b/tcpip/sample/tun_tcp_connect/main.go
index ddd08da..1391d90 100644
--- a/tcpip/sample/tun_tcp_connect/main.go
+++ b/tcpip/sample/tun_tcp_connect/main.go
@@ -146,9 +146,9 @@
// Create TCP endpoint.
var wq waiter.Queue
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ ep, e := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- log.Fatal(err)
+ log.Fatal(e)
}
// Bind if a port is specified.
@@ -161,16 +161,16 @@
// Issue connect request and wait for it to complete.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
wq.EventRegister(&waitEntry, waiter.EventOut)
- err = ep.Connect(remote)
- if err == tcpip.ErrConnectStarted {
+ terr := ep.Connect(remote)
+ if terr == tcpip.ErrConnectStarted {
fmt.Println("Connect is pending...")
<-notifyCh
- err = ep.GetSockOpt(tcpip.ErrorOption{})
+ terr = ep.GetSockOpt(tcpip.ErrorOption{})
}
wq.EventUnregister(&waitEntry)
- if err != nil {
- log.Fatal("Unable to connect: ", err)
+ if terr != nil {
+ log.Fatal("Unable to connect: ", terr)
}
fmt.Println("Connected")
diff --git a/tcpip/sample/tun_tcp_echo/main.go b/tcpip/sample/tun_tcp_echo/main.go
index 381e4b7..92989df 100644
--- a/tcpip/sample/tun_tcp_echo/main.go
+++ b/tcpip/sample/tun_tcp_echo/main.go
@@ -120,9 +120,9 @@
// Create TCP endpoint, bind it, then start listening.
var wq waiter.Queue
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq)
+ ep, e := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq)
if err != nil {
- log.Fatal(err)
+ log.Fatal(e)
}
defer ep.Close()
diff --git a/tcpip/stack/nic.go b/tcpip/stack/nic.go
index 062129b..a19424f 100644
--- a/tcpip/stack/nic.go
+++ b/tcpip/stack/nic.go
@@ -88,7 +88,7 @@
return ref
}
-func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, replace bool) (*referencedNetworkEndpoint, error) {
+func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) {
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
return nil, tcpip.ErrUnknownProtocol
@@ -137,7 +137,7 @@
// AddAddress adds a new address to n, so that it starts accepting packets
// targeted at the given address (and network protocol).
-func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) error {
+func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
// Add the endpoint.
n.mu.Lock()
_, err := n.addAddressLocked(protocol, addr, false)
@@ -196,7 +196,7 @@
}
// RemoveAddress removes an address from n.
-func (n *NIC) RemoveAddress(addr tcpip.Address) error {
+func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
r := n.endpoints[NetworkEndpointID{addr}]
if r == nil || !r.holdsInsertRef {
diff --git a/tcpip/stack/registration.go b/tcpip/stack/registration.go
index 25e97ca..0954ed8 100644
--- a/tcpip/stack/registration.go
+++ b/tcpip/stack/registration.go
@@ -52,7 +52,7 @@
Number() tcpip.TransportProtocolNumber
// NewEndpoint creates a new endpoint of the transport protocol.
- NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, error)
+ NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
// MinimumPacketSize returns the minimum valid packet size of this
// transport protocol. The stack automatically drops any packets smaller
@@ -61,7 +61,7 @@
// ParsePorts returns the source and destination ports stored in a
// packet of this protocol.
- ParsePorts(v buffer.View) (src, dst uint16, err error)
+ ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this
// protocol but that don't match any existing endpoint. For example,
@@ -97,7 +97,7 @@
// WritePacket writes a packet to the given destination address and
// protocol.
- WritePacket(r *Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) error
+ WritePacket(r *Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
@@ -133,7 +133,7 @@
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint creates a new endpoint of this protocol.
- NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, error)
+ NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error)
}
// NetworkDispatcher contains the methods used by the network stack to deliver
@@ -167,7 +167,7 @@
// WritePacket writes a packet with the given protocol through the given
// route.
- WritePacket(r *Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) error
+ WritePacket(r *Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
@@ -182,7 +182,7 @@
//
// A valid response will cause the discovery protocol's network
// endpoint to call AddLinkAddress.
- LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) error
+ LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error
// LinkAddressProtocol returns the network protocol of the
// addresses this this resolver can resolve.
diff --git a/tcpip/stack/route.go b/tcpip/stack/route.go
index 1cf6d99..c3eca94 100644
--- a/tcpip/stack/route.go
+++ b/tcpip/stack/route.go
@@ -76,7 +76,7 @@
}
// WritePacket writes the packet through the given route.
-func (r *Route) WritePacket(hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) error {
+func (r *Route) WritePacket(hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
if r.RemoteLinkAddress == "" && r.ref.linkRes != nil && !isLoopback(r.RemoteAddress) {
nextAddr := r.NextHop
if nextAddr == "" {
diff --git a/tcpip/stack/stack.go b/tcpip/stack/stack.go
index 7c1ddbf..3d9584d 100644
--- a/tcpip/stack/stack.go
+++ b/tcpip/stack/stack.go
@@ -132,7 +132,7 @@
}
// NewEndpoint creates a new transport layer endpoint of the given protocol.
-func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, error) {
+func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
t, ok := s.transportProtocols[transport]
if !ok {
return nil, tcpip.ErrUnknownProtocol
@@ -143,7 +143,7 @@
// createNIC creates a NIC with the provided id and link-layer endpoint, and
// optionally enable it.
-func (s *Stack) createNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID, enabled bool) error {
+func (s *Stack) createNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID, enabled bool) *tcpip.Error {
ep := FindLinkEndpoint(linkEP)
if ep == nil {
return tcpip.ErrBadLinkEndpoint
@@ -168,20 +168,20 @@
}
// CreateNIC creates a NIC with the provided id and link-layer endpoint.
-func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) error {
+func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
return s.createNIC(id, linkEP, true)
}
// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint,
// but leave it disable. Stack.EnableNIC must be called before the link-layer
// endpoint starts delivering packets to it.
-func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) error {
+func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
return s.createNIC(id, linkEP, false)
}
// EnableNIC enables the given NIC so that the link-layer endpoint can start
// delivering packets to it.
-func (s *Stack) EnableNIC(id tcpip.NICID) error {
+func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -209,7 +209,7 @@
}
// AddAddress adds a new network-layer address to the specified NIC.
-func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) error {
+func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -222,7 +222,7 @@
}
// AddSubnet adds a subnet range to the specified NIC.
-func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) error {
+func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -237,7 +237,7 @@
// RemoveAddress removes an existing network-layer address from the specified
// NIC.
-func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) error {
+func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -251,7 +251,7 @@
// FindRoute creates a route to the given destination address, leaving through
// the given nic and local address (if provided).
-func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (Route, error) {
+func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (Route, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -328,7 +328,7 @@
}
// SetPromiscuousMode enables or disables promiscuous mode in the given NIC.
-func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) error {
+func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -363,7 +363,7 @@
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
// nic-specific IDs have precedence over global ones.
-func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) error {
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
if nicID == 0 {
return s.demux.registerEndpoint(netProtos, protocol, id, ep)
}
@@ -399,7 +399,7 @@
// JoinGroup joins the given multicast group on every interface that
// matches the given interface address.
// TODO: notify network of subscription via igmp protocol
-func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, interfaceAddr tcpip.Address, multicastAddr tcpip.Address) error {
+func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, interfaceAddr tcpip.Address, multicastAddr tcpip.Address) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -422,7 +422,7 @@
// LeaveGroup leaves the given multicast group on every interface that
// matches the given interface address.
-func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, interfaceAddr tcpip.Address, multicastAddr tcpip.Address) error {
+func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, interfaceAddr tcpip.Address, multicastAddr tcpip.Address) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
diff --git a/tcpip/stack/stack_test.go b/tcpip/stack/stack_test.go
index 4033b5a..25c459f 100644
--- a/tcpip/stack/stack_test.go
+++ b/tcpip/stack/stack_test.go
@@ -74,7 +74,7 @@
return 0
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
@@ -109,7 +109,7 @@
return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
}
-func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, error) {
+func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
return &fakeNetworkEndpoint{
nicid: nicid,
id: stack.NetworkEndpointID{addr},
diff --git a/tcpip/stack/transport_demuxer.go b/tcpip/stack/transport_demuxer.go
index 2531efc..210905c 100644
--- a/tcpip/stack/transport_demuxer.go
+++ b/tcpip/stack/transport_demuxer.go
@@ -46,7 +46,7 @@
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
-func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) error {
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
for i, n := range netProtos {
if err := d.singleRegisterEndpoint(n, protocol, id, ep); err != nil {
d.unregisterEndpoint(netProtos[:i], protocol, id)
@@ -57,7 +57,7 @@
return nil
}
-func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) error {
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
if !ok {
return nil
@@ -67,7 +67,7 @@
defer eps.mu.Unlock()
if _, ok := eps.endpoints[id]; ok {
- return tcpip.ErrDuplicateAddress
+ return tcpip.ErrPortInUse
}
eps.endpoints[id] = ep
diff --git a/tcpip/stack/transport_test.go b/tcpip/stack/transport_test.go
index 7308e4e..b6d057b 100644
--- a/tcpip/stack/transport_test.go
+++ b/tcpip/stack/transport_test.go
@@ -5,7 +5,6 @@
package stack_test
import (
- "io"
"testing"
"github.com/google/netstack/tcpip"
@@ -47,11 +46,11 @@
return mask
}
-func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, error) {
+func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, *tcpip.Error) {
return buffer.View{}, nil
}
-func (f *fakeTransportEndpoint) Write(v buffer.View, _ *tcpip.FullAddress) (uintptr, error) {
+func (f *fakeTransportEndpoint) Write(v buffer.View, _ *tcpip.FullAddress) (uintptr, *tcpip.Error) {
if len(f.route.RemoteAddress) == 0 {
return 0, tcpip.ErrNoRoute
}
@@ -65,17 +64,17 @@
return uintptr(len(v)), nil
}
-func (f *fakeTransportEndpoint) Peek(io.Writer) (uintptr, error) {
+func (f *fakeTransportEndpoint) Peek([][]byte) (uintptr, *tcpip.Error) {
return 0, nil
}
// SetSockOpt sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOpt(interface{}) error {
+func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) error {
+func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch opt.(type) {
case tcpip.ErrorOption:
return nil
@@ -83,7 +82,7 @@
return tcpip.ErrInvalidEndpointState
}
-func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) error {
+func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
f.peerAddr = addr.Addr
// Find the route.
@@ -105,34 +104,34 @@
return nil
}
-func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) error {
+func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
return nil
}
-func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) error {
+func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) *tcpip.Error {
return nil
}
func (*fakeTransportEndpoint) Reset() {
}
-func (*fakeTransportEndpoint) Listen(int) error {
+func (*fakeTransportEndpoint) Listen(int) *tcpip.Error {
return nil
}
-func (*fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, error) {
+func (*fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, nil
}
-func (*fakeTransportEndpoint) Bind(_ tcpip.FullAddress, commit func() error) error {
+func (*fakeTransportEndpoint) Bind(_ tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
return commit()
}
-func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, error) {
+func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, nil
}
-func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, error) {
+func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, nil
}
@@ -151,7 +150,7 @@
return fakeTransNumber
}
-func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, error) {
+func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newFakeTransportEndpoint(stack, f, netProto), nil
}
@@ -159,7 +158,7 @@
return fakeTransHeaderLen
}
-func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err error) {
+func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcpip.Error) {
return 0, 0, nil
}
diff --git a/tcpip/tcpip.go b/tcpip/tcpip.go
index f9b5e5e..3faef89 100644
--- a/tcpip/tcpip.go
+++ b/tcpip/tcpip.go
@@ -21,41 +21,52 @@
import (
"errors"
"fmt"
- "io"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/waiter"
)
+// Error represents an error in the netstack error space. Using a special type
+// ensures that errors outside of this space are not accidentally introduced.
+type Error struct {
+ string
+}
+
+// String implements fmt.Stringer.String.
+func (e *Error) String() string {
+ return e.string
+}
+
// Errors that can be returned by the network stack.
var (
- ErrUnknownProtocol = errors.New("unknown protocol")
- ErrUnknownNICID = errors.New("unknown nic id")
- ErrDuplicateNICID = errors.New("duplicate nic id")
- ErrDuplicateAddress = errors.New("duplicate address")
- ErrNoRoute = errors.New("no route")
- ErrBadLinkEndpoint = errors.New("bad link layer endpoint")
- ErrAlreadyBound = errors.New("endpoint already bound")
- ErrInvalidEndpointState = errors.New("endpoint is in invalid state")
- ErrAlreadyConnecting = errors.New("endpoint is already connecting")
- ErrAlreadyConnected = errors.New("endpoint is already connected")
- ErrNoPortAvailable = errors.New("no ports are available")
- ErrPortInUse = errors.New("port is in use")
- ErrBadLocalAddress = errors.New("bad local address")
- ErrClosedForSend = errors.New("endpoint is closed for send")
- ErrClosedForReceive = errors.New("endpoint is closed for receive")
- ErrWouldBlock = errors.New("operation would block")
- ErrConnectionRefused = errors.New("connection was refused")
- ErrTimeout = errors.New("operation timed out")
- ErrAborted = errors.New("operation aborted")
- ErrConnectStarted = errors.New("connection attempt started")
- ErrDestinationRequired = errors.New("destination address is required")
- ErrNotSupported = errors.New("operation not supported")
- ErrQueueSizeNotSupported = errors.New("queue size querying not supported")
- ErrNotConnected = errors.New("endpoint not connected")
- ErrConnectionReset = errors.New("connection reset by peer")
- ErrConnectionAborted = errors.New("connection aborted")
- ErrNoLinkAddress = errors.New("no remote link address")
+ ErrUnknownProtocol = &Error{"unknown protocol"}
+ ErrUnknownNICID = &Error{"unknown nic id"}
+ ErrUnknownProtocolOption = &Error{"unknown option for protocol"}
+ ErrDuplicateNICID = &Error{"duplicate nic id"}
+ ErrDuplicateAddress = &Error{"duplicate address"}
+ ErrNoRoute = &Error{"no route"}
+ ErrBadLinkEndpoint = &Error{"bad link layer endpoint"}
+ ErrAlreadyBound = &Error{"endpoint already bound"}
+ ErrInvalidEndpointState = &Error{"endpoint is in invalid state"}
+ ErrAlreadyConnecting = &Error{"endpoint is already connecting"}
+ ErrAlreadyConnected = &Error{"endpoint is already connected"}
+ ErrNoPortAvailable = &Error{"no ports are available"}
+ ErrPortInUse = &Error{"port is in use"}
+ ErrBadLocalAddress = &Error{"bad local address"}
+ ErrClosedForSend = &Error{"endpoint is closed for send"}
+ ErrClosedForReceive = &Error{"endpoint is closed for receive"}
+ ErrWouldBlock = &Error{"operation would block"}
+ ErrConnectionRefused = &Error{"connection was refused"}
+ ErrTimeout = &Error{"operation timed out"}
+ ErrAborted = &Error{"operation aborted"}
+ ErrConnectStarted = &Error{"connection attempt started"}
+ ErrDestinationRequired = &Error{"destination address is required"}
+ ErrNotSupported = &Error{"operation not supported"}
+ ErrQueueSizeNotSupported = &Error{"queue size querying not supported"}
+ ErrNotConnected = &Error{"endpoint not connected"}
+ ErrConnectionReset = &Error{"connection reset by peer"}
+ ErrConnectionAborted = &Error{"connection aborted"}
+ ErrNoLinkAddress = &Error{"no remote link address"}
)
// Errors related to Subnet
@@ -178,17 +189,17 @@
// Read reads data from the endpoint and optionally returns the sender.
// This method does not block if there is no data pending.
// It will also either return an error or data, never both.
- Read(*FullAddress) (buffer.View, error)
+ Read(*FullAddress) (buffer.View, *Error)
// Write writes data to the endpoint's peer, or the provided address if
// one is specified. This method does not block if the data cannot be
// written.
- Write(buffer.View, *FullAddress) (uintptr, error)
+ Write(buffer.View, *FullAddress) (uintptr, *Error)
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
- Peek(io.Writer) (uintptr, error)
+ Peek([][]byte) (uintptr, *Error)
// Connect connects the endpoint to its peer. Specifying a NIC is
// optional.
@@ -201,22 +212,22 @@
// the endpoint becomes writable. (This mimics the
// connect(2) syscall behavior.)
// Anything else -- the attempt to connect failed.
- Connect(address FullAddress) error
+ Connect(address FullAddress) *Error
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
- Shutdown(flags ShutdownFlags) error
+ Shutdown(flags ShutdownFlags) *Error
// Listen puts the endpoint in "listen" mode, which allows it to accept
// new connections.
- Listen(backlog int) error
+ Listen(backlog int) *Error
// Accept returns a new endpoint if a peer has established a connection
// to an endpoint previously set to listen mode. This method does not
// block if no new connections are available.
//
// The returned Queue is the wait queue for the newly created endpoint.
- Accept() (Endpoint, *waiter.Queue, error)
+ Accept() (Endpoint, *waiter.Queue, *Error)
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
@@ -224,25 +235,25 @@
// An optional commit function will be executed atomically with respect
// to binding the endpoint. If this returns an error, the bind will not
// occur and the error will be propagated back to the caller.
- Bind(address FullAddress, commit func() error) error
+ Bind(address FullAddress, commit func() *Error) *Error
// GetLocalAddress returns the address to which the endpoint is bound.
- GetLocalAddress() (FullAddress, error)
+ GetLocalAddress() (FullAddress, *Error)
// GetRemoteAddress returns the address to which the endpoint is
// connected.
- GetRemoteAddress() (FullAddress, error)
+ GetRemoteAddress() (FullAddress, *Error)
// Readiness returns the current readiness of the endpoint. For example,
// if waiter.EventIn is set, the endpoint is immediately readable.
Readiness(mask waiter.EventMask) waiter.EventMask
// SetSockOpt sets a socket option. opt should be one of the *Option types.
- SetSockOpt(opt interface{}) error
+ SetSockOpt(opt interface{}) *Error
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// *Option types.
- GetSockOpt(opt interface{}) error
+ GetSockOpt(opt interface{}) *Error
}
// ErrorOption is used in GetSockOpt to specify that the last error reported by
@@ -257,6 +268,10 @@
// receive buffer size option.
type ReceiveBufferSizeOption int
+// SendQueueSizeOption is used in GetSockOpt to specify that the number of
+// unread bytes in the output buffer should be returned.
+type SendQueueSizeOption int
+
// ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of
// unread bytes in the input buffer should be returned.
type ReceiveQueueSizeOption int
@@ -349,17 +364,17 @@
type Stack interface {
// NewEndpoint creates a new transport layer endpoint of the given
// protocol.
- NewEndpoint(transport TransportProtocolNumber, network NetworkProtocolNumber, waiterQueue *waiter.Queue) (Endpoint, error)
+ NewEndpoint(transport TransportProtocolNumber, network NetworkProtocolNumber, waiterQueue *waiter.Queue) (Endpoint, *Error)
// SetRouteTable assigns the route table to be used by this stack. It
// specifies which NICs to use for given destination address ranges.
SetRouteTable(table []Route)
// CreateNIC creates a NIC with the provided id and link-layer sender.
- CreateNIC(id NICID, linkEndpoint LinkEndpointID) error
+ CreateNIC(id NICID, linkEndpoint LinkEndpointID) *Error
// AddAddress adds a new network-layer address to the specified NIC.
- AddAddress(id NICID, protocol NetworkProtocolNumber, addr Address) error
+ AddAddress(id NICID, protocol NetworkProtocolNumber, addr Address) *Error
// Stats returns a snapshot of the current stats.
// TODO: Make stats available in sentry for debugging/diag.
diff --git a/tcpip/transport/queue/queue.go b/tcpip/transport/queue/queue.go
index a269232..b9cd3d1 100644
--- a/tcpip/transport/queue/queue.go
+++ b/tcpip/transport/queue/queue.go
@@ -86,7 +86,7 @@
//
// If notify is true, ReaderQueue.Notify must be called:
// q.ReaderQueue.Notify(waiter.EventIn)
-func (q *Queue) Enqueue(e Entry) (notify bool, err error) {
+func (q *Queue) Enqueue(e Entry) (notify bool, err *tcpip.Error) {
q.mu.Lock()
if q.closed {
@@ -112,7 +112,7 @@
//
// If notify is true, WriterQueue.Notify must be called:
// q.WriterQueue.Notify(waiter.EventOut)
-func (q *Queue) Dequeue() (e Entry, notify bool, err error) {
+func (q *Queue) Dequeue() (e Entry, notify bool, err *tcpip.Error) {
q.mu.Lock()
if q.dataList.Front() == nil {
@@ -139,7 +139,7 @@
}
// Peek returns the first entry in the data queue, if one exists.
-func (q *Queue) Peek() (Entry, error) {
+func (q *Queue) Peek() (Entry, *tcpip.Error) {
q.mu.Lock()
defer q.mu.Unlock()
@@ -159,3 +159,8 @@
func (q *Queue) QueuedSize() int64 {
return q.used
}
+
+// MaxQueueSize returns the maximum number of bytes storable in the queue.
+func (q *Queue) MaxQueueSize() int64 {
+ return q.limit
+}
diff --git a/tcpip/transport/tcp/accept.go b/tcpip/transport/tcp/accept.go
index 5561601..c5eecc6 100644
--- a/tcpip/transport/tcp/accept.go
+++ b/tcpip/transport/tcp/accept.go
@@ -183,7 +183,7 @@
// createConnectedEndpoint creates a new connected endpoint, with the connection
// parameters given by the arguments.
-func (l *listenContext) createConnectedEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, mss uint16, sndWndScale int) (*endpoint, error) {
+func (l *listenContext) createConnectedEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, mss uint16, sndWndScale int) (*endpoint, *tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
@@ -218,7 +218,7 @@
// createEndpoint creates a new endpoint in connected state and then performs
// the TCP 3-way handshake.
-func (l *listenContext) createEndpointAndPerformHandshake(s *segment, mss uint16, sndWndScale int) (*endpoint, error) {
+func (l *listenContext) createEndpointAndPerformHandshake(s *segment, mss uint16, sndWndScale int) (*endpoint, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
cookie := l.createCookie(s.id, irs, encodeMSS(mss))
@@ -313,7 +313,7 @@
// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in
// its own goroutine and is responsible for handling connection requests.
-func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) error {
+func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
// handleSynSegment() from attempting to queue new connections
diff --git a/tcpip/transport/tcp/connect.go b/tcpip/transport/tcp/connect.go
index 883fa78..a64fe87 100644
--- a/tcpip/transport/tcp/connect.go
+++ b/tcpip/transport/tcp/connect.go
@@ -75,7 +75,7 @@
rcvWndScale int
}
-func newHandshake(ep *endpoint, rcvWnd seqnum.Size) (handshake, error) {
+func newHandshake(ep *endpoint, rcvWnd seqnum.Size) (handshake, *tcpip.Error) {
h := handshake{
ep: ep,
active: true,
@@ -108,10 +108,10 @@
// resetState resets the state of the handshake object such that it becomes
// ready for a new 3-way handshake.
-func (h *handshake) resetState() error {
+func (h *handshake) resetState() *tcpip.Error {
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
- return err
+ panic(err)
}
h.state = handshakeSynSent
@@ -164,7 +164,7 @@
// synSentState handles a segment received when the TCP 3-way handshake is in
// the SYN-SENT state.
-func (h *handshake) synSentState(s *segment) error {
+func (h *handshake) synSentState(s *segment) *tcpip.Error {
// RFC 793, page 37, states that in the SYN-SENT state, a reset is
// acceptable if the ack field acknowledges the SYN.
if s.flagIsSet(flagRst) {
@@ -215,7 +215,7 @@
// synRcvdState handles a segment received when the TCP 3-way handshake is in
// the SYN-RCVD state.
-func (h *handshake) synRcvdState(s *segment) error {
+func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
if s.flagIsSet(flagRst) {
// RFC 793, page 37, states that in the SYN-RCVD state, a reset
// is acceptable if the sequence number is in the window.
@@ -264,7 +264,7 @@
// processSegments goes through the segment queue and processes up to
// maxSegmentsPerWake (if they're available).
-func (h *handshake) processSegments() error {
+func (h *handshake) processSegments() *tcpip.Error {
for i := 0; i < maxSegmentsPerWake; i++ {
s := h.ep.segmentQueue.dequeue()
if s == nil {
@@ -276,7 +276,7 @@
h.sndWnd <<= uint8(h.sndWndScale)
}
- var err error
+ var err *tcpip.Error
switch h.state {
case handshakeSynRcvd:
err = h.synRcvdState(s)
@@ -306,7 +306,7 @@
}
// execute executes the TCP 3-way handshake.
-func (h *handshake) execute() error {
+func (h *handshake) execute() *tcpip.Error {
// Initialize the resend timer.
resendWaker := sleep.Waker{}
timeOut := time.Duration(time.Second)
@@ -404,7 +404,7 @@
return mss, ws, true
}
-func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, rcvWndScale int) error {
+func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, rcvWndScale int) *tcpip.Error {
// Initialize the options.
mss := r.MTU() - header.TCPMinimumSize
options := []byte{
@@ -425,7 +425,7 @@
// sendTCPWithOptions sends a TCP segment with the provided options via the
// provided network endpoint and under the provided identity.
-func sendTCPWithOptions(r *stack.Route, id stack.TransportEndpointID, data buffer.View, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) error {
+func sendTCPWithOptions(r *stack.Route, id stack.TransportEndpointID, data buffer.View, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error {
// Allocate a buffer for the TCP header.
hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + len(opts))
@@ -460,7 +460,7 @@
// sendTCP sends a TCP segment via the provided network endpoint and under the
// provided identity.
-func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.View, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) error {
+func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.View, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
// Allocate a buffer for the TCP header.
hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()))
@@ -493,7 +493,7 @@
}
// sendRaw sends a TCP segment to the endpoint's peer.
-func (e *endpoint) sendRaw(data buffer.View, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) error {
+func (e *endpoint) sendRaw(data buffer.View, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
return sendTCP(&e.route, e.id, data, flags, seq, ack, rcvWnd)
}
@@ -537,7 +537,7 @@
// resetConnection sends a RST segment and puts the endpoint in an error state
// with the given error code.
// This method must only be called from the protocol goroutine.
-func (e *endpoint) resetConnection(err error) {
+func (e *endpoint) resetConnection(err *tcpip.Error) {
e.sendRaw(nil, flagAck|flagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
e.mu.Lock()
@@ -614,7 +614,7 @@
// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
// goroutine and is responsible for sending segments and handling received
// segments.
-func (e *endpoint) protocolMainLoop(passive bool) error {
+func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
var closeTimer *time.Timer
var closeWaker sleep.Waker
diff --git a/tcpip/transport/tcp/dual_stack_test.go b/tcpip/transport/tcp/dual_stack_test.go
index 90e2ba8..ca3258b 100644
--- a/tcpip/transport/tcp/dual_stack_test.go
+++ b/tcpip/transport/tcp/dual_stack_test.go
@@ -72,7 +72,7 @@
}
func (c *testContext) createV6Endpoint(v4only bool) {
- var err error
+ var err *tcpip.Error
c.ep, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
@@ -666,7 +666,7 @@
defer c.cleanup()
// Create TCP endpoint.
- var err error
+ var err *tcpip.Error
c.ep, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
diff --git a/tcpip/transport/tcp/endpoint.go b/tcpip/transport/tcp/endpoint.go
index 2bebc1e..e925245 100644
--- a/tcpip/transport/tcp/endpoint.go
+++ b/tcpip/transport/tcp/endpoint.go
@@ -5,7 +5,6 @@
package tcp
import (
- "io"
"math"
"sync"
"sync/atomic"
@@ -63,7 +62,7 @@
// lastError represents the last error that the endpoint reported;
// access to it is protected by the following mutex.
lastErrorMu sync.Mutex
- lastError error
+ lastError *tcpip.Error
// The following fields are used to manage the receive queue. The
// protocol goroutine adds ready-for-delivery segments to rcvList,
@@ -98,7 +97,7 @@
// hardError is meaningful only when state is stateError, it stores the
// error to be returned when read/write syscalls are called and the
// endpoint is in this state.
- hardError error
+ hardError *tcpip.Error
// workerRunning specifies if a worker goroutine is running.
workerRunning bool
@@ -161,7 +160,6 @@
stack: stack,
netProto: netProto,
waiterQueue: waiterQueue,
- v6only: true,
rcvBufSize: defaultBufferSize,
sndBufSize: defaultBufferSize,
noDelay: true,
@@ -295,20 +293,17 @@
}
// Read reads data from the endpoint.
-func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, error) {
+func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, *tcpip.Error) {
e.mu.RLock()
- // The endpoint cannot be read from if it's not connected.
- if s := e.state; s != stateConnected {
+ // The endpoint can be read if it's connected, or if it's already closed
+ // but has some pending unread data.
+ if s := e.state; s != stateConnected && s != stateClosed {
e.mu.RUnlock()
- switch s {
- case stateClosed:
- return buffer.View{}, tcpip.ErrClosedForReceive
- case stateError:
+ if s == stateError {
return buffer.View{}, e.hardError
- default:
- return buffer.View{}, tcpip.ErrInvalidEndpointState
}
+ return buffer.View{}, tcpip.ErrInvalidEndpointState
}
e.rcvListMu.Lock()
@@ -320,9 +315,9 @@
return v, err
}
-func (e *endpoint) readLocked() (buffer.View, error) {
+func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
if e.rcvBufUsed == 0 {
- if e.rcvClosed {
+ if e.rcvClosed || e.state != stateConnected {
return buffer.View{}, tcpip.ErrClosedForReceive
}
return buffer.View{}, tcpip.ErrWouldBlock
@@ -349,10 +344,9 @@
}
// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(v buffer.View, to *tcpip.FullAddress) (uintptr, error) {
- if to != nil {
- return 0, tcpip.ErrAlreadyConnected
- }
+func (e *endpoint) Write(v buffer.View, to *tcpip.FullAddress) (uintptr, *tcpip.Error) {
+ // Linux completely ignores any address passed to sendto(2) for TCP sockets
+ // (without the MSG_FASTOPEN flag).
e.mu.RLock()
defer e.mu.RUnlock()
@@ -363,7 +357,7 @@
case stateError:
return 0, e.hardError
default:
- return 0, tcpip.ErrInvalidEndpointState
+ return 0, tcpip.ErrClosedForSend
}
}
@@ -412,44 +406,57 @@
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
-func (e *endpoint) Peek(w io.Writer) (uintptr, error) {
+func (e *endpoint) Peek(vec [][]byte) (uintptr, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- // The endpoint cannot be read from if it's not connected.
- if e.state != stateConnected {
- switch e.state {
- case stateClosed:
- return 0, tcpip.ErrClosedForReceive
- case stateError:
+ // The endpoint can be read if it's connected, or if it's already closed
+ // but has some pending unread data.
+ if s := e.state; s != stateConnected && s != stateClosed {
+ if s == stateError {
return 0, e.hardError
- default:
- return 0, tcpip.ErrInvalidEndpointState
}
+ return 0, tcpip.ErrInvalidEndpointState
}
e.rcvListMu.Lock()
defer e.rcvListMu.Unlock()
if e.rcvBufUsed == 0 {
- if e.rcvClosed {
+ if e.rcvClosed || e.state != stateConnected {
return 0, tcpip.ErrClosedForReceive
}
return 0, tcpip.ErrWouldBlock
}
+ // Make a copy of vec so we can modify the slide headers.
+ vec = append([][]byte(nil), vec...)
+
var num uintptr
for s := e.rcvList.Front(); s != nil; s = s.Next() {
views := s.data.Views()
+
for i := s.viewToDeliver; i < len(views); i++ {
- n, err := w.Write(views[i])
- num += uintptr(n)
- if err != nil {
- return num, err
+ v := views[i]
+
+ for len(v) > 0 {
+ if len(vec) == 0 {
+ return num, nil
+ }
+ if len(vec[0]) == 0 {
+ vec = vec[1:]
+ continue
+ }
+
+ n := copy(vec[0], v)
+ v = v[n:]
+ vec[0] = vec[0][n:]
+ num += uintptr(n)
}
}
}
+
return num, nil
}
@@ -466,7 +473,7 @@
}
// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt interface{}) error {
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
switch v := opt.(type) {
case tcpip.NoDelayOption:
e.mu.Lock()
@@ -533,7 +540,7 @@
}
// readyReceiveSize returns the number of bytes ready to be received.
-func (e *endpoint) readyReceiveSize() (int, error) {
+func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
@@ -549,7 +556,7 @@
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt interface{}) error {
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
case tcpip.ErrorOption:
e.lastErrorMu.Lock()
@@ -604,7 +611,7 @@
case *tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.netProto != header.IPv6ProtocolNumber {
- return tcpip.ErrInvalidEndpointState
+ return tcpip.ErrUnknownProtocolOption
}
e.mu.Lock()
@@ -618,10 +625,10 @@
return nil
}
- return tcpip.ErrInvalidEndpointState
+ return tcpip.ErrUnknownProtocolOption
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, error) {
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
netProto := e.netProto
if header.IsV4MappedAddress(addr.Addr) {
// Fail if using a v4 mapped address on a v6only endpoint.
@@ -646,7 +653,7 @@
}
// Connect connects the endpoint to its peer.
-func (e *endpoint) Connect(addr tcpip.FullAddress) error {
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -708,13 +715,13 @@
} else {
// The endpoint doesn't have a local port yet, so try to get
// one.
- _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, error) {
+ _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
e.id.LocalPort = p
err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e)
switch err {
case nil:
return true, nil
- case tcpip.ErrDuplicateAddress:
+ case tcpip.ErrPortInUse:
return false, nil
default:
return false, err
@@ -746,13 +753,13 @@
}
// ConnectEndpoint is not supported.
-func (*endpoint) ConnectEndpoint(tcpip.Endpoint) error {
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
// Shutdown closes the read and/or write end of the endpoint connection to its
// peer.
-func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) error {
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -797,7 +804,7 @@
// Listen puts the endpoint in "listen" mode, which allows it to accept
// new connections.
-func (e *endpoint) Listen(backlog int) error {
+func (e *endpoint) Listen(backlog int) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -850,7 +857,7 @@
// Accept returns a new endpoint if a peer has established a connection
// to an endpoint previously set to listen mode.
-func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, error) {
+func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
@@ -875,7 +882,7 @@
}
// Bind binds the endpoint to a specific local port and optionally address.
-func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() error) (retErr error) {
+func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (retErr *tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
@@ -951,7 +958,7 @@
}
// GetLocalAddress returns the address to which the endpoint is bound.
-func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, error) {
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
@@ -963,7 +970,7 @@
}
// GetRemoteAddress returns the address to which the endpoint is connected.
-func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, error) {
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
diff --git a/tcpip/transport/tcp/forwarder.go b/tcpip/transport/tcp/forwarder.go
index 2490a85..6bd184e 100644
--- a/tcpip/transport/tcp/forwarder.go
+++ b/tcpip/transport/tcp/forwarder.go
@@ -137,7 +137,7 @@
// CreateEndpoint creates a TCP endpoint for the connection request, performing
// the 3-way handshake in the process.
-func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, error) {
+func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
r.mu.Lock()
defer r.mu.Unlock()
diff --git a/tcpip/transport/tcp/protocol.go b/tcpip/transport/tcp/protocol.go
index 9efc86a..7e8dcf9 100644
--- a/tcpip/transport/tcp/protocol.go
+++ b/tcpip/transport/tcp/protocol.go
@@ -35,7 +35,7 @@
}
// NewEndpoint creates a new tcp endpoint.
-func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, error) {
+func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, waiterQueue), nil
}
@@ -46,7 +46,7 @@
// ParsePorts returns the source and destination ports stored in the given tcp
// packet.
-func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err error) {
+func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
h := header.TCP(v)
return h.SourcePort(), h.DestinationPort(), nil
}
diff --git a/tcpip/transport/tcp/snd.go b/tcpip/transport/tcp/snd.go
index 5497824..0e0bbc0 100644
--- a/tcpip/transport/tcp/snd.go
+++ b/tcpip/transport/tcp/snd.go
@@ -9,6 +9,7 @@
"time"
"github.com/google/netstack/sleep"
+ "github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
"github.com/google/netstack/tcpip/seqnum"
@@ -479,9 +480,12 @@
ackLeft := acked
originalOutsanding := s.outstanding
- for s.writeList.Front() != nil {
+ for ackLeft > 0 {
+ // We use logicalLen here because we can have FIN
+ // segments (which are always at the end of list) that
+ // have no data, but do consume a sequence number.
seg := s.writeList.Front()
- datalen := seqnum.Size(seg.data.Size())
+ datalen := seg.logicalLen()
if datalen > ackLeft {
seg.data.TrimFront(int(ackLeft))
@@ -527,7 +531,7 @@
// sendSegment sends a new segment containing the given payload, flags and
// sequence number.
-func (s *sender) sendSegment(data *buffer.VectorisedView, flags byte, seq seqnum.Value) error {
+func (s *sender) sendSegment(data *buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error {
s.lastSendTime = time.Now()
if seq == s.rttMeasureSeqNum {
s.rttMeasureTime = s.lastSendTime
diff --git a/tcpip/transport/tcp/tcp_test.go b/tcpip/transport/tcp/tcp_test.go
index 29f6f80..86782bf 100644
--- a/tcpip/transport/tcp/tcp_test.go
+++ b/tcpip/transport/tcp/tcp_test.go
@@ -211,7 +211,7 @@
func (c *testContext) createConnectedWithOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
// Create TCP endpoint.
- var err error
+ var err *tcpip.Error
c.ep, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
@@ -1401,9 +1401,9 @@
defer c.cleanup()
s := c.s.(*stack.Stack)
- ch := make(chan error, 1)
+ ch := make(chan *tcpip.Error, 1)
f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) {
- var err error
+ var err *tcpip.Error
c.ep, err = r.CreateEndpoint(&c.wq)
ch <- err
})
@@ -1432,7 +1432,7 @@
defer c.cleanup()
// Create TCP endpoint.
- var err error
+ var err *tcpip.Error
c.ep, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
@@ -1774,6 +1774,89 @@
)
}
+func TestFinWithPendingDataCwndFull(t *testing.T) {
+ c := newTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createConnected(789, 30000, nil)
+
+ // Write something out but don't ACK it yet.
+ view := buffer.NewView(10)
+ if _, err := c.ep.Write(view, nil); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ next := uint32(c.irs) + 1
+ checker.IPv4(c.t, c.getPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(testPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ next += uint32(len(view))
+
+ // Shutdown the connection, check that the FIN segment isn't sent
+ // because the congestion window doesn't allow it. Wait until a
+ // retransmit is received.
+ if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Unexpected error from Shutdown: %v", err)
+ }
+
+ checker.IPv4(c.t, c.getPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(testPort),
+ checker.SeqNum(next-uint32(len(view))),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ // Send the ACK that will allow the FIN to be sent as well.
+ c.sendPacket(nil, &headers{
+ srcPort: testPort,
+ dstPort: c.port,
+ flags: header.TCPFlagAck,
+ seqNum: 790,
+ ackNum: seqnum.Value(next),
+ rcvWnd: 30000,
+ })
+
+ checker.IPv4(c.t, c.getPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(testPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+ next++
+
+ // Send a FIN that acknowledges everything. Get an ACK back.
+ c.sendPacket(nil, &headers{
+ srcPort: testPort,
+ dstPort: c.port,
+ flags: header.TCPFlagAck | header.TCPFlagFin,
+ seqNum: 790,
+ ackNum: seqnum.Value(next),
+ rcvWnd: 30000,
+ })
+
+ checker.IPv4(c.t, c.getPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(testPort),
+ checker.SeqNum(next),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
func TestFinWithPendingData(t *testing.T) {
c := newTestContext(t, defaultMTU)
defer c.cleanup()
@@ -2383,3 +2466,100 @@
}
}
}
+
+func TestReadAfterClosedState(t *testing.T) {
+ // This test ensures that calling Read() or Peek() after the endpoint
+ // has transitioned to closedState still works if there is pending
+ // data. To transition to stateClosed without calling Close(), we must
+ // shutdown the send path and the peer must send its own FIN.
+ c := newTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createConnected(789, 30000, nil)
+
+ we, ch := waiter.NewChannelEntry(nil)
+ c.wq.EventRegister(&we, waiter.EventIn)
+ defer c.wq.EventUnregister(&we)
+
+ if _, err := c.ep.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+
+ // Shutdown immediately for write, check that we get a FIN.
+ if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Unexpected error from Shutdown: %v", err)
+ }
+
+ checker.IPv4(c.t, c.getPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(testPort),
+ checker.SeqNum(uint32(c.irs)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+
+ // Send some data and acknowledge the FIN.
+ data := []byte{1, 2, 3}
+ c.sendPacket(data, &headers{
+ srcPort: testPort,
+ dstPort: c.port,
+ flags: header.TCPFlagAck | header.TCPFlagFin,
+ seqNum: 790,
+ ackNum: c.irs.Add(2),
+ rcvWnd: 30000,
+ })
+
+ // Check that ACK is received.
+ checker.IPv4(c.t, c.getPacket(),
+ checker.TCP(
+ checker.DstPort(testPort),
+ checker.SeqNum(uint32(c.irs)+2),
+ checker.AckNum(uint32(791+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Give the stack the chance to transition to closed state.
+ time.Sleep(1 * time.Second)
+
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ // Check that peek works.
+ peekBuf := make([]byte, 10)
+ n, err := c.ep.Peek([][]byte{peekBuf})
+ if err != nil {
+ t.Fatalf("Unexpected error from Peek: %v", err)
+ }
+
+ peekBuf = peekBuf[:n]
+ if bytes.Compare(data, peekBuf) != 0 {
+ t.Fatalf("Data is different: expected %v, got %v", data, peekBuf)
+ }
+
+ // Receive data.
+ v, err := c.ep.Read(nil)
+ if err != nil {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+
+ if bytes.Compare(data, v) != 0 {
+ t.Fatalf("Data is different: expected %v, got %v", data, v)
+ }
+
+ // Now that we drained the queue, check that functions fail with the
+ // right error code.
+ if _, err := c.ep.Read(nil); err != tcpip.ErrClosedForReceive {
+ t.Fatalf("Unexpected return from Read: got %v, want %v", err, tcpip.ErrClosedForReceive)
+ }
+
+ if _, err := c.ep.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
+ t.Fatalf("Unexpected return from Peek: got %v, want %v", err, tcpip.ErrClosedForReceive)
+ }
+}
diff --git a/tcpip/transport/udp/endpoint.go b/tcpip/transport/udp/endpoint.go
index 20fbe97..8d74fc0 100644
--- a/tcpip/transport/udp/endpoint.go
+++ b/tcpip/transport/udp/endpoint.go
@@ -5,8 +5,6 @@
package udp
import (
- "errors"
- "io"
"sync"
"github.com/google/netstack/tcpip"
@@ -34,8 +32,6 @@
stateClosed
)
-var errRetryPrepare = errors.New("prepare operation must be retried")
-
// endpoint represents a UDP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -97,7 +93,7 @@
// NewConnectedEndpoint creates a new endpoint in the connected state using the
// provided route.
-func NewConnectedEndpoint(stack *stack.Stack, r *stack.Route, id stack.TransportEndpointID, waiterQueue *waiter.Queue) (tcpip.Endpoint, error) {
+func NewConnectedEndpoint(stack *stack.Stack, r *stack.Route, id stack.TransportEndpointID, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
ep := newEndpoint(stack, r.NetProto, waiterQueue)
// Register new endpoint so that packets are routed to it.
@@ -153,7 +149,7 @@
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, error) {
+func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, *tcpip.Error) {
e.rcvMu.Lock()
if e.rcvList.Empty() {
@@ -182,20 +178,20 @@
// binds it if it's still in the initial state. To do so, it must first
// reacquire the mutex in exclusive mode.
//
-// Returns errRetryPrepare if preparation should be retried.
-func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) error {
+// Returns true for retry if preparation should be retried.
+func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
switch e.state {
case stateInitial:
case stateConnected:
- return nil
+ return false, nil
case stateBound:
if to == nil {
- return tcpip.ErrDestinationRequired
+ return false, tcpip.ErrDestinationRequired
}
- return nil
+ return false, nil
default:
- return tcpip.ErrInvalidEndpointState
+ return false, tcpip.ErrInvalidEndpointState
}
e.mu.RUnlock()
@@ -207,32 +203,32 @@
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
if e.state != stateInitial {
- return errRetryPrepare
+ return true, nil
}
// The state is still 'initial', so try to bind the endpoint.
if err := e.bindLocked(tcpip.FullAddress{}, nil); err != nil {
- return err
+ return false, err
}
- return errRetryPrepare
+ return true, nil
}
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(v buffer.View, to *tcpip.FullAddress) (uintptr, error) {
+func (e *endpoint) Write(v buffer.View, to *tcpip.FullAddress) (uintptr, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
// Prepare for write.
for {
- err := e.prepareForWrite(to)
- if err == nil {
- break
+ retry, err := e.prepareForWrite(to)
+ if err != nil {
+ return 0, err
}
- if err != errRetryPrepare {
- return 0, err
+ if !retry {
+ break
}
}
@@ -277,12 +273,12 @@
}
// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek(io.Writer) (uintptr, error) {
+func (e *endpoint) Peek([][]byte) (uintptr, *tcpip.Error) {
return 0, nil
}
// SetSockOpt sets a socket option. Currently not supported.
-func (e *endpoint) SetSockOpt(opt interface{}) error {
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// TODO: Actually implement this.
switch v := opt.(type) {
case tcpip.V6OnlyOption:
@@ -330,7 +326,7 @@
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt interface{}) error {
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
case tcpip.ErrorOption:
return nil
@@ -350,7 +346,7 @@
case *tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.netProto != header.IPv6ProtocolNumber {
- return tcpip.ErrInvalidEndpointState
+ return tcpip.ErrUnknownProtocolOption
}
e.mu.Lock()
@@ -381,12 +377,12 @@
return nil
}
- return tcpip.ErrInvalidEndpointState
+ return tcpip.ErrUnknownProtocolOption
}
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.View, localPort, remotePort uint16) error {
+func sendUDP(r *stack.Route, data buffer.View, localPort, remotePort uint16) *tcpip.Error {
// Allocate a buffer for the UDP header.
hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
@@ -411,7 +407,7 @@
return r.WritePacket(&hdr, data, ProtocolNumber)
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, error) {
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
netProto := e.netProto
if header.IsV4MappedAddress(addr.Addr) {
// Fail if using a v4 mapped address on a v6only endpoint.
@@ -436,7 +432,7 @@
}
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
-func (e *endpoint) Connect(addr tcpip.FullAddress) error {
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if addr.Port == 0 {
// We don't support connecting to port zero.
return tcpip.ErrInvalidEndpointState
@@ -520,13 +516,13 @@
}
// ConnectEndpoint is not supported.
-func (*endpoint) ConnectEndpoint(tcpip.Endpoint) error {
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
-func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) error {
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
e.mu.RLock()
defer e.mu.RUnlock()
@@ -551,16 +547,16 @@
}
// Listen is not supported by UDP, it just fails.
-func (*endpoint) Listen(int) error {
+func (*endpoint) Listen(int) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Accept is not supported by UDP, it just fails.
-func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, error) {
+func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
-func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, error) {
+func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
@@ -569,13 +565,13 @@
}
// We need to find a port for the endpoint.
- _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, error) {
+ _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
id.LocalPort = p
err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e)
switch err {
case nil:
return true, nil
- case tcpip.ErrDuplicateAddress:
+ case tcpip.ErrPortInUse:
return false, nil
default:
return false, err
@@ -585,7 +581,7 @@
return id, err
}
-func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() error) error {
+func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
if e.state != stateInitial {
@@ -647,7 +643,7 @@
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
-func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() error) error {
+func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -663,7 +659,7 @@
}
// GetLocalAddress returns the address to which the endpoint is bound.
-func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, error) {
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
@@ -675,12 +671,12 @@
}
// GetRemoteAddress returns the address to which the endpoint is connected.
-func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, error) {
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
if e.state != stateConnected {
- return tcpip.FullAddress{}, tcpip.ErrInvalidEndpointState
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
return tcpip.FullAddress{
diff --git a/tcpip/transport/udp/protocol.go b/tcpip/transport/udp/protocol.go
index 87e56e9..ee2715c 100644
--- a/tcpip/transport/udp/protocol.go
+++ b/tcpip/transport/udp/protocol.go
@@ -34,7 +34,7 @@
}
// NewEndpoint creates a new udp endpoint.
-func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, error) {
+func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, waiterQueue), nil
}
@@ -45,7 +45,7 @@
// ParsePorts returns the source and destination ports stored in the given udp
// packet.
-func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err error) {
+func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
h := header.UDP(v)
return h.SourcePort(), h.DestinationPort(), nil
}
diff --git a/tcpip/transport/udp/udp_test.go b/tcpip/transport/udp/udp_test.go
index 6b405b9..19069c1 100644
--- a/tcpip/transport/udp/udp_test.go
+++ b/tcpip/transport/udp/udp_test.go
@@ -103,7 +103,7 @@
}
func (c *testContext) createV6Endpoint(v4only bool) {
- var err error
+ var err *tcpip.Error
c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
@@ -386,7 +386,7 @@
defer c.cleanup()
// Create v4 UDP endpoint.
- var err error
+ var err *tcpip.Error
c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
diff --git a/tcpip/transport/unix/connectioned.go b/tcpip/transport/unix/connectioned.go
index b4316b5..34dd051 100644
--- a/tcpip/transport/unix/connectioned.go
+++ b/tcpip/transport/unix/connectioned.go
@@ -20,7 +20,7 @@
}
// A ConnectingEndpoint is a connectioned unix endpoint that is attempting to
-// connect to a ConnectionedEndpoint.
+// establish a bidirectional connection with a BoundEndpoint.
type ConnectingEndpoint interface {
// ID returns the endpoint's globally unique identifier. This identifier
// must be used to determine locking order if more than one endpoint is
@@ -37,7 +37,7 @@
Type() SockType
// GetLocalAddress returns the bound path.
- GetLocalAddress() (tcpip.FullAddress, error)
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
// Locker protects the following methods. While locked, only the holder of
// the lock can change the return value of the protected methods.
@@ -212,11 +212,16 @@
}
// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
-func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) error {
+func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *tcpip.Error {
if ce.Type() != e.stype {
return tcpip.ErrConnectionRefused
}
+ // Check if ce is e to avoid a deadlock.
+ if ce, ok := ce.(*connectionedEndpoint); ok && ce == e {
+ return tcpip.ErrInvalidEndpointState
+ }
+
// Do a dance to safely acquire locks on both endpoints.
if e.id < ce.ID() {
e.Lock()
@@ -299,13 +304,13 @@
}
// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
-func (e *connectionedEndpoint) UnidirectionalConnect() (ConnectedEndpoint, error) {
+func (e *connectionedEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *tcpip.Error) {
return nil, tcpip.ErrConnectionRefused
}
// Connect attempts to directly connect to another Endpoint.
// Implements Endpoint.Connect.
-func (e *connectionedEndpoint) Connect(server BoundEndpoint) error {
+func (e *connectionedEndpoint) Connect(server BoundEndpoint) *tcpip.Error {
returnConnect := func(r Receiver, ce ConnectedEndpoint) {
e.receiver = r
e.connected = ce
@@ -315,7 +320,7 @@
}
// Listen starts listening on the connection.
-func (e *connectionedEndpoint) Listen(backlog int) error {
+func (e *connectionedEndpoint) Listen(backlog int) *tcpip.Error {
e.Lock()
defer e.Unlock()
if e.Listening() {
@@ -342,7 +347,7 @@
}
// Accept accepts a new connection.
-func (e *connectionedEndpoint) Accept() (Endpoint, error) {
+func (e *connectionedEndpoint) Accept() (Endpoint, *tcpip.Error) {
e.Lock()
defer e.Unlock()
@@ -368,12 +373,9 @@
//
// Bind will fail only if the socket is connected, bound or the passed address
// is invalid (the empty string).
-func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() error) error {
+func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
e.Lock()
defer e.Unlock()
- if e.Connected() {
- return tcpip.ErrAlreadyConnected
- }
if e.isBound() || e.Listening() {
return tcpip.ErrAlreadyBound
}
diff --git a/tcpip/transport/unix/connectionless.go b/tcpip/transport/unix/connectionless.go
index de1dead..1b37fed 100644
--- a/tcpip/transport/unix/connectionless.go
+++ b/tcpip/transport/unix/connectionless.go
@@ -59,12 +59,12 @@
}
// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
-func (e *connectionlessEndpoint) BidirectionalConnect(ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) error {
+func (e *connectionlessEndpoint) BidirectionalConnect(ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *tcpip.Error {
return tcpip.ErrConnectionRefused
}
// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
-func (e *connectionlessEndpoint) UnidirectionalConnect() (ConnectedEndpoint, error) {
+func (e *connectionlessEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *tcpip.Error) {
return &connectedEndpoint{
endpoint: e,
writeQueue: e.receiver.(*queueReceiver).readQueue,
@@ -73,7 +73,7 @@
// SendMsg writes data and a control message to the specified endpoint.
// This method does not block if the data cannot be written.
-func (e *connectionlessEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, error) {
+func (e *connectionlessEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *tcpip.Error) {
if to == nil {
return e.baseEndpoint.SendMsg(data, c, nil)
}
@@ -103,7 +103,7 @@
}
// Connect attempts to connect directly to server.
-func (e *connectionlessEndpoint) Connect(server BoundEndpoint) error {
+func (e *connectionlessEndpoint) Connect(server BoundEndpoint) *tcpip.Error {
connected, err := server.UnidirectionalConnect()
if err != nil {
return err
@@ -117,12 +117,12 @@
}
// Listen starts listening on the connection.
-func (e *connectionlessEndpoint) Listen(int) error {
+func (e *connectionlessEndpoint) Listen(int) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Accept accepts a new connection.
-func (e *connectionlessEndpoint) Accept() (Endpoint, error) {
+func (e *connectionlessEndpoint) Accept() (Endpoint, *tcpip.Error) {
return nil, tcpip.ErrNotSupported
}
@@ -134,7 +134,7 @@
//
// Bind will fail only if the socket is connected, bound or the passed address
// is invalid (the empty string).
-func (e *connectionlessEndpoint) Bind(addr tcpip.FullAddress, commit func() error) error {
+func (e *connectionlessEndpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
e.Lock()
defer e.Unlock()
if e.isBound() {
diff --git a/tcpip/transport/unix/unix.go b/tcpip/transport/unix/unix.go
index 97b949a..cdaffdb 100644
--- a/tcpip/transport/unix/unix.go
+++ b/tcpip/transport/unix/unix.go
@@ -68,13 +68,13 @@
// If peek is true, no data should be consumed from the Endpoint. Any and
// all data returned from a peek should be available in the next call to
// RecvMsg.
- RecvMsg(data [][]byte, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, ControlMessages, error)
+ RecvMsg(data [][]byte, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, ControlMessages, *tcpip.Error)
// SendMsg writes data and a control message to the endpoint's peer.
// This method does not block if the data cannot be written.
//
// SendMsg does not take ownership of any of its arguments on error.
- SendMsg([][]byte, ControlMessages, BoundEndpoint) (uintptr, error)
+ SendMsg([][]byte, ControlMessages, BoundEndpoint) (uintptr, *tcpip.Error)
// Connect connects this endpoint directly to another.
//
@@ -82,22 +82,22 @@
// endpoint passed in as a parameter.
//
// The error codes are the same as Connect.
- Connect(server BoundEndpoint) error
+ Connect(server BoundEndpoint) *tcpip.Error
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
- Shutdown(flags tcpip.ShutdownFlags) error
+ Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error
// Listen puts the endpoint in "listen" mode, which allows it to accept
// new connections.
- Listen(backlog int) error
+ Listen(backlog int) *tcpip.Error
// Accept returns a new endpoint if a peer has established a connection
// to an endpoint previously set to listen mode. This method does not
// block if no new connections are available.
//
// The returned Queue is the wait queue for the newly created endpoint.
- Accept() (Endpoint, error)
+ Accept() (Endpoint, *tcpip.Error)
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
@@ -105,26 +105,26 @@
// An optional commit function will be executed atomically with respect
// to binding the endpoint. If this returns an error, the bind will not
// occur and the error will be propagated back to the caller.
- Bind(address tcpip.FullAddress, commit func() error) error
+ Bind(address tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error
// Type return the socket type, typically either SockStream, SockDgram
// or SockSeqpacket.
Type() SockType
// GetLocalAddress returns the address to which the endpoint is bound.
- GetLocalAddress() (tcpip.FullAddress, error)
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
// GetRemoteAddress returns the address to which the endpoint is
// connected.
- GetRemoteAddress() (tcpip.FullAddress, error)
+ GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error)
// SetSockOpt sets a socket option. opt should be one of the tcpip.*Option
// types.
- SetSockOpt(opt interface{}) error
+ SetSockOpt(opt interface{}) *tcpip.Error
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// tcpip.*Option types.
- GetSockOpt(opt interface{}) error
+ GetSockOpt(opt interface{}) *tcpip.Error
}
// A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket
@@ -159,13 +159,13 @@
//
// This method will return tcpip.ErrConnectionRefused on endpoints with a
// type that isn't SockStream or SockSeqpacket.
- BidirectionalConnect(ep ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) error
+ BidirectionalConnect(ep ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *tcpip.Error
// UnidirectionalConnect establishes a write-only connection to a unix endpoint.
//
// This method will return tcpip.ErrConnectionRefused on a non-SockDgram
// endpoint.
- UnidirectionalConnect() (ConnectedEndpoint, error)
+ UnidirectionalConnect() (ConnectedEndpoint, *tcpip.Error)
// Release releases any resources held by the BoundEndpoint. It must be
// called before dropping all references to a BoundEndpoint returned by a
@@ -216,7 +216,7 @@
// Recv receives a single message. This method does not block.
//
// notify indicates if RecvNotify should be called.
- Recv(data [][]byte, numRights uintptr, peek bool) (n uintptr, cm ControlMessages, source tcpip.FullAddress, notify bool, err error)
+ Recv(data [][]byte, numRights uintptr, peek bool) (n uintptr, cm ControlMessages, source tcpip.FullAddress, notify bool, err *tcpip.Error)
// RecvNotify notifies the Receiver of a successful Recv. This must not be
// called while holding any endpoint locks.
@@ -235,9 +235,13 @@
// includes when read has been shutdown.
Readable() bool
- // QueuedSize returns the total amount of data currently receivable.
- // QueuedSize should return -1 if the operation isn't supported.
- QueuedSize() int64
+ // RecvQueuedSize returns the total amount of data currently receivable.
+ // RecvQueuedSize should return -1 if the operation isn't supported.
+ RecvQueuedSize() int64
+
+ // RecvMaxQueueSize returns maximum value for RecvQueuedSize.
+ // RecvMaxQueueSize should return -1 if the operation isn't supported.
+ RecvMaxQueueSize() int64
// Release releases any resources owned by the Receiver. It should be
// called before droping all references to a Receiver.
@@ -250,10 +254,10 @@
}
// Recv implements Receiver.Recv.
-func (q *queueReceiver) Recv(data [][]byte, numRights uintptr, peek bool) (uintptr, ControlMessages, tcpip.FullAddress, bool, error) {
+func (q *queueReceiver) Recv(data [][]byte, numRights uintptr, peek bool) (uintptr, ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) {
var m queue.Entry
var notify bool
- var err error
+ var err *tcpip.Error
if peek {
m, err = q.readQueue.Peek()
} else {
@@ -294,11 +298,16 @@
return q.readQueue.IsReadable()
}
-// QueuedSize implements Receiver.QueuedSize.
-func (q *queueReceiver) QueuedSize() int64 {
+// RecvQueuedSize implements Receiver.RecvQueuedSize.
+func (q *queueReceiver) RecvQueuedSize() int64 {
return q.readQueue.QueuedSize()
}
+// RecvMaxQueueSize implements ConnectedEndpoint.RecvMaxQueueSize.
+func (q *queueReceiver) RecvMaxQueueSize() int64 {
+ return q.readQueue.MaxQueueSize()
+}
+
// Release implements Receiver.Release.
func (*queueReceiver) Release() {}
@@ -312,8 +321,22 @@
addr tcpip.FullAddress
}
+func vecCopy(data [][]byte, buf []byte) (uintptr, [][]byte, []byte) {
+ var copied uintptr
+ for len(data) > 0 && len(buf) > 0 {
+ n := copy(data[0], buf)
+ copied += uintptr(n)
+ buf = buf[n:]
+ data[0] = data[0][n:]
+ if len(data[0]) == 0 {
+ data = data[1:]
+ }
+ }
+ return copied, data, buf
+}
+
// Recv implements Receiver.Recv.
-func (q *streamQueueReceiver) Recv(data [][]byte, numRights uintptr, peek bool) (uintptr, ControlMessages, tcpip.FullAddress, bool, error) {
+func (q *streamQueueReceiver) Recv(data [][]byte, numRights uintptr, peek bool) (uintptr, ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) {
q.mu.Lock()
defer q.mu.Unlock()
@@ -333,26 +356,57 @@
q.control = msg.Control
q.addr = msg.Address
}
- buf := q.buffer
+
var copied uintptr
- for _, d := range data {
- if len(buf) == 0 {
+ if peek {
+ var c ControlMessages
+ if q.control != nil {
+ // Don't consume control message if we are peeking.
+ c = q.control.Clone()
+ }
+
+ // Don't consume data since we are peeking.
+ copied, data, _ = vecCopy(data, q.buffer)
+
+ return copied, c, q.addr, notify, nil
+ }
+
+ // Consume data and control message since we are not peeking.
+ copied, data, q.buffer = vecCopy(data, q.buffer)
+ c := q.control
+ if c != nil {
+ q.control = c.CloneCreds()
+ }
+
+ if c != nil {
+ // FIXME: We don't support coalescing messages
+ // containing control messages.
+ return copied, c, q.addr, notify, nil
+ }
+
+ for len(data) > 0 {
+ m, n, err := q.readQueue.Dequeue()
+ if err != nil {
+ // We already got some data, so ignore this error. This will
+ // manifest as a short read to the user, which is what Linux
+ // does.
break
}
- n := copy(d, buf)
- copied += uintptr(n)
- buf = buf[n:]
- }
- c := q.control
- if !peek {
- // Consume data and control message if we are not peeking.
- if c != nil {
- q.control = c.CloneCreds()
+ notify = notify || n
+ msg := m.(*message)
+ q.buffer = []byte(msg.Data)
+ q.control = msg.Control
+ q.addr = msg.Address
+
+ if q.control != nil {
+ // FIXME: We don't support coalescing messages
+ // containing control messages.
+ break
}
- q.buffer = buf
- } else if q.control != nil {
- // Don't consume control message if we are peeking.
- c = c.Clone()
+
+ var cpd uintptr
+ cpd, data, q.buffer = vecCopy(data, q.buffer)
+ copied += cpd
}
return copied, c, q.addr, notify, nil
}
@@ -363,12 +417,12 @@
Passcred() bool
// GetLocalAddress implements Endpoint.GetLocalAddress.
- GetLocalAddress() (tcpip.FullAddress, error)
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
// Send sends a single message. This method does not block.
//
// notify indicates if SendNotify should be called.
- Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (n uintptr, notify bool, err error)
+ Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (n uintptr, notify bool, err *tcpip.Error)
// SendNotify notifies the ConnectedEndpoint of a successful Send. This
// must not be called while holding any endpoint locks.
@@ -391,6 +445,15 @@
// have changed.
EventUpdate()
+ // SendQueuedSize returns the total amount of data currently queued for
+ // sending. SendQueuedSize should return -1 if the operation isn't
+ // supported.
+ SendQueuedSize() int64
+
+ // SendMaxQueueSize returns maximum value for SendQueuedSize.
+ // SendMaxQueueSize should return -1 if the operation isn't supported.
+ SendMaxQueueSize() int64
+
// Release releases any resources owned by the ConnectedEndpoint. It should
// be called before droping all references to a ConnectedEndpoint.
Release()
@@ -406,7 +469,7 @@
Passcred() bool
// GetLocalAddress implements Endpoint.GetLocalAddress.
- GetLocalAddress() (tcpip.FullAddress, error)
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
// Type implements Endpoint.Type.
Type() SockType
@@ -421,12 +484,12 @@
}
// GetLocalAddress implements ConnectedEndpoint.GetLocalAddress.
-func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, error) {
+func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
return e.endpoint.GetLocalAddress()
}
// Send implements ConnectedEndpoint.Send.
-func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (uintptr, bool, error) {
+func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (uintptr, bool, *tcpip.Error) {
var l int
for _, d := range data {
l += len(d)
@@ -473,6 +536,16 @@
// EventUpdate implements ConnectedEndpoint.EventUpdate.
func (*connectedEndpoint) EventUpdate() {}
+// SendQueuedSize implements ConnectedEndpoint.SendQueuedSize.
+func (e *connectedEndpoint) SendQueuedSize() int64 {
+ return e.writeQueue.QueuedSize()
+}
+
+// SendMaxQueueSize implements ConnectedEndpoint.SendMaxQueueSize.
+func (e *connectedEndpoint) SendMaxQueueSize() int64 {
+ return e.writeQueue.MaxQueueSize()
+}
+
// Release implements ConnectedEndpoint.Release.
func (*connectedEndpoint) Release() {}
@@ -548,7 +621,7 @@
}
// RecvMsg reads data and a control message from the endpoint.
-func (e *baseEndpoint) RecvMsg(data [][]byte, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, ControlMessages, error) {
+func (e *baseEndpoint) RecvMsg(data [][]byte, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, ControlMessages, *tcpip.Error) {
e.Lock()
if e.receiver == nil {
@@ -574,7 +647,7 @@
// SendMsg writes data and a control message to the endpoint's peer.
// This method does not block if the data cannot be written.
-func (e *baseEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, error) {
+func (e *baseEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *tcpip.Error) {
e.Lock()
if !e.Connected() {
e.Unlock()
@@ -599,7 +672,7 @@
}
// SetSockOpt sets a socket option. Currently not supported.
-func (e *baseEndpoint) SetSockOpt(opt interface{}) error {
+func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
switch v := opt.(type) {
case tcpip.PasscredOption:
e.setPasscred(v != 0)
@@ -609,17 +682,30 @@
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *baseEndpoint) GetSockOpt(opt interface{}) error {
+func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
case tcpip.ErrorOption:
return nil
+ case *tcpip.SendQueueSizeOption:
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return tcpip.ErrNotConnected
+ }
+ qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize())
+ e.Unlock()
+ if qs < 0 {
+ return tcpip.ErrQueueSizeNotSupported
+ }
+ *o = qs
+ return nil
case *tcpip.ReceiveQueueSizeOption:
e.Lock()
if !e.Connected() {
e.Unlock()
return tcpip.ErrNotConnected
}
- qs := tcpip.ReceiveQueueSizeOption(e.receiver.QueuedSize())
+ qs := tcpip.ReceiveQueueSizeOption(e.receiver.RecvQueuedSize())
e.Unlock()
if qs < 0 {
return tcpip.ErrQueueSizeNotSupported
@@ -633,13 +719,39 @@
*o = tcpip.PasscredOption(0)
}
return nil
+ case *tcpip.SendBufferSizeOption:
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return tcpip.ErrNotConnected
+ }
+ qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize())
+ e.Unlock()
+ if qs < 0 {
+ return tcpip.ErrQueueSizeNotSupported
+ }
+ *o = qs
+ return nil
+ case *tcpip.ReceiveBufferSizeOption:
+ e.Lock()
+ if e.receiver == nil {
+ e.Unlock()
+ return tcpip.ErrNotConnected
+ }
+ qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize())
+ e.Unlock()
+ if qs < 0 {
+ return tcpip.ErrQueueSizeNotSupported
+ }
+ *o = qs
+ return nil
}
- return tcpip.ErrInvalidEndpointState
+ return tcpip.ErrUnknownProtocolOption
}
// Shutdown closes the read and/or write end of the endpoint connection to its
// peer.
-func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) error {
+func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
e.Lock()
if !e.Connected() {
e.Unlock()
@@ -668,7 +780,7 @@
}
// GetLocalAddress returns the bound path.
-func (e *baseEndpoint) GetLocalAddress() (tcpip.FullAddress, error) {
+func (e *baseEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
e.Lock()
defer e.Unlock()
return tcpip.FullAddress{Addr: tcpip.Address(e.path)}, nil
@@ -676,7 +788,7 @@
// GetRemoteAddress returns the local address of the connected endpoint (if
// available).
-func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, error) {
+func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.Lock()
c := e.connected
e.Unlock()
diff --git a/waiter/waiter.go b/waiter/waiter.go
index f3c5775..0cc09f2 100644
--- a/waiter/waiter.go
+++ b/waiter/waiter.go
@@ -88,14 +88,8 @@
EventUnregister(e *Entry)
}
-// Entry represents a waiter that can be add to the a wait queue. It can
-// only be in one queue at a time, and is added "intrusively" to the queue with
-// no extra memory allocations.
-type Entry struct {
- // Context stores any state the waiter may wish to store in the entry
- // itself, which may be used at wake up time.
- Context interface{}
-
+// EntryCallback provides a notify callback.
+type EntryCallback interface {
// Callback is the function to be called when the waiter entry is
// notified. It is responsible for doing whatever is needed to wake up
// the waiter.
@@ -103,13 +97,38 @@
// The callback is supposed to perform minimal work, and cannot call
// any method on the queue itself because it will be locked while the
// callback is running.
- Callback func(e *Entry)
+ Callback(e *Entry)
+}
+
+// Entry represents a waiter that can be add to the a wait queue. It can
+// only be in one queue at a time, and is added "intrusively" to the queue with
+// no extra memory allocations.
+type Entry struct {
+ // Context stores any state the waiter may wish to store in the entry
+ // itself, which may be used at wake up time.
+ //
+ // Note that use of this field is optional and state may alternatively be
+ // stored in the callback itself.
+ Context interface{}
+
+ Callback EntryCallback
// The following fields are protected by the queue lock.
mask EventMask
ilist.Entry
}
+type channelCallback struct{}
+
+// Callback implements EntryCallback.Callback.
+func (*channelCallback) Callback(e *Entry) {
+ ch := e.Context.(chan struct{})
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+}
+
// NewChannelEntry initializes a new Entry that does a non-blocking write to a
// struct{} channel when the callback is called. It returns the new Entry
// instance and the channel being used.
@@ -122,16 +141,7 @@
c = make(chan struct{}, 1)
}
- return Entry{
- Context: c,
- Callback: func(e *Entry) {
- ch := e.Context.(chan struct{})
- select {
- case ch <- struct{}{}:
- default:
- }
- },
- }, c
+ return Entry{Context: c, Callback: &channelCallback{}}, c
}
// Queue represents the wait queue where waiters can be added and
@@ -166,7 +176,7 @@
for it := q.list.Front(); it != nil; it = it.Next() {
e := it.(*Entry)
if (mask & e.mask) != 0 {
- e.Callback(e)
+ e.Callback.Callback(e)
}
}
q.mu.RUnlock()
diff --git a/waiter/waiter_test.go b/waiter/waiter_test.go
index e392c03..1a20335 100644
--- a/waiter/waiter_test.go
+++ b/waiter/waiter_test.go
@@ -9,6 +9,15 @@
"testing"
)
+type callbackStub struct {
+ f func(e *Entry)
+}
+
+// Callback implements EntryCallback.Callback.
+func (c *callbackStub) Callback(e *Entry) {
+ c.f(e)
+}
+
func TestEmptyQueue(t *testing.T) {
var q Queue
@@ -17,7 +26,7 @@
// Register then unregister a waiter, then notify the queue.
cnt := 0
- e := Entry{Callback: func(*Entry) { cnt++ }}
+ e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}}
q.EventRegister(&e, EventIn)
q.EventUnregister(&e)
q.Notify(EventIn)
@@ -30,7 +39,7 @@
// Register a waiter.
var q Queue
var cnt int
- e := Entry{Callback: func(*Entry) { cnt++ }}
+ e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}}
q.EventRegister(&e, EventIn|EventErr)
// Notify with an overlapping mask.
@@ -82,12 +91,12 @@
for i := 0; i < concurrency; i++ {
go func() {
var e Entry
- e.Callback = func(entry *Entry) {
+ e.Callback = &callbackStub{func(entry *Entry) {
cnt++
if entry != &e {
t.Errorf("entry = %p, want %p", entry, &e)
}
- }
+ }}
// Wait for notification, then register.
<-ch1
@@ -139,12 +148,12 @@
// Register waiters.
for i := 0; i < waiterCount; i++ {
var e Entry
- e.Callback = func(entry *Entry) {
+ e.Callback = &callbackStub{func(entry *Entry) {
atomic.AddInt32(&cnt, 1)
if entry != &e {
t.Errorf("entry = %p, want %p", entry, &e)
}
- }
+ }}
q.EventRegister(&e, EventIn|EventErr)
}