| // Copyright 2017 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| // Package netboot implements the Zircon netboot protocol. |
| package netboot |
| |
| import ( |
| "bytes" |
| "context" |
| "encoding/binary" |
| "errors" |
| "fmt" |
| "net" |
| "reflect" |
| "regexp" |
| "strings" |
| "time" |
| |
| "go.fuchsia.dev/fuchsia/tools/lib/logger" |
| "golang.org/x/net/ipv6" |
| ) |
| |
| // NodenameWildcard is the wildcard for discovering all nodes. |
| const NodenameWildcard = "*" |
| |
| // Magic constants used by the netboot protocol. |
| const ( |
| baseCookie = uint32(0x12345678) |
| magic = 0xAA774217 // see //zircon/system/public/zircon/boot/netboot.h |
| ) |
| |
| // Port numbers used by the netboot protocol. |
| const ( |
| serverPort = 33330 // netboot server port |
| advertPort = 33331 // advertisement port |
| clientPortStart = 33332 // client port range start. |
| clientPortEnd = 33339 // client port range end. |
| ) |
| |
| // Commands supported by the netboot protocol. |
| const ( |
| cmdAck = uint32(0) // ack |
| cmdCommand = uint32(1) // command |
| cmdSendFile = uint32(2) // send file |
| cmdData = uint32(3) // data |
| cmdBoot = uint32(4) // boot command |
| cmdQuery = uint32(5) // query command |
| cmdShell = uint32(6) // shell command |
| cmdOpen = uint32(7) // open file |
| cmdRead = uint32(8) // read data |
| cmdWrite = uint32(9) // write data |
| cmdClose = uint32(10) // close file |
| cmdLastData = uint32(11) // |
| cmdReboot = uint32(12) // reboot command |
| ) |
| |
| // Client implements the netboot protocol. |
| type Client struct { |
| ServerPort int |
| AdvertPort int |
| Cookie uint32 |
| Timeout time.Duration |
| Wait bool |
| } |
| |
| // netbootHeader is the netboot protocol message header. |
| type netbootHeader struct { |
| Magic uint32 |
| Cookie uint32 |
| Cmd uint32 |
| Arg uint32 |
| } |
| |
| // netbootMessage is the netboot protocol message. |
| type netbootMessage struct { |
| Header netbootHeader |
| Data [1024]byte |
| } |
| |
| // Target defines a netboot protocol target, which includes information about how |
| // to find said target on the network: the target's nodename, its address, the |
| // address from the target back to the host, and the interface used to connect |
| // from the host to the target. |
| type Target struct { |
| // Nodename is target's nodename: thumb-set-human-neon is an example. |
| // This is derived from the NIC mac address. |
| Nodename string |
| |
| // TargetAddress is the address of the target from the host. |
| TargetAddress net.IP |
| |
| // HostAddress is the "local" address, i.e. the one to which the target |
| // is responding. This would be the address the target would send to in |
| // order to communicate with the host. |
| HostAddress net.IP |
| |
| // Interface is the index of the "local" interface connecting to |
| // the Fuchsia device. nil if this does not apply. |
| Interface *net.Interface |
| |
| // Error is the error associated with the device when returned via |
| // the StartDiscover function. |
| Error error |
| } |
| |
| // NewClient creates a new Client instance. |
| func NewClient(timeout time.Duration) *Client { |
| return &Client{ |
| Timeout: timeout, |
| ServerPort: serverPort, |
| AdvertPort: advertPort, |
| Cookie: baseCookie, |
| } |
| } |
| |
| type netbootQuery struct { |
| message netbootMessage |
| |
| conn6 *ipv6.PacketConn |
| conn *net.UDPConn |
| port int // The port to write on. |
| |
| isOpen bool |
| } |
| |
| func bindNetbootPort() (*net.UDPConn, error) { |
| var err error |
| var conn *net.UDPConn |
| // https://fuchsia.googlesource.com/fuchsia/+/0e30059/zircon/tools/netprotocol/netprotocol.c#59 |
| for i := clientPortStart; i <= clientPortEnd; i++ { |
| // Don't use the debugPort which is used by loglistener. |
| if i == debugPort { |
| continue |
| } |
| conn, err = net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: i}) |
| if err == nil { |
| break |
| } |
| } |
| return conn, err |
| } |
| |
| func newNetbootQuery(nodename string, cookie uint32, port int) (*netbootQuery, error) { |
| conn, err := bindNetbootPort() |
| if err != nil { |
| return nil, err |
| } |
| req := netbootMessage{ |
| Header: netbootHeader{ |
| Magic: magic, |
| Cookie: cookie, |
| Cmd: cmdQuery, |
| Arg: 0, |
| }, |
| } |
| copy(req.Data[:], nodename) |
| conn6 := ipv6.NewPacketConn(conn) |
| conn6.SetControlMessage(ipv6.FlagDst|ipv6.FlagSrc|ipv6.FlagInterface, true) |
| return &netbootQuery{ |
| message: req, |
| conn: conn, |
| conn6: conn6, |
| port: port, |
| isOpen: true}, nil |
| } |
| |
| func (n *netbootQuery) write(ctx context.Context) error { |
| // Cleanup function is used here in favor of defer to be explicit about |
| // what is being returned. It is difficult to reason about otherwise. |
| cleanup := func(e error) error { |
| if e != nil { |
| n.conn.Close() |
| } |
| return e |
| } |
| var buf bytes.Buffer |
| if err := binary.Write(&buf, binary.LittleEndian, n.message); err != nil { |
| return cleanup(err) |
| } |
| ifaces, err := net.Interfaces() |
| if err != nil { |
| return cleanup(err) |
| } |
| wrote := false |
| // Tracks last write error (in the event that all writes fail, for debugging). |
| var lastWriteErr error |
| for _, iface := range ifaces { |
| if iface.Flags&net.FlagUp == 0 { |
| continue |
| } |
| if iface.Flags&net.FlagLoopback != 0 { |
| continue |
| } |
| _, err := n.conn.WriteToUDP(buf.Bytes(), &net.UDPAddr{ |
| IP: net.IPv6linklocalallnodes, |
| Port: n.port, |
| Zone: iface.Name, |
| }) |
| |
| logger.Debugf(ctx, "writing on %s: %v", iface.Name, err) |
| |
| // Skip errors here, as it may be possible to write on |
| // some interfaces but not others. Track last error in |
| // case all writes fail on all interfaces. |
| if err != nil { |
| lastWriteErr = err |
| continue |
| } |
| wrote = true |
| } |
| if !wrote { |
| return cleanup(fmt.Errorf("write on any iface. Last err: %v", lastWriteErr)) |
| } |
| return nil |
| } |
| |
| func (n *netbootQuery) read() (*Target, error) { |
| b := make([]byte, 4096) |
| _, cm, _, err := n.conn6.ReadFrom(b) |
| // If there was an error, as the connection was already valid at the time |
| // of creation, this means that some timeout somewhere else has closed |
| // before this function was called (this is going to be called in a loop |
| // in a goroutine in most cases). There is no way to determine if the |
| // connection is still open unless there is an attempt at reading on it |
| // unfortunately. |
| if err != nil { |
| n.isOpen = false |
| return nil, nil |
| } |
| node, err := n.parse(b) |
| if err != nil { |
| return nil, err |
| } |
| if len(node) == 0 { |
| return nil, nil |
| } |
| var iface *net.Interface |
| if cm.IfIndex > 0 { |
| iface, err = net.InterfaceByIndex(cm.IfIndex) |
| if err != nil { |
| return nil, fmt.Errorf("query iface lookup: err") |
| } |
| } |
| return &Target{ |
| Nodename: node, |
| TargetAddress: cm.Src, |
| HostAddress: cm.Dst, |
| Interface: iface, |
| }, nil |
| } |
| |
| func (n *netbootQuery) parse(b []byte) (string, error) { |
| r := bytes.NewReader(b) |
| var res netbootMessage |
| if err := binary.Read(r, binary.LittleEndian, &res); err != nil { |
| return "", fmt.Errorf("query parse error: %v", err) |
| } |
| if res.Header.Magic != n.message.Header.Magic || res.Header.Cookie != n.message.Header.Cookie || res.Header.Cmd != cmdAck { |
| return "", nil |
| } |
| data, err := netbootString(res.Data[:]) |
| if err != nil { |
| return "", err |
| } |
| return data, nil |
| } |
| |
| func (n *netbootQuery) close() error { |
| return n.conn.Close() |
| } |
| |
| // Discover returns the netsvc address of nodename. |
| func (n *Client) Discover(ctx context.Context, nodename string) (*net.UDPAddr, error) { |
| ctx, cancel := context.WithTimeout(ctx, n.Timeout) |
| defer cancel() |
| t := make(chan *Target) |
| cleanup, err := n.StartDiscover(ctx, t, nodename) |
| if err != nil { |
| return nil, err |
| } |
| defer cleanup() |
| for { |
| select { |
| case <-ctx.Done(): |
| return nil, fmt.Errorf("discover waiting for results: %w", ctx.Err()) |
| case target := <-t: |
| if err := target.Error; err != nil { |
| return nil, err |
| } |
| ifaceName := "" |
| if target.Interface != nil { |
| ifaceName = target.Interface.Name |
| } |
| if nodename == NodenameWildcard { |
| logger.Debugf(ctx, "found nodename %s", target.Nodename) |
| } |
| return &net.UDPAddr{IP: target.TargetAddress, Zone: ifaceName}, nil |
| } |
| } |
| } |
| |
| // DiscoverAll returns all netsvc addresses on the local network. |
| // |
| // If no devices are found, returns a nil array. |
| func (n *Client) DiscoverAll(ctx context.Context) ([]*Target, error) { |
| ctx, cancel := context.WithTimeout(ctx, n.Timeout) |
| defer cancel() |
| t := make(chan *Target) |
| cleanup, err := n.StartDiscover(ctx, t, NodenameWildcard) |
| if err != nil { |
| return nil, err |
| } |
| defer cleanup() |
| results := []*Target{} |
| for { |
| select { |
| case <-ctx.Done(): |
| if len(results) == 0 { |
| return nil, nil |
| } |
| return results, nil |
| case target := <-t: |
| // If there's an error and devices hav already been found, |
| // this isn't an error. |
| if err := target.Error; err != nil && len(results) == 0 { |
| return nil, err |
| } |
| if err := target.Error; err == nil { |
| results = append(results, target) |
| } |
| } |
| } |
| } |
| |
| // StartDiscover takes a channel and returns every Target as it is found. Returns |
| // a cleanup function for closing the discovery connection or an error if there |
| // is a failure. |
| // |
| // Errors within discovery will be propagated up the channel via the Target.Error |
| // field. |
| // |
| // The Timeout field is not used with this function, and as such it is the |
| // caller's responsibility to handle timeouts when using this function. |
| // |
| // Example: |
| // ctx, cancel := context.WithTimeout(context.Background(), timeout) |
| // defer cancel() |
| // cleanup, err := c.StartDiscover(ctx, t, true) |
| // if err != nil { |
| // return err |
| // } |
| // defer cleanup() |
| // for { |
| // select { |
| // case target := <-t: |
| // // Do something with the target. |
| // case <-ctx.Done(): |
| // // Do something now that the parent context is |
| // // completed. |
| // } |
| // } |
| func (n *Client) StartDiscover(ctx context.Context, t chan<- *Target, nodename string) (func() error, error) { |
| n.Cookie++ |
| q, err := newNetbootQuery(nodename, n.Cookie, n.ServerPort) |
| if err != nil { |
| return nil, err |
| } |
| go func() { |
| defer q.close() |
| if err := q.write(ctx); err != nil { |
| t <- &Target{Error: err} |
| return |
| } |
| |
| for { |
| logger.Debugf(ctx, "discovering nodename=%s", nodename) |
| target, err := q.read() |
| if err != nil { |
| t <- &Target{Error: err} |
| return |
| } |
| if target != nil { |
| // Only skip if there's a name mismatch. |
| if nodename != NodenameWildcard && !strings.Contains(target.Nodename, nodename) { |
| logger.Debugf(ctx, "discarding response from %s, want=%s", target.Nodename, nodename) |
| continue |
| } |
| t <- target |
| } |
| // Simple cleanup to avoid extra cycles. |
| if !q.isOpen { |
| return |
| } |
| } |
| }() |
| return q.close, nil |
| } |
| |
| // Advertisement represents the information in an advertisement sent from a device. |
| type Advertisement struct { |
| Nodename string |
| BootloaderVersion string |
| } |
| |
| func (n *Client) beacon(conn *net.UDPConn) (*net.UDPAddr, *Advertisement, error) { |
| conn.SetReadDeadline(time.Now().Add(n.Timeout)) |
| |
| b := make([]byte, 4096) |
| _, addr, err := conn.ReadFromUDP(b) |
| if err != nil { |
| return nil, nil, err |
| } |
| msg, err := n.parseBeacon(b) |
| if err != nil { |
| return nil, nil, err |
| } |
| return addr, msg, err |
| } |
| |
| func (n *Client) parseBeacon(b []byte) (*Advertisement, error) { |
| r := bytes.NewReader(b) |
| var res netbootMessage |
| if err := binary.Read(r, binary.LittleEndian, &res); err != nil { |
| return nil, err |
| } |
| |
| data, err := netbootString(res.Data[:]) |
| if err != nil { |
| return nil, err |
| } |
| msg := &Advertisement{} |
| // The query packet payload contains fields separated by ;. |
| // Each field has a key=value format. |
| nodenameRegex := regexp.MustCompile("nodename=([^;]+)") |
| versionRegex := regexp.MustCompile("version=([^;]+)") |
| submatches := nodenameRegex.FindStringSubmatch(data) |
| if len(submatches) == 2 { |
| msg.Nodename = submatches[1] |
| } |
| submatches = versionRegex.FindStringSubmatch(data) |
| if len(submatches) == 2 { |
| msg.BootloaderVersion = submatches[1] |
| } |
| if msg.Nodename == "" || msg.BootloaderVersion == "" { |
| return nil, errors.New("no valid beacon") |
| } |
| |
| return msg, nil |
| } |
| |
| // BeaconOnInterface receives the beacon packet on a particular interface. |
| func (n *Client) BeaconOnInterface(networkInterface string) (*net.UDPAddr, error) { |
| conn, err := UDPConnWithReusablePort(n.AdvertPort, networkInterface, true) |
| if err != nil { |
| return nil, err |
| } |
| defer conn.Close() |
| addr, _, err := n.beacon(conn) |
| return addr, err |
| } |
| |
| func (n *Client) beaconForDevice(ctx context.Context, conn *net.UDPConn, nodename string, ipv6Addr *net.UDPAddr) (*net.UDPAddr, *Advertisement, error) { |
| for { |
| addr, msg, err := n.beacon(conn) |
| if err != nil { |
| return nil, nil, err |
| } |
| if ipv6Addr != nil { |
| if !reflect.DeepEqual(addr.IP, ipv6Addr.IP) || (ipv6Addr.Zone != "" && addr.Zone != ipv6Addr.Zone) { |
| logger.Debugf(ctx, "ignoring message not from allowed address %q", ipv6Addr) |
| continue |
| } |
| } |
| if nodename == NodenameWildcard { |
| logger.Debugf(ctx, "found nodename %s", msg.Nodename) |
| return addr, msg, err |
| } |
| if msg.Nodename != nodename { |
| logger.Debugf(ctx, "ignoring nodename %s (expecting %s)", msg.Nodename, nodename) |
| continue |
| } |
| return addr, msg, err |
| } |
| } |
| |
| // BeaconForDevice receives the beacon packet for a particular device with the given nodename or ipv6 address. |
| func (n *Client) BeaconForDevice(ctx context.Context, nodename string, ipv6Addr *net.UDPAddr, reusable bool) (*net.UDPAddr, *Advertisement, func(), error) { |
| ctx, cancel := context.WithTimeout(ctx, n.Timeout) |
| defer cancel() |
| conn, err := UDPConnWithReusablePort(n.AdvertPort, "", reusable) |
| if err != nil { |
| return nil, nil, func() {}, err |
| } |
| cleanup := func() { |
| conn.Close() |
| } |
| |
| type response struct { |
| addr *net.UDPAddr |
| msg *Advertisement |
| err error |
| } |
| r := make(chan response) |
| go func() { |
| addr, msg, err := n.beaconForDevice(ctx, conn, nodename, ipv6Addr) |
| r <- response{addr, msg, err} |
| }() |
| |
| select { |
| case <-ctx.Done(): |
| cleanup() |
| return nil, nil, func() {}, fmt.Errorf("beacon waiting for results: %w", ctx.Err()) |
| case resp := <-r: |
| if resp.err != nil { |
| cleanup() |
| return nil, nil, func() {}, resp.err |
| } |
| return resp.addr, resp.msg, cleanup, nil |
| } |
| } |
| |
| // Beacon receives the beacon packet, returning the address of the sender. |
| func (n *Client) Beacon() (*net.UDPAddr, error) { |
| conn, err := UDPConnWithReusablePort(n.AdvertPort, "", true) |
| if err != nil { |
| return nil, err |
| } |
| defer conn.Close() |
| addr, _, err := n.beacon(conn) |
| return addr, err |
| } |
| |
| // Boot sends a boot packet to the address. |
| func (n *Client) Boot(addr *net.UDPAddr) error { |
| n.Cookie++ |
| msg := &netbootHeader{ |
| Magic: magic, |
| Cookie: n.Cookie, |
| Cmd: cmdBoot, |
| Arg: 0, |
| } |
| if err := sendPacket(msg, addr, n.ServerPort); err != nil { |
| return fmt.Errorf("send boot command: %v\n", err) |
| } |
| return nil |
| } |
| |
| // Reboot sends a reboot packet the address. |
| func (n *Client) Reboot(addr *net.UDPAddr) error { |
| n.Cookie++ |
| msg := &netbootHeader{ |
| Magic: magic, |
| Cookie: n.Cookie, |
| Cmd: cmdReboot, |
| Arg: 0, |
| } |
| if err := sendPacket(msg, addr, n.ServerPort); err != nil { |
| return fmt.Errorf("send reboot command: %v\n", err) |
| } |
| return nil |
| } |
| |
| func sendPacket(msg *netbootHeader, addr *net.UDPAddr, port int) error { |
| if msg == nil { |
| return errors.New("no message provided") |
| } |
| var buf bytes.Buffer |
| if err := binary.Write(&buf, binary.LittleEndian, *msg); err != nil { |
| return err |
| } |
| |
| conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero}) |
| if err != nil { |
| return fmt.Errorf("create a socket: %v\n", err) |
| } |
| defer conn.Close() |
| |
| _, err = conn.WriteToUDP(buf.Bytes(), &net.UDPAddr{ |
| IP: addr.IP, |
| Port: port, |
| Zone: addr.Zone, |
| }) |
| return err |
| } |
| |
| func netbootString(bs []byte) (string, error) { |
| for i, b := range bs { |
| if b == 0 { |
| return string(bs[:i]), nil |
| } |
| } |
| return "", errors.New("no null terminated string found") |
| } |