dhcp: govern client lifecycle with a context

Also move the renewal logic into the same even loop
that drives the initial request, to make it a little
easier to understand.

Change-Id: I14197ebce925e746504e9b9fcfb424aad67ee6ba
diff --git a/dhcp/client.go b/dhcp/client.go
index f474b33..b266af0 100644
--- a/dhcp/client.go
+++ b/dhcp/client.go
@@ -9,7 +9,6 @@
 	"context"
 	"crypto/rand"
 	"fmt"
-	"log"
 	"sync"
 	"time"
 
@@ -47,20 +46,48 @@
 	}
 }
 
-// Start starts the DHCP client.
+// Run starts the DHCP client.
 // It will periodically search for an IP address using the Request method.
-func (c *Client) Start() {
-	go func() {
-		for {
-			ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
-			err := c.Request(ctx, "")
-			cancel()
-			if err == nil {
-				break
-			}
-			time.Sleep(1 * time.Second)
+func (c *Client) Run(ctx context.Context) {
+	go c.run(ctx)
+}
+
+func (c *Client) run(ctx context.Context) {
+	defer func() {
+		c.mu.Lock()
+		defer c.mu.Unlock()
+		if c.addr != "" {
+			c.stack.RemoveAddress(c.nicid, c.addr)
 		}
 	}()
+
+	var renewAddr tcpip.Address
+	for {
+		reqCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
+		cfg, err := c.Request(reqCtx, renewAddr)
+		cancel()
+		if err != nil {
+			select {
+			case <-time.After(1 * time.Second):
+				// loop and try again
+			case <-ctx.Done():
+				return
+			}
+		}
+
+		c.mu.Lock()
+		renewAddr = c.addr
+		c.mu.Unlock()
+
+		timer := time.NewTimer(cfg.LeaseLength)
+		select {
+		case <-ctx.Done():
+			timer.Stop()
+			return
+		case <-timer.C:
+			// loop and make a renewal request
+		}
+	}
 }
 
 // Address reports the IP address acquired by the DHCP client.
@@ -77,29 +104,17 @@
 	return c.cfg
 }
 
-// Shutdown relinquishes any lease and ends any outstanding renewal timers.
-func (c *Client) Shutdown() {
-	c.mu.Lock()
-	defer c.mu.Unlock()
-	if c.addr != "" {
-		c.stack.RemoveAddress(c.nicid, c.addr)
-	}
-	if c.cancelRenew != nil {
-		c.cancelRenew()
-	}
-}
-
 // Request executes a DHCP request session.
 //
 // On success, it adds a new address to this client's TCPIP stack.
 // If the server sets a lease limit a timer is set to automatically
 // renew it.
-func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) (reterr error) {
+func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) (cfg Config, reterr error) {
 	if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, "\xff\xff\xff\xff"); err != nil && err != tcpip.ErrDuplicateAddress {
-		return fmt.Errorf("dhcp: %v", err)
+		return Config{}, fmt.Errorf("dhcp: %v", err)
 	}
 	if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, "\x00\x00\x00\x00"); err != nil && err != tcpip.ErrDuplicateAddress {
-		return fmt.Errorf("dhcp: %v", err)
+		return Config{}, fmt.Errorf("dhcp: %v", err)
 	}
 	defer c.stack.RemoveAddress(c.nicid, "\xff\xff\xff\xff")
 	defer c.stack.RemoveAddress(c.nicid, "\x00\x00\x00\x00")
@@ -107,7 +122,7 @@
 	var wq waiter.Queue
 	ep, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
 	if err != nil {
-		return fmt.Errorf("dhcp: outbound endpoint: %v", err)
+		return Config{}, fmt.Errorf("dhcp: outbound endpoint: %v", err)
 	}
 	err = ep.Bind(tcpip.FullAddress{
 		Addr: "\x00\x00\x00\x00",
@@ -116,12 +131,12 @@
 	}, nil)
 	defer ep.Close()
 	if err != nil {
-		return fmt.Errorf("dhcp: connect failed: %v", err)
+		return Config{}, fmt.Errorf("dhcp: connect failed: %v", err)
 	}
 
 	epin, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
 	if err != nil {
-		return fmt.Errorf("dhcp: inbound endpoint: %v", err)
+		return Config{}, fmt.Errorf("dhcp: inbound endpoint: %v", err)
 	}
 	err = epin.Bind(tcpip.FullAddress{
 		Addr: "\xff\xff\xff\xff",
@@ -130,7 +145,7 @@
 	}, nil)
 	defer epin.Close()
 	if err != nil {
-		return fmt.Errorf("dhcp: connect failed: %v", err)
+		return Config{}, fmt.Errorf("dhcp: connect failed: %v", err)
 	}
 
 	var xid [4]byte
@@ -170,7 +185,7 @@
 		NIC:  c.nicid,
 	}
 	if _, err := ep.Write(buffer.View(h), serverAddr); err != nil {
-		return fmt.Errorf("dhcp discovery write: %v", err)
+		return Config{}, fmt.Errorf("dhcp discovery write: %v", err)
 	}
 
 	we, ch := waiter.NewChannelEntry(nil)
@@ -187,7 +202,7 @@
 			case <-ch:
 				continue
 			case <-ctx.Done():
-				return tcpip.ErrAborted
+				return Config{}, tcpip.ErrAborted
 			}
 		}
 		h = header(v)
@@ -203,16 +218,15 @@
 	}
 
 	var ack bool
-	var cfg Config
 	if err := cfg.decode(opts); err != nil {
-		return fmt.Errorf("dhcp offer: %v", err)
+		return Config{}, fmt.Errorf("dhcp offer: %v", err)
 	}
 
 	// DHCPREQUEST
 	addr := tcpip.Address(h.yiaddr())
 	if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, addr); err != nil {
 		if err != tcpip.ErrDuplicateAddress {
-			return err
+			return Config{}, err
 		}
 	}
 	defer func() {
@@ -256,7 +270,7 @@
 	}
 	h.setOptions(reqOpts)
 	if _, err := ep.Write([]byte(h), serverAddr); err != nil {
-		return fmt.Errorf("dhcp discovery write: %v", err)
+		return Config{}, fmt.Errorf("dhcp discovery write: %v", err)
 	}
 
 	// DHCPACK
@@ -268,7 +282,7 @@
 			case <-ch:
 				continue
 			case <-ctx.Done():
-				return tcpip.ErrAborted
+				return Config{}, tcpip.ErrAborted
 			}
 		}
 		h = header(v)
@@ -280,41 +294,16 @@
 			}
 			if opts, valid, _ = loadDHCPReply(h, dhcpNAK, xid[:]); valid {
 				if msg := opts.message(); msg != "" {
-					return fmt.Errorf("dhcp: NAK %q", msg)
+					return Config{}, fmt.Errorf("dhcp: NAK %q", msg)
 				}
-				return fmt.Errorf("dhcp: NAK with no message")
+				return Config{}, fmt.Errorf("dhcp: NAK with no message")
 			}
 			continue
 		}
 		break
 	}
 	ack = true
-	if cfg.LeaseLength != 0 {
-		go c.renewAfter(cfg.LeaseLength)
-	}
-	return nil
-}
-
-func (c *Client) renewAfter(d time.Duration) {
-	c.mu.Lock()
-	defer c.mu.Unlock()
-	if c.cancelRenew != nil {
-		c.cancelRenew()
-	}
-	ctx, cancel := context.WithCancel(context.Background())
-	c.cancelRenew = cancel
-	go func() {
-		timer := time.NewTimer(d)
-		defer timer.Stop()
-		select {
-		case <-ctx.Done():
-		case <-timer.C:
-			if err := c.Request(ctx, c.addr); err != nil {
-				log.Printf("address renewal failed: %v", err)
-				go c.renewAfter(1 * time.Minute)
-			}
-		}
-	}()
+	return cfg, nil
 }
 
 func loadDHCPReply(h header, typ dhcpMsgType, xid []byte) (opts options, valid bool, err error) {
diff --git a/dhcp/dhcp_test.go b/dhcp/dhcp_test.go
index 0f6d5da..f385e97 100644
--- a/dhcp/dhcp_test.go
+++ b/dhcp/dhcp_test.go
@@ -81,13 +81,13 @@
 
 	const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52")
 	c0 := NewClient(s, nicid, clientLinkAddr0, nil)
-	if err := c0.Request(context.Background(), ""); err != nil {
+	if _, err := c0.Request(context.Background(), ""); err != nil {
 		t.Fatal(err)
 	}
 	if got, want := c0.Address(), clientAddrs[0]; got != want {
 		t.Errorf("c.Addr()=%s, want=%s", got, want)
 	}
-	if err := c0.Request(context.Background(), ""); err != nil {
+	if _, err := c0.Request(context.Background(), ""); err != nil {
 		t.Fatal(err)
 	}
 	if got, want := c0.Address(), clientAddrs[0]; got != want {
@@ -96,14 +96,14 @@
 
 	const clientLinkAddr1 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x53")
 	c1 := NewClient(s, nicid, clientLinkAddr1, nil)
-	if err := c1.Request(context.Background(), ""); err != nil {
+	if _, err := c1.Request(context.Background(), ""); err != nil {
 		t.Fatal(err)
 	}
 	if got, want := c1.Address(), clientAddrs[1]; got != want {
 		t.Errorf("c.Addr()=%s, want=%s", got, want)
 	}
 
-	if err := c0.Request(context.Background(), ""); err != nil {
+	if _, err := c0.Request(context.Background(), ""); err != nil {
 		t.Fatal(err)
 	}
 	if got, want := c0.Address(), clientAddrs[0]; got != want {
@@ -166,9 +166,10 @@
 		addrCh <- newAddr
 	}
 
+	clientCtx, cancel := context.WithCancel(context.Background())
 	const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52")
 	c := NewClient(s, nicid, clientLinkAddr0, acquiredFunc)
-	c.Start()
+	c.Run(clientCtx)
 
 	var addr tcpip.Address
 	select {
@@ -188,7 +189,7 @@
 		t.Fatal("timeout waiting for address renewal")
 	}
 
-	c.Shutdown()
+	cancel()
 }
 
 // Regression test for https://fuchsia.atlassian.net/browse/NET-17
@@ -308,7 +309,7 @@
 
 	const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52")
 	c := NewClient(s, nicid, clientLinkAddr0, nil)
-	if err := c.Request(context.Background(), ""); err != nil {
+	if _, err := c.Request(context.Background(), ""); err != nil {
 		t.Fatal(err)
 	}
 }