| // Copyright 2009 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // DNS client: see RFC 1035. |
| // Has to be linked into package net for Dial. |
| |
| // TODO(rsc): |
| // Could potentially handle many outstanding lookups faster. |
| // Could have a small cache. |
| // Random UDP source port (net.Dial should do that for us). |
| // Random request IDs. |
| // TODO(mpcomplete): |
| // Cleanup |
| // Decide whether we need DNSSEC, EDNS0, reverse DNS or other query types |
| // We don't support ipv6 zones. Do we need to? |
| |
| package dns |
| |
| import ( |
| "context" |
| "errors" |
| "fmt" |
| "math/rand" |
| "strings" |
| "sync" |
| "time" |
| |
| "github.com/google/netstack/dns/dnsmessage" |
| "github.com/google/netstack/tcpip" |
| "github.com/google/netstack/tcpip/network/ipv4" |
| "github.com/google/netstack/tcpip/transport/tcp" |
| "github.com/google/netstack/tcpip/transport/udp" |
| "github.com/google/netstack/waiter" |
| ) |
| |
| // TODO(mpcomplete): Use FIDL to fetch the DNS config from the parent process. |
| var tmpDNSConfig = dnsConfig{ |
| servers: []tcpip.FullAddress{{Addr: tcpip.Address([]byte{8, 8, 8, 8}), Port: 53}}, |
| search: []string{}, |
| ndots: 100, |
| timeout: 3 * time.Second, |
| attempts: 3, |
| rotate: true, |
| unknownOpt: false, |
| lookup: []string{}, |
| err: nil, |
| mtime: time.Now(), |
| } |
| |
| // Client is a DNS client. |
| type Client struct { |
| stack tcpip.Stack |
| nicid tcpip.NICID |
| } |
| |
| // A Resolver answers DNS Questions. |
| type Resolver func(c *Client, question dnsmessage.Question) (cname string, rrs []dnsmessage.Resource, msg *dnsmessage.Message, err error) |
| |
| // Error represents an error while issuing a DNS query for a hostname. |
| type Error struct { |
| Err string // a general error string |
| Name string // the hostname being queried |
| Server *tcpip.FullAddress // optional DNS server |
| IsTimeout bool // true if the operation timed out |
| } |
| |
| func (e *Error) Error() string { |
| if e.Server != nil { |
| return fmt.Sprintf("lookup %s on %v: %s", e.Name, e.Server, e.Err) |
| } |
| return fmt.Sprintf("lookup %s: %s", e.Name, e.Err) |
| } |
| |
| // NewClient creates a DHCP client. |
| func NewClient(s tcpip.Stack, nicid tcpip.NICID) *Client { |
| return &Client{ |
| stack: s, |
| nicid: nicid, |
| } |
| } |
| |
| // roundTrip writes the query to and reads the response from the Endpoint. |
| // The message format is slightly different depending on the transport protocol |
| // (for TCP, a 2 byte message length is prepended). See RFC 1035. |
| func roundTrip(ctx context.Context, transport tcpip.TransportProtocolNumber, ep tcpip.Endpoint, wq *waiter.Queue, query *dnsmessage.Message) (response *dnsmessage.Message, err error) { |
| b, err := query.Pack() |
| if err != nil { |
| return nil, err |
| } |
| if transport == tcp.ProtocolNumber { |
| l := len(b) |
| b = append([]byte{byte(l >> 8), byte(l)}, b...) |
| } |
| |
| // Write to endpoint. |
| for len(b) > 0 { |
| n, err := ep.Write(b, nil) |
| if err != nil { |
| return nil, err |
| } |
| |
| b = b[n:] |
| } |
| |
| // Read from endpoint. |
| b = []byte{} |
| waitEntry, notifyCh := waiter.NewChannelEntry(nil) |
| wq.EventRegister(&waitEntry, waiter.EventIn) |
| defer wq.EventUnregister(&waitEntry) |
| for { |
| v, err := ep.Read(nil) |
| if err != nil { |
| if err == tcpip.ErrClosedForReceive { |
| break |
| } |
| |
| if err == tcpip.ErrWouldBlock { |
| select { |
| case <-notifyCh: |
| continue |
| case <-ctx.Done(): |
| return nil, tcpip.ErrTimeout |
| } |
| } |
| |
| return nil, err |
| } |
| |
| b = append(b, []byte(v)...) |
| |
| // Get the contents of the response. |
| var bcontents []byte |
| switch transport { |
| case tcp.ProtocolNumber: |
| if len(b) > 2 { |
| l := int(b[0])<<8 | int(b[1]) |
| bcontents = b[2:(l + 2)] |
| } else { |
| continue |
| } |
| case udp.ProtocolNumber: |
| bcontents = b |
| } |
| |
| response = &dnsmessage.Message{} |
| if err := response.Unpack(bcontents); err != nil { |
| // Ignore invalid responses as they may be malicious |
| // forgery attempts. Instead continue waiting until |
| // timeout. See golang.org/issue/13281. |
| continue |
| } |
| break |
| } |
| |
| return response, nil |
| } |
| |
| func (c *Client) connect(ctx context.Context, transport tcpip.TransportProtocolNumber, server tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, error) { |
| var wq waiter.Queue |
| ep, err := c.stack.NewEndpoint(transport, ipv4.ProtocolNumber, &wq) |
| if err != nil { |
| return nil, nil, err |
| } |
| |
| // Issue connect request and wait for it to complete. |
| waitEntry, notifyCh := waiter.NewChannelEntry(nil) |
| wq.EventRegister(&waitEntry, waiter.EventOut) |
| err = ep.Connect(server) |
| defer wq.EventUnregister(&waitEntry) |
| if err == tcpip.ErrConnectStarted { |
| select { |
| case <-notifyCh: |
| err = ep.GetSockOpt(tcpip.ErrorOption{}) |
| case <-ctx.Done(): |
| return nil, nil, tcpip.ErrTimeout |
| } |
| } |
| |
| if err != nil { |
| return nil, nil, err |
| } |
| |
| return ep, &wq, nil |
| } |
| |
| // exchange sends a query on the connection and hopes for a response. |
| func (c *Client) exchange(server tcpip.FullAddress, name string, qtype dnsmessage.Type, timeout time.Duration) (response *dnsmessage.Message, err error) { |
| query := dnsmessage.Message{ |
| Header: dnsmessage.Header{ |
| RecursionDesired: true, |
| }, |
| Questions: []dnsmessage.Question{ |
| {name, qtype, dnsmessage.ClassINET}, |
| }, |
| } |
| |
| protos := []tcpip.TransportProtocolNumber{udp.ProtocolNumber, tcp.ProtocolNumber} |
| for _, proto := range protos { |
| ctx, cancel := context.WithTimeout(context.Background(), timeout) |
| |
| ep, wq, err := c.connect(ctx, proto, server) |
| if err != nil { |
| cancel() |
| return nil, err |
| } |
| |
| query.ID = uint16(rand.Int()) ^ uint16(time.Now().UnixNano()) |
| response, err = roundTrip(ctx, proto, ep, wq, &query) |
| cancel() |
| |
| if err != nil { |
| return nil, err |
| } |
| if response.Truncated { // see RFC 5966 |
| continue |
| } |
| return response, nil |
| } |
| return nil, errors.New("no answer from the DNS server") |
| } |
| |
| // Do a lookup for a single name, which must be rooted |
| // (otherwise answer will not find the answers). |
| func (c *Client) tryOneName(cfg *dnsConfig, name string, qtype dnsmessage.Type) (string, []dnsmessage.Resource, *dnsmessage.Message, error) { |
| if len(cfg.servers) == 0 { |
| return "", nil, nil, &Error{Err: "no DNS servers", Name: name} |
| } |
| |
| var lastErr error |
| for i := 0; i < cfg.attempts; i++ { |
| for _, server := range cfg.servers { |
| server := tcpip.FullAddress{ |
| NIC: c.nicid, |
| Addr: server.Addr, |
| Port: server.Port, |
| } |
| msg, err := c.exchange(server, name, qtype, cfg.timeout) |
| if err != nil { |
| lastErr = &Error{ |
| Err: err.Error(), |
| Name: name, |
| Server: &server, |
| } |
| continue |
| } |
| // libresolv continues to the next server when it receives |
| // an invalid referral response. See golang.org/issue/15434. |
| if msg.RCode == dnsmessage.RCodeSuccess && !msg.Authoritative && !msg.RecursionAvailable && len(msg.Answers) == 0 && len(msg.Additionals) == 0 { |
| lastErr = &Error{Err: "lame referral", Name: name, Server: &server} |
| continue |
| } |
| cname, rrs, err := answer(name, server, msg, qtype) |
| // If answer errored for rcodes dnsRcodeSuccess or dnsRcodeNameError, |
| // it means the response in msg was not useful and trying another |
| // server probably won't help. Return now in those cases. |
| // TODO: indicate this in a more obvious way, such as a field on Error? |
| if err == nil || msg.RCode == dnsmessage.RCodeSuccess || msg.RCode == dnsmessage.RCodeNameError { |
| return cname, rrs, msg, err |
| } |
| lastErr = err |
| } |
| } |
| return "", nil, nil, lastErr |
| } |
| |
| // addrRecordList converts and returns a list of IP addresses from DNS |
| // address records (both A and AAAA). Other record types are ignored. |
| func addrRecordList(rrs []dnsmessage.Resource) []tcpip.Address { |
| addrs := make([]tcpip.Address, 0, 4) |
| for _, rr := range rrs { |
| switch rr := rr.(type) { |
| case *dnsmessage.AResource: |
| addrs = append(addrs, tcpip.Address(rr.A[:])) |
| case *dnsmessage.AAAAResource: |
| addrs = append(addrs, tcpip.Address(rr.AAAA[:])) |
| } |
| } |
| return addrs |
| } |
| |
| // A resolverConfig represents a DNS stub resolver configuration. |
| type resolverConfig struct { |
| initOnce sync.Once // guards init of resolverConfig |
| |
| stk tcpip.Stack |
| |
| // ch is used as a semaphore that only allows one lookup at a |
| // time to recheck resolv.conf. |
| ch chan struct{} // guards lastChecked and modTime |
| lastChecked time.Time // last time resolv.conf was checked |
| |
| mu sync.RWMutex // protects dnsConfig |
| dnsConfig *dnsConfig // parsed resolv.conf structure used in lookups |
| } |
| |
| type dnsConfig struct { |
| servers []tcpip.FullAddress // server addresses (host and port) to use |
| search []string // rooted suffixes to append to local name |
| ndots int // number of dots in name to trigger absolute lookup |
| timeout time.Duration // wait before giving up on a query, including retries |
| attempts int // lost packets before giving up on server |
| rotate bool // round robin among servers |
| unknownOpt bool // anything unknown was encountered |
| lookup []string // OpenBSD top-level database "lookup" order |
| err error // any error that occurs during open of resolv.conf |
| mtime time.Time // time of resolv.conf modification |
| resolver Resolver // a handler which answers DNS Questions |
| } |
| |
| var resolvConf resolverConfig |
| |
| func newNetworkResolver(config *dnsConfig) Resolver { |
| return func(c *Client, question dnsmessage.Question) (string, []dnsmessage.Resource, *dnsmessage.Message, error) { |
| return c.tryOneName(config, question.Name, question.Type) |
| } |
| } |
| |
| func readConfig() *dnsConfig { |
| config := tmpDNSConfig |
| config.resolver = newCachedResolver(newNetworkResolver(&config)) |
| return &config |
| } |
| |
| // init initializes conf and is only called via conf.initOnce. |
| func (conf *resolverConfig) init() { |
| // Set dnsConfig and lastChecked so we don't parse |
| // resolv.conf twice the first time. |
| if conf.dnsConfig == nil { |
| conf.dnsConfig = readConfig() |
| } |
| conf.lastChecked = time.Now() |
| |
| // Prepare ch so that only one update of resolverConfig may |
| // run at once. |
| conf.ch = make(chan struct{}, 1) |
| } |
| |
| // tryUpdate tries to update conf with the named resolv.conf file. |
| // The name variable only exists for testing. It is otherwise always |
| // "/etc/resolv.conf". |
| func (conf *resolverConfig) tryUpdate() { |
| conf.initOnce.Do(conf.init) |
| |
| // Ensure only one update at a time checks resolv.conf. |
| if !conf.tryAcquireSema() { |
| return |
| } |
| defer conf.releaseSema() |
| |
| now := time.Now() |
| if conf.lastChecked.After(now.Add(-5 * time.Second)) { |
| return |
| } |
| conf.lastChecked = now |
| |
| dnsConf := readConfig() |
| conf.mu.Lock() |
| conf.dnsConfig = dnsConf |
| conf.mu.Unlock() |
| } |
| |
| func (conf *resolverConfig) tryAcquireSema() bool { |
| select { |
| case conf.ch <- struct{}{}: |
| return true |
| default: |
| return false |
| } |
| } |
| |
| func (conf *resolverConfig) releaseSema() { |
| <-conf.ch |
| } |
| |
| // avoidDNS reports whether this is a hostname for which we should not |
| // use DNS. Currently this includes only .onion, per RFC 7686. See |
| // golang.org/issue/13705. Does not cover .local names (RFC 6762), |
| // see golang.org/issue/16739. |
| func avoidDNS(name string) bool { |
| if name == "" { |
| return true |
| } |
| if name[len(name)-1] == '.' { |
| name = name[:len(name)-1] |
| } |
| return strings.HasSuffix(name, ".onion") |
| } |
| |
| // nameList returns a list of names for sequential DNS queries. |
| func (conf *dnsConfig) nameList(name string) []string { |
| if avoidDNS(name) { |
| return nil |
| } |
| |
| // If name is rooted (trailing dot), try only that name. |
| rooted := len(name) > 0 && name[len(name)-1] == '.' |
| if rooted { |
| return []string{name} |
| } |
| |
| // hasNdots := count(name, '.') >= conf.ndots |
| hasNdots := false |
| name += "." |
| |
| // Build list of search choices. |
| names := make([]string, 0, 1+len(conf.search)) |
| // If name has enough dots, try unsuffixed first. |
| if hasNdots { |
| names = append(names, name) |
| } |
| // Try suffixes. |
| for _, suffix := range conf.search { |
| names = append(names, name+suffix) |
| } |
| // Try unsuffixed, if not tried first above. |
| if !hasNdots { |
| names = append(names, name) |
| } |
| return names |
| } |
| |
| // LookupIP returns a list of IP addresses that are registered for the give domain name. |
| func (c *Client) LookupIP(name string) (addrs []tcpip.Address, err error) { |
| if !isDomainName(name) { |
| return nil, &Error{Err: "invalid domain name", Name: name} |
| } |
| resolvConf.tryUpdate() |
| resolvConf.mu.RLock() |
| conf := resolvConf.dnsConfig |
| resolvConf.mu.RUnlock() |
| type racer struct { |
| fqdn string |
| rrs []dnsmessage.Resource |
| error |
| } |
| lane := make(chan racer, 1) |
| qtypes := [...]dnsmessage.Type{dnsmessage.TypeA, dnsmessage.TypeAAAA} |
| var lastErr error |
| for _, fqdn := range conf.nameList(name) { |
| for _, qtype := range qtypes { |
| go func(qtype dnsmessage.Type) { |
| _, rrs, _, err := conf.resolver(c, dnsmessage.Question{Name: fqdn, Type: qtype, Class: dnsmessage.ClassINET}) |
| lane <- racer{fqdn, rrs, err} |
| }(qtype) |
| } |
| for range qtypes { |
| racer := <-lane |
| if racer.error != nil { |
| // Prefer error for original name. |
| if lastErr == nil || racer.fqdn == name+"." { |
| lastErr = racer.error |
| } |
| continue |
| } |
| addrs = append(addrs, addrRecordList(racer.rrs)...) |
| } |
| if len(addrs) > 0 { |
| break |
| } |
| } |
| if lastErr, ok := lastErr.(*Error); ok { |
| // Show original name passed to lookup, not suffixed one. |
| // In general we might have tried many suffixes; showing |
| // just one is misleading. See also golang.org/issue/6324. |
| lastErr.Name = name |
| } |
| sortByRFC6724(c, addrs) |
| if len(addrs) == 0 && lastErr != nil { |
| return nil, lastErr |
| } |
| return addrs, nil |
| } |
| |
| const noSuchHost = "no such host" |
| |
| // Answer extracts the appropriate answer for a DNS lookup |
| // for (name, qtype) from the response message msg, which |
| // is assumed to have come from server. |
| // It is exported mainly for use by registered helpers. |
| func answer(name string, server tcpip.FullAddress, msg *dnsmessage.Message, qtype dnsmessage.Type) (cname string, addrs []dnsmessage.Resource, err error) { |
| addrs = make([]dnsmessage.Resource, 0, len(msg.Answers)) |
| |
| if msg.RCode == dnsmessage.RCodeNameError { |
| return "", nil, &Error{Err: noSuchHost, Name: name, Server: &server} |
| } |
| if msg.RCode != dnsmessage.RCodeSuccess { |
| // None of the error codes make sense |
| // for the query we sent. If we didn't get |
| // a name error and we didn't get success, |
| // the server is behaving incorrectly. |
| return "", nil, &Error{Err: "server misbehaving", Name: name, Server: &server} |
| } |
| |
| // Look for the name. |
| // Presotto says it's okay to assume that servers listed in |
| // /etc/resolv.conf are recursive resolvers. |
| // We asked for recursion, so it should have included |
| // all the answers we need in this one packet. |
| Cname: |
| for cnameloop := 0; cnameloop < 10; cnameloop++ { |
| addrs = addrs[0:0] |
| for _, rr := range msg.Answers { |
| h := rr.Header() |
| if h.Class == dnsmessage.ClassINET && equalASCIILabel(h.Name, name) { |
| switch h.Type { |
| case qtype: |
| addrs = append(addrs, rr) |
| case dnsmessage.TypeCNAME: |
| // redirect to cname |
| name = rr.(*dnsmessage.CNAMEResource).CNAME |
| continue Cname |
| } |
| } |
| } |
| if len(addrs) == 0 { |
| return "", nil, &Error{Err: noSuchHost, Name: name, Server: &server} |
| } |
| return name, addrs, nil |
| } |
| |
| return "", nil, &Error{Err: "too many redirects", Name: name, Server: &server} |
| } |
| |
| func equalASCIILabel(x, y string) bool { |
| if len(x) != len(y) { |
| return false |
| } |
| for i := 0; i < len(x); i++ { |
| a := x[i] |
| b := y[i] |
| if 'A' <= a && a <= 'Z' { |
| a += 0x20 |
| } |
| if 'A' <= b && b <= 'Z' { |
| b += 0x20 |
| } |
| if a != b { |
| return false |
| } |
| } |
| return true |
| } |
| |
| func isDomainName(s string) bool { |
| // See RFC 1035, RFC 3696. |
| if len(s) == 0 { |
| return false |
| } |
| if len(s) > 255 { |
| return false |
| } |
| |
| last := byte('.') |
| ok := false // Ok once we've seen a letter. |
| partlen := 0 |
| for i := 0; i < len(s); i++ { |
| c := s[i] |
| switch { |
| default: |
| return false |
| case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': |
| ok = true |
| partlen++ |
| case '0' <= c && c <= '9': |
| // fine |
| partlen++ |
| case c == '-': |
| // Byte before dash cannot be dot. |
| if last == '.' { |
| return false |
| } |
| partlen++ |
| case c == '.': |
| // Byte before dot cannot be dot, dash. |
| if last == '.' || last == '-' { |
| return false |
| } |
| if partlen > 63 || partlen == 0 { |
| return false |
| } |
| partlen = 0 |
| } |
| last = c |
| } |
| if last == '-' || partlen > 63 { |
| return false |
| } |
| |
| return ok |
| } |