Reduce the number of Write calls to the underlying io.Writer.
diff --git a/encode.go b/encode.go
index d68d441..834e3b0 100644
--- a/encode.go
+++ b/encode.go
@@ -190,7 +190,7 @@
func NewWriter(w io.Writer) *Writer {
return &Writer{
w: w,
- obuf: make([]byte, MaxEncodedLen(maxUncompressedChunkLen)),
+ obuf: make([]byte, obufLen),
}
}
@@ -205,7 +205,7 @@
return &Writer{
w: w,
ibuf: make([]byte, 0, maxUncompressedChunkLen),
- obuf: make([]byte, MaxEncodedLen(maxUncompressedChunkLen)),
+ obuf: make([]byte, obufLen),
}
}
@@ -224,11 +224,6 @@
// obuf is a buffer for the outgoing (compressed) bytes.
obuf []byte
- // chunkHeaderBuf is a buffer for the per-chunk header (chunk type, length
- // and checksum), not to be confused with the magic string that forms the
- // stream header.
- chunkHeaderBuf [checksumSize + chunkHeaderSize]byte
-
// wroteStreamHeader is whether we have written the stream header.
wroteStreamHeader bool
}
@@ -284,17 +279,14 @@
if w.err != nil {
return 0, w.err
}
- if !w.wroteStreamHeader {
- if copy(w.obuf, magicChunk) != len(magicChunk) {
- panic("unreachable")
- }
- if _, err := w.w.Write(w.obuf[:len(magicChunk)]); err != nil {
- w.err = err
- return nRet, err
- }
- w.wroteStreamHeader = true
- }
for len(p) > 0 {
+ obufStart := len(magicChunk)
+ if !w.wroteStreamHeader {
+ w.wroteStreamHeader = true
+ copy(w.obuf, magicChunk)
+ obufStart = 0
+ }
+
var uncompressed []byte
if len(p) > maxUncompressedChunkLen {
uncompressed, p = p[:maxUncompressedChunkLen], p[maxUncompressedChunkLen:]
@@ -305,28 +297,35 @@
// Compress the buffer, discarding the result if the improvement
// isn't at least 12.5%.
+ compressed := Encode(w.obuf[obufHeaderLen:], uncompressed)
chunkType := uint8(chunkTypeCompressedData)
- chunkBody := Encode(w.obuf, uncompressed)
- if len(chunkBody) >= len(uncompressed)-len(uncompressed)/8 {
- chunkType, chunkBody = chunkTypeUncompressedData, uncompressed
+ chunkLen := 4 + len(compressed)
+ obufEnd := obufHeaderLen + len(compressed)
+ if len(compressed) >= len(uncompressed)-len(uncompressed)/8 {
+ chunkType = chunkTypeUncompressedData
+ chunkLen = 4 + len(uncompressed)
+ obufEnd = obufHeaderLen
}
- chunkLen := 4 + len(chunkBody)
- w.chunkHeaderBuf[0] = chunkType
- w.chunkHeaderBuf[1] = uint8(chunkLen >> 0)
- w.chunkHeaderBuf[2] = uint8(chunkLen >> 8)
- w.chunkHeaderBuf[3] = uint8(chunkLen >> 16)
- w.chunkHeaderBuf[4] = uint8(checksum >> 0)
- w.chunkHeaderBuf[5] = uint8(checksum >> 8)
- w.chunkHeaderBuf[6] = uint8(checksum >> 16)
- w.chunkHeaderBuf[7] = uint8(checksum >> 24)
- if _, err := w.w.Write(w.chunkHeaderBuf[:]); err != nil {
+ // Fill in the per-chunk header that comes before the body.
+ w.obuf[len(magicChunk)+0] = chunkType
+ w.obuf[len(magicChunk)+1] = uint8(chunkLen >> 0)
+ w.obuf[len(magicChunk)+2] = uint8(chunkLen >> 8)
+ w.obuf[len(magicChunk)+3] = uint8(chunkLen >> 16)
+ w.obuf[len(magicChunk)+4] = uint8(checksum >> 0)
+ w.obuf[len(magicChunk)+5] = uint8(checksum >> 8)
+ w.obuf[len(magicChunk)+6] = uint8(checksum >> 16)
+ w.obuf[len(magicChunk)+7] = uint8(checksum >> 24)
+
+ if _, err := w.w.Write(w.obuf[obufStart:obufEnd]); err != nil {
w.err = err
return nRet, err
}
- if _, err := w.w.Write(chunkBody); err != nil {
- w.err = err
- return nRet, err
+ if chunkType == chunkTypeUncompressedData {
+ if _, err := w.w.Write(uncompressed); err != nil {
+ w.err = err
+ return nRet, err
+ }
}
nRet += len(uncompressed)
}
diff --git a/snappy.go b/snappy.go
index 15af18d..1c2b671 100644
--- a/snappy.go
+++ b/snappy.go
@@ -49,6 +49,15 @@
// https://github.com/google/snappy/blob/master/framing_format.txt says
// that "the uncompressed data in a chunk must be no longer than 65536 bytes".
maxUncompressedChunkLen = 65536
+
+ // maxEncodedLenOfMaxUncompressedChunkLen equals
+ // MaxEncodedLen(maxUncompressedChunkLen), but is hard coded to be a const
+ // instead of a variable, so that obufLen can also be a const. Their
+ // equivalence is confirmed by TestMaxEncodedLenOfMaxUncompressedChunkLen.
+ maxEncodedLenOfMaxUncompressedChunkLen = 76490
+
+ obufHeaderLen = len(magicChunk) + checksumSize + chunkHeaderSize
+ obufLen = obufHeaderLen + maxEncodedLenOfMaxUncompressedChunkLen
)
const (
diff --git a/snappy_test.go b/snappy_test.go
index 905dba0..99fa45f 100644
--- a/snappy_test.go
+++ b/snappy_test.go
@@ -23,6 +23,14 @@
testdata = flag.String("testdata", "testdata", "Directory containing the test data")
)
+func TestMaxEncodedLenOfMaxUncompressedChunkLen(t *testing.T) {
+ got := maxEncodedLenOfMaxUncompressedChunkLen
+ want := MaxEncodedLen(maxUncompressedChunkLen)
+ if got != want {
+ t.Fatalf("got %d, want %d", got, want)
+ }
+}
+
func roundtrip(b, ebuf, dbuf []byte) error {
d, err := Decode(dbuf, Encode(ebuf, b))
if err != nil {
@@ -141,6 +149,32 @@
}
}
+func TestWriterGoldenOutput(t *testing.T) {
+ buf := new(bytes.Buffer)
+ w := NewBufferedWriter(buf)
+ defer w.Close()
+ w.Write([]byte("abcd")) // Not compressible.
+ w.Flush()
+ w.Write(bytes.Repeat([]byte{'A'}, 100)) // Compressible.
+ w.Flush()
+ got := buf.String()
+ want := strings.Join([]string{
+ magicChunk,
+ "\x01\x08\x00\x00", // Uncompressed chunk, 8 bytes long (including 4 byte checksum).
+ "\x68\x10\xe6\xb6", // Checksum.
+ "\x61\x62\x63\x64", // Uncompressed payload: "abcd".
+ "\x00\x0d\x00\x00", // Compressed chunk, 13 bytes long (including 4 byte checksum).
+ "\x37\xcb\xbc\x9d", // Checksum.
+ "\x64", // Compressed payload: Uncompressed length (varint encoded): 100.
+ "\x00\x41", // Compressed payload: tagLiteral, length=1, "A".
+ "\xfe\x01\x00", // Compressed payload: tagCopy2, length=64, offset=1.
+ "\x8a\x01\x00", // Compressed payload: tagCopy2, length=35, offset=1.
+ }, "")
+ if got != want {
+ t.Fatalf("\ngot: % x\nwant: % x", got, want)
+ }
+}
+
func TestNewBufferedWriter(t *testing.T) {
// Test all 32 possible sub-sequences of these 5 input slices.
//
@@ -314,6 +348,46 @@
}
}
+type writeCounter int
+
+func (c *writeCounter) Write(p []byte) (int, error) {
+ *c++
+ return len(p), nil
+}
+
+// TestNumUnderlyingWrites tests that each Writer flush only makes one or two
+// Write calls on its underlying io.Writer, depending on whether or not the
+// flushed buffer was compressible.
+func TestNumUnderlyingWrites(t *testing.T) {
+ testCases := []struct {
+ input []byte
+ want int
+ }{
+ {bytes.Repeat([]byte{'x'}, 100), 1},
+ {bytes.Repeat([]byte{'y'}, 100), 1},
+ {[]byte("ABCDEFGHIJKLMNOPQRST"), 2},
+ }
+
+ var c writeCounter
+ w := NewBufferedWriter(&c)
+ defer w.Close()
+ for i, tc := range testCases {
+ c = 0
+ if _, err := w.Write(tc.input); err != nil {
+ t.Errorf("#%d: Write: %v", i, err)
+ continue
+ }
+ if err := w.Flush(); err != nil {
+ t.Errorf("#%d: Flush: %v", i, err)
+ continue
+ }
+ if int(c) != tc.want {
+ t.Errorf("#%d: got %d underlying writes, want %d", i, c, tc.want)
+ continue
+ }
+ }
+}
+
func benchDecode(b *testing.B, src []byte) {
encoded := Encode(nil, src)
// Bandwidth is in amount of uncompressed data.