Merge "Pull in upstream version of go/net/dns/dnsmessage package"
diff --git a/dns/cache.go b/dns/cache.go
index 5d60847..43670b6 100644
--- a/dns/cache.go
+++ b/dns/cache.go
@@ -9,7 +9,7 @@
 
 	"sync"
 
-	"github.com/google/netstack/dns/message"
+	"github.com/google/netstack/dns/dnsmessage"
 )
 
 const (
@@ -22,14 +22,14 @@
 // TODO(mpcomplete): I don't think we need the whole Resource. Depends on the questions
 // we want to answer.
 type cacheEntry struct {
-	rr  message.Resource // the resource
-	ttd time.Time        // when this entry expires
+	rr  dnsmessage.Resource // the resource
+	ttd time.Time           // when this entry expires
 }
 
 // Returns true if this entry is a CNAME that points at something no longer in our cache.
 func (entry *cacheEntry) isDanglingCNAME(cache *cacheInfo) bool {
 	h := entry.rr.Header()
-	return h.Type == message.TypeCNAME && cache.m[h.Name] == nil
+	return h.Type == dnsmessage.TypeCNAME && cache.m[h.Name] == nil
 }
 
 // The full cache.
@@ -44,12 +44,12 @@
 }
 
 // Returns a list of Resources that match the given Question (same class and type and matching domain name).
-func (cache *cacheInfo) lookup(question *message.Question) []message.Resource {
+func (cache *cacheInfo) lookup(question *dnsmessage.Question) []dnsmessage.Resource {
 	cache.RLock()
 	entries := cache.m[question.Name]
 	cache.RUnlock()
 
-	rrs := []message.Resource{}
+	rrs := []dnsmessage.Resource{}
 	for _, entry := range entries {
 		h := entry.rr.Header()
 		if h.Class == question.Class && h.Type == question.Type && h.Name == question.Name {
@@ -60,7 +60,7 @@
 }
 
 // Attempts to add a new entry into the cache. Can fail if the cache is full.
-func (cache *cacheInfo) insert(h *message.ResourceHeader, rr message.Resource) {
+func (cache *cacheInfo) insert(h *dnsmessage.ResourceHeader, rr dnsmessage.Resource) {
 	newEntry := cacheEntry{
 		ttd: testHookNow().Add(time.Duration(h.TTL) * time.Second),
 		rr:  rr,
@@ -75,13 +75,13 @@
 }
 
 // Attempts to add each Resource as a new entry in the cache. Can fail if the cache is full.
-func (cache *cacheInfo) insertAll(rrs []message.Resource) {
+func (cache *cacheInfo) insertAll(rrs []dnsmessage.Resource) {
 	cache.prune()
 	for _, rr := range rrs {
 		h := rr.Header()
-		if h.Class == message.ClassINET {
+		if h.Class == dnsmessage.ClassINET {
 			switch h.Type {
-			case message.TypeA, message.TypeAAAA, message.TypeCNAME:
+			case dnsmessage.TypeA, dnsmessage.TypeAAAA, dnsmessage.TypeCNAME:
 				cache.insert(h, rr)
 			}
 		}
@@ -116,7 +116,7 @@
 var cache = newCache()
 
 func newCachedResolver(fallback Resolver) Resolver {
-	return func(c *Client, question message.Question) (string, []message.Resource, *message.Message, error) {
+	return func(c *Client, question dnsmessage.Question) (string, []dnsmessage.Resource, *dnsmessage.Message, error) {
 		rrs := cache.lookup(&question)
 		if len(rrs) != 0 {
 			return "", rrs, nil, nil
diff --git a/dns/cache_test.go b/dns/cache_test.go
index 786d8a7..470a657 100644
--- a/dns/cache_test.go
+++ b/dns/cache_test.go
@@ -8,32 +8,32 @@
 	"testing"
 	"time"
 
-	"github.com/google/netstack/dns/message"
+	"github.com/google/netstack/dns/dnsmessage"
 )
 
-func makeResourceHeader(name string, ttl uint32) message.ResourceHeader {
-	return message.ResourceHeader{
+func makeResourceHeader(name string, ttl uint32) dnsmessage.ResourceHeader {
+	return dnsmessage.ResourceHeader{
 		Name:  name,
-		Type:  message.TypeA,
-		Class: message.ClassINET,
+		Type:  dnsmessage.TypeA,
+		Class: dnsmessage.ClassINET,
 		TTL:   ttl,
 	}
 }
 
-func makeQuestion(name string) *message.Question {
-	return &message.Question{
+func makeQuestion(name string) *dnsmessage.Question {
+	return &dnsmessage.Question{
 		Name:  name,
-		Type:  message.TypeA,
-		Class: message.ClassINET,
+		Type:  dnsmessage.TypeA,
+		Class: dnsmessage.ClassINET,
 	}
 }
 
-var smallTestResources = []message.Resource{
-	&message.AResource{
+var smallTestResources = []dnsmessage.Resource{
+	&dnsmessage.AResource{
 		ResourceHeader: makeResourceHeader("example.com.", 5),
 		A:              [4]byte{127, 0, 0, 1},
 	},
-	&message.AResource{
+	&dnsmessage.AResource{
 		ResourceHeader: makeResourceHeader("example.com.", 5),
 		A:              [4]byte{127, 0, 0, 2},
 	},
@@ -90,8 +90,8 @@
 	testHookNow = func() time.Time { return testTime }
 
 	// One record that expires at 10 seconds.
-	cache.insertAll([]message.Resource{
-		&message.AResource{
+	cache.insertAll([]dnsmessage.Resource{
+		&dnsmessage.AResource{
 			ResourceHeader: makeResourceHeader("example.com.", 10),
 			A:              [4]byte{127, 0, 0, 1},
 		},
@@ -99,8 +99,8 @@
 
 	// A bunch that expire at 5 seconds.
 	for i := 0; i < maxEntries; i++ {
-		cache.insertAll([]message.Resource{
-			&message.AResource{
+		cache.insertAll([]dnsmessage.Resource{
+			&dnsmessage.AResource{
 				ResourceHeader: makeResourceHeader("example.com.", 5),
 				A:              [4]byte{127, byte(i << 16), byte(i << 8), byte(i)},
 			},
@@ -113,8 +113,8 @@
 	}
 
 	// Cache is at capacity. Can't insert anymore.
-	cache.insertAll([]message.Resource{
-		&message.AResource{
+	cache.insertAll([]dnsmessage.Resource{
+		&dnsmessage.AResource{
 			ResourceHeader: makeResourceHeader("foo.example.com.", 5),
 			A:              [4]byte{192, 168, 0, 1},
 		},
@@ -126,8 +126,8 @@
 
 	// Advance the clock so the 5 second entries expire. Insert should succeed.
 	testTime = testTime.Add(6 * time.Second)
-	cache.insertAll([]message.Resource{
-		&message.AResource{
+	cache.insertAll([]dnsmessage.Resource{
+		&dnsmessage.AResource{
 			ResourceHeader: makeResourceHeader("foo.example.com.", 5),
 			A:              [4]byte{192, 168, 0, 1},
 		},
diff --git a/dns/client.go b/dns/client.go
index ea60207..42fa03c 100644
--- a/dns/client.go
+++ b/dns/client.go
@@ -26,7 +26,7 @@
 	"sync"
 	"time"
 
-	"github.com/google/netstack/dns/message"
+	"github.com/google/netstack/dns/dnsmessage"
 	"github.com/google/netstack/tcpip"
 	"github.com/google/netstack/tcpip/network/ipv4"
 	"github.com/google/netstack/tcpip/transport/tcp"
@@ -55,7 +55,7 @@
 }
 
 // A Resolver answers DNS Questions.
-type Resolver func(c *Client, question message.Question) (cname string, rrs []message.Resource, msg *message.Message, err error)
+type Resolver func(c *Client, question dnsmessage.Question) (cname string, rrs []dnsmessage.Resource, msg *dnsmessage.Message, err error)
 
 // Error represents an error while issuing a DNS query for a hostname.
 type Error struct {
@@ -83,7 +83,7 @@
 // roundTrip writes the query to and reads the response from the Endpoint.
 // The message format is slightly different depending on the transport protocol
 // (for TCP, a 2 byte message length is prepended). See RFC 1035.
-func roundTrip(ctx context.Context, transport tcpip.TransportProtocolNumber, ep tcpip.Endpoint, wq *waiter.Queue, query *message.Message) (response *message.Message, err error) {
+func roundTrip(ctx context.Context, transport tcpip.TransportProtocolNumber, ep tcpip.Endpoint, wq *waiter.Queue, query *dnsmessage.Message) (response *dnsmessage.Message, err error) {
 	b, err := query.Pack()
 	if err != nil {
 		return nil, err
@@ -143,7 +143,7 @@
 			bcontents = b
 		}
 
-		response = &message.Message{}
+		response = &dnsmessage.Message{}
 		if err := response.Unpack(bcontents); err != nil {
 			// Ignore invalid responses as they may be malicious
 			// forgery attempts. Instead continue waiting until
@@ -185,13 +185,13 @@
 }
 
 // exchange sends a query on the connection and hopes for a response.
-func (c *Client) exchange(server tcpip.FullAddress, name string, qtype message.Type, timeout time.Duration) (response *message.Message, err error) {
-	query := message.Message{
-		Header: message.Header{
+func (c *Client) exchange(server tcpip.FullAddress, name string, qtype dnsmessage.Type, timeout time.Duration) (response *dnsmessage.Message, err error) {
+	query := dnsmessage.Message{
+		Header: dnsmessage.Header{
 			RecursionDesired: true,
 		},
-		Questions: []message.Question{
-			{name, qtype, message.ClassINET},
+		Questions: []dnsmessage.Question{
+			{name, qtype, dnsmessage.ClassINET},
 		},
 	}
 
@@ -222,7 +222,7 @@
 
 // Do a lookup for a single name, which must be rooted
 // (otherwise answer will not find the answers).
-func (c *Client) tryOneName(cfg *dnsConfig, name string, qtype message.Type) (string, []message.Resource, *message.Message, error) {
+func (c *Client) tryOneName(cfg *dnsConfig, name string, qtype dnsmessage.Type) (string, []dnsmessage.Resource, *dnsmessage.Message, error) {
 	if len(cfg.servers) == 0 {
 		return "", nil, nil, &Error{Err: "no DNS servers", Name: name}
 	}
@@ -246,7 +246,7 @@
 			}
 			// libresolv continues to the next server when it receives
 			// an invalid referral response. See golang.org/issue/15434.
-			if msg.RCode == message.RCodeSuccess && !msg.Authoritative && !msg.RecursionAvailable && len(msg.Answers) == 0 && len(msg.Additionals) == 0 {
+			if msg.RCode == dnsmessage.RCodeSuccess && !msg.Authoritative && !msg.RecursionAvailable && len(msg.Answers) == 0 && len(msg.Additionals) == 0 {
 				lastErr = &Error{Err: "lame referral", Name: name, Server: &server}
 				continue
 			}
@@ -255,7 +255,7 @@
 			// it means the response in msg was not useful and trying another
 			// server probably won't help. Return now in those cases.
 			// TODO: indicate this in a more obvious way, such as a field on Error?
-			if err == nil || msg.RCode == message.RCodeSuccess || msg.RCode == message.RCodeNameError {
+			if err == nil || msg.RCode == dnsmessage.RCodeSuccess || msg.RCode == dnsmessage.RCodeNameError {
 				return cname, rrs, msg, err
 			}
 			lastErr = err
@@ -266,13 +266,13 @@
 
 // addrRecordList converts and returns a list of IP addresses from DNS
 // address records (both A and AAAA). Other record types are ignored.
-func addrRecordList(rrs []message.Resource) []tcpip.Address {
+func addrRecordList(rrs []dnsmessage.Resource) []tcpip.Address {
 	addrs := make([]tcpip.Address, 0, 4)
 	for _, rr := range rrs {
 		switch rr := rr.(type) {
-		case *message.AResource:
+		case *dnsmessage.AResource:
 			addrs = append(addrs, tcpip.Address(rr.A[:]))
-		case *message.AAAAResource:
+		case *dnsmessage.AAAAResource:
 			addrs = append(addrs, tcpip.Address(rr.AAAA[:]))
 		}
 	}
@@ -311,7 +311,7 @@
 var resolvConf resolverConfig
 
 func newNetworkResolver(config *dnsConfig) Resolver {
-	return func(c *Client, question message.Question) (string, []message.Resource, *message.Message, error) {
+	return func(c *Client, question dnsmessage.Question) (string, []dnsmessage.Resource, *dnsmessage.Message, error) {
 		return c.tryOneName(config, question.Name, question.Type)
 	}
 }
@@ -431,16 +431,16 @@
 	resolvConf.mu.RUnlock()
 	type racer struct {
 		fqdn string
-		rrs  []message.Resource
+		rrs  []dnsmessage.Resource
 		error
 	}
 	lane := make(chan racer, 1)
-	qtypes := [...]message.Type{message.TypeA, message.TypeAAAA}
+	qtypes := [...]dnsmessage.Type{dnsmessage.TypeA, dnsmessage.TypeAAAA}
 	var lastErr error
 	for _, fqdn := range conf.nameList(name) {
 		for _, qtype := range qtypes {
-			go func(qtype message.Type) {
-				_, rrs, _, err := conf.resolver(c, message.Question{Name: fqdn, Type: qtype, Class: message.ClassINET})
+			go func(qtype dnsmessage.Type) {
+				_, rrs, _, err := conf.resolver(c, dnsmessage.Question{Name: fqdn, Type: qtype, Class: dnsmessage.ClassINET})
 				lane <- racer{fqdn, rrs, err}
 			}(qtype)
 		}
@@ -478,13 +478,13 @@
 // for (name, qtype) from the response message msg, which
 // is assumed to have come from server.
 // It is exported mainly for use by registered helpers.
-func answer(name string, server tcpip.FullAddress, msg *message.Message, qtype message.Type) (cname string, addrs []message.Resource, err error) {
-	addrs = make([]message.Resource, 0, len(msg.Answers))
+func answer(name string, server tcpip.FullAddress, msg *dnsmessage.Message, qtype dnsmessage.Type) (cname string, addrs []dnsmessage.Resource, err error) {
+	addrs = make([]dnsmessage.Resource, 0, len(msg.Answers))
 
-	if msg.RCode == message.RCodeNameError {
+	if msg.RCode == dnsmessage.RCodeNameError {
 		return "", nil, &Error{Err: noSuchHost, Name: name, Server: &server}
 	}
-	if msg.RCode != message.RCodeSuccess {
+	if msg.RCode != dnsmessage.RCodeSuccess {
 		// None of the error codes make sense
 		// for the query we sent.  If we didn't get
 		// a name error and we didn't get success,
@@ -502,13 +502,13 @@
 		addrs = addrs[0:0]
 		for _, rr := range msg.Answers {
 			h := rr.Header()
-			if h.Class == message.ClassINET && equalASCIILabel(h.Name, name) {
+			if h.Class == dnsmessage.ClassINET && equalASCIILabel(h.Name, name) {
 				switch h.Type {
 				case qtype:
 					addrs = append(addrs, rr)
-				case message.TypeCNAME:
+				case dnsmessage.TypeCNAME:
 					// redirect to cname
-					name = rr.(*message.CNAMEResource).CNAME
+					name = rr.(*dnsmessage.CNAMEResource).CNAME
 					continue Cname
 				}
 			}
diff --git a/dns/message/message.go b/dns/dnsmessage/message.go
similarity index 98%
rename from dns/message/message.go
rename to dns/dnsmessage/message.go
index 1af0cf4..da43b0b 100644
--- a/dns/message/message.go
+++ b/dns/dnsmessage/message.go
@@ -2,16 +2,12 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-// This is a fork of igudger@'s patch to go's net library, found at
-// https://go-review.googlesource.com/#/c/35237/
-// Can remove if/when the code is accepted into the net library.
-
-// Package message provides a mostly RFC 1035 compliant implementation of DNS
-// message packing and unpacking.
+// Package dnsmessage provides a mostly RFC 1035 compliant implementation of
+// DNS message packing and unpacking.
 //
 // This implementation is designed to minimize heap allocations and avoid
 // unnecessary packing and unpacking as much as possible.
-package message
+package dnsmessage
 
 import (
 	"errors"
@@ -366,6 +362,16 @@
 }
 
 func (p *Parser) skipResource(sec section) error {
+	if p.resHeaderValid {
+		newOff := p.off + int(p.resHeader.Length)
+		if newOff > len(p.msg) {
+			return errResourceLen
+		}
+		p.off = newOff
+		p.resHeaderValid = false
+		p.index++
+		return nil
+	}
 	if err := p.checkAdvance(sec); err != nil {
 		return err
 	}
diff --git a/dns/message/message_test.go b/dns/dnsmessage/message_test.go
similarity index 87%
rename from dns/message/message_test.go
rename to dns/dnsmessage/message_test.go
index cafa70d..46edd72 100644
--- a/dns/message/message_test.go
+++ b/dns/dnsmessage/message_test.go
@@ -2,11 +2,13 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
-package message
+package dnsmessage
 
 import (
 	"fmt"
+	"net"
 	"reflect"
+	"strings"
 	"testing"
 )
 
@@ -265,6 +267,115 @@
 	}
 }
 
+func ExampleHeaderSearch() {
+	msg := Message{
+		Header: Header{Response: true, Authoritative: true},
+		Questions: []Question{
+			{
+				Name:  "foo.bar.example.com.",
+				Type:  TypeA,
+				Class: ClassINET,
+			},
+			{
+				Name:  "bar.example.com.",
+				Type:  TypeA,
+				Class: ClassINET,
+			},
+		},
+		Answers: []Resource{
+			&AResource{
+				ResourceHeader: ResourceHeader{
+					Name:  "foo.bar.example.com.",
+					Type:  TypeA,
+					Class: ClassINET,
+				},
+				A: [4]byte{127, 0, 0, 1},
+			},
+			&AResource{
+				ResourceHeader: ResourceHeader{
+					Name:  "bar.example.com.",
+					Type:  TypeA,
+					Class: ClassINET,
+				},
+				A: [4]byte{127, 0, 0, 2},
+			},
+		},
+	}
+
+	buf, err := msg.Pack()
+	if err != nil {
+		panic(err)
+	}
+
+	wantName := "bar.example.com."
+
+	var p Parser
+	if _, err := p.Start(buf); err != nil {
+		panic(err)
+	}
+
+	for {
+		q, err := p.Question()
+		if err == ErrSectionDone {
+			break
+		}
+		if err != nil {
+			panic(err)
+		}
+
+		if q.Name != wantName {
+			continue
+		}
+
+		fmt.Println("Found question for name", wantName)
+		if err := p.SkipAllQuestions(); err != nil {
+			panic(err)
+		}
+		break
+	}
+
+	var gotIPs []net.IP
+	for {
+		h, err := p.AnswerHeader()
+		if err == ErrSectionDone {
+			break
+		}
+		if err != nil {
+			panic(err)
+		}
+
+		if (h.Type != TypeA && h.Type != TypeAAAA) || h.Class != ClassINET {
+			continue
+		}
+
+		if !strings.EqualFold(h.Name, wantName) {
+			if err := p.SkipAnswer(); err != nil {
+				panic(err)
+			}
+			continue
+		}
+		a, err := p.Answer()
+		if err != nil {
+			panic(err)
+		}
+
+		switch r := a.(type) {
+		default:
+			panic(fmt.Sprintf("unknown type: %T", r))
+		case *AResource:
+			gotIPs = append(gotIPs, r.A[:])
+		case *AAAAResource:
+			gotIPs = append(gotIPs, r.AAAA[:])
+		}
+	}
+
+	fmt.Printf("Found A/AAAA records for name %s: %v\n", wantName, gotIPs)
+
+	// Output:
+	// Found question for name bar.example.com.
+	// Found A/AAAA records for name bar.example.com.: [127.0.0.2]
+}
+
 func largeTestMsg() Message {
 	return Message{
 		Header: Header{Response: true, Authoritative: true},