blob: 6bd4ac8c788abf8d91283418d9e5438b225a03a9 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
package pscompat
import (
"context"
"errors"
"testing"
"time"
pubsub "cloud.google.com/go/internal/pubsub"
"cloud.google.com/go/internal/testutil"
"cloud.google.com/go/pubsublite/internal/test"
"cloud.google.com/go/pubsublite/internal/wire"
"github.com/google/go-cmp/cmp/cmpopts"
"golang.org/x/sync/errgroup"
tspb "github.com/golang/protobuf/ptypes/timestamp"
pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
)
const (
defaultSubscriberTestTimeout = 10 * time.Second
activePartition = 1
)
// mockAckConsumer is a mock implementation of the wire.AckConsumer interface.
type mockAckConsumer struct {
AckCount int
}
func (ac *mockAckConsumer) Ack() {
ac.AckCount++
}
// mockWireSubscriber is a mock implementation of the wire.Subscriber interface.
type mockWireSubscriber struct {
receiver wire.MessageReceiverFunc
onReassignment wire.ReassignmentHandlerFunc
activePartitions wire.PartitionSet
msgsC chan *wire.ReceivedMessage
stopC chan struct{}
err error
Stopped bool
Terminated bool
}
// DeliverMessages should be called from the test to simulate a message
// delivery.
func (ms *mockWireSubscriber) DeliverMessages(msgs ...*wire.ReceivedMessage) {
for _, m := range msgs {
ms.msgsC <- m
}
}
// OnReassignment should be called from the test to simulate a partition
// reassignment.
func (ms *mockWireSubscriber) DeliverReassignment(before, after wire.PartitionSet) {
if err := ms.onReassignment(before, after); err != nil {
ms.SimulateFatalError(err)
}
}
// SimulateFatalError should be called from the test to simulate a fatal error
// occurring in the wire subscriber.
func (ms *mockWireSubscriber) SimulateFatalError(err error) {
ms.err = err
close(ms.stopC)
}
// wire.Subscriber implementation
func (ms *mockWireSubscriber) Start() {
go func() {
for {
// Ensure stop has higher priority.
select {
case <-ms.stopC:
return // Exit goroutine
default:
}
select {
case <-ms.stopC:
return // Exit goroutine
case msg := <-ms.msgsC:
ms.receiver(msg)
}
}
}()
}
func (ms *mockWireSubscriber) WaitStarted() error {
return nil
}
func (ms *mockWireSubscriber) Stop() {
if !ms.Stopped && !ms.Terminated {
ms.Stopped = true
close(ms.stopC)
}
}
func (ms *mockWireSubscriber) Terminate() {
if !ms.Stopped && !ms.Terminated {
ms.Terminated = true
close(ms.stopC)
}
}
func (ms *mockWireSubscriber) WaitStopped() error {
<-ms.stopC // Wait until Stopped
return ms.err
}
func (ms *mockWireSubscriber) PartitionActive(partition int) bool {
return ms.activePartitions.Contains(partition)
}
type mockWireSubscriberFactory struct{}
func (f *mockWireSubscriberFactory) New(ctx context.Context, receiver wire.MessageReceiverFunc, onReassignment wire.ReassignmentHandlerFunc) (wire.Subscriber, error) {
return &mockWireSubscriber{
receiver: receiver,
onReassignment: onReassignment,
activePartitions: wire.NewPartitionSet([]int{activePartition}),
msgsC: make(chan *wire.ReceivedMessage, 10),
stopC: make(chan struct{}),
}, nil
}
func newTestSubscriberInstance(ctx context.Context, settings ReceiveSettings, receiver messageReceiverFunc) *subscriberInstance {
sub, _ := newSubscriberInstance(ctx, context.Background(), new(mockWireSubscriberFactory), settings, receiver)
return sub
}
func TestSubscriberInstanceTransformMessage(t *testing.T) {
const partition = 3
ctx := context.Background()
input := &pb.SequencedMessage{
Message: &pb.PubSubMessage{
Data: []byte("data"),
Key: []byte("key"),
Attributes: map[string]*pb.AttributeValues{
"attr": {Values: [][]byte{[]byte("value")}},
},
},
Cursor: &pb.Cursor{Offset: 123},
PublishTime: &tspb.Timestamp{
Seconds: 1577836800,
Nanos: 900800700,
},
}
for _, tc := range []struct {
desc string
// mutateSettings is passed a copy of DefaultReceiveSettings to mutate.
mutateSettings func(settings *ReceiveSettings)
want *pubsub.Message
}{
{
desc: "default settings",
mutateSettings: func(settings *ReceiveSettings) {},
want: &pubsub.Message{
Data: []byte("data"),
OrderingKey: "key",
Attributes: map[string]string{"attr": "value"},
ID: "3:123",
PublishTime: time.Unix(1577836800, 900800700),
},
},
{
desc: "custom message transformer",
mutateSettings: func(settings *ReceiveSettings) {
settings.MessageTransformer = func(from *pb.SequencedMessage, to *pubsub.Message) error {
// Swaps data and key.
to.OrderingKey = string(from.Message.Data)
to.Data = from.Message.Key
return nil
}
},
want: &pubsub.Message{
Data: []byte("key"),
OrderingKey: "data",
ID: "3:123",
},
},
} {
t.Run(tc.desc, func(t *testing.T) {
settings := DefaultReceiveSettings
tc.mutateSettings(&settings)
ack := &mockAckConsumer{}
msg := &wire.ReceivedMessage{Msg: input, Ack: ack, Partition: partition}
cctx, stopSubscriber := context.WithTimeout(ctx, defaultSubscriberTestTimeout)
messageReceiver := func(ctx context.Context, got *pubsub.Message) {
if diff := testutil.Diff(got, tc.want, cmpopts.IgnoreUnexported(pubsub.Message{}), cmpopts.EquateEmpty()); diff != "" {
t.Errorf("Received message got: -, want: +\n%s", diff)
}
got.Ack()
got.Nack() // Should be ignored
stopSubscriber()
}
subInstance := newTestSubscriberInstance(cctx, settings, messageReceiver)
subInstance.wireSub.(*mockWireSubscriber).DeliverMessages(msg)
if err := subInstance.Wait(cctx); err != nil {
t.Errorf("subscriberInstance.Wait() got err: %v", err)
}
if got, want := ack.AckCount, 1; got != want {
t.Errorf("mockAckConsumer.AckCount: got %d, want %d", got, want)
}
if got, want := subInstance.recvCtx.Err(), context.Canceled; !test.ErrorEqual(got, want) {
t.Errorf("subscriberInstance.recvCtx.Err(): got (%v), want (%v)", got, want)
}
if got, want := subInstance.wireSub.(*mockWireSubscriber).Stopped, true; got != want {
t.Errorf("mockWireSubscriber.Stopped: got %v, want %v", got, want)
}
})
}
}
func TestSubscriberInstanceTransformMessageError(t *testing.T) {
transformErr := errors.New("message could not be converted")
for _, tc := range []struct {
desc string
transformer ReceiveMessageTransformerFunc
wantErr error
}{
{
desc: "returns error",
transformer: func(_ *pb.SequencedMessage, _ *pubsub.Message) error {
return transformErr
},
wantErr: transformErr,
},
{
desc: "sets message id",
transformer: func(_ *pb.SequencedMessage, out *pubsub.Message) error {
out.ID = "should_not_be_set"
return nil
},
wantErr: errMessageIDSet,
},
} {
t.Run(tc.desc, func(t *testing.T) {
settings := DefaultReceiveSettings
settings.MessageTransformer = tc.transformer
ctx := context.Background()
ack := &mockAckConsumer{}
msg := &wire.ReceivedMessage{
Ack: ack,
Msg: &pb.SequencedMessage{
Message: &pb.PubSubMessage{Data: []byte("data")},
},
}
cctx, _ := context.WithTimeout(ctx, defaultSubscriberTestTimeout)
messageReceiver := func(ctx context.Context, got *pubsub.Message) {
t.Errorf("Received unexpected message: %v", got)
got.Nack()
}
subInstance := newTestSubscriberInstance(cctx, settings, messageReceiver)
subInstance.wireSub.(*mockWireSubscriber).DeliverMessages(msg)
if gotErr := subInstance.Wait(cctx); !test.ErrorEqual(gotErr, tc.wantErr) {
t.Errorf("subscriberInstance.Wait() got err: (%v), want: (%v)", gotErr, tc.wantErr)
}
if got, want := ack.AckCount, 0; got != want {
t.Errorf("mockAckConsumer.AckCount: got %d, want %d", got, want)
}
if got, want := subInstance.recvCtx.Err(), context.Canceled; !test.ErrorEqual(got, want) {
t.Errorf("subscriberInstance.recvCtx.Err(): got (%v), want (%v)", got, want)
}
if got, want := subInstance.wireSub.(*mockWireSubscriber).Terminated, true; got != want {
t.Errorf("mockWireSubscriber.Terminated: got %v, want %v", got, want)
}
})
}
}
func TestSubscriberInstanceNack(t *testing.T) {
nackErr := errors.New("message nacked")
ctx := context.Background()
msg := &pb.SequencedMessage{
Message: &pb.PubSubMessage{
Data: []byte("data"),
Key: []byte("key"),
},
}
for _, tc := range []struct {
desc string
// mutateSettings is passed a copy of DefaultReceiveSettings to mutate.
mutateSettings func(settings *ReceiveSettings)
msgPartition int
wantErr error
wantAckCount int
wantStopped bool
wantTerminated bool
}{
{
desc: "default settings",
mutateSettings: func(settings *ReceiveSettings) {},
msgPartition: activePartition,
wantErr: errNackCalled,
wantAckCount: 0,
wantTerminated: true,
},
{
desc: "message partition inactive",
mutateSettings: func(settings *ReceiveSettings) {},
msgPartition: activePartition + 1,
wantErr: nil,
wantAckCount: 0,
wantStopped: true,
},
{
desc: "nack handler returns nil",
mutateSettings: func(settings *ReceiveSettings) {
settings.NackHandler = func(_ *pubsub.Message) error {
return nil
}
},
msgPartition: activePartition,
wantErr: nil,
wantAckCount: 1,
wantStopped: true,
},
{
desc: "nack handler returns error",
mutateSettings: func(settings *ReceiveSettings) {
settings.NackHandler = func(_ *pubsub.Message) error {
return nackErr
}
},
msgPartition: activePartition,
wantErr: nackErr,
wantAckCount: 0,
wantTerminated: true,
},
} {
t.Run(tc.desc, func(t *testing.T) {
settings := DefaultReceiveSettings
tc.mutateSettings(&settings)
ack := &mockAckConsumer{}
msg := &wire.ReceivedMessage{Msg: msg, Ack: ack, Partition: tc.msgPartition}
cctx, stopSubscriber := context.WithTimeout(ctx, defaultSubscriberTestTimeout)
messageReceiver := func(ctx context.Context, got *pubsub.Message) {
got.Nack()
// Only need to stop the subscriber when the nack handler actually acks
// the message. For other cases, the subscriber is forcibly terminated.
if tc.wantErr == nil {
stopSubscriber()
}
}
subInstance := newTestSubscriberInstance(cctx, settings, messageReceiver)
subInstance.wireSub.(*mockWireSubscriber).DeliverMessages(msg)
if gotErr := subInstance.Wait(cctx); !test.ErrorEqual(gotErr, tc.wantErr) {
t.Errorf("subscriberInstance.Wait() got err: (%v), want: (%v)", gotErr, tc.wantErr)
}
if got, want := ack.AckCount, tc.wantAckCount; got != want {
t.Errorf("mockAckConsumer.AckCount: got %d, want %d", got, want)
}
if got, want := subInstance.recvCtx.Err(), context.Canceled; !test.ErrorEqual(got, want) {
t.Errorf("subscriberInstance.recvCtx.Err(): got (%v), want (%v)", got, want)
}
if got, want := subInstance.wireSub.(*mockWireSubscriber).Stopped, tc.wantStopped; got != want {
t.Errorf("mockWireSubscriber.Stopped: got %v, want %v", got, want)
}
if got, want := subInstance.wireSub.(*mockWireSubscriber).Terminated, tc.wantTerminated; got != want {
t.Errorf("mockWireSubscriber.Terminated: got %v, want %v", got, want)
}
})
}
}
func TestSubscriberInstanceReassignmentHandler(t *testing.T) {
reassignmentErr := errors.New("reassignment failure")
before := wire.NewPartitionSet([]int{3, 2, 1})
after := wire.NewPartitionSet([]int{4, 5, 3})
ctx := context.Background()
for _, tc := range []struct {
desc string
// mutateSettings is passed a copy of DefaultReceiveSettings to mutate.
mutateSettings func(settings *ReceiveSettings)
wantErr error
}{
{
desc: "default settings",
mutateSettings: func(settings *ReceiveSettings) {},
},
{
desc: "reassignment handler returns nil",
mutateSettings: func(settings *ReceiveSettings) {
settings.ReassignmentHandler = func(before, after []int) error {
if got, want := before, []int{1, 2, 3}; !testutil.Equal(got, want) {
t.Errorf("before: got %d, want %d", got, want)
}
if got, want := after, []int{3, 4, 5}; !testutil.Equal(got, want) {
t.Errorf("after: got %d, want %d", got, want)
}
return nil
}
},
},
{
desc: "reassignment handler returns error",
mutateSettings: func(settings *ReceiveSettings) {
settings.ReassignmentHandler = func(before, after []int) error {
return reassignmentErr
}
},
wantErr: reassignmentErr,
},
} {
t.Run(tc.desc, func(t *testing.T) {
settings := DefaultReceiveSettings
tc.mutateSettings(&settings)
cctx, stopSubscriber := context.WithTimeout(ctx, defaultSubscriberTestTimeout)
messageReceiver := func(ctx context.Context, got *pubsub.Message) {
t.Error("Message receiver should not be called")
}
subInstance := newTestSubscriberInstance(cctx, settings, messageReceiver)
subInstance.wireSub.(*mockWireSubscriber).DeliverReassignment(before, after)
if tc.wantErr == nil {
stopSubscriber()
}
if gotErr := subInstance.Wait(cctx); !test.ErrorEqual(gotErr, tc.wantErr) {
t.Errorf("subscriberInstance.Wait() got err: (%v), want err: (%v)", gotErr, tc.wantErr)
}
})
}
}
func TestSubscriberInstanceWireSubscriberFails(t *testing.T) {
fatalErr := errors.New("server error")
ctx := context.Background()
msg := &wire.ReceivedMessage{
Ack: &mockAckConsumer{},
Msg: &pb.SequencedMessage{
Message: &pb.PubSubMessage{Data: []byte("data")},
},
}
cctx, _ := context.WithTimeout(ctx, defaultSubscriberTestTimeout)
messageReceiver := func(ctx context.Context, got *pubsub.Message) {
// Verifies that receivers are notified via ctx.Done when the subscriber is
// shutting down.
select {
case <-time.After(defaultSubscriberTestTimeout):
t.Errorf("MessageReceiverFunc context not closed within %v", defaultSubscriberTestTimeout)
case <-ctx.Done():
}
}
subInstance := newTestSubscriberInstance(cctx, DefaultReceiveSettings, messageReceiver)
subInstance.wireSub.(*mockWireSubscriber).DeliverMessages(msg)
time.AfterFunc(100*time.Millisecond, func() {
// Simulates a fatal server error that causes the wire subscriber to
// terminate from within.
subInstance.wireSub.(*mockWireSubscriber).SimulateFatalError(fatalErr)
})
if gotErr := subInstance.Wait(cctx); !test.ErrorEqual(gotErr, fatalErr) {
t.Errorf("subscriberInstance.Wait() got err: (%v), want: (%v)", gotErr, fatalErr)
}
if got, want := subInstance.recvCtx.Err(), context.Canceled; !test.ErrorEqual(got, want) {
t.Errorf("subscriberInstance.recvCtx.Err(): got (%v), want (%v)", got, want)
}
if got, want := subInstance.wireSub.(*mockWireSubscriber).Stopped, false; got != want {
t.Errorf("mockWireSubscriber.Stopped: got %v, want %v", got, want)
}
if got, want := subInstance.wireSub.(*mockWireSubscriber).Terminated, false; got != want {
t.Errorf("mockWireSubscriber.Terminated: got %v, want %v", got, want)
}
}
func TestSubscriberClientDuplicateReceive(t *testing.T) {
ctx := context.Background()
subClient := &SubscriberClient{
settings: DefaultReceiveSettings,
wireSubFactory: new(mockWireSubscriberFactory),
}
messageReceiver := func(_ context.Context, got *pubsub.Message) {
t.Errorf("No messages expected, got: %v", got)
}
g, gctx := errgroup.WithContext(ctx)
for i := 0; i < 3; i++ {
// Receive() is blocking, so we must start them in goroutines. Passing gctx
// to Receive will stop the subscribers once the first error occurs.
g.Go(func() error {
return subClient.Receive(gctx, messageReceiver)
})
}
if gotErr, wantErr := g.Wait(), errDuplicateReceive; !test.ErrorEqual(gotErr, wantErr) {
t.Errorf("SubscriberClient.Receive() got err: (%v), want: (%v)", gotErr, wantErr)
}
}