Don't allow EOF in the middle of a framing chunk.
diff --git a/decode.go b/decode.go
index 521e146..819c717 100644
--- a/decode.go
+++ b/decode.go
@@ -108,9 +108,9 @@
r.readHeader = false
}
-func (r *Reader) readFull(p []byte) (ok bool) {
+func (r *Reader) readFull(p []byte, allowEOF bool) (ok bool) {
if _, r.err = io.ReadFull(r.r, p); r.err != nil {
- if r.err == io.ErrUnexpectedEOF {
+ if r.err == io.ErrUnexpectedEOF || (r.err == io.EOF && !allowEOF) {
r.err = ErrCorrupt
}
return false
@@ -129,7 +129,7 @@
r.i += n
return n, nil
}
- if !r.readFull(r.buf[:4]) {
+ if !r.readFull(r.buf[:4], true) {
return 0, r.err
}
chunkType := r.buf[0]
@@ -156,7 +156,7 @@
return 0, r.err
}
buf := r.buf[:chunkLen]
- if !r.readFull(buf) {
+ if !r.readFull(buf, false) {
return 0, r.err
}
checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
@@ -189,7 +189,7 @@
return 0, r.err
}
buf := r.buf[:checksumSize]
- if !r.readFull(buf) {
+ if !r.readFull(buf, false) {
return 0, r.err
}
checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
@@ -199,7 +199,7 @@
r.err = ErrCorrupt
return 0, r.err
}
- if !r.readFull(r.decoded[:n]) {
+ if !r.readFull(r.decoded[:n], false) {
return 0, r.err
}
if crc(r.decoded[:n]) != checksum {
@@ -215,7 +215,7 @@
r.err = ErrCorrupt
return 0, r.err
}
- if !r.readFull(r.buf[:len(magicBody)]) {
+ if !r.readFull(r.buf[:len(magicBody)], false) {
return 0, r.err
}
for i := 0; i < len(magicBody); i++ {
@@ -234,7 +234,7 @@
}
// Section 4.4 Padding (chunk type 0xfe).
// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
- if !r.readFull(r.buf[:chunkLen]) {
+ if !r.readFull(r.buf[:chunkLen], false) {
return 0, r.err
}
}
diff --git a/snappy_test.go b/snappy_test.go
index e8a1d4d..31ad750 100644
--- a/snappy_test.go
+++ b/snappy_test.go
@@ -725,6 +725,31 @@
}
}
+func TestReaderUncompressedDataOK(t *testing.T) {
+ r := NewReader(strings.NewReader(magicChunk +
+ "\x01\x08\x00\x00" + // Uncompressed chunk, 8 bytes long (including 4 byte checksum).
+ "\x68\x10\xe6\xb6" + // Checksum.
+ "\x61\x62\x63\x64", // Uncompressed payload: "abcd".
+ ))
+ g, err := ioutil.ReadAll(r)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := string(g), "abcd"; got != want {
+ t.Fatalf("got %q, want %q", got, want)
+ }
+}
+
+func TestReaderUncompressedDataNoPayload(t *testing.T) {
+ r := NewReader(strings.NewReader(magicChunk +
+ "\x01\x04\x00\x00" + // Uncompressed chunk, 4 bytes long.
+ "", // No payload; corrupt input.
+ ))
+ if _, err := ioutil.ReadAll(r); err != ErrCorrupt {
+ t.Fatalf("got %v, want %v", err, ErrCorrupt)
+ }
+}
+
func TestReaderUncompressedDataTooLong(t *testing.T) {
// https://github.com/google/snappy/blob/master/framing_format.txt section
// 4.3 says that "the maximum legal chunk length... is 65540", or 0x10004.