blob: 3d96826f3b78be621356e55d21a6aac288f6beb5 [file] [log] [blame]
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build fuchsia
package dispatch
import (
"strconv"
"sync"
"syscall/zx"
)
// WaitResult represents a wait result that is returned by the callback.
// WaitResult determines whether the wait should be re-queued.
type WaitResult int
const (
WaitFinished WaitResult = iota
WaitAgain
)
// WaitID is a monotonically increasing ID which corresponds to a particular
// call to BeginWait. Note that it may become invalidated if a wait is dequeued.
type WaitID uint64
const (
shutdownKey uint64 = 1
)
// Handler is the callback that will be called when the wait is complete.
type Handler func(error, *zx.PacketSignal) WaitResult
// waitContext is a bookkeeping structure for in-progress waits.
type waitContext struct {
object zx.Handle
callback Handler
trigger zx.Signals
}
// Dispatcher can read messages from a handle and assign them to a callback.
type Dispatcher struct {
// port is the underlying port used to wait on signals and dispatch.
port zx.Port
mu struct {
sync.RWMutex
// objects is a map that manages wait contexts.
objects map[WaitID]*waitContext
nextWaitID WaitID
// serveLoops is the count of calls to dispatcher.Serve outstanding.
serveLoops int
// closed is whether Close has already been called.
closed bool
}
}
// NewDispatcher creates a new dispatcher.
func NewDispatcher() (*Dispatcher, error) {
port, err := zx.NewPort(0)
if err != nil {
return nil, err
}
d := Dispatcher{
port: port,
}
d.mu.objects = make(map[WaitID]*waitContext)
return &d, nil
}
func assertWaitResult(result WaitResult, err error) {
if !(result == WaitFinished || (result == WaitAgain && err == nil)) {
panic("expected " + strconv.Itoa(int(result)) + " for status " + err.Error())
}
}
// Close closes a dispatcher, shutting down all handles added to the dispatcher
// and all Serve loops.
func (d *Dispatcher) Close() {
d.mu.Lock()
defer d.mu.Unlock()
if d.mu.closed {
return
}
for i := 0; i < d.mu.serveLoops; i++ {
if err := d.ShutdownOne(); err != nil {
if err, ok := err.(*zx.Error); ok {
switch err.Status {
case zx.ErrShouldWait:
// //docs/reference/syscalls/port_queue.md@e2c644: The port has too
// many pending packets. This may be because Serve loops are stuck in
// a dispatch callback and can't dequeue packets. Such a loop can't
// be shut down otherwise - panic so that these Serve loops can be
// killed and the system can restart the offending process.
panic("dispatcher.Close(): zx_port_queue: too many pending packets to send shutdown packet")
case zx.ErrBadHandle:
// Port was closed out from under us; this shouldn't be possible
// since we checked shutdown above.
panic("dispatcher.Close(): attempted to zx_port_queue on closed port")
default:
panic("dispatcher.Close(): zx_port_queue unexpected error: " + err.Error())
}
}
}
}
d.mu.objects = make(map[WaitID]*waitContext)
d.mu.closed = true
}
// BeginWait creates a new wait on the handle h for signals t. The new wait
// will call the handler c on wait completion.
func (d *Dispatcher) BeginWait(h zx.Handle, t zx.Signals, c Handler) (WaitID, error) {
d.mu.Lock()
defer d.mu.Unlock()
if d.mu.closed {
return 0, &zx.Error{Status: zx.ErrBadState}
}
// Find the next unused ID.
for {
if _, ok := d.mu.objects[d.mu.nextWaitID]; !ok {
break
}
d.mu.nextWaitID++
}
id := d.mu.nextWaitID
if err := d.port.WaitAsync(h, uint64(id), t, zx.PortWaitAsyncOnce); err != nil {
return id, err
}
d.mu.objects[id] = &waitContext{
object: h,
callback: c,
trigger: t,
}
return id, nil
}
// CancelWait cancels the wait with the given WaitID.
func (d *Dispatcher) CancelWait(id WaitID) error {
d.mu.Lock()
defer d.mu.Unlock()
if wc, ok := d.mu.objects[id]; ok {
delete(d.mu.objects, id)
// Note this must be done under lock to avoid another wait occupying the
// same slot and being cancelled accidentally.
return d.port.Cancel(wc.object, uint64(id))
}
return &zx.Error{Status: zx.ErrNotFound}
}
func (d *Dispatcher) dispatch(id WaitID, wc *waitContext, signals *zx.PacketSignal) {
// Call the handler.
result := wc.callback(nil, signals)
assertWaitResult(result, nil)
if result == WaitAgain {
// If we fail to re-arm, notify the handler of what happened.
if err := d.port.WaitAsync(wc.object, uint64(id), wc.trigger, zx.PortWaitAsyncOnce); err != nil {
assertWaitResult(wc.callback(err, nil), err)
}
}
}
// ShutdownOne signals that one Serve loop should exit.
//
// error reflects failure to queue the shutdown signal.
func (d *Dispatcher) ShutdownOne() error {
return d.port.Queue(&zx.Packet{
Hdr: zx.PacketHeader{
Key: shutdownKey,
Type: zx.PortPacketTypeUser,
Status: zx.ErrOk,
},
})
}
// Serve runs until the dispatcher is closed, waiting for the port to return a
// packet and dispatches the relevant handlers.
//
// Serve exits when the Dispatcher is closed.
func (d *Dispatcher) Serve() {
d.mu.Lock()
closed := d.mu.closed
if !closed {
d.mu.serveLoops++
}
d.mu.Unlock()
if closed {
return
}
defer func() {
d.mu.Lock()
d.mu.serveLoops--
if d.mu.serveLoops == 0 {
d.port.Close()
}
d.mu.Unlock()
}()
for {
var packet zx.Packet
if err := d.port.Wait(&packet, zx.TimensecInfinite); err != nil {
// Possible errors:
//
// ZX_ERR_BAD_HANDLE: indicates the port has been closed. Should never
// happen because the port is not closed until all loops have been shut
// down.
//
// ZX_ERR_INVALID_ARGS: indicates packet is not a valid pointer.
//
// ZX_ERR_ACCESS_DENIED: indicates missing rights.
//
// ZX_ERR_TIMED_OUT: indicates a timeout. Should never happen because we
// use an infinite timeout.
//
// These errors all indicate bugs in the implementation. Panic.
panic("dispatcher.Serve(): zx_port_wait unexpected error: " + err.Error())
}
switch typ := packet.Hdr.Type; typ {
case zx.PortPacketTypeUser:
switch key := packet.Hdr.Key; key {
case shutdownKey:
return
default:
panic("unhandled user packet key=" + strconv.FormatUint(key, 10))
}
case zx.PortPacketTypeSignalOne:
id := WaitID(packet.Hdr.Key)
d.mu.RLock()
wc, ok := d.mu.objects[id]
d.mu.RUnlock()
if ok {
d.dispatch(id, wc, packet.Signal())
}
default:
panic("unhandled packet type=" + strconv.FormatUint(uint64(typ), 10))
}
d.mu.RLock()
closed := d.mu.closed
d.mu.RUnlock()
if closed {
return
}
}
}