Introduce new Compressor/Decompressor API (#1428)
diff --git a/call.go b/call.go
index 4bd673c..368980d 100644
--- a/call.go
+++ b/call.go
@@ -19,7 +19,6 @@
package grpc
import (
- "bytes"
"io"
"time"
@@ -27,6 +26,7 @@
"golang.org/x/net/trace"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes"
+ "google.golang.org/grpc/encoding"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
@@ -62,7 +62,7 @@
if c.maxReceiveMessageSize == nil {
return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
}
- if err = recv(p, dopts.codec, stream, dopts.dc, reply, *c.maxReceiveMessageSize, inPayload); err != nil {
+ if err = recv(p, dopts.codec, stream, dopts.dc, reply, *c.maxReceiveMessageSize, inPayload, encoding.GetCompressor(c.compressorType)); err != nil {
if err == io.EOF {
break
}
@@ -89,18 +89,17 @@
}
}()
var (
- cbuf *bytes.Buffer
outPayload *stats.OutPayload
)
- if compressor != nil {
- cbuf = new(bytes.Buffer)
- }
if dopts.copts.StatsHandler != nil {
outPayload = &stats.OutPayload{
Client: true,
}
}
- hdr, data, err := encode(dopts.codec, args, compressor, cbuf, outPayload)
+ if c.compressorType != "" && encoding.GetCompressor(c.compressorType) == nil {
+ return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", c.compressorType)
+ }
+ hdr, data, err := encode(dopts.codec, args, compressor, outPayload, encoding.GetCompressor(c.compressorType))
if err != nil {
return err
}
@@ -223,7 +222,9 @@
Host: cc.authority,
Method: method,
}
- if cc.dopts.cp != nil {
+ if c.compressorType != "" {
+ callHdr.SendCompress = c.compressorType
+ } else if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type()
}
if c.creds != nil {
diff --git a/clientconn.go b/clientconn.go
index 5f5aac4..2c5d8ee 100644
--- a/clientconn.go
+++ b/clientconn.go
@@ -104,6 +104,16 @@
// DialOption configures how we set up the connection.
type DialOption func(*dialOptions)
+// UseCompressor returns a CallOption which sets the compressor used when sending the request.
+// If WithCompressor is set, UseCompressor has higher priority.
+// This API is EXPERIMENTAL.
+func UseCompressor(name string) CallOption {
+ return beforeCall(func(c *callInfo) error {
+ c.compressorType = name
+ return nil
+ })
+}
+
// WithWriteBufferSize lets you set the size of write buffer, this determines how much data can be batched
// before doing a write on the wire.
func WithWriteBufferSize(s int) DialOption {
@@ -156,7 +166,8 @@
}
// WithCompressor returns a DialOption which sets a CompressorGenerator for generating message
-// compressor.
+// compressor. It has lower priority than the compressor set by RegisterCompressor.
+// This function is deprecated.
func WithCompressor(cp Compressor) DialOption {
return func(o *dialOptions) {
o.cp = cp
@@ -164,7 +175,8 @@
}
// WithDecompressor returns a DialOption which sets a DecompressorGenerator for generating
-// message decompressor.
+// message decompressor. It has higher priority than the decompressor set by RegisterCompressor.
+// This function is deprecated.
func WithDecompressor(dc Decompressor) DialOption {
return func(o *dialOptions) {
o.dc = dc
diff --git a/encoding/encoding.go b/encoding/encoding.go
new file mode 100644
index 0000000..f6cc3d6
--- /dev/null
+++ b/encoding/encoding.go
@@ -0,0 +1,57 @@
+/*
+ *
+ * Copyright 2017 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+// Package encoding defines the interface for the compressor and the functions
+// to register and get the compossor.
+// This package is EXPERIMENTAL.
+package encoding
+
+import (
+ "io"
+)
+
+var registerCompressor = make(map[string]Compressor)
+
+// Compressor is used for compressing and decompressing when sending or receiving messages.
+type Compressor interface {
+ // Compress writes the data written to wc to w after compressing it. If an error
+ // occurs while initializing the compressor, that error is returned instead.
+ Compress(w io.Writer) (io.WriteCloser, error)
+ // Decompress reads data from r, decompresses it, and provides the uncompressed data
+ // via the returned io.Reader. If an error occurs while initializing the decompressor, that error
+ // is returned instead.
+ Decompress(r io.Reader) (io.Reader, error)
+ // Name is the name of the compression codec and is used to set the content coding header.
+ Name() string
+}
+
+// RegisterCompressor registers the compressor with gRPC by its name. It can be activated when
+// sending an RPC via grpc.UseCompressor(). It will be automatically accessed when receiving a
+// message based on the content coding header. Servers also use it to send a response with the
+// same encoding as the request.
+//
+// NOTE: this function must only be called during initialization time (i.e. in an init() function). If
+// multiple Compressors are registered with the same name, the one registered last will take effect.
+func RegisterCompressor(c Compressor) {
+ registerCompressor[c.Name()] = c
+}
+
+// GetCompressor returns Compressor for the given compressor name.
+func GetCompressor(name string) Compressor {
+ return registerCompressor[name]
+}
diff --git a/encoding/gzip/gzip.go b/encoding/gzip/gzip.go
new file mode 100644
index 0000000..fb4385e
--- /dev/null
+++ b/encoding/gzip/gzip.go
@@ -0,0 +1,93 @@
+/*
+ *
+ * Copyright 2017 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+// Package gzip implements and registers the gzip compressor
+// during the initialization.
+// This package is EXPERIMENTAL.
+package gzip
+
+import (
+ "compress/gzip"
+ "io"
+ "io/ioutil"
+ "sync"
+
+ "google.golang.org/grpc/encoding"
+)
+
+func init() {
+ c := &compressor{}
+ c.poolCompressor.New = func() interface{} {
+ return &writer{Writer: gzip.NewWriter(ioutil.Discard), pool: &c.poolCompressor}
+ }
+ encoding.RegisterCompressor(c)
+}
+
+type writer struct {
+ *gzip.Writer
+ pool *sync.Pool
+}
+
+func (c *compressor) Compress(w io.Writer) (io.WriteCloser, error) {
+ z := c.poolCompressor.Get().(*writer)
+ z.Writer.Reset(w)
+ return z, nil
+}
+
+func (z *writer) Close() error {
+ defer z.pool.Put(z)
+ return z.Writer.Close()
+}
+
+type reader struct {
+ *gzip.Reader
+ pool *sync.Pool
+}
+
+func (c *compressor) Decompress(r io.Reader) (io.Reader, error) {
+ z, inPool := c.poolDecompressor.Get().(*reader)
+ if !inPool {
+ newZ, err := gzip.NewReader(r)
+ if err != nil {
+ return nil, err
+ }
+ return &reader{Reader: newZ, pool: &c.poolDecompressor}, nil
+ }
+ if err := z.Reset(r); err != nil {
+ c.poolDecompressor.Put(z)
+ return nil, err
+ }
+ return z, nil
+}
+
+func (z *reader) Read(p []byte) (n int, err error) {
+ n, err = z.Reader.Read(p)
+ if err == io.EOF {
+ z.pool.Put(z)
+ }
+ return n, err
+}
+
+func (c *compressor) Name() string {
+ return "gzip"
+}
+
+type compressor struct {
+ poolCompressor sync.Pool
+ poolDecompressor sync.Pool
+}
diff --git a/rpc_util.go b/rpc_util.go
index 9c8d881..eccf84d 100644
--- a/rpc_util.go
+++ b/rpc_util.go
@@ -31,6 +31,7 @@
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/encoding"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
@@ -122,6 +123,7 @@
// callInfo contains all related configuration and information about an RPC.
type callInfo struct {
+ compressorType string
failFast bool
headerMD metadata.MD
trailerMD metadata.MD
@@ -294,13 +296,16 @@
// encode serializes msg and returns a buffer of message header and a buffer of msg.
// If msg is nil, it generates the message header and an empty msg buffer.
-func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayload *stats.OutPayload) ([]byte, []byte, error) {
- var b []byte
+// TODO(ddyihai): eliminate extra Compressor parameter.
+func encode(c Codec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, []byte, error) {
+ var (
+ b []byte
+ cbuf *bytes.Buffer
+ )
const (
payloadLen = 1
sizeLen = 4
)
-
if msg != nil {
var err error
b, err = c.Marshal(msg)
@@ -313,24 +318,35 @@
outPayload.Data = b
outPayload.Length = len(b)
}
- if cp != nil {
- if err := cp.Do(cbuf, b); err != nil {
- return nil, nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
+ if compressor != nil || cp != nil {
+ cbuf = new(bytes.Buffer)
+ // Has compressor, check Compressor is set by UseCompressor first.
+ if compressor != nil {
+ z, _ := compressor.Compress(cbuf)
+ if _, err := z.Write(b); err != nil {
+ return nil, nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
+ }
+ z.Close()
+ } else {
+ // If Compressor is not set by UseCompressor, use default Compressor
+ if err := cp.Do(cbuf, b); err != nil {
+ return nil, nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
+ }
}
b = cbuf.Bytes()
}
}
-
if uint(len(b)) > math.MaxUint32 {
return nil, nil, Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
}
bufHeader := make([]byte, payloadLen+sizeLen)
- if cp == nil {
- bufHeader[0] = byte(compressionNone)
- } else {
+ if compressor != nil || cp != nil {
bufHeader[0] = byte(compressionMade)
+ } else {
+ bufHeader[0] = byte(compressionNone)
}
+
// Write length of b into buf
binary.BigEndian.PutUint32(bufHeader[payloadLen:], uint32(len(b)))
if outPayload != nil {
@@ -343,7 +359,7 @@
switch pf {
case compressionNone:
case compressionMade:
- if dc == nil || recvCompress != dc.Type() {
+ if (dc == nil || recvCompress != dc.Type()) && encoding.GetCompressor(recvCompress) == nil {
return Errorf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
}
default:
@@ -352,7 +368,9 @@
return nil
}
-func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload) error {
+// TODO(ddyihai): eliminate extra Compressor parameter.
+func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int,
+ inPayload *stats.InPayload, compressor encoding.Compressor) error {
pf, d, err := p.recvMsg(maxReceiveMessageSize)
if err != nil {
return err
@@ -364,9 +382,22 @@
return err
}
if pf == compressionMade {
- d, err = dc.Do(bytes.NewReader(d))
- if err != nil {
- return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
+ // To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
+ // use this decompressor as the default.
+ if dc != nil {
+ d, err = dc.Do(bytes.NewReader(d))
+ if err != nil {
+ return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
+ }
+ } else {
+ dcReader, err := compressor.Decompress(bytes.NewReader(d))
+ if err != nil {
+ return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
+ }
+ d, err = ioutil.ReadAll(dcReader)
+ if err != nil {
+ return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
+ }
}
}
if len(d) > maxReceiveMessageSize {
diff --git a/server.go b/server.go
index def301a..2c26db8 100644
--- a/server.go
+++ b/server.go
@@ -32,11 +32,14 @@
"sync"
"time"
+ "io/ioutil"
+
"golang.org/x/net/context"
"golang.org/x/net/http2"
"golang.org/x/net/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/encoding"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/keepalive"
@@ -187,6 +190,8 @@
}
// RPCCompressor returns a ServerOption that sets a compressor for outbound messages.
+// It has lower priority than the compressor set by RegisterCompressor.
+// This function is deprecated.
func RPCCompressor(cp Compressor) ServerOption {
return func(o *options) {
o.cp = cp
@@ -194,6 +199,8 @@
}
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages.
+// It has higher priority than the decompressor set by RegisterCompressor.
+// This function is deprecated.
func RPCDecompressor(dc Decompressor) ServerOption {
return func(o *options) {
o.dc = dc
@@ -701,16 +708,18 @@
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options) error {
var (
- cbuf *bytes.Buffer
outPayload *stats.OutPayload
)
- if cp != nil {
- cbuf = new(bytes.Buffer)
- }
if s.opts.statsHandler != nil {
outPayload = &stats.OutPayload{}
}
- hdr, data, err := encode(s.opts.codec, msg, cp, cbuf, outPayload)
+ if stream.RecvCompress() != "" {
+ // Server receives compressor, check compressor set by register and default.
+ if encoding.GetCompressor(stream.RecvCompress()) == nil && (cp == nil || cp != nil && cp.Type() != stream.RecvCompress()) {
+ return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", stream.RecvCompress())
+ }
+ }
+ hdr, data, err := encode(s.opts.codec, msg, cp, outPayload, encoding.GetCompressor(stream.RecvCompress()))
if err != nil {
grpclog.Errorln("grpc: server failed to encode response: ", err)
return err
@@ -754,7 +763,9 @@
}
}()
}
- if s.opts.cp != nil {
+ if stream.RecvCompress() != "" {
+ stream.SetSendCompress(stream.RecvCompress())
+ } else if s.opts.cp != nil {
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
stream.SetSendCompress(s.opts.cp.Type())
}
@@ -786,7 +797,6 @@
}
return err
}
-
if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil {
if st, ok := status.FromError(err); ok {
if e := t.WriteStatus(stream, st); e != nil {
@@ -812,9 +822,18 @@
}
if pf == compressionMade {
var err error
- req, err = s.opts.dc.Do(bytes.NewReader(req))
- if err != nil {
- return Errorf(codes.Internal, err.Error())
+ if s.opts.dc != nil {
+ req, err = s.opts.dc.Do(bytes.NewReader(req))
+ if err != nil {
+ return Errorf(codes.Internal, err.Error())
+ }
+ } else {
+ dcReader := encoding.GetCompressor(stream.RecvCompress())
+ tmp, _ := dcReader.Decompress(bytes.NewReader(req))
+ req, err = ioutil.ReadAll(tmp)
+ if err != nil {
+ return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
+ }
}
}
if len(req) > s.opts.maxReceiveMessageSize {
@@ -909,16 +928,19 @@
sh.HandleRPC(stream.Context(), end)
}()
}
- if s.opts.cp != nil {
+ if stream.RecvCompress() != "" {
+ stream.SetSendCompress(stream.RecvCompress())
+ } else if s.opts.cp != nil {
stream.SetSendCompress(s.opts.cp.Type())
}
ss := &serverStream{
- t: t,
- s: stream,
- p: &parser{r: stream},
- codec: s.opts.codec,
- cp: s.opts.cp,
- dc: s.opts.dc,
+ t: t,
+ s: stream,
+ p: &parser{r: stream},
+ codec: s.opts.codec,
+ cpType: stream.RecvCompress(),
+ cp: s.opts.cp,
+ dc: s.opts.dc,
maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
maxSendMessageSize: s.opts.maxSendMessageSize,
trInfo: trInfo,
diff --git a/stream.go b/stream.go
index b58f7f8..a659f14 100644
--- a/stream.go
+++ b/stream.go
@@ -19,7 +19,6 @@
package grpc
import (
- "bytes"
"errors"
"io"
"sync"
@@ -29,6 +28,7 @@
"golang.org/x/net/trace"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes"
+ "google.golang.org/grpc/encoding"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
@@ -151,7 +151,9 @@
// time soon, so we ask the transport to flush the header.
Flush: desc.ClientStreams,
}
- if cc.dopts.cp != nil {
+ if c.compressorType != "" {
+ callHdr.SendCompress = c.compressorType
+ } else if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type()
}
if c.creds != nil {
@@ -242,6 +244,7 @@
c: c,
desc: desc,
codec: cc.dopts.codec,
+ cpType: c.compressorType,
cp: cc.dopts.cp,
dc: cc.dopts.dc,
cancel: cancel,
@@ -292,6 +295,7 @@
p *parser
desc *StreamDesc
codec Codec
+ cpType string
cp Compressor
dc Decompressor
cancel context.CancelFunc
@@ -369,7 +373,10 @@
Client: true,
}
}
- hdr, data, err := encode(cs.codec, m, cs.cp, bytes.NewBuffer([]byte{}), outPayload)
+ if cs.cpType != "" && encoding.GetCompressor(cs.cpType) == nil {
+ return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", cs.cpType)
+ }
+ hdr, data, err := encode(cs.codec, m, cs.cp, outPayload, encoding.GetCompressor(cs.cpType))
if err != nil {
return err
}
@@ -397,7 +404,7 @@
if cs.c.maxReceiveMessageSize == nil {
return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
}
- err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload)
+ err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload, encoding.GetCompressor(cs.cpType))
defer func() {
// err != nil indicates the termination of the stream.
if err != nil {
@@ -423,7 +430,7 @@
if cs.c.maxReceiveMessageSize == nil {
return Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)")
}
- err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil)
+ err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, encoding.GetCompressor(cs.cpType))
cs.closeTransportStream(err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@@ -552,6 +559,7 @@
s *transport.Stream
p *parser
codec Codec
+ cpType string
cp Compressor
dc Decompressor
maxReceiveMessageSize int
@@ -609,7 +617,12 @@
if ss.statsHandler != nil {
outPayload = &stats.OutPayload{}
}
- hdr, data, err := encode(ss.codec, m, ss.cp, bytes.NewBuffer([]byte{}), outPayload)
+ if ss.cpType != "" {
+ if encoding.GetCompressor(ss.cpType) == nil && (ss.cp == nil || ss.cp != nil && ss.cp.Type() != ss.cpType) {
+ return Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", ss.cpType)
+ }
+ }
+ hdr, data, err := encode(ss.codec, m, ss.cp, outPayload, encoding.GetCompressor(ss.cpType))
if err != nil {
return err
}
@@ -649,7 +662,7 @@
if ss.statsHandler != nil {
inPayload = &stats.InPayload{}
}
- if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload); err != nil {
+ if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload, encoding.GetCompressor(ss.cpType)); err != nil {
if err == io.EOF {
return err
}
diff --git a/test/end2end_test.go b/test/end2end_test.go
index a5e81a9..1518dca 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -51,6 +51,7 @@
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
+ _ "google.golang.org/grpc/encoding/gzip"
_ "google.golang.org/grpc/grpclog/glogger"
"google.golang.org/grpc/health"
healthpb "google.golang.org/grpc/health/grpc_health_v1"
@@ -437,18 +438,24 @@
cancel context.CancelFunc
// Configurable knobs, after newTest returns:
- testServer testpb.TestServiceServer // nil means none
- healthServer *health.Server // nil means disabled
- maxStream uint32
- tapHandle tap.ServerInHandle
- maxMsgSize *int
- maxClientReceiveMsgSize *int
- maxClientSendMsgSize *int
- maxServerReceiveMsgSize *int
- maxServerSendMsgSize *int
- userAgent string
- clientCompression bool
- serverCompression bool
+ testServer testpb.TestServiceServer // nil means none
+ healthServer *health.Server // nil means disabled
+ maxStream uint32
+ tapHandle tap.ServerInHandle
+ maxMsgSize *int
+ maxClientReceiveMsgSize *int
+ maxClientSendMsgSize *int
+ maxServerReceiveMsgSize *int
+ maxServerSendMsgSize *int
+ userAgent string
+ // clientCompression and serverCompression are set to test the deprecated API
+ // WithCompressor and WithDecompressor.
+ clientCompression bool
+ serverCompression bool
+ // clientUseCompression is set to test the new compressor registration API UseCompressor.
+ clientUseCompression bool
+ // clientNopCompression is set to create a compressor whose type is not supported.
+ clientNopCompression bool
unaryClientInt grpc.UnaryClientInterceptor
streamClientInt grpc.StreamClientInterceptor
unaryServerInt grpc.UnaryServerInterceptor
@@ -594,6 +601,32 @@
te.srvAddr = addr
}
+type nopCompressor struct {
+ grpc.Compressor
+}
+
+// NewNopCompressor creates a compressor to test the case that type is not supported.
+func NewNopCompressor() grpc.Compressor {
+ return &nopCompressor{grpc.NewGZIPCompressor()}
+}
+
+func (c *nopCompressor) Type() string {
+ return "nop"
+}
+
+type nopDecompressor struct {
+ grpc.Decompressor
+}
+
+// NewNopDecompressor creates a decompressor to test the case that type is not supported.
+func NewNopDecompressor() grpc.Decompressor {
+ return &nopDecompressor{grpc.NewGZIPDecompressor()}
+}
+
+func (d *nopDecompressor) Type() string {
+ return "nop"
+}
+
func (te *test) clientConn() *grpc.ClientConn {
if te.cc != nil {
return te.cc
@@ -613,6 +646,15 @@
grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
)
}
+ if te.clientUseCompression {
+ opts = append(opts, grpc.WithDefaultCallOptions(grpc.UseCompressor("gzip")))
+ }
+ if te.clientNopCompression {
+ opts = append(opts,
+ grpc.WithCompressor(NewNopCompressor()),
+ grpc.WithDecompressor(NewNopDecompressor()),
+ )
+ }
if te.unaryClientInt != nil {
opts = append(opts, grpc.WithUnaryInterceptor(te.unaryClientInt))
}
@@ -3749,7 +3791,8 @@
func testCompressServerHasNoSupport(t *testing.T, e env) {
te := newTest(t, e)
te.serverCompression = false
- te.clientCompression = true
+ te.clientCompression = false
+ te.clientNopCompression = true
te.startServer(&testServer{security: e.security})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
@@ -5572,3 +5615,65 @@
t.Fatalf("Invoke with method %q, got %q, %v, want %q, true", testMethod, method, ok, testMethod)
}
}
+
+func TestCompressorRegister(t *testing.T) {
+ defer leakcheck.Check(t)
+ for _, e := range listTestEnv() {
+ testCompressorRegister(t, e)
+ }
+}
+
+func testCompressorRegister(t *testing.T, e env) {
+ te := newTest(t, e)
+ te.clientCompression = false
+ te.serverCompression = false
+ te.clientUseCompression = true
+
+ te.startServer(&testServer{security: e.security})
+ defer te.tearDown()
+ tc := testpb.NewTestServiceClient(te.clientConn())
+
+ // Unary call
+ const argSize = 271828
+ const respSize = 314159
+ payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req := &testpb.SimpleRequest{
+ ResponseType: testpb.PayloadType_COMPRESSABLE,
+ ResponseSize: respSize,
+ Payload: payload,
+ }
+ ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("something", "something"))
+ if _, err := tc.UnaryCall(ctx, req); err != nil {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, <nil>", err)
+ }
+ // Streaming RPC
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ stream, err := tc.FullDuplexCall(ctx)
+ if err != nil {
+ t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
+ }
+ respParam := []*testpb.ResponseParameters{
+ {
+ Size: 31415,
+ },
+ }
+ payload, err = newPayload(testpb.PayloadType_COMPRESSABLE, int32(31415))
+ if err != nil {
+ t.Fatal(err)
+ }
+ sreq := &testpb.StreamingOutputCallRequest{
+ ResponseType: testpb.PayloadType_COMPRESSABLE,
+ ResponseParameters: respParam,
+ Payload: payload,
+ }
+ if err := stream.Send(sreq); err != nil {
+ t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
+ }
+ if _, err := stream.Recv(); err != nil {
+ t.Fatalf("%v.Recv() = %v, want <nil>", stream, err)
+ }
+}