feat(pubsublite): Committer implementation (#3198)
Commits desired offsets based on the ack prefix offset. Commits are batched.
diff --git a/pubsublite/internal/wire/acks.go b/pubsublite/internal/wire/acks.go
index 0a7c208..411a3ea 100644
--- a/pubsublite/internal/wire/acks.go
+++ b/pubsublite/internal/wire/acks.go
@@ -83,7 +83,7 @@
// ackTracker manages outstanding message acks, i.e. messages that have been
// delivered to the user, but not yet acked. It is used by the committer and
-// wireSubscriber, so requires its own mutex.
+// subscribeStream, so requires its own mutex.
type ackTracker struct {
// Guards access to fields below.
mu sync.Mutex
@@ -163,6 +163,13 @@
at.outstandingAcks.Init()
}
+// Empty when there are no outstanding acks.
+func (at *ackTracker) Empty() bool {
+ at.mu.Lock()
+ defer at.mu.Unlock()
+ return at.outstandingAcks.Len() == 0
+}
+
// commitCursorTracker tracks pending and last successful committed offsets.
// It is only accessed by the committer.
type commitCursorTracker struct {
diff --git a/pubsublite/internal/wire/committer.go b/pubsublite/internal/wire/committer.go
new file mode 100644
index 0000000..d540412
--- /dev/null
+++ b/pubsublite/internal/wire/committer.go
@@ -0,0 +1,219 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+
+package wire
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "reflect"
+ "time"
+
+ "google.golang.org/grpc"
+
+ vkit "cloud.google.com/go/pubsublite/apiv1"
+ pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
+)
+
+var (
+ errInvalidInitialCommitResponse = errors.New("pubsublite: first response from server was not an initial response for streaming commit")
+ errInvalidCommitResponse = errors.New("pubsublite: received invalid commit response from server")
+)
+
+// The frequency of batched cursor commits.
+const commitCursorPeriod = 50 * time.Millisecond
+
+// committer wraps a commit cursor stream for a subscription and partition.
+// A background task periodically effectively reads the latest desired cursor
+// offset from the `ackTracker` and sends a commit request to the stream if the
+// cursor needs to be updated. The `commitCursorTracker` is used to manage
+// in-flight commit requests.
+type committer struct {
+ // Immutable after creation.
+ cursorClient *vkit.CursorClient
+ initialReq *pb.StreamingCommitCursorRequest
+
+ // Fields below must be guarded with mutex.
+ stream *retryableStream
+ acks *ackTracker
+ cursorTracker *commitCursorTracker
+ pollCommits *periodicTask
+
+ abstractService
+}
+
+func newCommitter(ctx context.Context, cursor *vkit.CursorClient, settings ReceiveSettings,
+ subscription subscriptionPartition, acks *ackTracker, disableTasks bool) *committer {
+
+ c := &committer{
+ cursorClient: cursor,
+ initialReq: &pb.StreamingCommitCursorRequest{
+ Request: &pb.StreamingCommitCursorRequest_Initial{
+ Initial: &pb.InitialCommitCursorRequest{
+ Subscription: subscription.Path,
+ Partition: int64(subscription.Partition),
+ },
+ },
+ },
+ acks: acks,
+ cursorTracker: newCommitCursorTracker(acks),
+ }
+ c.stream = newRetryableStream(ctx, c, settings.Timeout, reflect.TypeOf(pb.StreamingCommitCursorResponse{}))
+
+ backgroundTask := c.commitOffsetToStream
+ if disableTasks {
+ backgroundTask = func() {}
+ }
+ c.pollCommits = newPeriodicTask(commitCursorPeriod, backgroundTask)
+ return c
+}
+
+// Start attempts to establish a streaming commit cursor connection.
+func (c *committer) Start() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if c.unsafeUpdateStatus(serviceStarting, nil) {
+ c.stream.Start()
+ c.pollCommits.Start()
+ }
+}
+
+// Stop initiates shutdown of the committer. The commit stream remains open to
+// process all outstanding acks and send the final commit offset.
+func (c *committer) Stop() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.unsafeInitiateShutdown(serviceTerminating, nil)
+}
+
+func (c *committer) newStream(ctx context.Context) (grpc.ClientStream, error) {
+ return c.cursorClient.StreamingCommitCursor(ctx)
+}
+
+func (c *committer) initialRequest() (req interface{}, needsResp bool) {
+ req = c.initialReq
+ needsResp = true
+ return
+}
+
+func (c *committer) validateInitialResponse(response interface{}) error {
+ commitResponse, _ := response.(*pb.StreamingCommitCursorResponse)
+ if commitResponse.GetInitial() == nil {
+ return errInvalidInitialCommitResponse
+ }
+ return nil
+}
+
+func (c *committer) onStreamStatusChange(status streamStatus) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ switch status {
+ case streamConnected:
+ c.unsafeUpdateStatus(serviceActive, nil)
+ // Once the stream connects, clear unconfirmed commits and immediately send
+ // the latest desired commit offset.
+ c.cursorTracker.ClearPending()
+ c.unsafeCommitOffsetToStream()
+ c.pollCommits.Start()
+
+ case streamReconnecting:
+ c.pollCommits.Stop()
+
+ case streamTerminated:
+ c.unsafeInitiateShutdown(serviceTerminated, c.stream.Error())
+ }
+}
+
+func (c *committer) onResponse(response interface{}) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ // If an inconsistency is detected in the server's responses, immediately
+ // terminate the committer, as correct processing of commits cannot be
+ // guaranteed.
+ processResponse := func() error {
+ commitResponse, _ := response.(*pb.StreamingCommitCursorResponse)
+ if commitResponse.GetCommit() == nil {
+ return errInvalidCommitResponse
+ }
+ numAcked := commitResponse.GetCommit().GetAcknowledgedCommits()
+ if numAcked <= 0 {
+ return fmt.Errorf("pubsublite: server acknowledged an invalid commit count: %d", numAcked)
+ }
+ if err := c.cursorTracker.ConfirmOffsets(numAcked); err != nil {
+ return err
+ }
+ c.unsafeCheckDone()
+ return nil
+ }
+ if err := processResponse(); err != nil {
+ c.unsafeInitiateShutdown(serviceTerminated, err)
+ }
+}
+
+// commitOffsetToStream is called by the periodic background task.
+func (c *committer) commitOffsetToStream() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.unsafeCommitOffsetToStream()
+}
+
+func (c *committer) unsafeCommitOffsetToStream() {
+ nextOffset := c.cursorTracker.NextOffset()
+ if nextOffset == nilCursorOffset {
+ return
+ }
+
+ req := &pb.StreamingCommitCursorRequest{
+ Request: &pb.StreamingCommitCursorRequest_Commit{
+ Commit: &pb.SequencedCommitCursorRequest{
+ Cursor: &pb.Cursor{Offset: nextOffset},
+ },
+ },
+ }
+ if c.stream.Send(req) {
+ c.cursorTracker.AddPending(nextOffset)
+ }
+}
+
+func (c *committer) unsafeInitiateShutdown(targetStatus serviceStatus, err error) {
+ if !c.unsafeUpdateStatus(targetStatus, err) {
+ return
+ }
+
+ // If it's a graceful shutdown, expedite sending final commits to the stream.
+ if targetStatus == serviceTerminating {
+ c.unsafeCommitOffsetToStream()
+ c.unsafeCheckDone()
+ return
+ }
+ // Otherwise immediately terminate the stream.
+ c.unsafeTerminate()
+}
+
+func (c *committer) unsafeCheckDone() {
+ // If the user stops the subscriber, they will no longer receive messages, but
+ // the commit stream remains open to process acks for outstanding messages.
+ if c.status == serviceTerminating && c.cursorTracker.Done() && c.acks.Empty() {
+ c.unsafeTerminate()
+ }
+}
+
+func (c *committer) unsafeTerminate() {
+ c.acks.Release()
+ c.pollCommits.Stop()
+ c.stream.Stop()
+}
diff --git a/pubsublite/internal/wire/committer_test.go b/pubsublite/internal/wire/committer_test.go
new file mode 100644
index 0000000..1c1ea2e
--- /dev/null
+++ b/pubsublite/internal/wire/committer_test.go
@@ -0,0 +1,252 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+
+package wire
+
+import (
+ "context"
+ "testing"
+
+ "cloud.google.com/go/pubsublite/internal/test"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+)
+
+// testCommitter wraps a committer for ease of testing.
+type testCommitter struct {
+ cmt *committer
+ serviceTestProxy
+}
+
+func newTestCommitter(t *testing.T, subscription subscriptionPartition, acks *ackTracker) *testCommitter {
+ ctx := context.Background()
+ cursorClient, err := newCursorClient(ctx, "ignored", testClientOpts...)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ tc := &testCommitter{
+ cmt: newCommitter(ctx, cursorClient, testReceiveSettings(), subscription, acks, true),
+ }
+ tc.initAndStart(t, tc.cmt, "Committer")
+ return tc
+}
+
+// SendBatchCommit invokes the periodic background batch commit. Note that the
+// periodic task is disabled in tests.
+func (tc *testCommitter) SendBatchCommit() {
+ tc.cmt.commitOffsetToStream()
+}
+
+func TestCommitterStreamReconnect(t *testing.T) {
+ subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
+ ack1 := newAckConsumer(33, 0, nil)
+ ack2 := newAckConsumer(55, 0, nil)
+ acks := newAckTracker()
+ acks.Push(ack1)
+ acks.Push(ack2)
+
+ verifiers := test.NewVerifiers(t)
+
+ // Simulate a transient error that results in a reconnect.
+ stream1 := test.NewRPCVerifier(t)
+ stream1.Push(initCommitReq(subscription), initCommitResp(), nil)
+ barrier := stream1.PushWithBarrier(commitReq(34), nil, status.Error(codes.Unavailable, "server unavailable"))
+ verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream1)
+
+ // When the stream reconnects, the latest commit offset should be sent to the
+ // server.
+ stream2 := test.NewRPCVerifier(t)
+ stream2.Push(initCommitReq(subscription), initCommitResp(), nil)
+ stream2.Push(commitReq(56), commitResp(1), nil)
+ verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream2)
+
+ mockServer.OnTestStart(verifiers)
+ defer mockServer.OnTestEnd()
+
+ cmt := newTestCommitter(t, subscription, acks)
+ if gotErr := cmt.StartError(); gotErr != nil {
+ t.Errorf("Start() got err: (%v)", gotErr)
+ }
+
+ // Send 2 commits.
+ ack1.Ack()
+ cmt.SendBatchCommit()
+ ack2.Ack()
+ cmt.SendBatchCommit()
+
+ // Then send the retryable error, which results in reconnect.
+ barrier.Release()
+ cmt.StopVerifyNoError()
+}
+
+func TestCommitterStopFlushesCommits(t *testing.T) {
+ subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
+ ack1 := newAckConsumer(33, 0, nil)
+ ack2 := newAckConsumer(55, 0, nil)
+ acks := newAckTracker()
+ acks.Push(ack1)
+ acks.Push(ack2)
+
+ verifiers := test.NewVerifiers(t)
+ stream := test.NewRPCVerifier(t)
+ stream.Push(initCommitReq(subscription), initCommitResp(), nil)
+ stream.Push(commitReq(34), commitResp(1), nil)
+ stream.Push(commitReq(56), commitResp(1), nil)
+ verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
+
+ mockServer.OnTestStart(verifiers)
+ defer mockServer.OnTestEnd()
+
+ cmt := newTestCommitter(t, subscription, acks)
+ if gotErr := cmt.StartError(); gotErr != nil {
+ t.Errorf("Start() got err: (%v)", gotErr)
+ }
+
+ ack1.Ack()
+ cmt.Stop() // Stop should flush the first offset
+ ack2.Ack() // Acks after Stop() are still processed
+ cmt.SendBatchCommit()
+ // Committer terminates when all acks are processed.
+ if gotErr := cmt.FinalError(); gotErr != nil {
+ t.Errorf("Final err: (%v), want: <nil>", gotErr)
+ }
+}
+
+func TestCommitterPermanentStreamError(t *testing.T) {
+ subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
+ acks := newAckTracker()
+ wantErr := status.Error(codes.FailedPrecondition, "failed")
+
+ verifiers := test.NewVerifiers(t)
+ stream := test.NewRPCVerifier(t)
+ stream.Push(initCommitReq(subscription), nil, wantErr)
+ verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
+
+ mockServer.OnTestStart(verifiers)
+ defer mockServer.OnTestEnd()
+
+ cmt := newTestCommitter(t, subscription, acks)
+ if gotErr := cmt.StartError(); !test.ErrorEqual(gotErr, wantErr) {
+ t.Errorf("Start() got err: (%v), want: (%v)", gotErr, wantErr)
+ }
+}
+
+func TestCommitterInvalidInitialResponse(t *testing.T) {
+ subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
+ acks := newAckTracker()
+
+ verifiers := test.NewVerifiers(t)
+ stream := test.NewRPCVerifier(t)
+ stream.Push(initCommitReq(subscription), commitResp(1234), nil) // Invalid initial response
+ verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
+
+ mockServer.OnTestStart(verifiers)
+ defer mockServer.OnTestEnd()
+
+ cmt := newTestCommitter(t, subscription, acks)
+
+ wantErr := errInvalidInitialCommitResponse
+ if gotErr := cmt.StartError(); !test.ErrorEqual(gotErr, wantErr) {
+ t.Errorf("Start() got err: (%v), want: (%v)", gotErr, wantErr)
+ }
+ if gotErr := cmt.FinalError(); !test.ErrorEqual(gotErr, wantErr) {
+ t.Errorf("Final err: (%v), want: (%v)", gotErr, wantErr)
+ }
+}
+
+func TestCommitterInvalidCommitResponse(t *testing.T) {
+ subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
+ ack := newAckConsumer(33, 0, nil)
+ acks := newAckTracker()
+ acks.Push(ack)
+
+ verifiers := test.NewVerifiers(t)
+ stream := test.NewRPCVerifier(t)
+ stream.Push(initCommitReq(subscription), initCommitResp(), nil)
+ stream.Push(commitReq(34), initCommitResp(), nil) // Invalid commit response
+ verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
+
+ mockServer.OnTestStart(verifiers)
+ defer mockServer.OnTestEnd()
+
+ cmt := newTestCommitter(t, subscription, acks)
+ if gotErr := cmt.StartError(); gotErr != nil {
+ t.Errorf("Start() got err: (%v)", gotErr)
+ }
+
+ ack.Ack()
+ cmt.SendBatchCommit()
+
+ if gotErr, wantErr := cmt.FinalError(), errInvalidCommitResponse; !test.ErrorEqual(gotErr, wantErr) {
+ t.Errorf("Final err: (%v), want: (%v)", gotErr, wantErr)
+ }
+}
+
+func TestCommitterExcessConfirmedOffsets(t *testing.T) {
+ subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
+ ack := newAckConsumer(33, 0, nil)
+ acks := newAckTracker()
+ acks.Push(ack)
+
+ verifiers := test.NewVerifiers(t)
+ stream := test.NewRPCVerifier(t)
+ stream.Push(initCommitReq(subscription), initCommitResp(), nil)
+ stream.Push(commitReq(34), commitResp(2), nil) // More confirmed offsets than committed
+ verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
+
+ mockServer.OnTestStart(verifiers)
+ defer mockServer.OnTestEnd()
+
+ cmt := newTestCommitter(t, subscription, acks)
+ if gotErr := cmt.StartError(); gotErr != nil {
+ t.Errorf("Start() got err: (%v)", gotErr)
+ }
+
+ ack.Ack()
+ cmt.SendBatchCommit()
+
+ wantMsg := "server acknowledged 2 cursor commits"
+ if gotErr := cmt.FinalError(); !test.ErrorHasMsg(gotErr, wantMsg) {
+ t.Errorf("Final err: (%v), want msg: (%v)", gotErr, wantMsg)
+ }
+}
+
+func TestCommitterZeroConfirmedOffsets(t *testing.T) {
+ subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
+ ack := newAckConsumer(33, 0, nil)
+ acks := newAckTracker()
+ acks.Push(ack)
+
+ verifiers := test.NewVerifiers(t)
+ stream := test.NewRPCVerifier(t)
+ stream.Push(initCommitReq(subscription), initCommitResp(), nil)
+ stream.Push(commitReq(34), commitResp(0), nil) // Zero confirmed offsets (invalid)
+ verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
+
+ mockServer.OnTestStart(verifiers)
+ defer mockServer.OnTestEnd()
+
+ cmt := newTestCommitter(t, subscription, acks)
+ if gotErr := cmt.StartError(); gotErr != nil {
+ t.Errorf("Start() got err: (%v)", gotErr)
+ }
+
+ ack.Ack()
+ cmt.SendBatchCommit()
+
+ wantMsg := "server acknowledged an invalid commit count"
+ if gotErr := cmt.FinalError(); !test.ErrorHasMsg(gotErr, wantMsg) {
+ t.Errorf("Final err: (%v), want msg: (%v)", gotErr, wantMsg)
+ }
+}
diff --git a/pubsublite/internal/wire/service_util_test.go b/pubsublite/internal/wire/service_util_test.go
new file mode 100644
index 0000000..e55d160
--- /dev/null
+++ b/pubsublite/internal/wire/service_util_test.go
@@ -0,0 +1,91 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+
+package wire
+
+import (
+ "fmt"
+ "testing"
+ "time"
+)
+
+func testReceiveSettings() ReceiveSettings {
+ settings := DefaultReceiveSettings
+ settings.Timeout = 5 * time.Second
+ return settings
+}
+
+const serviceTestWaitTimeout = 30 * time.Second
+
+// serviceTestProxy wraps a `service` and provides some convenience methods for
+// testing.
+type serviceTestProxy struct {
+ t *testing.T
+ service service
+ name string
+ started chan struct{}
+ terminated chan struct{}
+}
+
+func (sp *serviceTestProxy) initAndStart(t *testing.T, s service, name string) {
+ sp.t = t
+ sp.service = s
+ sp.name = name
+ sp.started = make(chan struct{})
+ sp.terminated = make(chan struct{})
+ s.AddStatusChangeReceiver(nil, sp.onStatusChange)
+ s.Start()
+}
+
+func (sp *serviceTestProxy) onStatusChange(_ serviceHandle, status serviceStatus, _ error) {
+ if status == serviceActive {
+ close(sp.started)
+ }
+ if status == serviceTerminated {
+ close(sp.terminated)
+ }
+}
+
+func (sp *serviceTestProxy) Start() { sp.service.Start() }
+func (sp *serviceTestProxy) Stop() { sp.service.Stop() }
+
+// StartError waits for the service to start and returns the error.
+func (sp *serviceTestProxy) StartError() error {
+ select {
+ case <-time.After(serviceTestWaitTimeout):
+ return fmt.Errorf("%s did not start within %v", sp.name, serviceTestWaitTimeout)
+ case <-sp.terminated:
+ return sp.service.Error()
+ case <-sp.started:
+ return sp.service.Error()
+ }
+}
+
+// FinalError waits for the service to terminate and returns the error.
+func (sp *serviceTestProxy) FinalError() error {
+ select {
+ case <-time.After(serviceTestWaitTimeout):
+ return fmt.Errorf("%s did not terminate within %v", sp.name, serviceTestWaitTimeout)
+ case <-sp.terminated:
+ return sp.service.Error()
+ }
+}
+
+// StopVerifyNoError stops the service, waits for it to terminate and verifies
+// that there is no error.
+func (sp *serviceTestProxy) StopVerifyNoError() {
+ sp.service.Stop()
+ if gotErr := sp.FinalError(); gotErr != nil {
+ sp.t.Errorf("%s final err: (%v), want: <nil>", sp.name, gotErr)
+ }
+}