// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and

package pscompat

import (
	"context"
	"errors"
	"testing"
	"time"

	pubsub "cloud.google.com/go/internal/pubsub"
	"cloud.google.com/go/internal/testutil"
	"cloud.google.com/go/pubsublite/internal/test"
	"cloud.google.com/go/pubsublite/internal/wire"
	"github.com/google/go-cmp/cmp/cmpopts"
	"golang.org/x/sync/errgroup"

	tspb "github.com/golang/protobuf/ptypes/timestamp"
	pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
)

const (
	defaultSubscriberTestTimeout = 10 * time.Second
	activePartition              = 1
)

// mockAckConsumer is a mock implementation of the wire.AckConsumer interface.
type mockAckConsumer struct {
	AckCount int
}

func (ac *mockAckConsumer) Ack() {
	ac.AckCount++
}

// mockWireSubscriber is a mock implementation of the wire.Subscriber interface.
type mockWireSubscriber struct {
	receiver         wire.MessageReceiverFunc
	onReassignment   wire.ReassignmentHandlerFunc
	activePartitions wire.PartitionSet
	msgsC            chan *wire.ReceivedMessage
	stopC            chan struct{}
	err              error
	Stopped          bool
	Terminated       bool
}

// DeliverMessages should be called from the test to simulate a message
// delivery.
func (ms *mockWireSubscriber) DeliverMessages(msgs ...*wire.ReceivedMessage) {
	for _, m := range msgs {
		ms.msgsC <- m
	}
}

// OnReassignment should be called from the test to simulate a partition
// reassignment.
func (ms *mockWireSubscriber) DeliverReassignment(before, after wire.PartitionSet) {
	if err := ms.onReassignment(before, after); err != nil {
		ms.SimulateFatalError(err)
	}
}

// SimulateFatalError should be called from the test to simulate a fatal error
// occurring in the wire subscriber.
func (ms *mockWireSubscriber) SimulateFatalError(err error) {
	ms.err = err
	close(ms.stopC)
}

// wire.Subscriber implementation

func (ms *mockWireSubscriber) Start() {
	go func() {
		for {
			// Ensure stop has higher priority.
			select {
			case <-ms.stopC:
				return // Exit goroutine
			default:
			}

			select {
			case <-ms.stopC:
				return // Exit goroutine
			case msg := <-ms.msgsC:
				ms.receiver(msg)
			}
		}
	}()
}

func (ms *mockWireSubscriber) WaitStarted() error {
	return nil
}

func (ms *mockWireSubscriber) Stop() {
	if !ms.Stopped && !ms.Terminated {
		ms.Stopped = true
		close(ms.stopC)
	}
}

func (ms *mockWireSubscriber) Terminate() {
	if !ms.Stopped && !ms.Terminated {
		ms.Terminated = true
		close(ms.stopC)
	}
}

func (ms *mockWireSubscriber) WaitStopped() error {
	<-ms.stopC // Wait until Stopped
	return ms.err
}

func (ms *mockWireSubscriber) PartitionActive(partition int) bool {
	return ms.activePartitions.Contains(partition)
}

type mockWireSubscriberFactory struct{}

func (f *mockWireSubscriberFactory) New(ctx context.Context, receiver wire.MessageReceiverFunc, onReassignment wire.ReassignmentHandlerFunc) (wire.Subscriber, error) {
	return &mockWireSubscriber{
		receiver:         receiver,
		onReassignment:   onReassignment,
		activePartitions: wire.NewPartitionSet([]int{activePartition}),
		msgsC:            make(chan *wire.ReceivedMessage, 10),
		stopC:            make(chan struct{}),
	}, nil
}

func newTestSubscriberInstance(ctx context.Context, settings ReceiveSettings, receiver messageReceiverFunc) *subscriberInstance {
	sub, _ := newSubscriberInstance(ctx, context.Background(), new(mockWireSubscriberFactory), settings, receiver)
	return sub
}

func TestSubscriberInstanceTransformMessage(t *testing.T) {
	const partition = 3
	ctx := context.Background()
	input := &pb.SequencedMessage{
		Message: &pb.PubSubMessage{
			Data: []byte("data"),
			Key:  []byte("key"),
			Attributes: map[string]*pb.AttributeValues{
				"attr": {Values: [][]byte{[]byte("value")}},
			},
		},
		Cursor: &pb.Cursor{Offset: 123},
		PublishTime: &tspb.Timestamp{
			Seconds: 1577836800,
			Nanos:   900800700,
		},
	}

	for _, tc := range []struct {
		desc string
		// mutateSettings is passed a copy of DefaultReceiveSettings to mutate.
		mutateSettings func(settings *ReceiveSettings)
		want           *pubsub.Message
	}{
		{
			desc:           "default settings",
			mutateSettings: func(settings *ReceiveSettings) {},
			want: &pubsub.Message{
				Data:        []byte("data"),
				OrderingKey: "key",
				Attributes:  map[string]string{"attr": "value"},
				ID:          "3:123",
				PublishTime: time.Unix(1577836800, 900800700),
			},
		},
		{
			desc: "custom message transformer",
			mutateSettings: func(settings *ReceiveSettings) {
				settings.MessageTransformer = func(from *pb.SequencedMessage, to *pubsub.Message) error {
					// Swaps data and key.
					to.OrderingKey = string(from.Message.Data)
					to.Data = from.Message.Key
					return nil
				}
			},
			want: &pubsub.Message{
				Data:        []byte("key"),
				OrderingKey: "data",
				ID:          "3:123",
			},
		},
	} {
		t.Run(tc.desc, func(t *testing.T) {
			settings := DefaultReceiveSettings
			tc.mutateSettings(&settings)

			ack := &mockAckConsumer{}
			msg := &wire.ReceivedMessage{Msg: input, Ack: ack, Partition: partition}

			cctx, stopSubscriber := context.WithTimeout(ctx, defaultSubscriberTestTimeout)
			messageReceiver := func(ctx context.Context, got *pubsub.Message) {
				if diff := testutil.Diff(got, tc.want, cmpopts.IgnoreUnexported(pubsub.Message{}), cmpopts.EquateEmpty()); diff != "" {
					t.Errorf("Received message got: -, want: +\n%s", diff)
				}
				got.Ack()
				got.Nack() // Should be ignored
				stopSubscriber()
			}
			subInstance := newTestSubscriberInstance(cctx, settings, messageReceiver)
			subInstance.wireSub.(*mockWireSubscriber).DeliverMessages(msg)

			if err := subInstance.Wait(cctx); err != nil {
				t.Errorf("subscriberInstance.Wait() got err: %v", err)
			}
			if got, want := ack.AckCount, 1; got != want {
				t.Errorf("mockAckConsumer.AckCount: got %d, want %d", got, want)
			}
			if got, want := subInstance.recvCtx.Err(), context.Canceled; !test.ErrorEqual(got, want) {
				t.Errorf("subscriberInstance.recvCtx.Err(): got (%v), want (%v)", got, want)
			}
			if got, want := subInstance.wireSub.(*mockWireSubscriber).Stopped, true; got != want {
				t.Errorf("mockWireSubscriber.Stopped: got %v, want %v", got, want)
			}
		})
	}
}

func TestSubscriberInstanceTransformMessageError(t *testing.T) {
	transformErr := errors.New("message could not be converted")

	for _, tc := range []struct {
		desc        string
		transformer ReceiveMessageTransformerFunc
		wantErr     error
	}{
		{
			desc: "returns error",
			transformer: func(_ *pb.SequencedMessage, _ *pubsub.Message) error {
				return transformErr
			},
			wantErr: transformErr,
		},
		{
			desc: "sets message id",
			transformer: func(_ *pb.SequencedMessage, out *pubsub.Message) error {
				out.ID = "should_not_be_set"
				return nil
			},
			wantErr: errMessageIDSet,
		},
	} {
		t.Run(tc.desc, func(t *testing.T) {
			settings := DefaultReceiveSettings
			settings.MessageTransformer = tc.transformer

			ctx := context.Background()
			ack := &mockAckConsumer{}
			msg := &wire.ReceivedMessage{
				Ack: ack,
				Msg: &pb.SequencedMessage{
					Message: &pb.PubSubMessage{Data: []byte("data")},
				},
			}

			cctx, _ := context.WithTimeout(ctx, defaultSubscriberTestTimeout)
			messageReceiver := func(ctx context.Context, got *pubsub.Message) {
				t.Errorf("Received unexpected message: %v", got)
				got.Nack()
			}
			subInstance := newTestSubscriberInstance(cctx, settings, messageReceiver)
			subInstance.wireSub.(*mockWireSubscriber).DeliverMessages(msg)

			if gotErr := subInstance.Wait(cctx); !test.ErrorEqual(gotErr, tc.wantErr) {
				t.Errorf("subscriberInstance.Wait() got err: (%v), want: (%v)", gotErr, tc.wantErr)
			}
			if got, want := ack.AckCount, 0; got != want {
				t.Errorf("mockAckConsumer.AckCount: got %d, want %d", got, want)
			}
			if got, want := subInstance.recvCtx.Err(), context.Canceled; !test.ErrorEqual(got, want) {
				t.Errorf("subscriberInstance.recvCtx.Err(): got (%v), want (%v)", got, want)
			}
			if got, want := subInstance.wireSub.(*mockWireSubscriber).Terminated, true; got != want {
				t.Errorf("mockWireSubscriber.Terminated: got %v, want %v", got, want)
			}
		})
	}
}

func TestSubscriberInstanceNack(t *testing.T) {
	nackErr := errors.New("message nacked")

	ctx := context.Background()
	msg := &pb.SequencedMessage{
		Message: &pb.PubSubMessage{
			Data: []byte("data"),
			Key:  []byte("key"),
		},
	}

	for _, tc := range []struct {
		desc string
		// mutateSettings is passed a copy of DefaultReceiveSettings to mutate.
		mutateSettings func(settings *ReceiveSettings)
		msgPartition   int
		wantErr        error
		wantAckCount   int
		wantStopped    bool
		wantTerminated bool
	}{
		{
			desc:           "default settings",
			mutateSettings: func(settings *ReceiveSettings) {},
			msgPartition:   activePartition,
			wantErr:        errNackCalled,
			wantAckCount:   0,
			wantTerminated: true,
		},
		{
			desc:           "message partition inactive",
			mutateSettings: func(settings *ReceiveSettings) {},
			msgPartition:   activePartition + 1,
			wantErr:        nil,
			wantAckCount:   0,
			wantStopped:    true,
		},
		{
			desc: "nack handler returns nil",
			mutateSettings: func(settings *ReceiveSettings) {
				settings.NackHandler = func(_ *pubsub.Message) error {
					return nil
				}
			},
			msgPartition: activePartition,
			wantErr:      nil,
			wantAckCount: 1,
			wantStopped:  true,
		},
		{
			desc: "nack handler returns error",
			mutateSettings: func(settings *ReceiveSettings) {
				settings.NackHandler = func(_ *pubsub.Message) error {
					return nackErr
				}
			},
			msgPartition:   activePartition,
			wantErr:        nackErr,
			wantAckCount:   0,
			wantTerminated: true,
		},
	} {
		t.Run(tc.desc, func(t *testing.T) {
			settings := DefaultReceiveSettings
			tc.mutateSettings(&settings)

			ack := &mockAckConsumer{}
			msg := &wire.ReceivedMessage{Msg: msg, Ack: ack, Partition: tc.msgPartition}

			cctx, stopSubscriber := context.WithTimeout(ctx, defaultSubscriberTestTimeout)
			messageReceiver := func(ctx context.Context, got *pubsub.Message) {
				got.Nack()

				// Only need to stop the subscriber when the nack handler actually acks
				// the message. For other cases, the subscriber is forcibly terminated.
				if tc.wantErr == nil {
					stopSubscriber()
				}
			}
			subInstance := newTestSubscriberInstance(cctx, settings, messageReceiver)
			subInstance.wireSub.(*mockWireSubscriber).DeliverMessages(msg)

			if gotErr := subInstance.Wait(cctx); !test.ErrorEqual(gotErr, tc.wantErr) {
				t.Errorf("subscriberInstance.Wait() got err: (%v), want: (%v)", gotErr, tc.wantErr)
			}
			if got, want := ack.AckCount, tc.wantAckCount; got != want {
				t.Errorf("mockAckConsumer.AckCount: got %d, want %d", got, want)
			}
			if got, want := subInstance.recvCtx.Err(), context.Canceled; !test.ErrorEqual(got, want) {
				t.Errorf("subscriberInstance.recvCtx.Err(): got (%v), want (%v)", got, want)
			}
			if got, want := subInstance.wireSub.(*mockWireSubscriber).Stopped, tc.wantStopped; got != want {
				t.Errorf("mockWireSubscriber.Stopped: got %v, want %v", got, want)
			}
			if got, want := subInstance.wireSub.(*mockWireSubscriber).Terminated, tc.wantTerminated; got != want {
				t.Errorf("mockWireSubscriber.Terminated: got %v, want %v", got, want)
			}
		})
	}
}

func TestSubscriberInstanceReassignmentHandler(t *testing.T) {
	reassignmentErr := errors.New("reassignment failure")
	before := wire.NewPartitionSet([]int{3, 2, 1})
	after := wire.NewPartitionSet([]int{4, 5, 3})
	ctx := context.Background()

	for _, tc := range []struct {
		desc string
		// mutateSettings is passed a copy of DefaultReceiveSettings to mutate.
		mutateSettings func(settings *ReceiveSettings)
		wantErr        error
	}{
		{
			desc:           "default settings",
			mutateSettings: func(settings *ReceiveSettings) {},
		},
		{
			desc: "reassignment handler returns nil",
			mutateSettings: func(settings *ReceiveSettings) {
				settings.ReassignmentHandler = func(before, after []int) error {
					if got, want := before, []int{1, 2, 3}; !testutil.Equal(got, want) {
						t.Errorf("before: got %d, want %d", got, want)
					}
					if got, want := after, []int{3, 4, 5}; !testutil.Equal(got, want) {
						t.Errorf("after: got %d, want %d", got, want)
					}
					return nil
				}
			},
		},
		{
			desc: "reassignment handler returns error",
			mutateSettings: func(settings *ReceiveSettings) {
				settings.ReassignmentHandler = func(before, after []int) error {
					return reassignmentErr
				}
			},
			wantErr: reassignmentErr,
		},
	} {
		t.Run(tc.desc, func(t *testing.T) {
			settings := DefaultReceiveSettings
			tc.mutateSettings(&settings)

			cctx, stopSubscriber := context.WithTimeout(ctx, defaultSubscriberTestTimeout)
			messageReceiver := func(ctx context.Context, got *pubsub.Message) {
				t.Error("Message receiver should not be called")
			}
			subInstance := newTestSubscriberInstance(cctx, settings, messageReceiver)
			subInstance.wireSub.(*mockWireSubscriber).DeliverReassignment(before, after)
			if tc.wantErr == nil {
				stopSubscriber()
			}

			if gotErr := subInstance.Wait(cctx); !test.ErrorEqual(gotErr, tc.wantErr) {
				t.Errorf("subscriberInstance.Wait() got err: (%v), want err: (%v)", gotErr, tc.wantErr)
			}
		})
	}
}

func TestSubscriberInstanceWireSubscriberFails(t *testing.T) {
	fatalErr := errors.New("server error")

	ctx := context.Background()
	msg := &wire.ReceivedMessage{
		Ack: &mockAckConsumer{},
		Msg: &pb.SequencedMessage{
			Message: &pb.PubSubMessage{Data: []byte("data")},
		},
	}

	cctx, _ := context.WithTimeout(ctx, defaultSubscriberTestTimeout)
	messageReceiver := func(ctx context.Context, got *pubsub.Message) {
		// Verifies that receivers are notified via ctx.Done when the subscriber is
		// shutting down.
		select {
		case <-time.After(defaultSubscriberTestTimeout):
			t.Errorf("MessageReceiverFunc context not closed within %v", defaultSubscriberTestTimeout)
		case <-ctx.Done():
		}
	}
	subInstance := newTestSubscriberInstance(cctx, DefaultReceiveSettings, messageReceiver)
	subInstance.wireSub.(*mockWireSubscriber).DeliverMessages(msg)
	time.AfterFunc(100*time.Millisecond, func() {
		// Simulates a fatal server error that causes the wire subscriber to
		// terminate from within.
		subInstance.wireSub.(*mockWireSubscriber).SimulateFatalError(fatalErr)
	})

	if gotErr := subInstance.Wait(cctx); !test.ErrorEqual(gotErr, fatalErr) {
		t.Errorf("subscriberInstance.Wait() got err: (%v), want: (%v)", gotErr, fatalErr)
	}
	if got, want := subInstance.recvCtx.Err(), context.Canceled; !test.ErrorEqual(got, want) {
		t.Errorf("subscriberInstance.recvCtx.Err(): got (%v), want (%v)", got, want)
	}
	if got, want := subInstance.wireSub.(*mockWireSubscriber).Stopped, false; got != want {
		t.Errorf("mockWireSubscriber.Stopped: got %v, want %v", got, want)
	}
	if got, want := subInstance.wireSub.(*mockWireSubscriber).Terminated, false; got != want {
		t.Errorf("mockWireSubscriber.Terminated: got %v, want %v", got, want)
	}
}

func TestSubscriberClientDuplicateReceive(t *testing.T) {
	ctx := context.Background()
	subClient := &SubscriberClient{
		settings:       DefaultReceiveSettings,
		wireSubFactory: new(mockWireSubscriberFactory),
	}

	messageReceiver := func(_ context.Context, got *pubsub.Message) {
		t.Errorf("No messages expected, got: %v", got)
	}

	g, gctx := errgroup.WithContext(ctx)
	for i := 0; i < 3; i++ {
		// Receive() is blocking, so we must start them in goroutines. Passing gctx
		// to Receive will stop the subscribers once the first error occurs.
		g.Go(func() error {
			return subClient.Receive(gctx, messageReceiver)
		})
	}
	if gotErr, wantErr := g.Wait(), errDuplicateReceive; !test.ErrorEqual(gotErr, wantErr) {
		t.Errorf("SubscriberClient.Receive() got err: (%v), want: (%v)", gotErr, wantErr)
	}
}
