[buffers] Catch duplicate buffers.push() calls
Previously, if the same buffer gets returned to the buffer pool multiple
times (which should never happen, but could happen if a bug was
introduced that led to `buffers.push()` being called multiple times with
the same buffer), the buffer could subsequently be returned to multiple
different goroutines by `buffers.get()`. This would lead to parallelism
bugs as multiple goroutines try to write to the same buffer.
Now `buffers.push()` will panic if the buffer being pushed is already in
the buffer pool.
Change-Id: I701b2f45cdce8085bcdbb4d1b2bbb2247dd0ee21
Reviewed-on: https://fuchsia-review.googlesource.com/c/shac-project/shac/+/929154
Reviewed-by: Ina Huh <ihuh@google.com>
Fuchsia-Auto-Submit: Oliver Newman <olivernewman@google.com>
Commit-Queue: Auto-Submit <auto-submit@fuchsia-infra.iam.gserviceaccount.com>
diff --git a/internal/engine/buffers.go b/internal/engine/buffers.go
index 3f15fb6..dac0916 100644
--- a/internal/engine/buffers.go
+++ b/internal/engine/buffers.go
@@ -16,6 +16,7 @@
import (
"bytes"
+ "log"
"sync"
)
@@ -23,26 +24,31 @@
//
// Fill up 3 large buffers to accelerate the bootstrap.
var buffers = buffersImpl{
- b: []*bytes.Buffer{
- bytes.NewBuffer(make([]byte, 0, 16*1024)),
- bytes.NewBuffer(make([]byte, 0, 16*1024)),
- bytes.NewBuffer(make([]byte, 0, 16*1024)),
+ b: map[*bytes.Buffer]struct{}{
+ bytes.NewBuffer(make([]byte, 0, 16*1024)): {},
+ bytes.NewBuffer(make([]byte, 0, 16*1024)): {},
+ bytes.NewBuffer(make([]byte, 0, 16*1024)): {},
},
}
type buffersImpl struct {
mu sync.Mutex
- b []*bytes.Buffer
+ // Track buffers in a map to prevent storing duplicates in the pool.
+ b map[*bytes.Buffer]struct{}
}
func (i *buffersImpl) get() *bytes.Buffer {
var b *bytes.Buffer
i.mu.Lock()
- if l := len(i.b); l == 0 {
+ if len(i.b) == 0 {
b = &bytes.Buffer{}
} else {
- b = i.b[l-1]
- i.b = i.b[:l-1]
+ // Choose a random element from the pool by taking whatever buffer is
+ // returned first when iterating over the pool.
+ for b = range i.b {
+ break
+ }
+ delete(i.b, b)
}
i.mu.Unlock()
return b
@@ -52,6 +58,10 @@
// Reset keeps the buffer, so that the next execution will reuse the same allocation.
b.Reset()
i.mu.Lock()
- i.b = append(i.b, b)
+ if _, ok := i.b[b]; ok {
+ i.mu.Unlock()
+ log.Panicf("buffer at %p has already been returned to the pool", b)
+ }
+ i.b[b] = struct{}{}
i.mu.Unlock()
}
diff --git a/internal/engine/buffers_test.go b/internal/engine/buffers_test.go
new file mode 100644
index 0000000..337fc70
--- /dev/null
+++ b/internal/engine/buffers_test.go
@@ -0,0 +1,80 @@
+// Copyright 2023 The Shac Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package engine
+
+import (
+ "bytes"
+ "fmt"
+ "testing"
+)
+
+func TestBuffers(t *testing.T) {
+ t.Parallel()
+
+ t.Run("get buffer from empty pool", func(t *testing.T) {
+ t.Parallel()
+ // Avoid using the shared global buffer pool, for determinism.
+ buffers := buffersImpl{b: make(map[*bytes.Buffer]struct{})}
+
+ b := buffers.get()
+ if b.Len() != 0 {
+ t.Errorf("Expected new buffer to have length 0, got %d", b.Len())
+ }
+ if b.Cap() != 0 {
+ t.Errorf("Expected new buffer to have capacity 0, got %d", b.Cap())
+ }
+
+ _, err := b.WriteString("hello, world")
+ if err != nil {
+ t.Fatal(err)
+ }
+ wantCap := b.Cap()
+
+ buffers.push(b)
+
+ b2 := buffers.get()
+ if b2 != b {
+ t.Errorf("buffers.get() should return the existing buffer")
+ }
+ if b2.Len() != 0 {
+ t.Errorf("Expected reused buffer to be empty, but it contains: %q", b2)
+ }
+ if b2.Cap() != wantCap {
+ t.Errorf("Expected reused buffer to have capacity %d, got %d", wantCap, b2.Cap())
+ }
+ })
+
+ t.Run("push the same buffer multiple times", func(t *testing.T) {
+ t.Parallel()
+ // Avoid using the shared global buffer pool, for determinism.
+ buffers := buffersImpl{b: make(map[*bytes.Buffer]struct{})}
+
+ b := buffers.get()
+
+ buffers.push(b)
+
+ defer func() {
+ msg := recover()
+ if msg == nil {
+ t.Errorf("Expected a panic")
+ }
+ want := fmt.Sprintf("buffer at %p has already been returned to the pool", b)
+ if msg != want {
+ t.Errorf("Got wrong panic message: %s", msg)
+ }
+ }()
+ buffers.push(b)
+ })
+}