Expand stream's flow control in case of an active read. (#1248)
* First commit
* Imported tests from the original PR by @apolcyn.
* Formatting fixes.
* More formating fixes
* more golint
* Make logs more informative.
* post-review update
* Added test to check flow control accounts after sending large messages.
* post-review update
* Empty commit to kickstart travis.
* Post-review update.
diff --git a/rpc_util.go b/rpc_util.go
index 11558d7..8ddacb9 100644
--- a/rpc_util.go
+++ b/rpc_util.go
@@ -278,7 +278,7 @@
// that the underlying io.Reader must not return an incompatible
// error.
func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) {
- if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
+ if _, err := p.r.Read(p.header[:]); err != nil {
return 0, nil, err
}
@@ -294,7 +294,7 @@
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
// of making it for each message:
msg = make([]byte, int(length))
- if _, err := io.ReadFull(p.r, msg); err != nil {
+ if _, err := p.r.Read(msg); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
diff --git a/rpc_util_test.go b/rpc_util_test.go
index d832b12..cbaaa52 100644
--- a/rpc_util_test.go
+++ b/rpc_util_test.go
@@ -47,6 +47,14 @@
"google.golang.org/grpc/transport"
)
+type fullReader struct {
+ reader io.Reader
+}
+
+func (f fullReader) Read(p []byte) (int, error) {
+ return io.ReadFull(f.reader, p)
+}
+
var _ CallOption = EmptyCallOption{} // ensure EmptyCallOption implements the interface
func TestSimpleParsing(t *testing.T) {
@@ -67,7 +75,7 @@
// Check that messages with length >= 2^24 are parsed.
{append([]byte{0, 1, 0, 0, 0}, bigMsg...), nil, bigMsg, compressionNone},
} {
- buf := bytes.NewReader(test.p)
+ buf := fullReader{bytes.NewReader(test.p)}
parser := &parser{r: buf}
pt, b, err := parser.recvMsg(math.MaxInt32)
if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt {
@@ -79,7 +87,7 @@
func TestMultipleParsing(t *testing.T) {
// Set a byte stream consists of 3 messages with their headers.
p := []byte{0, 0, 0, 0, 1, 'a', 0, 0, 0, 0, 2, 'b', 'c', 0, 0, 0, 0, 1, 'd'}
- b := bytes.NewReader(p)
+ b := fullReader{bytes.NewReader(p)}
parser := &parser{r: b}
wantRecvs := []struct {
diff --git a/test/end2end_test.go b/test/end2end_test.go
index b028e83..52f8ce7 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -449,6 +449,7 @@
streamServerInt grpc.StreamServerInterceptor
unknownHandler grpc.StreamHandler
sc <-chan grpc.ServiceConfig
+ customCodec grpc.Codec
serverInitialWindowSize int32
serverInitialConnWindowSize int32
clientInitialWindowSize int32
@@ -555,6 +556,9 @@
case "clientTimeoutCreds":
sopts = append(sopts, grpc.Creds(&clientTimeoutCreds{}))
}
+ if te.customCodec != nil {
+ sopts = append(sopts, grpc.CustomCodec(te.customCodec))
+ }
s := grpc.NewServer(sopts...)
te.srv = s
if te.e.httpHandler {
@@ -641,6 +645,9 @@
if te.perRPCCreds != nil {
opts = append(opts, grpc.WithPerRPCCredentials(te.perRPCCreds))
}
+ if te.customCodec != nil {
+ opts = append(opts, grpc.WithCodec(te.customCodec))
+ }
var err error
te.cc, err = grpc.Dial(te.srvAddr, opts...)
if err != nil {
@@ -3271,26 +3278,51 @@
}
+func generatePayloadSizes() [][]int {
+ reqSizes := [][]int{
+ {27182, 8, 1828, 45904},
+ }
+
+ num8KPayloads := 1024
+ eightKPayloads := []int{}
+ for i := 0; i < num8KPayloads; i++ {
+ eightKPayloads = append(eightKPayloads, (1 << 13))
+ }
+ reqSizes = append(reqSizes, eightKPayloads)
+
+ num2MPayloads := 8
+ twoMPayloads := []int{}
+ for i := 0; i < num2MPayloads; i++ {
+ twoMPayloads = append(twoMPayloads, (1 << 21))
+ }
+ reqSizes = append(reqSizes, twoMPayloads)
+
+ return reqSizes
+}
+
func TestClientStreaming(t *testing.T) {
defer leakCheck(t)()
- for _, e := range listTestEnv() {
- testClientStreaming(t, e)
+ for _, s := range generatePayloadSizes() {
+ for _, e := range listTestEnv() {
+ testClientStreaming(t, e, s)
+ }
}
}
-func testClientStreaming(t *testing.T, e env) {
+func testClientStreaming(t *testing.T, e env, sizes []int) {
te := newTest(t, e)
te.startServer(&testServer{security: e.security})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
- stream, err := tc.StreamingInputCall(te.ctx)
+ ctx, _ := context.WithTimeout(te.ctx, time.Second*30)
+ stream, err := tc.StreamingInputCall(ctx)
if err != nil {
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want <nil>", tc, err)
}
var sum int
- for _, s := range reqSizes {
+ for _, s := range sizes {
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(s))
if err != nil {
t.Fatal(err)
diff --git a/test/servertester.go b/test/servertester.go
index f65bc3d..ed3724d 100644
--- a/test/servertester.go
+++ b/test/servertester.go
@@ -287,3 +287,9 @@
st.t.Fatalf("Error writing RST_STREAM: %v", err)
}
}
+
+func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, padding []byte) {
+ if err := st.fr.WriteDataPadded(streamID, endStream, data, padding); err != nil {
+ st.t.Fatalf("Error writing DATA with padding: %v", err)
+ }
+}
diff --git a/transport/control.go b/transport/control.go
index 8d29aee..68dfdd5 100644
--- a/transport/control.go
+++ b/transport/control.go
@@ -58,6 +58,8 @@
defaultServerKeepaliveTime = time.Duration(2 * time.Hour)
defaultServerKeepaliveTimeout = time.Duration(20 * time.Second)
defaultKeepalivePolicyMinTime = time.Duration(5 * time.Minute)
+ // max window limit set by HTTP2 Specs.
+ maxWindowSize = math.MaxInt32
)
// The following defines various control items which could flow through
@@ -167,6 +169,40 @@
// The amount of data the application has consumed but grpc has not sent
// window update for them. Used to reduce window update frequency.
pendingUpdate uint32
+ // delta is the extra window update given by receiver when an application
+ // is reading data bigger in size than the inFlow limit.
+ delta uint32
+}
+
+func (f *inFlow) maybeAdjust(n uint32) uint32 {
+ if n > uint32(math.MaxInt32) {
+ n = uint32(math.MaxInt32)
+ }
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ // estSenderQuota is the receiver's view of the maximum number of bytes the sender
+ // can send without a window update.
+ estSenderQuota := int32(f.limit - (f.pendingData + f.pendingUpdate))
+ // estUntransmittedData is the maximum number of bytes the sends might not have put
+ // on the wire yet. A value of 0 or less means that we have already received all or
+ // more bytes than the application is requesting to read.
+ estUntransmittedData := int32(n - f.pendingData) // Casting into int32 since it could be negative.
+ // This implies that unless we send a window update, the sender won't be able to send all the bytes
+ // for this message. Therefore we must send an update over the limit since there's an active read
+ // request from the application.
+ if estUntransmittedData > estSenderQuota {
+ // Sender's window shouldn't go more than 2^31 - 1 as speecified in the HTTP spec.
+ if f.limit+n > maxWindowSize {
+ f.delta = maxWindowSize - f.limit
+ } else {
+ // Send a window update for the whole message and not just the difference between
+ // estUntransmittedData and estSenderQuota. This will be helpful in case the message
+ // is padded; We will fallback on the current available window(at least a 1/4th of the limit).
+ f.delta = n
+ }
+ return f.delta
+ }
+ return 0
}
// onData is invoked when some data frame is received. It updates pendingData.
@@ -174,7 +210,7 @@
f.mu.Lock()
defer f.mu.Unlock()
f.pendingData += n
- if f.pendingData+f.pendingUpdate > f.limit {
+ if f.pendingData+f.pendingUpdate > f.limit+f.delta {
return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", f.pendingData+f.pendingUpdate, f.limit)
}
return nil
@@ -189,6 +225,13 @@
return 0
}
f.pendingData -= n
+ if n > f.delta {
+ n -= f.delta
+ f.delta = 0
+ } else {
+ f.delta -= n
+ n = 0
+ }
f.pendingUpdate += n
if f.pendingUpdate >= f.limit/4 {
wu := f.pendingUpdate
diff --git a/transport/handler_server.go b/transport/handler_server.go
index 31b0570..93144fc 100644
--- a/transport/handler_server.go
+++ b/transport/handler_server.go
@@ -316,13 +316,12 @@
req := ht.req
s := &Stream{
- id: 0, // irrelevant
- windowHandler: func(int) {}, // nothing
- cancel: cancel,
- buf: newRecvBuffer(),
- st: ht,
- method: req.URL.Path,
- recvCompress: req.Header.Get("grpc-encoding"),
+ id: 0, // irrelevant
+ cancel: cancel,
+ buf: newRecvBuffer(),
+ st: ht,
+ method: req.URL.Path,
+ recvCompress: req.Header.Get("grpc-encoding"),
}
pr := &peer.Peer{
Addr: ht.RemoteAddr(),
@@ -333,7 +332,7 @@
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
ctx = peer.NewContext(ctx, pr)
s.ctx = newContextWithStream(ctx, s)
- s.dec = &recvBufferReader{ctx: s.ctx, recv: s.buf}
+ s.trReader = &recvBufferReader{ctx: s.ctx, recv: s.buf}
// readerDone is closed when the Body.Read-ing goroutine exits.
readerDone := make(chan struct{})
diff --git a/transport/http2_client.go b/transport/http2_client.go
index 7db73d3..713f762 100644
--- a/transport/http2_client.go
+++ b/transport/http2_client.go
@@ -173,9 +173,9 @@
conn, err := dial(ctx, opts.Dialer, addr.Addr)
if err != nil {
if opts.FailOnNonTempDialError {
- return nil, connectionErrorf(isTemporary(err), err, "transport: %v", err)
+ return nil, connectionErrorf(isTemporary(err), err, "transport: error while dialing: %v", err)
}
- return nil, connectionErrorf(true, err, "transport: %v", err)
+ return nil, connectionErrorf(true, err, "transport: Error while dialing %v", err)
}
// Any further errors will close the underlying connection
defer func(conn net.Conn) {
@@ -194,7 +194,7 @@
// Credentials handshake errors are typically considered permanent
// to avoid retrying on e.g. bad certificates.
temp := isTemporary(err)
- return nil, connectionErrorf(temp, err, "transport: %v", err)
+ return nil, connectionErrorf(temp, err, "transport: authentication handshake failed: %v", err)
}
isSecure = true
}
@@ -269,7 +269,7 @@
n, err := t.conn.Write(clientPreface)
if err != nil {
t.Close()
- return nil, connectionErrorf(true, err, "transport: %v", err)
+ return nil, connectionErrorf(true, err, "transport: failed to write client preface: %v", err)
}
if n != len(clientPreface) {
t.Close()
@@ -285,13 +285,13 @@
}
if err != nil {
t.Close()
- return nil, connectionErrorf(true, err, "transport: %v", err)
+ return nil, connectionErrorf(true, err, "transport: failed to write initial settings frame: %v", err)
}
// Adjust the connection flow control window if needed.
if delta := uint32(icwz - defaultWindowSize); delta > 0 {
if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
t.Close()
- return nil, connectionErrorf(true, err, "transport: %v", err)
+ return nil, connectionErrorf(true, err, "transport: failed to write window update: %v", err)
}
}
go t.controller()
@@ -316,18 +316,24 @@
headerChan: make(chan struct{}),
}
t.nextID += 2
- s.windowHandler = func(n int) {
- t.updateWindow(s, uint32(n))
+ s.requestRead = func(n int) {
+ t.adjustWindow(s, uint32(n))
}
// The client side stream context should have exactly the same life cycle with the user provided context.
// That means, s.ctx should be read-only. And s.ctx is done iff ctx is done.
// So we use the original context here instead of creating a copy.
s.ctx = ctx
- s.dec = &recvBufferReader{
- ctx: s.ctx,
- goAway: s.goAway,
- recv: s.buf,
+ s.trReader = &transportReader{
+ reader: &recvBufferReader{
+ ctx: s.ctx,
+ goAway: s.goAway,
+ recv: s.buf,
+ },
+ windowHandler: func(n int) {
+ t.updateWindow(s, uint32(n))
+ },
}
+
return s
}
@@ -802,6 +808,20 @@
return s, ok
}
+// adjustWindow sends out extra window update over the initial window size
+// of stream if the application is requesting data larger in size than
+// the window.
+func (t *http2Client) adjustWindow(s *Stream, n uint32) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.state == streamDone {
+ return
+ }
+ if w := s.fc.maybeAdjust(n); w > 0 {
+ t.controlBuf.put(&windowUpdate{s.id, w})
+ }
+}
+
// updateWindow adjusts the inbound quota for the stream and the transport.
// Window updates will deliver to the controller for sending when
// the cumulative quota exceeds the corresponding threshold.
diff --git a/transport/http2_server.go b/transport/http2_server.go
index a1ec3a8..559d28d 100644
--- a/transport/http2_server.go
+++ b/transport/http2_server.go
@@ -274,10 +274,14 @@
if len(state.mdata) > 0 {
s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata)
}
-
- s.dec = &recvBufferReader{
- ctx: s.ctx,
- recv: s.buf,
+ s.trReader = &transportReader{
+ reader: &recvBufferReader{
+ ctx: s.ctx,
+ recv: s.buf,
+ },
+ windowHandler: func(n int) {
+ t.updateWindow(s, uint32(n))
+ },
}
s.recvCompress = state.encoding
s.method = state.method
@@ -316,8 +320,8 @@
t.idle = time.Time{}
}
t.mu.Unlock()
- s.windowHandler = func(n int) {
- t.updateWindow(s, uint32(n))
+ s.requestRead = func(n int) {
+ t.adjustWindow(s, uint32(n))
}
s.ctx = traceCtx(s.ctx, s.method)
if t.stats != nil {
@@ -361,7 +365,7 @@
return
}
if err != nil {
- grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
+ grpclog.Printf("transport: http2Server.HandleStreams failed to read initial settings frame: %v", err)
t.Close()
return
}
@@ -435,6 +439,20 @@
return s, true
}
+// adjustWindow sends out extra window update over the initial window size
+// of stream if the application is requesting data larger in size than
+// the window.
+func (t *http2Server) adjustWindow(s *Stream, n uint32) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.state == streamDone {
+ return
+ }
+ if w := s.fc.maybeAdjust(n); w > 0 {
+ t.controlBuf.put(&windowUpdate{s.id, w})
+ }
+}
+
// updateWindow adjusts the inbound quota for the stream and the transport.
// Window updates will deliver to the controller for sending when
// the cumulative quota exceeds the corresponding threshold.
diff --git a/transport/transport.go b/transport/transport.go
index 6bec5bd..c1be4da 100644
--- a/transport/transport.go
+++ b/transport/transport.go
@@ -185,14 +185,17 @@
recvCompress string
sendCompress string
buf *recvBuffer
- dec io.Reader
+ trReader io.Reader
fc *inFlow
recvQuota uint32
+
+ // TODO: Remote this unused variable.
// The accumulated inbound quota pending for window update.
updateQuota uint32
- // The handler to control the window update procedure for both this
- // particular stream and the associated transport.
- windowHandler func(int)
+
+ // Callback to state application's intentions to read data. This
+ // is used to adjust flow control, if need be.
+ requestRead func(int)
sendQuotaPool *quotaPool
// Close headerChan to indicate the end of reception of header metadata.
@@ -320,16 +323,35 @@
s.buf.put(&m)
}
-// Read reads all the data available for this Stream from the transport and
+// Read reads all p bytes from the wire for this stream.
+func (s *Stream) Read(p []byte) (n int, err error) {
+ // Don't request a read if there was an error earlier
+ if er := s.trReader.(*transportReader).er; er != nil {
+ return 0, er
+ }
+ s.requestRead(len(p))
+ return io.ReadFull(s.trReader, p)
+}
+
+// tranportReader reads all the data available for this Stream from the transport and
// passes them into the decoder, which converts them into a gRPC message stream.
// The error is io.EOF when the stream is done or another non-nil error if
// the stream broke.
-func (s *Stream) Read(p []byte) (n int, err error) {
- n, err = s.dec.Read(p)
+type transportReader struct {
+ reader io.Reader
+ // The handler to control the window update procedure for both this
+ // particular stream and the associated transport.
+ windowHandler func(int)
+ er error
+}
+
+func (t *transportReader) Read(p []byte) (n int, err error) {
+ n, err = t.reader.Read(p)
if err != nil {
+ t.er = err
return
}
- s.windowHandler(n)
+ t.windowHandler(n)
return
}
diff --git a/transport/transport_test.go b/transport/transport_test.go
index 0b534d2..72bd104 100644
--- a/transport/transport_test.go
+++ b/transport/transport_test.go
@@ -36,6 +36,8 @@
import (
"bufio"
"bytes"
+ "encoding/binary"
+ "errors"
"fmt"
"io"
"math"
@@ -84,6 +86,9 @@
misbehaved
encodingRequiredStatus
invalidHeaderField
+ delayRead
+ delayWrite
+ pingpong
)
func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
@@ -94,7 +99,7 @@
resp = expectedResponseLarge
}
p := make([]byte, len(req))
- _, err := io.ReadFull(s, p)
+ _, err := s.Read(p)
if err != nil {
return
}
@@ -107,6 +112,25 @@
h.t.WriteStatus(s, status.New(codes.OK, ""))
}
+func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) {
+ header := make([]byte, 5)
+ for i := 0; i < 10; i++ {
+ if _, err := s.Read(header); err != nil {
+ t.Fatalf("Error on server while reading data header: %v", err)
+ }
+ sz := binary.BigEndian.Uint32(header[1:])
+ msg := make([]byte, int(sz))
+ if _, err := s.Read(msg); err != nil {
+ t.Fatalf("Error on server while reading message: %v", err)
+ }
+ buf := make([]byte, sz+5)
+ buf[0] = byte(0)
+ binary.BigEndian.PutUint32(buf[1:], uint32(sz))
+ copy(buf[5:], msg)
+ h.t.Write(s, buf, &Options{})
+ }
+}
+
// handleStreamSuspension blocks until s.ctx is canceled.
func (h *testStreamHandler) handleStreamSuspension(s *Stream) {
go func() {
@@ -159,6 +183,58 @@
h.t.writableChan <- 0
}
+func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) {
+ req := expectedRequest
+ resp := expectedResponse
+ if s.Method() == "foo.Large" {
+ req = expectedRequestLarge
+ resp = expectedResponseLarge
+ }
+ p := make([]byte, len(req))
+
+ // Wait before reading. Give time to client to start sending
+ // before server starts reading.
+ time.Sleep(2 * time.Second)
+ _, err := s.Read(p)
+ if err != nil {
+ t.Fatalf("s.Read(_) = _, %v, want _, <nil>", err)
+ return
+ }
+
+ if !bytes.Equal(p, req) {
+ t.Fatalf("handleStream got %v, want %v", p, req)
+ }
+ // send a response back to the client.
+ h.t.Write(s, resp, &Options{})
+ // send the trailer to end the stream.
+ h.t.WriteStatus(s, status.New(codes.OK, ""))
+}
+
+func (h *testStreamHandler) handleStreamDelayWrite(t *testing.T, s *Stream) {
+ req := expectedRequest
+ resp := expectedResponse
+ if s.Method() == "foo.Large" {
+ req = expectedRequestLarge
+ resp = expectedResponseLarge
+ }
+ p := make([]byte, len(req))
+ _, err := s.Read(p)
+ if err != nil {
+ t.Fatalf("s.Read(_) = _, %v, want _, <nil>", err)
+ return
+ }
+ if !bytes.Equal(p, req) {
+ t.Fatalf("handleStream got %v, want %v", p, req)
+ }
+
+ // Wait before sending. Give time to client to start reading
+ // before server starts sending.
+ time.Sleep(2 * time.Second)
+ h.t.Write(s, resp, &Options{})
+ // send the trailer to end the stream.
+ h.t.WriteStatus(s, status.New(codes.OK, ""))
+}
+
// start starts server. Other goroutines should block on s.readyChan for further operations.
func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hType) {
var err error
@@ -221,6 +297,24 @@
}, func(ctx context.Context, method string) context.Context {
return ctx
})
+ case delayRead:
+ go transport.HandleStreams(func(s *Stream) {
+ go h.handleStreamDelayRead(t, s)
+ }, func(ctx context.Context, method string) context.Context {
+ return ctx
+ })
+ case delayWrite:
+ go transport.HandleStreams(func(s *Stream) {
+ go h.handleStreamDelayWrite(t, s)
+ }, func(ctx context.Context, method string) context.Context {
+ return ctx
+ })
+ case pingpong:
+ go transport.HandleStreams(func(s *Stream) {
+ go h.handleStreamPingPong(t, s)
+ }, func(ctx context.Context, method string) context.Context {
+ return ctx
+ })
default:
go transport.HandleStreams(func(s *Stream) {
go h.handleStream(t, s)
@@ -696,11 +790,11 @@
t.Fatalf("failed to send data: %v", err)
}
p := make([]byte, len(expectedResponse))
- _, recvErr := io.ReadFull(s1, p)
+ _, recvErr := s1.Read(p)
if recvErr != nil || !bytes.Equal(p, expectedResponse) {
t.Fatalf("Error: %v, want <nil>; Result: %v, want %v", recvErr, p, expectedResponse)
}
- _, recvErr = io.ReadFull(s1, p)
+ _, recvErr = s1.Read(p)
if recvErr != io.EOF {
t.Fatalf("Error: %v; want <EOF>", recvErr)
}
@@ -736,9 +830,9 @@
//
// Read response
p := make([]byte, len(expectedResponse))
- io.ReadFull(s, p)
+ s.Read(p)
// Read io.EOF
- io.ReadFull(s, p)
+ s.Read(p)
}
}
@@ -777,10 +871,80 @@
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponseLarge))
- if _, err := io.ReadFull(s, p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
- t.Errorf("io.ReadFull(_, %v) = _, %v, want %v, <nil>", err, p, expectedResponse)
+ if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
+ t.Errorf("s.Read(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
}
- if _, err = io.ReadFull(s, p); err != io.EOF {
+ if _, err = s.Read(p); err != io.EOF {
+ t.Errorf("Failed to complete the stream %v; want <EOF>", err)
+ }
+ }()
+ }
+ wg.Wait()
+ ct.Close()
+ server.stop()
+}
+
+func TestLargeMessageWithDelayRead(t *testing.T) {
+ server, ct := setUp(t, 0, math.MaxUint32, delayRead)
+ callHdr := &CallHdr{
+ Host: "localhost",
+ Method: "foo.Large",
+ }
+ var wg sync.WaitGroup
+ for i := 0; i < 2; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ s, err := ct.NewStream(context.Background(), callHdr)
+ if err != nil {
+ t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
+ }
+ if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
+ t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
+ }
+ p := make([]byte, len(expectedResponseLarge))
+
+ // Give time to server to begin sending before client starts reading.
+ time.Sleep(2 * time.Second)
+ if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
+ t.Errorf("s.Read(_) = _, %v, want _, <nil>", err)
+ }
+ if _, err = s.Read(p); err != io.EOF {
+ t.Errorf("Failed to complete the stream %v; want <EOF>", err)
+ }
+ }()
+ }
+ wg.Wait()
+ ct.Close()
+ server.stop()
+}
+
+func TestLargeMessageDelayWrite(t *testing.T) {
+ server, ct := setUp(t, 0, math.MaxUint32, delayWrite)
+ callHdr := &CallHdr{
+ Host: "localhost",
+ Method: "foo.Large",
+ }
+ var wg sync.WaitGroup
+ for i := 0; i < 2; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ s, err := ct.NewStream(context.Background(), callHdr)
+ if err != nil {
+ t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
+ }
+
+ // Give time to server to start reading before client starts sending.
+ time.Sleep(2 * time.Second)
+ if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
+ t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
+ }
+ p := make([]byte, len(expectedResponseLarge))
+ if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
+ t.Errorf("io.ReadFull(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
+ }
+ if _, err = s.Read(p); err != io.EOF {
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
}
}()
@@ -823,10 +987,10 @@
t.Fatalf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponse))
- if _, err := io.ReadFull(s, p); err != nil || !bytes.Equal(p, expectedResponse) {
- t.Fatalf("io.ReadFull(_, %v) = _, %v, want %v, <nil>", err, p, expectedResponse)
+ if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponse) {
+ t.Fatalf("s.Read(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
}
- if _, err = io.ReadFull(s, p); err != io.EOF {
+ if _, err = s.Read(p); err != io.EOF {
t.Fatalf("Failed to complete the stream %v; want <EOF>", err)
}
wg.Wait()
@@ -1074,7 +1238,7 @@
}
// Server sent a resetStream for s already.
code := http2ErrConvTab[http2.ErrCodeFlowControl]
- if _, err := io.ReadFull(s, make([]byte, 1)); err != io.EOF {
+ if _, err := s.Read(make([]byte, 1)); err != io.EOF {
t.Fatalf("%v got err %v want <EOF>", s, err)
}
if s.status.Code() != code {
@@ -1125,7 +1289,7 @@
// Read without window update.
for {
p := make([]byte, http2MaxFrameLen)
- if _, err = s.dec.Read(p); err != nil {
+ if _, err = s.trReader.(*transportReader).reader.Read(p); err != nil {
break
}
}
@@ -1184,7 +1348,7 @@
t.Fatalf("Failed to write the request: %v", err)
}
p := make([]byte, http2MaxFrameLen)
- if _, err := s.dec.Read(p); err != io.EOF {
+ if _, err := s.trReader.(*transportReader).Read(p); err != io.EOF {
t.Fatalf("Read got error %v, want %v", err, io.EOF)
}
if !reflect.DeepEqual(s.Status(), encodingTestStatus) {
@@ -1212,7 +1376,7 @@
t.Fatalf("Failed to write the request: %v", err)
}
p := make([]byte, http2MaxFrameLen)
- _, err = s.dec.Read(p)
+ _, err = s.trReader.(*transportReader).Read(p)
if se, ok := err.(StreamError); !ok || se.Code != codes.FailedPrecondition || !strings.Contains(err.Error(), expectedInvalidHeaderField) {
t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.FailedPrecondition, expectedInvalidHeaderField)
}
@@ -1269,6 +1433,13 @@
}
}
+func max(a, b int32) int32 {
+ if a > b {
+ return a
+ }
+ return b
+}
+
type windowSizeConfig struct {
serverStream int32
serverConn int32
@@ -1348,6 +1519,7 @@
}
return false, nil
})
+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
serverSendQuota, err := wait(ctx, nil, nil, nil, st.sendQuotaPool.acquire())
if err != nil {
@@ -1395,6 +1567,166 @@
}
}
+// Check accounting on both sides after sending and receiving large messages.
+func TestAccountCheckExpandingWindow(t *testing.T) {
+ server, client := setUp(t, 0, 0, pingpong)
+ defer server.stop()
+ defer client.Close()
+ waitWhileTrue(t, func() (bool, error) {
+ server.mu.Lock()
+ defer server.mu.Unlock()
+ if len(server.conns) == 0 {
+ return true, fmt.Errorf("timed out while waiting for server transport to be created")
+ }
+ return false, nil
+ })
+ var st *http2Server
+ server.mu.Lock()
+ for k := range server.conns {
+ st = k.(*http2Server)
+ }
+ server.mu.Unlock()
+ ct := client.(*http2Client)
+ cstream, err := client.NewStream(context.Background(), &CallHdr{Flush: true})
+ if err != nil {
+ t.Fatalf("Failed to create stream. Err: %v", err)
+ }
+
+ msgSize := 65535 * 16 * 2
+ msg := make([]byte, msgSize)
+ buf := make([]byte, msgSize+5)
+ buf[0] = byte(0)
+ binary.BigEndian.PutUint32(buf[1:], uint32(msgSize))
+ copy(buf[5:], msg)
+ opts := Options{}
+ header := make([]byte, 5)
+ for i := 1; i <= 10; i++ {
+ if err := ct.Write(cstream, buf, &opts); err != nil {
+ t.Fatalf("Error on client while writing message: %v", err)
+ }
+ if _, err := cstream.Read(header); err != nil {
+ t.Fatalf("Error on client while reading data frame header: %v", err)
+ }
+ sz := binary.BigEndian.Uint32(header[1:])
+ recvMsg := make([]byte, int(sz))
+ if _, err := cstream.Read(recvMsg); err != nil {
+ t.Fatalf("Error on client while reading data: %v", err)
+ }
+ if len(recvMsg) != len(msg) {
+ t.Fatalf("Length of message received by client: %v, want: %v", len(recvMsg), len(msg))
+ }
+ }
+ var sstream *Stream
+ st.mu.Lock()
+ for _, v := range st.activeStreams {
+ sstream = v
+ }
+ st.mu.Unlock()
+
+ waitWhileTrue(t, func() (bool, error) {
+ // Check that pendingData and delta on flow control windows on both sides are 0.
+ cstream.fc.mu.Lock()
+ if cstream.fc.delta != 0 {
+ cstream.fc.mu.Unlock()
+ return true, fmt.Errorf("delta on flow control window of client stream is non-zero")
+ }
+ if cstream.fc.pendingData != 0 {
+ cstream.fc.mu.Unlock()
+ return true, fmt.Errorf("pendingData on flow control window of client stream is non-zero")
+ }
+ cstream.fc.mu.Unlock()
+ sstream.fc.mu.Lock()
+ if sstream.fc.delta != 0 {
+ sstream.fc.mu.Unlock()
+ return true, fmt.Errorf("delta on flow control window of server stream is non-zero")
+ }
+ if sstream.fc.pendingData != 0 {
+ sstream.fc.mu.Unlock()
+ return true, fmt.Errorf("pendingData on flow control window of sercer stream is non-zero")
+ }
+ sstream.fc.mu.Unlock()
+ ct.fc.mu.Lock()
+ if ct.fc.delta != 0 {
+ ct.fc.mu.Unlock()
+ return true, fmt.Errorf("delta on flow control window of client transport is non-zero")
+ }
+ if ct.fc.pendingData != 0 {
+ ct.fc.mu.Unlock()
+ return true, fmt.Errorf("pendingData on flow control window of client transport is non-zero")
+ }
+ ct.fc.mu.Unlock()
+ st.fc.mu.Lock()
+ if st.fc.delta != 0 {
+ st.fc.mu.Unlock()
+ return true, fmt.Errorf("delta on flow control window of server transport is non-zero")
+ }
+ if st.fc.pendingData != 0 {
+ st.fc.mu.Unlock()
+ return true, fmt.Errorf("pendingData on flow control window of server transport is non-zero")
+ }
+ st.fc.mu.Unlock()
+
+ // Check flow conrtrol window on client stream is equal to out flow on server stream.
+ ctx, _ := context.WithTimeout(context.Background(), time.Second)
+ serverStreamSendQuota, err := wait(ctx, nil, nil, nil, sstream.sendQuotaPool.acquire())
+ if err != nil {
+ return true, fmt.Errorf("error while acquiring server stream send quota. Err: %v", err)
+ }
+ sstream.sendQuotaPool.add(serverStreamSendQuota)
+ cstream.fc.mu.Lock()
+ if uint32(serverStreamSendQuota) != cstream.fc.limit-cstream.fc.pendingUpdate {
+ cstream.fc.mu.Unlock()
+ return true, fmt.Errorf("server stream outflow: %v, estimated by client: %v", serverStreamSendQuota, cstream.fc.limit-cstream.fc.pendingUpdate)
+ }
+ cstream.fc.mu.Unlock()
+
+ // Check flow control window on server stream is equal to out flow on client stream.
+ ctx, _ = context.WithTimeout(context.Background(), time.Second)
+ clientStreamSendQuota, err := wait(ctx, nil, nil, nil, cstream.sendQuotaPool.acquire())
+ if err != nil {
+ return true, fmt.Errorf("error while acquiring client stream send quota. Err: %v", err)
+ }
+ cstream.sendQuotaPool.add(clientStreamSendQuota)
+ sstream.fc.mu.Lock()
+ if uint32(clientStreamSendQuota) != sstream.fc.limit-sstream.fc.pendingUpdate {
+ sstream.fc.mu.Unlock()
+ return true, fmt.Errorf("client stream outflow: %v. estimated by server: %v", clientStreamSendQuota, sstream.fc.limit-sstream.fc.pendingUpdate)
+ }
+ sstream.fc.mu.Unlock()
+
+ // Check flow control window on client transport is equal to out flow of server transport.
+ ctx, _ = context.WithTimeout(context.Background(), time.Second)
+ serverTrSendQuota, err := wait(ctx, nil, nil, nil, st.sendQuotaPool.acquire())
+ if err != nil {
+ return true, fmt.Errorf("error while acquring server transport send quota. Err: %v", err)
+ }
+ st.sendQuotaPool.add(serverTrSendQuota)
+ ct.fc.mu.Lock()
+ if uint32(serverTrSendQuota) != ct.fc.limit-ct.fc.pendingUpdate {
+ ct.fc.mu.Unlock()
+ return true, fmt.Errorf("server transport outflow: %v, estimated by client: %v", serverTrSendQuota, ct.fc.limit-ct.fc.pendingUpdate)
+ }
+ ct.fc.mu.Unlock()
+
+ // Check flow control window on server transport is equal to out flow of client transport.
+ ctx, _ = context.WithTimeout(context.Background(), time.Second)
+ clientTrSendQuota, err := wait(ctx, nil, nil, nil, ct.sendQuotaPool.acquire())
+ if err != nil {
+ return true, fmt.Errorf("error while acquiring client transport send quota. Err: %v", err)
+ }
+ ct.sendQuotaPool.add(clientTrSendQuota)
+ st.fc.mu.Lock()
+ if uint32(clientTrSendQuota) != st.fc.limit-st.fc.pendingUpdate {
+ st.fc.mu.Unlock()
+ return true, fmt.Errorf("client transport outflow: %v, estimated by client: %v", clientTrSendQuota, st.fc.limit-st.fc.pendingUpdate)
+ }
+ st.fc.mu.Unlock()
+
+ return false, nil
+ })
+
+}
+
func waitWhileTrue(t *testing.T, condition func() (bool, error)) {
var (
wait bool
@@ -1576,7 +1908,8 @@
stream, cleanUp := setUpHTTPStatusTest(t, httpStatus, wh)
defer cleanUp()
want := httpStatusConvTab[httpStatus]
- _, err := stream.Read([]byte{})
+ buf := make([]byte, 8)
+ _, err := stream.Read(buf)
if err == nil {
t.Fatalf("Stream.Read(_) unexpectedly returned no error. Expected stream error with code %v", want)
}
@@ -1592,7 +1925,8 @@
func TestHTTPStatusOKAndMissingGRPCStatus(t *testing.T) {
stream, cleanUp := setUpHTTPStatusTest(t, http.StatusOK, writeOneHeader)
defer cleanUp()
- _, err := stream.Read([]byte{})
+ buf := make([]byte, 8)
+ _, err := stream.Read(buf)
if err != io.EOF {
t.Fatalf("stream.Read(_) = _, %v, want _, io.EOF", err)
}
@@ -1607,3 +1941,50 @@
func TestHTTPStatusNottOKAndMissingGRPCStatusInSecondHeader(t *testing.T) {
testHTTPToGRPCStatusMapping(t, http.StatusUnauthorized, writeTwoHeaders)
}
+
+// If any error occurs on a call to Stream.Read, future calls
+// should continue to return that same error.
+func TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) {
+ testRecvBuffer := newRecvBuffer()
+ s := &Stream{
+ ctx: context.Background(),
+ goAway: make(chan struct{}),
+ buf: testRecvBuffer,
+ requestRead: func(int) {},
+ }
+ s.trReader = &transportReader{
+ reader: &recvBufferReader{
+ ctx: s.ctx,
+ goAway: s.goAway,
+ recv: s.buf,
+ },
+ windowHandler: func(int) {},
+ }
+ testData := make([]byte, 1)
+ testData[0] = 5
+ testErr := errors.New("test error")
+ s.write(recvMsg{data: testData, err: testErr})
+
+ inBuf := make([]byte, 1)
+ actualCount, actualErr := s.Read(inBuf)
+ if actualCount != 0 {
+ t.Errorf("actualCount, _ := s.Read(_) differs; want 0; got %v", actualCount)
+ }
+ if actualErr.Error() != testErr.Error() {
+ t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
+ }
+
+ s.write(recvMsg{data: testData, err: nil})
+ s.write(recvMsg{data: testData, err: errors.New("different error from first")})
+
+ for i := 0; i < 2; i++ {
+ inBuf := make([]byte, 1)
+ actualCount, actualErr := s.Read(inBuf)
+ if actualCount != 0 {
+ t.Errorf("actualCount, _ := s.Read(_) differs; want %v; got %v", 0, actualCount)
+ }
+ if actualErr.Error() != testErr.Error() {
+ t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
+ }
+ }
+}