blob: 43670b609011010068053759e278b75e8b3cd3da [file] [log] [blame]
// Copyright 2017 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dns
import (
"time"
"sync"
"github.com/google/netstack/dns/dnsmessage"
)
const (
maxEntries = 1024
)
var testHookNow = func() time.Time { return time.Now() }
// Single entry in the cache, like a TypeA resource holding an IPv4 address.
// TODO(mpcomplete): I don't think we need the whole Resource. Depends on the questions
// we want to answer.
type cacheEntry struct {
rr dnsmessage.Resource // the resource
ttd time.Time // when this entry expires
}
// Returns true if this entry is a CNAME that points at something no longer in our cache.
func (entry *cacheEntry) isDanglingCNAME(cache *cacheInfo) bool {
h := entry.rr.Header()
return h.Type == dnsmessage.TypeCNAME && cache.m[h.Name] == nil
}
// The full cache.
type cacheInfo struct {
sync.RWMutex
m map[string][]cacheEntry
numEntries int
}
func newCache() cacheInfo {
return cacheInfo{m: make(map[string][]cacheEntry)}
}
// Returns a list of Resources that match the given Question (same class and type and matching domain name).
func (cache *cacheInfo) lookup(question *dnsmessage.Question) []dnsmessage.Resource {
cache.RLock()
entries := cache.m[question.Name]
cache.RUnlock()
rrs := []dnsmessage.Resource{}
for _, entry := range entries {
h := entry.rr.Header()
if h.Class == question.Class && h.Type == question.Type && h.Name == question.Name {
rrs = append(rrs, entry.rr)
}
}
return rrs
}
// Attempts to add a new entry into the cache. Can fail if the cache is full.
func (cache *cacheInfo) insert(h *dnsmessage.ResourceHeader, rr dnsmessage.Resource) {
newEntry := cacheEntry{
ttd: testHookNow().Add(time.Duration(h.TTL) * time.Second),
rr: rr,
}
cache.Lock()
if cache.numEntries+1 <= maxEntries {
// TODO(mpcomplete): might be better to evict the LRU entry instead.
cache.m[h.Name] = append(cache.m[h.Name], newEntry)
cache.numEntries++
}
cache.Unlock()
}
// Attempts to add each Resource as a new entry in the cache. Can fail if the cache is full.
func (cache *cacheInfo) insertAll(rrs []dnsmessage.Resource) {
cache.prune()
for _, rr := range rrs {
h := rr.Header()
if h.Class == dnsmessage.ClassINET {
switch h.Type {
case dnsmessage.TypeA, dnsmessage.TypeAAAA, dnsmessage.TypeCNAME:
cache.insert(h, rr)
}
}
}
}
// Removes every expired/dangling entry from the cache.
func (cache *cacheInfo) prune() {
now := testHookNow()
cache.Lock()
for name, entries := range cache.m {
removed := false
for i := 0; i < len(entries); {
if now.After(entries[i].ttd) || entries[i].isDanglingCNAME(cache) {
entries[i] = entries[len(entries)-1]
entries = entries[:len(entries)-1]
cache.numEntries--
removed = true
} else {
i++
}
}
if len(entries) == 0 {
delete(cache.m, name)
} else if removed {
cache.m[name] = entries
}
}
cache.Unlock()
}
var cache = newCache()
func newCachedResolver(fallback Resolver) Resolver {
return func(c *Client, question dnsmessage.Question) (string, []dnsmessage.Resource, *dnsmessage.Message, error) {
rrs := cache.lookup(&question)
if len(rrs) != 0 {
return "", rrs, nil, nil
}
cname, rrs, msg, err := fallback(c, question)
if err == nil {
go cache.insertAll(msg.Answers)
}
return cname, rrs, msg, err
}
}