blob: 1cef0cddadfa379d7fb707029722f224a42c8244 [file] [log] [blame]
// +build !darwin
package main
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"strings"
"syscall"
"time"
"go.fuchsia.dev/fuchsia/tools/net/netboot"
devicePkg "go.fuchsia.dev/infra/devices"
"golang.org/x/sys/unix"
)
const (
// broadcastRetryTimeout defines how long we can retry broadcast packet reads.
broadcastRetryTimeout = 10 * time.Second
)
// netbootHeader is the netboot protocol message header.
type netbootHeader struct {
Magic uint32
Cookie uint32
Cmd uint32
Arg uint32
}
// netbootMessage is the netboot protocol message.
type netbootMessage struct {
Header netbootHeader
Data [1024]byte
}
func netbootString(bs []byte) (string, error) {
for i, b := range bs {
if b == 0 {
return string(bs[:i]), nil
}
}
return "", errors.New("no null terminated string found")
}
// checkBroadcasting ensures that broadcast packets are being sent by the device
// is a no-op on NUCs
func checkBroadcasting(n *netboot.Client, device *devicePkg.DeviceTarget) error {
if device.Type() == "nuc" {
return nil
}
// Create a socket.
socket, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP)
if err != nil {
return err
}
// Set reuseport on the socket so that multiple invocations do not
// trample each other.
if err := syscall.SetsockoptInt(socket, syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
return err
}
// Set a timeout so we don't block forever.
timeout := &syscall.Timeval{
Sec: 10,
}
if err := syscall.SetsockoptTimeval(socket, syscall.SOL_SOCKET, syscall.SO_RCVTIMEO, timeout); err != nil {
return err
}
// Set bind to device to only check for this device.
if err := syscall.SetsockoptString(socket, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, device.Interface()); err != nil {
return err
}
// Bind to the advertisement port.
var addr [16]byte
sockAddr := &syscall.SockaddrInet6{
Port: n.AdvertPort,
Addr: addr,
}
if err := syscall.Bind(socket, sockAddr); err != nil {
return err
}
// Read a packet in from the socket.
b := make([]byte, 4096)
startTime := time.Now()
for {
if _, err := syscall.Read(socket, b); err != nil {
if err != syscall.EINTR {
return err
}
} else {
break
}
if time.Now().Sub(startTime) >= broadcastRetryTimeout {
return errors.New("failed to read broadcast packets")
}
}
r := bytes.NewReader(b)
var res netbootMessage
if err := binary.Read(r, binary.LittleEndian, &res); err != nil {
return err
}
data, err := netbootString(res.Data[:])
if err != nil {
return err
}
// The query packet payload contains fields separated by ;.
for _, f := range strings.Split(string(data[:]), ";") {
// The field has a key=value format.
vars := strings.SplitN(f, "=", 2)
// The field with the "nodename" key contains the name of the device.
if vars[0] == "nodename" && vars[1] == device.Nodename() {
return nil
}
}
if err := syscall.Close(socket); err != nil {
return err
}
return fmt.Errorf("device %s is not broadcasting", device.Nodename())
}