Merge "Pull in upstream version of go/net/dns/dnsmessage package"
diff --git a/dns/cache.go b/dns/cache.go
index 5d60847..43670b6 100644
--- a/dns/cache.go
+++ b/dns/cache.go
@@ -9,7 +9,7 @@
"sync"
- "github.com/google/netstack/dns/message"
+ "github.com/google/netstack/dns/dnsmessage"
)
const (
@@ -22,14 +22,14 @@
// TODO(mpcomplete): I don't think we need the whole Resource. Depends on the questions
// we want to answer.
type cacheEntry struct {
- rr message.Resource // the resource
- ttd time.Time // when this entry expires
+ rr dnsmessage.Resource // the resource
+ ttd time.Time // when this entry expires
}
// Returns true if this entry is a CNAME that points at something no longer in our cache.
func (entry *cacheEntry) isDanglingCNAME(cache *cacheInfo) bool {
h := entry.rr.Header()
- return h.Type == message.TypeCNAME && cache.m[h.Name] == nil
+ return h.Type == dnsmessage.TypeCNAME && cache.m[h.Name] == nil
}
// The full cache.
@@ -44,12 +44,12 @@
}
// Returns a list of Resources that match the given Question (same class and type and matching domain name).
-func (cache *cacheInfo) lookup(question *message.Question) []message.Resource {
+func (cache *cacheInfo) lookup(question *dnsmessage.Question) []dnsmessage.Resource {
cache.RLock()
entries := cache.m[question.Name]
cache.RUnlock()
- rrs := []message.Resource{}
+ rrs := []dnsmessage.Resource{}
for _, entry := range entries {
h := entry.rr.Header()
if h.Class == question.Class && h.Type == question.Type && h.Name == question.Name {
@@ -60,7 +60,7 @@
}
// Attempts to add a new entry into the cache. Can fail if the cache is full.
-func (cache *cacheInfo) insert(h *message.ResourceHeader, rr message.Resource) {
+func (cache *cacheInfo) insert(h *dnsmessage.ResourceHeader, rr dnsmessage.Resource) {
newEntry := cacheEntry{
ttd: testHookNow().Add(time.Duration(h.TTL) * time.Second),
rr: rr,
@@ -75,13 +75,13 @@
}
// Attempts to add each Resource as a new entry in the cache. Can fail if the cache is full.
-func (cache *cacheInfo) insertAll(rrs []message.Resource) {
+func (cache *cacheInfo) insertAll(rrs []dnsmessage.Resource) {
cache.prune()
for _, rr := range rrs {
h := rr.Header()
- if h.Class == message.ClassINET {
+ if h.Class == dnsmessage.ClassINET {
switch h.Type {
- case message.TypeA, message.TypeAAAA, message.TypeCNAME:
+ case dnsmessage.TypeA, dnsmessage.TypeAAAA, dnsmessage.TypeCNAME:
cache.insert(h, rr)
}
}
@@ -116,7 +116,7 @@
var cache = newCache()
func newCachedResolver(fallback Resolver) Resolver {
- return func(c *Client, question message.Question) (string, []message.Resource, *message.Message, error) {
+ return func(c *Client, question dnsmessage.Question) (string, []dnsmessage.Resource, *dnsmessage.Message, error) {
rrs := cache.lookup(&question)
if len(rrs) != 0 {
return "", rrs, nil, nil
diff --git a/dns/cache_test.go b/dns/cache_test.go
index 786d8a7..470a657 100644
--- a/dns/cache_test.go
+++ b/dns/cache_test.go
@@ -8,32 +8,32 @@
"testing"
"time"
- "github.com/google/netstack/dns/message"
+ "github.com/google/netstack/dns/dnsmessage"
)
-func makeResourceHeader(name string, ttl uint32) message.ResourceHeader {
- return message.ResourceHeader{
+func makeResourceHeader(name string, ttl uint32) dnsmessage.ResourceHeader {
+ return dnsmessage.ResourceHeader{
Name: name,
- Type: message.TypeA,
- Class: message.ClassINET,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
TTL: ttl,
}
}
-func makeQuestion(name string) *message.Question {
- return &message.Question{
+func makeQuestion(name string) *dnsmessage.Question {
+ return &dnsmessage.Question{
Name: name,
- Type: message.TypeA,
- Class: message.ClassINET,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
}
}
-var smallTestResources = []message.Resource{
- &message.AResource{
+var smallTestResources = []dnsmessage.Resource{
+ &dnsmessage.AResource{
ResourceHeader: makeResourceHeader("example.com.", 5),
A: [4]byte{127, 0, 0, 1},
},
- &message.AResource{
+ &dnsmessage.AResource{
ResourceHeader: makeResourceHeader("example.com.", 5),
A: [4]byte{127, 0, 0, 2},
},
@@ -90,8 +90,8 @@
testHookNow = func() time.Time { return testTime }
// One record that expires at 10 seconds.
- cache.insertAll([]message.Resource{
- &message.AResource{
+ cache.insertAll([]dnsmessage.Resource{
+ &dnsmessage.AResource{
ResourceHeader: makeResourceHeader("example.com.", 10),
A: [4]byte{127, 0, 0, 1},
},
@@ -99,8 +99,8 @@
// A bunch that expire at 5 seconds.
for i := 0; i < maxEntries; i++ {
- cache.insertAll([]message.Resource{
- &message.AResource{
+ cache.insertAll([]dnsmessage.Resource{
+ &dnsmessage.AResource{
ResourceHeader: makeResourceHeader("example.com.", 5),
A: [4]byte{127, byte(i << 16), byte(i << 8), byte(i)},
},
@@ -113,8 +113,8 @@
}
// Cache is at capacity. Can't insert anymore.
- cache.insertAll([]message.Resource{
- &message.AResource{
+ cache.insertAll([]dnsmessage.Resource{
+ &dnsmessage.AResource{
ResourceHeader: makeResourceHeader("foo.example.com.", 5),
A: [4]byte{192, 168, 0, 1},
},
@@ -126,8 +126,8 @@
// Advance the clock so the 5 second entries expire. Insert should succeed.
testTime = testTime.Add(6 * time.Second)
- cache.insertAll([]message.Resource{
- &message.AResource{
+ cache.insertAll([]dnsmessage.Resource{
+ &dnsmessage.AResource{
ResourceHeader: makeResourceHeader("foo.example.com.", 5),
A: [4]byte{192, 168, 0, 1},
},
diff --git a/dns/client.go b/dns/client.go
index ea60207..42fa03c 100644
--- a/dns/client.go
+++ b/dns/client.go
@@ -26,7 +26,7 @@
"sync"
"time"
- "github.com/google/netstack/dns/message"
+ "github.com/google/netstack/dns/dnsmessage"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/network/ipv4"
"github.com/google/netstack/tcpip/transport/tcp"
@@ -55,7 +55,7 @@
}
// A Resolver answers DNS Questions.
-type Resolver func(c *Client, question message.Question) (cname string, rrs []message.Resource, msg *message.Message, err error)
+type Resolver func(c *Client, question dnsmessage.Question) (cname string, rrs []dnsmessage.Resource, msg *dnsmessage.Message, err error)
// Error represents an error while issuing a DNS query for a hostname.
type Error struct {
@@ -83,7 +83,7 @@
// roundTrip writes the query to and reads the response from the Endpoint.
// The message format is slightly different depending on the transport protocol
// (for TCP, a 2 byte message length is prepended). See RFC 1035.
-func roundTrip(ctx context.Context, transport tcpip.TransportProtocolNumber, ep tcpip.Endpoint, wq *waiter.Queue, query *message.Message) (response *message.Message, err error) {
+func roundTrip(ctx context.Context, transport tcpip.TransportProtocolNumber, ep tcpip.Endpoint, wq *waiter.Queue, query *dnsmessage.Message) (response *dnsmessage.Message, err error) {
b, err := query.Pack()
if err != nil {
return nil, err
@@ -143,7 +143,7 @@
bcontents = b
}
- response = &message.Message{}
+ response = &dnsmessage.Message{}
if err := response.Unpack(bcontents); err != nil {
// Ignore invalid responses as they may be malicious
// forgery attempts. Instead continue waiting until
@@ -185,13 +185,13 @@
}
// exchange sends a query on the connection and hopes for a response.
-func (c *Client) exchange(server tcpip.FullAddress, name string, qtype message.Type, timeout time.Duration) (response *message.Message, err error) {
- query := message.Message{
- Header: message.Header{
+func (c *Client) exchange(server tcpip.FullAddress, name string, qtype dnsmessage.Type, timeout time.Duration) (response *dnsmessage.Message, err error) {
+ query := dnsmessage.Message{
+ Header: dnsmessage.Header{
RecursionDesired: true,
},
- Questions: []message.Question{
- {name, qtype, message.ClassINET},
+ Questions: []dnsmessage.Question{
+ {name, qtype, dnsmessage.ClassINET},
},
}
@@ -222,7 +222,7 @@
// Do a lookup for a single name, which must be rooted
// (otherwise answer will not find the answers).
-func (c *Client) tryOneName(cfg *dnsConfig, name string, qtype message.Type) (string, []message.Resource, *message.Message, error) {
+func (c *Client) tryOneName(cfg *dnsConfig, name string, qtype dnsmessage.Type) (string, []dnsmessage.Resource, *dnsmessage.Message, error) {
if len(cfg.servers) == 0 {
return "", nil, nil, &Error{Err: "no DNS servers", Name: name}
}
@@ -246,7 +246,7 @@
}
// libresolv continues to the next server when it receives
// an invalid referral response. See golang.org/issue/15434.
- if msg.RCode == message.RCodeSuccess && !msg.Authoritative && !msg.RecursionAvailable && len(msg.Answers) == 0 && len(msg.Additionals) == 0 {
+ if msg.RCode == dnsmessage.RCodeSuccess && !msg.Authoritative && !msg.RecursionAvailable && len(msg.Answers) == 0 && len(msg.Additionals) == 0 {
lastErr = &Error{Err: "lame referral", Name: name, Server: &server}
continue
}
@@ -255,7 +255,7 @@
// it means the response in msg was not useful and trying another
// server probably won't help. Return now in those cases.
// TODO: indicate this in a more obvious way, such as a field on Error?
- if err == nil || msg.RCode == message.RCodeSuccess || msg.RCode == message.RCodeNameError {
+ if err == nil || msg.RCode == dnsmessage.RCodeSuccess || msg.RCode == dnsmessage.RCodeNameError {
return cname, rrs, msg, err
}
lastErr = err
@@ -266,13 +266,13 @@
// addrRecordList converts and returns a list of IP addresses from DNS
// address records (both A and AAAA). Other record types are ignored.
-func addrRecordList(rrs []message.Resource) []tcpip.Address {
+func addrRecordList(rrs []dnsmessage.Resource) []tcpip.Address {
addrs := make([]tcpip.Address, 0, 4)
for _, rr := range rrs {
switch rr := rr.(type) {
- case *message.AResource:
+ case *dnsmessage.AResource:
addrs = append(addrs, tcpip.Address(rr.A[:]))
- case *message.AAAAResource:
+ case *dnsmessage.AAAAResource:
addrs = append(addrs, tcpip.Address(rr.AAAA[:]))
}
}
@@ -311,7 +311,7 @@
var resolvConf resolverConfig
func newNetworkResolver(config *dnsConfig) Resolver {
- return func(c *Client, question message.Question) (string, []message.Resource, *message.Message, error) {
+ return func(c *Client, question dnsmessage.Question) (string, []dnsmessage.Resource, *dnsmessage.Message, error) {
return c.tryOneName(config, question.Name, question.Type)
}
}
@@ -431,16 +431,16 @@
resolvConf.mu.RUnlock()
type racer struct {
fqdn string
- rrs []message.Resource
+ rrs []dnsmessage.Resource
error
}
lane := make(chan racer, 1)
- qtypes := [...]message.Type{message.TypeA, message.TypeAAAA}
+ qtypes := [...]dnsmessage.Type{dnsmessage.TypeA, dnsmessage.TypeAAAA}
var lastErr error
for _, fqdn := range conf.nameList(name) {
for _, qtype := range qtypes {
- go func(qtype message.Type) {
- _, rrs, _, err := conf.resolver(c, message.Question{Name: fqdn, Type: qtype, Class: message.ClassINET})
+ go func(qtype dnsmessage.Type) {
+ _, rrs, _, err := conf.resolver(c, dnsmessage.Question{Name: fqdn, Type: qtype, Class: dnsmessage.ClassINET})
lane <- racer{fqdn, rrs, err}
}(qtype)
}
@@ -478,13 +478,13 @@
// for (name, qtype) from the response message msg, which
// is assumed to have come from server.
// It is exported mainly for use by registered helpers.
-func answer(name string, server tcpip.FullAddress, msg *message.Message, qtype message.Type) (cname string, addrs []message.Resource, err error) {
- addrs = make([]message.Resource, 0, len(msg.Answers))
+func answer(name string, server tcpip.FullAddress, msg *dnsmessage.Message, qtype dnsmessage.Type) (cname string, addrs []dnsmessage.Resource, err error) {
+ addrs = make([]dnsmessage.Resource, 0, len(msg.Answers))
- if msg.RCode == message.RCodeNameError {
+ if msg.RCode == dnsmessage.RCodeNameError {
return "", nil, &Error{Err: noSuchHost, Name: name, Server: &server}
}
- if msg.RCode != message.RCodeSuccess {
+ if msg.RCode != dnsmessage.RCodeSuccess {
// None of the error codes make sense
// for the query we sent. If we didn't get
// a name error and we didn't get success,
@@ -502,13 +502,13 @@
addrs = addrs[0:0]
for _, rr := range msg.Answers {
h := rr.Header()
- if h.Class == message.ClassINET && equalASCIILabel(h.Name, name) {
+ if h.Class == dnsmessage.ClassINET && equalASCIILabel(h.Name, name) {
switch h.Type {
case qtype:
addrs = append(addrs, rr)
- case message.TypeCNAME:
+ case dnsmessage.TypeCNAME:
// redirect to cname
- name = rr.(*message.CNAMEResource).CNAME
+ name = rr.(*dnsmessage.CNAMEResource).CNAME
continue Cname
}
}
diff --git a/dns/message/message.go b/dns/dnsmessage/message.go
similarity index 98%
rename from dns/message/message.go
rename to dns/dnsmessage/message.go
index 1af0cf4..da43b0b 100644
--- a/dns/message/message.go
+++ b/dns/dnsmessage/message.go
@@ -2,16 +2,12 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// This is a fork of igudger@'s patch to go's net library, found at
-// https://go-review.googlesource.com/#/c/35237/
-// Can remove if/when the code is accepted into the net library.
-
-// Package message provides a mostly RFC 1035 compliant implementation of DNS
-// message packing and unpacking.
+// Package dnsmessage provides a mostly RFC 1035 compliant implementation of
+// DNS message packing and unpacking.
//
// This implementation is designed to minimize heap allocations and avoid
// unnecessary packing and unpacking as much as possible.
-package message
+package dnsmessage
import (
"errors"
@@ -366,6 +362,16 @@
}
func (p *Parser) skipResource(sec section) error {
+ if p.resHeaderValid {
+ newOff := p.off + int(p.resHeader.Length)
+ if newOff > len(p.msg) {
+ return errResourceLen
+ }
+ p.off = newOff
+ p.resHeaderValid = false
+ p.index++
+ return nil
+ }
if err := p.checkAdvance(sec); err != nil {
return err
}
diff --git a/dns/message/message_test.go b/dns/dnsmessage/message_test.go
similarity index 87%
rename from dns/message/message_test.go
rename to dns/dnsmessage/message_test.go
index cafa70d..46edd72 100644
--- a/dns/message/message_test.go
+++ b/dns/dnsmessage/message_test.go
@@ -2,11 +2,13 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-package message
+package dnsmessage
import (
"fmt"
+ "net"
"reflect"
+ "strings"
"testing"
)
@@ -265,6 +267,115 @@
}
}
+func ExampleHeaderSearch() {
+ msg := Message{
+ Header: Header{Response: true, Authoritative: true},
+ Questions: []Question{
+ {
+ Name: "foo.bar.example.com.",
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ {
+ Name: "bar.example.com.",
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ },
+ Answers: []Resource{
+ &AResource{
+ ResourceHeader: ResourceHeader{
+ Name: "foo.bar.example.com.",
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ A: [4]byte{127, 0, 0, 1},
+ },
+ &AResource{
+ ResourceHeader: ResourceHeader{
+ Name: "bar.example.com.",
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ A: [4]byte{127, 0, 0, 2},
+ },
+ },
+ }
+
+ buf, err := msg.Pack()
+ if err != nil {
+ panic(err)
+ }
+
+ wantName := "bar.example.com."
+
+ var p Parser
+ if _, err := p.Start(buf); err != nil {
+ panic(err)
+ }
+
+ for {
+ q, err := p.Question()
+ if err == ErrSectionDone {
+ break
+ }
+ if err != nil {
+ panic(err)
+ }
+
+ if q.Name != wantName {
+ continue
+ }
+
+ fmt.Println("Found question for name", wantName)
+ if err := p.SkipAllQuestions(); err != nil {
+ panic(err)
+ }
+ break
+ }
+
+ var gotIPs []net.IP
+ for {
+ h, err := p.AnswerHeader()
+ if err == ErrSectionDone {
+ break
+ }
+ if err != nil {
+ panic(err)
+ }
+
+ if (h.Type != TypeA && h.Type != TypeAAAA) || h.Class != ClassINET {
+ continue
+ }
+
+ if !strings.EqualFold(h.Name, wantName) {
+ if err := p.SkipAnswer(); err != nil {
+ panic(err)
+ }
+ continue
+ }
+ a, err := p.Answer()
+ if err != nil {
+ panic(err)
+ }
+
+ switch r := a.(type) {
+ default:
+ panic(fmt.Sprintf("unknown type: %T", r))
+ case *AResource:
+ gotIPs = append(gotIPs, r.A[:])
+ case *AAAAResource:
+ gotIPs = append(gotIPs, r.AAAA[:])
+ }
+ }
+
+ fmt.Printf("Found A/AAAA records for name %s: %v\n", wantName, gotIPs)
+
+ // Output:
+ // Found question for name bar.example.com.
+ // Found A/AAAA records for name bar.example.com.: [127.0.0.2]
+}
+
func largeTestMsg() Message {
return Message{
Header: Header{Response: true, Authoritative: true},