Client should have a check on maximum size of received message size.
diff --git a/call.go b/call.go
index 81b52be..c1588c6 100644
--- a/call.go
+++ b/call.go
@@ -36,7 +36,6 @@
import (
"bytes"
"io"
- "math"
"time"
"golang.org/x/net/context"
@@ -73,7 +72,7 @@
}
}
for {
- if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32, inPayload); err != nil {
+ if err = recv(p, dopts.codec, stream, dopts.dc, reply, dopts.maxMsgSize, inPayload); err != nil {
if err == io.EOF {
break
}
diff --git a/clientconn.go b/clientconn.go
index b8e3198..1bf824b 100644
--- a/clientconn.go
+++ b/clientconn.go
@@ -36,6 +36,7 @@
import (
"errors"
"fmt"
+ "math"
"net"
"strings"
"sync"
@@ -87,23 +88,33 @@
// dialOptions configure a Dial call. dialOptions are set by the DialOption
// values passed to Dial.
type dialOptions struct {
- unaryInt UnaryClientInterceptor
- streamInt StreamClientInterceptor
- codec Codec
- cp Compressor
- dc Decompressor
- bs backoffStrategy
- balancer Balancer
- block bool
- insecure bool
- timeout time.Duration
- scChan <-chan ServiceConfig
- copts transport.ConnectOptions
+ unaryInt UnaryClientInterceptor
+ streamInt StreamClientInterceptor
+ codec Codec
+ cp Compressor
+ dc Decompressor
+ bs backoffStrategy
+ balancer Balancer
+ block bool
+ insecure bool
+ timeout time.Duration
+ scChan <-chan ServiceConfig
+ copts transport.ConnectOptions
+ maxMsgSize int
}
+const defaultClientMaxMsgSize = math.MaxInt32
+
// DialOption configures how we set up the connection.
type DialOption func(*dialOptions)
+// WithMaxMsgSize returns a DialOption which sets the maximum message size the client can receive.
+func WithMaxMsgSize(s int) DialOption {
+ return func(o *dialOptions) {
+ o.maxMsgSize = s
+ }
+}
+
// WithCodec returns a DialOption which sets a codec for message marshaling and unmarshaling.
func WithCodec(c Codec) DialOption {
return func(o *dialOptions) {
@@ -304,6 +315,9 @@
ctx, cancel = context.WithTimeout(ctx, cc.dopts.timeout)
defer cancel()
}
+ if cc.dopts.maxMsgSize == 0 {
+ cc.dopts.maxMsgSize = defaultClientMaxMsgSize
+ }
defer func() {
select {
diff --git a/stream.go b/stream.go
index bb468dc..0ef2077 100644
--- a/stream.go
+++ b/stream.go
@@ -37,7 +37,6 @@
"bytes"
"errors"
"io"
- "math"
"sync"
"time"
@@ -208,13 +207,14 @@
break
}
cs := &clientStream{
- opts: opts,
- c: c,
- desc: desc,
- codec: cc.dopts.codec,
- cp: cc.dopts.cp,
- dc: cc.dopts.dc,
- cancel: cancel,
+ opts: opts,
+ c: c,
+ desc: desc,
+ codec: cc.dopts.codec,
+ cp: cc.dopts.cp,
+ dc: cc.dopts.dc,
+ maxMsgSize: cc.dopts.maxMsgSize,
+ cancel: cancel,
put: put,
t: t,
@@ -259,17 +259,18 @@
// clientStream implements a client side Stream.
type clientStream struct {
- opts []CallOption
- c callInfo
- t transport.ClientTransport
- s *transport.Stream
- p *parser
- desc *StreamDesc
- codec Codec
- cp Compressor
- cbuf *bytes.Buffer
- dc Decompressor
- cancel context.CancelFunc
+ opts []CallOption
+ c callInfo
+ t transport.ClientTransport
+ s *transport.Stream
+ p *parser
+ desc *StreamDesc
+ codec Codec
+ cp Compressor
+ cbuf *bytes.Buffer
+ dc Decompressor
+ maxMsgSize int
+ cancel context.CancelFunc
tracing bool // set to EnableTracing when the clientStream is created.
@@ -382,7 +383,7 @@
Client: true,
}
}
- err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, inPayload)
+ err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, inPayload)
defer func() {
// err != nil indicates the termination of the stream.
if err != nil {
@@ -405,7 +406,7 @@
}
// Special handling for client streaming rpc.
// This recv expects EOF or errors, so we don't collect inPayload.
- err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, nil)
+ err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, nil)
cs.closeTransportStream(err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
diff --git a/test/end2end_test.go b/test/end2end_test.go
index d743623..8aee5a1 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -570,6 +570,9 @@
if te.streamClientInt != nil {
opts = append(opts, grpc.WithStreamInterceptor(te.streamClientInt))
}
+ if te.maxMsgSize > 0 {
+ opts = append(opts, grpc.WithMaxMsgSize(te.maxMsgSize))
+ }
switch te.e.security {
case "tls":
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
@@ -1427,22 +1430,33 @@
tc := testpb.NewTestServiceClient(te.clientConn())
argSize := int32(te.maxMsgSize + 1)
- const respSize = 1
+ const smallSize = 1
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
if err != nil {
t.Fatal(err)
}
+ smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize)
+ if err != nil {
+ t.Fatal(err)
+ }
+ // test on server side for unary RPC
req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
- ResponseSize: proto.Int32(respSize),
+ ResponseSize: proto.Int32(smallSize),
Payload: payload,
}
if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal {
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.Internal)
}
+ // test on client side for unary RPC
+ req.ResponseSize = proto.Int32(int32(te.maxMsgSize) + 1)
+ if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal {
+ t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.Internal)
+ }
+ // test on server side for streaming RPC
stream, err := tc.FullDuplexCall(te.ctx)
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
@@ -1469,6 +1483,21 @@
if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal {
t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.Internal)
}
+
+ // test on client side for streaming RPC
+ stream, err = tc.FullDuplexCall(te.ctx)
+ if err != nil {
+ t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
+ }
+ respParam[0].Size = proto.Int32(int32(te.maxMsgSize) + 1)
+ sreq.Payload = smallPayload
+ if err := stream.Send(sreq); err != nil {
+ t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
+ }
+ if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal {
+ t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.Internal)
+ }
+
}
func TestPeerClientSide(t *testing.T) {