[cas] Implement per-request timeouts in a stream (#325)
* [cas] Implement per-request timeouts in a stream
grpc package does not have native support for per-request timeouts for
streaming RPC. Add a helper function to accommodate that.
The func is not specific to CAS, but it wasn't obvious which package
is a better place for it, so place it in `cas/client.go` for now.
diff --git a/go/pkg/cas/client.go b/go/pkg/cas/client.go
index c85362f..309ad7b 100644
--- a/go/pkg/cas/client.go
+++ b/go/pkg/cas/client.go
@@ -320,3 +320,21 @@
return nil
}
+
+// withPerCallTimeout returns a function wrapper that cancels the context if
+// fn does not return within the timeout.
+func withPerCallTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc, func(fn func())) {
+ ctx, cancel := context.WithCancel(ctx)
+ return ctx, cancel, func(fn func()) {
+ stop := make(chan struct{})
+ defer close(stop)
+ go func() {
+ select {
+ case <-time.After(timeout):
+ cancel()
+ case <-stop:
+ }
+ }()
+ fn()
+ }
+}
diff --git a/go/pkg/cas/client_test.go b/go/pkg/cas/client_test.go
new file mode 100644
index 0000000..5ffae60
--- /dev/null
+++ b/go/pkg/cas/client_test.go
@@ -0,0 +1,35 @@
+package cas
+
+import (
+ "context"
+ "testing"
+ "time"
+)
+
+func TestPerCallTimeout(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ ctx, cancel, withTimeout := withPerCallTimeout(ctx, time.Millisecond)
+ defer cancel()
+
+ t.Run("succeeded", func(t *testing.T) {
+ withTimeout(func() {})
+ if ctx.Err() != nil {
+ t.Fatalf("want nil, got %s", ctx.Err())
+ }
+ })
+
+ t.Run("canceled", func(t *testing.T) {
+ withTimeout(func() {
+ select {
+ case <-ctx.Done():
+ case <-time.After(time.Second):
+ }
+ })
+
+ if ctx.Err() != context.Canceled {
+ t.Fatalf("want %s, got %s", context.Canceled, ctx.Err())
+ }
+ })
+}
diff --git a/go/pkg/cas/upload.go b/go/pkg/cas/upload.go
index fc75ba4..fb68e97 100644
--- a/go/pkg/cas/upload.go
+++ b/go/pkg/cas/upload.go
@@ -695,7 +695,6 @@
}
defer r.Close()
- // TODO(nodir): implement per-RPC timeouts. No nice way to do it.
rewind := false
return u.withRetries(ctx, func(ctx context.Context) error {
// TODO(nodir): add support for resumable uploads.
@@ -753,6 +752,9 @@
}
func (u *uploader) streamFromReader(ctx context.Context, r io.Reader, digest *repb.Digest, compressed, updateCacheStats bool) error {
+ ctx, cancel, withTimeout := withPerCallTimeout(ctx, u.ByteStreamWrite.Timeout)
+ defer cancel()
+
stream, err := u.byteStream.Write(ctx)
if err != nil {
return err
@@ -788,7 +790,10 @@
req.Data = buf[:n] // must limit by `:n` in ErrUnexpectedEOF case
// Send the chunk.
- switch err = stream.Send(req); {
+ withTimeout(func() {
+ err = stream.Send(req)
+ })
+ switch {
case err == io.EOF:
// The server closed the stream.
// Most likely the file is already uploaded, see the CommittedSize check below.