// Copyright 2019 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

// +build !build_with_native_toolchain

package dhcp

import (
	"bytes"
	"context"
	"fmt"
	"math/rand"
	"net"
	"sync/atomic"
	"time"

	syslog "go.fuchsia.dev/fuchsia/src/lib/syslog/go"

	"gvisor.dev/gvisor/pkg/tcpip"
	"gvisor.dev/gvisor/pkg/tcpip/buffer"
	"gvisor.dev/gvisor/pkg/tcpip/header"
	"gvisor.dev/gvisor/pkg/tcpip/stack"
	"gvisor.dev/gvisor/pkg/tcpip/transport/packet"
	"gvisor.dev/gvisor/pkg/waiter"
)

const (
	tag                        = "DHCP"
	defaultLeaseLength Seconds = 12 * 3600
)

type AcquiredFunc func(oldAddr, newAddr tcpip.AddressWithPrefix, cfg Config)

// Client is a DHCP client.
type Client struct {
	stack *stack.Stack

	// info holds the Client's state as type Info.
	info atomic.Value

	acquiredFunc AcquiredFunc

	wq waiter.Queue

	// Used to ensure that only one Run goroutine per interface may be
	// permitted to run at a time. In certain cases, rapidly flapping the
	// DHCP client on and off can cause a second instance of Run to start
	// before the existing one has finished, which can violate invariants.
	// At the time of writing, TestDhcpConfiguration was creating this
	// scenario and causing panics.
	sem chan struct{}

	stats Stats

	leaseExpirationTime, renewTime, rebindTime time.Time

	// Stubbable in test.
	rand           *rand.Rand
	retransTimeout func(time.Duration) <-chan time.Time
	acquire        func(context.Context, *Client, string, *Info) (Config, error)
	now            func() time.Time
}

// Stats collects DHCP statistics per client.
type Stats struct {
	InitAcquire                 tcpip.StatCounter
	RenewAcquire                tcpip.StatCounter
	RebindAcquire               tcpip.StatCounter
	SendDiscovers               tcpip.StatCounter
	RecvOffers                  tcpip.StatCounter
	SendRequests                tcpip.StatCounter
	RecvAcks                    tcpip.StatCounter
	RecvNaks                    tcpip.StatCounter
	SendDiscoverErrors          tcpip.StatCounter
	SendRequestErrors           tcpip.StatCounter
	RecvOfferErrors             tcpip.StatCounter
	RecvOfferUnexpectedType     tcpip.StatCounter
	RecvOfferOptsDecodeErrors   tcpip.StatCounter
	RecvOfferTimeout            tcpip.StatCounter
	RecvOfferAcquisitionTimeout tcpip.StatCounter
	RecvAckErrors               tcpip.StatCounter
	RecvNakErrors               tcpip.StatCounter
	RecvAckOptsDecodeErrors     tcpip.StatCounter
	RecvAckAddrErrors           tcpip.StatCounter
	RecvAckUnexpectedType       tcpip.StatCounter
	RecvAckTimeout              tcpip.StatCounter
	RecvAckAcquisitionTimeout   tcpip.StatCounter
	ReacquireAfterNAK           tcpip.StatCounter
}

type Info struct {
	// NICID is the identifer to the associated NIC.
	NICID tcpip.NICID
	// LinkAddr is the link-address of the associated NIC.
	LinkAddr tcpip.LinkAddress
	// Acquisition is the duration within which a complete DHCP transaction must
	// complete before timing out.
	Acquisition time.Duration
	// Backoff is the duration for which the client must wait before starting a
	// new DHCP transaction after a failed transaction.
	Backoff time.Duration
	// Retransmission is the duration to wait before resending a DISCOVER or
	// REQUEST within an active transaction.
	Retransmission time.Duration
	// Addr is the acquired network address.
	Addr tcpip.AddressWithPrefix
	// Server is the network address of the DHCP server.
	Server tcpip.Address
	// State is the DHCP client state.
	State dhcpClientState
	// OldAddr is the address reported in the last call to acquiredFunc.
	OldAddr tcpip.AddressWithPrefix
}

// NewClient creates a DHCP client.
//
// acquiredFunc will be called after each DHCP acquisition, and is responsible
// for making necessary modifications to the stack state.
func NewClient(
	s *stack.Stack,
	nicid tcpip.NICID,
	linkAddr tcpip.LinkAddress,
	acquisition,
	backoff,
	retransmission time.Duration,
	acquiredFunc AcquiredFunc,
) *Client {
	c := &Client{
		stack:          s,
		acquiredFunc:   acquiredFunc,
		sem:            make(chan struct{}, 1),
		rand:           rand.New(rand.NewSource(time.Now().UnixNano())),
		retransTimeout: time.After,
		acquire:        acquire,
		now:            time.Now,
	}
	c.info.Store(Info{
		NICID:          nicid,
		LinkAddr:       linkAddr,
		Acquisition:    acquisition,
		Retransmission: retransmission,
		Backoff:        backoff,
	})
	return c
}

// Info returns a copy of the synchronized state of the Info.
func (c *Client) Info() Info {
	return c.info.Load().(Info)
}

// Stats returns a reference to the Client`s stats.
func (c *Client) Stats() *Stats {
	return &c.stats
}

// Run runs the DHCP client.
//
// The function periodically searches for a new IP address.
func (c *Client) Run(ctx context.Context) {
	info := c.Info()

	nicName := c.stack.FindNICNameFromID(info.NICID)

	// For the initial iteration of the acquisition loop, the client should
	// be in the initSelecting state, corresponding to the
	// INIT->SELECTING->REQUESTING->BOUND state transition:
	// https://tools.ietf.org/html/rfc2131#section-4.4
	info.State = initSelecting

	c.sem <- struct{}{}
	defer func() { <-c.sem }()
	defer func() {
		_ = syslog.InfoTf(tag, "%s: client is stopping, cleaning up", nicName)
		c.cleanup(&info)
		// cleanup mutates info.
		c.info.Store(info)
	}()

	var timer *time.Timer

	for {
		if err := func() error {
			acquisitionTimeout := info.Acquisition

			// Adjust the timeout to make sure client is not stuck in retransmission
			// when it should transition to the next state. This can only happen for
			// two time-driven transitions: RENEW->REBIND, REBIND->INIT.
			//
			// Another time-driven transition BOUND->RENEW is not handled here because
			// the client does not have to send out any request during BOUND.
			switch s := info.State; s {
			case initSelecting:
				// Nothing to do. The client is initializing, no leases have been acquired.
				// Thus no times are set for renew, rebind, and lease expiration.
				c.stats.InitAcquire.Increment()
			case renewing:
				c.stats.RenewAcquire.Increment()
				// Instead of `time.Until`, use `now` stored on the client so
				// it can be stubbed out in test for consistency.
				if tilRebind := c.rebindTime.Sub(c.now()); tilRebind < acquisitionTimeout {
					acquisitionTimeout = tilRebind
				}
			case rebinding:
				c.stats.RebindAcquire.Increment()
				// Instead of `time.Until`, use `now` stored on the client so
				// it can be stubbed out in test for consistency.
				if tilLeaseExpire := c.leaseExpirationTime.Sub(c.now()); tilLeaseExpire < acquisitionTimeout {
					acquisitionTimeout = tilLeaseExpire
				}
			default:
				panic(fmt.Sprintf("unexpected state before acquire: %s", s))
			}

			ctx, cancel := context.WithTimeout(ctx, acquisitionTimeout)
			defer cancel()

			cfg, err := c.acquire(ctx, c, nicName, &info)
			if err != nil {
				return err
			}
			if cfg.Declined {
				c.stats.ReacquireAfterNAK.Increment()
				c.cleanup(&info)
				// Reset all the times so the client will re-acquire.
				c.leaseExpirationTime = time.Time{}
				c.renewTime = time.Time{}
				c.rebindTime = time.Time{}
				return nil
			}

			if cfg.LeaseLength == 0 {
				_ = syslog.WarnTf(tag, "%s: unspecified lease length; proceeding with default (%s)", nicName, defaultLeaseLength)
				cfg.LeaseLength = defaultLeaseLength
			}
			{
				// Based on RFC 2131 Sec. 4.4.5, this defaults to (0.5 * duration_of_lease).
				defaultRenewTime := cfg.LeaseLength / 2
				if cfg.RenewTime == 0 {
					_ = syslog.WarnTf(tag, "%s: unspecified renew time; proceeding with default (%s)", nicName, defaultRenewTime)
					cfg.RenewTime = defaultRenewTime
				}
				if cfg.RenewTime >= cfg.LeaseLength {
					_ = syslog.WarnTf(tag, "%s: renew time (%s) >= lease length (%s); proceeding with default (%s)", nicName, cfg.RenewTime, cfg.LeaseLength, defaultRenewTime)
					cfg.RenewTime = defaultRenewTime
				}
			}
			{
				// Based on RFC 2131 Sec. 4.4.5, this defaults to (0.875 * duration_of_lease).
				defaultRebindTime := cfg.LeaseLength * 875 / 1000
				if cfg.RebindTime == 0 {
					cfg.RebindTime = defaultRebindTime
				}
				if cfg.RebindTime <= cfg.RenewTime {
					_ = syslog.WarnTf(tag, "%s: rebind time (%s) <= renew time (%s); proceeding with default (%s)", nicName, cfg.RebindTime, cfg.RenewTime, defaultRebindTime)
					cfg.RebindTime = defaultRebindTime
				}
			}

			now := c.now()
			c.leaseExpirationTime = now.Add(cfg.LeaseLength.Duration())
			c.renewTime = now.Add(cfg.RenewTime.Duration())
			c.rebindTime = now.Add(cfg.RebindTime.Duration())

			if fn := c.acquiredFunc; fn != nil {
				fn(info.OldAddr, info.Addr, cfg)
			}
			info.OldAddr = info.Addr
			info.State = bound

			return nil
		}(); err != nil {
			if ctx.Err() != nil {
				return
			}
			_ = syslog.InfoTf(tag, "%s: %s; retrying %s", nicName, err, info.State)
		}

		// Synchronize info after attempt to acquire is complete.
		c.info.Store(info)

		// RFC 2131 Section 4.4.5
		// https://tools.ietf.org/html/rfc2131#section-4.4.5
		//
		//   T1 MUST be earlier than T2, which, in turn, MUST be earlier than
		//   the time at which the client's lease will expire.
		var next dhcpClientState
		var waitDuration time.Duration
		switch now := c.now(); {
		case !now.Before(c.leaseExpirationTime):
			next = initSelecting
		case !now.Before(c.rebindTime):
			next = rebinding
		case !now.Before(c.renewTime):
			next = renewing
		default:
			switch s := info.State; s {
			case renewing, rebinding:
				// This means the client is stuck in a bad state, because if
				// the timers are correctly set, previous cases should have matched.
				panic(fmt.Sprintf(
					"invalid client state %s, now=%s, leaseExpirationTime=%s, renewTime=%s, rebindTime=%s",
					s, now, c.leaseExpirationTime, c.renewTime, c.rebindTime,
				))
			}
			waitDuration = c.renewTime.Sub(now)
			next = renewing
		}

		// No state transition occurred, the client is retrying.
		if info.State == next {
			waitDuration = info.Backoff
		}

		if info.State != next && next != renewing {
			// Transition immediately for RENEW->REBIND, REBIND->INIT.
			if ctx.Err() != nil {
				return
			}
		} else {
			// Only (re)set timer if we actually wait on it, otherwise subsequent
			// `timer.Reset` may not work as expected because of undrained `timer.C`.
			//
			// https://golang.org/pkg/time/#Timer.Reset
			if timer == nil {
				timer = time.NewTimer(waitDuration)
			} else if waitDuration != 0 {
				timer.Reset(waitDuration)
			}
			_ = syslog.InfoTf(tag, "%s: scheduling renewal in %.fs", nicName, waitDuration.Seconds())
			select {
			case <-ctx.Done():
				return
			case <-timer.C:
			}
		}

		if info.State != initSelecting && next == initSelecting {
			_ = syslog.WarnTf(tag, "%s: lease time expired, cleaning up", nicName)
			c.cleanup(&info)
		}

		info.State = next

		// Synchronize info after any state updates.
		c.info.Store(info)
	}
}

func (c *Client) cleanup(info *Info) {
	if info.OldAddr == (tcpip.AddressWithPrefix{}) {
		return
	}

	// Remove the old address and configuration.
	if fn := c.acquiredFunc; fn != nil {
		fn(info.OldAddr, tcpip.AddressWithPrefix{}, Config{})
	}
	info.OldAddr = tcpip.AddressWithPrefix{}
}

const maxBackoff = 64 * time.Second

// Exponential backoff calculates the backoff delay for this iteration (0-indexed) of retransmission.
//
// RFC 2131 section 4.1
// https://tools.ietf.org/html/rfc2131#section-4.1
//
//   The delay between retransmissions SHOULD be
//   chosen to allow sufficient time for replies from the server to be
//   delivered based on the characteristics of the internetwork between
//   the client and the server.  For example, in a 10Mb/sec Ethernet
//   internetwork, the delay before the first retransmission SHOULD be 4
//   seconds randomized by the value of a uniform random number chosen
//   from the range -1 to +1.  Clients with clocks that provide resolution
//   granularity of less than one second may choose a non-integer
//   randomization value.  The delay before the next retransmission SHOULD
//   be 8 seconds randomized by the value of a uniform number chosen from
//   the range -1 to +1.  The retransmission delay SHOULD be doubled with
//   subsequent retransmissions up to a maximum of 64 seconds.
func (c *Client) exponentialBackoff(iteration uint) time.Duration {
	jitter := time.Duration(c.rand.Int63n(int64(2*time.Second+1))) - time.Second // [-1s, +1s]
	backoff := maxBackoff
	// Guards against overflow.
	if retransmission := c.Info().Retransmission; (maxBackoff/retransmission)>>iteration != 0 {
		backoff = retransmission * (1 << iteration)
	}
	backoff += jitter
	if backoff < 0 {
		return 0
	}
	return backoff
}

func acquire(ctx context.Context, c *Client, nicName string, info *Info) (Config, error) {
	netEP, err := c.stack.GetNetworkEndpoint(info.NICID, header.IPv4ProtocolNumber)
	if err != nil {
		return Config{}, fmt.Errorf("stack.GetNetworkEndpoint(%d, header.IPv4ProtocolNumber): %s", info.NICID, err)
	}

	// https://tools.ietf.org/html/rfc2131#section-4.3.6 Client messages:
	//
	// ---------------------------------------------------------------------
	// |              |INIT-REBOOT  |SELECTING    |RENEWING     |REBINDING |
	// ---------------------------------------------------------------------
	// |broad/unicast |broadcast    |broadcast    |unicast      |broadcast |
	// |server-ip     |MUST NOT     |MUST         |MUST NOT     |MUST NOT  |
	// |requested-ip  |MUST         |MUST         |MUST NOT     |MUST NOT  |
	// |ciaddr        |zero         |zero         |IP address   |IP address|
	// ---------------------------------------------------------------------
	writeTo := tcpip.FullAddress{
		Addr: header.IPv4Broadcast,
		Port: ServerPort,
		NIC:  info.NICID,
	}

	ep, err := packet.NewEndpoint(c.stack, true /* cooked */, header.IPv4ProtocolNumber, &c.wq)
	if err != nil {
		return Config{}, fmt.Errorf("packet.NewEndpoint(_, true, header.IPv4ProtocolNumber, _): %s", err)
	}
	defer ep.Close()

	recvOn := tcpip.FullAddress{
		NIC: info.NICID,
	}
	if err := ep.Bind(recvOn); err != nil {
		return Config{}, fmt.Errorf("ep.Bind(%+v): %s", recvOn, err)
	}

	switch info.State {
	case initSelecting:
	case renewing:
		writeTo.Addr = info.Server
	case rebinding:
	default:
		panic(fmt.Sprintf("unknown client state: c.State=%s", info.State))
	}

	we, ch := waiter.NewChannelEntry(nil)
	c.wq.EventRegister(&we, waiter.EventIn)
	defer c.wq.EventUnregister(&we)

	var xid [4]byte
	if _, err := c.rand.Read(xid[:]); err != nil {
		return Config{}, fmt.Errorf("c.rand.Read(): %w", err)
	}

	commonOpts := options{
		{optParamReq, []byte{
			1,  // request subnet mask
			3,  // request router
			15, // domain name
			6,  // domain name server
		}},
	}
	requestedAddr := info.Addr
	if info.State == initSelecting {
		discOpts := append(options{
			{optDHCPMsgType, []byte{byte(dhcpDISCOVER)}},
		}, commonOpts...)
		if len(requestedAddr.Address) != 0 {
			discOpts = append(discOpts, option{optReqIPAddr, []byte(requestedAddr.Address)})
		}

	retransmitDiscover:
		for i := uint(0); ; i++ {
			if err := c.send(
				ctx,
				nicName,
				info,
				netEP,
				discOpts,
				writeTo,
				xid[:],
				false, /* broadcast */
				false, /* ciaddr */
			); err != nil {
				c.stats.SendDiscoverErrors.Increment()
				return Config{}, fmt.Errorf("%s: %w", dhcpDISCOVER, err)
			}
			c.stats.SendDiscovers.Increment()

			// Receive a DHCPOFFER message from a responding DHCP server.
			retransmit := c.retransTimeout(c.exponentialBackoff(i))
			for {
				result, retransmit, err := c.recv(ctx, nicName, ep, ch, xid[:], retransmit)
				if err != nil {
					if retransmit {
						c.stats.RecvOfferAcquisitionTimeout.Increment()
					} else {
						c.stats.RecvOfferErrors.Increment()
					}
					return Config{}, fmt.Errorf("recv %s: %w", dhcpOFFER, err)
				}
				if retransmit {
					c.stats.RecvOfferTimeout.Increment()
					_ = syslog.WarnTf(tag, "%s: recv timeout waiting for %s; retransmitting %s", nicName, dhcpOFFER, dhcpDISCOVER)
					continue retransmitDiscover
				}

				if result.typ != dhcpOFFER {
					c.stats.RecvOfferUnexpectedType.Increment()
					_ = syslog.InfoTf(tag, "%s: got DHCP type = %s from %s, want = %s; discarding", nicName, result.typ, result.source, dhcpOFFER)
					continue
				}
				c.stats.RecvOffers.Increment()

				var cfg Config
				if err := cfg.decode(result.options); err != nil {
					c.stats.RecvOfferOptsDecodeErrors.Increment()
					return Config{}, fmt.Errorf("%s decode: %w", result.typ, err)
				}

				// We can overwrite the client's server notion, since there's no
				// atomicity required for correctness.
				//
				// We do not perform sophisticated offer selection and instead merely
				// select the first valid offer we receive.
				info.Server = cfg.ServerAddress

				if len(cfg.SubnetMask) == 0 {
					cfg.SubnetMask = tcpip.AddressMask(net.IP(info.Addr.Address).DefaultMask())
				}

				prefixLen, _ := net.IPMask(cfg.SubnetMask).Size()
				requestedAddr = tcpip.AddressWithPrefix{
					Address:   result.yiaddr,
					PrefixLen: prefixLen,
				}

				_ = syslog.InfoTf(
					tag,
					"%s: got %s from %s: Address=%s, server=%s, leaseLength=%s, renewTime=%s, rebindTime=%s",
					nicName,
					result.typ,
					result.source,
					requestedAddr,
					info.Server,
					cfg.LeaseLength,
					cfg.RenewTime,
					cfg.RebindTime,
				)

				break retransmitDiscover
			}
		}
	}

	reqOpts := append(options{
		{optDHCPMsgType, []byte{byte(dhcpREQUEST)}},
	}, commonOpts...)
	if info.State == initSelecting {
		reqOpts = append(reqOpts,
			options{
				{optDHCPServer, []byte(info.Server)},
				{optReqIPAddr, []byte(requestedAddr.Address)},
			}...)
	}

retransmitRequest:
	for i := uint(0); ; i++ {
		if err := c.send(
			ctx,
			nicName,
			info,
			netEP,
			reqOpts,
			writeTo,
			xid[:],
			false,                       /* broadcast */
			info.State != initSelecting, /* ciaddr */
		); err != nil {
			c.stats.SendRequestErrors.Increment()
			return Config{}, fmt.Errorf("%s: %w", dhcpREQUEST, err)
		}
		c.stats.SendRequests.Increment()

		// RFC 2131 Section 4.4.5
		// https://tools.ietf.org/html/rfc2131#section-4.4.5
		//
		//   In both RENEWING and REBINDING states, if the client receives no
		//   response to its DHCPREQUEST message, the client SHOULD wait one-half of
		//   the remaining time until T2 (in RENEWING state) and one-half of the
		//   remaining lease time (in REBINDING state), down to a minimum of 60
		//   seconds, before retransmitting the DHCPREQUEST message.
		var retransmitAfter time.Duration
		switch info.State {
		case initSelecting:
			retransmitAfter = c.exponentialBackoff(i)
		case renewing:
			retransmitAfter = c.rebindTime.Sub(c.now()) / 2
			if min := 60 * time.Second; retransmitAfter < min {
				retransmitAfter = min
			}
		case rebinding:
			retransmitAfter = c.leaseExpirationTime.Sub(c.now()) / 2
			if min := 60 * time.Second; retransmitAfter < min {
				retransmitAfter = min
			}
		default:
			panic(fmt.Sprintf("invalid client state %s", info.State))
		}

		// Receive a DHCPACK/DHCPNAK from the server.
		retransmit := c.retransTimeout(retransmitAfter)
		for {
			result, retransmit, err := c.recv(ctx, nicName, ep, ch, xid[:], retransmit)
			if err != nil {
				if retransmit {
					c.stats.RecvAckAcquisitionTimeout.Increment()
				} else {
					c.stats.RecvAckErrors.Increment()
				}
				return Config{}, fmt.Errorf("recv %s: %w", dhcpACK, err)
			}
			if retransmit {
				c.stats.RecvAckTimeout.Increment()
				_ = syslog.WarnTf(tag, "%s: recv timeout waiting for %s; retransmitting %s", nicName, dhcpACK, dhcpREQUEST)
				continue retransmitRequest
			}

			switch result.typ {
			case dhcpACK:
				var cfg Config
				if err := cfg.decode(result.options); err != nil {
					c.stats.RecvAckOptsDecodeErrors.Increment()
					return Config{}, fmt.Errorf("%s decode: %w", result.typ, err)
				}
				prefixLen, _ := net.IPMask(cfg.SubnetMask).Size()
				addr := tcpip.AddressWithPrefix{
					Address:   result.yiaddr,
					PrefixLen: prefixLen,
				}
				if addr != requestedAddr {
					c.stats.RecvAckAddrErrors.Increment()
					return Config{}, fmt.Errorf("%s with unexpected address=%s expected=%s", result.typ, addr, requestedAddr)
				}
				c.stats.RecvAcks.Increment()

				// Now that we've successfully acquired the address, update the client state.
				info.Addr = requestedAddr
				_ = syslog.InfoTf(tag, "%s: got %s from %s with leaseLength=%s", nicName, result.typ, result.source, cfg.LeaseLength)
				return cfg, nil
			case dhcpNAK:
				if msg := result.options.message(); len(msg) != 0 {
					c.stats.RecvNakErrors.Increment()
					return Config{}, fmt.Errorf("%s: %x", result.typ, msg)
				}
				c.stats.RecvNaks.Increment()
				_ = syslog.InfoTf(tag, "%s: got %s from %s", nicName, result.typ, result.source)
				// We lost the lease.
				return Config{
					Declined: true,
				}, nil
			default:
				c.stats.RecvAckUnexpectedType.Increment()
				_ = syslog.InfoTf(tag, "%s: got DHCP type = %s from %s, want = %s or %s; discarding", nicName, result.typ, result.source, dhcpACK, dhcpNAK)
				continue
			}
		}
	}
}

func (c *Client) send(
	ctx context.Context,
	nicName string,
	info *Info,
	ep stack.NetworkEndpoint,
	opts options,
	writeTo tcpip.FullAddress,
	xid []byte,
	broadcast,
	ciaddr bool,
) error {
	dhcpLength := headerBaseSize + opts.len() + 1
	b := buffer.NewPrependable(header.IPv4MinimumSize + header.UDPMinimumSize + dhcpLength)
	dhcpPayload := hdr(b.Prepend(dhcpLength))
	dhcpPayload.init()
	dhcpPayload.setOp(opRequest)
	if l := copy(dhcpPayload.xidbytes(), xid); l != len(xid) {
		panic(fmt.Sprintf("failed to copy xid bytes, want=%d got=%d", len(xid), l))
	}
	if broadcast {
		dhcpPayload.setBroadcast()
	}
	if ciaddr {
		if l := copy(dhcpPayload.ciaddr(), info.Addr.Address); l != len(info.Addr.Address) {
			panic(fmt.Sprintf("failed to copy info.Addr.Address bytes, want=%d got=%d", len(info.Addr.Address), l))
		}
	}

	if l := copy(dhcpPayload.chaddr(), info.LinkAddr); l != len(info.LinkAddr) {
		panic(fmt.Sprintf("failed to copy all info.LinkAddr bytes, want=%d got=%d", len(info.LinkAddr), l))
	}
	dhcpPayload.setOptions(opts)

	typ, err := opts.dhcpMsgType()
	if err != nil {
		panic(err)
	}

	_ = syslog.InfoTf(
		tag,
		"%s: send %s from %s:%d to %s:%d on NIC:%d (bcast=%t ciaddr=%t)",
		nicName,
		typ,
		info.Addr.Address,
		ClientPort,
		writeTo.Addr,
		writeTo.Port,
		writeTo.NIC,
		broadcast,
		ciaddr,
	)

	// TODO(https://gvisor.dev/issues/4957): Use more streamlined serialization
	// functions when available.

	// Initialize the UDP header.
	udp := header.UDP(b.Prepend(header.UDPMinimumSize))
	length := uint16(b.UsedLength())
	udp.Encode(&header.UDPFields{
		SrcPort: ClientPort,
		DstPort: writeTo.Port,
		Length:  length,
	})
	xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, info.Addr.Address, writeTo.Addr, length)
	xsum = header.Checksum(dhcpPayload, xsum)
	udp.SetChecksum(^udp.CalculateChecksum(xsum))

	// Initialize the IP header.
	ip := header.IPv4(b.Prepend(header.IPv4MinimumSize))
	ip.Encode(&header.IPv4Fields{
		TotalLength: uint16(b.UsedLength()),
		Flags:       header.IPv4FlagDontFragment,
		ID:          0,
		TTL:         ep.DefaultTTL(),
		TOS:         stack.DefaultTOS,
		Protocol:    uint8(header.UDPProtocolNumber),
		SrcAddr:     info.Addr.Address,
		DstAddr:     writeTo.Addr,
	})
	ip.SetChecksum(^ip.CalculateChecksum())

	var linkAddress tcpip.LinkAddress
	{
		ch := make(chan stack.LinkResolutionResult, 1)
		err := c.stack.GetLinkAddress(info.NICID, writeTo.Addr, info.Addr.Address, header.IPv4ProtocolNumber, func(result stack.LinkResolutionResult) {
			ch <- result
		})
		switch err.(type) {
		case nil:
			result := <-ch
			if result.Success {
				linkAddress = result.LinkAddress
			} else {
				err = &tcpip.ErrTimeout{}
			}
		case *tcpip.ErrWouldBlock:
			select {
			case result := <-ch:
				if result.Success {
					linkAddress = result.LinkAddress
					err = nil
				} else {
					err = &tcpip.ErrTimeout{}
				}
			case <-ctx.Done():
				return fmt.Errorf("client address resolution: %w", ctx.Err())
			}
		}
		if err != nil {
			return fmt.Errorf("failed to resolve link address: %s", err)
		}
	}

	if err := c.stack.WritePacketToRemote(
		writeTo.NIC,
		linkAddress,
		header.IPv4ProtocolNumber,
		b.View().ToVectorisedView(),
	); err != nil {
		return fmt.Errorf("failed to write packet: %s", err)
	}
	return nil
}

type recvResult struct {
	source  tcpip.Address
	yiaddr  tcpip.Address
	options options
	typ     dhcpMsgType
}

func (c *Client) recv(
	ctx context.Context,
	nicName string,
	ep tcpip.Endpoint,
	read <-chan struct{},
	xid []byte,
	retransmit <-chan time.Time,
) (recvResult, bool, error) {
	var b bytes.Buffer
	for {
		b.Reset()

		res, err := ep.Read(&b, tcpip.ReadOptions{
			NeedRemoteAddr:     true,
			NeedLinkPacketInfo: true,
		})
		senderAddr := tcpip.LinkAddress(res.RemoteAddr.Addr)
		if _, ok := err.(*tcpip.ErrWouldBlock); ok {
			select {
			case <-read:
				continue
			case <-retransmit:
				return recvResult{}, true, nil
			case <-ctx.Done():
				return recvResult{}, true, fmt.Errorf("read: %w", ctx.Err())
			}
		}
		if err != nil {
			return recvResult{}, false, fmt.Errorf("read: %s", err)
		}

		if res.LinkPacketInfo.Protocol != header.IPv4ProtocolNumber {
			continue
		}

		switch res.LinkPacketInfo.PktType {
		case tcpip.PacketHost, tcpip.PacketBroadcast:
		default:
			continue
		}

		v := b.Bytes()
		ip := header.IPv4(v)
		if !ip.IsValid(len(v)) {
			_ = syslog.WarnTf(
				tag,
				"%s: received malformed IP frame from %s; discarding %d bytes",
				nicName,
				senderAddr,
				len(v),
			)
			continue
		}
		// TODO(https://gvisor.dev/issues/5049): Abstract away checksum validation when possible.
		if ip.CalculateChecksum() != 0xffff {
			_ = syslog.WarnTf(
				tag,
				"%s: received damaged IP frame from %s; discarding %d bytes",
				nicName,
				senderAddr,
				len(v),
			)
			continue
		}
		if ip.More() || ip.FragmentOffset() != 0 {
			_ = syslog.WarnTf(
				tag,
				"%s: received fragmented IP frame from %s; discarding %d bytes",
				nicName,
				senderAddr,
				len(v),
			)
			continue
		}
		if ip.TransportProtocol() != header.UDPProtocolNumber {
			continue
		}
		udp := header.UDP(ip.Payload())
		if len(udp) < header.UDPMinimumSize {
			_ = syslog.WarnTf(
				tag,
				"%s: received malformed UDP frame (%s@%s -> %s); discarding %d bytes",
				nicName,
				ip.SourceAddress(),
				senderAddr,
				ip.DestinationAddress(),
				len(udp),
			)
			continue
		}
		if udp.DestinationPort() != ClientPort {
			continue
		}
		if udp.Length() > uint16(len(udp)) {
			_ = syslog.WarnTf(
				tag,
				"%s: received malformed UDP frame (%s@%s -> %s); discarding %d bytes",
				nicName,
				ip.SourceAddress(),
				senderAddr,
				ip.DestinationAddress(),
				len(udp),
			)
			continue
		}
		payload := udp.Payload()
		if xsum := udp.Checksum(); xsum != 0 {
			xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, ip.DestinationAddress(), ip.SourceAddress(), udp.Length())
			xsum = header.Checksum(payload, xsum)
			if udp.CalculateChecksum(xsum) != 0xffff {
				_ = syslog.WarnTf(
					tag,
					"%s: received damaged UDP frame (%s@%s -> %s); discarding %d bytes",
					nicName,
					ip.SourceAddress(),
					senderAddr,
					ip.DestinationAddress(),
					len(udp),
				)
				continue
			}
		}

		h := hdr(payload)
		if !h.isValid() {
			return recvResult{}, false, fmt.Errorf("invalid hdr: %x", h)
		}

		if op := h.op(); op != opReply {
			return recvResult{}, false, fmt.Errorf("op-code=%s, want=%s", h, opReply)
		}

		if !bytes.Equal(h.xidbytes(), xid[:]) {
			// This message is for another client, ignore silently.
			continue
		}

		{
			opts, err := h.options()
			if err != nil {
				return recvResult{}, false, fmt.Errorf("invalid options: %w", err)
			}

			typ, err := opts.dhcpMsgType()
			if err != nil {
				return recvResult{}, false, fmt.Errorf("invalid type: %w", err)
			}

			return recvResult{
				source:  ip.SourceAddress(),
				yiaddr:  tcpip.Address(h.yiaddr()),
				options: opts,
				typ:     typ,
			}, false, nil
		}
	}
}
