Merge remote-tracking branch 'upstream/master' into HEAD
Change-Id: Id3c90ecf27135f97e563c39d0b2288e3faf9dcf7
diff --git a/tcpip/header/ipv6.go b/tcpip/header/ipv6.go
index 47ed6ec..0eeb895 100644
--- a/tcpip/header/ipv6.go
+++ b/tcpip/header/ipv6.go
@@ -27,7 +27,7 @@
nextHdr = 6
hopLimit = 7
v6SrcAddr = 8
- v6DstAddr = 24
+ v6DstAddr = v6SrcAddr + IPv6AddressSize
)
// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the
@@ -119,13 +119,13 @@
// SourceAddress returns the "source address" field of the ipv6 header.
func (b IPv6) SourceAddress() tcpip.Address {
- return tcpip.Address(b[v6SrcAddr : v6SrcAddr+IPv6AddressSize])
+ return tcpip.Address(b[v6SrcAddr:][:IPv6AddressSize])
}
// DestinationAddress returns the "destination address" field of the ipv6
// header.
func (b IPv6) DestinationAddress() tcpip.Address {
- return tcpip.Address(b[v6DstAddr : v6DstAddr+IPv6AddressSize])
+ return tcpip.Address(b[v6DstAddr:][:IPv6AddressSize])
}
// Checksum implements Network.Checksum. Given that IPv6 doesn't have a
@@ -153,13 +153,13 @@
// SetSourceAddress sets the "source address" field of the ipv6 header.
func (b IPv6) SetSourceAddress(addr tcpip.Address) {
- copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], addr)
+ copy(b[v6SrcAddr:][:IPv6AddressSize], addr)
}
// SetDestinationAddress sets the "destination address" field of the ipv6
// header.
func (b IPv6) SetDestinationAddress(addr tcpip.Address) {
- copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], addr)
+ copy(b[v6DstAddr:][:IPv6AddressSize], addr)
}
// SetNextHeader sets the value of the "next header" field of the ipv6 header.
@@ -178,8 +178,8 @@
b.SetPayloadLength(i.PayloadLength)
b[nextHdr] = i.NextHeader
b[hopLimit] = i.HopLimit
- copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], i.SrcAddr)
- copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], i.DstAddr)
+ b.SetSourceAddress(i.SrcAddr)
+ b.SetDestinationAddress(i.DstAddr)
}
// IsValid performs basic validation on the packet.
diff --git a/tcpip/link/fdbased/mmap.go b/tcpip/link/fdbased/mmap.go
index 2dca173..d1a0a7c 100644
--- a/tcpip/link/fdbased/mmap.go
+++ b/tcpip/link/fdbased/mmap.go
@@ -12,12 +12,183 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build !linux !amd64
+// +build linux,amd64 linux,arm64
package fdbased
-// Stubbed out version for non-linux/non-amd64 platforms.
+import (
+ "encoding/binary"
+ "syscall"
-func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
- return nil, nil
+ "github.com/google/netstack/tcpip"
+ "github.com/google/netstack/tcpip/buffer"
+ "github.com/google/netstack/tcpip/header"
+ "github.com/google/netstack/tcpip/link/rawfile"
+ "golang.org/x/sys/unix"
+)
+
+const (
+ tPacketAlignment = uintptr(16)
+ tpStatusKernel = 0
+ tpStatusUser = 1
+ tpStatusCopy = 2
+ tpStatusLosing = 4
+)
+
+// We overallocate the frame size to accommodate space for the
+// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding.
+//
+// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB
+//
+// NOTE:
+// Frames need to be aligned at 16 byte boundaries.
+// BlockSize needs to be page aligned.
+//
+// For details see PACKET_MMAP setting constraints in
+// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
+const (
+ tpFrameSize = 65536 + 128
+ tpBlockSize = tpFrameSize * 32
+ tpBlockNR = 1
+ tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize
+)
+
+// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct
+// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>.
+func tPacketAlign(v uintptr) uintptr {
+ return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1))
+}
+
+// tPacketReq is the tpacket_req structure as described in
+// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
+type tPacketReq struct {
+ tpBlockSize uint32
+ tpBlockNR uint32
+ tpFrameSize uint32
+ tpFrameNR uint32
+}
+
+// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h>
+type tPacketHdr []byte
+
+const (
+ tpStatusOffset = 0
+ tpLenOffset = 8
+ tpSnapLenOffset = 12
+ tpMacOffset = 16
+ tpNetOffset = 18
+ tpSecOffset = 20
+ tpUSecOffset = 24
+)
+
+func (t tPacketHdr) tpLen() uint32 {
+ return binary.LittleEndian.Uint32(t[tpLenOffset:])
+}
+
+func (t tPacketHdr) tpSnapLen() uint32 {
+ return binary.LittleEndian.Uint32(t[tpSnapLenOffset:])
+}
+
+func (t tPacketHdr) tpMac() uint16 {
+ return binary.LittleEndian.Uint16(t[tpMacOffset:])
+}
+
+func (t tPacketHdr) tpNet() uint16 {
+ return binary.LittleEndian.Uint16(t[tpNetOffset:])
+}
+
+func (t tPacketHdr) tpSec() uint32 {
+ return binary.LittleEndian.Uint32(t[tpSecOffset:])
+}
+
+func (t tPacketHdr) tpUSec() uint32 {
+ return binary.LittleEndian.Uint32(t[tpUSecOffset:])
+}
+
+func (t tPacketHdr) Payload() []byte {
+ return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()]
+}
+
+// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets.
+// See: mmap_amd64_unsafe.go for implementation details.
+type packetMMapDispatcher struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // e is the endpoint this dispatcher is attached to.
+ e *endpoint
+
+ // ringBuffer is only used when PacketMMap dispatcher is used and points
+ // to the start of the mmapped PACKET_RX_RING buffer.
+ ringBuffer []byte
+
+ // ringOffset is the current offset into the ring buffer where the next
+ // inbound packet will be placed by the kernel.
+ ringOffset int
+}
+
+func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) {
+ hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
+ for hdr.tpStatus()&tpStatusUser == 0 {
+ event := rawfile.PollEvent{
+ FD: int32(d.fd),
+ Events: unix.POLLIN | unix.POLLERR,
+ }
+ if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 {
+ if errno == syscall.EINTR {
+ continue
+ }
+ return nil, rawfile.TranslateErrno(errno)
+ }
+ if hdr.tpStatus()&tpStatusCopy != 0 {
+ // This frame is truncated so skip it after flipping the
+ // buffer to the kernel.
+ hdr.setTPStatus(tpStatusKernel)
+ d.ringOffset = (d.ringOffset + 1) % tpFrameNR
+ hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:])
+ continue
+ }
+ }
+
+ // Copy out the packet from the mmapped frame to a locally owned buffer.
+ pkt := make([]byte, hdr.tpSnapLen())
+ copy(pkt, hdr.Payload())
+ // Release packet to kernel.
+ hdr.setTPStatus(tpStatusKernel)
+ d.ringOffset = (d.ringOffset + 1) % tpFrameNR
+ return pkt, nil
+}
+
+// dispatch reads packets from an mmaped ring buffer and dispatches them to the
+// network stack.
+func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
+ pkt, err := d.readMMappedPacket()
+ if err != nil {
+ return false, err
+ }
+ var (
+ p tcpip.NetworkProtocolNumber
+ remote, local tcpip.LinkAddress
+ )
+ if d.e.hdrSize > 0 {
+ eth := header.Ethernet(pkt)
+ p = eth.Type()
+ remote = eth.SourceAddress()
+ local = eth.DestinationAddress()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ switch header.IPVersion(pkt) {
+ case header.IPv4Version:
+ p = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ p = header.IPv6ProtocolNumber
+ default:
+ return true, nil
+ }
+ }
+
+ pkt = pkt[d.e.hdrSize:]
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}))
+ return true, nil
}
diff --git a/tcpip/link/fdbased/mmap_amd64.go b/tcpip/link/fdbased/mmap_amd64.go
deleted file mode 100644
index e69e1e8..0000000
--- a/tcpip/link/fdbased/mmap_amd64.go
+++ /dev/null
@@ -1,194 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build linux,amd64
-
-package fdbased
-
-import (
- "encoding/binary"
- "syscall"
-
- "github.com/google/netstack/tcpip"
- "github.com/google/netstack/tcpip/buffer"
- "github.com/google/netstack/tcpip/header"
- "github.com/google/netstack/tcpip/link/rawfile"
- "golang.org/x/sys/unix"
-)
-
-const (
- tPacketAlignment = uintptr(16)
- tpStatusKernel = 0
- tpStatusUser = 1
- tpStatusCopy = 2
- tpStatusLosing = 4
-)
-
-// We overallocate the frame size to accommodate space for the
-// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding.
-//
-// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB
-//
-// NOTE:
-// Frames need to be aligned at 16 byte boundaries.
-// BlockSize needs to be page aligned.
-//
-// For details see PACKET_MMAP setting constraints in
-// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
-const (
- tpFrameSize = 65536 + 128
- tpBlockSize = tpFrameSize * 32
- tpBlockNR = 1
- tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize
-)
-
-// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct
-// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>.
-func tPacketAlign(v uintptr) uintptr {
- return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1))
-}
-
-// tPacketReq is the tpacket_req structure as described in
-// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
-type tPacketReq struct {
- tpBlockSize uint32
- tpBlockNR uint32
- tpFrameSize uint32
- tpFrameNR uint32
-}
-
-// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h>
-type tPacketHdr []byte
-
-const (
- tpStatusOffset = 0
- tpLenOffset = 8
- tpSnapLenOffset = 12
- tpMacOffset = 16
- tpNetOffset = 18
- tpSecOffset = 20
- tpUSecOffset = 24
-)
-
-func (t tPacketHdr) tpLen() uint32 {
- return binary.LittleEndian.Uint32(t[tpLenOffset:])
-}
-
-func (t tPacketHdr) tpSnapLen() uint32 {
- return binary.LittleEndian.Uint32(t[tpSnapLenOffset:])
-}
-
-func (t tPacketHdr) tpMac() uint16 {
- return binary.LittleEndian.Uint16(t[tpMacOffset:])
-}
-
-func (t tPacketHdr) tpNet() uint16 {
- return binary.LittleEndian.Uint16(t[tpNetOffset:])
-}
-
-func (t tPacketHdr) tpSec() uint32 {
- return binary.LittleEndian.Uint32(t[tpSecOffset:])
-}
-
-func (t tPacketHdr) tpUSec() uint32 {
- return binary.LittleEndian.Uint32(t[tpUSecOffset:])
-}
-
-func (t tPacketHdr) Payload() []byte {
- return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()]
-}
-
-// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets.
-// See: mmap_amd64_unsafe.go for implementation details.
-type packetMMapDispatcher struct {
- // fd is the file descriptor used to send and receive packets.
- fd int
-
- // e is the endpoint this dispatcher is attached to.
- e *endpoint
-
- // ringBuffer is only used when PacketMMap dispatcher is used and points
- // to the start of the mmapped PACKET_RX_RING buffer.
- ringBuffer []byte
-
- // ringOffset is the current offset into the ring buffer where the next
- // inbound packet will be placed by the kernel.
- ringOffset int
-}
-
-func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) {
- hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
- for hdr.tpStatus()&tpStatusUser == 0 {
- event := rawfile.PollEvent{
- FD: int32(d.fd),
- Events: unix.POLLIN | unix.POLLERR,
- }
- if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 {
- if errno == syscall.EINTR {
- continue
- }
- return nil, rawfile.TranslateErrno(errno)
- }
- if hdr.tpStatus()&tpStatusCopy != 0 {
- // This frame is truncated so skip it after flipping the
- // buffer to the kernel.
- hdr.setTPStatus(tpStatusKernel)
- d.ringOffset = (d.ringOffset + 1) % tpFrameNR
- hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:])
- continue
- }
- }
-
- // Copy out the packet from the mmapped frame to a locally owned buffer.
- pkt := make([]byte, hdr.tpSnapLen())
- copy(pkt, hdr.Payload())
- // Release packet to kernel.
- hdr.setTPStatus(tpStatusKernel)
- d.ringOffset = (d.ringOffset + 1) % tpFrameNR
- return pkt, nil
-}
-
-// dispatch reads packets from an mmaped ring buffer and dispatches them to the
-// network stack.
-func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
- pkt, err := d.readMMappedPacket()
- if err != nil {
- return false, err
- }
- var (
- p tcpip.NetworkProtocolNumber
- remote, local tcpip.LinkAddress
- )
- if d.e.hdrSize > 0 {
- eth := header.Ethernet(pkt)
- p = eth.Type()
- remote = eth.SourceAddress()
- local = eth.DestinationAddress()
- } else {
- // We don't get any indication of what the packet is, so try to guess
- // if it's an IPv4 or IPv6 packet.
- switch header.IPVersion(pkt) {
- case header.IPv4Version:
- p = header.IPv4ProtocolNumber
- case header.IPv6Version:
- p = header.IPv6ProtocolNumber
- default:
- return true, nil
- }
- }
-
- pkt = pkt[d.e.hdrSize:]
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}))
- return true, nil
-}
diff --git a/tcpip/link/fdbased/mmap_stub.go b/tcpip/link/fdbased/mmap_stub.go
new file mode 100644
index 0000000..67be52d
--- /dev/null
+++ b/tcpip/link/fdbased/mmap_stub.go
@@ -0,0 +1,23 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !linux !amd64,!arm64
+
+package fdbased
+
+// Stubbed out version for non-linux/non-amd64/non-arm64 platforms.
+
+func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ return nil, nil
+}
diff --git a/tcpip/link/fdbased/mmap_amd64_unsafe.go b/tcpip/link/fdbased/mmap_unsafe.go
similarity index 98%
rename from tcpip/link/fdbased/mmap_amd64_unsafe.go
rename to tcpip/link/fdbased/mmap_unsafe.go
index 47cb1d1..3894185 100644
--- a/tcpip/link/fdbased/mmap_amd64_unsafe.go
+++ b/tcpip/link/fdbased/mmap_unsafe.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build linux,amd64
+// +build linux,amd64 linux,arm64
package fdbased
diff --git a/tcpip/link/rawfile/blockingpoll_arm64.s b/tcpip/link/rawfile/blockingpoll_arm64.s
new file mode 100644
index 0000000..b62888b
--- /dev/null
+++ b/tcpip/link/rawfile/blockingpoll_arm64.s
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// BlockingPoll makes the ppoll() syscall while calling the version of
+// entersyscall that relinquishes the P so that other Gs can run. This is meant
+// to be called in cases when the syscall is expected to block.
+//
+// func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (n int, err syscall.Errno)
+TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
+ BL ·callEntersyscallblock(SB)
+ MOVD fds+0(FP), R0
+ MOVD nfds+8(FP), R1
+ MOVD timeout+16(FP), R2
+ MOVD $0x0, R3 // sigmask parameter which isn't used here
+ MOVD $0x49, R8 // SYS_PPOLL
+ SVC
+ CMP $0xfffffffffffff001, R0
+ BLS ok
+ MOVD $-1, R1
+ MOVD R1, n+24(FP)
+ NEG R0, R0
+ MOVD R0, err+32(FP)
+ BL ·callExitsyscall(SB)
+ RET
+ok:
+ MOVD R0, n+24(FP)
+ MOVD $0, err+32(FP)
+ BL ·callExitsyscall(SB)
+ RET
diff --git a/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go b/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
new file mode 100644
index 0000000..621ab8d
--- /dev/null
+++ b/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
@@ -0,0 +1,31 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux,!amd64,!arm64
+
+package rawfile
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+// BlockingPoll is just a stub function that forwards to the ppoll() system call
+// on non-amd64 and non-arm64 platforms.
+func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) {
+ n, _, e := syscall.Syscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(fds)),
+ uintptr(nfds), uintptr(unsafe.Pointer(timeout)), 0, 0, 0)
+
+ return int(n), e
+}
diff --git a/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go b/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
similarity index 83%
rename from tcpip/link/rawfile/blockingpoll_amd64_unsafe.go
rename to tcpip/link/rawfile/blockingpoll_yield_unsafe.go
index 47039a4..dda3b10 100644
--- a/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go
+++ b/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build linux,amd64
+// +build linux,amd64 linux,arm64
// +build go1.12
// +build !go1.14
@@ -25,6 +25,12 @@
_ "unsafe" // for go:linkname
)
+// BlockingPoll on amd64/arm64 makes the ppoll() syscall while calling the
+// version of entersyscall that relinquishes the P so that other Gs can
+// run. This is meant to be called in cases when the syscall is expected to
+// block. On non amd64/arm64 platforms it just forwards to the ppoll() system
+// call.
+//
//go:noescape
func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno)
diff --git a/tcpip/network/arp/arp.go b/tcpip/network/arp/arp.go
index d565c11..8282c90 100644
--- a/tcpip/network/arp/arp.go
+++ b/tcpip/network/arp/arp.go
@@ -112,11 +112,7 @@
copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender())
copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender())
e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
- fallthrough // also fill the cache from requests
case header.ARPReply:
- addr := tcpip.Address(h.ProtocolAddressSender())
- linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr)
}
}
diff --git a/tcpip/network/ipv6/icmp.go b/tcpip/network/ipv6/icmp.go
index c544a55..1ba1f6f 100644
--- a/tcpip/network/ipv6/icmp.go
+++ b/tcpip/network/ipv6/icmp.go
@@ -100,13 +100,11 @@
case header.ICMPv6NeighborSolicit:
received.NeighborSolicit.Increment()
- e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
-
if len(v) < header.ICMPv6NeighborSolicitMinimumSize {
received.Invalid.Increment()
return
}
- targetAddr := tcpip.Address(v[8:][:16])
+ targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize])
if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 {
// We don't have a useful answer; the best we can do is ignore the request.
return
@@ -146,7 +144,7 @@
received.Invalid.Increment()
return
}
- targetAddr := tcpip.Address(v[8:][:16])
+ targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize])
e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress)
if targetAddr != r.RemoteAddress {
e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
diff --git a/tcpip/stack/linkaddrcache.go b/tcpip/stack/linkaddrcache.go
index 4cadb9e..aac5c9a 100644
--- a/tcpip/stack/linkaddrcache.go
+++ b/tcpip/stack/linkaddrcache.go
@@ -42,10 +42,11 @@
// resolved before failing.
resolutionAttempts int
- mu sync.Mutex
- cache map[tcpip.FullAddress]*linkAddrEntry
- next int // array index of next available entry
- entries [linkAddrCacheSize]linkAddrEntry
+ cache struct {
+ sync.Mutex
+ table map[tcpip.FullAddress]*linkAddrEntry
+ lru linkAddrEntryList
+ }
}
// entryState controls the state of a single entry in the cache.
@@ -60,9 +61,6 @@
// failed means that address resolution timed out and the address
// could not be resolved.
failed
- // expired means that the cache entry has expired and the address must be
- // resolved again.
- expired
)
// String implements Stringer.
@@ -74,8 +72,6 @@
return "ready"
case failed:
return "failed"
- case expired:
- return "expired"
default:
return fmt.Sprintf("unknown(%d)", s)
}
@@ -84,119 +80,102 @@
// A linkAddrEntry is an entry in the linkAddrCache.
// This struct is thread-compatible.
type linkAddrEntry struct {
+ linkAddrEntryEntry
+
addr tcpip.FullAddress
linkAddr tcpip.LinkAddress
expiration time.Time
s entryState
// wakers is a set of waiters for address resolution result. Anytime
- // state transitions out of 'incomplete' these waiters are notified.
+ // state transitions out of incomplete these waiters are notified.
wakers map[*sleep.Waker]struct{}
+ // done is used to allow callers to wait on address resolution. It is nil iff
+ // s is incomplete and resolution is not yet in progress.
done chan struct{}
}
-func (e *linkAddrEntry) state() entryState {
- if e.s != expired && time.Now().After(e.expiration) {
- // Force the transition to ensure waiters are notified.
- e.changeState(expired)
- }
- return e.s
-}
-
-func (e *linkAddrEntry) changeState(ns entryState) {
- if e.s == ns {
- return
- }
-
- // Validate state transition.
- switch e.s {
- case incomplete:
- // All transitions are valid.
- case ready, failed:
- if ns != expired {
- panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns))
- }
- case expired:
- // Terminal state.
- panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns))
- default:
- panic(fmt.Sprintf("invalid state: %s", e.s))
- }
-
+// changeState sets the entry's state to ns, notifying any waiters.
+//
+// The entry's expiration is bumped up to the greater of itself and the passed
+// expiration; the zero value indicates immediate expiration, and is set
+// unconditionally - this is an implementation detail that allows for entries
+// to be reused.
+func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) {
// Notify whoever is waiting on address resolution when transitioning
- // out of 'incomplete'.
- if e.s == incomplete {
+ // out of incomplete.
+ if e.s == incomplete && ns != incomplete {
for w := range e.wakers {
w.Assert()
}
e.wakers = nil
- if e.done != nil {
- close(e.done)
+ if ch := e.done; ch != nil {
+ close(ch)
}
+ e.done = nil
+ }
+
+ if expiration.IsZero() || expiration.After(e.expiration) {
+ e.expiration = expiration
}
e.s = ns
}
-func (e *linkAddrEntry) maybeAddWaker(w *sleep.Waker) {
- if w != nil {
- e.wakers[w] = struct{}{}
- }
-}
-
func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
delete(e.wakers, w)
}
// add adds a k -> v mapping to the cache.
func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
- c.mu.Lock()
- defer c.mu.Unlock()
+ // Calculate expiration time before acquiring the lock, since expiration is
+ // relative to the time when information was learned, rather than when it
+ // happened to be inserted into the cache.
+ expiration := time.Now().Add(c.ageLimit)
- entry, ok := c.cache[k]
- if ok {
- s := entry.state()
- if s != expired && entry.linkAddr == v {
- // Disregard repeated calls.
- return
- }
- // Check if entry is waiting for address resolution.
- if s == incomplete {
- entry.linkAddr = v
- } else {
- // Otherwise create a new entry to replace it.
- entry = c.makeAndAddEntry(k, v)
- }
- } else {
- entry = c.makeAndAddEntry(k, v)
- }
+ c.cache.Lock()
+ entry := c.getOrCreateEntryLocked(k)
+ entry.linkAddr = v
- entry.changeState(ready)
+ entry.changeState(ready, expiration)
+ c.cache.Unlock()
}
-// makeAndAddEntry is a helper function to create and add a new
-// entry to the cache map and evict older entry as needed.
-func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry {
- // Take over the next entry.
- entry := &c.entries[c.next]
- if c.cache[entry.addr] == entry {
- delete(c.cache, entry.addr)
+// getOrCreateEntryLocked retrieves a cache entry associated with k. The
+// returned entry is always refreshed in the cache (it is reachable via the
+// map, and its place is bumped in LRU).
+//
+// If a matching entry exists in the cache, it is returned. If no matching
+// entry exists and the cache is full, an existing entry is evicted via LRU,
+// reset to state incomplete, and returned. If no matching entry exists and the
+// cache is not full, a new entry with state incomplete is allocated and
+// returned.
+func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry {
+ if entry, ok := c.cache.table[k]; ok {
+ c.cache.lru.Remove(entry)
+ c.cache.lru.PushFront(entry)
+ return entry
}
+ var entry *linkAddrEntry
+ if len(c.cache.table) == linkAddrCacheSize {
+ entry = c.cache.lru.Back()
- // Mark the soon-to-be-replaced entry as expired, just in case there is
- // someone waiting for address resolution on it.
- entry.changeState(expired)
+ delete(c.cache.table, entry.addr)
+ c.cache.lru.Remove(entry)
+
+ // Wake waiters and mark the soon-to-be-reused entry as expired. Note
+ // that the state passed doesn't matter when the zero time is passed.
+ entry.changeState(failed, time.Time{})
+ } else {
+ entry = new(linkAddrEntry)
+ }
*entry = linkAddrEntry{
- addr: k,
- linkAddr: v,
- expiration: time.Now().Add(c.ageLimit),
- wakers: make(map[*sleep.Waker]struct{}),
- done: make(chan struct{}),
+ addr: k,
+ s: incomplete,
}
-
- c.cache[k] = entry
- c.next = (c.next + 1) % len(c.entries)
+ c.cache.table[k] = entry
+ c.cache.lru.PushFront(entry)
return entry
}
@@ -208,43 +187,55 @@
}
}
- c.mu.Lock()
- defer c.mu.Unlock()
- if entry, ok := c.cache[k]; ok {
- switch s := entry.state(); s {
- case expired:
- case ready:
- return entry.linkAddr, nil, nil
- case failed:
- return "", nil, tcpip.ErrNoLinkAddress
- case incomplete:
- // Address resolution is still in progress.
- entry.maybeAddWaker(waker)
- return "", entry.done, tcpip.ErrWouldBlock
- default:
- panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ c.cache.Lock()
+ defer c.cache.Unlock()
+ entry := c.getOrCreateEntryLocked(k)
+ switch s := entry.s; s {
+ case ready, failed:
+ if !time.Now().After(entry.expiration) {
+ // Not expired.
+ switch s {
+ case ready:
+ return entry.linkAddr, nil, nil
+ case failed:
+ return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ }
}
+
+ entry.changeState(incomplete, time.Time{})
+ fallthrough
+ case incomplete:
+ if waker != nil {
+ if entry.wakers == nil {
+ entry.wakers = make(map[*sleep.Waker]struct{})
+ }
+ entry.wakers[waker] = struct{}{}
+ }
+
+ if entry.done == nil {
+ // Address resolution needs to be initiated.
+ if linkRes == nil {
+ return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
+ }
+
+ entry.done = make(chan struct{})
+ go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done)
+ }
+
+ return entry.linkAddr, entry.done, tcpip.ErrWouldBlock
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
-
- if linkRes == nil {
- return "", nil, tcpip.ErrNoLinkAddress
- }
-
- // Add 'incomplete' entry in the cache to mark that resolution is in progress.
- e := c.makeAndAddEntry(k, "")
- e.maybeAddWaker(waker)
-
- go c.startAddressResolution(k, linkRes, localAddr, linkEP, e.done)
-
- return "", e.done, tcpip.ErrWouldBlock
}
// removeWaker removes a waker previously added through get().
func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
- c.mu.Lock()
- defer c.mu.Unlock()
+ c.cache.Lock()
+ defer c.cache.Unlock()
- if entry, ok := c.cache[k]; ok {
+ if entry, ok := c.cache.table[k]; ok {
entry.removeWaker(waker)
}
}
@@ -256,8 +247,8 @@
linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
select {
- case <-time.After(c.resolutionTimeout):
- if stop := c.checkLinkRequest(k, i); stop {
+ case now := <-time.After(c.resolutionTimeout):
+ if stop := c.checkLinkRequest(now, k, i); stop {
return
}
case <-done:
@@ -269,38 +260,36 @@
// checkLinkRequest checks whether previous attempt to resolve address has succeeded
// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request
// can stop, false if another request should be sent.
-func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool {
- c.mu.Lock()
- defer c.mu.Unlock()
-
- entry, ok := c.cache[k]
+func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool {
+ c.cache.Lock()
+ defer c.cache.Unlock()
+ entry, ok := c.cache.table[k]
if !ok {
// Entry was evicted from the cache.
return true
}
-
- switch s := entry.state(); s {
- case ready, failed, expired:
+ switch s := entry.s; s {
+ case ready, failed:
// Entry was made ready by resolver or failed. Either way we're done.
- return true
case incomplete:
- if attempt+1 >= c.resolutionAttempts {
- // Max number of retries reached, mark entry as failed.
- entry.changeState(failed)
- return true
+ if attempt+1 < c.resolutionAttempts {
+ // No response yet, need to send another ARP request.
+ return false
}
- // No response yet, need to send another ARP request.
- return false
+ // Max number of retries reached, mark entry as failed.
+ entry.changeState(failed, now.Add(c.ageLimit))
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
+ return true
}
func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
- return &linkAddrCache{
+ c := &linkAddrCache{
ageLimit: ageLimit,
resolutionTimeout: resolutionTimeout,
resolutionAttempts: resolutionAttempts,
- cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize),
}
+ c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize)
+ return c
}
diff --git a/tcpip/stack/linkaddrcache_test.go b/tcpip/stack/linkaddrcache_test.go
index 59eae3d..0966077 100644
--- a/tcpip/stack/linkaddrcache_test.go
+++ b/tcpip/stack/linkaddrcache_test.go
@@ -17,6 +17,7 @@
import (
"fmt"
"sync"
+ "sync/atomic"
"testing"
"time"
@@ -29,25 +30,34 @@
linkAddr tcpip.LinkAddress
}
-var testaddrs []testaddr
+var testAddrs = func() []testaddr {
+ var addrs []testaddr
+ for i := 0; i < 4*linkAddrCacheSize; i++ {
+ addr := fmt.Sprintf("Addr%06d", i)
+ addrs = append(addrs, testaddr{
+ addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
+ linkAddr: tcpip.LinkAddress("Link" + addr),
+ })
+ }
+ return addrs
+}()
type testLinkAddressResolver struct {
- cache *linkAddrCache
- delay time.Duration
+ cache *linkAddrCache
+ delay time.Duration
+ onLinkAddressRequest func()
}
func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
- go func() {
- if r.delay > 0 {
- time.Sleep(r.delay)
- }
- r.fakeRequest(addr)
- }()
+ time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
+ if f := r.onLinkAddressRequest; f != nil {
+ f()
+ }
return nil
}
func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) {
- for _, ta := range testaddrs {
+ for _, ta := range testAddrs {
if ta.addr.Addr == addr {
r.cache.add(ta.addr, ta.linkAddr)
break
@@ -80,20 +90,10 @@
}
}
-func init() {
- for i := 0; i < 4*linkAddrCacheSize; i++ {
- addr := fmt.Sprintf("Addr%06d", i)
- testaddrs = append(testaddrs, testaddr{
- addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
- linkAddr: tcpip.LinkAddress("Link" + addr),
- })
- }
-}
-
func TestCacheOverflow(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
- for i := len(testaddrs) - 1; i >= 0; i-- {
- e := testaddrs[i]
+ for i := len(testAddrs) - 1; i >= 0; i-- {
+ e := testAddrs[i]
c.add(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
@@ -105,7 +105,7 @@
}
// Expect to find at least half of the most recent entries.
for i := 0; i < linkAddrCacheSize/2; i++ {
- e := testaddrs[i]
+ e := testAddrs[i]
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
@@ -115,8 +115,8 @@
}
}
// The earliest entries should no longer be in the cache.
- for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- {
- e := testaddrs[i]
+ for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- {
+ e := testAddrs[i]
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
}
@@ -130,7 +130,7 @@
for r := 0; r < 16; r++ {
wg.Add(1)
go func() {
- for _, e := range testaddrs {
+ for _, e := range testAddrs {
c.add(e.addr, e.linkAddr)
c.get(e.addr, nil, "", nil, nil) // make work for gotsan
}
@@ -142,7 +142,7 @@
// All goroutines add in the same order and add more values than
// can fit in the cache, so our eviction strategy requires that
// the last entry be present and the first be missing.
- e := testaddrs[len(testaddrs)-1]
+ e := testAddrs[len(testAddrs)-1]
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
@@ -151,7 +151,7 @@
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
}
- e = testaddrs[0]
+ e = testAddrs[0]
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
@@ -159,7 +159,7 @@
func TestCacheAgeLimit(t *testing.T) {
c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
- e := testaddrs[0]
+ e := testAddrs[0]
c.add(e.addr, e.linkAddr)
time.Sleep(50 * time.Millisecond)
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
@@ -169,7 +169,7 @@
func TestCacheReplace(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
- e := testaddrs[0]
+ e := testAddrs[0]
l2 := e.linkAddr + "2"
c.add(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
@@ -193,7 +193,7 @@
func TestCacheResolution(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1)
linkRes := &testLinkAddressResolver{cache: c}
- for i, ta := range testaddrs {
+ for i, ta := range testAddrs {
got, err := getBlocking(c, ta.addr, linkRes)
if err != nil {
t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err)
@@ -205,7 +205,7 @@
// Check that after resolved, address stays in the cache and never returns WouldBlock.
for i := 0; i < 10; i++ {
- e := testaddrs[len(testaddrs)-1]
+ e := testAddrs[len(testAddrs)-1]
got, _, err := c.get(e.addr, linkRes, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
@@ -220,8 +220,13 @@
c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5)
linkRes := &testLinkAddressResolver{cache: c}
+ var requestCount uint32
+ linkRes.onLinkAddressRequest = func() {
+ atomic.AddUint32(&requestCount, 1)
+ }
+
// First, sanity check that resolution is working...
- e := testaddrs[0]
+ e := testAddrs[0]
got, err := getBlocking(c, e.addr, linkRes)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
@@ -230,10 +235,16 @@
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
}
+ before := atomic.LoadUint32(&requestCount)
+
e.addr.Addr += "2"
if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
+
+ if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want {
+ t.Errorf("got link address request count = %d, want = %d", got, want)
+ }
}
func TestCacheResolutionTimeout(t *testing.T) {
@@ -242,7 +253,7 @@
c := newLinkAddrCache(expiration, 1*time.Millisecond, 3)
linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
- e := testaddrs[0]
+ e := testAddrs[0]
if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
diff --git a/tcpip/stack/nic.go b/tcpip/stack/nic.go
index 1359fdd..c2251c2 100644
--- a/tcpip/stack/nic.go
+++ b/tcpip/stack/nic.go
@@ -531,6 +531,13 @@
return nil
}
+func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) {
+ r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
+ r.RemoteLinkAddress = remotelinkAddr
+ ref.ep.HandlePacket(&r, vv)
+ ref.decRef()
+}
+
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
// hands the packet over for further processing. This function is called when
// the NIC receives a packet from the physical interface.
@@ -558,6 +565,8 @@
src, dst := netProto.ParseAddresses(vv.First())
+ n.stack.AddLinkAddress(n.id, src, remote)
+
// If the packet is destined to the IPv4 Broadcast address, then make a
// route to each IPv4 network endpoint and let each endpoint handle the
// packet.
@@ -566,10 +575,7 @@
n.mu.RLock()
for _, ref := range n.endpoints {
if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() {
- r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */)
- r.RemoteLinkAddress = remote
- ref.ep.HandlePacket(&r, vv)
- ref.decRef()
+ handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
}
}
n.mu.RUnlock()
@@ -577,10 +583,7 @@
}
if ref := n.getRef(protocol, dst); ref != nil {
- r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */)
- r.RemoteLinkAddress = remote
- ref.ep.HandlePacket(&r, vv)
- ref.decRef()
+ handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
return
}
diff --git a/tcpip/transport/tcp/snd.go b/tcpip/transport/tcp/snd.go
index a7087fa..a2f65b7 100644
--- a/tcpip/transport/tcp/snd.go
+++ b/tcpip/transport/tcp/snd.go
@@ -664,7 +664,14 @@
segEnd = seg.sequenceNumber.Add(1)
// Transition to FIN-WAIT1 state since we're initiating an active close.
s.ep.mu.Lock()
- s.ep.state = StateFinWait1
+ switch s.ep.state {
+ case StateCloseWait:
+ // We've already received a FIN and are now sending our own. The
+ // sender is now awaiting a final ACK for this FIN.
+ s.ep.state = StateLastAck
+ default:
+ s.ep.state = StateFinWait1
+ }
s.ep.mu.Unlock()
} else {
// We're sending a non-FIN segment.