blob: e936aa728cd0c5d686e0d3219b855b8b49467a5d [file] [log] [blame]
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"fmt"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
const (
// maxPendingResolutions is the maximum number of pending link-address
// resolutions.
maxPendingResolutions = 64
maxPendingPacketsPerResolution = 256
)
// pendingPacketBuffer is a pending packet buffer.
//
// TODO(gvisor.dev/issue/5331): Drop this when we drop WritePacket and only use
// WritePackets so we can use a PacketBufferList everywhere.
type pendingPacketBuffer interface {
len() int
}
func (*PacketBuffer) len() int {
return 1
}
func (p *PacketBufferList) len() int {
return p.Len()
}
type pendingPacket struct {
routeInfo RouteInfo
gso *GSO
proto tcpip.NetworkProtocolNumber
pkt pendingPacketBuffer
}
// packetsPendingLinkResolution is a queue of packets pending link resolution.
//
// Once link resolution completes successfully, the packets will be written.
type packetsPendingLinkResolution struct {
nic *nic
mu struct {
sync.Mutex
// The packets to send once the resolver completes.
//
// The link resolution channel is used as the key for this map.
packets map[<-chan struct{}][]pendingPacket
// FIFO of channels used to cancel the oldest goroutine waiting for
// link-address resolution.
//
// cancelChans holds the same channels that are used as keys to packets.
cancelChans []<-chan struct{}
}
}
func (f *packetsPendingLinkResolution) incrementOutgoingPacketErrors(proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) {
n := uint64(pkt.len())
f.nic.stack.stats.IP.OutgoingPacketErrors.IncrementBy(n)
if ipEndpointStats, ok := f.nic.getNetworkEndpoint(proto).Stats().(IPNetworkEndpointStats); ok {
ipEndpointStats.IPStats().OutgoingPacketErrors.IncrementBy(n)
}
}
func (f *packetsPendingLinkResolution) init(nic *nic) {
f.mu.Lock()
defer f.mu.Unlock()
f.nic = nic
f.mu.packets = make(map[<-chan struct{}][]pendingPacket)
}
// dequeue any pending packets associated with ch.
//
// If err is nil, packets will be written and sent to the given remote link
// address.
func (f *packetsPendingLinkResolution) dequeue(ch <-chan struct{}, linkAddr tcpip.LinkAddress, err tcpip.Error) {
f.mu.Lock()
packets, ok := f.mu.packets[ch]
delete(f.mu.packets, ch)
if ok {
for i, cancelChan := range f.mu.cancelChans {
if cancelChan == ch {
f.mu.cancelChans = append(f.mu.cancelChans[:i], f.mu.cancelChans[i+1:]...)
break
}
}
}
f.mu.Unlock()
if ok {
f.dequeuePackets(packets, linkAddr, err)
}
}
// enqueue a packet to be sent once link resolution completes.
//
// If the maximum number of pending resolutions is reached, the packets
// associated with the oldest link resolution will be dequeued as if they failed
// link resolution.
func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
f.mu.Lock()
// Make sure we attempt resolution while holding f's lock so that we avoid
// a race where link resolution completes before we enqueue the packets.
//
// A @ T1: Call ResolvedFields (get link resolution channel)
// B @ T2: Complete link resolution, dequeue pending packets
// C @ T1: Enqueue packet that already completed link resolution (which will
// never dequeue)
//
// To make sure B does not interleave with A and C, we make sure A and C are
// done while holding the lock.
routeInfo, ch, err := r.resolvedFields(nil)
switch err.(type) {
case nil:
// The route resolved immediately, so we don't need to wait for link
// resolution to send the packet.
f.mu.Unlock()
return f.nic.writePacketBuffer(routeInfo, gso, proto, pkt)
case *tcpip.ErrWouldBlock:
// We need to wait for link resolution to complete.
default:
f.mu.Unlock()
return 0, err
}
defer f.mu.Unlock()
packets, ok := f.mu.packets[ch]
packets = append(packets, pendingPacket{
routeInfo: routeInfo,
gso: gso,
proto: proto,
pkt: pkt,
})
if len(packets) > maxPendingPacketsPerResolution {
f.incrementOutgoingPacketErrors(packets[0].proto, packets[0].pkt)
packets[0] = pendingPacket{}
packets = packets[1:]
if numPackets := len(packets); numPackets != maxPendingPacketsPerResolution {
panic(fmt.Sprintf("holding more queued packets than expected; got = %d, want <= %d", numPackets, maxPendingPacketsPerResolution))
}
}
f.mu.packets[ch] = packets
if ok {
return pkt.len(), nil
}
cancelledPackets := f.newCancelChannelLocked(ch)
if len(cancelledPackets) != 0 {
// Dequeue the pending packets in a new goroutine to not hold up the current
// goroutine as handing link resolution failures may be a costly operation.
go f.dequeuePackets(cancelledPackets, "" /* linkAddr */, &tcpip.ErrAborted{})
}
return pkt.len(), nil
}
// newCancelChannelLocked appends the link resolution channel to a FIFO. If the
// maximum number of pending resolutions is reached, the oldest channel will be
// removed and its associated pending packets will be returned.
func (f *packetsPendingLinkResolution) newCancelChannelLocked(newCH <-chan struct{}) []pendingPacket {
f.mu.cancelChans = append(f.mu.cancelChans, newCH)
if len(f.mu.cancelChans) <= maxPendingResolutions {
return nil
}
ch := f.mu.cancelChans[0]
f.mu.cancelChans[0] = nil
f.mu.cancelChans = f.mu.cancelChans[1:]
if l := len(f.mu.cancelChans); l > maxPendingResolutions {
panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions))
}
packets, ok := f.mu.packets[ch]
if !ok {
panic("must have a packet queue for an uncancelled channel")
}
delete(f.mu.packets, ch)
return packets
}
func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, linkAddr tcpip.LinkAddress, err tcpip.Error) {
for _, p := range packets {
if err == nil {
p.routeInfo.RemoteLinkAddress = linkAddr
_, _ = f.nic.writePacketBuffer(p.routeInfo, p.gso, p.proto, p.pkt)
} else {
f.incrementOutgoingPacketErrors(p.proto, p.pkt)
if linkResolvableEP, ok := f.nic.getNetworkEndpoint(p.proto).(LinkResolvableNetworkEndpoint); ok {
switch pkt := p.pkt.(type) {
case *PacketBuffer:
linkResolvableEP.HandleLinkResolutionFailure(pkt)
case *PacketBufferList:
for pb := pkt.Front(); pb != nil; pb = pb.Next() {
linkResolvableEP.HandleLinkResolutionFailure(pb)
}
default:
panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", p.pkt))
}
}
}
}
}