Serve() should not return error on Stop() or GracefulStop() (#1485)
* Serve() should not return error on Stop() or GracefulStop()
* Block Serve() until Stop() or GracefulStop() returns
diff --git a/server.go b/server.go
index 5537fdf..def301a 100644
--- a/server.go
+++ b/server.go
@@ -96,6 +96,11 @@
cv *sync.Cond
m map[string]*service // service name -> service info
events trace.EventLog
+
+ quit chan struct{}
+ done chan struct{}
+ quitOnce sync.Once
+ doneOnce sync.Once
}
type options struct {
@@ -307,6 +312,8 @@
opts: opts,
conns: make(map[io.Closer]bool),
m: make(map[string]*service),
+ quit: make(chan struct{}),
+ done: make(chan struct{}),
}
s.cv = sync.NewCond(&s.mu)
s.ctx, s.cancel = context.WithCancel(context.Background())
@@ -485,6 +492,14 @@
s.mu.Lock()
s.printf("done serving; Accept = %v", err)
s.mu.Unlock()
+
+ // If Stop or GracefulStop is called, block until they are done and return nil
+ select {
+ case <-s.quit:
+ <-s.done
+ return nil
+ default:
+ }
return err
}
tempDelay = 0
@@ -1052,6 +1067,16 @@
// pending RPCs on the client side will get notified by connection
// errors.
func (s *Server) Stop() {
+ s.quitOnce.Do(func() {
+ close(s.quit)
+ })
+
+ defer func() {
+ s.doneOnce.Do(func() {
+ close(s.done)
+ })
+ }()
+
s.mu.Lock()
listeners := s.lis
s.lis = nil
@@ -1081,6 +1106,16 @@
// accepting new connections and RPCs and blocks until all the pending RPCs are
// finished.
func (s *Server) GracefulStop() {
+ s.quitOnce.Do(func() {
+ close(s.quit)
+ })
+
+ defer func() {
+ s.doneOnce.Do(func() {
+ close(s.done)
+ })
+ }()
+
s.mu.Lock()
defer s.mu.Unlock()
if s.conns == nil {
diff --git a/server_test.go b/server_test.go
index 6438b5f..cd2f2c0 100644
--- a/server_test.go
+++ b/server_test.go
@@ -23,6 +23,7 @@
"reflect"
"strings"
"testing"
+ "time"
"google.golang.org/grpc/test/leakcheck"
)
@@ -53,6 +54,27 @@
}
}
+func TestGracefulStop(t *testing.T) {
+ defer leakcheck.Check(t)
+
+ lis, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatalf("failed to create listener: %v", err)
+ }
+
+ server := NewServer()
+ go func() {
+ // make sure Serve() is called
+ time.Sleep(time.Millisecond * 500)
+ server.GracefulStop()
+ }()
+
+ err = server.Serve(lis)
+ if err != nil {
+ t.Fatalf("Serve() returned non-nil error on GracefulStop: %v", err)
+ }
+}
+
func TestGetServiceInfo(t *testing.T) {
defer leakcheck.Check(t)
testSd := ServiceDesc{