http2: add support for graceful shutdown of Server
This adds support for gracefully shutting down the Server, including
sending a GOAWAY frame to clients and shutting down when things are
idle. For now this support is only when integrated with the standard
library's Server.
A future change could export some of this.
Updates golang/go#4674
Change-Id: I78cd4f58ca529bf9d149054f929d9089e7685875
Reviewed-on: https://go-review.googlesource.com/32412
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
diff --git a/http2/server.go b/http2/server.go
index e04fbe0..370e42e 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -213,6 +213,13 @@
return nil
}
+// h1ServerShutdownChan if non-nil provides a func to return a channel
+// that will be closed when the provided *http.Server wants to shut
+// down. This is initialized via an init func in net/http (via its
+// mangled name from x/tools/cmd/bundle). This is only used when http2
+// is bundled into std for now.
+var h1ServerShutdownChan func(*http.Server) <-chan struct{}
+
// ServeConnOpts are options for the Server.ServeConn method.
type ServeConnOpts struct {
// BaseConfig optionally sets the base configuration
@@ -710,6 +717,11 @@
sc.idleTimerCh = sc.idleTimer.C
}
+ var gracefulShutdownCh <-chan struct{}
+ if sc.hs != nil && h1ServerShutdownChan != nil {
+ gracefulShutdownCh = h1ServerShutdownChan(sc.hs)
+ }
+
go sc.readFrames() // closed by defer sc.conn.Close above
settingsTimer := time.NewTimer(firstSettingsTimeout)
@@ -737,6 +749,9 @@
case <-settingsTimer.C:
sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr())
return
+ case <-gracefulShutdownCh:
+ gracefulShutdownCh = nil
+ sc.goAwayIn(ErrCodeNo, 0)
case <-sc.shutdownTimerCh:
sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
return
@@ -746,6 +761,10 @@
case fn := <-sc.testHookCh:
fn(loopNum)
}
+
+ if sc.inGoAway && sc.curClientStreams == 0 && !sc.needToSendGoAway && !sc.writingFrame {
+ return
+ }
}
}
@@ -1018,7 +1037,7 @@
sc.startFrameWrite(FrameWriteRequest{write: writeSettingsAck{}})
continue
}
- if !sc.inGoAway {
+ if !sc.inGoAway || sc.goAwayCode == ErrCodeNo {
if wr, ok := sc.writeSched.Pop(); ok {
sc.startFrameWrite(wr)
continue
@@ -1036,14 +1055,23 @@
func (sc *serverConn) goAway(code ErrCode) {
sc.serveG.check()
+ var forceCloseIn time.Duration
+ if code != ErrCodeNo {
+ forceCloseIn = 250 * time.Millisecond
+ } else {
+ // TODO: configurable
+ forceCloseIn = 1 * time.Second
+ }
+ sc.goAwayIn(code, forceCloseIn)
+}
+
+func (sc *serverConn) goAwayIn(code ErrCode, forceCloseIn time.Duration) {
+ sc.serveG.check()
if sc.inGoAway {
return
}
- if code != ErrCodeNo {
- sc.shutDownIn(250 * time.Millisecond)
- } else {
- // TODO: configurable
- sc.shutDownIn(1 * time.Second)
+ if forceCloseIn != 0 {
+ sc.shutDownIn(forceCloseIn)
}
sc.inGoAway = true
sc.needToSendGoAway = true
diff --git a/http2/server_test.go b/http2/server_test.go
index cbc89d4..642e6ca 100644
--- a/http2/server_test.go
+++ b/http2/server_test.go
@@ -3530,3 +3530,48 @@
<-done
}
}
+
+func TestServerGracefulShutdown(t *testing.T) {
+ shutdownCh := make(chan struct{})
+ defer func() { h1ServerShutdownChan = nil }()
+ h1ServerShutdownChan = func(*http.Server) <-chan struct{} { return shutdownCh }
+
+ var st *serverTester
+ handlerDone := make(chan struct{})
+ st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ defer close(handlerDone)
+ close(shutdownCh)
+
+ ga := st.wantGoAway()
+ if ga.ErrCode != ErrCodeNo {
+ t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
+ }
+ if ga.LastStreamID != 1 {
+ t.Errorf("GOAWAY LastStreamID = %v; want 1", ga.LastStreamID)
+ }
+
+ w.Header().Set("x-foo", "bar")
+ })
+ defer st.Close()
+
+ st.greet()
+ st.bodylessReq1()
+
+ <-handlerDone
+ hf := st.wantHeaders()
+ goth := st.decodeHeader(hf.HeaderBlockFragment())
+ wanth := [][2]string{
+ {":status", "200"},
+ {"x-foo", "bar"},
+ {"content-type", "text/plain; charset=utf-8"},
+ {"content-length", "0"},
+ }
+ if !reflect.DeepEqual(goth, wanth) {
+ t.Errorf("Got headers %v; want %v", goth, wanth)
+ }
+
+ n, err := st.cc.Read([]byte{0})
+ if n != 0 || err == nil {
+ t.Errorf("Read = %v, %v; want 0, non-nil", n, err)
+ }
+}