// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package syslog

import (
	"context"
	"errors"
	"io"
	"os"
	"runtime/pprof"
	"sync"
	"testing"
	"time"

	"go.fuchsia.dev/fuchsia/tools/lib/logger"
	"go.fuchsia.dev/fuchsia/tools/lib/retry"
	"go.fuchsia.dev/fuchsia/tools/net/sshutil"
)

const testTimeout = time.Second

type fakeSSHClient struct {
	// Run() will listen on this channel and return any error that it receives.
	mockRunErrs chan error
	// Run() will listen on this channel and send any received data to the
	// `stdout` writer.
	mockRunStdout chan string
	// ReconnectWithBackoff() will close this channel.
	reconnected chan struct{}

	// mu controls access to the `runListeners` field.
	mu sync.Mutex
	// Channel that will be closed next time Run() is called.
	runListener chan struct{}
}

func (c *fakeSSHClient) Run(ctx context.Context, _ []string, stdout, _ io.Writer) error {
	c.mu.Lock()
	if c.runListener != nil {
		close(c.runListener)
		c.runListener = nil
	}
	c.mu.Unlock()

	for {
		select {
		case mockStdout := <-c.mockRunStdout:
			stdout.Write([]byte(mockStdout))
			continue
		case err := <-c.mockRunErrs:
			return err
		case <-ctx.Done():
			return ctx.Err()
		}
	}
}

// addRunListener returns a channel that will be closed the next time
// client.Run() is called.
func (c *fakeSSHClient) listenForRun() <-chan struct{} {
	c.mu.Lock()
	if c.runListener == nil {
		c.runListener = make(chan struct{})
	}
	c.mu.Unlock()
	return c.runListener
}

func (c *fakeSSHClient) ReconnectWithBackoff(_ context.Context, _ retry.Backoff) error {
	close(c.reconnected)
	return nil
}

// startStream creates a new syslogger using a fake SSH client and starts
// streaming from that syslogger. It returns the fake SSH client used by the
// syslogger, an io.Reader from which the syslogs output by Stream() can be
// read, and a channel on which the error (possibly nil) returned by Stream()
// will be sent.
//
// Note that this function will only call Stream() once, with no retries,
// although Stream() itself is expected to have internal retries for certain
// failure modes.
func startStream(ctx, loggerCtx context.Context) (*fakeSSHClient, io.Reader, <-chan error) {
	client := &fakeSSHClient{
		mockRunErrs:   make(chan error),
		mockRunStdout: make(chan string),
		reconnected:   make(chan struct{}),
	}
	running := client.listenForRun()

	syslogger := &Syslogger{
		client: client,
	}
	pipeReader, pipeWriter := io.Pipe()

	// syslogger.Stream() will start streaming in a background goroutine
	// that should run for the duration of the test that uses this function.
	streamErrs := syslogger.Stream(ctx, loggerCtx, pipeWriter)

	// Don't return until we've started running the SSH command, to ensure that
	// we always return a client in a deterministic state.
	<-running
	return client, pipeReader, streamErrs
}

// runWithTimeout runs the given function with `testTimeout`, and fails the test
// (and prints a stacktrace) if the function exceeds the timeout.
func runWithTimeout(t *testing.T, f func() error, failureMessage string) error {
	errs := make(chan error, 1)
	go func() {
		errs <- f()
	}()

	select {
	case err := <-errs:
		return err
	case <-time.After(testTimeout):
		pprof.Lookup("goroutine").WriteTo(os.Stderr, 1)
		t.Fatalf(failureMessage)
	}
	return nil
}

func assertClosed(t *testing.T, channel <-chan struct{}, failureMessage string) {
	runWithTimeout(t, func() error {
		<-channel
		return nil
	}, failureMessage)
}

func TestStream(t *testing.T) {
	// NoLogLevel may be changed for verbosity while debugging.
	l := logger.NewLogger(logger.NoLogLevel, nil, nil, nil, "")
	ctx := logger.WithLogger(context.Background(), l)

	t.Run("streams stdout until context canceled", func(t *testing.T) {
		syslogCtx, cancel := context.WithCancel(ctx)
		defer cancel()
		client, streamOutput, streamErrs := startStream(syslogCtx, ctx)

		stdout := "ABCDE"
		client.mockRunStdout <- stdout
		buf := make([]byte, len(stdout))
		io.ReadAtLeast(streamOutput, buf, len(stdout))
		if string(buf) != "ABCDE" {
			t.Errorf("unexpected bytes. wanted: %q, got: %q", stdout, string(buf))
		}

		runningAgain := client.listenForRun()

		cancel()

		err := runWithTimeout(t, func() error {
			return <-streamErrs
		}, "expected context cancellation to stop streaming")
		if !errors.Is(err, context.Canceled) {
			t.Fatalf("unexpected streaming error: %v", err)
		}

		select {
		case <-runningAgain:
			t.Errorf("expected client.Run() to only be called once")
		default:
		}

		select {
		case <-client.reconnected:
			t.Errorf("client unexpectedly reconnected")
		default:
		}
	})

	// Errors not of type sshutil.ConnectionError should cause the syslogger to
	// exit, rather than reconnecting and resuming streaming.
	t.Run("non-connection error interrupts the stream", func(t *testing.T) {
		client, _, streamErrs := startStream(ctx, ctx)

		// We'll be notified in case Run() is called again (it shouldn't be
		// called again).
		runningAgain := client.listenForRun()

		// Run() will return a non-connection error, which should cause Stream()
		// to exit.
		nonConnectionError := errors.New("foo")
		client.mockRunErrs <- nonConnectionError

		err := runWithTimeout(t, func() error {
			return <-streamErrs
		}, "expected a Run() error to stop streaming")
		if !errors.Is(err, nonConnectionError) {
			t.Fatalf("unexpected streaming error: %v, expected: %v", err, nonConnectionError)
		}

		// Stream() should have exited immediately after receiving the error
		// from client.Run(), so it should not have called Run() again.
		select {
		case <-runningAgain:
			t.Errorf("expected client.Run() to be called once")
		default:
		}

		// Likewise, we shouldn't have tried to reconnect after receiving the
		// error from client.Run(), because the error was not SSH-related.
		select {
		case <-client.reconnected:
			t.Errorf("runner unexpectedly reconnected")
		default:
		}
	})

	// If we come across a connection error we should reconnect and re-run log_listener.
	t.Run("stream should recover from a connection error", func(t *testing.T) {
		client, streamOutput, _ := startStream(ctx, ctx)

		runningAgain := client.listenForRun()

		go func() {
			// Read output so that the syslogger doesn't get blocked on writing
			// the "syslog stream interrupted" message.
			io.ReadAll(streamOutput)
		}()

		client.mockRunErrs <- sshutil.ConnectionError{}

		assertClosed(t, client.reconnected,
			"expected syslogger to reconnect after connection error, but it didn't")

		assertClosed(t, runningAgain,
			"expected syslogger to re-run log_listener after reconnection")
	})

	// If log_listener exits successfully, we should just keep running it.
	t.Run("stream should rerun log_listener if it exits successfully", func(t *testing.T) {
		client, _, _ := startStream(ctx, ctx)

		runningAgain := client.listenForRun()

		client.mockRunErrs <- nil

		assertClosed(t, runningAgain,
			"expected syslogger to re-run log_listener again after a successful exit")

		select {
		case <-client.reconnected:
			t.Errorf("runner unexpectedly reconnected")
		default:
		}
	})
}
