// Copyright 2019 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can
// found in the LICENSE file.

package main

import (
	"context"
	"flag"
	"fmt"
	"io"
	"log"
	"net/http"
	"os"
	"os/exec"
	"os/signal"
	"strconv"
	"strings"
	"sync"
	"syscall"
	"time"

	devicePkg "go.fuchsia.dev/infra/devices"
)

const (
	failureExitCode = 1
	portEnvVar      = "NUC_SERVER_PORT"
	// chainloaderExitStr instructs a NUC to exit IPXE and boot from disk
	chainloaderExitStr = "#!ipxe\necho Catalyst chain loader\nexit 0\n"
	rebootDuration     = 15 * time.Second
	zedbootPath        = "zedboot.zbi"
)

var (
	imagesManifest      string
	deviceConfigPath    string
	bootserverPath      string
	disableRebootServer bool
	credentialsPath     string
	enableFastboot      bool
)

func init() {
	flag.StringVar(&imagesManifest, "images", "", "Path to the images manifest json")
	flag.StringVar(&deviceConfigPath, "config", "", "Path to the device config file")
	flag.StringVar(&bootserverPath, "bootserver", "", "Path to the bootserver binary")
	flag.StringVar(&credentialsPath, "credentials", "", "Path to the service account json to use")
	flag.BoolVar(&disableRebootServer, "disable-reboot-server", false, "Disables the NUC soft reboot server")
	flag.BoolVar(&enableFastboot, "enable-fastboot", false, "Enable fastboot flashing")
}

// Runs a subprocess and sets up a handler that propagates SIGTERM on context cancel.
func runSubprocess(ctx context.Context, command []string) int {
	if len(command) == 0 {
		return 0
	}
	cmd := exec.Command(command[0], command[1:]...)

	// Spin off handler to exit subprocesses cleanly via SIGTERM.
	processDone := make(chan bool, 1)
	var processMu sync.Mutex
	go func() {
		select {
		case <-processDone:
		case <-ctx.Done():
			// We need to check if the process is nil because it won't exist if
			// it has been SIGKILL'd already.
			processMu.Lock()
			defer processMu.Unlock()
			if cmd.Process != nil {
				if err := cmd.Process.Signal(syscall.SIGTERM); err != nil {
					log.Printf("exited cmd with error %v", err)
				}
			}
		}
	}()

	// Ensure that the context still exists before running the subprocess.
	if ctx.Err() != nil {
		log.Print("context exited before starting subprocess")
		return failureExitCode
	}

	cmd.Stderr = os.Stderr
	cmd.Stdout = os.Stdout
	// We need to make this a critical section because running Start changes
	// cmd.Process, which we attempt to access in the goroutine above. Not locking
	// causes a data race.
	processMu.Lock()
	if err := cmd.Start(); err != nil {
		log.Printf("err starting subprocess: %v", err)
	}
	processMu.Unlock()
	if err := cmd.Wait(); err != nil {
		log.Printf("err running subprocess: %v", err)
	}
	processDone <- true
	return cmd.ProcessState.ExitCode()
}

func runNUCServer(ctx context.Context, devices []*devicePkg.DeviceTarget, disabled bool) (*http.Server, error) {
	// Parse the port from the environment variable.
	portStr := os.Getenv(portEnvVar)
	if portStr == "" || disabled {
		return nil, nil
	}
	port, err := strconv.Atoi(portStr)
	if err != nil {
		return nil, err
	}

	mux := http.NewServeMux()
	for _, device := range devices {
		// Add endpoint for exit chainloader. This allows NUCs to reboot from disk.
		mux.HandleFunc(fmt.Sprintf("/%s.ipxe", device.Mac()), func(w http.ResponseWriter, r *http.Request) {
			w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s.ipxe", device.Mac()))
			chainloader := strings.NewReader(chainloaderExitStr)
			io.Copy(w, chainloader)
		})

		// Add single-use endpoint to deliver build's version of zedboot.
		zedbootFile, err := os.Open(zedbootPath)
		if err != nil {
			zedbootFile = nil
		}
		mux.HandleFunc(fmt.Sprintf("/zedboot/%s", device.Mac()), func(w http.ResponseWriter, r *http.Request) {
			if zedbootFile == nil {
				w.WriteHeader(http.StatusNotFound)
				return
			}
			w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", device.Mac()))
			io.Copy(w, zedbootFile)
			zedbootFile.Close()
			zedbootFile = nil
		})
	}

	srv := &http.Server{
		Addr:    fmt.Sprintf(":%d", port),
		Handler: mux,
	}
	go srv.ListenAndServe()
	return srv, nil
}

func runFastbootFlash(ctx context.Context, device *devicePkg.DeviceTarget, files []string) int {
	timeoutCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
	defer cancel()
	cmdSlice := BuildFastbootCmd(ctx, device.FastbootSernum(), files)
	return runSubprocess(timeoutCtx, cmdSlice)
}

func runBootservers(ctx context.Context, devices []*devicePkg.DeviceTarget, files []string) int {
	// Execute bootserver for each node that isn't a NUC
	exitCodes := make(chan int)
	numSubprocesses := 0
	for _, device := range devices {
		if device.Type() == "nuc" {
			device.Powercycle(ctx)
		} else if enableFastboot && (device.Type() == "astro" || device.Type() == "sherlock") {
			if exitCode := runFastbootFlash(ctx, device, files); exitCode != 0 {
				log.Printf("fastboot flash failed for %s: exit code: %d", device.Nodename(), exitCode)
			}
		} else {
			go func(device *devicePkg.DeviceTarget) {
				exitCodes <- runSubprocess(ctx, device.BootserverCmd)
			}(device)
			numSubprocesses += 1
		}
	}

	// Wait for all of the bootservers to finish running and ensure success
	numErrs := 0
	for i := 0; i < numSubprocesses; i++ {
		if exitCode := <-exitCodes; exitCode != 0 {
			log.Printf("bootserver exited with exit code: %d\n", exitCode)
			numErrs += 1
		}
	}

	if numErrs > 0 {
		return failureExitCode
	}
	return 0
}

func execute(ctx context.Context, subcommandArgs []string) (int, error) {
	// If this is a QEMU test bench, skip the device setup and just run the subprocess.
	if deviceConfigPath == "" {
		return runSubprocess(ctx, subcommandArgs), nil
	}

	// Download any needed images.
	files, err := DownloadImages(ctx, credentialsPath)
	if err != nil {
		return failureExitCode, err
	}

	// Contains all necessary bootserver flags except device nodename.
	bootserverCmdStub := []string{
		bootserverPath,
		"--images", imagesManifest,
		"--mode", "pave-zedboot",
		"-n",
	}

	// Create devicePkg.DeviceTargets for each of the devices in the config file
	devices, err := devicePkg.CreateDeviceTargets(ctx, deviceConfigPath, bootserverCmdStub)
	if err != nil {
		return failureExitCode, err
	}

	// Set up a NUC server. If the port environment variable is not set, or if
	// the disableRebootServer flag is set, this is a no-op.
	srv, err := runNUCServer(ctx, devices, disableRebootServer)
	if err != nil {
		return failureExitCode, err
	}

	// Clean up after tests
	defer func() {
		if srv != nil {
			if err := srv.Shutdown(ctx); err != nil {
				log.Printf("NUC server shutdown failed: %v", err)
			}
		}
	}()

	if exitCode := runBootservers(ctx, devices, files); exitCode != 0 {
		return exitCode, nil
	}

	time.Sleep(rebootDuration)
	// Execute the passed in subcommand
	return runSubprocess(ctx, subcommandArgs), nil
}

func main() {
	// Initialize
	flag.Parse()

	// Handle SIGTERM
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	signals := make(chan os.Signal)
	defer func() {
		signal.Stop(signals)
		close(signals)
	}()
	signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)

	go func() {
		select {
		case <-signals:
			cancel()
		case <-ctx.Done():
		}
	}()

	// Resolve environment variables within the subcommand's command line, as
	// Swarming does not do it automatically, and best to have this done once and
	// uniformly across all possible test task commands.
	var subcmdArgs []string
	for _, arg := range flag.Args() {
		subcmdArgs = append(subcmdArgs, os.ExpandEnv(arg))
	}
	exitCode, err := execute(ctx, subcmdArgs)

	if err != nil {
		log.Printf("Exit code: %d, Err: %s\n", exitCode, err)
	}
	os.Exit(exitCode)
}
