merge master resolve conflicts
diff --git a/call.go b/call.go
index 13ca5b7..3f0549f 100644
--- a/call.go
+++ b/call.go
@@ -52,7 +52,7 @@
//
// TODO(zhaoq): Check whether the received message sequence is valid.
// TODO ctx is used for stats collection and processing. It is the context passed from the application.
-func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) {
+func recvResponse(ctx context.Context, dopts dialOptions, msgSizeLimit int, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) {
// Try to acquire header metadata from the server if there is any.
defer func() {
if err != nil {
@@ -73,7 +73,7 @@
}
}
for {
- if err = recv(p, dopts.codec, stream, dopts.dc, reply, dopts.maxMsgSize, inPayload); err != nil {
+ if err = recv(p, dopts.codec, stream, dopts.dc, reply, msgSizeLimit, inPayload); err != nil {
if err == io.EOF {
break
}
@@ -93,7 +93,7 @@
}
// sendRequest writes out various information of an RPC such as Context and Message.
-func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
+func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, msgSizeLimit int, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
stream, err := t.NewStream(ctx, callHdr)
if err != nil {
return nil, err
@@ -122,6 +122,9 @@
if err != nil {
return nil, Errorf(codes.Internal, "grpc: %v", err)
}
+ if len(outBuf) > msgSizeLimit {
+ return nil, Errorf(codes.InvalidArgument, "Sent message larger than max (%d vs. %d)", len(outBuf), msgSizeLimit)
+ }
err = t.Write(stream, outBuf, opts)
if err == nil && outPayload != nil {
outPayload.SentTime = time.Now()
@@ -149,13 +152,41 @@
func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (e error) {
c := defaultCallInfo
- if mc, ok := cc.getMethodConfig(method); ok {
- c.failFast = !mc.WaitForReady
- if mc.Timeout > 0 {
+ maxReceiveMessageSize := defaultClientMaxReceiveMessageSize
+ maxSendMessageSize := defaultClientMaxSendMessageSize
+ if mc, ok := cc.GetMethodConfig(method); ok {
+ if mc.WaitForReady != nil {
+ c.failFast = !*mc.WaitForReady
+ }
+
+ if mc.Timeout != nil && *mc.Timeout >= 0 {
var cancel context.CancelFunc
- ctx, cancel = context.WithTimeout(ctx, mc.Timeout)
+ ctx, cancel = context.WithTimeout(ctx, *mc.Timeout)
defer cancel()
}
+
+ if mc.MaxReqSize != nil && cc.dopts.maxSendMessageSize >= 0 {
+ maxSendMessageSize = min(*mc.MaxReqSize, cc.dopts.maxSendMessageSize)
+ } else if mc.MaxReqSize != nil {
+ maxSendMessageSize = *mc.MaxReqSize
+ } else if mc.MaxReqSize == nil && cc.dopts.maxSendMessageSize >= 0 {
+ maxSendMessageSize = cc.dopts.maxSendMessageSize
+ }
+
+ if mc.MaxRespSize != nil && cc.dopts.maxReceiveMessageSize >= 0 {
+ maxReceiveMessageSize = min(*mc.MaxRespSize, cc.dopts.maxReceiveMessageSize)
+ } else if mc.MaxRespSize != nil {
+ maxReceiveMessageSize = *mc.MaxRespSize
+ } else if mc.MaxRespSize == nil && cc.dopts.maxReceiveMessageSize >= 0 {
+ maxReceiveMessageSize = cc.dopts.maxReceiveMessageSize
+ }
+ } else {
+ if cc.dopts.maxSendMessageSize >= 0 {
+ maxSendMessageSize = cc.dopts.maxSendMessageSize
+ }
+ if cc.dopts.maxReceiveMessageSize >= 0 {
+ maxReceiveMessageSize = cc.dopts.maxReceiveMessageSize
+ }
}
for _, o := range opts {
if err := o.before(&c); err != nil {
@@ -246,7 +277,7 @@
if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
}
- stream, err = sendRequest(ctx, cc.dopts, cc.dopts.cp, callHdr, t, args, topts)
+ stream, err = sendRequest(ctx, cc.dopts, cc.dopts.cp, maxSendMessageSize, callHdr, t, args, topts)
if err != nil {
if put != nil {
put()
@@ -263,7 +294,7 @@
}
return toRPCErr(err)
}
- err = recvResponse(ctx, cc.dopts, t, &c, stream, reply)
+ err = recvResponse(ctx, cc.dopts, maxReceiveMessageSize, t, &c, stream, reply)
if err != nil {
if put != nil {
put()
diff --git a/clientconn.go b/clientconn.go
index aff4f5c..e01b138 100644
--- a/clientconn.go
+++ b/clientconn.go
@@ -36,8 +36,8 @@
import (
"errors"
"fmt"
- "math"
"net"
+ "strings"
"sync"
"time"
@@ -86,30 +86,48 @@
// dialOptions configure a Dial call. dialOptions are set by the DialOption
// values passed to Dial.
type dialOptions struct {
- unaryInt UnaryClientInterceptor
- streamInt StreamClientInterceptor
- codec Codec
- cp Compressor
- dc Decompressor
- bs backoffStrategy
- balancer Balancer
- block bool
- insecure bool
- timeout time.Duration
- scChan <-chan ServiceConfig
- copts transport.ConnectOptions
- maxMsgSize int
+ unaryInt UnaryClientInterceptor
+ streamInt StreamClientInterceptor
+ codec Codec
+ cp Compressor
+ dc Decompressor
+ bs backoffStrategy
+ balancer Balancer
+ block bool
+ insecure bool
+ timeout time.Duration
+ scChan <-chan ServiceConfig
+ copts transport.ConnectOptions
+ maxReceiveMessageSize int
+ maxSendMessageSize int
}
-const defaultClientMaxMsgSize = math.MaxInt32
+const (
+ defaultClientMaxReceiveMessageSize = 1024 * 1024 * 4
+ defaultClientMaxSendMessageSize = 1024 * 1024 * 4
+ defaultServerMaxReceiveMessageSize = 1024 * 1024 * 4
+ defaultServerMaxSendMessageSize = 1024 * 1024 * 4
+)
// DialOption configures how we set up the connection.
type DialOption func(*dialOptions)
-// WithMaxMsgSize returns a DialOption which sets the maximum message size the client can receive.
+// WithMaxMsgSize Deprecated: use WithMaxReceiveMessageSize instead.
func WithMaxMsgSize(s int) DialOption {
+ return WithMaxReceiveMessageSize(s)
+}
+
+// WithMaxReceiveMessageSize returns a DialOption which sets the maximum message size the client can receive. Negative input is invalid and has the same effect as not setting the field.
+func WithMaxReceiveMessageSize(s int) DialOption {
return func(o *dialOptions) {
- o.maxMsgSize = s
+ o.maxReceiveMessageSize = s
+ }
+}
+
+// WithMaxSendMessageSize returns a DialOption which sets the maximum message size the client can send. Negative input is invalid and has the same effect as not seeting the field.
+func WithMaxSendMessageSize(s int) DialOption {
+ return func(o *dialOptions) {
+ o.maxSendMessageSize = s
}
}
@@ -305,7 +323,11 @@
conns: make(map[Address]*addrConn),
}
cc.ctx, cc.cancel = context.WithCancel(context.Background())
- cc.dopts.maxMsgSize = defaultClientMaxMsgSize
+
+ // initialize maxReceiveMessageSize and maxSendMessageSize to -1 before applying DialOption functions to distinguish whether the user set the message limit or not.
+ cc.dopts.maxReceiveMessageSize = -1
+ cc.dopts.maxSendMessageSize = -1
+
for _, opt := range opts {
opt(&cc.dopts)
}
@@ -337,14 +359,13 @@
}()
if cc.dopts.scChan != nil {
- // Wait for the initial service config.
+ // Try to get an initial service config.
select {
case sc, ok := <-cc.dopts.scChan:
if ok {
cc.sc = sc
}
- case <-ctx.Done():
- return nil, ctx.Err()
+ default:
}
}
// Set defaults.
@@ -616,11 +637,16 @@
return nil
}
+// GetMethodConfig gets the method config of the input method. If there's no exact match for the input method (i.e. /service/method), we will return the default config for all methods under the service (/service/).
// TODO: Avoid the locking here.
-func (cc *ClientConn) getMethodConfig(method string) (m MethodConfig, ok bool) {
+func (cc *ClientConn) GetMethodConfig(method string) (m MethodConfig, ok bool) {
cc.mu.RLock()
defer cc.mu.RUnlock()
m, ok = cc.sc.Methods[method]
+ if !ok {
+ i := strings.LastIndex(method, "/")
+ m, ok = cc.sc.Methods[method[:i+1]]
+ }
return
}
diff --git a/rpc_util.go b/rpc_util.go
index 4d12528..18d1f0d 100644
--- a/rpc_util.go
+++ b/rpc_util.go
@@ -214,7 +214,7 @@
// No other error values or types must be returned, which also means
// that the underlying io.Reader must not return an incompatible
// error.
-func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) {
+func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) {
if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
return 0, nil, err
}
@@ -225,8 +225,8 @@
if length == 0 {
return pf, nil, nil
}
- if length > uint32(maxMsgSize) {
- return 0, nil, Errorf(codes.Internal, "grpc: received message length %d exceeding the max size %d", length, maxMsgSize)
+ if length > uint32(maxReceiveMessageSize) {
+ return 0, nil, Errorf(codes.InvalidArgument, "grpc: Received message larger than max (%d vs. %d)", length, maxReceiveMessageSize)
}
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
// of making it for each message:
@@ -310,8 +310,8 @@
return nil
}
-func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int, inPayload *stats.InPayload) error {
- pf, d, err := p.recvMsg(maxMsgSize)
+func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload) error {
+ pf, d, err := p.recvMsg(maxReceiveMessageSize)
if err != nil {
return err
}
@@ -327,10 +327,10 @@
return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
}
}
- if len(d) > maxMsgSize {
+ if len(d) > maxReceiveMessageSize {
// TODO: Revisit the error code. Currently keep it consistent with java
// implementation.
- return Errorf(codes.Internal, "grpc: received a message of %d bytes exceeding %d limit", len(d), maxMsgSize)
+ return Errorf(codes.InvalidArgument, "grpc: Received message larger than max (%d vs. %d)", len(d), maxReceiveMessageSize)
}
if err := c.Unmarshal(d, m); err != nil {
return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
@@ -433,24 +433,22 @@
// WaitForReady indicates whether RPCs sent to this method should wait until
// the connection is ready by default (!failfast). The value specified via the
// gRPC client API will override the value set here.
- WaitForReady bool
+ WaitForReady *bool
// Timeout is the default timeout for RPCs sent to this method. The actual
// deadline used will be the minimum of the value specified here and the value
// set by the application via the gRPC client API. If either one is not set,
// then the other will be used. If neither is set, then the RPC has no deadline.
- Timeout time.Duration
+ Timeout *time.Duration
// MaxReqSize is the maximum allowed payload size for an individual request in a
// stream (client->server) in bytes. The size which is measured is the serialized
// payload after per-message compression (but before stream compression) in bytes.
// The actual value used is the minumum of the value specified here and the value set
// by the application via the gRPC client API. If either one is not set, then the other
// will be used. If neither is set, then the built-in default is used.
- // TODO: support this.
- MaxReqSize uint32
+ MaxReqSize *int
// MaxRespSize is the maximum allowed payload size for an individual response in a
// stream (server->client) in bytes.
- // TODO: support this.
- MaxRespSize uint32
+ MaxRespSize *int
}
// ServiceConfig is provided by the service provider and contains parameters for how
@@ -461,6 +459,9 @@
// via grpc.WithBalancer will override this.
LB Balancer
// Methods contains a map for the methods in this service.
+ // If there is an exact match for a method (i.e. /service/method) in the map, use the corresponding MethodConfig.
+ // If there's no exact match, look for the default config for all methods under the service (/service/) and use the corresponding MethodConfig.
+ // Otherwise, the method has no MethodConfig to use.
Methods map[string]MethodConfig
}
@@ -474,3 +475,10 @@
// Version is the current grpc version.
const Version = "1.3.0-dev"
+
+func min(a, b int) int {
+ if a < b {
+ return a
+ }
+ return b
+}
diff --git a/server.go b/server.go
index b15f71c..10fa50d 100644
--- a/server.go
+++ b/server.go
@@ -107,24 +107,24 @@
}
type options struct {
- creds credentials.TransportCredentials
- codec Codec
- cp Compressor
- dc Decompressor
- maxMsgSize int
- unaryInt UnaryServerInterceptor
- streamInt StreamServerInterceptor
- inTapHandle tap.ServerInHandle
- statsHandler stats.Handler
- maxConcurrentStreams uint32
- useHandlerImpl bool // use http.Handler-based server
- unknownStreamDesc *StreamDesc
- keepaliveParams keepalive.ServerParameters
- keepalivePolicy keepalive.EnforcementPolicy
+ creds credentials.TransportCredentials
+ codec Codec
+ cp Compressor
+ dc Decompressor
+ maxMsgSize int
+ unaryInt UnaryServerInterceptor
+ streamInt StreamServerInterceptor
+ inTapHandle tap.ServerInHandle
+ statsHandler stats.Handler
+ maxConcurrentStreams uint32
+ maxReceiveMessageSize int
+ maxSendMessageSize int
+ useHandlerImpl bool // use http.Handler-based server
+ unknownStreamDesc *StreamDesc
+ keepaliveParams keepalive.ServerParameters
+ keepalivePolicy keepalive.EnforcementPolicy
}
-var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit
-
// A ServerOption sets options.
type ServerOption func(*options)
@@ -163,11 +163,24 @@
}
}
-// MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages.
-// If this is not set, gRPC uses the default 4MB.
+// MaxMsgSize Deprecated: use MaxReceiveMessageSize instead.
func MaxMsgSize(m int) ServerOption {
+ return MaxReceiveMessageSize(m)
+}
+
+// MaxReceiveMessageSize returns a ServerOption to set the max message size in bytes for inbound mesages.
+// If this is not set, gRPC uses the default 4MB.
+func MaxReceiveMessageSize(m int) ServerOption {
return func(o *options) {
- o.maxMsgSize = m
+ o.maxReceiveMessageSize = m
+ }
+}
+
+// MaxSendMessageSize returns a ServerOption to set the max message size in bytes for outbound mesages.
+// If this is not set, gRPC uses the default 4MB.
+func MaxSendMessageSize(m int) ServerOption {
+ return func(o *options) {
+ o.maxSendMessageSize = m
}
}
@@ -229,7 +242,7 @@
// UnknownServiceHandler returns a ServerOption that allows for adding a custom
// unknown service handler. The provided method is a bidi-streaming RPC service
-// handler that will be invoked instead of returning the the "unimplemented" gRPC
+// handler that will be invoked instead of returning the "unimplemented" gRPC
// error whenever a request is received for an unregistered service or method.
// The handling function has full access to the Context of the request and the
// stream, and the invocation passes through interceptors.
@@ -249,7 +262,8 @@
// started to accept requests yet.
func NewServer(opt ...ServerOption) *Server {
var opts options
- opts.maxMsgSize = defaultMaxMsgSize
+ opts.maxReceiveMessageSize = defaultServerMaxReceiveMessageSize
+ opts.maxSendMessageSize = defaultServerMaxSendMessageSize
for _, o := range opt {
o(&opts)
}
@@ -629,6 +643,9 @@
// the optimal option.
grpclog.Fatalf("grpc: Server failed to encode response %v", err)
}
+ if len(p) > s.opts.maxSendMessageSize {
+ return status.Errorf(codes.InvalidArgument, "Sent message larger than max (%d vs. %d)", len(p), s.opts.maxSendMessageSize)
+ }
err = t.Write(stream, p, opts)
if err == nil && outPayload != nil {
outPayload.SentTime = time.Now()
@@ -673,7 +690,7 @@
}
p := &parser{r: stream}
for { // TODO: delete
- pf, req, err := p.recvMsg(s.opts.maxMsgSize)
+ pf, req, err := p.recvMsg(s.opts.maxReceiveMessageSize)
if err == io.EOF {
// The entire stream is done (for unary RPC only).
return err
@@ -731,10 +748,10 @@
return Errorf(codes.Internal, err.Error())
}
}
- if len(req) > s.opts.maxMsgSize {
+ if len(req) > s.opts.maxReceiveMessageSize {
// TODO: Revisit the error code. Currently keep it consistent with
// java implementation.
- return status.Errorf(codes.Internal, "grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize)
+ return status.Errorf(codes.InvalidArgument, "Received message larger than max (%d vs. %d)", len(req), s.opts.maxReceiveMessageSize)
}
if err := s.opts.codec.Unmarshal(req, v); err != nil {
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
@@ -830,15 +847,16 @@
stream.SetSendCompress(s.opts.cp.Type())
}
ss := &serverStream{
- t: t,
- s: stream,
- p: &parser{r: stream},
- codec: s.opts.codec,
- cp: s.opts.cp,
- dc: s.opts.dc,
- maxMsgSize: s.opts.maxMsgSize,
- trInfo: trInfo,
- statsHandler: sh,
+ t: t,
+ s: stream,
+ p: &parser{r: stream},
+ codec: s.opts.codec,
+ cp: s.opts.cp,
+ dc: s.opts.dc,
+ maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
+ maxSendMessageSize: s.opts.maxSendMessageSize,
+ trInfo: trInfo,
+ statsHandler: sh,
}
if ss.cp != nil {
ss.cbuf = new(bytes.Buffer)
diff --git a/stream.go b/stream.go
index 008ff10..02ec28f 100644
--- a/stream.go
+++ b/stream.go
@@ -113,10 +113,40 @@
cancel context.CancelFunc
)
c := defaultCallInfo
- if mc, ok := cc.getMethodConfig(method); ok {
- c.failFast = !mc.WaitForReady
- if mc.Timeout > 0 {
- ctx, cancel = context.WithTimeout(ctx, mc.Timeout)
+ maxReceiveMessageSize := defaultClientMaxReceiveMessageSize
+ maxSendMessageSize := defaultClientMaxSendMessageSize
+ if mc, ok := cc.GetMethodConfig(method); ok {
+ if mc.WaitForReady != nil {
+ c.failFast = !*mc.WaitForReady
+ }
+
+ if mc.Timeout != nil && *mc.Timeout >= 0 {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithTimeout(ctx, *mc.Timeout)
+ defer cancel()
+ }
+
+ if mc.MaxReqSize != nil && cc.dopts.maxSendMessageSize >= 0 {
+ maxSendMessageSize = min(*mc.MaxReqSize, cc.dopts.maxSendMessageSize)
+ } else if mc.MaxReqSize != nil {
+ maxSendMessageSize = *mc.MaxReqSize
+ } else if mc.MaxReqSize == nil && cc.dopts.maxSendMessageSize >= 0 {
+ maxSendMessageSize = cc.dopts.maxSendMessageSize
+ }
+
+ if mc.MaxRespSize != nil && cc.dopts.maxReceiveMessageSize >= 0 {
+ maxReceiveMessageSize = min(*mc.MaxRespSize, cc.dopts.maxReceiveMessageSize)
+ } else if mc.MaxRespSize != nil {
+ maxReceiveMessageSize = *mc.MaxRespSize
+ } else if mc.MaxRespSize == nil && cc.dopts.maxReceiveMessageSize >= 0 {
+ maxReceiveMessageSize = cc.dopts.maxReceiveMessageSize
+ }
+ } else {
+ if cc.dopts.maxSendMessageSize >= 0 {
+ maxSendMessageSize = cc.dopts.maxSendMessageSize
+ }
+ if cc.dopts.maxReceiveMessageSize >= 0 {
+ maxReceiveMessageSize = cc.dopts.maxReceiveMessageSize
}
}
for _, o := range opts {
@@ -208,14 +238,15 @@
break
}
cs := &clientStream{
- opts: opts,
- c: c,
- desc: desc,
- codec: cc.dopts.codec,
- cp: cc.dopts.cp,
- dc: cc.dopts.dc,
- maxMsgSize: cc.dopts.maxMsgSize,
- cancel: cancel,
+ opts: opts,
+ c: c,
+ desc: desc,
+ codec: cc.dopts.codec,
+ cp: cc.dopts.cp,
+ dc: cc.dopts.dc,
+ maxReceiveMessageSize: maxReceiveMessageSize,
+ maxSendMessageSize: maxSendMessageSize,
+ cancel: cancel,
put: put,
t: t,
@@ -256,18 +287,19 @@
// clientStream implements a client side Stream.
type clientStream struct {
- opts []CallOption
- c callInfo
- t transport.ClientTransport
- s *transport.Stream
- p *parser
- desc *StreamDesc
- codec Codec
- cp Compressor
- cbuf *bytes.Buffer
- dc Decompressor
- maxMsgSize int
- cancel context.CancelFunc
+ opts []CallOption
+ c callInfo
+ t transport.ClientTransport
+ s *transport.Stream
+ p *parser
+ desc *StreamDesc
+ codec Codec
+ cp Compressor
+ cbuf *bytes.Buffer
+ dc Decompressor
+ maxReceiveMessageSize int
+ maxSendMessageSize int
+ cancel context.CancelFunc
tracing bool // set to EnableTracing when the clientStream is created.
@@ -351,6 +383,9 @@
if err != nil {
return Errorf(codes.Internal, "grpc: %v", err)
}
+ if len(out) > cs.maxSendMessageSize {
+ return Errorf(codes.InvalidArgument, "Sent message larger than max (%d vs. %d)", len(out), cs.maxSendMessageSize)
+ }
err = cs.t.Write(cs.s, out, &transport.Options{Last: false})
if err == nil && outPayload != nil {
outPayload.SentTime = time.Now()
@@ -366,7 +401,7 @@
Client: true,
}
}
- err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, inPayload)
+ err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxReceiveMessageSize, inPayload)
defer func() {
// err != nil indicates the termination of the stream.
if err != nil {
@@ -389,7 +424,7 @@
}
// Special handling for client streaming rpc.
// This recv expects EOF or errors, so we don't collect inPayload.
- err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, nil)
+ err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxReceiveMessageSize, nil)
cs.closeTransportStream(err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@@ -510,15 +545,16 @@
// serverStream implements a server side Stream.
type serverStream struct {
- t transport.ServerTransport
- s *transport.Stream
- p *parser
- codec Codec
- cp Compressor
- dc Decompressor
- cbuf *bytes.Buffer
- maxMsgSize int
- trInfo *traceInfo
+ t transport.ServerTransport
+ s *transport.Stream
+ p *parser
+ codec Codec
+ cp Compressor
+ dc Decompressor
+ cbuf *bytes.Buffer
+ maxReceiveMessageSize int
+ maxSendMessageSize int
+ trInfo *traceInfo
statsHandler stats.Handler
@@ -577,6 +613,9 @@
err = Errorf(codes.Internal, "grpc: %v", err)
return err
}
+ if len(out) > ss.maxSendMessageSize {
+ return Errorf(codes.InvalidArgument, "Sent message larger than max (%d vs. %d)", len(out), ss.maxSendMessageSize)
+ }
if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil {
return toRPCErr(err)
}
@@ -606,7 +645,7 @@
if ss.statsHandler != nil {
inPayload = &stats.InPayload{}
}
- if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize, inPayload); err != nil {
+ if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload); err != nil {
if err == io.EOF {
return err
}
diff --git a/test/end2end_test.go b/test/end2end_test.go
index 54840ee..881b1a4 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -429,20 +429,25 @@
cancel context.CancelFunc
// Configurable knobs, after newTest returns:
- testServer testpb.TestServiceServer // nil means none
- healthServer *health.Server // nil means disabled
- maxStream uint32
- tapHandle tap.ServerInHandle
- maxMsgSize int
- userAgent string
- clientCompression bool
- serverCompression bool
- unaryClientInt grpc.UnaryClientInterceptor
- streamClientInt grpc.StreamClientInterceptor
- unaryServerInt grpc.UnaryServerInterceptor
- streamServerInt grpc.StreamServerInterceptor
- unknownHandler grpc.StreamHandler
- sc <-chan grpc.ServiceConfig
+ testServer testpb.TestServiceServer // nil means none
+ healthServer *health.Server // nil means disabled
+ maxStream uint32
+ tapHandle tap.ServerInHandle
+ maxMsgSize int
+ maxClientReceiveMsgSize int
+ maxClientSendMsgSize int
+ maxServerReceiveMsgSize int
+ maxServerSendMsgSize int
+ userAgent string
+ clientCompression bool
+ serverCompression bool
+ timeout time.Duration
+ unaryClientInt grpc.UnaryClientInterceptor
+ streamClientInt grpc.StreamClientInterceptor
+ unaryServerInt grpc.UnaryServerInterceptor
+ streamServerInt grpc.StreamServerInterceptor
+ unknownHandler grpc.StreamHandler
+ sc <-chan grpc.ServiceConfig
// srv and srvAddr are set once startServer is called.
srv *grpc.Server
@@ -478,6 +483,12 @@
t: t,
e: e,
maxStream: math.MaxUint32,
+ // Default value 0 is meaningful (0 byte msg size limit), thus using -1 to indiciate the field is unset.
+ maxClientReceiveMsgSize: -1,
+ maxClientSendMsgSize: -1,
+ maxServerReceiveMsgSize: -1,
+ maxServerSendMsgSize: -1,
+ maxMsgSize: -1,
}
te.ctx, te.cancel = context.WithCancel(context.Background())
return te
@@ -489,9 +500,15 @@
te.testServer = ts
te.t.Logf("Running test in %s environment...", te.e.name)
sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(te.maxStream)}
- if te.maxMsgSize > 0 {
+ if te.maxMsgSize >= 0 {
sopts = append(sopts, grpc.MaxMsgSize(te.maxMsgSize))
}
+ if te.maxServerReceiveMsgSize >= 0 {
+ sopts = append(sopts, grpc.MaxReceiveMessageSize(te.maxServerReceiveMsgSize))
+ }
+ if te.maxServerSendMsgSize >= 0 {
+ sopts = append(sopts, grpc.MaxSendMessageSize(te.maxServerSendMsgSize))
+ }
if te.tapHandle != nil {
sopts = append(sopts, grpc.InTapHandle(te.tapHandle))
}
@@ -583,9 +600,18 @@
if te.streamClientInt != nil {
opts = append(opts, grpc.WithStreamInterceptor(te.streamClientInt))
}
- if te.maxMsgSize > 0 {
+ if te.maxMsgSize >= 0 {
opts = append(opts, grpc.WithMaxMsgSize(te.maxMsgSize))
}
+ if te.maxClientReceiveMsgSize >= 0 {
+ opts = append(opts, grpc.WithMaxReceiveMessageSize(te.maxClientReceiveMsgSize))
+ }
+ if te.maxClientSendMsgSize >= 0 {
+ opts = append(opts, grpc.WithMaxSendMessageSize(te.maxClientSendMsgSize))
+ }
+ if te.timeout > 0 {
+ opts = append(opts, grpc.WithTimeout(te.timeout))
+ }
switch te.e.security {
case "tls":
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
@@ -1043,13 +1069,17 @@
func TestServiceConfig(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
- testServiceConfig(t, e)
+ testGetMethodConfig(t, e)
+ testServiceConfigWaitForReady(t, e)
+ // Timeout logic (min of service config and client API) is implemented implicitly in context. WithTimeout(). No need to test here.
+ testServiceConfigMaxMsgSize(t, e)
}
}
-func testServiceConfig(t *testing.T, e env) {
+func testServiceConfigSetup(t *testing.T, e env) (*test, chan grpc.ServiceConfig) {
te := newTest(t, e)
- ch := make(chan grpc.ServiceConfig)
+ // We write before read.
+ ch := make(chan grpc.ServiceConfig, 1)
te.sc = ch
te.userAgent = testAppUA
te.declareLogNoise(
@@ -1058,37 +1088,78 @@
"grpc: addrConn.resetTransport failed to create client transport: connection error",
"Failed to dial : context canceled; please retry.",
)
+ return te, ch
+}
+
+func newBool(b bool) (a *bool) {
+ a = new(bool)
+ *a = b
+ return
+}
+
+func newInt(b int) (a *int) {
+ a = new(int)
+ *a = b
+ return
+}
+
+func newDuration(b time.Duration) (a *time.Duration) {
+ a = new(time.Duration)
+ *a = b
+ return
+}
+
+func testGetMethodConfig(t *testing.T, e env) {
+ te, ch := testServiceConfigSetup(t, e)
defer te.tearDown()
- var wg sync.WaitGroup
- wg.Add(1)
- go func() {
- defer wg.Done()
- mc := grpc.MethodConfig{
- WaitForReady: true,
- Timeout: time.Millisecond,
- }
- m := make(map[string]grpc.MethodConfig)
- m["/grpc.testing.TestService/EmptyCall"] = mc
- m["/grpc.testing.TestService/FullDuplexCall"] = mc
- sc := grpc.ServiceConfig{
- Methods: m,
- }
- ch <- sc
- }()
+ mc1 := grpc.MethodConfig{
+ WaitForReady: newBool(true),
+ Timeout: newDuration(time.Millisecond),
+ }
+ mc2 := grpc.MethodConfig{WaitForReady: newBool(false)}
+ m := make(map[string]grpc.MethodConfig)
+ m["/grpc.testing.TestService/EmptyCall"] = mc1
+ m["/grpc.testing.TestService/"] = mc2
+ sc := grpc.ServiceConfig{
+ Methods: m,
+ }
+ ch <- sc
+
cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
// The following RPCs are expected to become non-fail-fast ones with 1ms deadline.
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.DeadlineExceeded {
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded)
}
- if _, err := tc.FullDuplexCall(context.Background()); grpc.Code(err) != codes.DeadlineExceeded {
- t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want %s", err, codes.DeadlineExceeded)
+
+ m = make(map[string]grpc.MethodConfig)
+ m["/grpc.testing.TestService/UnaryCall"] = mc1
+ m["/grpc.testing.TestService/"] = mc2
+ sc = grpc.ServiceConfig{
+ Methods: m,
}
- wg.Wait()
- // Generate a service config update.
+ ch <- sc
+ // Wait for the new service config to propagate.
+ for {
+ if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) == codes.DeadlineExceeded {
+ continue
+ }
+ break
+ }
+ if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable {
+ t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.Unavailable)
+ }
+}
+
+func testServiceConfigWaitForReady(t *testing.T, e env) {
+ te, ch := testServiceConfigSetup(t, e)
+ defer te.tearDown()
+
+ // Case1: Client API set failfast to be false, and service config set wait_for_ready to be false, Client API should win, and the rpc will wait until deadline exceeds.
mc := grpc.MethodConfig{
- WaitForReady: false,
+ WaitForReady: newBool(false),
+ Timeout: newDuration(time.Millisecond),
}
m := make(map[string]grpc.MethodConfig)
m["/grpc.testing.TestService/EmptyCall"] = mc
@@ -1097,19 +1168,506 @@
Methods: m,
}
ch <- sc
- // Loop until the new update becomes effective.
+
+ cc := te.clientConn()
+ tc := testpb.NewTestServiceClient(cc)
+ // The following RPCs are expected to become non-fail-fast ones with 1ms deadline.
+ if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded {
+ t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded)
+ }
+ if _, err := tc.FullDuplexCall(context.Background(), grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded {
+ t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want %s", err, codes.DeadlineExceeded)
+ }
+
+ // Generate a service config update.
+ // Case2: Client API does not set failfast, and service config set wait_for_ready to be true, and the rpc will wait until deadline exceeds.
+ mc.WaitForReady = newBool(true)
+ m = make(map[string]grpc.MethodConfig)
+ m["/grpc.testing.TestService/EmptyCall"] = mc
+ m["/grpc.testing.TestService/FullDuplexCall"] = mc
+ sc = grpc.ServiceConfig{
+ Methods: m,
+ }
+ ch <- sc
+
+ // Wait for the new service config to take effect.
+ mc, ok := cc.GetMethodConfig("/grpc.testing.TestService/EmptyCall")
for {
- if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable {
+ if ok && !*mc.WaitForReady {
+ time.Sleep(100 * time.Millisecond)
+ mc, ok = cc.GetMethodConfig("/grpc.testing.TestService/EmptyCall")
continue
}
break
}
- // The following RPCs are expected to become fail-fast.
- if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable {
- t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.Unavailable)
+ // The following RPCs are expected to become non-fail-fast ones with 1ms deadline.
+ if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.DeadlineExceeded {
+ t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded)
}
- if _, err := tc.FullDuplexCall(context.Background()); grpc.Code(err) != codes.Unavailable {
- t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want %s", err, codes.Unavailable)
+ if _, err := tc.FullDuplexCall(context.Background()); grpc.Code(err) != codes.DeadlineExceeded {
+ t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want %s", err, codes.DeadlineExceeded)
+ }
+}
+
+func testServiceConfigMaxMsgSize(t *testing.T, e env) {
+ // Setting up values and objects shared across all test cases.
+ const smallSize = 1
+ const largeSize = 1024
+ const extraLargeSize = 2048
+
+ smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize)
+ if err != nil {
+ t.Fatal(err)
+ }
+ largePayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, largeSize)
+ if err != nil {
+ t.Fatal(err)
+ }
+ extraLargePayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, extraLargeSize)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ mc := grpc.MethodConfig{
+ MaxReqSize: newInt(extraLargeSize),
+ MaxRespSize: newInt(extraLargeSize),
+ }
+
+ m := make(map[string]grpc.MethodConfig)
+ m["/grpc.testing.TestService/UnaryCall"] = mc
+ m["/grpc.testing.TestService/FullDuplexCall"] = mc
+ sc := grpc.ServiceConfig{
+ Methods: m,
+ }
+ // Case1: sc set maxReqSize to 2048 (send), maxRespSize to 2048 (recv).
+ te1, ch1 := testServiceConfigSetup(t, e)
+ te1.startServer(&testServer{security: e.security})
+ defer te1.tearDown()
+
+ ch1 <- sc
+ tc := testpb.NewTestServiceClient(te1.clientConn())
+
+ req := &testpb.SimpleRequest{
+ ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
+ ResponseSize: proto.Int32(int32(extraLargeSize)),
+ Payload: smallPayload,
+ }
+ // test for unary RPC recv
+ if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument)
+ }
+
+ // test for unary RPC send
+ req.Payload = extraLargePayload
+ req.ResponseSize = proto.Int32(int32(smallSize))
+ if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument)
+ }
+
+ // test for streaming RPC recv
+ respParam := []*testpb.ResponseParameters{
+ {
+ Size: proto.Int32(int32(extraLargeSize)),
+ },
+ }
+ sreq := &testpb.StreamingOutputCallRequest{
+ ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
+ ResponseParameters: respParam,
+ Payload: smallPayload,
+ }
+ stream, err := tc.FullDuplexCall(te1.ctx)
+ if err != nil {
+ t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
+ }
+ if err := stream.Send(sreq); err != nil {
+ t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
+ }
+ if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.InvalidArgument)
+ }
+
+ // test for streaming RPC send
+ respParam[0].Size = proto.Int32(int32(smallSize))
+ sreq.Payload = extraLargePayload
+ stream, err = tc.FullDuplexCall(te1.ctx)
+ if err != nil {
+ t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
+ }
+ if err := stream.Send(sreq); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("%v.Send(%v) = %v, want _, error code: %s", stream, sreq, err, codes.InvalidArgument)
+ }
+
+ // Case2: Client API set maxReqSize to 1024 (send), maxRespSize to 1024 (recv). Sc sets maxReqSize to 2048 (send), maxRespSize to 2048 (recv).
+ te2, ch2 := testServiceConfigSetup(t, e)
+ te2.maxClientReceiveMsgSize = 1024
+ te2.maxClientSendMsgSize = 1024
+ te2.startServer(&testServer{security: e.security})
+ defer te2.tearDown()
+ ch2 <- sc
+ tc = testpb.NewTestServiceClient(te2.clientConn())
+
+ // Test for unary RPC recv.
+ req.Payload = smallPayload
+ req.ResponseSize = proto.Int32(int32(largeSize))
+
+ if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument)
+ }
+
+ // Test for unary RPC send.
+ req.Payload = largePayload
+ req.ResponseSize = proto.Int32(int32(smallSize))
+ if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument)
+ }
+
+ // Test for streaming RPC recv.
+ stream, err = tc.FullDuplexCall(te2.ctx)
+ respParam[0].Size = proto.Int32(int32(largeSize))
+ sreq.Payload = smallPayload
+ if err != nil {
+ t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
+ }
+ if err := stream.Send(sreq); err != nil {
+ t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
+ }
+ if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.InvalidArgument)
+ }
+
+ // Test for streaming RPC send.
+ respParam[0].Size = proto.Int32(int32(smallSize))
+ sreq.Payload = largePayload
+ stream, err = tc.FullDuplexCall(te2.ctx)
+ if err != nil {
+ t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
+ }
+ if err := stream.Send(sreq); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("%v.Send(%v) = %v, want _, error code: %s", stream, sreq, err, codes.InvalidArgument)
+ }
+
+ // Case3: Client API set maxReqSize to 4096 (send), maxRespSize to 4096 (recv). Sc sets maxReqSize to 2048 (send), maxRespSize to 2048 (recv).
+ te3, ch3 := testServiceConfigSetup(t, e)
+ te3.maxClientReceiveMsgSize = 4096
+ te3.maxClientSendMsgSize = 4096
+ te3.startServer(&testServer{security: e.security})
+ defer te3.tearDown()
+ ch3 <- sc
+ tc = testpb.NewTestServiceClient(te3.clientConn())
+
+ // Test for unary RPC recv.
+ req.Payload = smallPayload
+ req.ResponseSize = proto.Int32(int32(largeSize))
+
+ if _, err := tc.UnaryCall(context.Background(), req); err != nil {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want <nil>", err)
+ }
+
+ req.ResponseSize = proto.Int32(int32(extraLargeSize))
+ if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument)
+ }
+
+ // Test for unary RPC send.
+ req.Payload = largePayload
+ req.ResponseSize = proto.Int32(int32(smallSize))
+ if _, err := tc.UnaryCall(context.Background(), req); err != nil {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want <nil>", err)
+ }
+
+ req.Payload = extraLargePayload
+ if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument)
+ }
+
+ // Test for streaming RPC recv.
+ stream, err = tc.FullDuplexCall(te3.ctx)
+ if err != nil {
+ t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
+ }
+ respParam[0].Size = proto.Int32(int32(largeSize))
+ sreq.Payload = smallPayload
+
+ if err := stream.Send(sreq); err != nil {
+ t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
+ }
+ if _, err := stream.Recv(); err != nil {
+ t.Fatalf("%v.Recv() = _, %v, want <nil>", stream, err)
+ }
+
+ respParam[0].Size = proto.Int32(int32(extraLargeSize))
+
+ if err := stream.Send(sreq); err != nil {
+ t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
+ }
+ if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.InvalidArgument)
+ }
+
+ // Test for streaming RPC send.
+ respParam[0].Size = proto.Int32(int32(smallSize))
+ sreq.Payload = largePayload
+ stream, err = tc.FullDuplexCall(te3.ctx)
+ if err != nil {
+ t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
+ }
+ if err := stream.Send(sreq); err != nil {
+ t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
+ }
+ sreq.Payload = extraLargePayload
+ if err := stream.Send(sreq); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("%v.Send(%v) = %v, want _, error code: %s", stream, sreq, err, codes.InvalidArgument)
+ }
+}
+
+func TestMsgSizeDefaultAndAPI(t *testing.T) {
+ defer leakCheck(t)()
+ for _, e := range listTestEnv() {
+ testMaxMsgSizeClientDefault(t, e)
+ testMaxMsgSizeClientAPI(t, e)
+ testMaxMsgSizeServerAPI(t, e)
+ }
+}
+
+func testMaxMsgSizeClientDefault(t *testing.T, e env) {
+ te := newTest(t, e)
+ te.userAgent = testAppUA
+ // To avoid error on server side.
+ te.maxServerSendMsgSize = 5 * 1024 * 1024
+ te.declareLogNoise(
+ "transport: http2Client.notifyError got notified that the client transport was broken EOF",
+ "grpc: addrConn.transportMonitor exits due to: grpc: the connection is closing",
+ "grpc: addrConn.resetTransport failed to create client transport: connection error",
+ "Failed to dial : context canceled; please retry.",
+ )
+ te.startServer(&testServer{security: e.security})
+
+ defer te.tearDown()
+ tc := testpb.NewTestServiceClient(te.clientConn())
+
+ const smallSize = 1
+ const largeSize = 4 * 1024 * 1024
+ smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ largePayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, largeSize)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req := &testpb.SimpleRequest{
+ ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
+ ResponseSize: proto.Int32(int32(largeSize)),
+ Payload: smallPayload,
+ }
+ // Test for unary RPC recv.
+ if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument)
+ }
+
+ // Test for unary RPC send.
+ req.Payload = largePayload
+ req.ResponseSize = proto.Int32(int32(smallSize))
+ if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument)
+ }
+
+ respParam := []*testpb.ResponseParameters{
+ {
+ Size: proto.Int32(int32(largeSize)),
+ },
+ }
+ sreq := &testpb.StreamingOutputCallRequest{
+ ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
+ ResponseParameters: respParam,
+ Payload: smallPayload,
+ }
+
+ // Test for streaming RPC recv.
+ stream, err := tc.FullDuplexCall(te.ctx)
+ if err != nil {
+ t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
+ }
+ if err := stream.Send(sreq); err != nil {
+ t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
+ }
+ if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.InvalidArgument)
+ }
+
+ // Test for streaming RPC send.
+ respParam[0].Size = proto.Int32(int32(smallSize))
+ sreq.Payload = largePayload
+ stream, err = tc.FullDuplexCall(te.ctx)
+ if err != nil {
+ t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
+ }
+ if err := stream.Send(sreq); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("%v.Send(%v) = %v, want _, error codes: %s", stream, sreq, err, codes.InvalidArgument)
+ }
+}
+
+func testMaxMsgSizeClientAPI(t *testing.T, e env) {
+ te := newTest(t, e)
+ te.userAgent = testAppUA
+ // To avoid error on server side.
+ te.maxServerSendMsgSize = 5 * 1024 * 1024
+ te.maxClientReceiveMsgSize = 1024
+ te.maxClientSendMsgSize = 1024
+ te.declareLogNoise(
+ "transport: http2Client.notifyError got notified that the client transport was broken EOF",
+ "grpc: addrConn.transportMonitor exits due to: grpc: the connection is closing",
+ "grpc: addrConn.resetTransport failed to create client transport: connection error",
+ "Failed to dial : context canceled; please retry.",
+ )
+ te.startServer(&testServer{security: e.security})
+
+ defer te.tearDown()
+ tc := testpb.NewTestServiceClient(te.clientConn())
+
+ const smallSize = 1
+ const largeSize = 1024
+ smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ largePayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, largeSize)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req := &testpb.SimpleRequest{
+ ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
+ ResponseSize: proto.Int32(int32(largeSize)),
+ Payload: smallPayload,
+ }
+ // Test for unary RPC recv.
+ if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument)
+ }
+
+ // Test for unary RPC send.
+ req.Payload = largePayload
+ req.ResponseSize = proto.Int32(int32(smallSize))
+ if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument)
+ }
+
+ respParam := []*testpb.ResponseParameters{
+ {
+ Size: proto.Int32(int32(largeSize)),
+ },
+ }
+ sreq := &testpb.StreamingOutputCallRequest{
+ ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
+ ResponseParameters: respParam,
+ Payload: smallPayload,
+ }
+
+ // Test for streaming RPC recv.
+ stream, err := tc.FullDuplexCall(te.ctx)
+ if err != nil {
+ t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
+ }
+ if err := stream.Send(sreq); err != nil {
+ t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
+ }
+ if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.InvalidArgument)
+ }
+
+ // Test for streaming RPC send.
+ respParam[0].Size = proto.Int32(int32(smallSize))
+ sreq.Payload = largePayload
+ stream, err = tc.FullDuplexCall(te.ctx)
+ if err != nil {
+ t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
+ }
+ if err := stream.Send(sreq); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("%v.Send(%v) = %v, want _, error code: %s", stream, sreq, err, codes.InvalidArgument)
+ }
+}
+
+func testMaxMsgSizeServerAPI(t *testing.T, e env) {
+ te := newTest(t, e)
+ te.userAgent = testAppUA
+ te.maxServerReceiveMsgSize = 1024
+ te.maxServerSendMsgSize = 1024
+ te.declareLogNoise(
+ "transport: http2Client.notifyError got notified that the client transport was broken EOF",
+ "grpc: addrConn.transportMonitor exits due to: grpc: the connection is closing",
+ "grpc: addrConn.resetTransport failed to create client transport: connection error",
+ "Failed to dial : context canceled; please retry.",
+ )
+ te.startServer(&testServer{security: e.security})
+
+ defer te.tearDown()
+ tc := testpb.NewTestServiceClient(te.clientConn())
+
+ const smallSize = 1
+ const largeSize = 1024
+ smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ largePayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, largeSize)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req := &testpb.SimpleRequest{
+ ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
+ ResponseSize: proto.Int32(int32(largeSize)),
+ Payload: smallPayload,
+ }
+ // Test for unary RPC send.
+ if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument)
+ }
+
+ // Test for unary RPC recv.
+ req.Payload = largePayload
+ req.ResponseSize = proto.Int32(int32(smallSize))
+ if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument)
+ }
+
+ respParam := []*testpb.ResponseParameters{
+ {
+ Size: proto.Int32(int32(largeSize)),
+ },
+ }
+ sreq := &testpb.StreamingOutputCallRequest{
+ ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
+ ResponseParameters: respParam,
+ Payload: smallPayload,
+ }
+
+ // Test for streaming RPC send.
+ stream, err := tc.FullDuplexCall(te.ctx)
+ if err != nil {
+ t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
+ }
+ if err := stream.Send(sreq); err != nil {
+ t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
+ }
+ if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.InvalidArgument)
+ }
+
+ // Test for streaming RPC recv.
+ respParam[0].Size = proto.Int32(int32(smallSize))
+ sreq.Payload = largePayload
+ stream, err = tc.FullDuplexCall(te.ctx)
+ if err != nil {
+ t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
+ }
+ if err := stream.Send(sreq); err != nil {
+ t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
+ }
+ if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.InvalidArgument)
}
}
@@ -1429,6 +1987,7 @@
}
}
+// Test backward-compatability API for setting msg size limit.
func TestExceedMsgLimit(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
@@ -1455,23 +2014,23 @@
t.Fatal(err)
}
- // test on server side for unary RPC
+ // Test on server side for unary RPC.
req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseSize: proto.Int32(smallSize),
Payload: payload,
}
- if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal {
- t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.Internal)
+ if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument)
}
- // test on client side for unary RPC
+ // Test on client side for unary RPC.
req.ResponseSize = proto.Int32(int32(te.maxMsgSize) + 1)
req.Payload = smallPayload
- if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal {
- t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.Internal)
+ if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.InvalidArgument)
}
- // test on server side for streaming RPC
+ // Test on server side for streaming RPC.
stream, err := tc.FullDuplexCall(te.ctx)
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
@@ -1495,11 +2054,11 @@
if err := stream.Send(sreq); err != nil {
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
}
- if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal {
- t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.Internal)
+ if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.InvalidArgument)
}
- // test on client side for streaming RPC
+ // Test on client side for streaming RPC.
stream, err = tc.FullDuplexCall(te.ctx)
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
@@ -1509,8 +2068,8 @@
if err := stream.Send(sreq); err != nil {
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
}
- if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal {
- t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.Internal)
+ if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.InvalidArgument {
+ t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.InvalidArgument)
}
}