blob: ca530b9995ce559d12818a1aa38d412ebf1a5a53 [file] [log] [blame]
// Copyright 2020 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dns
import (
"testing"
"time"
"github.com/google/go-cmp/cmp"
"golang.org/x/net/dns/dnsmessage"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
const (
incrementalTimeout = 100 * time.Millisecond
shortLifetime = 1 * time.Nanosecond
shortLifetimeTimeout = 1 * time.Second
middleLifetime = 1 * time.Second
middleLifetimeTimeout = 2 * time.Second
longLifetime = 1 * time.Hour
)
var (
addr1 = tcpip.FullAddress{
Addr: "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
Port: defaultDNSPort,
}
// Address is the same as addr1, but differnt port.
addr2 = tcpip.FullAddress{
Addr: "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
Port: defaultDNSPort + 1,
}
addr3 = tcpip.FullAddress{
Addr: "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02",
Port: defaultDNSPort + 2,
}
// Should assume default port of 53.
addr4 = tcpip.FullAddress{
Addr: "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03",
NIC: 5,
}
addr5 = tcpip.FullAddress{
Addr: "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05",
Port: defaultDNSPort,
}
addr6 = tcpip.FullAddress{
Addr: "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06",
Port: defaultDNSPort,
}
addr7 = tcpip.FullAddress{
Addr: "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x07",
Port: defaultDNSPort,
}
addr8 = tcpip.FullAddress{
Addr: "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x08",
Port: defaultDNSPort,
}
addr9 = tcpip.FullAddress{
Addr: "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x09",
Port: defaultDNSPort,
}
addr10 = tcpip.FullAddress{
Addr: "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0a",
Port: defaultDNSPort,
}
)
func containsFullAddress(list []tcpip.FullAddress, item tcpip.FullAddress) bool {
for _, i := range list {
if i == item {
return true
}
}
return false
}
func containsAddress(list []tcpip.Address, item tcpip.Address) bool {
for _, i := range list {
if i == item {
return true
}
}
return false
}
func TestGetServersCacheNoDuplicates(t *testing.T) {
addr3 := addr3
addr3.Port = defaultDNSPort
addr4WithPort := addr4
addr4WithPort.Port = defaultDNSPort
// We set stack.Stack to nil since thats not what we are testing here.
c := NewClient(nil)
c.UpdateExpiringServers([]tcpip.FullAddress{addr1, addr2, addr2, addr3, addr4, addr8}, longLifetime)
runtimeServers1 := []tcpip.Address{addr5.Addr, addr5.Addr, addr6.Addr, addr7.Addr}
runtimeServers2 := []tcpip.Address{addr6.Addr, addr7.Addr, addr8.Addr, addr9.Addr}
c.SetRuntimeServers([]*[]tcpip.Address{&runtimeServers1, &runtimeServers2})
c.SetDefaultServers([]tcpip.Address{addr3.Addr, addr9.Addr, addr10.Addr, addr10.Addr})
servers := c.GetServersCache()
if !containsFullAddress(servers, addr1) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr1, servers)
}
if !containsFullAddress(servers, addr2) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr2, servers)
}
if !containsFullAddress(servers, addr3) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr3, servers)
}
if !containsFullAddress(servers, addr4WithPort) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr4WithPort, servers)
}
if !containsFullAddress(servers, addr5) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr5, servers)
}
if !containsFullAddress(servers, addr6) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr6, servers)
}
if !containsFullAddress(servers, addr7) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr7, servers)
}
if !containsFullAddress(servers, addr8) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr8, servers)
}
if !containsFullAddress(servers, addr9) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr9, servers)
}
if !containsFullAddress(servers, addr10) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr10, servers)
}
if l := len(servers); l != 10 {
t.Errorf("got len(servers) = %d, want = 10; servers = %+v", l, servers)
}
}
func TestGetServersCacheOrdering(t *testing.T) {
addr4WithPort := addr4
addr4WithPort.Port = defaultDNSPort
// We set stack.Stack to nil since thats not what we are testing here.
c := NewClient(nil)
c.UpdateExpiringServers([]tcpip.FullAddress{addr1, addr2, addr3, addr4}, longLifetime)
runtimeServers1 := []tcpip.Address{addr5.Addr, addr6.Addr}
runtimeServers2 := []tcpip.Address{addr7.Addr, addr8.Addr}
c.SetRuntimeServers([]*[]tcpip.Address{&runtimeServers1, &runtimeServers2})
c.SetDefaultServers([]tcpip.Address{addr9.Addr, addr10.Addr})
servers := c.GetServersCache()
expiringServers := servers[:4]
runtimeServers := servers[4:8]
defaultServers := servers[8:]
if !containsFullAddress(expiringServers, addr1) {
t.Errorf("expected %+v to be in the expiring server cache, got = %+v", addr1, expiringServers)
}
if !containsFullAddress(expiringServers, addr2) {
t.Errorf("expected %+v to be in the expiring server cache, got = %+v", addr2, expiringServers)
}
if !containsFullAddress(expiringServers, addr3) {
t.Errorf("expected %+v to be in the expiring server cache, got = %+v", addr3, expiringServers)
}
if !containsFullAddress(expiringServers, addr4WithPort) {
t.Errorf("expected %+v to be in the expiring server cache, got = %+v", addr4WithPort, expiringServers)
}
if !containsFullAddress(runtimeServers, addr5) {
t.Errorf("expected %+v to be in the runtime server cache, got = %+v", addr5, runtimeServers)
}
if !containsFullAddress(runtimeServers, addr6) {
t.Errorf("expected %+v to be in the runtime server cache, got = %+v", addr6, runtimeServers)
}
if !containsFullAddress(runtimeServers, addr7) {
t.Errorf("expected %+v to be in the runtime server cache, got = %+v", addr7, runtimeServers)
}
if !containsFullAddress(runtimeServers, addr8) {
t.Errorf("expected %+v to be in the runtime server cache, got = %+v", addr8, runtimeServers)
}
if !containsFullAddress(defaultServers, addr9) {
t.Errorf("expected %+v to be in the default server cache, got = %+v", addr9, defaultServers)
}
if !containsFullAddress(defaultServers, addr10) {
t.Errorf("expected %+v to be in the default server cache, got = %+v", addr10, defaultServers)
}
if l := len(servers); l != 10 {
t.Errorf("got len(servers) = %d, want = 10; servers = %+v", l, servers)
}
}
func TestRemoveAllServersWithNIC(t *testing.T) {
addr3 := addr3
addr3.NIC = addr4.NIC
addr4WithPort := addr4
addr4WithPort.Port = defaultDNSPort
// We set stack.Stack to nil since thats not what we are testing here.
c := NewClient(nil)
expectAllAddresses := func() {
t.Helper()
servers := c.GetServersCache()
if !containsFullAddress(servers, addr1) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr1, servers)
}
if !containsFullAddress(servers, addr2) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr2, servers)
}
if !containsFullAddress(servers, addr3) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr3, servers)
}
if !containsFullAddress(servers, addr4WithPort) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr4WithPort, servers)
}
if !containsFullAddress(servers, addr5) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr5, servers)
}
if !containsFullAddress(servers, addr6) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr6, servers)
}
if !containsFullAddress(servers, addr7) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr7, servers)
}
if !containsFullAddress(servers, addr8) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr8, servers)
}
if !containsFullAddress(servers, addr9) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr9, servers)
}
if !containsFullAddress(servers, addr10) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr10, servers)
}
if l := len(servers); l != 10 {
t.Errorf("got len(servers) = %d, want = 10; servers = %+v", l, servers)
}
if t.Failed() {
t.FailNow()
}
}
c.SetDefaultServers([]tcpip.Address{addr5.Addr, addr6.Addr})
runtimeServers1 := []tcpip.Address{addr7.Addr, addr8.Addr}
runtimeServers2 := []tcpip.Address{addr9.Addr, addr10.Addr}
c.SetRuntimeServers([]*[]tcpip.Address{&runtimeServers1, &runtimeServers2})
c.UpdateExpiringServers([]tcpip.FullAddress{addr1, addr2, addr3, addr4}, longLifetime)
expectAllAddresses()
// Should do nothing since a NIC is not specified.
c.RemoveAllServersWithNIC(0)
expectAllAddresses()
// Should do nothing since none of the addresses are associated with a NIC
// with ID 255.
c.RemoveAllServersWithNIC(255)
expectAllAddresses()
// Should remove addr3 and addr4.
c.RemoveAllServersWithNIC(addr4.NIC)
servers := c.GetServersCache()
if !containsFullAddress(servers, addr1) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr1, servers)
}
if !containsFullAddress(servers, addr2) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr2, servers)
}
if !containsFullAddress(servers, addr5) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr5, servers)
}
if !containsFullAddress(servers, addr6) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr6, servers)
}
if !containsFullAddress(servers, addr7) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr7, servers)
}
if !containsFullAddress(servers, addr8) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr8, servers)
}
if !containsFullAddress(servers, addr9) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr9, servers)
}
if !containsFullAddress(servers, addr10) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr10, servers)
}
if l := len(servers); l != 8 {
t.Errorf("got len(servers) = %d, want = 8; servers = %+v", l, servers)
}
}
func TestResolver(t *testing.T) {
const nicID = 1
const nicIPv4Addr = tcpip.Address("\x01\x02\x03\x04")
const nicIPv6Addr = tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")
exampleIPv4Addr1Bytes := [4]byte{192, 168, 0, 1}
exampleIPv4Addr2Bytes := [4]byte{192, 168, 0, 2}
fooExampleIPv4Addr1Bytes := [4]byte{192, 168, 0, 3}
fooExampleIPv4Addr2Bytes := [4]byte{192, 168, 0, 4}
exampleIPv6Addr1Bytes := [16]byte{192, 168, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
exampleIPv6Addr2Bytes := [16]byte{192, 168, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
fooExampleIPv6Addr1Bytes := [16]byte{192, 168, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3}
fooExampleIPv6Addr2Bytes := [16]byte{192, 168, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4}
exampleIPv4Addr1 := tcpip.Address(exampleIPv4Addr1Bytes[:])
exampleIPv4Addr2 := tcpip.Address(exampleIPv4Addr2Bytes[:])
fooExampleIPv4Addr1 := tcpip.Address(fooExampleIPv4Addr1Bytes[:])
fooExampleIPv4Addr2 := tcpip.Address(fooExampleIPv4Addr2Bytes[:])
exampleIPv6Addr1 := tcpip.Address(exampleIPv6Addr1Bytes[:])
exampleIPv6Addr2 := tcpip.Address(exampleIPv6Addr2Bytes[:])
fooExampleIPv6Addr1 := tcpip.Address(fooExampleIPv6Addr1Bytes[:])
fooExampleIPv6Addr2 := tcpip.Address(fooExampleIPv6Addr2Bytes[:])
fakeIPv4AddrBytes := [4]byte{1, 2, 3, 4}
// A simple resolver that returns 1.2.3.4 for all A record questions.
fakeResolver := func(question dnsmessage.Question) (dnsmessage.Name, []dnsmessage.Resource, dnsmessage.Message, error) {
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: 0,
Response: true,
},
Questions: []dnsmessage.Question{question},
}
if question.Type == dnsmessage.TypeA {
r.Answers = []dnsmessage.Resource{
{
Header: dnsmessage.ResourceHeader{
Name: question.Name,
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
Length: 4,
},
Body: &dnsmessage.AResource{
A: fakeIPv4AddrBytes,
},
},
}
}
return question.Name, r.Answers, r, nil
}
fakeResolverResponse := []tcpip.Address{tcpip.Address(fakeIPv4AddrBytes[:])}
// We need a Stack because the default resolver tries to find a route to the
// servers to make sure a route exists. No packets are sent.
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{
arp.NewProtocol(),
ipv4.NewProtocol(),
ipv6.NewProtocol(),
},
TransportProtocols: []stack.TransportProtocol{
udp.NewProtocol(),
},
HandleLocal: true,
})
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
if err := s.AddAddress(nicID, ipv4.ProtocolNumber, nicIPv4Addr); err != nil {
t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, nicIPv4Addr, err)
}
if err := s.AddAddress(nicID, ipv6.ProtocolNumber, nicIPv6Addr); err != nil {
t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, nicIPv6Addr, err)
}
c := NewClient(s)
// Add some entries to the cache that will be returned when testing the
// default resolver.
//
// We need to add both A and AAAA records to the cache for each domain name
// to make sure we do not try to query a DNS server through s.
c.cache.insertAll([]dnsmessage.Resource{
makeTypeAResource(example, 5, exampleIPv4Addr1Bytes),
makeTypeAResource(example, 5, exampleIPv4Addr2Bytes),
makeTypeAResource(fooExample, 5, fooExampleIPv4Addr1Bytes),
makeTypeAResource(fooExample, 5, fooExampleIPv4Addr2Bytes),
makeTypeAAAAResource(example, 5, exampleIPv6Addr1Bytes),
makeTypeAAAAResource(example, 5, exampleIPv6Addr2Bytes),
makeTypeAAAAResource(fooExample, 5, fooExampleIPv6Addr1Bytes),
makeTypeAAAAResource(fooExample, 5, fooExampleIPv6Addr2Bytes),
})
// We check the default resolver by making sure the entries we populated the
// cache with is returned.
checkDefaultResolver := func() {
t.Helper()
if addrs, err := c.LookupIP(example); err != nil {
t.Fatalf("c.LookupIP(%q): %s", example, err)
} else {
if !containsAddress(addrs, exampleIPv4Addr1) {
t.Errorf("expected %s to bein the list of addresses for %s; got = %s", exampleIPv4Addr1, example, addrs)
}
if !containsAddress(addrs, exampleIPv4Addr2) {
t.Errorf("expected %s to bein the list of addresses for %s; got = %s", exampleIPv4Addr2, example, addrs)
}
if !containsAddress(addrs, exampleIPv6Addr1) {
t.Errorf("expected %s to bein the list of addresses for %s; got = %s", exampleIPv6Addr1, example, addrs)
}
if !containsAddress(addrs, exampleIPv6Addr2) {
t.Errorf("expected %s to bein the list of addresses for %s; got = %s", exampleIPv6Addr2, example, addrs)
}
if l := len(addrs); l != 4 {
t.Errorf("got len(addrs) = %d, want = 4; addrs = %s", l, addrs)
}
}
if addrs, err := c.LookupIP(fooExample); err != nil {
t.Fatalf("c.LookupIP(%q): %s", fooExample, err)
} else {
if !containsAddress(addrs, fooExampleIPv4Addr1) {
t.Errorf("expected %s to bein the list of addresses for %s; got = %s", fooExampleIPv4Addr1, fooExample, addrs)
}
if !containsAddress(addrs, fooExampleIPv4Addr2) {
t.Errorf("expected %s to bein the list of addresses for %s; got = %s", fooExampleIPv4Addr2, fooExample, addrs)
}
if !containsAddress(addrs, fooExampleIPv6Addr1) {
t.Errorf("expected %s to bein the list of addresses for %s; got = %s", fooExampleIPv6Addr1, fooExample, addrs)
}
if !containsAddress(addrs, fooExampleIPv6Addr2) {
t.Errorf("expected %s to bein the list of addresses for %s; got = %s", fooExampleIPv6Addr2, fooExample, addrs)
}
if l := len(addrs); l != 4 {
t.Errorf("got len(addrs) = %d, want = 4; addrs = %s", l, addrs)
}
}
}
// c should be initialized with the default resolver.
checkDefaultResolver()
// Update c to use fakeResolver as its resolver.
c.SetResolver(fakeResolver)
if addrs, err := c.LookupIP(example); err != nil {
t.Fatalf("c.LookupIP(%q): %s", example, err)
} else {
if diff := cmp.Diff(fakeResolverResponse, addrs); diff != "" {
t.Errorf("domain name addresses mismatch (-want +got):\n%s", diff)
}
}
if addrs, err := c.LookupIP(fooExample); err != nil {
t.Fatalf("c.LookupIP(%q): %s", fooExample, err)
} else {
if diff := cmp.Diff(fakeResolverResponse, addrs); diff != "" {
t.Errorf("domain name addresses mismatch (-want +got):\n%s", diff)
}
}
// A nil Resolver should update c's resolver to the default resolver.
c.SetResolver(nil)
checkDefaultResolver()
}
func TestGetServersCache(t *testing.T) {
// We set stack.Stack to nil since thats not what we are testing here.
c := NewClient(nil)
addr3 := addr3
addr3.NIC = addr4.NIC
addr4WithPort := addr4
addr4WithPort.Port = defaultDNSPort
c.SetDefaultServers([]tcpip.Address{addr5.Addr, addr6.Addr})
servers := c.GetServersCache()
if !containsFullAddress(servers, addr5) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr5, servers)
}
if !containsFullAddress(servers, addr6) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr6, servers)
}
if l := len(servers); l != 2 {
t.Errorf("got len(servers) = %d, want = 2; servers = %+v", l, servers)
}
if t.Failed() {
t.FailNow()
}
runtimeServers1 := []tcpip.Address{addr7.Addr, addr8.Addr}
runtimeServers2 := []tcpip.Address{addr9.Addr, addr10.Addr}
c.SetRuntimeServers([]*[]tcpip.Address{&runtimeServers1, &runtimeServers2})
servers = c.GetServersCache()
if !containsFullAddress(servers, addr5) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr5, servers)
}
if !containsFullAddress(servers, addr6) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr6, servers)
}
if !containsFullAddress(servers, addr7) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr7, servers)
}
if !containsFullAddress(servers, addr8) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr8, servers)
}
if !containsFullAddress(servers, addr9) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr9, servers)
}
if !containsFullAddress(servers, addr10) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr10, servers)
}
if l := len(servers); l != 6 {
t.Errorf("got len(servers) = %d, want = 6; servers = %+v", l, servers)
}
if t.Failed() {
t.FailNow()
}
c.UpdateExpiringServers([]tcpip.FullAddress{addr1, addr2, addr3, addr4}, longLifetime)
servers = c.GetServersCache()
if !containsFullAddress(servers, addr1) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr1, servers)
}
if !containsFullAddress(servers, addr2) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr2, servers)
}
if !containsFullAddress(servers, addr3) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr3, servers)
}
if !containsFullAddress(servers, addr4WithPort) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr4WithPort, servers)
}
if !containsFullAddress(servers, addr5) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr5, servers)
}
if !containsFullAddress(servers, addr6) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr6, servers)
}
if !containsFullAddress(servers, addr7) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr7, servers)
}
if !containsFullAddress(servers, addr8) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr8, servers)
}
if !containsFullAddress(servers, addr9) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr9, servers)
}
if !containsFullAddress(servers, addr10) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr10, servers)
}
if l := len(servers); l != 10 {
t.Errorf("got len(servers) = %d, want = 10; servers = %+v", l, servers)
}
// Should get the same results since there were no updates.
if diff := cmp.Diff(servers, c.GetServersCache()); diff != "" {
t.Errorf("c.GetServersCache() mismatch (-want +got):\n%s", diff)
}
if t.Failed() {
t.FailNow()
}
c.SetRuntimeServers(nil)
servers = c.GetServersCache()
if !containsFullAddress(servers, addr1) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr1, servers)
}
if !containsFullAddress(servers, addr2) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr2, servers)
}
if !containsFullAddress(servers, addr3) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr3, servers)
}
if !containsFullAddress(servers, addr4WithPort) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr4WithPort, servers)
}
if !containsFullAddress(servers, addr5) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr5, servers)
}
if !containsFullAddress(servers, addr6) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr6, servers)
}
if l := len(servers); l != 6 {
t.Errorf("got len(servers) = %d, want = 6; servers = %+v", l, servers)
}
// Should remove addr3 and addr4.
c.RemoveAllServersWithNIC(addr4.NIC)
servers = c.GetServersCache()
if !containsFullAddress(servers, addr1) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr1, servers)
}
if !containsFullAddress(servers, addr2) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr2, servers)
}
if !containsFullAddress(servers, addr5) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr5, servers)
}
if !containsFullAddress(servers, addr6) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr6, servers)
}
if l := len(servers); l != 4 {
t.Errorf("got len(servers) = %d, want = 4; servers = %+v", l, servers)
}
if t.Failed() {
t.FailNow()
}
c.SetDefaultServers(nil)
servers = c.GetServersCache()
if !containsFullAddress(servers, addr1) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr1, servers)
}
if !containsFullAddress(servers, addr2) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr2, servers)
}
if l := len(servers); l != 2 {
t.Errorf("got len(servers) = %d, want = 2; servers = %+v", l, servers)
}
if t.Failed() {
t.FailNow()
}
c.UpdateExpiringServers([]tcpip.FullAddress{addr1, addr2}, 0)
servers = c.GetServersCache()
if l := len(servers); l != 0 {
t.Errorf("got len(servers) = %d, want = 0; servers = %+v", l, servers)
}
}
func TestExpiringServersDefaultDNSPort(t *testing.T) {
// We set stack.Stack to nil since thats not what we are testing here.
c := NewClient(nil)
addr4WithPort := addr4
addr4WithPort.Port = defaultDNSPort
c.UpdateExpiringServers([]tcpip.FullAddress{addr4}, longLifetime)
servers := c.GetServersCache()
if !containsFullAddress(servers, addr4WithPort) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr4WithPort, servers)
}
if l := len(servers); l != 1 {
t.Errorf("got len(servers) = %d, want = 1; servers = %+v", l, servers)
}
}
func TestExpiringServersUpdateWithDuplicates(t *testing.T) {
// We set stack.Stack to nil since thats not what we are testing here.
c := NewClient(nil)
c.UpdateExpiringServers([]tcpip.FullAddress{addr1, addr1, addr1}, longLifetime)
servers := c.GetServersCache()
if !containsFullAddress(servers, addr1) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr1, servers)
}
if l := len(servers); l != 1 {
t.Errorf("got len(servers) = %d, want = 1; servers = %+v", l, servers)
}
}
func TestExpiringServersAddAndUpdate(t *testing.T) {
// We set stack.Stack to nil since thats not what we are testing here.
c := NewClient(nil)
c.UpdateExpiringServers([]tcpip.FullAddress{addr1, addr2}, longLifetime)
servers := c.GetServersCache()
if !containsFullAddress(servers, addr1) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr1, servers)
}
if !containsFullAddress(servers, addr2) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr2, servers)
}
if l := len(servers); l != 2 {
t.Errorf("got len(servers) = %d, want = 2; servers = %+v", l, servers)
}
if t.Failed() {
t.FailNow()
}
// Refresh addr1 and addr2, add addr3.
c.UpdateExpiringServers([]tcpip.FullAddress{addr1, addr3, addr2}, longLifetime)
servers = c.GetServersCache()
if !containsFullAddress(servers, addr1) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr1, servers)
}
if !containsFullAddress(servers, addr2) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr2, servers)
}
if !containsFullAddress(servers, addr3) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr3, servers)
}
if l := len(servers); l != 3 {
t.Errorf("got len(servers) = %d, want = 3; servers = %+v", l, servers)
}
if t.Failed() {
t.FailNow()
}
// Lifetime of 0 should remove servers if they exist.
c.UpdateExpiringServers([]tcpip.FullAddress{addr4, addr1}, 0)
servers = c.GetServersCache()
if containsFullAddress(servers, addr1) {
t.Errorf("expected %+v to not be in the server cache, got = %+v", addr1, servers)
}
if containsFullAddress(servers, addr4) {
t.Errorf("expected %+v to not be in the server cache, got = %+v", addr4, servers)
}
if !containsFullAddress(servers, addr2) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr2, servers)
}
if !containsFullAddress(servers, addr3) {
t.Errorf("expected %+v to be in the server cache, got = %+v", addr3, servers)
}
if l := len(servers); l != 2 {
t.Errorf("got len(servers) = %d, want = 2; servers = %+v", l, servers)
}
}
func TestExpiringServersExpireImmediatelyTimer(t *testing.T) {
t.Parallel()
// We set stack.Stack to nil since thats not what we are testing here.
c := NewClient(nil)
c.UpdateExpiringServers([]tcpip.FullAddress{addr1, addr2}, shortLifetime)
for elapsedTime := time.Duration(0); elapsedTime <= shortLifetimeTimeout; elapsedTime += incrementalTimeout {
time.Sleep(incrementalTimeout)
servers := c.GetServersCache()
if l := len(servers); l != 0 {
if elapsedTime < middleLifetimeTimeout {
continue
}
t.Fatalf("got len(servers) = %d, want = 0; servers = %+v", l, servers)
}
break
}
}
func TestExpiringServersExpireAfterUpdate(t *testing.T) {
t.Parallel()
// We set stack.Stack to nil since thats not what we are testing here.
c := NewClient(nil)
c.UpdateExpiringServers([]tcpip.FullAddress{addr1, addr2}, longLifetime)
servers := c.GetServersCache()
if l := len(servers); l != 2 {
t.Fatalf("got len(servers) = %d, want = 2; servers = %+v", l, servers)
}
// addr2 and addr3 should expire, but addr1 should stay.
c.UpdateExpiringServers([]tcpip.FullAddress{addr2, addr3}, shortLifetime)
for elapsedTime := time.Duration(0); elapsedTime <= shortLifetimeTimeout; elapsedTime += incrementalTimeout {
time.Sleep(incrementalTimeout)
servers = c.GetServersCache()
if !containsFullAddress(servers, addr1) {
if elapsedTime < middleLifetimeTimeout {
continue
}
t.Errorf("expected %+v to be in the server cache, got = %+v", addr2, servers)
}
if containsFullAddress(servers, addr2) {
if elapsedTime < middleLifetimeTimeout {
continue
}
t.Errorf("expected %+v to not be in the server cache, got = %+v", addr3, servers)
}
if containsFullAddress(servers, addr3) {
if elapsedTime < middleLifetimeTimeout {
continue
}
t.Errorf("expected %+v to not be in the server cache, got = %+v", addr3, servers)
}
if l := len(servers); l != 1 {
if elapsedTime < middleLifetimeTimeout {
continue
}
t.Errorf("got len(servers) = %d, want = 1; servers = %+v", l, servers)
}
break
}
}
func TestExpiringServersInfiniteLifetime(t *testing.T) {
t.Parallel()
// We set stack.Stack to nil since thats not what we are testing here.
c := NewClient(nil)
c.UpdateExpiringServers([]tcpip.FullAddress{addr1, addr2}, middleLifetime)
servers := c.GetServersCache()
if l := len(servers); l != 2 {
t.Fatalf("got len(servers) = %d, want = 2; servers = %+v", l, servers)
}
// addr1 should expire, but addr2 and addr3 should be valid forever.
c.UpdateExpiringServers([]tcpip.FullAddress{addr2, addr3}, -1)
for elapsedTime := time.Duration(0); elapsedTime < middleLifetimeTimeout; elapsedTime += incrementalTimeout {
time.Sleep(incrementalTimeout)
servers = c.GetServersCache()
if containsFullAddress(servers, addr1) {
if elapsedTime < middleLifetimeTimeout {
continue
}
t.Errorf("expected %+v to not be in the server cache, got = %+v", addr2, servers)
}
if !containsFullAddress(servers, addr2) {
if elapsedTime < middleLifetimeTimeout {
continue
}
t.Errorf("expected %+v to be in the server cache, got = %+v", addr3, servers)
}
if !containsFullAddress(servers, addr3) {
if elapsedTime < middleLifetimeTimeout {
continue
}
t.Errorf("expected %+v to be in the server cache, got = %+v", addr3, servers)
}
if l := len(servers); l != 2 {
if elapsedTime < middleLifetimeTimeout {
continue
}
t.Errorf("got len(servers) = %d, want = 2; servers = %+v", l, servers)
}
break
}
if t.Failed() {
t.FailNow()
}
c.UpdateExpiringServers([]tcpip.FullAddress{addr2, addr3}, middleLifetime)
for elapsedTime := time.Duration(0); elapsedTime <= middleLifetimeTimeout; elapsedTime += incrementalTimeout {
time.Sleep(incrementalTimeout)
servers = c.GetServersCache()
if containsFullAddress(servers, addr1) {
if elapsedTime < middleLifetimeTimeout {
continue
}
t.Errorf("expected %+v to not be in the server cache, got = %+v", addr2, servers)
}
if containsFullAddress(servers, addr2) {
if elapsedTime < middleLifetimeTimeout {
continue
}
t.Errorf("expected %+v to not be in the server cache, got = %+v", addr3, servers)
}
if containsFullAddress(servers, addr3) {
if elapsedTime < middleLifetimeTimeout {
continue
}
t.Errorf("expected %+v to not be in the server cache, got = %+v", addr3, servers)
}
if l := len(servers); l != 0 {
if elapsedTime < middleLifetimeTimeout {
continue
}
t.Errorf("got len(servers) = %d, want = 0; servers = %+v", l, servers)
}
break
}
}