| // Copyright 2016 The Netstack Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package stack |
| |
| import ( |
| "sync" |
| "sync/atomic" |
| |
| "github.com/google/netstack/tcpip" |
| "github.com/google/netstack/tcpip/buffer" |
| "github.com/google/netstack/tcpip/header" |
| ) |
| |
| type protocolIDs struct { |
| network tcpip.NetworkProtocolNumber |
| transport tcpip.TransportProtocolNumber |
| } |
| |
| // transportEndpoints manages all endpoints of a given protocol. It has its own |
| // mutex so as to reduce interference between protocols. |
| type transportEndpoints struct { |
| mu sync.RWMutex |
| endpoints map[TransportEndpointID]TransportEndpoint |
| } |
| |
| // transportDemuxer demultiplexes packets targeted at a transport endpoint |
| // (i.e., after they've been parsed by the network layer). It does two levels |
| // of demultiplexing: first based on the network and transport protocols, then |
| // based on endpoints IDs. |
| type transportDemuxer struct { |
| protocol map[protocolIDs]*transportEndpoints |
| } |
| |
| func newTransportDemuxer(stack *Stack) *transportDemuxer { |
| d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)} |
| |
| // Add each network and and transport pair to the demuxer. |
| for netProto := range stack.networkProtocols { |
| for proto := range stack.transportProtocols { |
| d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{endpoints: make(map[TransportEndpointID]TransportEndpoint)} |
| } |
| } |
| |
| return d |
| } |
| |
| // registerEndpoint registers the given endpoint with the dispatcher such that |
| // packets that match the endpoint ID are delivered to it. |
| func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error { |
| for i, n := range netProtos { |
| if err := d.singleRegisterEndpoint(n, protocol, id, ep); err != nil { |
| d.unregisterEndpoint(netProtos[:i], protocol, id) |
| return err |
| } |
| } |
| |
| return nil |
| } |
| |
| func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error { |
| eps, ok := d.protocol[protocolIDs{netProto, protocol}] |
| if !ok { |
| return nil |
| } |
| |
| eps.mu.Lock() |
| defer eps.mu.Unlock() |
| |
| if _, ok := eps.endpoints[id]; ok { |
| return tcpip.ErrPortInUse |
| } |
| |
| eps.endpoints[id] = ep |
| |
| return nil |
| } |
| |
| // unregisterEndpoint unregisters the endpoint with the given id such that it |
| // won't receive any more packets. |
| func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) { |
| for _, n := range netProtos { |
| if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { |
| eps.mu.Lock() |
| delete(eps.endpoints, id) |
| eps.mu.Unlock() |
| } |
| } |
| } |
| |
| // deliverPacket attempts to deliver the given packet. Returns true if it found |
| // an endpoint, false otherwise. |
| func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView, id TransportEndpointID) bool { |
| eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] |
| if !ok { |
| return false |
| } |
| |
| eps.mu.RLock() |
| b := d.deliverPacketLocked(r, eps, vv, id) |
| eps.mu.RUnlock() |
| |
| // UDP packet could not be delivered to an unknown destination port |
| if !b && protocol == header.UDPProtocolNumber { |
| atomic.AddUint64(&r.MutableStats().UDP.UnknownPortErrors, 1) |
| } |
| |
| return b |
| } |
| |
| func (d *transportDemuxer) deliverPacketLocked(r *Route, eps *transportEndpoints, vv *buffer.VectorisedView, id TransportEndpointID) bool { |
| // Try to find a match with the id as provided. |
| if ep := eps.endpoints[id]; ep != nil { |
| ep.HandlePacket(r, id, vv) |
| return true |
| } |
| |
| // Try to find a match with the id minus the local address. |
| nid := id |
| |
| nid.LocalAddress = "" |
| if ep := eps.endpoints[nid]; ep != nil { |
| ep.HandlePacket(r, id, vv) |
| return true |
| } |
| |
| // Try to find a match with the id minus the remote part. |
| nid.LocalAddress = id.LocalAddress |
| nid.RemoteAddress = "" |
| nid.RemotePort = 0 |
| if ep := eps.endpoints[nid]; ep != nil { |
| ep.HandlePacket(r, id, vv) |
| return true |
| } |
| |
| // Try to find a match with only the local port. |
| nid.LocalAddress = "" |
| if ep := eps.endpoints[nid]; ep != nil { |
| ep.HandlePacket(r, id, vv) |
| return true |
| } |
| |
| return false |
| } |