blob: 1b6eed1464143d328156223a695d40226a2e3150 [file] [log] [blame]
// Copyright 2018 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build linux
package fdbased
import (
"bytes"
"fmt"
"math/rand"
"reflect"
"syscall"
"testing"
"time"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
"github.com/google/netstack/tcpip/stack"
)
const (
mtu = 1500
laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66")
raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc")
proto = 10
)
type packetInfo struct {
raddr tcpip.LinkAddress
proto tcpip.NetworkProtocolNumber
contents buffer.View
}
type context struct {
t *testing.T
fds [2]int
ep stack.LinkEndpoint
ch chan packetInfo
done chan struct{}
}
func newContext(t *testing.T, opt *Options) *context {
fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
if err != nil {
t.Fatalf("Socketpair failed: %v", err)
}
done := make(chan struct{}, 1)
opt.ClosedFunc = func(*tcpip.Error) {
done <- struct{}{}
}
opt.FD = fds[1]
ep := stack.FindLinkEndpoint(New(opt)).(*endpoint)
c := &context{
t: t,
fds: fds,
ep: ep,
ch: make(chan packetInfo, 100),
done: done,
}
ep.Attach(c)
return c
}
func (c *context) cleanup() {
syscall.Close(c.fds[0])
<-c.done
syscall.Close(c.fds[1])
}
func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, localLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
c.ch <- packetInfo{remoteLinkAddr, protocol, vv.ToView()}
}
func TestNoEthernetProperties(t *testing.T) {
c := newContext(t, &Options{MTU: mtu})
defer c.cleanup()
if want, v := uint16(0), c.ep.MaxHeaderLength(); want != v {
t.Fatalf("MaxHeaderLength() = %v, want %v", v, want)
}
if want, v := uint32(mtu), c.ep.MTU(); want != v {
t.Fatalf("MTU() = %v, want %v", v, want)
}
}
func TestEthernetProperties(t *testing.T) {
c := newContext(t, &Options{EthernetHeader: true, MTU: mtu})
defer c.cleanup()
if want, v := uint16(header.EthernetMinimumSize), c.ep.MaxHeaderLength(); want != v {
t.Fatalf("MaxHeaderLength() = %v, want %v", v, want)
}
if want, v := uint32(mtu), c.ep.MTU(); want != v {
t.Fatalf("MTU() = %v, want %v", v, want)
}
}
func TestAddress(t *testing.T) {
addrs := []tcpip.LinkAddress{"", "abc", "def"}
for _, a := range addrs {
t.Run(fmt.Sprintf("Address: %q", a), func(t *testing.T) {
c := newContext(t, &Options{Address: a, MTU: mtu})
defer c.cleanup()
if want, v := a, c.ep.LinkAddress(); want != v {
t.Fatalf("LinkAddress() = %v, want %v", v, want)
}
})
}
}
func TestWritePacket(t *testing.T) {
lengths := []int{0, 100, 1000}
eths := []bool{true, false}
for _, eth := range eths {
for _, plen := range lengths {
t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) {
c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth})
defer c.cleanup()
r := &stack.Route{
RemoteLinkAddress: raddr,
}
// Build header.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100)
b := hdr.Prepend(100)
for i := range b {
b[i] = uint8(rand.Intn(256))
}
// Build payload and write.
payload := make(buffer.View, plen)
for i := range payload {
payload[i] = uint8(rand.Intn(256))
}
want := append(hdr.View(), payload...)
if err := c.ep.WritePacket(r, hdr, payload.ToVectorisedView(), proto); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
// Read from fd, then compare with what we wrote.
b = make([]byte, mtu)
n, err := syscall.Read(c.fds[0], b)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
b = b[:n]
if eth {
h := header.Ethernet(b)
b = b[header.EthernetMinimumSize:]
if a := h.SourceAddress(); a != laddr {
t.Fatalf("SourceAddress() = %v, want %v", a, laddr)
}
if a := h.DestinationAddress(); a != raddr {
t.Fatalf("DestinationAddress() = %v, want %v", a, raddr)
}
if et := h.Type(); et != proto {
t.Fatalf("Type() = %v, want %v", et, proto)
}
}
if len(b) != len(want) {
t.Fatalf("Read returned %v bytes, want %v", len(b), len(want))
}
if !bytes.Equal(b, want) {
t.Fatalf("Read returned %x, want %x", b, want)
}
})
}
}
}
func TestPreserveSrcAddress(t *testing.T) {
baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99")
c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: true})
defer c.cleanup()
// Set LocalLinkAddress in route to the value of the bridged address.
r := &stack.Route{
RemoteLinkAddress: raddr,
LocalLinkAddress: baddr,
}
// WritePacket panics given a prependable with anything less than
// the minimum size of the ethernet header.
hdr := buffer.NewPrependable(header.EthernetMinimumSize)
if err := c.ep.WritePacket(r, hdr, buffer.VectorisedView{}, proto); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
// Read from the FD, then compare with what we wrote.
b := make([]byte, mtu)
n, err := syscall.Read(c.fds[0], b)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
b = b[:n]
h := header.Ethernet(b)
if a := h.SourceAddress(); a != baddr {
t.Fatalf("SourceAddress() = %v, want %v", a, baddr)
}
}
func TestDeliverPacket(t *testing.T) {
lengths := []int{100, 1000}
eths := []bool{true, false}
for _, eth := range eths {
for _, plen := range lengths {
t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) {
c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth})
defer c.cleanup()
// Build packet.
b := make([]byte, plen)
all := b
for i := range b {
b[i] = uint8(rand.Intn(256))
}
if !eth {
// So that it looks like an IPv4 packet.
b[0] = 0x40
} else {
hdr := make(header.Ethernet, header.EthernetMinimumSize)
hdr.Encode(&header.EthernetFields{
SrcAddr: raddr,
DstAddr: laddr,
Type: proto,
})
all = append(hdr, b...)
}
// Write packet via the file descriptor.
if _, err := syscall.Write(c.fds[0], all); err != nil {
t.Fatalf("Write failed: %v", err)
}
// Receive packet through the endpoint.
select {
case pi := <-c.ch:
want := packetInfo{
raddr: raddr,
proto: proto,
contents: b,
}
if !eth {
want.proto = header.IPv4ProtocolNumber
want.raddr = ""
}
if !reflect.DeepEqual(want, pi) {
t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want)
}
case <-time.After(10 * time.Second):
t.Fatalf("Timed out waiting for packet")
}
})
}
}
}
func TestBufConfigMaxLength(t *testing.T) {
got := 0
for _, i := range BufConfig {
got += i
}
want := header.MaxIPPacketSize // maximum TCP packet size
if got < want {
t.Errorf("total buffer size is invalid: got %d, want >= %d", got, want)
}
}
func TestBufConfigFirst(t *testing.T) {
// The stack assumes that the TCP/IP header is enterily contained in the first view.
// Therefore, the first view needs to be large enough to contain the maximum TCP/IP
// header, which is 120 bytes (60 bytes for IP + 60 bytes for TCP).
want := 120
got := BufConfig[0]
if got < want {
t.Errorf("first view has an invalid size: got %d, want >= %d", got, want)
}
}
func build(bufConfig []int) *endpoint {
e := &endpoint{
views: make([]buffer.View, len(bufConfig)),
iovecs: make([]syscall.Iovec, len(bufConfig)),
}
e.allocateViews(bufConfig)
return e
}
var capLengthTestCases = []struct {
comment string
config []int
n int
wantUsed int
wantLengths []int
}{
{
comment: "Single slice",
config: []int{2},
n: 1,
wantUsed: 1,
wantLengths: []int{1},
},
{
comment: "Multiple slices",
config: []int{1, 2},
n: 2,
wantUsed: 2,
wantLengths: []int{1, 1},
},
{
comment: "Entire buffer",
config: []int{1, 2},
n: 3,
wantUsed: 2,
wantLengths: []int{1, 2},
},
{
comment: "Entire buffer but not on the last slice",
config: []int{1, 2, 3},
n: 3,
wantUsed: 2,
wantLengths: []int{1, 2, 3},
},
}
func TestCapLength(t *testing.T) {
for _, c := range capLengthTestCases {
e := build(c.config)
used := e.capViews(c.n, c.config)
if used != c.wantUsed {
t.Errorf("Test \"%s\" failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed)
}
lengths := make([]int, len(e.views))
for i, v := range e.views {
lengths[i] = len(v)
}
if !reflect.DeepEqual(lengths, c.wantLengths) {
t.Errorf("Test \"%s\" failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths)
}
}
}