http2: add support for graceful shutdown of Server

This adds support for gracefully shutting down the Server, including
sending a GOAWAY frame to clients and shutting down when things are
idle. For now this support is only when integrated with the standard
library's Server.

A future change could export some of this.

Updates golang/go#4674

Change-Id: I78cd4f58ca529bf9d149054f929d9089e7685875
Reviewed-on: https://go-review.googlesource.com/32412
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
diff --git a/http2/server.go b/http2/server.go
index e04fbe0..370e42e 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -213,6 +213,13 @@
 	return nil
 }
 
+// h1ServerShutdownChan if non-nil provides a func to return a channel
+// that will be closed when the provided *http.Server wants to shut
+// down. This is initialized via an init func in net/http (via its
+// mangled name from x/tools/cmd/bundle). This is only used when http2
+// is bundled into std for now.
+var h1ServerShutdownChan func(*http.Server) <-chan struct{}
+
 // ServeConnOpts are options for the Server.ServeConn method.
 type ServeConnOpts struct {
 	// BaseConfig optionally sets the base configuration
@@ -710,6 +717,11 @@
 		sc.idleTimerCh = sc.idleTimer.C
 	}
 
+	var gracefulShutdownCh <-chan struct{}
+	if sc.hs != nil && h1ServerShutdownChan != nil {
+		gracefulShutdownCh = h1ServerShutdownChan(sc.hs)
+	}
+
 	go sc.readFrames() // closed by defer sc.conn.Close above
 
 	settingsTimer := time.NewTimer(firstSettingsTimeout)
@@ -737,6 +749,9 @@
 		case <-settingsTimer.C:
 			sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr())
 			return
+		case <-gracefulShutdownCh:
+			gracefulShutdownCh = nil
+			sc.goAwayIn(ErrCodeNo, 0)
 		case <-sc.shutdownTimerCh:
 			sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
 			return
@@ -746,6 +761,10 @@
 		case fn := <-sc.testHookCh:
 			fn(loopNum)
 		}
+
+		if sc.inGoAway && sc.curClientStreams == 0 && !sc.needToSendGoAway && !sc.writingFrame {
+			return
+		}
 	}
 }
 
@@ -1018,7 +1037,7 @@
 			sc.startFrameWrite(FrameWriteRequest{write: writeSettingsAck{}})
 			continue
 		}
-		if !sc.inGoAway {
+		if !sc.inGoAway || sc.goAwayCode == ErrCodeNo {
 			if wr, ok := sc.writeSched.Pop(); ok {
 				sc.startFrameWrite(wr)
 				continue
@@ -1036,14 +1055,23 @@
 
 func (sc *serverConn) goAway(code ErrCode) {
 	sc.serveG.check()
+	var forceCloseIn time.Duration
+	if code != ErrCodeNo {
+		forceCloseIn = 250 * time.Millisecond
+	} else {
+		// TODO: configurable
+		forceCloseIn = 1 * time.Second
+	}
+	sc.goAwayIn(code, forceCloseIn)
+}
+
+func (sc *serverConn) goAwayIn(code ErrCode, forceCloseIn time.Duration) {
+	sc.serveG.check()
 	if sc.inGoAway {
 		return
 	}
-	if code != ErrCodeNo {
-		sc.shutDownIn(250 * time.Millisecond)
-	} else {
-		// TODO: configurable
-		sc.shutDownIn(1 * time.Second)
+	if forceCloseIn != 0 {
+		sc.shutDownIn(forceCloseIn)
 	}
 	sc.inGoAway = true
 	sc.needToSendGoAway = true
diff --git a/http2/server_test.go b/http2/server_test.go
index cbc89d4..642e6ca 100644
--- a/http2/server_test.go
+++ b/http2/server_test.go
@@ -3530,3 +3530,48 @@
 		<-done
 	}
 }
+
+func TestServerGracefulShutdown(t *testing.T) {
+	shutdownCh := make(chan struct{})
+	defer func() { h1ServerShutdownChan = nil }()
+	h1ServerShutdownChan = func(*http.Server) <-chan struct{} { return shutdownCh }
+
+	var st *serverTester
+	handlerDone := make(chan struct{})
+	st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		defer close(handlerDone)
+		close(shutdownCh)
+
+		ga := st.wantGoAway()
+		if ga.ErrCode != ErrCodeNo {
+			t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
+		}
+		if ga.LastStreamID != 1 {
+			t.Errorf("GOAWAY LastStreamID = %v; want 1", ga.LastStreamID)
+		}
+
+		w.Header().Set("x-foo", "bar")
+	})
+	defer st.Close()
+
+	st.greet()
+	st.bodylessReq1()
+
+	<-handlerDone
+	hf := st.wantHeaders()
+	goth := st.decodeHeader(hf.HeaderBlockFragment())
+	wanth := [][2]string{
+		{":status", "200"},
+		{"x-foo", "bar"},
+		{"content-type", "text/plain; charset=utf-8"},
+		{"content-length", "0"},
+	}
+	if !reflect.DeepEqual(goth, wanth) {
+		t.Errorf("Got headers %v; want %v", goth, wanth)
+	}
+
+	n, err := st.cc.Read([]byte{0})
+	if n != 0 || err == nil {
+		t.Errorf("Read = %v, %v; want 0, non-nil", n, err)
+	}
+}