Fix CNAME handling for DNS cache.

Also ensure we don't cache the same entry multiple times.

NET-45 #done

Change-Id: Ib830a38ec0ddfa1966e5598d03ae4807f6160b03
diff --git a/dns/cache.go b/dns/cache.go
index 43670b6..9248057 100644
--- a/dns/cache.go
+++ b/dns/cache.go
@@ -5,6 +5,7 @@
 package dns
 
 import (
+	"log"
 	"time"
 
 	"sync"
@@ -16,20 +17,25 @@
 	maxEntries = 1024
 )
 
+const debug = false
+
 var testHookNow = func() time.Time { return time.Now() }
 
 // Single entry in the cache, like a TypeA resource holding an IPv4 address.
-// TODO(mpcomplete): I don't think we need the whole Resource. Depends on the questions
-// we want to answer.
 type cacheEntry struct {
 	rr  dnsmessage.Resource // the resource
 	ttd time.Time           // when this entry expires
 }
 
 // Returns true if this entry is a CNAME that points at something no longer in our cache.
+// NOTE: Assumes cache mutex is locked.
 func (entry *cacheEntry) isDanglingCNAME(cache *cacheInfo) bool {
-	h := entry.rr.Header()
-	return h.Type == dnsmessage.TypeCNAME && cache.m[h.Name] == nil
+	switch rr := entry.rr.(type) {
+	case *dnsmessage.CNAMEResource:
+		return cache.m[rr.CNAME] == nil
+	default:
+		return false
+	}
 }
 
 // The full cache.
@@ -52,24 +58,78 @@
 	rrs := []dnsmessage.Resource{}
 	for _, entry := range entries {
 		h := entry.rr.Header()
-		if h.Class == question.Class && h.Type == question.Type && h.Name == question.Name {
-			rrs = append(rrs, entry.rr)
+		if h.Class == question.Class && h.Name == question.Name {
+			switch rr := entry.rr.(type) {
+			case *dnsmessage.CNAMEResource:
+				cnamerrs := cache.lookup(&dnsmessage.Question{
+					Name:  rr.CNAME,
+					Class: question.Class,
+					Type:  question.Type,
+				})
+				rrs = append(rrs, cnamerrs...)
+			default:
+				if h.Type == question.Type {
+					rrs = append(rrs, rr)
+				}
+			}
 		}
 	}
 	return rrs
 }
 
+func resourceEqual(r1 dnsmessage.Resource, r2 dnsmessage.Resource) bool {
+	h1 := r1.Header()
+	h2 := r2.Header()
+	if h1.Class != h2.Class || h1.Type != h2.Type || h1.Name != h2.Name {
+		return false
+	}
+	switch r1 := r1.(type) {
+	case *dnsmessage.AResource:
+		return r1.A == r2.(*dnsmessage.AResource).A
+	case *dnsmessage.AAAAResource:
+		return r1.AAAA == r2.(*dnsmessage.AAAAResource).AAAA
+	case *dnsmessage.CNAMEResource:
+		return r1.CNAME == r2.(*dnsmessage.CNAMEResource).CNAME
+	}
+	panic("unexpected resource type")
+}
+
+// Searches `entries` for an exact resource match, returning the entry if found.
+func findExact(entries []cacheEntry, rr dnsmessage.Resource) *cacheEntry {
+	for _, entry := range entries {
+		if resourceEqual(entry.rr, rr) {
+			return &entry
+		}
+	}
+	return nil
+}
+
 // Attempts to add a new entry into the cache. Can fail if the cache is full.
-func (cache *cacheInfo) insert(h *dnsmessage.ResourceHeader, rr dnsmessage.Resource) {
+func (cache *cacheInfo) insert(rr dnsmessage.Resource) {
+	h := rr.Header()
 	newEntry := cacheEntry{
 		ttd: testHookNow().Add(time.Duration(h.TTL) * time.Second),
 		rr:  rr,
 	}
+
 	cache.Lock()
-	if cache.numEntries+1 <= maxEntries {
-		// TODO(mpcomplete): might be better to evict the LRU entry instead.
-		cache.m[h.Name] = append(cache.m[h.Name], newEntry)
+	entries := cache.m[h.Name]
+	if existing := findExact(entries, rr); existing != nil {
+		if debug {
+			log.Printf("DNS cache update: %v(%v)", h.Name, h.Type)
+		}
+		if newEntry.ttd.After(existing.ttd) {
+			existing.ttd = newEntry.ttd
+		}
+	} else if cache.numEntries+1 <= maxEntries {
+		if debug {
+			log.Printf("DNS cache insert: %v(%v)", h.Name, h.Type)
+		}
+		cache.m[h.Name] = append(entries, newEntry)
 		cache.numEntries++
+	} else {
+		// TODO(mpcomplete): might be better to evict the LRU entry instead.
+		log.Printf("DNS cache is full; insert failed: %v(%v)", h.Name, h.Type)
 	}
 	cache.Unlock()
 }
@@ -82,7 +142,7 @@
 		if h.Class == dnsmessage.ClassINET {
 			switch h.Type {
 			case dnsmessage.TypeA, dnsmessage.TypeAAAA, dnsmessage.TypeCNAME:
-				cache.insert(h, rr)
+				cache.insert(rr)
 			}
 		}
 	}
@@ -119,10 +179,16 @@
 	return func(c *Client, question dnsmessage.Question) (string, []dnsmessage.Resource, *dnsmessage.Message, error) {
 		rrs := cache.lookup(&question)
 		if len(rrs) != 0 {
+			if debug {
+				log.Printf("DNS cache lookup %v(%v) => %v", question.Name, question.Type, rrs)
+			}
 			return "", rrs, nil, nil
 		}
 
 		cname, rrs, msg, err := fallback(c, question)
+		if debug {
+			log.Printf("DNS cache lookup failed: %v(%v) => %v", question.Name, question.Type, rrs)
+		}
 		if err == nil {
 			go cache.insertAll(msg.Answers)
 		}
diff --git a/dns/cache_test.go b/dns/cache_test.go
index 470a657..378bda2 100644
--- a/dns/cache_test.go
+++ b/dns/cache_test.go
@@ -102,7 +102,7 @@
 		cache.insertAll([]dnsmessage.Resource{
 			&dnsmessage.AResource{
 				ResourceHeader: makeResourceHeader("example.com.", 5),
-				A:              [4]byte{127, byte(i << 16), byte(i << 8), byte(i)},
+				A:              [4]byte{byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i)},
 			},
 		})
 	}
@@ -143,3 +143,38 @@
 		t.Errorf("cache.insertAll failed. Got %d. Want %d.", len(rrs), 1)
 	}
 }
+
+// Tests that we get results when looking up a domain alias.
+func TestCNAME(t *testing.T) {
+	cache := newCache()
+	cache.insertAll(smallTestResources)
+
+	// One CNAME record that points at an existing record.
+	cache.insertAll([]dnsmessage.Resource{
+		&dnsmessage.CNAMEResource{
+			ResourceHeader: makeResourceHeader("foobar.com.", 10),
+			CNAME:          "example.com.",
+		},
+	})
+
+	rrs := cache.lookup(makeQuestion("foobar.com."))
+	if len(rrs) != 2 {
+		t.Errorf("cache.lookup failed. Got %d. Want %d.", len(rrs), 2)
+	}
+	for _, rr := range rrs {
+		if rr.Header().Name != "example.com." {
+			t.Errorf("cache.lookup failed. Got '%q'. Want 'example.com.'", rr.Header().Name)
+		}
+	}
+}
+
+// Tests that the cache doesn't store multiple identical records.
+func TestDupe(t *testing.T) {
+	cache := newCache()
+	cache.insertAll(smallTestResources)
+	cache.insertAll(smallTestResources)
+	rrs := cache.lookup(smallTestQuestion)
+	if len(rrs) != 2 {
+		t.Errorf("cache.lookup failed. Got %d. Want %d.", len(rrs), 2)
+	}
+}