blob: 8eede7684065e59456acd3771e71f8ac88997e93 [file] [log] [blame]
// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can
// found in the LICENSE file.
// +build !darwin
package serial
import (
"context"
"errors"
"fmt"
"log"
"net"
"os"
"syscall"
"time"
)
const (
maxConnections = 32
bufferSize = 4096
)
// Server is a serial multiplexing server that redirects i/o over
// unix domain sockets.
type Server struct {
ServerOptions
serialPath string
serialDevice *os.File
listener *net.UnixListener
clientFds map[int]bool
}
// ServerOptions specifies additional options that modify the behavior
// off the server.
type ServerOptions struct {
// Logger when set to non-nil receives debugging output.
Logger *log.Logger
}
// NewServer instantiates a new serial server that multiplexes i/o
// from the given serial path and socket path.
func NewServer(serialPath, serialSocketPath string, opts *ServerOptions) (*Server, error) {
timeout := time.After(2 * time.Minute)
tick := time.Tick(500 * time.Millisecond)
var serialDevice *os.File
var err error
loop:
for {
select {
case <-tick:
serialDevice, err = Open(serialPath)
if err == nil {
break loop
}
case <-timeout:
return nil, errors.New("timeout")
}
}
if err := os.Remove(serialSocketPath); err != nil && !os.IsNotExist(err) {
return nil, err
}
addr := &net.UnixAddr{Name: serialSocketPath, Net: "unix"}
l, err := net.ListenUnix("unix", addr)
if err != nil {
return nil, err
}
if err := os.Chmod(serialSocketPath, 0o777); err != nil {
return nil, err
}
return &Server{
serialPath: serialPath,
serialDevice: serialDevice,
listener: l,
clientFds: make(map[int]bool),
}, nil
}
func (s *Server) copyData(sync bool, buffer []byte, src int, dests ...int) error {
// Read data from the src.
bytesRead, err := syscall.Read(src, buffer)
if err != nil {
return err
}
// Write data to all destinations.
dataSlice := buffer[0:bytesRead]
for _, dest := range dests {
dst := dest
if sync {
syscall.Write(dst, dataSlice)
} else {
go syscall.Write(dst, dataSlice)
}
}
return nil
}
func (s *Server) acceptConnection(epfd int, listenerFd int) error {
connFd, _, err := syscall.Accept(listenerFd)
if err != nil {
return err
}
s.logf("fd %d has connected", connFd)
// Nonblocking sockets help minimize epoll_wait() time.
if err := syscall.SetNonblock(connFd, true); err != nil {
return err
}
event := &syscall.EpollEvent{
Events: syscall.EPOLLIN,
Fd: int32(connFd),
}
if err := syscall.EpollCtl(epfd, syscall.EPOLL_CTL_ADD, connFd, event); err != nil {
return err
}
s.clientFds[connFd] = true
return nil
}
func (s *Server) removeConnection(epfd int, connFd int) error {
s.logf("fd %d has disconnected", connFd)
if err := syscall.EpollCtl(epfd, syscall.EPOLL_CTL_DEL, connFd, nil); err != nil {
return err
}
if err := syscall.Close(connFd); err != nil {
return err
}
delete(s.clientFds, connFd)
return nil
}
// Run begins the server and blocks until the context signals done or an error
// is encountered reading from serial. While running, all serial i/o is
// forwarded to and from the any connection accepted by the listener.
// The listener is closed by the time Run returns, as is the serial device.
func (s *Server) Run(ctx context.Context) error {
serialFd := int(s.serialDevice.Fd())
defer s.serialDevice.Close()
// Get a copy of the fd to the listener.
listenerFile, err := s.listener.File()
if err != nil {
return err
}
defer listenerFile.Close()
// We have to close the listener as well, because listener.File() creates
// and returns a copy of the desired file descriptor.
defer s.listener.Close()
listenerFd := int(listenerFile.Fd())
epfd, err := syscall.EpollCreate(maxConnections)
if err != nil {
return err
}
defer syscall.Close(epfd)
// Register the serial line and socket with the epoll instance.
event := &syscall.EpollEvent{
Events: syscall.EPOLLIN,
Fd: int32(serialFd),
}
if err := syscall.EpollCtl(epfd, syscall.EPOLL_CTL_ADD, serialFd, event); err != nil {
return err
}
event = &syscall.EpollEvent{
Events: syscall.EPOLLIN,
Fd: int32(listenerFd),
}
if err := syscall.EpollCtl(epfd, syscall.EPOLL_CTL_ADD, listenerFd, event); err != nil {
return err
}
events := make([]syscall.EpollEvent, maxConnections)
for {
// If the context has been canceled, close all sockets and exit out.
if err := ctx.Err(); err != nil {
for connFd := range s.clientFds {
syscall.Close(connFd)
}
break
}
if _, err := os.Stat(s.serialPath); os.IsNotExist(err) {
log.Printf("reopening serial line: %s", s.serialPath)
// Normally, we'd want to delete the serialFd from the Epoll file descriptor,
// but udev appears to attempt to close all fds, which then automatically removes it
// from the epoll fd.
s.serialDevice.Close()
for {
var err error
s.serialDevice, err = Open(s.serialPath)
if err != nil {
time.Sleep(1 * time.Second)
} else {
break
}
}
serialFd = int(s.serialDevice.Fd())
event := &syscall.EpollEvent{
Events: syscall.EPOLLIN,
Fd: int32(serialFd),
}
if err := syscall.EpollCtl(epfd, syscall.EPOLL_CTL_ADD, serialFd, event); err != nil {
return err
}
log.Printf("done reopening: %s", s.serialPath)
}
numFds, err := syscall.EpollWait(epfd, events, 1000) // Wait for 1s before timing out.
if err != nil {
// EINTR is (primarily) triggered by EpollWait timeouts, which are
// part of proper functionality in this server.
if err == syscall.EINTR {
continue
}
return err
}
serialBuffer := make([]byte, bufferSize)
connBuffer := make([]byte, bufferSize)
for _, event := range events[0:numFds] {
eventFd := int(event.Fd)
if eventFd == listenerFd {
if (event.Events & syscall.EPOLLIN) > 0 {
// Handle new incoming connections.
if err := s.acceptConnection(epfd, listenerFd); err != nil {
return err
}
} else {
// The listener was closed or corrupted, exit.
return fmt.Errorf("unable to handle event 0x%x on listener", event.Events)
}
} else if eventFd == serialFd {
if (event.Events & syscall.EPOLLIN) > 0 {
// Handle input from the serial device.
connSlice := []int{}
for connFd := range s.clientFds {
connSlice = append(connSlice, connFd)
}
if err := s.copyData(false, serialBuffer, int(serialFd), connSlice...); err != nil {
return err
}
} else {
// The serial device was closed or corrupted, exit.
return fmt.Errorf("unable to handle event 0x%x on listener", event.Events)
}
} else {
// Handle input from a client.
if (event.Events & syscall.EPOLLIN) > 0 {
// Handle incoming data on connections. We force blocking writes when performing
// this copy, as we don't want input from multiple clients interleaved when
// writing to the serial line.
if err := s.copyData(true, connBuffer, int(event.Fd), serialFd); err != nil {
log.Printf("got error reading from client: %s", err)
}
}
// Handle a hangup from a client.
if (event.Events & syscall.EPOLLHUP) > 0 {
if err := s.removeConnection(epfd, eventFd); err != nil {
return err
}
}
}
}
}
return nil
}
func (s *Server) logf(format string, args ...interface{}) {
if s.Logger != nil {
s.Logger.Printf(format, args...)
}
}