dhcp: handle multiple DHCP servers

The client was not being careful enough and could be confused by
multiple DHCP servers on a network.

After sending a DISCOVER, the client would recieve two OFFERs.
It would Read() the first, send a REQUEST, and then Read() again
waiting for an ACK. The first thing the second Read() got was a
second OFFER. This was being mis-interpreted as a bad ACK, and
the process started over. A successful DHCP connection required
everything to be negotiated with one server before the other
responded, which could take minutes.

This CL explicitly checks the incoming message types, so
subsequent OFFER messages can be ignored.

It also fixes a small server bug: an incoming REQUEST for a
different server should be ignored, not NAKed.

This deals with the first problem in NET-43.

Change-Id: I133b1183c2a019041cd47d9bfaf4e09eb97cfcd3
diff --git a/dhcp/client.go b/dhcp/client.go
index 879c456..f474b33 100644
--- a/dhcp/client.go
+++ b/dhcp/client.go
@@ -137,7 +137,7 @@
 	rand.Read(xid[:])
 
 	// DHCPDISCOVERY
-	options := options{
+	discOpts := options{
 		{optDHCPMsgType, []byte{byte(dhcpDISCOVER)}},
 		{optParamReq, []byte{
 			1,  // request subnet mask
@@ -147,22 +147,22 @@
 		}},
 	}
 	if requestedAddr != "" {
-		options = append(options, option{optReqIPAddr, []byte(requestedAddr)})
+		discOpts = append(discOpts, option{optReqIPAddr, []byte(requestedAddr)})
 	}
 	var clientID []byte
 	if len(c.linkAddr) == 6 {
 		clientID = make([]byte, 7)
 		clientID[0] = 1 // htype: ARP Ethernet from RFC 1700
 		copy(clientID[1:], c.linkAddr)
-		options = append(options, option{optClientID, clientID})
+		discOpts = append(discOpts, option{optClientID, clientID})
 	}
-	h := make(header, headerBaseSize+options.len())
+	h := make(header, headerBaseSize+discOpts.len())
 	h.init()
 	h.setOp(opRequest)
 	copy(h.xidbytes(), xid[:])
 	h.setBroadcast()
 	copy(h.chaddr(), c.linkAddr)
-	h.setOptions(options)
+	h.setOptions(discOpts)
 
 	serverAddr := &tcpip.FullAddress{
 		Addr: "\xff\xff\xff\xff",
@@ -178,6 +178,7 @@
 	defer wq.EventUnregister(&we)
 
 	// DHCPOFFER
+	var opts options
 	for {
 		var addr tcpip.FullAddress
 		v, err := epin.Read(&addr)
@@ -190,13 +191,15 @@
 			}
 		}
 		h = header(v)
-		if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) {
-			break
+		var valid bool
+		opts, valid, err = loadDHCPReply(h, dhcpOFFER, xid[:])
+		if !valid {
+			if err != nil {
+				// TODO: report malformed server responses
+			}
+			continue
 		}
-	}
-	opts, err := h.options()
-	if err != nil {
-		return fmt.Errorf("dhcp offer: %v", err)
+		break
 	}
 
 	var ack bool
@@ -243,15 +246,15 @@
 	for i, b := 0, h.giaddr(); i < len(b); i++ {
 		b[i] = 0
 	}
-	options = []option{
+	reqOpts := []option{
 		{optDHCPMsgType, []byte{byte(dhcpREQUEST)}},
 		{optReqIPAddr, []byte(addr)},
 		{optDHCPServer, []byte(cfg.ServerAddress)},
 	}
 	if len(clientID) != 0 {
-		options = append(options, option{optClientID, clientID})
+		reqOpts = append(reqOpts, option{optClientID, clientID})
 	}
-	h.setOptions(options)
+	h.setOptions(reqOpts)
 	if _, err := ep.Write([]byte(h), serverAddr); err != nil {
 		return fmt.Errorf("dhcp discovery write: %v", err)
 	}
@@ -269,31 +272,23 @@
 			}
 		}
 		h = header(v)
-		if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) {
-			break
+		var valid bool
+		opts, valid, err = loadDHCPReply(h, dhcpACK, xid[:])
+		if !valid {
+			if err != nil {
+				// TODO: report malformed server responses
+			}
+			if opts, valid, _ = loadDHCPReply(h, dhcpNAK, xid[:]); valid {
+				if msg := opts.message(); msg != "" {
+					return fmt.Errorf("dhcp: NAK %q", msg)
+				}
+				return fmt.Errorf("dhcp: NAK with no message")
+			}
+			continue
 		}
+		break
 	}
-	opts, err = h.options()
-	if err != nil {
-		return fmt.Errorf("dhcp ack: %v", err)
-	}
-	if err := cfg.decode(opts); err != nil {
-		return fmt.Errorf("dhcp ack bad options: %v", err)
-	}
-	msgtype, err := opts.dhcpMsgType()
-	if err != nil {
-		return fmt.Errorf("dhcp ack: %v", err)
-	}
-	if msgtype == dhcpNAK {
-		if msg := opts.message(); msg != "" {
-			return fmt.Errorf("dhcp: NAK %q", msg)
-		}
-		return fmt.Errorf("dhcp: NAK with no message")
-	}
-	ack = msgtype == dhcpACK
-	if !ack {
-		return fmt.Errorf("dhcp: request not acknowledged")
-	}
+	ack = true
 	if cfg.LeaseLength != 0 {
 		go c.renewAfter(cfg.LeaseLength)
 	}
@@ -321,3 +316,21 @@
 		}
 	}()
 }
+
+func loadDHCPReply(h header, typ dhcpMsgType, xid []byte) (opts options, valid bool, err error) {
+	if !h.isValid() || h.op() != opReply || !bytes.Equal(h.xidbytes(), xid[:]) {
+		return nil, false, nil
+	}
+	opts, err = h.options()
+	if err != nil {
+		return nil, false, err
+	}
+	msgtype, err := opts.dhcpMsgType()
+	if err != nil {
+		return nil, false, err
+	}
+	if msgtype != typ {
+		return nil, false, nil
+	}
+	return opts, true, nil
+}
diff --git a/dhcp/dhcp_test.go b/dhcp/dhcp_test.go
index 92c5c41..0f6d5da 100644
--- a/dhcp/dhcp_test.go
+++ b/dhcp/dhcp_test.go
@@ -17,6 +17,7 @@
 	"github.com/google/netstack/tcpip/network/ipv4"
 	"github.com/google/netstack/tcpip/stack"
 	"github.com/google/netstack/tcpip/transport/udp"
+	"github.com/google/netstack/waiter"
 )
 
 const nicid = tcpip.NICID(1)
@@ -226,3 +227,88 @@
 		t.Errorf("bad options: %v", err)
 	}
 }
+
+func teeConn(c conn) (conn, conn) {
+	dup1 := &dupConn{
+		c:   c,
+		dup: make(chan connMsg, 8),
+	}
+	dup2 := &chConn{
+		c:  c,
+		ch: dup1.dup,
+	}
+	return dup1, dup2
+}
+
+type connMsg struct {
+	buf  buffer.View
+	addr tcpip.FullAddress
+	err  error
+}
+
+type dupConn struct {
+	c   conn
+	dup chan connMsg
+}
+
+func (c *dupConn) Read() (buffer.View, tcpip.FullAddress, error) {
+	v, addr, err := c.c.Read()
+	c.dup <- connMsg{v, addr, err}
+	return v, addr, err
+}
+func (c *dupConn) Write(b []byte, addr *tcpip.FullAddress) error { return c.c.Write(b, addr) }
+
+type chConn struct {
+	ch chan connMsg
+	c  conn
+}
+
+func (c *chConn) Read() (buffer.View, tcpip.FullAddress, error) {
+	msg := <-c.ch
+	return msg.buf, msg.addr, msg.err
+}
+func (c *chConn) Write(b []byte, addr *tcpip.FullAddress) error { return c.c.Write(b, addr) }
+
+func TestTwoServers(t *testing.T) {
+	s := createStack(t)
+
+	wq := new(waiter.Queue)
+	ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+	if err != nil {
+		t.Fatalf("dhcp: server endpoint: %v", err)
+	}
+	if err = ep.Bind(tcpip.FullAddress{Port: serverPort}, nil); err != nil {
+		t.Fatalf("dhcp: server bind: %v", err)
+	}
+
+	serverCtx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+	c1, c2 := teeConn(newEPConn(serverCtx, wq, ep))
+
+	_, err = newServer(serverCtx, c1, []tcpip.Address{"\xc0\xa8\x03\x02"}, Config{
+		ServerAddress: "\xc0\xa8\x03\x01",
+		SubnetMask:    "\xff\xff\xff\x00",
+		Gateway:       "\xc0\xa8\x03\xF0",
+		DNS:           []tcpip.Address{"\x08\x08\x08\x08"},
+		LeaseLength:   30 * time.Minute,
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+	_, err = newServer(serverCtx, c2, []tcpip.Address{"\xc0\xa8\x04\x02"}, Config{
+		ServerAddress: "\xc0\xa8\x04\x01",
+		SubnetMask:    "\xff\xff\xff\x00",
+		Gateway:       "\xc0\xa8\x03\xF0",
+		DNS:           []tcpip.Address{"\x08\x08\x08\x08"},
+		LeaseLength:   30 * time.Minute,
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52")
+	c := NewClient(s, nicid, clientLinkAddr0, nil)
+	if err := c.Request(context.Background(), ""); err != nil {
+		t.Fatal(err)
+	}
+}
diff --git a/dhcp/server.go b/dhcp/server.go
index 90ec04b..8162acc 100644
--- a/dhcp/server.go
+++ b/dhcp/server.go
@@ -7,6 +7,7 @@
 import (
 	"context"
 	"fmt"
+	"io"
 	"log"
 	"sync"
 	"time"
@@ -21,10 +22,8 @@
 
 // Server is a DHCP server.
 type Server struct {
-	stack     *stack.Stack
+	conn      conn
 	broadcast tcpip.FullAddress
-	wq        waiter.Queue
-	ep        tcpip.Endpoint
 	addrs     []tcpip.Address // TODO: use a tcpip.AddressMask or range structure
 	cfg       Config
 	cfgopts   []option // cfg to send to client
@@ -35,14 +34,79 @@
 	leases map[tcpip.LinkAddress]serverLease
 }
 
+// conn is a blocking read/write network endpoint.
+type conn interface {
+	Read() (buffer.View, tcpip.FullAddress, error)
+	Write([]byte, *tcpip.FullAddress) error
+}
+
+type epConn struct {
+	ctx  context.Context
+	wq   *waiter.Queue
+	ep   tcpip.Endpoint
+	we   waiter.Entry
+	inCh chan struct{}
+}
+
+func newEPConn(ctx context.Context, wq *waiter.Queue, ep tcpip.Endpoint) *epConn {
+	c := &epConn{
+		ctx: ctx,
+		wq:  wq,
+		ep:  ep,
+	}
+	c.we, c.inCh = waiter.NewChannelEntry(nil)
+	wq.EventRegister(&c.we, waiter.EventIn)
+
+	go func() {
+		<-ctx.Done()
+		wq.EventUnregister(&c.we)
+	}()
+
+	return c
+}
+
+func (c *epConn) Read() (buffer.View, tcpip.FullAddress, error) {
+	for {
+		var addr tcpip.FullAddress
+		v, err := c.ep.Read(&addr)
+		if err == tcpip.ErrWouldBlock {
+			select {
+			case <-c.inCh:
+				continue
+			case <-c.ctx.Done():
+				return nil, tcpip.FullAddress{}, io.EOF
+			}
+		}
+		return v, addr, err
+	}
+}
+
+func (c *epConn) Write(b []byte, addr *tcpip.FullAddress) error {
+	_, err := c.ep.Write(b, addr)
+	return err
+}
+
 // NewServer creates a new DHCP server and begins serving.
 // The server continues serving until ctx is done.
 func NewServer(ctx context.Context, stack *stack.Stack, addrs []tcpip.Address, cfg Config) (*Server, error) {
+	wq := new(waiter.Queue)
+	ep, err := stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+	if err != nil {
+		return nil, fmt.Errorf("dhcp: server endpoint: %v", err)
+	}
+	if err := ep.Bind(tcpip.FullAddress{Port: serverPort}, nil); err != nil {
+		return nil, fmt.Errorf("dhcp: server bind: %v", err)
+	}
+	c := newEPConn(ctx, wq, ep)
+	return newServer(ctx, c, addrs, cfg)
+}
+
+func newServer(ctx context.Context, c conn, addrs []tcpip.Address, cfg Config) (*Server, error) {
 	if cfg.ServerAddress == "" {
 		return nil, fmt.Errorf("dhcp: server requires explicit server address")
 	}
 	s := &Server{
-		stack:   stack,
+		conn:    c,
 		addrs:   addrs,
 		cfg:     cfg,
 		cfgopts: cfg.encode(),
@@ -55,19 +119,6 @@
 		leases:   make(map[tcpip.LinkAddress]serverLease),
 	}
 
-	var err error
-	s.ep, err = s.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &s.wq)
-	if err != nil {
-		return nil, fmt.Errorf("dhcp: server endpoint: %v", err)
-	}
-	serverBroadcast := tcpip.FullAddress{
-		Addr: "",
-		Port: serverPort,
-	}
-	if err = s.ep.Bind(serverBroadcast, nil); err != nil {
-		return nil, fmt.Errorf("dhcp: server bind: %v", err)
-	}
-
 	for i := 0; i < len(s.handlers); i++ {
 		ch := make(chan header, 8)
 		s.handlers[i] = ch
@@ -102,20 +153,10 @@
 // reader listens for all incoming DHCP packets and fans them out to
 // handling goroutines based on XID as session identifiers.
 func (s *Server) reader(ctx context.Context) {
-	we, ch := waiter.NewChannelEntry(nil)
-	s.wq.EventRegister(&we, waiter.EventIn)
-	defer s.wq.EventUnregister(&we)
-
 	for {
-		var addr tcpip.FullAddress
-		v, err := s.ep.Read(&addr)
-		if err == tcpip.ErrWouldBlock {
-			select {
-			case <-ch:
-				continue
-			case <-ctx.Done():
-				return
-			}
+		v, _, err := s.conn.Read()
+		if err != nil {
+			return
 		}
 
 		h := header(v)
@@ -235,7 +276,7 @@
 	copy(h.yiaddr(), lease.addr)
 	copy(h.chaddr(), hreq.chaddr())
 	h.setOptions(opts)
-	s.ep.Write(buffer.View(h), &s.broadcast)
+	s.conn.Write([]byte(h), &s.broadcast)
 }
 
 func (s *Server) nack(hreq header) {
@@ -250,7 +291,7 @@
 	copy(h.xidbytes(), hreq.xidbytes())
 	copy(h.chaddr(), hreq.chaddr())
 	h.setOptions(opts)
-	s.ep.Write(buffer.View(h), &s.broadcast)
+	s.conn.Write([]byte(h), &s.broadcast)
 }
 
 func (s *Server) handleRequest(hreq header, opts options) {
@@ -268,7 +309,7 @@
 		return
 	}
 	if reqcfg.ServerAddress != s.cfg.ServerAddress {
-		s.nack(hreq)
+		// This request is for a different DHCP server. Ignore it.
 		return
 	}
 
@@ -301,7 +342,7 @@
 	copy(h.yiaddr(), lease.addr)
 	copy(h.chaddr(), hreq.chaddr())
 	h.setOptions(opts)
-	s.ep.Write(buffer.View(h), &s.broadcast)
+	s.conn.Write([]byte(h), &s.broadcast)
 }
 
 type leaseState int