make max size a pointer type and initialize function a CallOption
diff --git a/call.go b/call.go
index a512060..688efed 100644
--- a/call.go
+++ b/call.go
@@ -52,7 +52,7 @@
//
// TODO(zhaoq): Check whether the received message sequence is valid.
// TODO ctx is used for stats collection and processing. It is the context passed from the application.
-func recvResponse(ctx context.Context, dopts dialOptions, msgSizeLimit int, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) {
+func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) {
// Try to acquire header metadata from the server if there is any.
defer func() {
if err != nil {
@@ -73,7 +73,7 @@
}
}
for {
- if err = recv(p, dopts.codec, stream, dopts.dc, reply, msgSizeLimit, inPayload); err != nil {
+ if err = recv(p, dopts.codec, stream, dopts.dc, reply, *c.maxReceiveMessageSize, inPayload); err != nil {
if err == io.EOF {
break
}
@@ -93,7 +93,7 @@
}
// sendRequest writes out various information of an RPC such as Context and Message.
-func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, msgSizeLimit int, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
+func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, c *callInfo, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
stream, err := t.NewStream(ctx, callHdr)
if err != nil {
return nil, err
@@ -122,8 +122,8 @@
if err != nil {
return nil, Errorf(codes.Internal, "grpc: %v", err)
}
- if len(outBuf) > msgSizeLimit {
- return nil, Errorf(codes.ResourceExhausted, "Sent message larger than max (%d vs. %d)", len(outBuf), msgSizeLimit)
+ if len(outBuf) > *c.maxSendMessageSize {
+ return nil, Errorf(codes.ResourceExhausted, "Sent message larger than max (%d vs. %d)", len(outBuf), *c.maxSendMessageSize)
}
err = t.Write(stream, outBuf, opts)
if err == nil && outPayload != nil {
@@ -152,7 +152,7 @@
func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (e error) {
c := defaultCallInfo
- mc := cc.GetMethodConfig(method)
+ mc, _ := cc.GetMethodConfig(method)
if mc.WaitForReady != nil {
c.failFast = !*mc.WaitForReady
}
@@ -163,9 +163,7 @@
defer cancel()
}
- maxSendMessageSize := getMaxSize(mc.MaxReqSize, cc.dopts.maxSendMessageSize, defaultClientMaxSendMessageSize)
- maxReceiveMessageSize := getMaxSize(mc.MaxRespSize, cc.dopts.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize)
-
+ opts = append(cc.dopts.callOptions, opts...)
for _, o := range opts {
if err := o.before(&c); err != nil {
return toRPCErr(err)
@@ -176,6 +174,10 @@
o.after(&c)
}
}()
+
+ c.maxSendMessageSize = getMaxSize(mc.MaxReqSize, c.maxSendMessageSize, defaultClientMaxSendMessageSize)
+ c.maxReceiveMessageSize = getMaxSize(mc.MaxRespSize, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize)
+
if EnableTracing {
c.traceInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
defer c.traceInfo.tr.Finish()
@@ -255,7 +257,7 @@
if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
}
- stream, err = sendRequest(ctx, cc.dopts, cc.dopts.cp, maxSendMessageSize, callHdr, t, args, topts)
+ stream, err = sendRequest(ctx, cc.dopts, cc.dopts.cp, &c, callHdr, t, args, topts)
if err != nil {
if put != nil {
put()
@@ -272,7 +274,7 @@
}
return toRPCErr(err)
}
- err = recvResponse(ctx, cc.dopts, maxReceiveMessageSize, t, &c, stream, reply)
+ err = recvResponse(ctx, cc.dopts, t, &c, stream, reply)
if err != nil {
if put != nil {
put()
diff --git a/clientconn.go b/clientconn.go
index 6556c3a..029899f 100644
--- a/clientconn.go
+++ b/clientconn.go
@@ -86,20 +86,19 @@
// dialOptions configure a Dial call. dialOptions are set by the DialOption
// values passed to Dial.
type dialOptions struct {
- unaryInt UnaryClientInterceptor
- streamInt StreamClientInterceptor
- codec Codec
- cp Compressor
- dc Decompressor
- bs backoffStrategy
- balancer Balancer
- block bool
- insecure bool
- timeout time.Duration
- scChan <-chan ServiceConfig
- copts transport.ConnectOptions
- maxReceiveMessageSize *int
- maxSendMessageSize *int
+ unaryInt UnaryClientInterceptor
+ streamInt StreamClientInterceptor
+ codec Codec
+ cp Compressor
+ dc Decompressor
+ bs backoffStrategy
+ balancer Balancer
+ block bool
+ insecure bool
+ timeout time.Duration
+ scChan <-chan ServiceConfig
+ copts transport.ConnectOptions
+ callOptions []CallOption
}
const (
@@ -114,20 +113,13 @@
// WithMaxMsgSize Deprecated: use WithMaxReceiveMessageSize instead.
func WithMaxMsgSize(s int) DialOption {
- return WithMaxReceiveMessageSize(s)
+ return WithDefaultCallOptions(WithMaxReceiveMessageSize(s))
}
-// WithMaxReceiveMessageSize returns a DialOption which sets the maximum message size the client can receive. Negative input is invalid and has the same effect as not setting the field.
-func WithMaxReceiveMessageSize(s int) DialOption {
+// WithDefaultCallOptions returns a DialOption which sets the default CallOptions for calls over the connection.
+func WithDefaultCallOptions(cos ...CallOption) DialOption {
return func(o *dialOptions) {
- *o.maxReceiveMessageSize = s
- }
-}
-
-// WithMaxSendMessageSize returns a DialOption which sets the maximum message size the client can send. Negative input is invalid and has the same effect as not seeting the field.
-func WithMaxSendMessageSize(s int) DialOption {
- return func(o *dialOptions) {
- *o.maxSendMessageSize = s
+ o.callOptions = append(o.callOptions, cos...)
}
}
@@ -642,13 +634,13 @@
// GetMethodConfig gets the method config of the input method. If there's no exact match for the input method (i.e. /service/method), we will return the default config for all methods under the service (/service/).
// TODO: Avoid the locking here.
-func (cc *ClientConn) GetMethodConfig(method string) (m MethodConfig) {
+func (cc *ClientConn) GetMethodConfig(method string) (m MethodConfig, ok bool) {
cc.mu.RLock()
defer cc.mu.RUnlock()
- m, ok := cc.sc.Methods[method]
+ m, ok = cc.sc.Methods[method]
if !ok {
i := strings.LastIndex(method, "/")
- m, _ = cc.sc.Methods[method[:i+1]]
+ m, ok = cc.sc.Methods[method[:i+1]]
}
return
}
diff --git a/rpc_util.go b/rpc_util.go
index f33d504..606294b 100644
--- a/rpc_util.go
+++ b/rpc_util.go
@@ -111,11 +111,13 @@
// callInfo contains all related configuration and information about an RPC.
type callInfo struct {
- failFast bool
- headerMD metadata.MD
- trailerMD metadata.MD
- peer *peer.Peer
- traceInfo traceInfo // in trace.go
+ failFast bool
+ headerMD metadata.MD
+ trailerMD metadata.MD
+ peer *peer.Peer
+ traceInfo traceInfo // in trace.go
+ maxReceiveMessageSize *int
+ maxSendMessageSize *int
}
var defaultCallInfo = callInfo{failFast: true}
@@ -181,6 +183,22 @@
})
}
+// WithMaxReceiveMessageSize returns a DialOption which sets the maximum message size the client can receive. Negative input is invalid and has the same effect as not setting the field.
+func WithMaxReceiveMessageSize(s int) CallOption {
+ return beforeCall(func(o *callInfo) error {
+ o.maxReceiveMessageSize = &s
+ return nil
+ })
+}
+
+// WithMaxSendMessageSize returns a DialOption which sets the maximum message size the client can send. Negative input is invalid and has the same effect as not seeting the field.
+func WithMaxSendMessageSize(s int) CallOption {
+ return beforeCall(func(o *callInfo) error {
+ o.maxSendMessageSize = &s
+ return nil
+ })
+}
+
// The format of the payload: compressed or not?
type payloadFormat uint8
@@ -476,24 +494,24 @@
// Version is the current grpc version.
const Version = "1.3.0-dev"
-func min(a, b int) int {
- if a < b {
+func min(a, b *int) *int {
+ if *a < *b {
return a
}
return b
}
-func getMaxSize(mcMax, doptMax *int, defaultVal int) int {
+func getMaxSize(mcMax, doptMax *int, defaultVal int) *int {
if mcMax == nil && doptMax == nil {
- return defaultVal
+ return &defaultVal
}
if mcMax != nil && doptMax != nil {
- return min(*mcMax, *doptMax)
+ return min(mcMax, doptMax)
}
if mcMax != nil {
- return *mcMax
+ return mcMax
}
- return *doptMax
+ return doptMax
}
const grpcUA = "grpc-go/" + Version
diff --git a/stream.go b/stream.go
index cfe7b44..8dcd062 100644
--- a/stream.go
+++ b/stream.go
@@ -113,7 +113,7 @@
cancel context.CancelFunc
)
c := defaultCallInfo
- mc := cc.GetMethodConfig(method)
+ mc, _ := cc.GetMethodConfig(method)
if mc.WaitForReady != nil {
c.failFast = !*mc.WaitForReady
}
@@ -124,13 +124,15 @@
defer cancel()
}
- maxSendMessageSize := getMaxSize(mc.MaxReqSize, cc.dopts.maxSendMessageSize, defaultClientMaxSendMessageSize)
- maxReceiveMessageSize := getMaxSize(mc.MaxRespSize, cc.dopts.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize)
+ opts = append(cc.dopts.callOptions, opts...)
for _, o := range opts {
if err := o.before(&c); err != nil {
return nil, toRPCErr(err)
}
}
+ c.maxSendMessageSize = getMaxSize(mc.MaxReqSize, c.maxSendMessageSize, defaultClientMaxSendMessageSize)
+ c.maxReceiveMessageSize = getMaxSize(mc.MaxRespSize, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize)
+
callHdr := &transport.CallHdr{
Host: cc.authority,
Method: method,
@@ -221,8 +223,8 @@
codec: cc.dopts.codec,
cp: cc.dopts.cp,
dc: cc.dopts.dc,
- maxReceiveMessageSize: maxReceiveMessageSize,
- maxSendMessageSize: maxSendMessageSize,
+ maxReceiveMessageSize: *c.maxReceiveMessageSize,
+ maxSendMessageSize: *c.maxSendMessageSize,
cancel: cancel,
put: put,
diff --git a/test/end2end_test.go b/test/end2end_test.go
index c85c33d..bc3c8fa 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -433,11 +433,11 @@
healthServer *health.Server // nil means disabled
maxStream uint32
tapHandle tap.ServerInHandle
- maxMsgSize int
- maxClientReceiveMsgSize int
- maxClientSendMsgSize int
- maxServerReceiveMsgSize int
- maxServerSendMsgSize int
+ maxMsgSize *int
+ maxClientReceiveMsgSize *int
+ maxClientSendMsgSize *int
+ maxServerReceiveMsgSize *int
+ maxServerSendMsgSize *int
userAgent string
clientCompression bool
serverCompression bool
@@ -483,12 +483,6 @@
t: t,
e: e,
maxStream: math.MaxUint32,
- // Default value 0 is meaningful (0 byte msg size limit), thus using -1 to indiciate the field is unset.
- maxClientReceiveMsgSize: -1,
- maxClientSendMsgSize: -1,
- maxServerReceiveMsgSize: -1,
- maxServerSendMsgSize: -1,
- maxMsgSize: -1,
}
te.ctx, te.cancel = context.WithCancel(context.Background())
return te
@@ -500,14 +494,14 @@
te.testServer = ts
te.t.Logf("Running test in %s environment...", te.e.name)
sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(te.maxStream)}
- if te.maxMsgSize >= 0 {
- sopts = append(sopts, grpc.MaxMsgSize(te.maxMsgSize))
+ if te.maxMsgSize != nil {
+ sopts = append(sopts, grpc.MaxMsgSize(*te.maxMsgSize))
}
- if te.maxServerReceiveMsgSize >= 0 {
- sopts = append(sopts, grpc.MaxReceiveMessageSize(te.maxServerReceiveMsgSize))
+ if te.maxServerReceiveMsgSize != nil {
+ sopts = append(sopts, grpc.MaxReceiveMessageSize(*te.maxServerReceiveMsgSize))
}
- if te.maxServerSendMsgSize >= 0 {
- sopts = append(sopts, grpc.MaxSendMessageSize(te.maxServerSendMsgSize))
+ if te.maxServerSendMsgSize != nil {
+ sopts = append(sopts, grpc.MaxSendMessageSize(*te.maxServerSendMsgSize))
}
if te.tapHandle != nil {
sopts = append(sopts, grpc.InTapHandle(te.tapHandle))
@@ -600,14 +594,14 @@
if te.streamClientInt != nil {
opts = append(opts, grpc.WithStreamInterceptor(te.streamClientInt))
}
- if te.maxMsgSize >= 0 {
- opts = append(opts, grpc.WithMaxMsgSize(te.maxMsgSize))
+ if te.maxMsgSize != nil {
+ opts = append(opts, grpc.WithMaxMsgSize(*te.maxMsgSize))
}
- if te.maxClientReceiveMsgSize >= 0 {
- opts = append(opts, grpc.WithMaxReceiveMessageSize(te.maxClientReceiveMsgSize))
+ if te.maxClientReceiveMsgSize != nil {
+ opts = append(opts, grpc.WithDefaultCallOptions(grpc.WithMaxReceiveMessageSize(*te.maxClientReceiveMsgSize)))
}
- if te.maxClientSendMsgSize >= 0 {
- opts = append(opts, grpc.WithMaxSendMessageSize(te.maxClientSendMsgSize))
+ if te.maxClientSendMsgSize != nil {
+ opts = append(opts, grpc.WithDefaultCallOptions(grpc.WithMaxSendMessageSize(*te.maxClientSendMsgSize)))
}
if te.timeout > 0 {
opts = append(opts, grpc.WithTimeout(te.timeout))
@@ -1334,8 +1328,8 @@
// Case2: Client API set maxReqSize to 1024 (send), maxRespSize to 1024 (recv). Sc sets maxReqSize to 2048 (send), maxRespSize to 2048 (recv).
te2, ch2 := testServiceConfigSetup(t, e)
- te2.maxClientReceiveMsgSize = 1024
- te2.maxClientSendMsgSize = 1024
+ te2.maxClientReceiveMsgSize = newInt(1024)
+ te2.maxClientSendMsgSize = newInt(1024)
te2.startServer(&testServer{security: e.security})
defer te2.tearDown()
ch2 <- sc
@@ -1383,8 +1377,8 @@
// Case3: Client API set maxReqSize to 4096 (send), maxRespSize to 4096 (recv). Sc sets maxReqSize to 2048 (send), maxRespSize to 2048 (recv).
te3, ch3 := testServiceConfigSetup(t, e)
- te3.maxClientReceiveMsgSize = 4096
- te3.maxClientSendMsgSize = 4096
+ te3.maxClientReceiveMsgSize = newInt(4096)
+ te3.maxClientSendMsgSize = newInt(4096)
te3.startServer(&testServer{security: e.security})
defer te3.tearDown()
ch3 <- sc
@@ -1468,7 +1462,7 @@
te := newTest(t, e)
te.userAgent = testAppUA
// To avoid error on server side.
- te.maxServerSendMsgSize = 5 * 1024 * 1024
+ te.maxServerSendMsgSize = newInt(5 * 1024 * 1024)
te.declareLogNoise(
"transport: http2Client.notifyError got notified that the client transport was broken EOF",
"grpc: addrConn.transportMonitor exits due to: grpc: the connection is closing",
@@ -1547,9 +1541,9 @@
te := newTest(t, e)
te.userAgent = testAppUA
// To avoid error on server side.
- te.maxServerSendMsgSize = 5 * 1024 * 1024
- te.maxClientReceiveMsgSize = 1024
- te.maxClientSendMsgSize = 1024
+ te.maxServerSendMsgSize = newInt(5 * 1024 * 1024)
+ te.maxClientReceiveMsgSize = newInt(1024)
+ te.maxClientSendMsgSize = newInt(1024)
te.declareLogNoise(
"transport: http2Client.notifyError got notified that the client transport was broken EOF",
"grpc: addrConn.transportMonitor exits due to: grpc: the connection is closing",
@@ -1627,8 +1621,8 @@
func testMaxMsgSizeServerAPI(t *testing.T, e env) {
te := newTest(t, e)
te.userAgent = testAppUA
- te.maxServerReceiveMsgSize = 1024
- te.maxServerSendMsgSize = 1024
+ te.maxServerReceiveMsgSize = newInt(1024)
+ te.maxServerSendMsgSize = newInt(1024)
te.declareLogNoise(
"transport: http2Client.notifyError got notified that the client transport was broken EOF",
"grpc: addrConn.transportMonitor exits due to: grpc: the connection is closing",
@@ -2032,12 +2026,12 @@
func testExceedMsgLimit(t *testing.T, e env) {
te := newTest(t, e)
- te.maxMsgSize = 1024
+ te.maxMsgSize = newInt(1024)
te.startServer(&testServer{security: e.security})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
- argSize := int32(te.maxMsgSize + 1)
+ argSize := int32(*te.maxMsgSize + 1)
const smallSize = 1
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
@@ -2059,7 +2053,7 @@
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted)
}
// Test on client side for unary RPC.
- req.ResponseSize = proto.Int32(int32(te.maxMsgSize) + 1)
+ req.ResponseSize = proto.Int32(int32(*te.maxMsgSize) + 1)
req.Payload = smallPayload
if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.ResourceExhausted {
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.ResourceExhausted)
@@ -2076,7 +2070,7 @@
},
}
- spayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(te.maxMsgSize+1))
+ spayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(*te.maxMsgSize+1))
if err != nil {
t.Fatal(err)
}
@@ -2098,7 +2092,7 @@
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
- respParam[0].Size = proto.Int32(int32(te.maxMsgSize) + 1)
+ respParam[0].Size = proto.Int32(int32(*te.maxMsgSize) + 1)
sreq.Payload = smallPayload
if err := stream.Send(sreq); err != nil {
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)