Merge pull request #1071 from MakMukhi/issue_1060

Issue #1060 maximum number of streams on the client should be capped …
diff --git a/test/end2end_test.go b/test/end2end_test.go
index 9bcea03..d743623 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -2702,6 +2702,48 @@
 	}
 }
 
+const defaultMaxStreamsClient = 100
+
+func TestExceedDefaultMaxStreamsLimit(t *testing.T) {
+	defer leakCheck(t)()
+	for _, e := range listTestEnv() {
+		testExceedDefaultMaxStreamsLimit(t, e)
+	}
+}
+
+func testExceedDefaultMaxStreamsLimit(t *testing.T, e env) {
+	te := newTest(t, e)
+	te.declareLogNoise(
+		"http2Client.notifyError got notified that the client transport was broken",
+		"Conn.resetTransport failed to create client transport",
+		"grpc: the connection is closing",
+	)
+	// When masStream is set to 0 the server doesn't send a settings frame for
+	// MaxConcurrentStreams, essentially allowing infinite (math.MaxInt32) streams.
+	// In such a case, there should be a default cap on the client-side.
+	te.maxStream = 0
+	te.startServer(&testServer{security: e.security})
+	defer te.tearDown()
+
+	cc := te.clientConn()
+	tc := testpb.NewTestServiceClient(cc)
+
+	// Create as many streams as a client can.
+	for i := 0; i < defaultMaxStreamsClient; i++ {
+		if _, err := tc.StreamingInputCall(te.ctx); err != nil {
+			t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
+		}
+	}
+
+	// Trying to create one more should timeout.
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+	defer cancel()
+	_, err := tc.StreamingInputCall(ctx)
+	if err == nil || grpc.Code(err) != codes.DeadlineExceeded {
+		t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %s", tc, err, codes.DeadlineExceeded)
+	}
+}
+
 func TestStreamsQuotaRecovery(t *testing.T) {
 	defer leakCheck(t)()
 	for _, e := range listTestEnv() {
diff --git a/transport/control.go b/transport/control.go
index 2586cba..33de7b6 100644
--- a/transport/control.go
+++ b/transport/control.go
@@ -44,8 +44,9 @@
 	// The default value of flow control window size in HTTP2 spec.
 	defaultWindowSize = 65535
 	// The initial window size for flow control.
-	initialWindowSize     = defaultWindowSize      // for an RPC
-	initialConnWindowSize = defaultWindowSize * 16 // for a connection
+	initialWindowSize       = defaultWindowSize      // for an RPC
+	initialConnWindowSize   = defaultWindowSize * 16 // for a connection
+	defaultMaxStreamsClient = 100
 )
 
 // The following defines various control items which could flow through
diff --git a/transport/http2_client.go b/transport/http2_client.go
index 892f8ba..001522b 100644
--- a/transport/http2_client.go
+++ b/transport/http2_client.go
@@ -208,7 +208,8 @@
 		state:           reachable,
 		activeStreams:   make(map[uint32]*Stream),
 		creds:           opts.PerRPCCredentials,
-		maxStreams:      math.MaxInt32,
+		maxStreams:      defaultMaxStreamsClient,
+		streamsQuota:    newQuotaPool(defaultMaxStreamsClient),
 		streamSendQuota: defaultWindowSize,
 		statsHandler:    opts.StatsHandler,
 	}
@@ -337,21 +338,18 @@
 		t.mu.Unlock()
 		return nil, ErrConnClosing
 	}
-	checkStreamsQuota := t.streamsQuota != nil
 	t.mu.Unlock()
-	if checkStreamsQuota {
-		sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire())
-		if err != nil {
-			return nil, err
-		}
-		// Returns the quota balance back.
-		if sq > 1 {
-			t.streamsQuota.add(sq - 1)
-		}
+	sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire())
+	if err != nil {
+		return nil, err
+	}
+	// Returns the quota balance back.
+	if sq > 1 {
+		t.streamsQuota.add(sq - 1)
 	}
 	if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
 		// Return the quota back now because there is no stream returned to the caller.
-		if _, ok := err.(StreamError); ok && checkStreamsQuota {
+		if _, ok := err.(StreamError); ok {
 			t.streamsQuota.add(1)
 		}
 		return nil, err
@@ -359,9 +357,7 @@
 	t.mu.Lock()
 	if t.state == draining {
 		t.mu.Unlock()
-		if checkStreamsQuota {
-			t.streamsQuota.add(1)
-		}
+		t.streamsQuota.add(1)
 		// Need to make t writable again so that the rpc in flight can still proceed.
 		t.writableChan <- 0
 		return nil, ErrStreamDrain
@@ -374,16 +370,7 @@
 	s.clientStatsCtx = userCtx
 	t.activeStreams[s.id] = s
 
-	// This stream is not counted when applySetings(...) initialize t.streamsQuota.
-	// Reset t.streamsQuota to the right value.
-	var reset bool
-	if !checkStreamsQuota && t.streamsQuota != nil {
-		reset = true
-	}
 	t.mu.Unlock()
-	if reset {
-		t.streamsQuota.add(-1)
-	}
 
 	// HPACK encodes various headers. Note that once WriteField(...) is
 	// called, the corresponding headers/continuation frame has to be sent
@@ -491,15 +478,11 @@
 // CloseStream clears the footprint of a stream when the stream is not needed any more.
 // This must not be executed in reader's goroutine.
 func (t *http2Client) CloseStream(s *Stream, err error) {
-	var updateStreams bool
 	t.mu.Lock()
 	if t.activeStreams == nil {
 		t.mu.Unlock()
 		return
 	}
-	if t.streamsQuota != nil {
-		updateStreams = true
-	}
 	delete(t.activeStreams, s.id)
 	if t.state == draining && len(t.activeStreams) == 0 {
 		// The transport is draining and s is the last live stream on t.
@@ -508,10 +491,25 @@
 		return
 	}
 	t.mu.Unlock()
-	if updateStreams {
-		t.streamsQuota.add(1)
-	}
+	// rstStream is true in case the stream is being closed at the client-side
+	// and the server needs to be intimated about it by sending a RST_STREAM
+	// frame.
+	// To make sure this frame is written to the wire before the headers of the
+	// next stream waiting for streamsQuota, we add to streamsQuota pool only
+	// after having acquired the writableChan to send RST_STREAM out (look at
+	// the controller() routine).
+	var rstStream bool
+	defer func() {
+		// In case, the client doesn't have to send RST_STREAM to server
+		// we can safely add back to streamsQuota pool now.
+		if !rstStream {
+			t.streamsQuota.add(1)
+			return
+		}
+		t.controlBuf.put(&resetStream{s.id, http2.ErrCodeCancel})
+	}()
 	s.mu.Lock()
+	rstStream = s.rstStream
 	if q := s.fc.resetPendingData(); q > 0 {
 		if n := t.fc.onRead(q); n > 0 {
 			t.controlBuf.put(&windowUpdate{0, n})
@@ -528,7 +526,7 @@
 	s.state = streamDone
 	s.mu.Unlock()
 	if se, ok := err.(StreamError); ok && se.Code != codes.DeadlineExceeded {
-		t.controlBuf.put(&resetStream{s.id, http2.ErrCodeCancel})
+		rstStream = true
 	}
 }
 
@@ -769,10 +767,10 @@
 			s.state = streamDone
 			s.statusCode = codes.Internal
 			s.statusDesc = err.Error()
+			s.rstStream = true
 			close(s.done)
 			s.mu.Unlock()
 			s.write(recvMsg{err: io.EOF})
-			t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
 			return
 		}
 		s.mu.Unlock()
@@ -1043,16 +1041,10 @@
 				s.Val = math.MaxInt32
 			}
 			t.mu.Lock()
-			reset := t.streamsQuota != nil
-			if !reset {
-				t.streamsQuota = newQuotaPool(int(s.Val) - len(t.activeStreams))
-			}
 			ms := t.maxStreams
 			t.maxStreams = int(s.Val)
 			t.mu.Unlock()
-			if reset {
-				t.streamsQuota.add(int(s.Val) - ms)
-			}
+			t.streamsQuota.add(int(s.Val) - ms)
 		case http2.SettingInitialWindowSize:
 			t.mu.Lock()
 			for _, stream := range t.activeStreams {
@@ -1085,6 +1077,12 @@
 						t.framer.writeSettings(true, i.ss...)
 					}
 				case *resetStream:
+					// If the server needs to be to intimated about stream closing,
+					// then we need to make sure the RST_STREAM frame is written to
+					// the wire before the headers of the next stream waiting on
+					// streamQuota. We ensure this by adding to the streamsQuota pool
+					// only after having acquired the writableChan to send RST_STREAM.
+					t.streamsQuota.add(1)
 					t.framer.writeRSTStream(true, i.streamID, i.code)
 				case *flushIO:
 					t.framer.flushWrite()
diff --git a/transport/transport.go b/transport/transport.go
index 7a462ec..aed75d5 100644
--- a/transport/transport.go
+++ b/transport/transport.go
@@ -213,6 +213,9 @@
 	// the status received from the server.
 	statusCode codes.Code
 	statusDesc string
+	// rstStream indicates whether a RST_STREAM frame needs to be sent
+	// to the server to signify that this stream is closing.
+	rstStream bool
 }
 
 // RecvCompress returns the compression algorithm applied to the inbound
diff --git a/transport/transport_test.go b/transport/transport_test.go
index 1ca6eb1..e91fc6e 100644
--- a/transport/transport_test.go
+++ b/transport/transport_test.go
@@ -507,7 +507,10 @@
 			case <-cc.streamsQuota.acquire():
 				t.Fatalf("streamsQuota.acquire() becomes readable mistakenly.")
 			default:
-				if cc.streamsQuota.quota != 0 {
+				cc.streamsQuota.mu.Lock()
+				quota := cc.streamsQuota.quota
+				cc.streamsQuota.mu.Unlock()
+				if quota != 0 {
 					t.Fatalf("streamsQuota.quota got non-zero quota mistakenly.")
 				}
 			}