journal: detect invalid chunk type
diff --git a/leveldb/journal/journal.go b/leveldb/journal/journal.go
index 891098b..d094c3d 100644
--- a/leveldb/journal/journal.go
+++ b/leveldb/journal/journal.go
@@ -180,34 +180,37 @@
checksum := binary.LittleEndian.Uint32(r.buf[r.j+0 : r.j+4])
length := binary.LittleEndian.Uint16(r.buf[r.j+4 : r.j+6])
chunkType := r.buf[r.j+6]
-
+ unprocBlock := r.n - r.j
if checksum == 0 && length == 0 && chunkType == 0 {
// Drop entire block.
- m := r.n - r.j
r.i = r.n
r.j = r.n
- return r.corrupt(m, "zero header", false)
- } else {
- m := r.n - r.j
- r.i = r.j + headerSize
- r.j = r.j + headerSize + int(length)
- if r.j > r.n {
- // Drop entire block.
- r.i = r.n
- r.j = r.n
- return r.corrupt(m, "chunk length overflows block", false)
- } else if r.checksum && checksum != util.NewCRC(r.buf[r.i-1:r.j]).Value() {
- // Drop entire block.
- r.i = r.n
- r.j = r.n
- return r.corrupt(m, "checksum mismatch", false)
- }
+ return r.corrupt(unprocBlock, "zero header", false)
+ }
+ if chunkType < fullChunkType || chunkType > lastChunkType {
+ // Drop entire block.
+ r.i = r.n
+ r.j = r.n
+ return r.corrupt(unprocBlock, fmt.Sprintf("invalid chunk type %#x", chunkType), false)
+ }
+ r.i = r.j + headerSize
+ r.j = r.j + headerSize + int(length)
+ if r.j > r.n {
+ // Drop entire block.
+ r.i = r.n
+ r.j = r.n
+ return r.corrupt(unprocBlock, "chunk length overflows block", false)
+ } else if r.checksum && checksum != util.NewCRC(r.buf[r.i-1:r.j]).Value() {
+ // Drop entire block.
+ r.i = r.n
+ r.j = r.n
+ return r.corrupt(unprocBlock, "checksum mismatch", false)
}
if first && chunkType != fullChunkType && chunkType != firstChunkType {
- m := r.j - r.i
+ chunkLength := (r.j - r.i) + headerSize
r.i = r.j
// Report the error, but skip it.
- return r.corrupt(m+headerSize, "orphan chunk", true)
+ return r.corrupt(chunkLength, "orphan chunk", true)
}
r.last = chunkType == fullChunkType || chunkType == lastChunkType
return nil