http2: implement support for server push
This makes x/net/http2's ResponseWriter implement the new interface,
http.Pusher. This new interface requires Go 1.8. When compiled against
older versions of Go, the ResponseWriter does not have a Push method.
Fixes golang/go#13443
Change-Id: I8486ffe4bb5562a94270ace21e90e8c9a4653da0
Reviewed-on: https://go-review.googlesource.com/29439
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/go18.go b/http2/go18.go
index c2ae167..fd3d652 100644
--- a/http2/go18.go
+++ b/http2/go18.go
@@ -6,6 +6,21 @@
package http2
-import "crypto/tls"
+import (
+ "crypto/tls"
+ "net/http"
+)
func cloneTLSConfig(c *tls.Config) *tls.Config { return c.Clone() }
+
+var _ http.Pusher = (*responseWriter)(nil)
+
+// Push implements http.Pusher.
+func (w *responseWriter) Push(target string, opts *http.PushOptions) error {
+ internalOpts := pushOptions{}
+ if opts != nil {
+ internalOpts.Method = opts.Method
+ internalOpts.Header = opts.Header
+ }
+ return w.push(target, internalOpts)
+}
diff --git a/http2/http2.go b/http2/http2.go
index 68afcc4..b6b0f9a 100644
--- a/http2/http2.go
+++ b/http2/http2.go
@@ -78,13 +78,23 @@
type streamState int
+// HTTP/2 stream states.
+//
+// See http://tools.ietf.org/html/rfc7540#section-5.1.
+//
+// For simplicity, the server code merges "reserved (local)" into
+// "half-closed (remote)". This is one less state transition to track.
+// The only downside is that we send PUSH_PROMISEs slightly less
+// liberally than allowable. More discussion here:
+// https://lists.w3.org/Archives/Public/ietf-http-wg/2016JulSep/0599.html
+//
+// "reserved (remote)" is omitted since the client code does not
+// support server push.
const (
stateIdle streamState = iota
stateOpen
stateHalfClosedLocal
stateHalfClosedRemote
- stateResvLocal
- stateResvRemote
stateClosed
)
@@ -93,8 +103,6 @@
stateOpen: "Open",
stateHalfClosedLocal: "HalfClosedLocal",
stateHalfClosedRemote: "HalfClosedRemote",
- stateResvLocal: "ResvLocal",
- stateResvRemote: "ResvRemote",
stateClosed: "Closed",
}
diff --git a/http2/server.go b/http2/server.go
index 50bc112..8fed0b1 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -33,6 +33,7 @@
"fmt"
"io"
"log"
+ "math"
"net"
"net/http"
"net/textproto"
@@ -262,9 +263,11 @@
streams: make(map[uint32]*stream),
readFrameCh: make(chan readFrameResult),
wantWriteFrameCh: make(chan FrameWriteRequest, 8),
+ wantStartPushCh: make(chan startPushRequest, 8),
wroteFrameCh: make(chan frameWriteResult, 1), // buffered; one send in writeFrameAsync
bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way
doneServing: make(chan struct{}),
+ clientMaxStreams: math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value"
advMaxStreams: s.maxConcurrentStreams(),
initialWindowSize: initialWindowSize,
maxFrameSize: initialMaxFrameSize,
@@ -361,6 +364,7 @@
doneServing chan struct{} // closed when serverConn.serve ends
readFrameCh chan readFrameResult // written by serverConn.readFrames
wantWriteFrameCh chan FrameWriteRequest // from handlers -> serve
+ wantStartPushCh chan startPushRequest // from handlers -> serve
wroteFrameCh chan frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes
bodyReadCh chan bodyReadMsg // from handlers -> serve
testHookCh chan func(int) // code to run on the serve loop
@@ -378,8 +382,10 @@
unackedSettings int // how many SETTINGS have we sent without ACKs?
clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit)
advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client
- curOpenStreams uint32 // client's number of open streams
- maxStreamID uint32 // max ever seen
+ curClientStreams uint32 // number of open streams initiated by the client
+ curPushedStreams uint32 // number of open streams initiated by server push
+ maxStreamID uint32 // max ever seen from client
+ maxPushPromiseID uint32 // ID of the last push promise, or 0 if there have been no pushes
streams map[uint32]*stream
initialWindowSize int32
maxFrameSize int32
@@ -457,7 +463,7 @@
func (sc *serverConn) state(streamID uint32) (streamState, *stream) {
sc.serveG.check()
- // http://http2.github.io/http2-spec/#rfc.section.5.1
+ // http://tools.ietf.org/html/rfc7540#section-5.1
if st, ok := sc.streams[streamID]; ok {
return st.state, st
}
@@ -701,6 +707,8 @@
select {
case wr := <-sc.wantWriteFrameCh:
sc.writeFrame(wr)
+ case spr := <-sc.wantStartPushCh:
+ sc.startPush(spr)
case res := <-sc.wroteFrameCh:
sc.wroteFrame(res)
case res := <-sc.readFrameCh:
@@ -881,6 +889,16 @@
panic(fmt.Sprintf("internal error: attempt to send a write %v on a closed stream", wr))
}
}
+ if wpp, ok := wr.write.(*writePushPromise); ok {
+ var err error
+ wpp.promisedID, err = wpp.allocatePromisedID()
+ if err != nil {
+ if wr.done != nil {
+ wr.done <- err
+ }
+ return
+ }
+ }
sc.writingFrame = true
sc.needsFrameFlush = true
@@ -1204,8 +1222,12 @@
panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state))
}
st.state = stateClosed
- sc.curOpenStreams--
- if sc.curOpenStreams == 0 {
+ if st.isPushed() {
+ sc.curPushedStreams--
+ } else {
+ sc.curClientStreams--
+ }
+ if sc.curClientStreams+sc.curPushedStreams == 0 {
sc.setConnState(http.StateIdle)
}
delete(sc.streams, st.id)
@@ -1388,6 +1410,11 @@
return nil
}
+// isPushed reports whether the stream is server-initiated.
+func (st *stream) isPushed() bool {
+ return st.id%2 == 0
+}
+
// endStream closes a Request.Body's pipe. It is called when a DATA
// frame says a request body is over (or after trailers).
func (st *stream) endStream() {
@@ -1417,12 +1444,12 @@
func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
sc.serveG.check()
- id := f.Header().StreamID
+ id := f.StreamID
if sc.inGoAway {
// Ignore.
return nil
}
- // http://http2.github.io/http2-spec/#rfc.section.5.1.1
+ // http://tools.ietf.org/html/rfc7540#section-5.1.1
// Streams initiated by a client MUST use odd-numbered stream
// identifiers. [...] An endpoint that receives an unexpected
// stream identifier MUST respond with a connection error
@@ -1434,8 +1461,7 @@
// send a trailer for an open one. If we already have a stream
// open, let it process its own HEADERS frame (trailers at this
// point, if it's valid).
- st := sc.streams[f.Header().StreamID]
- if st != nil {
+ if st := sc.streams[f.StreamID]; st != nil {
return st.processTrailerHeaders(f)
}
@@ -1453,48 +1479,31 @@
sc.idleTimer.Stop()
}
- ctx, cancelCtx := contextWithCancel(sc.baseCtx)
- st = &stream{
- sc: sc,
- id: id,
- state: stateOpen,
- ctx: ctx,
- cancelCtx: cancelCtx,
- }
- if f.StreamEnded() {
- st.state = stateHalfClosedRemote
- }
- st.cw.Init()
-
- st.flow.conn = &sc.flow // link to conn-level counter
- st.flow.add(sc.initialWindowSize)
- st.inflow.conn = &sc.inflow // link to conn-level counter
- st.inflow.add(initialWindowSize) // TODO: update this when we send a higher initial window size in the initial settings
-
- sc.streams[id] = st
- sc.writeSched.OpenStream(st.id, OpenStreamOptions{})
- sc.curOpenStreams++
- if sc.curOpenStreams == 1 {
- sc.setConnState(http.StateActive)
- }
- if sc.curOpenStreams > sc.advMaxStreams {
- // "Endpoints MUST NOT exceed the limit set by their
- // peer. An endpoint that receives a HEADERS frame
- // that causes their advertised concurrent stream
- // limit to be exceeded MUST treat this as a stream
- // error (Section 5.4.2) of type PROTOCOL_ERROR or
- // REFUSED_STREAM."
+ // http://tools.ietf.org/html/rfc7540#section-5.1.2
+ // [...] Endpoints MUST NOT exceed the limit set by their peer. An
+ // endpoint that receives a HEADERS frame that causes their
+ // advertised concurrent stream limit to be exceeded MUST treat
+ // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR
+ // or REFUSED_STREAM.
+ if sc.curClientStreams+1 > sc.advMaxStreams {
if sc.unackedSettings == 0 {
// They should know better.
- return streamError(st.id, ErrCodeProtocol)
+ return streamError(id, ErrCodeProtocol)
}
// Assume it's a network race, where they just haven't
// received our last SETTINGS update. But actually
// this can't happen yet, because we don't yet provide
// a way for users to adjust server parameters at
// runtime.
- return streamError(st.id, ErrCodeRefusedStream)
+ return streamError(id, ErrCodeRefusedStream)
}
+
+ initialState := stateOpen
+ if f.StreamEnded() {
+ initialState = stateHalfClosedRemote
+ }
+ st := sc.newStream(id, 0, initialState)
+
if f.HasPriority() {
if err := checkPriority(f.StreamID, f.Priority); err != nil {
return err
@@ -1517,7 +1526,7 @@
if f.Truncated {
// Their header list was too long. Send a 431 error.
handler = handleHeaderListTooLong
- } else if err := checkValidHTTP2Request(req); err != nil {
+ } else if err := checkValidHTTP2RequestHeaders(req.Header); err != nil {
handler = new400Handler(err)
}
@@ -1590,21 +1599,56 @@
return nil
}
+func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream {
+ sc.serveG.check()
+ if id == 0 {
+ panic("internal error: cannot create stream with id 0")
+ }
+
+ ctx, cancelCtx := contextWithCancel(sc.baseCtx)
+ st := &stream{
+ sc: sc,
+ id: id,
+ state: state,
+ ctx: ctx,
+ cancelCtx: cancelCtx,
+ }
+ st.cw.Init()
+ st.flow.conn = &sc.flow // link to conn-level counter
+ st.flow.add(sc.initialWindowSize)
+ st.inflow.conn = &sc.inflow // link to conn-level counter
+ st.inflow.add(initialWindowSize) // TODO: update this when we send a higher initial window size in the initial settings
+
+ sc.streams[id] = st
+ sc.writeSched.OpenStream(st.id, OpenStreamOptions{PusherID: pusherID})
+ if st.isPushed() {
+ sc.curPushedStreams++
+ } else {
+ sc.curClientStreams++
+ }
+ if sc.curClientStreams+sc.curPushedStreams == 1 {
+ sc.setConnState(http.StateActive)
+ }
+
+ return st
+}
+
func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *http.Request, error) {
sc.serveG.check()
- method := f.PseudoValue("method")
- path := f.PseudoValue("path")
- scheme := f.PseudoValue("scheme")
- authority := f.PseudoValue("authority")
+ rp := requestParam{
+ method: f.PseudoValue("method"),
+ scheme: f.PseudoValue("scheme"),
+ authority: f.PseudoValue("authority"),
+ path: f.PseudoValue("path"),
+ }
- isConnect := method == "CONNECT"
+ isConnect := rp.method == "CONNECT"
if isConnect {
- if path != "" || scheme != "" || authority == "" {
+ if rp.path != "" || rp.scheme != "" || rp.authority == "" {
return nil, nil, streamError(f.StreamID, ErrCodeProtocol)
}
- } else if method == "" || path == "" ||
- (scheme != "https" && scheme != "http") {
+ } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") {
// See 8.1.2.6 Malformed Requests and Responses:
//
// Malformed requests or responses that are detected
@@ -1619,36 +1663,64 @@
}
bodyOpen := !f.StreamEnded()
- if method == "HEAD" && bodyOpen {
+ if rp.method == "HEAD" && bodyOpen {
// HEAD requests can't have bodies
return nil, nil, streamError(f.StreamID, ErrCodeProtocol)
}
- var tlsState *tls.ConnectionState // nil if not scheme https
- if scheme == "https" {
+ rp.header = make(http.Header)
+ for _, hf := range f.RegularFields() {
+ rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value)
+ }
+ if rp.authority == "" {
+ rp.authority = rp.header.Get("Host")
+ }
+
+ rw, req, err := sc.newWriterAndRequestNoBody(st, rp)
+ if err != nil {
+ return nil, nil, err
+ }
+ if bodyOpen {
+ st.reqBuf = getRequestBodyBuf()
+ req.Body.(*requestBody).pipe = &pipe{
+ b: &fixedBuffer{buf: st.reqBuf},
+ }
+
+ if vv, ok := rp.header["Content-Length"]; ok {
+ req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64)
+ } else {
+ req.ContentLength = -1
+ }
+ }
+ return rw, req, nil
+}
+
+type requestParam struct {
+ method string
+ scheme, authority, path string
+ header http.Header
+}
+
+func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp requestParam) (*responseWriter, *http.Request, error) {
+ sc.serveG.check()
+
+ var tlsState *tls.ConnectionState // nil if not scheme https
+ if rp.scheme == "https" {
tlsState = sc.tlsState
}
- header := make(http.Header)
- for _, hf := range f.RegularFields() {
- header.Add(sc.canonicalHeader(hf.Name), hf.Value)
- }
-
- if authority == "" {
- authority = header.Get("Host")
- }
- needsContinue := header.Get("Expect") == "100-continue"
+ needsContinue := rp.header.Get("Expect") == "100-continue"
if needsContinue {
- header.Del("Expect")
+ rp.header.Del("Expect")
}
// Merge Cookie headers into one "; "-delimited value.
- if cookies := header["Cookie"]; len(cookies) > 1 {
- header.Set("Cookie", strings.Join(cookies, "; "))
+ if cookies := rp.header["Cookie"]; len(cookies) > 1 {
+ rp.header.Set("Cookie", strings.Join(cookies, "; "))
}
// Setup Trailers
var trailer http.Header
- for _, v := range header["Trailer"] {
+ for _, v := range rp.header["Trailer"] {
for _, key := range strings.Split(v, ",") {
key = http.CanonicalHeaderKey(strings.TrimSpace(key))
switch key {
@@ -1663,53 +1735,42 @@
}
}
}
- delete(header, "Trailer")
+ delete(rp.header, "Trailer")
+
+ var url_ *url.URL
+ var requestURI string
+ if rp.method == "CONNECT" {
+ url_ = &url.URL{Host: rp.authority}
+ requestURI = rp.authority // mimic HTTP/1 server behavior
+ } else {
+ var err error
+ url_, err = url.ParseRequestURI(rp.path)
+ if err != nil {
+ return nil, nil, streamError(st.id, ErrCodeProtocol)
+ }
+ requestURI = rp.path
+ }
body := &requestBody{
conn: sc,
stream: st,
needsContinue: needsContinue,
}
- var url_ *url.URL
- var requestURI string
- if isConnect {
- url_ = &url.URL{Host: authority}
- requestURI = authority // mimic HTTP/1 server behavior
- } else {
- var err error
- url_, err = url.ParseRequestURI(path)
- if err != nil {
- return nil, nil, streamError(f.StreamID, ErrCodeProtocol)
- }
- requestURI = path
- }
req := &http.Request{
- Method: method,
+ Method: rp.method,
URL: url_,
RemoteAddr: sc.remoteAddrStr,
- Header: header,
+ Header: rp.header,
RequestURI: requestURI,
Proto: "HTTP/2.0",
ProtoMajor: 2,
ProtoMinor: 0,
TLS: tlsState,
- Host: authority,
+ Host: rp.authority,
Body: body,
Trailer: trailer,
}
req = requestWithContext(req, st.ctx)
- if bodyOpen {
- st.reqBuf = getRequestBodyBuf()
- body.pipe = &pipe{
- b: &fixedBuffer{buf: st.reqBuf},
- }
-
- if vv, ok := header["Content-Length"]; ok {
- req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64)
- } else {
- req.ContentLength = -1
- }
- }
rws := responseWriterStatePool.Get().(*responseWriterState)
bwSave := rws.bw
@@ -2267,6 +2328,194 @@
responseWriterStatePool.Put(rws)
}
+// Push errors.
+var (
+ ErrRecursivePush = errors.New("http2: recursive push not allowed")
+ ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS")
+)
+
+// pushOptions is the internal version of http.PushOptions, which we
+// cannot include here because it's only defined in Go 1.8 and later.
+type pushOptions struct {
+ Method string
+ Header http.Header
+}
+
+func (w *responseWriter) push(target string, opts pushOptions) error {
+ st := w.rws.stream
+ sc := st.sc
+ sc.serveG.checkNotOn()
+
+ // No recursive pushes: "PUSH_PROMISE frames MUST only be sent on a peer-initiated stream."
+ // http://tools.ietf.org/html/rfc7540#section-6.6
+ if st.isPushed() {
+ return ErrRecursivePush
+ }
+
+ // Default options.
+ if opts.Method == "" {
+ opts.Method = "GET"
+ }
+ if opts.Header == nil {
+ opts.Header = http.Header{}
+ }
+ wantScheme := "http"
+ if w.rws.req.TLS != nil {
+ wantScheme = "https"
+ }
+
+ // Validate the request.
+ u, err := url.Parse(target)
+ if err != nil {
+ return err
+ }
+ if u.Scheme == "" {
+ if !strings.HasPrefix(target, "/") {
+ return fmt.Errorf("target must be an absolute URL or an absolute path: %q", target)
+ }
+ u.Scheme = wantScheme
+ u.Host = w.rws.req.Host
+ } else {
+ if u.Scheme != wantScheme {
+ return fmt.Errorf("cannot push URL with scheme %q from request with scheme %q", u.Scheme, wantScheme)
+ }
+ if u.Host == "" {
+ return errors.New("URL must have a host")
+ }
+ }
+ for k := range opts.Header {
+ if strings.HasPrefix(k, ":") {
+ return fmt.Errorf("promised request headers cannot include psuedo header %q", k)
+ }
+ // These headers are meaningful only if the request has a body,
+ // but PUSH_PROMISE requests cannot have a body.
+ // http://tools.ietf.org/html/rfc7540#section-8.2
+ // Also disallow Host, since the promised URL must be absolute.
+ switch strings.ToLower(k) {
+ case "content-length", "content-encoding", "trailer", "te", "expect", "host":
+ return fmt.Errorf("promised request headers cannot include %q", k)
+ }
+ }
+ if err := checkValidHTTP2RequestHeaders(opts.Header); err != nil {
+ return err
+ }
+
+ // The RFC effectively limits promised requests to GET and HEAD:
+ // "Promised requests MUST be cacheable [GET, HEAD, or POST], and MUST be safe [GET or HEAD]"
+ // http://tools.ietf.org/html/rfc7540#section-8.2
+ if opts.Method != "GET" && opts.Method != "HEAD" {
+ return fmt.Errorf("method %q must be GET or HEAD", opts.Method)
+ }
+
+ msg := startPushRequest{
+ parent: st,
+ method: opts.Method,
+ url: u,
+ header: cloneHeader(opts.Header),
+ done: errChanPool.Get().(chan error),
+ }
+
+ select {
+ case <-sc.doneServing:
+ return errClientDisconnected
+ case <-st.cw:
+ return errStreamClosed
+ case sc.wantStartPushCh <- msg:
+ }
+
+ select {
+ case <-sc.doneServing:
+ return errClientDisconnected
+ case <-st.cw:
+ return errStreamClosed
+ case err := <-msg.done:
+ errChanPool.Put(msg.done)
+ return err
+ }
+}
+
+type startPushRequest struct {
+ parent *stream
+ method string
+ url *url.URL
+ header http.Header
+ done chan error
+}
+
+func (sc *serverConn) startPush(msg startPushRequest) {
+ sc.serveG.check()
+
+ // http://tools.ietf.org/html/rfc7540#section-6.6.
+ // PUSH_PROMISE frames MUST only be sent on a peer-initiated stream that
+ // is in either the "open" or "half-closed (remote)" state.
+ if msg.parent.state != stateOpen && msg.parent.state != stateHalfClosedRemote {
+ // responseWriter.Push checks that the stream is peer-initiaed.
+ msg.done <- errStreamClosed
+ return
+ }
+
+ // http://tools.ietf.org/html/rfc7540#section-6.6.
+ if !sc.pushEnabled {
+ msg.done <- http.ErrNotSupported
+ return
+ }
+
+ // PUSH_PROMISE frames must be sent in increasing order by stream ID, so
+ // we allocate an ID for the promised stream lazily, when the PUSH_PROMISE
+ // is written. Once the ID is allocated, we start the request handler.
+ allocatePromisedID := func() (uint32, error) {
+ sc.serveG.check()
+
+ // Check this again, just in case. Technically, we might have received
+ // an updated SETTINGS by the time we got around to writing this frame.
+ if !sc.pushEnabled {
+ return 0, http.ErrNotSupported
+ }
+ // http://tools.ietf.org/html/rfc7540#section-6.5.2.
+ if sc.curPushedStreams+1 > sc.clientMaxStreams {
+ return 0, ErrPushLimitReached
+ }
+
+ // http://tools.ietf.org/html/rfc7540#section-5.1.1.
+ // Streams initiated by the server MUST use even-numbered identifiers.
+ sc.maxPushPromiseID += 2
+ promisedID := sc.maxPushPromiseID
+
+ // http://tools.ietf.org/html/rfc7540#section-8.2.
+ // Strictly speaking, the new stream should start in "reserved (local)", then
+ // transition to "half closed (remote)" after sending the initial HEADERS, but
+ // we start in "half closed (remote)" for simplicity.
+ // See further comments at the definition of stateHalfClosedRemote.
+ promised := sc.newStream(promisedID, msg.parent.id, stateHalfClosedRemote)
+ rw, req, err := sc.newWriterAndRequestNoBody(promised, requestParam{
+ method: msg.method,
+ scheme: msg.url.Scheme,
+ authority: msg.url.Host,
+ path: msg.url.RequestURI(),
+ header: msg.header,
+ })
+ if err != nil {
+ // Should not happen, since we've already validated msg.url.
+ panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err))
+ }
+
+ go sc.runHandler(rw, req, sc.handler.ServeHTTP)
+ return promisedID, nil
+ }
+
+ sc.writeFrame(FrameWriteRequest{
+ write: &writePushPromise{
+ streamID: msg.parent.id,
+ method: msg.method,
+ url: msg.url,
+ h: msg.header,
+ allocatePromisedID: allocatePromisedID,
+ },
+ stream: msg.parent,
+ done: msg.done,
+ })
+}
+
// foreachHeaderElement splits v according to the "#rule" construction
// in RFC 2616 section 2.1 and calls fn for each non-empty element.
func foreachHeaderElement(v string, fn func(string)) {
@@ -2294,16 +2543,16 @@
"Upgrade",
}
-// checkValidHTTP2Request checks whether req is a valid HTTP/2 request,
+// checkValidHTTP2RequestHeaders checks whether h is a valid HTTP/2 request,
// per RFC 7540 Section 8.1.2.2.
// The returned error is reported to users.
-func checkValidHTTP2Request(req *http.Request) error {
- for _, h := range connHeaders {
- if _, ok := req.Header[h]; ok {
- return fmt.Errorf("request header %q is not valid in HTTP/2", h)
+func checkValidHTTP2RequestHeaders(h http.Header) error {
+ for _, k := range connHeaders {
+ if _, ok := h[k]; ok {
+ return fmt.Errorf("request header %q is not valid in HTTP/2", k)
}
}
- te := req.Header["Te"]
+ te := h["Te"]
if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) {
return errors.New(`request header "TE" may only be "trailers" in HTTP/2`)
}
diff --git a/http2/server_push_test.go b/http2/server_push_test.go
new file mode 100644
index 0000000..5668839
--- /dev/null
+++ b/http2/server_push_test.go
@@ -0,0 +1,383 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.8
+
+package http2
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "net/http"
+ "reflect"
+ "strconv"
+ "testing"
+ "time"
+)
+
+func TestServer_Push_Success(t *testing.T) {
+ const (
+ mainBody = "<html>index page</html>"
+ pushedBody = "<html>pushed page</html>"
+ userAgent = "testagent"
+ cookie = "testcookie"
+ )
+
+ var stURL string
+ checkPromisedReq := func(r *http.Request, wantMethod string, wantH http.Header) error {
+ if got, want := r.Method, wantMethod; got != want {
+ return fmt.Errorf("promised Req.Method=%q, want %q", got, want)
+ }
+ if got, want := r.Header, wantH; !reflect.DeepEqual(got, want) {
+ return fmt.Errorf("promised Req.Header=%q, want %q", got, want)
+ }
+ if got, want := "https://"+r.Host, stURL; got != want {
+ return fmt.Errorf("promised Req.Host=%q, want %q", got, want)
+ }
+ if r.Body == nil {
+ return fmt.Errorf("nil Body")
+ }
+ if buf, err := ioutil.ReadAll(r.Body); err != nil || len(buf) != 0 {
+ return fmt.Errorf("ReadAll(Body)=%q,%v, want '',nil", buf, err)
+ }
+ return nil
+ }
+
+ errc := make(chan error, 3)
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.RequestURI() {
+ case "/":
+ // Push "/pushed?get" as a GET request, using an absolute URL.
+ opt := &http.PushOptions{
+ Header: http.Header{
+ "User-Agent": {userAgent},
+ },
+ }
+ if err := w.(http.Pusher).Push(stURL+"/pushed?get", opt); err != nil {
+ errc <- fmt.Errorf("error pushing /pushed?get: %v", err)
+ return
+ }
+ // Push "/pushed?head" as a HEAD request, using a path.
+ opt = &http.PushOptions{
+ Method: "HEAD",
+ Header: http.Header{
+ "User-Agent": {userAgent},
+ "Cookie": {cookie},
+ },
+ }
+ if err := w.(http.Pusher).Push("/pushed?head", opt); err != nil {
+ errc <- fmt.Errorf("error pushing /pushed?head: %v", err)
+ return
+ }
+ w.Header().Set("Content-Type", "text/html")
+ w.Header().Set("Content-Length", strconv.Itoa(len(mainBody)))
+ w.WriteHeader(200)
+ io.WriteString(w, mainBody)
+ errc <- nil
+
+ case "/pushed?get":
+ wantH := http.Header{}
+ wantH.Set("User-Agent", userAgent)
+ if err := checkPromisedReq(r, "GET", wantH); err != nil {
+ errc <- fmt.Errorf("/pushed?get: %v", err)
+ return
+ }
+ w.Header().Set("Content-Type", "text/html")
+ w.Header().Set("Content-Length", strconv.Itoa(len(pushedBody)))
+ w.WriteHeader(200)
+ io.WriteString(w, pushedBody)
+ errc <- nil
+
+ case "/pushed?head":
+ wantH := http.Header{}
+ wantH.Set("User-Agent", userAgent)
+ wantH.Set("Cookie", cookie)
+ if err := checkPromisedReq(r, "HEAD", wantH); err != nil {
+ errc <- fmt.Errorf("/pushed?head: %v", err)
+ return
+ }
+ w.WriteHeader(204)
+ errc <- nil
+
+ default:
+ errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
+ }
+ })
+ stURL = st.ts.URL
+
+ // Send one request, which should push two responses.
+ st.greet()
+ getSlash(st)
+ for k := 0; k < 3; k++ {
+ select {
+ case <-time.After(2 * time.Second):
+ t.Errorf("timeout waiting for handler %d to finish", k)
+ case err := <-errc:
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+ }
+
+ checkPushPromise := func(f Frame, promiseID uint32, wantH [][2]string) error {
+ pp, ok := f.(*PushPromiseFrame)
+ if !ok {
+ return fmt.Errorf("got a %T; want *PushPromiseFrame", f)
+ }
+ if !pp.HeadersEnded() {
+ return fmt.Errorf("want END_HEADERS flag in PushPromiseFrame")
+ }
+ if got, want := pp.PromiseID, promiseID; got != want {
+ return fmt.Errorf("got PromiseID %v; want %v", got, want)
+ }
+ gotH := st.decodeHeader(pp.HeaderBlockFragment())
+ if !reflect.DeepEqual(gotH, wantH) {
+ return fmt.Errorf("got promised headers %v; want %v", gotH, wantH)
+ }
+ return nil
+ }
+ checkHeaders := func(f Frame, wantH [][2]string) error {
+ hf, ok := f.(*HeadersFrame)
+ if !ok {
+ return fmt.Errorf("got a %T; want *HeadersFrame", f)
+ }
+ gotH := st.decodeHeader(hf.HeaderBlockFragment())
+ if !reflect.DeepEqual(gotH, wantH) {
+ return fmt.Errorf("got response headers %v; want %v", gotH, wantH)
+ }
+ return nil
+ }
+ checkData := func(f Frame, wantData string) error {
+ df, ok := f.(*DataFrame)
+ if !ok {
+ return fmt.Errorf("got a %T; want *DataFrame", f)
+ }
+ if gotData := string(df.Data()); gotData != wantData {
+ return fmt.Errorf("got response data %q; want %q", gotData, wantData)
+ }
+ return nil
+ }
+
+ // Stream 1 has 2 PUSH_PROMISE + HEADERS + DATA
+ // Stream 2 has HEADERS + DATA
+ // Stream 4 has HEADERS
+ expected := map[uint32][]func(Frame) error{
+ 1: {
+ func(f Frame) error {
+ return checkPushPromise(f, 2, [][2]string{
+ {":method", "GET"},
+ {":scheme", "https"},
+ {":authority", st.ts.Listener.Addr().String()},
+ {":path", "/pushed?get"},
+ {"user-agent", userAgent},
+ })
+ },
+ func(f Frame) error {
+ return checkPushPromise(f, 4, [][2]string{
+ {":method", "HEAD"},
+ {":scheme", "https"},
+ {":authority", st.ts.Listener.Addr().String()},
+ {":path", "/pushed?head"},
+ {"cookie", cookie},
+ {"user-agent", userAgent},
+ })
+ },
+ func(f Frame) error {
+ return checkHeaders(f, [][2]string{
+ {":status", "200"},
+ {"content-type", "text/html"},
+ {"content-length", strconv.Itoa(len(mainBody))},
+ })
+ },
+ func(f Frame) error {
+ return checkData(f, mainBody)
+ },
+ },
+ 2: {
+ func(f Frame) error {
+ return checkHeaders(f, [][2]string{
+ {":status", "200"},
+ {"content-type", "text/html"},
+ {"content-length", strconv.Itoa(len(pushedBody))},
+ })
+ },
+ func(f Frame) error {
+ return checkData(f, pushedBody)
+ },
+ },
+ 4: {
+ func(f Frame) error {
+ return checkHeaders(f, [][2]string{
+ {":status", "204"},
+ })
+ },
+ },
+ }
+
+ consumed := map[uint32]int{}
+ for k := 0; len(expected) > 0; k++ {
+ f, err := st.readFrame()
+ if err != nil {
+ for id, left := range expected {
+ t.Errorf("stream %d: missing %d frames", id, len(left))
+ }
+ t.Fatalf("readFrame %d: %v", k, err)
+ }
+ id := f.Header().StreamID
+ label := fmt.Sprintf("stream %d, frame %d", id, consumed[id])
+ if len(expected[id]) == 0 {
+ t.Fatalf("%s: unexpected frame %#+v", label, f)
+ }
+ check := expected[id][0]
+ expected[id] = expected[id][1:]
+ if len(expected[id]) == 0 {
+ delete(expected, id)
+ }
+ if err := check(f); err != nil {
+ t.Fatalf("%s: %v", label, err)
+ }
+ consumed[id]++
+ }
+}
+
+func TestServer_Push_RejectRecursivePush(t *testing.T) {
+ // Expect two requests, but might get three if there's a bug and the second push succeeds.
+ errc := make(chan error, 3)
+ handler := func(w http.ResponseWriter, r *http.Request) error {
+ baseURL := "https://" + r.Host
+ switch r.URL.Path {
+ case "/":
+ if err := w.(http.Pusher).Push(baseURL+"/push1", nil); err != nil {
+ return fmt.Errorf("first Push()=%v, want nil", err)
+ }
+ return nil
+
+ case "/push1":
+ if got, want := w.(http.Pusher).Push(baseURL+"/push2", nil), ErrRecursivePush; got != want {
+ return fmt.Errorf("Push()=%v, want %v", got, want)
+ }
+ return nil
+
+ default:
+ return fmt.Errorf("unexpected path: %q", r.URL.Path)
+ }
+ }
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ errc <- handler(w, r)
+ })
+ defer st.Close()
+ st.greet()
+ getSlash(st)
+ if err := <-errc; err != nil {
+ t.Errorf("First request failed: %v", err)
+ }
+ if err := <-errc; err != nil {
+ t.Errorf("Second request failed: %v", err)
+ }
+}
+
+func testServer_Push_RejectSingleRequest(t *testing.T, doPush func(http.Pusher, *http.Request) error, settings ...Setting) {
+ // Expect one request, but might get two if there's a bug and the push succeeds.
+ errc := make(chan error, 2)
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ errc <- doPush(w.(http.Pusher), r)
+ })
+ defer st.Close()
+ st.greet()
+ if err := st.fr.WriteSettings(settings...); err != nil {
+ st.t.Fatalf("WriteSettings: %v", err)
+ }
+ st.wantSettingsAck()
+ getSlash(st)
+ if err := <-errc; err != nil {
+ t.Error(err)
+ }
+ // Should not get a PUSH_PROMISE frame.
+ hf := st.wantHeaders()
+ if !hf.StreamEnded() {
+ t.Error("stream should end after headers")
+ }
+}
+
+func TestServer_Push_RejectIfDisabled(t *testing.T) {
+ testServer_Push_RejectSingleRequest(t,
+ func(p http.Pusher, r *http.Request) error {
+ if got, want := p.Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
+ return fmt.Errorf("Push()=%v, want %v", got, want)
+ }
+ return nil
+ },
+ Setting{SettingEnablePush, 0})
+}
+
+func TestServer_Push_RejectWhenNoConcurrentStreams(t *testing.T) {
+ testServer_Push_RejectSingleRequest(t,
+ func(p http.Pusher, r *http.Request) error {
+ if got, want := p.Push("https://"+r.Host+"/pushed", nil), ErrPushLimitReached; got != want {
+ return fmt.Errorf("Push()=%v, want %v", got, want)
+ }
+ return nil
+ },
+ Setting{SettingMaxConcurrentStreams, 0})
+}
+
+func TestServer_Push_RejectWrongScheme(t *testing.T) {
+ testServer_Push_RejectSingleRequest(t,
+ func(p http.Pusher, r *http.Request) error {
+ if err := p.Push("http://"+r.Host+"/pushed", nil); err == nil {
+ return errors.New("Push() should have failed (push target URL is http)")
+ }
+ return nil
+ })
+}
+
+func TestServer_Push_RejectMissingHost(t *testing.T) {
+ testServer_Push_RejectSingleRequest(t,
+ func(p http.Pusher, r *http.Request) error {
+ if err := p.Push("https:pushed", nil); err == nil {
+ return errors.New("Push() should have failed (push target URL missing host)")
+ }
+ return nil
+ })
+}
+
+func TestServer_Push_RejectRelativePath(t *testing.T) {
+ testServer_Push_RejectSingleRequest(t,
+ func(p http.Pusher, r *http.Request) error {
+ if err := p.Push("../test", nil); err == nil {
+ return errors.New("Push() should have failed (push target is a relative path)")
+ }
+ return nil
+ })
+}
+
+func TestServer_Push_RejectForbiddenMethod(t *testing.T) {
+ testServer_Push_RejectSingleRequest(t,
+ func(p http.Pusher, r *http.Request) error {
+ if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Method: "POST"}); err == nil {
+ return errors.New("Push() should have failed (cannot promise a POST)")
+ }
+ return nil
+ })
+}
+
+func TestServer_Push_RejectForbiddenHeader(t *testing.T) {
+ testServer_Push_RejectSingleRequest(t,
+ func(p http.Pusher, r *http.Request) error {
+ header := http.Header{
+ "Content-Length": {"10"},
+ "Content-Encoding": {"gzip"},
+ "Trailer": {"Foo"},
+ "Te": {"trailers"},
+ "Host": {"test.com"},
+ ":authority": {"test.com"},
+ }
+ if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Header: header}); err == nil {
+ return errors.New("Push() should have failed (forbidden headers)")
+ }
+ return nil
+ })
+}
diff --git a/http2/server_test.go b/http2/server_test.go
index 013abe1..cbc89d4 100644
--- a/http2/server_test.go
+++ b/http2/server_test.go
@@ -287,37 +287,42 @@
// encodeHeader encodes headers and returns their HPACK bytes. headers
// must contain an even number of key/value pairs. There may be
// multiple pairs for keys (e.g. "cookie"). The :method, :path, and
-// :scheme headers default to GET, / and https.
+// :scheme headers default to GET, / and https. The :authority header
+// defaults to st.ts.Listener.Addr().
func (st *serverTester) encodeHeader(headers ...string) []byte {
if len(headers)%2 == 1 {
panic("odd number of kv args")
}
st.headerBuf.Reset()
+ defaultAuthority := st.ts.Listener.Addr().String()
if len(headers) == 0 {
// Fast path, mostly for benchmarks, so test code doesn't pollute
// profiles when we're looking to improve server allocations.
st.encodeHeaderField(":method", "GET")
- st.encodeHeaderField(":path", "/")
st.encodeHeaderField(":scheme", "https")
+ st.encodeHeaderField(":authority", defaultAuthority)
+ st.encodeHeaderField(":path", "/")
return st.headerBuf.Bytes()
}
if len(headers) == 2 && headers[0] == ":method" {
// Another fast path for benchmarks.
st.encodeHeaderField(":method", headers[1])
- st.encodeHeaderField(":path", "/")
st.encodeHeaderField(":scheme", "https")
+ st.encodeHeaderField(":authority", defaultAuthority)
+ st.encodeHeaderField(":path", "/")
return st.headerBuf.Bytes()
}
pseudoCount := map[string]int{}
- keys := []string{":method", ":path", ":scheme"}
+ keys := []string{":method", ":scheme", ":authority", ":path"}
vals := map[string][]string{
- ":method": {"GET"},
- ":path": {"/"},
- ":scheme": {"https"},
+ ":method": {"GET"},
+ ":scheme": {"https"},
+ ":authority": {defaultAuthority},
+ ":path": {"/"},
}
for len(headers) > 0 {
k, v := headers[0], headers[1]
@@ -512,7 +517,18 @@
if !sf.Header().Flags.Has(FlagSettingsAck) {
st.t.Fatal("Settings Frame didn't have ACK set")
}
+}
+func (st *serverTester) wantPushPromise() *PushPromiseFrame {
+ f, err := st.readFrame()
+ if err != nil {
+ st.t.Fatal(err)
+ }
+ ppf, ok := f.(*PushPromiseFrame)
+ if !ok {
+ st.t.Fatalf("Wanted PushPromise, received %T", ppf)
+ }
+ return ppf
}
func TestServer(t *testing.T) {
@@ -767,7 +783,7 @@
testServerRequest(t, func(st *serverTester) {
st.writeHeaders(HeadersFrameParam{
StreamID: 1, // clients send odd numbers
- BlockFragment: st.encodeHeader("host", host),
+ BlockFragment: st.encodeHeader(":authority", "", "host", host),
EndStream: true,
EndHeaders: true,
})
@@ -3314,40 +3330,40 @@
func TestCheckValidHTTP2Request(t *testing.T) {
tests := []struct {
- req *http.Request
+ h http.Header
want error
}{
{
- req: &http.Request{Header: http.Header{"Te": {"trailers"}}},
+ h: http.Header{"Te": {"trailers"}},
want: nil,
},
{
- req: &http.Request{Header: http.Header{"Te": {"trailers", "bogus"}}},
+ h: http.Header{"Te": {"trailers", "bogus"}},
want: errors.New(`request header "TE" may only be "trailers" in HTTP/2`),
},
{
- req: &http.Request{Header: http.Header{"Foo": {""}}},
+ h: http.Header{"Foo": {""}},
want: nil,
},
{
- req: &http.Request{Header: http.Header{"Connection": {""}}},
+ h: http.Header{"Connection": {""}},
want: errors.New(`request header "Connection" is not valid in HTTP/2`),
},
{
- req: &http.Request{Header: http.Header{"Proxy-Connection": {""}}},
+ h: http.Header{"Proxy-Connection": {""}},
want: errors.New(`request header "Proxy-Connection" is not valid in HTTP/2`),
},
{
- req: &http.Request{Header: http.Header{"Keep-Alive": {""}}},
+ h: http.Header{"Keep-Alive": {""}},
want: errors.New(`request header "Keep-Alive" is not valid in HTTP/2`),
},
{
- req: &http.Request{Header: http.Header{"Upgrade": {""}}},
+ h: http.Header{"Upgrade": {""}},
want: errors.New(`request header "Upgrade" is not valid in HTTP/2`),
},
}
for i, tt := range tests {
- got := checkValidHTTP2Request(tt.req)
+ got := checkValidHTTP2RequestHeaders(tt.h)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("%d. checkValidHTTP2Request = %v; want %v", i, got, tt.want)
}
diff --git a/http2/write.go b/http2/write.go
index a45d6de..1c135fd 100644
--- a/http2/write.go
+++ b/http2/write.go
@@ -9,6 +9,7 @@
"fmt"
"log"
"net/http"
+ "net/url"
"time"
"golang.org/x/net/http2/hpack"
@@ -150,6 +151,33 @@
func (writeSettingsAck) staysWithinBuffer(max int) bool { return frameHeaderLen <= max }
+// splitHeaderBlock splits headerBlock into fragments so that each fragment fits
+// in a single frame, then calls fn for each fragment. firstFrag/lastFrag are true
+// for the first/last fragment, respectively.
+func splitHeaderBlock(ctx writeContext, headerBlock []byte, fn func(ctx writeContext, frag []byte, firstFrag, lastFrag bool) error) error {
+ // For now we're lazy and just pick the minimum MAX_FRAME_SIZE
+ // that all peers must support (16KB). Later we could care
+ // more and send larger frames if the peer advertised it, but
+ // there's little point. Most headers are small anyway (so we
+ // generally won't have CONTINUATION frames), and extra frames
+ // only waste 9 bytes anyway.
+ const maxFrameSize = 16384
+
+ first := true
+ for len(headerBlock) > 0 {
+ frag := headerBlock
+ if len(frag) > maxFrameSize {
+ frag = frag[:maxFrameSize]
+ }
+ headerBlock = headerBlock[len(frag):]
+ if err := fn(ctx, frag, first, len(headerBlock) == 0); err != nil {
+ return err
+ }
+ first = false
+ }
+ return nil
+}
+
// writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames
// for HTTP response headers or trailers from a server handler.
type writeResHeaders struct {
@@ -207,39 +235,69 @@
panic("unexpected empty hpack")
}
- // For now we're lazy and just pick the minimum MAX_FRAME_SIZE
- // that all peers must support (16KB). Later we could care
- // more and send larger frames if the peer advertised it, but
- // there's little point. Most headers are small anyway (so we
- // generally won't have CONTINUATION frames), and extra frames
- // only waste 9 bytes anyway.
- const maxFrameSize = 16384
+ return splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock)
+}
- first := true
- for len(headerBlock) > 0 {
- frag := headerBlock
- if len(frag) > maxFrameSize {
- frag = frag[:maxFrameSize]
- }
- headerBlock = headerBlock[len(frag):]
- endHeaders := len(headerBlock) == 0
- var err error
- if first {
- first = false
- err = ctx.Framer().WriteHeaders(HeadersFrameParam{
- StreamID: w.streamID,
- BlockFragment: frag,
- EndStream: w.endStream,
- EndHeaders: endHeaders,
- })
- } else {
- err = ctx.Framer().WriteContinuation(w.streamID, endHeaders, frag)
- }
- if err != nil {
- return err
- }
+func (w *writeResHeaders) writeHeaderBlock(ctx writeContext, frag []byte, firstFrag, lastFrag bool) error {
+ if firstFrag {
+ return ctx.Framer().WriteHeaders(HeadersFrameParam{
+ StreamID: w.streamID,
+ BlockFragment: frag,
+ EndStream: w.endStream,
+ EndHeaders: lastFrag,
+ })
+ } else {
+ return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag)
}
- return nil
+}
+
+// writePushPromise is a request to write a PUSH_PROMISE and 0+ CONTINUATION frames.
+type writePushPromise struct {
+ streamID uint32 // pusher stream
+ method string // for :method
+ url *url.URL // for :scheme, :authority, :path
+ h http.Header
+
+ // Creates an ID for a pushed stream. This runs on serveG just before
+ // the frame is written. The returned ID is copied to promisedID.
+ allocatePromisedID func() (uint32, error)
+ promisedID uint32
+}
+
+func (w *writePushPromise) staysWithinBuffer(max int) bool {
+ // TODO: see writeResHeaders.staysWithinBuffer
+ return false
+}
+
+func (w *writePushPromise) writeFrame(ctx writeContext) error {
+ enc, buf := ctx.HeaderEncoder()
+ buf.Reset()
+
+ encKV(enc, ":method", w.method)
+ encKV(enc, ":scheme", w.url.Scheme)
+ encKV(enc, ":authority", w.url.Host)
+ encKV(enc, ":path", w.url.RequestURI())
+ encodeHeaders(enc, w.h, nil)
+
+ headerBlock := buf.Bytes()
+ if len(headerBlock) == 0 {
+ panic("unexpected empty hpack")
+ }
+
+ return splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock)
+}
+
+func (w *writePushPromise) writeHeaderBlock(ctx writeContext, frag []byte, firstFrag, lastFrag bool) error {
+ if firstFrag {
+ return ctx.Framer().WritePushPromise(PushPromiseParam{
+ StreamID: w.streamID,
+ PromiseID: w.promisedID,
+ BlockFragment: frag,
+ EndHeaders: lastFrag,
+ })
+ } else {
+ return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag)
+ }
}
type write100ContinueHeadersFrame struct {
@@ -274,6 +332,8 @@
return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n)
}
+// encodeHeaders encodes an http.Header. If keys is not nil, then (k, h[k])
+// is encoded only only if k is in keys.
func encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) {
if keys == nil {
sorter := sorterPool.Get().(*sorter)