| // Copyright 2018 The gVisor Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package tcp |
| |
| import ( |
| "gvisor.dev/gvisor/pkg/rand" |
| "gvisor.dev/gvisor/pkg/sleep" |
| "gvisor.dev/gvisor/pkg/sync" |
| "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" |
| "gvisor.dev/gvisor/pkg/tcpip/header" |
| "gvisor.dev/gvisor/pkg/tcpip/stack" |
| ) |
| |
| // epQueue is a queue of endpoints. |
| type epQueue struct { |
| mu sync.Mutex |
| list endpointList |
| } |
| |
| // enqueue adds e to the queue if the endpoint is not already on the queue. |
| func (q *epQueue) enqueue(e *endpoint) { |
| q.mu.Lock() |
| if e.pendingProcessing { |
| q.mu.Unlock() |
| return |
| } |
| q.list.PushBack(e) |
| e.pendingProcessing = true |
| q.mu.Unlock() |
| } |
| |
| // dequeue removes and returns the first element from the queue if available, |
| // returns nil otherwise. |
| func (q *epQueue) dequeue() *endpoint { |
| q.mu.Lock() |
| if e := q.list.Front(); e != nil { |
| q.list.Remove(e) |
| e.pendingProcessing = false |
| q.mu.Unlock() |
| return e |
| } |
| q.mu.Unlock() |
| return nil |
| } |
| |
| // empty returns true if the queue is empty, false otherwise. |
| func (q *epQueue) empty() bool { |
| q.mu.Lock() |
| v := q.list.Empty() |
| q.mu.Unlock() |
| return v |
| } |
| |
| // processor is responsible for processing packets queued to a tcp endpoint. |
| type processor struct { |
| epQ epQueue |
| newEndpointWaker sleep.Waker |
| closeWaker sleep.Waker |
| id int |
| wg sync.WaitGroup |
| } |
| |
| func newProcessor(id int) *processor { |
| p := &processor{ |
| id: id, |
| } |
| p.wg.Add(1) |
| go p.handleSegments() |
| return p |
| } |
| |
| func (p *processor) close() { |
| p.closeWaker.Assert() |
| } |
| |
| func (p *processor) wait() { |
| p.wg.Wait() |
| } |
| |
| func (p *processor) queueEndpoint(ep *endpoint) { |
| // Queue an endpoint for processing by the processor goroutine. |
| p.epQ.enqueue(ep) |
| p.newEndpointWaker.Assert() |
| } |
| |
| func (p *processor) handleSegments() { |
| const newEndpointWaker = 1 |
| const closeWaker = 2 |
| s := sleep.Sleeper{} |
| s.AddWaker(&p.newEndpointWaker, newEndpointWaker) |
| s.AddWaker(&p.closeWaker, closeWaker) |
| defer s.Done() |
| for { |
| id, ok := s.Fetch(true) |
| if ok && id == closeWaker { |
| p.wg.Done() |
| return |
| } |
| for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() { |
| if ep.segmentQueue.empty() { |
| continue |
| } |
| |
| // If socket has transitioned out of connected state |
| // then just let the worker handle the packet. |
| // |
| // NOTE: We read this outside of e.mu lock which means |
| // that by the time we get to handleSegments the |
| // endpoint may not be in ESTABLISHED. But this should |
| // be fine as all normal shutdown states are handled by |
| // handleSegments and if the endpoint moves to a |
| // CLOSED/ERROR state then handleSegments is a noop. |
| if ep.EndpointState() != StateEstablished { |
| ep.newSegmentWaker.Assert() |
| continue |
| } |
| |
| if !ep.mu.TryLock() { |
| ep.newSegmentWaker.Assert() |
| continue |
| } |
| // If the endpoint is in a connected state then we do |
| // direct delivery to ensure low latency and avoid |
| // scheduler interactions. |
| if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose { |
| // Send any active resets if required. |
| if err != nil { |
| ep.resetConnectionLocked(err) |
| } |
| ep.notifyProtocolGoroutine(notifyTickleWorker) |
| ep.mu.Unlock() |
| continue |
| } |
| |
| if !ep.segmentQueue.empty() { |
| p.epQ.enqueue(ep) |
| } |
| |
| ep.mu.Unlock() |
| } |
| } |
| } |
| |
| // dispatcher manages a pool of TCP endpoint processors which are responsible |
| // for the processing of inbound segments. This fixed pool of processor |
| // goroutines do full tcp processing. The processor is selected based on the |
| // hash of the endpoint id to ensure that delivery for the same endpoint happens |
| // in-order. |
| type dispatcher struct { |
| processors []*processor |
| seed uint32 |
| } |
| |
| func newDispatcher(nProcessors int) *dispatcher { |
| processors := []*processor{} |
| for i := 0; i < nProcessors; i++ { |
| processors = append(processors, newProcessor(i)) |
| } |
| return &dispatcher{ |
| processors: processors, |
| seed: generateRandUint32(), |
| } |
| } |
| |
| func (d *dispatcher) close() { |
| for _, p := range d.processors { |
| p.close() |
| } |
| } |
| |
| func (d *dispatcher) wait() { |
| for _, p := range d.processors { |
| p.wait() |
| } |
| } |
| |
| func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) { |
| ep := stackEP.(*endpoint) |
| s := newSegment(r, id, pkt) |
| if !s.parse() { |
| ep.stack.Stats().MalformedRcvdPackets.Increment() |
| ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() |
| ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() |
| s.decRef() |
| return |
| } |
| |
| if !s.csumValid { |
| ep.stack.Stats().MalformedRcvdPackets.Increment() |
| ep.stack.Stats().TCP.ChecksumErrors.Increment() |
| ep.stats.ReceiveErrors.ChecksumErrors.Increment() |
| s.decRef() |
| return |
| } |
| |
| ep.stack.Stats().TCP.ValidSegmentsReceived.Increment() |
| ep.stats.SegmentsReceived.Increment() |
| if (s.flags & header.TCPFlagRst) != 0 { |
| ep.stack.Stats().TCP.ResetsReceived.Increment() |
| } |
| |
| if !ep.enqueueSegment(s) { |
| s.decRef() |
| return |
| } |
| |
| // For sockets not in established state let the worker goroutine |
| // handle the packets. |
| if ep.EndpointState() != StateEstablished { |
| ep.newSegmentWaker.Assert() |
| return |
| } |
| |
| d.selectProcessor(id).queueEndpoint(ep) |
| } |
| |
| func generateRandUint32() uint32 { |
| b := make([]byte, 4) |
| if _, err := rand.Read(b); err != nil { |
| panic(err) |
| } |
| return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 |
| } |
| |
| func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor { |
| payload := []byte{ |
| byte(id.LocalPort), |
| byte(id.LocalPort >> 8), |
| byte(id.RemotePort), |
| byte(id.RemotePort >> 8)} |
| |
| h := jenkins.Sum32(d.seed) |
| h.Write(payload) |
| h.Write([]byte(id.LocalAddress)) |
| h.Write([]byte(id.RemoteAddress)) |
| |
| return d.processors[h.Sum32()%uint32(len(d.processors))] |
| } |