Remove buf copy when the compressor exist (#1427)

diff --git a/call.go b/call.go
index 797190f..438758f 100644
--- a/call.go
+++ b/call.go
@@ -99,17 +99,17 @@
 			Client: true,
 		}
 	}
-	outBuf, err := encode(dopts.codec, args, compressor, cbuf, outPayload)
+	hdr, data, err := encode(dopts.codec, args, compressor, cbuf, outPayload)
 	if err != nil {
 		return err
 	}
 	if c.maxSendMessageSize == nil {
 		return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)")
 	}
-	if len(outBuf) > *c.maxSendMessageSize {
-		return Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(outBuf), *c.maxSendMessageSize)
+	if len(data) > *c.maxSendMessageSize {
+		return Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), *c.maxSendMessageSize)
 	}
-	err = t.Write(stream, outBuf, opts)
+	err = t.Write(stream, hdr, data, opts)
 	if err == nil && outPayload != nil {
 		outPayload.SentTime = time.Now()
 		dopts.copts.StatsHandler.HandleRPC(ctx, outPayload)
diff --git a/call_test.go b/call_test.go
index deb3cb6..f311309 100644
--- a/call_test.go
+++ b/call_test.go
@@ -104,12 +104,12 @@
 		}
 	}
 	// send a response back to end the stream.
-	reply, err := encode(testCodec{}, &expectedResponse, nil, nil, nil)
+	hdr, data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil)
 	if err != nil {
 		t.Errorf("Failed to encode the response: %v", err)
 		return
 	}
-	h.t.Write(s, reply, &transport.Options{})
+	h.t.Write(s, hdr, data, &transport.Options{})
 	h.t.WriteStatus(s, status.New(codes.OK, ""))
 }
 
diff --git a/rpc_util.go b/rpc_util.go
index be8444a..caded65 100644
--- a/rpc_util.go
+++ b/rpc_util.go
@@ -288,19 +288,20 @@
 	return pf, msg, nil
 }
 
-// encode serializes msg and prepends the message header. If msg is nil, it
-// generates the message header of 0 message length.
-func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayload *stats.OutPayload) ([]byte, error) {
-	var (
-		b      []byte
-		length uint
+// encode serializes msg and returns a buffer of message header and a buffer of msg.
+// If msg is nil, it generates the message header and an empty msg buffer.
+func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayload *stats.OutPayload) ([]byte, []byte, error) {
+	var b []byte
+	const (
+		payloadLen = 1
+		sizeLen    = 4
 	)
+
 	if msg != nil {
 		var err error
-		// TODO(zhaoq): optimize to reduce memory alloc and copying.
 		b, err = c.Marshal(msg)
 		if err != nil {
-			return nil, Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
+			return nil, nil, Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
 		}
 		if outPayload != nil {
 			outPayload.Payload = msg
@@ -310,39 +311,28 @@
 		}
 		if cp != nil {
 			if err := cp.Do(cbuf, b); err != nil {
-				return nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
+				return nil, nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
 			}
 			b = cbuf.Bytes()
 		}
-		length = uint(len(b))
-	}
-	if length > math.MaxUint32 {
-		return nil, Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", length)
 	}
 
-	const (
-		payloadLen = 1
-		sizeLen    = 4
-	)
+	if len(b) > math.MaxUint32 {
+		return nil, nil, Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
+	}
 
-	var buf = make([]byte, payloadLen+sizeLen+len(b))
-
-	// Write payload format
+	bufHeader := make([]byte, payloadLen+sizeLen)
 	if cp == nil {
-		buf[0] = byte(compressionNone)
+		bufHeader[0] = byte(compressionNone)
 	} else {
-		buf[0] = byte(compressionMade)
+		bufHeader[0] = byte(compressionMade)
 	}
 	// Write length of b into buf
-	binary.BigEndian.PutUint32(buf[1:], uint32(length))
-	// Copy encoded msg to buf
-	copy(buf[5:], b)
-
+	binary.BigEndian.PutUint32(bufHeader[payloadLen:], uint32(len(b)))
 	if outPayload != nil {
-		outPayload.WireLength = len(buf)
+		outPayload.WireLength = payloadLen + sizeLen + len(b)
 	}
-
-	return buf, nil
+	return bufHeader, b, nil
 }
 
 func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) error {
diff --git a/rpc_util_test.go b/rpc_util_test.go
index 7cbad49..23c471e 100644
--- a/rpc_util_test.go
+++ b/rpc_util_test.go
@@ -104,14 +104,15 @@
 		msg proto.Message
 		cp  Compressor
 		// outputs
-		b   []byte
-		err error
+		hdr  []byte
+		data []byte
+		err  error
 	}{
-		{nil, nil, []byte{0, 0, 0, 0, 0}, nil},
+		{nil, nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil},
 	} {
-		b, err := encode(protoCodec{}, test.msg, nil, nil, nil)
-		if err != test.err || !bytes.Equal(b, test.b) {
-			t.Fatalf("encode(_, _, %v, _) = %v, %v\nwant %v, %v", test.cp, b, err, test.b, test.err)
+		hdr, data, err := encode(protoCodec{}, test.msg, nil, nil, nil)
+		if err != test.err || !bytes.Equal(hdr, test.hdr) || !bytes.Equal(data, test.data) {
+			t.Fatalf("encode(_, _, %v, _) = %v, %v, %v\nwant %v, %v, %v", test.cp, hdr, data, err, test.hdr, test.data, test.err)
 		}
 	}
 }
@@ -164,8 +165,8 @@
 // bytes.
 func bmEncode(b *testing.B, mSize int) {
 	msg := &perfpb.Buffer{Body: make([]byte, mSize)}
-	encoded, _ := encode(protoCodec{}, msg, nil, nil, nil)
-	encodedSz := int64(len(encoded))
+	encodeHdr, encodeData, _ := encode(protoCodec{}, msg, nil, nil, nil)
+	encodedSz := int64(len(encodeHdr) + len(encodeData))
 	b.ReportAllocs()
 	b.ResetTimer()
 	for i := 0; i < b.N; i++ {
diff --git a/server.go b/server.go
index 5885c6c..86fe20a 100644
--- a/server.go
+++ b/server.go
@@ -677,15 +677,15 @@
 	if s.opts.statsHandler != nil {
 		outPayload = &stats.OutPayload{}
 	}
-	p, err := encode(s.opts.codec, msg, cp, cbuf, outPayload)
+	hdr, data, err := encode(s.opts.codec, msg, cp, cbuf, outPayload)
 	if err != nil {
 		grpclog.Errorln("grpc: server failed to encode response: ", err)
 		return err
 	}
-	if len(p) > s.opts.maxSendMessageSize {
-		return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(p), s.opts.maxSendMessageSize)
+	if len(data) > s.opts.maxSendMessageSize {
+		return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize)
 	}
-	err = t.Write(stream, p, opts)
+	err = t.Write(stream, hdr, data, opts)
 	if err == nil && outPayload != nil {
 		outPayload.SentTime = time.Now()
 		s.opts.statsHandler.HandleRPC(stream.Context(), outPayload)
diff --git a/stream.go b/stream.go
index c155d3d..2fcf368 100644
--- a/stream.go
+++ b/stream.go
@@ -362,7 +362,7 @@
 			Client: true,
 		}
 	}
-	out, err := encode(cs.codec, m, cs.cp, cs.cbuf, outPayload)
+	hdr, data, err := encode(cs.codec, m, cs.cp, cs.cbuf, outPayload)
 	defer func() {
 		if cs.cbuf != nil {
 			cs.cbuf.Reset()
@@ -374,10 +374,10 @@
 	if cs.c.maxSendMessageSize == nil {
 		return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)")
 	}
-	if len(out) > *cs.c.maxSendMessageSize {
-		return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(out), *cs.c.maxSendMessageSize)
+	if len(data) > *cs.c.maxSendMessageSize {
+		return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize)
 	}
-	err = cs.t.Write(cs.s, out, &transport.Options{Last: false})
+	err = cs.t.Write(cs.s, hdr, data, &transport.Options{Last: false})
 	if err == nil && outPayload != nil {
 		outPayload.SentTime = time.Now()
 		cs.statsHandler.HandleRPC(cs.statsCtx, outPayload)
@@ -449,7 +449,7 @@
 }
 
 func (cs *clientStream) CloseSend() (err error) {
-	err = cs.t.Write(cs.s, nil, &transport.Options{Last: true})
+	err = cs.t.Write(cs.s, nil, nil, &transport.Options{Last: true})
 	defer func() {
 		if err != nil {
 			cs.finish(err)
@@ -608,7 +608,7 @@
 	if ss.statsHandler != nil {
 		outPayload = &stats.OutPayload{}
 	}
-	out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload)
+	hdr, data, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload)
 	defer func() {
 		if ss.cbuf != nil {
 			ss.cbuf.Reset()
@@ -617,10 +617,10 @@
 	if err != nil {
 		return err
 	}
-	if len(out) > ss.maxSendMessageSize {
-		return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(out), ss.maxSendMessageSize)
+	if len(data) > ss.maxSendMessageSize {
+		return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize)
 	}
-	if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil {
+	if err := ss.t.Write(ss.s, hdr, data, &transport.Options{Last: false}); err != nil {
 		return toRPCErr(err)
 	}
 	if outPayload != nil {
diff --git a/transport/handler_server.go b/transport/handler_server.go
index 85b8ee0..0489fad 100644
--- a/transport/handler_server.go
+++ b/transport/handler_server.go
@@ -255,9 +255,10 @@
 	}
 }
 
-func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error {
+func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
 	return ht.do(func() {
 		ht.writeCommonHeaders(s)
+		ht.rw.Write(hdr)
 		ht.rw.Write(data)
 		if !opts.Delay {
 			ht.rw.(http.Flusher).Flush()
diff --git a/transport/http2_client.go b/transport/http2_client.go
index 8546d09..5f22913 100644
--- a/transport/http2_client.go
+++ b/transport/http2_client.go
@@ -683,8 +683,15 @@
 // should proceed only if Write returns nil.
 // TODO(zhaoq): opts.Delay is ignored in this implementation. Support it later
 // if it improves the performance.
-func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
-	r := bytes.NewBuffer(data)
+func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
+	secondStart := http2MaxFrameLen - len(hdr)%http2MaxFrameLen
+	if len(data) < secondStart {
+		secondStart = len(data)
+	}
+	hdr = append(hdr, data[:secondStart]...)
+	data = data[secondStart:]
+	isLastSlice := (len(data) == 0)
+	r := bytes.NewBuffer(hdr)
 	var (
 		p   []byte
 		oqv uint32
@@ -726,9 +733,6 @@
 			endStream  bool
 			forceFlush bool
 		)
-		if opts.Last && r.Len() == 0 {
-			endStream = true
-		}
 		// Indicate there is a writer who is about to write a data frame.
 		t.framer.adjustNumWriters(1)
 		// Got some quota. Try to acquire writing privilege on the transport.
@@ -768,10 +772,22 @@
 			t.writableChan <- 0
 			continue
 		}
-		if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 {
-			// Do a force flush iff this is last frame for the entire gRPC message
-			// and the caller is the only writer at this moment.
-			forceFlush = true
+		if r.Len() == 0 {
+			if isLastSlice {
+				if opts.Last {
+					endStream = true
+				}
+				if t.framer.adjustNumWriters(0) == 1 {
+					// Do a force flush iff this is last frame for the entire gRPC message
+					// and the caller is the only writer at this moment.
+					forceFlush = true
+				}
+			} else {
+				isLastSlice = true
+				if len(data) != 0 {
+					r = bytes.NewBuffer(data)
+				}
+			}
 		}
 		// If WriteData fails, all the pending streams will be handled
 		// by http2Client.Close(). No explicit CloseStream() needs to be
diff --git a/transport/http2_server.go b/transport/http2_server.go
index 6ee6f40..302651b 100644
--- a/transport/http2_server.go
+++ b/transport/http2_server.go
@@ -827,8 +827,15 @@
 
 // Write converts the data into HTTP2 data frame and sends it out. Non-nil error
 // is returns if it fails (e.g., framing error, transport error).
-func (t *http2Server) Write(s *Stream, data []byte, opts *Options) (err error) {
+func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) (err error) {
 	// TODO(zhaoq): Support multi-writers for a single stream.
+	secondStart := http2MaxFrameLen - len(hdr)%http2MaxFrameLen
+	if len(data) < secondStart {
+		secondStart = len(data)
+	}
+	hdr = append(hdr, data[:secondStart]...)
+	data = data[secondStart:]
+	isLastSlice := (len(data) == 0)
 	var writeHeaderFrame bool
 	s.mu.Lock()
 	if s.state == streamDone {
@@ -842,7 +849,7 @@
 	if writeHeaderFrame {
 		t.WriteHeader(s, nil)
 	}
-	r := bytes.NewBuffer(data)
+	r := bytes.NewBuffer(hdr)
 	var (
 		p   []byte
 		oqv uint32
@@ -921,8 +928,15 @@
 			continue
 		}
 		var forceFlush bool
-		if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 && !opts.Last {
-			forceFlush = true
+		if r.Len() == 0 {
+			if isLastSlice {
+				if t.framer.adjustNumWriters(0) == 1 && !opts.Last {
+					forceFlush = true
+				}
+			} else {
+				r = bytes.NewBuffer(data)
+				isLastSlice = true
+			}
 		}
 		// Reset ping strikes when sending data since this might cause
 		// the peer to send ping.
diff --git a/transport/transport.go b/transport/transport.go
index ec0fe67..c5732be 100644
--- a/transport/transport.go
+++ b/transport/transport.go
@@ -564,7 +564,7 @@
 
 	// Write sends the data for the given stream. A nil stream indicates
 	// the write is to be performed on the transport as a whole.
-	Write(s *Stream, data []byte, opts *Options) error
+	Write(s *Stream, hdr []byte, data []byte, opts *Options) error
 
 	// NewStream creates a Stream for an RPC.
 	NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error)
@@ -606,7 +606,7 @@
 
 	// Write sends the data for the given stream.
 	// Write may not be called on all streams.
-	Write(s *Stream, data []byte, opts *Options) error
+	Write(s *Stream, hdr []byte, data []byte, opts *Options) error
 
 	// WriteStatus sends the status of a stream to the client.  WriteStatus is
 	// the final call made on a stream and always occurs.
diff --git a/transport/transport_test.go b/transport/transport_test.go
index 8610478..be2d8da 100644
--- a/transport/transport_test.go
+++ b/transport/transport_test.go
@@ -92,7 +92,7 @@
 		t.Fatalf("handleStream got %v, want %v", p, req)
 	}
 	// send a response back to the client.
-	h.t.Write(s, resp, &Options{})
+	h.t.Write(s, resp, nil, &Options{})
 	// send the trailer to end the stream.
 	h.t.WriteStatus(s, status.New(codes.OK, ""))
 }
@@ -112,7 +112,7 @@
 		buf[0] = byte(0)
 		binary.BigEndian.PutUint32(buf[1:], uint32(sz))
 		copy(buf[5:], msg)
-		h.t.Write(s, buf, &Options{})
+		h.t.Write(s, buf, nil, &Options{})
 	}
 }
 
@@ -190,7 +190,7 @@
 		t.Fatalf("handleStream got %v, want %v", p, req)
 	}
 	// send a response back to the client.
-	h.t.Write(s, resp, &Options{})
+	h.t.Write(s, resp, nil, &Options{})
 	// send the trailer to end the stream.
 	h.t.WriteStatus(s, status.New(codes.OK, ""))
 }
@@ -215,7 +215,7 @@
 	// Wait before sending. Give time to client to start reading
 	// before server starts sending.
 	time.Sleep(2 * time.Second)
-	h.t.Write(s, resp, &Options{})
+	h.t.Write(s, resp, nil, &Options{})
 	// send the trailer to end the stream.
 	h.t.WriteStatus(s, status.New(codes.OK, ""))
 }
@@ -808,7 +808,7 @@
 		Last:  true,
 		Delay: false,
 	}
-	if err := ct.Write(s1, expectedRequest, &opts); err != nil && err != io.EOF {
+	if err := ct.Write(s1, expectedRequest, nil, &opts); err != nil && err != io.EOF {
 		t.Fatalf("failed to send data: %v", err)
 	}
 	p := make([]byte, len(expectedResponse))
@@ -845,7 +845,7 @@
 		Last:  true,
 		Delay: false,
 	}
-	if err := ct.Write(s, expectedRequest, &opts); err == nil || err == io.EOF {
+	if err := ct.Write(s, expectedRequest, nil, &opts); err == nil || err == io.EOF {
 		time.Sleep(5 * time.Millisecond)
 		// The following s.Recv()'s could error out because the
 		// underlying transport is gone.
@@ -889,7 +889,7 @@
 			if err != nil {
 				t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
 			}
-			if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
+			if err := ct.Write(s, expectedRequestLarge, nil, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
 				t.Errorf("%v.Write(_, _, _) = %v, want  <nil>", ct, err)
 			}
 			p := make([]byte, len(expectedResponseLarge))
@@ -921,7 +921,7 @@
 			if err != nil {
 				t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
 			}
-			if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
+			if err := ct.Write(s, expectedRequestLarge, nil, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
 				t.Errorf("%v.Write(_, _, _) = %v, want  <nil>", ct, err)
 			}
 			p := make([]byte, len(expectedResponseLarge))
@@ -959,7 +959,7 @@
 
 			// Give time to server to start reading before client starts sending.
 			time.Sleep(2 * time.Second)
-			if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
+			if err := ct.Write(s, expectedRequestLarge, nil, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
 				t.Errorf("%v.Write(_, _, _) = %v, want  <nil>", ct, err)
 			}
 			p := make([]byte, len(expectedResponseLarge))
@@ -1005,7 +1005,7 @@
 		Delay: false,
 	}
 	// The stream which was created before graceful close can still proceed.
-	if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF {
+	if err := ct.Write(s, expectedRequest, nil, &opts); err != nil && err != io.EOF {
 		t.Fatalf("%v.Write(_, _, _) = %v, want  <nil>", ct, err)
 	}
 	p := make([]byte, len(expectedResponse))
@@ -1034,7 +1034,7 @@
 	}
 	// Write should not be done successfully due to flow control.
 	msg := make([]byte, initialWindowSize*8)
-	err = ct.Write(s, msg, &Options{Last: true, Delay: false})
+	err = ct.Write(s, msg, nil, &Options{Last: true, Delay: false})
 	expectedErr := streamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded)
 	if err != expectedErr {
 		t.Fatalf("Write got %v, want %v", err, expectedErr)
@@ -1311,7 +1311,7 @@
 		t.Fatalf("Failed to create 1st stream. Err: %v", err)
 	}
 	// Exhaust server's connection window.
-	if err := client.Write(cstream1, make([]byte, defaultWindowSize), &Options{Last: true}); err != nil {
+	if err := client.Write(cstream1, make([]byte, defaultWindowSize), nil, &Options{Last: true}); err != nil {
 		t.Fatalf("Client failed to write data. Err: %v", err)
 	}
 	//Client should be able to create another stream and send data on it.
@@ -1319,7 +1319,7 @@
 	if err != nil {
 		t.Fatalf("Failed to create 2nd stream. Err: %v", err)
 	}
-	if err := client.Write(cstream2, make([]byte, defaultWindowSize), &Options{}); err != nil {
+	if err := client.Write(cstream2, make([]byte, defaultWindowSize), nil, &Options{}); err != nil {
 		t.Fatalf("Client failed to write data. Err: %v", err)
 	}
 	// Get the streams on server.
@@ -1474,7 +1474,7 @@
 		t.Fatalf("Failed to open stream: %v", err)
 	}
 	d := make([]byte, 1)
-	if err := ct.Write(s, d, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
+	if err := ct.Write(s, d, nil, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
 		t.Fatalf("Failed to write: %v", err)
 	}
 	// Read without window update.
@@ -1516,7 +1516,7 @@
 		Last:  true,
 		Delay: false,
 	}
-	if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF {
+	if err := ct.Write(s, expectedRequest, nil, &opts); err != nil && err != io.EOF {
 		t.Fatalf("Failed to write the request: %v", err)
 	}
 	p := make([]byte, http2MaxFrameLen)
@@ -1544,7 +1544,7 @@
 		Last:  true,
 		Delay: false,
 	}
-	if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF {
+	if err := ct.Write(s, expectedRequest, nil, &opts); err != nil && err != io.EOF {
 		t.Fatalf("Failed to write the request: %v", err)
 	}
 	p := make([]byte, http2MaxFrameLen)
@@ -1787,7 +1787,7 @@
 	opts := Options{}
 	header := make([]byte, 5)
 	for i := 1; i <= 10; i++ {
-		if err := ct.Write(cstream, buf, &opts); err != nil {
+		if err := ct.Write(cstream, buf, nil, &opts); err != nil {
 			t.Fatalf("Error on client while writing message: %v", err)
 		}
 		if _, err := cstream.Read(header); err != nil {