blob: cc7a2e048b98e78e519266a8e06406b27ea907d3 [file] [log] [blame]
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"fmt"
"math/rand"
"sync"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/hash/jenkins"
"github.com/google/netstack/tcpip/header"
)
type protocolIDs struct {
network tcpip.NetworkProtocolNumber
transport tcpip.TransportProtocolNumber
}
// transportEndpoints manages all endpoints of a given protocol. It has its own
// mutex so as to reduce interference between protocols.
type transportEndpoints struct {
// mu protects all fields of the transportEndpoints.
mu sync.RWMutex
endpoints map[TransportEndpointID]*endpointsByNic
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
rawEndpoints []RawTransportEndpoint
}
type endpointsByNic struct {
mu sync.RWMutex
endpoints map[tcpip.NICID]*multiPortEndpoint
// seed is a random secret for a jenkins hash.
seed uint32
}
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
epsByNic.mu.RLock()
mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
if !ok {
if mpep, ok = epsByNic.endpoints[0]; !ok {
epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
return
}
}
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
if isMulticastOrBroadcast(id.LocalAddress) {
mpep.handlePacketAll(r, id, vv)
epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
return
}
// multiPortEndpoints are guaranteed to have at least one element.
selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv)
epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
epsByNic.mu.RLock()
defer epsByNic.mu.RUnlock()
mpep, ok := epsByNic.endpoints[n.ID()]
if !ok {
mpep, ok = epsByNic.endpoints[0]
}
if !ok {
return
}
// TODO(eyalsoha): Why don't we look at id to see if this packet needs to
// broadcast like we are doing with handlePacket above?
// multiPortEndpoints are guaranteed to have at least one element.
selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv)
}
// registerEndpoint returns true if it succeeds. It fails and returns
// false if ep already has an element with the same key.
func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
epsByNic.mu.Lock()
defer epsByNic.mu.Unlock()
if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok {
// There was already a bind.
return multiPortEp.singleRegisterEndpoint(t, reusePort)
}
// This is a new binding.
multiPortEp := &multiPortEndpoint{}
multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
multiPortEp.reuse = reusePort
epsByNic.endpoints[bindToDevice] = multiPortEp
return multiPortEp.singleRegisterEndpoint(t, reusePort)
}
// unregisterEndpoint returns true if endpointsByNic has to be unregistered.
func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool {
epsByNic.mu.Lock()
defer epsByNic.mu.Unlock()
multiPortEp, ok := epsByNic.endpoints[bindToDevice]
if !ok {
return false
}
if multiPortEp.unregisterEndpoint(t) {
delete(epsByNic.endpoints, bindToDevice)
}
return len(epsByNic.endpoints) == 0
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
eps.mu.Lock()
defer eps.mu.Unlock()
epsByNic, ok := eps.endpoints[id]
if !ok {
return
}
if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
return
}
delete(eps.endpoints, id)
}
// transportDemuxer demultiplexes packets targeted at a transport endpoint
// (i.e., after they've been parsed by the network layer). It does two levels
// of demultiplexing: first based on the network and transport protocols, then
// based on endpoints IDs. It should only be instantiated via
// newTransportDemuxer.
type transportDemuxer struct {
// protocol is immutable.
protocol map[protocolIDs]*transportEndpoints
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)}
// Add each network and transport pair to the demuxer.
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
endpoints: make(map[TransportEndpointID]*endpointsByNic),
}
}
}
return d
}
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
for i, n := range netProtos {
if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil {
d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice)
return err
}
}
return nil
}
// multiPortEndpoint is a container for TransportEndpoints which are bound to
// the same pair of address and port. endpointsArr always has at least one
// element.
type multiPortEndpoint struct {
mu sync.RWMutex
endpointsArr []TransportEndpoint
endpointsMap map[TransportEndpoint]int
// reuse indicates if more than one endpoint is allowed.
reuse bool
}
// reciprocalScale scales a value into range [0, n).
//
// This is similar to val % n, but faster.
// See http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
func reciprocalScale(val, n uint32) uint32 {
return uint32((uint64(val) * uint64(n)) >> 32)
}
// selectEndpoint calculates a hash of destination and source addresses and
// ports then uses it to select a socket. In this case, all packets from one
// address will be sent to same endpoint.
func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
if len(mpep.endpointsArr) == 1 {
return mpep.endpointsArr[0]
}
payload := []byte{
byte(id.LocalPort),
byte(id.LocalPort >> 8),
byte(id.RemotePort),
byte(id.RemotePort >> 8),
}
h := jenkins.Sum32(seed)
h.Write(payload)
h.Write([]byte(id.LocalAddress))
h.Write([]byte(id.RemoteAddress))
hash := h.Sum32()
idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr)))
return mpep.endpointsArr[idx]
}
func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
ep.mu.RLock()
for i, endpoint := range ep.endpointsArr {
// HandlePacket modifies vv, so each endpoint needs its own copy except for
// the final one.
if i == len(ep.endpointsArr)-1 {
endpoint.HandlePacket(r, id, vv)
break
}
vvCopy := buffer.NewView(vv.Size())
copy(vvCopy, vv.ToView())
endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
}
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
// list. The list might be empty already.
func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
if len(ep.endpointsArr) > 0 {
// If it was previously bound, we need to check if we can bind again.
if !ep.reuse || !reusePort {
return tcpip.ErrPortInUse
}
}
// A new endpoint is added into endpointsArr and its index there is saved in
// endpointsMap. This will allow us to remove endpoint from the array fast.
ep.endpointsMap[t] = len(ep.endpointsArr)
ep.endpointsArr = append(ep.endpointsArr, t)
return nil
}
// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
ep.mu.Lock()
defer ep.mu.Unlock()
idx, ok := ep.endpointsMap[t]
if !ok {
return false
}
delete(ep.endpointsMap, t)
l := len(ep.endpointsArr)
if l > 1 {
// The last endpoint in endpointsArr is moved instead of the deleted one.
lastEp := ep.endpointsArr[l-1]
ep.endpointsArr[idx] = lastEp
ep.endpointsMap[lastEp] = idx
ep.endpointsArr = ep.endpointsArr[0 : l-1]
return false
}
return true
}
func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
if id.RemotePort != 0 {
// TODO(eyalsoha): Why?
reusePort = false
}
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
if !ok {
return tcpip.ErrUnknownProtocol
}
eps.mu.Lock()
defer eps.mu.Unlock()
if epsByNic, ok := eps.endpoints[id]; ok {
// There was already a binding.
return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
}
// This is a new binding.
epsByNic := &endpointsByNic{
endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
seed: rand.Uint32(),
}
eps.endpoints[id] = epsByNic
return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
for _, n := range netProtos {
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
eps.unregisterEndpoint(id, ep, bindToDevice)
}
}
}
var loopbackSubnet = func() tcpip.Subnet {
sn, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00")
if err != nil {
panic(err)
}
return sn
}()
// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if it
// found one or more endpoints, false otherwise.
func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
}
eps.mu.RLock()
// Determine which transport endpoint or endpoints to deliver this packet to.
// If the packet is a broadcast or multicast, then find all matching
// transport endpoints.
var destEps []*endpointsByNic
if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
destEps = d.findAllEndpointsLocked(eps, vv, id)
} else if ep := d.findEndpointLocked(eps, vv, id); ep != nil {
destEps = append(destEps, ep)
}
eps.mu.RUnlock()
// Fail if we didn't find at least one matching transport endpoint.
if len(destEps) == 0 {
// UDP packet could not be delivered to an unknown destination port.
if protocol == header.UDPProtocolNumber {
r.Stats().UDP.UnknownPortErrors.Increment()
}
return false
}
// Deliver the packet.
for _, ep := range destEps {
ep.handlePacket(r, id, vv)
}
return true
}
// deliverRawPacket attempts to deliver the given packet and returns whether it
// was delivered successfully.
func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
}
// As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via
// raw endpoint first. If there are multiple raw endpoints, they all
// receive the packet.
foundRaw := false
eps.mu.RLock()
for _, rawEP := range eps.rawEndpoints {
// Each endpoint gets its own copy of the packet for the sake
// of save/restore.
rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView())
foundRaw = true
}
eps.mu.RUnlock()
return foundRaw
}
// deliverControlPacket attempts to deliver the given control packet. Returns
// true if it found an endpoint, false otherwise.
func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{net, trans}]
if !ok {
return false
}
// Try to find the endpoint.
eps.mu.RLock()
ep := d.findEndpointLocked(eps, vv, id)
eps.mu.RUnlock()
// Fail if we didn't find one.
if ep == nil {
return false
}
// Deliver the packet.
ep.handleControlPacket(n, id, typ, extra, vv)
return true
}
func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) []*endpointsByNic {
var matchedEPs []*endpointsByNic
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
matchedEPs = append(matchedEPs, ep)
}
// Try to find a match with the id minus the local address.
nid := id
nid.LocalAddress = ""
if ep, ok := eps.endpoints[nid]; ok {
matchedEPs = append(matchedEPs, ep)
}
// Try to find a match with the id minus the remote part.
nid.LocalAddress = id.LocalAddress
nid.RemoteAddress = ""
nid.RemotePort = 0
if ep, ok := eps.endpoints[nid]; ok {
matchedEPs = append(matchedEPs, ep)
}
// Try to find a match with only the local port.
nid.LocalAddress = ""
if ep, ok := eps.endpoints[nid]; ok {
matchedEPs = append(matchedEPs, ep)
}
return matchedEPs
}
// findEndpointLocked returns the endpoint that most closely matches the given
// id.
func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic {
if matchedEPs := d.findAllEndpointsLocked(eps, vv, id); len(matchedEPs) > 0 {
return matchedEPs[0]
}
return nil
}
// registerRawEndpoint registers the given endpoint with the dispatcher such
// that packets of the appropriate protocol are delivered to it. A single
// packet can be sent to one or more raw endpoints along with a non-raw
// endpoint.
func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
return tcpip.ErrNotSupported
}
eps.mu.Lock()
defer eps.mu.Unlock()
eps.rawEndpoints = append(eps.rawEndpoints, ep)
return nil
}
// unregisterRawEndpoint unregisters the raw endpoint for the given transport
// protocol such that it won't receive any more packets.
func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
panic(fmt.Errorf("tried to unregister endpoint with unsupported network and transport protocol pair: %d, %d", netProto, transProto))
}
eps.mu.Lock()
defer eps.mu.Unlock()
for i, rawEP := range eps.rawEndpoints {
if rawEP == ep {
eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...)
return
}
}
}
func isMulticastOrBroadcast(addr tcpip.Address) bool {
return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
}