[mdns] add support for custom addresses to the mdns library
Bug: http://fxb/DX-1453
Change-Id: I8f176519b93833758e299b6faa4e9498c5a1ef66
diff --git a/mdns/mdns.go b/mdns/mdns.go
index 61baced..0f7f9a2 100644
--- a/mdns/mdns.go
+++ b/mdns/mdns.go
@@ -396,12 +396,22 @@
type MDNS struct {
conn *ipv4.PacketConn
senders []net.PacketConn
+ Address string
port int
pHandlers []func(net.Interface, net.Addr, Packet)
wHandlers []func(net.Addr, error)
eHandlers []func(error)
}
+var defaultMDNSMulticastIPv4 = net.ParseIP("224.0.0.251")
+
+func (m *MDNS) ipToSend() net.IP {
+ if m.Address == "" {
+ return defaultMDNSMulticastIPv4
+ }
+ return net.ParseIP(m.Address)
+}
+
// AddHandler calls f on every Packet received.
func (m *MDNS) AddHandler(f func(net.Interface, net.Addr, Packet)) {
m.pHandlers = append(m.pHandlers, f)
@@ -437,13 +447,11 @@
return nil
}
-var mdnsMulticastIPv4 = net.ParseIP("224.0.0.251")
-
// Send serializes and sends packet out as a multicast to all interfaces
// using the port that m is listening on. Note that Start must be
// called prior to making this call.
func (m *MDNS) Send(packet Packet) error {
- dst := net.UDPAddr{IP: mdnsMulticastIPv4, Port: m.port}
+ dst := net.UDPAddr{IP: m.ipToSend(), Port: m.port}
return m.SendTo(packet, &dst)
}
@@ -482,7 +490,7 @@
// Start causes m to start listening for MDNS packets on all interfaces on
// the specified port. Listening will stop if ctx is done.
func (m *MDNS) Start(ctx context.Context, port int) error {
- dst := &net.UDPAddr{IP: mdnsMulticastIPv4, Port: port}
+ dst := &net.UDPAddr{IP: m.ipToSend(), Port: port}
conn, err := net.ListenUDP("udp4", dst)
if err != nil {
return err
diff --git a/mdns/mdns_test.go b/mdns/mdns_test.go
index 29637da..18d3c62 100644
--- a/mdns/mdns_test.go
+++ b/mdns/mdns_test.go
@@ -6,7 +6,10 @@
import (
"bytes"
+ "net"
"testing"
+
+ "github.com/google/go-cmp/cmp"
)
func TestUint16(t *testing.T) {
@@ -15,8 +18,8 @@
writeUint16(&buf, v)
var v2 uint16
readUint16(&buf, &v2)
- if v != v2 {
- t.Errorf("read/writeUint16 mismatch: wrote %v, read %v", v, v2)
+ if d := cmp.Diff(v, v2); d != "" {
+ t.Errorf("read/writeUint16: mismatch (-wrote +read)\n%s", d)
}
}
@@ -26,8 +29,8 @@
writeUint32(&buf, v)
var v2 uint32
readUint32(&buf, &v2)
- if v != v2 {
- t.Errorf("read/writeUint32 mismatch: wrote %v, read %v", v, v2)
+ if d := cmp.Diff(v, v2); d != "" {
+ t.Errorf("read/writeUint32: mismatch (-wrote +read)\n%s", d)
}
}
@@ -44,8 +47,8 @@
v.serialize(&buf)
var v2 Header
v2.deserialize(buf.Bytes(), &buf)
- if v != v2 {
- t.Errorf("header (de)serialize mismatch: wrote %v, read %v", v, v2)
+ if d := cmp.Diff(v, v2); d != "" {
+ t.Errorf("header (de)serialize: mismatch (-serialize +deserialize)\n%s", d)
}
}
@@ -55,8 +58,8 @@
writeDomain(&buf, v)
var v2 string
readDomain(buf.Bytes(), &buf, &v2)
- if v != v2 {
- t.Errorf("read/writeDomain mismatch: wrote %v, read %v", v, v2)
+ if d := cmp.Diff(v, v2); d != "" {
+ t.Errorf("read/writeDomain: mismatch (-wrote +read)\n%s", d)
}
}
@@ -70,23 +73,11 @@
v.serialize(&buf)
var v2 Question
v2.deserialize(buf.Bytes(), &buf)
- if v != v2 {
- t.Errorf("question (de)serialize mismatch: wrote %v, read %v", v, v2)
+ if d := cmp.Diff(v, v2); d != "" {
+ t.Errorf("question (de)serialize: mismatch (-serialize +deserialize)\n%s", d)
}
}
-func equalBytes(a, b []byte) bool {
- if len(a) != len(b) {
- return false
- }
- for i, ai := range a {
- if ai != b[i] {
- return false
- }
- }
- return true
-}
-
func TestRecord(t *testing.T) {
var buf bytes.Buffer
v := Record{
@@ -100,22 +91,25 @@
v.serialize(&buf)
var v2 Record
v2.deserialize(buf.Bytes(), &buf)
- if v.Domain != v2.Domain {
- t.Errorf("record (de)serialize mismatch (domain): wrote %v, read %v", v.Domain, v2.Domain)
+ if d := cmp.Diff(v, v2); d != "" {
+ t.Errorf("record (de)serialize: mismatch (-serialize +deserialize)\n%s", d)
}
- if v.Type != v2.Type {
- t.Errorf("record (de)serialize mismatch (type): wrote %v, read %v", v.Type, v2.Type)
+}
+
+func TestIPToSend(t *testing.T) {
+ m := MDNS{}
+ got := m.ipToSend()
+ // Should send to the default address.
+ want := net.ParseIP("224.0.0.251")
+ if d := cmp.Diff(want, got); d != "" {
+ t.Errorf("ipToSend (default): mismatch (-want +got)\n%s", d)
}
- if v.Class != v2.Class {
- t.Errorf("record (de)serialize mismatch (class): wrote %v, read %v", v.Class, v2.Class)
- }
- if v.Flush != v2.Flush {
- t.Errorf("record (de)serialize mismatch (flush): wrote %v, read %v", v.Flush, v2.Flush)
- }
- if v.TTL != v2.TTL {
- t.Errorf("record (de)serialize mismatch (ttl): wrote %v, read %v", v.TTL, v2.TTL)
- }
- if !equalBytes(v.Data, v2.Data) {
- t.Errorf("record (de)serialize mismatch (data): wrote %v, read %v", v.Data, v2.Data)
+
+ m = MDNS{Address: "11.22.33.44"}
+ got = m.ipToSend()
+ // Should send to the given custom address.
+ want = net.ParseIP("11.22.33.44")
+ if d := cmp.Diff(want, got); d != "" {
+ t.Errorf("ipToSend (custom): mismatch (-want +got)\n%s", d)
}
}