Ensure that Decode doesn't write past the decoded length.
diff --git a/decode.go b/decode.go
index a62e26f..400873b 100644
--- a/decode.go
+++ b/decode.go
@@ -52,7 +52,9 @@
if err != nil {
return nil, err
}
- if len(dst) < dLen {
+ if dLen <= len(dst) {
+ dst = dst[:dLen]
+ } else {
dst = make([]byte, dLen)
}
diff --git a/snappy_test.go b/snappy_test.go
index 0efc9b3..4376600 100644
--- a/snappy_test.go
+++ b/snappy_test.go
@@ -282,11 +282,46 @@
nil,
}}
+ // notPresent is a byte value that is not present in either the input or
+ // the output. It is written to dBuf to check that Decode does not write
+ // bytes past the end of dBuf[:dLen].
+ const notPresent = 0xfe
+
+ var dBuf [100]byte
+loop:
for i, tc := range testCases {
- g, gotErr := Decode(nil, []byte(tc.input))
+ input := []byte(tc.input)
+ for _, x := range input {
+ if x == notPresent {
+ t.Errorf("#%d (%s): input shouldn't contain byte value %#02x", i, tc.desc, notPresent)
+ continue loop
+ }
+ }
+
+ dLen, n := binary.Uvarint(input)
+ if n <= 0 {
+ t.Errorf("#%d (%s): invalid varint-encoded dLen", i, tc.desc)
+ continue
+ }
+ if dLen > uint64(len(dBuf)) {
+ t.Errorf("#%d (%s): dLen %d is too large", i, tc.desc, dLen)
+ continue
+ }
+
+ for j := range dBuf {
+ dBuf[j] = notPresent
+ }
+ g, gotErr := Decode(dBuf[:], input)
if got := string(g); got != tc.want || gotErr != tc.wantErr {
t.Errorf("#%d (%s):\ngot %q, %v\nwant %q, %v",
i, tc.desc, got, gotErr, tc.want, tc.wantErr)
+ continue
+ }
+ for _, x := range dBuf[dLen:] {
+ if x != notPresent {
+ t.Errorf("#%d (%s): dBuf[dLen:] should all be byte value %#02x", i, tc.desc, notPresent)
+ continue loop
+ }
}
}
}
@@ -297,12 +332,20 @@
const (
prefix = "abcdefghijkl"
suffix = "ABCDEFGHIJKL"
+
+ // notPresent is a byte value that is not present in either the input
+ // or the output. It is written to gotBuf to check that Decode does not
+ // write bytes past the end of gotBuf[:totalLen].
+ notPresent = 0xfe
)
var gotBuf, wantBuf, inputBuf [256]byte
for length := 1; length < 12; length++ {
for offset := 1; offset < 12; offset++ {
+ loop:
for suffixLen := 0; suffixLen < 12; suffixLen++ {
- inputLen := binary.PutUvarint(inputBuf[:], uint64(len(prefix)+length+suffixLen))
+ totalLen := uint64(len(prefix) + length + suffixLen)
+
+ inputLen := binary.PutUvarint(inputBuf[:], totalLen)
inputBuf[inputLen] = tagLiteral + 4*byte(len(prefix)-1)
inputLen++
inputLen += copy(inputBuf[inputLen:], prefix)
@@ -317,9 +360,12 @@
}
input := inputBuf[:inputLen]
+ for i := range gotBuf {
+ gotBuf[i] = notPresent
+ }
got, err := Decode(gotBuf[:], input)
if err != nil {
- t.Errorf("length=%d, offset=%d; suffixLen=%d: %v", length, offset, suffixLen)
+ t.Errorf("length=%d, offset=%d; suffixLen=%d: %v", length, offset, suffixLen, err)
continue
}
@@ -332,6 +378,28 @@
wantLen += copy(wantBuf[wantLen:], suffix[:suffixLen])
want := wantBuf[:wantLen]
+ for _, x := range input {
+ if x == notPresent {
+ t.Errorf("length=%d, offset=%d; suffixLen=%d: input shouldn't contain byte value %#02x",
+ length, offset, suffixLen, notPresent)
+ continue loop
+ }
+ }
+ for _, x := range gotBuf[totalLen:] {
+ if x != notPresent {
+ t.Errorf("length=%d, offset=%d; suffixLen=%d: gotBuf[totalLen:] should all be byte value %#02x",
+ length, offset, suffixLen, notPresent)
+ continue loop
+ }
+ }
+ for _, x := range want {
+ if x == notPresent {
+ t.Errorf("length=%d, offset=%d; suffixLen=%d: want shouldn't contain byte value %#02x",
+ length, offset, suffixLen, notPresent)
+ continue loop
+ }
+ }
+
if !bytes.Equal(got, want) {
t.Errorf("length=%d, offset=%d; suffixLen=%d:\ninput % x\ngot % x\nwant % x",
length, offset, suffixLen, input, got, want)