// Copyright 2019 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package main

import (
	"context"
	"encoding/json"
	"flag"
	"fmt"
	"io"
	"log"
	"net"
	"os"
	"os/exec"
	"time"

	devicePkg "go.fuchsia.dev/infra/devices"
	"go.fuchsia.dev/tools/netboot"
)

const usage = `usage: health_checker [options]

Checks the health of the attached device by checking to see if it can
discover and ping the device's netsvc address. A healthy device should be
running in Zedboot.
`

// Command line flag values
var (
	timeout           time.Duration
	configFile        string
	rebootIfUnhealthy bool
	forceReboot       bool
)

const (
	healthyState         = "healthy"
	unhealthyState       = "unhealthy"
	logFile              = "/tmp/health_checker.log"
)

// DeviceHealthProperties contains health properties of a hardware device.
type HealthCheckResult struct {
	// Nodename is the hostname of the device that we want to boot on.
	Nodename string `json:"nodename"`

	// State is the health status of the device (either "healthy" or "unhealthy").
	State string `json:"state"`

	// ErrorMsg is the error message provided by the health check.
	ErrorMsg string `json:"error_msg"`
}

func pingZedboot(n *netboot.Client, nodename string) error {
	netsvcAddr, err := n.Discover(nodename, false)
	if err != nil {
		return fmt.Errorf("Failed to discover netsvc addr: %v.", err)
	}
	netsvcIpAddr := &net.IPAddr{IP: netsvcAddr.IP, Zone: netsvcAddr.Zone}
	cmd := exec.Command("ping", "-6", netsvcIpAddr.String(), "-c", "1")
	if _, err = cmd.Output(); err != nil {
		return fmt.Errorf("Failed to ping netsvc addr %s: %v.", netsvcIpAddr, err)
	}
	return nil
}

func ensureNotFuchsia(n *netboot.Client, nodename string) error {
	fuchsiaAddr, err := n.Discover(nodename, true)
	if err != nil {
		return fmt.Errorf("Failed to discover fuchsia addr: %v.", err)
	}
	fuchsiaIpAddr := &net.IPAddr{IP: fuchsiaAddr.IP, Zone: fuchsiaAddr.Zone}
	cmd := exec.Command("ping", "-6", fuchsiaIpAddr.String(), "-c", "1")
	if _, err = cmd.Output(); err == nil {
		return fmt.Errorf("Device is in Fuchsia, should be in Zedboot.")
	}
	return nil
}

func deviceInZedboot(n *netboot.Client, nodename string) error {
	if err := pingZedboot(n, nodename); err != nil {
		return err
	}
	if err := ensureNotFuchsia(n, nodename); err != nil {
		return err
	}
	return nil
}

// checkSerial sends an echo command over serial and ensures that the proper
// response is received - is a no-op if the device doesn't have serial
// this is also a no-op for everything other than NUC, as the check seems flaky
// on astros/sherlocks
func checkSerial(device *devicePkg.DeviceTarget) error {
	if device.Type() != "nuc" {
		return nil
	}
	if device.Serial() == nil {
		return nil
	}
	cmdString := "\necho hello\n"
	resultString := "\r\n$ echo hello\r\nhello"
	if _, err := io.WriteString(device.Serial(), cmdString); err != nil {
		return err
	}
	buffer := make([]byte, len(resultString))
	if _, err := io.ReadAtLeast(device.Serial(), buffer, len(resultString)); err != nil {
		return err
	}
	if string(buffer) != resultString {
		log.Printf("serial test got unexpected output: %s", string(buffer))
		return fmt.Errorf("serial test got unexpected output")
	}
	return nil
}

// checkBroadcasting ensures that broadcast packets are being sent by the device
// is a no-op on NUCs
func checkBroadcasting(n *netboot.Client, device *devicePkg.DeviceTarget) error {
	if device.Type() == "nuc" {
		return nil
	}
	if _, err := n.Beacon(); err != nil {
		return err
	}
	return nil
}

func checkHealth(n *netboot.Client, device *devicePkg.DeviceTarget) HealthCheckResult {
	nodename := device.Nodename()
	log.Printf("Checking health for %s", nodename)
	// Check the device is in zedboot.
	if err := deviceInZedboot(n, nodename); err != nil {
		return HealthCheckResult{nodename, unhealthyState, err.Error()}
	}
	// Check the device is responding to serial. Is a no-op if serial line doesn't exist.
	if err := checkSerial(device); err != nil {
		return HealthCheckResult{nodename, unhealthyState, err.Error()}
	}
	// Check the device is broadcasting. Is a no-op on NUCs.
	if err := checkBroadcasting(n, device); err != nil {
		return HealthCheckResult{nodename, unhealthyState, err.Error()}
	}
	return HealthCheckResult{nodename, healthyState, ""}
}

func printHealthCheckResults(checkResults []HealthCheckResult) error {
	output, err := json.Marshal(checkResults)
	if err != nil {
		return err
	}
	fmt.Println(string(output))
	return nil
}

func init() {
	flag.Usage = func() {
		fmt.Fprint(os.Stderr, usage)
		flag.PrintDefaults()
	}

	// First set the flags ...
	flag.StringVar(&configFile, "config", "/etc/catalyst/config.json",
		"The path of the json config file that contains the nodename of the device.")
	flag.DurationVar(&timeout, "timeout", 10*time.Second,
		"The timeout for checking each device. The format should be a value acceptable to time.ParseDuration.")
	flag.BoolVar(&rebootIfUnhealthy, "reboot", false, "If true, attempt to reboot the device if unhealthy.")
	flag.BoolVar(&forceReboot, "force-reboot", false, "If true, will skip health checks and reboot the device.")
}

func main() {
	flag.Parse()
	client := netboot.NewClient(timeout)
	ctx := context.Background()
	devices, err := devicePkg.CreateDeviceTargets(ctx, configFile, nil)
	if err != nil {
		log.Fatal(err)
	}

	f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
	if err != nil {
		log.Fatal(err)
	}
	defer f.Close()
	log.SetOutput(f)

	if forceReboot {
		for _, device := range devices {
			log.Printf("attempting forced device restart for: %s", device.Nodename())
			if device.Serial() == nil {
				log.Printf("device does not have serial, powercycling")
				if err := device.Powercycle(ctx); err != nil {
					log.Printf("powercycle failed: %s", err.Error())
				}
			} else {
				log.Printf("device has serial, restarting")
				if err := device.Restart(ctx); err != nil {
					log.Printf("forced restart failed with error: %s", err.Error())
				}
			}
			log.Printf("forced restart for device %s is complete", device.Nodename())
		}
		return
	}

	var checkResultSlice []HealthCheckResult
	for _, device := range devices {
		checkResult := checkHealth(client, device)
		log.Printf("state=%s, error_msg=%s", checkResult.State, checkResult.ErrorMsg)
		if checkResult.State == unhealthyState && rebootIfUnhealthy {
			if err := device.Powercycle(ctx); err != nil {
				log.Printf("powercycle call failed with error: %s", err.Error())
				if err := device.Restart(ctx); err != nil {
					log.Printf("restart call failed with error: %s", err.Error())
					checkResult.ErrorMsg += "; Failed to perform powercycle and restart"
				} else {
					checkResult.ErrorMsg += "; Failed to perform powercycle; restart succeeded"
				}
			} else {
				log.Printf("powercycle call succeeded for %s", device.Nodename())
			}
		}
		checkResultSlice = append(checkResultSlice, checkResult)
	}
	if err = printHealthCheckResults(checkResultSlice); err != nil {
		log.Fatal(err)
	}
}
