feat(pubsublite): Flow controller and offset tracker for the subscriber (#3132)
Ports the flow control token counter & batcher, and offset tracker from the Pub/Sub Lite Java client library.
diff --git a/pubsublite/go.mod b/pubsublite/go.mod
index c5d7c50..9150918 100644
--- a/pubsublite/go.mod
+++ b/pubsublite/go.mod
@@ -5,6 +5,7 @@
require (
cloud.google.com/go v0.71.0
github.com/golang/protobuf v1.4.3
+ github.com/google/go-cmp v0.5.2
github.com/googleapis/gax-go/v2 v2.0.5
golang.org/x/tools v0.0.0-20201102212025-f46e4245211d // indirect
google.golang.org/api v0.34.0
diff --git a/pubsublite/internal/wire/flow_control.go b/pubsublite/internal/wire/flow_control.go
new file mode 100644
index 0000000..785cdd0
--- /dev/null
+++ b/pubsublite/internal/wire/flow_control.go
@@ -0,0 +1,180 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+
+package wire
+
+import (
+ "errors"
+ "fmt"
+ "math"
+
+ pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
+)
+
+var (
+ errTokenCounterBytesNegative = errors.New("pubsublite: received messages that account for more bytes than were requested")
+ errTokenCounterMessagesNegative = errors.New("pubsublite: received more messages than were requested")
+ errOutOfOrderMessages = errors.New("pubsublite: server delivered messages out of order")
+)
+
+type flowControlTokens struct {
+ Bytes int64
+ Messages int64
+}
+
+// A TokenCounter stores the amount of outstanding byte and message flow control
+// tokens that the client believes exists for the stream.
+type tokenCounter struct {
+ Bytes int64
+ Messages int64
+}
+
+func saturatedAdd(sum, delta int64) int64 {
+ remainder := math.MaxInt64 - sum
+ if delta >= remainder {
+ return math.MaxInt64
+ }
+ return sum + delta
+}
+
+func (tc *tokenCounter) Add(delta flowControlTokens) {
+ tc.Bytes = saturatedAdd(tc.Bytes, delta.Bytes)
+ tc.Messages = saturatedAdd(tc.Messages, delta.Messages)
+}
+
+func (tc *tokenCounter) Sub(delta flowControlTokens) error {
+ if delta.Bytes > tc.Bytes {
+ return errTokenCounterBytesNegative
+ }
+ if delta.Messages > tc.Messages {
+ return errTokenCounterMessagesNegative
+ }
+ tc.Bytes -= delta.Bytes
+ tc.Messages -= delta.Messages
+ return nil
+}
+
+func (tc *tokenCounter) Reset() {
+ tc.Bytes = 0
+ tc.Messages = 0
+}
+
+func (tc *tokenCounter) ToFlowControlRequest() *pb.FlowControlRequest {
+ if tc.Bytes <= 0 && tc.Messages <= 0 {
+ return nil
+ }
+ return &pb.FlowControlRequest{
+ AllowedBytes: tc.Bytes,
+ AllowedMessages: tc.Messages,
+ }
+}
+
+// flowControlBatcher tracks flow control tokens and manages batching of flow
+// control requests to avoid overwhelming the server. It is only accessed by
+// the wireSubscriber.
+type flowControlBatcher struct {
+ // The current amount of outstanding byte and message flow control tokens.
+ clientTokens tokenCounter
+ // The pending batch flow control request that needs to be sent to the stream.
+ pendingTokens tokenCounter
+}
+
+const expediteBatchRequestRatio = 0.5
+
+func exceedsExpediteRatio(pending, client int64) bool {
+ return client > 0 && (float64(pending)/float64(client)) >= expediteBatchRequestRatio
+}
+
+// OnClientFlow increments flow control tokens. This occurs when:
+// - Initialization from ReceiveSettings.
+// - The user acks messages.
+func (fc *flowControlBatcher) OnClientFlow(tokens flowControlTokens) {
+ fc.clientTokens.Add(tokens)
+ fc.pendingTokens.Add(tokens)
+}
+
+// OnMessages decrements flow control tokens when messages are received from the
+// server.
+func (fc *flowControlBatcher) OnMessages(msgs []*pb.SequencedMessage) error {
+ var totalBytes int64
+ for _, msg := range msgs {
+ totalBytes += msg.GetSizeBytes()
+ }
+ return fc.clientTokens.Sub(flowControlTokens{Bytes: totalBytes, Messages: int64(len(msgs))})
+}
+
+// RequestForRestart returns a FlowControlRequest that should be sent when a new
+// subscriber stream is connected. May return nil.
+func (fc *flowControlBatcher) RequestForRestart() *pb.FlowControlRequest {
+ fc.pendingTokens.Reset()
+ return fc.clientTokens.ToFlowControlRequest()
+}
+
+// ReleasePendingRequest returns a non-nil request when there is a batch
+// FlowControlRequest to send to the stream.
+func (fc *flowControlBatcher) ReleasePendingRequest() *pb.FlowControlRequest {
+ req := fc.pendingTokens.ToFlowControlRequest()
+ fc.pendingTokens.Reset()
+ return req
+}
+
+// ShouldExpediteBatchRequest returns true if a batch FlowControlRequest should
+// be sent ASAP to avoid starving the client of messages. This occurs when the
+// client is rapidly acking messages.
+func (fc *flowControlBatcher) ShouldExpediteBatchRequest() bool {
+ if exceedsExpediteRatio(fc.pendingTokens.Bytes, fc.clientTokens.Bytes) {
+ return true
+ }
+ if exceedsExpediteRatio(fc.pendingTokens.Messages, fc.clientTokens.Messages) {
+ return true
+ }
+ return false
+}
+
+// subscriberOffsetTracker tracks the expected offset of the next message
+// received from the server. It is only accessed by the wireSubscriber.
+type subscriberOffsetTracker struct {
+ minNextOffset int64
+}
+
+// RequestForRestart returns the seek request to send when a new subscribe
+// stream reconnects. Returns nil if the subscriber has just started, in which
+// case the server returns the offset of the last committed cursor.
+func (ot *subscriberOffsetTracker) RequestForRestart() *pb.SeekRequest {
+ if ot.minNextOffset <= 0 {
+ return nil
+ }
+ return &pb.SeekRequest{
+ Target: &pb.SeekRequest_Cursor{
+ Cursor: &pb.Cursor{Offset: ot.minNextOffset},
+ },
+ }
+}
+
+// OnMessages verifies that messages are delivered in order and updates the next
+// expected offset.
+func (ot *subscriberOffsetTracker) OnMessages(msgs []*pb.SequencedMessage) error {
+ nextOffset := ot.minNextOffset
+ for i, msg := range msgs {
+ offset := msg.GetCursor().GetOffset()
+ if offset < nextOffset {
+ if i == 0 {
+ return fmt.Errorf("pubsublite: server delivered messages with start offset = %d, expected >= %d", offset, ot.minNextOffset)
+ }
+ return errOutOfOrderMessages
+ }
+ nextOffset = offset + 1
+ }
+ ot.minNextOffset = nextOffset
+ return nil
+}
diff --git a/pubsublite/internal/wire/flow_control_test.go b/pubsublite/internal/wire/flow_control_test.go
new file mode 100644
index 0000000..ca6d066
--- /dev/null
+++ b/pubsublite/internal/wire/flow_control_test.go
@@ -0,0 +1,325 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+
+package wire
+
+import (
+ "math"
+ "testing"
+
+ "cloud.google.com/go/internal/testutil"
+ "cloud.google.com/go/pubsublite/internal/test"
+ "github.com/golang/protobuf/proto"
+ "github.com/google/go-cmp/cmp"
+
+ pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
+)
+
+func flowControlReq(tokens flowControlTokens) *pb.FlowControlRequest {
+ return &pb.FlowControlRequest{
+ AllowedBytes: tokens.Bytes,
+ AllowedMessages: tokens.Messages,
+ }
+}
+
+func seqMsgWithOffset(offset int64) *pb.SequencedMessage {
+ return &pb.SequencedMessage{
+ Cursor: &pb.Cursor{Offset: offset},
+ }
+}
+
+func seqMsgWithSizeBytes(size int64) *pb.SequencedMessage {
+ return &pb.SequencedMessage{
+ SizeBytes: size,
+ }
+}
+
+func TestTokenCounterAdd(t *testing.T) {
+ // Note: tests are applied to this counter instance in order.
+ counter := tokenCounter{}
+
+ for _, tc := range []struct {
+ desc string
+ delta flowControlTokens
+ want tokenCounter
+ }{
+ {
+ desc: "Initialize",
+ delta: flowControlTokens{Bytes: 9876543, Messages: 1234},
+ want: tokenCounter{Bytes: 9876543, Messages: 1234},
+ },
+ {
+ desc: "Add delta",
+ delta: flowControlTokens{Bytes: 1, Messages: 2},
+ want: tokenCounter{Bytes: 9876544, Messages: 1236},
+ },
+ {
+ desc: "Overflow",
+ delta: flowControlTokens{Bytes: math.MaxInt64, Messages: math.MaxInt64},
+ want: tokenCounter{Bytes: math.MaxInt64, Messages: math.MaxInt64},
+ },
+ } {
+ t.Run(tc.desc, func(t *testing.T) {
+ counter.Add(tc.delta)
+ if !testutil.Equal(counter, tc.want) {
+ t.Errorf("tokenCounter.Add(%v): got %v, want %v", tc.delta, counter, tc.want)
+ }
+ })
+ }
+}
+
+func TestTokenCounterSub(t *testing.T) {
+ for _, tc := range []struct {
+ desc string
+ counter tokenCounter
+ delta flowControlTokens
+ want tokenCounter
+ wantErr error
+ }{
+ {
+ desc: "Result zero",
+ counter: tokenCounter{Bytes: 9876543, Messages: 1234},
+ delta: flowControlTokens{Bytes: 9876543, Messages: 1234},
+ want: tokenCounter{Bytes: 0, Messages: 0},
+ },
+ {
+ desc: "Result non-zero",
+ counter: tokenCounter{Bytes: 9876543, Messages: 1234},
+ delta: flowControlTokens{Bytes: 9876500, Messages: 1200},
+ want: tokenCounter{Bytes: 43, Messages: 34},
+ },
+ {
+ desc: "Bytes negative",
+ counter: tokenCounter{Bytes: 9876543, Messages: 1234},
+ delta: flowControlTokens{Bytes: 9876544, Messages: 1234},
+ want: tokenCounter{Bytes: 9876543, Messages: 1234},
+ wantErr: errTokenCounterBytesNegative,
+ },
+ {
+ desc: "Messages negative",
+ counter: tokenCounter{Bytes: 9876543, Messages: 1234},
+ delta: flowControlTokens{Bytes: 9876543, Messages: 1235},
+ want: tokenCounter{Bytes: 9876543, Messages: 1234},
+ wantErr: errTokenCounterMessagesNegative,
+ },
+ } {
+ t.Run(tc.desc, func(t *testing.T) {
+ gotErr := tc.counter.Sub(tc.delta)
+ if !testutil.Equal(tc.counter, tc.want) {
+ t.Errorf("tokenCounter.Sub(%v): got %v, want %v", tc.delta, tc.counter, tc.want)
+ }
+ if !test.ErrorEqual(gotErr, tc.wantErr) {
+ t.Errorf("tokenCounter.Sub(%v) error: got %v, want %v", tc.delta, gotErr, tc.wantErr)
+ }
+ })
+ }
+}
+
+func TestTokenCounterToFlowControlRequest(t *testing.T) {
+ for _, tc := range []struct {
+ desc string
+ counter tokenCounter
+ want *pb.FlowControlRequest
+ }{
+ {
+ desc: "Uninitialized counter",
+ counter: tokenCounter{},
+ want: nil,
+ },
+ {
+ desc: "Bytes non-zero",
+ counter: tokenCounter{Bytes: 1},
+ want: &pb.FlowControlRequest{AllowedBytes: 1},
+ },
+ {
+ desc: "Messages non-zero",
+ counter: tokenCounter{Messages: 1},
+ want: &pb.FlowControlRequest{AllowedMessages: 1},
+ },
+ {
+ desc: "Messages and bytes",
+ counter: tokenCounter{Bytes: 56, Messages: 32},
+ want: &pb.FlowControlRequest{AllowedBytes: 56, AllowedMessages: 32},
+ },
+ } {
+ t.Run(tc.desc, func(t *testing.T) {
+ got := tc.counter.ToFlowControlRequest()
+ if !proto.Equal(got, tc.want) {
+ t.Errorf("tokenCounter(%v).ToFlowControlRequest(): got %v, want %v", tc.counter, got, tc.want)
+ }
+ })
+ }
+}
+
+func TestFlowControlBatcher(t *testing.T) {
+ var batcher flowControlBatcher
+
+ t.Run("Uninitialized", func(t *testing.T) {
+ if got, want := batcher.ShouldExpediteBatchRequest(), false; got != want {
+ t.Errorf("flowControlBatcher.ShouldExpediteBatchRequest(): got %v, want %v", got, want)
+ }
+ if got, want := batcher.ReleasePendingRequest(), (*pb.FlowControlRequest)(nil); !proto.Equal(got, want) {
+ t.Errorf("flowControlBatcher.ReleasePendingRequest(): got %v, want %v", got, want)
+ }
+ if got, want := batcher.RequestForRestart(), (*pb.FlowControlRequest)(nil); !proto.Equal(got, want) {
+ t.Errorf("flowControlBatcher.RequestForRestart(): got %v, want %v", got, want)
+ }
+ })
+
+ t.Run("OnClientFlow-1", func(t *testing.T) {
+ deltaTokens := flowControlTokens{Bytes: 500, Messages: 10}
+ batcher.OnClientFlow(deltaTokens)
+
+ if got, want := batcher.ShouldExpediteBatchRequest(), true; got != want {
+ t.Errorf("flowControlBatcher.ShouldExpediteBatchRequest(): got %v, want %v", got, want)
+ }
+ if got, want := batcher.ReleasePendingRequest(), flowControlReq(deltaTokens); !proto.Equal(got, want) {
+ t.Errorf("flowControlBatcher.ReleasePendingRequest(): got %v, want %v", got, want)
+ }
+ if got, want := batcher.RequestForRestart(), flowControlReq(deltaTokens); !proto.Equal(got, want) {
+ t.Errorf("flowControlBatcher.RequestForRestart(): got %v, want %v", got, want)
+ }
+ })
+
+ t.Run("OnClientFlow-2", func(t *testing.T) {
+ deltaTokens := flowControlTokens{Bytes: 100, Messages: 1}
+ batcher.OnClientFlow(deltaTokens)
+
+ if got, want := batcher.ShouldExpediteBatchRequest(), false; got != want {
+ t.Errorf("flowControlBatcher.ShouldExpediteBatchRequest(): got %v, want %v", got, want)
+ }
+ if got, want := batcher.ReleasePendingRequest(), flowControlReq(deltaTokens); !proto.Equal(got, want) {
+ t.Errorf("flowControlBatcher.ReleasePendingRequest(): got %v, want %v", got, want)
+ }
+ if got, want := batcher.RequestForRestart(), flowControlReq(flowControlTokens{Bytes: 600, Messages: 11}); !proto.Equal(got, want) {
+ t.Errorf("flowControlBatcher.RequestForRestart(): got %v, want %v", got, want)
+ }
+ })
+
+ t.Run("OnMessages-Valid", func(t *testing.T) {
+ msgs := []*pb.SequencedMessage{seqMsgWithSizeBytes(10), seqMsgWithSizeBytes(20)}
+ if gotErr := batcher.OnMessages(msgs); gotErr != nil {
+ t.Errorf("flowControlBatcher.OnMessages(): got err (%v), want err <nil>", gotErr)
+ }
+
+ if got, want := batcher.ShouldExpediteBatchRequest(), false; got != want {
+ t.Errorf("flowControlBatcher.ShouldExpediteBatchRequest(): got %v, want %v", got, want)
+ }
+ if got, want := batcher.ReleasePendingRequest(), (*pb.FlowControlRequest)(nil); !proto.Equal(got, want) {
+ t.Errorf("flowControlBatcher.ReleasePendingRequest(): got %v, want %v", got, want)
+ }
+ if got, want := batcher.RequestForRestart(), flowControlReq(flowControlTokens{Bytes: 570, Messages: 9}); !proto.Equal(got, want) {
+ t.Errorf("flowControlBatcher.RequestForRestart(): got %v, want %v", got, want)
+ }
+ })
+
+ t.Run("OnMessages-Underflow", func(t *testing.T) {
+ msgs := []*pb.SequencedMessage{seqMsgWithSizeBytes(400), seqMsgWithSizeBytes(200)}
+ if gotErr, wantErr := batcher.OnMessages(msgs), errTokenCounterBytesNegative; !test.ErrorEqual(gotErr, wantErr) {
+ t.Errorf("flowControlBatcher.OnMessages(): got err (%v), want err (%v)", gotErr, wantErr)
+ }
+
+ if got, want := batcher.ShouldExpediteBatchRequest(), false; got != want {
+ t.Errorf("flowControlBatcher.ShouldExpediteBatchRequest(): got %v, want %v", got, want)
+ }
+ if got, want := batcher.ReleasePendingRequest(), (*pb.FlowControlRequest)(nil); !proto.Equal(got, want) {
+ t.Errorf("flowControlBatcher.ReleasePendingRequest(): got %v, want %v", got, want)
+ }
+ if got, want := batcher.RequestForRestart(), flowControlReq(flowControlTokens{Bytes: 570, Messages: 9}); !proto.Equal(got, want) {
+ t.Errorf("flowControlBatcher.RequestForRestart(): got %v, want %v", got, want)
+ }
+ })
+}
+
+func TestOffsetTrackerRequestForRestart(t *testing.T) {
+ for _, tc := range []struct {
+ desc string
+ tracker subscriberOffsetTracker
+ want *pb.SeekRequest
+ }{
+ {
+ desc: "Uninitialized tracker",
+ tracker: subscriberOffsetTracker{},
+ want: nil,
+ },
+ {
+ desc: "Next offset positive",
+ tracker: subscriberOffsetTracker{minNextOffset: 1},
+ want: &pb.SeekRequest{
+ Target: &pb.SeekRequest_Cursor{
+ Cursor: &pb.Cursor{Offset: 1},
+ },
+ },
+ },
+ } {
+ t.Run(tc.desc, func(t *testing.T) {
+ got := tc.tracker.RequestForRestart()
+ if !proto.Equal(got, tc.want) {
+ t.Errorf("subscriberOffsetTracker(%v).RequestForRestart(): got %v, want %v", tc.tracker, got, tc.want)
+ }
+ })
+ }
+}
+
+func TestOffsetTrackerOnMessages(t *testing.T) {
+ for _, tc := range []struct {
+ desc string
+ tracker subscriberOffsetTracker
+ msgs []*pb.SequencedMessage
+ want subscriberOffsetTracker
+ wantErr bool
+ }{
+ {
+ desc: "Uninitialized tracker",
+ tracker: subscriberOffsetTracker{},
+ msgs: []*pb.SequencedMessage{seqMsgWithOffset(0)},
+ want: subscriberOffsetTracker{minNextOffset: 1},
+ },
+ {
+ desc: "Consecutive message offsets",
+ tracker: subscriberOffsetTracker{minNextOffset: 5},
+ msgs: []*pb.SequencedMessage{seqMsgWithOffset(5), seqMsgWithOffset(6), seqMsgWithOffset(7)},
+ want: subscriberOffsetTracker{minNextOffset: 8},
+ },
+ {
+ desc: "Skip message offsets",
+ tracker: subscriberOffsetTracker{minNextOffset: 5},
+ msgs: []*pb.SequencedMessage{seqMsgWithOffset(10), seqMsgWithOffset(15)},
+ want: subscriberOffsetTracker{minNextOffset: 16},
+ },
+ {
+ desc: "Start offset before minNextOffset",
+ tracker: subscriberOffsetTracker{minNextOffset: 5},
+ msgs: []*pb.SequencedMessage{seqMsgWithOffset(4)},
+ want: subscriberOffsetTracker{minNextOffset: 5},
+ wantErr: true,
+ },
+ {
+ desc: "Unordered messages",
+ tracker: subscriberOffsetTracker{minNextOffset: 5},
+ msgs: []*pb.SequencedMessage{seqMsgWithOffset(5), seqMsgWithOffset(10), seqMsgWithOffset(9)},
+ want: subscriberOffsetTracker{minNextOffset: 5},
+ wantErr: true,
+ },
+ } {
+ t.Run(tc.desc, func(t *testing.T) {
+ err := tc.tracker.OnMessages(tc.msgs)
+ if !testutil.Equal(tc.tracker, tc.want, cmp.AllowUnexported(subscriberOffsetTracker{})) {
+ t.Errorf("subscriberOffsetTracker().OnMessages(): got %v, want %v", tc.tracker, tc.want)
+ }
+ if gotErr := err != nil; gotErr != tc.wantErr {
+ t.Errorf("subscriberOffsetTracker().OnMessages() error: got (%v), want err=%v", err, tc.wantErr)
+ }
+ })
+ }
+}