[zxwait] Implement context cancellation
Change-Id: I0a6c66affe8c365076c7a86da7e34f86bd6adbae
Reviewed-on: https://fuchsia-review.googlesource.com/c/third_party/go/+/554242
Reviewed-by: Bruno Dal Bo <brunodalbo@google.com>
Commit-Queue: Tamir Duberstein <tamird@google.com>
diff --git a/src/syscall/zx/fidl/interface.go b/src/syscall/zx/fidl/interface.go
index cda347b..3ba9f15 100644
--- a/src/syscall/zx/fidl/interface.go
+++ b/src/syscall/zx/fidl/interface.go
@@ -7,6 +7,7 @@
import (
"sync"
"syscall/zx"
+ "syscall/zx/internal/context"
)
// ServiceRequest is an abstraction over a FIDL interface request which is
@@ -142,7 +143,7 @@
}
// Write the encoded bytes to the channel.
- return withRetry(func() error {
+ return withRetryContext(context.Background(), func() error {
return channelWriteEtc(&p.Channel, respb[:nb], resphd[:nh], 0)
}, *p.Channel.Handle(), zx.SignalChannelWritable, zx.SignalChannelPeerClosed)
}
@@ -177,7 +178,7 @@
func (p *ChannelProxy) read(header *MessageHeader, poolBytes []byte, poolHandleInfos []zx.HandleInfo) ([]byte, []zx.HandleInfo, error) {
var nb, nh uint32
- if err := withRetry(func() error {
+ if err := withRetryContext(context.Background(), func() error {
var err error
nb, nh, err = channelReadEtc(&p.Channel, poolBytes[:], poolHandleInfos[:], 0)
return err
diff --git a/src/syscall/zx/fidl/interface_test.go b/src/syscall/zx/fidl/interface_test.go
index c3f8418..439d181 100644
--- a/src/syscall/zx/fidl/interface_test.go
+++ b/src/syscall/zx/fidl/interface_test.go
@@ -4,11 +4,13 @@
// `go mod` ignores file names for the purpose of resolving
// dependencies, and zxwait doesn't build on not-Fuchsia.
+//go:build fuchsia
// +build fuchsia
package fidl_test
import (
+ "context"
"math/rand"
"sync"
"syscall/zx"
@@ -162,7 +164,7 @@
defer wg.Done()
receiver := &fidl.ChannelProxy{Channel: c1}
- if err := zxwait.WithRetry(func() error {
+ if err := zxwait.WithRetryContext(context.Background(), func() error {
return receiver.Recv(ordinal, &message{})
}, *receiver.Channel.Handle(), zx.SignalChannelReadable, zx.SignalChannelPeerClosed); err != nil {
t.Error(err)
@@ -338,7 +340,7 @@
var header fidl.MessageHeader
var msg message
- if err := zxwait.WithRetry(func() error {
+ if err := zxwait.WithRetryContext(context.Background(), func() error {
nb, _, err := ch.Read(respb[:], nil, 0)
if err != nil {
return err
diff --git a/src/syscall/zx/fidl/zx_fuchsia.go b/src/syscall/zx/fidl/zx_fuchsia.go
index abff2cd..eb9d20c 100644
--- a/src/syscall/zx/fidl/zx_fuchsia.go
+++ b/src/syscall/zx/fidl/zx_fuchsia.go
@@ -4,6 +4,7 @@
// `go mod` ignores file names for the purpose of resolving
// dependencies, and zxwait doesn't build on not-Fuchsia.
+//go:build fuchsia
// +build fuchsia
package fidl
@@ -21,4 +22,4 @@
var channelReadEtc = (*zx.Channel).ReadEtc
var channelWriteEtc = (*zx.Channel).WriteEtc
-var withRetry = zxwait.WithRetry
+var withRetryContext = zxwait.WithRetryContext
diff --git a/src/syscall/zx/fidl/zx_notfuchsia.go b/src/syscall/zx/fidl/zx_notfuchsia.go
index 0ded79f..ae85ffc 100644
--- a/src/syscall/zx/fidl/zx_notfuchsia.go
+++ b/src/syscall/zx/fidl/zx_notfuchsia.go
@@ -2,11 +2,15 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+//go:build !fuchsia
// +build !fuchsia
package fidl
-import "syscall/zx"
+import (
+ "syscall/zx"
+ "syscall/zx/internal/context"
+)
// Return nil rather than panicking so that tests can use fake handles on non-fuchsia.
// Tests on non-fuchsia end up calling this because unmarshalers for strict types close
@@ -37,6 +41,6 @@
panic("channel write etc only supported on fuchsia")
}
-func withRetry(fn func() error, _ zx.Handle, _, _ zx.Signals) error {
+func withRetryContext(_ context.Context, fn func() error, _ zx.Handle, _, _ zx.Signals) error {
return fn()
}
diff --git a/src/syscall/zx/zxwait/zxwait.go b/src/syscall/zx/zxwait/zxwait.go
index 537bf5f..030da11 100644
--- a/src/syscall/zx/zxwait/zxwait.go
+++ b/src/syscall/zx/zxwait/zxwait.go
@@ -4,6 +4,7 @@
// Go's distribution tools attempt to compile everything; this file
// depends on types that don't compile in not-Fuchsia.
+//go:build fuchsia
// +build fuchsia
// Package zxwait implements a Zircon port waiter compatible with goroutines.
@@ -17,6 +18,7 @@
"sync"
"sync/atomic"
"syscall/zx"
+ "syscall/zx/internal/context"
_ "unsafe" // for go:linkname
)
@@ -37,19 +39,28 @@
}
return observed, nil
}
+ return WaitContext(context.Background(), handle, signals)
+}
+
+// WaitContext waits for signals on handle.
+func WaitContext(ctx context.Context, handle zx.Handle, signals zx.Signals) (zx.Signals, error) {
sysWaiterOnce.Do(sysWaiterInit)
- return sysWaiter.Wait(handle, signals)
+ return sysWaiter.Wait(ctx, handle, signals)
}
func WithRetry(fn func() error, handle zx.Handle, ready, closed zx.Signals) error {
+ return WithRetryContext(context.Background(), fn, handle, ready, closed)
+}
+
+func WithRetryContext(ctx context.Context, fn func() error, handle zx.Handle, ready, closed zx.Signals) error {
signals := ready | closed
for {
err := fn()
if err, ok := err.(*zx.Error); ok && err.Status == zx.ErrShouldWait {
- obs, err := Wait(
+ obs, err := WaitContext(
+ ctx,
handle,
signals,
- zx.TimensecInfinite,
)
if err != nil {
return err
@@ -85,7 +96,8 @@
g uintptr
- obs zx.Signals
+ done chan<- struct{}
+ obs zx.Signals
}
// A waiter is a zircon port that parks goroutines waiting on signals.
@@ -141,10 +153,14 @@
waiting.obs = pkt.Signal().Observed
}
- switch g := atomic.SwapUintptr(&waiting.g, 0); g {
- case 0, preparingG:
- default:
- goready(g, 0)
+ if done := waiting.done; done != nil {
+ close(done)
+ } else {
+ switch g := atomic.SwapUintptr(&waiting.g, 0); g {
+ case 0, preparingG:
+ default:
+ goready(g, 0)
+ }
}
}
}
@@ -155,10 +171,12 @@
for waiting := range w.mu.allByHandle[handle] {
switch status := zx.Sys_port_cancel(zx.Handle(w.port), handle, waiting.key); status {
case zx.ErrOk:
- if err := w.port.Queue(&zx.Packet{Hdr: zx.PacketHeader{
- Key: waiting.key,
- Type: zx.PortPacketTypeUser,
- }}); err != nil {
+ if err := w.port.Queue(&zx.Packet{
+ Hdr: zx.PacketHeader{
+ Key: waiting.key,
+ Type: zx.PortPacketTypeUser,
+ },
+ }); err != nil {
return err
}
case zx.ErrNotFound:
@@ -173,7 +191,7 @@
// Wait waits for signals on handle.
//
// See the package function Wait for more commentary.
-func (w *waiter) Wait(handle zx.Handle, signals zx.Signals) (zx.Signals, error) {
+func (w *waiter) Wait(ctx context.Context, handle zx.Handle, signals zx.Signals) (zx.Signals, error) {
var waiting *waitingG
w.mu.Lock()
@@ -192,6 +210,32 @@
m = make(map[*waitingG]struct{})
w.mu.allByHandle[handle] = m
}
+ var wait func() (zx.Signals, error)
+ if done := ctx.Done(); done != nil {
+ ch := make(chan struct{})
+ waiting.done = ch
+ wait = func() (zx.Signals, error) {
+ select {
+ case <-ch:
+ // Wait complete.
+ return waiting.obs, nil
+ case <-done:
+ // Context canceled.
+ switch status := zx.Sys_port_cancel(zx.Handle(w.port), handle, waiting.key); status {
+ case zx.ErrOk:
+ return 0, ctx.Err()
+ case zx.ErrNotFound:
+ // We lost the race.
+ <-ch
+ return waiting.obs, nil
+ default:
+ return 0, &zx.Error{Status: status, Text: "zx.Port.Cancel"}
+ }
+ }
+ }
+ } else {
+ waiting.done = nil
+ }
m[waiting] = struct{}{}
// waiting must be fully initialized before a wakeup in dequeue is possible -
// after the call to wait_async and when the mutex is not held.
@@ -212,6 +256,10 @@
return 0, err
}
+ if wait != nil {
+ return wait()
+ }
+
const waitReasonIOWait = 2
const traceEvGoBlockSelect = 24
gopark(w.unlockf, waiting, waitReasonIOWait, traceEvGoBlockSelect, 0)
diff --git a/src/syscall/zx/zxwait/zxwait_test.go b/src/syscall/zx/zxwait/zxwait_test.go
index a0bbd33..c4c418b 100644
--- a/src/syscall/zx/zxwait/zxwait_test.go
+++ b/src/syscall/zx/zxwait/zxwait_test.go
@@ -4,11 +4,13 @@
// `go mod` ignores file names for the purpose of resolving
// dependencies, and zxwait doesn't build on not-Fuchsia.
+//go:build fuchsia
// +build fuchsia
package zxwait_test
import (
+ "context"
"fmt"
"runtime"
"sync"
@@ -17,7 +19,7 @@
"testing"
)
-func TestWaitPreexisting(t *testing.T) {
+func TestWaitContextPreexisting(t *testing.T) {
c0, c1, err := zx.NewChannel(0)
if err != nil {
t.Fatal(err)
@@ -28,7 +30,7 @@
}()
for _, ch := range [...]zx.Channel{c0, c1} {
- obs, err := zxwait.Wait(*ch.Handle(), zx.SignalChannelWritable, zx.TimensecInfinite)
+ obs, err := zxwait.WaitContext(context.Background(), *ch.Handle(), zx.SignalChannelWritable)
if err != nil {
t.Fatal(err)
}
@@ -43,7 +45,7 @@
err error
}
-func TestWait(t *testing.T) {
+func TestWaitContext(t *testing.T) {
var pairs [][2]zx.Channel
defer func() {
for _, pair := range pairs {
@@ -62,7 +64,7 @@
pair := [...]zx.Channel{ch1, ch2}
pairs = append(pairs, pair)
go func() {
- obs, err := zxwait.Wait(*pair[0].Handle(), zx.SignalChannelReadable, zx.TimensecInfinite)
+ obs, err := zxwait.WaitContext(context.Background(), *pair[0].Handle(), zx.SignalChannelReadable)
ch <- waitResult{obs: obs, err: err}
}()
}
@@ -85,7 +87,50 @@
}
}
-func TestWait_LocalClose(t *testing.T) {
+func TestWaitContext_Cancel(t *testing.T) {
+ event, err := zx.NewEvent(0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func() {
+ if err := event.Close(); err != nil {
+ t.Fatal(err)
+ }
+ }()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ ch := make(chan waitResult, 100)
+ var wg sync.WaitGroup
+ for i := 0; i < cap(ch)-1; i++ {
+ wg.Add(1)
+ go func() {
+ wg.Done()
+ obs, err := zxwait.WaitContext(ctx, *event.Handle(), 0)
+ ch <- waitResult{obs: obs, err: err}
+ }()
+ }
+ // Wait for all goroutines to be scheduled.
+ wg.Wait()
+ cancel()
+ // Guarantee at least one result happens after cancel.
+ {
+ obs, err := zxwait.WaitContext(ctx, *event.Handle(), 0)
+ ch <- waitResult{obs: obs, err: err}
+ }
+ for i := 0; i < cap(ch); i++ {
+ waitResult := <-ch
+ if got, want := waitResult.obs, zx.Signals(0); got != want {
+ t.Errorf("%d: got obs = %b, want = %b", i, got, want)
+ }
+ if got, want := waitResult.err, context.Canceled; got != want {
+ t.Errorf("%d: got zxwait.WaitContext(<closed handle>) = (_, %s), want = (_, %s)", i, got, want)
+ }
+ }
+}
+
+func TestWaitContext_LocalClose(t *testing.T) {
event, err := zx.NewEvent(0)
if err != nil {
t.Fatal(err)
@@ -99,7 +144,7 @@
wg.Add(1)
go func() {
wg.Done()
- obs, err := zxwait.Wait(*event.Handle(), 0, zx.TimensecInfinite)
+ obs, err := zxwait.WaitContext(context.Background(), *event.Handle(), 0)
ch <- waitResult{obs: obs, err: err}
}()
}
@@ -110,7 +155,7 @@
}
// Guarantee at least one result happens after local close.
{
- obs, err := zxwait.Wait(*event.Handle(), 0, zx.TimensecInfinite)
+ obs, err := zxwait.WaitContext(context.Background(), *event.Handle(), 0)
ch <- waitResult{obs: obs, err: err}
}
var badHandle, cancelled int
@@ -136,7 +181,7 @@
continue
}
}
- t.Errorf("%d: got zxwait.Wait(<closed handle>) = (_, %s), want = (_, %s or %s)", i, err, zx.ErrBadHandle, zx.ErrCanceled)
+ t.Errorf("%d: got zxwait.WaitContext(<closed handle>) = (_, %s), want = (_, %s or %s)", i, err, zx.ErrBadHandle, zx.ErrCanceled)
}
if badHandle == 0 {
t.Errorf("failed to observe post-close condition")
@@ -146,7 +191,7 @@
}
}
-func TestWait_LocalCloseRace(t *testing.T) {
+func TestWaitContext_LocalCloseRace(t *testing.T) {
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
@@ -173,13 +218,13 @@
// which resets its receiver.
go func(event zx.Handle) {
ch <- func() error {
- obs, err := zxwait.Wait(event, zx.SignalUser0, zx.TimensecInfinite)
+ obs, err := zxwait.WaitContext(context.Background(), event, zx.SignalUser0)
if err != nil {
if err, ok := err.(*zx.Error); ok {
switch err.Status {
case zx.ErrCanceled:
if obs != zx.SignalHandleClosed {
- return fmt.Errorf("got zxwait.Wait(..., %b, ...) = %b", zx.SignalUser0, obs)
+ return fmt.Errorf("got zxwait.WaitContext(..., %b, ...) = %b", zx.SignalUser0, obs)
}
fallthrough
case zx.ErrBadHandle:
@@ -190,7 +235,7 @@
return fmt.Errorf("failed to zxwait: %w", err)
}
if obs != zx.SignalUser0 {
- return fmt.Errorf("got zxwait.Wait(..., %b, ...) = %b", zx.SignalUser0, obs)
+ return fmt.Errorf("got zxwait.WaitContext(..., %b, ...) = %b", zx.SignalUser0, obs)
}
return nil
}()
@@ -232,12 +277,12 @@
return fmt.Errorf("failed to signal event: %w", err)
}
- obs, err := zxwait.Wait(event, zx.SignalUser1, zx.TimensecInfinite)
+ obs, err := zxwait.WaitContext(context.Background(), event, zx.SignalUser1)
if err != nil {
return fmt.Errorf("failed to zxwait: %w", err)
}
if obs != zx.SignalUser1 {
- return fmt.Errorf("got zxwait.Wait(..., %b, ...) = %b", zx.SignalUser1, obs)
+ return fmt.Errorf("got zxwait.WaitContext(..., %b, ...) = %b", zx.SignalUser1, obs)
}
if err := event.Close(); err != nil {