[mdns] Implement basic mDNS lookup
We need to replace netaddr with something else. This library
attempts to implement mDNS in a OS agnostic way for use as
a host tool or in botanist.
US-545
Change-Id: I9df7540b3ff36dde3b8ec69326d06943419c89c7
diff --git a/cmd/mdnstool/main.go b/cmd/mdnstool/main.go
new file mode 100644
index 0000000..7a2e1da
--- /dev/null
+++ b/cmd/mdnstool/main.go
@@ -0,0 +1,191 @@
+package main
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "log"
+ "net"
+ "os"
+ "time"
+
+ "fuchsia.googlesource.com/tools/mdns"
+)
+
+// TODO(jakehehrlich): This doesn't retry or anything, it just times out. It would
+// be nice to make this more robust.
+// sends out a request for the ip of domain. Waits for up to |dur| amount of time.
+// Will stop waiting for a response if ctx is done. The default port for mDNS is
+// 5353 but you're allowed to specify via port.
+func mDNSResolve(ctx context.Context, domain string, port int, dur time.Duration) (net.IP, error) {
+ var m mdns.MDNS
+ out := make(chan net.IP)
+ // Add all of our handlers
+ m.AddHandler(func(iface net.Interface, addr net.Addr, packet mdns.Packet) {
+ for _, a := range packet.Answers {
+ if a.Class == mdns.IN && a.Type == mdns.A && a.Domain == domain {
+ out <- net.IP(a.Data)
+ return
+ }
+ }
+ })
+ m.AddWarningHandler(func(addr net.Addr, err error) {
+ log.Printf("from: %v warn: %v", addr, err)
+ })
+ errs := make(chan error)
+ m.AddErrorHandler(func(err error) {
+ errs <- err
+ })
+ // Before we start we need to add a context to timeout with
+ ctx, cancel := context.WithTimeout(ctx, dur)
+ defer cancel()
+ // Start up the mdns loop
+ if err := m.Start(ctx, port); err != nil {
+ return nil, fmt.Errorf("starting mdns: %v", err)
+ }
+ // Send a packet requesting an answer to "what is the IP of |domain|?"
+ m.Send(mdns.QuestionPacket(domain))
+ // Now wait for either a timeout, an error, or an answer.
+ select {
+ case <-ctx.Done():
+ return nil, fmt.Errorf("timeout")
+ case err := <-errs:
+ return nil, err
+ case ip := <-out:
+ return ip, nil
+ }
+}
+
+// TODO(jakehehrlich): Add support for unicast.
+// mDNSPublish will respond to requests for the ip of domain by responding with ip.
+// It is assumed that ip is an ipv4 address. You can stop the server by canceling
+// ctx. Even though mDNS is generally on 5353 you can specify any port via port.
+func mDNSPublish(ctx context.Context, domain string, port int, ip net.IP) error {
+ // Now create and mDNS server
+ var m mdns.MDNS
+ m.AddHandler(func(iface net.Interface, addr net.Addr, packet mdns.Packet) {
+ log.Printf("from %v packet %v", addr, packet)
+ for _, q := range packet.Questions {
+ if q.Class == mdns.IN && q.Type == mdns.A && q.Domain == domain {
+ // We ignore the Unicast bit here but in theory this could be handled via SendTo and addr.
+ m.Send(mdns.AnswerPacket(domain, ip))
+ }
+ }
+ })
+ m.AddWarningHandler(func(addr net.Addr, err error) {
+ log.Printf("from: %v warn: %v", addr, err)
+ })
+ errs := make(chan error)
+ m.AddErrorHandler(func(err error) {
+ errs <- err
+ })
+ // Now start the server.
+ if err := m.Start(ctx, port); err != nil {
+ return fmt.Errorf("starting mdns: %v", err)
+ }
+ // Now wait for either a timeout, an error, or an answer.
+ select {
+ case <-ctx.Done():
+ return nil
+ case err := <-errs:
+ return err
+ }
+}
+
+// This function makes a faulty assumption. It assumes that the first
+// multicast interface it finds with an ipv4 address will be the
+// address the user wants. There isn't really a way to guess exactly
+// the address that the user will want.
+func getMulticastIP() net.IP {
+ ifaces, err := net.Interfaces()
+ if err != nil {
+ return nil
+ }
+ for _, i := range ifaces {
+ if i.Flags&net.FlagMulticast == 0 {
+ continue
+ }
+ addrs, err := i.Addrs()
+ if err != nil {
+ return nil
+ }
+ for _, addr := range addrs {
+ switch v := addr.(type) {
+ case *net.IPNet:
+ if ip4 := v.IP.To4(); ip4 != nil {
+ return ip4
+ }
+ }
+ }
+ }
+ return nil
+}
+
+var (
+ port int
+ timeout int
+ ipAddr string
+)
+
+func init() {
+ flag.StringVar(&ipAddr, "ip", "", "the ip to respond with when servering.")
+ flag.IntVar(&port, "port", 5353, "the port your mDNS servers operate on")
+ flag.IntVar(&timeout, "timeout", 2000, "the number of milliseconds before declaring a timeout")
+}
+
+func publish(args ...string) error {
+ if len(args) < 1 {
+ return fmt.Errorf("missing domain to serve")
+ }
+ var ip net.IP
+ if ipAddr == "" {
+ ip = getMulticastIP()
+ if ip = getMulticastIP(); ip == nil {
+ return fmt.Errorf("could not find a suitable ip")
+ }
+ } else {
+ if ip = net.ParseIP(ipAddr); ip == nil {
+ return fmt.Errorf("'%s' is not a valid ip address", ipAddr)
+ }
+ }
+ domain := args[0]
+ if err := mDNSPublish(context.Background(), domain, port, ip); err != nil {
+ return err
+ }
+ return nil
+}
+
+func resolve(args ...string) error {
+ if len(args) < 1 {
+ return fmt.Errorf("missing domain to request")
+ }
+ domain := args[0]
+ ip, err := mDNSResolve(context.Background(), domain, port, time.Duration(timeout)*time.Millisecond)
+ if err != nil {
+ return err
+ }
+ fmt.Printf("%v\n", ip)
+ return nil
+}
+
+func main() {
+ flag.Parse()
+ args := flag.Args()
+ if len(args) < 1 {
+ log.Printf("error: no command given")
+ os.Exit(1)
+ }
+ mp := map[string]func(...string) error{
+ "publish": publish,
+ "resolve": resolve,
+ }
+ if f, ok := mp[args[0]]; ok {
+ if err := f(args[1:]...); err != nil {
+ log.Printf("error: %v", err)
+ }
+ return
+ } else {
+ log.Printf("error: %s is not a command", args[0])
+ }
+ os.Exit(1)
+}
diff --git a/go.mod b/go.mod
index 23be58f..0d3c6ee 100644
--- a/go.mod
+++ b/go.mod
@@ -4,5 +4,6 @@
github.com/google/subcommands v0.0.0-20181012225330-46f0354f6315
github.com/google/uuid v1.1.0
go.chromium.org/luci v0.0.0-20181205024016-0c89bd1bcf4f
+ golang.org/x/net v0.0.0-20181029044818-c44066c5c816
golang.org/x/sys v0.0.0-20181128092732-4ed8d59d0b35
)
diff --git a/go.sum b/go.sum
index 4766bd7..664dcc7 100644
--- a/go.sum
+++ b/go.sum
@@ -4,5 +4,7 @@
github.com/google/uuid v1.1.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
go.chromium.org/luci v0.0.0-20181205024016-0c89bd1bcf4f h1:YHpVpP44+EX/FdtkLuityBjE7JHl6Yhm9SVVToihk+s=
go.chromium.org/luci v0.0.0-20181205024016-0c89bd1bcf4f/go.mod h1:MIQewVTLvOvc0UioV0JNqTNO/RspKFS0XEeoKrOxsdM=
+golang.org/x/net v0.0.0-20181029044818-c44066c5c816 h1:mVFkLpejdFLXVUv9E42f3XJVfMdqd0IVLVIVLjZWn5o=
+golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/sys v0.0.0-20181128092732-4ed8d59d0b35 h1:YAFjXN64LMvktoUZH9zgY4lGc/msGN7HQfoSuKCgaDU=
golang.org/x/sys v0.0.0-20181128092732-4ed8d59d0b35/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
diff --git a/mdns/mdns.go b/mdns/mdns.go
new file mode 100644
index 0000000..eff713c
--- /dev/null
+++ b/mdns/mdns.go
@@ -0,0 +1,598 @@
+package mdns
+
+import (
+ "bytes"
+ "context"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "net"
+ "os"
+ "strings"
+ "syscall"
+ "unicode/utf8"
+
+ "golang.org/x/net/ipv4"
+)
+
+type Header struct {
+ ID uint16
+ Flags uint16
+ QDCount uint16
+ ANCount uint16
+ NSCount uint16
+ ARCount uint16
+}
+
+type Record struct {
+ Domain string
+ Type uint16
+ Class uint16
+ Flush bool
+ TTL uint32
+ Data []byte
+}
+
+type Question struct {
+ Domain string
+ Type uint16
+ Class uint16
+ Unicast bool
+}
+
+type Packet struct {
+ Header Header
+ Questions []Question
+ Answers []Record
+ Authority []Record
+ Additional []Record
+}
+
+func writeUint16(out io.Writer, val uint16) error {
+ buf := make([]byte, 2)
+ binary.BigEndian.PutUint16(buf, val)
+ _, err := out.Write(buf)
+ return err
+}
+
+func (h Header) serialize(out io.Writer) error {
+ if err := writeUint16(out, h.ID); err != nil {
+ return err
+ }
+ if err := writeUint16(out, h.Flags); err != nil {
+ return err
+ }
+ if err := writeUint16(out, h.QDCount); err != nil {
+ return err
+ }
+ if err := writeUint16(out, h.ANCount); err != nil {
+ return err
+ }
+ if err := writeUint16(out, h.NSCount); err != nil {
+ return err
+ }
+ if err := writeUint16(out, h.ARCount); err != nil {
+ return err
+ }
+ return nil
+}
+
+func writeDomain(out io.Writer, domain string) error {
+ domain = strings.TrimSuffix(domain, ".")
+ parts := strings.Split(domain, ".")
+ // TODO(jakehehrlich): Add check that each label is ASCII.
+ // TODO(jakehehrlich): Add check that each label is <= 63 in length.
+ // TODO(jakehehrlich): Add support for compression.
+ for _, dpart := range parts {
+ ascii := []byte(dpart)
+ if _, err := out.Write([]byte{byte(len(ascii))}); err != nil {
+ return err
+ }
+ if _, err := out.Write(ascii); err != nil {
+ return err
+ }
+ }
+ _, err := out.Write([]byte{0})
+ return err
+}
+
+func (q Question) serialize(out io.Writer) error {
+ if err := writeDomain(out, q.Domain); err != nil {
+ return err
+ }
+ if err := writeUint16(out, q.Type); err != nil {
+ return err
+ }
+ var unicast uint16
+ if q.Unicast {
+ unicast = 1 << 15
+ }
+ if err := writeUint16(out, unicast|q.Class); err != nil {
+ return err
+ }
+ return nil
+}
+
+func writeUint32(out io.Writer, val uint32) error {
+ buf := make([]byte, 4)
+ binary.BigEndian.PutUint32(buf, val)
+ _, err := out.Write(buf)
+ return err
+}
+
+func (r Record) serialize(out io.Writer) error {
+ if err := writeDomain(out, r.Domain); err != nil {
+ return err
+ }
+ if err := writeUint16(out, r.Type); err != nil {
+ return err
+ }
+ var flush uint16
+ if r.Flush {
+ flush = 1 << 15
+ }
+ if err := writeUint16(out, flush|r.Class); err != nil {
+ return err
+ }
+ if err := writeUint32(out, r.TTL); err != nil {
+ return err
+ }
+ if err := writeUint16(out, uint16(len(r.Data))); err != nil {
+ return err
+ }
+ if _, err := out.Write(r.Data); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (p Packet) serialize(out io.Writer) error {
+ if err := p.Header.serialize(out); err != nil {
+ return err
+ }
+ for _, question := range p.Questions {
+ if err := question.serialize(out); err != nil {
+ return err
+ }
+ }
+ for _, answer := range p.Answers {
+ if err := answer.serialize(out); err != nil {
+ return err
+ }
+ }
+ for _, authority := range p.Authority {
+ if err := authority.serialize(out); err != nil {
+ return err
+ }
+ }
+ for _, addon := range p.Additional {
+ if err := addon.serialize(out); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func readUint16(in io.Reader, out *uint16) error {
+ buf := make([]byte, 2)
+ _, err := in.Read(buf)
+ if err != nil {
+ return err
+ }
+ *out = binary.BigEndian.Uint16(buf)
+ return nil
+}
+
+func (h *Header) deserialize(data []byte, in io.Reader) error {
+ if err := readUint16(in, &h.ID); err != nil {
+ return err
+ }
+ if err := readUint16(in, &h.Flags); err != nil {
+ return err
+ }
+ if err := readUint16(in, &h.QDCount); err != nil {
+ return err
+ }
+ if err := readUint16(in, &h.ANCount); err != nil {
+ return err
+ }
+ if err := readUint16(in, &h.NSCount); err != nil {
+ return err
+ }
+ if err := readUint16(in, &h.ARCount); err != nil {
+ return err
+ }
+ return nil
+}
+
+func readDomain(data []byte, in io.Reader, domain *string) error {
+ // TODO(jakehehrlich): Don't stack overflow when domain contains cycle.
+
+ var d bytes.Buffer
+ for {
+ sizeBuf := make([]byte, 1)
+ if _, err := in.Read(sizeBuf); err != nil {
+ return err
+ }
+ size := sizeBuf[0]
+ // A size of zero indicates that we're done.
+ if size == 0 {
+ break
+ }
+ // We don't support compressed domains right now.
+ if size > 63 {
+ if size < 192 {
+ return fmt.Errorf("invalid size for label")
+ }
+ if _, err := in.Read(sizeBuf); err != nil {
+ return err
+ }
+ offset := ((size & 0x3f) << 8) | sizeBuf[0]
+ var pDomain string
+ readDomain(data, bytes.NewBuffer(data[offset:]), &pDomain)
+ if _, err := d.WriteString(pDomain); err != nil {
+ return err
+ }
+ if err := d.WriteByte(byte('.')); err != nil {
+ return err
+ }
+ break
+ }
+ // Read in the specified bytes (max length 256)
+ buf := make([]byte, size)
+ if _, err := in.Read(buf); err != nil {
+ return err
+ }
+ // Make sure the string is ASCII
+ for _, b := range buf {
+ if b >= utf8.RuneSelf {
+ return fmt.Errorf("Found non-ASCII byte %v in domain", b)
+ }
+ }
+ // Now add this to a temporary domain
+ if _, err := d.Write(buf); err != nil {
+ return err
+ }
+ // Add the trailing "." as seen in the RFC.
+ if err := d.WriteByte(byte('.')); err != nil {
+ return err
+ }
+ }
+ *domain = string(d.Bytes())
+ // Remove the trailing '.' to canonicalize.
+ *domain = strings.TrimSuffix(*domain, ".")
+ return nil
+}
+
+func (q *Question) deserialize(data []byte, in io.Reader) error {
+ if err := readDomain(data, in, &q.Domain); err != nil {
+ return fmt.Errorf("reading domain: %v", err)
+ }
+ if err := readUint16(in, &q.Type); err != nil {
+ return err
+ }
+ var tmp uint16
+ if err := readUint16(in, &tmp); err != nil {
+ return err
+ }
+ // Extract class and unicast bit.
+ q.Unicast = (tmp >> 15) != 0
+ q.Class = (tmp << 1) >> 1
+ return nil
+}
+
+func readUint32(in io.Reader, out *uint32) error {
+ buf := make([]byte, 4)
+ _, err := in.Read(buf)
+ if err != nil {
+ return err
+ }
+ *out = binary.BigEndian.Uint32(buf)
+ return nil
+}
+
+func (r *Record) deserialize(data []byte, in io.Reader) error {
+ if err := readDomain(data, in, &r.Domain); err != nil {
+ return err
+ }
+ if err := readUint16(in, &r.Type); err != nil {
+ return err
+ }
+ var tmp uint16
+ if err := readUint16(in, &tmp); err != nil {
+ return err
+ }
+ // Extract class and flush bit.
+ r.Flush = (tmp >> 15) != 0
+ r.Class = (tmp << 1) >> 1
+ if err := readUint32(in, &r.TTL); err != nil {
+ return err
+ }
+
+ var dataLength uint16
+ if err := readUint16(in, &dataLength); err != nil {
+ return err
+ }
+ // Now read the data (max allocation size of 64k)
+ r.Data = make([]byte, dataLength)
+ if _, err := in.Read(r.Data); err != nil {
+ return err
+ }
+ return nil
+}
+
+// TODO(jakehehrlich): Handle truncation.
+func (p *Packet) deserialize(data []byte, in io.Reader) error {
+ if err := p.Header.deserialize(data, in); err != nil {
+ return err
+ }
+ p.Questions = make([]Question, p.Header.QDCount)
+ for i := uint16(0); i < p.Header.QDCount; i++ {
+ if err := p.Questions[i].deserialize(data, in); err != nil {
+ return err
+ }
+ }
+ p.Answers = make([]Record, p.Header.ANCount)
+ for i := uint16(0); i < p.Header.ANCount; i++ {
+ if err := p.Answers[i].deserialize(data, in); err != nil {
+ return err
+ }
+ }
+ p.Authority = make([]Record, p.Header.NSCount)
+ for i := uint16(0); i < p.Header.NSCount; i++ {
+ if err := p.Authority[i].deserialize(data, in); err != nil {
+ return err
+ }
+ }
+ p.Additional = make([]Record, p.Header.ARCount)
+ for i := uint16(0); i < p.Header.ARCount; i++ {
+ if err := p.Additional[i].deserialize(data, in); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// getFlag constructs the flag field of a header for the tiny subset of
+// flag options that we need.
+// TODO(jakehehrlich): Implement response code error handling.
+// TODO(jakehehrlich): Implement truncation.
+func getFlag(query bool, authority bool) uint16 {
+ var out uint16
+ if !query {
+ out |= 1
+ }
+ if authority {
+ out |= 1 << 5
+ }
+ return out
+}
+
+const (
+ // A is the DNS Type for ipv4
+ A = 1
+ // AAAA is the DNS Type for ipv6
+ AAAA = 28
+ // PTR is the DNS Type for domain name pointers
+ PTR = 12
+ // SRV is the DNS Type for services
+ SRV = 33
+ // IN is the Internet DNS Class
+ IN = 1
+)
+
+// MDNS is the central type though which requests are sent and received.
+// This implementation is agnostic to use case and asynchronous.
+// To handle various responses add Handlers. To send a packet you may use
+// either SendTo (generally used for unicast) or Send (generally used for
+// multicast).
+type MDNS struct {
+ conn *ipv4.PacketConn
+ senders []net.PacketConn
+ port int
+ pHandlers []func(net.Interface, net.Addr, Packet)
+ wHandlers []func(net.Addr, error)
+ eHandlers []func(error)
+}
+
+// AddHandler calls f on every Packet received.
+func (m *MDNS) AddHandler(f func(net.Interface, net.Addr, Packet)) {
+ m.pHandlers = append(m.pHandlers, f)
+}
+
+// AddWarningHandler calls f on every non-fatal error.
+func (m *MDNS) AddWarningHandler(f func(net.Addr, error)) {
+ m.wHandlers = append(m.wHandlers, f)
+}
+
+// AddErrorHandler calls f on every fatal error. After
+// all active handlers are called, m will stop listening and
+// close it's connection so this function will not be called twice.
+func (m *MDNS) AddErrorHandler(f func(error)) {
+ m.eHandlers = append(m.eHandlers, f)
+}
+
+// SendTo serializes and sends packet to dst. If dst is a multicast
+// address then packet is multicast to the corresponding group on
+// all interfaces. Note that start must be called prior to making this
+// call.
+func (m *MDNS) SendTo(packet Packet, dst *net.UDPAddr) error {
+ var buf bytes.Buffer
+ // TODO(jakehehrlich): Add checking that the packet is well formed.
+ if err := packet.serialize(&buf); err != nil {
+ return err
+ }
+ for _, sender := range m.senders {
+ if _, err := sender.WriteTo(buf.Bytes(), dst); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+var mdnsMulticastIPv4 = net.ParseIP("224.0.0.251")
+
+// Send serializes and sends packet out as a multicast to all interfaces
+// using the port that m is listening on. Note that Start must be
+// called prior to making this call.
+func (m *MDNS) Send(packet Packet) error {
+ dst := net.UDPAddr{IP: mdnsMulticastIPv4, Port: m.port}
+ return m.SendTo(packet, &dst)
+}
+
+func makeUnixIpv4Socket(port int, ip net.IP) (net.PacketConn, error) {
+ fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP)
+ if err != nil {
+ return nil, fmt.Errorf("creating socket: %v", err)
+ }
+ // SO_REUSEADDR and SO_REUSEPORT allows binding to the same port multiple
+ // times which is necessary in the case when there are multiple instances.
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, 0x2 /*SO_REUSEADDR*/, 1); err != nil {
+ syscall.Close(fd)
+ return nil, fmt.Errorf("setting reuse addr: %v", err)
+ }
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, 0xf /*SO_REUSEPORT*/, 1); err != nil {
+ syscall.Close(fd)
+ return nil, fmt.Errorf("setting reuse port: %v", err)
+ }
+ // Bind the socket to the specified port.
+ var ipArray [4]byte
+ copy(ipArray[:], []byte(ip))
+ if err := syscall.Bind(fd, &syscall.SockaddrInet4{Addr: ipArray, Port: port}); err != nil {
+ syscall.Close(fd)
+ return nil, fmt.Errorf("binding to %v: %v", ip, err)
+ }
+ // Now make a socket.
+ f := os.NewFile(uintptr(fd), "")
+ conn, err := net.FilePacketConn(f)
+ f.Close()
+ if err != nil {
+ return nil, fmt.Errorf("creating packet conn: %v", err)
+ }
+ return conn, nil
+}
+
+// Start causes m to start listening for MDNS packets on all interfaces on
+// the specified port. Listening will stop if ctx is done.
+func (m *MDNS) Start(ctx context.Context, port int) error {
+ dst := &net.UDPAddr{IP: mdnsMulticastIPv4, Port: port}
+ conn, err := net.ListenUDP("udp4", dst)
+ if err != nil {
+ return err
+ }
+ // Now we need a low level ipv4 packet connection.
+ m.conn = ipv4.NewPacketConn(conn)
+ m.conn.SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true)
+ m.port = port
+ // Now we need to join this connection to every interface that supports
+ // Multicast.
+ ifaces, err := net.Interfaces()
+ if err != nil {
+ conn.Close()
+ return fmt.Errorf("listing interfaces: %v", err)
+ }
+ // We need to make sure to handle each interface.
+ for _, iface := range ifaces {
+ if iface.Flags&net.FlagMulticast == 0 || iface.Flags&net.FlagUp == 0 {
+ continue
+ }
+ // This allows us to listen on this specific interface.
+ if err := m.conn.JoinGroup(&iface, dst); err != nil {
+ conn.Close()
+ return fmt.Errorf("joining %v%%%v: %v", iface, dst, err)
+ }
+ addrs, err := iface.Addrs()
+ if err != nil {
+ return fmt.Errorf("getting addresses of %v: %v", iface, err)
+ }
+ for _, addr := range addrs {
+ var ip net.IP
+ switch v := addr.(type) {
+ case *net.IPNet:
+ ip = v.IP
+ case *net.IPAddr:
+ ip = v.IP
+ }
+ if ip == nil || ip.To4() == nil {
+ continue
+ }
+ conn, err := makeUnixIpv4Socket(port, ip.To4())
+ if err != nil {
+ return fmt.Errorf("creating socket for %v via %v: %v", iface, ip, err)
+ }
+ m.senders = append(m.senders, conn)
+ break
+ }
+ }
+ go func() {
+ defer conn.Close()
+ // Now that we've joined every possible interface we can handle the main loop.
+ payloadBuf := make([]byte, 1<<16)
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
+ size, cm, src, err := m.conn.ReadFrom(payloadBuf)
+ if err != nil {
+ for _, e := range m.eHandlers {
+ go e(err)
+ }
+ return
+ }
+ iface, err := net.InterfaceByIndex(cm.IfIndex)
+ if err != nil {
+ for _, e := range m.eHandlers {
+ go e(err)
+ }
+ return
+ }
+ var packet Packet
+ data := payloadBuf[:size]
+ if err := packet.deserialize(data, bytes.NewBuffer(data)); err != nil {
+ for _, w := range m.wHandlers {
+ go w(src, err)
+ }
+ continue
+ }
+ for _, p := range m.pHandlers {
+ go p(*iface, src, packet)
+ }
+ }
+ }()
+ return nil
+}
+
+// QuestionPacket constructs and returns a packet that
+// requests the ip address associated with domain.
+func QuestionPacket(domain string) Packet {
+ return Packet{
+ Header: Header{QDCount: 1},
+ Questions: []Question{
+ Question{
+ Domain: domain,
+ Type: A,
+ Class: IN,
+ Unicast: false,
+ },
+ },
+ }
+}
+
+// AnswerPacket constructs and returns a packet that
+// gives a response to the
+func AnswerPacket(domain string, ip net.IP) Packet {
+ return Packet{
+ Header: Header{ANCount: 1},
+ Answers: []Record{
+ Record{
+ Domain: domain,
+ Type: A,
+ Class: IN,
+ Flush: false,
+ Data: []byte(ip),
+ },
+ },
+ }
+}
diff --git a/mdns/mdns_test.go b/mdns/mdns_test.go
new file mode 100644
index 0000000..8832eb8
--- /dev/null
+++ b/mdns/mdns_test.go
@@ -0,0 +1,117 @@
+package mdns
+
+import (
+ "bytes"
+ "testing"
+)
+
+func testUint16(t *testing.T) {
+ var buf bytes.Buffer
+ v := uint16(6857)
+ writeUint16(&buf, v)
+ var v2 uint16
+ readUint16(&buf, &v2)
+ if v != v2 {
+ t.Fatal()
+ }
+}
+
+func testUint32(t *testing.T) {
+ var buf bytes.Buffer
+ v := uint32(6857)
+ writeUint32(&buf, v)
+ var v2 uint32
+ readUint32(&buf, &v2)
+ if v != v2 {
+ t.Fatal()
+ }
+}
+
+func testHeader(t *testing.T) {
+ var buf bytes.Buffer
+ v := Header{
+ ID: 593,
+ Flags: 795,
+ QDCount: 5839,
+ ANCount: 9009,
+ NSCount: 8583,
+ ARCount: 7764,
+ }
+ v.serialize(&buf)
+ var v2 Header
+ v.deserialize(buf.Bytes(), &buf)
+ if v != v2 {
+ t.Fatal()
+ }
+}
+
+func testDomain(t *testing.T) {
+ var buf bytes.Buffer
+ v := "this.is.a.random.domain.to.check"
+ writeDomain(&buf, v)
+ var v2 string
+ readDomain(buf.Bytes(), &buf, &v2)
+ if v != v2 {
+ t.Fatal()
+ }
+}
+
+func testQuestion(t *testing.T) {
+ var buf bytes.Buffer
+ v := Question{
+ Domain: "some.random.thing.local",
+ Type: 5954,
+ Unicast: true,
+ }
+ v.serialize(&buf)
+ var v2 Question
+ v2.deserialize(buf.Bytes(), &buf)
+ if v != v2 {
+ t.Fatal()
+ }
+}
+
+func equalBytes(a, b []byte) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for i, ai := range a {
+ if ai != b[i] {
+ return false
+ }
+ }
+ return true
+}
+
+func testRecord(t *testing.T) {
+ var buf bytes.Buffer
+ v := Record{
+ Domain: "some.random.thing",
+ Type: 1234,
+ Class: 8765,
+ Flush: true,
+ TTL: 18656,
+ Data: []byte{45, 145, 253, 167, 34, 74},
+ }
+ v.serialize(&buf)
+ var v2 Record
+ v2.deserialize(buf.Bytes(), &buf)
+ if v.Domain != v2.Domain {
+ t.Fatal()
+ }
+ if v.Type != v2.Type {
+ t.Fatal()
+ }
+ if v.Class != v2.Class {
+ t.Fatal()
+ }
+ if v.Flush != v2.Flush {
+ t.Fatal()
+ }
+ if v.TTL != v2.TTL {
+ t.Fatal()
+ }
+ if !equalBytes(v.Data, v2.Data) {
+ t.Fatal()
+ }
+}