Acquire all stream related quota and cache it locally since no more than one write can happen in parallel on stream (#1614)
* Acquire all the stream related quotas and cache them locally since only one write can happen on a stream at a time.
* Added new tests.
* Fix flake
* Post-review updates
* Post-review update
diff --git a/transport/http2_client.go b/transport/http2_client.go
index f665ef0..1057512 100644
--- a/transport/http2_client.go
+++ b/transport/http2_client.go
@@ -659,44 +659,51 @@
}
hdr = append(hdr, data[:emptyLen]...)
data = data[emptyLen:]
+ var (
+ streamQuota int
+ streamQuotaVer uint32
+ localSendQuota int
+ err error
+ sqChan <-chan int
+ )
for idx, r := range [][]byte{hdr, data} {
for len(r) > 0 {
size := http2MaxFrameLen
- // Wait until the stream has some quota to send the data.
- quotaChan, quotaVer := s.sendQuotaPool.acquireWithVersion()
- sq, err := wait(s.ctx, t.ctx, s.done, s.goAway, quotaChan)
- if err != nil {
- return err
+ if size > len(r) {
+ size = len(r)
}
+ if streamQuota == 0 { // Used up all the locally cached stream quota.
+ sqChan, streamQuotaVer = s.sendQuotaPool.acquireWithVersion()
+ // Wait until the stream has some quota to send the data.
+ streamQuota, err = wait(s.ctx, t.ctx, s.done, s.goAway, sqChan)
+ if err != nil {
+ return err
+ }
+ }
+ if localSendQuota <= 0 { // Being a soft limit, it can go negative.
+ // Acquire local send quota to be able to write to the controlBuf.
+ localSendQuota, err = wait(s.ctx, t.ctx, s.done, s.goAway, s.localSendQuota.acquire())
+ if err != nil {
+ return err
+ }
+ }
+ if size > streamQuota {
+ size = streamQuota
+ } // No need to do that for localSendQuota since that's only a soft limit.
// Wait until the transport has some quota to send the data.
tq, err := wait(s.ctx, t.ctx, s.done, s.goAway, t.sendQuotaPool.acquire())
if err != nil {
return err
}
- if sq < size {
- size = sq
- }
if tq < size {
size = tq
}
- if size > len(r) {
- size = len(r)
+ if tq > size { // Overbooked transport quota. Return it back.
+ t.sendQuotaPool.add(tq - size)
}
+ streamQuota -= size
+ localSendQuota -= size
p := r[:size]
- ps := len(p)
- if ps < tq {
- // Overbooked transport quota. Return it back.
- t.sendQuotaPool.add(tq - ps)
- }
- // Acquire local send quota to be able to write to the controlBuf.
- ltq, err := wait(s.ctx, t.ctx, s.done, s.goAway, s.localSendQuota.acquire())
- if err != nil {
- if _, ok := err.(ConnectionError); !ok {
- t.sendQuotaPool.add(ps)
- }
- return err
- }
- s.localSendQuota.add(ltq - ps) // It's ok if we make it negative.
var endStream bool
// See if this is the last frame to be written.
if opts.Last {
@@ -711,21 +718,28 @@
}
}
success := func() {
- t.controlBuf.put(&dataFrame{streamID: s.id, endStream: endStream, d: p, f: func() { s.localSendQuota.add(ps) }})
- if ps < sq {
- s.sendQuotaPool.lockedAdd(sq - ps)
- }
- r = r[ps:]
+ sz := size
+ t.controlBuf.put(&dataFrame{streamID: s.id, endStream: endStream, d: p, f: func() { s.localSendQuota.add(sz) }})
+ r = r[size:]
}
- failure := func() {
- s.sendQuotaPool.lockedAdd(sq)
+ failure := func() { // The stream quota version must have changed.
+ // Our streamQuota cache is invalidated now, so give it back.
+ s.sendQuotaPool.lockedAdd(streamQuota + size)
}
- if !s.sendQuotaPool.compareAndExecute(quotaVer, success, failure) {
- t.sendQuotaPool.add(ps)
- s.localSendQuota.add(ps)
+ if !s.sendQuotaPool.compareAndExecute(streamQuotaVer, success, failure) {
+ // Couldn't send this chunk out.
+ t.sendQuotaPool.add(size)
+ localSendQuota += size
+ streamQuota = 0
}
}
}
+ if streamQuota > 0 { // Add the left over quota back to stream.
+ s.sendQuotaPool.add(streamQuota)
+ }
+ if localSendQuota > 0 {
+ s.localSendQuota.add(localSendQuota)
+ }
if !opts.Last {
return nil
}
diff --git a/transport/http2_server.go b/transport/http2_server.go
index f84d70a..6582024 100644
--- a/transport/http2_server.go
+++ b/transport/http2_server.go
@@ -813,7 +813,7 @@
// Write converts the data into HTTP2 data frame and sends it out. Non-nil error
// is returns if it fails (e.g., framing error, transport error).
-func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) (err error) {
+func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
select {
case <-s.ctx.Done():
return ContextErr(s.ctx.Err())
@@ -842,66 +842,79 @@
}
hdr = append(hdr, data[:emptyLen]...)
data = data[emptyLen:]
+ var (
+ streamQuota int
+ streamQuotaVer uint32
+ localSendQuota int
+ err error
+ sqChan <-chan int
+ )
for _, r := range [][]byte{hdr, data} {
for len(r) > 0 {
size := http2MaxFrameLen
- // Wait until the stream has some quota to send the data.
- quotaChan, quotaVer := s.sendQuotaPool.acquireWithVersion()
- sq, err := wait(s.ctx, t.ctx, nil, nil, quotaChan)
- if err != nil {
- return err
+ if size > len(r) {
+ size = len(r)
}
+ if streamQuota == 0 { // Used up all the locally cached stream quota.
+ sqChan, streamQuotaVer = s.sendQuotaPool.acquireWithVersion()
+ // Wait until the stream has some quota to send the data.
+ streamQuota, err = wait(s.ctx, t.ctx, nil, nil, sqChan)
+ if err != nil {
+ return err
+ }
+ }
+ if localSendQuota <= 0 {
+ localSendQuota, err = wait(s.ctx, t.ctx, nil, nil, s.localSendQuota.acquire())
+ if err != nil {
+ return err
+ }
+ }
+ if size > streamQuota {
+ size = streamQuota
+ } // No need to do that for localSendQuota since that's only a soft limit.
// Wait until the transport has some quota to send the data.
tq, err := wait(s.ctx, t.ctx, nil, nil, t.sendQuotaPool.acquire())
if err != nil {
return err
}
- if sq < size {
- size = sq
- }
if tq < size {
size = tq
}
- if size > len(r) {
- size = len(r)
+ if tq > size {
+ t.sendQuotaPool.add(tq - size)
}
+ streamQuota -= size
+ localSendQuota -= size
p := r[:size]
- ps := len(p)
- if ps < tq {
- // Overbooked transport quota. Return it back.
- t.sendQuotaPool.add(tq - ps)
- }
- // Acquire local send quota to be able to write to the controlBuf.
- ltq, err := wait(s.ctx, t.ctx, nil, nil, s.localSendQuota.acquire())
- if err != nil {
- if _, ok := err.(ConnectionError); !ok {
- t.sendQuotaPool.add(ps)
- }
- return err
- }
- s.localSendQuota.add(ltq - ps) // It's ok we make this negative.
// Reset ping strikes when sending data since this might cause
// the peer to send ping.
atomic.StoreUint32(&t.resetPingStrikes, 1)
success := func() {
+ sz := size
t.controlBuf.put(&dataFrame{streamID: s.id, endStream: false, d: p, f: func() {
- s.localSendQuota.add(ps)
+ s.localSendQuota.add(sz)
}})
- if ps < sq {
- // Overbooked stream quota. Return it back.
- s.sendQuotaPool.lockedAdd(sq - ps)
- }
- r = r[ps:]
+ r = r[size:]
}
- failure := func() {
- s.sendQuotaPool.lockedAdd(sq)
+ failure := func() { // The stream quota version must have changed.
+ // Our streamQuota cache is invalidated now, so give it back.
+ s.sendQuotaPool.lockedAdd(streamQuota + size)
}
- if !s.sendQuotaPool.compareAndExecute(quotaVer, success, failure) {
- t.sendQuotaPool.add(ps)
- s.localSendQuota.add(ps)
+ if !s.sendQuotaPool.compareAndExecute(streamQuotaVer, success, failure) {
+ // Couldn't send this chunk out.
+ t.sendQuotaPool.add(size)
+ localSendQuota += size
+ streamQuota = 0
}
}
}
+ if streamQuota > 0 {
+ // ADd the left over quota back to stream.
+ s.sendQuotaPool.add(streamQuota)
+ }
+ if localSendQuota > 0 {
+ s.localSendQuota.add(localSendQuota)
+ }
return nil
}
diff --git a/transport/transport_test.go b/transport/transport_test.go
index e1dd080..6960254 100644
--- a/transport/transport_test.go
+++ b/transport/transport_test.go
@@ -115,8 +115,12 @@
func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) {
header := make([]byte, 5)
- for i := 0; i < 10; i++ {
+ for {
if _, err := s.Read(header); err != nil {
+ if err == io.EOF {
+ h.t.WriteStatus(s, status.New(codes.OK, ""))
+ return
+ }
t.Fatalf("Error on server while reading data header: %v", err)
}
sz := binary.BigEndian.Uint32(header[1:])
@@ -1786,6 +1790,12 @@
t.Fatalf("Length of message received by client: %v, want: %v", len(recvMsg), len(msg))
}
}
+ defer func() {
+ ct.Write(cstream, nil, nil, &Options{Last: true}) // Close the stream.
+ if _, err := cstream.Read(header); err != io.EOF {
+ t.Fatalf("Client expected an EOF from the server. Got: %v", err)
+ }
+ }()
var sstream *Stream
st.mu.Lock()
for _, v := range st.activeStreams {
@@ -2156,3 +2166,73 @@
}
}
}
+
+func TestPingPong1B(t *testing.T) {
+ runPingPongTest(t, 1)
+}
+
+func TestPingPong1KB(t *testing.T) {
+ runPingPongTest(t, 1024)
+}
+
+func TestPingPong64KB(t *testing.T) {
+ runPingPongTest(t, 65536)
+}
+
+func TestPingPong1MB(t *testing.T) {
+ runPingPongTest(t, 1048576)
+}
+
+//This is a stress-test of flow control logic.
+func runPingPongTest(t *testing.T, msgSize int) {
+ server, client := setUp(t, 0, 0, pingpong)
+ defer server.stop()
+ defer client.Close()
+ waitWhileTrue(t, func() (bool, error) {
+ server.mu.Lock()
+ defer server.mu.Unlock()
+ if len(server.conns) == 0 {
+ return true, fmt.Errorf("timed out while waiting for server transport to be created")
+ }
+ return false, nil
+ })
+ ct := client.(*http2Client)
+ stream, err := client.NewStream(context.Background(), &CallHdr{})
+ if err != nil {
+ t.Fatalf("Failed to create stream. Err: %v", err)
+ }
+ msg := make([]byte, msgSize)
+ outgoingHeader := make([]byte, 5)
+ outgoingHeader[0] = byte(0)
+ binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(msgSize))
+ opts := &Options{}
+ incomingHeader := make([]byte, 5)
+ done := make(chan struct{})
+ go func() {
+ timer := time.NewTimer(time.Second * 5)
+ <-timer.C
+ close(done)
+ }()
+ for {
+ select {
+ case <-done:
+ ct.Write(stream, nil, nil, &Options{Last: true})
+ if _, err := stream.Read(incomingHeader); err != io.EOF {
+ t.Fatalf("Client expected EOF from the server. Got: %v", err)
+ }
+ return
+ default:
+ if err := ct.Write(stream, outgoingHeader, msg, opts); err != nil {
+ t.Fatalf("Error on client while writing message. Err: %v", err)
+ }
+ if _, err := stream.Read(incomingHeader); err != nil {
+ t.Fatalf("Error on client while reading data header. Err: %v", err)
+ }
+ sz := binary.BigEndian.Uint32(incomingHeader[1:])
+ recvMsg := make([]byte, int(sz))
+ if _, err := stream.Read(recvMsg); err != nil {
+ t.Fatalf("Error on client while reading data. Err: %v", err)
+ }
+ }
+ }
+}