server: after GracefulStop, ensure connections are closed when final RPC completes (#5968)
Fixes https://github.com/grpc/grpc-go/issues/5930
diff --git a/internal/transport/controlbuf.go b/internal/transport/controlbuf.go
index a5b7513..9097385 100644
--- a/internal/transport/controlbuf.go
+++ b/internal/transport/controlbuf.go
@@ -527,6 +527,9 @@
// As an optimization, to increase the batch size for each flush, loopy yields the processor, once
// if the batch size is too low to give stream goroutines a chance to fill it up.
func (l *loopyWriter) run() (err error) {
+ // Always flush the writer before exiting in case there are pending frames
+ // to be sent.
+ defer l.framer.writer.Flush()
for {
it, err := l.cbuf.get(true)
if err != nil {
@@ -759,7 +762,7 @@
return err
}
}
- if l.side == clientSide && l.draining && len(l.estdStreams) == 0 {
+ if l.draining && len(l.estdStreams) == 0 {
return errors.New("finished processing active streams while in draining mode")
}
return nil
@@ -814,7 +817,6 @@
}
func (l *loopyWriter) closeConnectionHandler() error {
- l.framer.writer.Flush()
// Exit loopyWriter entirely by returning an error here. This will lead to
// the transport closing the connection, and, ultimately, transport
// closure.
diff --git a/test/gracefulstop_test.go b/test/gracefulstop_test.go
index a5a8448..15e0611 100644
--- a/test/gracefulstop_test.go
+++ b/test/gracefulstop_test.go
@@ -26,6 +26,7 @@
"testing"
"time"
+ "golang.org/x/net/http2"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
@@ -164,3 +165,53 @@
cancel()
wg.Wait()
}
+
+func (s) TestGracefulStopClosesConnAfterLastStream(t *testing.T) {
+ // This test ensures that a server closes the connections to its clients
+ // when the final stream has completed after a GOAWAY.
+
+ handlerCalled := make(chan struct{})
+ gracefulStopCalled := make(chan struct{})
+
+ ts := &funcServer{streamingInputCall: func(stream testpb.TestService_StreamingInputCallServer) error {
+ close(handlerCalled) // Initiate call to GracefulStop.
+ <-gracefulStopCalled // Wait for GOAWAYs to be received by the client.
+ return nil
+ }}
+
+ te := newTest(t, tcpClearEnv)
+ te.startServer(ts)
+ defer te.tearDown()
+
+ te.withServerTester(func(st *serverTester) {
+ st.writeHeadersGRPC(1, "/grpc.testing.TestService/StreamingInputCall", false)
+
+ <-handlerCalled // Wait for the server to invoke its handler.
+
+ // Gracefully stop the server.
+ gracefulStopDone := make(chan struct{})
+ go func() {
+ te.srv.GracefulStop()
+ close(gracefulStopDone)
+ }()
+ st.wantGoAway(http2.ErrCodeNo) // Server sends a GOAWAY due to GracefulStop.
+ pf := st.wantPing() // Server sends a ping to verify client receipt.
+ st.writePing(true, pf.Data) // Send ping ack to confirm.
+ st.wantGoAway(http2.ErrCodeNo) // Wait for subsequent GOAWAY to indicate no new stream processing.
+
+ close(gracefulStopCalled) // Unblock server handler.
+
+ fr := st.wantAnyFrame() // Wait for trailer.
+ hdr, ok := fr.(*http2.MetaHeadersFrame)
+ if !ok {
+ t.Fatalf("Received unexpected frame of type (%T) from server: %v; want HEADERS", fr, fr)
+ }
+ if !hdr.StreamEnded() {
+ t.Fatalf("Received unexpected HEADERS frame from server: %v; want END_STREAM set", fr)
+ }
+
+ st.wantRSTStream(http2.ErrCodeNo) // Server should send RST_STREAM because client did not half-close.
+
+ <-gracefulStopDone // Wait for GracefulStop to return.
+ })
+}
diff --git a/test/servertester.go b/test/servertester.go
index bf7bd8b..3701a0e 100644
--- a/test/servertester.go
+++ b/test/servertester.go
@@ -138,19 +138,46 @@
}
}
+func (st *serverTester) wantGoAway(errCode http2.ErrCode) *http2.GoAwayFrame {
+ f, err := st.readFrame()
+ if err != nil {
+ st.t.Fatalf("Error while expecting an RST frame: %v", err)
+ }
+ gaf, ok := f.(*http2.GoAwayFrame)
+ if !ok {
+ st.t.Fatalf("got a %T; want *http2.GoAwayFrame", f)
+ }
+ if gaf.ErrCode != errCode {
+ st.t.Fatalf("expected GOAWAY error code '%v', got '%v'", errCode.String(), gaf.ErrCode.String())
+ }
+ return gaf
+}
+
+func (st *serverTester) wantPing() *http2.PingFrame {
+ f, err := st.readFrame()
+ if err != nil {
+ st.t.Fatalf("Error while expecting an RST frame: %v", err)
+ }
+ pf, ok := f.(*http2.PingFrame)
+ if !ok {
+ st.t.Fatalf("got a %T; want *http2.GoAwayFrame", f)
+ }
+ return pf
+}
+
func (st *serverTester) wantRSTStream(errCode http2.ErrCode) *http2.RSTStreamFrame {
f, err := st.readFrame()
if err != nil {
st.t.Fatalf("Error while expecting an RST frame: %v", err)
}
- sf, ok := f.(*http2.RSTStreamFrame)
+ rf, ok := f.(*http2.RSTStreamFrame)
if !ok {
st.t.Fatalf("got a %T; want *http2.RSTStreamFrame", f)
}
- if sf.ErrCode != errCode {
- st.t.Fatalf("expected RST error code '%v', got '%v'", errCode.String(), sf.ErrCode.String())
+ if rf.ErrCode != errCode {
+ st.t.Fatalf("expected RST error code '%v', got '%v'", errCode.String(), rf.ErrCode.String())
}
- return sf
+ return rf
}
func (st *serverTester) wantSettings() *http2.SettingsFrame {
diff --git a/test/stream_cleanup_test.go b/test/stream_cleanup_test.go
index 83dd685..53298ea 100644
--- a/test/stream_cleanup_test.go
+++ b/test/stream_cleanup_test.go
@@ -46,7 +46,7 @@
return &testpb.Empty{}, nil
},
}
- if err := ss.Start([]grpc.ServerOption{grpc.MaxConcurrentStreams(1)}, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(callRecvMsgSize))), grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil {
+ if err := ss.Start(nil, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(callRecvMsgSize))), grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
@@ -79,7 +79,7 @@
})
},
}
- if err := ss.Start([]grpc.ServerOption{grpc.MaxConcurrentStreams(1)}, grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil {
+ if err := ss.Start(nil, grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
@@ -132,6 +132,6 @@
case <-gracefulStopDone:
timer.Stop()
case <-timer.C:
- t.Fatalf("s.GracefulStop() didn't finish without 1 second after the last RPC")
+ t.Fatalf("s.GracefulStop() didn't finish within 1 second after the last RPC")
}
}