Fix CNAME handling for DNS cache.
Also ensure we don't cache the same entry multiple times.
NET-45 #done
Change-Id: Ib830a38ec0ddfa1966e5598d03ae4807f6160b03
diff --git a/dns/cache.go b/dns/cache.go
index 43670b6..9248057 100644
--- a/dns/cache.go
+++ b/dns/cache.go
@@ -5,6 +5,7 @@
package dns
import (
+ "log"
"time"
"sync"
@@ -16,20 +17,25 @@
maxEntries = 1024
)
+const debug = false
+
var testHookNow = func() time.Time { return time.Now() }
// Single entry in the cache, like a TypeA resource holding an IPv4 address.
-// TODO(mpcomplete): I don't think we need the whole Resource. Depends on the questions
-// we want to answer.
type cacheEntry struct {
rr dnsmessage.Resource // the resource
ttd time.Time // when this entry expires
}
// Returns true if this entry is a CNAME that points at something no longer in our cache.
+// NOTE: Assumes cache mutex is locked.
func (entry *cacheEntry) isDanglingCNAME(cache *cacheInfo) bool {
- h := entry.rr.Header()
- return h.Type == dnsmessage.TypeCNAME && cache.m[h.Name] == nil
+ switch rr := entry.rr.(type) {
+ case *dnsmessage.CNAMEResource:
+ return cache.m[rr.CNAME] == nil
+ default:
+ return false
+ }
}
// The full cache.
@@ -52,24 +58,78 @@
rrs := []dnsmessage.Resource{}
for _, entry := range entries {
h := entry.rr.Header()
- if h.Class == question.Class && h.Type == question.Type && h.Name == question.Name {
- rrs = append(rrs, entry.rr)
+ if h.Class == question.Class && h.Name == question.Name {
+ switch rr := entry.rr.(type) {
+ case *dnsmessage.CNAMEResource:
+ cnamerrs := cache.lookup(&dnsmessage.Question{
+ Name: rr.CNAME,
+ Class: question.Class,
+ Type: question.Type,
+ })
+ rrs = append(rrs, cnamerrs...)
+ default:
+ if h.Type == question.Type {
+ rrs = append(rrs, rr)
+ }
+ }
}
}
return rrs
}
+func resourceEqual(r1 dnsmessage.Resource, r2 dnsmessage.Resource) bool {
+ h1 := r1.Header()
+ h2 := r2.Header()
+ if h1.Class != h2.Class || h1.Type != h2.Type || h1.Name != h2.Name {
+ return false
+ }
+ switch r1 := r1.(type) {
+ case *dnsmessage.AResource:
+ return r1.A == r2.(*dnsmessage.AResource).A
+ case *dnsmessage.AAAAResource:
+ return r1.AAAA == r2.(*dnsmessage.AAAAResource).AAAA
+ case *dnsmessage.CNAMEResource:
+ return r1.CNAME == r2.(*dnsmessage.CNAMEResource).CNAME
+ }
+ panic("unexpected resource type")
+}
+
+// Searches `entries` for an exact resource match, returning the entry if found.
+func findExact(entries []cacheEntry, rr dnsmessage.Resource) *cacheEntry {
+ for _, entry := range entries {
+ if resourceEqual(entry.rr, rr) {
+ return &entry
+ }
+ }
+ return nil
+}
+
// Attempts to add a new entry into the cache. Can fail if the cache is full.
-func (cache *cacheInfo) insert(h *dnsmessage.ResourceHeader, rr dnsmessage.Resource) {
+func (cache *cacheInfo) insert(rr dnsmessage.Resource) {
+ h := rr.Header()
newEntry := cacheEntry{
ttd: testHookNow().Add(time.Duration(h.TTL) * time.Second),
rr: rr,
}
+
cache.Lock()
- if cache.numEntries+1 <= maxEntries {
- // TODO(mpcomplete): might be better to evict the LRU entry instead.
- cache.m[h.Name] = append(cache.m[h.Name], newEntry)
+ entries := cache.m[h.Name]
+ if existing := findExact(entries, rr); existing != nil {
+ if debug {
+ log.Printf("DNS cache update: %v(%v)", h.Name, h.Type)
+ }
+ if newEntry.ttd.After(existing.ttd) {
+ existing.ttd = newEntry.ttd
+ }
+ } else if cache.numEntries+1 <= maxEntries {
+ if debug {
+ log.Printf("DNS cache insert: %v(%v)", h.Name, h.Type)
+ }
+ cache.m[h.Name] = append(entries, newEntry)
cache.numEntries++
+ } else {
+ // TODO(mpcomplete): might be better to evict the LRU entry instead.
+ log.Printf("DNS cache is full; insert failed: %v(%v)", h.Name, h.Type)
}
cache.Unlock()
}
@@ -82,7 +142,7 @@
if h.Class == dnsmessage.ClassINET {
switch h.Type {
case dnsmessage.TypeA, dnsmessage.TypeAAAA, dnsmessage.TypeCNAME:
- cache.insert(h, rr)
+ cache.insert(rr)
}
}
}
@@ -119,10 +179,16 @@
return func(c *Client, question dnsmessage.Question) (string, []dnsmessage.Resource, *dnsmessage.Message, error) {
rrs := cache.lookup(&question)
if len(rrs) != 0 {
+ if debug {
+ log.Printf("DNS cache lookup %v(%v) => %v", question.Name, question.Type, rrs)
+ }
return "", rrs, nil, nil
}
cname, rrs, msg, err := fallback(c, question)
+ if debug {
+ log.Printf("DNS cache lookup failed: %v(%v) => %v", question.Name, question.Type, rrs)
+ }
if err == nil {
go cache.insertAll(msg.Answers)
}
diff --git a/dns/cache_test.go b/dns/cache_test.go
index 470a657..378bda2 100644
--- a/dns/cache_test.go
+++ b/dns/cache_test.go
@@ -102,7 +102,7 @@
cache.insertAll([]dnsmessage.Resource{
&dnsmessage.AResource{
ResourceHeader: makeResourceHeader("example.com.", 5),
- A: [4]byte{127, byte(i << 16), byte(i << 8), byte(i)},
+ A: [4]byte{byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i)},
},
})
}
@@ -143,3 +143,38 @@
t.Errorf("cache.insertAll failed. Got %d. Want %d.", len(rrs), 1)
}
}
+
+// Tests that we get results when looking up a domain alias.
+func TestCNAME(t *testing.T) {
+ cache := newCache()
+ cache.insertAll(smallTestResources)
+
+ // One CNAME record that points at an existing record.
+ cache.insertAll([]dnsmessage.Resource{
+ &dnsmessage.CNAMEResource{
+ ResourceHeader: makeResourceHeader("foobar.com.", 10),
+ CNAME: "example.com.",
+ },
+ })
+
+ rrs := cache.lookup(makeQuestion("foobar.com."))
+ if len(rrs) != 2 {
+ t.Errorf("cache.lookup failed. Got %d. Want %d.", len(rrs), 2)
+ }
+ for _, rr := range rrs {
+ if rr.Header().Name != "example.com." {
+ t.Errorf("cache.lookup failed. Got '%q'. Want 'example.com.'", rr.Header().Name)
+ }
+ }
+}
+
+// Tests that the cache doesn't store multiple identical records.
+func TestDupe(t *testing.T) {
+ cache := newCache()
+ cache.insertAll(smallTestResources)
+ cache.insertAll(smallTestResources)
+ rrs := cache.lookup(smallTestQuestion)
+ if len(rrs) != 2 {
+ t.Errorf("cache.lookup failed. Got %d. Want %d.", len(rrs), 2)
+ }
+}