blob: 0499d89f5091b649fc4d7fc4d4f737d8fa6254d1 [file] [log] [blame]
// Copyright (c) 2023, Google Inc.
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
package runner
import (
"context"
"encoding/binary"
"fmt"
"io"
"net"
"os"
"sync"
"time"
)
type shimDispatcher struct {
lock sync.Mutex
nextShimID uint64
listener *net.TCPListener
shims map[uint64]*shimListener
err error
}
func newShimDispatcher() (*shimDispatcher, error) {
listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv6loopback})
if err != nil {
listener, err = net.ListenTCP("tcp4", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}})
}
if err != nil {
return nil, err
}
d := &shimDispatcher{listener: listener, shims: make(map[uint64]*shimListener)}
go d.acceptLoop()
return d, nil
}
func (d *shimDispatcher) NewShim() (*shimListener, error) {
d.lock.Lock()
defer d.lock.Unlock()
if d.err != nil {
return nil, d.err
}
l := &shimListener{dispatcher: d, shimID: d.nextShimID, connChan: make(chan net.Conn, 1)}
d.shims[l.shimID] = l
d.nextShimID++
return l, nil
}
func (d *shimDispatcher) unregisterShim(l *shimListener) {
d.lock.Lock()
delete(d.shims, l.shimID)
d.lock.Unlock()
}
func (d *shimDispatcher) acceptLoop() {
for {
conn, err := d.listener.Accept()
if err != nil {
// Something went wrong. Shut down the listener.
d.closeWithError(err)
return
}
go func() {
if err := d.dispatch(conn); err != nil {
// To be robust against port scanners, etc., we log a warning
// but otherwise treat undispatchable connections as non-fatal.
fmt.Fprintf(os.Stderr, "Error dispatching connection: %s\n", err)
conn.Close()
}
}()
}
}
func (d *shimDispatcher) dispatch(conn net.Conn) error {
conn.SetReadDeadline(time.Now().Add(*idleTimeout))
var buf [8]byte
if _, err := io.ReadFull(conn, buf[:]); err != nil {
return err
}
conn.SetReadDeadline(time.Time{})
shimID := binary.LittleEndian.Uint64(buf[:])
d.lock.Lock()
shim, ok := d.shims[shimID]
d.lock.Unlock()
if !ok {
return fmt.Errorf("shim ID %d not found", shimID)
}
shim.connChan <- conn
return nil
}
func (d *shimDispatcher) Close() error {
return d.closeWithError(net.ErrClosed)
}
func (d *shimDispatcher) closeWithError(err error) error {
closeErr := d.listener.Close()
d.lock.Lock()
shims := d.shims
d.shims = make(map[uint64]*shimListener)
d.err = err
d.lock.Unlock()
for _, shim := range shims {
shim.closeWithError(err)
}
return closeErr
}
type shimListener struct {
dispatcher *shimDispatcher
shimID uint64
// connChan contains connections from the dispatcher. On fatal error, it is
// closed, with the error available in err.
connChan chan net.Conn
err error
lock sync.Mutex
}
func (l *shimListener) Port() int {
return l.dispatcher.listener.Addr().(*net.TCPAddr).Port
}
func (l *shimListener) IsIPv6() bool {
return len(l.dispatcher.listener.Addr().(*net.TCPAddr).IP) == net.IPv6len
}
func (l *shimListener) ShimID() uint64 {
return l.shimID
}
func (l *shimListener) Close() error {
l.dispatcher.unregisterShim(l)
l.closeWithError(net.ErrClosed)
return nil
}
func (l *shimListener) closeWithError(err error) {
// Multiple threads may close the listener at once, so protect closing with
// a lock.
l.lock.Lock()
if l.err == nil {
l.err = err
close(l.connChan)
}
l.lock.Unlock()
}
func (l *shimListener) Accept(deadline time.Time) (net.Conn, error) {
var timerChan <-chan time.Time
if !deadline.IsZero() {
remaining := time.Until(deadline)
if remaining < 0 {
return nil, context.DeadlineExceeded
}
timer := time.NewTimer(remaining)
defer timer.Stop()
timerChan = timer.C
}
select {
case <-timerChan:
return nil, context.DeadlineExceeded
case conn, ok := <-l.connChan:
if !ok {
return nil, l.err
}
return conn, nil
}
}