feat(pubsublite): Committer implementation (#3198)

Commits desired offsets based on the ack prefix offset. Commits are batched.
diff --git a/pubsublite/internal/wire/acks.go b/pubsublite/internal/wire/acks.go
index 0a7c208..411a3ea 100644
--- a/pubsublite/internal/wire/acks.go
+++ b/pubsublite/internal/wire/acks.go
@@ -83,7 +83,7 @@
 
 // ackTracker manages outstanding message acks, i.e. messages that have been
 // delivered to the user, but not yet acked. It is used by the committer and
-// wireSubscriber, so requires its own mutex.
+// subscribeStream, so requires its own mutex.
 type ackTracker struct {
 	// Guards access to fields below.
 	mu sync.Mutex
@@ -163,6 +163,13 @@
 	at.outstandingAcks.Init()
 }
 
+// Empty when there are no outstanding acks.
+func (at *ackTracker) Empty() bool {
+	at.mu.Lock()
+	defer at.mu.Unlock()
+	return at.outstandingAcks.Len() == 0
+}
+
 // commitCursorTracker tracks pending and last successful committed offsets.
 // It is only accessed by the committer.
 type commitCursorTracker struct {
diff --git a/pubsublite/internal/wire/committer.go b/pubsublite/internal/wire/committer.go
new file mode 100644
index 0000000..d540412
--- /dev/null
+++ b/pubsublite/internal/wire/committer.go
@@ -0,0 +1,219 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+
+package wire
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"reflect"
+	"time"
+
+	"google.golang.org/grpc"
+
+	vkit "cloud.google.com/go/pubsublite/apiv1"
+	pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
+)
+
+var (
+	errInvalidInitialCommitResponse = errors.New("pubsublite: first response from server was not an initial response for streaming commit")
+	errInvalidCommitResponse        = errors.New("pubsublite: received invalid commit response from server")
+)
+
+// The frequency of batched cursor commits.
+const commitCursorPeriod = 50 * time.Millisecond
+
+// committer wraps a commit cursor stream for a subscription and partition.
+// A background task periodically effectively reads the latest desired cursor
+// offset from the `ackTracker` and sends a commit request to the stream if the
+// cursor needs to be updated. The `commitCursorTracker` is used to manage
+// in-flight commit requests.
+type committer struct {
+	// Immutable after creation.
+	cursorClient *vkit.CursorClient
+	initialReq   *pb.StreamingCommitCursorRequest
+
+	// Fields below must be guarded with mutex.
+	stream        *retryableStream
+	acks          *ackTracker
+	cursorTracker *commitCursorTracker
+	pollCommits   *periodicTask
+
+	abstractService
+}
+
+func newCommitter(ctx context.Context, cursor *vkit.CursorClient, settings ReceiveSettings,
+	subscription subscriptionPartition, acks *ackTracker, disableTasks bool) *committer {
+
+	c := &committer{
+		cursorClient: cursor,
+		initialReq: &pb.StreamingCommitCursorRequest{
+			Request: &pb.StreamingCommitCursorRequest_Initial{
+				Initial: &pb.InitialCommitCursorRequest{
+					Subscription: subscription.Path,
+					Partition:    int64(subscription.Partition),
+				},
+			},
+		},
+		acks:          acks,
+		cursorTracker: newCommitCursorTracker(acks),
+	}
+	c.stream = newRetryableStream(ctx, c, settings.Timeout, reflect.TypeOf(pb.StreamingCommitCursorResponse{}))
+
+	backgroundTask := c.commitOffsetToStream
+	if disableTasks {
+		backgroundTask = func() {}
+	}
+	c.pollCommits = newPeriodicTask(commitCursorPeriod, backgroundTask)
+	return c
+}
+
+// Start attempts to establish a streaming commit cursor connection.
+func (c *committer) Start() {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	if c.unsafeUpdateStatus(serviceStarting, nil) {
+		c.stream.Start()
+		c.pollCommits.Start()
+	}
+}
+
+// Stop initiates shutdown of the committer. The commit stream remains open to
+// process all outstanding acks and send the final commit offset.
+func (c *committer) Stop() {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	c.unsafeInitiateShutdown(serviceTerminating, nil)
+}
+
+func (c *committer) newStream(ctx context.Context) (grpc.ClientStream, error) {
+	return c.cursorClient.StreamingCommitCursor(ctx)
+}
+
+func (c *committer) initialRequest() (req interface{}, needsResp bool) {
+	req = c.initialReq
+	needsResp = true
+	return
+}
+
+func (c *committer) validateInitialResponse(response interface{}) error {
+	commitResponse, _ := response.(*pb.StreamingCommitCursorResponse)
+	if commitResponse.GetInitial() == nil {
+		return errInvalidInitialCommitResponse
+	}
+	return nil
+}
+
+func (c *committer) onStreamStatusChange(status streamStatus) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	switch status {
+	case streamConnected:
+		c.unsafeUpdateStatus(serviceActive, nil)
+		// Once the stream connects, clear unconfirmed commits and immediately send
+		// the latest desired commit offset.
+		c.cursorTracker.ClearPending()
+		c.unsafeCommitOffsetToStream()
+		c.pollCommits.Start()
+
+	case streamReconnecting:
+		c.pollCommits.Stop()
+
+	case streamTerminated:
+		c.unsafeInitiateShutdown(serviceTerminated, c.stream.Error())
+	}
+}
+
+func (c *committer) onResponse(response interface{}) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	// If an inconsistency is detected in the server's responses, immediately
+	// terminate the committer, as correct processing of commits cannot be
+	// guaranteed.
+	processResponse := func() error {
+		commitResponse, _ := response.(*pb.StreamingCommitCursorResponse)
+		if commitResponse.GetCommit() == nil {
+			return errInvalidCommitResponse
+		}
+		numAcked := commitResponse.GetCommit().GetAcknowledgedCommits()
+		if numAcked <= 0 {
+			return fmt.Errorf("pubsublite: server acknowledged an invalid commit count: %d", numAcked)
+		}
+		if err := c.cursorTracker.ConfirmOffsets(numAcked); err != nil {
+			return err
+		}
+		c.unsafeCheckDone()
+		return nil
+	}
+	if err := processResponse(); err != nil {
+		c.unsafeInitiateShutdown(serviceTerminated, err)
+	}
+}
+
+// commitOffsetToStream is called by the periodic background task.
+func (c *committer) commitOffsetToStream() {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	c.unsafeCommitOffsetToStream()
+}
+
+func (c *committer) unsafeCommitOffsetToStream() {
+	nextOffset := c.cursorTracker.NextOffset()
+	if nextOffset == nilCursorOffset {
+		return
+	}
+
+	req := &pb.StreamingCommitCursorRequest{
+		Request: &pb.StreamingCommitCursorRequest_Commit{
+			Commit: &pb.SequencedCommitCursorRequest{
+				Cursor: &pb.Cursor{Offset: nextOffset},
+			},
+		},
+	}
+	if c.stream.Send(req) {
+		c.cursorTracker.AddPending(nextOffset)
+	}
+}
+
+func (c *committer) unsafeInitiateShutdown(targetStatus serviceStatus, err error) {
+	if !c.unsafeUpdateStatus(targetStatus, err) {
+		return
+	}
+
+	// If it's a graceful shutdown, expedite sending final commits to the stream.
+	if targetStatus == serviceTerminating {
+		c.unsafeCommitOffsetToStream()
+		c.unsafeCheckDone()
+		return
+	}
+	// Otherwise immediately terminate the stream.
+	c.unsafeTerminate()
+}
+
+func (c *committer) unsafeCheckDone() {
+	// If the user stops the subscriber, they will no longer receive messages, but
+	// the commit stream remains open to process acks for outstanding messages.
+	if c.status == serviceTerminating && c.cursorTracker.Done() && c.acks.Empty() {
+		c.unsafeTerminate()
+	}
+}
+
+func (c *committer) unsafeTerminate() {
+	c.acks.Release()
+	c.pollCommits.Stop()
+	c.stream.Stop()
+}
diff --git a/pubsublite/internal/wire/committer_test.go b/pubsublite/internal/wire/committer_test.go
new file mode 100644
index 0000000..1c1ea2e
--- /dev/null
+++ b/pubsublite/internal/wire/committer_test.go
@@ -0,0 +1,252 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+
+package wire
+
+import (
+	"context"
+	"testing"
+
+	"cloud.google.com/go/pubsublite/internal/test"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/status"
+)
+
+// testCommitter wraps a committer for ease of testing.
+type testCommitter struct {
+	cmt *committer
+	serviceTestProxy
+}
+
+func newTestCommitter(t *testing.T, subscription subscriptionPartition, acks *ackTracker) *testCommitter {
+	ctx := context.Background()
+	cursorClient, err := newCursorClient(ctx, "ignored", testClientOpts...)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	tc := &testCommitter{
+		cmt: newCommitter(ctx, cursorClient, testReceiveSettings(), subscription, acks, true),
+	}
+	tc.initAndStart(t, tc.cmt, "Committer")
+	return tc
+}
+
+// SendBatchCommit invokes the periodic background batch commit. Note that the
+// periodic task is disabled in tests.
+func (tc *testCommitter) SendBatchCommit() {
+	tc.cmt.commitOffsetToStream()
+}
+
+func TestCommitterStreamReconnect(t *testing.T) {
+	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
+	ack1 := newAckConsumer(33, 0, nil)
+	ack2 := newAckConsumer(55, 0, nil)
+	acks := newAckTracker()
+	acks.Push(ack1)
+	acks.Push(ack2)
+
+	verifiers := test.NewVerifiers(t)
+
+	// Simulate a transient error that results in a reconnect.
+	stream1 := test.NewRPCVerifier(t)
+	stream1.Push(initCommitReq(subscription), initCommitResp(), nil)
+	barrier := stream1.PushWithBarrier(commitReq(34), nil, status.Error(codes.Unavailable, "server unavailable"))
+	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream1)
+
+	// When the stream reconnects, the latest commit offset should be sent to the
+	// server.
+	stream2 := test.NewRPCVerifier(t)
+	stream2.Push(initCommitReq(subscription), initCommitResp(), nil)
+	stream2.Push(commitReq(56), commitResp(1), nil)
+	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream2)
+
+	mockServer.OnTestStart(verifiers)
+	defer mockServer.OnTestEnd()
+
+	cmt := newTestCommitter(t, subscription, acks)
+	if gotErr := cmt.StartError(); gotErr != nil {
+		t.Errorf("Start() got err: (%v)", gotErr)
+	}
+
+	// Send 2 commits.
+	ack1.Ack()
+	cmt.SendBatchCommit()
+	ack2.Ack()
+	cmt.SendBatchCommit()
+
+	// Then send the retryable error, which results in reconnect.
+	barrier.Release()
+	cmt.StopVerifyNoError()
+}
+
+func TestCommitterStopFlushesCommits(t *testing.T) {
+	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
+	ack1 := newAckConsumer(33, 0, nil)
+	ack2 := newAckConsumer(55, 0, nil)
+	acks := newAckTracker()
+	acks.Push(ack1)
+	acks.Push(ack2)
+
+	verifiers := test.NewVerifiers(t)
+	stream := test.NewRPCVerifier(t)
+	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
+	stream.Push(commitReq(34), commitResp(1), nil)
+	stream.Push(commitReq(56), commitResp(1), nil)
+	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
+
+	mockServer.OnTestStart(verifiers)
+	defer mockServer.OnTestEnd()
+
+	cmt := newTestCommitter(t, subscription, acks)
+	if gotErr := cmt.StartError(); gotErr != nil {
+		t.Errorf("Start() got err: (%v)", gotErr)
+	}
+
+	ack1.Ack()
+	cmt.Stop() // Stop should flush the first offset
+	ack2.Ack() // Acks after Stop() are still processed
+	cmt.SendBatchCommit()
+	// Committer terminates when all acks are processed.
+	if gotErr := cmt.FinalError(); gotErr != nil {
+		t.Errorf("Final err: (%v), want: <nil>", gotErr)
+	}
+}
+
+func TestCommitterPermanentStreamError(t *testing.T) {
+	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
+	acks := newAckTracker()
+	wantErr := status.Error(codes.FailedPrecondition, "failed")
+
+	verifiers := test.NewVerifiers(t)
+	stream := test.NewRPCVerifier(t)
+	stream.Push(initCommitReq(subscription), nil, wantErr)
+	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
+
+	mockServer.OnTestStart(verifiers)
+	defer mockServer.OnTestEnd()
+
+	cmt := newTestCommitter(t, subscription, acks)
+	if gotErr := cmt.StartError(); !test.ErrorEqual(gotErr, wantErr) {
+		t.Errorf("Start() got err: (%v), want: (%v)", gotErr, wantErr)
+	}
+}
+
+func TestCommitterInvalidInitialResponse(t *testing.T) {
+	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
+	acks := newAckTracker()
+
+	verifiers := test.NewVerifiers(t)
+	stream := test.NewRPCVerifier(t)
+	stream.Push(initCommitReq(subscription), commitResp(1234), nil) // Invalid initial response
+	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
+
+	mockServer.OnTestStart(verifiers)
+	defer mockServer.OnTestEnd()
+
+	cmt := newTestCommitter(t, subscription, acks)
+
+	wantErr := errInvalidInitialCommitResponse
+	if gotErr := cmt.StartError(); !test.ErrorEqual(gotErr, wantErr) {
+		t.Errorf("Start() got err: (%v), want: (%v)", gotErr, wantErr)
+	}
+	if gotErr := cmt.FinalError(); !test.ErrorEqual(gotErr, wantErr) {
+		t.Errorf("Final err: (%v), want: (%v)", gotErr, wantErr)
+	}
+}
+
+func TestCommitterInvalidCommitResponse(t *testing.T) {
+	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
+	ack := newAckConsumer(33, 0, nil)
+	acks := newAckTracker()
+	acks.Push(ack)
+
+	verifiers := test.NewVerifiers(t)
+	stream := test.NewRPCVerifier(t)
+	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
+	stream.Push(commitReq(34), initCommitResp(), nil) // Invalid commit response
+	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
+
+	mockServer.OnTestStart(verifiers)
+	defer mockServer.OnTestEnd()
+
+	cmt := newTestCommitter(t, subscription, acks)
+	if gotErr := cmt.StartError(); gotErr != nil {
+		t.Errorf("Start() got err: (%v)", gotErr)
+	}
+
+	ack.Ack()
+	cmt.SendBatchCommit()
+
+	if gotErr, wantErr := cmt.FinalError(), errInvalidCommitResponse; !test.ErrorEqual(gotErr, wantErr) {
+		t.Errorf("Final err: (%v), want: (%v)", gotErr, wantErr)
+	}
+}
+
+func TestCommitterExcessConfirmedOffsets(t *testing.T) {
+	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
+	ack := newAckConsumer(33, 0, nil)
+	acks := newAckTracker()
+	acks.Push(ack)
+
+	verifiers := test.NewVerifiers(t)
+	stream := test.NewRPCVerifier(t)
+	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
+	stream.Push(commitReq(34), commitResp(2), nil) // More confirmed offsets than committed
+	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
+
+	mockServer.OnTestStart(verifiers)
+	defer mockServer.OnTestEnd()
+
+	cmt := newTestCommitter(t, subscription, acks)
+	if gotErr := cmt.StartError(); gotErr != nil {
+		t.Errorf("Start() got err: (%v)", gotErr)
+	}
+
+	ack.Ack()
+	cmt.SendBatchCommit()
+
+	wantMsg := "server acknowledged 2 cursor commits"
+	if gotErr := cmt.FinalError(); !test.ErrorHasMsg(gotErr, wantMsg) {
+		t.Errorf("Final err: (%v), want msg: (%v)", gotErr, wantMsg)
+	}
+}
+
+func TestCommitterZeroConfirmedOffsets(t *testing.T) {
+	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
+	ack := newAckConsumer(33, 0, nil)
+	acks := newAckTracker()
+	acks.Push(ack)
+
+	verifiers := test.NewVerifiers(t)
+	stream := test.NewRPCVerifier(t)
+	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
+	stream.Push(commitReq(34), commitResp(0), nil) // Zero confirmed offsets (invalid)
+	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
+
+	mockServer.OnTestStart(verifiers)
+	defer mockServer.OnTestEnd()
+
+	cmt := newTestCommitter(t, subscription, acks)
+	if gotErr := cmt.StartError(); gotErr != nil {
+		t.Errorf("Start() got err: (%v)", gotErr)
+	}
+
+	ack.Ack()
+	cmt.SendBatchCommit()
+
+	wantMsg := "server acknowledged an invalid commit count"
+	if gotErr := cmt.FinalError(); !test.ErrorHasMsg(gotErr, wantMsg) {
+		t.Errorf("Final err: (%v), want msg: (%v)", gotErr, wantMsg)
+	}
+}
diff --git a/pubsublite/internal/wire/service_util_test.go b/pubsublite/internal/wire/service_util_test.go
new file mode 100644
index 0000000..e55d160
--- /dev/null
+++ b/pubsublite/internal/wire/service_util_test.go
@@ -0,0 +1,91 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+
+package wire
+
+import (
+	"fmt"
+	"testing"
+	"time"
+)
+
+func testReceiveSettings() ReceiveSettings {
+	settings := DefaultReceiveSettings
+	settings.Timeout = 5 * time.Second
+	return settings
+}
+
+const serviceTestWaitTimeout = 30 * time.Second
+
+// serviceTestProxy wraps a `service` and provides some convenience methods for
+// testing.
+type serviceTestProxy struct {
+	t          *testing.T
+	service    service
+	name       string
+	started    chan struct{}
+	terminated chan struct{}
+}
+
+func (sp *serviceTestProxy) initAndStart(t *testing.T, s service, name string) {
+	sp.t = t
+	sp.service = s
+	sp.name = name
+	sp.started = make(chan struct{})
+	sp.terminated = make(chan struct{})
+	s.AddStatusChangeReceiver(nil, sp.onStatusChange)
+	s.Start()
+}
+
+func (sp *serviceTestProxy) onStatusChange(_ serviceHandle, status serviceStatus, _ error) {
+	if status == serviceActive {
+		close(sp.started)
+	}
+	if status == serviceTerminated {
+		close(sp.terminated)
+	}
+}
+
+func (sp *serviceTestProxy) Start() { sp.service.Start() }
+func (sp *serviceTestProxy) Stop()  { sp.service.Stop() }
+
+// StartError waits for the service to start and returns the error.
+func (sp *serviceTestProxy) StartError() error {
+	select {
+	case <-time.After(serviceTestWaitTimeout):
+		return fmt.Errorf("%s did not start within %v", sp.name, serviceTestWaitTimeout)
+	case <-sp.terminated:
+		return sp.service.Error()
+	case <-sp.started:
+		return sp.service.Error()
+	}
+}
+
+// FinalError waits for the service to terminate and returns the error.
+func (sp *serviceTestProxy) FinalError() error {
+	select {
+	case <-time.After(serviceTestWaitTimeout):
+		return fmt.Errorf("%s did not terminate within %v", sp.name, serviceTestWaitTimeout)
+	case <-sp.terminated:
+		return sp.service.Error()
+	}
+}
+
+// StopVerifyNoError stops the service, waits for it to terminate and verifies
+// that there is no error.
+func (sp *serviceTestProxy) StopVerifyNoError() {
+	sp.service.Stop()
+	if gotErr := sp.FinalError(); gotErr != nil {
+		sp.t.Errorf("%s final err: (%v), want: <nil>", sp.name, gotErr)
+	}
+}