blob: 46b6e78a4fe284ba9dd9940118e9613cea927fcc [file] [log] [blame]
// Copyright 2016 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dhcp
import (
"bytes"
"context"
"crypto/rand"
"fmt"
"sync"
"time"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/network/ipv4"
"github.com/google/netstack/tcpip/stack"
"github.com/google/netstack/tcpip/transport/udp"
"github.com/google/netstack/waiter"
)
// Client is a DHCP client.
type Client struct {
stack *stack.Stack
nicid tcpip.NICID
linkAddr tcpip.LinkAddress
acquiredFunc func(old, new tcpip.Address, cfg Config)
mu sync.Mutex
addr tcpip.Address
cfg Config
lease time.Duration
cancelRenew func()
}
// NewClient creates a DHCP client.
//
// TODO(crawshaw): add s.LinkAddr(nicid) to *stack.Stack.
func NewClient(s *stack.Stack, nicid tcpip.NICID, linkAddr tcpip.LinkAddress, acquiredFunc func(old, new tcpip.Address, cfg Config)) *Client {
return &Client{
stack: s,
nicid: nicid,
linkAddr: linkAddr,
acquiredFunc: acquiredFunc,
}
}
// Run starts the DHCP client.
// It will periodically search for an IP address using the Request method.
func (c *Client) Run(ctx context.Context) {
go c.run(ctx)
}
func (c *Client) run(ctx context.Context) {
defer func() {
c.mu.Lock()
defer c.mu.Unlock()
if c.addr != "" {
c.stack.RemoveAddress(c.nicid, c.addr)
}
}()
var renewAddr tcpip.Address
for {
reqCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
cfg, err := c.Request(reqCtx, renewAddr)
cancel()
if err != nil {
select {
case <-time.After(1 * time.Second):
// loop and try again
case <-ctx.Done():
return
}
}
c.mu.Lock()
renewAddr = c.addr
c.mu.Unlock()
timer := time.NewTimer(cfg.LeaseLength)
select {
case <-ctx.Done():
timer.Stop()
return
case <-timer.C:
// loop and make a renewal request
}
}
}
// Address reports the IP address acquired by the DHCP client.
func (c *Client) Address() tcpip.Address {
c.mu.Lock()
defer c.mu.Unlock()
return c.addr
}
// Config reports the DHCP configuration acquired with the IP address lease.
func (c *Client) Config() Config {
c.mu.Lock()
defer c.mu.Unlock()
return c.cfg
}
// Request executes a DHCP request session.
//
// On success, it adds a new address to this client's TCPIP stack.
// If the server sets a lease limit a timer is set to automatically
// renew it.
func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) (cfg Config, reterr error) {
if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, "\xff\xff\xff\xff"); err != nil && err != tcpip.ErrDuplicateAddress {
return Config{}, fmt.Errorf("dhcp: %v", err)
}
if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, "\x00\x00\x00\x00"); err != nil && err != tcpip.ErrDuplicateAddress {
return Config{}, fmt.Errorf("dhcp: %v", err)
}
defer c.stack.RemoveAddress(c.nicid, "\xff\xff\xff\xff")
defer c.stack.RemoveAddress(c.nicid, "\x00\x00\x00\x00")
var wq waiter.Queue
ep, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
return Config{}, fmt.Errorf("dhcp: outbound endpoint: %v", err)
}
err = ep.Bind(tcpip.FullAddress{
Addr: "\x00\x00\x00\x00",
Port: clientPort,
NIC: c.nicid,
}, nil)
defer ep.Close()
if err != nil {
return Config{}, fmt.Errorf("dhcp: connect failed: %v", err)
}
epin, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
return Config{}, fmt.Errorf("dhcp: inbound endpoint: %v", err)
}
err = epin.Bind(tcpip.FullAddress{
Addr: "\xff\xff\xff\xff",
Port: clientPort,
NIC: c.nicid,
}, nil)
defer epin.Close()
if err != nil {
return Config{}, fmt.Errorf("dhcp: connect failed: %v", err)
}
var xid [4]byte
rand.Read(xid[:])
// DHCPDISCOVERY
discOpts := options{
{optDHCPMsgType, []byte{byte(dhcpDISCOVER)}},
{optParamReq, []byte{
1, // request subnet mask
3, // request router
15, // domain name
6, // domain name server
}},
}
if requestedAddr != "" {
discOpts = append(discOpts, option{optReqIPAddr, []byte(requestedAddr)})
}
var clientID []byte
if len(c.linkAddr) == 6 {
clientID = make([]byte, 7)
clientID[0] = 1 // htype: ARP Ethernet from RFC 1700
copy(clientID[1:], c.linkAddr)
discOpts = append(discOpts, option{optClientID, clientID})
}
h := make(header, headerBaseSize+discOpts.len()+1)
h.init()
h.setOp(opRequest)
copy(h.xidbytes(), xid[:])
h.setBroadcast()
copy(h.chaddr(), c.linkAddr)
h.setOptions(discOpts)
serverAddr := &tcpip.FullAddress{
Addr: "\xff\xff\xff\xff",
Port: serverPort,
NIC: c.nicid,
}
if _, err := ep.Write(buffer.View(h), serverAddr); err != nil {
return Config{}, fmt.Errorf("dhcp discovery write: %v", err)
}
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
// DHCPOFFER
var opts options
for {
var addr tcpip.FullAddress
v, e := epin.Read(&addr)
if e == tcpip.ErrWouldBlock {
select {
case <-ch:
continue
case <-ctx.Done():
return Config{}, fmt.Errorf("reading dhcp offer: %v", tcpip.ErrAborted)
}
}
h = header(v)
var valid bool
var err error
opts, valid, err = loadDHCPReply(h, dhcpOFFER, xid[:])
if !valid {
if err != nil {
// TODO: report malformed server responses
}
continue
}
break
}
var ack bool
if err := cfg.decode(opts); err != nil {
return Config{}, fmt.Errorf("dhcp offer: %v", err)
}
// DHCPREQUEST
addr := tcpip.Address(h.yiaddr())
if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, addr); err != nil {
if err != tcpip.ErrDuplicateAddress {
return Config{}, fmt.Errorf("adding address: %v", err)
}
}
defer func() {
if !ack || reterr != nil {
c.stack.RemoveAddress(c.nicid, addr)
addr = ""
cfg = Config{Error: reterr}
}
c.mu.Lock()
oldAddr := c.addr
c.addr = addr
c.cfg = cfg
c.mu.Unlock()
// Clean up broadcast addresses before calling acquiredFunc
// so nothing else uses them by mistake.
//
// (The deferred RemoveAddress calls above silently error.)
c.stack.RemoveAddress(c.nicid, "\xff\xff\xff\xff")
c.stack.RemoveAddress(c.nicid, "\x00\x00\x00\x00")
if c.acquiredFunc != nil {
c.acquiredFunc(oldAddr, addr, cfg)
}
if requestedAddr != "" && requestedAddr != addr {
c.stack.RemoveAddress(c.nicid, requestedAddr)
}
}()
h.init()
h.setOp(opRequest)
for i, b := 0, h.yiaddr(); i < len(b); i++ {
b[i] = 0
}
for i, b := 0, h.siaddr(); i < len(b); i++ {
b[i] = 0
}
for i, b := 0, h.giaddr(); i < len(b); i++ {
b[i] = 0
}
reqOpts := []option{
{optDHCPMsgType, []byte{byte(dhcpREQUEST)}},
{optReqIPAddr, []byte(addr)},
{optDHCPServer, []byte(cfg.ServerAddress)},
}
if len(clientID) != 0 {
reqOpts = append(reqOpts, option{optClientID, clientID})
}
h.setOptions(reqOpts)
if _, err := ep.Write([]byte(h), serverAddr); err != nil {
return Config{}, fmt.Errorf("dhcp discovery write: %v", err)
}
// DHCPACK
for {
var addr tcpip.FullAddress
v, e := epin.Read(&addr)
if e == tcpip.ErrWouldBlock {
select {
case <-ch:
continue
case <-ctx.Done():
return Config{}, fmt.Errorf("reading dhcp ack: %v", tcpip.ErrAborted)
}
}
h = header(v)
var valid bool
var err error
opts, valid, err = loadDHCPReply(h, dhcpACK, xid[:])
if !valid {
if err != nil {
// TODO: report malformed server responses
}
if opts, valid, _ = loadDHCPReply(h, dhcpNAK, xid[:]); valid {
if msg := opts.message(); msg != "" {
return Config{}, fmt.Errorf("dhcp: NAK %q", msg)
}
return Config{}, fmt.Errorf("dhcp: NAK with no message")
}
continue
}
break
}
ack = true
return cfg, nil
}
func loadDHCPReply(h header, typ dhcpMsgType, xid []byte) (opts options, valid bool, err error) {
if !h.isValid() || h.op() != opReply || !bytes.Equal(h.xidbytes(), xid[:]) {
return nil, false, nil
}
opts, e := h.options()
if e != nil {
return nil, false, fmt.Errorf("dhcp ack: %v", e)
}
msgtype, e := opts.dhcpMsgType()
if e != nil {
return nil, false, fmt.Errorf("dhcp ack: %v", e)
}
if msgtype != typ {
return nil, false, nil
}
return opts, true, nil
}