snappy: fix (1) encoding a 0-length input returned garbage, and
(2) decoding into an existing buffer returned excess bytes.
R=bradfitz
CC=golang-dev
http://codereview.appspot.com/6294045
diff --git a/snappy/decode.go b/snappy/decode.go
index d169bea..d93c1b9 100644
--- a/snappy/decode.go
+++ b/snappy/decode.go
@@ -117,5 +117,8 @@
dst[d] = dst[d-offset]
}
}
- return dst, nil
+ if d != dLen {
+ return nil, ErrCorrupt
+ }
+ return dst[:d], nil
}
diff --git a/snappy/encode.go b/snappy/encode.go
index a403ab9..846f971 100644
--- a/snappy/encode.go
+++ b/snappy/encode.go
@@ -96,7 +96,9 @@
// Return early if src is short.
if len(src) <= 4 {
- d += emitLiteral(dst[d:], src)
+ if len(src) != 0 {
+ d += emitLiteral(dst[d:], src)
+ }
return dst[:d], nil
}
diff --git a/snappy/snappy_test.go b/snappy/snappy_test.go
index 13ee5c2..aa74bd9 100644
--- a/snappy/snappy_test.go
+++ b/snappy/snappy_test.go
@@ -13,12 +13,12 @@
"testing"
)
-func roundtrip(b []byte) error {
- e, err := Encode(nil, b)
+func roundtrip(b, ebuf, dbuf []byte) error {
+ e, err := Encode(ebuf, b)
if err != nil {
return fmt.Errorf("encoding error: %v", err)
}
- d, err := Decode(nil, e)
+ d, err := Decode(dbuf, e)
if err != nil {
return fmt.Errorf("decoding error: %v", err)
}
@@ -28,11 +28,21 @@
return nil
}
+func TestEmpty(t *testing.T) {
+ if err := roundtrip(nil, nil, nil); err != nil {
+ t.Fatal(err)
+ }
+}
+
func TestSmallCopy(t *testing.T) {
- for i := 0; i < 32; i++ {
- s := "aaaa" + strings.Repeat("b", i) + "aaaabbbb"
- if err := roundtrip([]byte(s)); err != nil {
- t.Fatalf("i=%d: %v", i, err)
+ for _, ebuf := range [][]byte{nil, make([]byte, 20), make([]byte, 64)} {
+ for _, dbuf := range [][]byte{nil, make([]byte, 20), make([]byte, 64)} {
+ for i := 0; i < 32; i++ {
+ s := "aaaa" + strings.Repeat("b", i) + "aaaabbbb"
+ if err := roundtrip([]byte(s), ebuf, dbuf); err != nil {
+ t.Errorf("len(ebuf)=%d, len(dbuf)=%d, i=%d: %v", len(ebuf), len(dbuf), i, err)
+ }
+ }
}
}
}
@@ -44,7 +54,7 @@
for i, _ := range b {
b[i] = uint8(rand.Uint32())
}
- if err := roundtrip(b); err != nil {
+ if err := roundtrip(b, nil, nil); err != nil {
t.Fatal(err)
}
}
@@ -56,7 +66,7 @@
for i, _ := range b {
b[i] = uint8(i%10 + 'a')
}
- if err := roundtrip(b); err != nil {
+ if err := roundtrip(b, nil, nil); err != nil {
t.Fatal(err)
}
}