feat(pubsublite): Flow controller and offset tracker for the subscriber (#3132)

Ports the flow control token counter & batcher, and offset tracker from the Pub/Sub Lite Java client library.
diff --git a/pubsublite/go.mod b/pubsublite/go.mod
index c5d7c50..9150918 100644
--- a/pubsublite/go.mod
+++ b/pubsublite/go.mod
@@ -5,6 +5,7 @@
 require (
 	cloud.google.com/go v0.71.0
 	github.com/golang/protobuf v1.4.3
+	github.com/google/go-cmp v0.5.2
 	github.com/googleapis/gax-go/v2 v2.0.5
 	golang.org/x/tools v0.0.0-20201102212025-f46e4245211d // indirect
 	google.golang.org/api v0.34.0
diff --git a/pubsublite/internal/wire/flow_control.go b/pubsublite/internal/wire/flow_control.go
new file mode 100644
index 0000000..785cdd0
--- /dev/null
+++ b/pubsublite/internal/wire/flow_control.go
@@ -0,0 +1,180 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+
+package wire
+
+import (
+	"errors"
+	"fmt"
+	"math"
+
+	pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
+)
+
+var (
+	errTokenCounterBytesNegative    = errors.New("pubsublite: received messages that account for more bytes than were requested")
+	errTokenCounterMessagesNegative = errors.New("pubsublite: received more messages than were requested")
+	errOutOfOrderMessages           = errors.New("pubsublite: server delivered messages out of order")
+)
+
+type flowControlTokens struct {
+	Bytes    int64
+	Messages int64
+}
+
+// A TokenCounter stores the amount of outstanding byte and message flow control
+// tokens that the client believes exists for the stream.
+type tokenCounter struct {
+	Bytes    int64
+	Messages int64
+}
+
+func saturatedAdd(sum, delta int64) int64 {
+	remainder := math.MaxInt64 - sum
+	if delta >= remainder {
+		return math.MaxInt64
+	}
+	return sum + delta
+}
+
+func (tc *tokenCounter) Add(delta flowControlTokens) {
+	tc.Bytes = saturatedAdd(tc.Bytes, delta.Bytes)
+	tc.Messages = saturatedAdd(tc.Messages, delta.Messages)
+}
+
+func (tc *tokenCounter) Sub(delta flowControlTokens) error {
+	if delta.Bytes > tc.Bytes {
+		return errTokenCounterBytesNegative
+	}
+	if delta.Messages > tc.Messages {
+		return errTokenCounterMessagesNegative
+	}
+	tc.Bytes -= delta.Bytes
+	tc.Messages -= delta.Messages
+	return nil
+}
+
+func (tc *tokenCounter) Reset() {
+	tc.Bytes = 0
+	tc.Messages = 0
+}
+
+func (tc *tokenCounter) ToFlowControlRequest() *pb.FlowControlRequest {
+	if tc.Bytes <= 0 && tc.Messages <= 0 {
+		return nil
+	}
+	return &pb.FlowControlRequest{
+		AllowedBytes:    tc.Bytes,
+		AllowedMessages: tc.Messages,
+	}
+}
+
+// flowControlBatcher tracks flow control tokens and manages batching of flow
+// control requests to avoid overwhelming the server. It is only accessed by
+// the wireSubscriber.
+type flowControlBatcher struct {
+	// The current amount of outstanding byte and message flow control tokens.
+	clientTokens tokenCounter
+	// The pending batch flow control request that needs to be sent to the stream.
+	pendingTokens tokenCounter
+}
+
+const expediteBatchRequestRatio = 0.5
+
+func exceedsExpediteRatio(pending, client int64) bool {
+	return client > 0 && (float64(pending)/float64(client)) >= expediteBatchRequestRatio
+}
+
+// OnClientFlow increments flow control tokens. This occurs when:
+// - Initialization from ReceiveSettings.
+// - The user acks messages.
+func (fc *flowControlBatcher) OnClientFlow(tokens flowControlTokens) {
+	fc.clientTokens.Add(tokens)
+	fc.pendingTokens.Add(tokens)
+}
+
+// OnMessages decrements flow control tokens when messages are received from the
+// server.
+func (fc *flowControlBatcher) OnMessages(msgs []*pb.SequencedMessage) error {
+	var totalBytes int64
+	for _, msg := range msgs {
+		totalBytes += msg.GetSizeBytes()
+	}
+	return fc.clientTokens.Sub(flowControlTokens{Bytes: totalBytes, Messages: int64(len(msgs))})
+}
+
+// RequestForRestart returns a FlowControlRequest that should be sent when a new
+// subscriber stream is connected. May return nil.
+func (fc *flowControlBatcher) RequestForRestart() *pb.FlowControlRequest {
+	fc.pendingTokens.Reset()
+	return fc.clientTokens.ToFlowControlRequest()
+}
+
+// ReleasePendingRequest returns a non-nil request when there is a batch
+// FlowControlRequest to send to the stream.
+func (fc *flowControlBatcher) ReleasePendingRequest() *pb.FlowControlRequest {
+	req := fc.pendingTokens.ToFlowControlRequest()
+	fc.pendingTokens.Reset()
+	return req
+}
+
+// ShouldExpediteBatchRequest returns true if a batch FlowControlRequest should
+// be sent ASAP to avoid starving the client of messages. This occurs when the
+// client is rapidly acking messages.
+func (fc *flowControlBatcher) ShouldExpediteBatchRequest() bool {
+	if exceedsExpediteRatio(fc.pendingTokens.Bytes, fc.clientTokens.Bytes) {
+		return true
+	}
+	if exceedsExpediteRatio(fc.pendingTokens.Messages, fc.clientTokens.Messages) {
+		return true
+	}
+	return false
+}
+
+// subscriberOffsetTracker tracks the expected offset of the next message
+// received from the server. It is only accessed by the wireSubscriber.
+type subscriberOffsetTracker struct {
+	minNextOffset int64
+}
+
+// RequestForRestart returns the seek request to send when a new subscribe
+// stream reconnects. Returns nil if the subscriber has just started, in which
+// case the server returns the offset of the last committed cursor.
+func (ot *subscriberOffsetTracker) RequestForRestart() *pb.SeekRequest {
+	if ot.minNextOffset <= 0 {
+		return nil
+	}
+	return &pb.SeekRequest{
+		Target: &pb.SeekRequest_Cursor{
+			Cursor: &pb.Cursor{Offset: ot.minNextOffset},
+		},
+	}
+}
+
+// OnMessages verifies that messages are delivered in order and updates the next
+// expected offset.
+func (ot *subscriberOffsetTracker) OnMessages(msgs []*pb.SequencedMessage) error {
+	nextOffset := ot.minNextOffset
+	for i, msg := range msgs {
+		offset := msg.GetCursor().GetOffset()
+		if offset < nextOffset {
+			if i == 0 {
+				return fmt.Errorf("pubsublite: server delivered messages with start offset = %d, expected >= %d", offset, ot.minNextOffset)
+			}
+			return errOutOfOrderMessages
+		}
+		nextOffset = offset + 1
+	}
+	ot.minNextOffset = nextOffset
+	return nil
+}
diff --git a/pubsublite/internal/wire/flow_control_test.go b/pubsublite/internal/wire/flow_control_test.go
new file mode 100644
index 0000000..ca6d066
--- /dev/null
+++ b/pubsublite/internal/wire/flow_control_test.go
@@ -0,0 +1,325 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+
+package wire
+
+import (
+	"math"
+	"testing"
+
+	"cloud.google.com/go/internal/testutil"
+	"cloud.google.com/go/pubsublite/internal/test"
+	"github.com/golang/protobuf/proto"
+	"github.com/google/go-cmp/cmp"
+
+	pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
+)
+
+func flowControlReq(tokens flowControlTokens) *pb.FlowControlRequest {
+	return &pb.FlowControlRequest{
+		AllowedBytes:    tokens.Bytes,
+		AllowedMessages: tokens.Messages,
+	}
+}
+
+func seqMsgWithOffset(offset int64) *pb.SequencedMessage {
+	return &pb.SequencedMessage{
+		Cursor: &pb.Cursor{Offset: offset},
+	}
+}
+
+func seqMsgWithSizeBytes(size int64) *pb.SequencedMessage {
+	return &pb.SequencedMessage{
+		SizeBytes: size,
+	}
+}
+
+func TestTokenCounterAdd(t *testing.T) {
+	// Note: tests are applied to this counter instance in order.
+	counter := tokenCounter{}
+
+	for _, tc := range []struct {
+		desc  string
+		delta flowControlTokens
+		want  tokenCounter
+	}{
+		{
+			desc:  "Initialize",
+			delta: flowControlTokens{Bytes: 9876543, Messages: 1234},
+			want:  tokenCounter{Bytes: 9876543, Messages: 1234},
+		},
+		{
+			desc:  "Add delta",
+			delta: flowControlTokens{Bytes: 1, Messages: 2},
+			want:  tokenCounter{Bytes: 9876544, Messages: 1236},
+		},
+		{
+			desc:  "Overflow",
+			delta: flowControlTokens{Bytes: math.MaxInt64, Messages: math.MaxInt64},
+			want:  tokenCounter{Bytes: math.MaxInt64, Messages: math.MaxInt64},
+		},
+	} {
+		t.Run(tc.desc, func(t *testing.T) {
+			counter.Add(tc.delta)
+			if !testutil.Equal(counter, tc.want) {
+				t.Errorf("tokenCounter.Add(%v): got %v, want %v", tc.delta, counter, tc.want)
+			}
+		})
+	}
+}
+
+func TestTokenCounterSub(t *testing.T) {
+	for _, tc := range []struct {
+		desc    string
+		counter tokenCounter
+		delta   flowControlTokens
+		want    tokenCounter
+		wantErr error
+	}{
+		{
+			desc:    "Result zero",
+			counter: tokenCounter{Bytes: 9876543, Messages: 1234},
+			delta:   flowControlTokens{Bytes: 9876543, Messages: 1234},
+			want:    tokenCounter{Bytes: 0, Messages: 0},
+		},
+		{
+			desc:    "Result non-zero",
+			counter: tokenCounter{Bytes: 9876543, Messages: 1234},
+			delta:   flowControlTokens{Bytes: 9876500, Messages: 1200},
+			want:    tokenCounter{Bytes: 43, Messages: 34},
+		},
+		{
+			desc:    "Bytes negative",
+			counter: tokenCounter{Bytes: 9876543, Messages: 1234},
+			delta:   flowControlTokens{Bytes: 9876544, Messages: 1234},
+			want:    tokenCounter{Bytes: 9876543, Messages: 1234},
+			wantErr: errTokenCounterBytesNegative,
+		},
+		{
+			desc:    "Messages negative",
+			counter: tokenCounter{Bytes: 9876543, Messages: 1234},
+			delta:   flowControlTokens{Bytes: 9876543, Messages: 1235},
+			want:    tokenCounter{Bytes: 9876543, Messages: 1234},
+			wantErr: errTokenCounterMessagesNegative,
+		},
+	} {
+		t.Run(tc.desc, func(t *testing.T) {
+			gotErr := tc.counter.Sub(tc.delta)
+			if !testutil.Equal(tc.counter, tc.want) {
+				t.Errorf("tokenCounter.Sub(%v): got %v, want %v", tc.delta, tc.counter, tc.want)
+			}
+			if !test.ErrorEqual(gotErr, tc.wantErr) {
+				t.Errorf("tokenCounter.Sub(%v) error: got %v, want %v", tc.delta, gotErr, tc.wantErr)
+			}
+		})
+	}
+}
+
+func TestTokenCounterToFlowControlRequest(t *testing.T) {
+	for _, tc := range []struct {
+		desc    string
+		counter tokenCounter
+		want    *pb.FlowControlRequest
+	}{
+		{
+			desc:    "Uninitialized counter",
+			counter: tokenCounter{},
+			want:    nil,
+		},
+		{
+			desc:    "Bytes non-zero",
+			counter: tokenCounter{Bytes: 1},
+			want:    &pb.FlowControlRequest{AllowedBytes: 1},
+		},
+		{
+			desc:    "Messages non-zero",
+			counter: tokenCounter{Messages: 1},
+			want:    &pb.FlowControlRequest{AllowedMessages: 1},
+		},
+		{
+			desc:    "Messages and bytes",
+			counter: tokenCounter{Bytes: 56, Messages: 32},
+			want:    &pb.FlowControlRequest{AllowedBytes: 56, AllowedMessages: 32},
+		},
+	} {
+		t.Run(tc.desc, func(t *testing.T) {
+			got := tc.counter.ToFlowControlRequest()
+			if !proto.Equal(got, tc.want) {
+				t.Errorf("tokenCounter(%v).ToFlowControlRequest(): got %v, want %v", tc.counter, got, tc.want)
+			}
+		})
+	}
+}
+
+func TestFlowControlBatcher(t *testing.T) {
+	var batcher flowControlBatcher
+
+	t.Run("Uninitialized", func(t *testing.T) {
+		if got, want := batcher.ShouldExpediteBatchRequest(), false; got != want {
+			t.Errorf("flowControlBatcher.ShouldExpediteBatchRequest(): got %v, want %v", got, want)
+		}
+		if got, want := batcher.ReleasePendingRequest(), (*pb.FlowControlRequest)(nil); !proto.Equal(got, want) {
+			t.Errorf("flowControlBatcher.ReleasePendingRequest(): got %v, want %v", got, want)
+		}
+		if got, want := batcher.RequestForRestart(), (*pb.FlowControlRequest)(nil); !proto.Equal(got, want) {
+			t.Errorf("flowControlBatcher.RequestForRestart(): got %v, want %v", got, want)
+		}
+	})
+
+	t.Run("OnClientFlow-1", func(t *testing.T) {
+		deltaTokens := flowControlTokens{Bytes: 500, Messages: 10}
+		batcher.OnClientFlow(deltaTokens)
+
+		if got, want := batcher.ShouldExpediteBatchRequest(), true; got != want {
+			t.Errorf("flowControlBatcher.ShouldExpediteBatchRequest(): got %v, want %v", got, want)
+		}
+		if got, want := batcher.ReleasePendingRequest(), flowControlReq(deltaTokens); !proto.Equal(got, want) {
+			t.Errorf("flowControlBatcher.ReleasePendingRequest(): got %v, want %v", got, want)
+		}
+		if got, want := batcher.RequestForRestart(), flowControlReq(deltaTokens); !proto.Equal(got, want) {
+			t.Errorf("flowControlBatcher.RequestForRestart(): got %v, want %v", got, want)
+		}
+	})
+
+	t.Run("OnClientFlow-2", func(t *testing.T) {
+		deltaTokens := flowControlTokens{Bytes: 100, Messages: 1}
+		batcher.OnClientFlow(deltaTokens)
+
+		if got, want := batcher.ShouldExpediteBatchRequest(), false; got != want {
+			t.Errorf("flowControlBatcher.ShouldExpediteBatchRequest(): got %v, want %v", got, want)
+		}
+		if got, want := batcher.ReleasePendingRequest(), flowControlReq(deltaTokens); !proto.Equal(got, want) {
+			t.Errorf("flowControlBatcher.ReleasePendingRequest(): got %v, want %v", got, want)
+		}
+		if got, want := batcher.RequestForRestart(), flowControlReq(flowControlTokens{Bytes: 600, Messages: 11}); !proto.Equal(got, want) {
+			t.Errorf("flowControlBatcher.RequestForRestart(): got %v, want %v", got, want)
+		}
+	})
+
+	t.Run("OnMessages-Valid", func(t *testing.T) {
+		msgs := []*pb.SequencedMessage{seqMsgWithSizeBytes(10), seqMsgWithSizeBytes(20)}
+		if gotErr := batcher.OnMessages(msgs); gotErr != nil {
+			t.Errorf("flowControlBatcher.OnMessages(): got err (%v), want err <nil>", gotErr)
+		}
+
+		if got, want := batcher.ShouldExpediteBatchRequest(), false; got != want {
+			t.Errorf("flowControlBatcher.ShouldExpediteBatchRequest(): got %v, want %v", got, want)
+		}
+		if got, want := batcher.ReleasePendingRequest(), (*pb.FlowControlRequest)(nil); !proto.Equal(got, want) {
+			t.Errorf("flowControlBatcher.ReleasePendingRequest(): got %v, want %v", got, want)
+		}
+		if got, want := batcher.RequestForRestart(), flowControlReq(flowControlTokens{Bytes: 570, Messages: 9}); !proto.Equal(got, want) {
+			t.Errorf("flowControlBatcher.RequestForRestart(): got %v, want %v", got, want)
+		}
+	})
+
+	t.Run("OnMessages-Underflow", func(t *testing.T) {
+		msgs := []*pb.SequencedMessage{seqMsgWithSizeBytes(400), seqMsgWithSizeBytes(200)}
+		if gotErr, wantErr := batcher.OnMessages(msgs), errTokenCounterBytesNegative; !test.ErrorEqual(gotErr, wantErr) {
+			t.Errorf("flowControlBatcher.OnMessages(): got err (%v), want err (%v)", gotErr, wantErr)
+		}
+
+		if got, want := batcher.ShouldExpediteBatchRequest(), false; got != want {
+			t.Errorf("flowControlBatcher.ShouldExpediteBatchRequest(): got %v, want %v", got, want)
+		}
+		if got, want := batcher.ReleasePendingRequest(), (*pb.FlowControlRequest)(nil); !proto.Equal(got, want) {
+			t.Errorf("flowControlBatcher.ReleasePendingRequest(): got %v, want %v", got, want)
+		}
+		if got, want := batcher.RequestForRestart(), flowControlReq(flowControlTokens{Bytes: 570, Messages: 9}); !proto.Equal(got, want) {
+			t.Errorf("flowControlBatcher.RequestForRestart(): got %v, want %v", got, want)
+		}
+	})
+}
+
+func TestOffsetTrackerRequestForRestart(t *testing.T) {
+	for _, tc := range []struct {
+		desc    string
+		tracker subscriberOffsetTracker
+		want    *pb.SeekRequest
+	}{
+		{
+			desc:    "Uninitialized tracker",
+			tracker: subscriberOffsetTracker{},
+			want:    nil,
+		},
+		{
+			desc:    "Next offset positive",
+			tracker: subscriberOffsetTracker{minNextOffset: 1},
+			want: &pb.SeekRequest{
+				Target: &pb.SeekRequest_Cursor{
+					Cursor: &pb.Cursor{Offset: 1},
+				},
+			},
+		},
+	} {
+		t.Run(tc.desc, func(t *testing.T) {
+			got := tc.tracker.RequestForRestart()
+			if !proto.Equal(got, tc.want) {
+				t.Errorf("subscriberOffsetTracker(%v).RequestForRestart(): got %v, want %v", tc.tracker, got, tc.want)
+			}
+		})
+	}
+}
+
+func TestOffsetTrackerOnMessages(t *testing.T) {
+	for _, tc := range []struct {
+		desc    string
+		tracker subscriberOffsetTracker
+		msgs    []*pb.SequencedMessage
+		want    subscriberOffsetTracker
+		wantErr bool
+	}{
+		{
+			desc:    "Uninitialized tracker",
+			tracker: subscriberOffsetTracker{},
+			msgs:    []*pb.SequencedMessage{seqMsgWithOffset(0)},
+			want:    subscriberOffsetTracker{minNextOffset: 1},
+		},
+		{
+			desc:    "Consecutive message offsets",
+			tracker: subscriberOffsetTracker{minNextOffset: 5},
+			msgs:    []*pb.SequencedMessage{seqMsgWithOffset(5), seqMsgWithOffset(6), seqMsgWithOffset(7)},
+			want:    subscriberOffsetTracker{minNextOffset: 8},
+		},
+		{
+			desc:    "Skip message offsets",
+			tracker: subscriberOffsetTracker{minNextOffset: 5},
+			msgs:    []*pb.SequencedMessage{seqMsgWithOffset(10), seqMsgWithOffset(15)},
+			want:    subscriberOffsetTracker{minNextOffset: 16},
+		},
+		{
+			desc:    "Start offset before minNextOffset",
+			tracker: subscriberOffsetTracker{minNextOffset: 5},
+			msgs:    []*pb.SequencedMessage{seqMsgWithOffset(4)},
+			want:    subscriberOffsetTracker{minNextOffset: 5},
+			wantErr: true,
+		},
+		{
+			desc:    "Unordered messages",
+			tracker: subscriberOffsetTracker{minNextOffset: 5},
+			msgs:    []*pb.SequencedMessage{seqMsgWithOffset(5), seqMsgWithOffset(10), seqMsgWithOffset(9)},
+			want:    subscriberOffsetTracker{minNextOffset: 5},
+			wantErr: true,
+		},
+	} {
+		t.Run(tc.desc, func(t *testing.T) {
+			err := tc.tracker.OnMessages(tc.msgs)
+			if !testutil.Equal(tc.tracker, tc.want, cmp.AllowUnexported(subscriberOffsetTracker{})) {
+				t.Errorf("subscriberOffsetTracker().OnMessages(): got %v, want %v", tc.tracker, tc.want)
+			}
+			if gotErr := err != nil; gotErr != tc.wantErr {
+				t.Errorf("subscriberOffsetTracker().OnMessages() error: got (%v), want err=%v", err, tc.wantErr)
+			}
+		})
+	}
+}