// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// +build fuchsia

package fidl_test

import (
	"math/rand"
	"sync"
	"syscall/zx"
	"syscall/zx/fidl"
	"syscall/zx/fidl/bindingstest"
	"syscall/zx/zxwait"
	"testing"
)

var _ fidl.Message = (*message)(nil)

type message struct {
	_ struct{} `fidl:"s" fidl_size_v1:"1" fidl_alignment_v1:"0"`
}

var _mMessage = fidl.MustCreateMarshaler(message{})

func (*message) Marshaler() fidl.Marshaler {
	return _mMessage
}

func TestProxySimple(t *testing.T) {
	c0, c1, err := zx.NewChannel(0)
	if err != nil {
		t.Fatal(err)
	}
	defer c0.Close()
	defer c1.Close()

	var wg sync.WaitGroup
	defer wg.Wait()

	wg.Add(1)
	go func() {
		defer wg.Done()
		sender := &fidl.ChannelProxy{Channel: c0}
		if err := sender.Call(1, &message{}, &message{}); err != nil {
			t.Error(err)
		}
	}()

	if err := serve(c1); err != nil {
		t.Fatal(err)
	}
}

func TestCallReturnsError(t *testing.T) {
	var c zx.Channel // Invalid channel.
	sender := &fidl.ChannelProxy{Channel: c}
	err := sender.Call(1, &message{}, &message{})
	if err == nil {
		t.Fatalf("expected ErrBadHandle error")
	} else if zerr, ok := err.(*zx.Error); !ok || zerr.Status != zx.ErrBadHandle {
		t.Fatalf("unexpected error: %v", err)
	}
}

func TestCall(t *testing.T) {
	const concurrency = 10
	const numMessages = 100

	c0, c1, err := zx.NewChannel(0)
	if err != nil {
		t.Fatal(err)
	}
	defer c0.Close()
	defer c1.Close()

	var wg sync.WaitGroup
	defer wg.Wait()

	sender := fidl.ChannelProxy{Channel: c0}
	// Spin up |concurrency| senders, each sending |numMessages|.
	for i := 0; i < concurrency; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			for i := 0; i < numMessages; i++ {
				ordinal := rand.Uint64()

				if err := sender.Call(ordinal, &message{}, &message{}); err != nil {
					t.Error(err)
				}
			}
		}()
	}

	//.Echo the |concurrency*numMessages| messages sent by the senders above.
	for i := 0; i < concurrency*numMessages; i++ {
		if err := serve(c1); err != nil {
			t.Fatal(err)
		}
	}
}

func TestRecv(t *testing.T) {
	const n = 100
	c0, c1, err := zx.NewChannel(0)
	if err != nil {
		t.Fatal(err)
	}
	defer c0.Close()
	defer c1.Close()

	const ordinal uint64 = 1

	var wg sync.WaitGroup
	defer wg.Wait()

	for i := 0; i < n; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()

			receiver := &fidl.ChannelProxy{Channel: c1}
			if err := zxwait.WithRetry(func() error {
				return receiver.Recv(ordinal, &message{})
			}, *receiver.Channel.Handle(), zx.SignalChannelReadable, zx.SignalChannelPeerClosed); err != nil {
				t.Error(err)
			}
		}()
	}

	for i := 0; i < n; i++ {
		sender := &fidl.ChannelProxy{Channel: c0}
		if err := sender.Send(ordinal, &message{}); err != nil {
			t.Fatal(err)
		}
	}
}

func TestRecvReturnsError(t *testing.T) {
	var c zx.Channel // Invalid channel.
	sender := &fidl.ChannelProxy{Channel: c}
	err := sender.Recv(1, &message{})
	if err == nil {
		t.Fatalf("expected ErrBadHandle error")
	} else if zerr, ok := err.(*zx.Error); !ok || zerr.Status != zx.ErrBadHandle {
		t.Fatalf("unexpected error: %v", err)
	}
}

func TestMagicNumberSend(t *testing.T) {
	ch, sh, err := zx.NewChannel(0)
	if err != nil {
		t.Fatal(err)
	}
	defer ch.Close()
	defer sh.Close()

	msg := bindingstest.TestSimple{X: 3}
	client := fidl.ChannelProxy{Channel: ch}
	if err := client.Send(0, &msg); err != nil {
		t.Fatal(err)
	}

	respb := make([]byte, zx.ChannelMaxMessageBytes)
	resphi := make([]zx.HandleInfo, zx.ChannelMaxMessageHandles)
	nb, nh, err := sh.ReadEtc(respb[:], resphi[:], 0)
	if err != nil {
		t.Fatal(err)
	}

	var header fidl.MessageHeader
	if err = fidl.UnmarshalHeaderThenMessage(respb[:nb], resphi[:nh], &header, &msg); err != nil {
		t.Fatal(err)
	}

	if header.Magic != fidl.FidlWireFormatMagicNumberInitial {
		t.Logf("expected client to send initial magic number")
		t.Fatal(err)
	}
}

func TestMagicNumberCheck(t *testing.T) {
	ch, sh, err := zx.NewChannel(0)
	if err != nil {
		t.Fatal(err)
	}
	defer ch.Close()
	defer sh.Close()
	client := fidl.ChannelProxy{Channel: ch}

	// send an event with an unknown magic number
	event := []byte{
		0, 0, 0, 0, //txid
		0, 0, 0, // flags
		0,                      // magic number,
		0, 0, 0, 0, 0, 0, 0, 0, // method ordinal
		0, 0, 0, 0, 0, 0, 0, 0, // empty struct data
	}
	if err := sh.Write(event, []zx.Handle{}, 0); err != nil {
		t.Fatal(err)
	}

	var eventMsg bindingstest.EmptyStruct
	if err := client.Recv(0, &eventMsg); err != fidl.ErrUnknownMagic {
		t.Fatal("expected unknown magic error")
	}

	// read directly from channel to ensure it is closed.
	_, _, err = sh.Read([]byte{}, []zx.Handle{}, 0)
	zerr, ok := err.(*zx.Error)
	if err == nil || (ok && zerr.Status != zx.ErrPeerClosed) {
		t.Logf("expected client side to be closed")
		t.Fatal(err)
	}
}

func TestConcurrentUseOfChannelProxy(t *testing.T) {
	const nEvents = 500
	const nMessages = 500

	clientChan, serverChan, err := zx.NewChannel(0)
	if err != nil {
		t.Fatal(err)
	}
	defer clientChan.Close()
	defer serverChan.Close()

	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		defer wg.Done()
		for i := 0; i < nMessages; i++ {
			if err := serve(serverChan); err != nil {
				t.Fatal(err)
			}
		}
	}()

	wg.Add(1)
	go func() {
		defer wg.Done()
		sender := &fidl.ChannelProxy{Channel: serverChan}
		for i := 0; i < nEvents; i++ {
			if err := sender.Send(1, &message{}); err != nil {
				t.Fatal(err)
			}
		}
	}()

	client := &fidl.ChannelProxy{Channel: clientChan}

	for i := 0; i < nMessages; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			if err := client.Call(1, &message{}, &message{}); err != nil {
				t.Fatal(err)
			}
		}()
	}

	for i := 0; i < nEvents; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			if err := client.Recv(1, &message{}); err != nil {
				t.Fatal(err)
			}
		}()
	}

	wg.Wait()
}

func serve(ch zx.Channel) error {
	var respb [zx.ChannelMaxMessageBytes]byte

	var header fidl.MessageHeader
	var msg message
	if err := zxwait.WithRetry(func() error {
		nb, _, err := ch.Read(respb[:], nil, 0)
		if err != nil {
			return err
		}
		if err := fidl.UnmarshalHeaderThenMessage(respb[:nb], nil, &header, &msg); err != nil {
			return err
		}
		return nil
	}, *ch.Handle(), zx.SignalChannelReadable, zx.SignalChannelPeerClosed); err != nil {
		return err
	}

	cnb, _, err := fidl.MarshalHeaderThenMessage(&header, &msg, respb[:], nil)
	if err != nil {
		return err
	}

	if err := ch.Write(respb[:cnb], nil, 0); err != nil {
		return err
	}

	return nil
}
