Remove buf copy when the compressor exist (#1427)
diff --git a/call.go b/call.go
index 797190f..438758f 100644
--- a/call.go
+++ b/call.go
@@ -99,17 +99,17 @@
Client: true,
}
}
- outBuf, err := encode(dopts.codec, args, compressor, cbuf, outPayload)
+ hdr, data, err := encode(dopts.codec, args, compressor, cbuf, outPayload)
if err != nil {
return err
}
if c.maxSendMessageSize == nil {
return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)")
}
- if len(outBuf) > *c.maxSendMessageSize {
- return Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(outBuf), *c.maxSendMessageSize)
+ if len(data) > *c.maxSendMessageSize {
+ return Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), *c.maxSendMessageSize)
}
- err = t.Write(stream, outBuf, opts)
+ err = t.Write(stream, hdr, data, opts)
if err == nil && outPayload != nil {
outPayload.SentTime = time.Now()
dopts.copts.StatsHandler.HandleRPC(ctx, outPayload)
diff --git a/call_test.go b/call_test.go
index deb3cb6..f311309 100644
--- a/call_test.go
+++ b/call_test.go
@@ -104,12 +104,12 @@
}
}
// send a response back to end the stream.
- reply, err := encode(testCodec{}, &expectedResponse, nil, nil, nil)
+ hdr, data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil)
if err != nil {
t.Errorf("Failed to encode the response: %v", err)
return
}
- h.t.Write(s, reply, &transport.Options{})
+ h.t.Write(s, hdr, data, &transport.Options{})
h.t.WriteStatus(s, status.New(codes.OK, ""))
}
diff --git a/rpc_util.go b/rpc_util.go
index be8444a..caded65 100644
--- a/rpc_util.go
+++ b/rpc_util.go
@@ -288,19 +288,20 @@
return pf, msg, nil
}
-// encode serializes msg and prepends the message header. If msg is nil, it
-// generates the message header of 0 message length.
-func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayload *stats.OutPayload) ([]byte, error) {
- var (
- b []byte
- length uint
+// encode serializes msg and returns a buffer of message header and a buffer of msg.
+// If msg is nil, it generates the message header and an empty msg buffer.
+func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayload *stats.OutPayload) ([]byte, []byte, error) {
+ var b []byte
+ const (
+ payloadLen = 1
+ sizeLen = 4
)
+
if msg != nil {
var err error
- // TODO(zhaoq): optimize to reduce memory alloc and copying.
b, err = c.Marshal(msg)
if err != nil {
- return nil, Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
+ return nil, nil, Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
}
if outPayload != nil {
outPayload.Payload = msg
@@ -310,39 +311,28 @@
}
if cp != nil {
if err := cp.Do(cbuf, b); err != nil {
- return nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
+ return nil, nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
}
b = cbuf.Bytes()
}
- length = uint(len(b))
- }
- if length > math.MaxUint32 {
- return nil, Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", length)
}
- const (
- payloadLen = 1
- sizeLen = 4
- )
+ if len(b) > math.MaxUint32 {
+ return nil, nil, Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
+ }
- var buf = make([]byte, payloadLen+sizeLen+len(b))
-
- // Write payload format
+ bufHeader := make([]byte, payloadLen+sizeLen)
if cp == nil {
- buf[0] = byte(compressionNone)
+ bufHeader[0] = byte(compressionNone)
} else {
- buf[0] = byte(compressionMade)
+ bufHeader[0] = byte(compressionMade)
}
// Write length of b into buf
- binary.BigEndian.PutUint32(buf[1:], uint32(length))
- // Copy encoded msg to buf
- copy(buf[5:], b)
-
+ binary.BigEndian.PutUint32(bufHeader[payloadLen:], uint32(len(b)))
if outPayload != nil {
- outPayload.WireLength = len(buf)
+ outPayload.WireLength = payloadLen + sizeLen + len(b)
}
-
- return buf, nil
+ return bufHeader, b, nil
}
func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) error {
diff --git a/rpc_util_test.go b/rpc_util_test.go
index 7cbad49..23c471e 100644
--- a/rpc_util_test.go
+++ b/rpc_util_test.go
@@ -104,14 +104,15 @@
msg proto.Message
cp Compressor
// outputs
- b []byte
- err error
+ hdr []byte
+ data []byte
+ err error
}{
- {nil, nil, []byte{0, 0, 0, 0, 0}, nil},
+ {nil, nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil},
} {
- b, err := encode(protoCodec{}, test.msg, nil, nil, nil)
- if err != test.err || !bytes.Equal(b, test.b) {
- t.Fatalf("encode(_, _, %v, _) = %v, %v\nwant %v, %v", test.cp, b, err, test.b, test.err)
+ hdr, data, err := encode(protoCodec{}, test.msg, nil, nil, nil)
+ if err != test.err || !bytes.Equal(hdr, test.hdr) || !bytes.Equal(data, test.data) {
+ t.Fatalf("encode(_, _, %v, _) = %v, %v, %v\nwant %v, %v, %v", test.cp, hdr, data, err, test.hdr, test.data, test.err)
}
}
}
@@ -164,8 +165,8 @@
// bytes.
func bmEncode(b *testing.B, mSize int) {
msg := &perfpb.Buffer{Body: make([]byte, mSize)}
- encoded, _ := encode(protoCodec{}, msg, nil, nil, nil)
- encodedSz := int64(len(encoded))
+ encodeHdr, encodeData, _ := encode(protoCodec{}, msg, nil, nil, nil)
+ encodedSz := int64(len(encodeHdr) + len(encodeData))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
diff --git a/server.go b/server.go
index 5885c6c..86fe20a 100644
--- a/server.go
+++ b/server.go
@@ -677,15 +677,15 @@
if s.opts.statsHandler != nil {
outPayload = &stats.OutPayload{}
}
- p, err := encode(s.opts.codec, msg, cp, cbuf, outPayload)
+ hdr, data, err := encode(s.opts.codec, msg, cp, cbuf, outPayload)
if err != nil {
grpclog.Errorln("grpc: server failed to encode response: ", err)
return err
}
- if len(p) > s.opts.maxSendMessageSize {
- return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(p), s.opts.maxSendMessageSize)
+ if len(data) > s.opts.maxSendMessageSize {
+ return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize)
}
- err = t.Write(stream, p, opts)
+ err = t.Write(stream, hdr, data, opts)
if err == nil && outPayload != nil {
outPayload.SentTime = time.Now()
s.opts.statsHandler.HandleRPC(stream.Context(), outPayload)
diff --git a/stream.go b/stream.go
index c155d3d..2fcf368 100644
--- a/stream.go
+++ b/stream.go
@@ -362,7 +362,7 @@
Client: true,
}
}
- out, err := encode(cs.codec, m, cs.cp, cs.cbuf, outPayload)
+ hdr, data, err := encode(cs.codec, m, cs.cp, cs.cbuf, outPayload)
defer func() {
if cs.cbuf != nil {
cs.cbuf.Reset()
@@ -374,10 +374,10 @@
if cs.c.maxSendMessageSize == nil {
return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)")
}
- if len(out) > *cs.c.maxSendMessageSize {
- return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(out), *cs.c.maxSendMessageSize)
+ if len(data) > *cs.c.maxSendMessageSize {
+ return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize)
}
- err = cs.t.Write(cs.s, out, &transport.Options{Last: false})
+ err = cs.t.Write(cs.s, hdr, data, &transport.Options{Last: false})
if err == nil && outPayload != nil {
outPayload.SentTime = time.Now()
cs.statsHandler.HandleRPC(cs.statsCtx, outPayload)
@@ -449,7 +449,7 @@
}
func (cs *clientStream) CloseSend() (err error) {
- err = cs.t.Write(cs.s, nil, &transport.Options{Last: true})
+ err = cs.t.Write(cs.s, nil, nil, &transport.Options{Last: true})
defer func() {
if err != nil {
cs.finish(err)
@@ -608,7 +608,7 @@
if ss.statsHandler != nil {
outPayload = &stats.OutPayload{}
}
- out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload)
+ hdr, data, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload)
defer func() {
if ss.cbuf != nil {
ss.cbuf.Reset()
@@ -617,10 +617,10 @@
if err != nil {
return err
}
- if len(out) > ss.maxSendMessageSize {
- return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(out), ss.maxSendMessageSize)
+ if len(data) > ss.maxSendMessageSize {
+ return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize)
}
- if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil {
+ if err := ss.t.Write(ss.s, hdr, data, &transport.Options{Last: false}); err != nil {
return toRPCErr(err)
}
if outPayload != nil {
diff --git a/transport/handler_server.go b/transport/handler_server.go
index 85b8ee0..0489fad 100644
--- a/transport/handler_server.go
+++ b/transport/handler_server.go
@@ -255,9 +255,10 @@
}
}
-func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error {
+func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
return ht.do(func() {
ht.writeCommonHeaders(s)
+ ht.rw.Write(hdr)
ht.rw.Write(data)
if !opts.Delay {
ht.rw.(http.Flusher).Flush()
diff --git a/transport/http2_client.go b/transport/http2_client.go
index 8546d09..5f22913 100644
--- a/transport/http2_client.go
+++ b/transport/http2_client.go
@@ -683,8 +683,15 @@
// should proceed only if Write returns nil.
// TODO(zhaoq): opts.Delay is ignored in this implementation. Support it later
// if it improves the performance.
-func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
- r := bytes.NewBuffer(data)
+func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
+ secondStart := http2MaxFrameLen - len(hdr)%http2MaxFrameLen
+ if len(data) < secondStart {
+ secondStart = len(data)
+ }
+ hdr = append(hdr, data[:secondStart]...)
+ data = data[secondStart:]
+ isLastSlice := (len(data) == 0)
+ r := bytes.NewBuffer(hdr)
var (
p []byte
oqv uint32
@@ -726,9 +733,6 @@
endStream bool
forceFlush bool
)
- if opts.Last && r.Len() == 0 {
- endStream = true
- }
// Indicate there is a writer who is about to write a data frame.
t.framer.adjustNumWriters(1)
// Got some quota. Try to acquire writing privilege on the transport.
@@ -768,10 +772,22 @@
t.writableChan <- 0
continue
}
- if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 {
- // Do a force flush iff this is last frame for the entire gRPC message
- // and the caller is the only writer at this moment.
- forceFlush = true
+ if r.Len() == 0 {
+ if isLastSlice {
+ if opts.Last {
+ endStream = true
+ }
+ if t.framer.adjustNumWriters(0) == 1 {
+ // Do a force flush iff this is last frame for the entire gRPC message
+ // and the caller is the only writer at this moment.
+ forceFlush = true
+ }
+ } else {
+ isLastSlice = true
+ if len(data) != 0 {
+ r = bytes.NewBuffer(data)
+ }
+ }
}
// If WriteData fails, all the pending streams will be handled
// by http2Client.Close(). No explicit CloseStream() needs to be
diff --git a/transport/http2_server.go b/transport/http2_server.go
index 6ee6f40..302651b 100644
--- a/transport/http2_server.go
+++ b/transport/http2_server.go
@@ -827,8 +827,15 @@
// Write converts the data into HTTP2 data frame and sends it out. Non-nil error
// is returns if it fails (e.g., framing error, transport error).
-func (t *http2Server) Write(s *Stream, data []byte, opts *Options) (err error) {
+func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) (err error) {
// TODO(zhaoq): Support multi-writers for a single stream.
+ secondStart := http2MaxFrameLen - len(hdr)%http2MaxFrameLen
+ if len(data) < secondStart {
+ secondStart = len(data)
+ }
+ hdr = append(hdr, data[:secondStart]...)
+ data = data[secondStart:]
+ isLastSlice := (len(data) == 0)
var writeHeaderFrame bool
s.mu.Lock()
if s.state == streamDone {
@@ -842,7 +849,7 @@
if writeHeaderFrame {
t.WriteHeader(s, nil)
}
- r := bytes.NewBuffer(data)
+ r := bytes.NewBuffer(hdr)
var (
p []byte
oqv uint32
@@ -921,8 +928,15 @@
continue
}
var forceFlush bool
- if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 && !opts.Last {
- forceFlush = true
+ if r.Len() == 0 {
+ if isLastSlice {
+ if t.framer.adjustNumWriters(0) == 1 && !opts.Last {
+ forceFlush = true
+ }
+ } else {
+ r = bytes.NewBuffer(data)
+ isLastSlice = true
+ }
}
// Reset ping strikes when sending data since this might cause
// the peer to send ping.
diff --git a/transport/transport.go b/transport/transport.go
index ec0fe67..c5732be 100644
--- a/transport/transport.go
+++ b/transport/transport.go
@@ -564,7 +564,7 @@
// Write sends the data for the given stream. A nil stream indicates
// the write is to be performed on the transport as a whole.
- Write(s *Stream, data []byte, opts *Options) error
+ Write(s *Stream, hdr []byte, data []byte, opts *Options) error
// NewStream creates a Stream for an RPC.
NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error)
@@ -606,7 +606,7 @@
// Write sends the data for the given stream.
// Write may not be called on all streams.
- Write(s *Stream, data []byte, opts *Options) error
+ Write(s *Stream, hdr []byte, data []byte, opts *Options) error
// WriteStatus sends the status of a stream to the client. WriteStatus is
// the final call made on a stream and always occurs.
diff --git a/transport/transport_test.go b/transport/transport_test.go
index 8610478..be2d8da 100644
--- a/transport/transport_test.go
+++ b/transport/transport_test.go
@@ -92,7 +92,7 @@
t.Fatalf("handleStream got %v, want %v", p, req)
}
// send a response back to the client.
- h.t.Write(s, resp, &Options{})
+ h.t.Write(s, resp, nil, &Options{})
// send the trailer to end the stream.
h.t.WriteStatus(s, status.New(codes.OK, ""))
}
@@ -112,7 +112,7 @@
buf[0] = byte(0)
binary.BigEndian.PutUint32(buf[1:], uint32(sz))
copy(buf[5:], msg)
- h.t.Write(s, buf, &Options{})
+ h.t.Write(s, buf, nil, &Options{})
}
}
@@ -190,7 +190,7 @@
t.Fatalf("handleStream got %v, want %v", p, req)
}
// send a response back to the client.
- h.t.Write(s, resp, &Options{})
+ h.t.Write(s, resp, nil, &Options{})
// send the trailer to end the stream.
h.t.WriteStatus(s, status.New(codes.OK, ""))
}
@@ -215,7 +215,7 @@
// Wait before sending. Give time to client to start reading
// before server starts sending.
time.Sleep(2 * time.Second)
- h.t.Write(s, resp, &Options{})
+ h.t.Write(s, resp, nil, &Options{})
// send the trailer to end the stream.
h.t.WriteStatus(s, status.New(codes.OK, ""))
}
@@ -808,7 +808,7 @@
Last: true,
Delay: false,
}
- if err := ct.Write(s1, expectedRequest, &opts); err != nil && err != io.EOF {
+ if err := ct.Write(s1, expectedRequest, nil, &opts); err != nil && err != io.EOF {
t.Fatalf("failed to send data: %v", err)
}
p := make([]byte, len(expectedResponse))
@@ -845,7 +845,7 @@
Last: true,
Delay: false,
}
- if err := ct.Write(s, expectedRequest, &opts); err == nil || err == io.EOF {
+ if err := ct.Write(s, expectedRequest, nil, &opts); err == nil || err == io.EOF {
time.Sleep(5 * time.Millisecond)
// The following s.Recv()'s could error out because the
// underlying transport is gone.
@@ -889,7 +889,7 @@
if err != nil {
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
}
- if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
+ if err := ct.Write(s, expectedRequestLarge, nil, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponseLarge))
@@ -921,7 +921,7 @@
if err != nil {
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
}
- if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
+ if err := ct.Write(s, expectedRequestLarge, nil, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponseLarge))
@@ -959,7 +959,7 @@
// Give time to server to start reading before client starts sending.
time.Sleep(2 * time.Second)
- if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
+ if err := ct.Write(s, expectedRequestLarge, nil, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponseLarge))
@@ -1005,7 +1005,7 @@
Delay: false,
}
// The stream which was created before graceful close can still proceed.
- if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF {
+ if err := ct.Write(s, expectedRequest, nil, &opts); err != nil && err != io.EOF {
t.Fatalf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponse))
@@ -1034,7 +1034,7 @@
}
// Write should not be done successfully due to flow control.
msg := make([]byte, initialWindowSize*8)
- err = ct.Write(s, msg, &Options{Last: true, Delay: false})
+ err = ct.Write(s, msg, nil, &Options{Last: true, Delay: false})
expectedErr := streamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded)
if err != expectedErr {
t.Fatalf("Write got %v, want %v", err, expectedErr)
@@ -1311,7 +1311,7 @@
t.Fatalf("Failed to create 1st stream. Err: %v", err)
}
// Exhaust server's connection window.
- if err := client.Write(cstream1, make([]byte, defaultWindowSize), &Options{Last: true}); err != nil {
+ if err := client.Write(cstream1, make([]byte, defaultWindowSize), nil, &Options{Last: true}); err != nil {
t.Fatalf("Client failed to write data. Err: %v", err)
}
//Client should be able to create another stream and send data on it.
@@ -1319,7 +1319,7 @@
if err != nil {
t.Fatalf("Failed to create 2nd stream. Err: %v", err)
}
- if err := client.Write(cstream2, make([]byte, defaultWindowSize), &Options{}); err != nil {
+ if err := client.Write(cstream2, make([]byte, defaultWindowSize), nil, &Options{}); err != nil {
t.Fatalf("Client failed to write data. Err: %v", err)
}
// Get the streams on server.
@@ -1474,7 +1474,7 @@
t.Fatalf("Failed to open stream: %v", err)
}
d := make([]byte, 1)
- if err := ct.Write(s, d, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
+ if err := ct.Write(s, d, nil, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
t.Fatalf("Failed to write: %v", err)
}
// Read without window update.
@@ -1516,7 +1516,7 @@
Last: true,
Delay: false,
}
- if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF {
+ if err := ct.Write(s, expectedRequest, nil, &opts); err != nil && err != io.EOF {
t.Fatalf("Failed to write the request: %v", err)
}
p := make([]byte, http2MaxFrameLen)
@@ -1544,7 +1544,7 @@
Last: true,
Delay: false,
}
- if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF {
+ if err := ct.Write(s, expectedRequest, nil, &opts); err != nil && err != io.EOF {
t.Fatalf("Failed to write the request: %v", err)
}
p := make([]byte, http2MaxFrameLen)
@@ -1787,7 +1787,7 @@
opts := Options{}
header := make([]byte, 5)
for i := 1; i <= 10; i++ {
- if err := ct.Write(cstream, buf, &opts); err != nil {
+ if err := ct.Write(cstream, buf, nil, &opts); err != nil {
t.Fatalf("Error on client while writing message: %v", err)
}
if _, err := cstream.Read(header); err != nil {