[mdns][dev_finder] Add unicast receive support.

This allows devices in different subnets or behind port forwarding
to be found (in this case via port forwarding).

This is based off of the result occurring when querying for a device
running in QEMU due to port-forwarding (the destination address is no
longer marked as multicast after hostfwd translation).

See section 5.5

https://tools.ietf.org/html/rfc6762

Refactors connection interfaces for clarity.

Bug: 6537

Change-Id: I069572e340cb934b0bcc4df71a9d41e05ba34d05
diff --git a/net/dev_finder/cmd/common.go b/net/dev_finder/cmd/common.go
index 29f66ec..ac60e83 100644
--- a/net/dev_finder/cmd/common.go
+++ b/net/dev_finder/cmd/common.go
@@ -76,6 +76,10 @@
 	// established a connection to the Fuchsia device (rather than the address of the
 	// Fuchsia device on its own).
 	localResolve bool
+	// Determines whether to accept incoming unicast mDNS responses. This can happen if the
+	// receiving device is on a different subnet, or the receiving device's listener port
+	// has been forwarded to from a non-standard port.
+	acceptUnicast bool
 	// The limit of devices to discover. If this number of devices has been discovered before
 	// the timeout has been reached the program will exit successfully.
 	deviceLimit int
@@ -98,6 +102,7 @@
 	f.StringVar(&cmd.mdnsPorts, "port", "5353,5356", "Comma separated list of ports to issue mDNS queries to.")
 	f.IntVar(&cmd.timeout, "timeout", 2000, "The number of milliseconds before declaring a timeout.")
 	f.BoolVar(&cmd.localResolve, "local", false, "Returns the address of the interface to the host when doing service lookup/domain resolution.")
+	f.BoolVar(&cmd.acceptUnicast, "accept-unicast", false, "Accepts unicast responses. For if the receiving device responds from a different subnet or behind port forwarding.")
 	f.IntVar(&cmd.deviceLimit, "device-limit", 0, "Exits before the timeout at this many devices per resolution (zero means no limit).")
 }
 
@@ -133,6 +138,7 @@
 		m.EnableIPv6()
 	}
 	m.SetAddress(address)
+	m.SetAcceptUnicastResponses(cmd.acceptUnicast)
 	return m
 }
 
@@ -205,10 +211,14 @@
 		case err := <-errChan:
 			return nil, err
 		case device := <-devChan:
-			// Creates a hashable string to remove duplicate devices,
-			// as no two devices on this network should have the same
-			// IP and domain.
-			devices[fmt.Sprintf("%s|%s", string(device.addr), device.domain)] = device
+			// Creates a hashable string to remove duplicate devices.
+			//
+			// There should only be one of each domain on the network, but given how
+			// mcast is sent on each interface, multiple responses with different IP
+			// addresses can be returned for a single device in the case of a device
+			// running on the host in an emulator (this is a special case). Each
+			// IP would be point to the localhost in this case.
+			devices[fmt.Sprintf("%s", device.domain)] = device
 			if cmd.deviceLimit != 0 && len(devices) == cmd.deviceLimit {
 				return sortDeviceMap(devices), nil
 			}
diff --git a/net/mdns/mdns.go b/net/mdns/mdns.go
index bc1a353..c304d94 100644
--- a/net/mdns/mdns.go
+++ b/net/mdns/mdns.go
@@ -56,6 +56,15 @@
 	Additional []Record
 }
 
+// A small struct used to send received UDP packets and
+// information about their interface / source address through a channel.
+type receivedPacketInfo struct {
+	data  []byte
+	iface *net.Interface
+	src   net.Addr
+	err   error
+}
+
 func writeUint16(out io.Writer, val uint16) error {
 	buf := make([]byte, 2)
 	binary.BigEndian.PutUint16(buf, val)
@@ -413,6 +422,11 @@
 	// SetAddress sets a non-default listen address.
 	SetAddress(address string) error
 
+	// Sets whether to accept unicast responses. These can be received when
+	// the receiving device is in a different subnet, or if there is port
+	// forwarding occurring between the host and the device.
+	SetAcceptUnicastResponses(accept bool)
+
 	// ipToSend returns the IP corresponding to the current address.
 	ipToSend() net.IP
 
@@ -446,21 +460,132 @@
 	Close()
 }
 
+// Interface used to read packets into a caller-owned buffer.
+type packetReceiver interface {
+	// Reads a packet from the connection into |buf|.  On success, returns
+	// the packet size in bytes, the interface |iface| on which it was
+	// received, and the source address |src|.
+	//
+	// On error, returns a non-nil |err|.
+	ReadPacket(buf []byte) (size int, iface *net.Interface, src net.Addr, err error)
+	// Implements ipv4/ipv6 JoinGroup functionality, joining a group address
+	// on the interface |iface|.
+	JoinGroup(iface *net.Interface, group net.Addr) error
+	Close() error
+}
+
+type packetReceiver4 struct {
+	conn *ipv4.PacketConn
+}
+
+func (p *packetReceiver4) ReadPacket(buf []byte) (size int, iface *net.Interface, src net.Addr, err error) {
+	var cm *ipv4.ControlMessage
+	size, cm, src, err = p.conn.ReadFrom(buf)
+	if err != nil {
+		return
+	}
+	iface, err = net.InterfaceByIndex(cm.IfIndex)
+	return
+}
+
+func (p *packetReceiver4) JoinGroup(iface *net.Interface, group net.Addr) error {
+	return p.conn.JoinGroup(iface, group)
+}
+
+func (p *packetReceiver4) Close() error {
+	err := p.conn.Close()
+	p.conn = nil
+	return err
+}
+
+type packetReceiver6 struct {
+	conn *ipv6.PacketConn
+}
+
+func (p *packetReceiver6) ReadPacket(buf []byte) (size int, iface *net.Interface, src net.Addr, err error) {
+	var cm *ipv6.ControlMessage
+	size, cm, src, err = p.conn.ReadFrom(buf)
+	if err != nil {
+		return
+	}
+	iface, err = net.InterfaceByIndex(cm.IfIndex)
+	return
+}
+
+func (p *packetReceiver6) JoinGroup(iface *net.Interface, group net.Addr) error {
+	return p.conn.JoinGroup(iface, group)
+}
+
+func (p *packetReceiver6) Close() error {
+	err := p.conn.Close()
+	p.conn = nil
+	return err
+}
+
 type mDNSConn interface {
 	Close() error
 	SetIp(ip net.IP) error
 	getIp() net.IP
+	SetAcceptUnicastResponses(accept bool)
 	SendTo(buf bytes.Buffer, dst *net.UDPAddr) error
 	Send(buf bytes.Buffer) error
-	Listen(port int) error
+	InitReceiver(port int) error
+	// Starts a goroutine to listen for packets incoming on the multicast
+	// connection.
+	//
+	// If |SetAcceptUnicastResponses| has been set to true, starts a
+	// separate goroutine to listen to incoming unicast packets for each
+	// of the machine's interfaces.
+	//
+	// Returns a read-only channel on which received packets are written.
+	Listen() <-chan receivedPacketInfo
 	JoinGroup(iface net.Interface) error
+	// Connects to the specific port and interface.
+	//
+	// This is primarily for sending packets, but if
+	// |SetAcceptUnicastResponses(true)| has been called, then this
+	// connection becomes bidirectional, and is also used for reading
+	// packets when |Listen()| is later invoked.
+	//
+	// If |SetAcceptUnicastResponses| is called between multiple invocations
+	// of this function, then not all connections will be read from when
+	// |Listen| is called.
 	ConnectTo(port int, ip net.IP, iface *net.Interface) error
 	ReadFrom(buf []byte) (size int, iface *net.Interface, src net.Addr, err error)
+	NewPacketReceiver(net.PacketConn) packetReceiver
 }
 
 type mDNSConnBase struct {
-	dst     net.UDPAddr
-	senders []net.PacketConn
+	dst            net.UDPAddr
+	acceptUnicast  bool
+	receiver       packetReceiver
+	senders        []net.PacketConn
+	ucastReceivers []packetReceiver
+}
+
+func (c *mDNSConnBase) Close() error {
+	if c.receiver != nil {
+		err := c.receiver.Close()
+		c.receiver = nil
+		if err != nil {
+			return err
+		}
+	}
+	for _, receiver := range c.ucastReceivers {
+		if err := receiver.Close(); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (c *mDNSConnBase) Listen() <-chan receivedPacketInfo {
+	ch := make(chan receivedPacketInfo, 1)
+	startListenLoop(c.receiver, ch)
+	for _, receiver := range c.ucastReceivers {
+		startListenLoop(receiver, ch)
+	}
+	return ch
 }
 
 func (c *mDNSConnBase) getIp() net.IP {
@@ -476,17 +601,48 @@
 	return nil
 }
 
+func (c *mDNSConnBase) SetAcceptUnicastResponses(accept bool) {
+	c.acceptUnicast = accept
+}
+
 func (c *mDNSConnBase) Send(buf bytes.Buffer) error {
 	return c.SendTo(buf, &c.dst)
 }
 
 type mDNSConn4 struct {
 	mDNSConnBase
-	conn *ipv4.PacketConn
 }
 
 var defaultMDNSMulticastIPv4 = net.ParseIP("224.0.0.251")
 
+// Helper function for the |Listen| method in the |mDNSConn| interface.
+//
+// Accepts write-only channel of receivePacketInfo items.
+//
+// Starts a goroutine which will stop when it cannot read anymore from
+// |receiver|, which will happen in the case of an error (like the connection
+// being closed).  In this case it will send a final receivedPacketInfo instance
+// with only an error code before exiting.
+func startListenLoop(receiver packetReceiver, ch chan<- receivedPacketInfo) {
+	go func() {
+		payloadBuf := make([]byte, 1<<16)
+		for {
+			size, iface, src, err := receiver.ReadPacket(payloadBuf)
+			if err != nil {
+				ch <- receivedPacketInfo{err: err}
+				return
+			}
+			data := make([]byte, size)
+			copy(data, payloadBuf[:size])
+			ch <- receivedPacketInfo{
+				data:  data,
+				iface: iface,
+				src:   src,
+				err:   nil}
+		}
+	}()
+}
+
 func newMDNSConn4() mDNSConn {
 	c := mDNSConn4{}
 	c.SetIp(defaultMDNSMulticastIPv4)
@@ -501,30 +657,25 @@
 	return nil
 }
 
-func (c *mDNSConn4) Close() error {
-	if c.conn != nil {
-		err := c.conn.Close()
-		c.conn = nil
-		return err
-	}
-	return nil
-}
-
-func (c *mDNSConn4) Listen(port int) error {
+func (c *mDNSConn4) InitReceiver(port int) error {
 	c.dst.Port = port
 	conn, err := net.ListenUDP("udp4", &c.dst)
 	if err != nil {
 		return err
 	}
-	// Now we need a low level ipv4 packet connection.
-	c.conn = ipv4.NewPacketConn(conn)
-	c.conn.SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true)
+	c.receiver = c.NewPacketReceiver(conn)
 	return nil
 }
 
+func (c *mDNSConn4) NewPacketReceiver(conn net.PacketConn) packetReceiver {
+	conn4 := ipv4.NewPacketConn(conn)
+	conn4.SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true)
+	return &packetReceiver4{conn4}
+}
+
 // This allows us to listen on this specific interface.
 func (c *mDNSConn4) JoinGroup(iface net.Interface) error {
-	if err := c.conn.JoinGroup(&iface, &c.dst); err != nil {
+	if err := c.receiver.JoinGroup(&iface, &c.dst); err != nil {
 		c.Close()
 		return fmt.Errorf("joining %v%%%v: %v", iface, c.dst, err)
 	}
@@ -541,21 +692,18 @@
 		return err
 	}
 	c.senders = append(c.senders, conn)
+	if c.acceptUnicast {
+		c.ucastReceivers = append(c.ucastReceivers, c.NewPacketReceiver(conn))
+	}
 	return nil
 }
 
-func (c *mDNSConn4) ReadFrom(buf []byte) (size int, iface *net.Interface, src net.Addr, err error) {
-	var cm *ipv4.ControlMessage
-	size, cm, src, err = c.conn.ReadFrom(buf)
-	if err == nil {
-		iface, err = net.InterfaceByIndex(cm.IfIndex)
-	}
-	return
+func (c *mDNSConn4) ReadFrom(buf []byte) (int, *net.Interface, net.Addr, error) {
+	return c.receiver.ReadPacket(buf)
 }
 
 type mDNSConn6 struct {
 	mDNSConnBase
-	conn *ipv6.PacketConn
 }
 
 var defaultMDNSMulticastIPv6 = net.ParseIP("ff02::fb")
@@ -574,30 +722,25 @@
 	return nil
 }
 
-func (c *mDNSConn6) Close() error {
-	if c.conn != nil {
-		err := c.conn.Close()
-		c.conn = nil
-		return err
-	}
-	return nil
-}
-
-func (c *mDNSConn6) Listen(port int) error {
+func (c *mDNSConn6) InitReceiver(port int) error {
 	c.dst.Port = port
 	conn, err := net.ListenUDP("udp6", &c.dst)
 	if err != nil {
 		return err
 	}
-	// Now we need a low level ipv4 packet connection.
-	c.conn = ipv6.NewPacketConn(conn)
-	c.conn.SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true)
+	c.receiver = c.NewPacketReceiver(conn)
 	return nil
 }
 
+func (c *mDNSConn6) NewPacketReceiver(conn net.PacketConn) packetReceiver {
+	conn6 := ipv6.NewPacketConn(conn)
+	conn6.SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true)
+	return &packetReceiver6{conn6}
+}
+
 // This allows us to listen on this specific interface.
 func (c *mDNSConn6) JoinGroup(iface net.Interface) error {
-	if err := c.conn.JoinGroup(&iface, &c.dst); err != nil {
+	if err := c.receiver.JoinGroup(&iface, &c.dst); err != nil {
 		c.Close()
 		return fmt.Errorf("joining %v%%%v: %v", iface, c.dst, err)
 	}
@@ -614,16 +757,14 @@
 		return err
 	}
 	c.senders = append(c.senders, conn)
+	if c.acceptUnicast {
+		c.ucastReceivers = append(c.ucastReceivers, c.NewPacketReceiver(conn))
+	}
 	return nil
 }
 
-func (c *mDNSConn6) ReadFrom(buf []byte) (size int, iface *net.Interface, src net.Addr, err error) {
-	var cm *ipv6.ControlMessage
-	size, cm, src, err = c.conn.ReadFrom(buf)
-	if err == nil {
-		iface, err = net.InterfaceByIndex(cm.IfIndex)
-	}
-	return
+func (c *mDNSConn6) ReadFrom(buf []byte) (int, *net.Interface, net.Addr, error) {
+	return c.receiver.ReadPacket(buf)
 }
 
 type mDNS struct {
@@ -657,6 +798,15 @@
 	}
 }
 
+func (m *mDNS) SetAcceptUnicastResponses(accept bool) {
+	if m.conn4 != nil {
+		m.conn4.SetAcceptUnicastResponses(accept)
+	}
+	if m.conn6 != nil {
+		m.conn6.SetAcceptUnicastResponses(accept)
+	}
+}
+
 func (m *mDNS) Close() {
 	if m.conn4 != nil {
 		m.conn4.Close()
@@ -794,53 +944,16 @@
 	return listenConfig.ListenPacket(context.Background(), network, address)
 }
 
-// receivedPacket is a small struct used to send received UDP packets and
-// information about their interface / source address through a channel.
-type receivedPacket struct {
-	data  []byte
-	iface *net.Interface
-	src   net.Addr
-	err   error
-}
-
-// startListenLoop returns a channel of receivedPacket items, after starting
-// a goroutine that listens for packets on conn, and writes then to the channel
-// in a loop. The goroutine will stop when it cannot read anymore from conn,
-// which will happen in case of error (e.g. when the connection is closed).
-// In this case, it will send a final receivedPacket instance with the error
-// code only before exiting.
-func startListenLoop(conn mDNSConn) <-chan receivedPacket {
-	channel := make(chan receivedPacket, 1)
-	go func() {
-		payloadBuf := make([]byte, 1<<16)
-		for {
-			size, iface, src, err := conn.ReadFrom(payloadBuf)
-			if err != nil {
-				channel <- receivedPacket{err: err}
-				return
-			}
-			data := make([]byte, size)
-			copy(data, payloadBuf[:size])
-			channel <- receivedPacket{
-				data:  data,
-				iface: iface,
-				src:   src,
-				err:   nil}
-		}
-	}()
-	return channel
-}
-
 // Start causes m to start listening for MDNS packets on all interfaces on
 // the specified port. Listening will stop if ctx is done.
 func (m *mDNS) Start(ctx context.Context, port int) error {
 	if m.conn4 != nil {
-		if err := m.conn4.Listen(port); err != nil {
+		if err := m.conn4.InitReceiver(port); err != nil {
 			return err
 		}
 	}
 	if m.conn6 != nil {
-		if err := m.conn6.Listen(port); err != nil {
+		if err := m.conn6.InitReceiver(port); err != nil {
 			return err
 		}
 	}
@@ -905,17 +1018,17 @@
 		// the goroutines started by startListenLoop() to exit.
 		defer m.Close()
 
-		var chan4 <-chan receivedPacket
-		var chan6 <-chan receivedPacket
+		var chan4 <-chan receivedPacketInfo
+		var chan6 <-chan receivedPacketInfo
 
 		if m.conn4 != nil {
-			chan4 = startListenLoop(m.conn4)
+			chan4 = m.conn4.Listen()
 		}
 		if m.conn6 != nil {
-			chan6 = startListenLoop(m.conn6)
+			chan6 = m.conn6.Listen()
 		}
 		for {
-			var received receivedPacket
+			var received receivedPacketInfo
 
 			select {
 			case <-ctx.Done():
@@ -955,13 +1068,13 @@
 	return Packet{
 		Header: Header{QDCount: 2},
 		Questions: []Question{
-			Question{
+			{
 				Domain:  domain,
 				Type:    AAAA,
 				Class:   IN,
 				Unicast: false,
 			},
-			Question{
+			{
 				Domain:  domain,
 				Type:    A,
 				Class:   IN,
@@ -977,7 +1090,7 @@
 	return Packet{
 		Header: Header{ANCount: 1},
 		Answers: []Record{
-			Record{
+			{
 				Domain: domain,
 				Type:   IpToDnsRecordType(ip),
 				Class:  IN,