[mdns] add support for custom addresses to the mdns library

Bug: http://fxb/DX-1453
Change-Id: I8f176519b93833758e299b6faa4e9498c5a1ef66
diff --git a/mdns/mdns.go b/mdns/mdns.go
index 61baced..0f7f9a2 100644
--- a/mdns/mdns.go
+++ b/mdns/mdns.go
@@ -396,12 +396,22 @@
 type MDNS struct {
 	conn      *ipv4.PacketConn
 	senders   []net.PacketConn
+	Address   string
 	port      int
 	pHandlers []func(net.Interface, net.Addr, Packet)
 	wHandlers []func(net.Addr, error)
 	eHandlers []func(error)
 }
 
+var defaultMDNSMulticastIPv4 = net.ParseIP("224.0.0.251")
+
+func (m *MDNS) ipToSend() net.IP {
+	if m.Address == "" {
+		return defaultMDNSMulticastIPv4
+	}
+	return net.ParseIP(m.Address)
+}
+
 // AddHandler calls f on every Packet received.
 func (m *MDNS) AddHandler(f func(net.Interface, net.Addr, Packet)) {
 	m.pHandlers = append(m.pHandlers, f)
@@ -437,13 +447,11 @@
 	return nil
 }
 
-var mdnsMulticastIPv4 = net.ParseIP("224.0.0.251")
-
 // Send serializes and sends packet out as a multicast to all interfaces
 // using the port that m is listening on. Note that Start must be
 // called prior to making this call.
 func (m *MDNS) Send(packet Packet) error {
-	dst := net.UDPAddr{IP: mdnsMulticastIPv4, Port: m.port}
+	dst := net.UDPAddr{IP: m.ipToSend(), Port: m.port}
 	return m.SendTo(packet, &dst)
 }
 
@@ -482,7 +490,7 @@
 // Start causes m to start listening for MDNS packets on all interfaces on
 // the specified port. Listening will stop if ctx is done.
 func (m *MDNS) Start(ctx context.Context, port int) error {
-	dst := &net.UDPAddr{IP: mdnsMulticastIPv4, Port: port}
+	dst := &net.UDPAddr{IP: m.ipToSend(), Port: port}
 	conn, err := net.ListenUDP("udp4", dst)
 	if err != nil {
 		return err
diff --git a/mdns/mdns_test.go b/mdns/mdns_test.go
index 29637da..18d3c62 100644
--- a/mdns/mdns_test.go
+++ b/mdns/mdns_test.go
@@ -6,7 +6,10 @@
 
 import (
 	"bytes"
+	"net"
 	"testing"
+
+	"github.com/google/go-cmp/cmp"
 )
 
 func TestUint16(t *testing.T) {
@@ -15,8 +18,8 @@
 	writeUint16(&buf, v)
 	var v2 uint16
 	readUint16(&buf, &v2)
-	if v != v2 {
-		t.Errorf("read/writeUint16 mismatch: wrote %v, read %v", v, v2)
+	if d := cmp.Diff(v, v2); d != "" {
+		t.Errorf("read/writeUint16: mismatch (-wrote +read)\n%s", d)
 	}
 }
 
@@ -26,8 +29,8 @@
 	writeUint32(&buf, v)
 	var v2 uint32
 	readUint32(&buf, &v2)
-	if v != v2 {
-		t.Errorf("read/writeUint32 mismatch: wrote %v, read %v", v, v2)
+	if d := cmp.Diff(v, v2); d != "" {
+		t.Errorf("read/writeUint32: mismatch (-wrote +read)\n%s", d)
 	}
 }
 
@@ -44,8 +47,8 @@
 	v.serialize(&buf)
 	var v2 Header
 	v2.deserialize(buf.Bytes(), &buf)
-	if v != v2 {
-		t.Errorf("header (de)serialize mismatch: wrote %v, read %v", v, v2)
+	if d := cmp.Diff(v, v2); d != "" {
+		t.Errorf("header (de)serialize: mismatch (-serialize +deserialize)\n%s", d)
 	}
 }
 
@@ -55,8 +58,8 @@
 	writeDomain(&buf, v)
 	var v2 string
 	readDomain(buf.Bytes(), &buf, &v2)
-	if v != v2 {
-		t.Errorf("read/writeDomain mismatch: wrote %v, read %v", v, v2)
+	if d := cmp.Diff(v, v2); d != "" {
+		t.Errorf("read/writeDomain: mismatch (-wrote +read)\n%s", d)
 	}
 }
 
@@ -70,23 +73,11 @@
 	v.serialize(&buf)
 	var v2 Question
 	v2.deserialize(buf.Bytes(), &buf)
-	if v != v2 {
-		t.Errorf("question (de)serialize mismatch: wrote %v, read %v", v, v2)
+	if d := cmp.Diff(v, v2); d != "" {
+		t.Errorf("question (de)serialize: mismatch (-serialize +deserialize)\n%s", d)
 	}
 }
 
-func equalBytes(a, b []byte) bool {
-	if len(a) != len(b) {
-		return false
-	}
-	for i, ai := range a {
-		if ai != b[i] {
-			return false
-		}
-	}
-	return true
-}
-
 func TestRecord(t *testing.T) {
 	var buf bytes.Buffer
 	v := Record{
@@ -100,22 +91,25 @@
 	v.serialize(&buf)
 	var v2 Record
 	v2.deserialize(buf.Bytes(), &buf)
-	if v.Domain != v2.Domain {
-		t.Errorf("record (de)serialize mismatch (domain): wrote %v, read %v", v.Domain, v2.Domain)
+	if d := cmp.Diff(v, v2); d != "" {
+		t.Errorf("record (de)serialize: mismatch (-serialize +deserialize)\n%s", d)
 	}
-	if v.Type != v2.Type {
-		t.Errorf("record (de)serialize mismatch (type): wrote %v, read %v", v.Type, v2.Type)
+}
+
+func TestIPToSend(t *testing.T) {
+	m := MDNS{}
+	got := m.ipToSend()
+	// Should send to the default address.
+	want := net.ParseIP("224.0.0.251")
+	if d := cmp.Diff(want, got); d != "" {
+		t.Errorf("ipToSend (default): mismatch (-want +got)\n%s", d)
 	}
-	if v.Class != v2.Class {
-		t.Errorf("record (de)serialize mismatch (class): wrote %v, read %v", v.Class, v2.Class)
-	}
-	if v.Flush != v2.Flush {
-		t.Errorf("record (de)serialize mismatch (flush): wrote %v, read %v", v.Flush, v2.Flush)
-	}
-	if v.TTL != v2.TTL {
-		t.Errorf("record (de)serialize mismatch (ttl): wrote %v, read %v", v.TTL, v2.TTL)
-	}
-	if !equalBytes(v.Data, v2.Data) {
-		t.Errorf("record (de)serialize mismatch (data): wrote %v, read %v", v.Data, v2.Data)
+
+	m = MDNS{Address: "11.22.33.44"}
+	got = m.ipToSend()
+	// Should send to the given custom address.
+	want = net.ParseIP("11.22.33.44")
+	if d := cmp.Diff(want, got); d != "" {
+		t.Errorf("ipToSend (custom): mismatch (-want +got)\n%s", d)
 	}
 }