transport: fix race sending RPC status that could lead to a panic (#1687)
WriteStatus can be called concurrently: one by SendMsg,
the other by RecvMsg. Then, closing writes channel
becomes racey without proper locking.
Make transport closing synchronous in such case.
diff --git a/transport/handler_server.go b/transport/handler_server.go
index f1f6caf..7e0fdb3 100644
--- a/transport/handler_server.go
+++ b/transport/handler_server.go
@@ -123,10 +123,9 @@
// when WriteStatus is called.
writes chan func()
- mu sync.Mutex
- // streamDone indicates whether WriteStatus has been called and writes channel
- // has been closed.
- streamDone bool
+ // block concurrent WriteStatus calls
+ // e.g. grpc/(*serverStream).SendMsg/RecvMsg
+ writeStatusMu sync.Mutex
}
func (ht *serverHandlerTransport) Close() error {
@@ -177,13 +176,9 @@
}
func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) error {
- ht.mu.Lock()
- if ht.streamDone {
- ht.mu.Unlock()
- return nil
- }
- ht.streamDone = true
- ht.mu.Unlock()
+ ht.writeStatusMu.Lock()
+ defer ht.writeStatusMu.Unlock()
+
err := ht.do(func() {
ht.writeCommonHeaders(s)
@@ -222,7 +217,11 @@
}
}
})
- close(ht.writes)
+
+ if err == nil { // transport has not been closed
+ ht.Close()
+ close(ht.writes)
+ }
return err
}
diff --git a/transport/handler_server_test.go b/transport/handler_server_test.go
index 06fe813..8505e1a 100644
--- a/transport/handler_server_test.go
+++ b/transport/handler_server_test.go
@@ -391,9 +391,10 @@
}
}
+// TestHandlerTransport_HandleStreams_MultiWriteStatus ensures that
+// concurrent "WriteStatus"s do not panic writing to closed "writes" channel.
func TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) {
- st := newHandleStreamTest(t)
- handleStream := func(s *Stream) {
+ testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
if want := "/service/foo.bar"; s.method != want {
t.Errorf("stream method = %q; want %q", s.method, want)
}
@@ -408,9 +409,27 @@
}()
}
wg.Wait()
- }
+ })
+}
+
+// TestHandlerTransport_HandleStreams_WriteStatusWrite ensures that "Write"
+// following "WriteStatus" does not panic writing to closed "writes" channel.
+func TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
+ testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
+ if want := "/service/foo.bar"; s.method != want {
+ t.Errorf("stream method = %q; want %q", s.method, want)
+ }
+ st.bodyw.Close() // no body
+
+ st.ht.WriteStatus(s, status.New(codes.OK, ""))
+ st.ht.Write(s, []byte("hdr"), []byte("data"), &Options{})
+ })
+}
+
+func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) {
+ st := newHandleStreamTest(t)
st.ht.HandleStreams(
- func(s *Stream) { go handleStream(s) },
+ func(s *Stream) { go handleStream(st, s) },
func(ctx context.Context, method string) context.Context { return ctx },
)
}