Merge pull request #1071 from MakMukhi/issue_1060
Issue #1060 maximum number of streams on the client should be capped …
diff --git a/test/end2end_test.go b/test/end2end_test.go
index 9bcea03..d743623 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -2702,6 +2702,48 @@
}
}
+const defaultMaxStreamsClient = 100
+
+func TestExceedDefaultMaxStreamsLimit(t *testing.T) {
+ defer leakCheck(t)()
+ for _, e := range listTestEnv() {
+ testExceedDefaultMaxStreamsLimit(t, e)
+ }
+}
+
+func testExceedDefaultMaxStreamsLimit(t *testing.T, e env) {
+ te := newTest(t, e)
+ te.declareLogNoise(
+ "http2Client.notifyError got notified that the client transport was broken",
+ "Conn.resetTransport failed to create client transport",
+ "grpc: the connection is closing",
+ )
+ // When masStream is set to 0 the server doesn't send a settings frame for
+ // MaxConcurrentStreams, essentially allowing infinite (math.MaxInt32) streams.
+ // In such a case, there should be a default cap on the client-side.
+ te.maxStream = 0
+ te.startServer(&testServer{security: e.security})
+ defer te.tearDown()
+
+ cc := te.clientConn()
+ tc := testpb.NewTestServiceClient(cc)
+
+ // Create as many streams as a client can.
+ for i := 0; i < defaultMaxStreamsClient; i++ {
+ if _, err := tc.StreamingInputCall(te.ctx); err != nil {
+ t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
+ }
+ }
+
+ // Trying to create one more should timeout.
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+ _, err := tc.StreamingInputCall(ctx)
+ if err == nil || grpc.Code(err) != codes.DeadlineExceeded {
+ t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %s", tc, err, codes.DeadlineExceeded)
+ }
+}
+
func TestStreamsQuotaRecovery(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
diff --git a/transport/control.go b/transport/control.go
index 2586cba..33de7b6 100644
--- a/transport/control.go
+++ b/transport/control.go
@@ -44,8 +44,9 @@
// The default value of flow control window size in HTTP2 spec.
defaultWindowSize = 65535
// The initial window size for flow control.
- initialWindowSize = defaultWindowSize // for an RPC
- initialConnWindowSize = defaultWindowSize * 16 // for a connection
+ initialWindowSize = defaultWindowSize // for an RPC
+ initialConnWindowSize = defaultWindowSize * 16 // for a connection
+ defaultMaxStreamsClient = 100
)
// The following defines various control items which could flow through
diff --git a/transport/http2_client.go b/transport/http2_client.go
index 892f8ba..001522b 100644
--- a/transport/http2_client.go
+++ b/transport/http2_client.go
@@ -208,7 +208,8 @@
state: reachable,
activeStreams: make(map[uint32]*Stream),
creds: opts.PerRPCCredentials,
- maxStreams: math.MaxInt32,
+ maxStreams: defaultMaxStreamsClient,
+ streamsQuota: newQuotaPool(defaultMaxStreamsClient),
streamSendQuota: defaultWindowSize,
statsHandler: opts.StatsHandler,
}
@@ -337,21 +338,18 @@
t.mu.Unlock()
return nil, ErrConnClosing
}
- checkStreamsQuota := t.streamsQuota != nil
t.mu.Unlock()
- if checkStreamsQuota {
- sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire())
- if err != nil {
- return nil, err
- }
- // Returns the quota balance back.
- if sq > 1 {
- t.streamsQuota.add(sq - 1)
- }
+ sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire())
+ if err != nil {
+ return nil, err
+ }
+ // Returns the quota balance back.
+ if sq > 1 {
+ t.streamsQuota.add(sq - 1)
}
if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
// Return the quota back now because there is no stream returned to the caller.
- if _, ok := err.(StreamError); ok && checkStreamsQuota {
+ if _, ok := err.(StreamError); ok {
t.streamsQuota.add(1)
}
return nil, err
@@ -359,9 +357,7 @@
t.mu.Lock()
if t.state == draining {
t.mu.Unlock()
- if checkStreamsQuota {
- t.streamsQuota.add(1)
- }
+ t.streamsQuota.add(1)
// Need to make t writable again so that the rpc in flight can still proceed.
t.writableChan <- 0
return nil, ErrStreamDrain
@@ -374,16 +370,7 @@
s.clientStatsCtx = userCtx
t.activeStreams[s.id] = s
- // This stream is not counted when applySetings(...) initialize t.streamsQuota.
- // Reset t.streamsQuota to the right value.
- var reset bool
- if !checkStreamsQuota && t.streamsQuota != nil {
- reset = true
- }
t.mu.Unlock()
- if reset {
- t.streamsQuota.add(-1)
- }
// HPACK encodes various headers. Note that once WriteField(...) is
// called, the corresponding headers/continuation frame has to be sent
@@ -491,15 +478,11 @@
// CloseStream clears the footprint of a stream when the stream is not needed any more.
// This must not be executed in reader's goroutine.
func (t *http2Client) CloseStream(s *Stream, err error) {
- var updateStreams bool
t.mu.Lock()
if t.activeStreams == nil {
t.mu.Unlock()
return
}
- if t.streamsQuota != nil {
- updateStreams = true
- }
delete(t.activeStreams, s.id)
if t.state == draining && len(t.activeStreams) == 0 {
// The transport is draining and s is the last live stream on t.
@@ -508,10 +491,25 @@
return
}
t.mu.Unlock()
- if updateStreams {
- t.streamsQuota.add(1)
- }
+ // rstStream is true in case the stream is being closed at the client-side
+ // and the server needs to be intimated about it by sending a RST_STREAM
+ // frame.
+ // To make sure this frame is written to the wire before the headers of the
+ // next stream waiting for streamsQuota, we add to streamsQuota pool only
+ // after having acquired the writableChan to send RST_STREAM out (look at
+ // the controller() routine).
+ var rstStream bool
+ defer func() {
+ // In case, the client doesn't have to send RST_STREAM to server
+ // we can safely add back to streamsQuota pool now.
+ if !rstStream {
+ t.streamsQuota.add(1)
+ return
+ }
+ t.controlBuf.put(&resetStream{s.id, http2.ErrCodeCancel})
+ }()
s.mu.Lock()
+ rstStream = s.rstStream
if q := s.fc.resetPendingData(); q > 0 {
if n := t.fc.onRead(q); n > 0 {
t.controlBuf.put(&windowUpdate{0, n})
@@ -528,7 +526,7 @@
s.state = streamDone
s.mu.Unlock()
if se, ok := err.(StreamError); ok && se.Code != codes.DeadlineExceeded {
- t.controlBuf.put(&resetStream{s.id, http2.ErrCodeCancel})
+ rstStream = true
}
}
@@ -769,10 +767,10 @@
s.state = streamDone
s.statusCode = codes.Internal
s.statusDesc = err.Error()
+ s.rstStream = true
close(s.done)
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
- t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
return
}
s.mu.Unlock()
@@ -1043,16 +1041,10 @@
s.Val = math.MaxInt32
}
t.mu.Lock()
- reset := t.streamsQuota != nil
- if !reset {
- t.streamsQuota = newQuotaPool(int(s.Val) - len(t.activeStreams))
- }
ms := t.maxStreams
t.maxStreams = int(s.Val)
t.mu.Unlock()
- if reset {
- t.streamsQuota.add(int(s.Val) - ms)
- }
+ t.streamsQuota.add(int(s.Val) - ms)
case http2.SettingInitialWindowSize:
t.mu.Lock()
for _, stream := range t.activeStreams {
@@ -1085,6 +1077,12 @@
t.framer.writeSettings(true, i.ss...)
}
case *resetStream:
+ // If the server needs to be to intimated about stream closing,
+ // then we need to make sure the RST_STREAM frame is written to
+ // the wire before the headers of the next stream waiting on
+ // streamQuota. We ensure this by adding to the streamsQuota pool
+ // only after having acquired the writableChan to send RST_STREAM.
+ t.streamsQuota.add(1)
t.framer.writeRSTStream(true, i.streamID, i.code)
case *flushIO:
t.framer.flushWrite()
diff --git a/transport/transport.go b/transport/transport.go
index 7a462ec..aed75d5 100644
--- a/transport/transport.go
+++ b/transport/transport.go
@@ -213,6 +213,9 @@
// the status received from the server.
statusCode codes.Code
statusDesc string
+ // rstStream indicates whether a RST_STREAM frame needs to be sent
+ // to the server to signify that this stream is closing.
+ rstStream bool
}
// RecvCompress returns the compression algorithm applied to the inbound
diff --git a/transport/transport_test.go b/transport/transport_test.go
index 1ca6eb1..e91fc6e 100644
--- a/transport/transport_test.go
+++ b/transport/transport_test.go
@@ -507,7 +507,10 @@
case <-cc.streamsQuota.acquire():
t.Fatalf("streamsQuota.acquire() becomes readable mistakenly.")
default:
- if cc.streamsQuota.quota != 0 {
+ cc.streamsQuota.mu.Lock()
+ quota := cc.streamsQuota.quota
+ cc.streamsQuota.mu.Unlock()
+ if quota != 0 {
t.Fatalf("streamsQuota.quota got non-zero quota mistakenly.")
}
}