Add compressed reads to cas.go. (#237)
It follows the specs specified in bazelbuild/remote-apis#168, and
it is similar to #232. Note that while the API still has room to
change, it is mostly finalized and worth implementing.
A caveat of this implementation is that while the `offset` in reads
refers to the uncompressed bytes, the `limit` refers to the compressed
bytes.
diff --git a/go/pkg/client/BUILD.bazel b/go/pkg/client/BUILD.bazel
index 909f9c4..18d871e 100644
--- a/go/pkg/client/BUILD.bazel
+++ b/go/pkg/client/BUILD.bazel
@@ -27,6 +27,7 @@
"@com_github_golang_glog//:go_default_library",
"@com_github_golang_protobuf//proto:go_default_library",
"@com_github_golang_protobuf//ptypes:go_default_library_gen",
+ "@com_github_klauspost_compress//zstd:go_default_library",
"@com_github_pborman_uuid//:go_default_library",
"@com_github_pkg_errors//:go_default_library",
"@go_googleapis//google/bytestream:bytestream_go_proto",
diff --git a/go/pkg/client/cas.go b/go/pkg/client/cas.go
index c15d289..4e0baec 100644
--- a/go/pkg/client/cas.go
+++ b/go/pkg/client/cas.go
@@ -18,6 +18,7 @@
"github.com/bazelbuild/remote-apis-sdks/go/pkg/filemetadata"
"github.com/bazelbuild/remote-apis-sdks/go/pkg/uploadinfo"
"github.com/golang/protobuf/proto"
+ "github.com/klauspost/compress/zstd"
"github.com/pborman/uuid"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/codes"
@@ -27,8 +28,9 @@
log "github.com/golang/glog"
)
-// DefaultCompressedBytestreamThreshold is the default threshold for transferring blobs compressed on ByteStream.Write RPCs.
-const DefaultCompressedBytestreamThreshold = 1024
+// DefaultCompressedBytestreamThreshold is the default threshold, in bytes, for
+// transferring blobs compressed on ByteStream.Write RPCs.
+const DefaultCompressedBytestreamThreshold = 1024 * 1024
const logInterval = 25
@@ -489,6 +491,22 @@
return dg, c.WriteChunked(ctx, c.ResourceNameWrite(dg.Hash, dg.Size), ch)
}
+// maybeCompressReadBlob will, depending on the client configuration, set the blobs to be
+// read compressed. It returns the appropriate resource name.
+func (c *Client) maybeCompressReadBlob(hash string, sizeBytes int64, w io.WriteCloser) (string, io.WriteCloser, chan error, error) {
+ if !c.shouldCompress(sizeBytes) {
+ // If we aren't compressing the data, theere's nothing to wait on.
+ dummyDone := make(chan error, 1)
+ dummyDone <- nil
+ return c.resourceNameRead(hash, sizeBytes), w, dummyDone, nil
+ }
+ cw, done, err := NewCompressedWriteBuffer(w)
+ if err != nil {
+ return "", nil, nil, err
+ }
+ return c.resourceNameCompressedRead(hash, sizeBytes), cw, done, nil
+}
+
// BatchWriteBlobs uploads a number of blobs to the CAS. They must collectively be below the
// maximum total size for a batch upload, which is about 4 MB (see MaxBatchSize). Digests must be
// computed in advance by the caller. In case multiple errors occur during the blob upload, the
@@ -746,14 +764,36 @@
}
func (c *Client) readBlobToFile(ctx context.Context, hash string, sizeBytes int64, fpath string) (int64, error) {
- n, err := c.readToFile(ctx, c.resourceNameRead(hash, sizeBytes), fpath)
+ f, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, c.RegularMode)
if err != nil {
- return n, err
+ return 0, err
}
- if n != sizeBytes {
- return n, fmt.Errorf("CAS fetch read %d bytes but %d were expected", n, sizeBytes)
+ defer f.Close()
+ return c.readBlobStreamed(ctx, hash, sizeBytes, 0, 0, f)
+}
+
+func NewCompressedWriteBuffer(w io.Writer) (io.WriteCloser, chan error, error) {
+ r, nw := io.Pipe()
+
+ // TODO(rubensf): Reuse decoders when possible to save the effort of starting/closing goroutines.
+ decoder, err := zstd.NewReader(r)
+ if err != nil {
+ return nil, nil, err
}
- return n, nil
+
+ done := make(chan error, 1)
+ go func() {
+ _, err := decoder.WriteTo(w)
+ if err != nil {
+ // Because WriteTo returned early, the pipe writers still
+ // have to go somewhere or they'll block execution.
+ io.Copy(ioutil.Discard, r)
+ }
+ decoder.Close()
+ done <- err
+ }()
+
+ return nw, done, nil
}
// ReadBlobStreamed fetches a blob with a provided digest from the CAS.
@@ -762,20 +802,45 @@
return c.readBlobStreamed(ctx, d.Hash, d.Size, 0, 0, w)
}
+type writerTracker struct {
+ io.Writer
+ writtenBytes int64
+}
+
+func (wc *writerTracker) Write(p []byte) (int, error) {
+ n, err := wc.Writer.Write(p)
+ wc.writtenBytes += int64(n)
+ return n, err
+}
+
+func (wc *writerTracker) Close() error { return nil }
+
func (c *Client) readBlobStreamed(ctx context.Context, hash string, sizeBytes, offset, limit int64, w io.Writer) (int64, error) {
if sizeBytes == 0 {
// Do not download empty blobs.
return 0, nil
}
- n, err := c.readStreamed(ctx, c.resourceNameRead(hash, sizeBytes), offset, limit, w)
+ wt := &writerTracker{Writer: w}
+ name, wc, done, err := c.maybeCompressReadBlob(hash, sizeBytes, wt)
if err != nil {
+ return 0, err
+ }
+ n, err := c.readStreamed(ctx, name, offset, limit, wc)
+ if err != nil {
+ return n, err
+ }
+ if err = wc.Close(); err != nil {
return n, err
}
sz := sizeBytes - offset
if limit > 0 && limit < sz {
sz = limit
}
- if n != sz {
+ if err := <-done; err != nil {
+ return n, fmt.Errorf("Failed to finalize writing downloaded data downstream: %v", err)
+ }
+ close(done)
+ if wt.writtenBytes != sz {
return n, fmt.Errorf("CAS fetch read %d bytes but %d were expected", n, sz)
}
return n, nil
@@ -856,6 +921,12 @@
return fmt.Sprintf("%s/blobs/%s/%d", c.InstanceName, hash, sizeBytes)
}
+// TODO(rubensf): Converge compressor to proto in https://github.com/bazelbuild/remote-apis/pull/168 once
+// that gets merged in.
+func (c *Client) resourceNameCompressedRead(hash string, sizeBytes int64) string {
+ return fmt.Sprintf("%s/compressed-blobs/zstd/%s/%d", c.InstanceName, hash, sizeBytes)
+}
+
// ResourceNameWrite generates a valid write resource name.
func (c *Client) ResourceNameWrite(hash string, sizeBytes int64) string {
return fmt.Sprintf("%s/uploads/%s/blobs/%s/%d", c.InstanceName, uuid.New(), hash, sizeBytes)
diff --git a/go/pkg/client/cas_test.go b/go/pkg/client/cas_test.go
index 84ea727..2d9dc46 100644
--- a/go/pkg/client/cas_test.go
+++ b/go/pkg/client/cas_test.go
@@ -217,7 +217,7 @@
}
for _, tc := range tests {
- t.Run(tc.name, func(t *testing.T) {
+ testFunc := func(t *testing.T) {
*fake = tc.fake
fake.Validate(t)
@@ -243,7 +243,17 @@
if !bytes.Equal(want, got) {
t.Errorf("c.ReadBlobRange(ctx, digest, %d, %d) gave diff: want %v, got %v", tc.offset, tc.limit, want, got)
}
- })
+ }
+
+ // Harder to write in a for loop since it -1/0 isn't an intuitive "enabled/disabled"
+ c.CompressedBytestreamThreshold = -1
+ t.Run(tc.name+" - no compression", testFunc)
+ if tc.limit == 0 {
+ // Limit tests don't work well with compression, as the limit refers to the compressed bytes
+ // while offset, per spec, refers to uncompressed bytes.
+ c.CompressedBytestreamThreshold = 0
+ t.Run(tc.name+" - with compression", testFunc)
+ }
}
}
diff --git a/go/pkg/fakes/cas.go b/go/pkg/fakes/cas.go
index 940e6bb..48976ba 100644
--- a/go/pkg/fakes/cas.go
+++ b/go/pkg/fakes/cas.go
@@ -28,6 +28,8 @@
bspb "google.golang.org/genproto/googleapis/bytestream"
)
+var zstdEncoder, _ = zstd.NewWriter(nil, zstd.WithZeroFrames(true))
+
// Reader implements ByteStream's Read interface, returning one blob.
type Reader struct {
// Blob is the blob being read.
@@ -55,18 +57,34 @@
// Read implements the corresponding RE API function.
func (f *Reader) Read(req *bspb.ReadRequest, stream bsgrpc.ByteStream_ReadServer) error {
path := strings.Split(req.ResourceName, "/")
- if len(path) != 4 || path[0] != "instance" || path[1] != "blobs" {
- return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/blobs/<hash>/<size>\"")
+ if (len(path) != 4 && len(path) != 5) || path[0] != "instance" || (path[1] != "blobs" && path[1] != "compressed-blobs") {
+ return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/blobs|compressed-blobs/<compressor?>/<hash>/<size>\"")
}
+ // indexOffset for all 2+ paths - `compressed-blobs` has one more URI element.
+ indexOffset := 0
+ if path[1] == "compressed-blobs" {
+ indexOffset = 1
+ }
+
dg := digest.NewFromBlob(f.Blob)
- if path[2] != dg.Hash || path[3] != strconv.FormatInt(dg.Size, 10) {
- return status.Errorf(codes.NotFound, "test fake only has blob with digest %s, but %s/%s was requested", dg, path[2], path[3])
+ if path[2+indexOffset] != dg.Hash || path[3+indexOffset] != strconv.FormatInt(dg.Size, 10) {
+ return status.Errorf(codes.NotFound, "test fake only has blob with digest %s, but %s/%s was requested", dg, path[2+indexOffset], path[3+indexOffset])
}
offset := req.ReadOffset
limit := req.ReadLimit
blob := f.Blob
chunks := f.Chunks
+ if path[1] == "compressed-blobs" {
+ if path[2] != "zstd" {
+ return status.Error(codes.InvalidArgument, "test fake expected valid compressor, eg zstd")
+ }
+ blob = zstdEncoder.EncodeAll(blob[offset:], nil)
+ offset = 0
+ // For simplicity in coordinating test server & client, compressed blobs are returned as
+ // one chunk.
+ chunks = []int{len(blob)}
+ }
for len(chunks) > 0 {
buf := blob[:chunks[0]]
if offset >= int64(len(buf)) {
@@ -252,6 +270,7 @@
BatchSize: client.DefaultMaxBatchSize,
PerDigestBlockFn: make(map[digest.Digest]func()),
}
+
c.Clear()
var err error
@@ -632,14 +651,20 @@
}
path := strings.Split(req.ResourceName, "/")
- if len(path) != 4 || path[0] != "instance" || path[1] != "blobs" {
- return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/blobs/<hash>/<size>\"")
+ if (len(path) != 4 && len(path) != 5) || path[0] != "instance" || (path[1] != "blobs" && path[1] != "compressed-blobs") {
+ return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/blobs|compressed-blobs/<compressor?>/<hash>/<size>\"")
}
- size, err := strconv.Atoi(path[3])
+ // indexOffset for all 2+ paths - `compressed-blobs` has one more URI element.
+ indexOffset := 0
+ if path[1] == "compressed-blobs" {
+ indexOffset = 1
+ }
+
+ size, err := strconv.Atoi(path[3+indexOffset])
if err != nil {
- return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/blobs/<hash>/<size>\"")
+ return status.Error(codes.InvalidArgument, "test fake expected resource name of the form \"instance/blobs|compressed-blobs/<compressor?>/<hash>/<size>\"")
}
- dg := digest.TestNew(path[2], int64(size))
+ dg := digest.TestNew(path[2+indexOffset], int64(size))
f.maybeSleep()
f.maybeBlock(dg)
blob, ok := f.blobs[dg]
@@ -655,6 +680,13 @@
if err != nil {
return status.Errorf(codes.Internal, "test fake failed to create chunker: %v", err)
}
+ if path[1] == "compressed-blobs" {
+ if path[2] != "zstd" {
+ return status.Error(codes.InvalidArgument, "test fake expected valid compressor, eg zstd")
+ }
+ blob = zstdEncoder.EncodeAll(blob, nil)
+ }
+
resp := &bspb.ReadResponse{}
for ch.HasNext() {
chunk, err := ch.Next()