blob: c92998a4c6cc543a2800d9f54e9a19e13b6b29c4 [file] [log] [blame]
// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package netstack
import (
"context"
"sync"
"syscall/zx/dispatch"
"testing"
"time"
"netstack/dns"
"netstack/fidlconv"
"fidl/fuchsia/net"
"fidl/fuchsia/net/name"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/tcpip"
)
const (
defaultServerIpv4 tcpip.Address = "\x08\x08\x08\x08"
altServerIpv4 tcpip.Address = "\x08\x08\x04\x04"
defaultServerIpv6 tcpip.Address = "\x20\x01\x48\x60\x48\x60\x00\x00\x00\x00\x00\x00\x00\x00\x88\x88"
)
var (
defaultServerIPv4SocketAddress = fidlconv.ToNetSocketAddress(tcpip.FullAddress{
NIC: 0,
Addr: defaultServerIpv4,
Port: dns.DefaultDNSPort,
})
altServerIpv4SocketAddress = fidlconv.ToNetSocketAddress(tcpip.FullAddress{
NIC: 0,
Addr: altServerIpv4,
Port: dns.DefaultDNSPort,
})
defaultServerIpv6FullAddress = tcpip.FullAddress{
NIC: 0,
Addr: defaultServerIpv6,
Port: dns.DefaultDNSPort,
}
defaultServerIpv6SocketAddress = fidlconv.ToNetSocketAddress(defaultServerIpv6FullAddress)
staticDnsSource = name.DnsServerSource{
I_dnsServerSourceTag: name.DnsServerSourceStaticSource,
}
)
func makeDnsServer(address net.SocketAddress, source name.DnsServerSource) name.DnsServer {
var s name.DnsServer
s.SetAddress(address)
s.SetSource(source)
return s
}
func createCollection(dispatcher *dispatch.Dispatcher) *dnsServerWatcherCollection {
watcherCollection := newDnsServerWatcherCollection(dispatcher, dns.NewClient(nil))
watcherCollection.dnsClient.SetOnServersChanged(watcherCollection.NotifyServersChanged)
return watcherCollection
}
func bindWatcher(t *testing.T, watcherCollection *dnsServerWatcherCollection) *name.DnsServerWatcherWithCtxInterface {
request, watcher, err := name.NewDnsServerWatcherWithCtxInterfaceRequest()
if err != nil {
t.Fatalf("failed to create DnsServerWatcher channel pair: %s", err)
}
if err := watcherCollection.Bind(request); err != nil {
t.Fatalf("failed to bind watcher: %s", err)
}
return watcher
}
func TestDnsWatcherResolvesAndBlocks(t *testing.T) {
dispatcher, err := dispatch.NewDispatcher()
if err != nil {
t.Fatal(err)
}
var wg sync.WaitGroup
defer func() {
dispatcher.Close()
wg.Wait()
}()
wg.Add(1)
go func() {
defer wg.Done()
dispatcher.Serve()
}()
watcherCollection := createCollection(dispatcher)
watcherCollection.dnsClient.SetDefaultServers([]tcpip.Address{defaultServerIpv4, defaultServerIpv6})
watcher := bindWatcher(t, watcherCollection)
defer func() {
_ = watcher.Close()
}()
servers, err := watcher.WatchServers(context.Background())
if err != nil {
t.Fatalf("failed to call WatchServers: %s", err)
}
if diff := cmp.Diff([]name.DnsServer{
makeDnsServer(defaultServerIPv4SocketAddress, staticDnsSource),
makeDnsServer(defaultServerIpv6SocketAddress, staticDnsSource),
}, servers, cmpopts.IgnoreTypes(struct{}{})); diff != "" {
t.Fatalf("WatchServers() mismatch (-want +got):\n%s", diff)
}
type watchServersResult struct {
servers []name.DnsServer
err error
}
c := make(chan watchServersResult)
go func() {
servers, err := watcher.WatchServers(context.Background())
c <- watchServersResult{
servers: servers,
err: err,
}
}()
select {
case res := <-c:
t.Fatalf("WatchServers finished too early with %+v. Should've timed out.", res)
case <-time.After(50 * time.Millisecond):
}
// Clear the list of servers, now we should get something on the channel.
watcherCollection.dnsClient.SetDefaultServers(nil)
result := <-c
if result.err != nil {
t.Fatalf("WatchServers failed: %s", result.err)
}
if diff := cmp.Diff([]name.DnsServer{}, result.servers); diff != "" {
t.Fatalf("WatchServers() mismatch (-want +got):\n%s", diff)
}
}
func TestDnsWatcherCancelledContext(t *testing.T) {
dispatcher, err := dispatch.NewDispatcher()
if err != nil {
t.Fatal(err)
}
defer dispatcher.Close()
watcherCollection := createCollection(dispatcher)
watcher := dnsServerWatcher{
dnsClient: watcherCollection.dnsClient,
dispatcher: watcherCollection.dispatcher,
broadcast: &watcherCollection.broadcast,
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
if _, err := watcher.WatchServers(ctx); err == nil {
t.Fatal("expected cancelled context to produce an error")
}
watcher.mu.Lock()
if !watcher.mu.isDead {
t.Errorf("watcher must be marked as dead on context cancellation")
}
if watcher.mu.isHanging {
t.Errorf("watcher must not be marked as hanging on context cancellation")
}
watcher.mu.Unlock()
}
func TestDnsWatcherDisallowMultiplePending(t *testing.T) {
t.Skip("Go bindings don't currently support simultaneous calls, test is invalid.")
dispatcher, err := dispatch.NewDispatcher()
if err != nil {
t.Fatal(err)
}
defer dispatcher.Close()
watcherCollection := createCollection(dispatcher)
watcher := bindWatcher(t, watcherCollection)
defer func() {
_ = watcher.Close()
}()
var wg sync.WaitGroup
defer wg.Wait()
for i := 0; i < 2; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, err := watcher.WatchServers(context.Background())
if err == nil {
t.Errorf("non-nil error watching servers, expected error")
}
}()
}
}
func TestDnsWatcherMultipleWatchers(t *testing.T) {
dispatcher, err := dispatch.NewDispatcher()
if err != nil {
t.Fatal(err)
}
var wg sync.WaitGroup
defer func() {
dispatcher.Close()
wg.Wait()
}()
wg.Add(1)
go func() {
defer wg.Done()
dispatcher.Serve()
}()
watcherCollection := createCollection(dispatcher)
watcher1 := bindWatcher(t, watcherCollection)
defer func() {
_ = watcher1.Close()
}()
watcher2 := bindWatcher(t, watcherCollection)
defer func() {
_ = watcher2.Close()
}()
func() {
// Wait for these goroutines to do their thing before allowing the outer
// defers to tear down the watchers.
var wg sync.WaitGroup
defer wg.Wait()
for _, watcher := range []*name.DnsServerWatcherWithCtxInterface{watcher1, watcher2} {
wg.Add(1)
go func(watcher *name.DnsServerWatcherWithCtxInterface) {
defer wg.Done()
servers, err := watcher.WatchServers(context.Background())
if err != nil {
t.Errorf("WatchServers failed: %s", err)
return
}
if diff := cmp.Diff([]name.DnsServer{
makeDnsServer(defaultServerIPv4SocketAddress, staticDnsSource),
}, servers, cmpopts.IgnoreTypes(struct{}{})); diff != "" {
t.Errorf("WatchServers() mismatch (-want +got):\n%s", diff)
}
}(watcher)
}
watcherCollection.dnsClient.SetDefaultServers([]tcpip.Address{defaultServerIpv4})
}()
}
func TestDnsWatcherServerListEquality(t *testing.T) {
addr1 := dns.Server{
Address: tcpip.FullAddress{
NIC: 0,
Addr: defaultServerIpv4,
Port: dns.DefaultDNSPort,
},
Source: staticDnsSource,
}
addr2 := dns.Server{
Address: tcpip.FullAddress{
NIC: 0,
Addr: defaultServerIpv6,
Port: dns.DefaultDNSPort,
},
Source: staticDnsSource,
}
addr3 := dns.Server{
Address: tcpip.FullAddress{
NIC: 0,
Addr: defaultServerIpv4,
Port: 8080,
},
Source: staticDnsSource,
}
testEquality := func(equals bool, l, r []dns.Server) {
if serverListEquals(l, r) != equals {
t.Errorf("list comparison failed %+v == %+v should be %v, got %v", l, r, equals, !equals)
}
}
testEquality(true, []dns.Server{addr1, addr2, addr3}, []dns.Server{addr1, addr2, addr3})
testEquality(true, []dns.Server{}, nil)
testEquality(false, []dns.Server{addr1, addr2, addr3}, []dns.Server{addr3, addr2, addr1})
testEquality(false, []dns.Server{addr1}, []dns.Server{addr3, addr2, addr1})
testEquality(false, []dns.Server{addr1}, []dns.Server{addr2})
}
func TestDnsWatcherDifferentAddressTypes(t *testing.T) {
dispatcher, err := dispatch.NewDispatcher()
if err != nil {
t.Fatal(err)
}
var wg sync.WaitGroup
defer func() {
dispatcher.Close()
wg.Wait()
}()
wg.Add(1)
go func() {
defer wg.Done()
dispatcher.Serve()
}()
watcherCollection := createCollection(dispatcher)
watcher := bindWatcher(t, watcherCollection)
defer func() {
_ = watcher.Close()
}()
watcherCollection.dnsClient.SetDefaultServers([]tcpip.Address{defaultServerIpv4})
watcherCollection.dnsClient.UpdateDhcpServers(0, &[]tcpip.Address{altServerIpv4})
watcherCollection.dnsClient.UpdateNdpServers([]tcpip.FullAddress{defaultServerIpv6FullAddress}, -1)
checkServers := func(expect []net.SocketAddress) {
servers, err := watcher.WatchServers(context.Background())
if err != nil {
t.Fatalf("WatchServers failed: %s", err)
}
if len(servers) != len(expect) {
t.Fatalf("bad WatchServers result, expected %d servers got: %+v", len(expect), servers)
}
for i := range expect {
if servers[i].Address != expect[i] {
t.Fatalf("bad server at %d.\nExpected: %+v\nGot: %+v", i, expect[i], servers[i])
}
}
}
checkServers([]net.SocketAddress{defaultServerIpv6SocketAddress, altServerIpv4SocketAddress, defaultServerIPv4SocketAddress})
// Remove the expiring server and verify result.
watcherCollection.dnsClient.UpdateNdpServers([]tcpip.FullAddress{defaultServerIpv6FullAddress}, 0)
checkServers([]net.SocketAddress{altServerIpv4SocketAddress, defaultServerIPv4SocketAddress})
// Remove DHCP server and verify result.
watcherCollection.dnsClient.UpdateDhcpServers(0, nil)
checkServers([]net.SocketAddress{defaultServerIPv4SocketAddress})
// Remove the last server and verify result.
watcherCollection.dnsClient.SetDefaultServers(nil)
checkServers([]net.SocketAddress{})
}