feat(pubsublite): Refactoring and unit tests for retryableStream (#3160)

Minor refactoring and unit tests for retryableStream. Eliminated test flakiness in mock server.
diff --git a/pubsublite/internal/test/mock.go b/pubsublite/internal/test/mock.go
index d27a437..bf74683 100644
--- a/pubsublite/internal/test/mock.go
+++ b/pubsublite/internal/test/mock.go
@@ -15,12 +15,12 @@
 
 import (
 	"context"
-	"fmt"
 	"io"
 	"reflect"
 	"sync"
 
 	"cloud.google.com/go/internal/testutil"
+	"cloud.google.com/go/internal/uid"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/status"
@@ -40,20 +40,11 @@
 // This is the interface that should be used by tests.
 type MockServer interface {
 	// OnTestStart must be called at the start of each test to clear any existing
-	// state and set the verifier for unary RPCs.
-	OnTestStart(globalVerifier *RPCVerifier)
+	// state and set the test verifiers.
+	OnTestStart(*Verifiers)
 	// OnTestEnd should be called at the end of each test to flush the verifiers
 	// (i.e. check whether any expected requests were not sent to the server).
 	OnTestEnd()
-	// AddPublishStream adds a verifier for a publish stream of a topic partition.
-	AddPublishStream(topic string, partition int, streamVerifier *RPCVerifier)
-	// AddSubscribeStream adds a verifier for a subscribe stream of a partition.
-	AddSubscribeStream(subscription string, partition int, streamVerifier *RPCVerifier)
-	// AddCommitStream adds a verifier for a commit stream of a partition.
-	AddCommitStream(subscription string, partition int, streamVerifier *RPCVerifier)
-	// AddAssignmentStream adds a verifier for a partition assignment stream for a
-	// subscription.
-	AddAssignmentStream(subscription string, streamVerifier *RPCVerifier)
 }
 
 // NewServer creates a new mock Pub/Sub Lite server.
@@ -97,79 +88,63 @@
 
 	mu sync.Mutex
 
-	// Global list of verifiers for all unary RPCs. This should be set before the
-	// test begins.
-	globalVerifier *RPCVerifier
-
-	// Stream verifiers by key.
-	publishVerifiers    *keyedStreamVerifiers
-	subscribeVerifiers  *keyedStreamVerifiers
-	commitVerifiers     *keyedStreamVerifiers
-	assignmentVerifiers *keyedStreamVerifiers
-
-	nextStreamID  int
-	activeStreams map[int]*streamHolder
-	testActive    bool
-}
-
-func key(path string, partition int) string {
-	return fmt.Sprintf("%s:%d", path, partition)
+	testVerifiers *Verifiers
+	testIDs       *uid.Space
+	currentTestID string
 }
 
 func newMockLiteServer() *mockLiteServer {
 	return &mockLiteServer{
-		publishVerifiers:    newKeyedStreamVerifiers(),
-		subscribeVerifiers:  newKeyedStreamVerifiers(),
-		commitVerifiers:     newKeyedStreamVerifiers(),
-		assignmentVerifiers: newKeyedStreamVerifiers(),
-		activeStreams:       make(map[int]*streamHolder),
+		testIDs: uid.NewSpace("mockLiteServer", nil),
 	}
 }
 
-func (s *mockLiteServer) startStream(stream grpc.ServerStream, verifier *RPCVerifier) (id int) {
+func (s *mockLiteServer) popGlobalVerifiers(request interface{}) (interface{}, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
-	id = s.nextStreamID
-	s.nextStreamID++
-	s.activeStreams[id] = &streamHolder{stream: stream, verifier: verifier}
-	return
+	if s.testVerifiers == nil {
+		return nil, status.Errorf(codes.FailedPrecondition, "mockserver: previous test has ended")
+	}
+	return s.testVerifiers.GlobalVerifier.Pop(request)
 }
 
-func (s *mockLiteServer) endStream(id int) {
+func (s *mockLiteServer) popStreamVerifier(key string) (*RPCVerifier, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
-	delete(s.activeStreams, id)
+	if s.testVerifiers == nil {
+		return nil, status.Errorf(codes.FailedPrecondition, "mockserver: previous test has ended")
+	}
+	return s.testVerifiers.streamVerifiers.Pop(key)
 }
 
-func (s *mockLiteServer) popStreamVerifier(key string, keyedVerifiers *keyedStreamVerifiers) (*RPCVerifier, error) {
-	s.mu.Lock()
-	defer s.mu.Unlock()
-
-	return keyedVerifiers.Pop(key)
-}
-
-func (s *mockLiteServer) handleStream(stream grpc.ServerStream, req interface{}, requestType reflect.Type, key string, keyedVerifiers *keyedStreamVerifiers) (err error) {
-	verifier, err := s.popStreamVerifier(key, keyedVerifiers)
+func (s *mockLiteServer) handleStream(stream grpc.ServerStream, req interface{}, requestType reflect.Type, key string) (err error) {
+	testID := s.currentTest()
+	if testID == "" {
+		return status.Errorf(codes.FailedPrecondition, "mockserver: previous test has ended")
+	}
+	verifier, err := s.popStreamVerifier(key)
 	if err != nil {
 		return err
 	}
 
-	id := s.startStream(stream, verifier)
-
 	// Verify initial request.
 	retResponse, retErr := verifier.Pop(req)
 	var ok bool
 
 	for {
+		// See comments for RPCVerifier.Push for valid stream request/response
+		// combinations.
 		if retErr != nil {
 			err = retErr
 			break
 		}
-		if err = stream.SendMsg(retResponse); err != nil {
-			err = status.Errorf(codes.FailedPrecondition, "mockserver: stream send error: %v", err)
-			break
+		if retResponse != nil {
+			if err = stream.SendMsg(retResponse); err != nil {
+				err = status.Errorf(codes.FailedPrecondition, "mockserver: stream send error: %v", err)
+				break
+			}
 		}
 
 		// Check whether the next response isn't blocked on a request.
@@ -185,70 +160,47 @@
 			err = status.Errorf(codes.FailedPrecondition, "mockserver: stream recv error: %v", err)
 			break
 		}
+		if testID != s.currentTest() {
+			err = status.Errorf(codes.FailedPrecondition, "mockserver: previous test has ended")
+			break
+		}
 		retResponse, retErr = verifier.Pop(req)
 	}
 
 	// Check whether the stream ended prematurely.
-	verifier.Flush()
-	s.endStream(id)
+	if testID == s.currentTest() {
+		verifier.Flush()
+	}
 	return
 }
 
 // MockServer implementation.
 
-func (s *mockLiteServer) OnTestStart(globalVerifier *RPCVerifier) {
+func (s *mockLiteServer) OnTestStart(verifiers *Verifiers) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
-	if s.testActive {
+	if s.currentTestID != "" {
 		panic("mockserver is already in use by another test")
 	}
-
-	s.testActive = true
-	s.globalVerifier = globalVerifier
-	s.publishVerifiers.Reset()
-	s.subscribeVerifiers.Reset()
-	s.commitVerifiers.Reset()
-	s.assignmentVerifiers.Reset()
-	s.activeStreams = make(map[int]*streamHolder)
+	s.currentTestID = s.testIDs.New()
+	s.testVerifiers = verifiers
 }
 
 func (s *mockLiteServer) OnTestEnd() {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
-	s.testActive = false
-	if s.globalVerifier != nil {
-		s.globalVerifier.Flush()
-	}
-
-	for _, as := range s.activeStreams {
-		as.verifier.Flush()
+	s.currentTestID = ""
+	if s.testVerifiers != nil {
+		s.testVerifiers.flush()
 	}
 }
 
-func (s *mockLiteServer) AddPublishStream(topic string, partition int, streamVerifier *RPCVerifier) {
+func (s *mockLiteServer) currentTest() string {
 	s.mu.Lock()
 	defer s.mu.Unlock()
-	s.publishVerifiers.Push(key(topic, partition), streamVerifier)
-}
-
-func (s *mockLiteServer) AddSubscribeStream(subscription string, partition int, streamVerifier *RPCVerifier) {
-	s.mu.Lock()
-	defer s.mu.Unlock()
-	s.subscribeVerifiers.Push(key(subscription, partition), streamVerifier)
-}
-
-func (s *mockLiteServer) AddCommitStream(subscription string, partition int, streamVerifier *RPCVerifier) {
-	s.mu.Lock()
-	defer s.mu.Unlock()
-	s.commitVerifiers.Push(key(subscription, partition), streamVerifier)
-}
-
-func (s *mockLiteServer) AddAssignmentStream(subscription string, streamVerifier *RPCVerifier) {
-	s.mu.Lock()
-	defer s.mu.Unlock()
-	s.assignmentVerifiers.Push(subscription, streamVerifier)
+	return s.currentTestID
 }
 
 // PublisherService implementation.
@@ -263,8 +215,8 @@
 	}
 
 	initReq := req.GetInitialRequest()
-	k := key(initReq.GetTopic(), int(initReq.GetPartition()))
-	return s.handleStream(stream, req, reflect.TypeOf(pb.PublishRequest{}), k, s.publishVerifiers)
+	k := keyPartition(publishStreamType, initReq.GetTopic(), int(initReq.GetPartition()))
+	return s.handleStream(stream, req, reflect.TypeOf(pb.PublishRequest{}), k)
 }
 
 // SubscriberService implementation.
@@ -279,8 +231,8 @@
 	}
 
 	initReq := req.GetInitial()
-	k := key(initReq.GetSubscription(), int(initReq.GetPartition()))
-	return s.handleStream(stream, req, reflect.TypeOf(pb.SubscribeRequest{}), k, s.subscribeVerifiers)
+	k := keyPartition(subscribeStreamType, initReq.GetSubscription(), int(initReq.GetPartition()))
+	return s.handleStream(stream, req, reflect.TypeOf(pb.SubscribeRequest{}), k)
 }
 
 // CursorService implementation.
@@ -295,8 +247,8 @@
 	}
 
 	initReq := req.GetInitial()
-	k := key(initReq.GetSubscription(), int(initReq.GetPartition()))
-	return s.handleStream(stream, req, reflect.TypeOf(pb.StreamingCommitCursorRequest{}), k, s.commitVerifiers)
+	k := keyPartition(commitStreamType, initReq.GetSubscription(), int(initReq.GetPartition()))
+	return s.handleStream(stream, req, reflect.TypeOf(pb.StreamingCommitCursorRequest{}), k)
 }
 
 // PartitionAssignmentService implementation.
@@ -310,17 +262,14 @@
 		return status.Errorf(codes.InvalidArgument, "mockserver: received invalid initial partition assignment request: %v", req)
 	}
 
-	k := req.GetInitial().GetSubscription()
-	return s.handleStream(stream, req, reflect.TypeOf(pb.PartitionAssignmentRequest{}), k, s.assignmentVerifiers)
+	k := key(assignmentStreamType, req.GetInitial().GetSubscription())
+	return s.handleStream(stream, req, reflect.TypeOf(pb.PartitionAssignmentRequest{}), k)
 }
 
 // AdminService implementation.
 
 func (s *mockLiteServer) GetTopicPartitions(ctx context.Context, req *pb.GetTopicPartitionsRequest) (*pb.TopicPartitions, error) {
-	s.mu.Lock()
-	defer s.mu.Unlock()
-
-	retResponse, retErr := s.globalVerifier.Pop(req)
+	retResponse, retErr := s.popGlobalVerifiers(req)
 	if retErr != nil {
 		return nil, retErr
 	}
diff --git a/pubsublite/internal/test/verifier.go b/pubsublite/internal/test/verifier.go
index a1e2681..3f57e5e 100644
--- a/pubsublite/internal/test/verifier.go
+++ b/pubsublite/internal/test/verifier.go
@@ -15,6 +15,7 @@
 
 import (
 	"container/list"
+	"fmt"
 	"sync"
 	"testing"
 	"time"
@@ -30,28 +31,72 @@
 	blockWaitTimeout = 30 * time.Second
 )
 
-type rpcMetadata struct {
-	wantRequest   interface{}
-	retResponse   interface{}
-	retErr        error
-	blockResponse chan struct{}
+// Barrier is used to perform two-way synchronization betwen the server and
+// client (test) to ensure tests are deterministic.
+type Barrier struct {
+	// Used to block until the server is ready to send the response.
+	serverBlock chan struct{}
+	// Used to block until the client wants the server to send the response.
+	clientBlock chan struct{}
+	err         error
 }
 
-// wait until the `blockResponse` is released by the test, or a timeout occurs.
-// Returns immediately if there was no block.
-func (r *rpcMetadata) wait() error {
-	if r.blockResponse == nil {
-		return nil
+func newBarrier() *Barrier {
+	return &Barrier{
+		serverBlock: make(chan struct{}),
+		clientBlock: make(chan struct{}),
 	}
+}
+
+// Release should be called by the test.
+func (b *Barrier) Release() {
+	// Wait for the server to reach the barrier.
+	select {
+	case <-time.After(blockWaitTimeout):
+		// Note: avoid returning a retryable code to quickly terminate the test.
+		b.err = status.Errorf(codes.FailedPrecondition, "mockserver: server did not reach barrier within %v", blockWaitTimeout)
+	case <-b.serverBlock:
+	}
+
+	// Then close the client block.
+	close(b.clientBlock)
+}
+
+func (b *Barrier) serverWait() error {
+	if b.err != nil {
+		return b.err
+	}
+
+	// Close the server block to signal the server reaching the point where it is
+	// ready to send the response.
+	close(b.serverBlock)
+
+	// Wait for the test to release the client block.
 	select {
 	case <-time.After(blockWaitTimeout):
 		// Note: avoid returning a retryable code to quickly terminate the test.
 		return status.Errorf(codes.FailedPrecondition, "mockserver: test did not unblock response within %v", blockWaitTimeout)
-	case <-r.blockResponse:
+	case <-b.clientBlock:
 		return nil
 	}
 }
 
+type rpcMetadata struct {
+	wantRequest interface{}
+	retResponse interface{}
+	retErr      error
+	barrier     *Barrier
+}
+
+// wait until the barrier is released by the test, or a timeout occurs.
+// Returns immediately if there was no block.
+func (r *rpcMetadata) wait() error {
+	if r.barrier == nil {
+		return nil
+	}
+	return r.barrier.serverWait()
+}
+
 // RPCVerifier stores an queue of requests expected from the client, and the
 // corresponding response or error to return.
 type RPCVerifier struct {
@@ -71,6 +116,15 @@
 }
 
 // Push appends a new {request, response, error} tuple.
+//
+// Valid combinations for unary and streaming RPCs:
+// - {request, response, nil}
+// - {request, nil, error}
+//
+// Additional combinations for streams only:
+// - {nil, response, nil}: send a response without a request (e.g. messages).
+// - {nil, nil, error}: break the stream without a request.
+// - {request, nil, nil}: expect a request, but don't send any response.
 func (v *RPCVerifier) Push(wantRequest interface{}, retResponse interface{}, retErr error) {
 	v.mu.Lock()
 	defer v.mu.Unlock()
@@ -82,21 +136,21 @@
 	})
 }
 
-// PushWithBlock is like Push, but returns a channel that the test should close
-// when it would like the response to be sent to the client. This is useful for
-// synchronizing with work that needs to be done on the client.
-func (v *RPCVerifier) PushWithBlock(wantRequest interface{}, retResponse interface{}, retErr error) chan struct{} {
+// PushWithBarrier is like Push, but returns a barrier that the test should call
+// Release when it would like the response to be sent to the client. This is
+// useful for synchronizing with work that needs to be done on the client.
+func (v *RPCVerifier) PushWithBarrier(wantRequest interface{}, retResponse interface{}, retErr error) *Barrier {
 	v.mu.Lock()
 	defer v.mu.Unlock()
 
-	block := make(chan struct{})
+	barrier := newBarrier()
 	v.rpcs.PushBack(&rpcMetadata{
-		wantRequest:   wantRequest,
-		retResponse:   retResponse,
-		retErr:        retErr,
-		blockResponse: block,
+		wantRequest: wantRequest,
+		retResponse: retResponse,
+		retErr:      retErr,
+		barrier:     barrier,
 	})
-	return block
+	return barrier
 }
 
 // Pop validates the received request with the next {request, response, error}
@@ -158,7 +212,11 @@
 	for elem := v.rpcs.Front(); elem != nil; elem = elem.Next() {
 		v.numCalls++
 		rpc, _ := elem.Value.(*rpcMetadata)
-		v.t.Errorf("call(%d): did not receive expected request:\n%v", v.numCalls, rpc.wantRequest)
+		if rpc.wantRequest != nil {
+			v.t.Errorf("call(%d): did not receive expected request:\n%v", v.numCalls, rpc.wantRequest)
+		} else {
+			v.t.Errorf("call(%d): unsent response:\n%v, err = (%v)", v.numCalls, rpc.retResponse, rpc.retErr)
+		}
 	}
 	v.rpcs.Init()
 }
@@ -195,7 +253,15 @@
 	return v, nil
 }
 
-// keyedStreamVerifiers stores indexed streamVerifiers.
+func (sv *streamVerifiers) Flush() {
+	for elem := sv.verifiers.Front(); elem != nil; elem = elem.Next() {
+		v, _ := elem.Value.(*RPCVerifier)
+		v.Flush()
+	}
+}
+
+// keyedStreamVerifiers stores indexed streamVerifiers. Examples of keys:
+// "streamType:topic_path:partition".
 type keyedStreamVerifiers struct {
 	verifiers map[string]*streamVerifiers
 }
@@ -204,10 +270,6 @@
 	return &keyedStreamVerifiers{verifiers: make(map[string]*streamVerifiers)}
 }
 
-func (kv *keyedStreamVerifiers) Reset() {
-	kv.verifiers = make(map[string]*streamVerifiers)
-}
-
 func (kv *keyedStreamVerifiers) Push(key string, v *RPCVerifier) {
 	sv, ok := kv.verifiers[key]
 	if !ok {
@@ -224,3 +286,97 @@
 	}
 	return sv.Pop()
 }
+
+func (kv *keyedStreamVerifiers) Flush() {
+	for _, sv := range kv.verifiers {
+		sv.Flush()
+	}
+}
+
+// Verifiers contains RPCVerifiers for unary RPCs and streaming RPCs.
+type Verifiers struct {
+	t  *testing.T
+	mu sync.Mutex
+
+	// Global list of verifiers for all unary RPCs.
+	GlobalVerifier *RPCVerifier
+	// Stream verifiers by key.
+	streamVerifiers       *keyedStreamVerifiers
+	activeStreamVerifiers []*RPCVerifier
+}
+
+// NewVerifiers creates a new instance of Verifiers for a test.
+func NewVerifiers(t *testing.T) *Verifiers {
+	return &Verifiers{
+		t:               t,
+		GlobalVerifier:  NewRPCVerifier(t),
+		streamVerifiers: newKeyedStreamVerifiers(),
+	}
+}
+
+// streamType is used as a key prefix for keyedStreamVerifiers.
+type streamType string
+
+const (
+	publishStreamType    streamType = "publish"
+	subscribeStreamType  streamType = "subscribe"
+	commitStreamType     streamType = "commit"
+	assignmentStreamType streamType = "assignment"
+)
+
+func keyPartition(st streamType, path string, partition int) string {
+	return fmt.Sprintf("%s:%s:%d", st, path, partition)
+}
+
+func key(st streamType, path string) string {
+	return fmt.Sprintf("%s:%s", st, path)
+}
+
+// AddPublishStream adds verifiers for a publish stream.
+func (tv *Verifiers) AddPublishStream(topic string, partition int, streamVerifier *RPCVerifier) {
+	tv.mu.Lock()
+	defer tv.mu.Unlock()
+	tv.streamVerifiers.Push(keyPartition(publishStreamType, topic, partition), streamVerifier)
+}
+
+// AddSubscribeStream adds verifiers for a subscribe stream.
+func (tv *Verifiers) AddSubscribeStream(subscription string, partition int, streamVerifier *RPCVerifier) {
+	tv.mu.Lock()
+	defer tv.mu.Unlock()
+	tv.streamVerifiers.Push(keyPartition(subscribeStreamType, subscription, partition), streamVerifier)
+}
+
+// AddCommitStream adds verifiers for a commit stream.
+func (tv *Verifiers) AddCommitStream(subscription string, partition int, streamVerifier *RPCVerifier) {
+	tv.mu.Lock()
+	defer tv.mu.Unlock()
+	tv.streamVerifiers.Push(keyPartition(commitStreamType, subscription, partition), streamVerifier)
+}
+
+// AddAssignmentStream adds verifiers for an assignment stream.
+func (tv *Verifiers) AddAssignmentStream(subscription string, streamVerifier *RPCVerifier) {
+	tv.mu.Lock()
+	defer tv.mu.Unlock()
+	tv.streamVerifiers.Push(key(assignmentStreamType, subscription), streamVerifier)
+}
+
+func (tv *Verifiers) popStreamVerifier(key string) (*RPCVerifier, error) {
+	tv.mu.Lock()
+	defer tv.mu.Unlock()
+	v, err := tv.streamVerifiers.Pop(key)
+	if v != nil {
+		tv.activeStreamVerifiers = append(tv.activeStreamVerifiers, v)
+	}
+	return v, err
+}
+
+func (tv *Verifiers) flush() {
+	tv.mu.Lock()
+	defer tv.mu.Unlock()
+
+	tv.GlobalVerifier.Flush()
+	tv.streamVerifiers.Flush()
+	for _, v := range tv.activeStreamVerifiers {
+		v.Flush()
+	}
+}
diff --git a/pubsublite/internal/wire/main_test.go b/pubsublite/internal/wire/main_test.go
new file mode 100644
index 0000000..0f9b401
--- /dev/null
+++ b/pubsublite/internal/wire/main_test.go
@@ -0,0 +1,50 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+
+package wire
+
+import (
+	"flag"
+	"log"
+	"os"
+	"testing"
+
+	"cloud.google.com/go/pubsublite/internal/test"
+	"google.golang.org/api/option"
+	"google.golang.org/grpc"
+)
+
+var (
+	// Initialized in TestMain.
+	mockServer     test.MockServer
+	testClientOpts []option.ClientOption
+)
+
+func TestMain(m *testing.M) {
+	flag.Parse()
+
+	testServer, err := test.NewServer()
+	if err != nil {
+		log.Fatal(err)
+	}
+	mockServer = testServer.LiteServer
+	conn, err := grpc.Dial(testServer.Addr(), grpc.WithInsecure())
+	if err != nil {
+		log.Fatal(err)
+	}
+	testClientOpts = []option.ClientOption{option.WithGRPCConn(conn)}
+
+	exit := m.Run()
+	testServer.Close()
+	os.Exit(exit)
+}
diff --git a/pubsublite/internal/wire/streams.go b/pubsublite/internal/wire/streams.go
index 1df7ed3..9bbc1ed 100644
--- a/pubsublite/internal/wire/streams.go
+++ b/pubsublite/internal/wire/streams.go
@@ -51,14 +51,21 @@
 	// newStream implementations must create the client stream with the given
 	// (cancellable) context.
 	newStream(context.Context) (grpc.ClientStream, error)
-	initialRequest() interface{}
+	// initialRequest should return the initial request and whether an initial
+	// response is expected.
+	initialRequest() (interface{}, bool)
 	validateInitialResponse(interface{}) error
 
 	// onStreamStatusChange is used to notify stream handlers when the stream has
-	// changed state. In particular, the `streamTerminated` state must be handled.
-	// retryableStream.Error() returns the error that caused the stream to
-	// terminate. Stream handlers should perform any necessary reset of state upon
-	// `streamConnected`.
+	// changed state. A `streamReconnecting` status change is fired before
+	// attempting to connect a new stream. A `streamConnected` status change is
+	// fired when the stream is successfully connected. These are followed by
+	// onResponse() calls when responses are received from the server. These
+	// events are guaranteed to occur in this order.
+	//
+	// A final `streamTerminated` status change is fired when a permanent error
+	// occurs. retryableStream.Error() returns the error that caused the stream to
+	// terminate.
 	onStreamStatusChange(streamStatus)
 	// onResponse forwards a response received on the stream to the stream
 	// handler.
@@ -68,10 +75,9 @@
 // retryableStream is a wrapper around a bidirectional gRPC client stream to
 // handle automatic reconnection when the stream breaks.
 //
-// A retryableStream cycles between the following goroutines:
-//   Start() --> reconnect() <--> listen()
-// terminate() can be called at any time, either by the client to force stream
-// closure, or as a result of an unretryable error.
+// The connectStream() goroutine handles each stream connection. terminate() can
+// be called at any time, either by the client to force stream closure, or as a
+// result of an unretryable error.
 //
 // Safe to call capitalized methods from multiple goroutines. All other methods
 // are private implementation.
@@ -114,7 +120,7 @@
 	if rs.status != streamUninitialized {
 		return
 	}
-	go rs.reconnect()
+	go rs.connectStream()
 }
 
 // Stop gracefully closes the stream without error.
@@ -139,7 +145,7 @@
 			// stream. Nothing to do here.
 			break
 		case isRetryableSendError(err):
-			go rs.reconnect()
+			go rs.connectStream()
 		default:
 			rs.mu.Unlock() // terminate acquires the mutex.
 			rs.terminate(err)
@@ -190,14 +196,13 @@
 	rs.cancelStream = cancel
 }
 
-// reconnect attempts to establish a valid connection with the server. Due to
-// the potential high latency, initNewStream() should not be done while holding
-// retryableStream.mu. Hence we need to handle the stream being force terminated
-// during reconnection.
+// connectStream attempts to establish a valid connection with the server. Due
+// to the potential high latency, initNewStream() should not be done while
+// holding retryableStream.mu. Hence we need to handle the stream being force
+// terminated during reconnection.
 //
-// Intended to be called in a goroutine. It ends once the connection has been
-// established or the stream terminated.
-func (rs *retryableStream) reconnect() {
+// Intended to be called in a goroutine. It ends once the client stream closes.
+func (rs *retryableStream) connectStream() {
 	canReconnect := func() bool {
 		rs.mu.Lock()
 		defer rs.mu.Unlock()
@@ -235,13 +240,14 @@
 		rs.status = streamConnected
 		rs.stream = newStream
 		rs.cancelStream = cancelFunc
-		go rs.listen(newStream)
 		return true
 	}
 	if !connected() {
 		return
 	}
+
 	rs.handler.onStreamStatusChange(streamConnected)
+	rs.listen(newStream)
 }
 
 func (rs *retryableStream) initNewStream() (newStream grpc.ClientStream, cancelFunc context.CancelFunc, err error) {
@@ -266,16 +272,19 @@
 			if err != nil {
 				return r.RetryRecv(err)
 			}
-			if err = newStream.SendMsg(rs.handler.initialRequest()); err != nil {
+			initReq, needsResponse := rs.handler.initialRequest()
+			if err = newStream.SendMsg(initReq); err != nil {
 				return r.RetrySend(err)
 			}
-			response := reflect.New(rs.responseType).Interface()
-			if err = newStream.RecvMsg(response); err != nil {
-				return r.RetryRecv(err)
-			}
-			if err = rs.handler.validateInitialResponse(response); err != nil {
-				// An unexpected initial response from the server is a permanent error.
-				return 0, false
+			if needsResponse {
+				response := reflect.New(rs.responseType).Interface()
+				if err = newStream.RecvMsg(response); err != nil {
+					return r.RetryRecv(err)
+				}
+				if err = rs.handler.validateInitialResponse(response); err != nil {
+					// An unexpected initial response from the server is a permanent error.
+					return 0, false
+				}
 			}
 
 			// We have a valid connection and should break from the outer loop.
@@ -285,6 +294,9 @@
 		if !shouldRetry {
 			break
 		}
+		if rs.Status() == streamTerminated {
+			break
+		}
 		if err = gax.Sleep(rs.ctx, backoff); err != nil {
 			break
 		}
@@ -294,8 +306,6 @@
 
 // listen receives responses from the current stream. It initiates reconnection
 // upon retryable errors or terminates the stream upon permanent error.
-//
-// Intended to be called in a goroutine. It ends when recvStream has closed.
 func (rs *retryableStream) listen(recvStream grpc.ClientStream) {
 	for {
 		response := reflect.New(rs.responseType).Interface()
@@ -309,7 +319,7 @@
 		}
 		if err != nil {
 			if isRetryableRecvError(err) {
-				go rs.reconnect()
+				go rs.connectStream()
 			} else {
 				rs.terminate(err)
 			}
diff --git a/pubsublite/internal/wire/streams_test.go b/pubsublite/internal/wire/streams_test.go
new file mode 100644
index 0000000..29b0bb5
--- /dev/null
+++ b/pubsublite/internal/wire/streams_test.go
@@ -0,0 +1,354 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+
+package wire
+
+import (
+	"context"
+	"errors"
+	"reflect"
+	"testing"
+	"time"
+
+	"cloud.google.com/go/internal/testutil"
+	"cloud.google.com/go/pubsublite/internal/test"
+	"google.golang.org/grpc"
+	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/status"
+
+	vkit "cloud.google.com/go/pubsublite/apiv1"
+	pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
+)
+
+const defaultStreamTimeout = 30 * time.Second
+
+var errInvalidInitialResponse = errors.New("invalid initial response")
+
+// testStreamHandler is a simplified publisher service that owns a
+// retryableStream.
+type testStreamHandler struct {
+	Topic      topicPartition
+	InitialReq *pb.PublishRequest
+	Stream     *retryableStream
+
+	t         *testing.T
+	statuses  chan streamStatus
+	responses chan interface{}
+	pubClient *vkit.PublisherClient
+}
+
+func newTestStreamHandler(t *testing.T, timeout time.Duration) *testStreamHandler {
+	ctx := context.Background()
+	pubClient, err := newPublisherClient(ctx, "ignored", testClientOpts...)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	topic := topicPartition{Path: "path/to/topic", Partition: 1}
+	sh := &testStreamHandler{
+		Topic:      topic,
+		InitialReq: initPubReq(topic),
+		t:          t,
+		statuses:   make(chan streamStatus, 3),
+		responses:  make(chan interface{}, 1),
+		pubClient:  pubClient,
+	}
+	sh.Stream = newRetryableStream(ctx, sh, timeout, reflect.TypeOf(pb.PublishResponse{}))
+	return sh
+}
+
+func (sh *testStreamHandler) NextStatus() streamStatus {
+	select {
+	case status := <-sh.statuses:
+		return status
+	case <-time.After(defaultStreamTimeout):
+		sh.t.Errorf("Stream did not change state within %v", defaultStreamTimeout)
+		return streamUninitialized
+	}
+}
+
+func (sh *testStreamHandler) NextResponse() interface{} {
+	select {
+	case response := <-sh.responses:
+		return response
+	case <-time.After(defaultStreamTimeout):
+		sh.t.Errorf("Stream did not receive response within %v", defaultStreamTimeout)
+		return nil
+	}
+}
+
+func (sh *testStreamHandler) newStream(ctx context.Context) (grpc.ClientStream, error) {
+	return sh.pubClient.Publish(ctx)
+}
+
+func (sh *testStreamHandler) validateInitialResponse(response interface{}) error {
+	pubResponse, _ := response.(*pb.PublishResponse)
+	if pubResponse.GetInitialResponse() == nil {
+		return errInvalidInitialResponse
+	}
+	return nil
+}
+
+func (sh *testStreamHandler) initialRequest() (interface{}, bool) {
+	return sh.InitialReq, true
+}
+
+func (sh *testStreamHandler) onStreamStatusChange(status streamStatus) {
+	sh.statuses <- status
+}
+
+func (sh *testStreamHandler) onResponse(response interface{}) {
+	sh.responses <- response
+}
+
+func TestRetryableStreamStartOnce(t *testing.T) {
+	pub := newTestStreamHandler(t, defaultStreamTimeout)
+
+	verifiers := test.NewVerifiers(t)
+	stream := test.NewRPCVerifier(t)
+	stream.Push(pub.InitialReq, initPubResp(), nil)
+	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream)
+
+	mockServer.OnTestStart(verifiers)
+	defer mockServer.OnTestEnd()
+
+	// Ensure that new streams are not opened if the publisher is started twice
+	// (note: only 1 stream verifier was added to the mock server above).
+	pub.Stream.Start()
+	pub.Stream.Start()
+	pub.Stream.Start()
+	if got, want := pub.NextStatus(), streamReconnecting; got != want {
+		t.Errorf("Stream status change: got %d, want %d", got, want)
+	}
+	if got, want := pub.NextStatus(), streamConnected; got != want {
+		t.Errorf("Stream status change: got %d, want %d", got, want)
+	}
+
+	pub.Stream.Stop()
+	if got, want := pub.NextStatus(), streamTerminated; got != want {
+		t.Errorf("Stream status change: got %d, want %d", got, want)
+	}
+	if gotErr := pub.Stream.Error(); gotErr != nil {
+		t.Errorf("Stream final err: got (%v), want <nil>", gotErr)
+	}
+}
+
+func TestRetryableStreamStopWhileConnecting(t *testing.T) {
+	pub := newTestStreamHandler(t, defaultStreamTimeout)
+
+	verifiers := test.NewVerifiers(t)
+	stream := test.NewRPCVerifier(t)
+	barrier := stream.PushWithBarrier(pub.InitialReq, initPubResp(), nil)
+	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream)
+
+	mockServer.OnTestStart(verifiers)
+	defer mockServer.OnTestEnd()
+
+	pub.Stream.Start()
+	if got, want := pub.NextStatus(), streamReconnecting; got != want {
+		t.Errorf("Stream status change: got %d, want %d", got, want)
+	}
+
+	barrier.Release()
+	pub.Stream.Stop()
+
+	// The stream should transition to terminated and the client stream should be
+	// discarded.
+	if got, want := pub.NextStatus(), streamTerminated; got != want {
+		t.Errorf("Stream status change: got %d, want %d", got, want)
+	}
+	if pub.Stream.currentStream() != nil {
+		t.Error("Client stream should be nil")
+	}
+	if gotErr := pub.Stream.Error(); gotErr != nil {
+		t.Errorf("Stream final err: got (%v), want <nil>", gotErr)
+	}
+}
+
+func TestRetryableStreamStopAbortsRetries(t *testing.T) {
+	pub := newTestStreamHandler(t, defaultStreamTimeout)
+
+	verifiers := test.NewVerifiers(t)
+	stream := test.NewRPCVerifier(t)
+	// Aborted is a retryable error, but the stream should not be retried because
+	// the publisher is stopped.
+	barrier := stream.PushWithBarrier(pub.InitialReq, nil, status.Error(codes.Aborted, "abort retry"))
+	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream)
+
+	mockServer.OnTestStart(verifiers)
+	defer mockServer.OnTestEnd()
+
+	pub.Stream.Start()
+	if got, want := pub.NextStatus(), streamReconnecting; got != want {
+		t.Errorf("Stream status change: got %d, want %d", got, want)
+	}
+
+	barrier.Release()
+	pub.Stream.Stop()
+
+	// The stream should transition to terminated and the client stream should be
+	// discarded.
+	if got, want := pub.NextStatus(), streamTerminated; got != want {
+		t.Errorf("Stream status change: got %d, want %d", got, want)
+	}
+	if pub.Stream.currentStream() != nil {
+		t.Error("Client stream should be nil")
+	}
+	if gotErr := pub.Stream.Error(); gotErr != nil {
+		t.Errorf("Stream final err: got (%v), want <nil>", gotErr)
+	}
+}
+
+func TestRetryableStreamConnectRetries(t *testing.T) {
+	pub := newTestStreamHandler(t, defaultStreamTimeout)
+
+	verifiers := test.NewVerifiers(t)
+
+	// First 2 errors are retryable.
+	stream1 := test.NewRPCVerifier(t)
+	stream1.Push(pub.InitialReq, nil, status.Error(codes.Unavailable, "server unavailable"))
+	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream1)
+
+	stream2 := test.NewRPCVerifier(t)
+	stream2.Push(pub.InitialReq, nil, status.Error(codes.Internal, "internal"))
+	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream2)
+
+	// Third stream should succeed.
+	stream3 := test.NewRPCVerifier(t)
+	stream3.Push(pub.InitialReq, initPubResp(), nil)
+	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream3)
+
+	mockServer.OnTestStart(verifiers)
+	defer mockServer.OnTestEnd()
+
+	pub.Stream.Start()
+	if got, want := pub.NextStatus(), streamReconnecting; got != want {
+		t.Errorf("Stream status change: got %d, want %d", got, want)
+	}
+	if got, want := pub.NextStatus(), streamConnected; got != want {
+		t.Errorf("Stream status change: got %d, want %d", got, want)
+	}
+
+	pub.Stream.Stop()
+	if got, want := pub.NextStatus(), streamTerminated; got != want {
+		t.Errorf("Stream status change: got %d, want %d", got, want)
+	}
+}
+
+func TestRetryableStreamConnectPermanentFailure(t *testing.T) {
+	pub := newTestStreamHandler(t, defaultStreamTimeout)
+	permanentErr := status.Error(codes.PermissionDenied, "denied")
+
+	verifiers := test.NewVerifiers(t)
+	// The stream connection results in a non-retryable error, so the publisher
+	// cannot start.
+	stream := test.NewRPCVerifier(t)
+	stream.Push(pub.InitialReq, nil, permanentErr)
+	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream)
+
+	mockServer.OnTestStart(verifiers)
+	defer mockServer.OnTestEnd()
+
+	pub.Stream.Start()
+	if got, want := pub.NextStatus(), streamReconnecting; got != want {
+		t.Errorf("Stream status change: got %d, want %d", got, want)
+	}
+	if got, want := pub.NextStatus(), streamTerminated; got != want {
+		t.Errorf("Stream status change: got %d, want %d", got, want)
+	}
+	if pub.Stream.currentStream() != nil {
+		t.Error("Client stream should be nil")
+	}
+	if gotErr := pub.Stream.Error(); !test.ErrorEqual(gotErr, permanentErr) {
+		t.Errorf("Stream final err: got (%v), want (%v)", gotErr, permanentErr)
+	}
+}
+
+func TestRetryableStreamConnectTimeout(t *testing.T) {
+	// Set a very low timeout to ensure no retries.
+	timeout := time.Millisecond
+	pub := newTestStreamHandler(t, timeout)
+	wantErr := status.Error(codes.DeadlineExceeded, "timeout")
+
+	verifiers := test.NewVerifiers(t)
+	stream := test.NewRPCVerifier(t)
+	barrier := stream.PushWithBarrier(pub.InitialReq, nil, wantErr)
+	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream)
+
+	mockServer.OnTestStart(verifiers)
+	defer mockServer.OnTestEnd()
+
+	pub.Stream.Start()
+	if got, want := pub.NextStatus(), streamReconnecting; got != want {
+		t.Errorf("Stream status change: got %d, want %d", got, want)
+	}
+
+	// Send the initial server response well after the timeout setting.
+	time.Sleep(10 * timeout)
+	barrier.Release()
+
+	if got, want := pub.NextStatus(), streamTerminated; got != want {
+		t.Errorf("Stream status change: got %d, want %d", got, want)
+	}
+	if pub.Stream.currentStream() != nil {
+		t.Error("Client stream should be nil")
+	}
+	if gotErr := pub.Stream.Error(); !test.ErrorEqual(gotErr, wantErr) {
+		t.Errorf("Stream final err: got (%v), want (%v)", gotErr, wantErr)
+	}
+}
+
+func TestRetryableStreamSendReceive(t *testing.T) {
+	pub := newTestStreamHandler(t, defaultStreamTimeout)
+	req := msgPubReq(&pb.PubSubMessage{Data: []byte("msg")})
+	wantResp := msgPubResp(5)
+
+	verifiers := test.NewVerifiers(t)
+	stream := test.NewRPCVerifier(t)
+	barrier := stream.PushWithBarrier(pub.InitialReq, initPubResp(), nil)
+	stream.Push(req, wantResp, nil)
+	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream)
+
+	mockServer.OnTestStart(verifiers)
+	defer mockServer.OnTestEnd()
+
+	pub.Stream.Start()
+	if got, want := pub.NextStatus(), streamReconnecting; got != want {
+		t.Errorf("Stream status change: got %d, want %d", got, want)
+	}
+
+	// While the stream is reconnecting, requests are discarded.
+	if got, want := pub.Stream.Send(req), false; got != want {
+		t.Errorf("Stream send: got %v, want %v", got, want)
+	}
+
+	barrier.Release()
+	if got, want := pub.NextStatus(), streamConnected; got != want {
+		t.Errorf("Stream status change: got %d, want %d", got, want)
+	}
+
+	if got, want := pub.Stream.Send(req), true; got != want {
+		t.Errorf("Stream send: got %v, want %v", got, want)
+	}
+	if gotResp := pub.NextResponse(); !testutil.Equal(gotResp, wantResp) {
+		t.Errorf("Stream response: got %v, want %v", gotResp, wantResp)
+	}
+
+	pub.Stream.Stop()
+	if got, want := pub.NextStatus(), streamTerminated; got != want {
+		t.Errorf("Stream status change: got %d, want %d", got, want)
+	}
+	if gotErr := pub.Stream.Error(); gotErr != nil {
+		t.Errorf("Stream final err: got (%v), want <nil>", gotErr)
+	}
+}