// Copyright (c) 2017 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package multierr

import (
	"errors"
	"fmt"
	"io"
	"sync"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

// richFormatError is an error that prints a different output depending on
// whether %v or %+v was used.
type richFormatError struct{}

func (r richFormatError) Error() string {
	return fmt.Sprint(r)
}

func (richFormatError) Format(f fmt.State, c rune) {
	if c == 'v' && f.Flag('+') {
		io.WriteString(f, "multiline\nmessage\nwith plus")
	} else {
		io.WriteString(f, "without plus")
	}
}

func appendN(initial, err error, n int) error {
	errs := initial
	for i := 0; i < n; i++ {
		errs = Append(errs, err)
	}
	return errs
}

func newMultiErr(errors ...error) error {
	return &multiError{errors: errors}
}

func TestCombine(t *testing.T) {
	tests := []struct {
		giveErrors     []error
		wantError      error
		wantMultiline  string
		wantSingleline string
	}{
		{
			giveErrors: nil,
			wantError:  nil,
		},
		{
			giveErrors: []error{},
			wantError:  nil,
		},
		{
			giveErrors: []error{
				errors.New("foo"),
				nil,
				newMultiErr(
					errors.New("bar"),
				),
				nil,
			},
			wantError: newMultiErr(
				errors.New("foo"),
				errors.New("bar"),
			),
			wantMultiline: "the following errors occurred:\n" +
				" -  foo\n" +
				" -  bar",
			wantSingleline: "foo; bar",
		},
		{
			giveErrors: []error{
				errors.New("foo"),
				newMultiErr(
					errors.New("bar"),
				),
			},
			wantError: newMultiErr(
				errors.New("foo"),
				errors.New("bar"),
			),
			wantMultiline: "the following errors occurred:\n" +
				" -  foo\n" +
				" -  bar",
			wantSingleline: "foo; bar",
		},
		{
			giveErrors:     []error{errors.New("great sadness")},
			wantError:      errors.New("great sadness"),
			wantMultiline:  "great sadness",
			wantSingleline: "great sadness",
		},
		{
			giveErrors: []error{
				errors.New("foo"),
				errors.New("bar"),
			},
			wantError: newMultiErr(
				errors.New("foo"),
				errors.New("bar"),
			),
			wantMultiline: "the following errors occurred:\n" +
				" -  foo\n" +
				" -  bar",
			wantSingleline: "foo; bar",
		},
		{
			giveErrors: []error{
				errors.New("great sadness"),
				errors.New("multi\n  line\nerror message"),
				errors.New("single line error message"),
			},
			wantError: newMultiErr(
				errors.New("great sadness"),
				errors.New("multi\n  line\nerror message"),
				errors.New("single line error message"),
			),
			wantMultiline: "the following errors occurred:\n" +
				" -  great sadness\n" +
				" -  multi\n" +
				"      line\n" +
				"    error message\n" +
				" -  single line error message",
			wantSingleline: "great sadness; " +
				"multi\n  line\nerror message; " +
				"single line error message",
		},
		{
			giveErrors: []error{
				errors.New("foo"),
				newMultiErr(
					errors.New("bar"),
					errors.New("baz"),
				),
				errors.New("qux"),
			},
			wantError: newMultiErr(
				errors.New("foo"),
				errors.New("bar"),
				errors.New("baz"),
				errors.New("qux"),
			),
			wantMultiline: "the following errors occurred:\n" +
				" -  foo\n" +
				" -  bar\n" +
				" -  baz\n" +
				" -  qux",
			wantSingleline: "foo; bar; baz; qux",
		},
		{
			giveErrors: []error{
				errors.New("foo"),
				nil,
				newMultiErr(
					errors.New("bar"),
				),
				nil,
			},
			wantError: newMultiErr(
				errors.New("foo"),
				errors.New("bar"),
			),
			wantMultiline: "the following errors occurred:\n" +
				" -  foo\n" +
				" -  bar",
			wantSingleline: "foo; bar",
		},
		{
			giveErrors: []error{
				errors.New("foo"),
				newMultiErr(
					errors.New("bar"),
				),
			},
			wantError: newMultiErr(
				errors.New("foo"),
				errors.New("bar"),
			),
			wantMultiline: "the following errors occurred:\n" +
				" -  foo\n" +
				" -  bar",
			wantSingleline: "foo; bar",
		},
		{
			giveErrors: []error{
				errors.New("foo"),
				richFormatError{},
				errors.New("bar"),
			},
			wantError: newMultiErr(
				errors.New("foo"),
				richFormatError{},
				errors.New("bar"),
			),
			wantMultiline: "the following errors occurred:\n" +
				" -  foo\n" +
				" -  multiline\n" +
				"    message\n" +
				"    with plus\n" +
				" -  bar",
			wantSingleline: "foo; without plus; bar",
		},
	}

	for i, tt := range tests {
		t.Run(fmt.Sprint(i), func(t *testing.T) {
			err := Combine(tt.giveErrors...)
			require.Equal(t, tt.wantError, err)

			if tt.wantMultiline != "" {
				assert.Equal(t, tt.wantMultiline, fmt.Sprintf("%+v", err))
			}

			if tt.wantSingleline != "" {
				assert.Equal(t, tt.wantSingleline, err.Error())
				if s, ok := err.(fmt.Stringer); ok {
					assert.Equal(t, tt.wantSingleline, s.String())
				}
				assert.Equal(t, tt.wantSingleline, fmt.Sprintf("%v", err))
			}
		})
	}
}

func TestCombineDoesNotModifySlice(t *testing.T) {
	errors := []error{
		errors.New("foo"),
		nil,
		errors.New("bar"),
	}

	assert.NotNil(t, Combine(errors...))
	assert.Len(t, errors, 3)
	assert.Nil(t, errors[1], 3)
}

func TestAppend(t *testing.T) {
	tests := []struct {
		left  error
		right error
		want  error
	}{
		{
			left:  nil,
			right: nil,
			want:  nil,
		},
		{
			left:  nil,
			right: errors.New("great sadness"),
			want:  errors.New("great sadness"),
		},
		{
			left:  errors.New("great sadness"),
			right: nil,
			want:  errors.New("great sadness"),
		},
		{
			left:  errors.New("foo"),
			right: errors.New("bar"),
			want: newMultiErr(
				errors.New("foo"),
				errors.New("bar"),
			),
		},
		{
			left: newMultiErr(
				errors.New("foo"),
				errors.New("bar"),
			),
			right: errors.New("baz"),
			want: newMultiErr(
				errors.New("foo"),
				errors.New("bar"),
				errors.New("baz"),
			),
		},
		{
			left: errors.New("baz"),
			right: newMultiErr(
				errors.New("foo"),
				errors.New("bar"),
			),
			want: newMultiErr(
				errors.New("baz"),
				errors.New("foo"),
				errors.New("bar"),
			),
		},
		{
			left: newMultiErr(
				errors.New("foo"),
			),
			right: newMultiErr(
				errors.New("bar"),
			),
			want: newMultiErr(
				errors.New("foo"),
				errors.New("bar"),
			),
		},
	}

	for _, tt := range tests {
		assert.Equal(t, tt.want, Append(tt.left, tt.right))
	}
}

func createMultiErrWithCapacity() error {
	// Create a multiError that has capacity for more errors so Append will
	// modify the underlying array that may be shared.
	return appendN(nil, errors.New("append"), 50)
}

func TestAppendDoesNotModify(t *testing.T) {
	initial := createMultiErrWithCapacity()
	err1 := Append(initial, errors.New("err1"))
	err2 := Append(initial, errors.New("err2"))

	// Make sure the error messages match, since we do modify the copyNeeded
	// atomic, the values cannot be compared.
	assert.EqualError(t, initial, createMultiErrWithCapacity().Error(), "Initial should not be modified")

	assert.EqualError(t, err1, Append(createMultiErrWithCapacity(), errors.New("err1")).Error())
	assert.EqualError(t, err2, Append(createMultiErrWithCapacity(), errors.New("err2")).Error())
}

func TestAppendRace(t *testing.T) {
	initial := createMultiErrWithCapacity()

	var wg sync.WaitGroup
	for i := 0; i < 10; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()

			err := initial
			for j := 0; j < 10; j++ {
				err = Append(err, errors.New("err"))
			}
		}()
	}

	wg.Wait()
}
