// Copyright 2017 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package dns

import (
	"fmt"
	"log"
	"math"
	"sync"
	"time"

	"github.com/google/netstack/dns/dnsmessage"
)

const (
	// TODO: Think about a good value. dnsmasq defaults to 150 names.
	maxEntries = 1024
)

const debug = false

var testHookNow = func() time.Time { return time.Now() }

// Single entry in the cache, like a TypeA resource holding an IPv4 address.
type cacheEntry struct {
	rr  dnsmessage.Resource // the resource
	ttd time.Time           // when this entry expires
}

// Returns true if this entry is a CNAME that points at something no longer in our cache.
func (entry *cacheEntry) isDanglingCNAME(cache *cacheInfo) bool {
	switch rr := entry.rr.(type) {
	case *dnsmessage.CNAMEResource:
		return cache.m[rr.CNAME] == nil
	default:
		return false
	}
}

// The full cache.
type cacheInfo struct {
	mu         sync.Mutex
	m          map[string][]*cacheEntry
	numEntries int
}

func newCache() cacheInfo {
	return cacheInfo{m: make(map[string][]*cacheEntry)}
}

// Returns a list of Resources that match the given Question (same class and type and matching domain name).
func (cache *cacheInfo) lookup(question *dnsmessage.Question) []dnsmessage.Resource {
	entries := cache.m[question.Name]

	rrs := []dnsmessage.Resource{}
	for _, entry := range entries {
		h := entry.rr.Header()
		if h.Class == question.Class && h.Name == question.Name {
			switch rr := entry.rr.(type) {
			case *dnsmessage.CNAMEResource:
				cnamerrs := cache.lookup(&dnsmessage.Question{
					Name:  rr.CNAME,
					Class: question.Class,
					Type:  question.Type,
				})
				rrs = append(rrs, cnamerrs...)
			default:
				if h.Type == question.Type {
					rrs = append(rrs, rr)
				}
			}
		}
	}
	return rrs
}

func resourceEqual(r1 dnsmessage.Resource, r2 dnsmessage.Resource) bool {
	h1 := r1.Header()
	h2 := r2.Header()
	if h1.Class != h2.Class || h1.Type != h2.Type || h1.Name != h2.Name {
		return false
	}
	switch r1 := r1.(type) {
	case *dnsmessage.AResource:
		return r1.A == r2.(*dnsmessage.AResource).A
	case *dnsmessage.AAAAResource:
		return r1.AAAA == r2.(*dnsmessage.AAAAResource).AAAA
	case *dnsmessage.CNAMEResource:
		return r1.CNAME == r2.(*dnsmessage.CNAMEResource).CNAME
	case *dnsmessage.NegativeResource:
		return true
	}
	panic("unexpected resource type")
}

// Searches `entries` for an exact resource match, returning the entry if found.
func findExact(entries []*cacheEntry, rr dnsmessage.Resource) *cacheEntry {
	for _, entry := range entries {
		if resourceEqual(entry.rr, rr) {
			return entry
		}
	}
	return nil
}

// Finds the minimum TTL value of any SOA resource in a response. Returns 0 if not found.
// This is used for caching a failed DNS query. See RFC 2308.
func findSOAMinTTL(auths []dnsmessage.Resource) uint32 {
	minTTL := uint32(math.MaxUint32)
	foundSOA := false
	for _, auth := range auths {
		if auth.Header().Class == dnsmessage.ClassINET {
			switch soa := auth.(type) {
			case *dnsmessage.SOAResource:
				foundSOA = true
				if soa.MinTTL < minTTL {
					minTTL = soa.MinTTL
				}
			}
		}
	}
	if foundSOA {
		return minTTL
	}
	return 0
}

// Attempts to add a new entry into the cache. Can fail if the cache is full.
func (cache *cacheInfo) insert(rr dnsmessage.Resource) {
	h := rr.Header()
	newEntry := cacheEntry{
		ttd: testHookNow().Add(time.Duration(h.TTL) * time.Second),
		rr:  rr,
	}

	entries := cache.m[h.Name]
	if existing := findExact(entries, rr); existing != nil {
		if _, ok := existing.rr.(*dnsmessage.NegativeResource); ok {
			// We have a valid record now; replace the negative resource entirely.
			existing.rr = rr
			existing.ttd = newEntry.ttd
		} else if newEntry.ttd.After(existing.ttd) {
			existing.ttd = newEntry.ttd
		}
		if debug {
			log.Printf("DNS cache update: %v(%v) expires %v", h.Name, h.Type, existing.ttd)
		}
	} else if cache.numEntries+1 <= maxEntries {
		if debug {
			log.Printf("DNS cache insert: %v(%v) expires %v", h.Name, h.Type, newEntry.ttd)
		}
		cache.m[h.Name] = append(entries, &newEntry)
		cache.numEntries++
	} else {
		// TODO(mpcomplete): might be better to evict the LRU entry instead.
		// TODO(mpcomplete): RFC 1035 7.4 says that if we can't cache this RR, we
		// shouldn't cache any other RRs for the same name in this response.
		log.Printf("DNS cache is full; insert failed: %v(%v)", h.Name, h.Type)
	}
}

// Attempts to add each Resource as a new entry in the cache. Can fail if the cache is full.
func (cache *cacheInfo) insertAll(rrs []dnsmessage.Resource) {
	cache.prune()
	for _, rr := range rrs {
		h := rr.Header()
		if h.Class == dnsmessage.ClassINET {
			switch h.Type {
			case dnsmessage.TypeA, dnsmessage.TypeAAAA, dnsmessage.TypeCNAME:
				cache.insert(rr)
			}
		}
	}
}

func (cache *cacheInfo) insertNegative(question *dnsmessage.Question, msg *dnsmessage.Message) {
	cache.prune()
	minTTL := findSOAMinTTL(msg.Authorities)
	if minTTL == 0 {
		// Don't cache without a TTL value.
		return
	}
	rr := &dnsmessage.NegativeResource{
		ResourceHeader: dnsmessage.ResourceHeader{
			Name:  question.Name,
			Type:  question.Type,
			Class: dnsmessage.ClassINET,
			TTL:   minTTL,
		},
	}
	cache.insert(rr)
}

// Removes every expired/dangling entry from the cache.
func (cache *cacheInfo) prune() {
	now := testHookNow()
	for name, entries := range cache.m {
		removed := false
		for i := 0; i < len(entries); {
			if now.After(entries[i].ttd) || entries[i].isDanglingCNAME(cache) {
				entries[i] = entries[len(entries)-1]
				entries = entries[:len(entries)-1]
				cache.numEntries--
				removed = true
			} else {
				i++
			}
		}
		if len(entries) == 0 {
			delete(cache.m, name)
		} else if removed {
			cache.m[name] = entries
		}
	}
}

func debugString(rr []dnsmessage.Resource) string {
	str := "["
	for _, rr := range rr {
		str += fmt.Sprintf("%v, ", rr)
	}
	str += "]"
	return str
}

var cache = newCache()

func newCachedResolver(fallback Resolver) Resolver {
	return func(c *Client, question dnsmessage.Question) (string, []dnsmessage.Resource, *dnsmessage.Message, error) {
		if !(question.Class == dnsmessage.ClassINET && (question.Type == dnsmessage.TypeA || question.Type == dnsmessage.TypeAAAA)) {
			panic("unexpected question type")
		}

		cache.mu.Lock()
		rrs := cache.lookup(&question)
		cache.mu.Unlock()
		if len(rrs) != 0 {
			if debug {
				log.Printf("DNS cache hit %v(%v) => %v", question.Name, question.Type, debugString(rrs))
			}
			return "", rrs, nil, nil
		}

		cname, rrs, msg, err := fallback(c, question)
		if debug {
			log.Printf("DNS cache miss, server returned %v(%v) => %v; err=%v", question.Name, question.Type, debugString(rrs), err)
		}
		cache.mu.Lock()
		if err == nil {
			cache.insertAll(msg.Answers)
		} else if err, ok := err.(*Error); ok && err.CacheNegative {
			cache.insertNegative(&question, msg)
		}
		cache.mu.Unlock()

		return cname, rrs, msg, err
	}
}
