// Copyright 2019 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can
// found in the LICENSE file.

package main

import (
	"context"
	"flag"
	"fmt"
	"io"
	"log"
	"net/http"
	"net/url"
	"os"
	"os/exec"
	"os/signal"
	"path/filepath"
	"strconv"
	"strings"
	"syscall"
	"time"

	"cloud.google.com/go/storage"
	devicePkg "go.fuchsia.dev/infra/devices"
)

const (
	failureExitCode = 1
	portEnvVar      = "NUC_SERVER_PORT"
	// chainloaderExitStr instructs a NUC to exit IPXE and boot from disk
	chainloaderExitStr = "#!ipxe\necho Catalyst chain loader\nexit 0\n"
	rebootDuration     = 1 * time.Minute
	zedbootPath        = "zedboot.zbi"
)

var (
	imagesManifest   string
	deviceConfigPath string
	bootserverPath   string
)

func init() {
	flag.StringVar(&imagesManifest, "images", "", "Path to the images manifest json")
	flag.StringVar(&deviceConfigPath, "config", "", "Path to the device config file")
	flag.StringVar(&bootserverPath, "bootserver", "", "Path to the bootserver binary")
}

// Runs a subprocess and sets up a handler that propagates SIGTERM on context cancel.
func runSubprocess(ctx context.Context, command []string) int {
	if len(command) == 0 {
		return 0
	}
	cmd := exec.Command(command[0], command[1:]...)

	// Spin off handler to exit subprocesses cleanly via SIGTERM.
	processDone := make(chan bool, 1)
	go func() {
		select {
		case <-processDone:
		case <-ctx.Done():
			if err := cmd.Process.Signal(syscall.SIGTERM); err != nil {
				log.Printf("exited cmd with error %v", err)
			}
		}
	}()

	// Ensure that the context still exists before running the subprocess.
	if ctx.Err() != nil {
		return failureExitCode
	}

	cmd.Stdout = os.Stdout
	cmd.Stderr = os.Stderr
	cmd.Run()
	processDone <- true
	return cmd.ProcessState.ExitCode()
}

func runNUCServer(ctx context.Context, devices []*devicePkg.DeviceTarget) (*http.Server, error) {
	// Parse the port from the environment variable.
	portStr := os.Getenv(portEnvVar)
	if portStr == "" {
		return nil, nil
	}
	port, err := strconv.Atoi(portStr)
	if err != nil {
		return nil, err
	}

	mux := http.NewServeMux()
	for _, device := range devices {
		// Add endpoint for exit chainloader. This allows NUCs to reboot from disk.
		mux.HandleFunc(fmt.Sprintf("/%s.ipxe", device.Mac()), func(w http.ResponseWriter, r *http.Request) {
			w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s.ipxe", device.Mac()))
			chainloader := strings.NewReader(chainloaderExitStr)
			io.Copy(w, chainloader)
		})

		// Add single-use endpoint to deliver build's version of zedboot.
		zedbootFile, err := os.Open(zedbootPath)
		if err != nil {
			zedbootFile = nil
		}
		mux.HandleFunc(fmt.Sprintf("/zedboot/%s", device.Mac()), func(w http.ResponseWriter, r *http.Request) {
			if zedbootFile == nil {
				w.WriteHeader(http.StatusNotFound)
				return
			}
			w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", device.Mac()))
			io.Copy(w, zedbootFile)
			zedbootFile.Close()
			zedbootFile = nil
		})
	}

	srv := &http.Server{
		Addr:    fmt.Sprintf(":%d", port),
		Handler: mux,
	}
	go srv.ListenAndServe()
	return srv, nil
}

func runBootservers(ctx context.Context, devices []*devicePkg.DeviceTarget) int {
	// Execute bootserver for each node that isn't a NUC
	exitCodes := make(chan int)
	numSubprocesses := 0
	for _, device := range devices {
		if device.Type() != "nuc" {
			go func(device *devicePkg.DeviceTarget) {
				exitCodes <- runSubprocess(ctx, device.BootserverCmd)
			}(device)
			numSubprocesses += 1
		} else {
			// If this is a NUC, rebooting will allow iPXE to retrieve the
			// new version of zedboot.
			device.Powercycle(ctx)
		}
	}

	// Wait for all of the bootservers to finish running and ensure success
	numErrs := 0
	for i := 0; i < numSubprocesses; i++ {
		if exitCode := <-exitCodes; exitCode != 0 {
			log.Printf("bootserver exited with exit code: %d\n", exitCode)
			numErrs += 1
		}
	}

	if numErrs > 0 {
		return failureExitCode
	}
	return 0
}

func hasNUC(devices []*devicePkg.DeviceTarget) bool {
	for _, device := range devices {
		if device.Type() == "nuc" {
			return true
		}
	}
	return false
}

func downloadZedboot(ctx context.Context, imageURL *url.URL, devices []*devicePkg.DeviceTarget) error {
	// Ensure that we have a NUC that needs a local zedboot.zbi.
	if !hasNUC(devices) {
		return nil
	}

	// If we are not getting images from GCS, then this is a no-op.
	if imageURL.Scheme != "gs" {
		return nil
	}
	// Connect to GCS.
	bucket := imageURL.Host
	client, err := storage.NewClient(ctx)
	if err != nil {
		return err
	}
	bkt := client.Bucket(bucket)

	// Construct path to zedboot given images manifest path.
	gcsZedbootPath := strings.TrimLeft(
		fmt.Sprintf("%s/%s", filepath.Dir(imageURL.Path), zedbootPath), "/",
	)

	// Get reader to GCS object.
	r, err := bkt.Object(gcsZedbootPath).NewReader(ctx)
	if err != nil {
		return err
	}
	defer r.Close()

	// Open a local zedboot file and download remote data into it.
	file, err := os.Create(zedbootPath)
	if err != nil {
		return err
	}
	defer file.Close()

	if _, err := io.Copy(file, r); err != nil {
		return err
	}
	return nil
}

func execute(ctx context.Context, subcommandArgs []string) (int, error) {
	// If this is a QEMU test bench, skip the device setup and just run the subprocess.
	if deviceConfigPath == "" {
		return runSubprocess(ctx, subcommandArgs), nil
	}
	// Contains all necessary bootserver flags except device nodename.
	bootserverCmdStub := []string{
		bootserverPath,
		"--images", imagesManifest,
		"--mode", "pave-zedboot",
		"-n",
	}

	// Create devicePkg.DeviceTargets for each of the devices in the config file
	devices, err := devicePkg.CreateDeviceTargets(ctx, deviceConfigPath, bootserverCmdStub)
	if err != nil {
		return failureExitCode, err
	}

	// If there is a NUC among the device targets, download zedboot. Is a no-op if no NUCs
	// exist on this testbed.
	imageURL, err := url.Parse(imagesManifest)
	if err != nil {
		return failureExitCode, err
	}
	if err := downloadZedboot(ctx, imageURL, devices); err != nil {
		return failureExitCode, err
	}

	// Set up a NUC server. If the port environment variable is not set, this is a no-op.
	srv, err := runNUCServer(ctx, devices)
	if err != nil {
		return failureExitCode, err
	}

	// Clean up after tests
	defer func() {
		if srv != nil {
			if err := srv.Shutdown(ctx); err != nil {
				log.Printf("NUC server shutdown failed: %v", err)
			}
		}
	}()

	if exitCode := runBootservers(ctx, devices); exitCode != 0 {
		return exitCode, nil
	}

	time.Sleep(rebootDuration)
	// Execute the passed in subcommand
	return runSubprocess(ctx, subcommandArgs), nil
}

func main() {
	// Initialize
	flag.Parse()

	// Handle SIGTERM
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	signals := make(chan os.Signal)
	defer func() {
		signal.Stop(signals)
		close(signals)
	}()
	signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)

	go func() {
		select {
		case <-signals:
			cancel()
		case <-ctx.Done():
		}
	}()

	// Resolve environment variables within the subcommand's command line, as
	// Swarming does not do it automatically, and best to have this done once and
	// uniformly across all possible test task commands.
	var subcmdArgs []string
	for _, arg := range flag.Args() {
		subcmdArgs = append(subcmdArgs, os.ExpandEnv(arg))
	}
	exitCode, err := execute(ctx, subcmdArgs)

	if err != nil || exitCode != 0 {
		log.Printf("Exit code: %d, Err: %s\n", exitCode, err)
	}
	os.Exit(exitCode)
}
