[zxwait] Implement context cancellation

Change-Id: I0a6c66affe8c365076c7a86da7e34f86bd6adbae
Reviewed-on: https://fuchsia-review.googlesource.com/c/third_party/go/+/554242
Reviewed-by: Bruno Dal Bo <brunodalbo@google.com>
Commit-Queue: Tamir Duberstein <tamird@google.com>
diff --git a/src/syscall/zx/fidl/interface.go b/src/syscall/zx/fidl/interface.go
index cda347b..3ba9f15 100644
--- a/src/syscall/zx/fidl/interface.go
+++ b/src/syscall/zx/fidl/interface.go
@@ -7,6 +7,7 @@
 import (
 	"sync"
 	"syscall/zx"
+	"syscall/zx/internal/context"
 )
 
 // ServiceRequest is an abstraction over a FIDL interface request which is
@@ -142,7 +143,7 @@
 	}
 
 	// Write the encoded bytes to the channel.
-	return withRetry(func() error {
+	return withRetryContext(context.Background(), func() error {
 		return channelWriteEtc(&p.Channel, respb[:nb], resphd[:nh], 0)
 	}, *p.Channel.Handle(), zx.SignalChannelWritable, zx.SignalChannelPeerClosed)
 }
@@ -177,7 +178,7 @@
 
 func (p *ChannelProxy) read(header *MessageHeader, poolBytes []byte, poolHandleInfos []zx.HandleInfo) ([]byte, []zx.HandleInfo, error) {
 	var nb, nh uint32
-	if err := withRetry(func() error {
+	if err := withRetryContext(context.Background(), func() error {
 		var err error
 		nb, nh, err = channelReadEtc(&p.Channel, poolBytes[:], poolHandleInfos[:], 0)
 		return err
diff --git a/src/syscall/zx/fidl/interface_test.go b/src/syscall/zx/fidl/interface_test.go
index c3f8418..439d181 100644
--- a/src/syscall/zx/fidl/interface_test.go
+++ b/src/syscall/zx/fidl/interface_test.go
@@ -4,11 +4,13 @@
 
 // `go mod` ignores file names for the purpose of resolving
 // dependencies, and zxwait doesn't build on not-Fuchsia.
+//go:build fuchsia
 // +build fuchsia
 
 package fidl_test
 
 import (
+	"context"
 	"math/rand"
 	"sync"
 	"syscall/zx"
@@ -162,7 +164,7 @@
 			defer wg.Done()
 
 			receiver := &fidl.ChannelProxy{Channel: c1}
-			if err := zxwait.WithRetry(func() error {
+			if err := zxwait.WithRetryContext(context.Background(), func() error {
 				return receiver.Recv(ordinal, &message{})
 			}, *receiver.Channel.Handle(), zx.SignalChannelReadable, zx.SignalChannelPeerClosed); err != nil {
 				t.Error(err)
@@ -338,7 +340,7 @@
 
 	var header fidl.MessageHeader
 	var msg message
-	if err := zxwait.WithRetry(func() error {
+	if err := zxwait.WithRetryContext(context.Background(), func() error {
 		nb, _, err := ch.Read(respb[:], nil, 0)
 		if err != nil {
 			return err
diff --git a/src/syscall/zx/fidl/zx_fuchsia.go b/src/syscall/zx/fidl/zx_fuchsia.go
index abff2cd..eb9d20c 100644
--- a/src/syscall/zx/fidl/zx_fuchsia.go
+++ b/src/syscall/zx/fidl/zx_fuchsia.go
@@ -4,6 +4,7 @@
 
 // `go mod` ignores file names for the purpose of resolving
 // dependencies, and zxwait doesn't build on not-Fuchsia.
+//go:build fuchsia
 // +build fuchsia
 
 package fidl
@@ -21,4 +22,4 @@
 var channelReadEtc = (*zx.Channel).ReadEtc
 var channelWriteEtc = (*zx.Channel).WriteEtc
 
-var withRetry = zxwait.WithRetry
+var withRetryContext = zxwait.WithRetryContext
diff --git a/src/syscall/zx/fidl/zx_notfuchsia.go b/src/syscall/zx/fidl/zx_notfuchsia.go
index 0ded79f..ae85ffc 100644
--- a/src/syscall/zx/fidl/zx_notfuchsia.go
+++ b/src/syscall/zx/fidl/zx_notfuchsia.go
@@ -2,11 +2,15 @@
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
+//go:build !fuchsia
 // +build !fuchsia
 
 package fidl
 
-import "syscall/zx"
+import (
+	"syscall/zx"
+	"syscall/zx/internal/context"
+)
 
 // Return nil rather than panicking so that tests can use fake handles on non-fuchsia.
 // Tests on non-fuchsia end up calling this because unmarshalers for strict types close
@@ -37,6 +41,6 @@
 	panic("channel write etc only supported on fuchsia")
 }
 
-func withRetry(fn func() error, _ zx.Handle, _, _ zx.Signals) error {
+func withRetryContext(_ context.Context, fn func() error, _ zx.Handle, _, _ zx.Signals) error {
 	return fn()
 }
diff --git a/src/syscall/zx/zxwait/zxwait.go b/src/syscall/zx/zxwait/zxwait.go
index 537bf5f..030da11 100644
--- a/src/syscall/zx/zxwait/zxwait.go
+++ b/src/syscall/zx/zxwait/zxwait.go
@@ -4,6 +4,7 @@
 
 // Go's distribution tools attempt to compile everything; this file
 // depends on types that don't compile in not-Fuchsia.
+//go:build fuchsia
 // +build fuchsia
 
 // Package zxwait implements a Zircon port waiter compatible with goroutines.
@@ -17,6 +18,7 @@
 	"sync"
 	"sync/atomic"
 	"syscall/zx"
+	"syscall/zx/internal/context"
 	_ "unsafe" // for go:linkname
 )
 
@@ -37,19 +39,28 @@
 		}
 		return observed, nil
 	}
+	return WaitContext(context.Background(), handle, signals)
+}
+
+// WaitContext waits for signals on handle.
+func WaitContext(ctx context.Context, handle zx.Handle, signals zx.Signals) (zx.Signals, error) {
 	sysWaiterOnce.Do(sysWaiterInit)
-	return sysWaiter.Wait(handle, signals)
+	return sysWaiter.Wait(ctx, handle, signals)
 }
 
 func WithRetry(fn func() error, handle zx.Handle, ready, closed zx.Signals) error {
+	return WithRetryContext(context.Background(), fn, handle, ready, closed)
+}
+
+func WithRetryContext(ctx context.Context, fn func() error, handle zx.Handle, ready, closed zx.Signals) error {
 	signals := ready | closed
 	for {
 		err := fn()
 		if err, ok := err.(*zx.Error); ok && err.Status == zx.ErrShouldWait {
-			obs, err := Wait(
+			obs, err := WaitContext(
+				ctx,
 				handle,
 				signals,
-				zx.TimensecInfinite,
 			)
 			if err != nil {
 				return err
@@ -85,7 +96,8 @@
 
 	g uintptr
 
-	obs zx.Signals
+	done chan<- struct{}
+	obs  zx.Signals
 }
 
 // A waiter is a zircon port that parks goroutines waiting on signals.
@@ -141,10 +153,14 @@
 			waiting.obs = pkt.Signal().Observed
 		}
 
-		switch g := atomic.SwapUintptr(&waiting.g, 0); g {
-		case 0, preparingG:
-		default:
-			goready(g, 0)
+		if done := waiting.done; done != nil {
+			close(done)
+		} else {
+			switch g := atomic.SwapUintptr(&waiting.g, 0); g {
+			case 0, preparingG:
+			default:
+				goready(g, 0)
+			}
 		}
 	}
 }
@@ -155,10 +171,12 @@
 	for waiting := range w.mu.allByHandle[handle] {
 		switch status := zx.Sys_port_cancel(zx.Handle(w.port), handle, waiting.key); status {
 		case zx.ErrOk:
-			if err := w.port.Queue(&zx.Packet{Hdr: zx.PacketHeader{
-				Key:  waiting.key,
-				Type: zx.PortPacketTypeUser,
-			}}); err != nil {
+			if err := w.port.Queue(&zx.Packet{
+				Hdr: zx.PacketHeader{
+					Key:  waiting.key,
+					Type: zx.PortPacketTypeUser,
+				},
+			}); err != nil {
 				return err
 			}
 		case zx.ErrNotFound:
@@ -173,7 +191,7 @@
 // Wait waits for signals on handle.
 //
 // See the package function Wait for more commentary.
-func (w *waiter) Wait(handle zx.Handle, signals zx.Signals) (zx.Signals, error) {
+func (w *waiter) Wait(ctx context.Context, handle zx.Handle, signals zx.Signals) (zx.Signals, error) {
 	var waiting *waitingG
 
 	w.mu.Lock()
@@ -192,6 +210,32 @@
 		m = make(map[*waitingG]struct{})
 		w.mu.allByHandle[handle] = m
 	}
+	var wait func() (zx.Signals, error)
+	if done := ctx.Done(); done != nil {
+		ch := make(chan struct{})
+		waiting.done = ch
+		wait = func() (zx.Signals, error) {
+			select {
+			case <-ch:
+				// Wait complete.
+				return waiting.obs, nil
+			case <-done:
+				// Context canceled.
+				switch status := zx.Sys_port_cancel(zx.Handle(w.port), handle, waiting.key); status {
+				case zx.ErrOk:
+					return 0, ctx.Err()
+				case zx.ErrNotFound:
+					// We lost the race.
+					<-ch
+					return waiting.obs, nil
+				default:
+					return 0, &zx.Error{Status: status, Text: "zx.Port.Cancel"}
+				}
+			}
+		}
+	} else {
+		waiting.done = nil
+	}
 	m[waiting] = struct{}{}
 	// waiting must be fully initialized before a wakeup in dequeue is possible -
 	// after the call to wait_async and when the mutex is not held.
@@ -212,6 +256,10 @@
 		return 0, err
 	}
 
+	if wait != nil {
+		return wait()
+	}
+
 	const waitReasonIOWait = 2
 	const traceEvGoBlockSelect = 24
 	gopark(w.unlockf, waiting, waitReasonIOWait, traceEvGoBlockSelect, 0)
diff --git a/src/syscall/zx/zxwait/zxwait_test.go b/src/syscall/zx/zxwait/zxwait_test.go
index a0bbd33..c4c418b 100644
--- a/src/syscall/zx/zxwait/zxwait_test.go
+++ b/src/syscall/zx/zxwait/zxwait_test.go
@@ -4,11 +4,13 @@
 
 // `go mod` ignores file names for the purpose of resolving
 // dependencies, and zxwait doesn't build on not-Fuchsia.
+//go:build fuchsia
 // +build fuchsia
 
 package zxwait_test
 
 import (
+	"context"
 	"fmt"
 	"runtime"
 	"sync"
@@ -17,7 +19,7 @@
 	"testing"
 )
 
-func TestWaitPreexisting(t *testing.T) {
+func TestWaitContextPreexisting(t *testing.T) {
 	c0, c1, err := zx.NewChannel(0)
 	if err != nil {
 		t.Fatal(err)
@@ -28,7 +30,7 @@
 	}()
 
 	for _, ch := range [...]zx.Channel{c0, c1} {
-		obs, err := zxwait.Wait(*ch.Handle(), zx.SignalChannelWritable, zx.TimensecInfinite)
+		obs, err := zxwait.WaitContext(context.Background(), *ch.Handle(), zx.SignalChannelWritable)
 		if err != nil {
 			t.Fatal(err)
 		}
@@ -43,7 +45,7 @@
 	err error
 }
 
-func TestWait(t *testing.T) {
+func TestWaitContext(t *testing.T) {
 	var pairs [][2]zx.Channel
 	defer func() {
 		for _, pair := range pairs {
@@ -62,7 +64,7 @@
 		pair := [...]zx.Channel{ch1, ch2}
 		pairs = append(pairs, pair)
 		go func() {
-			obs, err := zxwait.Wait(*pair[0].Handle(), zx.SignalChannelReadable, zx.TimensecInfinite)
+			obs, err := zxwait.WaitContext(context.Background(), *pair[0].Handle(), zx.SignalChannelReadable)
 			ch <- waitResult{obs: obs, err: err}
 		}()
 	}
@@ -85,7 +87,50 @@
 	}
 }
 
-func TestWait_LocalClose(t *testing.T) {
+func TestWaitContext_Cancel(t *testing.T) {
+	event, err := zx.NewEvent(0)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer func() {
+		if err := event.Close(); err != nil {
+			t.Fatal(err)
+		}
+	}()
+
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	ch := make(chan waitResult, 100)
+	var wg sync.WaitGroup
+	for i := 0; i < cap(ch)-1; i++ {
+		wg.Add(1)
+		go func() {
+			wg.Done()
+			obs, err := zxwait.WaitContext(ctx, *event.Handle(), 0)
+			ch <- waitResult{obs: obs, err: err}
+		}()
+	}
+	// Wait for all goroutines to be scheduled.
+	wg.Wait()
+	cancel()
+	// Guarantee at least one result happens after cancel.
+	{
+		obs, err := zxwait.WaitContext(ctx, *event.Handle(), 0)
+		ch <- waitResult{obs: obs, err: err}
+	}
+	for i := 0; i < cap(ch); i++ {
+		waitResult := <-ch
+		if got, want := waitResult.obs, zx.Signals(0); got != want {
+			t.Errorf("%d: got obs = %b, want = %b", i, got, want)
+		}
+		if got, want := waitResult.err, context.Canceled; got != want {
+			t.Errorf("%d: got zxwait.WaitContext(<closed handle>) = (_, %s), want = (_, %s)", i, got, want)
+		}
+	}
+}
+
+func TestWaitContext_LocalClose(t *testing.T) {
 	event, err := zx.NewEvent(0)
 	if err != nil {
 		t.Fatal(err)
@@ -99,7 +144,7 @@
 		wg.Add(1)
 		go func() {
 			wg.Done()
-			obs, err := zxwait.Wait(*event.Handle(), 0, zx.TimensecInfinite)
+			obs, err := zxwait.WaitContext(context.Background(), *event.Handle(), 0)
 			ch <- waitResult{obs: obs, err: err}
 		}()
 	}
@@ -110,7 +155,7 @@
 	}
 	// Guarantee at least one result happens after local close.
 	{
-		obs, err := zxwait.Wait(*event.Handle(), 0, zx.TimensecInfinite)
+		obs, err := zxwait.WaitContext(context.Background(), *event.Handle(), 0)
 		ch <- waitResult{obs: obs, err: err}
 	}
 	var badHandle, cancelled int
@@ -136,7 +181,7 @@
 				continue
 			}
 		}
-		t.Errorf("%d: got zxwait.Wait(<closed handle>) = (_, %s), want = (_, %s or %s)", i, err, zx.ErrBadHandle, zx.ErrCanceled)
+		t.Errorf("%d: got zxwait.WaitContext(<closed handle>) = (_, %s), want = (_, %s or %s)", i, err, zx.ErrBadHandle, zx.ErrCanceled)
 	}
 	if badHandle == 0 {
 		t.Errorf("failed to observe post-close condition")
@@ -146,7 +191,7 @@
 	}
 }
 
-func TestWait_LocalCloseRace(t *testing.T) {
+func TestWaitContext_LocalCloseRace(t *testing.T) {
 	var wg sync.WaitGroup
 
 	for i := 0; i < 100; i++ {
@@ -173,13 +218,13 @@
 				// which resets its receiver.
 				go func(event zx.Handle) {
 					ch <- func() error {
-						obs, err := zxwait.Wait(event, zx.SignalUser0, zx.TimensecInfinite)
+						obs, err := zxwait.WaitContext(context.Background(), event, zx.SignalUser0)
 						if err != nil {
 							if err, ok := err.(*zx.Error); ok {
 								switch err.Status {
 								case zx.ErrCanceled:
 									if obs != zx.SignalHandleClosed {
-										return fmt.Errorf("got zxwait.Wait(..., %b, ...) = %b", zx.SignalUser0, obs)
+										return fmt.Errorf("got zxwait.WaitContext(..., %b, ...) = %b", zx.SignalUser0, obs)
 									}
 									fallthrough
 								case zx.ErrBadHandle:
@@ -190,7 +235,7 @@
 							return fmt.Errorf("failed to zxwait: %w", err)
 						}
 						if obs != zx.SignalUser0 {
-							return fmt.Errorf("got zxwait.Wait(..., %b, ...) = %b", zx.SignalUser0, obs)
+							return fmt.Errorf("got zxwait.WaitContext(..., %b, ...) = %b", zx.SignalUser0, obs)
 						}
 						return nil
 					}()
@@ -232,12 +277,12 @@
 					return fmt.Errorf("failed to signal event: %w", err)
 				}
 
-				obs, err := zxwait.Wait(event, zx.SignalUser1, zx.TimensecInfinite)
+				obs, err := zxwait.WaitContext(context.Background(), event, zx.SignalUser1)
 				if err != nil {
 					return fmt.Errorf("failed to zxwait: %w", err)
 				}
 				if obs != zx.SignalUser1 {
-					return fmt.Errorf("got zxwait.Wait(..., %b, ...) = %b", zx.SignalUser1, obs)
+					return fmt.Errorf("got zxwait.WaitContext(..., %b, ...) = %b", zx.SignalUser1, obs)
 				}
 
 				if err := event.Close(); err != nil {