Merge remote-tracking branch 'upstream/master' into HEAD
Change-Id: I83e74e1786dafcab7f189904b43ee0ae919f40c8
diff --git a/tcpip/buffer/view.go b/tcpip/buffer/view.go
index 1a9d407..150310c 100644
--- a/tcpip/buffer/view.go
+++ b/tcpip/buffer/view.go
@@ -50,7 +50,7 @@
return NewVectorisedView(len(v), []View{v})
}
-// VectorisedView is a vectorised version of View using non contigous memory.
+// VectorisedView is a vectorised version of View using non contiguous memory.
// It supports all the convenience methods supported by View.
//
// +stateify savable
diff --git a/tcpip/link/fdbased/endpoint.go b/tcpip/link/fdbased/endpoint.go
index 252aeec..531de28 100644
--- a/tcpip/link/fdbased/endpoint.go
+++ b/tcpip/link/fdbased/endpoint.go
@@ -21,6 +21,22 @@
// FD based endpoints can be used in the networking stack by calling New() to
// create a new endpoint, and then passing it as an argument to
// Stack.CreateNIC().
+//
+// FD based endpoints can use more than one file descriptor to read incoming
+// packets. If there are more than one FDs specified and the underlying FD is an
+// AF_PACKET then the endpoint will enable FANOUT mode on the socket so that the
+// host kernel will consistently hash the packets to the sockets. This ensures
+// that packets for the same TCP streams are not reordered.
+//
+// Similarly if more than one FD's are specified where the underlying FD is not
+// AF_PACKET then it's the caller's responsibility to ensure that all inbound
+// packets on the descriptors are consistently 5 tuple hashed to one of the
+// descriptors to prevent TCP reordering.
+//
+// Since netstack today does not compute 5 tuple hashes for outgoing packets we
+// only use the first FD to write outbound packets. Once 5 tuple hashes for
+// all outbound packets are available we will make use of all underlying FD's to
+// write outbound packets.
package fdbased
import (
@@ -32,6 +48,7 @@
"github.com/google/netstack/tcpip/header"
"github.com/google/netstack/tcpip/link/rawfile"
"github.com/google/netstack/tcpip/stack"
+ "golang.org/x/sys/unix"
)
// linkDispatcher reads packets from the link FD and dispatches them to the
@@ -65,8 +82,10 @@
)
type endpoint struct {
- // fd is the file descriptor used to send and receive packets.
- fd int
+ // fds is the set of file descriptors each identifying one inbound/outbound
+ // channel. The endpoint will dispatch from all inbound channels as well as
+ // hash outbound packets to specific channels based on the packet hash.
+ fds []int
// mtu (maximum transmission unit) is the maximum size of a packet.
mtu uint32
@@ -85,8 +104,8 @@
// its end of the communication pipe.
closed func(*tcpip.Error)
- inboundDispatcher linkDispatcher
- dispatcher stack.NetworkDispatcher
+ inboundDispatchers []linkDispatcher
+ dispatcher stack.NetworkDispatcher
// packetDispatchMode controls the packet dispatcher used by this
// endpoint.
@@ -99,17 +118,47 @@
// Options specify the details about the fd-based endpoint to be created.
type Options struct {
- FD int
- MTU uint32
- EthernetHeader bool
- ClosedFunc func(*tcpip.Error)
- Address tcpip.LinkAddress
- SaveRestore bool
- DisconnectOk bool
- GSOMaxSize uint32
+ // FDs is a set of FDs used to read/write packets.
+ FDs []int
+
+ // MTU is the mtu to use for this endpoint.
+ MTU uint32
+
+ // EthernetHeader if true, indicates that the endpoint should read/write
+ // ethernet frames instead of IP packets.
+ EthernetHeader bool
+
+ // ClosedFunc is a function to be called when an endpoint's peer (if
+ // any) closes its end of the communication pipe.
+ ClosedFunc func(*tcpip.Error)
+
+ // Address is the link address for this endpoint. Only used if
+ // EthernetHeader is true.
+ Address tcpip.LinkAddress
+
+ // SaveRestore if true, indicates that this NIC capability set should
+ // include CapabilitySaveRestore
+ SaveRestore bool
+
+ // DisconnectOk if true, indicates that this NIC capability set should
+ // include CapabilityDisconnectOk.
+ DisconnectOk bool
+
+ // GSOMaxSize is the maximum GSO packet size. It is zero if GSO is
+ // disabled.
+ GSOMaxSize uint32
+
+ // PacketDispatchMode specifies the type of inbound dispatcher to be
+ // used for this endpoint.
PacketDispatchMode PacketDispatchMode
- TXChecksumOffload bool
- RXChecksumOffload bool
+
+ // TXChecksumOffload if true, indicates that this endpoints capability
+ // set should include CapabilityTXChecksumOffload.
+ TXChecksumOffload bool
+
+ // RXChecksumOffload if true, indicates that this endpoints capability
+ // set should include CapabilityRXChecksumOffload.
+ RXChecksumOffload bool
}
// New creates a new fd-based endpoint.
@@ -117,10 +166,6 @@
// Makes fd non-blocking, but does not take ownership of fd, which must remain
// open for the lifetime of the returned endpoint.
func New(opts *Options) (tcpip.LinkEndpointID, error) {
- if err := syscall.SetNonblock(opts.FD, true); err != nil {
- return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", opts.FD, err)
- }
-
caps := stack.LinkEndpointCapabilities(0)
if opts.RXChecksumOffload {
caps |= stack.CapabilityRXChecksumOffload
@@ -144,8 +189,12 @@
caps |= stack.CapabilityDisconnectOk
}
+ if len(opts.FDs) == 0 {
+ return 0, fmt.Errorf("opts.FD is empty, at least one FD must be specified")
+ }
+
e := &endpoint{
- fd: opts.FD,
+ fds: opts.FDs,
mtu: opts.MTU,
caps: caps,
closed: opts.ClosedFunc,
@@ -154,46 +203,71 @@
packetDispatchMode: opts.PacketDispatchMode,
}
- isSocket, err := isSocketFD(e.fd)
- if err != nil {
- return 0, err
- }
- if isSocket {
- if opts.GSOMaxSize != 0 {
- e.caps |= stack.CapabilityGSO
- e.gsoMaxSize = opts.GSOMaxSize
+ // Create per channel dispatchers.
+ for i := 0; i < len(e.fds); i++ {
+ fd := e.fds[i]
+ if err := syscall.SetNonblock(fd, true); err != nil {
+ return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err)
}
- }
- e.inboundDispatcher, err = createInboundDispatcher(e, isSocket)
- if err != nil {
- return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err)
+
+ isSocket, err := isSocketFD(fd)
+ if err != nil {
+ return 0, err
+ }
+ if isSocket {
+ if opts.GSOMaxSize != 0 {
+ e.caps |= stack.CapabilityGSO
+ e.gsoMaxSize = opts.GSOMaxSize
+ }
+ }
+ inboundDispatcher, err := createInboundDispatcher(e, fd, isSocket)
+ if err != nil {
+ return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err)
+ }
+ e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher)
}
return stack.RegisterLinkEndpoint(e), nil
}
-func createInboundDispatcher(e *endpoint, isSocket bool) (linkDispatcher, error) {
+func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher, error) {
// By default use the readv() dispatcher as it works with all kinds of
// FDs (tap/tun/unix domain sockets and af_packet).
- inboundDispatcher, err := newReadVDispatcher(e.fd, e)
+ inboundDispatcher, err := newReadVDispatcher(fd, e)
if err != nil {
- return nil, fmt.Errorf("newReadVDispatcher(%d, %+v) = %v", e.fd, e, err)
+ return nil, fmt.Errorf("newReadVDispatcher(%d, %+v) = %v", fd, e, err)
}
if isSocket {
+ sa, err := unix.Getsockname(fd)
+ if err != nil {
+ return nil, fmt.Errorf("unix.Getsockname(%d) = %v", fd, err)
+ }
+ switch sa.(type) {
+ case *unix.SockaddrLinklayer:
+ // enable PACKET_FANOUT mode is the underlying socket is
+ // of type AF_PACKET.
+ const fanoutID = 1
+ const fanoutType = 0x8000 // PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_DEFRAG
+ fanoutArg := fanoutID | fanoutType<<16
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil {
+ return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err)
+ }
+ }
+
switch e.packetDispatchMode {
case PacketMMap:
- inboundDispatcher, err = newPacketMMapDispatcher(e.fd, e)
+ inboundDispatcher, err = newPacketMMapDispatcher(fd, e)
if err != nil {
- return nil, fmt.Errorf("newPacketMMapDispatcher(%d, %+v) = %v", e.fd, e, err)
+ return nil, fmt.Errorf("newPacketMMapDispatcher(%d, %+v) = %v", fd, e, err)
}
case RecvMMsg:
// If the provided FD is a socket then we optimize
// packet reads by using recvmmsg() instead of read() to
// read packets in a batch.
- inboundDispatcher, err = newRecvMMsgDispatcher(e.fd, e)
+ inboundDispatcher, err = newRecvMMsgDispatcher(fd, e)
if err != nil {
- return nil, fmt.Errorf("newRecvMMsgDispatcher(%d, %+v) = %v", e.fd, e, err)
+ return nil, fmt.Errorf("newRecvMMsgDispatcher(%d, %+v) = %v", fd, e, err)
}
}
}
@@ -215,7 +289,9 @@
// Link endpoints are not savable. When transportation endpoints are
// saved, they stop sending outgoing packets and all incoming packets
// are rejected.
- go e.dispatchLoop()
+ for i := range e.inboundDispatchers {
+ go e.dispatchLoop(e.inboundDispatchers[i])
+ }
}
// IsAttached implements stack.LinkEndpoint.IsAttached.
@@ -305,26 +381,26 @@
}
}
- return rawfile.NonBlockingWrite3(e.fd, vnetHdrBuf, hdr.View(), payload.ToView())
+ return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, hdr.View(), payload.ToView())
}
if payload.Size() == 0 {
- return rawfile.NonBlockingWrite(e.fd, hdr.View())
+ return rawfile.NonBlockingWrite(e.fds[0], hdr.View())
}
- return rawfile.NonBlockingWrite3(e.fd, hdr.View(), payload.ToView(), nil)
+ return rawfile.NonBlockingWrite3(e.fds[0], hdr.View(), payload.ToView(), nil)
}
// WriteRawPacket writes a raw packet directly to the file descriptor.
func (e *endpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error {
- return rawfile.NonBlockingWrite(e.fd, packet)
+ return rawfile.NonBlockingWrite(e.fds[0], packet)
}
// dispatchLoop reads packets from the file descriptor in a loop and dispatches
// them to the network stack.
-func (e *endpoint) dispatchLoop() *tcpip.Error {
+func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) *tcpip.Error {
for {
- cont, err := e.inboundDispatcher.dispatch()
+ cont, err := inboundDispatcher.dispatch()
if err != nil || !cont {
if e.closed != nil {
e.closed(err)
@@ -363,7 +439,7 @@
syscall.SetNonblock(fd, true)
e := &InjectableEndpoint{endpoint: endpoint{
- fd: fd,
+ fds: []int{fd},
mtu: mtu,
caps: capabilities,
}}
diff --git a/tcpip/link/fdbased/endpoint_test.go b/tcpip/link/fdbased/endpoint_test.go
index e9ac164..b14edc5 100644
--- a/tcpip/link/fdbased/endpoint_test.go
+++ b/tcpip/link/fdbased/endpoint_test.go
@@ -67,7 +67,7 @@
done <- struct{}{}
}
- opt.FD = fds[1]
+ opt.FDs = []int{fds[1]}
epID, err := New(opt)
if err != nil {
t.Fatalf("Failed to create FD endpoint: %v", err)
diff --git a/tcpip/link/rawfile/errors.go b/tcpip/link/rawfile/errors.go
index 312645a..75a41c3 100644
--- a/tcpip/link/rawfile/errors.go
+++ b/tcpip/link/rawfile/errors.go
@@ -30,7 +30,7 @@
// TranslateErrno translate an errno from the syscall package into a
// *tcpip.Error.
//
-// Valid, but unreconigized errnos will be translated to
+// Valid, but unrecognized errnos will be translated to
// tcpip.ErrInvalidEndpointState (EINVAL). Panics on invalid errnos.
func TranslateErrno(e syscall.Errno) *tcpip.Error {
if err := translations[e]; err != nil {
diff --git a/tcpip/link/rawfile/rawfile_unsafe.go b/tcpip/link/rawfile/rawfile_unsafe.go
index 41f8562..7f12acb 100644
--- a/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/tcpip/link/rawfile/rawfile_unsafe.go
@@ -110,7 +110,7 @@
// BlockingRead reads from a file descriptor that is set up as non-blocking. If
// no data is available, it will block in a poll() syscall until the file
-// descirptor becomes readable.
+// descriptor becomes readable.
func BlockingRead(fd int, b []byte) (int, *tcpip.Error) {
for {
n, _, e := syscall.RawSyscall(syscall.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)))
diff --git a/tcpip/link/sharedmem/sharedmem_test.go b/tcpip/link/sharedmem/sharedmem_test.go
index a329ff8..4c90b6a 100644
--- a/tcpip/link/sharedmem/sharedmem_test.go
+++ b/tcpip/link/sharedmem/sharedmem_test.go
@@ -636,7 +636,7 @@
syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
// Wait for packet to be received, then check it.
- c.waitForPackets(1, time.After(time.Second), "Error waiting for packet")
+ c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet")
c.mu.Lock()
rcvd := []byte(c.packets[0].vv.First())
c.packets = c.packets[:0]
diff --git a/tcpip/link/sniffer/sniffer.go b/tcpip/link/sniffer/sniffer.go
index 59044d3..8e18041 100644
--- a/tcpip/link/sniffer/sniffer.go
+++ b/tcpip/link/sniffer/sniffer.go
@@ -118,7 +118,7 @@
// logs the packet before forwarding to the actual dispatcher.
func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
- logPacket("recv", protocol, vv.First())
+ logPacket("recv", protocol, vv.First(), nil)
}
if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
vs := vv.Views()
@@ -198,7 +198,7 @@
// the request to the lower endpoint.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
- logPacket("send", protocol, hdr.View())
+ logPacket("send", protocol, hdr.View(), gso)
}
if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
hdrBuf := hdr.View()
@@ -240,7 +240,7 @@
return e.lower.WritePacket(r, gso, hdr, payload, protocol)
}
-func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View) {
+func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
src := tcpip.Address("unknown")
@@ -404,5 +404,9 @@
return
}
+ if gso != nil {
+ details += fmt.Sprintf(" gso: %+v", gso)
+ }
+
log.Printf("%s %s %v:%v -> %v:%v len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details)
}
diff --git a/tcpip/network/fragmentation/fragmentation.go b/tcpip/network/fragmentation/fragmentation.go
index e2873a4..83c4ec0 100644
--- a/tcpip/network/fragmentation/fragmentation.go
+++ b/tcpip/network/fragmentation/fragmentation.go
@@ -60,7 +60,7 @@
// lowMemoryLimit specifies the limit on which we will reach by dropping
// fragments after reaching highMemoryLimit.
//
-// reassemblingTimeout specifes the maximum time allowed to reassemble a packet.
+// reassemblingTimeout specifies the maximum time allowed to reassemble a packet.
// Fragments are lazily evicted only when a new a packet with an
// already existing fragmentation-id arrives after the timeout.
func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation {
@@ -80,7 +80,7 @@
}
}
-// Process processes an incoming fragment beloning to an ID
+// Process processes an incoming fragment belonging to an ID
// and returns a complete packet when all the packets belonging to that ID have been received.
func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool) {
f.mu.Lock()
diff --git a/tcpip/sample/tun_tcp_connect/main.go b/tcpip/sample/tun_tcp_connect/main.go
index 6cc61dc..ae281c9 100644
--- a/tcpip/sample/tun_tcp_connect/main.go
+++ b/tcpip/sample/tun_tcp_connect/main.go
@@ -137,7 +137,7 @@
log.Fatal(err)
}
- linkID, err := fdbased.New(&fdbased.Options{FD: fd, MTU: mtu})
+ linkID, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu})
if err != nil {
log.Fatal(err)
}
diff --git a/tcpip/sample/tun_tcp_echo/main.go b/tcpip/sample/tun_tcp_echo/main.go
index 9f0bae8..f5b37bd 100644
--- a/tcpip/sample/tun_tcp_echo/main.go
+++ b/tcpip/sample/tun_tcp_echo/main.go
@@ -129,7 +129,7 @@
}
linkID, err := fdbased.New(&fdbased.Options{
- FD: fd,
+ FDs: []int{fd},
MTU: mtu,
EthernetHeader: *tap,
Address: tcpip.LinkAddress(maddr),
diff --git a/tcpip/stack/stack.go b/tcpip/stack/stack.go
index 8670b12..ab130f8 100644
--- a/tcpip/stack/stack.go
+++ b/tcpip/stack/stack.go
@@ -225,6 +225,45 @@
MaxSACKED seqnum.Value
}
+// RcvBufAutoTuneParams holds state related to TCP receive buffer auto-tuning.
+type RcvBufAutoTuneParams struct {
+ // MeasureTime is the time at which the current measurement
+ // was started.
+ MeasureTime time.Time
+
+ // CopiedBytes is the number of bytes copied to user space since
+ // this measure began.
+ CopiedBytes int
+
+ // PrevCopiedBytes is the number of bytes copied to user space in
+ // the previous RTT period.
+ PrevCopiedBytes int
+
+ // RcvBufSize is the auto tuned receive buffer size.
+ RcvBufSize int
+
+ // RTT is the smoothed RTT as measured by observing the time between
+ // when a byte is first acknowledged and the receipt of data that is at
+ // least one window beyond the sequence number that was acknowledged.
+ RTT time.Duration
+
+ // RTTVar is the "round-trip time variation" as defined in section 2
+ // of RFC6298.
+ RTTVar time.Duration
+
+ // RTTMeasureSeqNumber is the highest acceptable sequence number at the
+ // time this RTT measurement period began.
+ RTTMeasureSeqNumber seqnum.Value
+
+ // RTTMeasureTime is the absolute time at which the current RTT
+ // measurement period began.
+ RTTMeasureTime time.Time
+
+ // Disabled is true if an explicit receive buffer is set for the
+ // endpoint.
+ Disabled bool
+}
+
// TCPEndpointState is a copy of the internal state of a TCP endpoint.
type TCPEndpointState struct {
// ID is a copy of the TransportEndpointID for the endpoint.
@@ -240,6 +279,10 @@
// buffer for the endpoint.
RcvBufUsed int
+ // RcvBufAutoTuneParams is used to hold state variables to compute
+ // the auto tuned receive buffer size.
+ RcvAutoParams RcvBufAutoTuneParams
+
// RcvClosed if true, indicates the endpoint has been closed for reading.
RcvClosed bool
diff --git a/tcpip/stack/transport_test.go b/tcpip/stack/transport_test.go
index 8cbf9b8..f4ab791 100644
--- a/tcpip/stack/transport_test.go
+++ b/tcpip/stack/transport_test.go
@@ -188,6 +188,13 @@
f.proto.controlCount++
}
+func (f *fakeTransportEndpoint) State() uint32 {
+ return 0
+}
+
+func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {
+}
+
type fakeTransportGoodOption bool
type fakeTransportBadOption bool
diff --git a/tcpip/tcpip.go b/tcpip/tcpip.go
index 3c5d790..09478f4 100644
--- a/tcpip/tcpip.go
+++ b/tcpip/tcpip.go
@@ -377,6 +377,17 @@
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// *Option types.
GetSockOpt(opt interface{}) *Error
+
+ // State returns a socket's lifecycle state. The returned value is
+ // protocol-specific and is primarily used for diagnostics.
+ State() uint32
+
+ // ModerateRecvBuf should be called everytime data is copied to the user
+ // space. This allows for dynamic tuning of recv buffer space for a
+ // given socket.
+ //
+ // NOTE: This method is a no-op for sockets other than TCP.
+ ModerateRecvBuf(copied int)
}
// WriteOptions contains options for Endpoint.Write.
@@ -468,6 +479,18 @@
// closed.
type KeepaliveCountOption int
+// CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get
+// the current congestion control algorithm.
+type CongestionControlOption string
+
+// AvailableCongestionControlOption is used to query the supported congestion
+// control algorithms.
+type AvailableCongestionControlOption string
+
+// ModerateReceiveBufferOption allows the caller to enable/disable TCP receive
+// buffer moderation.
+type ModerateReceiveBufferOption bool
+
// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
// TTL value for multicast messages. The default is 1.
type MulticastTTLOption uint8
@@ -511,7 +534,7 @@
// Route is a row in the routing table. It specifies through which NIC (and
// gateway) sets of packets should be routed. A row is considered viable if the
-// masked target address matches the destination adddress in the row.
+// masked target address matches the destination address in the row.
type Route struct {
// Destination is the address that must be matched against the masked
// target address to check if this row is viable.
diff --git a/tcpip/transport/icmp/endpoint.go b/tcpip/transport/icmp/endpoint.go
index a707081..38e17be 100644
--- a/tcpip/transport/icmp/endpoint.go
+++ b/tcpip/transport/icmp/endpoint.go
@@ -127,6 +127,9 @@
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
+// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
+func (e *endpoint) ModerateRecvBuf(copied int) {}
+
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
@@ -708,3 +711,9 @@
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
}
+
+// State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't
+// expose internal socket state.
+func (e *endpoint) State() uint32 {
+ return 0
+}
diff --git a/tcpip/transport/raw/endpoint.go b/tcpip/transport/raw/endpoint.go
index af5b50e..c727339 100644
--- a/tcpip/transport/raw/endpoint.go
+++ b/tcpip/transport/raw/endpoint.go
@@ -16,7 +16,7 @@
// sockets allow applications to:
//
// * manually write and inspect transport layer headers and payloads
-// * receive all traffic of a given transport protcol (e.g. ICMP or UDP)
+// * receive all traffic of a given transport protocol (e.g. ICMP or UDP)
// * optionally write and inspect network layer and link layer headers for
// packets
//
@@ -147,6 +147,9 @@
ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
+// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
+func (ep *endpoint) ModerateRecvBuf(copied int) {}
+
// Read implements tcpip.Endpoint.Read.
func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
ep.rcvMu.Lock()
@@ -519,3 +522,8 @@
ep.waiterQueue.Notify(waiter.EventIn)
}
}
+
+// State implements socket.Socket.State.
+func (ep *endpoint) State() uint32 {
+ return 0
+}
diff --git a/tcpip/transport/tcp/accept.go b/tcpip/transport/tcp/accept.go
index aa97c50..57e7dca 100644
--- a/tcpip/transport/tcp/accept.go
+++ b/tcpip/transport/tcp/accept.go
@@ -213,6 +213,7 @@
n.route = s.route.Clone()
n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
n.rcvBufSize = int(l.rcvWnd)
+ n.amss = mssForRoute(&n.route)
n.maybeEnableTimestamp(rcvdSynOpts)
n.maybeEnableSACKPermitted(rcvdSynOpts)
@@ -226,14 +227,17 @@
}
n.isRegistered = true
- n.state = stateConnecting
// Create sender and receiver.
//
// The receiver at least temporarily has a zero receive window scale,
// but the caller may change it (before starting the protocol loop).
n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS)
- n.rcv = newReceiver(n, irs, l.rcvWnd, 0)
+ n.rcv = newReceiver(n, irs, seqnum.Size(n.initialReceiveWindow()), 0, seqnum.Size(n.receiveBufferSize()))
+ // Bootstrap the auto tuning algorithm. Starting at zero will result in
+ // a large step function on the first window adjustment causing the
+ // window to grow to a really large value.
+ n.rcvAutoParams.prevCopied = n.initialReceiveWindow()
return n, nil
}
@@ -250,16 +254,17 @@
}
// Perform the 3-way handshake.
- h := newHandshake(ep, l.rcvWnd)
+ h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow()))
- h.resetToSynRcvd(cookie, irs, opts, l.listenEP)
+ h.resetToSynRcvd(cookie, irs, opts)
if err := h.execute(); err != nil {
ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
ep.Close()
return nil, err
}
-
- ep.state = stateConnected
+ ep.mu.Lock()
+ ep.state = StateEstablished
+ ep.mu.Unlock()
// Update the receive window scaling. We can't do it before the
// handshake because it's possible that the peer doesn't support window
@@ -276,7 +281,7 @@
e.mu.RLock()
state := e.state
e.mu.RUnlock()
- if state == stateListen {
+ if state == StateListen {
e.acceptedChan <- n
e.waiterQueue.Notify(waiter.EventIn)
} else {
@@ -294,7 +299,6 @@
defer decSynRcvdCount()
defer e.decSynRcvdCount()
defer s.decRef()
-
n, err := ctx.createEndpointAndPerformHandshake(s, opts)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
@@ -306,7 +310,7 @@
func (e *endpoint) incSynRcvdCount() bool {
e.mu.Lock()
- if l, c := len(e.acceptedChan), cap(e.acceptedChan); l == c && e.synRcvdCount >= c {
+ if e.synRcvdCount >= cap(e.acceptedChan) {
e.mu.Unlock()
return false
}
@@ -321,6 +325,16 @@
e.mu.Unlock()
}
+func (e *endpoint) acceptQueueIsFull() bool {
+ e.mu.Lock()
+ if l, c := len(e.acceptedChan)+e.synRcvdCount, cap(e.acceptedChan); l >= c {
+ e.mu.Unlock()
+ return true
+ }
+ e.mu.Unlock()
+ return false
+}
+
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
@@ -328,35 +342,48 @@
case header.TCPFlagSyn:
opts := parseSynSegmentOptions(s)
if incSynRcvdCount() {
- // Drop the SYN if the listen endpoint's accept queue is
- // overflowing.
- if e.incSynRcvdCount() {
+ // Only handle the syn if the following conditions hold
+ // - accept queue is not full.
+ // - number of connections in synRcvd state is less than the
+ // backlog.
+ if !e.acceptQueueIsFull() && e.incSynRcvdCount() {
s.incRef()
go e.handleSynSegment(ctx, s, &opts)
return
}
+ decSynRcvdCount()
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
return
} else {
+ // If cookies are in use but the endpoint accept queue
+ // is full then drop the syn.
+ if e.acceptQueueIsFull() {
+ e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ }
cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
- // Send SYN with window scaling because we currently
+
+ // Send SYN without window scaling because we currently
// dont't encode this information in the cookie.
//
// Enable Timestamp option if the original syn did have
// the timestamp option specified.
+ mss := mssForRoute(&s.route)
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
TSVal: tcpTimeStamp(timeStampOffset()),
TSEcr: opts.TSVal,
+ MSS: uint16(mss),
}
sendSynTCP(&s.route, s.id, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
}
case header.TCPFlagAck:
- if len(e.acceptedChan) == cap(e.acceptedChan) {
+ if e.acceptQueueIsFull() {
// Silently drop the ack as the application can't accept
// the connection at this point. The ack will be
// retransmitted by the sender anyway and we can
@@ -406,7 +433,7 @@
n.tsOffset = 0
// Switch state to connected.
- n.state = stateConnected
+ n.state = StateEstablished
// Do the delivery in a separate goroutine so
// that we don't block the listen loop in case
@@ -429,7 +456,7 @@
// handleSynSegment() from attempting to queue new connections
// to the endpoint.
e.mu.Lock()
- e.state = stateClosed
+ e.state = StateClose
// Do cleanup if needed.
e.completeWorkerLocked()
diff --git a/tcpip/transport/tcp/connect.go b/tcpip/transport/tcp/connect.go
index 4c6020d..0e93598 100644
--- a/tcpip/transport/tcp/connect.go
+++ b/tcpip/transport/tcp/connect.go
@@ -60,12 +60,11 @@
// handshake holds the state used during a TCP 3-way handshake.
type handshake struct {
- ep *endpoint
- listenEP *endpoint // only non nil when doing passive connects.
- state handshakeState
- active bool
- flags uint8
- ackNum seqnum.Value
+ ep *endpoint
+ state handshakeState
+ active bool
+ flags uint8
+ ackNum seqnum.Value
// iss is the initial send sequence number, as defined in RFC 793.
iss seqnum.Value
@@ -79,6 +78,9 @@
// mss is the maximum segment size received from the peer.
mss uint16
+ // amss is the maximum segment size advertised by us to the peer.
+ amss uint16
+
// sndWndScale is the send window scale, as defined in RFC 1323. A
// negative value means no scaling is supported by the peer.
sndWndScale int
@@ -88,11 +90,24 @@
}
func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
+ rcvWndScale := ep.rcvWndScaleForHandshake()
+
+ // Round-down the rcvWnd to a multiple of wndScale. This ensures that the
+ // window offered in SYN won't be reduced due to the loss of precision if
+ // window scaling is enabled after the handshake.
+ rcvWnd = (rcvWnd >> uint8(rcvWndScale)) << uint8(rcvWndScale)
+
+ // Ensure we can always accept at least 1 byte if the scale specified
+ // was too high for the provided rcvWnd.
+ if rcvWnd == 0 {
+ rcvWnd = 1
+ }
+
h := handshake{
ep: ep,
active: true,
rcvWnd: rcvWnd,
- rcvWndScale: FindWndScale(rcvWnd),
+ rcvWndScale: int(rcvWndScale),
}
h.resetState()
return h
@@ -142,7 +157,7 @@
// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD
// state.
-func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, listenEP *endpoint) {
+func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) {
h.active = false
h.state = handshakeSynRcvd
h.flags = header.TCPFlagSyn | header.TCPFlagAck
@@ -150,7 +165,9 @@
h.ackNum = irs + 1
h.mss = opts.MSS
h.sndWndScale = opts.WS
- h.listenEP = listenEP
+ h.ep.mu.Lock()
+ h.ep.state = StateSynRecv
+ h.ep.mu.Unlock()
}
// checkAck checks if the ACK number, if present, of a segment received during
@@ -219,8 +236,11 @@
// but resend our own SYN and wait for it to be acknowledged in the
// SYN-RCVD state.
h.state = handshakeSynRcvd
+ h.ep.mu.Lock()
+ h.ep.state = StateSynRecv
+ h.ep.mu.Unlock()
synOpts := header.TCPSynOptions{
- WS: h.rcvWndScale,
+ WS: int(h.effectiveRcvWndScale()),
TS: rcvSynOpts.TS,
TSVal: h.ep.timestamp(),
TSEcr: h.ep.recentTS,
@@ -229,6 +249,7 @@
// permits SACK. This is not explicitly defined in the RFC but
// this is the behaviour implemented by Linux.
SACKPermitted: rcvSynOpts.SACKPermitted,
+ MSS: h.ep.amss,
}
sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
@@ -273,6 +294,7 @@
TSVal: h.ep.timestamp(),
TSEcr: h.ep.recentTS,
SACKPermitted: h.ep.sackPermitted,
+ MSS: h.ep.amss,
}
sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
return nil
@@ -281,23 +303,6 @@
// We have previously received (and acknowledged) the peer's SYN. If the
// peer acknowledges our SYN, the handshake is completed.
if s.flagIsSet(header.TCPFlagAck) {
- // listenContext is also used by a tcp.Forwarder and in that
- // context we do not have a listening endpoint to check the
- // backlog. So skip this check if listenEP is nil.
- if h.listenEP != nil {
- h.listenEP.mu.Lock()
- if len(h.listenEP.acceptedChan) == cap(h.listenEP.acceptedChan) {
- h.listenEP.mu.Unlock()
- // If there is no space in the accept queue to accept
- // this endpoint then silently drop this ACK. The peer
- // will anyway resend the ack and we can complete the
- // connection the next time it's retransmitted.
- h.ep.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
- h.ep.stack.Stats().DroppedPackets.Increment()
- return nil
- }
- h.listenEP.mu.Unlock()
- }
// If the timestamp option is negotiated and the segment does
// not carry a timestamp option then the segment must be dropped
// as per https://tools.ietf.org/html/rfc7323#section-3.2.
@@ -432,12 +437,15 @@
// Send the initial SYN segment and loop until the handshake is
// completed.
+ h.ep.amss = mssForRoute(&h.ep.route)
+
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
TS: true,
TSVal: h.ep.timestamp(),
TSEcr: h.ep.recentTS,
SACKPermitted: bool(sackEnabled),
+ MSS: h.ep.amss,
}
// Execute is also called in a listen context so we want to make sure we
@@ -446,6 +454,11 @@
if h.state == handshakeSynRcvd {
synOpts.TS = h.ep.sendTSOk
synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled)
+ if h.sndWndScale < 0 {
+ // Disable window scaling if the peer did not send us
+ // the window scaling option.
+ synOpts.WS = -1
+ }
}
sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
for h.state != handshakeCompleted {
@@ -567,13 +580,6 @@
}
func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
- // The MSS in opts is automatically calculated as this function is
- // called from many places and we don't want every call point being
- // embedded with the MSS calculation.
- if opts.MSS == 0 {
- opts.MSS = uint16(r.MTU() - header.TCPMinimumSize)
- }
-
options := makeSynOptions(opts)
err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options, nil)
putOptions(options)
@@ -668,7 +674,7 @@
// sendRaw sends a TCP segment to the endpoint's peer.
func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
var sackBlocks []header.SACKBlock
- if e.state == stateConnected && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
+ if e.state == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
@@ -719,8 +725,7 @@
// protocol goroutine.
func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
-
- e.state = stateError
+ e.state = StateError
e.hardError = err
}
@@ -875,15 +880,21 @@
// This is an active connection, so we must initiate the 3-way
// handshake, and then inform potential waiters about its
// completion.
- h := newHandshake(e, seqnum.Size(e.receiveBufferAvailable()))
+ initialRcvWnd := e.initialReceiveWindow()
+ h := newHandshake(e, seqnum.Size(initialRcvWnd))
+ e.mu.Lock()
+ h.ep.state = StateSynSent
+ e.mu.Unlock()
+
if err := h.execute(); err != nil {
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
e.mu.Lock()
- e.state = stateError
+ e.state = StateError
e.hardError = err
+
// Lock released below.
epilogue()
@@ -895,8 +906,14 @@
// (indicated by a negative send window scale).
e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+ rcvBufSize := seqnum.Size(e.receiveBufferSize())
e.rcvListMu.Lock()
- e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
+ e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
+ // boot strap the auto tuning algorithm. Starting at zero will
+ // result in a large step function on the first proper causing
+ // the window to just go to a really large value after the first
+ // RTT itself.
+ e.rcvAutoParams.prevCopied = initialRcvWnd
e.rcvListMu.Unlock()
}
@@ -905,7 +922,7 @@
// Tell waiters that the endpoint is connected and writable.
e.mu.Lock()
- e.state = stateConnected
+ e.state = StateEstablished
drained := e.drainDone != nil
e.mu.Unlock()
if drained {
@@ -1005,7 +1022,7 @@
return err
}
}
- if e.state != stateError {
+ if e.state != StateError {
close(e.drainDone)
<-e.undrain
}
@@ -1061,8 +1078,8 @@
// Mark endpoint as closed.
e.mu.Lock()
- if e.state != stateError {
- e.state = stateClosed
+ if e.state != StateError {
+ e.state = StateClose
}
// Lock released below.
epilogue()
diff --git a/tcpip/transport/tcp/cubic.go b/tcpip/transport/tcp/cubic.go
index e618cd2..fb0e58b 100644
--- a/tcpip/transport/tcp/cubic.go
+++ b/tcpip/transport/tcp/cubic.go
@@ -23,6 +23,7 @@
// control algorithm state.
//
// See: https://tools.ietf.org/html/rfc8312.
+// +stateify savable
type cubicState struct {
// wLastMax is the previous wMax value.
wLastMax float64
diff --git a/tcpip/transport/tcp/endpoint.go b/tcpip/transport/tcp/endpoint.go
index 1692202..1b84f63 100644
--- a/tcpip/transport/tcp/endpoint.go
+++ b/tcpip/transport/tcp/endpoint.go
@@ -17,6 +17,7 @@
import (
"fmt"
"math"
+ "strings"
"sync"
"sync/atomic"
"time"
@@ -32,18 +33,81 @@
"github.com/google/netstack/waiter"
)
-type endpointState int
+// EndpointState represents the state of a TCP endpoint.
+type EndpointState uint32
+// Endpoint states. Note that are represented in a netstack-specific manner and
+// may not be meaningful externally. Specifically, they need to be translated to
+// Linux's representation for these states if presented to userspace.
const (
- stateInitial endpointState = iota
- stateBound
- stateListen
- stateConnecting
- stateConnected
- stateClosed
- stateError
+ // Endpoint states internal to netstack. These map to the TCP state CLOSED.
+ StateInitial EndpointState = iota
+ StateBound
+ StateConnecting // Connect() called, but the initial SYN hasn't been sent.
+ StateError
+
+ // TCP protocol states.
+ StateEstablished
+ StateSynSent
+ StateSynRecv
+ StateFinWait1
+ StateFinWait2
+ StateTimeWait
+ StateClose
+ StateCloseWait
+ StateLastAck
+ StateListen
+ StateClosing
)
+// connected is the set of states where an endpoint is connected to a peer.
+func (s EndpointState) connected() bool {
+ switch s {
+ case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
+ return true
+ default:
+ return false
+ }
+}
+
+// String implements fmt.Stringer.String.
+func (s EndpointState) String() string {
+ switch s {
+ case StateInitial:
+ return "INITIAL"
+ case StateBound:
+ return "BOUND"
+ case StateConnecting:
+ return "CONNECTING"
+ case StateError:
+ return "ERROR"
+ case StateEstablished:
+ return "ESTABLISHED"
+ case StateSynSent:
+ return "SYN-SENT"
+ case StateSynRecv:
+ return "SYN-RCVD"
+ case StateFinWait1:
+ return "FIN-WAIT1"
+ case StateFinWait2:
+ return "FIN-WAIT2"
+ case StateTimeWait:
+ return "TIME-WAIT"
+ case StateClose:
+ return "CLOSED"
+ case StateCloseWait:
+ return "CLOSE-WAIT"
+ case StateLastAck:
+ return "LAST-ACK"
+ case StateListen:
+ return "LISTEN"
+ case StateClosing:
+ return "CLOSING"
+ default:
+ panic("unreachable")
+ }
+}
+
// Reasons for notifying the protocol goroutine.
const (
notifyNonZeroReceiveWindow = 1 << iota
@@ -68,6 +132,42 @@
NumBlocks int
}
+// rcvBufAutoTuneParams are used to hold state variables to compute
+// the auto tuned recv buffer size.
+//
+// +stateify savable
+type rcvBufAutoTuneParams struct {
+ // measureTime is the time at which the current measurement
+ // was started.
+ measureTime time.Time
+
+ // copied is the number of bytes copied out of the receive
+ // buffers since this measure began.
+ copied int
+
+ // prevCopied is the number of bytes copied out of the receive
+ // buffers in the previous RTT period.
+ prevCopied int
+
+ // rtt is the non-smoothed minimum RTT as measured by observing the time
+ // between when a byte is first acknowledged and the receipt of data
+ // that is at least one window beyond the sequence number that was
+ // acknowledged.
+ rtt time.Duration
+
+ // rttMeasureSeqNumber is the highest acceptable sequence number at the
+ // time this RTT measurement period began.
+ rttMeasureSeqNumber seqnum.Value
+
+ // rttMeasureTime is the absolute time at which the current rtt
+ // measurement period began.
+ rttMeasureTime time.Time
+
+ // disabled is true if an explicit receive buffer is set for the
+ // endpoint.
+ disabled bool
+}
+
// endpoint represents a TCP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
@@ -101,16 +201,27 @@
// to indicate to users that no more data is coming.
//
// rcvListMu can be taken after the endpoint mu below.
- rcvListMu sync.Mutex
- rcvList segmentList
- rcvClosed bool
- rcvBufSize int
- rcvBufUsed int
+ rcvListMu sync.Mutex
+ rcvList segmentList
+ rcvClosed bool
+ rcvBufSize int
+ rcvBufUsed int
+ rcvAutoParams rcvBufAutoTuneParams
+ // zeroWindow indicates that the window was closed due to receive buffer
+ // space being filled up. This is set by the worker goroutine before
+ // moving a segment to the rcvList. This setting is cleared by the
+ // endpoint when a Read() call reads enough data for the new window to
+ // be non-zero.
+ zeroWindow bool
// The following fields are protected by the mutex.
- mu sync.RWMutex
- id stack.TransportEndpointID
- state endpointState
+ mu sync.RWMutex
+ id stack.TransportEndpointID
+
+ // state endpointState
+ // pState ProtocolState
+ state EndpointState
+
isPortReserved bool
isRegistered bool
boundNICID tcpip.NICID
@@ -219,7 +330,7 @@
// cc stores the name of the Congestion Control algorithm to use for
// this endpoint.
- cc CongestionControlOption
+ cc tcpip.CongestionControlOption
// The following are used when a "packet too big" control packet is
// received. They are protected by sndBufMu. They are used to
@@ -271,6 +382,9 @@
bindAddress tcpip.Address
connectingAddress tcpip.Address
+ // amss is the advertised MSS to the peer by this endpoint.
+ amss uint16
+
gso *stack.GSO
}
@@ -304,8 +418,9 @@
stack: stack,
netProto: netProto,
waiterQueue: waiterQueue,
- rcvBufSize: DefaultBufferSize,
- sndBufSize: DefaultBufferSize,
+ state: StateInitial,
+ rcvBufSize: DefaultReceiveBufferSize,
+ sndBufSize: DefaultSendBufferSize,
sndMTU: int(math.MaxInt32),
reuseAddr: true,
keepalive: keepalive{
@@ -326,11 +441,16 @@
e.rcvBufSize = rs.Default
}
- var cs CongestionControlOption
+ var cs tcpip.CongestionControlOption
if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil {
e.cc = cs
}
+ var mrb tcpip.ModerateReceiveBufferOption
+ if err := stack.TransportProtocolOption(ProtocolNumber, &mrb); err == nil {
+ e.rcvAutoParams.disabled = !bool(mrb)
+ }
+
if p := stack.GetTCPProbe(); p != nil {
e.probe = p
}
@@ -339,6 +459,7 @@
e.workMu.Init()
e.workMu.Lock()
e.tsOffset = timeStampOffset()
+
return e
}
@@ -351,14 +472,14 @@
defer e.mu.RUnlock()
switch e.state {
- case stateInitial, stateBound, stateConnecting:
+ case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv:
// Ready for nothing.
- case stateClosed, stateError:
+ case StateClose, StateError:
// Ready for anything.
result = mask
- case stateListen:
+ case StateListen:
// Check if there's anything in the accepted channel.
if (mask & waiter.EventIn) != 0 {
if len(e.acceptedChan) > 0 {
@@ -366,7 +487,7 @@
}
}
- case stateConnected:
+ case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
// Determine if the endpoint is writable if requested.
if (mask & waiter.EventOut) != 0 {
e.sndBufMu.Lock()
@@ -427,7 +548,7 @@
// are immediately available for reuse after Close() is called. If also
// registered, we unregister as well otherwise the next user would fail
// in Listen() when trying to register.
- if e.state == stateListen && e.isPortReserved {
+ if e.state == StateListen && e.isPortReserved {
if e.isRegistered {
e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
e.isRegistered = false
@@ -482,20 +603,97 @@
tcpip.DeleteDanglingEndpoint(e)
}
+// initialReceiveWindow returns the initial receive window to advertise in the
+// SYN/SYN-ACK.
+func (e *endpoint) initialReceiveWindow() int {
+ rcvWnd := e.receiveBufferAvailable()
+ if rcvWnd > math.MaxUint16 {
+ rcvWnd = math.MaxUint16
+ }
+ routeWnd := InitialCwnd * int(mssForRoute(&e.route)) * 2
+ if rcvWnd > routeWnd {
+ rcvWnd = routeWnd
+ }
+ return rcvWnd
+}
+
+// ModerateRecvBuf adjusts the receive buffer and the advertised window
+// based on the number of bytes copied to user space.
+func (e *endpoint) ModerateRecvBuf(copied int) {
+ e.rcvListMu.Lock()
+ if e.rcvAutoParams.disabled {
+ e.rcvListMu.Unlock()
+ return
+ }
+ now := time.Now()
+ if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt {
+ e.rcvAutoParams.copied += copied
+ e.rcvListMu.Unlock()
+ return
+ }
+ prevRTTCopied := e.rcvAutoParams.copied + copied
+ prevCopied := e.rcvAutoParams.prevCopied
+ rcvWnd := 0
+ if prevRTTCopied > prevCopied {
+ // The minimal receive window based on what was copied by the app
+ // in the immediate preceding RTT and some extra buffer for 16
+ // segments to account for variations.
+ // We multiply by 2 to account for packet losses.
+ rcvWnd = prevRTTCopied*2 + 16*int(e.amss)
+
+ // Scale for slow start based on bytes copied in this RTT vs previous.
+ grow := (rcvWnd * (prevRTTCopied - prevCopied)) / prevCopied
+
+ // Multiply growth factor by 2 again to account for sender being
+ // in slow-start where the sender grows it's congestion window
+ // by 100% per RTT.
+ rcvWnd += grow * 2
+
+ // Make sure auto tuned buffer size can always receive upto 2x
+ // the initial window of 10 segments.
+ if minRcvWnd := int(e.amss) * InitialCwnd * 2; rcvWnd < minRcvWnd {
+ rcvWnd = minRcvWnd
+ }
+
+ // Cap the auto tuned buffer size by the maximum permissible
+ // receive buffer size.
+ if max := e.maxReceiveBufferSize(); rcvWnd > max {
+ rcvWnd = max
+ }
+
+ // We do not adjust downwards as that can cause the receiver to
+ // reject valid data that might already be in flight as the
+ // acceptable window will shrink.
+ if rcvWnd > e.rcvBufSize {
+ e.rcvBufSize = rcvWnd
+ e.notifyProtocolGoroutine(notifyReceiveWindowChanged)
+ }
+
+ // We only update prevCopied when we grow the buffer because in cases
+ // where prevCopied > prevRTTCopied the existing buffer is already big
+ // enough to handle the current rate and we don't need to do any
+ // adjustments.
+ e.rcvAutoParams.prevCopied = prevRTTCopied
+ }
+ e.rcvAutoParams.measureTime = now
+ e.rcvAutoParams.copied = 0
+ e.rcvListMu.Unlock()
+}
+
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.mu.RLock()
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data. Also note that a RST being received
- // would cause the state to become stateError so we should allow the
+ // would cause the state to become StateError so we should allow the
// reads to proceed before returning a ECONNRESET.
e.rcvListMu.Lock()
bufUsed := e.rcvBufUsed
- if s := e.state; s != stateConnected && s != stateClosed && bufUsed == 0 {
+ if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
he := e.hardError
e.mu.RUnlock()
- if s == stateError {
+ if s == StateError {
return buffer.View{}, tcpip.ControlMessages{}, he
}
return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
@@ -511,7 +709,7 @@
func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
if e.rcvBufUsed == 0 {
- if e.rcvClosed || e.state != stateConnected {
+ if e.rcvClosed || !e.state.connected() {
return buffer.View{}, tcpip.ErrClosedForReceive
}
return buffer.View{}, tcpip.ErrWouldBlock
@@ -527,10 +725,12 @@
s.decRef()
}
- scale := e.rcv.rcvWndScale
- wasZero := e.zeroReceiveWindow(scale)
e.rcvBufUsed -= len(v)
- if wasZero && !e.zeroReceiveWindow(scale) {
+ // If the window was zero before this read and if the read freed up
+ // enough buffer space for the scaled window to be non-zero then notify
+ // the protocol goroutine to send a window update.
+ if e.zeroWindow && !e.zeroReceiveWindow(e.rcv.rcvWndScale) {
+ e.zeroWindow = false
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
@@ -547,9 +747,9 @@
defer e.mu.RUnlock()
// The endpoint cannot be written to if it's not connected.
- if e.state != stateConnected {
+ if !e.state.connected() {
switch e.state {
- case stateError:
+ case StateError:
return 0, nil, e.hardError
default:
return 0, nil, tcpip.ErrClosedForSend
@@ -612,8 +812,8 @@
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data.
- if s := e.state; s != stateConnected && s != stateClosed {
- if s == stateError {
+ if s := e.state; !s.connected() && s != StateClose {
+ if s == StateError {
return 0, tcpip.ControlMessages{}, e.hardError
}
return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
@@ -623,7 +823,7 @@
defer e.rcvListMu.Unlock()
if e.rcvBufUsed == 0 {
- if e.rcvClosed || e.state != stateConnected {
+ if e.rcvClosed || !e.state.connected() {
return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
}
return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
@@ -750,9 +950,10 @@
size = math.MaxInt32 / 2
}
- wasZero := e.zeroReceiveWindow(scale)
e.rcvBufSize = size
- if wasZero && !e.zeroReceiveWindow(scale) {
+ e.rcvAutoParams.disabled = true
+ if e.zeroWindow && !e.zeroReceiveWindow(scale) {
+ e.zeroWindow = false
mask |= notifyNonZeroReceiveWindow
}
e.rcvListMu.Unlock()
@@ -789,7 +990,7 @@
defer e.mu.Unlock()
// We only allow this to be set when we're in the initial state.
- if e.state != stateInitial {
+ if e.state != StateInitial {
return tcpip.ErrInvalidEndpointState
}
@@ -830,6 +1031,40 @@
e.mu.Unlock()
return nil
+ case tcpip.CongestionControlOption:
+ // Query the available cc algorithms in the stack and
+ // validate that the specified algorithm is actually
+ // supported in the stack.
+ var avail tcpip.AvailableCongestionControlOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil {
+ return err
+ }
+ availCC := strings.Split(string(avail), " ")
+ for _, cc := range availCC {
+ if v == tcpip.CongestionControlOption(cc) {
+ // Acquire the work mutex as we may need to
+ // reinitialize the congestion control state.
+ e.mu.Lock()
+ state := e.state
+ e.cc = v
+ e.mu.Unlock()
+ switch state {
+ case StateEstablished:
+ e.workMu.Lock()
+ e.mu.Lock()
+ if e.state == state {
+ e.snd.cc = e.snd.initCongestionControl(e.cc)
+ }
+ e.mu.Unlock()
+ e.workMu.Unlock()
+ }
+ return nil
+ }
+ }
+
+ // Linux returns ENOENT when an invalid congestion
+ // control algorithm is specified.
+ return tcpip.ErrNoSuchFile
default:
return nil
}
@@ -841,7 +1076,7 @@
defer e.mu.RUnlock()
// The endpoint cannot be in listen state.
- if e.state == stateListen {
+ if e.state == StateListen {
return 0, tcpip.ErrInvalidEndpointState
}
@@ -999,6 +1234,12 @@
}
return nil
+ case *tcpip.CongestionControlOption:
+ e.mu.Lock()
+ *o = e.cc
+ e.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -1057,7 +1298,7 @@
nicid := addr.NIC
switch e.state {
- case stateBound:
+ case StateBound:
// If we're already bound to a NIC but the caller is requesting
// that we use a different one now, we cannot proceed.
if e.boundNICID == 0 {
@@ -1070,16 +1311,16 @@
nicid = e.boundNICID
- case stateInitial:
- // Nothing to do. We'll eventually fill-in the gaps in the ID
- // (if any) when we find a route.
+ case StateInitial:
+ // Nothing to do. We'll eventually fill-in the gaps in the ID (if any)
+ // when we find a route.
- case stateConnecting:
- // A connection request has already been issued but hasn't
- // completed yet.
+ case StateConnecting, StateSynSent, StateSynRecv:
+ // A connection request has already been issued but hasn't completed
+ // yet.
return tcpip.ErrAlreadyConnecting
- case stateConnected:
+ case StateEstablished:
// The endpoint is already connected. If caller hasn't been notified yet, return success.
if !e.isConnectNotified {
e.isConnectNotified = true
@@ -1088,7 +1329,7 @@
// Otherwise return that it's already connected.
return tcpip.ErrAlreadyConnected
- case stateError:
+ case StateError:
return e.hardError
default:
@@ -1154,7 +1395,7 @@
}
e.isRegistered = true
- e.state = stateConnecting
+ e.state = StateConnecting
e.route = r.Clone()
e.boundNICID = nicid
e.effectiveNetProtos = netProtos
@@ -1175,7 +1416,7 @@
}
e.segmentQueue.mu.Unlock()
e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
- e.state = stateConnected
+ e.state = StateEstablished
}
if run {
@@ -1199,8 +1440,8 @@
defer e.mu.Unlock()
e.shutdownFlags |= flags
- switch e.state {
- case stateConnected:
+ switch {
+ case e.state.connected():
// Close for read.
if (e.shutdownFlags & tcpip.ShutdownRead) != 0 {
// Mark read side as closed.
@@ -1241,7 +1482,7 @@
e.sndCloseWaker.Assert()
}
- case stateListen:
+ case e.state == StateListen:
// Tell protocolListenLoop to stop.
if flags&tcpip.ShutdownRead != 0 {
e.notifyProtocolGoroutine(notifyClose)
@@ -1269,7 +1510,7 @@
// When the endpoint shuts down, it sets workerCleanup to true, and from
// that point onward, acceptedChan is the responsibility of the cleanup()
// method (and should not be touched anywhere else, including here).
- if e.state == stateListen && !e.workerCleanup {
+ if e.state == StateListen && !e.workerCleanup {
// Adjust the size of the channel iff we can fix existing
// pending connections into the new one.
if len(e.acceptedChan) > backlog {
@@ -1288,7 +1529,7 @@
}
// Endpoint must be bound before it can transition to listen mode.
- if e.state != stateBound {
+ if e.state != StateBound {
return tcpip.ErrInvalidEndpointState
}
@@ -1298,7 +1539,7 @@
}
e.isRegistered = true
- e.state = stateListen
+ e.state = StateListen
if e.acceptedChan == nil {
e.acceptedChan = make(chan *endpoint, backlog)
}
@@ -1325,7 +1566,7 @@
defer e.mu.RUnlock()
// Endpoint must be in listen state before it can accept connections.
- if e.state != stateListen {
+ if e.state != StateListen {
return nil, nil, tcpip.ErrInvalidEndpointState
}
@@ -1353,7 +1594,7 @@
// Don't allow binding once endpoint is not in the initial state
// anymore. This is because once the endpoint goes into a connected or
// listen state, it is already bound.
- if e.state != stateInitial {
+ if e.state != StateInitial {
return tcpip.ErrAlreadyBound
}
@@ -1408,7 +1649,7 @@
}
// Mark endpoint as bound.
- e.state = stateBound
+ e.state = StateBound
return nil
}
@@ -1430,7 +1671,7 @@
e.mu.RLock()
defer e.mu.RUnlock()
- if e.state != stateConnected {
+ if !e.state.connected() {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -1514,6 +1755,13 @@
if s != nil {
s.incRef()
e.rcvBufUsed += s.data.Size()
+ // Check if the receive window is now closed. If so make sure
+ // we set the zero window before we deliver the segment to ensure
+ // that a subsequent read of the segment will correctly trigger
+ // a non-zero notification.
+ if avail := e.receiveBufferAvailableLocked(); avail>>e.rcv.rcvWndScale == 0 {
+ e.zeroWindow = true
+ }
e.rcvList.PushBack(s)
} else {
e.rcvClosed = true
@@ -1523,21 +1771,26 @@
e.waiterQueue.Notify(waiter.EventIn)
}
+// receiveBufferAvailableLocked calculates how many bytes are still available
+// in the receive buffer.
+// rcvListMu must be held when this function is called.
+func (e *endpoint) receiveBufferAvailableLocked() int {
+ // We may use more bytes than the buffer size when the receive buffer
+ // shrinks.
+ if e.rcvBufUsed >= e.rcvBufSize {
+ return 0
+ }
+
+ return e.rcvBufSize - e.rcvBufUsed
+}
+
// receiveBufferAvailable calculates how many bytes are still available in the
// receive buffer.
func (e *endpoint) receiveBufferAvailable() int {
e.rcvListMu.Lock()
- size := e.rcvBufSize
- used := e.rcvBufUsed
+ available := e.receiveBufferAvailableLocked()
e.rcvListMu.Unlock()
-
- // We may use more bytes than the buffer size when the receive buffer
- // shrinks.
- if used >= size {
- return 0
- }
-
- return size - used
+ return available
}
func (e *endpoint) receiveBufferSize() int {
@@ -1548,6 +1801,33 @@
return size
}
+func (e *endpoint) maxReceiveBufferSize() int {
+ var rs ReceiveBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil {
+ // As a fallback return the hardcoded max buffer size.
+ return MaxBufferSize
+ }
+ return rs.Max
+}
+
+// rcvWndScaleForHandshake computes the receive window scale to offer to the
+// peer when window scaling is enabled (true by default). If auto-tuning is
+// disabled then the window scaling factor is based on the size of the
+// receiveBuffer otherwise we use the max permissible receive buffer size to
+// compute the scale.
+func (e *endpoint) rcvWndScaleForHandshake() int {
+ bufSizeForScale := e.receiveBufferSize()
+
+ e.rcvListMu.Lock()
+ autoTuningDisabled := e.rcvAutoParams.disabled
+ e.rcvListMu.Unlock()
+ if autoTuningDisabled {
+ return FindWndScale(seqnum.Size(bufSizeForScale))
+ }
+
+ return FindWndScale(seqnum.Size(e.maxReceiveBufferSize()))
+}
+
// updateRecentTimestamp updates the recent timestamp using the algorithm
// described in https://tools.ietf.org/html/rfc7323#section-4.3
func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) {
@@ -1640,6 +1920,13 @@
s.RcvBufSize = e.rcvBufSize
s.RcvBufUsed = e.rcvBufUsed
s.RcvClosed = e.rcvClosed
+ s.RcvAutoParams.MeasureTime = e.rcvAutoParams.measureTime
+ s.RcvAutoParams.CopiedBytes = e.rcvAutoParams.copied
+ s.RcvAutoParams.PrevCopiedBytes = e.rcvAutoParams.prevCopied
+ s.RcvAutoParams.RTT = e.rcvAutoParams.rtt
+ s.RcvAutoParams.RTTMeasureSeqNumber = e.rcvAutoParams.rttMeasureSeqNumber
+ s.RcvAutoParams.RTTMeasureTime = e.rcvAutoParams.rttMeasureTime
+ s.RcvAutoParams.Disabled = e.rcvAutoParams.disabled
e.rcvListMu.Unlock()
// Endpoint TCP Option state.
@@ -1693,13 +1980,13 @@
RTTMeasureTime: e.snd.rttMeasureTime,
Closed: e.snd.closed,
RTO: e.snd.rto,
- SRTTInited: e.snd.srttInited,
MaxPayloadSize: e.snd.maxPayloadSize,
SndWndScale: e.snd.sndWndScale,
MaxSentAck: e.snd.maxSentAck,
}
e.snd.rtt.Lock()
s.Sender.SRTT = e.snd.rtt.srtt
+ s.Sender.SRTTInited = e.snd.rtt.srttInited
e.snd.rtt.Unlock()
if cubic, ok := e.snd.cc.(*cubicState); ok {
@@ -1739,3 +2026,15 @@
gso.MaxSize = e.route.GSOMaxSize()
e.gso = gso
}
+
+// State implements tcpip.Endpoint.State. It exports the endpoint's protocol
+// state for diagnostics.
+func (e *endpoint) State() uint32 {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return uint32(e.state)
+}
+
+func mssForRoute(r *stack.Route) uint16 {
+ return uint16(r.MTU() - header.TCPMinimumSize)
+}
diff --git a/tcpip/transport/tcp/forwarder.go b/tcpip/transport/tcp/forwarder.go
index 6fe5afe..7e308a7 100644
--- a/tcpip/transport/tcp/forwarder.go
+++ b/tcpip/transport/tcp/forwarder.go
@@ -47,7 +47,7 @@
// If rcvWnd is set to zero, the default buffer size is used instead.
func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*ForwarderRequest)) *Forwarder {
if rcvWnd == 0 {
- rcvWnd = DefaultBufferSize
+ rcvWnd = DefaultReceiveBufferSize
}
return &Forwarder{
maxInFlight: maxInFlight,
diff --git a/tcpip/transport/tcp/protocol.go b/tcpip/transport/tcp/protocol.go
index 17ebfa1..c889dee 100644
--- a/tcpip/transport/tcp/protocol.go
+++ b/tcpip/transport/tcp/protocol.go
@@ -41,13 +41,18 @@
ProtocolNumber = header.TCPProtocolNumber
// MinBufferSize is the smallest size of a receive or send buffer.
- minBufferSize = 4 << 10 // 4096 bytes.
+ MinBufferSize = 4 << 10 // 4096 bytes.
- // DefaultBufferSize is the default size of the receive and send buffers.
- DefaultBufferSize = 1 << 20 // 1MB
+ // DefaultSendBufferSize is the default size of the send buffer for
+ // an endpoint.
+ DefaultSendBufferSize = 1 << 20 // 1MB
- // MaxBufferSize is the largest size a receive and send buffer can grow to.
- maxBufferSize = 4 << 20 // 4MB
+ // DefaultReceiveBufferSize is the default size of the receive buffer
+ // for an endpoint.
+ DefaultReceiveBufferSize = 1 << 20 // 1MB
+
+ // MaxBufferSize is the largest size a receive/send buffer can grow to.
+ MaxBufferSize = 4 << 20 // 4MB
// MaxUnprocessedSegments is the maximum number of unprocessed segments
// that can be queued for a given endpoint.
@@ -79,13 +84,6 @@
ccCubic = "cubic"
)
-// CongestionControlOption sets the current congestion control algorithm.
-type CongestionControlOption string
-
-// AvailableCongestionControlOption returns the supported congestion control
-// algorithms.
-type AvailableCongestionControlOption string
-
type protocol struct {
mu sync.Mutex
sackEnabled bool
@@ -93,7 +91,7 @@
recvBufferSize ReceiveBufferSizeOption
congestionControl string
availableCongestionControl []string
- allowedCongestionControl []string
+ moderateReceiveBuffer bool
}
// Number returns the tcp protocol number.
@@ -188,7 +186,7 @@
p.mu.Unlock()
return nil
- case CongestionControlOption:
+ case tcpip.CongestionControlOption:
for _, c := range p.availableCongestionControl {
if string(v) == c {
p.mu.Lock()
@@ -197,7 +195,16 @@
return nil
}
}
- return tcpip.ErrInvalidOptionValue
+ // linux returns ENOENT when an invalid congestion control
+ // is specified.
+ return tcpip.ErrNoSuchFile
+
+ case tcpip.ModerateReceiveBufferOption:
+ p.mu.Lock()
+ p.moderateReceiveBuffer = bool(v)
+ p.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -223,16 +230,25 @@
*v = p.recvBufferSize
p.mu.Unlock()
return nil
- case *CongestionControlOption:
+
+ case *tcpip.CongestionControlOption:
p.mu.Lock()
- *v = CongestionControlOption(p.congestionControl)
+ *v = tcpip.CongestionControlOption(p.congestionControl)
p.mu.Unlock()
return nil
- case *AvailableCongestionControlOption:
+
+ case *tcpip.AvailableCongestionControlOption:
p.mu.Lock()
- *v = AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
+ *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
p.mu.Unlock()
return nil
+
+ case *tcpip.ModerateReceiveBufferOption:
+ p.mu.Lock()
+ *v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer)
+ p.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -241,8 +257,8 @@
func init() {
stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
return &protocol{
- sendBufferSize: SendBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
- recvBufferSize: ReceiveBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
+ sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
+ recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
congestionControl: ccReno,
availableCongestionControl: []string{ccReno, ccCubic},
}
diff --git a/tcpip/transport/tcp/rcv.go b/tcpip/transport/tcp/rcv.go
index e0d55a0..a834f96 100644
--- a/tcpip/transport/tcp/rcv.go
+++ b/tcpip/transport/tcp/rcv.go
@@ -16,6 +16,7 @@
import (
"container/heap"
+ "time"
"github.com/google/netstack/tcpip/header"
"github.com/google/netstack/tcpip/seqnum"
@@ -38,6 +39,9 @@
// shrinking it.
rcvAcc seqnum.Value
+ // rcvWnd is the non-scaled receive window last advertised to the peer.
+ rcvWnd seqnum.Size
+
rcvWndScale uint8
closed bool
@@ -47,13 +51,14 @@
pendingBufSize seqnum.Size
}
-func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
+func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver {
return &receiver{
ep: ep,
rcvNxt: irs + 1,
rcvAcc: irs.Add(rcvWnd + 1),
+ rcvWnd: rcvWnd,
rcvWndScale: rcvWndScale,
- pendingBufSize: rcvWnd,
+ pendingBufSize: pendingBufSize,
}
}
@@ -72,14 +77,16 @@
// getSendParams returns the parameters needed by the sender when building
// segments to send.
func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
- // Calculate the window size based on the current buffer size.
- n := r.ep.receiveBufferAvailable()
- acc := r.rcvNxt.Add(seqnum.Size(n))
+ // Calculate the window size based on the available buffer space.
+ receiveBufferAvailable := r.ep.receiveBufferAvailable()
+ acc := r.rcvNxt.Add(seqnum.Size(receiveBufferAvailable))
if r.rcvAcc.LessThan(acc) {
r.rcvAcc = acc
}
-
- return r.rcvNxt, r.rcvNxt.Size(r.rcvAcc) >> r.rcvWndScale
+ // Stash away the non-scaled receive window as we use it for measuring
+ // receiver's estimated RTT.
+ r.rcvWnd = r.rcvNxt.Size(r.rcvAcc)
+ return r.rcvNxt, r.rcvWnd >> r.rcvWndScale
}
// nonZeroWindow is called when the receive window grows from zero to nonzero;
@@ -130,10 +137,26 @@
// Update the segment that we're expecting to consume.
r.rcvNxt = segSeq.Add(segLen)
+ // In cases of a misbehaving sender which could send more than the
+ // advertised window, we could end up in a situation where we get a
+ // segment that exceeds the window advertised. Instead of partially
+ // accepting the segment and discarding bytes beyond the advertised
+ // window, we accept the whole segment and make sure r.rcvAcc is moved
+ // forward to match r.rcvNxt to indicate that the window is now closed.
+ //
+ // In absence of this check the r.acceptable() check fails and accepts
+ // segments that should be dropped because rcvWnd is calculated as
+ // the size of the interval (rcvNxt, rcvAcc] which becomes extremely
+ // large if rcvAcc is ever less than rcvNxt.
+ if r.rcvAcc.LessThan(r.rcvNxt) {
+ r.rcvAcc = r.rcvNxt
+ }
+
// Trim SACK Blocks to remove any SACK information that covers
// sequence numbers that have been consumed.
TrimSACKBlockList(&r.ep.sack, r.rcvNxt)
+ // Handle FIN or FIN-ACK.
if s.flagIsSet(header.TCPFlagFin) {
r.rcvNxt++
@@ -144,6 +167,25 @@
r.closed = true
r.ep.readyToRead(nil)
+ // We just received a FIN, our next state depends on whether we sent a
+ // FIN already or not.
+ r.ep.mu.Lock()
+ switch r.ep.state {
+ case StateEstablished:
+ r.ep.state = StateCloseWait
+ case StateFinWait1:
+ if s.flagIsSet(header.TCPFlagAck) {
+ // FIN-ACK, transition to TIME-WAIT.
+ r.ep.state = StateTimeWait
+ } else {
+ // Simultaneous close, expecting a final ACK.
+ r.ep.state = StateClosing
+ }
+ case StateFinWait2:
+ r.ep.state = StateTimeWait
+ }
+ r.ep.mu.Unlock()
+
// Flush out any pending segments, except the very first one if
// it happens to be the one we're handling now because the
// caller is using it.
@@ -156,11 +198,61 @@
r.pendingRcvdSegments[i].decRef()
}
r.pendingRcvdSegments = r.pendingRcvdSegments[:first]
+
+ return true
+ }
+
+ // Handle ACK (not FIN-ACK, which we handled above) during one of the
+ // shutdown states.
+ if s.flagIsSet(header.TCPFlagAck) {
+ r.ep.mu.Lock()
+ switch r.ep.state {
+ case StateFinWait1:
+ r.ep.state = StateFinWait2
+ case StateClosing:
+ r.ep.state = StateTimeWait
+ case StateLastAck:
+ r.ep.state = StateClose
+ }
+ r.ep.mu.Unlock()
}
return true
}
+// updateRTT updates the receiver RTT measurement based on the sequence number
+// of the received segment.
+func (r *receiver) updateRTT() {
+ // From: https://public.lanl.gov/radiant/pubs/drs/sc2001-poster.pdf
+ //
+ // A system that is only transmitting acknowledgements can still
+ // estimate the round-trip time by observing the time between when a byte
+ // is first acknowledged and the receipt of data that is at least one
+ // window beyond the sequence number that was acknowledged.
+ r.ep.rcvListMu.Lock()
+ if r.ep.rcvAutoParams.rttMeasureTime.IsZero() {
+ // New measurement.
+ r.ep.rcvAutoParams.rttMeasureTime = time.Now()
+ r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd)
+ r.ep.rcvListMu.Unlock()
+ return
+ }
+ if r.rcvNxt.LessThan(r.ep.rcvAutoParams.rttMeasureSeqNumber) {
+ r.ep.rcvListMu.Unlock()
+ return
+ }
+ rtt := time.Since(r.ep.rcvAutoParams.rttMeasureTime)
+ // We only store the minimum observed RTT here as this is only used in
+ // absence of a SRTT available from either timestamps or a sender
+ // measurement of RTT.
+ if r.ep.rcvAutoParams.rtt == 0 || rtt < r.ep.rcvAutoParams.rtt {
+ r.ep.rcvAutoParams.rtt = rtt
+ }
+ r.ep.rcvAutoParams.rttMeasureTime = time.Now()
+ r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd)
+ r.ep.rcvListMu.Unlock()
+}
+
// handleRcvdSegment handles TCP segments directed at the connection managed by
// r as they arrive. It is called by the protocol main loop.
func (r *receiver) handleRcvdSegment(s *segment) {
@@ -189,10 +281,9 @@
r.pendingBufUsed += s.logicalLen()
s.incRef()
heap.Push(&r.pendingRcvdSegments, s)
+ UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
}
- UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
-
// Immediately send an ack so that the peer knows it may
// have to retransmit.
r.ep.snd.sendAck()
@@ -200,6 +291,12 @@
return
}
+ // Since we consumed a segment update the receiver's RTT estimate
+ // if required.
+ if segLen > 0 {
+ r.updateRTT()
+ }
+
// By consuming the current segment, we may have filled a gap in the
// sequence number domain that allows pending segments to be consumed
// now. So try to do it.
diff --git a/tcpip/transport/tcp/sack.go b/tcpip/transport/tcp/sack.go
index be749e4..e3c6291 100644
--- a/tcpip/transport/tcp/sack.go
+++ b/tcpip/transport/tcp/sack.go
@@ -31,6 +31,13 @@
// segment identified by segStart->segEnd.
func UpdateSACKBlocks(sack *SACKInfo, segStart seqnum.Value, segEnd seqnum.Value, rcvNxt seqnum.Value) {
newSB := header.SACKBlock{Start: segStart, End: segEnd}
+
+ // Ignore any invalid SACK blocks or blocks that are before rcvNxt as
+ // those bytes have already been acked.
+ if newSB.End.LessThanEq(newSB.Start) || newSB.End.LessThan(rcvNxt) {
+ return
+ }
+
if sack.NumBlocks == 0 {
sack.Blocks[0] = newSB
sack.NumBlocks = 1
@@ -39,9 +46,8 @@
var n = 0
for i := 0; i < sack.NumBlocks; i++ {
start, end := sack.Blocks[i].Start, sack.Blocks[i].End
- if end.LessThanEq(start) || start.LessThanEq(rcvNxt) {
- // Discard any invalid blocks where end is before start
- // and discard any sack blocks that are before rcvNxt as
+ if end.LessThanEq(rcvNxt) {
+ // Discard any sack blocks that are before rcvNxt as
// those have already been acked.
continue
}
diff --git a/tcpip/transport/tcp/snd.go b/tcpip/transport/tcp/snd.go
index 80e4bdd..175b152 100644
--- a/tcpip/transport/tcp/snd.go
+++ b/tcpip/transport/tcp/snd.go
@@ -121,9 +121,8 @@
// rtt.srtt, rtt.rttvar, and rto are the "smoothed round-trip time",
// "round-trip time variation" and "retransmit timeout", as defined in
// section 2 of RFC 6298.
- rtt rtt
- rto time.Duration
- srttInited bool
+ rtt rtt
+ rto time.Duration
// maxPayloadSize is the maximum size of the payload of a given segment.
// It is initialized on demand.
@@ -150,8 +149,9 @@
type rtt struct {
sync.Mutex
- srtt time.Duration
- rttvar time.Duration
+ srtt time.Duration
+ rttvar time.Duration
+ srttInited bool
}
// fastRecovery holds information related to fast recovery from a packet loss.
@@ -194,8 +194,6 @@
s := &sender{
ep: ep,
- sndCwnd: InitialCwnd,
- sndSsthresh: math.MaxInt64,
sndWnd: sndWnd,
sndUna: iss + 1,
sndNxt: iss + 1,
@@ -238,7 +236,13 @@
return s
}
-func (s *sender) initCongestionControl(congestionControlName CongestionControlOption) congestionControl {
+// initCongestionControl initializes the specified congestion control module and
+// returns a handle to it. It also initializes the sndCwnd and sndSsThresh to
+// their initial values.
+func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionControlOption) congestionControl {
+ s.sndCwnd = InitialCwnd
+ s.sndSsthresh = math.MaxInt64
+
switch congestionControlName {
case ccCubic:
return newCubicCC(s)
@@ -319,10 +323,10 @@
// available. This is done in accordance with section 2 of RFC 6298.
func (s *sender) updateRTO(rtt time.Duration) {
s.rtt.Lock()
- if !s.srttInited {
+ if !s.rtt.srttInited {
s.rtt.rttvar = rtt / 2
s.rtt.srtt = rtt
- s.srttInited = true
+ s.rtt.srttInited = true
} else {
diff := s.rtt.srtt - rtt
if diff < 0 {
@@ -632,6 +636,10 @@
}
seg.flags = header.TCPFlagAck | header.TCPFlagFin
segEnd = seg.sequenceNumber.Add(1)
+ // Transition to FIN-WAIT1 state since we're initiating an active close.
+ s.ep.mu.Lock()
+ s.ep.state = StateFinWait1
+ s.ep.mu.Unlock()
} else {
// We're sending a non-FIN segment.
if seg.flags&header.TCPFlagFin != 0 {
@@ -779,7 +787,7 @@
break
}
dataSent = true
- s.outstanding++
+ s.outstanding += s.pCount(seg)
s.writeNext = seg.Next()
}
}
diff --git a/tcpip/transport/tcp/tcp_test.go b/tcpip/transport/tcp/tcp_test.go
index 7d60642..a9f602f 100644
--- a/tcpip/transport/tcp/tcp_test.go
+++ b/tcpip/transport/tcp/tcp_test.go
@@ -168,8 +168,8 @@
// Receive the SYN-ACK reply.
b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcp.SequenceNumber())
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
ackHeaders := &context.Headers{
SrcPort: context.TestPort,
@@ -269,8 +269,8 @@
time.Sleep(3 * time.Second)
for {
b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- if tcp.Flags() == header.TCPFlagAck|header.TCPFlagFin {
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin {
// This is a retransmit of the FIN, ignore it.
continue
}
@@ -553,9 +553,13 @@
// We shouldn't consume a sequence number on RST.
checker.SeqNum(uint32(c.IRS)+1),
))
+ // The RST puts the endpoint into an error state.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
- // This final should be ignored because an ACK on a reset doesn't
- // mean anything.
+ // This final ACK should be ignored because an ACK on a reset doesn't mean
+ // anything.
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
@@ -618,6 +622,10 @@
checker.SeqNum(uint32(c.IRS)+1),
))
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
// Cause a RST to be generated by closing the read end now since we have
// unread data.
c.EP.Shutdown(tcpip.ShutdownRead)
@@ -630,6 +638,10 @@
// We shouldn't consume a sequence number on RST.
checker.SeqNum(uint32(c.IRS)+1),
))
+ // The RST puts the endpoint into an error state.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
// The ACK to the FIN should now be rejected since the connection has been
// closed by a RST.
@@ -1098,8 +1110,9 @@
t.Fatalf("Listen failed: %v", err)
}
- // Do 3-way handshake.
- c.PassiveConnect(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS})
+ // Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN
+ // should not carry the window scaling option.
+ c.PassiveConnect(100, -1, header.TCPSynOptions{MSS: defaultIPv4MSS})
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -1510,8 +1523,8 @@
for bytesReceived != dataLen {
b := c.GetPacket()
numPackets++
- tcp := header.TCP(header.IPv4(b).Payload())
- payloadLen := len(tcp.Payload())
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ payloadLen := len(tcpHdr.Payload())
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
@@ -1522,7 +1535,7 @@
)
pdata := data[bytesReceived : bytesReceived+payloadLen]
- if p := tcp.Payload(); !bytes.Equal(pdata, p) {
+ if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) {
t.Fatalf("got data = %v, want = %v", p, pdata)
}
bytesReceived += payloadLen
@@ -1530,7 +1543,7 @@
if c.TimeStampEnabled {
// If timestamp option is enabled, echo back the timestamp and increment
// the TSEcr value included in the packet and send that back as the TSVal.
- parsedOpts := tcp.ParsedOptions()
+ parsedOpts := tcpHdr.ParsedOptions()
tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:])
options = tsOpt[:]
@@ -1588,7 +1601,6 @@
// Set the buffer size to a deterministic size so that we can check the
// window scaling option.
const rcvBufferSize = 0x20000
- const wndScale = 2
if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1602,7 +1614,7 @@
}
// Do 3-way handshake.
- c.PassiveConnect(maxPayload, wndScale, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
+ c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
@@ -1701,7 +1713,7 @@
s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket)
// Do 3-way handshake.
- c.PassiveConnect(maxPayload, 1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
+ c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
// Wait for connection to be available.
select {
@@ -1757,8 +1769,8 @@
),
)
- tcp := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcp.SequenceNumber())
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
// Wait for retransmit.
time.Sleep(1 * time.Second)
@@ -1766,8 +1778,8 @@
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagSyn),
- checker.SrcPort(tcp.SourcePort()),
- checker.SeqNum(tcp.SequenceNumber()),
+ checker.SrcPort(tcpHdr.SourcePort()),
+ checker.SeqNum(tcpHdr.SequenceNumber()),
checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}),
),
)
@@ -1775,8 +1787,8 @@
// Send SYN-ACK.
iss := seqnum.Value(789)
c.SendPacket(nil, &context.Headers{
- SrcPort: tcp.DestinationPort(),
- DstPort: tcp.SourcePort(),
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
Flags: header.TCPFlagSyn | header.TCPFlagAck,
SeqNum: iss,
AckNum: c.IRS.Add(1),
@@ -2523,8 +2535,8 @@
checker.TCPFlags(header.TCPFlagAck),
),
)
- tcp := header.TCP(header.IPv4(b).Payload())
- ack := seqnum.Value(tcp.AckNumber())
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ ack := seqnum.Value(tcpHdr.AckNumber())
if ack == last {
break
}
@@ -2568,6 +2580,10 @@
),
)
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
// Send some data and acknowledge the FIN.
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -2589,9 +2605,15 @@
),
)
- // Give the stack the chance to transition to closed state.
+ // Give the stack the chance to transition to closed state. Note that since
+ // both the sender and receiver are now closed, we effectively skip the
+ // TIME-WAIT state.
time.Sleep(1 * time.Second)
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
// Wait for receive to be notified.
select {
case <-ch:
@@ -2745,11 +2767,11 @@
}
}()
- checkSendBufferSize(t, ep, tcp.DefaultBufferSize)
- checkRecvBufferSize(t, ep, tcp.DefaultBufferSize)
+ checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize)
+ checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default send buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultBufferSize * 2, tcp.DefaultBufferSize * 20}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize * 2, tcp.DefaultSendBufferSize * 20}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
@@ -2759,11 +2781,11 @@
t.Fatalf("NewEndpoint failed; %v", err)
}
- checkSendBufferSize(t, ep, tcp.DefaultBufferSize*2)
- checkRecvBufferSize(t, ep, tcp.DefaultBufferSize)
+ checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
+ checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default receive buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultBufferSize * 3, tcp.DefaultBufferSize * 30}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize * 3, tcp.DefaultReceiveBufferSize * 30}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
@@ -2773,8 +2795,8 @@
t.Fatalf("NewEndpoint failed; %v", err)
}
- checkSendBufferSize(t, ep, tcp.DefaultBufferSize*2)
- checkRecvBufferSize(t, ep, tcp.DefaultBufferSize*3)
+ checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
+ checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*3)
}
func TestMinMaxBufferSizes(t *testing.T) {
@@ -2788,11 +2810,11 @@
defer ep.Close()
// Change the min/max values for send/receive
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultBufferSize * 2, tcp.DefaultBufferSize * 20}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultReceiveBufferSize * 2, tcp.DefaultReceiveBufferSize * 20}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultBufferSize * 3, tcp.DefaultBufferSize * 30}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultSendBufferSize * 3, tcp.DefaultSendBufferSize * 30}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
@@ -2810,17 +2832,17 @@
checkSendBufferSize(t, ep, 300)
// Set values above the max.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultBufferSize*20)); err != nil {
+ if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultReceiveBufferSize*20)); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
- checkRecvBufferSize(t, ep, tcp.DefaultBufferSize*20)
+ checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20)
- if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultBufferSize*30)); err != nil {
+ if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultSendBufferSize*30)); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
- checkSendBufferSize(t, ep, tcp.DefaultBufferSize*30)
+ checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30)
}
func makeStack() (*stack.Stack, *tcpip.Error) {
@@ -3183,13 +3205,14 @@
}
}
-func TestSetCongestionControl(t *testing.T) {
+func TestStackSetCongestionControl(t *testing.T) {
testCases := []struct {
- cc tcp.CongestionControlOption
- mustPass bool
+ cc tcpip.CongestionControlOption
+ err *tcpip.Error
}{
- {"reno", true},
- {"cubic", true},
+ {"reno", nil},
+ {"cubic", nil},
+ {"blahblah", tcpip.ErrNoSuchFile},
}
for _, tc := range testCases {
@@ -3199,62 +3222,135 @@
s := c.Stack()
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != nil && tc.mustPass {
- t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want not-nil", tcp.ProtocolNumber, tc.cc, err)
+ var oldCC tcpip.CongestionControlOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err)
}
- var cc tcp.CongestionControlOption
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err {
+ t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err)
+ }
+
+ var cc tcpip.CongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
}
- if got, want := cc, tc.cc; got != want {
+
+ got, want := cc, oldCC
+ // If SetTransportProtocolOption is expected to succeed
+ // then the returned value for congestion control should
+ // match the one specified in the
+ // SetTransportProtocolOption call above, else it should
+ // be what it was before the call to
+ // SetTransportProtocolOption.
+ if tc.err == nil {
+ want = tc.cc
+ }
+ if got != want {
t.Fatalf("got congestion control: %v, want: %v", got, want)
}
})
}
}
-func TestAvailableCongestionControl(t *testing.T) {
+func TestStackAvailableCongestionControl(t *testing.T) {
c := context.New(t, 1500)
defer c.Cleanup()
s := c.Stack()
// Query permitted congestion control algorithms.
- var aCC tcp.AvailableCongestionControlOption
+ var aCC tcpip.AvailableCongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err)
}
- if got, want := aCC, tcp.AvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want)
+ if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
}
}
-func TestSetAvailableCongestionControl(t *testing.T) {
+func TestStackSetAvailableCongestionControl(t *testing.T) {
c := context.New(t, 1500)
defer c.Cleanup()
s := c.Stack()
// Setting AvailableCongestionControlOption should fail.
- aCC := tcp.AvailableCongestionControlOption("xyz")
+ aCC := tcpip.AvailableCongestionControlOption("xyz")
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC)
}
// Verify that we still get the expected list of congestion control options.
- var cc tcp.AvailableCongestionControlOption
+ var cc tcpip.AvailableCongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil {
t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err)
}
- if got, want := cc, tcp.AvailableCongestionControlOption("reno cubic"); got != want {
- t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want)
+ if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want {
+ t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want)
+ }
+}
+
+func TestEndpointSetCongestionControl(t *testing.T) {
+ testCases := []struct {
+ cc tcpip.CongestionControlOption
+ err *tcpip.Error
+ }{
+ {"reno", nil},
+ {"cubic", nil},
+ {"blahblah", tcpip.ErrNoSuchFile},
+ }
+
+ for _, connected := range []bool{false, true} {
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) {
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ var oldCC tcpip.CongestionControlOption
+ if err := c.EP.GetSockOpt(&oldCC); err != nil {
+ t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err)
+ }
+
+ if connected {
+ c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil)
+ }
+
+ if err := c.EP.SetSockOpt(tc.cc); err != tc.err {
+ t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err)
+ }
+
+ var cc tcpip.CongestionControlOption
+ if err := c.EP.GetSockOpt(&cc); err != nil {
+ t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err)
+ }
+
+ got, want := cc, oldCC
+ // If SetSockOpt is expected to succeed then the
+ // returned value for congestion control should match
+ // the one specified in the SetSockOpt above, else it
+ // should be what it was before the call to SetSockOpt.
+ if tc.err == nil {
+ want = tc.cc
+ }
+ if got != want {
+ t.Fatalf("got congestion control: %v, want: %v", got, want)
+ }
+ })
+ }
}
}
func enableCUBIC(t *testing.T, c *context.Context) {
t.Helper()
- opt := tcp.CongestionControlOption("cubic")
+ opt := tcpip.CongestionControlOption("cubic")
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil {
t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err)
}
@@ -3383,7 +3479,7 @@
RcvWnd: 30000,
})
- // Receive the SYN-ACK reply.
+ // Receive the SYN-ACK reply.w
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
iss = seqnum.Value(tcp.SequenceNumber())
@@ -3447,12 +3543,18 @@
time.Sleep(50 * time.Millisecond)
- // Now execute one more handshake. This should not be completed and
- // delivered on an Accept() call as the backlog is full at this point.
- irs, iss := executeHandshake(t, c, context.TestPort+uint16(listenBacklog), false /* synCookieInUse */)
+ // Now execute send one more SYN. The stack should not respond as the backlog
+ // is full at this point.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort + 2,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: seqnum.Value(789),
+ RcvWnd: 30000,
+ })
+ c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
- time.Sleep(50 * time.Millisecond)
- // Try to accept the connection.
+ // Try to accept the connections in the backlog.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
@@ -3484,16 +3586,8 @@
}
}
- // Now craft the ACK again and verify that the connection is now ready
- // to be accepted.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + uint16(listenBacklog),
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
+ // Now a new handshake must succeed.
+ executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */)
newEP, _, err := c.EP.Accept()
if err == tcpip.ErrWouldBlock {
@@ -3509,6 +3603,7 @@
t.Fatalf("Timed out waiting for accept")
}
}
+
// Now verify that the TCP socket is usable and in a connected state.
data := "Don't panic"
newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
@@ -3519,6 +3614,110 @@
}
}
+func TestListenSynRcvdQueueFull(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ // Start listening.
+ listenBacklog := 1
+ if err := c.EP.Listen(listenBacklog); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ // Send two SYN's the first one should get a SYN-ACK, the
+ // second one should not get any response and is dropped as
+ // the synRcvd count will be equal to backlog.
+ irs := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: seqnum.Value(789),
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcp := header.TCP(header.IPv4(b).Payload())
+ iss := seqnum.Value(tcp.SequenceNumber())
+ tcpCheckers := []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
+ checker.AckNum(uint32(irs) + 1),
+ }
+ checker.IPv4(t, b, checker.TCP(tcpCheckers...))
+
+ // Now execute send one more SYN. The stack should not respond as the backlog
+ // is full at this point.
+ //
+ // NOTE: we did not complete the handshake for the previous one so the
+ // accept backlog should be empty and there should be one connection in
+ // synRcvd state.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort + 1,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: seqnum.Value(889),
+ RcvWnd: 30000,
+ })
+ c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
+
+ // Now complete the previous connection and verify that there is a connection
+ // to accept.
+ // Send ACK.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+
+ // Try to accept the connections in the backlog.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ newEP, _, err := c.EP.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ newEP, _, err = c.EP.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %v", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Now verify that the TCP socket is usable and in a connected state.
+ data := "Don't panic"
+ newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
+ pkt := c.GetPacket()
+ tcp = header.TCP(header.IPv4(pkt).Payload())
+ if string(tcp.Payload()) != data {
+ t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
+ }
+}
+
func TestListenBacklogFullSynCookieInUse(t *testing.T) {
saved := tcp.SynRcvdCountThreshold
defer func() {
@@ -3554,26 +3753,17 @@
// Wait for this to be delivered to the accept queue.
time.Sleep(50 * time.Millisecond)
- nonCookieIRS, nonCookieISS := executeHandshake(t, c, context.TestPort+portOffset, false)
-
- // Since the backlog is full at this point this connection will not
- // transition out of handshake and ignore the ACK.
- //
- // At this point there should be 1 completed connection in the backlog
- // and one incomplete one pending for a final ACK and hence not ready to be
- // delivered to the endpoint.
- //
- // Now execute one more handshake. This should not be completed and
- // delivered on an Accept() call as the backlog is full at this point
- // and there is already 1 pending endpoint.
- //
- // This one should use a SYN cookie as the synRcvdCount is equal to the
- // SynRcvdCountThreshold.
- time.Sleep(50 * time.Millisecond)
- portOffset++
- irs, iss := executeHandshake(t, c, context.TestPort+portOffset, true)
-
- time.Sleep(50 * time.Millisecond)
+ // Send a SYN request.
+ irs := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+ // The Syn should be dropped as the endpoint's backlog is full.
+ c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
// Verify that there is only one acceptable connection at this point.
we, ch := waiter.NewChannelEntry(nil)
@@ -3604,68 +3794,6 @@
case <-time.After(1 * time.Second):
}
}
-
- // Now send an ACK for the half completed connection
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + portOffset - 1,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: nonCookieIRS + 1,
- AckNum: nonCookieISS + 1,
- RcvWnd: 30000,
- })
-
- // Verify that the connection is now delivered to the backlog.
- _, _, err = c.EP.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- _, _, err = c.EP.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Finally send an ACK for the connection that used a cookie and verify that
- // it's also completed and delivered.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + portOffset,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs,
- AckNum: iss,
- RcvWnd: 30000,
- })
-
- time.Sleep(50 * time.Millisecond)
- newEP, _, err := c.EP.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- newEP, _, err = c.EP.Accept()
- if err != nil {
- t.Fatalf("Accept failed: %v", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
- // Now verify that the TCP socket is usable and in a connected state.
- data := "Don't panic"
- newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
- b := c.GetPacket()
- tcp := header.TCP(header.IPv4(b).Payload())
- if string(tcp.Payload()) != data {
- t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
- }
}
func TestPassiveConnectionAttemptIncrement(t *testing.T) {
@@ -3680,9 +3808,15 @@
if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %v", err)
}
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
if err := c.EP.Listen(1); err != nil {
t.Fatalf("Listen failed: %v", err)
}
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
stats := c.Stack().Stats()
want := stats.TCP.PassiveConnectionOpenings.Value() + 1
@@ -3733,18 +3867,12 @@
}
srcPort := uint16(context.TestPort)
- // Now attempt 3 handshakes, the first two will fill up the accept and the SYN-RCVD
- // queue for the endpoint.
+ // Now attempt a handshakes it will fill up the accept backlog.
executeHandshake(t, c, srcPort, false)
// Give time for the final ACK to be processed as otherwise the next handshake could
// get accepted before the previous one based on goroutine scheduling.
time.Sleep(50 * time.Millisecond)
- irs, iss := executeHandshake(t, c, srcPort+1, false)
-
- // Wait for a short while for the accepted connection to be delivered to
- // the channel before trying to send the 3rd SYN.
- time.Sleep(40 * time.Millisecond)
want := stats.TCP.ListenOverflowSynDrop.Value() + 1
@@ -3782,26 +3910,44 @@
t.Fatalf("Timed out waiting for accept")
}
}
+}
- // Now complete the next connection in SYN-RCVD state as it should
- // have dropped the final ACK to the handshake due to accept queue
- // being full.
- c.SendPacket(nil, &context.Headers{
- SrcPort: srcPort + 1,
- DstPort: context.StackPort,
- Flags: header.TCPFlagAck,
- SeqNum: irs + 1,
- AckNum: iss + 1,
- RcvWnd: 30000,
- })
+func TestEndpointBindListenAcceptState(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
- // Now check that there is one more acceptable connections.
- _, _, err = c.EP.Accept()
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
+ c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS})
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ aep, _, err := ep.Accept()
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
- _, _, err = c.EP.Accept()
+ aep, _, err = ep.Accept()
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
@@ -3810,19 +3956,293 @@
t.Fatalf("Timed out waiting for accept")
}
}
+ if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ // Listening endpoint remains in listen state.
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
- // Try and accept a 3rd one this should fail.
- _, _, err = c.EP.Accept()
- if err == tcpip.ErrWouldBlock {
- // Wait for connection to be established.
- select {
- case <-ch:
- ep, _, err = c.EP.Accept()
- if err == nil {
- t.Fatalf("Accept succeeded when it should have failed got: %+v", ep)
- }
+ ep.Close()
+ // Give worker goroutines time to receive the close notification.
+ time.Sleep(1 * time.Second)
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ // Accepted endpoint remains open when the listen endpoint is closed.
+ if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
+ t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
- case <-time.After(1 * time.Second):
+}
+
+// This test verifies that the auto tuning does not grow the receive buffer if
+// the application is not reading the data actively.
+func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
+ const mtu = 1500
+ const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
+
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ stk := c.Stack()
+ // Set lower limits for auto-tuning tests. This is required because the
+ // test stops the worker which can cause packets to be dropped because
+ // the segment queue holding unprocessed packets is limited to 500.
+ const receiveBufferSize = 80 << 10 // 80KB.
+ const maxReceiveBufferSize = receiveBufferSize * 10
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+
+ // Enable auto-tuning.
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+ // Change the expected window scale to match the value needed for the
+ // maximum buffer size defined above.
+ c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
+
+ rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
+
+ // NOTE: The timestamp values in the sent packets are meaningless to the
+ // peer so we just increment the timestamp value by 1 every batch as we
+ // are not really using them for anything. Send a single byte to verify
+ // the advertised window.
+ tsVal := rawEP.TSVal + 1
+
+ // Introduce a 25ms latency by delaying the first byte.
+ latency := 25 * time.Millisecond
+ time.Sleep(latency)
+ rawEP.SendPacketWithTS([]byte{1}, tsVal)
+
+ // Verify that the ACK has the expected window.
+ wantRcvWnd := receiveBufferSize
+ wantRcvWnd = (wantRcvWnd >> uint32(c.WindowScale))
+ rawEP.VerifyACKRcvWnd(uint16(wantRcvWnd - 1))
+ time.Sleep(25 * time.Millisecond)
+
+ // Allocate a large enough payload for the test.
+ b := make([]byte, int(receiveBufferSize)*2)
+ offset := 0
+ payloadSize := receiveBufferSize - 1
+ worker := (c.EP).(interface {
+ StopWork()
+ ResumeWork()
+ })
+ tsVal++
+
+ // Stop the worker goroutine.
+ worker.StopWork()
+ start := offset
+ end := offset + payloadSize
+ packetsSent := 0
+ for ; start < end; start += mss {
+ rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
+ packetsSent++
+ }
+ // Resume the worker so that it only sees the packets once all of them
+ // are waiting to be read.
+ worker.ResumeWork()
+
+ // Since we read no bytes the window should goto zero till the
+ // application reads some of the data.
+ // Discard all intermediate acks except the last one.
+ if packetsSent > 100 {
+ for i := 0; i < (packetsSent / 100); i++ {
+ _ = c.GetPacket()
}
}
+ rawEP.VerifyACKRcvWnd(0)
+
+ time.Sleep(25 * time.Millisecond)
+ // Verify that sending more data when window is closed is dropped and
+ // not acked.
+ rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
+
+ // Verify that the stack sends us back an ACK with the sequence number
+ // of the last packet sent indicating it was dropped.
+ p := c.GetPacket()
+ checker.IPv4(t, p, checker.TCP(
+ checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)),
+ checker.Window(0),
+ ))
+
+ // Now read all the data from the endpoint and verify that advertised
+ // window increases to the full available buffer size.
+ for {
+ _, _, err := c.EP.Read(nil)
+ if err == tcpip.ErrWouldBlock {
+ break
+ }
+ }
+
+ // Verify that we receive a non-zero window update ACK. When running
+ // under thread santizer this test can end up sending more than 1
+ // ack, 1 for the non-zero window
+ p = c.GetPacket()
+ checker.IPv4(t, p, checker.TCP(
+ checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)),
+ func(t *testing.T, h header.Transport) {
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+ if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) {
+ t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w, wantRcvWnd)
+ }
+ },
+ ))
+}
+
+// This test verifies that the auto tuning does not grow the receive buffer if
+// the application is not reading the data actively.
+func TestReceiveBufferAutoTuning(t *testing.T) {
+ const mtu = 1500
+ const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize
+
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ // Enable Auto-tuning.
+ stk := c.Stack()
+ // Set lower limits for auto-tuning tests. This is required because the
+ // test stops the worker which can cause packets to be dropped because
+ // the segment queue holding unprocessed packets is limited to 500.
+ const receiveBufferSize = 80 << 10 // 80KB.
+ const maxReceiveBufferSize = receiveBufferSize * 10
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+
+ // Enable auto-tuning.
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+ // Change the expected window scale to match the value needed for the
+ // maximum buffer size used by stack.
+ c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize))
+
+ rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4})
+
+ wantRcvWnd := receiveBufferSize
+ scaleRcvWnd := func(rcvWnd int) uint16 {
+ return uint16(rcvWnd >> uint16(c.WindowScale))
+ }
+ // Allocate a large array to send to the endpoint.
+ b := make([]byte, receiveBufferSize*48)
+
+ // In every iteration we will send double the number of bytes sent in
+ // the previous iteration and read the same from the app. The received
+ // window should grow by at least 2x of bytes read by the app in every
+ // RTT.
+ offset := 0
+ payloadSize := receiveBufferSize / 8
+ worker := (c.EP).(interface {
+ StopWork()
+ ResumeWork()
+ })
+ tsVal := rawEP.TSVal
+ // We are going to do our own computation of what the moderated receive
+ // buffer should be based on sent/copied data per RTT and verify that
+ // the advertised window by the stack matches our calculations.
+ prevCopied := 0
+ done := false
+ latency := 1 * time.Millisecond
+ for i := 0; !done; i++ {
+ tsVal++
+
+ // Stop the worker goroutine.
+ worker.StopWork()
+ start := offset
+ end := offset + payloadSize
+ totalSent := 0
+ packetsSent := 0
+ for ; start < end; start += mss {
+ rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
+ totalSent += mss
+ packetsSent++
+ }
+ // Resume it so that it only sees the packets once all of them
+ // are waiting to be read.
+ worker.ResumeWork()
+
+ // Give 1ms for the worker to process the packets.
+ time.Sleep(1 * time.Millisecond)
+
+ // Verify that the advertised window on the ACK is reduced by
+ // the total bytes sent.
+ expectedWnd := wantRcvWnd - totalSent
+ if packetsSent > 100 {
+ for i := 0; i < (packetsSent / 100); i++ {
+ _ = c.GetPacket()
+ }
+ }
+ rawEP.VerifyACKRcvWnd(scaleRcvWnd(expectedWnd))
+
+ // Now read all the data from the endpoint and invoke the
+ // moderation API to allow for receive buffer auto-tuning
+ // to happen before we measure the new window.
+ totalCopied := 0
+ for {
+ b, _, err := c.EP.Read(nil)
+ if err == tcpip.ErrWouldBlock {
+ break
+ }
+ totalCopied += len(b)
+ }
+
+ // Invoke the moderation API. This is required for auto-tuning
+ // to happen. This method is normally expected to be invoked
+ // from a higher layer than tcpip.Endpoint. So we simulate
+ // copying to user-space by invoking it explicitly here.
+ c.EP.ModerateRecvBuf(totalCopied)
+
+ // Now send a keep-alive packet to trigger an ACK so that we can
+ // measure the new window.
+ rawEP.NextSeqNum--
+ rawEP.SendPacketWithTS(nil, tsVal)
+ rawEP.NextSeqNum++
+
+ if i == 0 {
+ // In the first iteration the receiver based RTT is not
+ // yet known as a result the moderation code should not
+ // increase the advertised window.
+ rawEP.VerifyACKRcvWnd(scaleRcvWnd(wantRcvWnd))
+ prevCopied = totalCopied
+ } else {
+ rttCopied := totalCopied
+ if i == 1 {
+ // The moderation code accumulates copied bytes till
+ // RTT is established. So add in the bytes sent in
+ // the first iteration to the total bytes for this
+ // RTT.
+ rttCopied += prevCopied
+ // Now reset it to the initial value used by the
+ // auto tuning logic.
+ prevCopied = tcp.InitialCwnd * mss * 2
+ }
+ newWnd := rttCopied<<1 + 16*mss
+ grow := (newWnd * (rttCopied - prevCopied)) / prevCopied
+ newWnd += (grow << 1)
+ if newWnd > maxReceiveBufferSize {
+ newWnd = maxReceiveBufferSize
+ done = true
+ }
+ rawEP.VerifyACKRcvWnd(scaleRcvWnd(newWnd))
+ wantRcvWnd = newWnd
+ prevCopied = rttCopied
+ // Increase the latency after first two iterations to
+ // establish a low RTT value in the receiver since it
+ // only tracks the lowest value. This ensures that when
+ // ModerateRcvBuf is called the elapsed time is always >
+ // rtt. Without this the test is flaky due to delays due
+ // to scheduling/wakeup etc.
+ latency += 50 * time.Millisecond
+ }
+ time.Sleep(latency)
+ offset += payloadSize
+ payloadSize *= 2
+ }
}
diff --git a/tcpip/transport/tcp/tcp_timestamp_test.go b/tcpip/transport/tcp/tcp_timestamp_test.go
index 5688c57..9a3e638 100644
--- a/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -182,7 +182,7 @@
wndSize uint16
}{
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
+ {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5.
}
for _, tc := range testCases {
timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
@@ -239,7 +239,7 @@
wndSize uint16
}{
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
- {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
+ {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5.
}
for _, tc := range testCases {
timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
diff --git a/tcpip/transport/tcp/testing/context/context.go b/tcpip/transport/tcp/testing/context/context.go
index 426026e..7dc5dce 100644
--- a/tcpip/transport/tcp/testing/context/context.go
+++ b/tcpip/transport/tcp/testing/context/context.go
@@ -72,12 +72,6 @@
testInitialSequenceNumber = 789
)
-// defaultWindowScale value specified here depends on the tcp.DefaultBufferSize
-// constant defined in the tcp/endpoint.go because the tcp.DefaultBufferSize is
-// used in tcp.newHandshake to determine the window scale to use when sending a
-// SYN/SYN-ACK.
-var defaultWindowScale = tcp.FindWndScale(tcp.DefaultBufferSize)
-
// Headers is used to represent the TCP header fields when building a
// new packet.
type Headers struct {
@@ -134,6 +128,10 @@
// TimeStampEnabled is true if ep is connected with the timestamp option
// enabled.
TimeStampEnabled bool
+
+ // WindowScale is the expected window scale in SYN packets sent by
+ // the stack.
+ WindowScale uint8
}
// New allocates and initializes a test context containing a new
@@ -142,11 +140,11 @@
s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
// Allow minimum send/receive buffer sizes to be 1 during tests.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultBufferSize, tcp.DefaultBufferSize * 10}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultBufferSize, tcp.DefaultBufferSize * 10}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize, 10 * tcp.DefaultReceiveBufferSize}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
@@ -184,9 +182,10 @@
})
return &Context{
- t: t,
- s: s,
- linkEP: linkEP,
+ t: t,
+ s: s,
+ linkEP: linkEP,
+ WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)),
}
}
@@ -520,32 +519,21 @@
c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil)
}
-// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
-// the specified option bytes as the Option field in the initial SYN packet.
+// Connect performs the 3-way handshake for c.EP with the provided Initial
+// Sequence Number (iss) and receive window(rcvWnd) and any options if
+// specified.
//
// It also sets the receive buffer for the endpoint to the specified
// value in epRcvBuf.
-func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
- // Create TCP endpoint.
- var err *tcpip.Error
- c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- if epRcvBuf != nil {
- if err := c.EP.SetSockOpt(*epRcvBuf); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
- }
-
+//
+// PreCondition: c.EP must already be created.
+func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) {
// Start connection attempt.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&waitEntry, waiter.EventOut)
defer c.WQ.EventUnregister(&waitEntry)
- err = c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort})
- if err != tcpip.ErrConnectStarted {
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted {
c.t.Fatalf("Unexpected return value from Connect: %v", err)
}
@@ -557,13 +545,16 @@
checker.TCPFlags(header.TCPFlagSyn),
),
)
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
- tcp := header.TCP(header.IPv4(b).Payload())
- c.IRS = seqnum.Value(tcp.SequenceNumber())
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
c.SendPacket(nil, &Headers{
- SrcPort: tcp.DestinationPort(),
- DstPort: tcp.SourcePort(),
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
Flags: header.TCPFlagSyn | header.TCPFlagAck,
SeqNum: iss,
AckNum: c.IRS.Add(1),
@@ -584,15 +575,38 @@
// Wait for connection to be established.
select {
case <-notifyCh:
- err = c.EP.GetSockOpt(tcpip.ErrorOption{})
- if err != nil {
+ if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
c.t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
c.t.Fatalf("Timed out waiting for connection")
}
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
+ c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
- c.Port = tcp.SourcePort()
+ c.Port = tcpHdr.SourcePort()
+}
+
+// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
+// the specified option bytes as the Option field in the initial SYN packet.
+//
+// It also sets the receive buffer for the endpoint to the specified
+// value in epRcvBuf.
+func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if epRcvBuf != nil {
+ if err := c.EP.SetSockOpt(*epRcvBuf); err != nil {
+ c.t.Fatalf("SetSockOpt failed failed: %v", err)
+ }
+ }
+ c.Connect(iss, rcvWnd, options)
}
// RawEndpoint is just a small wrapper around a TCP endpoint's state to make
@@ -657,6 +671,21 @@
r.RecentTS = opts.TSVal
}
+// VerifyACKRcvWnd verifies that the window advertised by the incoming ACK
+// matches the provided rcvWnd.
+func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) {
+ ackPacket := r.C.GetPacket()
+ checker.IPv4(r.C.t, ackPacket,
+ checker.TCP(
+ checker.DstPort(r.SrcPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(r.AckNum)),
+ checker.AckNum(uint32(r.NextSeqNum)),
+ checker.Window(rcvWnd),
+ ),
+ )
+}
+
// VerifyACKNoSACK verifies that the ACK does not contain a SACK block.
func (r *RawEndpoint) VerifyACKNoSACK() {
r.VerifyACKHasSACK(nil)
@@ -690,6 +719,9 @@
if err != nil {
c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err)
}
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want {
+ c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
// Start connection attempt.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
@@ -714,17 +746,24 @@
checker.TCPSynOptions(header.TCPSynOptions{
MSS: mss,
TS: true,
- WS: defaultWindowScale,
+ WS: int(c.WindowScale),
SACKPermitted: c.SACKEnabled(),
}),
),
)
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
tcpSeg := header.TCP(header.IPv4(b).Payload())
synOptions := header.ParseSynOptions(tcpSeg.Options(), false)
// Build options w/ tsVal to be sent in the SYN-ACK.
synAckOptions := make([]byte, header.TCPOptionsMaximumSize)
offset := 0
+ if wantOptions.WS != -1 {
+ offset += header.EncodeWSOption(wantOptions.WS, synAckOptions[offset:])
+ }
if wantOptions.TS {
offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:])
}
@@ -782,6 +821,9 @@
case <-time.After(1 * time.Second):
c.t.Fatalf("Timed out waiting for connection")
}
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
+ c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
// Store the source port in use by the endpoint.
c.Port = tcpSeg.SourcePort()
@@ -821,10 +863,16 @@
if err := ep.Bind(tcpip.FullAddress{Port: StackPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
+ c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
if err := ep.Listen(10); err != nil {
c.t.Fatalf("Listen failed: %v", err)
}
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
+ c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
rep := c.PassiveConnectWithOptions(100, wndScale, synOptions)
@@ -847,6 +895,10 @@
c.t.Fatalf("Timed out waiting for accept")
}
}
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
+ c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+
return rep
}
diff --git a/tcpip/transport/udp/endpoint.go b/tcpip/transport/udp/endpoint.go
index 34d5c87..059c68d 100644
--- a/tcpip/transport/udp/endpoint.go
+++ b/tcpip/transport/udp/endpoint.go
@@ -169,6 +169,9 @@
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
+// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
+func (e *endpoint) ModerateRecvBuf(copied int) {}
+
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
@@ -1000,3 +1003,9 @@
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
}
+
+// State implements socket.Socket.State.
+func (e *endpoint) State() uint32 {
+ // TODO(b/112063468): Translate internal state to values returned by Linux.
+ return 0
+}