// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// `go mod` ignores file names for the purpose of resolving
// dependencies, and zxwait doesn't build on not-Fuchsia.
//go:build fuchsia

package fidl_test

import (
	"context"
	"math/rand"
	"sync"
	"syscall/zx"
	"syscall/zx/fidl"
	"syscall/zx/zxwait"
	"testing"
)

var _ fidl.Message = (*message)(nil)

type message struct {
	_ struct{} `fidl:"s" fidl_size_v2:"1" fidl_alignment_v2:"0"`
}

var _mMessage = fidl.MustCreateMarshaler(message{})

func (*message) Marshaler() fidl.Marshaler {
	return _mMessage
}

func TestProxySimple(t *testing.T) {
	c0, c1, err := zx.NewChannel(0)
	if err != nil {
		t.Fatal(err)
	}
	defer func() {
		if err := c0.Close(); err != nil {
			t.Error(err)
		}
		if err := c1.Close(); err != nil {
			t.Error(err)
		}
	}()

	var wg sync.WaitGroup
	defer wg.Wait()

	wg.Add(1)
	go func() {
		defer wg.Done()
		sender := &fidl.ChannelProxy{Channel: c0}
		if err := sender.Call(1, &message{}, &message{}); err != nil {
			t.Error(err)
		}
	}()

	if err := serve(c1); err != nil {
		t.Fatal(err)
	}
}

func TestCallReturnsError(t *testing.T) {
	var c zx.Channel // Invalid channel.
	sender := &fidl.ChannelProxy{Channel: c}
	if err, ok := sender.Call(1, &message{}, &message{}).(*zx.Error); !(ok && err.Status == zx.ErrBadHandle) {
		t.Errorf("got sender.Call(...) = %v, want = ErrBadHandle", err)
	}
}

func TestCall(t *testing.T) {
	const concurrency = 10
	const numMessages = 100

	c0, c1, err := zx.NewChannel(0)
	if err != nil {
		t.Fatal(err)
	}
	defer func() {
		if err := c0.Close(); err != nil {
			t.Log(err)
		}
		if err := c1.Close(); err != nil {
			t.Log(err)
		}
	}()

	var wg sync.WaitGroup
	defer wg.Wait()

	sender := fidl.ChannelProxy{Channel: c0}
	// Spin up |concurrency| senders, each sending |numMessages|.
	for i := 0; i < concurrency; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			for i := 0; i < numMessages; i++ {
				ordinal := rand.Uint64()

				if err := sender.Call(ordinal, &message{}, &message{}); err != nil {
					t.Error(err)
				}
			}
		}()
	}

	// A few extra reads from the channel to observe peer closure.
	//
	// This is a regression test for issues where observed errors would cause:
	//
	// - all outstanding calls to fail even if their responses had already
	//   been read.
	//
	// - pending reads to deadlock because the proxy did not release the "read
	//   lease", causing "followers" to wait indefinitely.
	for i := 0; i < 5; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			if err := sender.Recv(0, &message{}); err == nil {
				t.Error("expected ErrPeerClosed")
			} else if zerr, ok := err.(*zx.Error); !ok || zerr.Status != zx.ErrPeerClosed {
				t.Errorf("unexpected error: %s", err)
			}
		}()
	}

	// Echo the |concurrency*numMessages| messages sent by the senders above.
	for i := 0; i < concurrency*numMessages; i++ {
		if err := serve(c1); err != nil {
			t.Fatal(err)
		}
	}
	if err := c1.Close(); err != nil {
		t.Error(err)
	}
}

func TestRecv(t *testing.T) {
	const n = 100
	c0, c1, err := zx.NewChannel(0)
	if err != nil {
		t.Fatal(err)
	}
	defer func() {
		if err := c0.Close(); err != nil {
			t.Error(err)
		}
		if err := c1.Close(); err != nil {
			t.Error(err)
		}
	}()

	const ordinal uint64 = 1

	var wg sync.WaitGroup
	defer wg.Wait()

	for i := 0; i < n; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()

			receiver := &fidl.ChannelProxy{Channel: c1}
			if err := zxwait.WithRetryContext(context.Background(), func() error {
				return receiver.Recv(ordinal, &message{})
			}, *receiver.Channel.Handle(), zx.SignalChannelReadable, zx.SignalChannelPeerClosed); err != nil {
				t.Error(err)
			}
		}()
	}

	for i := 0; i < n; i++ {
		sender := &fidl.ChannelProxy{Channel: c0}
		if err := sender.Send(ordinal, &message{}); err != nil {
			t.Fatal(err)
		}
	}
}

func TestRecvReturnsError(t *testing.T) {
	var c zx.Channel // Invalid channel.
	sender := &fidl.ChannelProxy{Channel: c}
	if err, ok := sender.Recv(1, &message{}).(*zx.Error); !(ok && err.Status == zx.ErrBadHandle) {
		t.Errorf("got sender.Recv(...) = %v, want = ErrBadHandle", err)
	}
}

func TestMagicNumberSend(t *testing.T) {
	ch, sh, err := zx.NewChannel(0)
	if err != nil {
		t.Fatal(err)
	}
	defer func() {
		if err := ch.Close(); err != nil {
			t.Error(err)
		}
		if err := sh.Close(); err != nil {
			t.Error(err)
		}
	}()

	var msg message
	client := fidl.ChannelProxy{Channel: ch}
	if err := client.Send(0, &msg); err != nil {
		t.Fatal(err)
	}

	respb := make([]byte, zx.ChannelMaxMessageBytes)
	resphi := make([]zx.HandleInfo, zx.ChannelMaxMessageHandles)
	nb, nh, err := sh.ReadEtc(respb[:], resphi[:], 0)
	if err != nil {
		t.Fatal(err)
	}

	var header fidl.MessageHeader
	if err := fidl.UnmarshalHeaderThenMessage(respb[:nb], resphi[:nh], &header, &msg); err != nil {
		t.Fatal(err)
	}

	if header.Magic != fidl.FidlWireFormatMagicNumberInitial {
		t.Fatalf("got header.Magic = %d, want = %d", header.Magic, fidl.FidlWireFormatMagicNumberInitial)
	}
}

func TestMagicNumberCheck(t *testing.T) {
	ch, sh, err := zx.NewChannel(0)
	if err != nil {
		t.Fatal(err)
	}
	defer func() {
		if err, ok := ch.Close().(*zx.Error); !(ok && err.Status == zx.ErrBadHandle) {
			t.Errorf("got ch.Close() = %v, want = ErrBadHandle", err)
		}
		if err := sh.Close(); err != nil {
			t.Error(err)
		}
	}()

	client := fidl.ChannelProxy{Channel: ch}

	// send an event with an unknown magic number
	event := []byte{
		0, 0, 0, 0, //txid
		0, 0, 0, // flags
		0,                      // magic number,
		0, 0, 0, 0, 0, 0, 0, 0, // method ordinal
		0, 0, 0, 0, 0, 0, 0, 0, // empty struct data
	}
	if err := sh.Write(event, []zx.Handle{}, 0); err != nil {
		t.Fatal(err)
	}

	var eventMsg message
	if err := client.Recv(0, &eventMsg); err != fidl.ErrUnknownMagic {
		t.Fatalf("got client.Recv(...) = %v, want = %s", err, fidl.ErrUnknownMagic)
	}

	func() {
		// read directly from channel to ensure it is closed.
		_, _, err := sh.Read(nil, nil, 0)
		switch err := err.(type) {
		case *zx.Error:
			if err.Status == zx.ErrPeerClosed {
				return
			}
		}
		t.Fatalf("got sh.Read(...) = %v, want = %s", err, zx.ErrPeerClosed)
	}()
}

func TestConcurrentUseOfChannelProxy(t *testing.T) {
	const nEvents = 500
	const nMessages = 500

	clientChan, serverChan, err := zx.NewChannel(0)
	if err != nil {
		t.Fatal(err)
	}
	defer func() {
		if err := clientChan.Close(); err != nil {
			t.Error(err)
		}
		if err := serverChan.Close(); err != nil {
			t.Error(err)
		}
	}()

	var wg sync.WaitGroup
	defer wg.Wait()

	wg.Add(1)
	go func() {
		defer wg.Done()
		for i := 0; i < nMessages; i++ {
			if err := serve(serverChan); err != nil {
				t.Error(err)
			}
		}
	}()

	wg.Add(1)
	go func() {
		defer wg.Done()
		sender := &fidl.ChannelProxy{Channel: serverChan}
		for i := 0; i < nEvents; i++ {
			if err := sender.Send(1, &message{}); err != nil {
				t.Error(err)
			}
		}
	}()

	client := &fidl.ChannelProxy{Channel: clientChan}

	for i := 0; i < nMessages; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			if err := client.Call(1, &message{}, &message{}); err != nil {
				t.Error(err)
			}
		}()
	}

	for i := 0; i < nEvents; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			if err := client.Recv(1, &message{}); err != nil {
				t.Error(err)
			}
		}()
	}
}

func serve(ch zx.Channel) error {
	var respb [zx.ChannelMaxMessageBytes]byte

	var header fidl.MessageHeader
	var msg message
	if err := zxwait.WithRetryContext(context.Background(), func() error {
		nb, _, err := ch.Read(respb[:], nil, 0)
		if err != nil {
			return err
		}
		if err := fidl.UnmarshalHeaderThenMessage(respb[:nb], nil, &header, &msg); err != nil {
			return err
		}
		return nil
	}, *ch.Handle(), zx.SignalChannelReadable, zx.SignalChannelPeerClosed); err != nil {
		return err
	}

	cnb, _, err := fidl.MarshalHeaderThenMessage(&header, &msg, respb[:], nil)
	if err != nil {
		return err
	}

	if err := ch.Write(respb[:cnb], nil, 0); err != nil {
		return err
	}

	return nil
}
