blob: a3f1647b7048a1153e48a2b90456174528e3b322 [file] [log] [blame]
// Copyright 2022 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//go:build !build_with_native_toolchain
package netstack
import (
"context"
"fmt"
"syscall/zx"
"testing"
"time"
"go.fuchsia.dev/fuchsia/src/connectivity/network/netstack/udp_serde"
"go.fuchsia.dev/fuchsia/src/connectivity/network/netstack/util"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/waiter"
)
func TestDatagramSocketWithBlockingEndpoint(t *testing.T) {
for _, test := range []struct {
name string
closeWhileBlocked bool
}{
{name: "closeWhileBlocked", closeWhileBlocked: true},
{name: "closeAfterUnblocked", closeWhileBlocked: false},
} {
t.Run(test.name, func(t *testing.T) {
ns, _ := newNetstack(t, netstackTestOptions{})
linkEp := &sentinelEndpoint{}
linkEp.SetBlocking(true)
ifState := installAndValidateIface(t, ns, func(t *testing.T, ns *Netstack, name string) *ifState {
return addLinkEndpoint(t, ns, name, linkEp)
})
t.Cleanup(ifState.RemoveByUser)
addr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: util.Parse("240.240.240.240"),
PrefixLen: 24,
},
}
addAddressAndRoute(t, ns, ifState, addr)
wq := new(waiter.Queue)
ep := func() tcpip.Endpoint {
ep, err := ns.stack.NewEndpoint(header.UDPProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
t.Fatalf("NewEndpoint(header.UDPProtocolNumber, ipv4.ProtocolNumber, _) = %s", err)
}
return ep
}()
s, err := newDatagramSocketImpl(ns, ipv4.ProtocolNumber, ep, wq)
if err != nil {
t.Fatalf("got newDatagramSocketImpl(_, %d, _, _): %s", ipv4.ProtocolNumber, err)
}
// Increment refcount and provide a cancel callback so the endpoint can be
// closed below.
s.endpoint.incRef()
ctx, cancel := context.WithCancel(context.Background())
s.cancel = cancel
io, err := s.Describe(context.Background())
if err != nil {
t.Fatalf("got s.Describe(): %s", err)
}
data := []byte{0, 1, 2, 3, 4}
preludeSize := io.TxMetaBufSize
buf := make([]byte, len(data)+int(preludeSize))
toAddr := &tcpip.FullAddress{
Addr: addr.AddressWithPrefix.Address,
Port: 42,
}
if err := udp_serde.SerializeSendMsgMeta(
ipv4.ProtocolNumber,
*toAddr,
tcpip.SendableControlMessages{},
buf[:preludeSize],
); err != nil {
t.Fatalf("SerializeSendMsgAddress(%d, %#v, _): %s", ipv4.ProtocolNumber, toAddr, err)
}
copy(buf[preludeSize:], data)
writeUntilBlocked := func() uint {
written := 0
for {
n, err := io.Socket.Write(buf, 0)
if err == nil {
if got, want := n, len(buf); got != want {
t.Fatalf("got zx.socket.Write(_) = (%d, %s), want (%d, nil)", got, err, want)
}
written += 1
} else {
if err, ok := err.(*zx.Error); ok && err.Status == zx.ErrShouldWait {
break
}
t.Fatalf("got zx.socket.Write(_) = (_, %s), want (_, %s)", err, zx.ErrShouldWait)
}
}
return uint(written)
}
const numPayloadsFittingInSendBuf = 10
bytesPerPayload := len(data) + header.UDPMinimumSize + header.IPv4MaximumHeaderSize
ep.SocketOptions().SetSendBufferSize(int64(numPayloadsFittingInSendBuf*bytesPerPayload), false)
var enqueuedSoFar uint
expectLinkEpEnqueued := func(expected uint) error {
if got, want := linkEp.Enqueued(), expected+enqueuedSoFar; got != want {
return fmt.Errorf("got linkEp.Enqueued() = %d, want %d", got, want)
}
enqueuedSoFar += expected
return nil
}
// Expect that the sender becomes blocked once the link endpoint has enqueued
// enough payloads to exhaust the send buffer.
inflightPayloads := func() uint {
waiter := linkEp.WaitFor(numPayloadsFittingInSendBuf)
inflightPayloads := writeUntilBlocked()
if inflightPayloads < numPayloadsFittingInSendBuf {
t.Fatalf("wrote %d payloads, want at least %d", inflightPayloads, numPayloadsFittingInSendBuf)
}
<-waiter
if err := expectLinkEpEnqueued(numPayloadsFittingInSendBuf); err != nil {
t.Fatal(err)
}
inflightPayloads -= numPayloadsFittingInSendBuf
return inflightPayloads
}()
// Expect draining N packets lets N more be processed.
{
drained, waiter := linkEp.Drain()
if got, want := drained, uint(numPayloadsFittingInSendBuf); got != want {
t.Fatalf("got blockingLinkEp.Drain() = %d, want %d", got, want)
}
if inflightPayloads < drained {
t.Fatalf("wrote %d payloads, want at least %d", inflightPayloads, drained)
}
<-waiter
if err := expectLinkEpEnqueued(drained); err != nil {
t.Fatal(err)
}
inflightPayloads -= drained
}
validateClose := func() error {
// Expect the cancel routine is not called before the endpoint was closed.
if err := ctx.Err(); err != nil {
return fmt.Errorf("ctx unexpectedly closed with error: %w", err)
}
if _, err := s.Close(context.Background()); err != nil {
return fmt.Errorf("s.Close(): %w", err)
}
// Expect the cancel routine is called when the endpoint is closed.
<-ctx.Done()
return nil
}
if test.closeWhileBlocked {
if err := validateClose(); err != nil {
t.Fatal(err)
}
// Closing the endpoint while it is blocked drops outgoing payloads
// on the floor.
if err := expectLinkEpEnqueued(0); err != nil {
t.Fatal(err)
}
} else {
linkEp.SetBlocking(false)
// Wait until the write loop becomes unblocked and begins writing again before
// closing the socket; otherwise the notifications of the endpoint being
// writable and the socket closing can race, and the write loop will exit before
// enqueueing the remaining packets in the zircon socket.
<-linkEp.WaitFor(1)
if err := validateClose(); err != nil {
t.Fatal(err)
}
// When the endpoint is unblocked, Close() should block until all
// remaining payloads are sent.
if err := expectLinkEpEnqueued(inflightPayloads); err != nil {
t.Fatal(err)
}
}
})
}
}
func newNetstackAndEndpoint(t *testing.T, transProto tcpip.TransportProtocolNumber) (*Netstack, *faketime.ManualClock, *waiter.Queue, tcpip.Endpoint) {
t.Helper()
ns, clock := newNetstack(t, netstackTestOptions{})
wq := new(waiter.Queue)
ep := func() tcpip.Endpoint {
ep, err := ns.stack.NewEndpoint(transProto, ipv4.ProtocolNumber, wq)
if err != nil {
t.Fatalf("NewEndpoint(%d, %d, _): %s", transProto, ipv4.ProtocolNumber, err)
}
return ep
}()
return ns, clock, wq, ep
}
func addEndpoint(t *testing.T, ns *Netstack, ep *endpoint, stats socketOptionStats) {
t.Helper()
ns.onAddEndpoint(ep, stats)
ep.incRef()
}
func verifyZirconSocketClosed(t *testing.T, e *endpointWithSocket) {
t.Helper()
if e.local.Handle().IsValid() {
t.Error("got e.local.Handle().IsValid() = true, want = false")
}
if e.peer.Handle().IsValid() {
t.Error("got e.peer.Handle().IsValid() = true, want = false")
}
}
func TestCloseDatagramSocketClosesHandles(t *testing.T) {
ns, _, wq, ep := newNetstackAndEndpoint(t, header.UDPProtocolNumber)
s, err := newDatagramSocketImpl(ns, ipv4.ProtocolNumber, ep, wq)
if err != nil {
t.Fatalf("newDatagramSocketImpl(_, %d, _, _): %s", ipv4.ProtocolNumber, err)
}
addEndpoint(t, ns, &s.endpoint, &s.endpointWithSocket.endpoint.sockOptStats)
// Provide a cancel callback so the endpoint can be closed below.
s.cancel = func() {}
if _, err := s.Close(context.Background()); err != nil {
t.Fatalf("s.Close(): %s", err)
}
// Verify that the handles associated with the socket have been closed.
verifyZirconSocketClosed(t, s.endpointWithSocket)
if s.sharedState.destinationCacheMu.destinationCache.local.IsValid() {
t.Error("got s.sharedState.destinationCacheMu.destinationCache.local.IsValid() = true, want = false")
}
if s.sharedState.destinationCacheMu.destinationCache.peer.IsValid() {
t.Error("got s.sharedState.destinationCacheMu.destinationCache.peer.IsValid() = true, want = false")
}
if s.sharedState.cmsgCacheMu.cmsgCache.local.IsValid() {
t.Error("got s.sharedState.cmsgCacheMu.cmsgCache.local.IsValid() = true, want = false")
}
if s.sharedState.cmsgCacheMu.cmsgCache.peer.IsValid() {
t.Error("got s.sharedState.cmsgCacheMu.cmsgCache.peer.IsValid() = true, want = false")
}
}
func TestCloseSynchronousDatagramSocketClosesHandles(t *testing.T) {
ns, _, wq, ep := newNetstackAndEndpoint(t, header.UDPProtocolNumber)
s, err := makeSynchronousDatagramSocket(ep, ipv4.ProtocolNumber, header.UDPProtocolNumber, wq, ns)
if err != nil {
t.Fatalf("makeSynchronousDatagramSocket(_, %d, %d, _, _): %s", ipv4.ProtocolNumber, header.UDPProtocolNumber, err)
}
addEndpoint(t, ns, &s.endpoint, &s.endpointWithEvent.endpoint.sockOptStats)
// Provide a cancel callback so the endpoint can be closed below.
s.cancel = func() {}
if _, err := s.Close(context.Background()); err != nil {
t.Fatalf("s.Close(): %s", err)
}
// Verify that the handles associated with the socket have been closed.
if s.endpointWithEvent.local.IsValid() {
t.Error("got s.endpointWithEvent.local.IsValid() = true, want = false")
}
if s.endpointWithEvent.peer.IsValid() {
t.Error("got s.endpointWithEvent.peer.IsValid() = true, want = false")
}
}
func newNetstackAndSreamSocket(t *testing.T) (*faketime.ManualClock, *streamSocketImpl) {
ns, clock, wq, ep := newNetstackAndEndpoint(t, header.TCPProtocolNumber)
socketEp, err := newEndpointWithSocket(ep, wq, header.TCPProtocolNumber, ipv4.ProtocolNumber, ns, zx.SocketStream)
if err != nil {
t.Fatalf("newEndpointWithSocket(_, _, %d, %d, _, _): %s", header.TCPProtocolNumber, ipv4.ProtocolNumber, err)
}
s := makeStreamSocketImpl(socketEp)
s.endpoint.incRef()
// Provide a cancel callback so the endpoint can be closed below.
s.cancel = func() {}
return clock, &s
}
func TestCloseStreamSocketClosesHandles(t *testing.T) {
_, s := newNetstackAndSreamSocket(t)
if _, err := s.Close(context.Background()); err != nil {
t.Fatalf("s.Close(): %s", err)
}
// Verify that the handles associated with the socket have been closed.
verifyZirconSocketClosed(t, s.endpointWithSocket)
}
func TestCloseUnblockLoopWrite(t *testing.T) {
tests := []struct {
name string
lingerOpt tcpip.LingerOption
tcpLingerOpt tcpip.TCPLingerTimeoutOption
dataInZXSocket bool
closeAfter time.Duration
}{
{
name: "SO_LINGER enabled",
lingerOpt: tcpip.LingerOption{Enabled: false, Timeout: time.Second},
tcpLingerOpt: tcpip.TCPLingerTimeoutOption(time.Hour),
closeAfter: time.Second,
},
{
name: "SO_LINGER disabled with empty zx.Socket",
lingerOpt: tcpip.LingerOption{Enabled: false, Timeout: time.Hour},
tcpLingerOpt: tcpip.TCPLingerTimeoutOption(time.Hour),
dataInZXSocket: false,
closeAfter: 0,
},
{
name: "SO_LINGER disabled with non-empty zx.Socket",
lingerOpt: tcpip.LingerOption{Enabled: false, Timeout: time.Hour},
tcpLingerOpt: tcpip.TCPLingerTimeoutOption(time.Second),
dataInZXSocket: true,
closeAfter: time.Second,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
clock, s := newNetstackAndSreamSocket(t)
s.endpoint.ep.SocketOptions().SetLinger(test.lingerOpt)
if err := s.endpoint.ep.SetSockOpt(&test.tcpLingerOpt); err != nil {
t.Fatalf("s.endpoint.ep.SetSockOpt(&%T): %s", test.tcpLingerOpt, err)
}
if test.dataInZXSocket {
var data [1]byte
if n, err := s.endpointWithSocket.local.Write(data[:], 0 /* flags */); err != nil {
t.Fatalf("s.endpointWithSocket.local.Write(_, 0): %s", err)
} else if n != len(data) {
t.Fatalf("got s.endpointWithSocket.local.Write(_, 0) = %d, want = %d", n, len(data))
}
}
if _, err := s.Close(context.Background()); err != nil {
t.Fatalf("s.Close(): %s", err)
}
clock.Advance(test.closeAfter)
select {
case <-s.unblockLoopWrite:
default:
t.Error("expected s.unblockLoopWrite to be readable")
}
})
}
}