blob: 8162acc6996a335053a93006d4ad3eecf2691aae [file] [log] [blame]
// Copyright 2016 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dhcp
import (
"context"
"fmt"
"io"
"log"
"sync"
"time"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/network/ipv4"
"github.com/google/netstack/tcpip/stack"
"github.com/google/netstack/tcpip/transport/udp"
"github.com/google/netstack/waiter"
)
// Server is a DHCP server.
type Server struct {
conn conn
broadcast tcpip.FullAddress
addrs []tcpip.Address // TODO: use a tcpip.AddressMask or range structure
cfg Config
cfgopts []option // cfg to send to client
handlers []chan header
mu sync.Mutex
leases map[tcpip.LinkAddress]serverLease
}
// conn is a blocking read/write network endpoint.
type conn interface {
Read() (buffer.View, tcpip.FullAddress, error)
Write([]byte, *tcpip.FullAddress) error
}
type epConn struct {
ctx context.Context
wq *waiter.Queue
ep tcpip.Endpoint
we waiter.Entry
inCh chan struct{}
}
func newEPConn(ctx context.Context, wq *waiter.Queue, ep tcpip.Endpoint) *epConn {
c := &epConn{
ctx: ctx,
wq: wq,
ep: ep,
}
c.we, c.inCh = waiter.NewChannelEntry(nil)
wq.EventRegister(&c.we, waiter.EventIn)
go func() {
<-ctx.Done()
wq.EventUnregister(&c.we)
}()
return c
}
func (c *epConn) Read() (buffer.View, tcpip.FullAddress, error) {
for {
var addr tcpip.FullAddress
v, err := c.ep.Read(&addr)
if err == tcpip.ErrWouldBlock {
select {
case <-c.inCh:
continue
case <-c.ctx.Done():
return nil, tcpip.FullAddress{}, io.EOF
}
}
return v, addr, err
}
}
func (c *epConn) Write(b []byte, addr *tcpip.FullAddress) error {
_, err := c.ep.Write(b, addr)
return err
}
// NewServer creates a new DHCP server and begins serving.
// The server continues serving until ctx is done.
func NewServer(ctx context.Context, stack *stack.Stack, addrs []tcpip.Address, cfg Config) (*Server, error) {
wq := new(waiter.Queue)
ep, err := stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
return nil, fmt.Errorf("dhcp: server endpoint: %v", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: serverPort}, nil); err != nil {
return nil, fmt.Errorf("dhcp: server bind: %v", err)
}
c := newEPConn(ctx, wq, ep)
return newServer(ctx, c, addrs, cfg)
}
func newServer(ctx context.Context, c conn, addrs []tcpip.Address, cfg Config) (*Server, error) {
if cfg.ServerAddress == "" {
return nil, fmt.Errorf("dhcp: server requires explicit server address")
}
s := &Server{
conn: c,
addrs: addrs,
cfg: cfg,
cfgopts: cfg.encode(),
broadcast: tcpip.FullAddress{
Addr: "\xff\xff\xff\xff",
Port: clientPort,
},
handlers: make([]chan header, 8),
leases: make(map[tcpip.LinkAddress]serverLease),
}
for i := 0; i < len(s.handlers); i++ {
ch := make(chan header, 8)
s.handlers[i] = ch
go s.handler(ctx, ch)
}
go s.expirer(ctx)
go s.reader(ctx)
return s, nil
}
func (s *Server) expirer(ctx context.Context) {
t := time.NewTicker(1 * time.Minute)
defer t.Stop()
for {
select {
case <-t.C:
s.mu.Lock()
for linkAddr, lease := range s.leases {
if time.Since(lease.start) > s.cfg.LeaseLength {
lease.state = leaseExpired
s.leases[linkAddr] = lease
}
}
s.mu.Unlock()
case <-ctx.Done():
return
}
}
}
// reader listens for all incoming DHCP packets and fans them out to
// handling goroutines based on XID as session identifiers.
func (s *Server) reader(ctx context.Context) {
for {
v, _, err := s.conn.Read()
if err != nil {
return
}
h := header(v)
if !h.isValid() || h.op() != opRequest {
continue
}
xid := h.xid()
// Fan out the packet to a handler goroutine.
//
// Use a consistent handler for a given xid, so that
// packets from a particular client are processed
// in order.
ch := s.handlers[int(xid)%len(s.handlers)]
select {
case <-ctx.Done():
return
case ch <- h:
default:
// drop the packet
}
}
}
func (s *Server) handler(ctx context.Context, ch chan header) {
for {
select {
case h := <-ch:
if h == nil {
return
}
opts, err := h.options()
if err != nil {
continue
}
// TODO(crawshaw): Handle DHCPRELEASE and DHCPDECLINE.
msgtype, err := opts.dhcpMsgType()
if err != nil {
continue
}
switch msgtype {
case dhcpDISCOVER:
s.handleDiscover(h, opts)
case dhcpREQUEST:
s.handleRequest(h, opts)
}
case <-ctx.Done():
return
}
}
}
func (s *Server) handleDiscover(hreq header, opts options) {
linkAddr := tcpip.LinkAddress(hreq.chaddr()[:6])
xid := hreq.xid()
s.mu.Lock()
lease := s.leases[linkAddr]
switch lease.state {
case leaseNew:
if len(s.leases) < len(s.addrs) {
// Find an unused address.
// TODO(crawshaw): avoid building this state on each request.
alloced := make(map[tcpip.Address]bool)
for _, lease := range s.leases {
alloced[lease.addr] = true
}
for _, addr := range s.addrs {
if !alloced[addr] {
lease = serverLease{
start: time.Now(),
addr: addr,
xid: xid,
state: leaseOffer,
}
s.leases[linkAddr] = lease
break
}
}
} else {
// No more addresses, take an expired address.
for k, oldLease := range s.leases {
if oldLease.state == leaseExpired {
delete(s.leases, k)
lease = serverLease{
start: time.Now(),
addr: lease.addr,
xid: xid,
state: leaseOffer,
}
s.leases[linkAddr] = lease
break
}
}
log.Printf("server has no more addresses")
s.mu.Unlock()
return
}
case leaseOffer, leaseAck, leaseExpired:
lease = serverLease{
start: time.Now(),
addr: s.leases[linkAddr].addr,
xid: xid,
state: leaseOffer,
}
s.leases[linkAddr] = lease
}
s.mu.Unlock()
// DHCPOFFER
opts = options{{optDHCPMsgType, []byte{byte(dhcpOFFER)}}}
opts = append(opts, s.cfgopts...)
h := make(header, headerBaseSize+opts.len())
h.init()
h.setOp(opReply)
copy(h.xidbytes(), hreq.xidbytes())
copy(h.yiaddr(), lease.addr)
copy(h.chaddr(), hreq.chaddr())
h.setOptions(opts)
s.conn.Write([]byte(h), &s.broadcast)
}
func (s *Server) nack(hreq header) {
// DHCPNACK
opts := options([]option{
{optDHCPMsgType, []byte{byte(dhcpNAK)}},
{optDHCPServer, []byte(s.cfg.ServerAddress)},
})
h := make(header, headerBaseSize+opts.len())
h.init()
h.setOp(opReply)
copy(h.xidbytes(), hreq.xidbytes())
copy(h.chaddr(), hreq.chaddr())
h.setOptions(opts)
s.conn.Write([]byte(h), &s.broadcast)
}
func (s *Server) handleRequest(hreq header, opts options) {
linkAddr := tcpip.LinkAddress(hreq.chaddr()[:6])
xid := hreq.xid()
reqopts, err := hreq.options()
if err != nil {
s.nack(hreq)
return
}
var reqcfg Config
if err := reqcfg.decode(reqopts); err != nil {
s.nack(hreq)
return
}
if reqcfg.ServerAddress != s.cfg.ServerAddress {
// This request is for a different DHCP server. Ignore it.
return
}
s.mu.Lock()
lease := s.leases[linkAddr]
switch lease.state {
case leaseOffer, leaseAck, leaseExpired:
lease = serverLease{
start: time.Now(),
addr: s.leases[linkAddr].addr,
xid: xid,
state: leaseAck,
}
s.leases[linkAddr] = lease
}
s.mu.Unlock()
if lease.state == leaseNew {
// TODO(crawshaw): NACK or accept request
return
}
// DHCPACK
opts = []option{{optDHCPMsgType, []byte{byte(dhcpACK)}}}
opts = append(opts, s.cfgopts...)
h := make(header, headerBaseSize+opts.len())
h.init()
h.setOp(opReply)
copy(h.xidbytes(), hreq.xidbytes())
copy(h.yiaddr(), lease.addr)
copy(h.chaddr(), hreq.chaddr())
h.setOptions(opts)
s.conn.Write([]byte(h), &s.broadcast)
}
type leaseState int
const (
leaseNew leaseState = iota
leaseOffer
leaseAck
leaseExpired
)
type serverLease struct {
start time.Time
addr tcpip.Address
xid xid
state leaseState
}