blob: 6a1a3e575fe88bec7d22a685995ba7345593a9f5 [file] [log] [blame]
// Copyright 2016 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stack
import (
"sync"
"sync/atomic"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
)
type protocolIDs struct {
network tcpip.NetworkProtocolNumber
transport tcpip.TransportProtocolNumber
}
// transportEndpoints manages all endpoints of a given protocol. It has its own
// mutex so as to reduce interference between protocols.
type transportEndpoints struct {
mu sync.RWMutex
endpoints map[TransportEndpointID]TransportEndpoint
}
// transportDemuxer demultiplexes packets targeted at a transport endpoint
// (i.e., after they've been parsed by the network layer). It does two levels
// of demultiplexing: first based on the network and transport protocols, then
// based on endpoints IDs.
type transportDemuxer struct {
protocol map[protocolIDs]*transportEndpoints
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)}
// Add each network and and transport pair to the demuxer.
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{endpoints: make(map[TransportEndpointID]TransportEndpoint)}
}
}
return d
}
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
for i, n := range netProtos {
if err := d.singleRegisterEndpoint(n, protocol, id, ep); err != nil {
d.unregisterEndpoint(netProtos[:i], protocol, id)
return err
}
}
return nil
}
func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
if !ok {
return nil
}
eps.mu.Lock()
defer eps.mu.Unlock()
if _, ok := eps.endpoints[id]; ok {
return tcpip.ErrPortInUse
}
eps.endpoints[id] = ep
return nil
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
for _, n := range netProtos {
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
eps.mu.Lock()
delete(eps.endpoints, id)
eps.mu.Unlock()
}
}
}
// deliverPacket attempts to deliver the given packet. Returns true if it found
// an endpoint, false otherwise.
func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
}
eps.mu.RLock()
b := d.deliverPacketLocked(r, eps, vv, id)
eps.mu.RUnlock()
// UDP packet could not be delivered to an unknown destination port
if !b && protocol == header.UDPProtocolNumber {
atomic.AddUint64(&r.MutableStats().UDP.UnknownPortErrors, 1)
}
return b
}
func (d *transportDemuxer) deliverPacketLocked(r *Route, eps *transportEndpoints, vv *buffer.VectorisedView, id TransportEndpointID) bool {
// Try to find a match with the id as provided.
if ep := eps.endpoints[id]; ep != nil {
ep.HandlePacket(r, id, vv)
return true
}
// Try to find a match with the id minus the local address.
nid := id
nid.LocalAddress = ""
if ep := eps.endpoints[nid]; ep != nil {
ep.HandlePacket(r, id, vv)
return true
}
// Try to find a match with the id minus the remote part.
nid.LocalAddress = id.LocalAddress
nid.RemoteAddress = ""
nid.RemotePort = 0
if ep := eps.endpoints[nid]; ep != nil {
ep.HandlePacket(r, id, vv)
return true
}
// Try to find a match with only the local port.
nid.LocalAddress = ""
if ep := eps.endpoints[nid]; ep != nil {
ep.HandlePacket(r, id, vv)
return true
}
return false
}