feat(pubsublite): Refactoring and unit tests for retryableStream (#3160)
Minor refactoring and unit tests for retryableStream. Eliminated test flakiness in mock server.
diff --git a/pubsublite/internal/test/mock.go b/pubsublite/internal/test/mock.go
index d27a437..bf74683 100644
--- a/pubsublite/internal/test/mock.go
+++ b/pubsublite/internal/test/mock.go
@@ -15,12 +15,12 @@
import (
"context"
- "fmt"
"io"
"reflect"
"sync"
"cloud.google.com/go/internal/testutil"
+ "cloud.google.com/go/internal/uid"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
@@ -40,20 +40,11 @@
// This is the interface that should be used by tests.
type MockServer interface {
// OnTestStart must be called at the start of each test to clear any existing
- // state and set the verifier for unary RPCs.
- OnTestStart(globalVerifier *RPCVerifier)
+ // state and set the test verifiers.
+ OnTestStart(*Verifiers)
// OnTestEnd should be called at the end of each test to flush the verifiers
// (i.e. check whether any expected requests were not sent to the server).
OnTestEnd()
- // AddPublishStream adds a verifier for a publish stream of a topic partition.
- AddPublishStream(topic string, partition int, streamVerifier *RPCVerifier)
- // AddSubscribeStream adds a verifier for a subscribe stream of a partition.
- AddSubscribeStream(subscription string, partition int, streamVerifier *RPCVerifier)
- // AddCommitStream adds a verifier for a commit stream of a partition.
- AddCommitStream(subscription string, partition int, streamVerifier *RPCVerifier)
- // AddAssignmentStream adds a verifier for a partition assignment stream for a
- // subscription.
- AddAssignmentStream(subscription string, streamVerifier *RPCVerifier)
}
// NewServer creates a new mock Pub/Sub Lite server.
@@ -97,79 +88,63 @@
mu sync.Mutex
- // Global list of verifiers for all unary RPCs. This should be set before the
- // test begins.
- globalVerifier *RPCVerifier
-
- // Stream verifiers by key.
- publishVerifiers *keyedStreamVerifiers
- subscribeVerifiers *keyedStreamVerifiers
- commitVerifiers *keyedStreamVerifiers
- assignmentVerifiers *keyedStreamVerifiers
-
- nextStreamID int
- activeStreams map[int]*streamHolder
- testActive bool
-}
-
-func key(path string, partition int) string {
- return fmt.Sprintf("%s:%d", path, partition)
+ testVerifiers *Verifiers
+ testIDs *uid.Space
+ currentTestID string
}
func newMockLiteServer() *mockLiteServer {
return &mockLiteServer{
- publishVerifiers: newKeyedStreamVerifiers(),
- subscribeVerifiers: newKeyedStreamVerifiers(),
- commitVerifiers: newKeyedStreamVerifiers(),
- assignmentVerifiers: newKeyedStreamVerifiers(),
- activeStreams: make(map[int]*streamHolder),
+ testIDs: uid.NewSpace("mockLiteServer", nil),
}
}
-func (s *mockLiteServer) startStream(stream grpc.ServerStream, verifier *RPCVerifier) (id int) {
+func (s *mockLiteServer) popGlobalVerifiers(request interface{}) (interface{}, error) {
s.mu.Lock()
defer s.mu.Unlock()
- id = s.nextStreamID
- s.nextStreamID++
- s.activeStreams[id] = &streamHolder{stream: stream, verifier: verifier}
- return
+ if s.testVerifiers == nil {
+ return nil, status.Errorf(codes.FailedPrecondition, "mockserver: previous test has ended")
+ }
+ return s.testVerifiers.GlobalVerifier.Pop(request)
}
-func (s *mockLiteServer) endStream(id int) {
+func (s *mockLiteServer) popStreamVerifier(key string) (*RPCVerifier, error) {
s.mu.Lock()
defer s.mu.Unlock()
- delete(s.activeStreams, id)
+ if s.testVerifiers == nil {
+ return nil, status.Errorf(codes.FailedPrecondition, "mockserver: previous test has ended")
+ }
+ return s.testVerifiers.streamVerifiers.Pop(key)
}
-func (s *mockLiteServer) popStreamVerifier(key string, keyedVerifiers *keyedStreamVerifiers) (*RPCVerifier, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- return keyedVerifiers.Pop(key)
-}
-
-func (s *mockLiteServer) handleStream(stream grpc.ServerStream, req interface{}, requestType reflect.Type, key string, keyedVerifiers *keyedStreamVerifiers) (err error) {
- verifier, err := s.popStreamVerifier(key, keyedVerifiers)
+func (s *mockLiteServer) handleStream(stream grpc.ServerStream, req interface{}, requestType reflect.Type, key string) (err error) {
+ testID := s.currentTest()
+ if testID == "" {
+ return status.Errorf(codes.FailedPrecondition, "mockserver: previous test has ended")
+ }
+ verifier, err := s.popStreamVerifier(key)
if err != nil {
return err
}
- id := s.startStream(stream, verifier)
-
// Verify initial request.
retResponse, retErr := verifier.Pop(req)
var ok bool
for {
+ // See comments for RPCVerifier.Push for valid stream request/response
+ // combinations.
if retErr != nil {
err = retErr
break
}
- if err = stream.SendMsg(retResponse); err != nil {
- err = status.Errorf(codes.FailedPrecondition, "mockserver: stream send error: %v", err)
- break
+ if retResponse != nil {
+ if err = stream.SendMsg(retResponse); err != nil {
+ err = status.Errorf(codes.FailedPrecondition, "mockserver: stream send error: %v", err)
+ break
+ }
}
// Check whether the next response isn't blocked on a request.
@@ -185,70 +160,47 @@
err = status.Errorf(codes.FailedPrecondition, "mockserver: stream recv error: %v", err)
break
}
+ if testID != s.currentTest() {
+ err = status.Errorf(codes.FailedPrecondition, "mockserver: previous test has ended")
+ break
+ }
retResponse, retErr = verifier.Pop(req)
}
// Check whether the stream ended prematurely.
- verifier.Flush()
- s.endStream(id)
+ if testID == s.currentTest() {
+ verifier.Flush()
+ }
return
}
// MockServer implementation.
-func (s *mockLiteServer) OnTestStart(globalVerifier *RPCVerifier) {
+func (s *mockLiteServer) OnTestStart(verifiers *Verifiers) {
s.mu.Lock()
defer s.mu.Unlock()
- if s.testActive {
+ if s.currentTestID != "" {
panic("mockserver is already in use by another test")
}
-
- s.testActive = true
- s.globalVerifier = globalVerifier
- s.publishVerifiers.Reset()
- s.subscribeVerifiers.Reset()
- s.commitVerifiers.Reset()
- s.assignmentVerifiers.Reset()
- s.activeStreams = make(map[int]*streamHolder)
+ s.currentTestID = s.testIDs.New()
+ s.testVerifiers = verifiers
}
func (s *mockLiteServer) OnTestEnd() {
s.mu.Lock()
defer s.mu.Unlock()
- s.testActive = false
- if s.globalVerifier != nil {
- s.globalVerifier.Flush()
- }
-
- for _, as := range s.activeStreams {
- as.verifier.Flush()
+ s.currentTestID = ""
+ if s.testVerifiers != nil {
+ s.testVerifiers.flush()
}
}
-func (s *mockLiteServer) AddPublishStream(topic string, partition int, streamVerifier *RPCVerifier) {
+func (s *mockLiteServer) currentTest() string {
s.mu.Lock()
defer s.mu.Unlock()
- s.publishVerifiers.Push(key(topic, partition), streamVerifier)
-}
-
-func (s *mockLiteServer) AddSubscribeStream(subscription string, partition int, streamVerifier *RPCVerifier) {
- s.mu.Lock()
- defer s.mu.Unlock()
- s.subscribeVerifiers.Push(key(subscription, partition), streamVerifier)
-}
-
-func (s *mockLiteServer) AddCommitStream(subscription string, partition int, streamVerifier *RPCVerifier) {
- s.mu.Lock()
- defer s.mu.Unlock()
- s.commitVerifiers.Push(key(subscription, partition), streamVerifier)
-}
-
-func (s *mockLiteServer) AddAssignmentStream(subscription string, streamVerifier *RPCVerifier) {
- s.mu.Lock()
- defer s.mu.Unlock()
- s.assignmentVerifiers.Push(subscription, streamVerifier)
+ return s.currentTestID
}
// PublisherService implementation.
@@ -263,8 +215,8 @@
}
initReq := req.GetInitialRequest()
- k := key(initReq.GetTopic(), int(initReq.GetPartition()))
- return s.handleStream(stream, req, reflect.TypeOf(pb.PublishRequest{}), k, s.publishVerifiers)
+ k := keyPartition(publishStreamType, initReq.GetTopic(), int(initReq.GetPartition()))
+ return s.handleStream(stream, req, reflect.TypeOf(pb.PublishRequest{}), k)
}
// SubscriberService implementation.
@@ -279,8 +231,8 @@
}
initReq := req.GetInitial()
- k := key(initReq.GetSubscription(), int(initReq.GetPartition()))
- return s.handleStream(stream, req, reflect.TypeOf(pb.SubscribeRequest{}), k, s.subscribeVerifiers)
+ k := keyPartition(subscribeStreamType, initReq.GetSubscription(), int(initReq.GetPartition()))
+ return s.handleStream(stream, req, reflect.TypeOf(pb.SubscribeRequest{}), k)
}
// CursorService implementation.
@@ -295,8 +247,8 @@
}
initReq := req.GetInitial()
- k := key(initReq.GetSubscription(), int(initReq.GetPartition()))
- return s.handleStream(stream, req, reflect.TypeOf(pb.StreamingCommitCursorRequest{}), k, s.commitVerifiers)
+ k := keyPartition(commitStreamType, initReq.GetSubscription(), int(initReq.GetPartition()))
+ return s.handleStream(stream, req, reflect.TypeOf(pb.StreamingCommitCursorRequest{}), k)
}
// PartitionAssignmentService implementation.
@@ -310,17 +262,14 @@
return status.Errorf(codes.InvalidArgument, "mockserver: received invalid initial partition assignment request: %v", req)
}
- k := req.GetInitial().GetSubscription()
- return s.handleStream(stream, req, reflect.TypeOf(pb.PartitionAssignmentRequest{}), k, s.assignmentVerifiers)
+ k := key(assignmentStreamType, req.GetInitial().GetSubscription())
+ return s.handleStream(stream, req, reflect.TypeOf(pb.PartitionAssignmentRequest{}), k)
}
// AdminService implementation.
func (s *mockLiteServer) GetTopicPartitions(ctx context.Context, req *pb.GetTopicPartitionsRequest) (*pb.TopicPartitions, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- retResponse, retErr := s.globalVerifier.Pop(req)
+ retResponse, retErr := s.popGlobalVerifiers(req)
if retErr != nil {
return nil, retErr
}
diff --git a/pubsublite/internal/test/verifier.go b/pubsublite/internal/test/verifier.go
index a1e2681..3f57e5e 100644
--- a/pubsublite/internal/test/verifier.go
+++ b/pubsublite/internal/test/verifier.go
@@ -15,6 +15,7 @@
import (
"container/list"
+ "fmt"
"sync"
"testing"
"time"
@@ -30,28 +31,72 @@
blockWaitTimeout = 30 * time.Second
)
-type rpcMetadata struct {
- wantRequest interface{}
- retResponse interface{}
- retErr error
- blockResponse chan struct{}
+// Barrier is used to perform two-way synchronization betwen the server and
+// client (test) to ensure tests are deterministic.
+type Barrier struct {
+ // Used to block until the server is ready to send the response.
+ serverBlock chan struct{}
+ // Used to block until the client wants the server to send the response.
+ clientBlock chan struct{}
+ err error
}
-// wait until the `blockResponse` is released by the test, or a timeout occurs.
-// Returns immediately if there was no block.
-func (r *rpcMetadata) wait() error {
- if r.blockResponse == nil {
- return nil
+func newBarrier() *Barrier {
+ return &Barrier{
+ serverBlock: make(chan struct{}),
+ clientBlock: make(chan struct{}),
}
+}
+
+// Release should be called by the test.
+func (b *Barrier) Release() {
+ // Wait for the server to reach the barrier.
+ select {
+ case <-time.After(blockWaitTimeout):
+ // Note: avoid returning a retryable code to quickly terminate the test.
+ b.err = status.Errorf(codes.FailedPrecondition, "mockserver: server did not reach barrier within %v", blockWaitTimeout)
+ case <-b.serverBlock:
+ }
+
+ // Then close the client block.
+ close(b.clientBlock)
+}
+
+func (b *Barrier) serverWait() error {
+ if b.err != nil {
+ return b.err
+ }
+
+ // Close the server block to signal the server reaching the point where it is
+ // ready to send the response.
+ close(b.serverBlock)
+
+ // Wait for the test to release the client block.
select {
case <-time.After(blockWaitTimeout):
// Note: avoid returning a retryable code to quickly terminate the test.
return status.Errorf(codes.FailedPrecondition, "mockserver: test did not unblock response within %v", blockWaitTimeout)
- case <-r.blockResponse:
+ case <-b.clientBlock:
return nil
}
}
+type rpcMetadata struct {
+ wantRequest interface{}
+ retResponse interface{}
+ retErr error
+ barrier *Barrier
+}
+
+// wait until the barrier is released by the test, or a timeout occurs.
+// Returns immediately if there was no block.
+func (r *rpcMetadata) wait() error {
+ if r.barrier == nil {
+ return nil
+ }
+ return r.barrier.serverWait()
+}
+
// RPCVerifier stores an queue of requests expected from the client, and the
// corresponding response or error to return.
type RPCVerifier struct {
@@ -71,6 +116,15 @@
}
// Push appends a new {request, response, error} tuple.
+//
+// Valid combinations for unary and streaming RPCs:
+// - {request, response, nil}
+// - {request, nil, error}
+//
+// Additional combinations for streams only:
+// - {nil, response, nil}: send a response without a request (e.g. messages).
+// - {nil, nil, error}: break the stream without a request.
+// - {request, nil, nil}: expect a request, but don't send any response.
func (v *RPCVerifier) Push(wantRequest interface{}, retResponse interface{}, retErr error) {
v.mu.Lock()
defer v.mu.Unlock()
@@ -82,21 +136,21 @@
})
}
-// PushWithBlock is like Push, but returns a channel that the test should close
-// when it would like the response to be sent to the client. This is useful for
-// synchronizing with work that needs to be done on the client.
-func (v *RPCVerifier) PushWithBlock(wantRequest interface{}, retResponse interface{}, retErr error) chan struct{} {
+// PushWithBarrier is like Push, but returns a barrier that the test should call
+// Release when it would like the response to be sent to the client. This is
+// useful for synchronizing with work that needs to be done on the client.
+func (v *RPCVerifier) PushWithBarrier(wantRequest interface{}, retResponse interface{}, retErr error) *Barrier {
v.mu.Lock()
defer v.mu.Unlock()
- block := make(chan struct{})
+ barrier := newBarrier()
v.rpcs.PushBack(&rpcMetadata{
- wantRequest: wantRequest,
- retResponse: retResponse,
- retErr: retErr,
- blockResponse: block,
+ wantRequest: wantRequest,
+ retResponse: retResponse,
+ retErr: retErr,
+ barrier: barrier,
})
- return block
+ return barrier
}
// Pop validates the received request with the next {request, response, error}
@@ -158,7 +212,11 @@
for elem := v.rpcs.Front(); elem != nil; elem = elem.Next() {
v.numCalls++
rpc, _ := elem.Value.(*rpcMetadata)
- v.t.Errorf("call(%d): did not receive expected request:\n%v", v.numCalls, rpc.wantRequest)
+ if rpc.wantRequest != nil {
+ v.t.Errorf("call(%d): did not receive expected request:\n%v", v.numCalls, rpc.wantRequest)
+ } else {
+ v.t.Errorf("call(%d): unsent response:\n%v, err = (%v)", v.numCalls, rpc.retResponse, rpc.retErr)
+ }
}
v.rpcs.Init()
}
@@ -195,7 +253,15 @@
return v, nil
}
-// keyedStreamVerifiers stores indexed streamVerifiers.
+func (sv *streamVerifiers) Flush() {
+ for elem := sv.verifiers.Front(); elem != nil; elem = elem.Next() {
+ v, _ := elem.Value.(*RPCVerifier)
+ v.Flush()
+ }
+}
+
+// keyedStreamVerifiers stores indexed streamVerifiers. Examples of keys:
+// "streamType:topic_path:partition".
type keyedStreamVerifiers struct {
verifiers map[string]*streamVerifiers
}
@@ -204,10 +270,6 @@
return &keyedStreamVerifiers{verifiers: make(map[string]*streamVerifiers)}
}
-func (kv *keyedStreamVerifiers) Reset() {
- kv.verifiers = make(map[string]*streamVerifiers)
-}
-
func (kv *keyedStreamVerifiers) Push(key string, v *RPCVerifier) {
sv, ok := kv.verifiers[key]
if !ok {
@@ -224,3 +286,97 @@
}
return sv.Pop()
}
+
+func (kv *keyedStreamVerifiers) Flush() {
+ for _, sv := range kv.verifiers {
+ sv.Flush()
+ }
+}
+
+// Verifiers contains RPCVerifiers for unary RPCs and streaming RPCs.
+type Verifiers struct {
+ t *testing.T
+ mu sync.Mutex
+
+ // Global list of verifiers for all unary RPCs.
+ GlobalVerifier *RPCVerifier
+ // Stream verifiers by key.
+ streamVerifiers *keyedStreamVerifiers
+ activeStreamVerifiers []*RPCVerifier
+}
+
+// NewVerifiers creates a new instance of Verifiers for a test.
+func NewVerifiers(t *testing.T) *Verifiers {
+ return &Verifiers{
+ t: t,
+ GlobalVerifier: NewRPCVerifier(t),
+ streamVerifiers: newKeyedStreamVerifiers(),
+ }
+}
+
+// streamType is used as a key prefix for keyedStreamVerifiers.
+type streamType string
+
+const (
+ publishStreamType streamType = "publish"
+ subscribeStreamType streamType = "subscribe"
+ commitStreamType streamType = "commit"
+ assignmentStreamType streamType = "assignment"
+)
+
+func keyPartition(st streamType, path string, partition int) string {
+ return fmt.Sprintf("%s:%s:%d", st, path, partition)
+}
+
+func key(st streamType, path string) string {
+ return fmt.Sprintf("%s:%s", st, path)
+}
+
+// AddPublishStream adds verifiers for a publish stream.
+func (tv *Verifiers) AddPublishStream(topic string, partition int, streamVerifier *RPCVerifier) {
+ tv.mu.Lock()
+ defer tv.mu.Unlock()
+ tv.streamVerifiers.Push(keyPartition(publishStreamType, topic, partition), streamVerifier)
+}
+
+// AddSubscribeStream adds verifiers for a subscribe stream.
+func (tv *Verifiers) AddSubscribeStream(subscription string, partition int, streamVerifier *RPCVerifier) {
+ tv.mu.Lock()
+ defer tv.mu.Unlock()
+ tv.streamVerifiers.Push(keyPartition(subscribeStreamType, subscription, partition), streamVerifier)
+}
+
+// AddCommitStream adds verifiers for a commit stream.
+func (tv *Verifiers) AddCommitStream(subscription string, partition int, streamVerifier *RPCVerifier) {
+ tv.mu.Lock()
+ defer tv.mu.Unlock()
+ tv.streamVerifiers.Push(keyPartition(commitStreamType, subscription, partition), streamVerifier)
+}
+
+// AddAssignmentStream adds verifiers for an assignment stream.
+func (tv *Verifiers) AddAssignmentStream(subscription string, streamVerifier *RPCVerifier) {
+ tv.mu.Lock()
+ defer tv.mu.Unlock()
+ tv.streamVerifiers.Push(key(assignmentStreamType, subscription), streamVerifier)
+}
+
+func (tv *Verifiers) popStreamVerifier(key string) (*RPCVerifier, error) {
+ tv.mu.Lock()
+ defer tv.mu.Unlock()
+ v, err := tv.streamVerifiers.Pop(key)
+ if v != nil {
+ tv.activeStreamVerifiers = append(tv.activeStreamVerifiers, v)
+ }
+ return v, err
+}
+
+func (tv *Verifiers) flush() {
+ tv.mu.Lock()
+ defer tv.mu.Unlock()
+
+ tv.GlobalVerifier.Flush()
+ tv.streamVerifiers.Flush()
+ for _, v := range tv.activeStreamVerifiers {
+ v.Flush()
+ }
+}
diff --git a/pubsublite/internal/wire/main_test.go b/pubsublite/internal/wire/main_test.go
new file mode 100644
index 0000000..0f9b401
--- /dev/null
+++ b/pubsublite/internal/wire/main_test.go
@@ -0,0 +1,50 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+
+package wire
+
+import (
+ "flag"
+ "log"
+ "os"
+ "testing"
+
+ "cloud.google.com/go/pubsublite/internal/test"
+ "google.golang.org/api/option"
+ "google.golang.org/grpc"
+)
+
+var (
+ // Initialized in TestMain.
+ mockServer test.MockServer
+ testClientOpts []option.ClientOption
+)
+
+func TestMain(m *testing.M) {
+ flag.Parse()
+
+ testServer, err := test.NewServer()
+ if err != nil {
+ log.Fatal(err)
+ }
+ mockServer = testServer.LiteServer
+ conn, err := grpc.Dial(testServer.Addr(), grpc.WithInsecure())
+ if err != nil {
+ log.Fatal(err)
+ }
+ testClientOpts = []option.ClientOption{option.WithGRPCConn(conn)}
+
+ exit := m.Run()
+ testServer.Close()
+ os.Exit(exit)
+}
diff --git a/pubsublite/internal/wire/streams.go b/pubsublite/internal/wire/streams.go
index 1df7ed3..9bbc1ed 100644
--- a/pubsublite/internal/wire/streams.go
+++ b/pubsublite/internal/wire/streams.go
@@ -51,14 +51,21 @@
// newStream implementations must create the client stream with the given
// (cancellable) context.
newStream(context.Context) (grpc.ClientStream, error)
- initialRequest() interface{}
+ // initialRequest should return the initial request and whether an initial
+ // response is expected.
+ initialRequest() (interface{}, bool)
validateInitialResponse(interface{}) error
// onStreamStatusChange is used to notify stream handlers when the stream has
- // changed state. In particular, the `streamTerminated` state must be handled.
- // retryableStream.Error() returns the error that caused the stream to
- // terminate. Stream handlers should perform any necessary reset of state upon
- // `streamConnected`.
+ // changed state. A `streamReconnecting` status change is fired before
+ // attempting to connect a new stream. A `streamConnected` status change is
+ // fired when the stream is successfully connected. These are followed by
+ // onResponse() calls when responses are received from the server. These
+ // events are guaranteed to occur in this order.
+ //
+ // A final `streamTerminated` status change is fired when a permanent error
+ // occurs. retryableStream.Error() returns the error that caused the stream to
+ // terminate.
onStreamStatusChange(streamStatus)
// onResponse forwards a response received on the stream to the stream
// handler.
@@ -68,10 +75,9 @@
// retryableStream is a wrapper around a bidirectional gRPC client stream to
// handle automatic reconnection when the stream breaks.
//
-// A retryableStream cycles between the following goroutines:
-// Start() --> reconnect() <--> listen()
-// terminate() can be called at any time, either by the client to force stream
-// closure, or as a result of an unretryable error.
+// The connectStream() goroutine handles each stream connection. terminate() can
+// be called at any time, either by the client to force stream closure, or as a
+// result of an unretryable error.
//
// Safe to call capitalized methods from multiple goroutines. All other methods
// are private implementation.
@@ -114,7 +120,7 @@
if rs.status != streamUninitialized {
return
}
- go rs.reconnect()
+ go rs.connectStream()
}
// Stop gracefully closes the stream without error.
@@ -139,7 +145,7 @@
// stream. Nothing to do here.
break
case isRetryableSendError(err):
- go rs.reconnect()
+ go rs.connectStream()
default:
rs.mu.Unlock() // terminate acquires the mutex.
rs.terminate(err)
@@ -190,14 +196,13 @@
rs.cancelStream = cancel
}
-// reconnect attempts to establish a valid connection with the server. Due to
-// the potential high latency, initNewStream() should not be done while holding
-// retryableStream.mu. Hence we need to handle the stream being force terminated
-// during reconnection.
+// connectStream attempts to establish a valid connection with the server. Due
+// to the potential high latency, initNewStream() should not be done while
+// holding retryableStream.mu. Hence we need to handle the stream being force
+// terminated during reconnection.
//
-// Intended to be called in a goroutine. It ends once the connection has been
-// established or the stream terminated.
-func (rs *retryableStream) reconnect() {
+// Intended to be called in a goroutine. It ends once the client stream closes.
+func (rs *retryableStream) connectStream() {
canReconnect := func() bool {
rs.mu.Lock()
defer rs.mu.Unlock()
@@ -235,13 +240,14 @@
rs.status = streamConnected
rs.stream = newStream
rs.cancelStream = cancelFunc
- go rs.listen(newStream)
return true
}
if !connected() {
return
}
+
rs.handler.onStreamStatusChange(streamConnected)
+ rs.listen(newStream)
}
func (rs *retryableStream) initNewStream() (newStream grpc.ClientStream, cancelFunc context.CancelFunc, err error) {
@@ -266,16 +272,19 @@
if err != nil {
return r.RetryRecv(err)
}
- if err = newStream.SendMsg(rs.handler.initialRequest()); err != nil {
+ initReq, needsResponse := rs.handler.initialRequest()
+ if err = newStream.SendMsg(initReq); err != nil {
return r.RetrySend(err)
}
- response := reflect.New(rs.responseType).Interface()
- if err = newStream.RecvMsg(response); err != nil {
- return r.RetryRecv(err)
- }
- if err = rs.handler.validateInitialResponse(response); err != nil {
- // An unexpected initial response from the server is a permanent error.
- return 0, false
+ if needsResponse {
+ response := reflect.New(rs.responseType).Interface()
+ if err = newStream.RecvMsg(response); err != nil {
+ return r.RetryRecv(err)
+ }
+ if err = rs.handler.validateInitialResponse(response); err != nil {
+ // An unexpected initial response from the server is a permanent error.
+ return 0, false
+ }
}
// We have a valid connection and should break from the outer loop.
@@ -285,6 +294,9 @@
if !shouldRetry {
break
}
+ if rs.Status() == streamTerminated {
+ break
+ }
if err = gax.Sleep(rs.ctx, backoff); err != nil {
break
}
@@ -294,8 +306,6 @@
// listen receives responses from the current stream. It initiates reconnection
// upon retryable errors or terminates the stream upon permanent error.
-//
-// Intended to be called in a goroutine. It ends when recvStream has closed.
func (rs *retryableStream) listen(recvStream grpc.ClientStream) {
for {
response := reflect.New(rs.responseType).Interface()
@@ -309,7 +319,7 @@
}
if err != nil {
if isRetryableRecvError(err) {
- go rs.reconnect()
+ go rs.connectStream()
} else {
rs.terminate(err)
}
diff --git a/pubsublite/internal/wire/streams_test.go b/pubsublite/internal/wire/streams_test.go
new file mode 100644
index 0000000..29b0bb5
--- /dev/null
+++ b/pubsublite/internal/wire/streams_test.go
@@ -0,0 +1,354 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+
+package wire
+
+import (
+ "context"
+ "errors"
+ "reflect"
+ "testing"
+ "time"
+
+ "cloud.google.com/go/internal/testutil"
+ "cloud.google.com/go/pubsublite/internal/test"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ vkit "cloud.google.com/go/pubsublite/apiv1"
+ pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
+)
+
+const defaultStreamTimeout = 30 * time.Second
+
+var errInvalidInitialResponse = errors.New("invalid initial response")
+
+// testStreamHandler is a simplified publisher service that owns a
+// retryableStream.
+type testStreamHandler struct {
+ Topic topicPartition
+ InitialReq *pb.PublishRequest
+ Stream *retryableStream
+
+ t *testing.T
+ statuses chan streamStatus
+ responses chan interface{}
+ pubClient *vkit.PublisherClient
+}
+
+func newTestStreamHandler(t *testing.T, timeout time.Duration) *testStreamHandler {
+ ctx := context.Background()
+ pubClient, err := newPublisherClient(ctx, "ignored", testClientOpts...)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ topic := topicPartition{Path: "path/to/topic", Partition: 1}
+ sh := &testStreamHandler{
+ Topic: topic,
+ InitialReq: initPubReq(topic),
+ t: t,
+ statuses: make(chan streamStatus, 3),
+ responses: make(chan interface{}, 1),
+ pubClient: pubClient,
+ }
+ sh.Stream = newRetryableStream(ctx, sh, timeout, reflect.TypeOf(pb.PublishResponse{}))
+ return sh
+}
+
+func (sh *testStreamHandler) NextStatus() streamStatus {
+ select {
+ case status := <-sh.statuses:
+ return status
+ case <-time.After(defaultStreamTimeout):
+ sh.t.Errorf("Stream did not change state within %v", defaultStreamTimeout)
+ return streamUninitialized
+ }
+}
+
+func (sh *testStreamHandler) NextResponse() interface{} {
+ select {
+ case response := <-sh.responses:
+ return response
+ case <-time.After(defaultStreamTimeout):
+ sh.t.Errorf("Stream did not receive response within %v", defaultStreamTimeout)
+ return nil
+ }
+}
+
+func (sh *testStreamHandler) newStream(ctx context.Context) (grpc.ClientStream, error) {
+ return sh.pubClient.Publish(ctx)
+}
+
+func (sh *testStreamHandler) validateInitialResponse(response interface{}) error {
+ pubResponse, _ := response.(*pb.PublishResponse)
+ if pubResponse.GetInitialResponse() == nil {
+ return errInvalidInitialResponse
+ }
+ return nil
+}
+
+func (sh *testStreamHandler) initialRequest() (interface{}, bool) {
+ return sh.InitialReq, true
+}
+
+func (sh *testStreamHandler) onStreamStatusChange(status streamStatus) {
+ sh.statuses <- status
+}
+
+func (sh *testStreamHandler) onResponse(response interface{}) {
+ sh.responses <- response
+}
+
+func TestRetryableStreamStartOnce(t *testing.T) {
+ pub := newTestStreamHandler(t, defaultStreamTimeout)
+
+ verifiers := test.NewVerifiers(t)
+ stream := test.NewRPCVerifier(t)
+ stream.Push(pub.InitialReq, initPubResp(), nil)
+ verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream)
+
+ mockServer.OnTestStart(verifiers)
+ defer mockServer.OnTestEnd()
+
+ // Ensure that new streams are not opened if the publisher is started twice
+ // (note: only 1 stream verifier was added to the mock server above).
+ pub.Stream.Start()
+ pub.Stream.Start()
+ pub.Stream.Start()
+ if got, want := pub.NextStatus(), streamReconnecting; got != want {
+ t.Errorf("Stream status change: got %d, want %d", got, want)
+ }
+ if got, want := pub.NextStatus(), streamConnected; got != want {
+ t.Errorf("Stream status change: got %d, want %d", got, want)
+ }
+
+ pub.Stream.Stop()
+ if got, want := pub.NextStatus(), streamTerminated; got != want {
+ t.Errorf("Stream status change: got %d, want %d", got, want)
+ }
+ if gotErr := pub.Stream.Error(); gotErr != nil {
+ t.Errorf("Stream final err: got (%v), want <nil>", gotErr)
+ }
+}
+
+func TestRetryableStreamStopWhileConnecting(t *testing.T) {
+ pub := newTestStreamHandler(t, defaultStreamTimeout)
+
+ verifiers := test.NewVerifiers(t)
+ stream := test.NewRPCVerifier(t)
+ barrier := stream.PushWithBarrier(pub.InitialReq, initPubResp(), nil)
+ verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream)
+
+ mockServer.OnTestStart(verifiers)
+ defer mockServer.OnTestEnd()
+
+ pub.Stream.Start()
+ if got, want := pub.NextStatus(), streamReconnecting; got != want {
+ t.Errorf("Stream status change: got %d, want %d", got, want)
+ }
+
+ barrier.Release()
+ pub.Stream.Stop()
+
+ // The stream should transition to terminated and the client stream should be
+ // discarded.
+ if got, want := pub.NextStatus(), streamTerminated; got != want {
+ t.Errorf("Stream status change: got %d, want %d", got, want)
+ }
+ if pub.Stream.currentStream() != nil {
+ t.Error("Client stream should be nil")
+ }
+ if gotErr := pub.Stream.Error(); gotErr != nil {
+ t.Errorf("Stream final err: got (%v), want <nil>", gotErr)
+ }
+}
+
+func TestRetryableStreamStopAbortsRetries(t *testing.T) {
+ pub := newTestStreamHandler(t, defaultStreamTimeout)
+
+ verifiers := test.NewVerifiers(t)
+ stream := test.NewRPCVerifier(t)
+ // Aborted is a retryable error, but the stream should not be retried because
+ // the publisher is stopped.
+ barrier := stream.PushWithBarrier(pub.InitialReq, nil, status.Error(codes.Aborted, "abort retry"))
+ verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream)
+
+ mockServer.OnTestStart(verifiers)
+ defer mockServer.OnTestEnd()
+
+ pub.Stream.Start()
+ if got, want := pub.NextStatus(), streamReconnecting; got != want {
+ t.Errorf("Stream status change: got %d, want %d", got, want)
+ }
+
+ barrier.Release()
+ pub.Stream.Stop()
+
+ // The stream should transition to terminated and the client stream should be
+ // discarded.
+ if got, want := pub.NextStatus(), streamTerminated; got != want {
+ t.Errorf("Stream status change: got %d, want %d", got, want)
+ }
+ if pub.Stream.currentStream() != nil {
+ t.Error("Client stream should be nil")
+ }
+ if gotErr := pub.Stream.Error(); gotErr != nil {
+ t.Errorf("Stream final err: got (%v), want <nil>", gotErr)
+ }
+}
+
+func TestRetryableStreamConnectRetries(t *testing.T) {
+ pub := newTestStreamHandler(t, defaultStreamTimeout)
+
+ verifiers := test.NewVerifiers(t)
+
+ // First 2 errors are retryable.
+ stream1 := test.NewRPCVerifier(t)
+ stream1.Push(pub.InitialReq, nil, status.Error(codes.Unavailable, "server unavailable"))
+ verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream1)
+
+ stream2 := test.NewRPCVerifier(t)
+ stream2.Push(pub.InitialReq, nil, status.Error(codes.Internal, "internal"))
+ verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream2)
+
+ // Third stream should succeed.
+ stream3 := test.NewRPCVerifier(t)
+ stream3.Push(pub.InitialReq, initPubResp(), nil)
+ verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream3)
+
+ mockServer.OnTestStart(verifiers)
+ defer mockServer.OnTestEnd()
+
+ pub.Stream.Start()
+ if got, want := pub.NextStatus(), streamReconnecting; got != want {
+ t.Errorf("Stream status change: got %d, want %d", got, want)
+ }
+ if got, want := pub.NextStatus(), streamConnected; got != want {
+ t.Errorf("Stream status change: got %d, want %d", got, want)
+ }
+
+ pub.Stream.Stop()
+ if got, want := pub.NextStatus(), streamTerminated; got != want {
+ t.Errorf("Stream status change: got %d, want %d", got, want)
+ }
+}
+
+func TestRetryableStreamConnectPermanentFailure(t *testing.T) {
+ pub := newTestStreamHandler(t, defaultStreamTimeout)
+ permanentErr := status.Error(codes.PermissionDenied, "denied")
+
+ verifiers := test.NewVerifiers(t)
+ // The stream connection results in a non-retryable error, so the publisher
+ // cannot start.
+ stream := test.NewRPCVerifier(t)
+ stream.Push(pub.InitialReq, nil, permanentErr)
+ verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream)
+
+ mockServer.OnTestStart(verifiers)
+ defer mockServer.OnTestEnd()
+
+ pub.Stream.Start()
+ if got, want := pub.NextStatus(), streamReconnecting; got != want {
+ t.Errorf("Stream status change: got %d, want %d", got, want)
+ }
+ if got, want := pub.NextStatus(), streamTerminated; got != want {
+ t.Errorf("Stream status change: got %d, want %d", got, want)
+ }
+ if pub.Stream.currentStream() != nil {
+ t.Error("Client stream should be nil")
+ }
+ if gotErr := pub.Stream.Error(); !test.ErrorEqual(gotErr, permanentErr) {
+ t.Errorf("Stream final err: got (%v), want (%v)", gotErr, permanentErr)
+ }
+}
+
+func TestRetryableStreamConnectTimeout(t *testing.T) {
+ // Set a very low timeout to ensure no retries.
+ timeout := time.Millisecond
+ pub := newTestStreamHandler(t, timeout)
+ wantErr := status.Error(codes.DeadlineExceeded, "timeout")
+
+ verifiers := test.NewVerifiers(t)
+ stream := test.NewRPCVerifier(t)
+ barrier := stream.PushWithBarrier(pub.InitialReq, nil, wantErr)
+ verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream)
+
+ mockServer.OnTestStart(verifiers)
+ defer mockServer.OnTestEnd()
+
+ pub.Stream.Start()
+ if got, want := pub.NextStatus(), streamReconnecting; got != want {
+ t.Errorf("Stream status change: got %d, want %d", got, want)
+ }
+
+ // Send the initial server response well after the timeout setting.
+ time.Sleep(10 * timeout)
+ barrier.Release()
+
+ if got, want := pub.NextStatus(), streamTerminated; got != want {
+ t.Errorf("Stream status change: got %d, want %d", got, want)
+ }
+ if pub.Stream.currentStream() != nil {
+ t.Error("Client stream should be nil")
+ }
+ if gotErr := pub.Stream.Error(); !test.ErrorEqual(gotErr, wantErr) {
+ t.Errorf("Stream final err: got (%v), want (%v)", gotErr, wantErr)
+ }
+}
+
+func TestRetryableStreamSendReceive(t *testing.T) {
+ pub := newTestStreamHandler(t, defaultStreamTimeout)
+ req := msgPubReq(&pb.PubSubMessage{Data: []byte("msg")})
+ wantResp := msgPubResp(5)
+
+ verifiers := test.NewVerifiers(t)
+ stream := test.NewRPCVerifier(t)
+ barrier := stream.PushWithBarrier(pub.InitialReq, initPubResp(), nil)
+ stream.Push(req, wantResp, nil)
+ verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream)
+
+ mockServer.OnTestStart(verifiers)
+ defer mockServer.OnTestEnd()
+
+ pub.Stream.Start()
+ if got, want := pub.NextStatus(), streamReconnecting; got != want {
+ t.Errorf("Stream status change: got %d, want %d", got, want)
+ }
+
+ // While the stream is reconnecting, requests are discarded.
+ if got, want := pub.Stream.Send(req), false; got != want {
+ t.Errorf("Stream send: got %v, want %v", got, want)
+ }
+
+ barrier.Release()
+ if got, want := pub.NextStatus(), streamConnected; got != want {
+ t.Errorf("Stream status change: got %d, want %d", got, want)
+ }
+
+ if got, want := pub.Stream.Send(req), true; got != want {
+ t.Errorf("Stream send: got %v, want %v", got, want)
+ }
+ if gotResp := pub.NextResponse(); !testutil.Equal(gotResp, wantResp) {
+ t.Errorf("Stream response: got %v, want %v", gotResp, wantResp)
+ }
+
+ pub.Stream.Stop()
+ if got, want := pub.NextStatus(), streamTerminated; got != want {
+ t.Errorf("Stream status change: got %d, want %d", got, want)
+ }
+ if gotErr := pub.Stream.Error(); gotErr != nil {
+ t.Errorf("Stream final err: got (%v), want <nil>", gotErr)
+ }
+}