Merge https://github.com/google/netstack
Change-Id: I8f075dc77fe7de97b5c1726171f63a97f47671ca
diff --git a/dhcp/client.go b/dhcp/client.go
index 85778f5..c40c3e0 100644
--- a/dhcp/client.go
+++ b/dhcp/client.go
@@ -9,7 +9,6 @@
"context"
"crypto/rand"
"fmt"
- "log"
"sync"
"time"
@@ -23,9 +22,10 @@
// Client is a DHCP client.
type Client struct {
- stack *stack.Stack
- nicid tcpip.NICID
- linkAddr tcpip.LinkAddress
+ stack *stack.Stack
+ nicid tcpip.NICID
+ linkAddr tcpip.LinkAddress
+ acquiredFunc func(old, new tcpip.Address, cfg Config)
mu sync.Mutex
addr tcpip.Address
@@ -37,29 +37,57 @@
// NewClient creates a DHCP client.
//
// TODO(crawshaw): add s.LinkAddr(nicid) to *stack.Stack.
-func NewClient(s *stack.Stack, nicid tcpip.NICID, linkAddr tcpip.LinkAddress) *Client {
+func NewClient(s *stack.Stack, nicid tcpip.NICID, linkAddr tcpip.LinkAddress, acquiredFunc func(old, new tcpip.Address, cfg Config)) *Client {
return &Client{
- stack: s,
- nicid: nicid,
- linkAddr: linkAddr,
+ stack: s,
+ nicid: nicid,
+ linkAddr: linkAddr,
+ acquiredFunc: acquiredFunc,
}
}
-// Start starts the DHCP client.
+// Run starts the DHCP client.
// It will periodically search for an IP address using the Request method.
-func (c *Client) Start() {
- go func() {
- for {
- log.Print("DHCP request")
- ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- err := c.Request(ctx, "")
- cancel()
- if err == nil {
- break
+func (c *Client) Run(ctx context.Context) {
+ go c.run(ctx)
+}
+
+func (c *Client) run(ctx context.Context) {
+ defer func() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.addr != "" {
+ c.stack.RemoveAddress(c.nicid, c.addr)
+ }
+ }()
+
+ var renewAddr tcpip.Address
+ for {
+ reqCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
+ cfg, err := c.Request(reqCtx, renewAddr)
+ cancel()
+ if err != nil {
+ select {
+ case <-time.After(1 * time.Second):
+ // loop and try again
+ case <-ctx.Done():
+ return
}
}
- log.Printf("DHCP acquired IP %s for %s", c.Address(), c.Config().LeaseLength)
- }()
+
+ c.mu.Lock()
+ renewAddr = c.addr
+ c.mu.Unlock()
+
+ timer := time.NewTimer(cfg.LeaseLength)
+ select {
+ case <-ctx.Done():
+ timer.Stop()
+ return
+ case <-timer.C:
+ // loop and make a renewal request
+ }
+ }
}
// Address reports the IP address acquired by the DHCP client.
@@ -76,56 +104,55 @@
return c.cfg
}
-// Shutdown relinquishes any lease and ends any outstanding renewal timers.
-func (c *Client) Shutdown() {
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.addr != "" {
- c.stack.RemoveAddress(c.nicid, c.addr)
- }
- if c.cancelRenew != nil {
- c.cancelRenew()
- }
-}
-
// Request executes a DHCP request session.
//
// On success, it adds a new address to this client's TCPIP stack.
// If the server sets a lease limit a timer is set to automatically
// renew it.
-func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error {
+func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) (cfg Config, reterr error) {
+ if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, "\xff\xff\xff\xff"); err != nil && err != tcpip.ErrDuplicateAddress {
+ return Config{}, fmt.Errorf("dhcp: %v", err)
+ }
+ if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, "\x00\x00\x00\x00"); err != nil && err != tcpip.ErrDuplicateAddress {
+ return Config{}, fmt.Errorf("dhcp: %v", err)
+ }
+ defer c.stack.RemoveAddress(c.nicid, "\xff\xff\xff\xff")
+ defer c.stack.RemoveAddress(c.nicid, "\x00\x00\x00\x00")
+
var wq waiter.Queue
ep, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- return fmt.Errorf("dhcp: outbound endpoint: %v", err)
+ return Config{}, fmt.Errorf("dhcp: outbound endpoint: %v", err)
}
err = ep.Bind(tcpip.FullAddress{
Addr: "\x00\x00\x00\x00",
Port: clientPort,
+ NIC: c.nicid,
}, nil)
defer ep.Close()
if err != nil {
- return fmt.Errorf("dhcp: connect failed: %v", err)
+ return Config{}, fmt.Errorf("dhcp: connect failed: %v", err)
}
epin, err := c.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- return fmt.Errorf("dhcp: inbound endpoint: %v", err)
+ return Config{}, fmt.Errorf("dhcp: inbound endpoint: %v", err)
}
err = epin.Bind(tcpip.FullAddress{
Addr: "\xff\xff\xff\xff",
Port: clientPort,
+ NIC: c.nicid,
}, nil)
defer epin.Close()
if err != nil {
- return fmt.Errorf("dhcp: connect failed: %v", err)
+ return Config{}, fmt.Errorf("dhcp: connect failed: %v", err)
}
var xid [4]byte
rand.Read(xid[:])
// DHCPDISCOVERY
- options := options{
+ discOpts := options{
{optDHCPMsgType, []byte{byte(dhcpDISCOVER)}},
{optParamReq, []byte{
1, // request subnet mask
@@ -135,22 +162,30 @@
}},
}
if requestedAddr != "" {
- options = append(options, option{optReqIPAddr, []byte(requestedAddr)})
+ discOpts = append(discOpts, option{optReqIPAddr, []byte(requestedAddr)})
}
- h := make(header, headerBaseSize+options.len())
+ var clientID []byte
+ if len(c.linkAddr) == 6 {
+ clientID = make([]byte, 7)
+ clientID[0] = 1 // htype: ARP Ethernet from RFC 1700
+ copy(clientID[1:], c.linkAddr)
+ discOpts = append(discOpts, option{optClientID, clientID})
+ }
+ h := make(header, headerBaseSize+discOpts.len())
h.init()
h.setOp(opRequest)
copy(h.xidbytes(), xid[:])
h.setBroadcast()
copy(h.chaddr(), c.linkAddr)
- h.setOptions(options)
+ h.setOptions(discOpts)
serverAddr := &tcpip.FullAddress{
Addr: "\xff\xff\xff\xff",
Port: serverPort,
+ NIC: c.nicid,
}
if _, err := ep.Write(buffer.View(h), serverAddr); err != nil {
- return fmt.Errorf("dhcp discovery write: %v", err)
+ return Config{}, fmt.Errorf("dhcp discovery write: %v", err)
}
we, ch := waiter.NewChannelEntry(nil)
@@ -158,115 +193,142 @@
defer wq.EventUnregister(&we)
// DHCPOFFER
+ var opts options
for {
var addr tcpip.FullAddress
- v, err := epin.Read(&addr)
- if err == tcpip.ErrWouldBlock {
+ v, e := epin.Read(&addr)
+ if e == tcpip.ErrWouldBlock {
select {
case <-ch:
continue
case <-ctx.Done():
- return fmt.Errorf("reading dhcp offer: %v", tcpip.ErrAborted)
+ return Config{}, fmt.Errorf("reading dhcp offer: %v", tcpip.ErrAborted)
}
}
h = header(v)
- if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) {
- break
+ var valid bool
+ var err error
+ opts, valid, err = loadDHCPReply(h, dhcpOFFER, xid[:])
+ if !valid {
+ if err != nil {
+ // TODO: report malformed server responses
+ }
+ continue
}
- }
- if _, err := h.options(); err != nil {
- return fmt.Errorf("dhcp offer: %v", err)
+ break
}
var ack bool
- var cfg Config
+ if err := cfg.decode(opts); err != nil {
+ return Config{}, fmt.Errorf("dhcp offer: %v", err)
+ }
// DHCPREQUEST
addr := tcpip.Address(h.yiaddr())
if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, addr); err != nil {
if err != tcpip.ErrDuplicateAddress {
- return fmt.Errorf("adding address: %v", err)
+ return Config{}, fmt.Errorf("adding address: %v", err)
}
}
defer func() {
- if ack {
- c.mu.Lock()
- c.addr = addr
- c.cfg = cfg
- c.mu.Unlock()
- } else {
+ if !ack || reterr != nil {
c.stack.RemoveAddress(c.nicid, addr)
+ addr = ""
+ cfg = Config{Error: reterr}
+ }
+
+ c.mu.Lock()
+ oldAddr := c.addr
+ c.addr = addr
+ c.cfg = cfg
+ c.mu.Unlock()
+
+ // Clean up broadcast addresses before calling acquiredFunc
+ // so nothing else uses them by mistake.
+ //
+ // (The deferred RemoveAddress calls above silently error.)
+ c.stack.RemoveAddress(c.nicid, "\xff\xff\xff\xff")
+ c.stack.RemoveAddress(c.nicid, "\x00\x00\x00\x00")
+
+ if c.acquiredFunc != nil {
+ c.acquiredFunc(oldAddr, addr, cfg)
+ }
+ if requestedAddr != "" && requestedAddr != addr {
+ c.stack.RemoveAddress(c.nicid, requestedAddr)
}
}()
+ h.init()
h.setOp(opRequest)
for i, b := 0, h.yiaddr(); i < len(b); i++ {
b[i] = 0
}
- h.setOptions([]option{
+ for i, b := 0, h.siaddr(); i < len(b); i++ {
+ b[i] = 0
+ }
+ for i, b := 0, h.giaddr(); i < len(b); i++ {
+ b[i] = 0
+ }
+ reqOpts := []option{
{optDHCPMsgType, []byte{byte(dhcpREQUEST)}},
{optReqIPAddr, []byte(addr)},
- {optDHCPServer, h.siaddr()},
- })
+ {optDHCPServer, []byte(cfg.ServerAddress)},
+ }
+ if len(clientID) != 0 {
+ reqOpts = append(reqOpts, option{optClientID, clientID})
+ }
+ h.setOptions(reqOpts)
if _, err := ep.Write([]byte(h), serverAddr); err != nil {
- return fmt.Errorf("dhcp discovery write: %v", err)
+ return Config{}, fmt.Errorf("dhcp discovery write: %v", err)
}
// DHCPACK
for {
var addr tcpip.FullAddress
- v, err := epin.Read(&addr)
- if err == tcpip.ErrWouldBlock {
+ v, e := epin.Read(&addr)
+ if e == tcpip.ErrWouldBlock {
select {
case <-ch:
continue
case <-ctx.Done():
- return fmt.Errorf("reading dhcp ack: %v", tcpip.ErrAborted)
+ return Config{}, fmt.Errorf("reading dhcp ack: %v", tcpip.ErrAborted)
}
}
h = header(v)
- if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) {
- break
+ var valid bool
+ var err error
+ opts, valid, err = loadDHCPReply(h, dhcpACK, xid[:])
+ if !valid {
+ if err != nil {
+ // TODO: report malformed server responses
+ }
+ if opts, valid, _ = loadDHCPReply(h, dhcpNAK, xid[:]); valid {
+ if msg := opts.message(); msg != "" {
+ return Config{}, fmt.Errorf("dhcp: NAK %q", msg)
+ }
+ return Config{}, fmt.Errorf("dhcp: NAK with no message")
+ }
+ continue
}
+ break
+ }
+ ack = true
+ return cfg, nil
+}
+
+func loadDHCPReply(h header, typ dhcpMsgType, xid []byte) (opts options, valid bool, err error) {
+ if !h.isValid() || h.op() != opReply || !bytes.Equal(h.xidbytes(), xid[:]) {
+ return nil, false, nil
}
opts, e := h.options()
if e != nil {
- return fmt.Errorf("dhcp ack: %v", e)
- }
- if err := cfg.decode(opts); err != nil {
- return fmt.Errorf("dhcp ack bad options: %v", err)
+ return nil, false, fmt.Errorf("dhcp ack: %v", e)
}
msgtype, e := opts.dhcpMsgType()
if e != nil {
- return fmt.Errorf("dhcp ack: %v", e)
+ return nil, false, fmt.Errorf("dhcp ack: %v", e)
}
- ack = msgtype == dhcpACK
- if !ack {
- return fmt.Errorf("dhcp: request not acknowledged")
+ if msgtype != typ {
+ return nil, false, nil
}
- if cfg.LeaseLength != 0 {
- go c.renewAfter(cfg.LeaseLength)
- }
- return nil
-}
-
-func (c *Client) renewAfter(d time.Duration) {
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.cancelRenew != nil {
- c.cancelRenew()
- }
- ctx, cancel := context.WithCancel(context.Background())
- c.cancelRenew = cancel
- go func() {
- timer := time.NewTimer(d)
- defer timer.Stop()
- select {
- case <-ctx.Done():
- case <-timer.C:
- if err := c.Request(ctx, c.addr); err != nil {
- log.Printf("address renewal failed: %v", err)
- go c.renewAfter(1 * time.Minute)
- }
- }
- }()
+ return opts, true, nil
}
diff --git a/dhcp/dhcp.go b/dhcp/dhcp.go
index 4487c8c..de0f65f 100644
--- a/dhcp/dhcp.go
+++ b/dhcp/dhcp.go
@@ -16,18 +16,19 @@
// Config is standard DHCP configuration.
type Config struct {
- ServerAddress tcpip.Address // address of the server
- SubnetMask tcpip.AddressMask // client address subnet mask
- Gateway tcpip.Address // client default gateway
- DomainNameServer tcpip.Address // client domain name server
- LeaseLength time.Duration // length of the address lease
+ Error error
+ ServerAddress tcpip.Address // address of the server
+ SubnetMask tcpip.AddressMask // client address subnet mask
+ Gateway tcpip.Address // client default gateway
+ DNS []tcpip.Address // client domain name servers
+ LeaseLength time.Duration // length of the address lease
}
func (cfg *Config) decode(opts []option) error {
*cfg = Config{}
for _, opt := range opts {
b := opt.body
- if l := opt.code.len(); l != -1 && l != len(b) {
+ if !opt.code.lenValid(len(b)) {
return fmt.Errorf("%s bad length: %d", opt.code, len(b))
}
switch opt.code {
@@ -41,7 +42,12 @@
case optDefaultGateway:
cfg.Gateway = tcpip.Address(b)
case optDomainNameServer:
- cfg.DomainNameServer = tcpip.Address(b)
+ for ; len(b) > 0; b = b[4:] {
+ if len(b) < 4 {
+ return fmt.Errorf("DNS bad length: %d", len(b))
+ }
+ cfg.DNS = append(cfg.DNS, tcpip.Address(b[:4]))
+ }
}
}
return nil
@@ -57,8 +63,12 @@
if cfg.Gateway != "" {
opts = append(opts, option{optDefaultGateway, []byte(cfg.Gateway)})
}
- if cfg.DomainNameServer != "" {
- opts = append(opts, option{optDomainNameServer, []byte(cfg.DomainNameServer)})
+ if len(cfg.DNS) > 0 {
+ dns := make([]byte, 0, 4*len(cfg.DNS))
+ for _, addr := range cfg.DNS {
+ dns = append(dns, addr...)
+ }
+ opts = append(opts, option{optDomainNameServer, dns})
}
if l := cfg.LeaseLength / time.Second; l != 0 {
v := make([]byte, 4)
@@ -97,10 +107,10 @@
if o := h.op(); o != opRequest && o != opReply {
return false
}
- if h[1] != 0x01 || h[2] != 0x06 || h[3] != 0x00 {
+ if h[1] != 0x01 || h[2] != 0x06 {
return false
}
- return bytes.Equal(h[236:240], magicCookie) && h[len(h)-1] == 0
+ return bytes.Equal(h[236:240], magicCookie)
}
func (h header) op() op { return op(h[0]) }
@@ -131,7 +141,7 @@
}
optlen := int(h[i+1])
if len(h) < i+2+optlen {
- return nil, fmt.Errorf("option too long")
+ return nil, fmt.Errorf("option %v too long i=%d, optlen=%d", optionCode(h[i]), i, optlen)
}
opts = append(opts, option{
code: optionCode(h[i]),
@@ -172,47 +182,31 @@
optSubnetMask optionCode = 1
optDefaultGateway optionCode = 3
optDomainNameServer optionCode = 6
+ optDomainName optionCode = 15
optReqIPAddr optionCode = 50
optLeaseTime optionCode = 51
optDHCPMsgType optionCode = 53 // dhcpMsgType
optDHCPServer optionCode = 54
optParamReq optionCode = 55
+ optMessage optionCode = 56
+ optClientID optionCode = 61
)
-func (code optionCode) len() int {
+func (code optionCode) lenValid(l int) bool {
switch code {
- case optSubnetMask, optDefaultGateway, optDomainNameServer,
+ case optSubnetMask, optDefaultGateway,
optReqIPAddr, optLeaseTime, optDHCPServer:
- return 4
+ return l == 4
case optDHCPMsgType:
- return 1
- case optParamReq:
- return -1 // no fixed length
- default:
- return -1
- }
-}
-
-func (code optionCode) String() string {
- switch code {
- case optSubnetMask:
- return "option(subnet-mask)"
- case optDefaultGateway:
- return "option(default-gateway)"
+ return l == 1
case optDomainNameServer:
- return "option(dns)"
- case optReqIPAddr:
- return "option(request-ip-address)"
- case optLeaseTime:
- return "option(least-time)"
- case optDHCPMsgType:
- return "option(message-type)"
- case optDHCPServer:
- return "option(server)"
+ return l%4 == 0
+ case optMessage, optDomainName, optClientID:
+ return l >= 1
case optParamReq:
- return "option(parameter-request)"
+ return true // no fixed length
default:
- return fmt.Sprintf("option(%d)", code)
+ return true // unknown option, assume ok
}
}
@@ -222,7 +216,7 @@
for _, opt := range opts {
if opt.code == optDHCPMsgType {
if len(opt.body) != 1 {
- return 0, fmt.Errorf("%s: bad length: %d", optDHCPMsgType, len(opt.body))
+ return 0, fmt.Errorf("%s: wrong length: %d", optDHCPMsgType, len(opt.body))
}
v := opt.body[0]
if v <= 0 || v >= 8 {
@@ -234,6 +228,15 @@
return 0, nil
}
+func (opts options) message() string {
+ for _, opt := range opts {
+ if opt.code == optMessage {
+ return string(opt.body)
+ }
+ }
+ return ""
+}
+
func (opts options) len() int {
l := 0
for _, opt := range opts {
diff --git a/dhcp/dhcp_string.go b/dhcp/dhcp_string.go
new file mode 100644
index 0000000..68b23a4
--- /dev/null
+++ b/dhcp/dhcp_string.go
@@ -0,0 +1,101 @@
+package dhcp
+
+import (
+ "bytes"
+ "fmt"
+
+ "github.com/google/netstack/tcpip"
+)
+
+func (h header) String() string {
+ opts, err := h.options()
+ var msgtype dhcpMsgType
+ if err == nil {
+ msgtype, err = opts.dhcpMsgType()
+ }
+ if !h.isValid() || err != nil {
+ return fmt.Sprintf("DHCP invalid, %v %v h[1:4]=%x cookie=%x len=%d (%v)", h.op(), h.xid(), []byte(h[1:4]), []byte(h[236:240]), len(h), err)
+ }
+ buf := new(bytes.Buffer)
+ fmt.Fprintf(buf, "%v %v len=%d\n", msgtype, h.xid(), len(h))
+ fmt.Fprintf(buf, "\tciaddr:%v yiaddr:%v siaddr:%v giaddr:%v\n",
+ tcpip.Address(h.ciaddr()),
+ tcpip.Address(h.yiaddr()),
+ tcpip.Address(h.siaddr()),
+ tcpip.Address(h.giaddr()))
+ fmt.Fprintf(buf, "\tchaddr:%x", h.chaddr())
+ for _, opt := range opts {
+ fmt.Fprintf(buf, "\n\t%v", opt)
+ }
+ return buf.String()
+}
+
+func (opt option) String() string {
+ buf := new(bytes.Buffer)
+ fmt.Fprintf(buf, "%v: ", opt.code)
+ fmt.Fprintf(buf, "%x", opt.body)
+ return buf.String()
+}
+
+func (code optionCode) String() string {
+ switch code {
+ case optSubnetMask:
+ return "option(subnet-mask)"
+ case optDefaultGateway:
+ return "option(default-gateway)"
+ case optDomainNameServer:
+ return "option(dns)"
+ case optDomainName:
+ return "option(domain-name)"
+ case optReqIPAddr:
+ return "option(request-ip-address)"
+ case optLeaseTime:
+ return "option(lease-time)"
+ case optDHCPMsgType:
+ return "option(message-type)"
+ case optDHCPServer:
+ return "option(server)"
+ case optParamReq:
+ return "option(parameter-request)"
+ case optMessage:
+ return "option(message)"
+ case optClientID:
+ return "option(client-id)"
+ default:
+ return fmt.Sprintf("option(%d)", code)
+ }
+}
+
+func (o op) String() string {
+ switch o {
+ case opRequest:
+ return "op(request)"
+ case opReply:
+ return "op(reply)"
+ }
+ return fmt.Sprintf("op(UNKNOWN:%d)", int(o))
+}
+
+func (t dhcpMsgType) String() string {
+ switch t {
+ case dhcpDISCOVER:
+ return "DHCPDISCOVER"
+ case dhcpOFFER:
+ return "DHCPOFFER"
+ case dhcpREQUEST:
+ return "DHCPREQUEST"
+ case dhcpDECLINE:
+ return "DHCPDECLINE"
+ case dhcpACK:
+ return "DHCPACK"
+ case dhcpNAK:
+ return "DHCPNAK"
+ case dhcpRELEASE:
+ return "DHCPRELEASE"
+ }
+ return fmt.Sprintf("DHCP(%d)", int(t))
+}
+
+func (v xid) String() string {
+ return fmt.Sprintf("xid:%x", uint32(v))
+}
diff --git a/dhcp/dhcp_test.go b/dhcp/dhcp_test.go
index 5dc8f6f..f385e97 100644
--- a/dhcp/dhcp_test.go
+++ b/dhcp/dhcp_test.go
@@ -17,9 +17,13 @@
"github.com/google/netstack/tcpip/network/ipv4"
"github.com/google/netstack/tcpip/stack"
"github.com/google/netstack/tcpip/transport/udp"
+ "github.com/google/netstack/waiter"
)
-func TestDHCP(t *testing.T) {
+const nicid = tcpip.NICID(1)
+const serverAddr = tcpip.Address("\xc0\xa8\x03\x01")
+
+func createStack(t *testing.T) *stack.Stack {
const defaultMTU = 65536
id, linkEP := channel.New(256, defaultMTU, "")
if testing.Verbose() {
@@ -38,17 +42,9 @@
s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}).(*stack.Stack)
- const nicid tcpip.NICID = 1
if err := s.CreateNIC(nicid, id); err != nil {
t.Fatal(err)
}
- if err := s.AddAddress(nicid, ipv4.ProtocolNumber, "\x00\x00\x00\x00"); err != nil {
- t.Fatal(err)
- }
- if err := s.AddAddress(nicid, ipv4.ProtocolNumber, "\xff\xff\xff\xff"); err != nil {
- t.Fatal(err)
- }
- const serverAddr = tcpip.Address("\xc0\xa8\x03\x01")
if err := s.AddAddress(nicid, ipv4.ProtocolNumber, serverAddr); err != nil {
t.Fatal(err)
}
@@ -60,14 +56,21 @@
NIC: nicid,
}})
- var clientAddrs = []tcpip.Address{"\xc0\xa8\x03\x02", "\xc0\xa8\x03\x03"}
+ return s
+}
+
+func TestDHCP(t *testing.T) {
+ s := createStack(t)
+ clientAddrs := []tcpip.Address{"\xc0\xa8\x03\x02", "\xc0\xa8\x03\x03"}
serverCfg := Config{
- ServerAddress: serverAddr,
- SubnetMask: "\xff\xff\xff\x00",
- Gateway: "\xc0\xa8\x03\xF0",
- DomainNameServer: "\x08\x08\x08\x08",
- LeaseLength: 24 * time.Hour,
+ ServerAddress: serverAddr,
+ SubnetMask: "\xff\xff\xff\x00",
+ Gateway: "\xc0\xa8\x03\xF0",
+ DNS: []tcpip.Address{
+ "\x08\x08\x08\x08", "\x08\x08\x04\x04",
+ },
+ LeaseLength: 24 * time.Hour,
}
serverCtx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -77,14 +80,14 @@
}
const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52")
- c0 := NewClient(s, nicid, clientLinkAddr0)
- if err := c0.Request(context.Background(), ""); err != nil {
+ c0 := NewClient(s, nicid, clientLinkAddr0, nil)
+ if _, err := c0.Request(context.Background(), ""); err != nil {
t.Fatal(err)
}
if got, want := c0.Address(), clientAddrs[0]; got != want {
t.Errorf("c.Addr()=%s, want=%s", got, want)
}
- if err := c0.Request(context.Background(), ""); err != nil {
+ if _, err := c0.Request(context.Background(), ""); err != nil {
t.Fatal(err)
}
if got, want := c0.Address(), clientAddrs[0]; got != want {
@@ -92,22 +95,221 @@
}
const clientLinkAddr1 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x53")
- c1 := NewClient(s, nicid, clientLinkAddr1)
- if err := c1.Request(context.Background(), ""); err != nil {
+ c1 := NewClient(s, nicid, clientLinkAddr1, nil)
+ if _, err := c1.Request(context.Background(), ""); err != nil {
t.Fatal(err)
}
if got, want := c1.Address(), clientAddrs[1]; got != want {
t.Errorf("c.Addr()=%s, want=%s", got, want)
}
- if err := c0.Request(context.Background(), ""); err != nil {
+ if _, err := c0.Request(context.Background(), ""); err != nil {
t.Fatal(err)
}
if got, want := c0.Address(), clientAddrs[0]; got != want {
t.Errorf("c.Addr()=%s, want=%s", got, want)
}
- if got, want := c0.Config(), serverCfg; got != want {
+ if got, want := c0.Config(), serverCfg; !equalConfig(got, want) {
t.Errorf("client config:\n\t%#+v\nwant:\n\t%#+v", got, want)
}
}
+
+func equalConfig(c0, c1 Config) bool {
+ if c0.Error != c1.Error || c0.ServerAddress != c1.ServerAddress || c0.SubnetMask != c1.SubnetMask || c0.Gateway != c1.Gateway || c0.LeaseLength != c1.LeaseLength {
+ return false
+ }
+ if len(c0.DNS) != len(c1.DNS) {
+ return false
+ }
+ for i := 0; i < len(c0.DNS); i++ {
+ if c0.DNS[i] != c1.DNS[i] {
+ return false
+ }
+ }
+ return true
+}
+
+func TestRenew(t *testing.T) {
+ s := createStack(t)
+ clientAddrs := []tcpip.Address{"\xc0\xa8\x03\x02"}
+
+ serverCfg := Config{
+ ServerAddress: serverAddr,
+ SubnetMask: "\xff\xff\xff\x00",
+ Gateway: "\xc0\xa8\x03\xF0",
+ DNS: []tcpip.Address{"\x08\x08\x08\x08"},
+ LeaseLength: 1 * time.Second,
+ }
+ serverCtx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ _, err := NewServer(serverCtx, s, clientAddrs, serverCfg)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ count := 0
+ var curAddr tcpip.Address
+ addrCh := make(chan tcpip.Address)
+ acquiredFunc := func(oldAddr, newAddr tcpip.Address, cfg Config) {
+ if err := cfg.Error; err != nil {
+ t.Fatalf("acquisition %d failed: %v", count, err)
+ }
+ if oldAddr != curAddr {
+ t.Fatalf("aquisition %d: curAddr=%v, oldAddr=%v", count, curAddr, oldAddr)
+ }
+ if cfg.LeaseLength != time.Second {
+ t.Fatalf("aquisition %d: lease length: %v, want %v", count, cfg.LeaseLength, time.Second)
+ }
+ count++
+ curAddr = newAddr
+ addrCh <- newAddr
+ }
+
+ clientCtx, cancel := context.WithCancel(context.Background())
+ const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52")
+ c := NewClient(s, nicid, clientLinkAddr0, acquiredFunc)
+ c.Run(clientCtx)
+
+ var addr tcpip.Address
+ select {
+ case addr = <-addrCh:
+ t.Logf("got first address: %v", addr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout acquiring initial address")
+ }
+
+ select {
+ case newAddr := <-addrCh:
+ t.Logf("got renewal: %v", newAddr)
+ if newAddr != addr {
+ t.Fatalf("renewal address is %v, want %v", newAddr, addr)
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout waiting for address renewal")
+ }
+
+ cancel()
+}
+
+// Regression test for https://fuchsia.atlassian.net/browse/NET-17
+func TestNoNullTerminator(t *testing.T) {
+ v := "\x02\x01\x06\x00" +
+ "\xc8\x37\xbe\x73\x00\x00\x80\x00\x00\x00\x00\x00\xc0\xa8\x2b\x92" +
+ "\xc0\xa8\x2b\x01\x00\x00\x00\x00\x00\x0f\x60\x0a\x23\x93\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x63\x82\x53\x63\x35\x01\x02\x36" +
+ "\x04\xc0\xa8\x2b\x01\x33\x04\x00\x00\x0e\x10\x3a\x04\x00\x00\x07" +
+ "\x08\x3b\x04\x00\x00\x0c\x4e\x01\x04\xff\xff\xff\x00\x1c\x04\xc0" +
+ "\xa8\x2b\xff\x03\x04\xc0\xa8\x2b\x01\x06\x04\xc0\xa8\x2b\x01\x2b" +
+ "\x0f\x41\x4e\x44\x52\x4f\x49\x44\x5f\x4d\x45\x54\x45\x52\x45\x44" +
+ "\xff"
+ h := header(v)
+ if !h.isValid() {
+ t.Error("failed to decode header")
+ }
+
+ if op := h.op(); op != opReply {
+ t.Errorf("bad opcode: %v expected: %v", op, opReply)
+ }
+
+ if _, err := h.options(); err != nil {
+ t.Errorf("bad options: %v", err)
+ }
+}
+
+func teeConn(c conn) (conn, conn) {
+ dup1 := &dupConn{
+ c: c,
+ dup: make(chan connMsg, 8),
+ }
+ dup2 := &chConn{
+ c: c,
+ ch: dup1.dup,
+ }
+ return dup1, dup2
+}
+
+type connMsg struct {
+ buf buffer.View
+ addr tcpip.FullAddress
+ err error
+}
+
+type dupConn struct {
+ c conn
+ dup chan connMsg
+}
+
+func (c *dupConn) Read() (buffer.View, tcpip.FullAddress, error) {
+ v, addr, err := c.c.Read()
+ c.dup <- connMsg{v, addr, err}
+ return v, addr, err
+}
+func (c *dupConn) Write(b []byte, addr *tcpip.FullAddress) error { return c.c.Write(b, addr) }
+
+type chConn struct {
+ ch chan connMsg
+ c conn
+}
+
+func (c *chConn) Read() (buffer.View, tcpip.FullAddress, error) {
+ msg := <-c.ch
+ return msg.buf, msg.addr, msg.err
+}
+func (c *chConn) Write(b []byte, addr *tcpip.FullAddress) error { return c.c.Write(b, addr) }
+
+func TestTwoServers(t *testing.T) {
+ s := createStack(t)
+
+ wq := new(waiter.Queue)
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("dhcp: server endpoint: %v", err)
+ }
+ if err = ep.Bind(tcpip.FullAddress{Port: serverPort}, nil); err != nil {
+ t.Fatalf("dhcp: server bind: %v", err)
+ }
+
+ serverCtx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ c1, c2 := teeConn(newEPConn(serverCtx, wq, ep))
+
+ _, err = newServer(serverCtx, c1, []tcpip.Address{"\xc0\xa8\x03\x02"}, Config{
+ ServerAddress: "\xc0\xa8\x03\x01",
+ SubnetMask: "\xff\xff\xff\x00",
+ Gateway: "\xc0\xa8\x03\xF0",
+ DNS: []tcpip.Address{"\x08\x08\x08\x08"},
+ LeaseLength: 30 * time.Minute,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = newServer(serverCtx, c2, []tcpip.Address{"\xc0\xa8\x04\x02"}, Config{
+ ServerAddress: "\xc0\xa8\x04\x01",
+ SubnetMask: "\xff\xff\xff\x00",
+ Gateway: "\xc0\xa8\x03\xF0",
+ DNS: []tcpip.Address{"\x08\x08\x08\x08"},
+ LeaseLength: 30 * time.Minute,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ const clientLinkAddr0 = tcpip.LinkAddress("\x52\x11\x22\x33\x44\x52")
+ c := NewClient(s, nicid, clientLinkAddr0, nil)
+ if _, err := c.Request(context.Background(), ""); err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/dhcp/server.go b/dhcp/server.go
index 58f656f..56a1d33 100644
--- a/dhcp/server.go
+++ b/dhcp/server.go
@@ -7,6 +7,7 @@
import (
"context"
"fmt"
+ "io"
"log"
"sync"
"time"
@@ -21,10 +22,8 @@
// Server is a DHCP server.
type Server struct {
- stack *stack.Stack
+ conn conn
broadcast tcpip.FullAddress
- wq waiter.Queue
- ep tcpip.Endpoint
addrs []tcpip.Address // TODO: use a tcpip.AddressMask or range structure
cfg Config
cfgopts []option // cfg to send to client
@@ -35,11 +34,79 @@
leases map[tcpip.LinkAddress]serverLease
}
+// conn is a blocking read/write network endpoint.
+type conn interface {
+ Read() (buffer.View, tcpip.FullAddress, error)
+ Write([]byte, *tcpip.FullAddress) error
+}
+
+type epConn struct {
+ ctx context.Context
+ wq *waiter.Queue
+ ep tcpip.Endpoint
+ we waiter.Entry
+ inCh chan struct{}
+}
+
+func newEPConn(ctx context.Context, wq *waiter.Queue, ep tcpip.Endpoint) *epConn {
+ c := &epConn{
+ ctx: ctx,
+ wq: wq,
+ ep: ep,
+ }
+ c.we, c.inCh = waiter.NewChannelEntry(nil)
+ wq.EventRegister(&c.we, waiter.EventIn)
+
+ go func() {
+ <-ctx.Done()
+ wq.EventUnregister(&c.we)
+ }()
+
+ return c
+}
+
+func (c *epConn) Read() (buffer.View, tcpip.FullAddress, error) {
+ for {
+ var addr tcpip.FullAddress
+ v, err := c.ep.Read(&addr)
+ if err == tcpip.ErrWouldBlock {
+ select {
+ case <-c.inCh:
+ continue
+ case <-c.ctx.Done():
+ return nil, tcpip.FullAddress{}, io.EOF
+ }
+ }
+ return v, addr, fmt.Errorf("dhcp: %v", err)
+ }
+}
+
+func (c *epConn) Write(b []byte, addr *tcpip.FullAddress) error {
+ _, err := c.ep.Write(b, addr)
+ return fmt.Errorf("dhcp: %v", err)
+}
+
// NewServer creates a new DHCP server and begins serving.
// The server continues serving until ctx is done.
func NewServer(ctx context.Context, stack *stack.Stack, addrs []tcpip.Address, cfg Config) (*Server, error) {
+ wq := new(waiter.Queue)
+ ep, err := stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ return nil, fmt.Errorf("dhcp: server endpoint: %v", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: serverPort}, nil); err != nil {
+ return nil, fmt.Errorf("dhcp: server bind: %v", err)
+ }
+ c := newEPConn(ctx, wq, ep)
+ return newServer(ctx, c, addrs, cfg)
+}
+
+func newServer(ctx context.Context, c conn, addrs []tcpip.Address, cfg Config) (*Server, error) {
+ if cfg.ServerAddress == "" {
+ return nil, fmt.Errorf("dhcp: server requires explicit server address")
+ }
s := &Server{
- stack: stack,
+ conn: c,
addrs: addrs,
cfg: cfg,
cfgopts: cfg.encode(),
@@ -52,19 +119,6 @@
leases: make(map[tcpip.LinkAddress]serverLease),
}
- var err *tcpip.Error
- s.ep, err = s.stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &s.wq)
- if err != nil {
- return nil, fmt.Errorf("dhcp: server endpoint: %v", err)
- }
- serverBroadcast := tcpip.FullAddress{
- Addr: "",
- Port: serverPort,
- }
- if err := s.ep.Bind(serverBroadcast, nil); err != nil {
- return nil, fmt.Errorf("dhcp: server bind: %v", err)
- }
-
for i := 0; i < len(s.handlers); i++ {
ch := make(chan header, 8)
s.handlers[i] = ch
@@ -99,20 +153,10 @@
// reader listens for all incoming DHCP packets and fans them out to
// handling goroutines based on XID as session identifiers.
func (s *Server) reader(ctx context.Context) {
- we, ch := waiter.NewChannelEntry(nil)
- s.wq.EventRegister(&we, waiter.EventIn)
- defer s.wq.EventUnregister(&we)
-
for {
- var addr tcpip.FullAddress
- v, err := s.ep.Read(&addr)
- if err == tcpip.ErrWouldBlock {
- select {
- case <-ch:
- continue
- case <-ctx.Done():
- return
- }
+ v, _, err := s.conn.Read()
+ if err != nil {
+ return
}
h := header(v)
@@ -230,16 +274,45 @@
h.setOp(opReply)
copy(h.xidbytes(), hreq.xidbytes())
copy(h.yiaddr(), lease.addr)
- copy(h.siaddr(), s.cfg.ServerAddress)
copy(h.chaddr(), hreq.chaddr())
h.setOptions(opts)
- s.ep.Write(buffer.View(h), &s.broadcast)
+ s.conn.Write([]byte(h), &s.broadcast)
+}
+
+func (s *Server) nack(hreq header) {
+ // DHCPNACK
+ opts := options([]option{
+ {optDHCPMsgType, []byte{byte(dhcpNAK)}},
+ {optDHCPServer, []byte(s.cfg.ServerAddress)},
+ })
+ h := make(header, headerBaseSize+opts.len())
+ h.init()
+ h.setOp(opReply)
+ copy(h.xidbytes(), hreq.xidbytes())
+ copy(h.chaddr(), hreq.chaddr())
+ h.setOptions(opts)
+ s.conn.Write([]byte(h), &s.broadcast)
}
func (s *Server) handleRequest(hreq header, opts options) {
linkAddr := tcpip.LinkAddress(hreq.chaddr()[:6])
xid := hreq.xid()
+ reqopts, err := hreq.options()
+ if err != nil {
+ s.nack(hreq)
+ return
+ }
+ var reqcfg Config
+ if err := reqcfg.decode(reqopts); err != nil {
+ s.nack(hreq)
+ return
+ }
+ if reqcfg.ServerAddress != s.cfg.ServerAddress {
+ // This request is for a different DHCP server. Ignore it.
+ return
+ }
+
s.mu.Lock()
lease := s.leases[linkAddr]
switch lease.state {
@@ -267,10 +340,9 @@
h.setOp(opReply)
copy(h.xidbytes(), hreq.xidbytes())
copy(h.yiaddr(), lease.addr)
- copy(h.siaddr(), s.cfg.ServerAddress)
copy(h.chaddr(), hreq.chaddr())
h.setOptions(opts)
- s.ep.Write(buffer.View(h), &s.broadcast)
+ s.conn.Write([]byte(h), &s.broadcast)
}
type leaseState int
diff --git a/dns/addrselect.go b/dns/addrselect.go
new file mode 100644
index 0000000..6373759
--- /dev/null
+++ b/dns/addrselect.go
@@ -0,0 +1,472 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Minimal RFC 6724 address selection.
+
+package dns
+
+import (
+ "context"
+ "sort"
+ "time"
+
+ "github.com/google/netstack/tcpip"
+ "github.com/google/netstack/tcpip/transport/udp"
+)
+
+func sortByRFC6724(c *Client, addrs []tcpip.Address) {
+ if len(addrs) < 2 {
+ return
+ }
+ sortByRFC6724withSrcs(addrs, srcAddrs(c, addrs))
+}
+
+func sortByRFC6724withSrcs(addrs []tcpip.Address, srcs []tcpip.Address) {
+ if len(addrs) != len(srcs) {
+ panic("internal error")
+ }
+ addrAttr := make([]ipAttr, len(addrs))
+ srcAttr := make([]ipAttr, len(srcs))
+ for i, v := range addrs {
+ addrAttr[i] = ipAttrOf(v)
+ srcAttr[i] = ipAttrOf(srcs[i])
+ }
+ sort.Stable(&byRFC6724{
+ addrs: addrs,
+ addrAttr: addrAttr,
+ srcs: srcs,
+ srcAttr: srcAttr,
+ })
+}
+
+// srcsAddrs tries to UDP-connect to each address to see if it has a
+// route. (This doesn't send any packets). The destination port
+// number is irrelevant.
+func srcAddrs(c *Client, addrs []tcpip.Address) []tcpip.Address {
+ srcs := make([]tcpip.Address, len(addrs))
+ dst := tcpip.FullAddress{Port: 9, NIC: c.nicid}
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*1)
+ for i := range addrs {
+ dst.Addr = addrs[i]
+ ep, _, err := c.connect(ctx, udp.ProtocolNumber, dst)
+ if err == nil {
+ if src, err := ep.GetLocalAddress(); err == nil {
+ srcs[i] = src.Addr
+ }
+ ep.Close()
+ }
+ }
+ cancel()
+ return srcs
+}
+
+type ipSubnet struct {
+ Addr tcpip.Address
+ Mask tcpip.Address
+}
+
+type ipAttr struct {
+ Scope scope
+ Precedence uint8
+ Label uint8
+}
+
+func ipAttrOf(ip tcpip.Address) ipAttr {
+ if ip == "" {
+ return ipAttr{}
+ }
+ match := rfc6724policyTable.Classify(ip)
+ return ipAttr{
+ Scope: classifyScope(ip),
+ Precedence: match.Precedence,
+ Label: match.Label,
+ }
+}
+
+type byRFC6724 struct {
+ addrs []tcpip.Address // addrs to sort
+ addrAttr []ipAttr
+ srcs []tcpip.Address // or nil if unreachable
+ srcAttr []ipAttr
+}
+
+func (s *byRFC6724) Len() int { return len(s.addrs) }
+
+func (s *byRFC6724) Swap(i, j int) {
+ s.addrs[i], s.addrs[j] = s.addrs[j], s.addrs[i]
+ s.srcs[i], s.srcs[j] = s.srcs[j], s.srcs[i]
+ s.addrAttr[i], s.addrAttr[j] = s.addrAttr[j], s.addrAttr[i]
+ s.srcAttr[i], s.srcAttr[j] = s.srcAttr[j], s.srcAttr[i]
+}
+
+// Less reports whether i is a better destination address for this
+// host than j.
+//
+// The algorithm and variable names comes from RFC 6724 section 6.
+func (s *byRFC6724) Less(i, j int) bool {
+ DA := s.addrs[i]
+ DB := s.addrs[j]
+ SourceDA := s.srcs[i]
+ SourceDB := s.srcs[j]
+ attrDA := &s.addrAttr[i]
+ attrDB := &s.addrAttr[j]
+ attrSourceDA := &s.srcAttr[i]
+ attrSourceDB := &s.srcAttr[j]
+
+ const preferDA = true
+ const preferDB = false
+
+ // Rule 1: Avoid unusable destinations.
+ // If DB is known to be unreachable or if Source(DB) is undefined, then
+ // prefer DA. Similarly, if DA is known to be unreachable or if
+ // Source(DA) is undefined, then prefer DB.
+ if SourceDA == "" && SourceDB == "" {
+ return false // "equal"
+ }
+ if SourceDB == "" {
+ return preferDA
+ }
+ if SourceDA == "" {
+ return preferDB
+ }
+
+ // Rule 2: Prefer matching scope.
+ // If Scope(DA) = Scope(Source(DA)) and Scope(DB) <> Scope(Source(DB)),
+ // then prefer DA. Similarly, if Scope(DA) <> Scope(Source(DA)) and
+ // Scope(DB) = Scope(Source(DB)), then prefer DB.
+ if attrDA.Scope == attrSourceDA.Scope && attrDB.Scope != attrSourceDB.Scope {
+ return preferDA
+ }
+ if attrDA.Scope != attrSourceDA.Scope && attrDB.Scope == attrSourceDB.Scope {
+ return preferDB
+ }
+
+ // Rule 3: Avoid deprecated addresses.
+ // If Source(DA) is deprecated and Source(DB) is not, then prefer DB.
+ // Similarly, if Source(DA) is not deprecated and Source(DB) is
+ // deprecated, then prefer DA.
+
+ // TODO(bradfitz): implement? low priority for now.
+
+ // Rule 4: Prefer home addresses.
+ // If Source(DA) is simultaneously a home address and care-of address
+ // and Source(DB) is not, then prefer DA. Similarly, if Source(DB) is
+ // simultaneously a home address and care-of address and Source(DA) is
+ // not, then prefer DB.
+
+ // TODO(bradfitz): implement? low priority for now.
+
+ // Rule 5: Prefer matching label.
+ // If Label(Source(DA)) = Label(DA) and Label(Source(DB)) <> Label(DB),
+ // then prefer DA. Similarly, if Label(Source(DA)) <> Label(DA) and
+ // Label(Source(DB)) = Label(DB), then prefer DB.
+ if attrSourceDA.Label == attrDA.Label &&
+ attrSourceDB.Label != attrDB.Label {
+ return preferDA
+ }
+ if attrSourceDA.Label != attrDA.Label &&
+ attrSourceDB.Label == attrDB.Label {
+ return preferDB
+ }
+
+ // Rule 6: Prefer higher precedence.
+ // If Precedence(DA) > Precedence(DB), then prefer DA. Similarly, if
+ // Precedence(DA) < Precedence(DB), then prefer DB.
+ if attrDA.Precedence > attrDB.Precedence {
+ return preferDA
+ }
+ if attrDA.Precedence < attrDB.Precedence {
+ return preferDB
+ }
+
+ // Rule 7: Prefer native transport.
+ // If DA is reached via an encapsulating transition mechanism (e.g.,
+ // IPv6 in IPv4) and DB is not, then prefer DB. Similarly, if DB is
+ // reached via encapsulation and DA is not, then prefer DA.
+
+ // TODO(bradfitz): implement? low priority for now.
+
+ // Rule 8: Prefer smaller scope.
+ // If Scope(DA) < Scope(DB), then prefer DA. Similarly, if Scope(DA) >
+ // Scope(DB), then prefer DB.
+ if attrDA.Scope < attrDB.Scope {
+ return preferDA
+ }
+ if attrDA.Scope > attrDB.Scope {
+ return preferDB
+ }
+
+ // Rule 9: Use longest matching prefix.
+ // When DA and DB belong to the same address family (both are IPv6 or
+ // both are IPv4): If CommonPrefixLen(Source(DA), DA) >
+ // CommonPrefixLen(Source(DB), DB), then prefer DA. Similarly, if
+ // CommonPrefixLen(Source(DA), DA) < CommonPrefixLen(Source(DB), DB),
+ // then prefer DB.
+ da4 := DA.To4() != ""
+ db4 := DB.To4() != ""
+ if da4 == db4 {
+ commonA := commonPrefixLen(SourceDA, DA)
+ commonB := commonPrefixLen(SourceDB, DB)
+
+ // CommonPrefixLen doesn't really make sense for IPv4, and even
+ // causes problems for common load balancing practices
+ // (e.g., https://golang.org/issue/13283). Glibc instead only
+ // uses CommonPrefixLen for IPv4 when the source and destination
+ // addresses are on the same subnet, but that requires extra
+ // work to find the netmask for our source addresses. As a
+ // simpler heuristic, we limit its use to when the source and
+ // destination belong to the same special purpose block.
+ if da4 {
+ if !sameIPv4SpecialPurposeBlock(SourceDA, DA) {
+ commonA = 0
+ }
+ if !sameIPv4SpecialPurposeBlock(SourceDB, DB) {
+ commonB = 0
+ }
+ }
+
+ if commonA > commonB {
+ return preferDA
+ }
+ if commonA < commonB {
+ return preferDB
+ }
+ }
+
+ // Rule 10: Otherwise, leave the order unchanged.
+ // If DA preceded DB in the original list, prefer DA.
+ // Otherwise, prefer DB.
+ return false // "equal"
+}
+
+type policyTableEntry struct {
+ Prefix *tcpip.Subnet
+ Precedence uint8
+ Label uint8
+}
+
+type policyTable []policyTableEntry
+
+// RFC 6724 section 2.1.
+var rfc6724policyTable = policyTable{
+ {
+ // ::1/128
+ Prefix: makeSubnet(tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"), 128),
+ Precedence: 50,
+ Label: 0,
+ },
+ {
+ // ::/0
+ Prefix: makeSubnet(tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"), 0),
+ Precedence: 40,
+ Label: 1,
+ },
+ {
+ // ::ffff:0:0/96
+ Prefix: makeSubnet(tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"), 96),
+ Precedence: 35,
+ Label: 4,
+ },
+ {
+ // 2002::/16
+ Prefix: makeSubnet(tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"), 16),
+ Precedence: 30,
+ Label: 2,
+ },
+ {
+ // 2001::/32
+ Prefix: makeSubnet(tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"), 32),
+ Precedence: 5,
+ Label: 5,
+ },
+ {
+ // fc00::/7
+ Prefix: makeSubnet(tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"), 7),
+ Precedence: 3,
+ Label: 13,
+ },
+ {
+ // ::/96
+ Prefix: makeSubnet(tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"), 96),
+ Precedence: 1,
+ Label: 3,
+ },
+ {
+ // fec0::/10
+ Prefix: makeSubnet(tcpip.Address("\xfe\xc0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"), 10),
+ Precedence: 1,
+ Label: 11,
+ },
+ {
+ // 3ffe::/16
+ Prefix: makeSubnet(tcpip.Address("\x3f\xfe\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"), 16),
+ Precedence: 1,
+ Label: 12,
+ },
+}
+
+func init() {
+ sort.Sort(sort.Reverse(byMaskLength(rfc6724policyTable)))
+}
+
+// byMaskLength sorts policyTableEntry by the size of their Prefix.Mask.Size,
+// from smallest mask, to largest.
+type byMaskLength []policyTableEntry
+
+func (s byMaskLength) Len() int { return len(s) }
+func (s byMaskLength) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+func (s byMaskLength) Less(i, j int) bool {
+ isize, _ := s[i].Prefix.Bits()
+ jsize, _ := s[j].Prefix.Bits()
+ return isize < jsize
+}
+
+// Classify returns the policyTableEntry of the entry with the longest
+// matching prefix that contains ip.
+// The table t must be sorted from largest mask size to smallest.
+func (t policyTable) Classify(ip tcpip.Address) policyTableEntry {
+ for _, ent := range t {
+ if ent.Prefix.Contains(ip) {
+ return ent
+ }
+ }
+ return policyTableEntry{}
+}
+
+// RFC 6724 section 3.1.
+type scope uint8
+
+const (
+ scopeInterfaceLocal scope = 0x1
+ scopeLinkLocal scope = 0x2
+ scopeAdminLocal scope = 0x4
+ scopeSiteLocal scope = 0x5
+ scopeOrgLocal scope = 0x8
+ scopeGlobal scope = 0xe
+)
+
+func classifyScope(ip tcpip.Address) scope {
+ // TODO(mpcomplete): implement
+ // if ip.IsLoopback() || ip.IsLinkLocalUnicast() {
+ // return scopeLinkLocal
+ // }
+ ipv6 := len(ip) == 16 && ip.To4() == ""
+ // if ipv6 && ip.IsMulticast() {
+ // return scope(ip[1] & 0xf)
+ // }
+
+ // Site-local addresses are defined in RFC 3513 section 2.5.6
+ // (and deprecated in RFC 3879).
+ if ipv6 && ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 {
+ return scopeSiteLocal
+ }
+ return scopeGlobal
+}
+
+// commonPrefixLen reports the length of the longest prefix (looking
+// at the most significant, or leftmost, bits) that the
+// two addresses have in common, up to the length of a's prefix (i.e.,
+// the portion of the address not including the interface ID).
+//
+// If a or b is an IPv4 address as an IPv6 address, the IPv4 addresses
+// are compared (with max common prefix length of 32).
+// If a and b are different IP versions, 0 is returned.
+//
+// See https://tools.ietf.org/html/rfc6724#section-2.2
+func commonPrefixLen(a, b tcpip.Address) (cpl int) {
+ if a4 := a.To4(); a4 != "" {
+ a = a4
+ }
+ if b4 := b.To4(); b4 != "" {
+ b = b4
+ }
+ if len(a) != len(b) {
+ return 0
+ }
+ // If IPv6, only up to the prefix (first 64 bits)
+ if len(a) > 8 {
+ a = a[:8]
+ b = b[:8]
+ }
+ for len(a) > 0 {
+ if a[0] == b[0] {
+ cpl += 8
+ a = a[1:]
+ b = b[1:]
+ continue
+ }
+ bits := 8
+ ab, bb := a[0], b[0]
+ for {
+ ab >>= 1
+ bb >>= 1
+ bits--
+ if ab == bb {
+ cpl += bits
+ return
+ }
+ }
+ }
+ return
+}
+
+// sameIPv4SpecialPurposeBlock reports whether a and b belong to the same
+// address block reserved by the IANA IPv4 Special-Purpose Address Registry:
+// http://www.iana.org/assignments/iana-ipv4-special-registry/iana-ipv4-special-registry.xhtml
+func sameIPv4SpecialPurposeBlock(a, b tcpip.Address) bool {
+ a, b = a.To4(), b.To4()
+ if a == "" || b == "" || a[0] != b[0] {
+ return false
+ }
+ // IANA defines more special-purpose blocks, but these are the only
+ // ones likely to be relevant to typical Go systems.
+ switch a[0] {
+ case 10: // 10.0.0.0/8: Private-Use
+ return true
+ case 127: // 127.0.0.0/8: Loopback
+ return true
+ case 169: // 169.254.0.0/16: Link Local
+ return a[1] == 254 && b[1] == 254
+ case 172: // 172.16.0.0/12: Private-Use
+ return a[1]&0xf0 == 16 && b[1]&0xf0 == 16
+ case 192: // 192.168.0.0/16: Private-Use
+ return a[1] == 168 && b[1] == 168
+ }
+ return false
+}
+
+// makeSubnet returns a tcpip.Subnet from a valid IPv6 address and bit prefix.
+// It panics on failure.
+func makeSubnet(addr tcpip.Address, bits int) *tcpip.Subnet {
+ subnet, err := tcpip.NewSubnet(addr, makeMask(bits, 8*16))
+ if err != nil {
+ panic(err.Error())
+ }
+ return &subnet
+}
+
+// makeMask returns an AddressMask consisting of `ones' 1 bits
+// followed by 0s up to a total length of `bits' bits.
+// For a mask of this form, makeMask is the inverse of tcpip.Subnet.Bits.
+func makeMask(ones, bits int) tcpip.AddressMask {
+ if bits != 8*4 && bits != 8*16 {
+ return ""
+ }
+ if ones < 0 || ones > bits {
+ return ""
+ }
+ l := bits / 8
+ m := make([]byte, l)
+ n := uint(ones)
+ for i := 0; i < l; i++ {
+ if n >= 8 {
+ m[i] = 0xff
+ n -= 8
+ continue
+ }
+ m[i] = ^byte(0xff >> n)
+ n = 0
+ }
+ return tcpip.AddressMask(m)
+}
diff --git a/dns/cache.go b/dns/cache.go
new file mode 100644
index 0000000..61c647c
--- /dev/null
+++ b/dns/cache.go
@@ -0,0 +1,260 @@
+// Copyright 2017 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "fmt"
+ "log"
+ "math"
+ "sync"
+ "time"
+
+ "github.com/google/netstack/dns/dnsmessage"
+)
+
+const (
+ // TODO: Think about a good value. dnsmasq defaults to 150 names.
+ maxEntries = 1024
+)
+
+const debug = false
+
+var testHookNow = func() time.Time { return time.Now() }
+
+// Single entry in the cache, like a TypeA resource holding an IPv4 address.
+type cacheEntry struct {
+ rr dnsmessage.Resource // the resource
+ ttd time.Time // when this entry expires
+}
+
+// Returns true if this entry is a CNAME that points at something no longer in our cache.
+func (entry *cacheEntry) isDanglingCNAME(cache *cacheInfo) bool {
+ switch rr := entry.rr.(type) {
+ case *dnsmessage.CNAMEResource:
+ return cache.m[rr.CNAME] == nil
+ default:
+ return false
+ }
+}
+
+// The full cache.
+type cacheInfo struct {
+ mu sync.Mutex
+ m map[string][]*cacheEntry
+ numEntries int
+}
+
+func newCache() cacheInfo {
+ return cacheInfo{m: make(map[string][]*cacheEntry)}
+}
+
+// Returns a list of Resources that match the given Question (same class and type and matching domain name).
+func (cache *cacheInfo) lookup(question *dnsmessage.Question) []dnsmessage.Resource {
+ entries := cache.m[question.Name]
+
+ rrs := []dnsmessage.Resource{}
+ for _, entry := range entries {
+ h := entry.rr.Header()
+ if h.Class == question.Class && h.Name == question.Name {
+ switch rr := entry.rr.(type) {
+ case *dnsmessage.CNAMEResource:
+ cnamerrs := cache.lookup(&dnsmessage.Question{
+ Name: rr.CNAME,
+ Class: question.Class,
+ Type: question.Type,
+ })
+ rrs = append(rrs, cnamerrs...)
+ default:
+ if h.Type == question.Type {
+ rrs = append(rrs, rr)
+ }
+ }
+ }
+ }
+ return rrs
+}
+
+func resourceEqual(r1 dnsmessage.Resource, r2 dnsmessage.Resource) bool {
+ h1 := r1.Header()
+ h2 := r2.Header()
+ if h1.Class != h2.Class || h1.Type != h2.Type || h1.Name != h2.Name {
+ return false
+ }
+ switch r1 := r1.(type) {
+ case *dnsmessage.AResource:
+ return r1.A == r2.(*dnsmessage.AResource).A
+ case *dnsmessage.AAAAResource:
+ return r1.AAAA == r2.(*dnsmessage.AAAAResource).AAAA
+ case *dnsmessage.CNAMEResource:
+ return r1.CNAME == r2.(*dnsmessage.CNAMEResource).CNAME
+ case *dnsmessage.NegativeResource:
+ return true
+ }
+ panic("unexpected resource type")
+}
+
+// Searches `entries` for an exact resource match, returning the entry if found.
+func findExact(entries []*cacheEntry, rr dnsmessage.Resource) *cacheEntry {
+ for _, entry := range entries {
+ if resourceEqual(entry.rr, rr) {
+ return entry
+ }
+ }
+ return nil
+}
+
+// Finds the minimum TTL value of any SOA resource in a response. Returns 0 if not found.
+// This is used for caching a failed DNS query. See RFC 2308.
+func findSOAMinTTL(auths []dnsmessage.Resource) uint32 {
+ minTTL := uint32(math.MaxUint32)
+ foundSOA := false
+ for _, auth := range auths {
+ if auth.Header().Class == dnsmessage.ClassINET {
+ switch soa := auth.(type) {
+ case *dnsmessage.SOAResource:
+ foundSOA = true
+ if soa.MinTTL < minTTL {
+ minTTL = soa.MinTTL
+ }
+ }
+ }
+ }
+ if foundSOA {
+ return minTTL
+ }
+ return 0
+}
+
+// Attempts to add a new entry into the cache. Can fail if the cache is full.
+func (cache *cacheInfo) insert(rr dnsmessage.Resource) {
+ h := rr.Header()
+ newEntry := cacheEntry{
+ ttd: testHookNow().Add(time.Duration(h.TTL) * time.Second),
+ rr: rr,
+ }
+
+ entries := cache.m[h.Name]
+ if existing := findExact(entries, rr); existing != nil {
+ if _, ok := existing.rr.(*dnsmessage.NegativeResource); ok {
+ // We have a valid record now; replace the negative resource entirely.
+ existing.rr = rr
+ existing.ttd = newEntry.ttd
+ } else if newEntry.ttd.After(existing.ttd) {
+ existing.ttd = newEntry.ttd
+ }
+ if debug {
+ log.Printf("DNS cache update: %v(%v) expires %v", h.Name, h.Type, existing.ttd)
+ }
+ } else if cache.numEntries+1 <= maxEntries {
+ if debug {
+ log.Printf("DNS cache insert: %v(%v) expires %v", h.Name, h.Type, newEntry.ttd)
+ }
+ cache.m[h.Name] = append(entries, &newEntry)
+ cache.numEntries++
+ } else {
+ // TODO(mpcomplete): might be better to evict the LRU entry instead.
+ // TODO(mpcomplete): RFC 1035 7.4 says that if we can't cache this RR, we
+ // shouldn't cache any other RRs for the same name in this response.
+ log.Printf("DNS cache is full; insert failed: %v(%v)", h.Name, h.Type)
+ }
+}
+
+// Attempts to add each Resource as a new entry in the cache. Can fail if the cache is full.
+func (cache *cacheInfo) insertAll(rrs []dnsmessage.Resource) {
+ cache.prune()
+ for _, rr := range rrs {
+ h := rr.Header()
+ if h.Class == dnsmessage.ClassINET {
+ switch h.Type {
+ case dnsmessage.TypeA, dnsmessage.TypeAAAA, dnsmessage.TypeCNAME:
+ cache.insert(rr)
+ }
+ }
+ }
+}
+
+func (cache *cacheInfo) insertNegative(question *dnsmessage.Question, msg *dnsmessage.Message) {
+ cache.prune()
+ minTTL := findSOAMinTTL(msg.Authorities)
+ if minTTL == 0 {
+ // Don't cache without a TTL value.
+ return
+ }
+ rr := &dnsmessage.NegativeResource{
+ ResourceHeader: dnsmessage.ResourceHeader{
+ Name: question.Name,
+ Type: question.Type,
+ Class: dnsmessage.ClassINET,
+ TTL: minTTL,
+ },
+ }
+ cache.insert(rr)
+}
+
+// Removes every expired/dangling entry from the cache.
+func (cache *cacheInfo) prune() {
+ now := testHookNow()
+ for name, entries := range cache.m {
+ removed := false
+ for i := 0; i < len(entries); {
+ if now.After(entries[i].ttd) || entries[i].isDanglingCNAME(cache) {
+ entries[i] = entries[len(entries)-1]
+ entries = entries[:len(entries)-1]
+ cache.numEntries--
+ removed = true
+ } else {
+ i++
+ }
+ }
+ if len(entries) == 0 {
+ delete(cache.m, name)
+ } else if removed {
+ cache.m[name] = entries
+ }
+ }
+}
+
+func debugString(rr []dnsmessage.Resource) string {
+ str := "["
+ for _, rr := range rr {
+ str += fmt.Sprintf("%v, ", rr)
+ }
+ str += "]"
+ return str
+}
+
+var cache = newCache()
+
+func newCachedResolver(fallback Resolver) Resolver {
+ return func(c *Client, question dnsmessage.Question) (string, []dnsmessage.Resource, *dnsmessage.Message, error) {
+ if !(question.Class == dnsmessage.ClassINET && (question.Type == dnsmessage.TypeA || question.Type == dnsmessage.TypeAAAA)) {
+ panic("unexpected question type")
+ }
+
+ cache.mu.Lock()
+ rrs := cache.lookup(&question)
+ cache.mu.Unlock()
+ if len(rrs) != 0 {
+ if debug {
+ log.Printf("DNS cache hit %v(%v) => %v", question.Name, question.Type, debugString(rrs))
+ }
+ return "", rrs, nil, nil
+ }
+
+ cname, rrs, msg, err := fallback(c, question)
+ if debug {
+ log.Printf("DNS cache miss, server returned %v(%v) => %v; err=%v", question.Name, question.Type, debugString(rrs), err)
+ }
+ cache.mu.Lock()
+ if err == nil {
+ cache.insertAll(msg.Answers)
+ } else if err, ok := err.(*Error); ok && err.CacheNegative {
+ cache.insertNegative(&question, msg)
+ }
+ cache.mu.Unlock()
+
+ return cname, rrs, msg, err
+ }
+}
diff --git a/dns/cache_test.go b/dns/cache_test.go
new file mode 100644
index 0000000..9e7e75d
--- /dev/null
+++ b/dns/cache_test.go
@@ -0,0 +1,228 @@
+// Copyright 2017 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dns
+
+import (
+ "testing"
+ "time"
+
+ "github.com/google/netstack/dns/dnsmessage"
+)
+
+func makeResourceHeader(name string, ttl uint32) dnsmessage.ResourceHeader {
+ return dnsmessage.ResourceHeader{
+ Name: name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ TTL: ttl,
+ }
+}
+
+func makeQuestion(name string) *dnsmessage.Question {
+ return &dnsmessage.Question{
+ Name: name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ }
+}
+
+var smallTestResources = []dnsmessage.Resource{
+ &dnsmessage.AResource{
+ ResourceHeader: makeResourceHeader("example.com.", 5),
+ A: [4]byte{127, 0, 0, 1},
+ },
+ &dnsmessage.AResource{
+ ResourceHeader: makeResourceHeader("example.com.", 5),
+ A: [4]byte{127, 0, 0, 2},
+ },
+}
+
+var smallTestQuestion = makeQuestion("example.com.")
+
+var soaAuthority = dnsmessage.SOAResource{
+ ResourceHeader: makeResourceHeader("example.com.", 5),
+ MinTTL: 12,
+}
+
+// Tests a simple insert and lookup pair.
+func TestLookup(t *testing.T) {
+ cache := newCache()
+ cache.insertAll(smallTestResources)
+ rrs := cache.lookup(smallTestQuestion)
+ if len(rrs) != 2 {
+ t.Errorf("cache.lookup failed. Got %d. Want %d.", len(rrs), 2)
+ }
+ for _, rr := range rrs {
+ if rr.Header().Name != "example.com." {
+ t.Errorf("cache.lookup failed. Got '%q'. Want 'example.com.'", rr.Header().Name)
+ }
+ }
+}
+
+// Tests that entries are pruned when they expire, and not before.
+func TestExpires(t *testing.T) {
+ cache := newCache()
+
+ // These records expire at 5 seconds.
+ testTime := time.Now()
+ testHookNow = func() time.Time { return testTime }
+ cache.insertAll(smallTestResources)
+
+ // Still there after t=4 seconds.
+ testTime = testTime.Add(4 * time.Second)
+ cache.prune()
+ rrs := cache.lookup(smallTestQuestion)
+ if len(rrs) != 2 {
+ t.Errorf("cache.prune failed. Got %d. Want %d.", len(rrs), 2)
+ }
+
+ // Gone after t=6 seconds.
+ testTime = testTime.Add(2 * time.Second)
+ cache.prune()
+ rrs = cache.lookup(smallTestQuestion)
+ if len(rrs) != 0 {
+ t.Errorf("cache.prune failed. Got %d. Want %d.", len(rrs), 0)
+ }
+}
+
+// Tests that we can't insert more than maxEntries entries, but after pruning old ones, we can insert again.
+func TestMaxEntries(t *testing.T) {
+ cache := newCache()
+
+ testTime := time.Now()
+ testHookNow = func() time.Time { return testTime }
+
+ // One record that expires at 10 seconds.
+ cache.insertAll([]dnsmessage.Resource{
+ &dnsmessage.AResource{
+ ResourceHeader: makeResourceHeader("example.com.", 10),
+ A: [4]byte{127, 0, 0, 1},
+ },
+ })
+
+ // A bunch that expire at 5 seconds.
+ for i := 0; i < maxEntries; i++ {
+ cache.insertAll([]dnsmessage.Resource{
+ &dnsmessage.AResource{
+ ResourceHeader: makeResourceHeader("example.com.", 5),
+ A: [4]byte{byte(i >> 24), byte(i >> 16), byte(i >> 8), byte(i)},
+ },
+ })
+ }
+
+ rrs := cache.lookup(smallTestQuestion)
+ if len(rrs) != maxEntries {
+ t.Errorf("cache.insertAll failed. Got %d. Want %d.", len(rrs), maxEntries)
+ }
+
+ // Cache is at capacity. Can't insert anymore.
+ cache.insertAll([]dnsmessage.Resource{
+ &dnsmessage.AResource{
+ ResourceHeader: makeResourceHeader("foo.example.com.", 5),
+ A: [4]byte{192, 168, 0, 1},
+ },
+ })
+ rrs = cache.lookup(makeQuestion("foo.example.com."))
+ if len(rrs) != 0 {
+ t.Errorf("cache.insertAll failed. Got %d. Want %d.", len(rrs), 0)
+ }
+
+ // Advance the clock so the 5 second entries expire. Insert should succeed.
+ testTime = testTime.Add(6 * time.Second)
+ cache.insertAll([]dnsmessage.Resource{
+ &dnsmessage.AResource{
+ ResourceHeader: makeResourceHeader("foo.example.com.", 5),
+ A: [4]byte{192, 168, 0, 1},
+ },
+ })
+
+ rrs = cache.lookup(makeQuestion("foo.example.com."))
+ if len(rrs) != 1 {
+ t.Errorf("cache.insertAll failed. Got %d. Want %d.", len(rrs), 1)
+ }
+
+ rrs = cache.lookup(makeQuestion("example.com."))
+ if len(rrs) != 1 {
+ t.Errorf("cache.insertAll failed. Got %d. Want %d.", len(rrs), 1)
+ }
+}
+
+// Tests that we get results when looking up a domain alias.
+func TestCNAME(t *testing.T) {
+ cache := newCache()
+ cache.insertAll(smallTestResources)
+
+ // One CNAME record that points at an existing record.
+ cache.insertAll([]dnsmessage.Resource{
+ &dnsmessage.CNAMEResource{
+ ResourceHeader: makeResourceHeader("foobar.com.", 10),
+ CNAME: "example.com.",
+ },
+ })
+
+ rrs := cache.lookup(makeQuestion("foobar.com."))
+ if len(rrs) != 2 {
+ t.Errorf("cache.lookup failed. Got %d. Want %d.", len(rrs), 2)
+ }
+ for _, rr := range rrs {
+ if rr.Header().Name != "example.com." {
+ t.Errorf("cache.lookup failed. Got '%q'. Want 'example.com.'", rr.Header().Name)
+ }
+ }
+}
+
+// Tests that the cache doesn't store multiple identical records.
+func TestDupe(t *testing.T) {
+ cache := newCache()
+ cache.insertAll(smallTestResources)
+ cache.insertAll(smallTestResources)
+ rrs := cache.lookup(smallTestQuestion)
+ if len(rrs) != 2 {
+ t.Errorf("cache.lookup failed. Got %d. Want %d.", len(rrs), 2)
+ }
+}
+
+// Tests that we can insert and expire negative resources.
+func TestNegative(t *testing.T) {
+ cache := newCache()
+
+ // The negative record expires at 12 seconds (taken from the SOA authority resource).
+ testTime := time.Now()
+ testHookNow = func() time.Time { return testTime }
+ cache.insertNegative(smallTestQuestion, &dnsmessage.Message{
+ Questions: []dnsmessage.Question{*smallTestQuestion},
+ Authorities: []dnsmessage.Resource{&soaAuthority},
+ })
+
+ // Still there after t=11 seconds.
+ testTime = testTime.Add(11 * time.Second)
+ cache.prune()
+ rrs := cache.lookup(smallTestQuestion)
+ if len(rrs) != 1 {
+ t.Errorf("cache.prune failed. Got %d. Want %d.", len(rrs), 1)
+ }
+
+ // Gone after t=13 seconds.
+ testTime = testTime.Add(2 * time.Second)
+ cache.prune()
+ rrs = cache.lookup(smallTestQuestion)
+ if len(rrs) != 0 {
+ t.Errorf("cache.prune failed. Got %d. Want %d.", len(rrs), 0)
+ }
+}
+
+// Tests that a negative resource is replaced when we have an actual resource for that query.
+func TestNegativeUpdate(t *testing.T) {
+ cache := newCache()
+ cache.insertNegative(smallTestQuestion, &dnsmessage.Message{
+ Questions: []dnsmessage.Question{*smallTestQuestion},
+ Authorities: []dnsmessage.Resource{&soaAuthority},
+ })
+ cache.insertAll(smallTestResources)
+ rrs := cache.lookup(smallTestQuestion)
+ if len(rrs) != 2 {
+ t.Errorf("cache.lookup failed. Got %s. Want %d.", debugString(rrs), 2)
+ }
+}
diff --git a/dns/client.go b/dns/client.go
new file mode 100644
index 0000000..53b5a96
--- /dev/null
+++ b/dns/client.go
@@ -0,0 +1,625 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// DNS client: see RFC 1035.
+// Has to be linked into package net for Dial.
+
+// TODO(rsc):
+// Could potentially handle many outstanding lookups faster.
+// Could have a small cache.
+// Random UDP source port (net.Dial should do that for us).
+// Random request IDs.
+// TODO(mpcomplete):
+// Cleanup
+// Decide whether we need DNSSEC, EDNS0, reverse DNS or other query types
+// We don't support ipv6 zones. Do we need to?
+
+package dns
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "math/rand"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/google/netstack/dns/dnsmessage"
+ "github.com/google/netstack/tcpip"
+ "github.com/google/netstack/tcpip/network/ipv4"
+ "github.com/google/netstack/tcpip/transport/tcp"
+ "github.com/google/netstack/tcpip/transport/udp"
+ "github.com/google/netstack/waiter"
+)
+
+// TODO(mpcomplete): Use FIDL to fetch the DNS config from the parent process.
+var tmpDNSConfig = dnsConfig{
+ servers: []tcpip.FullAddress{},
+ search: []string{},
+ ndots: 100,
+ timeout: 3 * time.Second,
+ attempts: 3,
+ rotate: true,
+ unknownOpt: false,
+ lookup: []string{},
+ err: nil,
+ mtime: time.Now(),
+}
+
+// Client is a DNS client.
+type Client struct {
+ stack tcpip.Stack
+ nicid tcpip.NICID
+}
+
+// A Resolver answers DNS Questions.
+type Resolver func(c *Client, question dnsmessage.Question) (cname string, rrs []dnsmessage.Resource, msg *dnsmessage.Message, err error)
+
+// Error represents an error while issuing a DNS query for a hostname.
+type Error struct {
+ Err string // a general error string
+ Name string // the hostname being queried
+ Server *tcpip.FullAddress // optional DNS server
+ CacheNegative bool // true if this represents a negative response that should be cached (RFC 2308)
+}
+
+func (e *Error) Error() string {
+ if e.Server != nil {
+ return fmt.Sprintf("lookup %s on %v: %s", e.Name, e.Server, e.Err)
+ }
+ return fmt.Sprintf("lookup %s: %s", e.Name, e.Err)
+}
+
+// NewClient creates a DHCP client.
+func NewClient(s tcpip.Stack, nicid tcpip.NICID) *Client {
+ return &Client{
+ stack: s,
+ nicid: nicid,
+ }
+}
+
+// roundTrip writes the query to and reads the response from the Endpoint.
+// The message format is slightly different depending on the transport protocol
+// (for TCP, a 2 byte message length is prepended). See RFC 1035.
+func roundTrip(ctx context.Context, transport tcpip.TransportProtocolNumber, ep tcpip.Endpoint, wq *waiter.Queue, query *dnsmessage.Message) (response *dnsmessage.Message, err error) {
+ b, err := query.Pack()
+ if err != nil {
+ return nil, err
+ }
+ if transport == tcp.ProtocolNumber {
+ l := len(b)
+ b = append([]byte{byte(l >> 8), byte(l)}, b...)
+ }
+
+ // Write to endpoint.
+ for len(b) > 0 {
+ n, err := ep.Write(b, nil)
+ if err != nil {
+ return nil, fmt.Errorf("dns: write: %v", err)
+ }
+
+ b = b[n:]
+ }
+
+ // Read from endpoint.
+ b = []byte{}
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&waitEntry, waiter.EventIn)
+ defer wq.EventUnregister(&waitEntry)
+ for {
+ v, err := ep.Read(nil)
+ if err != nil {
+ if err == tcpip.ErrClosedForReceive {
+ break
+ }
+
+ if err == tcpip.ErrWouldBlock {
+ select {
+ case <-notifyCh:
+ continue
+ case <-ctx.Done():
+ return nil, fmt.Errorf("dns: read: %v", tcpip.ErrTimeout)
+ }
+ }
+
+ return nil, fmt.Errorf("dns: read: %v", err)
+ }
+
+ b = append(b, []byte(v)...)
+
+ // Get the contents of the response.
+ var bcontents []byte
+ switch transport {
+ case tcp.ProtocolNumber:
+ if len(b) > 2 {
+ l := int(b[0])<<8 | int(b[1])
+ bcontents = b[2:(l + 2)]
+ } else {
+ continue
+ }
+ case udp.ProtocolNumber:
+ bcontents = b
+ }
+
+ response = &dnsmessage.Message{}
+ if err := response.Unpack(bcontents); err != nil {
+ // Ignore invalid responses as they may be malicious
+ // forgery attempts. Instead continue waiting until
+ // timeout. See golang.org/issue/13281.
+ continue
+ }
+ break
+ }
+
+ return response, nil
+}
+
+func (c *Client) connect(ctx context.Context, transport tcpip.TransportProtocolNumber, server tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, error) {
+ var wq waiter.Queue
+ ep, err := c.stack.NewEndpoint(transport, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ return nil, nil, fmt.Errorf("dns: %v", err)
+ }
+
+ // Issue connect request and wait for it to complete.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&waitEntry, waiter.EventOut)
+ err = ep.Connect(server)
+ defer wq.EventUnregister(&waitEntry)
+ if err == tcpip.ErrConnectStarted {
+ select {
+ case <-notifyCh:
+ err = ep.GetSockOpt(tcpip.ErrorOption{})
+ case <-ctx.Done():
+ err = tcpip.ErrTimeout
+ }
+ }
+
+ if err != nil {
+ return nil, nil, fmt.Errorf("dns: %v", err)
+ }
+
+ return ep, &wq, nil
+}
+
+// exchange sends a query on the connection and hopes for a response.
+func (c *Client) exchange(server tcpip.FullAddress, name string, qtype dnsmessage.Type, timeout time.Duration) (response *dnsmessage.Message, err error) {
+ query := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ RecursionDesired: true,
+ },
+ Questions: []dnsmessage.Question{
+ {name, qtype, dnsmessage.ClassINET},
+ },
+ }
+
+ protos := []tcpip.TransportProtocolNumber{udp.ProtocolNumber, tcp.ProtocolNumber}
+ for _, proto := range protos {
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+
+ ep, wq, err := c.connect(ctx, proto, server)
+ defer func() {
+ if ep != nil {
+ ep.Close()
+ }
+ }()
+ if err != nil {
+ cancel()
+ return nil, err
+ }
+
+ query.ID = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
+ response, err = roundTrip(ctx, proto, ep, wq, &query)
+ cancel()
+
+ if err != nil {
+ return nil, err
+ }
+ if response.Truncated { // see RFC 5966
+ continue
+ }
+ return response, nil
+ }
+ return nil, errors.New("no answer from the DNS server")
+}
+
+// Do a lookup for a single name, which must be rooted
+// (otherwise answer will not find the answers).
+func (c *Client) tryOneName(cfg *dnsConfig, name string, qtype dnsmessage.Type) (string, []dnsmessage.Resource, *dnsmessage.Message, error) {
+ if len(cfg.servers) == 0 {
+ return "", nil, nil, &Error{Err: "no DNS servers", Name: name}
+ }
+
+ var lastErr error
+ for i := 0; i < cfg.attempts; i++ {
+ for _, server := range cfg.servers {
+ server := tcpip.FullAddress{
+ NIC: c.nicid,
+ Addr: server.Addr,
+ Port: server.Port,
+ }
+ msg, err := c.exchange(server, name, qtype, cfg.timeout)
+ if err != nil {
+ lastErr = &Error{
+ Err: err.Error(),
+ Name: name,
+ Server: &server,
+ }
+ continue
+ }
+ // libresolv continues to the next server when it receives
+ // an invalid referral response. See golang.org/issue/15434.
+ if msg.RCode == dnsmessage.RCodeSuccess && !msg.Authoritative && !msg.RecursionAvailable && len(msg.Answers) == 0 && len(msg.Additionals) == 0 {
+ lastErr = &Error{Err: "lame referral", Name: name, Server: &server}
+ continue
+ }
+
+ cname, rrs, err := answer(name, server, msg, qtype)
+ // If answer errored for rcodes dnsRcodeSuccess or dnsRcodeNameError,
+ // it means the response in msg was not useful and trying another
+ // server probably won't help. Return now in those cases.
+ // TODO: indicate this in a more obvious way, such as a field on Error?
+ if err == nil || msg.RCode == dnsmessage.RCodeSuccess || msg.RCode == dnsmessage.RCodeNameError {
+ return cname, rrs, msg, err
+ }
+ lastErr = err
+ }
+ }
+ return "", nil, nil, lastErr
+}
+
+// addrRecordList converts and returns a list of IP addresses from DNS
+// address records (both A and AAAA). Other record types are ignored.
+func addrRecordList(rrs []dnsmessage.Resource) []tcpip.Address {
+ addrs := make([]tcpip.Address, 0, 4)
+ for _, rr := range rrs {
+ switch rr := rr.(type) {
+ case *dnsmessage.AResource:
+ addrs = append(addrs, tcpip.Address(rr.A[:]))
+ case *dnsmessage.AAAAResource:
+ addrs = append(addrs, tcpip.Address(rr.AAAA[:]))
+ }
+ }
+ return addrs
+}
+
+// A clientConfig represents a DNS stub resolver configuration.
+type clientConfig struct {
+ initOnce sync.Once // guards init of clientConfig
+
+ // ch is used as a semaphore that only allows one lookup at a
+ // time to recheck resolv.conf.
+ ch chan struct{} // guards lastChecked and modTime
+ lastChecked time.Time // last time resolv.conf was checked
+
+ mu sync.RWMutex // protects the following vars
+ dnsConfig *dnsConfig // parsed resolv.conf structure used in lookups
+ runtimeServers []tcpip.FullAddress // servers added while running (e.g. by DHCP)
+ resolver Resolver // a handler which answers DNS Questions
+}
+
+type dnsConfig struct {
+ servers []tcpip.FullAddress // server addresses (host and port) to use
+ search []string // rooted suffixes to append to local name
+ ndots int // number of dots in name to trigger absolute lookup
+ timeout time.Duration // wait before giving up on a query, including retries
+ attempts int // lost packets before giving up on server
+ rotate bool // round robin among servers
+ unknownOpt bool // anything unknown was encountered
+ lookup []string // OpenBSD top-level database "lookup" order
+ err error // any error that occurs during open of resolv.conf
+ mtime time.Time // time of resolv.conf modification
+}
+
+var clientConf clientConfig
+
+func newNetworkResolver(cfg *dnsConfig) Resolver {
+ return func(c *Client, question dnsmessage.Question) (string, []dnsmessage.Resource, *dnsmessage.Message, error) {
+ return c.tryOneName(cfg, question.Name, question.Type)
+ }
+}
+
+func readConfig() *dnsConfig {
+ cfg := tmpDNSConfig
+ return &cfg
+}
+
+// updateConfigLocked replaces the dnsConfig with the given value.
+func (conf *clientConfig) updateConfigLocked(dnsConfig *dnsConfig) {
+ conf.dnsConfig = dnsConfig
+ if conf.runtimeServers != nil {
+ conf.dnsConfig.servers = append(conf.dnsConfig.servers, conf.runtimeServers...)
+ }
+ conf.resolver = newCachedResolver(newNetworkResolver(conf.dnsConfig))
+}
+
+// init initializes conf and is only called via conf.initOnce.
+func (conf *clientConfig) init() {
+ // Set dnsConfig and lastChecked so we don't parse
+ // resolv.conf twice the first time.
+ if conf.dnsConfig == nil {
+ conf.updateConfigLocked(readConfig())
+ }
+ conf.lastChecked = time.Now()
+
+ // Prepare ch so that only one update of clientConfig may
+ // run at once.
+ conf.ch = make(chan struct{}, 1)
+}
+
+// tryUpdate tries to update conf with the named resolv.conf file.
+// The name variable only exists for testing. It is otherwise always
+// "/etc/resolv.conf".
+func (conf *clientConfig) tryUpdate() {
+ conf.initOnce.Do(conf.init)
+
+ // Ensure only one update at a time checks resolv.conf.
+ if !conf.tryAcquireSema() {
+ return
+ }
+ defer conf.releaseSema()
+
+ now := time.Now()
+ if conf.lastChecked.After(now.Add(-5 * time.Second)) {
+ return
+ }
+ conf.lastChecked = now
+
+ dnsConf := readConfig()
+ conf.mu.Lock()
+ conf.updateConfigLocked(dnsConf)
+ conf.mu.Unlock()
+}
+
+func (conf *clientConfig) tryAcquireSema() bool {
+ select {
+ case conf.ch <- struct{}{}:
+ return true
+ default:
+ return false
+ }
+}
+
+func (conf *clientConfig) releaseSema() {
+ <-conf.ch
+}
+
+// avoidDNS reports whether this is a hostname for which we should not
+// use DNS. Currently this includes only .onion, per RFC 7686. See
+// golang.org/issue/13705. Does not cover .local names (RFC 6762),
+// see golang.org/issue/16739.
+func avoidDNS(name string) bool {
+ if name == "" {
+ return true
+ }
+ if name[len(name)-1] == '.' {
+ name = name[:len(name)-1]
+ }
+ return strings.HasSuffix(name, ".onion")
+}
+
+// nameList returns a list of names for sequential DNS queries.
+func (conf *dnsConfig) nameList(name string) []string {
+ if avoidDNS(name) {
+ return nil
+ }
+
+ // If name is rooted (trailing dot), try only that name.
+ rooted := len(name) > 0 && name[len(name)-1] == '.'
+ if rooted {
+ return []string{name}
+ }
+
+ // hasNdots := count(name, '.') >= conf.ndots
+ hasNdots := false
+ name += "."
+
+ // Build list of search choices.
+ names := make([]string, 0, 1+len(conf.search))
+ // If name has enough dots, try unsuffixed first.
+ if hasNdots {
+ names = append(names, name)
+ }
+ // Try suffixes.
+ for _, suffix := range conf.search {
+ names = append(names, name+suffix)
+ }
+ // Try unsuffixed, if not tried first above.
+ if !hasNdots {
+ names = append(names, name)
+ }
+ return names
+}
+
+// SetRuntimeServers sets the list of runtime servers to query. Servers are checked sequentially, in order.
+func (c *Client) SetRuntimeServers(addrs []tcpip.Address) {
+ clientConf.mu.Lock()
+ defer clientConf.mu.Unlock()
+
+ clientConf.runtimeServers = nil
+ for _, addr := range addrs {
+ server := tcpip.FullAddress{
+ Addr: addr,
+ Port: 53,
+ }
+ clientConf.runtimeServers = append(clientConf.runtimeServers, server)
+ }
+
+ if clientConf.dnsConfig != nil {
+ clientConf.updateConfigLocked(clientConf.dnsConfig)
+ }
+}
+
+// LookupIP returns a list of IP addresses that are registered for the give domain name.
+func (c *Client) LookupIP(name string) (addrs []tcpip.Address, err error) {
+ if !isDomainName(name) {
+ return nil, &Error{Err: "invalid domain name", Name: name}
+ }
+ clientConf.tryUpdate()
+ clientConf.mu.RLock()
+ conf := clientConf.dnsConfig
+ resolver := clientConf.resolver
+ clientConf.mu.RUnlock()
+ type racer struct {
+ fqdn string
+ rrs []dnsmessage.Resource
+ error
+ }
+ lane := make(chan racer, 1)
+ qtypes := [...]dnsmessage.Type{dnsmessage.TypeA, dnsmessage.TypeAAAA}
+ var lastErr error
+ for _, fqdn := range conf.nameList(name) {
+ for _, qtype := range qtypes {
+ go func(qtype dnsmessage.Type) {
+ _, rrs, _, err := resolver(c, dnsmessage.Question{Name: fqdn, Type: qtype, Class: dnsmessage.ClassINET})
+ lane <- racer{fqdn, rrs, err}
+ }(qtype)
+ }
+ for range qtypes {
+ racer := <-lane
+ if racer.error != nil {
+ // Prefer error for original name.
+ if lastErr == nil || racer.fqdn == name+"." {
+ lastErr = racer.error
+ }
+ continue
+ }
+ addrs = append(addrs, addrRecordList(racer.rrs)...)
+ }
+ if len(addrs) > 0 {
+ break
+ }
+ }
+ if lastErr, ok := lastErr.(*Error); ok {
+ // Show original name passed to lookup, not suffixed one.
+ // In general we might have tried many suffixes; showing
+ // just one is misleading. See also golang.org/issue/6324.
+ lastErr.Name = name
+ }
+ sortByRFC6724(c, addrs)
+ if len(addrs) == 0 && lastErr != nil {
+ return nil, lastErr
+ }
+ return addrs, nil
+}
+
+const noSuchHost = "no such host"
+
+// Answer extracts the appropriate answer for a DNS lookup
+// for (name, qtype) from the response message msg, which
+// is assumed to have come from server.
+// It is exported mainly for use by registered helpers.
+func answer(name string, server tcpip.FullAddress, msg *dnsmessage.Message, qtype dnsmessage.Type) (cname string, addrs []dnsmessage.Resource, err error) {
+ addrs = make([]dnsmessage.Resource, 0, len(msg.Answers))
+ if msg.RCode == dnsmessage.RCodeNameError {
+ // TODO: There seem to be some cases where we should cache a name error, but not all. The
+ // spec is confusing on this point. See RFC 2308.
+ return "", nil, &Error{Err: noSuchHost, Name: name, Server: &server, CacheNegative: false}
+ }
+ if msg.RCode != dnsmessage.RCodeSuccess {
+ // None of the error codes make sense
+ // for the query we sent. If we didn't get
+ // a name error and we didn't get success,
+ // the server is behaving incorrectly.
+ return "", nil, &Error{Err: "server misbehaving", Name: name, Server: &server}
+ }
+
+ // Look for the name.
+ // Presotto says it's okay to assume that servers listed in
+ // /etc/resolv.conf are recursive resolvers.
+ // We asked for recursion, so it should have included
+ // all the answers we need in this one packet.
+Cname:
+ for cnameloop := 0; cnameloop < 10; cnameloop++ {
+ addrs = addrs[0:0]
+ for _, rr := range msg.Answers {
+ h := rr.Header()
+ if h.Class == dnsmessage.ClassINET && equalASCIILabel(h.Name, name) {
+ switch h.Type {
+ case qtype:
+ addrs = append(addrs, rr)
+ case dnsmessage.TypeCNAME:
+ // redirect to cname
+ name = rr.(*dnsmessage.CNAMEResource).CNAME
+ continue Cname
+ }
+ }
+ }
+ if len(addrs) == 0 {
+ return "", nil, &Error{Err: noSuchHost, Name: name, Server: &server, CacheNegative: true}
+ }
+ return name, addrs, nil
+ }
+
+ return "", nil, &Error{Err: "too many redirects", Name: name, Server: &server}
+}
+
+func equalASCIILabel(x, y string) bool {
+ if len(x) != len(y) {
+ return false
+ }
+ for i := 0; i < len(x); i++ {
+ a := x[i]
+ b := y[i]
+ if 'A' <= a && a <= 'Z' {
+ a += 0x20
+ }
+ if 'A' <= b && b <= 'Z' {
+ b += 0x20
+ }
+ if a != b {
+ return false
+ }
+ }
+ return true
+}
+
+func isDomainName(s string) bool {
+ // See RFC 1035, RFC 3696.
+ if len(s) == 0 {
+ return false
+ }
+ if len(s) > 255 {
+ return false
+ }
+
+ last := byte('.')
+ ok := false // Ok once we've seen a letter.
+ partlen := 0
+ for i := 0; i < len(s); i++ {
+ c := s[i]
+ switch {
+ default:
+ return false
+ case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
+ ok = true
+ partlen++
+ case '0' <= c && c <= '9':
+ // fine
+ partlen++
+ case c == '-':
+ // Byte before dash cannot be dot.
+ if last == '.' {
+ return false
+ }
+ partlen++
+ case c == '.':
+ // Byte before dot cannot be dot, dash.
+ if last == '.' || last == '-' {
+ return false
+ }
+ if partlen > 63 || partlen == 0 {
+ return false
+ }
+ partlen = 0
+ }
+ last = c
+ }
+ if last == '-' || partlen > 63 {
+ return false
+ }
+
+ return ok
+}
diff --git a/dns/dnsmessage/message.go b/dns/dnsmessage/message.go
new file mode 100644
index 0000000..e94b535
--- /dev/null
+++ b/dns/dnsmessage/message.go
@@ -0,0 +1,1429 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package dnsmessage provides a mostly RFC 1035 compliant implementation of
+// DNS message packing and unpacking.
+//
+// This implementation is designed to minimize heap allocations and avoid
+// unnecessary packing and unpacking as much as possible.
+package dnsmessage
+
+import "errors"
+
+// Packet formats
+
+// A Type is a type of DNS request and response.
+type Type uint16
+
+// A Class is a type of network.
+type Class uint16
+
+// An OpCode is a DNS operation code.
+type OpCode uint16
+
+// An RCode is a DNS response status code.
+type RCode uint16
+
+// Wire constants.
+const (
+ // ResourceHeader.Type and Question.Type
+ TypeA Type = 1
+ TypeNS Type = 2
+ TypeCNAME Type = 5
+ TypeSOA Type = 6
+ TypePTR Type = 12
+ TypeMX Type = 15
+ TypeTXT Type = 16
+ TypeAAAA Type = 28
+ TypeSRV Type = 33
+
+ // Question.Type
+ TypeWKS Type = 11
+ TypeHINFO Type = 13
+ TypeMINFO Type = 14
+ TypeAXFR Type = 252
+ TypeALL Type = 255
+
+ // ResourceHeader.Class and Question.Class
+ ClassINET Class = 1
+ ClassCSNET Class = 2
+ ClassCHAOS Class = 3
+ ClassHESIOD Class = 4
+
+ // Question.Class
+ ClassANY Class = 255
+
+ // Message.Rcode
+ RCodeSuccess RCode = 0
+ RCodeFormatError RCode = 1
+ RCodeServerFailure RCode = 2
+ RCodeNameError RCode = 3
+ RCodeNotImplemented RCode = 4
+ RCodeRefused RCode = 5
+)
+
+var (
+ // ErrNotStarted indicates that the prerequisite information isn't
+ // available yet because the previous records haven't been appropriately
+ // parsed or skipped.
+ ErrNotStarted = errors.New("parsing of this type isn't available yet")
+
+ // ErrSectionDone indicated that all records in the section have been
+ // parsed.
+ ErrSectionDone = errors.New("parsing of this section has completed")
+
+ errBaseLen = errors.New("insufficient data for base length type")
+ errCalcLen = errors.New("insufficient data for calculated length type")
+ errReserved = errors.New("segment prefix is reserved")
+ errTooManyPtr = errors.New("too many pointers (>10)")
+ errInvalidPtr = errors.New("invalid pointer")
+ errResourceLen = errors.New("insufficient data for resource body length")
+ errSegTooLong = errors.New("segment length too long")
+ errZeroSegLen = errors.New("zero length segment")
+ errResTooLong = errors.New("resource length too long")
+ errTooManyQuestions = errors.New("too many Questions to pack (>65535)")
+ errTooManyAnswers = errors.New("too many Answers to pack (>65535)")
+ errTooManyAuthorities = errors.New("too many Authorities to pack (>65535)")
+ errTooManyAdditionals = errors.New("too many Additionals to pack (>65535)")
+)
+
+type nestedError struct {
+ // s is the current level's error message.
+ s string
+
+ // err is the nested error.
+ err error
+}
+
+// nestedError implements error.Error.
+func (e *nestedError) Error() string {
+ return e.s + ": " + e.err.Error()
+}
+
+// Header is a representation of a DNS message header.
+type Header struct {
+ ID uint16
+ Response bool
+ OpCode OpCode
+ Authoritative bool
+ Truncated bool
+ RecursionDesired bool
+ RecursionAvailable bool
+ RCode RCode
+}
+
+func (m *Header) pack() (id uint16, bits uint16) {
+ id = m.ID
+ bits = uint16(m.OpCode)<<11 | uint16(m.RCode)
+ if m.RecursionAvailable {
+ bits |= headerBitRA
+ }
+ if m.RecursionDesired {
+ bits |= headerBitRD
+ }
+ if m.Truncated {
+ bits |= headerBitTC
+ }
+ if m.Authoritative {
+ bits |= headerBitAA
+ }
+ if m.Response {
+ bits |= headerBitQR
+ }
+ return
+}
+
+// Message is a representation of a DNS message.
+type Message struct {
+ Header
+ Questions []Question
+ Answers []Resource
+ Authorities []Resource
+ Additionals []Resource
+}
+
+type section uint8
+
+const (
+ sectionHeader section = iota
+ sectionQuestions
+ sectionAnswers
+ sectionAuthorities
+ sectionAdditionals
+ sectionDone
+
+ headerBitQR = 1 << 15 // query/response (response=1)
+ headerBitAA = 1 << 10 // authoritative
+ headerBitTC = 1 << 9 // truncated
+ headerBitRD = 1 << 8 // recursion desired
+ headerBitRA = 1 << 7 // recursion available
+)
+
+var sectionNames = map[section]string{
+ sectionHeader: "header",
+ sectionQuestions: "Question",
+ sectionAnswers: "Answer",
+ sectionAuthorities: "Authority",
+ sectionAdditionals: "Additional",
+}
+
+// header is the wire format for a DNS message header.
+type header struct {
+ id uint16
+ bits uint16
+ questions uint16
+ answers uint16
+ authorities uint16
+ additionals uint16
+}
+
+func (h *header) count(sec section) uint16 {
+ switch sec {
+ case sectionQuestions:
+ return h.questions
+ case sectionAnswers:
+ return h.answers
+ case sectionAuthorities:
+ return h.authorities
+ case sectionAdditionals:
+ return h.additionals
+ }
+ return 0
+}
+
+func (h *header) pack(msg []byte) []byte {
+ msg = packUint16(msg, h.id)
+ msg = packUint16(msg, h.bits)
+ msg = packUint16(msg, h.questions)
+ msg = packUint16(msg, h.answers)
+ msg = packUint16(msg, h.authorities)
+ return packUint16(msg, h.additionals)
+}
+
+func (h *header) unpack(msg []byte, off int) (int, error) {
+ newOff := off
+ var err error
+ if h.id, newOff, err = unpackUint16(msg, newOff); err != nil {
+ return off, &nestedError{"id", err}
+ }
+ if h.bits, newOff, err = unpackUint16(msg, newOff); err != nil {
+ return off, &nestedError{"bits", err}
+ }
+ if h.questions, newOff, err = unpackUint16(msg, newOff); err != nil {
+ return off, &nestedError{"questions", err}
+ }
+ if h.answers, newOff, err = unpackUint16(msg, newOff); err != nil {
+ return off, &nestedError{"answers", err}
+ }
+ if h.authorities, newOff, err = unpackUint16(msg, newOff); err != nil {
+ return off, &nestedError{"authorities", err}
+ }
+ if h.additionals, newOff, err = unpackUint16(msg, newOff); err != nil {
+ return off, &nestedError{"additionals", err}
+ }
+ return newOff, nil
+}
+
+func (h *header) header() Header {
+ return Header{
+ ID: h.id,
+ Response: (h.bits & headerBitQR) != 0,
+ OpCode: OpCode(h.bits>>11) & 0xF,
+ Authoritative: (h.bits & headerBitAA) != 0,
+ Truncated: (h.bits & headerBitTC) != 0,
+ RecursionDesired: (h.bits & headerBitRD) != 0,
+ RecursionAvailable: (h.bits & headerBitRA) != 0,
+ RCode: RCode(h.bits & 0xF),
+ }
+}
+
+// A Resource is a DNS resource record.
+type Resource interface {
+ // Header return's the Resource's ResourceHeader.
+ Header() *ResourceHeader
+
+ // pack packs a Resource except for its header.
+ pack(msg []byte, compression map[string]int) ([]byte, error)
+
+ // realType returns the actual type of the Resource. This is used to
+ // fill in the header Type field.
+ realType() Type
+}
+
+func packResource(msg []byte, resource Resource, compression map[string]int) ([]byte, error) {
+ oldMsg := msg
+ resource.Header().Type = resource.realType()
+ msg, length, err := resource.Header().pack(msg, compression)
+ if err != nil {
+ return msg, &nestedError{"ResourceHeader", err}
+ }
+ preLen := len(msg)
+ msg, err = resource.pack(msg, compression)
+ if err != nil {
+ return msg, &nestedError{"content", err}
+ }
+ conLen := len(msg) - preLen
+ if conLen > int(^uint16(0)) {
+ return oldMsg, errResTooLong
+ }
+ // Fill in the length now that we know how long the content is.
+ packUint16(length[:0], uint16(conLen))
+ resource.Header().Length = uint16(conLen)
+ return msg, nil
+}
+
+// A Parser allows incrementally parsing a DNS message.
+//
+// When parsing is started, the Header is parsed. Next, each Question can be
+// either parsed or skipped. Alternatively, all Questions can be skipped at
+// once. When all Questions have been parsed, attempting to parse Questions
+// will return (nil, nil) and attempting to skip Questions will return
+// (true, nil). After all Questions have been either parsed or skipped, all
+// Answers, Authorities and Additionals can be either parsed or skipped in the
+// same way, and each type of Resource must be fully parsed or skipped before
+// proceeding to the next type of Resource.
+//
+// Note that there is no requirement to fully skip or parse the message.
+type Parser struct {
+ msg []byte
+ header header
+
+ section section
+ off int
+ index int
+ resHeaderValid bool
+ resHeader ResourceHeader
+}
+
+// Start parses the header and enables the parsing of Questions.
+func (p *Parser) Start(msg []byte) (Header, error) {
+ if p.msg != nil {
+ *p = Parser{}
+ }
+ p.msg = msg
+ var err error
+ if p.off, err = p.header.unpack(msg, 0); err != nil {
+ return Header{}, &nestedError{"unpacking header", err}
+ }
+ p.section = sectionQuestions
+ return p.header.header(), nil
+}
+
+func (p *Parser) checkAdvance(sec section) error {
+ if p.section < sec {
+ return ErrNotStarted
+ }
+ if p.section > sec {
+ return ErrSectionDone
+ }
+ p.resHeaderValid = false
+ if p.index == int(p.header.count(sec)) {
+ p.index = 0
+ p.section++
+ return ErrSectionDone
+ }
+ return nil
+}
+
+func (p *Parser) resource(sec section) (Resource, error) {
+ var r Resource
+ hdr, err := p.resourceHeader(sec)
+ if err != nil {
+ return r, err
+ }
+ p.resHeaderValid = false
+ r, p.off, err = unpackResource(p.msg, p.off, hdr)
+ if err != nil {
+ return nil, &nestedError{"unpacking " + sectionNames[sec], err}
+ }
+ p.index++
+ return r, nil
+}
+
+func (p *Parser) resourceHeader(sec section) (ResourceHeader, error) {
+ if p.resHeaderValid {
+ return p.resHeader, nil
+ }
+ if err := p.checkAdvance(sec); err != nil {
+ return ResourceHeader{}, err
+ }
+ var hdr ResourceHeader
+ off, err := hdr.unpack(p.msg, p.off)
+ if err != nil {
+ return ResourceHeader{}, err
+ }
+ p.resHeaderValid = true
+ p.resHeader = hdr
+ p.off = off
+ return hdr, nil
+}
+
+func (p *Parser) skipResource(sec section) error {
+ if p.resHeaderValid {
+ newOff := p.off + int(p.resHeader.Length)
+ if newOff > len(p.msg) {
+ return errResourceLen
+ }
+ p.off = newOff
+ p.resHeaderValid = false
+ p.index++
+ return nil
+ }
+ if err := p.checkAdvance(sec); err != nil {
+ return err
+ }
+ var err error
+ p.off, err = skipResource(p.msg, p.off)
+ if err != nil {
+ return &nestedError{"skipping: " + sectionNames[sec], err}
+ }
+ p.index++
+ return nil
+}
+
+// Question parses a single Question.
+func (p *Parser) Question() (Question, error) {
+ if err := p.checkAdvance(sectionQuestions); err != nil {
+ return Question{}, err
+ }
+ name, off, err := unpackName(p.msg, p.off)
+ if err != nil {
+ return Question{}, &nestedError{"unpacking Question.Name", err}
+ }
+ typ, off, err := unpackType(p.msg, off)
+ if err != nil {
+ return Question{}, &nestedError{"unpacking Question.Type", err}
+ }
+ class, off, err := unpackClass(p.msg, off)
+ if err != nil {
+ return Question{}, &nestedError{"unpacking Question.Class", err}
+ }
+ p.off = off
+ p.index++
+ return Question{name, typ, class}, nil
+}
+
+// AllQuestions parses all Questions.
+func (p *Parser) AllQuestions() ([]Question, error) {
+ qs := make([]Question, 0, p.header.questions)
+ for {
+ q, err := p.Question()
+ if err == ErrSectionDone {
+ return qs, nil
+ }
+ if err != nil {
+ return nil, err
+ }
+ qs = append(qs, q)
+ }
+}
+
+// SkipQuestion skips a single Question.
+func (p *Parser) SkipQuestion() error {
+ if err := p.checkAdvance(sectionQuestions); err != nil {
+ return err
+ }
+ off, err := skipName(p.msg, p.off)
+ if err != nil {
+ return &nestedError{"skipping Question Name", err}
+ }
+ if off, err = skipType(p.msg, off); err != nil {
+ return &nestedError{"skipping Question Type", err}
+ }
+ if off, err = skipClass(p.msg, off); err != nil {
+ return &nestedError{"skipping Question Class", err}
+ }
+ p.off = off
+ p.index++
+ return nil
+}
+
+// SkipAllQuestions skips all Questions.
+func (p *Parser) SkipAllQuestions() error {
+ for {
+ if err := p.SkipQuestion(); err == ErrSectionDone {
+ return nil
+ } else if err != nil {
+ return err
+ }
+ }
+}
+
+// AnswerHeader parses a single Answer ResourceHeader.
+func (p *Parser) AnswerHeader() (ResourceHeader, error) {
+ return p.resourceHeader(sectionAnswers)
+}
+
+// Answer parses a single Answer Resource.
+func (p *Parser) Answer() (Resource, error) {
+ return p.resource(sectionAnswers)
+}
+
+// AllAnswers parses all Answer Resources.
+func (p *Parser) AllAnswers() ([]Resource, error) {
+ as := make([]Resource, 0, p.header.answers)
+ for {
+ a, err := p.Answer()
+ if err == ErrSectionDone {
+ return as, nil
+ }
+ if err != nil {
+ return nil, err
+ }
+ as = append(as, a)
+ }
+}
+
+// SkipAnswer skips a single Answer Resource.
+func (p *Parser) SkipAnswer() error {
+ return p.skipResource(sectionAnswers)
+}
+
+// SkipAllAnswers skips all Answer Resources.
+func (p *Parser) SkipAllAnswers() error {
+ for {
+ if err := p.SkipAnswer(); err == ErrSectionDone {
+ return nil
+ } else if err != nil {
+ return err
+ }
+ }
+}
+
+// AuthorityHeader parses a single Authority ResourceHeader.
+func (p *Parser) AuthorityHeader() (ResourceHeader, error) {
+ return p.resourceHeader(sectionAuthorities)
+}
+
+// Authority parses a single Authority Resource.
+func (p *Parser) Authority() (Resource, error) {
+ return p.resource(sectionAuthorities)
+}
+
+// AllAuthorities parses all Authority Resources.
+func (p *Parser) AllAuthorities() ([]Resource, error) {
+ as := make([]Resource, 0, p.header.authorities)
+ for {
+ a, err := p.Authority()
+ if err == ErrSectionDone {
+ return as, nil
+ }
+ if err != nil {
+ return nil, err
+ }
+ as = append(as, a)
+ }
+}
+
+// SkipAuthority skips a single Authority Resource.
+func (p *Parser) SkipAuthority() error {
+ return p.skipResource(sectionAuthorities)
+}
+
+// SkipAllAuthorities skips all Authority Resources.
+func (p *Parser) SkipAllAuthorities() error {
+ for {
+ if err := p.SkipAuthority(); err == ErrSectionDone {
+ return nil
+ } else if err != nil {
+ return err
+ }
+ }
+}
+
+// AdditionalHeader parses a single Additional ResourceHeader.
+func (p *Parser) AdditionalHeader() (ResourceHeader, error) {
+ return p.resourceHeader(sectionAdditionals)
+}
+
+// Additional parses a single Additional Resource.
+func (p *Parser) Additional() (Resource, error) {
+ return p.resource(sectionAdditionals)
+}
+
+// AllAdditionals parses all Additional Resources.
+func (p *Parser) AllAdditionals() ([]Resource, error) {
+ as := make([]Resource, 0, p.header.additionals)
+ for {
+ a, err := p.Additional()
+ if err == ErrSectionDone {
+ return as, nil
+ }
+ if err != nil {
+ return nil, err
+ }
+ as = append(as, a)
+ }
+}
+
+// SkipAdditional skips a single Additional Resource.
+func (p *Parser) SkipAdditional() error {
+ return p.skipResource(sectionAdditionals)
+}
+
+// SkipAllAdditionals skips all Additional Resources.
+func (p *Parser) SkipAllAdditionals() error {
+ for {
+ if err := p.SkipAdditional(); err == ErrSectionDone {
+ return nil
+ } else if err != nil {
+ return err
+ }
+ }
+}
+
+// Unpack parses a full Message.
+func (m *Message) Unpack(msg []byte) error {
+ var p Parser
+ var err error
+ if m.Header, err = p.Start(msg); err != nil {
+ return err
+ }
+ if m.Questions, err = p.AllQuestions(); err != nil {
+ return err
+ }
+ if m.Answers, err = p.AllAnswers(); err != nil {
+ return err
+ }
+ if m.Authorities, err = p.AllAuthorities(); err != nil {
+ return err
+ }
+ if m.Additionals, err = p.AllAdditionals(); err != nil {
+ return err
+ }
+ return nil
+}
+
+// Pack packs a full Message.
+func (m *Message) Pack() ([]byte, error) {
+ // Validate the lengths. It is very unlikely that anyone will try to
+ // pack more than 65535 of any particular type, but it is possible and
+ // we should fail gracefully.
+ if len(m.Questions) > int(^uint16(0)) {
+ return nil, errTooManyQuestions
+ }
+ if len(m.Answers) > int(^uint16(0)) {
+ return nil, errTooManyAnswers
+ }
+ if len(m.Authorities) > int(^uint16(0)) {
+ return nil, errTooManyAuthorities
+ }
+ if len(m.Additionals) > int(^uint16(0)) {
+ return nil, errTooManyAdditionals
+ }
+
+ var h header
+ h.id, h.bits = m.Header.pack()
+
+ h.questions = uint16(len(m.Questions))
+ h.answers = uint16(len(m.Answers))
+ h.authorities = uint16(len(m.Authorities))
+ h.additionals = uint16(len(m.Additionals))
+
+ // The starting capacity doesn't matter too much, but most DNS responses
+ // Will be <= 512 bytes as it is the limit for DNS over UDP.
+ msg := make([]byte, 0, 512)
+
+ msg = h.pack(msg)
+
+ // RFC 1035 allows (but does not require) compression for packing. RFC
+ // 1035 requires unpacking implementations to support compression, so
+ // unconditionally enabling it is fine.
+ //
+ // DNS lookups are typically done over UDP, and RFC 1035 states that UDP
+ // DNS packets can be a maximum of 512 bytes long. Without compression,
+ // many DNS response packets are over this limit, so enabling
+ // compression will help ensure compliance.
+ compression := map[string]int{}
+
+ for _, q := range m.Questions {
+ var err error
+ msg, err = q.pack(msg, compression)
+ if err != nil {
+ return nil, &nestedError{"packing Question", err}
+ }
+ }
+ for _, a := range m.Answers {
+ var err error
+ msg, err = packResource(msg, a, compression)
+ if err != nil {
+ return nil, &nestedError{"packing Answer", err}
+ }
+ }
+ for _, a := range m.Authorities {
+ var err error
+ msg, err = packResource(msg, a, compression)
+ if err != nil {
+ return nil, &nestedError{"packing Authority", err}
+ }
+ }
+ for _, a := range m.Additionals {
+ var err error
+ msg, err = packResource(msg, a, compression)
+ if err != nil {
+ return nil, &nestedError{"packing Additional", err}
+ }
+ }
+
+ return msg, nil
+}
+
+// An ResourceHeader is the header of a DNS resource record. There are
+// many types of DNS resource records, but they all share the same header.
+type ResourceHeader struct {
+ // Name is the domain name for which this resource record pertains.
+ Name string
+
+ // Type is the type of DNS resource record.
+ //
+ // This field will be set automatically during packing.
+ Type Type
+
+ // Class is the class of network to which this DNS resource record
+ // pertains.
+ Class Class
+
+ // TTL is the length of time (measured in seconds) which this resource
+ // record is valid for (time to live). All Resources in a set should
+ // have the same TTL (RFC 2181 Section 5.2).
+ TTL uint32
+
+ // Length is the length of data in the resource record after the header.
+ //
+ // This field will be set automatically during packing.
+ Length uint16
+}
+
+// Header implements Resource.Header.
+func (h *ResourceHeader) Header() *ResourceHeader {
+ return h
+}
+
+// pack packs all of the fields in a ResourceHeader except for the length. The
+// length bytes are returned as a slice so they can be filled in after the rest
+// of the Resource has been packed.
+func (h *ResourceHeader) pack(oldMsg []byte, compression map[string]int) (msg []byte, length []byte, err error) {
+ msg = oldMsg
+ if msg, err = packName(msg, h.Name, compression); err != nil {
+ return oldMsg, nil, &nestedError{"Name", err}
+ }
+ msg = packType(msg, h.Type)
+ msg = packClass(msg, h.Class)
+ msg = packUint32(msg, h.TTL)
+ lenBegin := len(msg)
+ msg = packUint16(msg, h.Length)
+ return msg, msg[lenBegin:], nil
+}
+
+func (h *ResourceHeader) unpack(msg []byte, off int) (int, error) {
+ newOff := off
+ var err error
+ if h.Name, newOff, err = unpackName(msg, newOff); err != nil {
+ return off, &nestedError{"Name", err}
+ }
+ if h.Type, newOff, err = unpackType(msg, newOff); err != nil {
+ return off, &nestedError{"Type", err}
+ }
+ if h.Class, newOff, err = unpackClass(msg, newOff); err != nil {
+ return off, &nestedError{"Class", err}
+ }
+ if h.TTL, newOff, err = unpackUint32(msg, newOff); err != nil {
+ return off, &nestedError{"TTL", err}
+ }
+ if h.Length, newOff, err = unpackUint16(msg, newOff); err != nil {
+ return off, &nestedError{"Length", err}
+ }
+ return newOff, nil
+}
+
+func skipResource(msg []byte, off int) (int, error) {
+ newOff, err := skipName(msg, off)
+ if err != nil {
+ return off, &nestedError{"Name", err}
+ }
+ if newOff, err = skipType(msg, newOff); err != nil {
+ return off, &nestedError{"Type", err}
+ }
+ if newOff, err = skipClass(msg, newOff); err != nil {
+ return off, &nestedError{"Class", err}
+ }
+ if newOff, err = skipUint32(msg, newOff); err != nil {
+ return off, &nestedError{"TTL", err}
+ }
+ length, newOff, err := unpackUint16(msg, newOff)
+ if err != nil {
+ return off, &nestedError{"Length", err}
+ }
+ if newOff += int(length); newOff > len(msg) {
+ return off, errResourceLen
+ }
+ return newOff, nil
+}
+
+func packUint16(msg []byte, field uint16) []byte {
+ return append(msg, byte(field>>8), byte(field))
+}
+
+func unpackUint16(msg []byte, off int) (uint16, int, error) {
+ if off+2 > len(msg) {
+ return 0, off, errBaseLen
+ }
+ return uint16(msg[off])<<8 | uint16(msg[off+1]), off + 2, nil
+}
+
+func skipUint16(msg []byte, off int) (int, error) {
+ if off+2 > len(msg) {
+ return off, errBaseLen
+ }
+ return off + 2, nil
+}
+
+func packType(msg []byte, field Type) []byte {
+ return packUint16(msg, uint16(field))
+}
+
+func unpackType(msg []byte, off int) (Type, int, error) {
+ t, o, err := unpackUint16(msg, off)
+ return Type(t), o, err
+}
+
+func skipType(msg []byte, off int) (int, error) {
+ return skipUint16(msg, off)
+}
+
+func packClass(msg []byte, field Class) []byte {
+ return packUint16(msg, uint16(field))
+}
+
+func unpackClass(msg []byte, off int) (Class, int, error) {
+ c, o, err := unpackUint16(msg, off)
+ return Class(c), o, err
+}
+
+func skipClass(msg []byte, off int) (int, error) {
+ return skipUint16(msg, off)
+}
+
+func packUint32(msg []byte, field uint32) []byte {
+ return append(
+ msg,
+ byte(field>>24),
+ byte(field>>16),
+ byte(field>>8),
+ byte(field),
+ )
+}
+
+func unpackUint32(msg []byte, off int) (uint32, int, error) {
+ if off+4 > len(msg) {
+ return 0, off, errBaseLen
+ }
+ v := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])
+ return v, off + 4, nil
+}
+
+func skipUint32(msg []byte, off int) (int, error) {
+ if off+4 > len(msg) {
+ return off, errBaseLen
+ }
+ return off + 4, nil
+}
+
+func packText(msg []byte, field string) []byte {
+ for len(field) > 0 {
+ l := len(field)
+ if l > 255 {
+ l = 255
+ }
+ msg = append(msg, byte(l))
+ msg = append(msg, field[:l]...)
+ field = field[l:]
+ }
+ return msg
+}
+
+func unpackText(msg []byte, off int) (string, int, error) {
+ if off >= len(msg) {
+ return "", off, errBaseLen
+ }
+ beginOff := off + 1
+ endOff := beginOff + int(msg[off])
+ if endOff > len(msg) {
+ return "", off, errCalcLen
+ }
+ return string(msg[beginOff:endOff]), endOff, nil
+}
+
+func skipText(msg []byte, off int) (int, error) {
+ if off >= len(msg) {
+ return off, errBaseLen
+ }
+ endOff := off + 1 + int(msg[off])
+ if endOff > len(msg) {
+ return off, errCalcLen
+ }
+ return endOff, nil
+}
+
+func packBytes(msg []byte, field []byte) []byte {
+ return append(msg, field...)
+}
+
+func unpackBytes(msg []byte, off int, field []byte) (int, error) {
+ newOff := off + len(field)
+ if newOff > len(msg) {
+ return off, errBaseLen
+ }
+ copy(field, msg[off:newOff])
+ return newOff, nil
+}
+
+func skipBytes(msg []byte, off int, field []byte) (int, error) {
+ newOff := off + len(field)
+ if newOff > len(msg) {
+ return off, errBaseLen
+ }
+ return newOff, nil
+}
+
+// packName packs a domain name.
+//
+// Domain names are a sequence of counted strings split at the dots. They end
+// with a zero-length string. Compression can be used to reuse domain suffixes.
+//
+// The compression map will be updated with new domain suffixes. If compression
+// is nil, compression will not be used.
+func packName(msg []byte, name string, compression map[string]int) ([]byte, error) {
+ oldMsg := msg
+
+ // Add a trailing dot to canonicalize name.
+ if n := len(name); n == 0 || name[n-1] != '.' {
+ name += "."
+ }
+
+ // Allow root domain.
+ if name == "." {
+ return append(msg, 0), nil
+ }
+
+ // Emit sequence of counted strings, chopping at dots.
+ for i, begin := 0, 0; i < len(name); i++ {
+ // Check for the end of the segment.
+ if name[i] == '.' {
+ // The two most significant bits have special meaning.
+ // It isn't allowed for segments to be long enough to
+ // need them.
+ if i-begin >= 1<<6 {
+ return oldMsg, errSegTooLong
+ }
+
+ // Segments must have a non-zero length.
+ if i-begin == 0 {
+ return oldMsg, errZeroSegLen
+ }
+
+ msg = append(msg, byte(i-begin))
+
+ for j := begin; j < i; j++ {
+ msg = append(msg, name[j])
+ }
+
+ begin = i + 1
+ continue
+ }
+
+ // We can only compress domain suffixes starting with a new
+ // segment. A pointer is two bytes with the two most significant
+ // bits set to 1 to indicate that it is a pointer.
+ if (i == 0 || name[i-1] == '.') && compression != nil {
+ if ptr, ok := compression[name[i:]]; ok {
+ // Hit. Emit a pointer instead of the rest of
+ // the domain.
+ return append(msg, byte(ptr>>8|0xC0), byte(ptr)), nil
+ }
+
+ // Miss. Add the suffix to the compression table if the
+ // offset can be stored in the available 14 bytes.
+ if len(msg) <= int(^uint16(0)>>2) {
+ compression[name[i:]] = len(msg)
+ }
+ }
+ }
+ return append(msg, 0), nil
+}
+
+// unpackName unpacks a domain name.
+func unpackName(msg []byte, off int) (string, int, error) {
+ // currOff is the current working offset.
+ currOff := off
+
+ // newOff is the offset where the next record will start. Pointers lead
+ // to data that belongs to other names and thus doesn't count towards to
+ // the usage of this name.
+ newOff := off
+
+ // name is the domain name being unpacked.
+ name := make([]byte, 0, 255)
+
+ // ptr is the number of pointers followed.
+ var ptr int
+Loop:
+ for {
+ if currOff >= len(msg) {
+ return "", off, errBaseLen
+ }
+ c := int(msg[currOff])
+ currOff++
+ switch c & 0xC0 {
+ case 0x00: // String segment
+ if c == 0x00 {
+ // A zero length signals the end of the name.
+ break Loop
+ }
+ endOff := currOff + c
+ if endOff > len(msg) {
+ return "", off, errCalcLen
+ }
+ name = append(name, msg[currOff:endOff]...)
+ name = append(name, '.')
+ currOff = endOff
+ case 0xC0: // Pointer
+ if currOff >= len(msg) {
+ return "", off, errInvalidPtr
+ }
+ c1 := msg[currOff]
+ currOff++
+ if ptr == 0 {
+ newOff = currOff
+ }
+ // Don't follow too many pointers, maybe there's a loop.
+ if ptr++; ptr > 10 {
+ return "", off, errTooManyPtr
+ }
+ currOff = (c^0xC0)<<8 | int(c1)
+ default:
+ // Prefixes 0x80 and 0x40 are reserved.
+ return "", off, errReserved
+ }
+ }
+ if len(name) == 0 {
+ name = append(name, '.')
+ }
+ if ptr == 0 {
+ newOff = currOff
+ }
+ return string(name), newOff, nil
+}
+
+func skipName(msg []byte, off int) (int, error) {
+ // newOff is the offset where the next record will start. Pointers lead
+ // to data that belongs to other names and thus doesn't count towards to
+ // the usage of this name.
+ newOff := off
+
+Loop:
+ for {
+ if newOff >= len(msg) {
+ return off, errBaseLen
+ }
+ c := int(msg[newOff])
+ newOff++
+ switch c & 0xC0 {
+ case 0x00:
+ if c == 0x00 {
+ // A zero length signals the end of the name.
+ break Loop
+ }
+ // literal string
+ newOff += c
+ if newOff > len(msg) {
+ return off, errCalcLen
+ }
+ case 0xC0:
+ // Pointer to somewhere else in msg.
+
+ // Pointers are two bytes.
+ newOff++
+
+ // Don't follow the pointer as the data here has ended.
+ break Loop
+ default:
+ // Prefixes 0x80 and 0x40 are reserved.
+ return off, errReserved
+ }
+ }
+
+ return newOff, nil
+}
+
+// A Question is a DNS query.
+type Question struct {
+ Name string
+ Type Type
+ Class Class
+}
+
+func (q *Question) pack(msg []byte, compression map[string]int) ([]byte, error) {
+ msg, err := packName(msg, q.Name, compression)
+ if err != nil {
+ return msg, &nestedError{"Name", err}
+ }
+ msg = packType(msg, q.Type)
+ return packClass(msg, q.Class), nil
+}
+
+func unpackResource(msg []byte, off int, hdr ResourceHeader) (Resource, int, error) {
+ var (
+ r Resource
+ err error
+ name string
+ )
+ switch hdr.Type {
+ case TypeA:
+ r, err = unpackAResource(hdr, msg, off)
+ name = "A"
+ case TypeNS:
+ r, err = unpackNSResource(hdr, msg, off)
+ name = "NS"
+ case TypeCNAME:
+ r, err = unpackCNAMEResource(hdr, msg, off)
+ name = "CNAME"
+ case TypeSOA:
+ r, err = unpackSOAResource(hdr, msg, off)
+ name = "SOA"
+ case TypePTR:
+ r, err = unpackPTRResource(hdr, msg, off)
+ name = "PTR"
+ case TypeMX:
+ r, err = unpackMXResource(hdr, msg, off)
+ name = "MX"
+ case TypeTXT:
+ r, err = unpackTXTResource(hdr, msg, off)
+ name = "TXT"
+ case TypeAAAA:
+ r, err = unpackAAAAResource(hdr, msg, off)
+ name = "AAAA"
+ case TypeSRV:
+ r, err = unpackSRVResource(hdr, msg, off)
+ name = "SRV"
+ }
+ if err != nil {
+ return nil, off, &nestedError{name + " record", err}
+ }
+ if r != nil {
+ return r, off + int(hdr.Length), nil
+ }
+ return nil, off, errors.New("invalid resource type: " + string(hdr.Type+'0'))
+}
+
+// A CNAMEResource is a CNAME Resource record.
+type CNAMEResource struct {
+ ResourceHeader
+
+ CNAME string
+}
+
+func (r *CNAMEResource) realType() Type {
+ return TypeCNAME
+}
+
+func (r *CNAMEResource) pack(msg []byte, compression map[string]int) ([]byte, error) {
+ return packName(msg, r.CNAME, compression)
+}
+
+func unpackCNAMEResource(hdr ResourceHeader, msg []byte, off int) (*CNAMEResource, error) {
+ cname, _, err := unpackName(msg, off)
+ if err != nil {
+ return nil, err
+ }
+ return &CNAMEResource{hdr, cname}, nil
+}
+
+// An MXResource is an MX Resource record.
+type MXResource struct {
+ ResourceHeader
+
+ Pref uint16
+ MX string
+}
+
+func (r *MXResource) realType() Type {
+ return TypeMX
+}
+
+func (r *MXResource) pack(msg []byte, compression map[string]int) ([]byte, error) {
+ oldMsg := msg
+ msg = packUint16(msg, r.Pref)
+ msg, err := packName(msg, r.MX, compression)
+ if err != nil {
+ return oldMsg, &nestedError{"MXResource.MX", err}
+ }
+ return msg, nil
+}
+
+func unpackMXResource(hdr ResourceHeader, msg []byte, off int) (*MXResource, error) {
+ pref, off, err := unpackUint16(msg, off)
+ if err != nil {
+ return nil, &nestedError{"Pref", err}
+ }
+ mx, _, err := unpackName(msg, off)
+ if err != nil {
+ return nil, &nestedError{"MX", err}
+ }
+ return &MXResource{hdr, pref, mx}, nil
+}
+
+// An NSResource is an NS Resource record.
+type NSResource struct {
+ ResourceHeader
+
+ NS string
+}
+
+func (r *NSResource) realType() Type {
+ return TypeNS
+}
+
+func (r *NSResource) pack(msg []byte, compression map[string]int) ([]byte, error) {
+ return packName(msg, r.NS, compression)
+}
+
+func unpackNSResource(hdr ResourceHeader, msg []byte, off int) (*NSResource, error) {
+ ns, _, err := unpackName(msg, off)
+ if err != nil {
+ return nil, err
+ }
+ return &NSResource{hdr, ns}, nil
+}
+
+// A PTRResource is a PTR Resource record.
+type PTRResource struct {
+ ResourceHeader
+
+ PTR string
+}
+
+func (r *PTRResource) realType() Type {
+ return TypePTR
+}
+
+func (r *PTRResource) pack(msg []byte, compression map[string]int) ([]byte, error) {
+ return packName(msg, r.PTR, compression)
+}
+
+func unpackPTRResource(hdr ResourceHeader, msg []byte, off int) (*PTRResource, error) {
+ ptr, _, err := unpackName(msg, off)
+ if err != nil {
+ return nil, err
+ }
+ return &PTRResource{hdr, ptr}, nil
+}
+
+// An SOAResource is an SOA Resource record.
+type SOAResource struct {
+ ResourceHeader
+
+ NS string
+ MBox string
+ Serial uint32
+ Refresh uint32
+ Retry uint32
+ Expire uint32
+
+ // MinTTL the is the default TTL of Resources records which did not
+ // contain a TTL value and the TTL of negative responses. (RFC 2308
+ // Section 4)
+ MinTTL uint32
+}
+
+func (r *SOAResource) realType() Type {
+ return TypeSOA
+}
+
+func (r *SOAResource) pack(msg []byte, compression map[string]int) ([]byte, error) {
+ oldMsg := msg
+ msg, err := packName(msg, r.NS, compression)
+ if err != nil {
+ return oldMsg, &nestedError{"SOAResource.NS", err}
+ }
+ msg, err = packName(msg, r.MBox, compression)
+ if err != nil {
+ return oldMsg, &nestedError{"SOAResource.MBox", err}
+ }
+ msg = packUint32(msg, r.Serial)
+ msg = packUint32(msg, r.Refresh)
+ msg = packUint32(msg, r.Retry)
+ msg = packUint32(msg, r.Expire)
+ return packUint32(msg, r.MinTTL), nil
+}
+
+func unpackSOAResource(hdr ResourceHeader, msg []byte, off int) (*SOAResource, error) {
+ ns, off, err := unpackName(msg, off)
+ if err != nil {
+ return nil, &nestedError{"NS", err}
+ }
+ mbox, off, err := unpackName(msg, off)
+ if err != nil {
+ return nil, &nestedError{"MBox", err}
+ }
+ serial, off, err := unpackUint32(msg, off)
+ if err != nil {
+ return nil, &nestedError{"Serial", err}
+ }
+ refresh, off, err := unpackUint32(msg, off)
+ if err != nil {
+ return nil, &nestedError{"Refresh", err}
+ }
+ retry, off, err := unpackUint32(msg, off)
+ if err != nil {
+ return nil, &nestedError{"Retry", err}
+ }
+ expire, off, err := unpackUint32(msg, off)
+ if err != nil {
+ return nil, &nestedError{"Expire", err}
+ }
+ minTTL, _, err := unpackUint32(msg, off)
+ if err != nil {
+ return nil, &nestedError{"MinTTL", err}
+ }
+ return &SOAResource{hdr, ns, mbox, serial, refresh, retry, expire, minTTL}, nil
+}
+
+// A TXTResource is a TXT Resource record.
+type TXTResource struct {
+ ResourceHeader
+
+ Txt string // Not a domain name.
+}
+
+func (r *TXTResource) realType() Type {
+ return TypeTXT
+}
+
+func (r *TXTResource) pack(msg []byte, compression map[string]int) ([]byte, error) {
+ return packText(msg, r.Txt), nil
+}
+
+func unpackTXTResource(hdr ResourceHeader, msg []byte, off int) (*TXTResource, error) {
+ var txt string
+ for n := uint16(0); n < hdr.Length; {
+ var t string
+ var err error
+ if t, off, err = unpackText(msg, off); err != nil {
+ return nil, &nestedError{"text", err}
+ }
+ // Check if we got too many bytes.
+ if hdr.Length-n < uint16(len(t))+1 {
+ return nil, errCalcLen
+ }
+ n += uint16(len(t)) + 1
+ txt += t
+ }
+ return &TXTResource{hdr, txt}, nil
+}
+
+// An SRVResource is an SRV Resource record.
+type SRVResource struct {
+ ResourceHeader
+
+ Priority uint16
+ Weight uint16
+ Port uint16
+ Target string // Not compressed as per RFC 2782.
+}
+
+func (r *SRVResource) realType() Type {
+ return TypeSRV
+}
+
+func (r *SRVResource) pack(msg []byte, compression map[string]int) ([]byte, error) {
+ oldMsg := msg
+ msg = packUint16(msg, r.Priority)
+ msg = packUint16(msg, r.Weight)
+ msg = packUint16(msg, r.Port)
+ msg, err := packName(msg, r.Target, nil)
+ if err != nil {
+ return oldMsg, &nestedError{"SRVResource.Target", err}
+ }
+ return msg, nil
+}
+
+func unpackSRVResource(hdr ResourceHeader, msg []byte, off int) (*SRVResource, error) {
+ priority, off, err := unpackUint16(msg, off)
+ if err != nil {
+ return nil, &nestedError{"Priority", err}
+ }
+ weight, off, err := unpackUint16(msg, off)
+ if err != nil {
+ return nil, &nestedError{"Weight", err}
+ }
+ port, off, err := unpackUint16(msg, off)
+ if err != nil {
+ return nil, &nestedError{"Port", err}
+ }
+ target, _, err := unpackName(msg, off)
+ if err != nil {
+ return nil, &nestedError{"Target", err}
+ }
+ return &SRVResource{hdr, priority, weight, port, target}, nil
+}
+
+// An AResource is an A Resource record.
+type AResource struct {
+ ResourceHeader
+
+ A [4]byte
+}
+
+func (r *AResource) realType() Type {
+ return TypeA
+}
+
+func (r *AResource) pack(msg []byte, compression map[string]int) ([]byte, error) {
+ return packBytes(msg, r.A[:]), nil
+}
+
+func unpackAResource(hdr ResourceHeader, msg []byte, off int) (*AResource, error) {
+ var a [4]byte
+ if _, err := unpackBytes(msg, off, a[:]); err != nil {
+ return nil, err
+ }
+ return &AResource{hdr, a}, nil
+}
+
+// An AAAAResource is an AAAA Resource record.
+type AAAAResource struct {
+ ResourceHeader
+
+ AAAA [16]byte
+}
+
+func (r *AAAAResource) realType() Type {
+ return TypeAAAA
+}
+
+func (r *AAAAResource) pack(msg []byte, compression map[string]int) ([]byte, error) {
+ return packBytes(msg, r.AAAA[:]), nil
+}
+
+func unpackAAAAResource(hdr ResourceHeader, msg []byte, off int) (*AAAAResource, error) {
+ var aaaa [16]byte
+ if _, err := unpackBytes(msg, off, aaaa[:]); err != nil {
+ return nil, err
+ }
+ return &AAAAResource{hdr, aaaa}, nil
+}
+
+// NegativeResource indicates that the DNS entry doesn't exist. Used only for caching.
+type NegativeResource struct {
+ ResourceHeader
+}
+
+func (r *NegativeResource) realType() Type {
+ return Type(0)
+}
+
+func (r *NegativeResource) pack(msg []byte, compression map[string]int) ([]byte, error) {
+ return nil, errors.New("cannot pack negativeResource")
+}
diff --git a/dns/dnsmessage/message_test.go b/dns/dnsmessage/message_test.go
new file mode 100644
index 0000000..46edd72
--- /dev/null
+++ b/dns/dnsmessage/message_test.go
@@ -0,0 +1,575 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package dnsmessage
+
+import (
+ "fmt"
+ "net"
+ "reflect"
+ "strings"
+ "testing"
+)
+
+func (m *Message) String() string {
+ s := fmt.Sprintf("Message: %#v\n", &m.Header)
+ if len(m.Questions) > 0 {
+ s += "-- Questions\n"
+ for _, q := range m.Questions {
+ s += fmt.Sprintf("%#v\n", q)
+ }
+ }
+ if len(m.Answers) > 0 {
+ s += "-- Answers\n"
+ for _, a := range m.Answers {
+ s += fmt.Sprintf("%#v\n", a)
+ }
+ }
+ if len(m.Authorities) > 0 {
+ s += "-- Authorities\n"
+ for _, ns := range m.Authorities {
+ s += fmt.Sprintf("%#v\n", ns)
+ }
+ }
+ if len(m.Additionals) > 0 {
+ s += "-- Additionals\n"
+ for _, e := range m.Additionals {
+ s += fmt.Sprintf("%#v\n", e)
+ }
+ }
+ return s
+}
+
+func TestQuestionPackUnpack(t *testing.T) {
+ want := Question{
+ Name: ".",
+ Type: TypeA,
+ Class: ClassINET,
+ }
+ buf, err := want.pack(make([]byte, 1, 50), map[string]int{})
+ if err != nil {
+ t.Fatal("Packing failed:", err)
+ }
+ var p Parser
+ p.msg = buf
+ p.header.questions = 1
+ p.section = sectionQuestions
+ p.off = 1
+ got, err := p.Question()
+ if err != nil {
+ t.Fatalf("Unpacking failed: %v\n%s", err, string(buf[1:]))
+ }
+ if p.off != len(buf) {
+ t.Errorf("Unpacked different amount than packed: got n = %d, want = %d", p.off, len(buf))
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Got = %+v, want = %+v", got, want)
+ }
+}
+
+func TestNamePackUnpack(t *testing.T) {
+ tests := []struct {
+ in string
+ want string
+ err error
+ }{
+ {"", ".", nil},
+ {".", ".", nil},
+ {"google..com", "", errZeroSegLen},
+ {"google.com", "google.com.", nil},
+ {"google..com.", "", errZeroSegLen},
+ {"google.com.", "google.com.", nil},
+ {".google.com.", "", errZeroSegLen},
+ {"www..google.com.", "", errZeroSegLen},
+ {"www.google.com.", "www.google.com.", nil},
+ }
+
+ for _, test := range tests {
+ buf, err := packName(make([]byte, 0, 30), test.in, map[string]int{})
+ if err != test.err {
+ t.Errorf("Packing of %s: got err = %v, want err = %v", test.in, err, test.err)
+ continue
+ }
+ if test.err != nil {
+ continue
+ }
+ got, n, err := unpackName(buf, 0)
+ if err != nil {
+ t.Errorf("Unpacking for %s failed: %v", test.in, err)
+ continue
+ }
+ if n != len(buf) {
+ t.Errorf(
+ "Unpacked different amount than packed for %s: got n = %d, want = %d",
+ test.in,
+ n,
+ len(buf),
+ )
+ }
+ if got != test.want {
+ t.Errorf("Unpacking packing of %s: got = %s, want = %s", test.in, got, test.want)
+ }
+ }
+}
+
+func TestDNSPackUnpack(t *testing.T) {
+ wants := []Message{
+ {
+ Questions: []Question{
+ {
+ Name: ".",
+ Type: TypeAAAA,
+ Class: ClassINET,
+ },
+ },
+ Answers: []Resource{},
+ Authorities: []Resource{},
+ Additionals: []Resource{},
+ },
+ largeTestMsg(),
+ }
+ for i, want := range wants {
+ b, err := want.Pack()
+ if err != nil {
+ t.Fatalf("%d: packing failed: %v", i, err)
+ }
+ var got Message
+ err = got.Unpack(b)
+ if err != nil {
+ t.Fatalf("%d: unpacking failed: %v", i, err)
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("%d: got = %+v, want = %+v", i, &got, &want)
+ }
+ }
+}
+
+func TestSkipAll(t *testing.T) {
+ msg := largeTestMsg()
+ buf, err := msg.Pack()
+ if err != nil {
+ t.Fatal("Packing large test message:", err)
+ }
+ var p Parser
+ if _, err := p.Start(buf); err != nil {
+ t.Fatal(err)
+ }
+
+ tests := []struct {
+ name string
+ f func() error
+ }{
+ {"SkipAllQuestions", p.SkipAllQuestions},
+ {"SkipAllAnswers", p.SkipAllAnswers},
+ {"SkipAllAuthorities", p.SkipAllAuthorities},
+ {"SkipAllAdditionals", p.SkipAllAdditionals},
+ }
+ for _, test := range tests {
+ for i := 1; i <= 3; i++ {
+ if err := test.f(); err != nil {
+ t.Errorf("Call #%d to %s(): %v", i, test.name, err)
+ }
+ }
+ }
+}
+
+func TestSkipNotStarted(t *testing.T) {
+ var p Parser
+
+ tests := []struct {
+ name string
+ f func() error
+ }{
+ {"SkipAllQuestions", p.SkipAllQuestions},
+ {"SkipAllAnswers", p.SkipAllAnswers},
+ {"SkipAllAuthorities", p.SkipAllAuthorities},
+ {"SkipAllAdditionals", p.SkipAllAdditionals},
+ }
+ for _, test := range tests {
+ if err := test.f(); err != ErrNotStarted {
+ t.Errorf("Got %s() = %v, want = %v", test.name, err, ErrNotStarted)
+ }
+ }
+}
+
+func TestTooManyRecords(t *testing.T) {
+ const recs = int(^uint16(0)) + 1
+ tests := []struct {
+ name string
+ msg Message
+ want error
+ }{
+ {
+ "Questions",
+ Message{
+ Questions: make([]Question, recs),
+ },
+ errTooManyQuestions,
+ },
+ {
+ "Answers",
+ Message{
+ Answers: make([]Resource, recs),
+ },
+ errTooManyAnswers,
+ },
+ {
+ "Authorities",
+ Message{
+ Authorities: make([]Resource, recs),
+ },
+ errTooManyAuthorities,
+ },
+ {
+ "Additionals",
+ Message{
+ Additionals: make([]Resource, recs),
+ },
+ errTooManyAdditionals,
+ },
+ }
+
+ for _, test := range tests {
+ if _, got := test.msg.Pack(); got != test.want {
+ t.Errorf("Packing %d %s: got = %v, want = %v", recs, test.name, got, test.want)
+ }
+ }
+}
+
+func TestVeryLongTxt(t *testing.T) {
+ want := &TXTResource{
+ ResourceHeader: ResourceHeader{
+ Name: "foo.bar.example.com.",
+ Type: TypeTXT,
+ Class: ClassINET,
+ },
+ Txt: loremIpsum,
+ }
+ buf, err := packResource(make([]byte, 0, 8000), want, map[string]int{})
+ if err != nil {
+ t.Fatal("Packing failed:", err)
+ }
+ var hdr ResourceHeader
+ off, err := hdr.unpack(buf, 0)
+ if err != nil {
+ t.Fatal("Unpacking ResourceHeader failed:", err)
+ }
+ got, n, err := unpackResource(buf, off, hdr)
+ if err != nil {
+ t.Fatal("Unpacking failed:", err)
+ }
+ if n != len(buf) {
+ t.Errorf("Unpacked different amount than packed: got n = %d, want = %d", n, len(buf))
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Got = %+v, want = %+v", got, want)
+ }
+}
+
+func ExampleHeaderSearch() {
+ msg := Message{
+ Header: Header{Response: true, Authoritative: true},
+ Questions: []Question{
+ {
+ Name: "foo.bar.example.com.",
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ {
+ Name: "bar.example.com.",
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ },
+ Answers: []Resource{
+ &AResource{
+ ResourceHeader: ResourceHeader{
+ Name: "foo.bar.example.com.",
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ A: [4]byte{127, 0, 0, 1},
+ },
+ &AResource{
+ ResourceHeader: ResourceHeader{
+ Name: "bar.example.com.",
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ A: [4]byte{127, 0, 0, 2},
+ },
+ },
+ }
+
+ buf, err := msg.Pack()
+ if err != nil {
+ panic(err)
+ }
+
+ wantName := "bar.example.com."
+
+ var p Parser
+ if _, err := p.Start(buf); err != nil {
+ panic(err)
+ }
+
+ for {
+ q, err := p.Question()
+ if err == ErrSectionDone {
+ break
+ }
+ if err != nil {
+ panic(err)
+ }
+
+ if q.Name != wantName {
+ continue
+ }
+
+ fmt.Println("Found question for name", wantName)
+ if err := p.SkipAllQuestions(); err != nil {
+ panic(err)
+ }
+ break
+ }
+
+ var gotIPs []net.IP
+ for {
+ h, err := p.AnswerHeader()
+ if err == ErrSectionDone {
+ break
+ }
+ if err != nil {
+ panic(err)
+ }
+
+ if (h.Type != TypeA && h.Type != TypeAAAA) || h.Class != ClassINET {
+ continue
+ }
+
+ if !strings.EqualFold(h.Name, wantName) {
+ if err := p.SkipAnswer(); err != nil {
+ panic(err)
+ }
+ continue
+ }
+ a, err := p.Answer()
+ if err != nil {
+ panic(err)
+ }
+
+ switch r := a.(type) {
+ default:
+ panic(fmt.Sprintf("unknown type: %T", r))
+ case *AResource:
+ gotIPs = append(gotIPs, r.A[:])
+ case *AAAAResource:
+ gotIPs = append(gotIPs, r.AAAA[:])
+ }
+ }
+
+ fmt.Printf("Found A/AAAA records for name %s: %v\n", wantName, gotIPs)
+
+ // Output:
+ // Found question for name bar.example.com.
+ // Found A/AAAA records for name bar.example.com.: [127.0.0.2]
+}
+
+func largeTestMsg() Message {
+ return Message{
+ Header: Header{Response: true, Authoritative: true},
+ Questions: []Question{
+ {
+ Name: "foo.bar.example.com.",
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ },
+ Answers: []Resource{
+ &AResource{
+ ResourceHeader: ResourceHeader{
+ Name: "foo.bar.example.com.",
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ A: [4]byte{127, 0, 0, 1},
+ },
+ &AResource{
+ ResourceHeader: ResourceHeader{
+ Name: "foo.bar.example.com.",
+ Type: TypeA,
+ Class: ClassINET,
+ },
+ A: [4]byte{127, 0, 0, 2},
+ },
+ },
+ Authorities: []Resource{
+ &NSResource{
+ ResourceHeader: ResourceHeader{
+ Name: "foo.bar.example.com.",
+ Type: TypeNS,
+ Class: ClassINET,
+ },
+ NS: "ns1.example.com.",
+ },
+ &NSResource{
+ ResourceHeader: ResourceHeader{
+ Name: "foo.bar.example.com.",
+ Type: TypeNS,
+ Class: ClassINET,
+ },
+ NS: "ns2.example.com.",
+ },
+ },
+ Additionals: []Resource{
+ &TXTResource{
+ ResourceHeader: ResourceHeader{
+ Name: "foo.bar.example.com.",
+ Type: TypeTXT,
+ Class: ClassINET,
+ },
+ Txt: "So Long, and Thanks for All the Fish",
+ },
+ &TXTResource{
+ ResourceHeader: ResourceHeader{
+ Name: "foo.bar.example.com.",
+ Type: TypeTXT,
+ Class: ClassINET,
+ },
+ Txt: "Hamster Huey and the Gooey Kablooie",
+ },
+ },
+ }
+}
+
+const loremIpsum = `
+Lorem ipsum dolor sit amet, nec enim antiopam id, an ullum choro
+nonumes qui, pro eu debet honestatis mediocritatem. No alia enim eos,
+magna signiferumque ex vis. Mei no aperiri dissentias, cu vel quas
+regione. Malorum quaeque vim ut, eum cu semper aliquid invidunt, ei
+nam ipsum assentior.
+
+Nostrum appellantur usu no, vis ex probatus adipiscing. Cu usu illum
+facilis eleifend. Iusto conceptam complectitur vim id. Tale omnesque
+no usu, ei oblique sadipscing vim. At nullam voluptua usu, mei laudem
+reformidans et. Qui ei eros porro reformidans, ius suas veritus
+torquatos ex. Mea te facer alterum consequat.
+
+Soleat torquatos democritum sed et, no mea congue appareat, facer
+aliquam nec in. Has te ipsum tritani. At justo dicta option nec, movet
+phaedrum ad nam. Ea detracto verterem liberavisse has, delectus
+suscipiantur in mei. Ex nam meliore complectitur. Ut nam omnis
+honestatis quaerendum, ea mea nihil affert detracto, ad vix rebum
+mollis.
+
+Ut epicurei praesent neglegentur pri, prima fuisset intellegebat ad
+vim. An habemus comprehensam usu, at enim dignissim pro. Eam reque
+vivendum adipisci ea. Vel ne odio choro minimum. Sea admodum
+dissentiet ex. Mundi tamquam evertitur ius cu. Homero postea iisque ut
+pro, vel ne saepe senserit consetetur.
+
+Nulla utamur facilisis ius ea, in viderer diceret pertinax eum. Mei no
+enim quodsi facilisi, ex sed aeterno appareat mediocritatem, eum
+sententiae deterruisset ut. At suas timeam euismod cum, offendit
+appareat interpretaris ne vix. Vel ea civibus albucius, ex vim quidam
+accusata intellegebat, noluisse instructior sea id. Nec te nonumes
+habemus appellantur, quis dignissim vituperata eu nam.
+
+At vix apeirian patrioque vituperatoribus, an usu agam assum. Debet
+iisque an mea. Per eu dicant ponderum accommodare. Pri alienum
+placerat senserit an, ne eum ferri abhorreant vituperatoribus. Ut mea
+eligendi disputationi. Ius no tation everti impedit, ei magna quidam
+mediocritatem pri.
+
+Legendos perpetua iracundia ne usu, no ius ullum epicurei intellegam,
+ad modus epicuri lucilius eam. In unum quaerendum usu. Ne diam paulo
+has, ea veri virtute sed. Alia honestatis conclusionemque mea eu, ut
+iudico albucius his.
+
+Usu essent probatus eu, sed omnis dolor delicatissimi ex. No qui augue
+dissentias dissentiet. Laudem recteque no usu, vel an velit noluisse,
+an sed utinam eirmod appetere. Ne mea fuisset inimicus ocurreret. At
+vis dicant abhorreant, utinam forensibus nec ne, mei te docendi
+consequat. Brute inermis persecuti cum id. Ut ipsum munere propriae
+usu, dicit graeco disputando id has.
+
+Eros dolore quaerendum nam ei. Timeam ornatus inciderint pro id. Nec
+torquatos sadipscing ei, ancillae molestie per in. Malis principes duo
+ea, usu liber postulant ei.
+
+Graece timeam voluptatibus eu eam. Alia probatus quo no, ea scripta
+feugiat duo. Congue option meliore ex qui, noster invenire appellantur
+ea vel. Eu exerci legendos vel. Consetetur repudiandae vim ut. Vix an
+probo minimum, et nam illud falli tempor.
+
+Cum dico signiferumque eu. Sed ut regione maiorum, id veritus insolens
+tacimates vix. Eu mel sint tamquam lucilius, duo no oporteat
+tacimates. Atqui augue concludaturque vix ei, id mel utroque menandri.
+
+Ad oratio blandit aliquando pro. Vis et dolorum rationibus
+philosophia, ad cum nulla molestie. Hinc fuisset adversarium eum et,
+ne qui nisl verear saperet, vel te quaestio forensibus. Per odio
+option delenit an. Alii placerat has no, in pri nihil platonem
+cotidieque. Est ut elit copiosae scaevola, debet tollit maluisset sea
+an.
+
+Te sea hinc debet pericula, liber ridens fabulas cu sed, quem mutat
+accusam mea et. Elitr labitur albucius et pri, an labore feugait mel.
+Velit zril melius usu ea. Ad stet putent interpretaris qui. Mel no
+error volumus scripserit. In pro paulo iudico, quo ei dolorem
+verterem, affert fabellas dissentiet ea vix.
+
+Vis quot deserunt te. Error aliquid detraxit eu usu, vis alia eruditi
+salutatus cu. Est nostrud bonorum an, ei usu alii salutatus. Vel at
+nisl primis, eum ex aperiri noluisse reformidans. Ad veri velit
+utroque vis, ex equidem detraxit temporibus has.
+
+Inermis appareat usu ne. Eros placerat periculis mea ad, in dictas
+pericula pro. Errem postulant at usu, ea nec amet ornatus mentitum. Ad
+mazim graeco eum, vel ex percipit volutpat iudicabit, sit ne delicata
+interesset. Mel sapientem prodesset abhorreant et, oblique suscipit
+eam id.
+
+An maluisset disputando mea, vidit mnesarchum pri et. Malis insolens
+inciderint no sea. Ea persius maluisset vix, ne vim appellantur
+instructior, consul quidam definiebas pri id. Cum integre feugiat
+pericula in, ex sed persius similique, mel ne natum dicit percipitur.
+
+Primis discere ne pri, errem putent definitionem at vis. Ei mel dolore
+neglegentur, mei tincidunt percipitur ei. Pro ad simul integre
+rationibus. Eu vel alii honestatis definitiones, mea no nonumy
+reprehendunt.
+
+Dicta appareat legendos est cu. Eu vel congue dicunt omittam, no vix
+adhuc minimum constituam, quot noluisse id mel. Eu quot sale mutat
+duo, ex nisl munere invenire duo. Ne nec ullum utamur. Pro alterum
+debitis nostrum no, ut vel aliquid vivendo.
+
+Aliquip fierent praesent quo ne, id sit audiam recusabo delicatissimi.
+Usu postulant incorrupte cu. At pro dicit tibique intellegam, cibo
+dolore impedit id eam, et aeque feugait assentior has. Quando sensibus
+nec ex. Possit sensibus pri ad, unum mutat periculis cu vix.
+
+Mundi tibique vix te, duo simul partiendo qualisque id, est at vidit
+sonet tempor. No per solet aeterno deseruisse. Petentium salutandi
+definiebas pri cu. Munere vivendum est in. Ei justo congue eligendi
+vis, modus offendit omittantur te mel.
+
+Integre voluptaria in qui, sit habemus tractatos constituam no. Utinam
+melius conceptam est ne, quo in minimum apeirian delicata, ut ius
+porro recusabo. Dicant expetenda vix no, ludus scripserit sed ex, eu
+his modo nostro. Ut etiam sonet his, quodsi inciderint philosophia te
+per. Nullam lobortis eu cum, vix an sonet efficiendi repudiandae. Vis
+ad idque fabellas intellegebat.
+
+Eum commodo senserit conclusionemque ex. Sed forensibus sadipscing ut,
+mei in facer delicata periculis, sea ne hinc putent cetero. Nec ne
+alia corpora invenire, alia prima soleat te cum. Eleifend posidonium
+nam at.
+
+Dolorum indoctum cu quo, ex dolor legendos recteque eam, cu pri zril
+discere. Nec civibus officiis dissentiunt ex, est te liber ludus
+elaboraret. Cum ea fabellas invenire. Ex vim nostrud eripuit
+comprehensam, nam te inermis delectus, saepe inermis senserit.
+`
diff --git a/tcpip/header/icmpv6.go b/tcpip/header/icmpv6.go
new file mode 100644
index 0000000..18d8d4f
--- /dev/null
+++ b/tcpip/header/icmpv6.go
@@ -0,0 +1,67 @@
+// Copyright 2017 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "github.com/google/netstack/tcpip"
+)
+
+// ICMPv6 represents an ICMPv6 header stored in a byte array.
+type ICMPv6 []byte
+
+const (
+ // ICMPv6MinimumSize is the minimum size of a valid ICMP packet.
+ ICMPv6MinimumSize = 4
+
+ // ICMPv6ProtocolNumber is the ICMP transport protocol number.
+ ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58
+
+ // ICMPv6NeighborSolicitMinimumSize is the minimum size of a
+ // neighbor solicitation packet.
+ ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 4 + 16
+
+ // ICMPv6NeighborSolicitMinimumSize is size of a neighbor advertisement.
+ ICMPv6NeighborAdvertSize = 32
+
+ // ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet.
+ ICMPv6EchoMinimumSize = 8
+)
+
+// ICMPv6Type is the ICMP type field described in RFC 4443 and friends.
+type ICMPv6Type byte
+
+// Typical values of ICMPv6Type defined in RFC 792.
+const (
+ ICMPv6DstUnreachable ICMPv6Type = 1
+ ICMPv6PacketTooBig ICMPv6Type = 2
+ ICMPv6TimeExceeded ICMPv6Type = 3
+ ICMPv6ParamProblem ICMPv6Type = 4
+ ICMPv6EchoRequest ICMPv6Type = 128
+ ICMPv6EchoReply ICMPv6Type = 129
+
+ // Neighbor Discovery Protocol (NDP) messages, see RFC 4861.
+
+ ICMPv6RouterSolicit ICMPv6Type = 133
+ ICMPv6RouterAdvert ICMPv6Type = 134
+ ICMPv6NeighborSolicit ICMPv6Type = 135
+ ICMPv6NeighborAdvert ICMPv6Type = 136
+ ICMPv6RedirectMsg ICMPv6Type = 137
+)
+
+// Type is the ICMP type field.
+func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) }
+
+// SetType sets the ICMP type field.
+func (b ICMPv6) SetType(t ICMPv6Type) { b[0] = byte(t) }
+
+// Code is the ICMP code field. Its meaning depends on the value of Type.
+func (b ICMPv6) Code() byte { return b[1] }
+
+// SetChecksum calculates and sets the ICMP checksum field.
+func (b ICMPv6) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[2:], checksum)
+}
diff --git a/tcpip/header/ipv4.go b/tcpip/header/ipv4.go
index f5ad882..a496b5b 100644
--- a/tcpip/header/ipv4.go
+++ b/tcpip/header/ipv4.go
@@ -83,6 +83,9 @@
// IPv4Version is the version of the ipv4 procotol.
IPv4Version = 4
+
+ // IPv4Loopback is the loopback address of the IPv4 procotol.
+ IPv4Loopback tcpip.Address = "\x7f\x00\x00\x01"
)
// Flags that may be set in an IPv4 packet.
@@ -249,3 +252,12 @@
return true
}
+
+// IsV4MulticastAddress determines if the provided address is an IPv4
+// multicast address (range 224.0.0.0 to 239.255.255.255).
+func IsV4MulticastAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv4AddressSize {
+ return false
+ }
+ return (addr[0] & 0xf0) == 0xe0
+}
diff --git a/tcpip/header/ipv6.go b/tcpip/header/ipv6.go
index f413bbb..872fc3b 100644
--- a/tcpip/header/ipv6.go
+++ b/tcpip/header/ipv6.go
@@ -66,6 +66,9 @@
// IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460,
// section 5.
IPv6MinimumMTU = 1280
+
+ // IPv6Loopback is the loopback address of the IPv6 procotol.
+ IPv6Loopback tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
)
// PayloadLength returns the value of the "payload length" field of the ipv6
@@ -184,3 +187,12 @@
return true
}
+
+// IsV6MulticastAddress determines if the provided address is an IPv6
+// multicast address (anything starting with FF).
+func IsV6MulticastAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+ return addr[0] == 0xff
+}
diff --git a/tcpip/link/channel/channel.go b/tcpip/link/channel/channel.go
index 393d82c..aaf23f9 100644
--- a/tcpip/link/channel/channel.go
+++ b/tcpip/link/channel/channel.go
@@ -61,6 +61,12 @@
e.dispatcher.DeliverNetworkPacket(e, "", protocol, &uu)
}
+// InjectLinkAddr injects an inbound packet with a remote link address.
+func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, vv *buffer.VectorisedView) {
+ uu := vv.Clone(nil)
+ e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, protocol, &uu)
+}
+
// Attach saves the stack network-layer dispatcher for use later when packets
// are injected.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
diff --git a/tcpip/network/arp/arp.go b/tcpip/network/arp/arp.go
index cf08a39..5707f97 100644
--- a/tcpip/network/arp/arp.go
+++ b/tcpip/network/arp/arp.go
@@ -58,6 +58,9 @@
return e.linkEP.MaxHeaderLength() + header.ARPSize
}
+func (e *endpoint) SetTTL(_ uint8) {
+}
+
func (e *endpoint) Close() {}
func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
diff --git a/tcpip/network/ipv4/ipv4.go b/tcpip/network/ipv4/ipv4.go
index bf2950a..4b64763 100644
--- a/tcpip/network/ipv4/ipv4.go
+++ b/tcpip/network/ipv4/ipv4.go
@@ -46,6 +46,7 @@
dispatcher stack.TransportDispatcher
echoRequests chan echoRequest
fragmentation fragmentation.Fragmentation
+ ttl uint8
}
func newEndpoint(nicid tcpip.NICID, addr tcpip.Address, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) *endpoint {
@@ -55,6 +56,7 @@
dispatcher: dispatcher,
echoRequests: make(chan echoRequest, 10),
fragmentation: fragmentation.NewFragmentation(fragmentation.MemoryLimit, fragmentation.DefaultReassembleTimeout),
+ ttl: 65,
}
copy(e.address[:], addr)
e.id = stack.NetworkEndpointID{tcpip.Address(e.address[:])}
@@ -90,6 +92,12 @@
return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize
}
+// SetTTL sets the default time-to-live value for packets sent through
+// this endpoint.
+func (e *endpoint) SetTTL(ttl uint8) {
+ e.ttl = ttl
+}
+
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
@@ -104,7 +112,7 @@
IHL: header.IPv4MinimumSize,
TotalLength: length,
ID: uint16(id),
- TTL: 65,
+ TTL: e.ttl,
Protocol: uint8(protocol),
SrcAddr: tcpip.Address(e.address[:]),
DstAddr: r.RemoteAddress,
diff --git a/tcpip/network/ipv6/icmp.go b/tcpip/network/ipv6/icmp.go
new file mode 100644
index 0000000..0c7e5e6
--- /dev/null
+++ b/tcpip/network/ipv6/icmp.go
@@ -0,0 +1,130 @@
+// Copyright 2017 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ipv6
+
+import (
+ "encoding/binary"
+ "log"
+
+ "github.com/google/netstack/tcpip"
+ "github.com/google/netstack/tcpip/buffer"
+ "github.com/google/netstack/tcpip/header"
+ "github.com/google/netstack/tcpip/stack"
+)
+
+const (
+ ndpSolicitedFlag = 1 << 6
+ ndpOverrideFlag = 1 << 5
+
+ ndpOptSrcLinkAddr = 1
+ ndpOptDstLinkAddr = 2
+)
+
+func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) {
+ v := vv.First()
+ if len(v) < header.ICMPv6MinimumSize {
+ return
+ }
+ h := header.ICMPv6(v)
+
+ switch h.Type() {
+ case header.ICMPv6NeighborSolicit:
+ if len(v) < header.ICMPv6NeighborSolicitMinimumSize {
+ return
+ }
+ targetAddr := tcpip.Address(v[8 : 8+16])
+ if e.linkAddrCache.CheckLocalAddress(e.nicid, targetAddr) == 0 {
+ return // we have no useful answer, ignore the request
+ }
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ pkt[4] = ndpSolicitedFlag | ndpOverrideFlag
+ copy(pkt[8:24], v[8:])
+ pkt[24] = ndpOptDstLinkAddr
+ pkt[25] = 1 // address length
+ copy(pkt[26:], r.LocalLinkAddress[:])
+ r.LocalAddress = targetAddr
+ pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, nil))
+ r.WritePacket(&hdr, nil, header.ICMPv6ProtocolNumber)
+
+ e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
+
+ case header.ICMPv6NeighborAdvert:
+ if len(v) < header.ICMPv6NeighborAdvertSize {
+ return
+ }
+ e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
+
+ case header.ICMPv6EchoRequest:
+ if len(v) < header.ICMPv6EchoMinimumSize {
+ return
+ }
+ data := v[header.ICMPv6EchoMinimumSize:]
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+ copy(pkt, h)
+ pkt.SetType(header.ICMPv6EchoReply)
+ pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, data))
+ r.WritePacket(&hdr, data, header.ICMPv6ProtocolNumber)
+ default:
+ log.Printf("got ICMPv6: type=%v, code=%v, len(v)=%d", h.Type(), h.Code(), len(v))
+ }
+ // TODO case header.ICMPv6EchoReply
+}
+
+var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
+
+func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+ // Solicited-Node multicast address, used for NDP. Described in RFC 4291.
+ snaddr := "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff" + addr[len(addr)-3:]
+ r := &stack.Route{
+ LocalAddress: localAddr,
+ RemoteAddress: snaddr,
+ RemoteLinkAddress: broadcastMAC,
+ }
+ hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ copy(pkt[8:24], addr)
+ pkt[24] = ndpOptSrcLinkAddr
+ pkt[25] = 1 // address length
+ copy(pkt[26:], linkEP.LinkAddress())
+ pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, nil))
+
+ length := uint16(hdr.UsedLength())
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: length,
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+
+ return linkEP.WritePacket(r, &hdr, nil, ProtocolNumber)
+}
+
+func icmpChecksum(h header.ICMPv6, src, dst tcpip.Address, data []byte) uint16 {
+ // Calculate the IPv6 pseudo-header upper-layer checksum.
+ xsum := header.Checksum([]byte(src), 0)
+ xsum = header.Checksum([]byte(dst), xsum)
+ var upperLayerLength [4]byte
+ binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+len(data)))
+ xsum = header.Checksum(upperLayerLength[:], xsum)
+ xsum = header.Checksum([]byte{0, 0, 0, uint8(header.ICMPv6ProtocolNumber)}, xsum)
+ xsum = header.Checksum(data, xsum)
+
+ h2, h3 := h[2], h[3]
+ h[2], h[3] = 0, 0
+ xsum = ^header.Checksum(h, xsum)
+ h[2], h[3] = h2, h3
+
+ return xsum
+}
diff --git a/tcpip/network/ipv6/icmp_test.go b/tcpip/network/ipv6/icmp_test.go
new file mode 100644
index 0000000..6ab590c
--- /dev/null
+++ b/tcpip/network/ipv6/icmp_test.go
@@ -0,0 +1,160 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ipv6
+
+import (
+ "context"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/google/netstack/tcpip"
+ "github.com/google/netstack/tcpip/buffer"
+ "github.com/google/netstack/tcpip/header"
+ "github.com/google/netstack/tcpip/link/channel"
+ "github.com/google/netstack/tcpip/link/sniffer"
+ "github.com/google/netstack/tcpip/stack"
+)
+
+const (
+ linkAddr0 = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06")
+ linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
+)
+
+var (
+ lladdr0 = LinkLocalAddr(linkAddr0)
+ lladdr1 = LinkLocalAddr(linkAddr1)
+)
+
+type testContext struct {
+ t *testing.T
+ s0 *stack.Stack
+ s1 *stack.Stack
+
+ linkEP0 *channel.Endpoint
+ linkEP1 *channel.Endpoint
+
+ icmpCh chan header.ICMPv6Type
+}
+
+func newTestContext(t *testing.T) *testContext {
+ c := &testContext{
+ t: t,
+ s0: stack.New([]string{ProtocolName}, nil).(*stack.Stack),
+ s1: stack.New([]string{ProtocolName}, nil).(*stack.Stack),
+ icmpCh: make(chan header.ICMPv6Type, 10),
+ }
+
+ const defaultMTU = 65536
+ id0, linkEP0 := channel.New(256, defaultMTU, linkAddr0)
+ c.linkEP0 = linkEP0
+ if testing.Verbose() {
+ id0 = sniffer.New(id0)
+ }
+ if err := c.s0.CreateNIC(1, id0); err != nil {
+ t.Fatalf("CreateNIC s0: %v", err)
+ }
+ if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress lladdr0: %v", err)
+ }
+ if err := c.s0.AddAddress(1, ProtocolNumber, SolicitedNodeAddr(lladdr0)); err != nil {
+ t.Fatalf("AddAddress sn lladdr0: %v", err)
+ }
+
+ id1, linkEP1 := channel.New(256, defaultMTU, linkAddr1)
+ c.linkEP1 = linkEP1
+ if err := c.s1.CreateNIC(1, id1); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+ if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil {
+ t.Fatalf("AddAddress lladdr1: %v", err)
+ }
+ if err := c.s1.AddAddress(1, ProtocolNumber, SolicitedNodeAddr(lladdr1)); err != nil {
+ t.Fatalf("AddAddress sn lladdr1: %v", err)
+ }
+
+ routeTable := []tcpip.Route{{
+ Destination: tcpip.Address(strings.Repeat("\x00", 16)),
+ Mask: tcpip.Address(strings.Repeat("\x00", 16)),
+ NIC: 1,
+ }}
+ c.s0.SetRouteTable(routeTable)
+ c.s1.SetRouteTable(routeTable)
+
+ go c.routePackets(linkEP0.C, linkEP1)
+ go c.routePackets(linkEP1.C, linkEP0)
+
+ return c
+}
+
+func (c *testContext) countPacket(pkt channel.PacketInfo) {
+ if pkt.Proto != header.IPv6ProtocolNumber {
+ return
+ }
+ ipv6 := header.IPv6(pkt.Header)
+ transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader())
+ if transProto != header.ICMPv6ProtocolNumber {
+ return
+ }
+ b := pkt.Header[header.IPv6MinimumSize:]
+ icmp := header.ICMPv6(b)
+ c.icmpCh <- icmp.Type()
+}
+
+func (c *testContext) routePackets(ch <-chan channel.PacketInfo, ep *channel.Endpoint) {
+ for pkt := range ch {
+ c.countPacket(pkt)
+ views := []buffer.View{pkt.Header, pkt.Payload}
+ size := len(pkt.Header) + len(pkt.Payload)
+ vv := buffer.NewVectorisedView(size, views)
+ ep.InjectLinkAddr(pkt.Proto, ep.LinkAddress(), &vv)
+ }
+}
+
+func (c *testContext) cleanup() {
+ close(c.linkEP0.C)
+ close(c.linkEP1.C)
+}
+
+func TestLinkResolution(t *testing.T) {
+ c := newTestContext(t)
+ defer c.cleanup()
+ r, err := c.s0.FindRoute(1, lladdr0, lladdr1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer r.Release()
+
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+ pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, nil))
+ if err := r.WritePacket(&hdr, nil, header.ICMPv6ProtocolNumber); err != nil {
+ t.Fatal(err)
+ }
+
+ // This actually takes about 10 milliseconds, so no need to wait for
+ // a multi-minute go test timeout if something is broken.
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ stats := make(map[header.ICMPv6Type]int)
+ for {
+ select {
+ case <-ctx.Done():
+ t.Errorf("timeout waiting for ICMP, got: %#+v", stats)
+ return
+ case typ := <-c.icmpCh:
+ stats[typ]++
+
+ if stats[header.ICMPv6NeighborSolicit] > 0 &&
+ stats[header.ICMPv6NeighborAdvert] > 0 &&
+ stats[header.ICMPv6EchoRequest] > 0 &&
+ stats[header.ICMPv6EchoReply] > 0 {
+ return
+ }
+ }
+ }
+}
diff --git a/tcpip/network/ipv6/ipv6.go b/tcpip/network/ipv6/ipv6.go
index dbd30a7..f06efdc 100644
--- a/tcpip/network/ipv6/ipv6.go
+++ b/tcpip/network/ipv6/ipv6.go
@@ -32,18 +32,13 @@
type address [header.IPv6AddressSize]byte
type endpoint struct {
- nicid tcpip.NICID
- id stack.NetworkEndpointID
- address address
- linkEP stack.LinkEndpoint
- dispatcher stack.TransportDispatcher
-}
-
-func newEndpoint(nicid tcpip.NICID, addr tcpip.Address, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) *endpoint {
- e := &endpoint{nicid: nicid, linkEP: linkEP, dispatcher: dispatcher}
- copy(e.address[:], addr)
- e.id = stack.NetworkEndpointID{tcpip.Address(e.address[:])}
- return e
+ nicid tcpip.NICID
+ id stack.NetworkEndpointID
+ address address
+ linkEP stack.LinkEndpoint
+ linkAddrCache stack.LinkAddressCache
+ dispatcher stack.TransportDispatcher
+ ttl uint8
}
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
@@ -72,6 +67,12 @@
return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
}
+// SetTTL sets the default time-to-live value for packets sent through
+// this endpoint.
+func (e *endpoint) SetTTL(ttl uint8) {
+ e.ttl = ttl
+}
+
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
length := uint16(hdr.UsedLength())
@@ -82,8 +83,8 @@
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
NextHeader: uint8(protocol),
- HopLimit: 65,
- SrcAddr: tcpip.Address(e.address[:]),
+ HopLimit: e.ttl,
+ SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
@@ -100,7 +101,12 @@
vv.TrimFront(header.IPv6MinimumSize)
vv.CapLength(int(h.PayloadLength()))
- e.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(h.NextHeader()), vv)
+ p := tcpip.TransportProtocolNumber(h.NextHeader())
+ if p == header.ICMPv6ProtocolNumber {
+ e.handleICMP(r, vv)
+ return
+ }
+ e.dispatcher.DeliverTransportPacket(r, p, vv)
}
// Close cleans up resources associated with the endpoint.
@@ -134,9 +140,48 @@
// NewEndpoint creates a new ipv6 endpoint.
func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
- return newEndpoint(nicid, addr, dispatcher, linkEP), nil
+ e := &endpoint{
+ nicid: nicid,
+ linkEP: linkEP,
+ linkAddrCache: linkAddrCache,
+ dispatcher: dispatcher,
+ ttl: 255,
+ }
+ copy(e.address[:], addr)
+ e.id = stack.NetworkEndpointID{tcpip.Address(e.address[:])}
+ return e, nil
}
func init() {
stack.RegisterNetworkProtocol(ProtocolName, NewProtocol())
}
+
+// LinkLocalAddr computes the default IPv6 link-local address from
+// a link-layer (MAC) address.
+func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address {
+ // Convert a 48-bit MAC to an EUI-64 and then prepend the
+ // link-local header, FE80::.
+ //
+ // The conversion is very nearly:
+ // aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff
+ // Note the capital A. The conversion aa->Aa involves a bit flip.
+ lladdrb := [16]byte{
+ 0: 0xFE,
+ 1: 0x80,
+ 8: linkAddr[0] ^ 2,
+ 9: linkAddr[1],
+ 10: linkAddr[2],
+ 11: 0xFF,
+ 12: 0xFE,
+ 13: linkAddr[3],
+ 14: linkAddr[4],
+ 15: linkAddr[5],
+ }
+ return tcpip.Address(lladdrb[:])
+}
+
+// SolicitedNodeAddr computes the solicited-node multicast address.
+// This is used for NDP. Described in RFC 4291.
+func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address {
+ return "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff" + addr[len(addr)-3:]
+}
diff --git a/tcpip/stack/linkaddrcache.go b/tcpip/stack/linkaddrcache.go
index af6a4ec..73caf3c 100644
--- a/tcpip/stack/linkaddrcache.go
+++ b/tcpip/stack/linkaddrcache.go
@@ -5,6 +5,7 @@
package stack
import (
+ "context"
"sync"
"time"
@@ -23,6 +24,7 @@
cache map[tcpip.FullAddress]*linkAddrEntry
next int // array index of next available entry
entries [linkAddrCacheSize]linkAddrEntry
+ waiters map[tcpip.FullAddress]map[chan tcpip.LinkAddress]struct{}
}
// A linkAddrEntry is an entry in the linkAddrCache.
@@ -59,22 +61,60 @@
if c.next == len(c.entries) {
c.next = 0
}
+ for ch := range c.waiters[k] {
+ ch <- v
+ }
}
// get reports any known link address for k.
-func (c *linkAddrCache) get(k tcpip.FullAddress) (linkAddr tcpip.LinkAddress) {
+func (c *linkAddrCache) get(k tcpip.FullAddress, timeout time.Duration) (linkAddr tcpip.LinkAddress) {
c.mu.RLock()
if entry, found := c.cache[k]; found && c.valid(entry) {
linkAddr = entry.linkAddr
}
c.mu.RUnlock()
- return linkAddr
+ if linkAddr != "" || timeout == 0 {
+ return linkAddr
+ }
+ c.mu.Lock()
+ if entry, found := c.cache[k]; found && c.valid(entry) { // check again
+ c.mu.Unlock()
+ return entry.linkAddr
+ }
+ ch := make(chan tcpip.LinkAddress, 1)
+ m := c.waiters[k]
+ if m == nil {
+ m = make(map[chan tcpip.LinkAddress]struct{})
+ c.waiters[k] = m
+ }
+ m[ch] = struct{}{}
+ c.mu.Unlock()
+
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer func() {
+ cancel()
+ c.mu.Lock()
+ m := c.waiters[k]
+ delete(m, ch)
+ if len(m) == 0 {
+ delete(c.waiters, k)
+ }
+ c.mu.Unlock()
+ }()
+
+ select {
+ case linkAddr := <-ch:
+ return linkAddr
+ case <-ctx.Done():
+ return ""
+ }
}
func newLinkAddrCache(ageLimit time.Duration) *linkAddrCache {
c := &linkAddrCache{
ageLimit: ageLimit,
cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize),
+ waiters: make(map[tcpip.FullAddress]map[chan tcpip.LinkAddress]struct{}),
}
return c
}
diff --git a/tcpip/stack/linkaddrcache_test.go b/tcpip/stack/linkaddrcache_test.go
index 13b593d..4ff71cf 100644
--- a/tcpip/stack/linkaddrcache_test.go
+++ b/tcpip/stack/linkaddrcache_test.go
@@ -35,21 +35,21 @@
for i := len(testaddrs) - 1; i >= 0; i-- {
e := testaddrs[i]
c.add(e.addr, e.linkAddr)
- if got, want := c.get(e.addr), e.linkAddr; got != want {
+ if got, want := c.get(e.addr, 0), e.linkAddr; got != want {
t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, want)
}
}
// Expect to find at least half of the most recent entries.
for i := 0; i < linkAddrCacheSize/2; i++ {
e := testaddrs[i]
- if got, want := c.get(e.addr), e.linkAddr; got != want {
+ if got, want := c.get(e.addr, 0), e.linkAddr; got != want {
t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, want)
}
}
// The earliest entries should no longer be in the cache.
for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- {
e := testaddrs[i]
- if got := c.get(e.addr); got != "" {
+ if got := c.get(e.addr, 0); got != "" {
t.Errorf("check %d, c.get(%q)=%q, want no entry", i, string(e.addr.Addr), got)
}
}
@@ -64,7 +64,7 @@
go func() {
for _, e := range testaddrs {
c.add(e.addr, e.linkAddr)
- c.get(e.addr) // make work for gotsan
+ c.get(e.addr, 0) // make work for gotsan
}
wg.Done()
}()
@@ -75,11 +75,11 @@
// can fit in the cache, so our eviction strategy requires that
// the last entry be present and the first be missing.
e := testaddrs[len(testaddrs)-1]
- if got, want := c.get(e.addr), e.linkAddr; got != want {
+ if got, want := c.get(e.addr, 0), e.linkAddr; got != want {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, want)
}
e = testaddrs[0]
- if got := c.get(e.addr); got != "" {
+ if got := c.get(e.addr, 0); got != "" {
t.Errorf("c.get(%q)=%q, want no entry", string(e.addr.Addr), got)
}
}
@@ -89,7 +89,7 @@
e := testaddrs[0]
c.add(e.addr, e.linkAddr)
time.Sleep(50 * time.Millisecond)
- if got := c.get(e.addr); got != "" {
+ if got := c.get(e.addr, 0); got != "" {
t.Errorf("c.get(%q)=%q, want no stale entry", string(e.addr.Addr), got)
}
}
@@ -99,11 +99,11 @@
e := testaddrs[0]
l2 := e.linkAddr + "2"
c.add(e.addr, e.linkAddr)
- if got := c.get(e.addr); got != e.linkAddr {
+ if got := c.get(e.addr, 0); got != e.linkAddr {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
}
c.add(e.addr, l2)
- if got := c.get(e.addr); got != l2 {
+ if got := c.get(e.addr, 0); got != l2 {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2)
}
diff --git a/tcpip/stack/nic.go b/tcpip/stack/nic.go
index 360f42d..a19424f 100644
--- a/tcpip/stack/nic.go
+++ b/tcpip/stack/nic.go
@@ -109,7 +109,18 @@
n.removeEndpointLocked(ref)
}
- ref := newReferencedNetworkEndpoint(ep, protocol, n)
+ ref := &referencedNetworkEndpoint{
+ refs: 1,
+ ep: ep,
+ nic: n,
+ protocol: protocol,
+ holdsInsertRef: true,
+ }
+ if linkRes := n.stack.linkAddrResolvers[protocol]; linkRes != nil {
+ ref.linkRes = linkRes
+ ref.linkCache = n.stack
+ ref.linkEP = n.linkEP
+ }
n.endpoints[id] = ref
@@ -319,10 +330,13 @@
type referencedNetworkEndpoint struct {
ilist.Entry
- refs int32
- ep NetworkEndpoint
- nic *NIC
- protocol tcpip.NetworkProtocolNumber
+ refs int32
+ ep NetworkEndpoint
+ nic *NIC
+ protocol tcpip.NetworkProtocolNumber
+ linkRes LinkAddressResolver
+ linkCache LinkAddressCache
+ linkEP LinkEndpoint
// holdsInsertRef is protected by the NIC's mutex. It indicates whether
// the reference count is biased by 1 due to the insertion of the
@@ -331,16 +345,6 @@
holdsInsertRef bool
}
-func newReferencedNetworkEndpoint(ep NetworkEndpoint, protocol tcpip.NetworkProtocolNumber, nic *NIC) *referencedNetworkEndpoint {
- return &referencedNetworkEndpoint{
- refs: 1,
- ep: ep,
- nic: nic,
- protocol: protocol,
- holdsInsertRef: true,
- }
-}
-
// decRef decrements the ref count and cleans up the endpoint once it reaches
// zero.
func (r *referencedNetworkEndpoint) decRef() {
diff --git a/tcpip/stack/registration.go b/tcpip/stack/registration.go
index 1900ada..0954ed8 100644
--- a/tcpip/stack/registration.go
+++ b/tcpip/stack/registration.go
@@ -6,6 +6,7 @@
import (
"sync"
+ "time"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
@@ -104,6 +105,10 @@
// NICID returns the id of the NIC this endpoint belongs to.
NICID() tcpip.NICID
+ // SetTTL sets the default time-to-live value for packets sent through
+ // this endpoint.
+ SetTTL(ttl uint8)
+
// HandlePacket is called by the link layer when new packets arrive to
// this network endpoint.
HandlePacket(r *Route, vv *buffer.VectorisedView)
@@ -193,6 +198,8 @@
// AddLinkAddress adds a link address to the cache.
AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
+
+ GetLinkAddress(nicid tcpip.NICID, addr tcpip.Address, timeout time.Duration) tcpip.LinkAddress
}
var (
diff --git a/tcpip/stack/route.go b/tcpip/stack/route.go
index 52c1e17..c3eca94 100644
--- a/tcpip/stack/route.go
+++ b/tcpip/stack/route.go
@@ -5,6 +5,8 @@
package stack
import (
+ "time"
+
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
@@ -64,8 +66,33 @@
return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress)
}
+// SetTTL forwards the call to the network endpoint's implementation.
+func (r *Route) SetTTL(ttl uint8) {
+ r.ref.ep.SetTTL(ttl)
+}
+
+func isLoopback(addr tcpip.Address) bool {
+ return (len(addr) == 4 && addr[0] == 127) || addr == header.IPv6Loopback
+}
+
// WritePacket writes the packet through the given route.
func (r *Route) WritePacket(hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+ if r.RemoteLinkAddress == "" && r.ref.linkRes != nil && !isLoopback(r.RemoteAddress) {
+ nextAddr := r.NextHop
+ if nextAddr == "" {
+ nextAddr = r.RemoteAddress
+ }
+
+ nicid := r.ref.nic.ID()
+ r.RemoteLinkAddress = r.ref.linkCache.GetLinkAddress(nicid, nextAddr, 0)
+ if r.RemoteLinkAddress == "" {
+ r.ref.linkRes.LinkAddressRequest(nextAddr, r.LocalAddress, r.ref.linkEP)
+ r.RemoteLinkAddress = r.ref.linkCache.GetLinkAddress(nicid, nextAddr, 250*time.Millisecond)
+ }
+ if r.RemoteLinkAddress == "" {
+ return tcpip.ErrNoLinkAddress
+ }
+ }
return r.ref.ep.WritePacket(r, hdr, payload, protocol)
}
diff --git a/tcpip/stack/stack.go b/tcpip/stack/stack.go
index 71975b0..3d9584d 100644
--- a/tcpip/stack/stack.go
+++ b/tcpip/stack/stack.go
@@ -14,6 +14,7 @@
package stack
import (
+ "log"
"sync"
"time"
@@ -276,7 +277,6 @@
}
r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, ref)
- r.RemoteLinkAddress = s.linkAddrCache.get(tcpip.FullAddress{NIC: nic.ID(), Addr: remoteAddr})
r.NextHop = s.routeTable[i].Gateway
return r, nil
}
@@ -351,6 +351,14 @@
// for a particular address has been called.
}
+func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr tcpip.Address, timeout time.Duration) tcpip.LinkAddress {
+ if addr == "\xff\xff\xff\xff" {
+ return "\xff\xff\xff\xff\xff\xff"
+ }
+ fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ return s.linkAddrCache.get(fullAddr, timeout)
+}
+
// RegisterTransportEndpoint registers the given endpoint with the stack
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
@@ -387,3 +395,50 @@
nic.demux.unregisterEndpoint(netProtos, protocol, id)
}
}
+
+// JoinGroup joins the given multicast group on every interface that
+// matches the given interface address.
+// TODO: notify network of subscription via igmp protocol
+func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, interfaceAddr tcpip.Address, multicastAddr tcpip.Address) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ for _, n := range s.nics {
+ ref := n.findEndpoint(interfaceAddr)
+ if ref == nil {
+ continue
+ }
+ ref.decRef()
+
+ // This NIC matches the interface address.
+ err := n.AddAddress(protocol, multicastAddr)
+ if err != nil {
+ log.Printf("igmpJoinGroup failed for %v: %v", multicastAddr, err)
+ return err
+ }
+ }
+ return nil
+}
+
+// LeaveGroup leaves the given multicast group on every interface that
+// matches the given interface address.
+func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, interfaceAddr tcpip.Address, multicastAddr tcpip.Address) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ for _, n := range s.nics {
+ ref := n.findEndpoint(interfaceAddr)
+ if ref == nil {
+ continue
+ }
+ ref.decRef()
+
+ // This NIC matches the interface address.
+ err := n.RemoveAddress(multicastAddr)
+ if err != nil {
+ log.Printf("igmpLeaveGroup failed for %v: %v", multicastAddr, err)
+ return err
+ }
+ }
+ return nil
+}
diff --git a/tcpip/tcpip.go b/tcpip/tcpip.go
index 0cde934..3faef89 100644
--- a/tcpip/tcpip.go
+++ b/tcpip/tcpip.go
@@ -66,6 +66,7 @@
ErrNotConnected = &Error{"endpoint not connected"}
ErrConnectionReset = &Error{"connection reset by peer"}
ErrConnectionAborted = &Error{"connection aborted"}
+ ErrNoLinkAddress = &Error{"no remote link address"}
)
// Errors related to Subnet
@@ -294,6 +295,27 @@
// Only supported on Unix sockets.
type PasscredOption int
+// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
+// TTL value for multicast messages. The default is 1.
+type MulticastTTLOption uint8
+
+// MembershipOption is used by SetSockOpt/GetSockOpt as an argument to
+// AddMembershipOption and RemoveMembershipOption
+type MembershipOption struct {
+ InterfaceAddr Address
+ MulticastAddr Address
+}
+
+// AddMembershipOption is used by SetSockOpt/GetSockOpt to join a multicast
+// group identified by the given multicast address, on the interface matching
+// the given interface address.
+type AddMembershipOption MembershipOption
+
+// RemoveMembershipOption is used by SetSockOpt/GetSockOpt to leave a multicast
+// group identified by the given multicast address, on the interface matching
+// the given interface address.
+type RemoveMembershipOption MembershipOption
+
// Route is a row in the routing table. It specifies through which NIC (and
// gateway) sets of packets should be routed. A row is considered viable if the
// masked target address matches the destination adddress in the row.
@@ -390,11 +412,52 @@
switch len(a) {
case 4:
return fmt.Sprintf("%d.%d.%d.%d", int(a[0]), int(a[1]), int(a[2]), int(a[3]))
+ case 16:
+ // Find the longest subsequence of hexadecimal zeros.
+ start, end := -1, -1
+ for i := 0; i < len(a); i += 2 {
+ j := i
+ for j < len(a) && a[j] == 0 && a[j+1] == 0 {
+ j += 2
+ }
+ if j > i+2 && j-i > end-start {
+ start, end = i, j
+ }
+ }
+
+ var b []byte
+ for i := 0; i < len(a); i += 2 {
+ if i == start {
+ b = append(b, "::"...)
+ i = end
+ if end >= len(a) {
+ break
+ }
+ } else if i > 0 {
+ b = append(b, ':')
+ }
+ v := uint16(a[i+0])<<8 | uint16(a[i+1])
+ b = appendHex(b, v)
+ }
+ return string(b)
default:
return fmt.Sprintf("%x", []byte(a))
}
}
+func appendHex(b []byte, v uint16) []byte {
+ if v == 0 {
+ return append(b, '0')
+ }
+ const digits = "0123456789abcdef"
+ for i := uint(3); i < 4; i-- {
+ if v := v >> (i * 4); v != 0 {
+ b = append(b, digits[v&0xf])
+ }
+ }
+ return b
+}
+
// To4 converts the IPv4 address to a 4-byte representation.
// If the address is not an IPv4 address, To4 returns "".
func (a Address) To4() Address {
@@ -424,6 +487,109 @@
return true
}
+// Parse parses the string representation of an IPv4 or IPv6 address.
+func Parse(src string) Address {
+ for i := 0; i < len(src); i++ {
+ switch src[i] {
+ case '.':
+ return parseIP4(src)
+ case ':':
+ return parseIP6(src)
+ }
+ }
+ return ""
+}
+
+func parseIP4(src string) Address {
+ var addr [4]byte
+ _, err := fmt.Sscanf(src, "%d.%d.%d.%d", &addr[0], &addr[1], &addr[2], &addr[3])
+ if err != nil {
+ return ""
+ }
+ return Address(addr[:])
+}
+
+func parseIP6(src string) (res Address) {
+ a := make([]byte, 0, 16) // cap(a) is constant throughout
+ expansion := -1 // index of '::' expansion in a
+
+ if len(src) >= 2 && src[:2] == "::" {
+ if len(src) == 2 {
+ return Address(a[:cap(a)])
+ }
+ expansion = 0
+ src = src[2:]
+ }
+
+ for len(a) < cap(a) && len(src) > 0 {
+ var x uint16
+ var ok bool
+ x, src, ok = parseHex(src)
+ if !ok {
+ return ""
+ }
+ a = append(a, uint8(x>>8), uint8(x))
+
+ if len(src) == 0 {
+ break
+ }
+
+ // Next is either ":..." or "::[...]".
+ if src[0] != ':' || len(src) == 1 {
+ return ""
+ }
+ src = src[1:]
+ if src[0] == ':' {
+ if expansion >= 0 {
+ return "" // only one expansion allowed
+ }
+ expansion = len(a)
+ src = src[1:]
+ }
+ }
+ if len(src) != 0 {
+ return ""
+ }
+
+ if missing := cap(a) - len(a); missing > 0 {
+ if expansion < 0 {
+ return ""
+ }
+ a = a[:cap(a)]
+ copy(a[expansion+missing:], a[expansion:])
+ for i := 0; i < missing; i++ {
+ a[i+expansion] = 0
+ }
+ }
+
+ return Address(a)
+}
+
+func parseHex(src string) (x uint16, remaining string, ok bool) {
+ if len(src) == 0 {
+ return 0, src, false
+ }
+loop:
+ for len(src) > 0 {
+ v := src[0]
+ switch {
+ case '0' <= v && v <= '9':
+ v = v - '0'
+ case 'a' <= v && v <= 'f':
+ v = v - 'a' + 10
+ case 'A' <= v && v <= 'F':
+ v = v - 'A' + 10
+ case v == ':':
+ break loop
+ default:
+ return 0, src, false
+ }
+ src = src[1:]
+ x = (x << 4) | uint16(v)
+ }
+ return x, src, true
+}
+
// LinkAddress is a byte slice cast as a string that represents a link address.
// It is typically a 6-byte MAC address.
type LinkAddress string
diff --git a/tcpip/tcpip_test.go b/tcpip/tcpip_test.go
index fd4d834..413ab14 100644
--- a/tcpip/tcpip_test.go
+++ b/tcpip/tcpip_test.go
@@ -5,6 +5,7 @@
package tcpip
import (
+ "strings"
"testing"
)
@@ -128,3 +129,43 @@
}
}
}
+
+func TestParse(t *testing.T) {
+ tests := []struct {
+ txt string
+ addr Address
+ }{
+ {"::", Address(strings.Repeat("\x00", 16))},
+ {"8::", Address("\x00\x08" + strings.Repeat("\x00", 14))},
+ {"::8a", Address(strings.Repeat("\x00", 14) + "\x00\x8a")},
+ {"fe80::1234:5678", "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x12\x34\x56\x78"},
+ {"fe80::b097:c9ff:fe02:477", "\xfe\x80\x00\x00\x00\x00\x00\x00\xb0\x97\xc9\xff\xfe\x02\x04\x77"},
+ {"a:b:c:d:1:2:3:4", "\x00\x0a\x00\x0b\x00\x0c\x00\x0d\x00\x01\x00\x02\x00\x03\x00\x04"},
+ {"a:b:c::2:3:4", "\x00\x0a\x00\x0b\x00\x0c\x00\x00\x00\x00\x00\x02\x00\x03\x00\x04"},
+ {"000a:000b:000c::", "\x00\x0a\x00\x0b\x00\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
+ {"0000:0000:0000::0001", Address(strings.Repeat("\x00", 15) + "\x01")},
+ {"0:0::1", Address(strings.Repeat("\x00", 15) + "\x01")},
+ }
+
+ for _, test := range tests {
+ got := Parse(test.txt)
+ if got != test.addr {
+ t.Errorf("Parse(%v)=%v, want %v", test.txt, got, test.addr)
+ }
+ }
+}
+
+func TestAddressString(t *testing.T) {
+ tests := []string{
+ "a:b:c::2:3:4",
+ "8::",
+ "fe80::5054:ff:fe12:3456",
+ "::1",
+ }
+ for _, want := range tests {
+ addr := Parse(want)
+ if got := addr.String(); got != want {
+ t.Errorf("Address(%x).String()=%q, want %q", addr, got, want)
+ }
+ }
+}
diff --git a/tcpip/transport/tcp/connect.go b/tcpip/transport/tcp/connect.go
index bfa4192..a64fe87 100644
--- a/tcpip/transport/tcp/connect.go
+++ b/tcpip/transport/tcp/connect.go
@@ -717,9 +717,9 @@
}
if n¬ifyClose != 0 && closeTimer == nil {
- // Reset the connection 3 seconds after the
+ // Reset the connection 60 seconds after the
// endpoint has been closed.
- closeTimer = time.AfterFunc(3*time.Second, func() {
+ closeTimer = time.AfterFunc(60*time.Second, func() {
closeWaker.Assert()
})
}
diff --git a/tcpip/transport/udp/endpoint.go b/tcpip/transport/udp/endpoint.go
index 48feefe..8d74fc0 100644
--- a/tcpip/transport/udp/endpoint.go
+++ b/tcpip/transport/udp/endpoint.go
@@ -53,16 +53,21 @@
rcvClosed bool
// The following fields are protected by the mu mutex.
- mu sync.RWMutex
- sndBufSize int
- id stack.TransportEndpointID
- state endpointState
- bindNICID tcpip.NICID
- bindAddr tcpip.Address
- regNICID tcpip.NICID
- route stack.Route
- dstPort uint16
- v6only bool
+ mu sync.RWMutex
+ sndBufSize int
+ id stack.TransportEndpointID
+ state endpointState
+ bindNICID tcpip.NICID
+ bindAddr tcpip.Address
+ regNICID tcpip.NICID
+ route stack.Route
+ dstPort uint16
+ v6only bool
+ multicastTTL uint8
+
+ // A list of multicast memberships that we need to remove when the endpoint
+ // is closed. Protected by the mu mutex.
+ multicastMemberships []tcpip.AddMembershipOption
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
@@ -79,6 +84,8 @@
stack: stack,
netProto: netProto,
waiterQueue: waiterQueue,
+ v6only: true,
+ multicastTTL: 1,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
}
@@ -108,6 +115,9 @@
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
+ // Shutdown the endpoint so that we notify waiters that the endpoint is closed.
+ e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+
e.mu.Lock()
defer e.mu.Unlock()
@@ -116,6 +126,11 @@
e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
}
+ for _, mem := range e.multicastMemberships {
+ e.stack.LeaveGroup(e.netProto, mem.InterfaceAddr, mem.MulticastAddr)
+ }
+ e.multicastMemberships = nil
+
// Close the receive list and drain it.
e.rcvMu.Lock()
e.rcvClosed = true
@@ -249,6 +264,10 @@
dstPort = to.Port
}
+ if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
+ route.SetTTL(e.multicastTTL)
+ }
+
sendUDP(route, v, e.id.LocalPort, dstPort)
return uintptr(len(v)), nil
}
@@ -277,6 +296,31 @@
}
e.v6only = v != 0
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.multicastTTL = uint8(v)
+ case tcpip.AddMembershipOption:
+ if err := e.stack.JoinGroup(e.netProto, v.InterfaceAddr, v.MulticastAddr); err == nil {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ e.multicastMemberships = append(e.multicastMemberships, v)
+ }
+ case tcpip.RemoveMembershipOption:
+ if err := e.stack.LeaveGroup(e.netProto, v.InterfaceAddr, v.MulticastAddr); err == nil {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ for i, mem := range e.multicastMemberships {
+ if mem.InterfaceAddr == v.InterfaceAddr && mem.MulticastAddr == v.MulticastAddr {
+ // Only remove the first match, so that each added membership above is
+ // paired with exactly 1 removal.
+ e.multicastMemberships[i] = e.multicastMemberships[len(e.multicastMemberships)-1]
+ e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
+ break
+ }
+ }
+ }
}
return nil
}
@@ -315,6 +359,12 @@
}
return nil
+ case *tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ *o = tcpip.MulticastTTLOption(e.multicastTTL)
+ e.mu.Unlock()
+ return nil
+
case *tcpip.ReceiveQueueSizeOption:
e.rcvMu.Lock()
if e.rcvList.Empty() {
@@ -476,7 +526,9 @@
e.mu.RLock()
defer e.mu.RUnlock()
- if e.state != stateConnected {
+ // A socket in the bound state can still receive multicast messages,
+ // so we need to notify waiters on shutdown.
+ if e.state != stateBound && e.state != stateConnected {
return tcpip.ErrNotConnected
}