// Copyright 2017 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

package amberctl

import (
	"bytes"
	"crypto/sha256"
	"encoding/hex"
	"encoding/json"
	"flag"
	"fmt"
	"io"
	"io/ioutil"
	"log"
	"net/http"
	"net/url"
	"os"
	"path/filepath"
	"strings"
	"syscall/zx"
	"syscall/zx/zxwait"
	"time"

	"amber/urlscope"
	"app/context"
	"fidl/fuchsia/amber"
)

const usage = `usage: amber_ctl <command> [opts]
Commands
    get_up        - get an update for a package
      Options
        -n:      name of the package
        -v:      version of the package to retrieve, if none is supplied any
                 package instance could match
        -m:      merkle root of the package to retrieve, if none is supplied
                 any package instance could match
        -nowait: exit once package installation has started, but don't wait for
                 package activation

    get_blob      - get the specified content blob
        -i: content ID of the blob

    add_src       - add a source to the list we can use
        -n: name of the update source (optional, with URL)
        -f: file path or url to a source config file
        -h: SHA256 hash of source config file (optional, with URL)
        -x: do not disable other active sources (if the provided source is enabled)

    rm_src        - remove a source, if it exists
        -n: name of the update source

    list_srcs     - list the set of sources we can use

    enable_src
        -n: name of the update source
        -x: do not disable other active sources

    disable_src
        -n: name of the update source

    system_update - check for, download, and apply any available system update

    gc - trigger a garbage collection

    print_state - print go routine state of amber process
`

var (
	fs           = flag.NewFlagSet("default", flag.ExitOnError)
	pkgFile      = fs.String("f", "", "Path to a source config file")
	hash         = fs.String("h", "", "SHA256 hash of source config file (required if -f is a URL, ignored otherwise)")
	name         = fs.String("n", "", "Name of a source or package")
	version      = fs.String("v", "", "Version of a package")
	blobID       = fs.String("i", "", "Content ID of the blob")
	noWait       = fs.Bool("nowait", false, "Return once installation has started, package will not yet be available.")
	merkle       = fs.String("m", "", "Merkle root of the desired update.")
	nonExclusive = fs.Bool("x", false, "When adding or enabling a source, do not disable other sources.")
)

type ErrGetFile string

func NewErrGetFile(str string, inner error) ErrGetFile {
	return ErrGetFile(fmt.Sprintf("%s: %v", str, inner))
}

func (e ErrGetFile) Error() string {
	return string(e)
}

func doTest(pxy *amber.ControlInterface) error {
	v := int32(42)
	resp, err := pxy.DoTest(v)
	if err != nil {
		fmt.Println(err)
		return err
	}

	fmt.Printf("Response: %s\n", resp)
	return nil
}

func connect(ctx *context.Context) (*amber.ControlInterface, amber.ControlInterfaceRequest) {
	req, pxy, err := amber.NewControlInterfaceRequest()
	if err != nil {
		panic(err)
	}
	ctx.ConnectToEnvService(req)
	return pxy, req
}

func addSource(a *amber.ControlInterface) error {
	var cfg amber.SourceConfig

	if len(*pkgFile) == 0 {
		return fmt.Errorf("a url or file path (via -f) are required")
	}

	var source io.Reader
	url, err := url.Parse(*pkgFile)
	isURL := false
	if err == nil && url.IsAbs() {
		isURL = true
		var expectedHash []byte
		hash := strings.TrimSpace(*hash)
		if len(hash) != 0 {

			var err error
			expectedHash, err = hex.DecodeString(hash)
			if err != nil {
				return fmt.Errorf("hash is not a hex encoded string: %v", err)
			}
		}

		resp, err := http.Get(*pkgFile)
		if err != nil {
			return NewErrGetFile("failed to GET file", err)
		}
		defer resp.Body.Close()
		if resp.StatusCode != 200 {
			io.Copy(ioutil.Discard, resp.Body)
			return fmt.Errorf("GET response: %v", resp.Status)
		}

		body, err := ioutil.ReadAll(resp.Body)
		if err != nil {
			return fmt.Errorf("failed to read file body: %v", err)
		}

		if len(expectedHash) != 0 {
			hasher := sha256.New()
			hasher.Write(body)
			actualHash := hasher.Sum(nil)

			if !bytes.Equal(expectedHash, actualHash) {
				return fmt.Errorf("hash of config file does not match!")
			}
		}

		source = bytes.NewReader(body)

	} else {
		f, err := os.Open(*pkgFile)
		if err != nil {
			return fmt.Errorf("failed to open file: %v", err)
		}
		defer f.Close()

		source = f
	}

	if err := json.NewDecoder(source).Decode(&cfg); err != nil {
		return fmt.Errorf("failed to parse source config: %v", err)
	}

	if *name != "" {
		cfg.Id = *name
	}

	// Update the host segment of the URL with the original if it appears to have
	// only been de-scoped, so that link-local configurations retain ipv6 scopes.
	if isURL {
		if remote, err := url.Parse(cfg.RepoUrl); err == nil {
			if u := urlscope.Rescope(url, remote); u != nil {
				cfg.RepoUrl = u.String()
			}
		}
		if remote, err := url.Parse(cfg.BlobRepoUrl); err == nil {
			if u := urlscope.Rescope(url, remote); u != nil {
				cfg.BlobRepoUrl = u.String()
			}
		}
	}

	if cfg.BlobRepoUrl == "" {
		cfg.BlobRepoUrl = filepath.Join(cfg.RepoUrl, "blobs")
	}

	added, err := a.AddSrc(cfg)
	if !added {
		return fmt.Errorf("request arguments properly formatted, but possibly otherwise invalid")
	}
	if err != nil {
		return fmt.Errorf("IPC encountered an error: %s", err)
	}

	if isSourceConfigEnabled(&cfg) && !*nonExclusive {
		if err := disableAllSources(a, cfg.Id); err != nil {
			return err
		}
	}

	return nil
}

func rmSource(a *amber.ControlInterface) error {
	name := strings.TrimSpace(*name)
	if name == "" {
		return fmt.Errorf("no source id provided")
	}

	status, err := a.RemoveSrc(name)
	if err != nil {
		return fmt.Errorf("IPC encountered an error: %s", err)
	}
	switch status {
	case amber.StatusOk:
		return nil
	case amber.StatusErrNotFound:
		return fmt.Errorf("Source not found")
	case amber.StatusErr:
		return fmt.Errorf("Unspecified error")
	default:
		return fmt.Errorf("Unexpected status: %v", status)
	}
}

func getUp(a *amber.ControlInterface) error {
	if *name == "" {
		return fmt.Errorf("no source id provided")
	}
	if *noWait {
		c, err := a.GetUpdateComplete(*name, version, merkle)
		if err != nil {
			return fmt.Errorf("Error getting update %s\n", err)
		}
		c.Close()

		fmt.Printf("Update requested %s\n", *name)
		return nil
	}

	var err error
	for i := 0; i < 3; i++ {
		err = getUpdateComplete(a, *name, version, merkle)
		if err == nil {
			break
		}
		fmt.Printf("Update failed with error %s, retrying...\n", err)
		time.Sleep(2 * time.Second)
	}
	return err
}

func listSources(a *amber.ControlInterface) error {
	srcs, err := a.ListSrcs()
	if err != nil {
		fmt.Printf("failed to list sources: %s\n", err)
		return err
	}

	for _, src := range srcs {
		encoder := json.NewEncoder(os.Stdout)
		encoder.SetIndent("", "    ")
		if err := encoder.Encode(src); err != nil {
			fmt.Printf("failed to encode source into json: %s\n", err)
			return err
		}
	}

	return nil
}

func setSourceEnablement(a *amber.ControlInterface, id string, enabled bool) error {
	status, err := a.SetSrcEnabled(id, enabled)
	if err != nil {
		return fmt.Errorf("call failure attempting to change source status: %s", err)
	}
	if status != amber.StatusOk {
		return fmt.Errorf("failure changing source status")
	}

	return nil
}

func isSourceConfigEnabled(cfg *amber.SourceConfig) bool {
	if cfg.StatusConfig == nil {
		return true
	}
	return cfg.StatusConfig.Enabled
}

func disableAllSources(a *amber.ControlInterface, except string) error {
	errorIds := []string{}
	cfgs, err := a.ListSrcs()
	if err != nil {
		return err
	}
	for _, cfg := range cfgs {
		if cfg.Id != except && isSourceConfigEnabled(&cfg) {
			if err := setSourceEnablement(a, cfg.Id, false); err != nil {
				log.Printf("error disabling %q: %s", cfg.Id, err)
				errorIds = append(errorIds, fmt.Sprintf("%q", cfg.Id))
			} else {
				fmt.Printf("Source %q disabled\n", cfg.Id)
			}
		}
	}
	if len(errorIds) > 0 {
		return fmt.Errorf("could not disable %s", strings.Join(errorIds, ", "))
	}
	return nil
}

func printState(proxy *amber.ControlInterface) error {
	rd, wr, e := zx.NewChannel(0)
	if e != nil {
		return fmt.Errorf("channel creation failed: %s", e)
	}

	err := proxy.GetProcessState(wr)
	if err != nil {
		log.Printf("Error getting state from service")
		return err
	}

	defer rd.Close()
	b := make([]byte, 64*1024)
	for {
		sigs, err := zxwait.Wait(*rd.Handle(), zx.SignalChannelReadable|zx.SignalChannelPeerClosed,
			zx.TimensecInfinite)
		if err != nil {
			log.Printf("Unexpected error waiting on channel: %s", err)
			return NewErrDaemon(
				fmt.Sprintf("unknown error while waiting for response from channel: %s", err))
		}

		if sigs&zx.SignalChannelReadable != 0 {
			sz, _, err := rd.Read(b, []zx.Handle{}, 0)
			if err != nil {
				return NewErrDaemon(fmt.Sprintf("error reading channel: %s", err))
			}
			fmt.Printf(string(b[:sz]))
			continue
		}

		if sigs&zx.SignalChannelPeerClosed != 0 {
			break
		}
	}
	return nil
}

func do(proxy *amber.ControlInterface) int {
	switch os.Args[1] {
	case "get_up":
		if err := getUp(proxy); err != nil {
			log.Printf("error getting an update: %s", err)
			return 1
		}
	case "get_blob":
		if *blobID == "" {
			log.Printf("no blob id provided")
			return 1
		}
		if err := proxy.GetBlob(*blobID); err != nil {
			log.Printf("error requesting blob fetch: %s", err)
			return 1
		}
	case "add_src":
		if err := addSource(proxy); err != nil {
			log.Printf("error adding source: %s", err)
			if _, ok := err.(ErrGetFile); ok {
				return 2
			} else {
				return 1
			}
		}
	case "rm_src":
		if err := rmSource(proxy); err != nil {
			log.Printf("error removing source: %s", err)
			return 1
		}
	case "list_srcs":
		if err := listSources(proxy); err != nil {
			log.Printf("error listing sources: %s", err)
			return 1
		}
	case "check":
		log.Printf("%q not yet supported\n", os.Args[1])
		return 1
	case "test":
		if err := doTest(proxy); err != nil {
			log.Printf("error testing connection to amber: %s", err)
			return 1
		}
	case "system_update":
		configured, err := proxy.CheckForSystemUpdate()
		if err != nil {
			log.Printf("error checking for system update: %s", err)
			return 1
		}

		if configured {
			fmt.Printf("triggered a system update check\n")
		} else {
			fmt.Printf("system update is not configured\n")
		}
	case "login":
		device, err := proxy.Login(*name)
		if err != nil {
			log.Printf("failed to login: %s", err)
			return 1
		}
		fmt.Printf("On your computer go to:\n\n\t%v\n\nand enter\n\n\t%v\n\n", device.VerificationUrl, device.UserCode)
	case "enable_src":
		if *name == "" {
			log.Printf("Error enabling source: no source id provided")
			return 1
		}
		err := setSourceEnablement(proxy, *name, true)
		if err != nil {
			log.Printf("Error enabling source: %s", err)
			return 1
		}
		fmt.Printf("Source %q enabled\n", *name)
		if !*nonExclusive {
			if err := disableAllSources(proxy, *name); err != nil {
				log.Printf("Error disabling sources: %s", err)
				return 1
			}
		}
	case "disable_src":
		if *name == "" {
			log.Printf("Error disabling source: no source id provided")
			return 1
		}
		err := setSourceEnablement(proxy, *name, false)
		if err != nil {
			log.Printf("Error disabling source: %s", err)
			return 1
		}
		fmt.Printf("Source %q disabled\n", *name)
	case "gc":
		err := proxy.Gc()
		if err != nil {
			log.Printf("Error collecting garbage: %s", err)
			return 1
		}
		log.Printf("Started garbage collection. See logs for details")
	case "print_state":
		err := printState(proxy)
		if err != nil {
			log.Printf("Error printing process state: %s", err)
			return 1
		}
	default:

		log.Printf("Error, %q is not a recognized command\n%s",
			os.Args[1], usage)
		return -1
	}

	return 0
}

func Main() {
	if len(os.Args) < 2 {
		fmt.Printf("Error: no command provided\n%s\n", usage)
		os.Exit(-1)
	}

	fs.Parse(os.Args[2:])

	proxy, _ := connect(context.CreateFromStartupInfo())
	defer proxy.Close()

	os.Exit(do(proxy))
}

type ErrDaemon string

func NewErrDaemon(str string) ErrDaemon {
	return ErrDaemon(fmt.Sprintf("amber_ctl: daemon error: %s", str))
}

func (e ErrDaemon) Error() string {
	return string(e)
}

func getUpdateComplete(proxy *amber.ControlInterface, name string, version *string, merkle *string) error {
	c, err := proxy.GetUpdateComplete(name, version, merkle)
	if err != nil {
		return NewErrDaemon(fmt.Sprintf("error making FIDL request: %s", err))
	}

	defer c.Close()
	b := make([]byte, 64*1024)
	for {
		sigs, err := zxwait.Wait(*c.Handle(),
			zx.SignalChannelPeerClosed|zx.SignalChannelReadable,
			zx.Sys_deadline_after(zx.Duration((3 * time.Second).Nanoseconds())))

		if err != nil {
			if zerr, ok := err.(zx.Error); ok && zerr.Status == zx.ErrTimedOut {
				log.Println("Awaiting response...")
				continue
			}
			return NewErrDaemon(
				fmt.Sprintf("unknown error while waiting for response from channel: %s", err))
		}

		if sigs&zx.SignalChannelReadable != 0 {
			bs, _, err := c.Read(b, []zx.Handle{}, 0)
			if err != nil {
				return NewErrDaemon(
					fmt.Sprintf("error reading response from channel: %s", err))
			}

			if sigs&zx.SignalUser0 != 0 {
				return NewErrDaemon(string(b[:bs]))
			}

			pkgname := name
			if version != nil {
				pkgname = filepath.Join(pkgname, *version)
			}
			if merkle != nil {
				pkgname = filepath.Join(pkgname, *merkle)
			}
			log.Printf("Success %s: %s", pkgname, string(b[:bs]))
			return nil
		}

		if sigs&zx.SignalChannelPeerClosed != 0 {
			return NewErrDaemon("response channel closed unexpectedly.")
		}
	}
}
