[mdns] Implement IPv6 listening.

Modify the implementation to support listening on IPv6
multicast addresses as well.

Add another question record to each QuestionPacket()
result to ensure each one asks for both the IPv4 and IPv6
address of a given domain.

Note however that AnswerPacket() creates a packet with a
single answer record, whose record type (A or AAAA) depends
on the configured IP address.

+ Modify dev_finder to listen on both IPv4 and IPv6 addresses.

+ Modify mdnstool so that the "publish" and "resolve" command
  support both IPv4 and IPv6 addresses. Note that "publish"
  can support only one address type at a time.

BUG=DX-1664

Change-Id: Iddc9439f86c1958a131bf07e0b856e8bdcb2f85d
diff --git a/botanist/ip.go b/botanist/ip.go
index 267aa97..17aa44b 100644
--- a/botanist/ip.go
+++ b/botanist/ip.go
@@ -27,6 +27,7 @@
 // TODO(joshuaseaton): Refactor dev_finder to share 'resolve' logic with botanist.
 func ResolveIPv4(ctx context.Context, nodename string, timeout time.Duration) (net.IP, error) {
 	m := mdns.NewMDNS()
+	m.EnableIPv4()
 	out := make(chan net.IP)
 	domain := getLocalDomain(nodename)
 	m.AddHandler(func(iface net.Interface, addr net.Addr, packet mdns.Packet) {
diff --git a/cmd/dev_finder/common.go b/cmd/dev_finder/common.go
index 36780e1..2641ef4 100644
--- a/cmd/dev_finder/common.go
+++ b/cmd/dev_finder/common.go
@@ -93,7 +93,7 @@
 
 func (cmd *devFinderCmd) SetCommonFlags(f *flag.FlagSet) {
 	f.BoolVar(&cmd.json, "json", false, "Outputs in JSON format.")
-	f.StringVar(&cmd.mdnsAddrs, "addr", "224.0.0.251,224.0.0.250", "Comma separated list of addresses to issue mDNS queries to.")
+	f.StringVar(&cmd.mdnsAddrs, "addr", "224.0.0.251,224.0.0.250,ff02::fb", "Comma separated list of addresses to issue mDNS queries to.")
 	f.StringVar(&cmd.mdnsPorts, "port", "5353,5356", "Comma separated list of ports to issue mDNS queries to.")
 	f.IntVar(&cmd.timeout, "timeout", 2000, "The number of milliseconds before declaring a timeout.")
 	f.BoolVar(&cmd.localResolve, "local", false, "Returns the address of the interface to the host when doing service lookup/domain resolution.")
@@ -125,6 +125,12 @@
 		return cmd.newMDNSFunc(address)
 	}
 	m := mdns.NewMDNS()
+	ip := net.ParseIP(address)
+	if ip.To4() != nil {
+		m.EnableIPv4()
+	} else {
+		m.EnableIPv6()
+	}
 	m.SetAddress(address)
 	return m
 }
diff --git a/cmd/mdnstool/main.go b/cmd/mdnstool/main.go
index 3cb33ea..67048c1 100644
--- a/cmd/mdnstool/main.go
+++ b/cmd/mdnstool/main.go
@@ -23,11 +23,16 @@
 // 5353 but you're allowed to specify via port.
 func mDNSResolve(ctx context.Context, domain string, port int, dur time.Duration) (net.IP, error) {
 	m := mdns.NewMDNS()
+	m.EnableIPv4()
+	m.EnableIPv6()
 	out := make(chan net.IP)
 	// Add all of our handlers
 	m.AddHandler(func(iface net.Interface, addr net.Addr, packet mdns.Packet) {
 		for _, a := range packet.Answers {
-			if a.Class == mdns.IN && a.Type == mdns.A && a.Domain == domain {
+			if a.Class != mdns.IN || a.Domain != domain {
+				continue
+			}
+			if a.Type == mdns.A || a.Type == mdns.AAAA {
 				out <- net.IP(a.Data)
 				return
 			}
@@ -62,15 +67,19 @@
 
 // TODO(jakehehrlich): Add support for unicast.
 // mDNSPublish will respond to requests for the ip of domain by responding with ip.
-// It is assumed that ip is an ipv4 address. You can stop the server by canceling
-// ctx. Even though mDNS is generally on 5353 you can specify any port via port.
+// Note that this responds on both IPv4 and IPv6 interfaces, independent on the type
+// of ip itself. You can stop the server by canceling ctx. Even though mDNS is
+// generally on 5353 you can specify any port via port.
 func mDNSPublish(ctx context.Context, domain string, port int, ip net.IP) error {
 	// Now create and mDNS server
 	m := mdns.NewMDNS()
+	m.EnableIPv4()
+	m.EnableIPv6()
+	addrType := mdns.IpToDnsRecordType(ip)
 	m.AddHandler(func(iface net.Interface, addr net.Addr, packet mdns.Packet) {
 		log.Printf("from %v packet %v", addr, packet)
 		for _, q := range packet.Questions {
-			if q.Class == mdns.IN && q.Type == mdns.A && q.Domain == domain {
+			if q.Class == mdns.IN && q.Type == addrType && q.Domain == domain {
 				// We ignore the Unicast bit here but in theory this could be handled via SendTo and addr.
 				m.Send(mdns.AnswerPacket(domain, ip))
 			}
@@ -99,7 +108,8 @@
 // This function makes a faulty assumption. It assumes that the first
 // multicast interface it finds with an ipv4 address will be the
 // address the user wants. There isn't really a way to guess exactly
-// the address that the user will want.
+// the address that the user will want. If an IPv6 address is needed, then
+// using the -ip <address> option is required.
 func getMulticastIP() net.IP {
 	ifaces, err := net.Interfaces()
 	if err != nil {
@@ -132,7 +142,7 @@
 )
 
 func init() {
-	flag.StringVar(&ipAddr, "ip", "", "the ip to respond with when servering.")
+	flag.StringVar(&ipAddr, "ip", "", "the ip to respond with when serving.")
 	flag.IntVar(&port, "port", 5353, "the port your mDNS servers operate on")
 	flag.IntVar(&timeout, "timeout", 2000, "the number of milliseconds before declaring a timeout")
 }
diff --git a/mdns/mdns.go b/mdns/mdns.go
index d33c15e..bc1a353 100644
--- a/mdns/mdns.go
+++ b/mdns/mdns.go
@@ -16,6 +16,7 @@
 	"unicode/utf8"
 
 	"golang.org/x/net/ipv4"
+	"golang.org/x/net/ipv6"
 	"golang.org/x/sys/unix"
 )
 
@@ -388,12 +389,27 @@
 	IN = 1
 )
 
+// IpToDnsRecordType returns either A or AAAA based on the type of ip.
+func IpToDnsRecordType(ip net.IP) uint16 {
+	if ip4 := ip.To4(); ip4 != nil {
+		return A
+	} else {
+		return AAAA
+	}
+}
+
 // MDNS is the central interface through which requests are sent and received.
 // This implementation is agnostic to use case and asynchronous.
 // To handle various responses add Handlers. To send a packet you may use
 // either SendTo (generally used for unicast) or Send (generally used for
 // multicast).
 type MDNS interface {
+	// EnableIPv4 enables listening on IPv4 network interfaces.
+	EnableIPv4()
+
+	// EnableIPv6 enables listening on IPv6 network interfaces.
+	EnableIPv6()
+
 	// SetAddress sets a non-default listen address.
 	SetAddress(address string) error
 
@@ -447,14 +463,6 @@
 	senders []net.PacketConn
 }
 
-func (c *mDNSConnBase) SetIp(ip net.IP) error {
-	if ip4 := ip.To4(); ip4 == nil {
-		panic(fmt.Errorf("Not an IPv-4 address: %v", ip))
-	}
-	c.dst.IP = ip
-	return nil
-}
-
 func (c *mDNSConnBase) getIp() net.IP {
 	return c.dst.IP
 }
@@ -485,6 +493,14 @@
 	return &c
 }
 
+func (c *mDNSConn4) SetIp(ip net.IP) error {
+	if ip4 := ip.To4(); ip4 == nil {
+		panic(fmt.Errorf("Not an IPv-4 address: %v", ip))
+	}
+	c.dst.IP = ip
+	return nil
+}
+
 func (c *mDNSConn4) Close() error {
 	if c.conn != nil {
 		err := c.conn.Close()
@@ -537,22 +553,119 @@
 	return
 }
 
+type mDNSConn6 struct {
+	mDNSConnBase
+	conn *ipv6.PacketConn
+}
+
+var defaultMDNSMulticastIPv6 = net.ParseIP("ff02::fb")
+
+func newMDNSConn6() mDNSConn {
+	c := mDNSConn6{}
+	c.SetIp(defaultMDNSMulticastIPv6)
+	return &c
+}
+
+func (c *mDNSConn6) SetIp(ip net.IP) error {
+	if ip6 := ip.To16(); ip6 == nil {
+		panic(fmt.Errorf("Not an IPv6 address: %v", ip))
+	}
+	c.dst.IP = ip
+	return nil
+}
+
+func (c *mDNSConn6) Close() error {
+	if c.conn != nil {
+		err := c.conn.Close()
+		c.conn = nil
+		return err
+	}
+	return nil
+}
+
+func (c *mDNSConn6) Listen(port int) error {
+	c.dst.Port = port
+	conn, err := net.ListenUDP("udp6", &c.dst)
+	if err != nil {
+		return err
+	}
+	// Now we need a low level ipv4 packet connection.
+	c.conn = ipv6.NewPacketConn(conn)
+	c.conn.SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true)
+	return nil
+}
+
+// This allows us to listen on this specific interface.
+func (c *mDNSConn6) JoinGroup(iface net.Interface) error {
+	if err := c.conn.JoinGroup(&iface, &c.dst); err != nil {
+		c.Close()
+		return fmt.Errorf("joining %v%%%v: %v", iface, c.dst, err)
+	}
+	return nil
+}
+
+func (c *mDNSConn6) ConnectTo(port int, ip net.IP, iface *net.Interface) error {
+	ip6 := ip.To16()
+	if ip6 == nil {
+		return fmt.Errorf("Not a valid IPv6 address: %v", ip)
+	}
+	conn, err := makeUdpSocketWithReusePort(port, ip6, iface)
+	if err != nil {
+		return err
+	}
+	c.senders = append(c.senders, conn)
+	return nil
+}
+
+func (c *mDNSConn6) ReadFrom(buf []byte) (size int, iface *net.Interface, src net.Addr, err error) {
+	var cm *ipv6.ControlMessage
+	size, cm, src, err = c.conn.ReadFrom(buf)
+	if err == nil {
+		iface, err = net.InterfaceByIndex(cm.IfIndex)
+	}
+	return
+}
+
 type mDNS struct {
 	conn4     mDNSConn
+	conn6     mDNSConn
 	port      int
 	pHandlers []func(net.Interface, net.Addr, Packet)
 	wHandlers []func(net.Addr, error)
 	eHandlers []func(error)
 }
 
+// NewMDNS creates a new object implementing the MDNS interface. Do not forget
+// to call EnableIPv4() or EnableIPv6() to enable listening on interfaces of
+// the corresponding type, or nothing will work.
 func NewMDNS() MDNS {
 	m := mDNS{}
-	m.conn4 = newMDNSConn4()
+	m.conn4 = nil
+	m.conn6 = nil
 	return &m
 }
 
+func (m *mDNS) EnableIPv4() {
+	if m.conn4 == nil {
+		m.conn4 = newMDNSConn4()
+	}
+}
+
+func (m *mDNS) EnableIPv6() {
+	if m.conn6 == nil {
+		m.conn6 = newMDNSConn6()
+	}
+}
+
 func (m *mDNS) Close() {
-	m.conn4.Close()
+	if m.conn4 != nil {
+		m.conn4.Close()
+		m.conn4 = nil
+	}
+	if m.conn6 != nil {
+		m.conn6.Close()
+		m.conn6 = nil
+	}
 }
 
 func (m *mDNS) SetAddress(address string) error {
@@ -560,11 +673,27 @@
 	if ip == nil {
 		return fmt.Errorf("Not a valid IP address: %v", address)
 	}
-	return m.conn4.SetIp(ip)
+	if ip4 := ip.To4(); ip4 != nil {
+		if m.conn4 == nil {
+			return fmt.Errorf("mDNS IPv4 support is disabled")
+		}
+		return m.conn4.SetIp(ip4)
+	} else {
+		if m.conn6 == nil {
+			return fmt.Errorf("mDNS IPv6 support is disabled")
+		}
+		return m.conn6.SetIp(ip.To16())
+	}
 }
 
 func (m *mDNS) ipToSend() net.IP {
-	return m.conn4.getIp()
+	if m.conn4 != nil {
+		return m.conn4.getIp()
+	}
+	if m.conn6 != nil {
+		return m.conn6.getIp()
+	}
+	return nil
 }
 
 // AddHandler calls f on every Packet received.
@@ -594,7 +723,19 @@
 	if err := packet.serialize(&buf); err != nil {
 		return err
 	}
-	return m.conn4.SendTo(buf, dst)
+	if dst.IP.To4() != nil {
+		if m.conn4 != nil {
+			return m.conn4.SendTo(buf, dst)
+		} else {
+			return fmt.Errorf("IPv4 was not enabled!")
+		}
+	} else {
+		if m.conn6 != nil {
+			return m.conn6.SendTo(buf, dst)
+		} else {
+			return fmt.Errorf("IPv6 was not enabled!")
+		}
+	}
 }
 
 // Send serializes and sends packet out as a multicast to all interfaces
@@ -606,7 +747,18 @@
 	if err := packet.serialize(&buf); err != nil {
 		return err
 	}
-	return m.conn4.Send(buf)
+	var err4 error
+	if m.conn4 != nil {
+		err4 = m.conn4.Send(buf)
+	}
+	var err6 error
+	if m.conn6 != nil {
+		err6 = m.conn6.Send(buf)
+	}
+	if err4 != nil {
+		return err4
+	}
+	return err6
 }
 
 func makeUdpSocketWithReusePort(port int, ip net.IP, iface *net.Interface) (net.PacketConn, error) {
@@ -642,11 +794,55 @@
 	return listenConfig.ListenPacket(context.Background(), network, address)
 }
 
+// receivedPacket is a small struct used to send received UDP packets and
+// information about their interface / source address through a channel.
+type receivedPacket struct {
+	data  []byte
+	iface *net.Interface
+	src   net.Addr
+	err   error
+}
+
+// startListenLoop returns a channel of receivedPacket items, after starting
+// a goroutine that listens for packets on conn, and writes then to the channel
+// in a loop. The goroutine will stop when it cannot read anymore from conn,
+// which will happen in case of error (e.g. when the connection is closed).
+// In this case, it will send a final receivedPacket instance with the error
+// code only before exiting.
+func startListenLoop(conn mDNSConn) <-chan receivedPacket {
+	channel := make(chan receivedPacket, 1)
+	go func() {
+		payloadBuf := make([]byte, 1<<16)
+		for {
+			size, iface, src, err := conn.ReadFrom(payloadBuf)
+			if err != nil {
+				channel <- receivedPacket{err: err}
+				return
+			}
+			data := make([]byte, size)
+			copy(data, payloadBuf[:size])
+			channel <- receivedPacket{
+				data:  data,
+				iface: iface,
+				src:   src,
+				err:   nil}
+		}
+	}()
+	return channel
+}
+
 // Start causes m to start listening for MDNS packets on all interfaces on
 // the specified port. Listening will stop if ctx is done.
 func (m *mDNS) Start(ctx context.Context, port int) error {
-	if err := m.conn4.Listen(port); err != nil {
-		return err
+	if m.conn4 != nil {
+		if err := m.conn4.Listen(port); err != nil {
+			return err
+		}
+	}
+	if m.conn6 != nil {
+		if err := m.conn6.Listen(port); err != nil {
+			return err
+		}
 	}
 	// Now we need to join this connection to every interface that supports
 	// Multicast.
@@ -660,14 +856,24 @@
 		if iface.Flags&net.FlagMulticast == 0 || iface.Flags&net.FlagUp == 0 {
 			continue
 		}
-		if err := m.conn4.JoinGroup(iface); err != nil {
-			m.Close()
-			return fmt.Errorf("joining %v: %v", iface, err)
+		if m.conn4 != nil {
+			if err := m.conn4.JoinGroup(iface); err != nil {
+				m.Close()
+				return fmt.Errorf("joining %v: %v", iface, err)
+			}
+		}
+		if m.conn6 != nil {
+			if err := m.conn6.JoinGroup(iface); err != nil {
+				m.Close()
+				return fmt.Errorf("joining %v: %v", iface, err)
+			}
 		}
 		addrs, err := iface.Addrs()
 		if err != nil {
 			return fmt.Errorf("getting addresses of %v: %v", iface, err)
 		}
+		// When both IPv4 and IPv6 are enabled, only connect to the interface
+		// through its IPv6 address if it doesn't have an IPv4 one.
 		for _, addr := range addrs {
 			var ip net.IP
 			switch v := addr.(type) {
@@ -676,43 +882,67 @@
 			case *net.IPAddr:
 				ip = v.IP
 			}
-			if ip == nil || ip.To4() == nil {
+			if ip == nil {
 				continue
 			}
-			err := m.conn4.ConnectTo(port, ip.To4(), &iface)
-			if err != nil {
-				return fmt.Errorf("creating socket for %v via %v: %v", iface, ip, err)
+			ip4 := ip.To4()
+			if m.conn4 != nil && ip4 != nil {
+				if err := m.conn4.ConnectTo(port, ip4, &iface); err != nil {
+					return fmt.Errorf("creating socket for %v via %v: %v", iface, ip, err)
+				}
+				break
 			}
-			break
+			if m.conn6 != nil && ip4 == nil {
+				if err := m.conn6.ConnectTo(port, ip, &iface); err != nil {
+					return fmt.Errorf("creating socket for %v via %v: %v", iface, ip, err)
+				}
+				break
+			}
 		}
 	}
 	go func() {
+		// NOTE: This defer statement will close connections, which will force
+		// the goroutines started by startListenLoop() to exit.
 		defer m.Close()
-		// Now that we've joined every possible interface we can handle the main loop.
-		payloadBuf := make([]byte, 1<<16)
+
+		var chan4 <-chan receivedPacket
+		var chan6 <-chan receivedPacket
+
+		if m.conn4 != nil {
+			chan4 = startListenLoop(m.conn4)
+		}
+		if m.conn6 != nil {
+			chan6 = startListenLoop(m.conn6)
+		}
 		for {
+			var received receivedPacket
+
 			select {
 			case <-ctx.Done():
 				return
-			default:
+			case received = <-chan4:
+				break
+			case received = <-chan6:
+				break
 			}
-			size, iface, src, err := m.conn4.ReadFrom(payloadBuf)
-			if err != nil {
+
+			if received.err != nil {
 				for _, e := range m.eHandlers {
-					go e(err)
+					go e(received.err)
 				}
 				return
 			}
+
 			var packet Packet
-			data := payloadBuf[:size]
-			if err := packet.deserialize(data, bytes.NewBuffer(data)); err != nil {
+			if err := packet.deserialize(received.data, bytes.NewBuffer(received.data)); err != nil {
 				for _, w := range m.wHandlers {
-					go w(src, err)
+					go w(received.src, err)
 				}
 				continue
 			}
+
 			for _, p := range m.pHandlers {
-				go p(*iface, src, packet)
+				go p(*received.iface, received.src, packet)
 			}
 		}
 	}()
@@ -723,10 +953,16 @@
 // requests the ip address associated with domain.
 func QuestionPacket(domain string) Packet {
 	return Packet{
-		Header: Header{QDCount: 1},
+		Header: Header{QDCount: 2},
 		Questions: []Question{
 			Question{
 				Domain:  domain,
+				Type:    AAAA,
+				Class:   IN,
+				Unicast: false,
+			},
+			Question{
+				Domain:  domain,
 				Type:    A,
 				Class:   IN,
 				Unicast: false,
@@ -743,7 +979,7 @@
 		Answers: []Record{
 			Record{
 				Domain: domain,
-				Type:   A,
+				Type:   IpToDnsRecordType(ip),
 				Class:  IN,
 				Flush:  false,
 				Data:   []byte(ip),
diff --git a/mdns/mdns_test.go b/mdns/mdns_test.go
index 168007f..fec008a 100644
--- a/mdns/mdns_test.go
+++ b/mdns/mdns_test.go
@@ -96,8 +96,20 @@
 	}
 }
 
+func TestIpToDnsRecordType(t *testing.T) {
+	ip1 := "224.0.0.251"
+	if addrType := IpToDnsRecordType(net.ParseIP(ip1)); addrType != A {
+		t.Errorf("IpToDnsRecordType(%s) mismatch %v, wanted %v", ip1, addrType, A)
+	}
+	ip2 := "ff2e::fb"
+	if addrType := IpToDnsRecordType(net.ParseIP(ip2)); addrType != AAAA {
+		t.Errorf("IpToDnsRecordType(%s) mismatch %v, wanted %v", ip2, addrType, AAAA)
+	}
+}
+
 func TestSetAddress(t *testing.T) {
 	m := NewMDNS()
+	m.EnableIPv4()
 	got := m.ipToSend()
 	// Should send to the default address.
 	want := net.ParseIP("224.0.0.251")
@@ -105,11 +117,33 @@
 		t.Errorf("ipToSend (default): mismatch (-want +got)\n%s", d)
 	}
 
-	m.SetAddress("11.22.33.44")
+	if err := m.SetAddress("11.22.33.44"); err != nil {
+		t.Errorf("SetAddress() returned error: %s", err)
+	} else {
+		got = m.ipToSend()
+		// Should send to the given custom address.
+		want = net.ParseIP("11.22.33.44")
+		if d := cmp.Diff(want, got); d != "" {
+			t.Errorf("ipToSend (custom): mismatch (-want +got)\n%s", d)
+		}
+	}
+
+	m = NewMDNS()
+	m.EnableIPv6()
 	got = m.ipToSend()
-	// Should send to the given custom address.
-	want = net.ParseIP("11.22.33.44")
+	want = net.ParseIP("ff02::fb")
 	if d := cmp.Diff(want, got); d != "" {
-		t.Errorf("ipToSend (custom): mismatch (-want +got)\n%s", d)
+		t.Errorf("ipToSend (default): mismatch (-want +got)\n%s", d)
+	}
+
+	if err := m.SetAddress("11:22::33:44"); err != nil {
+		t.Errorf("SetAddress() returned error: %s", err)
+	} else {
+		got = m.ipToSend()
+		// Should send to the given custom address.
+		want = net.ParseIP("11:22::33:44")
+		if d := cmp.Diff(want, got); d != "" {
+			t.Errorf("ipToSend (custom): mismatch (-want +got)\n%s", d)
+		}
 	}
 }