[cas] Implement presence check (#304)

Implement batched checker that check if a digest is present on the
server. Use https://pkg.go.dev/google.golang.org/api/support/bundler
to bundle digests together, while respecting bundle size limits.

Add google.golang.org/api dependency.
diff --git a/go.mod b/go.mod
index fe0cc96..33faef1 100644
--- a/go.mod
+++ b/go.mod
@@ -17,6 +17,7 @@
 	golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
 	golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208
 	golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f // indirect
+	google.golang.org/api v0.30.0
 	google.golang.org/genproto v0.0.0-20200904004341-0bd0a958aa1d
 	google.golang.org/grpc v1.31.0
 	google.golang.org/protobuf v1.25.0
diff --git a/go.sum b/go.sum
index 401b407..aa4667e 100644
--- a/go.sum
+++ b/go.sum
@@ -305,6 +305,7 @@
 google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE=
 google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE=
 google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM=
+google.golang.org/api v0.30.0 h1:yfrXXP61wVuLb0vBcG6qaOoIoqYEzOQS8jum51jkv2w=
 google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc=
 google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
 google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
diff --git a/go/pkg/cas/client.go b/go/pkg/cas/client.go
index 947e238..0a1d513 100644
--- a/go/pkg/cas/client.go
+++ b/go/pkg/cas/client.go
@@ -31,7 +31,8 @@
 
 	// Mockable functions.
 
-	testScheduleCheck func(ctx context.Context, item *uploadItem) error
+	testScheduleCheck  func(ctx context.Context, item *uploadItem) error
+	testScheduleUpload func(ctx context.Context, item *uploadItem) error
 }
 
 // ClientConfig is a config for Client.
@@ -54,6 +55,10 @@
 	// FileIOSize is the size of file reads.
 	FileIOSize int64
 
+	// FindMissingBlobsBatchSize is the maximum number of digests to check in a
+	// single FindMissingBlobs RPC.
+	FindMissingBlobsBatchSize int
+
 	// TODO(nodir): add per-RPC timeouts.
 	// TODO(nodir): add retries.
 }
@@ -77,6 +82,8 @@
 		// GCE docs recommend 4MB IO size for large files.
 		// https://cloud.google.com/compute/docs/disks/optimizing-pd-performance#io-size
 		FileIOSize: 4 * 1024 * 1024, // 4MiB
+
+		FindMissingBlobsBatchSize: 1000,
 	}
 }
 
@@ -96,6 +103,10 @@
 	case c.FileIOSize <= 0:
 		return fmt.Errorf("FileIOSize must be positive")
 
+	// Do not allow more than 100K, otherwise we might run into the request size limits.
+	case c.FindMissingBlobsBatchSize <= 0 || c.FindMissingBlobsBatchSize > 10000:
+		return fmt.Errorf("FindMissingBlobsBatchSize must be in [1, 10000]")
+
 	default:
 		return nil
 	}
diff --git a/go/pkg/cas/upload.go b/go/pkg/cas/upload.go
index bcadf67..1c2125d 100644
--- a/go/pkg/cas/upload.go
+++ b/go/pkg/cas/upload.go
@@ -9,11 +9,13 @@
 	"path/filepath"
 	"sort"
 	"sync"
+	"sync/atomic"
 
 	"github.com/golang/protobuf/proto"
 	"github.com/pkg/errors"
 	"golang.org/x/sync/errgroup"
 	"golang.org/x/sync/semaphore"
+	"google.golang.org/api/support/bundler"
 
 	"github.com/bazelbuild/remote-apis-sdks/go/pkg/cache/singleflightcache"
 	"github.com/bazelbuild/remote-apis-sdks/go/pkg/digest"
@@ -68,8 +70,30 @@
 		fsSem: semaphore.NewWeighted(int64(c.FSConcurrency)),
 	}
 
+	// Initialize checkBundler, which checks if a blob is present on the server.
+	u.checkBundler = bundler.NewBundler(&uploadItem{}, func(items interface{}) {
+		// Handle errors and context cancelation via errgroup.
+		eg.Go(func() error {
+			return u.check(ctx, items.([]*uploadItem))
+		})
+	})
+	// Given that all digests are small (no more than 40 bytes), the count limit
+	// is the bottleneck.
+	// We might run into the request size limits only if we have >100K digests.
+	u.checkBundler.BundleCountThreshold = u.FindMissingBlobsBatchSize
+
 	// Start processing input.
 	eg.Go(func() error {
+		// Before exiting this main goroutine, ensure all the work has been completed.
+		// Just waiting for u.eg isn't enough because some work may be temporarily
+		// in a bundler.
+		defer func() {
+			u.wgFS.Wait()
+
+			// checkBundler can be flushed only after FS walk is done.
+			u.checkBundler.Flush()
+		}()
+
 		for {
 			select {
 			case <-ctx.Done():
@@ -78,9 +102,9 @@
 				if !ok {
 					return nil
 				}
-				eg.Go(func() error {
-					return errors.Wrapf(u.startProcessing(ctx, in), "%q", in.Path)
-				})
+				if err := u.startProcessing(ctx, in); err != nil {
+					return err
+				}
 			}
 		}
 	})
@@ -101,6 +125,8 @@
 	eg    *errgroup.Group
 	stats TransferStats
 
+	// wgFS is used to wait for all FS walking to finish.
+	wgFS sync.WaitGroup
 	// fsCache contains already-processed files.
 	fsCache singleflightcache.Cache
 
@@ -110,6 +136,11 @@
 	// muLargeFile ensures only one large file is read at a time.
 	// TODO(nodir): ensure this doesn't hurt performance on SSDs.
 	muLargeFile sync.Mutex
+
+	// checkBundler bundles digests that need to be checked for presence on the
+	// server.
+	checkBundler *bundler.Bundler
+	seenDigests  sync.Map // TODO: consider making it more global
 }
 
 // startProcessing adds the item to the appropriate stage depending on its type.
@@ -119,20 +150,26 @@
 		return u.scheduleCheck(ctx, uploadItemFromBlob("", in.Content))
 	}
 
-	// Compute the absolute path only once per directory tree.
-	absPath, err := filepath.Abs(in.Path)
-	if err != nil {
-		return errors.Wrapf(err, "failed to get absolute path")
-	}
+	// Schedule a file system walk.
+	u.wgFS.Add(1)
+	u.eg.Go(func() error {
+		defer u.wgFS.Done()
+		// Compute the absolute path only once per directory tree.
+		absPath, err := filepath.Abs(in.Path)
+		if err != nil {
+			return errors.Wrapf(err, "failed to get absolute path")
+		}
 
-	// Do not use os.Stat() here. We want to know if it is a symlink.
-	info, err := os.Lstat(absPath)
-	if err != nil {
-		return errors.Wrapf(err, "lstat failed")
-	}
+		// Do not use os.Stat() here. We want to know if it is a symlink.
+		info, err := os.Lstat(absPath)
+		if err != nil {
+			return errors.Wrapf(err, "lstat failed")
+		}
 
-	_, err = u.visitFile(ctx, absPath, info)
-	return err
+		_, err = u.visitFile(ctx, absPath, info)
+		return err
+	})
+	return nil
 }
 
 // visitFile visits the file/dir depending on its type (regular, dir, symlink).
@@ -255,7 +292,7 @@
 	var mu sync.Mutex
 	dir := &repb.Directory{}
 	var subErr error
-	var wg sync.WaitGroup
+	var wgChildren sync.WaitGroup
 
 	// This sub-function exist to avoid holding the semaphore while waiting for
 	// children.
@@ -284,9 +321,11 @@
 			for _, info := range infos {
 				info := info
 				absChild := joinFilePathsFast(absPath, info.Name())
-				wg.Add(1)
+				wgChildren.Add(1)
+				u.wgFS.Add(1)
 				u.eg.Go(func() error {
-					defer wg.Done()
+					defer wgChildren.Done()
+					defer u.wgFS.Done()
 					node, err := u.visitFile(ctx, absChild, info)
 					mu.Lock()
 					defer mu.Unlock()
@@ -317,8 +356,7 @@
 		return nil, err
 	}
 
-	// Wait for children.
-	wg.Wait()
+	wgChildren.Wait()
 	if subErr != nil {
 		return nil, errors.Wrapf(subErr, "failed to read the directory %q entirely", absPath)
 	}
@@ -354,6 +392,56 @@
 		return u.testScheduleCheck(ctx, item)
 	}
 
+	// Do not check the same digest twice.
+	cacheKey := digest.NewFromProtoUnvalidated(item.Digest)
+	if _, ok := u.seenDigests.LoadOrStore(cacheKey, struct{}{}); ok {
+		return nil
+	}
+	return u.checkBundler.AddWait(ctx, item, 0)
+}
+
+// check checks which items are present on the server, and schedules upload for
+// the missing ones.
+func (u *uploader) check(ctx context.Context, items []*uploadItem) error {
+	req := &repb.FindMissingBlobsRequest{
+		InstanceName: u.InstanceName,
+		BlobDigests:  make([]*repb.Digest, len(items)),
+	}
+	byDigest := make(map[digest.Digest]*uploadItem, len(items))
+	totalBytes := int64(0)
+	for i, item := range items {
+		req.BlobDigests[i] = item.Digest
+		byDigest[digest.NewFromProtoUnvalidated(item.Digest)] = item
+		totalBytes += item.Digest.SizeBytes
+	}
+
+	// TODO(nodir): add retries.
+	// TODO(nodir): add per-RPC timeouts.
+	res, err := u.cas.FindMissingBlobs(ctx, req)
+	if err != nil {
+		return err
+	}
+
+	missingBytes := int64(0)
+	for _, d := range res.MissingBlobDigests {
+		missingBytes += d.SizeBytes
+		item := byDigest[digest.NewFromProtoUnvalidated(d)]
+		if err := u.scheduleUpload(ctx, item); err != nil {
+			return err
+		}
+	}
+	atomic.AddInt64(&u.stats.CacheMisses.Digests, int64(len(res.MissingBlobDigests)))
+	atomic.AddInt64(&u.stats.CacheMisses.Bytes, missingBytes)
+	atomic.AddInt64(&u.stats.CacheHits.Digests, int64(len(items)-len(res.MissingBlobDigests)))
+	atomic.AddInt64(&u.stats.CacheHits.Bytes, totalBytes-missingBytes)
+	return nil
+}
+
+func (u *uploader) scheduleUpload(ctx context.Context, item *uploadItem) error {
+	if u.testScheduleUpload != nil {
+		return u.testScheduleUpload(ctx, item)
+	}
+
 	// TODO(nodir): implement.
 	panic("not implemented")
 }
diff --git a/go/pkg/cas/upload_test.go b/go/pkg/cas/upload_test.go
index cc70ed3..aeeb0e6 100644
--- a/go/pkg/cas/upload_test.go
+++ b/go/pkg/cas/upload_test.go
@@ -8,6 +8,8 @@
 	"sync"
 	"testing"
 
+	"google.golang.org/grpc"
+
 	repb "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2"
 	"github.com/golang/protobuf/proto"
 	"github.com/google/go-cmp/cmp"
@@ -103,6 +105,76 @@
 	}
 }
 
+func TestChecks(t *testing.T) {
+	t.Parallel()
+	ctx := context.Background()
+
+	var mu sync.Mutex
+	var gotDigestChecks []*repb.Digest
+	var gotRequestSizes []int
+	var gotScheduleUploadCalls []*uploadItem
+	cas := &fakeCAS{
+		findMissingBlobs: func(ctx context.Context, in *repb.FindMissingBlobsRequest, opts ...grpc.CallOption) (*repb.FindMissingBlobsResponse, error) {
+			mu.Lock()
+			defer mu.Unlock()
+			gotDigestChecks = append(gotDigestChecks, in.BlobDigests...)
+			gotRequestSizes = append(gotRequestSizes, len(in.BlobDigests))
+			return &repb.FindMissingBlobsResponse{MissingBlobDigests: in.BlobDigests[:1]}, nil
+		},
+	}
+	client := &Client{
+		InstanceName: "projects/p/instances/i",
+		ClientConfig: DefaultClientConfig(),
+		cas:          cas,
+		testScheduleUpload: func(ctx context.Context, item *uploadItem) error {
+			mu.Lock()
+			defer mu.Unlock()
+			gotScheduleUploadCalls = append(gotScheduleUploadCalls, item)
+			return nil
+		},
+	}
+	client.FindMissingBlobsBatchSize = 2
+
+	inputC := inputChanFrom(
+		&UploadInput{Content: []byte("a")},
+		&UploadInput{Content: []byte("b")},
+		&UploadInput{Content: []byte("c")},
+		&UploadInput{Content: []byte("d")},
+	)
+	if _, err := client.Upload(ctx, inputC); err != nil {
+		t.Fatalf("failed to upload: %s", err)
+	}
+
+	wantDigestChecks := []*repb.Digest{
+		{Hash: "18ac3e7343f016890c510e93f935261169d9e3f565436429830faf0934f4f8e4", SizeBytes: 1},
+		{Hash: "2e7d2c03a9507ae265ecf5b5356885a53393a2029d241394997265a1a25aefc6", SizeBytes: 1},
+		{Hash: "3e23e8160039594a33894f6564e1b1348bbd7a0088d42c4acb73eeaed59c009d", SizeBytes: 1},
+		{Hash: "ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb", SizeBytes: 1},
+	}
+	sort.Slice(gotDigestChecks, func(i, j int) bool {
+		return gotDigestChecks[i].Hash < gotDigestChecks[j].Hash
+	})
+	if diff := cmp.Diff(wantDigestChecks, gotDigestChecks); diff != "" {
+		t.Error(diff)
+	}
+	if diff := cmp.Diff([]int{2, 2}, gotRequestSizes); diff != "" {
+		t.Error(diff)
+	}
+
+	wantDigestUploads := []string{
+		"2e7d2c03a9507ae265ecf5b5356885a53393a2029d241394997265a1a25aefc6", // c
+		"ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb", // a
+	}
+	gotDigestUploads := make([]string, len(gotScheduleUploadCalls))
+	for i, req := range gotScheduleUploadCalls {
+		gotDigestUploads[i] = req.Digest.Hash
+	}
+	sort.Strings(gotDigestUploads)
+	if diff := cmp.Diff(wantDigestUploads, gotDigestUploads); diff != "" {
+		t.Error(diff)
+	}
+}
+
 func compareUploadItems(x, y *uploadItem) bool {
 	return x.Title == y.Title &&
 		proto.Equal(x.Digest, y.Digest) &&
@@ -129,3 +201,12 @@
 	close(inputC)
 	return inputC
 }
+
+type fakeCAS struct {
+	repb.ContentAddressableStorageClient
+	findMissingBlobs func(ctx context.Context, in *repb.FindMissingBlobsRequest, opts ...grpc.CallOption) (*repb.FindMissingBlobsResponse, error)
+}
+
+func (c *fakeCAS) FindMissingBlobs(ctx context.Context, in *repb.FindMissingBlobsRequest, opts ...grpc.CallOption) (*repb.FindMissingBlobsResponse, error) {
+	return c.findMissingBlobs(ctx, in, opts...)
+}