fix(v2/callctx): fix SetHeader race by cloning header map (#326)
diff --git a/v2/callctx/callctx.go b/v2/callctx/callctx.go
index 9aab3d9..f5af5c9 100644
--- a/v2/callctx/callctx.go
+++ b/v2/callctx/callctx.go
@@ -74,9 +74,27 @@
h, ok := ctx.Value(headerKey).(map[string][]string)
if !ok {
h = make(map[string][]string)
+ } else {
+ h = cloneHeaders(h)
}
+
for i := 0; i < len(keyvals); i = i + 2 {
h[keyvals[i]] = append(h[keyvals[i]], keyvals[i+1])
}
return context.WithValue(ctx, headerKey, h)
}
+
+// cloneHeaders makes a new key-value map while reusing the value slices.
+// As such, new values should be appended to the value slice, and modifying
+// indexed values is not thread safe.
+//
+// TODO: Replace this with maps.Clone when Go 1.21 is the minimum version.
+func cloneHeaders(h map[string][]string) map[string][]string {
+ c := make(map[string][]string, len(h))
+ for k, v := range h {
+ vc := make([]string, len(v))
+ copy(vc, v)
+ c[k] = vc
+ }
+ return c
+}
diff --git a/v2/callctx/callctx_test.go b/v2/callctx/callctx_test.go
index 46d91b0..e644d55 100644
--- a/v2/callctx/callctx_test.go
+++ b/v2/callctx/callctx_test.go
@@ -31,6 +31,7 @@
import (
"context"
+ "sync"
"testing"
"github.com/google/go-cmp/cmp"
@@ -77,3 +78,45 @@
ctx := context.Background()
SetHeaders(ctx, "1", "2", "3")
}
+
+func TestSetHeaders_reuse(t *testing.T) {
+ c := SetHeaders(context.Background(), "key", "value1")
+ v1 := HeadersFromContext(c)
+ c = SetHeaders(c, "key", "value2")
+ v2 := HeadersFromContext(c)
+
+ if cmp.Diff(v2, v1) == "" {
+ t.Errorf("Second header set did not differ from first header set as expected")
+ }
+}
+
+func TestSetHeaders_race(t *testing.T) {
+ key := "key"
+ value := "value"
+ want := map[string][]string{
+ key: []string{value, value},
+ }
+
+ // Init the ctx so a value already exists to be "shared".
+ cctx := SetHeaders(context.Background(), key, value)
+
+ // Reusing the same cctx and adding to the same header key
+ // should *not* produce a race condition when run with -race.
+ var wg sync.WaitGroup
+ for i := 0; i < 3; i++ {
+ wg.Add(1)
+ go func(ctx context.Context) {
+ defer wg.Done()
+ c := SetHeaders(ctx, key, value)
+ h := HeadersFromContext(c)
+
+ // Additionally, if there was a race condition,
+ // we may see that one instance of these headers
+ // contains extra values.
+ if diff := cmp.Diff(h, want); diff != "" {
+ t.Errorf("got(-),want(+):\n%s", diff)
+ }
+ }(cctx)
+ }
+ wg.Wait()
+}