blob: 73caf3caec88d77442b65d5c71978219d7a20cde [file] [log] [blame]
// Copyright 2016 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stack
import (
"context"
"sync"
"time"
"github.com/google/netstack/tcpip"
)
const linkAddrCacheSize = 512 // max cache entries
// linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses.
//
// The entries are stored in a ring buffer, oldest entry replaced first.
type linkAddrCache struct {
ageLimit time.Duration
mu sync.RWMutex
cache map[tcpip.FullAddress]*linkAddrEntry
next int // array index of next available entry
entries [linkAddrCacheSize]linkAddrEntry
waiters map[tcpip.FullAddress]map[chan tcpip.LinkAddress]struct{}
}
// A linkAddrEntry is an entry in the linkAddrCache.
type linkAddrEntry struct {
addr tcpip.FullAddress
linkAddr tcpip.LinkAddress
expiration time.Time
}
func (c *linkAddrCache) valid(e *linkAddrEntry) bool {
return time.Now().Before(e.expiration)
}
// add adds a k -> v mapping to the cache.
func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
c.mu.Lock()
defer c.mu.Unlock()
entry := c.cache[k]
if entry != nil && entry.linkAddr == v && c.valid(entry) {
return // Keep existing entry.
}
// Take next entry.
entry = &c.entries[c.next]
if c.cache[entry.addr] == entry {
delete(c.cache, entry.addr)
}
*entry = linkAddrEntry{
addr: k,
linkAddr: v,
expiration: time.Now().Add(c.ageLimit),
}
c.cache[k] = entry
c.next++
if c.next == len(c.entries) {
c.next = 0
}
for ch := range c.waiters[k] {
ch <- v
}
}
// get reports any known link address for k.
func (c *linkAddrCache) get(k tcpip.FullAddress, timeout time.Duration) (linkAddr tcpip.LinkAddress) {
c.mu.RLock()
if entry, found := c.cache[k]; found && c.valid(entry) {
linkAddr = entry.linkAddr
}
c.mu.RUnlock()
if linkAddr != "" || timeout == 0 {
return linkAddr
}
c.mu.Lock()
if entry, found := c.cache[k]; found && c.valid(entry) { // check again
c.mu.Unlock()
return entry.linkAddr
}
ch := make(chan tcpip.LinkAddress, 1)
m := c.waiters[k]
if m == nil {
m = make(map[chan tcpip.LinkAddress]struct{})
c.waiters[k] = m
}
m[ch] = struct{}{}
c.mu.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer func() {
cancel()
c.mu.Lock()
m := c.waiters[k]
delete(m, ch)
if len(m) == 0 {
delete(c.waiters, k)
}
c.mu.Unlock()
}()
select {
case linkAddr := <-ch:
return linkAddr
case <-ctx.Done():
return ""
}
}
func newLinkAddrCache(ageLimit time.Duration) *linkAddrCache {
c := &linkAddrCache{
ageLimit: ageLimit,
cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize),
waiters: make(map[tcpip.FullAddress]map[chan tcpip.LinkAddress]struct{}),
}
return c
}