[cas] Implement per-request timeouts in a stream (#325)

* [cas] Implement per-request timeouts in a stream

grpc package does not have native support for per-request timeouts for
streaming RPC. Add a helper function to accommodate that.

The func is not specific to CAS, but it wasn't obvious which package
is a better place for it, so place it in `cas/client.go` for now.
diff --git a/go/pkg/cas/client.go b/go/pkg/cas/client.go
index c85362f..309ad7b 100644
--- a/go/pkg/cas/client.go
+++ b/go/pkg/cas/client.go
@@ -320,3 +320,21 @@
 
 	return nil
 }
+
+// withPerCallTimeout returns a function wrapper that cancels the context if
+// fn does not return within the timeout.
+func withPerCallTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc, func(fn func())) {
+	ctx, cancel := context.WithCancel(ctx)
+	return ctx, cancel, func(fn func()) {
+		stop := make(chan struct{})
+		defer close(stop)
+		go func() {
+			select {
+			case <-time.After(timeout):
+				cancel()
+			case <-stop:
+			}
+		}()
+		fn()
+	}
+}
diff --git a/go/pkg/cas/client_test.go b/go/pkg/cas/client_test.go
new file mode 100644
index 0000000..5ffae60
--- /dev/null
+++ b/go/pkg/cas/client_test.go
@@ -0,0 +1,35 @@
+package cas
+
+import (
+	"context"
+	"testing"
+	"time"
+)
+
+func TestPerCallTimeout(t *testing.T) {
+	t.Parallel()
+
+	ctx := context.Background()
+	ctx, cancel, withTimeout := withPerCallTimeout(ctx, time.Millisecond)
+	defer cancel()
+
+	t.Run("succeeded", func(t *testing.T) {
+		withTimeout(func() {})
+		if ctx.Err() != nil {
+			t.Fatalf("want nil, got %s", ctx.Err())
+		}
+	})
+
+	t.Run("canceled", func(t *testing.T) {
+		withTimeout(func() {
+			select {
+			case <-ctx.Done():
+			case <-time.After(time.Second):
+			}
+		})
+
+		if ctx.Err() != context.Canceled {
+			t.Fatalf("want %s, got %s", context.Canceled, ctx.Err())
+		}
+	})
+}
diff --git a/go/pkg/cas/upload.go b/go/pkg/cas/upload.go
index fc75ba4..fb68e97 100644
--- a/go/pkg/cas/upload.go
+++ b/go/pkg/cas/upload.go
@@ -695,7 +695,6 @@
 	}
 	defer r.Close()
 
-	// TODO(nodir): implement per-RPC timeouts. No nice way to do it.
 	rewind := false
 	return u.withRetries(ctx, func(ctx context.Context) error {
 		// TODO(nodir): add support for resumable uploads.
@@ -753,6 +752,9 @@
 }
 
 func (u *uploader) streamFromReader(ctx context.Context, r io.Reader, digest *repb.Digest, compressed, updateCacheStats bool) error {
+	ctx, cancel, withTimeout := withPerCallTimeout(ctx, u.ByteStreamWrite.Timeout)
+	defer cancel()
+
 	stream, err := u.byteStream.Write(ctx)
 	if err != nil {
 		return err
@@ -788,7 +790,10 @@
 		req.Data = buf[:n] // must limit by `:n` in ErrUnexpectedEOF case
 
 		// Send the chunk.
-		switch err = stream.Send(req); {
+		withTimeout(func() {
+			err = stream.Send(req)
+		})
+		switch {
 		case err == io.EOF:
 			// The server closed the stream.
 			// Most likely the file is already uploaded, see the CommittedSize check below.