blob: 7e3d4dce93a8d642654a9dea8066fc3e18a2a68e [file] [log] [blame]
// Copyright 2016 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ipv4
import (
"context"
"encoding/binary"
"sync"
"time"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
"github.com/google/netstack/tcpip/stack"
"github.com/google/netstack/waiter"
)
// PingProtocolName is a pseudo transport protocol used to handle ping replies.
// Use it when constructing a stack that intends to use ipv4.Ping.
const PingProtocolName = "icmpv4ping"
// PingProtocolNumber is a transport protocol used to
// transmit and deliver ICMP messages. The ICMP identifier
// number is used as a port number for multiplexing.
const PingProtocolNumber tcpip.TransportProtocolNumber = 1
func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) {
v := vv.First()
if len(v) < header.ICMPv4MinimumSize {
return
}
h := header.ICMPv4(v)
switch h.Type() {
case header.ICMPv4Echo:
if len(v) < header.ICMPv4EchoMinimumSize {
return
}
vv.TrimFront(header.ICMPv4MinimumSize)
req := echoRequest{r: r.Clone(), v: vv.ToView()}
select {
case e.echoRequests <- req:
default:
req.r.Release()
}
case header.ICMPv4EchoReply, header.ICMPv4InfoReply, header.ICMPv4TimestampReply:
e.dispatcher.DeliverTransportPacket(r, PingProtocolNumber, vv)
}
// TODO(crawshaw): Handle other ICMP types.
}
type echoRequest struct {
r stack.Route
v buffer.View
}
func (e *endpoint) echoReplier() {
for req := range e.echoRequests {
sendICMPv4(&req.r, header.ICMPv4EchoReply, 0, req.v)
req.r.Release()
}
}
func sendICMPv4(r *stack.Route, typ header.ICMPv4Type, code byte, data buffer.View) *tcpip.Error {
hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength()))
icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
icmpv4.SetType(typ)
icmpv4.SetCode(code)
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
return r.WritePacket(&hdr, data, header.ICMPv4ProtocolNumber)
}
// A Pinger can send echo requests to an address.
type Pinger struct {
Stack *stack.Stack
NICID tcpip.NICID
Addr tcpip.Address
LocalAddr tcpip.Address // optional
Wait time.Duration // if zero, defaults to 1 second
Count uint16 // if zero, defaults to MaxUint16
}
type pingerEndpoint struct {
stack *stack.Stack
pktCh chan buffer.View
}
func (e *pingerEndpoint) Close() {
close(e.pktCh)
}
func (e *pingerEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) {
select {
case e.pktCh <- vv.ToView():
default:
}
}
// Ping sends echo requests to an ICMPv4 endpoint.
// Responses are streamed to the channel ch.
func (p *Pinger) Ping(ctx context.Context, ch chan<- PingReply) *tcpip.Error {
count := p.Count
if count == 0 {
count = 1<<16 - 1
}
wait := p.Wait
if wait == 0 {
wait = 1 * time.Second
}
r, err := p.Stack.FindRoute(p.NICID, p.LocalAddr, p.Addr, ProtocolNumber)
if err != nil {
return err
}
netProtos := []tcpip.NetworkProtocolNumber{ProtocolNumber}
ep := &pingerEndpoint{
stack: p.Stack,
pktCh: make(chan buffer.View, 1),
}
id := stack.TransportEndpointID{
LocalAddress: r.LocalAddress,
RemoteAddress: p.Addr,
}
_, err = p.Stack.PickEphemeralPort(func(port uint16) (bool, *tcpip.Error) {
id.LocalPort = port
err := p.Stack.RegisterTransportEndpoint(p.NICID, netProtos, PingProtocolNumber, id, ep)
switch err {
case nil:
return true, nil
case tcpip.ErrPortInUse:
return false, nil
default:
return false, err
}
})
if err != nil {
return err
}
defer p.Stack.UnregisterTransportEndpoint(p.NICID, netProtos, PingProtocolNumber, id)
v := buffer.NewView(4)
binary.BigEndian.PutUint16(v[0:], id.LocalPort)
start := time.Now()
done := make(chan struct{})
go func(count int) {
loop:
for ; count > 0; count-- {
select {
case v := <-ep.pktCh:
seq := binary.BigEndian.Uint16(v[header.ICMPv4MinimumSize+2:])
ch <- PingReply{
Duration: time.Since(start) - time.Duration(seq)*wait,
SeqNumber: seq,
}
case <-ctx.Done():
break loop
}
}
close(done)
}(int(count))
defer func() { <-done }()
t := time.NewTicker(wait)
defer t.Stop()
for seq := uint16(0); seq < count; seq++ {
select {
case <-t.C:
case <-ctx.Done():
return nil
}
binary.BigEndian.PutUint16(v[2:], seq)
sent := time.Now()
if err := sendICMPv4(&r, header.ICMPv4Echo, 0, v); err != nil {
ch <- PingReply{
Error: err,
Duration: time.Since(sent),
SeqNumber: seq,
}
}
}
return nil
}
// PingReply summarizes an ICMP echo reply.
type PingReply struct {
Error *tcpip.Error // reports any errors sending a ping request
Duration time.Duration
SeqNumber uint16
}
type endpointState int
const (
stateInitial endpointState = iota
stateConnected
stateClosed
)
type pingEndpoint struct {
stack *stack.Stack
netProto tcpip.NetworkProtocolNumber
waiterQueue *waiter.Queue
mu sync.RWMutex
pktCh chan buffer.View
state endpointState
route stack.Route
nic tcpip.NICID
id stack.TransportEndpointID
}
func (e *pingEndpoint) Close() {
e.mu.Lock()
defer e.mu.Unlock()
if e.state == stateClosed {
return
}
if e.state == stateConnected {
netProtos := []tcpip.NetworkProtocolNumber{ProtocolNumber}
e.stack.UnregisterTransportEndpoint(e.nic, netProtos, PingProtocolNumber, e.id)
e.route.Release()
}
close(e.pktCh)
e.state = stateClosed
}
func (e *pingEndpoint) Read(a *tcpip.FullAddress) (buffer.View, *tcpip.Error) {
select {
case v := <-e.pktCh:
return v, nil
default:
return buffer.View{}, tcpip.ErrWouldBlock
}
}
func (e *pingEndpoint) Write(v buffer.View, to *tcpip.FullAddress) (uintptr, *tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
switch state := e.state; state {
case stateInitial:
if to == nil {
return 0, tcpip.ErrNotSupported
} else if err := e.bindLocked(*to, nil); err != nil {
return 0, err
}
case stateConnected:
if to != nil {
prev := tcpip.FullAddress{
NIC: e.nic,
Addr: e.id.RemoteAddress,
Port: e.id.RemotePort,
}
if prev != *to {
return 0, tcpip.ErrAlreadyConnected
}
}
default:
return 0, tcpip.ErrClosedForSend
}
if len(v) < header.ICMPv4MinimumSize {
return 0, tcpip.ErrNotSupported
}
hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(e.route.MaxHeaderLength()))
icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
copy(icmpv4, v[:header.ICMPv4MinimumSize])
icmpv4.SetCode(0)
data := v[header.ICMPv4MinimumSize:]
// Overwrite the ID with the port number
binary.BigEndian.PutUint16(data[0:], e.id.LocalPort)
// Overwrite the checksum of the packet
icmpv4.SetChecksum(0)
chksum := header.ICMPv4(data).CalculateChecksum(icmpv4.CalculateChecksum(0))
icmpv4.SetChecksum(^chksum)
if err := e.route.WritePacket(&hdr, data, header.ICMPv4ProtocolNumber); err != nil {
return 0, err
}
return uintptr(len(v)), nil
}
func (e *pingEndpoint) Peek(data [][]byte) (uintptr, *tcpip.Error) {
return 0, tcpip.ErrNotSupported
}
// SetOption implements TransportProtocol.SetOption.
func (p *pingProtocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
func init() {
stack.RegisterTransportProtocolFactory(PingProtocolName, func() stack.TransportProtocol {
return &pingProtocol{}
})
}
func (e *pingEndpoint) Connect(address tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrNotSupported
}
func (e *pingEndpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
e.Close()
return nil
}
func (e *pingEndpoint) Listen(backlog int) *tcpip.Error {
return tcpip.ErrNotSupported
}
func (e *pingEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
func (e *pingEndpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.bindLocked(addr, commit)
}
func (e *pingEndpoint) bindLocked(to tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
if e.state != stateInitial {
return tcpip.ErrAlreadyConnected
}
r, err := e.stack.FindRoute(to.NIC, "", to.Addr, e.netProto)
if err != nil {
return err
}
netProtos := []tcpip.NetworkProtocolNumber{e.netProto}
id := stack.TransportEndpointID{
LocalAddress: r.LocalAddress,
RemoteAddress: to.Addr,
}
_, err = e.stack.PickEphemeralPort(func(port uint16) (bool, *tcpip.Error) {
id.LocalPort = port
err := e.stack.RegisterTransportEndpoint(to.NIC, netProtos, PingProtocolNumber, id, e)
switch err {
case nil:
return true, nil
case tcpip.ErrPortInUse:
return false, nil
default:
return false, err
}
})
if commit != nil {
if err := commit(); err != nil {
e.stack.UnregisterTransportEndpoint(to.NIC, netProtos, PingProtocolNumber, id)
r.Release()
return err
}
}
e.state = stateConnected
e.route = r
e.nic = to.NIC
e.id = id
return nil
}
func (e *pingEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, tcpip.ErrNotSupported
}
func (e *pingEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, tcpip.ErrNotSupported
}
func (e *pingEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
return 0
}
func (e *pingEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrNotSupported
}
func (e *pingEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrNotSupported
}
func (e *pingEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) {
select {
case e.pktCh <- vv.ToView():
e.waiterQueue.Notify(waiter.EventIn)
default:
}
}
type pingProtocol struct{}
func (p *pingProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return &pingEndpoint{
stack: stack,
netProto: netProto,
waiterQueue: waiterQueue,
pktCh: make(chan buffer.View, 10),
}, nil
}
func (*pingProtocol) Number() tcpip.TransportProtocolNumber { return PingProtocolNumber }
func (*pingProtocol) MinimumPacketSize() int { return header.ICMPv4EchoMinimumSize }
func (*pingProtocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
ident := binary.BigEndian.Uint16(v[4:])
return 0, ident, nil
}
func (*pingProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool {
return true
}