| // Copyright 2017 The Netstack Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package dns |
| |
| import ( |
| "log" |
| "time" |
| |
| "sync" |
| |
| "github.com/google/netstack/dns/dnsmessage" |
| ) |
| |
| const ( |
| maxEntries = 1024 |
| ) |
| |
| const debug = false |
| |
| var testHookNow = func() time.Time { return time.Now() } |
| |
| // Single entry in the cache, like a TypeA resource holding an IPv4 address. |
| type cacheEntry struct { |
| rr dnsmessage.Resource // the resource |
| ttd time.Time // when this entry expires |
| } |
| |
| // Returns true if this entry is a CNAME that points at something no longer in our cache. |
| // NOTE: Assumes cache mutex is locked. |
| func (entry *cacheEntry) isDanglingCNAME(cache *cacheInfo) bool { |
| switch rr := entry.rr.(type) { |
| case *dnsmessage.CNAMEResource: |
| return cache.m[rr.CNAME] == nil |
| default: |
| return false |
| } |
| } |
| |
| // The full cache. |
| type cacheInfo struct { |
| sync.RWMutex |
| m map[string][]cacheEntry |
| numEntries int |
| } |
| |
| func newCache() cacheInfo { |
| return cacheInfo{m: make(map[string][]cacheEntry)} |
| } |
| |
| // Returns a list of Resources that match the given Question (same class and type and matching domain name). |
| func (cache *cacheInfo) lookup(question *dnsmessage.Question) []dnsmessage.Resource { |
| cache.RLock() |
| entries := cache.m[question.Name] |
| cache.RUnlock() |
| |
| rrs := []dnsmessage.Resource{} |
| for _, entry := range entries { |
| h := entry.rr.Header() |
| if h.Class == question.Class && h.Name == question.Name { |
| switch rr := entry.rr.(type) { |
| case *dnsmessage.CNAMEResource: |
| cnamerrs := cache.lookup(&dnsmessage.Question{ |
| Name: rr.CNAME, |
| Class: question.Class, |
| Type: question.Type, |
| }) |
| rrs = append(rrs, cnamerrs...) |
| default: |
| if h.Type == question.Type { |
| rrs = append(rrs, rr) |
| } |
| } |
| } |
| } |
| return rrs |
| } |
| |
| func resourceEqual(r1 dnsmessage.Resource, r2 dnsmessage.Resource) bool { |
| h1 := r1.Header() |
| h2 := r2.Header() |
| if h1.Class != h2.Class || h1.Type != h2.Type || h1.Name != h2.Name { |
| return false |
| } |
| switch r1 := r1.(type) { |
| case *dnsmessage.AResource: |
| return r1.A == r2.(*dnsmessage.AResource).A |
| case *dnsmessage.AAAAResource: |
| return r1.AAAA == r2.(*dnsmessage.AAAAResource).AAAA |
| case *dnsmessage.CNAMEResource: |
| return r1.CNAME == r2.(*dnsmessage.CNAMEResource).CNAME |
| } |
| panic("unexpected resource type") |
| } |
| |
| // Searches `entries` for an exact resource match, returning the entry if found. |
| func findExact(entries []cacheEntry, rr dnsmessage.Resource) *cacheEntry { |
| for _, entry := range entries { |
| if resourceEqual(entry.rr, rr) { |
| return &entry |
| } |
| } |
| return nil |
| } |
| |
| // Attempts to add a new entry into the cache. Can fail if the cache is full. |
| func (cache *cacheInfo) insert(rr dnsmessage.Resource) { |
| h := rr.Header() |
| newEntry := cacheEntry{ |
| ttd: testHookNow().Add(time.Duration(h.TTL) * time.Second), |
| rr: rr, |
| } |
| |
| cache.Lock() |
| entries := cache.m[h.Name] |
| if existing := findExact(entries, rr); existing != nil { |
| if debug { |
| log.Printf("DNS cache update: %v(%v)", h.Name, h.Type) |
| } |
| if newEntry.ttd.After(existing.ttd) { |
| existing.ttd = newEntry.ttd |
| } |
| } else if cache.numEntries+1 <= maxEntries { |
| if debug { |
| log.Printf("DNS cache insert: %v(%v)", h.Name, h.Type) |
| } |
| cache.m[h.Name] = append(entries, newEntry) |
| cache.numEntries++ |
| } else { |
| // TODO(mpcomplete): might be better to evict the LRU entry instead. |
| log.Printf("DNS cache is full; insert failed: %v(%v)", h.Name, h.Type) |
| } |
| cache.Unlock() |
| } |
| |
| // Attempts to add each Resource as a new entry in the cache. Can fail if the cache is full. |
| func (cache *cacheInfo) insertAll(rrs []dnsmessage.Resource) { |
| cache.prune() |
| for _, rr := range rrs { |
| h := rr.Header() |
| if h.Class == dnsmessage.ClassINET { |
| switch h.Type { |
| case dnsmessage.TypeA, dnsmessage.TypeAAAA, dnsmessage.TypeCNAME: |
| cache.insert(rr) |
| } |
| } |
| } |
| } |
| |
| // Removes every expired/dangling entry from the cache. |
| func (cache *cacheInfo) prune() { |
| now := testHookNow() |
| cache.Lock() |
| for name, entries := range cache.m { |
| removed := false |
| for i := 0; i < len(entries); { |
| if now.After(entries[i].ttd) || entries[i].isDanglingCNAME(cache) { |
| entries[i] = entries[len(entries)-1] |
| entries = entries[:len(entries)-1] |
| cache.numEntries-- |
| removed = true |
| } else { |
| i++ |
| } |
| } |
| if len(entries) == 0 { |
| delete(cache.m, name) |
| } else if removed { |
| cache.m[name] = entries |
| } |
| } |
| cache.Unlock() |
| } |
| |
| var cache = newCache() |
| |
| func newCachedResolver(fallback Resolver) Resolver { |
| return func(c *Client, question dnsmessage.Question) (string, []dnsmessage.Resource, *dnsmessage.Message, error) { |
| rrs := cache.lookup(&question) |
| if len(rrs) != 0 { |
| if debug { |
| log.Printf("DNS cache lookup %v(%v) => %v", question.Name, question.Type, rrs) |
| } |
| return "", rrs, nil, nil |
| } |
| |
| cname, rrs, msg, err := fallback(c, question) |
| if debug { |
| log.Printf("DNS cache lookup failed: %v(%v) => %v", question.Name, question.Type, rrs) |
| } |
| if err == nil { |
| go cache.insertAll(msg.Answers) |
| } |
| return cname, rrs, msg, err |
| } |
| } |