[icmp] Add support for client-side ICMP requests sent at the transport layer

Change-Id: I25fbc818dea496dbbf020b16c716c1e0b03aebed
diff --git a/tcpip/header/icmpv4.go b/tcpip/header/icmpv4.go
index dd99a14..4787097 100644
--- a/tcpip/header/icmpv4.go
+++ b/tcpip/header/icmpv4.go
@@ -58,3 +58,8 @@
 func (b ICMPv4) SetChecksum(checksum uint16) {
 	binary.BigEndian.PutUint16(b[2:], checksum)
 }
+
+// CalculateChecksum calculates the checksum of the ipv4 header.
+func (b ICMPv4) CalculateChecksum(prev uint16) uint16 {
+	return Checksum(b[:], prev)
+}
diff --git a/tcpip/network/ipv4/icmp.go b/tcpip/network/ipv4/icmp.go
index 8533436..7e3d4dc 100644
--- a/tcpip/network/ipv4/icmp.go
+++ b/tcpip/network/ipv4/icmp.go
@@ -7,6 +7,7 @@
 import (
 	"context"
 	"encoding/binary"
+	"sync"
 	"time"
 
 	"github.com/google/netstack/tcpip"
@@ -20,10 +21,10 @@
 // Use it when constructing a stack that intends to use ipv4.Ping.
 const PingProtocolName = "icmpv4ping"
 
-// pingProtocolNumber is a fake transport protocol used to
-// deliver incoming ICMP echo replies. The ICMP identifier
+// PingProtocolNumber is a transport protocol used to
+// transmit and deliver ICMP messages. The ICMP identifier
 // number is used as a port number for multiplexing.
-const pingProtocolNumber tcpip.TransportProtocolNumber = 256 + 11
+const PingProtocolNumber tcpip.TransportProtocolNumber = 1
 
 func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) {
 	v := vv.First()
@@ -44,8 +45,8 @@
 		default:
 			req.r.Release()
 		}
-	case header.ICMPv4EchoReply:
-		e.dispatcher.DeliverTransportPacket(r, pingProtocolNumber, vv)
+	case header.ICMPv4EchoReply, header.ICMPv4InfoReply, header.ICMPv4TimestampReply:
+		e.dispatcher.DeliverTransportPacket(r, PingProtocolNumber, vv)
 	}
 	// TODO(crawshaw): Handle other ICMP types.
 }
@@ -83,6 +84,22 @@
 	Count     uint16        // if zero, defaults to MaxUint16
 }
 
+type pingerEndpoint struct {
+	stack *stack.Stack
+	pktCh chan buffer.View
+}
+
+func (e *pingerEndpoint) Close() {
+	close(e.pktCh)
+}
+
+func (e *pingerEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) {
+	select {
+	case e.pktCh <- vv.ToView():
+	default:
+	}
+}
+
 // Ping sends echo requests to an ICMPv4 endpoint.
 // Responses are streamed to the channel ch.
 func (p *Pinger) Ping(ctx context.Context, ch chan<- PingReply) *tcpip.Error {
@@ -101,7 +118,7 @@
 	}
 
 	netProtos := []tcpip.NetworkProtocolNumber{ProtocolNumber}
-	ep := &pingEndpoint{
+	ep := &pingerEndpoint{
 		stack: p.Stack,
 		pktCh: make(chan buffer.View, 1),
 	}
@@ -112,7 +129,7 @@
 
 	_, err = p.Stack.PickEphemeralPort(func(port uint16) (bool, *tcpip.Error) {
 		id.LocalPort = port
-		err := p.Stack.RegisterTransportEndpoint(p.NICID, netProtos, pingProtocolNumber, id, ep)
+		err := p.Stack.RegisterTransportEndpoint(p.NICID, netProtos, PingProtocolNumber, id, ep)
 		switch err {
 		case nil:
 			return true, nil
@@ -125,7 +142,7 @@
 	if err != nil {
 		return err
 	}
-	defer p.Stack.UnregisterTransportEndpoint(p.NICID, netProtos, pingProtocolNumber, id)
+	defer p.Stack.UnregisterTransportEndpoint(p.NICID, netProtos, PingProtocolNumber, id)
 
 	v := buffer.NewView(4)
 	binary.BigEndian.PutUint16(v[0:], id.LocalPort)
@@ -179,23 +196,101 @@
 	SeqNumber uint16
 }
 
-type pingProtocol struct{}
+type endpointState int
 
-func (*pingProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
-	return nil, tcpip.ErrNotSupported // endpoints are created directly
+const (
+	stateInitial endpointState = iota
+	stateConnected
+	stateClosed
+)
+
+type pingEndpoint struct {
+	stack       *stack.Stack
+	netProto    tcpip.NetworkProtocolNumber
+	waiterQueue *waiter.Queue
+
+	mu    sync.RWMutex
+	pktCh chan buffer.View
+	state endpointState
+	route stack.Route
+	nic   tcpip.NICID
+	id    stack.TransportEndpointID
 }
 
-func (*pingProtocol) Number() tcpip.TransportProtocolNumber { return pingProtocolNumber }
-
-func (*pingProtocol) MinimumPacketSize() int { return header.ICMPv4EchoMinimumSize }
-
-func (*pingProtocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
-	ident := binary.BigEndian.Uint16(v[4:])
-	return 0, ident, nil
+func (e *pingEndpoint) Close() {
+	e.mu.Lock()
+	defer e.mu.Unlock()
+	if e.state == stateClosed {
+		return
+	}
+	if e.state == stateConnected {
+		netProtos := []tcpip.NetworkProtocolNumber{ProtocolNumber}
+		e.stack.UnregisterTransportEndpoint(e.nic, netProtos, PingProtocolNumber, e.id)
+		e.route.Release()
+	}
+	close(e.pktCh)
+	e.state = stateClosed
 }
 
-func (*pingProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool {
-	return true
+func (e *pingEndpoint) Read(a *tcpip.FullAddress) (buffer.View, *tcpip.Error) {
+	select {
+	case v := <-e.pktCh:
+		return v, nil
+	default:
+		return buffer.View{}, tcpip.ErrWouldBlock
+	}
+}
+
+func (e *pingEndpoint) Write(v buffer.View, to *tcpip.FullAddress) (uintptr, *tcpip.Error) {
+	e.mu.Lock()
+	defer e.mu.Unlock()
+	switch state := e.state; state {
+	case stateInitial:
+		if to == nil {
+			return 0, tcpip.ErrNotSupported
+		} else if err := e.bindLocked(*to, nil); err != nil {
+			return 0, err
+		}
+	case stateConnected:
+		if to != nil {
+			prev := tcpip.FullAddress{
+				NIC:  e.nic,
+				Addr: e.id.RemoteAddress,
+				Port: e.id.RemotePort,
+			}
+
+			if prev != *to {
+				return 0, tcpip.ErrAlreadyConnected
+			}
+		}
+	default:
+		return 0, tcpip.ErrClosedForSend
+	}
+
+	if len(v) < header.ICMPv4MinimumSize {
+		return 0, tcpip.ErrNotSupported
+	}
+
+	hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(e.route.MaxHeaderLength()))
+	icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+	copy(icmpv4, v[:header.ICMPv4MinimumSize])
+	icmpv4.SetCode(0)
+	data := v[header.ICMPv4MinimumSize:]
+	// Overwrite the ID with the port number
+	binary.BigEndian.PutUint16(data[0:], e.id.LocalPort)
+	// Overwrite the checksum of the packet
+	icmpv4.SetChecksum(0)
+	chksum := header.ICMPv4(data).CalculateChecksum(icmpv4.CalculateChecksum(0))
+	icmpv4.SetChecksum(^chksum)
+
+	if err := e.route.WritePacket(&hdr, data, header.ICMPv4ProtocolNumber); err != nil {
+		return 0, err
+	}
+	return uintptr(len(v)), nil
+}
+
+func (e *pingEndpoint) Peek(data [][]byte) (uintptr, *tcpip.Error) {
+	return 0, tcpip.ErrNotSupported
 }
 
 // SetOption implements TransportProtocol.SetOption.
@@ -209,18 +304,120 @@
 	})
 }
 
-type pingEndpoint struct {
-	stack *stack.Stack
-	pktCh chan buffer.View
+func (e *pingEndpoint) Connect(address tcpip.FullAddress) *tcpip.Error {
+	return tcpip.ErrNotSupported
 }
 
-func (e *pingEndpoint) Close() {
-	close(e.pktCh)
+func (e *pingEndpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+	e.Close()
+	return nil
+}
+
+func (e *pingEndpoint) Listen(backlog int) *tcpip.Error {
+	return tcpip.ErrNotSupported
+}
+
+func (e *pingEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+	return nil, nil, tcpip.ErrNotSupported
+}
+
+func (e *pingEndpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
+	e.mu.Lock()
+	defer e.mu.Unlock()
+	return e.bindLocked(addr, commit)
+}
+
+func (e *pingEndpoint) bindLocked(to tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
+	if e.state != stateInitial {
+		return tcpip.ErrAlreadyConnected
+	}
+	r, err := e.stack.FindRoute(to.NIC, "", to.Addr, e.netProto)
+	if err != nil {
+		return err
+	}
+
+	netProtos := []tcpip.NetworkProtocolNumber{e.netProto}
+	id := stack.TransportEndpointID{
+		LocalAddress:  r.LocalAddress,
+		RemoteAddress: to.Addr,
+	}
+
+	_, err = e.stack.PickEphemeralPort(func(port uint16) (bool, *tcpip.Error) {
+		id.LocalPort = port
+		err := e.stack.RegisterTransportEndpoint(to.NIC, netProtos, PingProtocolNumber, id, e)
+		switch err {
+		case nil:
+			return true, nil
+		case tcpip.ErrPortInUse:
+			return false, nil
+		default:
+			return false, err
+		}
+	})
+
+	if commit != nil {
+		if err := commit(); err != nil {
+			e.stack.UnregisterTransportEndpoint(to.NIC, netProtos, PingProtocolNumber, id)
+			r.Release()
+			return err
+		}
+	}
+
+	e.state = stateConnected
+	e.route = r
+	e.nic = to.NIC
+	e.id = id
+	return nil
+}
+
+func (e *pingEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+	return tcpip.FullAddress{}, tcpip.ErrNotSupported
+}
+
+func (e *pingEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+	return tcpip.FullAddress{}, tcpip.ErrNotSupported
+}
+
+func (e *pingEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+	return 0
+}
+
+func (e *pingEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+	return tcpip.ErrNotSupported
+}
+
+func (e *pingEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+	return tcpip.ErrNotSupported
 }
 
 func (e *pingEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) {
 	select {
 	case e.pktCh <- vv.ToView():
+		e.waiterQueue.Notify(waiter.EventIn)
 	default:
 	}
 }
+
+type pingProtocol struct{}
+
+func (p *pingProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+	return &pingEndpoint{
+		stack:       stack,
+		netProto:    netProto,
+		waiterQueue: waiterQueue,
+		pktCh:       make(chan buffer.View, 10),
+	}, nil
+}
+
+func (*pingProtocol) Number() tcpip.TransportProtocolNumber { return PingProtocolNumber }
+
+func (*pingProtocol) MinimumPacketSize() int { return header.ICMPv4EchoMinimumSize }
+
+func (*pingProtocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+	ident := binary.BigEndian.Uint16(v[4:])
+	return 0, ident, nil
+}
+
+func (*pingProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool {
+	return true
+}
diff --git a/tcpip/network/ipv4/icmp_test.go b/tcpip/network/ipv4/icmp_test.go
index 4eb5818..3a04f60 100644
--- a/tcpip/network/ipv4/icmp_test.go
+++ b/tcpip/network/ipv4/icmp_test.go
@@ -6,15 +6,18 @@
 
 import (
 	"context"
+	"encoding/binary"
 	"testing"
 	"time"
 
 	"github.com/google/netstack/tcpip"
 	"github.com/google/netstack/tcpip/buffer"
+	"github.com/google/netstack/tcpip/header"
 	"github.com/google/netstack/tcpip/link/channel"
 	"github.com/google/netstack/tcpip/link/sniffer"
 	"github.com/google/netstack/tcpip/network/ipv4"
 	"github.com/google/netstack/tcpip/stack"
+	"github.com/google/netstack/waiter"
 )
 
 const stackAddr = "\x0a\x00\x00\x01"
@@ -122,3 +125,212 @@
 		}
 	}
 }
+
+const (
+	stackAddr0 = "\x0a\x00\x00\x02"
+	stackAddr1 = "\x0a\x00\x00\x03"
+	linkAddr0  = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06")
+	linkAddr1  = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
+)
+
+type testEndpointContext struct {
+	t *testing.T
+	s *stack.Stack
+
+	linkEP0 *channel.Endpoint
+	linkEP1 *channel.Endpoint
+
+	icmpCh chan header.ICMPv4
+}
+
+func (c *testEndpointContext) cleanup() {
+	close(c.linkEP0.C)
+	close(c.linkEP1.C)
+}
+
+func newTestEndpointContext(t *testing.T) *testEndpointContext {
+	c := &testEndpointContext{
+		t:      t,
+		s:      stack.New([]string{ipv4.ProtocolName}, []string{ipv4.PingProtocolName}),
+		icmpCh: make(chan header.ICMPv4, 10),
+	}
+
+	const defaultMTU = 65536
+	id0, linkEP := channel.New(256, defaultMTU, linkAddr0)
+	c.linkEP0 = linkEP
+	if testing.Verbose() {
+		id0 = sniffer.New(id0)
+	}
+	if err := c.s.CreateNIC(1, id0); err != nil {
+		t.Fatalf("CreateNIC s: %v", err)
+	}
+	id1, linkEP := channel.New(256, defaultMTU, linkAddr1)
+	c.linkEP1 = linkEP
+	if testing.Verbose() {
+		id1 = sniffer.New(id1)
+	}
+	if err := c.s.CreateNIC(2, id1); err != nil {
+		t.Fatalf("CreateNIC s: %v", err)
+	}
+	if err := c.s.AddAddress(2, ipv4.ProtocolNumber, stackAddr0); err != nil {
+		t.Fatalf("AddAddress failed: %v", err)
+	}
+	if err := c.s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil {
+		t.Fatalf("AddAddress failed: %v", err)
+	}
+	c.s.SetRouteTable([]tcpip.Route{
+		{
+			Destination: stackAddr0,
+			Mask:        "\xFF\xFF\xFF\xFF",
+			Gateway:     "",
+			NIC:         1,
+		},
+		{
+			Destination: stackAddr1,
+			Mask:        "\xFF\xFF\xFF\xFF",
+			Gateway:     "",
+			NIC:         2,
+		},
+	})
+
+	go c.routePackets(c.linkEP0.C, c.linkEP1)
+	go c.routePackets(c.linkEP1.C, c.linkEP0)
+	return c
+}
+
+func (c *testEndpointContext) countPacket(pkt channel.PacketInfo) {
+	if pkt.Proto != header.IPv4ProtocolNumber {
+		c.t.Fatalf("Received non IPV4 packet: 0x%x", pkt.Proto)
+	}
+	ipv4 := header.IPv4(pkt.Header)
+	c.icmpCh <- header.ICMPv4(append(pkt.Header[ipv4.HeaderLength():], pkt.Payload...))
+}
+
+func (c *testEndpointContext) routePackets(ch <-chan channel.PacketInfo, ep *channel.Endpoint) {
+	for pkt := range ch {
+		c.countPacket(pkt)
+		v := buffer.View(append(pkt.Header, pkt.Payload...))
+		vs := []buffer.View{v}
+		vv := buffer.NewVectorisedView(len(v), vs)
+		ep.InjectLinkAddr(pkt.Proto, ep.LinkAddress(), &vv)
+	}
+}
+
+type callbackStub struct {
+	f func(e *waiter.Entry)
+}
+
+func (c *callbackStub) Callback(e *waiter.Entry) {
+	c.f(e)
+}
+
+func TestEndpoints(t *testing.T) {
+	c := newTestEndpointContext(t)
+	defer c.cleanup()
+
+	wq0 := &waiter.Queue{}
+	ep0, err := c.s.NewEndpoint(ipv4.PingProtocolNumber, ipv4.ProtocolNumber, wq0)
+	if err != nil {
+		c.t.Fatalf("NewEndpoint failed: %v", err)
+	}
+	defer ep0.Close()
+	wq1 := &waiter.Queue{}
+	ep1, err := c.s.NewEndpoint(ipv4.PingProtocolNumber, ipv4.ProtocolNumber, wq1)
+	if err != nil {
+		c.t.Fatalf("NewEndpoint failed: %v", err)
+	}
+	defer ep1.Close()
+
+	if err := ep0.Bind(tcpip.FullAddress{NIC: 1, Addr: stackAddr0}, nil); err != nil {
+		c.t.Fatalf("Bind failed: %v", err)
+	}
+	if err := ep1.Bind(tcpip.FullAddress{NIC: 2, Addr: stackAddr1}, nil); err != nil {
+		c.t.Fatalf("Bind failed: %v", err)
+	}
+
+	echos := 64
+
+	ping := func(wq *waiter.Queue, ep tcpip.Endpoint, data []byte) {
+		outPkt := make([]byte, header.ICMPv4MinimumSize+4+len(data))
+		icmpv4 := header.ICMPv4(outPkt[:header.ICMPv4MinimumSize])
+		icmpv4.SetType(header.ICMPv4Echo)
+		copy(outPkt[header.ICMPv4MinimumSize+4:], data)
+		binary.BigEndian.PutUint16(outPkt[header.ICMPv4MinimumSize:], 0)
+
+		for seqno := uint16(1); seqno <= uint16(echos); seqno++ {
+			binary.BigEndian.PutUint16(outPkt[header.ICMPv4MinimumSize+2:], seqno)
+
+			// We need to register with the waiter queue before we try writing, since
+			// the notification that the endpoint received a response may arrive immediately.
+			ready := make(chan struct{})
+			e := waiter.Entry{Callback: &callbackStub{func(*waiter.Entry) { close(ready) }}}
+			wq.EventRegister(&e, waiter.EventIn)
+			n, err := ep.Write(buffer.View(outPkt), nil)
+			if err != nil {
+				c.t.Fatalf("Write failed: %v\n", err)
+			} else if n != uintptr(len(outPkt)) {
+				c.t.Fatalf("Write was short: %v\n", n)
+			}
+
+			// Avoid reading until we have something to read
+			select {
+			case <-time.After(1 * time.Second):
+				c.t.Fatalf("Timed out waiting for socket to be readable")
+			case <-ready:
+			}
+			wq.EventUnregister(&e)
+			inPkt, err := ep.Read(nil)
+			if err != nil {
+				c.t.Fatalf("Read failed: %v\n", err)
+			}
+
+			// Verify the contents of the packet we just read.
+			var icmp header.ICMPv4 = []byte(inPkt)
+			if icmp.Type() != header.ICMPv4EchoReply {
+				c.t.Fatalf("Unexpected packet type: %d", icmp.Type())
+			}
+			inSeqno := binary.BigEndian.Uint16(inPkt[header.ICMPv4MinimumSize+2 : header.ICMPv4MinimumSize+4])
+			if inSeqno != seqno {
+				c.t.Fatalf("Unexpected sequence number: %d", inSeqno)
+			}
+			outData := outPkt[header.ICMPv4EchoMinimumSize:]
+			inData := inPkt[header.ICMPv4EchoMinimumSize:]
+			if len(outData) != len(inData) {
+				c.t.Fatalf("Read packet of unexpected length: %d\n", len(inData))
+			}
+			for i := range inData {
+				if inData[i] != outData[i] {
+					c.t.Fatalf("Data mismatch")
+				}
+			}
+		}
+	}
+
+	data := []byte{0xaa, 0xab, 0xac}
+	go ping(wq0, ep0, data)
+	data = []byte{0xad, 0xae, 0xaf}
+	go ping(wq1, ep1, data)
+
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+	defer cancel()
+
+	stats := make(map[header.ICMPv4Type]int)
+	for {
+		select {
+		case <-ctx.Done():
+			t.Errorf("Timeout waiting for ICMP, got: %#+v", stats)
+			return
+		case icmp := <-c.icmpCh:
+			if icmp.Type() != header.ICMPv4Echo && icmp.Type() != header.ICMPv4EchoReply {
+				c.t.Fatalf("Unexpected type: %d", icmp.Type())
+			}
+			stats[icmp.Type()]++
+			if stats[icmp.Type()] > echos*2 {
+				c.t.Fatalf("Too many (%d) packets of type %d", stats[icmp.Type()], icmp.Type())
+			}
+			if len(stats) == 2 && stats[header.ICMPv4Echo] == echos*2 && stats[header.ICMPv4EchoReply] == echos*2 {
+				return
+			}
+		}
+	}
+}