blob: 3a04f605fb937c1aa9a628d103c002b7f73ffe4a [file] [log] [blame]
// Copyright 2016 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ipv4_test
import (
"context"
"encoding/binary"
"testing"
"time"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
"github.com/google/netstack/tcpip/link/channel"
"github.com/google/netstack/tcpip/link/sniffer"
"github.com/google/netstack/tcpip/network/ipv4"
"github.com/google/netstack/tcpip/stack"
"github.com/google/netstack/waiter"
)
const stackAddr = "\x0a\x00\x00\x01"
type testContext struct {
t *testing.T
linkEP *channel.Endpoint
s *stack.Stack
}
func newTestContext(t *testing.T) *testContext {
s := stack.New([]string{ipv4.ProtocolName}, []string{ipv4.PingProtocolName})
const defaultMTU = 65536
id, linkEP := channel.New(256, defaultMTU, "")
if testing.Verbose() {
id = sniffer.New(id)
}
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: "\x00\x00\x00\x00",
Mask: "\x00\x00\x00\x00",
Gateway: "",
NIC: 1,
}})
return &testContext{
t: t,
s: s,
linkEP: linkEP,
}
}
func (c *testContext) cleanup() {
close(c.linkEP.C)
}
func (c *testContext) loopback() {
go func() {
for pkt := range c.linkEP.C {
v := make(buffer.View, len(pkt.Header)+len(pkt.Payload))
copy(v, pkt.Header)
copy(v[len(pkt.Header):], pkt.Payload)
vv := v.ToVectorisedView([1]buffer.View{})
c.linkEP.Inject(pkt.Proto, &vv)
}
}()
}
func TestEcho(t *testing.T) {
c := newTestContext(t)
defer c.cleanup()
c.loopback()
ch := make(chan ipv4.PingReply, 1)
p := ipv4.Pinger{
Stack: c.s,
NICID: 1,
Addr: stackAddr,
Wait: 10 * time.Millisecond,
Count: 1, // one ping only
}
if err := p.Ping(context.Background(), ch); err != nil {
t.Fatalf("icmp.Ping failed: %v", err)
}
ping := <-ch
if ping.Error != nil {
t.Errorf("bad ping response: %v", ping.Error)
}
}
func TestEchoSequence(t *testing.T) {
c := newTestContext(t)
defer c.cleanup()
c.loopback()
const numPings = 3
ch := make(chan ipv4.PingReply, numPings)
p := ipv4.Pinger{
Stack: c.s,
NICID: 1,
Addr: stackAddr,
Wait: 10 * time.Millisecond,
Count: numPings,
}
if err := p.Ping(context.Background(), ch); err != nil {
t.Fatalf("icmp.Ping failed: %v", err)
}
for i := uint16(0); i < numPings; i++ {
ping := <-ch
if ping.Error != nil {
t.Errorf("i=%d bad ping response: %v", i, ping.Error)
}
if ping.SeqNumber != i {
t.Errorf("SeqNumber=%d, want %d", ping.SeqNumber, i)
}
}
}
const (
stackAddr0 = "\x0a\x00\x00\x02"
stackAddr1 = "\x0a\x00\x00\x03"
linkAddr0 = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
)
type testEndpointContext struct {
t *testing.T
s *stack.Stack
linkEP0 *channel.Endpoint
linkEP1 *channel.Endpoint
icmpCh chan header.ICMPv4
}
func (c *testEndpointContext) cleanup() {
close(c.linkEP0.C)
close(c.linkEP1.C)
}
func newTestEndpointContext(t *testing.T) *testEndpointContext {
c := &testEndpointContext{
t: t,
s: stack.New([]string{ipv4.ProtocolName}, []string{ipv4.PingProtocolName}),
icmpCh: make(chan header.ICMPv4, 10),
}
const defaultMTU = 65536
id0, linkEP := channel.New(256, defaultMTU, linkAddr0)
c.linkEP0 = linkEP
if testing.Verbose() {
id0 = sniffer.New(id0)
}
if err := c.s.CreateNIC(1, id0); err != nil {
t.Fatalf("CreateNIC s: %v", err)
}
id1, linkEP := channel.New(256, defaultMTU, linkAddr1)
c.linkEP1 = linkEP
if testing.Verbose() {
id1 = sniffer.New(id1)
}
if err := c.s.CreateNIC(2, id1); err != nil {
t.Fatalf("CreateNIC s: %v", err)
}
if err := c.s.AddAddress(2, ipv4.ProtocolNumber, stackAddr0); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
if err := c.s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
c.s.SetRouteTable([]tcpip.Route{
{
Destination: stackAddr0,
Mask: "\xFF\xFF\xFF\xFF",
Gateway: "",
NIC: 1,
},
{
Destination: stackAddr1,
Mask: "\xFF\xFF\xFF\xFF",
Gateway: "",
NIC: 2,
},
})
go c.routePackets(c.linkEP0.C, c.linkEP1)
go c.routePackets(c.linkEP1.C, c.linkEP0)
return c
}
func (c *testEndpointContext) countPacket(pkt channel.PacketInfo) {
if pkt.Proto != header.IPv4ProtocolNumber {
c.t.Fatalf("Received non IPV4 packet: 0x%x", pkt.Proto)
}
ipv4 := header.IPv4(pkt.Header)
c.icmpCh <- header.ICMPv4(append(pkt.Header[ipv4.HeaderLength():], pkt.Payload...))
}
func (c *testEndpointContext) routePackets(ch <-chan channel.PacketInfo, ep *channel.Endpoint) {
for pkt := range ch {
c.countPacket(pkt)
v := buffer.View(append(pkt.Header, pkt.Payload...))
vs := []buffer.View{v}
vv := buffer.NewVectorisedView(len(v), vs)
ep.InjectLinkAddr(pkt.Proto, ep.LinkAddress(), &vv)
}
}
type callbackStub struct {
f func(e *waiter.Entry)
}
func (c *callbackStub) Callback(e *waiter.Entry) {
c.f(e)
}
func TestEndpoints(t *testing.T) {
c := newTestEndpointContext(t)
defer c.cleanup()
wq0 := &waiter.Queue{}
ep0, err := c.s.NewEndpoint(ipv4.PingProtocolNumber, ipv4.ProtocolNumber, wq0)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
defer ep0.Close()
wq1 := &waiter.Queue{}
ep1, err := c.s.NewEndpoint(ipv4.PingProtocolNumber, ipv4.ProtocolNumber, wq1)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
defer ep1.Close()
if err := ep0.Bind(tcpip.FullAddress{NIC: 1, Addr: stackAddr0}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
if err := ep1.Bind(tcpip.FullAddress{NIC: 2, Addr: stackAddr1}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
echos := 64
ping := func(wq *waiter.Queue, ep tcpip.Endpoint, data []byte) {
outPkt := make([]byte, header.ICMPv4MinimumSize+4+len(data))
icmpv4 := header.ICMPv4(outPkt[:header.ICMPv4MinimumSize])
icmpv4.SetType(header.ICMPv4Echo)
copy(outPkt[header.ICMPv4MinimumSize+4:], data)
binary.BigEndian.PutUint16(outPkt[header.ICMPv4MinimumSize:], 0)
for seqno := uint16(1); seqno <= uint16(echos); seqno++ {
binary.BigEndian.PutUint16(outPkt[header.ICMPv4MinimumSize+2:], seqno)
// We need to register with the waiter queue before we try writing, since
// the notification that the endpoint received a response may arrive immediately.
ready := make(chan struct{})
e := waiter.Entry{Callback: &callbackStub{func(*waiter.Entry) { close(ready) }}}
wq.EventRegister(&e, waiter.EventIn)
n, err := ep.Write(buffer.View(outPkt), nil)
if err != nil {
c.t.Fatalf("Write failed: %v\n", err)
} else if n != uintptr(len(outPkt)) {
c.t.Fatalf("Write was short: %v\n", n)
}
// Avoid reading until we have something to read
select {
case <-time.After(1 * time.Second):
c.t.Fatalf("Timed out waiting for socket to be readable")
case <-ready:
}
wq.EventUnregister(&e)
inPkt, err := ep.Read(nil)
if err != nil {
c.t.Fatalf("Read failed: %v\n", err)
}
// Verify the contents of the packet we just read.
var icmp header.ICMPv4 = []byte(inPkt)
if icmp.Type() != header.ICMPv4EchoReply {
c.t.Fatalf("Unexpected packet type: %d", icmp.Type())
}
inSeqno := binary.BigEndian.Uint16(inPkt[header.ICMPv4MinimumSize+2 : header.ICMPv4MinimumSize+4])
if inSeqno != seqno {
c.t.Fatalf("Unexpected sequence number: %d", inSeqno)
}
outData := outPkt[header.ICMPv4EchoMinimumSize:]
inData := inPkt[header.ICMPv4EchoMinimumSize:]
if len(outData) != len(inData) {
c.t.Fatalf("Read packet of unexpected length: %d\n", len(inData))
}
for i := range inData {
if inData[i] != outData[i] {
c.t.Fatalf("Data mismatch")
}
}
}
}
data := []byte{0xaa, 0xab, 0xac}
go ping(wq0, ep0, data)
data = []byte{0xad, 0xae, 0xaf}
go ping(wq1, ep1, data)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
stats := make(map[header.ICMPv4Type]int)
for {
select {
case <-ctx.Done():
t.Errorf("Timeout waiting for ICMP, got: %#+v", stats)
return
case icmp := <-c.icmpCh:
if icmp.Type() != header.ICMPv4Echo && icmp.Type() != header.ICMPv4EchoReply {
c.t.Fatalf("Unexpected type: %d", icmp.Type())
}
stats[icmp.Type()]++
if stats[icmp.Type()] > echos*2 {
c.t.Fatalf("Too many (%d) packets of type %d", stats[icmp.Type()], icmp.Type())
}
if len(stats) == 2 && stats[header.ICMPv4Echo] == echos*2 && stats[header.ICMPv4EchoReply] == echos*2 {
return
}
}
}
}