Acquire all stream related quota and cache it locally since no more than one write can happen in parallel on stream (#1614)

* Acquire all the stream related quotas and cache them locally since only one write can happen on a stream at a time.

* Added new tests.

* Fix flake

* Post-review updates

* Post-review update
diff --git a/transport/http2_client.go b/transport/http2_client.go
index f665ef0..1057512 100644
--- a/transport/http2_client.go
+++ b/transport/http2_client.go
@@ -659,44 +659,51 @@
 	}
 	hdr = append(hdr, data[:emptyLen]...)
 	data = data[emptyLen:]
+	var (
+		streamQuota    int
+		streamQuotaVer uint32
+		localSendQuota int
+		err            error
+		sqChan         <-chan int
+	)
 	for idx, r := range [][]byte{hdr, data} {
 		for len(r) > 0 {
 			size := http2MaxFrameLen
-			// Wait until the stream has some quota to send the data.
-			quotaChan, quotaVer := s.sendQuotaPool.acquireWithVersion()
-			sq, err := wait(s.ctx, t.ctx, s.done, s.goAway, quotaChan)
-			if err != nil {
-				return err
+			if size > len(r) {
+				size = len(r)
 			}
+			if streamQuota == 0 { // Used up all the locally cached stream quota.
+				sqChan, streamQuotaVer = s.sendQuotaPool.acquireWithVersion()
+				// Wait until the stream has some quota to send the data.
+				streamQuota, err = wait(s.ctx, t.ctx, s.done, s.goAway, sqChan)
+				if err != nil {
+					return err
+				}
+			}
+			if localSendQuota <= 0 { // Being a soft limit, it can go negative.
+				// Acquire local send quota to be able to write to the controlBuf.
+				localSendQuota, err = wait(s.ctx, t.ctx, s.done, s.goAway, s.localSendQuota.acquire())
+				if err != nil {
+					return err
+				}
+			}
+			if size > streamQuota {
+				size = streamQuota
+			} // No need to do that for localSendQuota since that's only a soft limit.
 			// Wait until the transport has some quota to send the data.
 			tq, err := wait(s.ctx, t.ctx, s.done, s.goAway, t.sendQuotaPool.acquire())
 			if err != nil {
 				return err
 			}
-			if sq < size {
-				size = sq
-			}
 			if tq < size {
 				size = tq
 			}
-			if size > len(r) {
-				size = len(r)
+			if tq > size { // Overbooked transport quota. Return it back.
+				t.sendQuotaPool.add(tq - size)
 			}
+			streamQuota -= size
+			localSendQuota -= size
 			p := r[:size]
-			ps := len(p)
-			if ps < tq {
-				// Overbooked transport quota. Return it back.
-				t.sendQuotaPool.add(tq - ps)
-			}
-			// Acquire local send quota to be able to write to the controlBuf.
-			ltq, err := wait(s.ctx, t.ctx, s.done, s.goAway, s.localSendQuota.acquire())
-			if err != nil {
-				if _, ok := err.(ConnectionError); !ok {
-					t.sendQuotaPool.add(ps)
-				}
-				return err
-			}
-			s.localSendQuota.add(ltq - ps) // It's ok if we make it negative.
 			var endStream bool
 			// See if this is the last frame to be written.
 			if opts.Last {
@@ -711,21 +718,28 @@
 				}
 			}
 			success := func() {
-				t.controlBuf.put(&dataFrame{streamID: s.id, endStream: endStream, d: p, f: func() { s.localSendQuota.add(ps) }})
-				if ps < sq {
-					s.sendQuotaPool.lockedAdd(sq - ps)
-				}
-				r = r[ps:]
+				sz := size
+				t.controlBuf.put(&dataFrame{streamID: s.id, endStream: endStream, d: p, f: func() { s.localSendQuota.add(sz) }})
+				r = r[size:]
 			}
-			failure := func() {
-				s.sendQuotaPool.lockedAdd(sq)
+			failure := func() { // The stream quota version must have changed.
+				// Our streamQuota cache is invalidated now, so give it back.
+				s.sendQuotaPool.lockedAdd(streamQuota + size)
 			}
-			if !s.sendQuotaPool.compareAndExecute(quotaVer, success, failure) {
-				t.sendQuotaPool.add(ps)
-				s.localSendQuota.add(ps)
+			if !s.sendQuotaPool.compareAndExecute(streamQuotaVer, success, failure) {
+				// Couldn't send this chunk out.
+				t.sendQuotaPool.add(size)
+				localSendQuota += size
+				streamQuota = 0
 			}
 		}
 	}
+	if streamQuota > 0 { // Add the left over quota back to stream.
+		s.sendQuotaPool.add(streamQuota)
+	}
+	if localSendQuota > 0 {
+		s.localSendQuota.add(localSendQuota)
+	}
 	if !opts.Last {
 		return nil
 	}
diff --git a/transport/http2_server.go b/transport/http2_server.go
index f84d70a..6582024 100644
--- a/transport/http2_server.go
+++ b/transport/http2_server.go
@@ -813,7 +813,7 @@
 
 // Write converts the data into HTTP2 data frame and sends it out. Non-nil error
 // is returns if it fails (e.g., framing error, transport error).
-func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) (err error) {
+func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
 	select {
 	case <-s.ctx.Done():
 		return ContextErr(s.ctx.Err())
@@ -842,66 +842,79 @@
 	}
 	hdr = append(hdr, data[:emptyLen]...)
 	data = data[emptyLen:]
+	var (
+		streamQuota    int
+		streamQuotaVer uint32
+		localSendQuota int
+		err            error
+		sqChan         <-chan int
+	)
 	for _, r := range [][]byte{hdr, data} {
 		for len(r) > 0 {
 			size := http2MaxFrameLen
-			// Wait until the stream has some quota to send the data.
-			quotaChan, quotaVer := s.sendQuotaPool.acquireWithVersion()
-			sq, err := wait(s.ctx, t.ctx, nil, nil, quotaChan)
-			if err != nil {
-				return err
+			if size > len(r) {
+				size = len(r)
 			}
+			if streamQuota == 0 { // Used up all the locally cached stream quota.
+				sqChan, streamQuotaVer = s.sendQuotaPool.acquireWithVersion()
+				// Wait until the stream has some quota to send the data.
+				streamQuota, err = wait(s.ctx, t.ctx, nil, nil, sqChan)
+				if err != nil {
+					return err
+				}
+			}
+			if localSendQuota <= 0 {
+				localSendQuota, err = wait(s.ctx, t.ctx, nil, nil, s.localSendQuota.acquire())
+				if err != nil {
+					return err
+				}
+			}
+			if size > streamQuota {
+				size = streamQuota
+			} // No need to do that for localSendQuota since that's only a soft limit.
 			// Wait until the transport has some quota to send the data.
 			tq, err := wait(s.ctx, t.ctx, nil, nil, t.sendQuotaPool.acquire())
 			if err != nil {
 				return err
 			}
-			if sq < size {
-				size = sq
-			}
 			if tq < size {
 				size = tq
 			}
-			if size > len(r) {
-				size = len(r)
+			if tq > size {
+				t.sendQuotaPool.add(tq - size)
 			}
+			streamQuota -= size
+			localSendQuota -= size
 			p := r[:size]
-			ps := len(p)
-			if ps < tq {
-				// Overbooked transport quota. Return it back.
-				t.sendQuotaPool.add(tq - ps)
-			}
-			// Acquire local send quota to be able to write to the controlBuf.
-			ltq, err := wait(s.ctx, t.ctx, nil, nil, s.localSendQuota.acquire())
-			if err != nil {
-				if _, ok := err.(ConnectionError); !ok {
-					t.sendQuotaPool.add(ps)
-				}
-				return err
-			}
-			s.localSendQuota.add(ltq - ps) // It's ok we make this negative.
 			// Reset ping strikes when sending data since this might cause
 			// the peer to send ping.
 			atomic.StoreUint32(&t.resetPingStrikes, 1)
 			success := func() {
+				sz := size
 				t.controlBuf.put(&dataFrame{streamID: s.id, endStream: false, d: p, f: func() {
-					s.localSendQuota.add(ps)
+					s.localSendQuota.add(sz)
 				}})
-				if ps < sq {
-					// Overbooked stream quota. Return it back.
-					s.sendQuotaPool.lockedAdd(sq - ps)
-				}
-				r = r[ps:]
+				r = r[size:]
 			}
-			failure := func() {
-				s.sendQuotaPool.lockedAdd(sq)
+			failure := func() { // The stream quota version must have changed.
+				// Our streamQuota cache is invalidated now, so give it back.
+				s.sendQuotaPool.lockedAdd(streamQuota + size)
 			}
-			if !s.sendQuotaPool.compareAndExecute(quotaVer, success, failure) {
-				t.sendQuotaPool.add(ps)
-				s.localSendQuota.add(ps)
+			if !s.sendQuotaPool.compareAndExecute(streamQuotaVer, success, failure) {
+				// Couldn't send this chunk out.
+				t.sendQuotaPool.add(size)
+				localSendQuota += size
+				streamQuota = 0
 			}
 		}
 	}
+	if streamQuota > 0 {
+		// ADd the left over quota back to stream.
+		s.sendQuotaPool.add(streamQuota)
+	}
+	if localSendQuota > 0 {
+		s.localSendQuota.add(localSendQuota)
+	}
 	return nil
 }
 
diff --git a/transport/transport_test.go b/transport/transport_test.go
index e1dd080..6960254 100644
--- a/transport/transport_test.go
+++ b/transport/transport_test.go
@@ -115,8 +115,12 @@
 
 func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) {
 	header := make([]byte, 5)
-	for i := 0; i < 10; i++ {
+	for {
 		if _, err := s.Read(header); err != nil {
+			if err == io.EOF {
+				h.t.WriteStatus(s, status.New(codes.OK, ""))
+				return
+			}
 			t.Fatalf("Error on server while reading data header: %v", err)
 		}
 		sz := binary.BigEndian.Uint32(header[1:])
@@ -1786,6 +1790,12 @@
 			t.Fatalf("Length of message received by client: %v, want: %v", len(recvMsg), len(msg))
 		}
 	}
+	defer func() {
+		ct.Write(cstream, nil, nil, &Options{Last: true}) // Close the stream.
+		if _, err := cstream.Read(header); err != io.EOF {
+			t.Fatalf("Client expected an EOF from the server. Got: %v", err)
+		}
+	}()
 	var sstream *Stream
 	st.mu.Lock()
 	for _, v := range st.activeStreams {
@@ -2156,3 +2166,73 @@
 		}
 	}
 }
+
+func TestPingPong1B(t *testing.T) {
+	runPingPongTest(t, 1)
+}
+
+func TestPingPong1KB(t *testing.T) {
+	runPingPongTest(t, 1024)
+}
+
+func TestPingPong64KB(t *testing.T) {
+	runPingPongTest(t, 65536)
+}
+
+func TestPingPong1MB(t *testing.T) {
+	runPingPongTest(t, 1048576)
+}
+
+//This is a stress-test of flow control logic.
+func runPingPongTest(t *testing.T, msgSize int) {
+	server, client := setUp(t, 0, 0, pingpong)
+	defer server.stop()
+	defer client.Close()
+	waitWhileTrue(t, func() (bool, error) {
+		server.mu.Lock()
+		defer server.mu.Unlock()
+		if len(server.conns) == 0 {
+			return true, fmt.Errorf("timed out while waiting for server transport to be created")
+		}
+		return false, nil
+	})
+	ct := client.(*http2Client)
+	stream, err := client.NewStream(context.Background(), &CallHdr{})
+	if err != nil {
+		t.Fatalf("Failed to create stream. Err: %v", err)
+	}
+	msg := make([]byte, msgSize)
+	outgoingHeader := make([]byte, 5)
+	outgoingHeader[0] = byte(0)
+	binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(msgSize))
+	opts := &Options{}
+	incomingHeader := make([]byte, 5)
+	done := make(chan struct{})
+	go func() {
+		timer := time.NewTimer(time.Second * 5)
+		<-timer.C
+		close(done)
+	}()
+	for {
+		select {
+		case <-done:
+			ct.Write(stream, nil, nil, &Options{Last: true})
+			if _, err := stream.Read(incomingHeader); err != io.EOF {
+				t.Fatalf("Client expected EOF from the server. Got: %v", err)
+			}
+			return
+		default:
+			if err := ct.Write(stream, outgoingHeader, msg, opts); err != nil {
+				t.Fatalf("Error on client while writing message. Err: %v", err)
+			}
+			if _, err := stream.Read(incomingHeader); err != nil {
+				t.Fatalf("Error on client while reading data header. Err: %v", err)
+			}
+			sz := binary.BigEndian.Uint32(incomingHeader[1:])
+			recvMsg := make([]byte, int(sz))
+			if _, err := stream.Read(recvMsg); err != nil {
+				t.Fatalf("Error on client while reading data. Err: %v", err)
+			}
+		}
+	}
+}