blob: 1267ad23a67789df0f6edb0cdda127002bed9fe9 [file] [log] [blame]
// Copyright 2020 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dns
import (
"testing"
"github.com/google/go-cmp/cmp"
"golang.org/x/net/dns/dnsmessage"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
func TestResolver(t *testing.T) {
const nicID = 1
const nicIPv4Addr = tcpip.Address("\x01\x02\x03\x04")
const nicIPv6Addr = tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")
exampleIPv4Addr1Bytes := [4]byte{192, 168, 0, 1}
exampleIPv4Addr2Bytes := [4]byte{192, 168, 0, 2}
fooExampleIPv4Addr1Bytes := [4]byte{192, 168, 0, 3}
fooExampleIPv4Addr2Bytes := [4]byte{192, 168, 0, 4}
exampleIPv6Addr1Bytes := [16]byte{192, 168, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
exampleIPv6Addr2Bytes := [16]byte{192, 168, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
fooExampleIPv6Addr1Bytes := [16]byte{192, 168, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3}
fooExampleIPv6Addr2Bytes := [16]byte{192, 168, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4}
exampleIPv4Addr1 := tcpip.Address(exampleIPv4Addr1Bytes[:])
exampleIPv4Addr2 := tcpip.Address(exampleIPv4Addr2Bytes[:])
fooExampleIPv4Addr1 := tcpip.Address(fooExampleIPv4Addr1Bytes[:])
fooExampleIPv4Addr2 := tcpip.Address(fooExampleIPv4Addr2Bytes[:])
exampleIPv6Addr1 := tcpip.Address(exampleIPv6Addr1Bytes[:])
exampleIPv6Addr2 := tcpip.Address(exampleIPv6Addr2Bytes[:])
fooExampleIPv6Addr1 := tcpip.Address(fooExampleIPv6Addr1Bytes[:])
fooExampleIPv6Addr2 := tcpip.Address(fooExampleIPv6Addr2Bytes[:])
fakeIPv4AddrBytes := [4]byte{1, 2, 3, 4}
// A simple resolver that returns 1.2.3.4 for all A record questions.
fakeResolver := func(question dnsmessage.Question) (dnsmessage.Name, []dnsmessage.Resource, dnsmessage.Message, error) {
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: 0,
Response: true,
},
Questions: []dnsmessage.Question{question},
}
if question.Type == dnsmessage.TypeA {
r.Answers = []dnsmessage.Resource{
{
Header: dnsmessage.ResourceHeader{
Name: question.Name,
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
Length: 4,
},
Body: &dnsmessage.AResource{
A: fakeIPv4AddrBytes,
},
},
}
}
return question.Name, r.Answers, r, nil
}
fakeResolverResponse := []tcpip.Address{tcpip.Address(fakeIPv4AddrBytes[:])}
// We need a Stack because the default resolver tries to find a route to the
// servers to make sure a route exists. No packets are sent.
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{
arp.NewProtocol(),
ipv4.NewProtocol(),
ipv6.NewProtocol(),
},
TransportProtocols: []stack.TransportProtocol{
udp.NewProtocol(),
},
HandleLocal: true,
})
if err := s.CreateNIC(nicID, loopback.New()); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
if err := s.AddAddress(nicID, ipv4.ProtocolNumber, nicIPv4Addr); err != nil {
t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, nicIPv4Addr, err)
}
if err := s.AddAddress(nicID, ipv6.ProtocolNumber, nicIPv6Addr); err != nil {
t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, nicIPv6Addr, err)
}
c := NewClient(s)
// Add some entries to the cache that will be returned when testing the
// default resolver.
//
// We need to add both A and AAAA records to the cache for each domain name
// to make sure we do not try to query a DNS server through s.
c.cache.insertAll([]dnsmessage.Resource{
makeTypeAResource(example, 5, exampleIPv4Addr1Bytes),
makeTypeAResource(example, 5, exampleIPv4Addr2Bytes),
makeTypeAResource(fooExample, 5, fooExampleIPv4Addr1Bytes),
makeTypeAResource(fooExample, 5, fooExampleIPv4Addr2Bytes),
makeTypeAAAAResource(example, 5, exampleIPv6Addr1Bytes),
makeTypeAAAAResource(example, 5, exampleIPv6Addr2Bytes),
makeTypeAAAAResource(fooExample, 5, fooExampleIPv6Addr1Bytes),
makeTypeAAAAResource(fooExample, 5, fooExampleIPv6Addr2Bytes),
})
// We check the default resolver by making sure the entries we populated the
// cache with is returned.
checkDefaultResolver := func() {
t.Helper()
if addrs, err := c.LookupIP(example); err != nil {
t.Fatalf("c.LookupIP(%q): %s", example, err)
} else {
if !containsAddress(addrs, exampleIPv4Addr1) {
t.Errorf("expected %s to bein the list of addresses for %s; got = %s", exampleIPv4Addr1, example, addrs)
}
if !containsAddress(addrs, exampleIPv4Addr2) {
t.Errorf("expected %s to bein the list of addresses for %s; got = %s", exampleIPv4Addr2, example, addrs)
}
if !containsAddress(addrs, exampleIPv6Addr1) {
t.Errorf("expected %s to bein the list of addresses for %s; got = %s", exampleIPv6Addr1, example, addrs)
}
if !containsAddress(addrs, exampleIPv6Addr2) {
t.Errorf("expected %s to bein the list of addresses for %s; got = %s", exampleIPv6Addr2, example, addrs)
}
if l := len(addrs); l != 4 {
t.Errorf("got len(addrs) = %d, want = 4; addrs = %s", l, addrs)
}
}
if addrs, err := c.LookupIP(fooExample); err != nil {
t.Fatalf("c.LookupIP(%q): %s", fooExample, err)
} else {
if !containsAddress(addrs, fooExampleIPv4Addr1) {
t.Errorf("expected %s to bein the list of addresses for %s; got = %s", fooExampleIPv4Addr1, fooExample, addrs)
}
if !containsAddress(addrs, fooExampleIPv4Addr2) {
t.Errorf("expected %s to bein the list of addresses for %s; got = %s", fooExampleIPv4Addr2, fooExample, addrs)
}
if !containsAddress(addrs, fooExampleIPv6Addr1) {
t.Errorf("expected %s to bein the list of addresses for %s; got = %s", fooExampleIPv6Addr1, fooExample, addrs)
}
if !containsAddress(addrs, fooExampleIPv6Addr2) {
t.Errorf("expected %s to bein the list of addresses for %s; got = %s", fooExampleIPv6Addr2, fooExample, addrs)
}
if l := len(addrs); l != 4 {
t.Errorf("got len(addrs) = %d, want = 4; addrs = %s", l, addrs)
}
}
}
// c should be initialized with the default resolver.
checkDefaultResolver()
// Update c to use fakeResolver as its resolver.
c.SetResolver(fakeResolver)
if addrs, err := c.LookupIP(example); err != nil {
t.Fatalf("c.LookupIP(%q): %s", example, err)
} else {
if diff := cmp.Diff(fakeResolverResponse, addrs); diff != "" {
t.Errorf("domain name addresses mismatch (-want +got):\n%s", diff)
}
}
if addrs, err := c.LookupIP(fooExample); err != nil {
t.Fatalf("c.LookupIP(%q): %s", fooExample, err)
} else {
if diff := cmp.Diff(fakeResolverResponse, addrs); diff != "" {
t.Errorf("domain name addresses mismatch (-want +got):\n%s", diff)
}
}
// A nil Resolver should update c's resolver to the default resolver.
c.SetResolver(nil)
checkDefaultResolver()
}