// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// +build fuchsia

package fidl

import (
	"sync"
	"syscall/zx"
	"syscall/zx/dispatch"
)

const prefix = "zx/fidl: "

// d is a process-local dispatcher.
var d *dispatch.Dispatcher

func init() {
	disp, err := dispatch.NewDispatcher()
	if err != nil {
		panic(prefix + "failed to create dispatcher: " + err.Error())
	}
	d = disp
}

// Serve is a blocking call to the process-local dispatcher's serve method.
func Serve() {
	d.Serve()
}

type bindingState int32

const (
	idle bindingState = iota
	handling
	cleanup
	closed
)

// Binding binds the implementation of a Stub to a Channel.
//
// A Binding listens for incoming messages on the Channel, decodes them, and
// asks the Stub to dispatch to the appropriate implementation of the interface.
// If the message expects a reply, the Binding will also encode the reply and
// send it back over the Channel.
type Binding struct {
	// Stub is a wrapper around an implementation of a FIDL interface which
	// knows how to dispatch to a method by ordinal.
	Stub Stub

	// Channel is the Channel primitive to which the Stub is bound.
	Channel zx.Channel

	// id is an identifier for this waiter in the process-local dispatcher.
	id *dispatch.WaitID

	// errHandler is called with connection errors, when encountered.
	errHandler func(error)

	// handling is an atomically-updated signal which represents the state of the
	// binding.
	stateMu sync.Mutex
	state   bindingState
}

// Init initializes a Binding.
func (b *Binding) Init(errHandler func(error)) error {
	// Declare the wait handler as a closure.
	h := func(d *dispatch.Dispatcher, s zx.Status, sigs *zx.PacketSignal) (result dispatch.WaitResult) {
		b.stateMu.Lock()
		if b.state == cleanup {
			b.close()
			b.stateMu.Unlock()
			return dispatch.WaitFinished
		}
		b.state = handling
		b.stateMu.Unlock()
		defer func() {
			b.stateMu.Lock()
			defer b.stateMu.Unlock()
			if b.state == cleanup {
				// close will set b.state to closed.
				b.close()
				result = dispatch.WaitFinished
			} else {
				// otherwise, go to idle (out of handling state)
				b.state = idle
			}
		}()
		if s != zx.ErrOk {
			b.errHandler(&zx.Error{Status: s})
			return dispatch.WaitFinished
		}
		if sigs.Observed&zx.SignalChannelReadable != 0 {
			for i := uint64(0); i < sigs.Count; i++ {
				shouldWait, err := b.dispatch()
				if err != nil {
					b.errHandler(err)
					return dispatch.WaitFinished
				}
				if shouldWait {
					return dispatch.WaitAgain
				}
			}
			return dispatch.WaitAgain
		}
		b.errHandler(&zx.Error{Status: zx.ErrPeerClosed})
		return dispatch.WaitFinished
	}

	b.stateMu.Lock()
	// We defer unlocking the state here so that there's no way that the wait
	// actually triggers and tries to cancel the wait before we've stored the id.
	defer b.stateMu.Unlock()
	b.state = idle

	b.errHandler = errHandler

	// Start the wait on the Channel.
	id, err := d.BeginWait(
		zx.Handle(b.Channel),
		zx.SignalChannelReadable|zx.SignalChannelPeerClosed,
		0,
		h,
	)
	if err != nil {
		return err
	}
	b.id = &id
	return nil
}

// dispatch reads from the underlying Channel and dispatches into the Stub.
//
// Returns true if the channel should be waited, or false if there is more data to be read.
func (b *Binding) dispatch() (shouldWait bool, err error) {
	respb := messageBytesPool.Get().([]byte)
	resph := messageHandlesPool.Get().([]zx.Handle)

	defer messageBytesPool.Put(respb)
	defer messageHandlesPool.Put(resph)

	nb, nh, err := b.Channel.Read(respb, resph, 0)
	if err != nil {
		if err, ok := err.(*zx.Error); ok && err.Status == zx.ErrShouldWait {
			return true, nil
		}
		return false, err
	}
	msg := respb[:nb]
	handles := resph[:nh]
	// close handles we own on erroneous exit
	hasFailed := true
	defer func() {
		if hasFailed {
			for _, h := range handles {
				h.Close()
			}
		}
	}()

	var header MessageHeader
	hnb, _, err := Unmarshal(msg, nil, &header)
	if err != nil {
		return false, err
	}

	if !header.IsSupportedVersion() {
		return false, ErrUnknownMagic
	}

	p, shouldRespond, err := b.Stub.DispatchImpl(header.Ordinal, msg[hnb:], handles)
	if err != nil {
		return false, err
	}

	if shouldRespond {
		cnb, cnh, err := MarshalHeaderThenMessage(&header, p, respb, resph)
		if err != nil {
			return false, err
		}
		if err := b.Channel.Write(respb[:cnb], resph[:cnh], 0); err != nil {
			return false, err
		}
	}

	hasFailed = false
	return false, nil
}

// Close cancels any outstanding waits, resets the Binding's state, and closes
// the bound Channel once any out-standing requests are finished being handled.
func (b *Binding) Close() error {
	b.stateMu.Lock()
	defer b.stateMu.Unlock()
	if zx.Handle(b.Channel) == zx.HandleInvalid || b.state == cleanup || b.state == closed {
		panic(prefix + "double binding close")
	}
	switch b.state {
	case idle:
		return b.close()
	case handling:
		b.state = cleanup
	}
	return nil
}

// close cancels any outstanding waits, resets the Binding's state, and
// closes the bound Channel. This method is not thread-safe, and should be
// called with the binding's mutex set.
func (b *Binding) close() error {
	if err := d.CancelWait(*b.id); err != nil {
		// If it just says that the ID isn't found, there are cases where this is
		// a reasonable error (particularly when we're in the middle of handling
		// a signal from the dispatcher).
		if err, ok := err.(*zx.Error); !ok || err.Status != zx.ErrNotFound {
			// Attempt to close the channel if we hit a more serious error.
			b.Channel.Close()
			return err
		}
	}
	b.id = nil
	b.state = closed
	return b.Channel.Close()
}

// BindingKey is a key which maps to a specific binding.
//
// It is only valid for the BindingSet that produced it.
type BindingKey uint64

// BindingSet is a managed set of Bindings which know how to unbind and
// remove themselves in the event of a connection error.
type BindingSet struct {
	mu       sync.Mutex
	nextKey  BindingKey
	bindings map[BindingKey]*Binding
}

// Add creates a new Binding, initializes it, and adds it to the set.
//
// onError is an optional handler than may be passed which will be called after
// the binding between the Stub and the Channel is successfully closed.
func (b *BindingSet) Add(s Stub, c zx.Channel, onError func(error)) (BindingKey, error) {
	b.mu.Lock()
	defer b.mu.Unlock()
	binding := &Binding{
		Stub:    s,
		Channel: c,
	}
	if b.bindings == nil {
		b.bindings = make(map[BindingKey]*Binding)
	}
	key := b.nextKey
	err := binding.Init(func(err error) {
		if b.Remove(key) && onError != nil {
			onError(err)
		}
	})
	if err != nil {
		return 0, err
	}
	b.bindings[key] = binding
	b.nextKey += 1
	return key, nil
}

// Size returns the number of BindingKeys in the BindingSet.
func (b *BindingSet) Size() int {
	b.mu.Lock()
	defer b.mu.Unlock()
	return len(b.bindings)
}

// BindingKeys returns a copy slice of all the BindingKeys in the BindingSet.
func (b *BindingSet) BindingKeys() []BindingKey {
	b.mu.Lock()
	defer b.mu.Unlock()
	r := make([]BindingKey, 0, len(b.bindings))
	for key := range b.bindings {
		r = append(r, key)
	}
	return r
}

// ProxyFor returns an event proxy created from the channel of the binding referred
// to by key.
func (b *BindingSet) ProxyFor(key BindingKey) (*ChannelProxy, bool) {
	b.mu.Lock()
	defer b.mu.Unlock()
	if binding, ok := b.bindings[key]; ok {
		return &ChannelProxy{Channel: binding.Channel}, true
	}
	return nil, false
}

// Remove removes a Binding from the set when it is next idle.
//
// Note that this method invalidates the key, and it will never remove a Binding
// while it is actively being handled.
//
// Returns true if a Binding was found and removed.
func (b *BindingSet) Remove(key BindingKey) bool {
	b.mu.Lock()
	defer b.mu.Unlock()
	if binding, ok := b.bindings[key]; ok {
		delete(b.bindings, key)

		// Close the binding before calling the callback.
		if err := binding.Close(); err != nil {
			// Just panic. The only reason this can fail is if the handle
			// is bad, which it shouldn't be if we're tracking things. If
			// it does fail, better to fail fast.
			panic(prefix + err.Error())
		}
		return true
	}
	return false
}

// Close removes all the bindings from the set.
func (b *BindingSet) Close() {
	// Lock, close all the bindings, and clear the map.
	b.mu.Lock()
	defer b.mu.Unlock()
	for _, binding := range b.bindings {
		if err := binding.Close(); err != nil {
			// Just panic. The only reason this can fail is if the handle
			// is bad, which it shouldn't be if we're tracking things. If
			// it does fail, better to fail fast.
			panic(prefix + err.Error())
		}
	}
	b.bindings = nil
}
