blob: 5e9237de99188310f1d491c22749a1964f54b400 [file] [log] [blame]
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack_test
import (
"math"
"math/rand"
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
stackAddr = "\x0a\x00\x00\x01"
stackPort = 1234
testPort = 4096
)
type testContext struct {
t *testing.T
linkEps map[tcpip.NICID]*channel.Endpoint
s *stack.Stack
ep tcpip.Endpoint
wq waiter.Queue
}
func (c *testContext) cleanup() {
if c.ep != nil {
c.ep.Close()
}
}
func (c *testContext) createV6Endpoint(v6only bool) {
var err *tcpip.Error
c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil {
c.t.Fatalf("SetSockOpt failed: %v", err)
}
}
// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs.
func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
linkEps := make(map[tcpip.NICID]*channel.Endpoint)
for _, linkEpID := range linkEpIDs {
channelEp := channel.New(256, mtu, "")
if err := s.CreateNIC(linkEpID, channelEp); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
linkEps[linkEpID] = channelEp
if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil {
t.Fatalf("AddAddress IPv4 failed: %v", err)
}
if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
t.Fatalf("AddAddress IPv6 failed: %v", err)
}
}
s.SetRouteTable([]tcpip.Route{
{
Destination: header.IPv4EmptySubnet,
NIC: 1,
},
{
Destination: header.IPv6EmptySubnet,
NIC: 1,
},
})
return &testContext{
t: t,
s: s,
linkEps: linkEps,
}
}
type headers struct {
srcPort uint16
dstPort uint16
}
func newPayload() []byte {
b := make([]byte, 30+rand.Intn(100))
for i := range b {
b[i] = byte(rand.Intn(256))
}
return b
}
func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
NextHeader: uint8(udp.ProtocolNumber),
HopLimit: 65,
SrcAddr: testV6Addr,
DstAddr: stackV6Addr,
})
// Initialize the UDP header.
u := header.UDP(buf[header.IPv6MinimumSize:])
u.Encode(&header.UDPFields{
SrcPort: h.srcPort,
DstPort: h.dstPort,
Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
// Calculate the UDP checksum and set it.
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
func TestTransportDemuxerRegister(t *testing.T) {
for _, test := range []struct {
name string
proto tcpip.NetworkProtocolNumber
want *tcpip.Error
}{
{"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol},
{"success", ipv4.ProtocolNumber, nil},
} {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want {
t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want)
}
})
}
}
// TestReuseBindToDevice injects varied packets on input devices and checks that
// the distribution of packets received matches expectations.
func TestDistribution(t *testing.T) {
type endpointSockopts struct {
reuse int
bindToDevice tcpip.NICID
}
for _, test := range []struct {
name string
// endpoints will received the inject packets.
endpoints []endpointSockopts
// wantedDistribution is the wanted ratio of packets received on each
// endpoint for each NIC on which packets are injected.
wantedDistributions map[tcpip.NICID][]float64
}{
{
"BindPortReuse",
// 5 endpoints that all have reuse set.
[]endpointSockopts{
{1, 0},
{1, 0},
{1, 0},
{1, 0},
{1, 0},
},
map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed evenly.
1: {0.2, 0.2, 0.2, 0.2, 0.2},
},
},
{
"BindToDevice",
// 3 endpoints with various bindings.
[]endpointSockopts{
{0, 1},
{0, 2},
{0, 3},
},
map[tcpip.NICID][]float64{
// Injected packets on dev0 go only to the endpoint bound to dev0.
1: {1, 0, 0},
// Injected packets on dev1 go only to the endpoint bound to dev1.
2: {0, 1, 0},
// Injected packets on dev2 go only to the endpoint bound to dev2.
3: {0, 0, 1},
},
},
{
"ReuseAndBindToDevice",
// 6 endpoints with various bindings.
[]endpointSockopts{
{1, 1},
{1, 1},
{1, 2},
{1, 2},
{1, 2},
{1, 0},
},
map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed among endpoints bound to
// dev0.
1: {0.5, 0.5, 0, 0, 0, 0},
// Injected packets on dev1 get distributed among endpoints bound to
// dev1 or unbound.
2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
// Injected packets on dev999 go only to the unbound.
1000: {0, 0, 0, 0, 0, 1},
},
},
} {
t.Run(test.name, func(t *testing.T) {
for device, wantedDistribution := range test.wantedDistributions {
t.Run(string(device), func(t *testing.T) {
var devices []tcpip.NICID
for d := range test.wantedDistributions {
devices = append(devices, d)
}
c := newDualTestContextMultiNIC(t, defaultMTU, devices)
defer c.cleanup()
c.createV6Endpoint(false)
eps := make(map[tcpip.Endpoint]int)
pollChannel := make(chan tcpip.Endpoint)
for i, endpoint := range test.endpoints {
// Try to receive the data.
wq := waiter.Queue{}
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
defer close(ch)
var err *tcpip.Error
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
eps[ep] = i
go func(ep tcpip.Endpoint) {
for range ch {
pollChannel <- ep
}
}(ep)
defer ep.Close()
reusePortOption := tcpip.ReusePortOption(endpoint.reuse)
if err := ep.SetSockOpt(reusePortOption); err != nil {
c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err)
}
bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err)
}
if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err)
}
}
npackets := 100000
nports := 10000
if got, want := len(test.endpoints), len(wantedDistribution); got != want {
t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
}
ports := make(map[uint16]tcpip.Endpoint)
stats := make(map[tcpip.Endpoint]int)
for i := 0; i < npackets; i++ {
// Send a packet.
port := uint16(i % nports)
payload := newPayload()
c.sendV6Packet(payload,
&headers{
srcPort: testPort + port,
dstPort: stackPort},
device)
var addr tcpip.FullAddress
ep := <-pollChannel
_, _, err := ep.Read(&addr)
if err != nil {
c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err)
}
stats[ep]++
if i < nports {
ports[uint16(i)] = ep
} else {
// Check that all packets from one client are handled by the same
// socket.
if want, got := ports[port], ep; want != got {
t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got])
}
}
}
// Check that a packet distribution is as expected.
for ep, i := range eps {
wantedRatio := wantedDistribution[i]
wantedRecv := wantedRatio * float64(npackets)
actualRecv := stats[ep]
actualRatio := float64(stats[ep]) / float64(npackets)
// The deviation is less than 10%.
if math.Abs(actualRatio-wantedRatio) > 0.05 {
t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets)
}
}
})
}
})
}
}