[icmp] Add support for client-side ICMP requests sent at the transport layer
Change-Id: I25fbc818dea496dbbf020b16c716c1e0b03aebed
diff --git a/tcpip/header/icmpv4.go b/tcpip/header/icmpv4.go
index dd99a14..4787097 100644
--- a/tcpip/header/icmpv4.go
+++ b/tcpip/header/icmpv4.go
@@ -58,3 +58,8 @@
func (b ICMPv4) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[2:], checksum)
}
+
+// CalculateChecksum calculates the checksum of the ipv4 header.
+func (b ICMPv4) CalculateChecksum(prev uint16) uint16 {
+ return Checksum(b[:], prev)
+}
diff --git a/tcpip/network/ipv4/icmp.go b/tcpip/network/ipv4/icmp.go
index 8533436..7e3d4dc 100644
--- a/tcpip/network/ipv4/icmp.go
+++ b/tcpip/network/ipv4/icmp.go
@@ -7,6 +7,7 @@
import (
"context"
"encoding/binary"
+ "sync"
"time"
"github.com/google/netstack/tcpip"
@@ -20,10 +21,10 @@
// Use it when constructing a stack that intends to use ipv4.Ping.
const PingProtocolName = "icmpv4ping"
-// pingProtocolNumber is a fake transport protocol used to
-// deliver incoming ICMP echo replies. The ICMP identifier
+// PingProtocolNumber is a transport protocol used to
+// transmit and deliver ICMP messages. The ICMP identifier
// number is used as a port number for multiplexing.
-const pingProtocolNumber tcpip.TransportProtocolNumber = 256 + 11
+const PingProtocolNumber tcpip.TransportProtocolNumber = 1
func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) {
v := vv.First()
@@ -44,8 +45,8 @@
default:
req.r.Release()
}
- case header.ICMPv4EchoReply:
- e.dispatcher.DeliverTransportPacket(r, pingProtocolNumber, vv)
+ case header.ICMPv4EchoReply, header.ICMPv4InfoReply, header.ICMPv4TimestampReply:
+ e.dispatcher.DeliverTransportPacket(r, PingProtocolNumber, vv)
}
// TODO(crawshaw): Handle other ICMP types.
}
@@ -83,6 +84,22 @@
Count uint16 // if zero, defaults to MaxUint16
}
+type pingerEndpoint struct {
+ stack *stack.Stack
+ pktCh chan buffer.View
+}
+
+func (e *pingerEndpoint) Close() {
+ close(e.pktCh)
+}
+
+func (e *pingerEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) {
+ select {
+ case e.pktCh <- vv.ToView():
+ default:
+ }
+}
+
// Ping sends echo requests to an ICMPv4 endpoint.
// Responses are streamed to the channel ch.
func (p *Pinger) Ping(ctx context.Context, ch chan<- PingReply) *tcpip.Error {
@@ -101,7 +118,7 @@
}
netProtos := []tcpip.NetworkProtocolNumber{ProtocolNumber}
- ep := &pingEndpoint{
+ ep := &pingerEndpoint{
stack: p.Stack,
pktCh: make(chan buffer.View, 1),
}
@@ -112,7 +129,7 @@
_, err = p.Stack.PickEphemeralPort(func(port uint16) (bool, *tcpip.Error) {
id.LocalPort = port
- err := p.Stack.RegisterTransportEndpoint(p.NICID, netProtos, pingProtocolNumber, id, ep)
+ err := p.Stack.RegisterTransportEndpoint(p.NICID, netProtos, PingProtocolNumber, id, ep)
switch err {
case nil:
return true, nil
@@ -125,7 +142,7 @@
if err != nil {
return err
}
- defer p.Stack.UnregisterTransportEndpoint(p.NICID, netProtos, pingProtocolNumber, id)
+ defer p.Stack.UnregisterTransportEndpoint(p.NICID, netProtos, PingProtocolNumber, id)
v := buffer.NewView(4)
binary.BigEndian.PutUint16(v[0:], id.LocalPort)
@@ -179,23 +196,101 @@
SeqNumber uint16
}
-type pingProtocol struct{}
+type endpointState int
-func (*pingProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return nil, tcpip.ErrNotSupported // endpoints are created directly
+const (
+ stateInitial endpointState = iota
+ stateConnected
+ stateClosed
+)
+
+type pingEndpoint struct {
+ stack *stack.Stack
+ netProto tcpip.NetworkProtocolNumber
+ waiterQueue *waiter.Queue
+
+ mu sync.RWMutex
+ pktCh chan buffer.View
+ state endpointState
+ route stack.Route
+ nic tcpip.NICID
+ id stack.TransportEndpointID
}
-func (*pingProtocol) Number() tcpip.TransportProtocolNumber { return pingProtocolNumber }
-
-func (*pingProtocol) MinimumPacketSize() int { return header.ICMPv4EchoMinimumSize }
-
-func (*pingProtocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
- ident := binary.BigEndian.Uint16(v[4:])
- return 0, ident, nil
+func (e *pingEndpoint) Close() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if e.state == stateClosed {
+ return
+ }
+ if e.state == stateConnected {
+ netProtos := []tcpip.NetworkProtocolNumber{ProtocolNumber}
+ e.stack.UnregisterTransportEndpoint(e.nic, netProtos, PingProtocolNumber, e.id)
+ e.route.Release()
+ }
+ close(e.pktCh)
+ e.state = stateClosed
}
-func (*pingProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool {
- return true
+func (e *pingEndpoint) Read(a *tcpip.FullAddress) (buffer.View, *tcpip.Error) {
+ select {
+ case v := <-e.pktCh:
+ return v, nil
+ default:
+ return buffer.View{}, tcpip.ErrWouldBlock
+ }
+}
+
+func (e *pingEndpoint) Write(v buffer.View, to *tcpip.FullAddress) (uintptr, *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ switch state := e.state; state {
+ case stateInitial:
+ if to == nil {
+ return 0, tcpip.ErrNotSupported
+ } else if err := e.bindLocked(*to, nil); err != nil {
+ return 0, err
+ }
+ case stateConnected:
+ if to != nil {
+ prev := tcpip.FullAddress{
+ NIC: e.nic,
+ Addr: e.id.RemoteAddress,
+ Port: e.id.RemotePort,
+ }
+
+ if prev != *to {
+ return 0, tcpip.ErrAlreadyConnected
+ }
+ }
+ default:
+ return 0, tcpip.ErrClosedForSend
+ }
+
+ if len(v) < header.ICMPv4MinimumSize {
+ return 0, tcpip.ErrNotSupported
+ }
+
+ hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(e.route.MaxHeaderLength()))
+ icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ copy(icmpv4, v[:header.ICMPv4MinimumSize])
+ icmpv4.SetCode(0)
+ data := v[header.ICMPv4MinimumSize:]
+ // Overwrite the ID with the port number
+ binary.BigEndian.PutUint16(data[0:], e.id.LocalPort)
+ // Overwrite the checksum of the packet
+ icmpv4.SetChecksum(0)
+ chksum := header.ICMPv4(data).CalculateChecksum(icmpv4.CalculateChecksum(0))
+ icmpv4.SetChecksum(^chksum)
+
+ if err := e.route.WritePacket(&hdr, data, header.ICMPv4ProtocolNumber); err != nil {
+ return 0, err
+ }
+ return uintptr(len(v)), nil
+}
+
+func (e *pingEndpoint) Peek(data [][]byte) (uintptr, *tcpip.Error) {
+ return 0, tcpip.ErrNotSupported
}
// SetOption implements TransportProtocol.SetOption.
@@ -209,18 +304,120 @@
})
}
-type pingEndpoint struct {
- stack *stack.Stack
- pktCh chan buffer.View
+func (e *pingEndpoint) Connect(address tcpip.FullAddress) *tcpip.Error {
+ return tcpip.ErrNotSupported
}
-func (e *pingEndpoint) Close() {
- close(e.pktCh)
+func (e *pingEndpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.Close()
+ return nil
+}
+
+func (e *pingEndpoint) Listen(backlog int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+func (e *pingEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+func (e *pingEndpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.bindLocked(addr, commit)
+}
+
+func (e *pingEndpoint) bindLocked(to tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
+ if e.state != stateInitial {
+ return tcpip.ErrAlreadyConnected
+ }
+ r, err := e.stack.FindRoute(to.NIC, "", to.Addr, e.netProto)
+ if err != nil {
+ return err
+ }
+
+ netProtos := []tcpip.NetworkProtocolNumber{e.netProto}
+ id := stack.TransportEndpointID{
+ LocalAddress: r.LocalAddress,
+ RemoteAddress: to.Addr,
+ }
+
+ _, err = e.stack.PickEphemeralPort(func(port uint16) (bool, *tcpip.Error) {
+ id.LocalPort = port
+ err := e.stack.RegisterTransportEndpoint(to.NIC, netProtos, PingProtocolNumber, id, e)
+ switch err {
+ case nil:
+ return true, nil
+ case tcpip.ErrPortInUse:
+ return false, nil
+ default:
+ return false, err
+ }
+ })
+
+ if commit != nil {
+ if err := commit(); err != nil {
+ e.stack.UnregisterTransportEndpoint(to.NIC, netProtos, PingProtocolNumber, id)
+ r.Release()
+ return err
+ }
+ }
+
+ e.state = stateConnected
+ e.route = r
+ e.nic = to.NIC
+ e.id = id
+ return nil
+}
+
+func (e *pingEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, tcpip.ErrNotSupported
+}
+
+func (e *pingEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, tcpip.ErrNotSupported
+}
+
+func (e *pingEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return 0
+}
+
+func (e *pingEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+func (e *pingEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ return tcpip.ErrNotSupported
}
func (e *pingEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) {
select {
case e.pktCh <- vv.ToView():
+ e.waiterQueue.Notify(waiter.EventIn)
default:
}
}
+
+type pingProtocol struct{}
+
+func (p *pingProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return &pingEndpoint{
+ stack: stack,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
+ pktCh: make(chan buffer.View, 10),
+ }, nil
+}
+
+func (*pingProtocol) Number() tcpip.TransportProtocolNumber { return PingProtocolNumber }
+
+func (*pingProtocol) MinimumPacketSize() int { return header.ICMPv4EchoMinimumSize }
+
+func (*pingProtocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+ ident := binary.BigEndian.Uint16(v[4:])
+ return 0, ident, nil
+}
+
+func (*pingProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool {
+ return true
+}
diff --git a/tcpip/network/ipv4/icmp_test.go b/tcpip/network/ipv4/icmp_test.go
index 4eb5818..3a04f60 100644
--- a/tcpip/network/ipv4/icmp_test.go
+++ b/tcpip/network/ipv4/icmp_test.go
@@ -6,15 +6,18 @@
import (
"context"
+ "encoding/binary"
"testing"
"time"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
+ "github.com/google/netstack/tcpip/header"
"github.com/google/netstack/tcpip/link/channel"
"github.com/google/netstack/tcpip/link/sniffer"
"github.com/google/netstack/tcpip/network/ipv4"
"github.com/google/netstack/tcpip/stack"
+ "github.com/google/netstack/waiter"
)
const stackAddr = "\x0a\x00\x00\x01"
@@ -122,3 +125,212 @@
}
}
}
+
+const (
+ stackAddr0 = "\x0a\x00\x00\x02"
+ stackAddr1 = "\x0a\x00\x00\x03"
+ linkAddr0 = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06")
+ linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
+)
+
+type testEndpointContext struct {
+ t *testing.T
+ s *stack.Stack
+
+ linkEP0 *channel.Endpoint
+ linkEP1 *channel.Endpoint
+
+ icmpCh chan header.ICMPv4
+}
+
+func (c *testEndpointContext) cleanup() {
+ close(c.linkEP0.C)
+ close(c.linkEP1.C)
+}
+
+func newTestEndpointContext(t *testing.T) *testEndpointContext {
+ c := &testEndpointContext{
+ t: t,
+ s: stack.New([]string{ipv4.ProtocolName}, []string{ipv4.PingProtocolName}),
+ icmpCh: make(chan header.ICMPv4, 10),
+ }
+
+ const defaultMTU = 65536
+ id0, linkEP := channel.New(256, defaultMTU, linkAddr0)
+ c.linkEP0 = linkEP
+ if testing.Verbose() {
+ id0 = sniffer.New(id0)
+ }
+ if err := c.s.CreateNIC(1, id0); err != nil {
+ t.Fatalf("CreateNIC s: %v", err)
+ }
+ id1, linkEP := channel.New(256, defaultMTU, linkAddr1)
+ c.linkEP1 = linkEP
+ if testing.Verbose() {
+ id1 = sniffer.New(id1)
+ }
+ if err := c.s.CreateNIC(2, id1); err != nil {
+ t.Fatalf("CreateNIC s: %v", err)
+ }
+ if err := c.s.AddAddress(2, ipv4.ProtocolNumber, stackAddr0); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+ if err := c.s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+ c.s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: stackAddr0,
+ Mask: "\xFF\xFF\xFF\xFF",
+ Gateway: "",
+ NIC: 1,
+ },
+ {
+ Destination: stackAddr1,
+ Mask: "\xFF\xFF\xFF\xFF",
+ Gateway: "",
+ NIC: 2,
+ },
+ })
+
+ go c.routePackets(c.linkEP0.C, c.linkEP1)
+ go c.routePackets(c.linkEP1.C, c.linkEP0)
+ return c
+}
+
+func (c *testEndpointContext) countPacket(pkt channel.PacketInfo) {
+ if pkt.Proto != header.IPv4ProtocolNumber {
+ c.t.Fatalf("Received non IPV4 packet: 0x%x", pkt.Proto)
+ }
+ ipv4 := header.IPv4(pkt.Header)
+ c.icmpCh <- header.ICMPv4(append(pkt.Header[ipv4.HeaderLength():], pkt.Payload...))
+}
+
+func (c *testEndpointContext) routePackets(ch <-chan channel.PacketInfo, ep *channel.Endpoint) {
+ for pkt := range ch {
+ c.countPacket(pkt)
+ v := buffer.View(append(pkt.Header, pkt.Payload...))
+ vs := []buffer.View{v}
+ vv := buffer.NewVectorisedView(len(v), vs)
+ ep.InjectLinkAddr(pkt.Proto, ep.LinkAddress(), &vv)
+ }
+}
+
+type callbackStub struct {
+ f func(e *waiter.Entry)
+}
+
+func (c *callbackStub) Callback(e *waiter.Entry) {
+ c.f(e)
+}
+
+func TestEndpoints(t *testing.T) {
+ c := newTestEndpointContext(t)
+ defer c.cleanup()
+
+ wq0 := &waiter.Queue{}
+ ep0, err := c.s.NewEndpoint(ipv4.PingProtocolNumber, ipv4.ProtocolNumber, wq0)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep0.Close()
+ wq1 := &waiter.Queue{}
+ ep1, err := c.s.NewEndpoint(ipv4.PingProtocolNumber, ipv4.ProtocolNumber, wq1)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep1.Close()
+
+ if err := ep0.Bind(tcpip.FullAddress{NIC: 1, Addr: stackAddr0}, nil); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+ if err := ep1.Bind(tcpip.FullAddress{NIC: 2, Addr: stackAddr1}, nil); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ echos := 64
+
+ ping := func(wq *waiter.Queue, ep tcpip.Endpoint, data []byte) {
+ outPkt := make([]byte, header.ICMPv4MinimumSize+4+len(data))
+ icmpv4 := header.ICMPv4(outPkt[:header.ICMPv4MinimumSize])
+ icmpv4.SetType(header.ICMPv4Echo)
+ copy(outPkt[header.ICMPv4MinimumSize+4:], data)
+ binary.BigEndian.PutUint16(outPkt[header.ICMPv4MinimumSize:], 0)
+
+ for seqno := uint16(1); seqno <= uint16(echos); seqno++ {
+ binary.BigEndian.PutUint16(outPkt[header.ICMPv4MinimumSize+2:], seqno)
+
+ // We need to register with the waiter queue before we try writing, since
+ // the notification that the endpoint received a response may arrive immediately.
+ ready := make(chan struct{})
+ e := waiter.Entry{Callback: &callbackStub{func(*waiter.Entry) { close(ready) }}}
+ wq.EventRegister(&e, waiter.EventIn)
+ n, err := ep.Write(buffer.View(outPkt), nil)
+ if err != nil {
+ c.t.Fatalf("Write failed: %v\n", err)
+ } else if n != uintptr(len(outPkt)) {
+ c.t.Fatalf("Write was short: %v\n", n)
+ }
+
+ // Avoid reading until we have something to read
+ select {
+ case <-time.After(1 * time.Second):
+ c.t.Fatalf("Timed out waiting for socket to be readable")
+ case <-ready:
+ }
+ wq.EventUnregister(&e)
+ inPkt, err := ep.Read(nil)
+ if err != nil {
+ c.t.Fatalf("Read failed: %v\n", err)
+ }
+
+ // Verify the contents of the packet we just read.
+ var icmp header.ICMPv4 = []byte(inPkt)
+ if icmp.Type() != header.ICMPv4EchoReply {
+ c.t.Fatalf("Unexpected packet type: %d", icmp.Type())
+ }
+ inSeqno := binary.BigEndian.Uint16(inPkt[header.ICMPv4MinimumSize+2 : header.ICMPv4MinimumSize+4])
+ if inSeqno != seqno {
+ c.t.Fatalf("Unexpected sequence number: %d", inSeqno)
+ }
+ outData := outPkt[header.ICMPv4EchoMinimumSize:]
+ inData := inPkt[header.ICMPv4EchoMinimumSize:]
+ if len(outData) != len(inData) {
+ c.t.Fatalf("Read packet of unexpected length: %d\n", len(inData))
+ }
+ for i := range inData {
+ if inData[i] != outData[i] {
+ c.t.Fatalf("Data mismatch")
+ }
+ }
+ }
+ }
+
+ data := []byte{0xaa, 0xab, 0xac}
+ go ping(wq0, ep0, data)
+ data = []byte{0xad, 0xae, 0xaf}
+ go ping(wq1, ep1, data)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ stats := make(map[header.ICMPv4Type]int)
+ for {
+ select {
+ case <-ctx.Done():
+ t.Errorf("Timeout waiting for ICMP, got: %#+v", stats)
+ return
+ case icmp := <-c.icmpCh:
+ if icmp.Type() != header.ICMPv4Echo && icmp.Type() != header.ICMPv4EchoReply {
+ c.t.Fatalf("Unexpected type: %d", icmp.Type())
+ }
+ stats[icmp.Type()]++
+ if stats[icmp.Type()] > echos*2 {
+ c.t.Fatalf("Too many (%d) packets of type %d", stats[icmp.Type()], icmp.Type())
+ }
+ if len(stats) == 2 && stats[header.ICMPv4Echo] == echos*2 && stats[header.ICMPv4EchoReply] == echos*2 {
+ return
+ }
+ }
+ }
+}