blob: 6728370c31c5458178059a66d90d29d683805742 [file] [log] [blame]
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at //
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"bytes"
"fmt"
"testing"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
func TestPacketHeaderPush(t *testing.T) {
for _, test := range []struct {
name string
reserved int
link []byte
network []byte
transport []byte
data []byte
}{
{
name: "construct empty packet",
},
{
name: "construct link header only packet",
reserved: 60,
link: makeView(10),
},
{
name: "construct link and network header only packet",
reserved: 60,
link: makeView(10),
network: makeView(20),
},
{
name: "construct header only packet",
reserved: 60,
link: makeView(10),
network: makeView(20),
transport: makeView(30),
},
{
name: "construct data only packet",
data: makeView(40),
},
{
name: "construct L3 packet",
reserved: 60,
network: makeView(20),
transport: makeView(30),
data: makeView(40),
},
{
name: "construct L2 packet",
reserved: 60,
link: makeView(10),
network: makeView(20),
transport: makeView(30),
data: makeView(40),
},
} {
t.Run(test.name, func(t *testing.T) {
pk := NewPacketBuffer(PacketBufferOptions{
ReserveHeaderBytes: test.reserved,
// Make a copy of data to make sure our truth data won't be taint by
// PacketBuffer.
Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(),
})
allHdrSize := len(test.link) + len(test.network) + len(test.transport)
// Check the initial values for packet.
checkInitialPacketBuffer(t, pk, PacketBufferOptions{
ReserveHeaderBytes: test.reserved,
Data: buffer.View(test.data).ToVectorisedView(),
})
// Push headers.
if v := test.transport; len(v) > 0 {
copy(pk.TransportHeader().Push(len(v)), v)
}
if v := test.network; len(v) > 0 {
copy(pk.NetworkHeader().Push(len(v)), v)
}
if v := test.link; len(v) > 0 {
copy(pk.LinkHeader().Push(len(v)), v)
}
// Check the after values for packet.
if got, want := pk.ReservedHeaderBytes(), test.reserved; got != want {
t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want)
}
if got, want := pk.AvailableHeaderBytes(), test.reserved-allHdrSize; got != want {
t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want)
}
if got, want := pk.HeaderSize(), allHdrSize; got != want {
t.Errorf("After pk.HeaderSize() = %d, want %d", got, want)
}
if got, want := pk.Size(), allHdrSize+len(test.data); got != want {
t.Errorf("After pk.Size() = %d, want %d", got, want)
}
checkData(t, pk, test.data)
checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...),
concatViews(test.link, test.network, test.transport, test.data))
// Check the after values for each header.
checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link)
checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network)
checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport)
// Check the after values for PayloadSince.
checkViewEqual(t, "After PayloadSince(LinkHeader)",
PayloadSince(pk.LinkHeader()),
concatViews(test.link, test.network, test.transport, test.data))
checkViewEqual(t, "After PayloadSince(NetworkHeader)",
PayloadSince(pk.NetworkHeader()),
concatViews(test.network, test.transport, test.data))
checkViewEqual(t, "After PayloadSince(TransportHeader)",
PayloadSince(pk.TransportHeader()),
concatViews(test.transport, test.data))
})
}
}
func TestPacketHeaderConsume(t *testing.T) {
for _, test := range []struct {
name string
data []byte
link int
network int
transport int
}{
{
name: "parse L2 packet",
data: concatViews(makeView(10), makeView(20), makeView(30), makeView(40)),
link: 10,
network: 20,
transport: 30,
},
{
name: "parse L3 packet",
data: concatViews(makeView(20), makeView(30), makeView(40)),
network: 20,
transport: 30,
},
} {
t.Run(test.name, func(t *testing.T) {
pk := NewPacketBuffer(PacketBufferOptions{
// Make a copy of data to make sure our truth data won't be taint by
// PacketBuffer.
Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(),
})
// Check the initial values for packet.
checkInitialPacketBuffer(t, pk, PacketBufferOptions{
Data: buffer.View(test.data).ToVectorisedView(),
})
// Consume headers.
if size := test.link; size > 0 {
if _, ok := pk.LinkHeader().Consume(size); !ok {
t.Fatalf("pk.LinkHeader().Consume() = false, want true")
}
}
if size := test.network; size > 0 {
if _, ok := pk.NetworkHeader().Consume(size); !ok {
t.Fatalf("pk.NetworkHeader().Consume() = false, want true")
}
}
if size := test.transport; size > 0 {
if _, ok := pk.TransportHeader().Consume(size); !ok {
t.Fatalf("pk.TransportHeader().Consume() = false, want true")
}
}
allHdrSize := test.link + test.network + test.transport
// Check the after values for packet.
if got, want := pk.ReservedHeaderBytes(), 0; got != want {
t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want)
}
if got, want := pk.AvailableHeaderBytes(), 0; got != want {
t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want)
}
if got, want := pk.HeaderSize(), allHdrSize; got != want {
t.Errorf("After pk.HeaderSize() = %d, want %d", got, want)
}
if got, want := pk.Size(), len(test.data); got != want {
t.Errorf("After pk.Size() = %d, want %d", got, want)
}
// After state of pk.
var (
link = test.data[:test.link]
network = test.data[test.link:][:test.network]
transport = test.data[test.link+test.network:][:test.transport]
payload = test.data[allHdrSize:]
)
checkData(t, pk, payload)
checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data)
// Check the after values for each header.
checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link)
checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network)
checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport)
// Check the after values for PayloadSince.
checkViewEqual(t, "After PayloadSince(LinkHeader)",
PayloadSince(pk.LinkHeader()),
concatViews(link, network, transport, payload))
checkViewEqual(t, "After PayloadSince(NetworkHeader)",
PayloadSince(pk.NetworkHeader()),
concatViews(network, transport, payload))
checkViewEqual(t, "After PayloadSince(TransportHeader)",
PayloadSince(pk.TransportHeader()),
concatViews(transport, payload))
})
}
}
func TestPacketHeaderConsumeDataTooShort(t *testing.T) {
data := makeView(10)
pk := NewPacketBuffer(PacketBufferOptions{
// Make a copy of data to make sure our truth data won't be taint by
// PacketBuffer.
Data: buffer.NewViewFromBytes(data).ToVectorisedView(),
})
// Consume should fail if pkt.Data is too short.
if _, ok := pk.LinkHeader().Consume(11); ok {
t.Fatalf("pk.LinkHeader().Consume() = _, true; want _, false")
}
if _, ok := pk.NetworkHeader().Consume(11); ok {
t.Fatalf("pk.NetworkHeader().Consume() = _, true; want _, false")
}
if _, ok := pk.TransportHeader().Consume(11); ok {
t.Fatalf("pk.TransportHeader().Consume() = _, true; want _, false")
}
// Check packet should look the same as initial packet.
checkInitialPacketBuffer(t, pk, PacketBufferOptions{
Data: buffer.View(data).ToVectorisedView(),
})
}
func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) {
const headerSize = 10
pk := NewPacketBuffer(PacketBufferOptions{
ReserveHeaderBytes: headerSize * int(numHeaderType),
})
for _, h := range []PacketHeader{
pk.TransportHeader(),
pk.NetworkHeader(),
pk.LinkHeader(),
} {
t.Run("PushedTwice/"+h.typ.String(), func(t *testing.T) {
h.Push(headerSize)
defer func() { recover() }()
h.Push(headerSize)
t.Fatal("Second push should have panicked")
})
}
}
func TestPacketHeaderConsumeCalledAtMostOnce(t *testing.T) {
const headerSize = 10
pk := NewPacketBuffer(PacketBufferOptions{
Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(),
})
for _, h := range []PacketHeader{
pk.LinkHeader(),
pk.NetworkHeader(),
pk.TransportHeader(),
} {
t.Run("ConsumedTwice/"+h.typ.String(), func(t *testing.T) {
if _, ok := h.Consume(headerSize); !ok {
t.Fatal("First consume should succeed")
}
defer func() { recover() }()
h.Consume(headerSize)
t.Fatal("Second consume should have panicked")
})
}
}
func TestPacketHeaderPushThenConsumePanics(t *testing.T) {
const headerSize = 10
pk := NewPacketBuffer(PacketBufferOptions{
ReserveHeaderBytes: headerSize * int(numHeaderType),
})
for _, h := range []PacketHeader{
pk.TransportHeader(),
pk.NetworkHeader(),
pk.LinkHeader(),
} {
t.Run(h.typ.String(), func(t *testing.T) {
h.Push(headerSize)
defer func() { recover() }()
h.Consume(headerSize)
t.Fatal("Consume should have panicked")
})
}
}
func TestPacketHeaderConsumeThenPushPanics(t *testing.T) {
const headerSize = 10
pk := NewPacketBuffer(PacketBufferOptions{
Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(),
})
for _, h := range []PacketHeader{
pk.LinkHeader(),
pk.NetworkHeader(),
pk.TransportHeader(),
} {
t.Run(h.typ.String(), func(t *testing.T) {
h.Consume(headerSize)
defer func() { recover() }()
h.Push(headerSize)
t.Fatal("Push should have panicked")
})
}
}
func TestPacketBufferData(t *testing.T) {
for _, tc := range []struct {
name string
makePkt func(*testing.T) *PacketBuffer
data string
}{
{
name: "inbound packet",
makePkt: func(*testing.T) *PacketBuffer {
pkt := NewPacketBuffer(PacketBufferOptions{
Data: vv("aabbbbccccccDATA"),
})
pkt.LinkHeader().Consume(2)
pkt.NetworkHeader().Consume(4)
pkt.TransportHeader().Consume(6)
return pkt
},
data: "DATA",
},
{
name: "outbound packet",
makePkt: func(*testing.T) *PacketBuffer {
pkt := NewPacketBuffer(PacketBufferOptions{
ReserveHeaderBytes: 12,
Data: vv("DATA"),
})
copy(pkt.TransportHeader().Push(6), []byte("cccccc"))
copy(pkt.NetworkHeader().Push(4), []byte("bbbb"))
copy(pkt.LinkHeader().Push(2), []byte("aa"))
return pkt
},
data: "DATA",
},
} {
t.Run(tc.name, func(t *testing.T) {
// PullUp
for _, n := range []int{1, len(tc.data)} {
t.Run(fmt.Sprintf("PullUp%d", n), func(t *testing.T) {
pkt := tc.makePkt(t)
v, ok := pkt.Data().PullUp(n)
wantV := []byte(tc.data)[:n]
if !ok || !bytes.Equal(v, wantV) {
t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want %q, true", n, v, ok, wantV)
}
})
}
t.Run("PullUpOutOfBounds", func(t *testing.T) {
n := len(tc.data) + 1
pkt := tc.makePkt(t)
v, ok := pkt.Data().PullUp(n)
if ok || v != nil {
t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want nil, false", n, v, ok)
}
})
// TrimFront
for _, n := range []int{1, len(tc.data)} {
t.Run(fmt.Sprintf("TrimFront%d", n), func(t *testing.T) {
pkt := tc.makePkt(t)
pkt.Data().TrimFront(n)
checkData(t, pkt, []byte(tc.data)[n:])
})
}
// CapLength
for _, n := range []int{0, 1, len(tc.data)} {
t.Run(fmt.Sprintf("CapLength%d", n), func(t *testing.T) {
pkt := tc.makePkt(t)
pkt.Data().CapLength(n)
want := []byte(tc.data)
if n < len(want) {
want = want[:n]
}
checkData(t, pkt, want)
})
}
// Views
t.Run("Views", func(t *testing.T) {
pkt := tc.makePkt(t)
checkData(t, pkt, []byte(tc.data))
})
// AppendView
t.Run("AppendView", func(t *testing.T) {
s := "APPEND"
pkt := tc.makePkt(t)
pkt.Data().AppendView(buffer.View(s))
checkData(t, pkt, []byte(tc.data+s))
})
// ReadFromData/VV
for _, n := range []int{0, 1, 2, 7, 10, 14, 20} {
t.Run(fmt.Sprintf("ReadFromData%d", n), func(t *testing.T) {
s := "TO READ"
otherPkt := NewPacketBuffer(PacketBufferOptions{
Data: vv(s, s),
})
s += s
pkt := tc.makePkt(t)
pkt.Data().ReadFromData(otherPkt.Data(), n)
if n < len(s) {
s = s[:n]
}
checkData(t, pkt, []byte(tc.data+s))
})
t.Run(fmt.Sprintf("ReadFromVV%d", n), func(t *testing.T) {
s := "TO READ"
srcVV := vv(s, s)
s += s
pkt := tc.makePkt(t)
pkt.Data().ReadFromVV(&srcVV, n)
if n < len(s) {
s = s[:n]
}
checkData(t, pkt, []byte(tc.data+s))
})
}
// ExtractVV
t.Run("ExtractVV", func(t *testing.T) {
pkt := tc.makePkt(t)
extractedVV := pkt.Data().ExtractVV()
got := extractedVV.ToOwnedView()
want := []byte(tc.data)
if !bytes.Equal(got, want) {
t.Errorf("pkt.Data().ExtractVV().ToOwnedView() = %q, want %q", got, want)
}
})
// Replace
t.Run("Replace", func(t *testing.T) {
s := "REPLACED"
pkt := tc.makePkt(t)
pkt.Data().Replace(vv(s))
checkData(t, pkt, []byte(s))
})
})
}
}
func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) {
t.Helper()
reserved := opts.ReserveHeaderBytes
if got, want := pk.ReservedHeaderBytes(), reserved; got != want {
t.Errorf("Initial pk.ReservedHeaderBytes() = %d, want %d", got, want)
}
if got, want := pk.AvailableHeaderBytes(), reserved; got != want {
t.Errorf("Initial pk.AvailableHeaderBytes() = %d, want %d", got, want)
}
if got, want := pk.HeaderSize(), 0; got != want {
t.Errorf("Initial pk.HeaderSize() = %d, want %d", got, want)
}
data := opts.Data.ToView()
if got, want := pk.Size(), len(data); got != want {
t.Errorf("Initial pk.Size() = %d, want %d", got, want)
}
checkData(t, pk, data)
checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data)
// Check the initial values for each header.
checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil)
checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil)
checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil)
// Check the initial valies for PayloadSince.
checkViewEqual(t, "Initial PayloadSince(LinkHeader)",
PayloadSince(pk.LinkHeader()), data)
checkViewEqual(t, "Initial PayloadSince(NetworkHeader)",
PayloadSince(pk.NetworkHeader()), data)
checkViewEqual(t, "Initial PayloadSince(TransportHeader)",
PayloadSince(pk.TransportHeader()), data)
}
func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) {
t.Helper()
checkViewEqual(t, name+".View()", h.View(), want)
}
func checkViewEqual(t *testing.T, what string, got, want buffer.View) {
t.Helper()
if !bytes.Equal(got, want) {
t.Errorf("%s = %x, want %x", what, got, want)
}
}
func checkData(t *testing.T, pkt *PacketBuffer, want []byte) {
t.Helper()
if got := concatViews(pkt.Data().Views()...); !bytes.Equal(got, want) {
t.Errorf("pkt.Data().Views() = %x, want %x", got, want)
}
if got := pkt.Data().Size(); got != len(want) {
t.Errorf("pkt.Data().Size() = %d, want %d", got, len(want))
}
t.Run("AsRange", func(t *testing.T) {
// Full range
checkRange(t, pkt.Data().AsRange(), want)
// SubRange
for _, off := range []int{0, 1, len(want), len(want) + 1} {
t.Run(fmt.Sprintf("SubRange%d", off), func(t *testing.T) {
// Empty when off is greater than the size of range.
var sub []byte
if off < len(want) {
sub = want[off:]
}
checkRange(t, pkt.Data().AsRange().SubRange(off), sub)
})
}
// Capped
for _, n := range []int{0, 1, len(want), len(want) + 1} {
t.Run(fmt.Sprintf("Capped%d", n), func(t *testing.T) {
sub := want
if n < len(sub) {
sub = sub[:n]
}
checkRange(t, pkt.Data().AsRange().Capped(n), sub)
})
}
})
}
func checkRange(t *testing.T, r Range, data []byte) {
if got, want := r.Size(), len(data); got != want {
t.Errorf("r.Size() = %d, want %d", got, want)
}
if got := r.AsView(); !bytes.Equal(got, data) {
t.Errorf("r.AsView() = %x, want %x", got, data)
}
if got := r.ToOwnedView(); !bytes.Equal(got, data) {
t.Errorf("r.ToOwnedView() = %x, want %x", got, data)
}
if got, want := r.Checksum(), header.Checksum(data, 0 /* initial */); got != want {
t.Errorf("r.Checksum() = %x, want %x", got, want)
}
}
func vv(pieces ...string) buffer.VectorisedView {
var views []buffer.View
var size int
for _, p := range pieces {
v := buffer.View([]byte(p))
size += len(v)
views = append(views, v)
}
return buffer.NewVectorisedView(size, views)
}
func makeView(size int) buffer.View {
b := byte(size)
return bytes.Repeat([]byte{b}, size)
}
func concatViews(views ...buffer.View) buffer.View {
var all buffer.View
for _, v := range views {
all = append(all, v...)
}
return all
}