Use the same hpack encoder on a transport and share it between RPCs. (#1536)

diff --git a/stats/stats.go b/stats/stats.go
index b64c429..d5aa2f7 100644
--- a/stats/stats.go
+++ b/stats/stats.go
@@ -135,8 +135,6 @@
 type OutHeader struct {
 	// Client is true if this OutHeader is from client side.
 	Client bool
-	// WireLength is the wire length of header.
-	WireLength int
 
 	// The following fields are valid only if Client is true.
 	// FullMethod is the full RPC method string, i.e., /package.service/method.
diff --git a/stats/stats_test.go b/stats/stats_test.go
index 8865d3f..d66485f 100644
--- a/stats/stats_test.go
+++ b/stats/stats_test.go
@@ -444,10 +444,6 @@
 	if d.ctx == nil {
 		t.Fatalf("d.ctx = nil, want <non-nil>")
 	}
-	// TODO check real length, not just > 0.
-	if st.WireLength <= 0 {
-		t.Fatalf("st.Lenght = 0, want > 0")
-	}
 	if !d.client {
 		if st.FullMethod != e.method {
 			t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
@@ -530,18 +526,13 @@
 func checkInTrailer(t *testing.T, d *gotData, e *expectedData) {
 	var (
 		ok bool
-		st *stats.InTrailer
 	)
-	if st, ok = d.s.(*stats.InTrailer); !ok {
+	if _, ok = d.s.(*stats.InTrailer); !ok {
 		t.Fatalf("got %T, want InTrailer", d.s)
 	}
 	if d.ctx == nil {
 		t.Fatalf("d.ctx = nil, want <non-nil>")
 	}
-	// TODO check real length, not just > 0.
-	if st.WireLength <= 0 {
-		t.Fatalf("st.Lenght = 0, want > 0")
-	}
 }
 
 func checkOutHeader(t *testing.T, d *gotData, e *expectedData) {
@@ -555,10 +546,6 @@
 	if d.ctx == nil {
 		t.Fatalf("d.ctx = nil, want <non-nil>")
 	}
-	// TODO check real length, not just > 0.
-	if st.WireLength <= 0 {
-		t.Fatalf("st.Lenght = 0, want > 0")
-	}
 	if d.client {
 		if st.FullMethod != e.method {
 			t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
@@ -642,10 +629,6 @@
 	if st.Client {
 		t.Fatalf("st IsClient = true, want false")
 	}
-	// TODO check real length, not just > 0.
-	if st.WireLength <= 0 {
-		t.Fatalf("st.Lenght = 0, want > 0")
-	}
 }
 
 func checkEnd(t *testing.T, d *gotData, e *expectedData) {
diff --git a/transport/control.go b/transport/control.go
index 77914de..dd1a8d4 100644
--- a/transport/control.go
+++ b/transport/control.go
@@ -26,6 +26,7 @@
 	"time"
 
 	"golang.org/x/net/http2"
+	"golang.org/x/net/http2/hpack"
 )
 
 const (
@@ -56,7 +57,9 @@
 // control tasks, e.g., flow control, settings, streaming resetting, etc.
 
 type headerFrame struct {
-	p http2.HeadersFrameParam
+	streamID  uint32
+	hf        []hpack.HeaderField
+	endStream bool
 }
 
 func (*headerFrame) item() {}
diff --git a/transport/http2_client.go b/transport/http2_client.go
index 31fed9e..92ad868 100644
--- a/transport/http2_client.go
+++ b/transport/http2_client.go
@@ -193,6 +193,7 @@
 		icwz = opts.InitialConnWindowSize
 		dynamicWindow = false
 	}
+	var buf bytes.Buffer
 	t := &http2Client{
 		ctx:        ctx,
 		target:     addr.Addr,
@@ -209,6 +210,8 @@
 		goAway:            make(chan struct{}),
 		awakenKeepalive:   make(chan struct{}, 1),
 		framer:            newFramer(conn),
+		hBuf:              &buf,
+		hEnc:              hpack.NewEncoder(&buf),
 		controlBuf:        newControlBuffer(),
 		fc:                &inFlow{limit: uint32(icwz)},
 		sendQuotaPool:     newQuotaPool(defaultWindowSize),
@@ -361,7 +364,7 @@
 			authData[k] = v
 		}
 	}
-	callAuthData := make(map[string]string)
+	callAuthData := map[string]string{}
 	// Check if credentials.PerRPCCredentials were provided via call options.
 	// Note: if these credentials are provided both via dial options and call
 	// options, then both sets of credentials will be applied.
@@ -401,40 +404,40 @@
 	if sq > 1 {
 		t.streamsQuota.add(sq - 1)
 	}
-	// HPACK encodes various headers.
-	hBuf := bytes.NewBuffer([]byte{})
-	hEnc := hpack.NewEncoder(hBuf)
-	hEnc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"})
-	hEnc.WriteField(hpack.HeaderField{Name: ":scheme", Value: t.scheme})
-	hEnc.WriteField(hpack.HeaderField{Name: ":path", Value: callHdr.Method})
-	hEnc.WriteField(hpack.HeaderField{Name: ":authority", Value: callHdr.Host})
-	hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
-	hEnc.WriteField(hpack.HeaderField{Name: "user-agent", Value: t.userAgent})
-	hEnc.WriteField(hpack.HeaderField{Name: "te", Value: "trailers"})
+	// TODO(mmukhi): Benchmark if the perfomance gets better if count the metadata and other header fields
+	// first and create a slice of that exact size.
+	// Make the slice of certain predictable size to reduce allocations made by append.
+	hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te
+	hfLen += len(authData) + len(callAuthData)
+	headerFields := make([]hpack.HeaderField, 0, hfLen)
+	headerFields = append(headerFields, hpack.HeaderField{Name: ":method", Value: "POST"})
+	headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: t.scheme})
+	headerFields = append(headerFields, hpack.HeaderField{Name: ":path", Value: callHdr.Method})
+	headerFields = append(headerFields, hpack.HeaderField{Name: ":authority", Value: callHdr.Host})
+	headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
+	headerFields = append(headerFields, hpack.HeaderField{Name: "user-agent", Value: t.userAgent})
+	headerFields = append(headerFields, hpack.HeaderField{Name: "te", Value: "trailers"})
 
 	if callHdr.SendCompress != "" {
-		hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
+		headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
 	}
 	if dl, ok := ctx.Deadline(); ok {
 		// Send out timeout regardless its value. The server can detect timeout context by itself.
+		// TODO(mmukhi): Perhaps this field should be updated when actually writing out to the wire.
 		timeout := dl.Sub(time.Now())
-		hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
+		headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
 	}
-
 	for k, v := range authData {
-		hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
+		headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
 	}
 	for k, v := range callAuthData {
-		hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
+		headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
 	}
-	var (
-		endHeaders bool
-	)
 	if b := stats.OutgoingTags(ctx); b != nil {
-		hEnc.WriteField(hpack.HeaderField{Name: "grpc-tags-bin", Value: encodeBinHeader(b)})
+		headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-tags-bin", Value: encodeBinHeader(b)})
 	}
 	if b := stats.OutgoingTrace(ctx); b != nil {
-		hEnc.WriteField(hpack.HeaderField{Name: "grpc-trace-bin", Value: encodeBinHeader(b)})
+		headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-trace-bin", Value: encodeBinHeader(b)})
 	}
 	if md, ok := metadata.FromOutgoingContext(ctx); ok {
 		for k, vv := range md {
@@ -443,7 +446,7 @@
 				continue
 			}
 			for _, v := range vv {
-				hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
+				headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
 			}
 		}
 	}
@@ -453,7 +456,7 @@
 				continue
 			}
 			for _, v := range vv {
-				hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
+				headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
 			}
 		}
 	}
@@ -482,34 +485,11 @@
 		default:
 		}
 	}
-	first := true
-	bufLen := hBuf.Len()
-	// Sends the headers in a single batch even when they span multiple frames.
-	for !endHeaders {
-		size := hBuf.Len()
-		if size > http2MaxFrameLen {
-			size = http2MaxFrameLen
-		} else {
-			endHeaders = true
-		}
-		if first {
-			// Sends a HeadersFrame to server to start a new stream.
-			p := http2.HeadersFrameParam{
-				StreamID:      s.id,
-				BlockFragment: hBuf.Next(size),
-				EndStream:     false,
-				EndHeaders:    endHeaders,
-			}
-			// Do a force flush for the buffered frames iff it is the last headers frame
-			// and there is header metadata to be sent. Otherwise, there is flushing until
-			// the corresponding data frame is written.
-			t.controlBuf.put(&headerFrame{p})
-			first = false
-		} else {
-			// Sends Continuation frames for the leftover headers.
-			t.controlBuf.put(&continuationFrame{streamID: s.id, endHeaders: endHeaders, headerBlockFragment: hBuf.Next(size)})
-		}
-	}
+	t.controlBuf.put(&headerFrame{
+		streamID:  s.id,
+		hf:        headerFields,
+		endStream: false,
+	})
 	t.mu.Unlock()
 
 	s.mu.Lock()
@@ -519,7 +499,6 @@
 	if t.statsHandler != nil {
 		outHeader := &stats.OutHeader{
 			Client:      true,
-			WireLength:  bufLen,
 			FullMethod:  callHdr.Method,
 			RemoteAddr:  t.remoteAddr,
 			LocalAddr:   t.localAddr,
@@ -770,7 +749,7 @@
 		return
 	}
 	if w := s.fc.maybeAdjust(n); w > 0 {
-		// Piggyback conneciton's window update along.
+		// Piggyback connection's window update along.
 		if cw := t.fc.resetPendingUpdate(); cw > 0 {
 			t.controlBuf.put(&windowUpdate{0, cw})
 		}
@@ -1200,6 +1179,9 @@
 	}
 }
 
+// TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer)
+// is duplicated between the client and the server.
+// The transport layer needs to be refactored to take care of this.
 func (t *http2Client) itemHandler(i item) error {
 	var err error
 	defer func() {
@@ -1214,9 +1196,38 @@
 			i.f()
 		}
 	case *headerFrame:
-		err = t.framer.fr.WriteHeaders(i.p)
-	case *continuationFrame:
-		err = t.framer.fr.WriteContinuation(i.streamID, i.endHeaders, i.headerBlockFragment)
+		t.hBuf.Reset()
+		for _, f := range i.hf {
+			t.hEnc.WriteField(f)
+		}
+		endHeaders := false
+		first := true
+		for !endHeaders {
+			size := t.hBuf.Len()
+			if size > http2MaxFrameLen {
+				size = http2MaxFrameLen
+			} else {
+				endHeaders = true
+			}
+			if first {
+				first = false
+				err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{
+					StreamID:      i.streamID,
+					BlockFragment: t.hBuf.Next(size),
+					EndStream:     i.endStream,
+					EndHeaders:    endHeaders,
+				})
+			} else {
+				err = t.framer.fr.WriteContinuation(
+					i.streamID,
+					endHeaders,
+					t.hBuf.Next(size),
+				)
+			}
+			if err != nil {
+				return err
+			}
+		}
 	case *windowUpdate:
 		err = t.framer.fr.WriteWindowUpdate(i.streamID, i.increment)
 	case *settings:
diff --git a/transport/http2_server.go b/transport/http2_server.go
index 4f62cba..0f0e759 100644
--- a/transport/http2_server.go
+++ b/transport/http2_server.go
@@ -63,6 +63,8 @@
 	// blocking forever after Close.
 	shutdownChan chan struct{}
 	framer       *framer
+	hBuf         *bytes.Buffer  // the buffer for HPACK encoding
+	hEnc         *hpack.Encoder // HPACK encoder
 	// The max number of concurrent streams.
 	maxStreams uint32
 	// controlBuf delivers all the control related tasks (e.g., window
@@ -175,6 +177,7 @@
 	if kep.MinTime == 0 {
 		kep.MinTime = defaultKeepalivePolicyMinTime
 	}
+	var buf bytes.Buffer
 	t := &http2Server{
 		ctx:               context.Background(),
 		conn:              conn,
@@ -182,6 +185,8 @@
 		localAddr:         conn.LocalAddr(),
 		authInfo:          config.AuthInfo,
 		framer:            framer,
+		hBuf:              &buf,
+		hEnc:              hpack.NewEncoder(&buf),
 		maxStreams:        maxStreams,
 		inTapHandle:       config.InTapHandle,
 		controlBuf:        newControlBuffer(),
@@ -639,7 +644,7 @@
 	t.mu.Unlock()
 	if ns < 1 && !t.kep.PermitWithoutStream {
 		// Keepalive shouldn't be active thus, this new ping should
-		// have come after atleast defaultPingTimeout.
+		// have come after at least defaultPingTimeout.
 		if t.lastPingAt.Add(defaultPingTimeout).After(now) {
 			t.pingStrikes++
 		}
@@ -669,34 +674,6 @@
 	}
 }
 
-func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) error {
-	first := true
-	endHeaders := false
-	// Sends the headers in a single batch.
-	for !endHeaders {
-		size := b.Len()
-		if size > http2MaxFrameLen {
-			size = http2MaxFrameLen
-		} else {
-			endHeaders = true
-		}
-		if first {
-			p := http2.HeadersFrameParam{
-				StreamID:      s.id,
-				BlockFragment: b.Next(size),
-				EndStream:     endStream,
-				EndHeaders:    endHeaders,
-			}
-			t.controlBuf.put(&headerFrame{p})
-			first = false
-		} else {
-			t.controlBuf.put(&continuationFrame{streamID: s.id, endHeaders: endHeaders, headerBlockFragment: b.Next(size)})
-		}
-	}
-	atomic.StoreUint32(&t.resetPingStrikes, 1)
-	return nil
-}
-
 // WriteHeader sends the header metedata md back to the client.
 func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
 	select {
@@ -722,13 +699,13 @@
 	}
 	md = s.header
 	s.mu.Unlock()
-
-	hBuf := bytes.NewBuffer([]byte{}) // TODO(mmukhi): Try and re-use this memory later.
-	hEnc := hpack.NewEncoder(hBuf)
-	hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
-	hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
+	// TODO(mmukhi): Benchmark if the perfomance gets better if count the metadata and other header fields
+	// first and create a slice of that exact size.
+	headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else.
+	headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
+	headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
 	if s.sendCompress != "" {
-		hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
+		headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
 	}
 	for k, vv := range md {
 		if isReservedHeader(k) {
@@ -736,16 +713,17 @@
 			continue
 		}
 		for _, v := range vv {
-			hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
+			headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
 		}
 	}
-	bufLen := hBuf.Len()
-	if err := t.writeHeaders(s, hBuf, false); err != nil {
-		return err
-	}
+	t.controlBuf.put(&headerFrame{
+		streamID:  s.id,
+		hf:        headerFields,
+		endStream: false,
+	})
 	if t.stats != nil {
 		outHeader := &stats.OutHeader{
-			WireLength: bufLen,
+		//WireLength: // TODO(mmukhi): Revisit this later, if needed.
 		}
 		t.stats.HandleRPC(s.Context(), outHeader)
 	}
@@ -782,18 +760,15 @@
 		headersSent = true
 	}
 
-	hBuf := bytes.NewBuffer([]byte{}) // TODO(mmukhi): Try and re-use this memory.
-	hEnc := hpack.NewEncoder(hBuf)
+	// TODO(mmukhi): Benchmark if the perfomance gets better if count the metadata and other header fields
+	// first and create a slice of that exact size.
+	headerFields := make([]hpack.HeaderField, 0, 2) // grpc-status and grpc-message will be there if none else.
 	if !headersSent {
-		hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
-		hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
+		headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"})
+		headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
 	}
-	hEnc.WriteField(
-		hpack.HeaderField{
-			Name:  "grpc-status",
-			Value: strconv.Itoa(int(st.Code())),
-		})
-	hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())})
+	headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))})
+	headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())})
 
 	if p := st.Proto(); p != nil && len(p.Details) > 0 {
 		stBytes, err := proto.Marshal(p)
@@ -802,7 +777,7 @@
 			panic(err)
 		}
 
-		hEnc.WriteField(hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)})
+		headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)})
 	}
 
 	// Attach the trailer metadata.
@@ -812,19 +787,16 @@
 			continue
 		}
 		for _, v := range vv {
-			hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
+			headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
 		}
 	}
-	bufLen := hBuf.Len()
-	if err := t.writeHeaders(s, hBuf, true); err != nil {
-		t.Close()
-		return err
-	}
+	t.controlBuf.put(&headerFrame{
+		streamID:  s.id,
+		hf:        headerFields,
+		endStream: true,
+	})
 	if t.stats != nil {
-		outTrailer := &stats.OutTrailer{
-			WireLength: bufLen,
-		}
-		t.stats.HandleRPC(s.Context(), outTrailer)
+		t.stats.HandleRPC(s.Context(), &stats.OutTrailer{})
 	}
 	t.closeStream(s)
 	return nil
@@ -904,7 +876,6 @@
 			atomic.StoreUint32(&t.resetPingStrikes, 1)
 			success := func() {
 				t.controlBuf.put(&dataFrame{streamID: s.id, endStream: false, d: p, f: func() {
-					//fmt.Println("Adding quota back to localEendQuota", ps)
 					s.localSendQuota.add(ps)
 				}})
 				if ps < sq {
@@ -1007,6 +978,9 @@
 
 var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}}
 
+// TODO(mmukhi): A lot of this code(and code in other places in the tranpsort layer)
+// is duplicated between the client and the server.
+// The transport layer needs to be refactored to take care of this.
 func (t *http2Server) itemHandler(i item) error {
 	var err error
 	defer func() {
@@ -1022,9 +996,39 @@
 			i.f()
 		}
 	case *headerFrame:
-		err = t.framer.fr.WriteHeaders(i.p)
-	case *continuationFrame:
-		err = t.framer.fr.WriteContinuation(i.streamID, i.endHeaders, i.headerBlockFragment)
+		t.hBuf.Reset()
+		for _, f := range i.hf {
+			t.hEnc.WriteField(f)
+		}
+		first := true
+		endHeaders := false
+		for !endHeaders {
+			size := t.hBuf.Len()
+			if size > http2MaxFrameLen {
+				size = http2MaxFrameLen
+			} else {
+				endHeaders = true
+			}
+			if first {
+				first = false
+				err = t.framer.fr.WriteHeaders(http2.HeadersFrameParam{
+					StreamID:      i.streamID,
+					BlockFragment: t.hBuf.Next(size),
+					EndStream:     i.endStream,
+					EndHeaders:    endHeaders,
+				})
+			} else {
+				err = t.framer.fr.WriteContinuation(
+					i.streamID,
+					endHeaders,
+					t.hBuf.Next(size),
+				)
+			}
+			if err != nil {
+				return err
+			}
+		}
+		atomic.StoreUint32(&t.resetPingStrikes, 1)
 	case *windowUpdate:
 		err = t.framer.fr.WriteWindowUpdate(i.streamID, i.increment)
 	case *settings:
diff --git a/transport/transport_test.go b/transport/transport_test.go
index 1d12b17..f30ebc6 100644
--- a/transport/transport_test.go
+++ b/transport/transport_test.go
@@ -163,12 +163,13 @@
 }
 
 func (h *testStreamHandler) handleStreamInvalidHeaderField(t *testing.T, s *Stream) {
-	hBuf := bytes.NewBuffer([]byte{})
-	hEnc := hpack.NewEncoder(hBuf)
-	hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField})
-	if err := h.t.writeHeaders(s, hBuf, false); err != nil {
-		t.Fatalf("Failed to write headers: %v", err)
-	}
+	headerFields := []hpack.HeaderField{}
+	headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField})
+	h.t.controlBuf.put(&headerFrame{
+		streamID:  s.id,
+		hf:        headerFields,
+		endStream: false,
+	})
 }
 
 func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) {