blob: 3bc29ff4cb0e9ac4fb5b1d2d3226a29705a97553 [file] [log] [blame]
// Copyright 2016 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package udp_test
import (
"bytes"
"math/rand"
"testing"
"time"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/checker"
"github.com/google/netstack/tcpip/header"
"github.com/google/netstack/tcpip/link/channel"
"github.com/google/netstack/tcpip/link/sniffer"
"github.com/google/netstack/tcpip/network/ipv4"
"github.com/google/netstack/tcpip/network/ipv6"
"github.com/google/netstack/tcpip/stack"
"github.com/google/netstack/tcpip/transport/udp"
"github.com/google/netstack/waiter"
)
const (
testLinkAddr = "\x00\x00\x00\x00\x00\x02"
stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr
testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr
multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr
V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
stackAddr = "\x0a\x00\x00\x01"
stackPort = 1234
testAddr = "\x0a\x00\x00\x02"
testPort = 4096
multicastAddr = "\xe8\x2b\xd3\xea"
multicastPort = 1234
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
// of loopback interfaces on linux systems.
defaultMTU = 65536
)
type testContext struct {
t *testing.T
linkEP *channel.Endpoint
s *stack.Stack
ep tcpip.Endpoint
wq waiter.Queue
}
type headers struct {
srcPort uint16
dstPort uint16
}
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName})
id, linkEP := channel.New(256, mtu, "")
if testing.Verbose() {
id = sniffer.New(id)
}
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{
{
Destination: "\x00\x00\x00\x00",
Mask: "\x00\x00\x00\x00",
Gateway: "",
NIC: 1,
},
{
Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
Gateway: "",
NIC: 1,
},
})
// Add test IP -> MAC mappings to LinkResolverCache
s.AddLinkAddress(1, testV6Addr, testLinkAddr)
s.AddLinkAddress(1, testV4MappedAddr, testLinkAddr)
return &testContext{
t: t,
s: s,
linkEP: linkEP,
}
}
func (c *testContext) cleanup() {
if c.ep != nil {
c.ep.Close()
}
}
func (c *testContext) createV6Endpoint(v4only bool) {
var err *tcpip.Error
c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
var v tcpip.V6OnlyOption
if v4only {
v = 1
}
if err := c.ep.SetSockOpt(v); err != nil {
c.t.Fatalf("SetSockOpt failed failed: %v", err)
}
}
func (c *testContext) getV6Packet() []byte {
select {
case p := <-c.linkEP.C:
if p.Proto != ipv6.ProtocolNumber {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
}
b := make([]byte, len(p.Header)+len(p.Payload))
copy(b, p.Header)
copy(b[len(p.Header):], p.Payload)
checker.IPv6(c.t, b, checker.SrcAddr(stackV6Addr), checker.DstAddr(testV6Addr))
return b
case <-time.After(2 * time.Second):
c.t.Fatalf("Packet wasn't written out")
}
return nil
}
func (c *testContext) getPacket() []byte {
select {
case p := <-c.linkEP.C:
if p.Proto != ipv4.ProtocolNumber {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
b := make([]byte, len(p.Header)+len(p.Payload))
copy(b, p.Header)
copy(b[len(p.Header):], p.Payload)
checker.IPv4(c.t, b, checker.SrcAddr(stackAddr), checker.DstAddr(testAddr))
return b
case <-time.After(2 * time.Second):
c.t.Fatalf("Packet wasn't written out")
}
return nil
}
func (c *testContext) getMCPacket() []byte {
select {
case p := <-c.linkEP.C:
if p.Proto != ipv4.ProtocolNumber {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
b := make([]byte, len(p.Header)+len(p.Payload))
copy(b, p.Header)
copy(b[len(p.Header):], p.Payload)
checker.IPv4(c.t, b, checker.SrcAddr(stackAddr), checker.DstAddr(multicastAddr))
return b
case <-time.After(2 * time.Second):
c.t.Fatalf("Packet wasn't written out")
}
return nil
}
func (c *testContext) sendV6Packet(payload []byte, h *headers) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
NextHeader: uint8(udp.ProtocolNumber),
HopLimit: 65,
SrcAddr: testV6Addr,
DstAddr: stackV6Addr,
})
// Initialize the UDP header.
u := header.UDP(buf[header.IPv6MinimumSize:])
u.Encode(&header.UDPFields{
SrcPort: h.srcPort,
DstPort: h.dstPort,
Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
xsum := header.Checksum([]byte(testV6Addr), 0)
xsum = header.Checksum([]byte(stackV6Addr), xsum)
xsum = header.Checksum([]byte{0, uint8(udp.ProtocolNumber)}, xsum)
// Calculate the UDP checksum and set it.
length := uint16(header.UDPMinimumSize + len(payload))
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum, length))
// Inject packet.
var views [1]buffer.View
vv := buf.ToVectorisedView(views)
c.linkEP.Inject(ipv6.ProtocolNumber, &vv)
}
func (c *testContext) sendPacket(payload []byte, h *headers) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
// Initialize the IP header.
ip := header.IPv4(buf)
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(udp.ProtocolNumber),
SrcAddr: testAddr,
DstAddr: stackAddr,
})
ip.SetChecksum(^ip.CalculateChecksum())
// Initialize the UDP header.
u := header.UDP(buf[header.IPv4MinimumSize:])
u.Encode(&header.UDPFields{
SrcPort: h.srcPort,
DstPort: h.dstPort,
Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
xsum := header.Checksum([]byte(testAddr), 0)
xsum = header.Checksum([]byte(stackAddr), xsum)
xsum = header.Checksum([]byte{0, uint8(udp.ProtocolNumber)}, xsum)
// Calculate the UDP checksum and set it.
length := uint16(header.UDPMinimumSize + len(payload))
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum, length))
// Inject packet.
var views [1]buffer.View
vv := buf.ToVectorisedView(views)
c.linkEP.Inject(ipv4.ProtocolNumber, &vv)
}
func newPayload() []byte {
b := make([]byte, 30+rand.Intn(100))
for i := range b {
b[i] = byte(rand.Intn(256))
}
return b
}
func testV4Read(c *testContext) {
// Send a packet.
payload := newPayload()
c.sendPacket(payload, &headers{
srcPort: testPort,
dstPort: stackPort,
})
// Try to receive the data.
we, ch := waiter.NewChannelEntry(nil)
c.wq.EventRegister(&we, waiter.EventIn)
defer c.wq.EventUnregister(&we)
var addr tcpip.FullAddress
v, err := c.ep.Read(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for data to become available.
select {
case <-ch:
v, err = c.ep.Read(&addr)
if err != nil {
c.t.Fatalf("Read failed: %v", err)
}
case <-time.After(1 * time.Second):
c.t.Fatalf("Timed out waiting for data")
}
}
// Check the peer address.
if addr.Addr != testAddr {
c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr)
}
// Check the payload.
if !bytes.Equal(payload, v) {
c.t.Fatalf("Bad payload: got %x, want %x", v, payload)
}
}
func TestV4ReadOnV6(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
// Test acceptance.
testV4Read(c)
}
func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Bind to v4 mapped wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Addr: V4MappedWildcardAddr, Port: stackPort}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
// Test acceptance.
testV4Read(c)
}
func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Bind to local adress.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
// Test acceptance.
testV4Read(c)
}
func TestV6ReadOnV6(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
// Send a packet.
payload := newPayload()
c.sendV6Packet(payload, &headers{
srcPort: testPort,
dstPort: stackPort,
})
// Try to receive the data.
we, ch := waiter.NewChannelEntry(nil)
c.wq.EventRegister(&we, waiter.EventIn)
defer c.wq.EventUnregister(&we)
var addr tcpip.FullAddress
v, err := c.ep.Read(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for data to become available.
select {
case <-ch:
v, err = c.ep.Read(&addr)
if err != nil {
c.t.Fatalf("Read failed: %v", err)
}
case <-time.After(1 * time.Second):
c.t.Fatalf("Timed out waiting for data")
}
}
// Check the peer address.
if addr.Addr != testV6Addr {
c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr)
}
// Check the payload.
if !bytes.Equal(payload, v) {
c.t.Fatalf("Bad payload: got %x, want %x", v, payload)
}
}
func TestV4ReadOnV4(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
// Create v4 UDP endpoint.
var err *tcpip.Error
c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
// Test acceptance.
testV4Read(c)
}
func testDualWrite(c *testContext) uint16 {
// Write to V4 mapped address.
payload := buffer.View(newPayload())
n, err := c.ep.Write(payload, &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort})
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
if n != uintptr(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
// Check that we received the packet.
b := c.getPacket()
udp := header.UDP(header.IPv4(b).Payload())
checker.IPv4(c.t, b,
checker.UDP(
checker.DstPort(testPort),
),
)
port := udp.SourcePort()
// Check the payload.
if !bytes.Equal(payload, udp.Payload()) {
c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
}
// Write to v6 address.
payload = buffer.View(newPayload())
n, err = c.ep.Write(payload, &tcpip.FullAddress{Addr: testV6Addr, Port: testPort})
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
if n != uintptr(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
// Check that we received the packet, and that the source port is the
// same as the one used in ipv4.
b = c.getV6Packet()
udp = header.UDP(header.IPv6(b).Payload())
checker.IPv6(c.t, b,
checker.UDP(
checker.DstPort(testPort),
checker.SrcPort(port),
),
)
// Check the payload.
if !bytes.Equal(payload, udp.Payload()) {
c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
}
return port
}
func TestDualWriteUnbound(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
testDualWrite(c)
}
func TestDualWriteBoundToWildcard(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
p := testDualWrite(c)
if p != stackPort {
c.t.Fatalf("Bad port: got %v, want %v", p, stackPort)
}
}
func TestDualWriteConnectedToV6(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
testDualWrite(c)
}
func TestDualWriteConnectedToV4Mapped(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
testDualWrite(c)
}
func TestV4WriteOnV6Only(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(true)
// Write to V4 mapped address.
payload := buffer.View(newPayload())
_, err := c.ep.Write(payload, &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort})
if err != tcpip.ErrNoRoute {
c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute)
}
}
func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Bind to v4 mapped address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
// Write to v6 address.
payload := buffer.View(newPayload())
_, err := c.ep.Write(payload, &tcpip.FullAddress{Addr: testV6Addr, Port: testPort})
if err != tcpip.ErrNoRoute {
c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute)
}
}
func TestV6WriteOnConnected(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
c.t.Fatalf("Connect failed: %v", err)
}
// Write without destination.
payload := buffer.View(newPayload())
n, err := c.ep.Write(payload, nil)
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
if n != uintptr(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
// Check that we received the packet.
b := c.getV6Packet()
udp := header.UDP(header.IPv6(b).Payload())
checker.IPv6(c.t, b,
checker.UDP(
checker.DstPort(testPort),
),
)
// Check the payload.
if !bytes.Equal(payload, udp.Payload()) {
c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
}
}
func TestV4WriteOnConnected(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
c.t.Fatalf("Connect failed: %v", err)
}
// Write without destination.
payload := buffer.View(newPayload())
n, err := c.ep.Write(payload, nil)
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
if n != uintptr(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
// Check that we received the packet.
b := c.getPacket()
udp := header.UDP(header.IPv4(b).Payload())
checker.IPv4(c.t, b,
checker.UDP(
checker.DstPort(testPort),
),
)
// Check the payload.
if !bytes.Equal(payload, udp.Payload()) {
c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
}
}
func TestMulticastTTL(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
c.ep.SetSockOpt(tcpip.MulticastTTLOption(42))
payload := buffer.View(newPayload())
// Write a multicast packet. Its TTL value should be the above multicast value.
{
n, err := c.ep.Write(payload, &tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: multicastPort})
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
if n != uintptr(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
// Check that we received the packet and that it has the multicastTTL value.
b := c.getMCPacket()
checker.IPv4(c.t, b,
checker.TTL(42),
checker.UDP(
checker.DstPort(multicastPort),
),
)
}
// Write a regular packet. Its TTL value should be the default.
{
n, err := c.ep.Write(payload, &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort})
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
if n != uintptr(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
b := c.getPacket()
checker.IPv4(c.t, b,
checker.TTL(header.IPv4DefaultTTL),
checker.UDP(
checker.DstPort(testPort),
),
)
}
}
func TestReadIncrementsPacketsReceived(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
// Create IPv4 UDP endpoint
var err *tcpip.Error
c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
testV4Read(c)
var want uint64 = 1
if got := c.s.Stats().UDP.PacketsReceived; got != want {
c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want)
}
}
func TestWriteIncrementsPacketsSent(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
testDualWrite(c)
var want uint64 = 2
if got := c.s.Stats().UDP.PacketsSent; got != want {
c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want)
}
}