blob: 9a23f78b880e0fd87fd7194773ed2b2a0f939cd2 [file] [log] [blame]
// Copyright 2016 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dhcp
import (
"context"
"strings"
"testing"
"time"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/link/channel"
"github.com/google/netstack/tcpip/link/sniffer"
"github.com/google/netstack/tcpip/network/ipv4"
"github.com/google/netstack/tcpip/stack"
"github.com/google/netstack/tcpip/transport/udp"
)
const nicid = tcpip.NICID(1)
const serverAddr = tcpip.Address("\xc0\xa8\x03\x01")
func createStack(t *testing.T) *stack.Stack {
const defaultMTU = 65536
id, linkEP := channel.New(256, defaultMTU, "")
if testing.Verbose() {
id = sniffer.New(id)
}
go func() {
for pkt := range linkEP.C {
v := make(buffer.View, len(pkt.Header)+len(pkt.Payload))
copy(v, pkt.Header)
copy(v[len(pkt.Header):], pkt.Payload)
vv := v.ToVectorisedView([1]buffer.View{})
linkEP.Inject(pkt.Proto, &vv)
}
}()
s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}).(*stack.Stack)
if err := s.CreateNIC(nicid, id); err != nil {
t.Fatal(err)
}
if err := s.AddAddress(nicid, ipv4.ProtocolNumber, serverAddr); err != nil {
t.Fatal(err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: tcpip.Address(strings.Repeat("\x00", 4)),
Mask: tcpip.Address(strings.Repeat("\x00", 4)),
Gateway: "",
NIC: nicid,
}})
return s
}
func TestDHCP(t *testing.T) {
s := createStack(t)
clientAddrs := []tcpip.Address{"\xc0\xa8\x03\x02", "\xc0\xa8\x03\x03"}
serverCfg := Config{
ServerAddress: serverAddr,
SubnetMask: "\xff\xff\xff\x00",
Gateway: "\xc0\xa8\x03\xF0",
DomainNameServer: "\x08\x08\x08\x08",
LeaseLength: 24 * time.Hour,
}
serverCtx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := NewServer(serverCtx, s, clientAddrs, serverCfg)
if err != nil {
t.Fatal(err)
}
const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52")
c0 := NewClient(s, nicid, clientLinkAddr0, nil)
if err := c0.Request(context.Background(), ""); err != nil {
t.Fatal(err)
}
if got, want := c0.Address(), clientAddrs[0]; got != want {
t.Errorf("c.Addr()=%s, want=%s", got, want)
}
if err := c0.Request(context.Background(), ""); err != nil {
t.Fatal(err)
}
if got, want := c0.Address(), clientAddrs[0]; got != want {
t.Errorf("c.Addr()=%s, want=%s", got, want)
}
const clientLinkAddr1 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x53")
c1 := NewClient(s, nicid, clientLinkAddr1, nil)
if err := c1.Request(context.Background(), ""); err != nil {
t.Fatal(err)
}
if got, want := c1.Address(), clientAddrs[1]; got != want {
t.Errorf("c.Addr()=%s, want=%s", got, want)
}
if err := c0.Request(context.Background(), ""); err != nil {
t.Fatal(err)
}
if got, want := c0.Address(), clientAddrs[0]; got != want {
t.Errorf("c.Addr()=%s, want=%s", got, want)
}
if got, want := c0.Config(), serverCfg; got != want {
t.Errorf("client config:\n\t%#+v\nwant:\n\t%#+v", got, want)
}
}
func TestRenew(t *testing.T) {
s := createStack(t)
clientAddrs := []tcpip.Address{"\xc0\xa8\x03\x02"}
serverCfg := Config{
ServerAddress: serverAddr,
SubnetMask: "\xff\xff\xff\x00",
Gateway: "\xc0\xa8\x03\xF0",
DomainNameServer: "\x08\x08\x08\x08",
LeaseLength: 1 * time.Second,
}
serverCtx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err := NewServer(serverCtx, s, clientAddrs, serverCfg)
if err != nil {
t.Fatal(err)
}
count := 0
var curAddr tcpip.Address
addrCh := make(chan tcpip.Address)
acquiredFunc := func(oldAddr, newAddr tcpip.Address, cfg Config) {
if err := cfg.Error; err != nil {
t.Fatalf("acquisition %d failed: %v", count, err)
}
if oldAddr != curAddr {
t.Fatalf("aquisition %d: curAddr=%v, oldAddr=%v", count, curAddr, oldAddr)
}
if cfg.LeaseLength != time.Second {
t.Fatalf("aquisition %d: lease length: %v, want %v", count, cfg.LeaseLength, time.Second)
}
count++
curAddr = newAddr
addrCh <- newAddr
}
const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52")
c := NewClient(s, nicid, clientLinkAddr0, acquiredFunc)
c.Start()
var addr tcpip.Address
select {
case addr = <-addrCh:
t.Logf("got first address: %v", addr)
case <-time.After(5 * time.Second):
t.Fatal("timeout acquiring initial address")
}
select {
case newAddr := <-addrCh:
t.Logf("got renewal: %v", newAddr)
if newAddr != addr {
t.Fatalf("renewal address is %v, want %v", newAddr, addr)
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for address renewal")
}
c.Shutdown()
}
// Regression test for https://fuchsia.atlassian.net/browse/NET-17
func TestNoNullTerminator(t *testing.T) {
v := "\x02\x01\x06\x00" +
"\xc8\x37\xbe\x73\x00\x00\x80\x00\x00\x00\x00\x00\xc0\xa8\x2b\x92" +
"\xc0\xa8\x2b\x01\x00\x00\x00\x00\x00\x0f\x60\x0a\x23\x93\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x63\x82\x53\x63\x35\x01\x02\x36" +
"\x04\xc0\xa8\x2b\x01\x33\x04\x00\x00\x0e\x10\x3a\x04\x00\x00\x07" +
"\x08\x3b\x04\x00\x00\x0c\x4e\x01\x04\xff\xff\xff\x00\x1c\x04\xc0" +
"\xa8\x2b\xff\x03\x04\xc0\xa8\x2b\x01\x06\x04\xc0\xa8\x2b\x01\x2b" +
"\x0f\x41\x4e\x44\x52\x4f\x49\x44\x5f\x4d\x45\x54\x45\x52\x45\x44" +
"\xff"
h := header(v)
if !h.isValid() {
t.Error("failed to decode header")
}
if op := h.op(); op != opReply {
t.Errorf("bad opcode: %v expected: %v", op, opReply)
}
if _, err := h.options(); err != nil {
t.Errorf("bad options: %v", err)
}
}