Use the same hpack encoder on a transport and share it between RPCs. (#1536)
diff --git a/stats/stats.go b/stats/stats.go
index b64c429..d5aa2f7 100644
--- a/stats/stats.go
+++ b/stats/stats.go
@@ -135,8 +135,6 @@
type OutHeader struct {
// Client is true if this OutHeader is from client side.
Client bool
- // WireLength is the wire length of header.
- WireLength int
// The following fields are valid only if Client is true.
// FullMethod is the full RPC method string, i.e., /package.service/method.
diff --git a/stats/stats_test.go b/stats/stats_test.go
index 8865d3f..d66485f 100644
--- a/stats/stats_test.go
+++ b/stats/stats_test.go
@@ -444,10 +444,6 @@
if d.ctx == nil {
t.Fatalf("d.ctx = nil, want <non-nil>")
}
- // TODO check real length, not just > 0.
- if st.WireLength <= 0 {
- t.Fatalf("st.Lenght = 0, want > 0")
- }
if !d.client {
if st.FullMethod != e.method {
t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
@@ -530,18 +526,13 @@
func checkInTrailer(t *testing.T, d *gotData, e *expectedData) {
var (
ok bool
- st *stats.InTrailer
)
- if st, ok = d.s.(*stats.InTrailer); !ok {
+ if _, ok = d.s.(*stats.InTrailer); !ok {
t.Fatalf("got %T, want InTrailer", d.s)
}
if d.ctx == nil {
t.Fatalf("d.ctx = nil, want <non-nil>")
}
- // TODO check real length, not just > 0.
- if st.WireLength <= 0 {
- t.Fatalf("st.Lenght = 0, want > 0")
- }
}
func checkOutHeader(t *testing.T, d *gotData, e *expectedData) {
@@ -555,10 +546,6 @@
if d.ctx == nil {
t.Fatalf("d.ctx = nil, want <non-nil>")
}
- // TODO check real length, not just > 0.
- if st.WireLength <= 0 {
- t.Fatalf("st.Lenght = 0, want > 0")
- }
if d.client {
if st.FullMethod != e.method {
t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
@@ -642,10 +629,6 @@
if st.Client {
t.Fatalf("st IsClient = true, want false")
}
- // TODO check real length, not just > 0.
- if st.WireLength <= 0 {
- t.Fatalf("st.Lenght = 0, want > 0")
- }
}
func checkEnd(t *testing.T, d *gotData, e *expectedData) {
diff --git a/transport/control.go b/transport/control.go
index 77914de..dd1a8d4 100644
--- a/transport/control.go
+++ b/transport/control.go
@@ -26,6 +26,7 @@
"time"
"golang.org/x/net/http2"
+ "golang.org/x/net/http2/hpack"
)
const (
@@ -56,7 +57,9 @@
// control tasks, e.g., flow control, settings, streaming resetting, etc.
type headerFrame struct {
- p http2.HeadersFrameParam
+ streamID uint32
+ hf []hpack.HeaderField
+ endStream bool
}
func (*headerFrame) item() {}
diff --git a/transport/http2_client.go b/transport/http2_client.go
index 31fed9e..92ad868 100644
--- a/transport/http2_client.go
+++ b/transport/http2_client.go
@@ -193,6 +193,7 @@
icwz = opts.InitialConnWindowSize
dynamicWindow = false
}
+ var buf bytes.Buffer
t := &http2Client{
ctx: ctx,
target: addr.Addr,
@@ -209,6 +210,8 @@
goAway: make(chan struct{}),
awakenKeepalive: make(chan struct{}, 1),
framer: newFramer(conn),
+ hBuf: &buf,
+ hEnc: hpack.NewEncoder(&buf),
controlBuf: newControlBuffer(),
fc: &inFlow{limit: uint32(icwz)},
sendQuotaPool: newQuotaPool(defaultWindowSize),
@@ -361,7 +364,7 @@
authData[k] = v
}
}
- callAuthData := make(map[string]string)
+ callAuthData := map[string]string{}
// Check if credentials.PerRPCCredentials were provided via call options.
// Note: if these credentials are provided both via dial options and call
// options, then both sets of credentials will be applied.
@@ -401,40 +404,40 @@
if sq > 1 {
t.streamsQuota.add(sq - 1)
}
- // HPACK encodes various headers.
- hBuf := bytes.NewBuffer([]byte{})
- hEnc := hpack.NewEncoder(hBuf)
- hEnc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"})
- hEnc.WriteField(hpack.HeaderField{Name: ":scheme", Value: t.scheme})
- hEnc.WriteField(hpack.HeaderField{Name: ":path", Value: callHdr.Method})
- hEnc.WriteField(hpack.HeaderField{Name: ":authority", Value: callHdr.Host})
- hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
- hEnc.WriteField(hpack.HeaderField{Name: "user-agent", Value: t.userAgent})
- hEnc.WriteField(hpack.HeaderField{Name: "te", Value: "trailers"})
+ // TODO(mmukhi): Benchmark if the perfomance gets better if count the metadata and other header fields
+ // first and create a slice of that exact size.
+ // Make the slice of certain predictable size to reduce allocations made by append.
+ hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te
+ hfLen += len(authData) + len(callAuthData)
+ headerFields := make([]hpack.HeaderField, 0, hfLen)
+ headerFields = append(headerFields, hpack.HeaderField{Name: ":method", Value: "POST"})
+ headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: t.scheme})
+ headerFields = append(headerFields, hpack.HeaderField{Name: ":path", Value: callHdr.Method})
+ headerFields = append(headerFields, hpack.HeaderField{Name: ":authority", Value: callHdr.Host})
+ headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
+ headerFields = append(headerFields, hpack.HeaderField{Name: "user-agent", Value: t.userAgent})
+ headerFields = append(headerFields, hpack.HeaderField{Name: "te", Value: "trailers"})
if callHdr.SendCompress != "" {
- hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
+ headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
}
if dl, ok := ctx.Deadline(); ok {
// Send out timeout regardless its value. The server can detect timeout context by itself.
+ // TODO(mmukhi): Perhaps this field should be updated when actually writing out to the wire.
timeout := dl.Sub(time.Now())
- hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
+ headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
}
-
for k, v := range authData {
- hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
+ headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
for k, v := range callAuthData {
- hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
+ headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
- var (
- endHeaders bool
- )
if b := stats.OutgoingTags(ctx); b != nil {
- hEnc.WriteField(hpack.HeaderField{Name: "grpc-tags-bin", Value: encodeBinHeader(b)})
+ headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-tags-bin", Value: encodeBinHeader(b)})
}
if b := stats.OutgoingTrace(ctx); b != nil {
- hEnc.WriteField(hpack.HeaderField{Name: "grpc-trace-bin", Value: encodeBinHeader(b)})
+ headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-trace-bin", Value: encodeBinHeader(b)})
}
if md, ok := metadata.FromOutgoingContext(ctx); ok {
for k, vv := range md {
@@ -443,7 +446,7 @@
continue
}
for _, v := range vv {
- hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
+ headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
}
}
@@ -453,7 +456,7 @@
continue
}
for _, v := range vv {
- hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
+ headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
}
}
@@ -482,34 +485,11 @@
default:
}
}
- first := true
- bufLen := hBuf.Len()
- // Sends the headers in a single batch even when they span multiple frames.
- for !endHeaders {
- size := hBuf.Len()
- if size > http2MaxFrameLen {
- size = http2MaxFrameLen
- } else {
- endHeaders = true
- }
- if first {
- // Sends a HeadersFrame to server to start a new stream.
- p := http2.HeadersFrameParam{
- StreamID: s.id,
- BlockFragment: hBuf.Next(size),
- EndStream: false,
- EndHeaders: endHeaders,
- }
- // Do a force flush for the buffered frames iff it is the last headers frame
- // and there is header metadata to be sent. Otherwise, there is flushing until
- // the corresponding data frame is written.
- t.controlBuf.put(&headerFrame{p})
- first = false
- } else {
- // Sends Continuation frames for the leftover headers.
- t.controlBuf.put(&continuationFrame{streamID: s.id, endHeaders: endHeaders, headerBlockFragment: hBuf.Next(size)})
- }
- }
+ t.controlBuf.put(&headerFrame{
+ streamID: s.id,
+ hf: headerFields,
+ endStream: false,
+ })
t.mu.Unlock()
s.mu.Lock()
@@ -519,7 +499,6 @@
if t.statsHandler != nil {
outHeader := &stats.OutHeader{
Client: true,
- WireLength: bufLen,
FullMethod: callHdr.Method,
RemoteAddr: t.remoteAddr,
LocalAddr: t.localAddr,
@@ -770,7 +749,7 @@
return
}
if w := s.fc.maybeAdjust(n); w > 0 {
- // Piggyback conneciton's window update along.
+ // Piggyback connection's window update along.
if cw := t.fc.resetPendingUpdate(); cw > 0 {
t.controlBuf.put(&windowUpdate{0, cw})
}
@@ -1200,6 +1179,9 @@
}
}
+// TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer)
+// is duplicated between the client and the server.
+// The transport layer needs to be refactored to take care of this.
func (t *http2Client) itemHandler(i item) error {
var err error
defer func() {
@@ -1214,9 +1196,38 @@
i.f()
}
case *headerFrame:
- err = t.framer.fr.WriteHeaders(i.p)
- case *continuationFrame:
- err = t.framer.fr.WriteContinuation(i.streamID, i.endHeaders, i.headerBlockFragment)
+ t.hBuf.Reset()
+ for _, f := range i.hf {
+ t.hEnc.WriteField(f)
+ }
+ endHeaders := false
+ first := true
+ for !endHeaders {
+ size := t.hBuf.Len()
+ if size > http2MaxFrameLen {
+ size = http2MaxFrameLen
+ } else {
+ endHeaders = true
+ }
+ if first {
+ first = false
+ err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{
+ StreamID: i.streamID,
+ BlockFragment: t.hBuf.Next(size),
+ EndStream: i.endStream,
+ EndHeaders: endHeaders,
+ })
+ } else {
+ err = t.framer.fr.WriteContinuation(
+ i.streamID,
+ endHeaders,
+ t.hBuf.Next(size),
+ )
+ }
+ if err != nil {
+ return err
+ }
+ }
case *windowUpdate:
err = t.framer.fr.WriteWindowUpdate(i.streamID, i.increment)
case *settings:
diff --git a/transport/http2_server.go b/transport/http2_server.go
index 4f62cba..0f0e759 100644
--- a/transport/http2_server.go
+++ b/transport/http2_server.go
@@ -63,6 +63,8 @@
// blocking forever after Close.
shutdownChan chan struct{}
framer *framer
+ hBuf *bytes.Buffer // the buffer for HPACK encoding
+ hEnc *hpack.Encoder // HPACK encoder
// The max number of concurrent streams.
maxStreams uint32
// controlBuf delivers all the control related tasks (e.g., window
@@ -175,6 +177,7 @@
if kep.MinTime == 0 {
kep.MinTime = defaultKeepalivePolicyMinTime
}
+ var buf bytes.Buffer
t := &http2Server{
ctx: context.Background(),
conn: conn,
@@ -182,6 +185,8 @@
localAddr: conn.LocalAddr(),
authInfo: config.AuthInfo,
framer: framer,
+ hBuf: &buf,
+ hEnc: hpack.NewEncoder(&buf),
maxStreams: maxStreams,
inTapHandle: config.InTapHandle,
controlBuf: newControlBuffer(),
@@ -639,7 +644,7 @@
t.mu.Unlock()
if ns < 1 && !t.kep.PermitWithoutStream {
// Keepalive shouldn't be active thus, this new ping should
- // have come after atleast defaultPingTimeout.
+ // have come after at least defaultPingTimeout.
if t.lastPingAt.Add(defaultPingTimeout).After(now) {
t.pingStrikes++
}
@@ -669,34 +674,6 @@
}
}
-func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) error {
- first := true
- endHeaders := false
- // Sends the headers in a single batch.
- for !endHeaders {
- size := b.Len()
- if size > http2MaxFrameLen {
- size = http2MaxFrameLen
- } else {
- endHeaders = true
- }
- if first {
- p := http2.HeadersFrameParam{
- StreamID: s.id,
- BlockFragment: b.Next(size),
- EndStream: endStream,
- EndHeaders: endHeaders,
- }
- t.controlBuf.put(&headerFrame{p})
- first = false
- } else {
- t.controlBuf.put(&continuationFrame{streamID: s.id, endHeaders: endHeaders, headerBlockFragment: b.Next(size)})
- }
- }
- atomic.StoreUint32(&t.resetPingStrikes, 1)
- return nil
-}
-
// WriteHeader sends the header metedata md back to the client.
func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
select {
@@ -722,13 +699,13 @@
}
md = s.header
s.mu.Unlock()
-
- hBuf := bytes.NewBuffer([]byte{}) // TODO(mmukhi): Try and re-use this memory later.
- hEnc := hpack.NewEncoder(hBuf)
- hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
+ // TODO(mmukhi): Benchmark if the perfomance gets better if count the metadata and other header fields
+ // first and create a slice of that exact size.
+ headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else.
+ headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
+ headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
if s.sendCompress != "" {
- hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
+ headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
}
for k, vv := range md {
if isReservedHeader(k) {
@@ -736,16 +713,17 @@
continue
}
for _, v := range vv {
- hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
+ headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
}
- bufLen := hBuf.Len()
- if err := t.writeHeaders(s, hBuf, false); err != nil {
- return err
- }
+ t.controlBuf.put(&headerFrame{
+ streamID: s.id,
+ hf: headerFields,
+ endStream: false,
+ })
if t.stats != nil {
outHeader := &stats.OutHeader{
- WireLength: bufLen,
+ //WireLength: // TODO(mmukhi): Revisit this later, if needed.
}
t.stats.HandleRPC(s.Context(), outHeader)
}
@@ -782,18 +760,15 @@
headersSent = true
}
- hBuf := bytes.NewBuffer([]byte{}) // TODO(mmukhi): Try and re-use this memory.
- hEnc := hpack.NewEncoder(hBuf)
+ // TODO(mmukhi): Benchmark if the perfomance gets better if count the metadata and other header fields
+ // first and create a slice of that exact size.
+ headerFields := make([]hpack.HeaderField, 0, 2) // grpc-status and grpc-message will be there if none else.
if !headersSent {
- hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
- hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
+ headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
+ headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
}
- hEnc.WriteField(
- hpack.HeaderField{
- Name: "grpc-status",
- Value: strconv.Itoa(int(st.Code())),
- })
- hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())})
+ headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))})
+ headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())})
if p := st.Proto(); p != nil && len(p.Details) > 0 {
stBytes, err := proto.Marshal(p)
@@ -802,7 +777,7 @@
panic(err)
}
- hEnc.WriteField(hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)})
+ headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)})
}
// Attach the trailer metadata.
@@ -812,19 +787,16 @@
continue
}
for _, v := range vv {
- hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
+ headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
}
- bufLen := hBuf.Len()
- if err := t.writeHeaders(s, hBuf, true); err != nil {
- t.Close()
- return err
- }
+ t.controlBuf.put(&headerFrame{
+ streamID: s.id,
+ hf: headerFields,
+ endStream: true,
+ })
if t.stats != nil {
- outTrailer := &stats.OutTrailer{
- WireLength: bufLen,
- }
- t.stats.HandleRPC(s.Context(), outTrailer)
+ t.stats.HandleRPC(s.Context(), &stats.OutTrailer{})
}
t.closeStream(s)
return nil
@@ -904,7 +876,6 @@
atomic.StoreUint32(&t.resetPingStrikes, 1)
success := func() {
t.controlBuf.put(&dataFrame{streamID: s.id, endStream: false, d: p, f: func() {
- //fmt.Println("Adding quota back to localEendQuota", ps)
s.localSendQuota.add(ps)
}})
if ps < sq {
@@ -1007,6 +978,9 @@
var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}}
+// TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer)
+// is duplicated between the client and the server.
+// The transport layer needs to be refactored to take care of this.
func (t *http2Server) itemHandler(i item) error {
var err error
defer func() {
@@ -1022,9 +996,39 @@
i.f()
}
case *headerFrame:
- err = t.framer.fr.WriteHeaders(i.p)
- case *continuationFrame:
- err = t.framer.fr.WriteContinuation(i.streamID, i.endHeaders, i.headerBlockFragment)
+ t.hBuf.Reset()
+ for _, f := range i.hf {
+ t.hEnc.WriteField(f)
+ }
+ first := true
+ endHeaders := false
+ for !endHeaders {
+ size := t.hBuf.Len()
+ if size > http2MaxFrameLen {
+ size = http2MaxFrameLen
+ } else {
+ endHeaders = true
+ }
+ if first {
+ first = false
+ err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{
+ StreamID: i.streamID,
+ BlockFragment: t.hBuf.Next(size),
+ EndStream: i.endStream,
+ EndHeaders: endHeaders,
+ })
+ } else {
+ err = t.framer.fr.WriteContinuation(
+ i.streamID,
+ endHeaders,
+ t.hBuf.Next(size),
+ )
+ }
+ if err != nil {
+ return err
+ }
+ }
+ atomic.StoreUint32(&t.resetPingStrikes, 1)
case *windowUpdate:
err = t.framer.fr.WriteWindowUpdate(i.streamID, i.increment)
case *settings:
diff --git a/transport/transport_test.go b/transport/transport_test.go
index 1d12b17..f30ebc6 100644
--- a/transport/transport_test.go
+++ b/transport/transport_test.go
@@ -163,12 +163,13 @@
}
func (h *testStreamHandler) handleStreamInvalidHeaderField(t *testing.T, s *Stream) {
- hBuf := bytes.NewBuffer([]byte{})
- hEnc := hpack.NewEncoder(hBuf)
- hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField})
- if err := h.t.writeHeaders(s, hBuf, false); err != nil {
- t.Fatalf("Failed to write headers: %v", err)
- }
+ headerFields := []hpack.HeaderField{}
+ headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField})
+ h.t.controlBuf.put(&headerFrame{
+ streamID: s.id,
+ hf: headerFields,
+ endStream: false,
+ })
}
func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) {