// Copyright 2017 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <arpa/inet.h>
#include <errno.h>
#include <fcntl.h>
#include <inttypes.h>
#include <lib/netboot/netboot.h>
#include <netinet/in.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/time.h>
#include <time.h>
#include <unistd.h>

#include "bootserver.h"

#define DEFAULT_US_BETWEEN_PACKETS 20

static uint32_t cookie = 1;
static const int MAX_READ_RETRIES = 10;
static const int MAX_SEND_RETRIES = 10000;

static int io_rcv(int s, netboot_message_header_t* hdr, netboot_message_header_t* ack) {
  for (int i = 0; i < MAX_READ_RETRIES; i++) {
    bool retry_allowed = i + 1 < MAX_READ_RETRIES;

    ssize_t r = read(s, ack, 2048);
    if (r < 0) {
      if (retry_allowed && errno == EAGAIN) {
        continue;
      }
      fprintf(stderr, "\n%s: error: Socket read error %d\n", appname, errno);
      return -1;
    }
    if ((size_t)r < sizeof(netboot_message_header_t)) {
      fprintf(stderr, "\n%s: error: Read too short\n", appname);
      return -1;
    }
#ifdef DEBUG
    fprintf(stdout, " < magic = %08x, cookie = %08x, cmd = %08x, arg = %08x\n", ack->magic,
            ack->cookie, ack->cmd, ack->arg);
#endif

    if (ack->magic != NETBOOT_MAGIC) {
      fprintf(stderr, "\n%s: error: Bad magic\n", appname);
      return 0;
    }
    if (hdr) {
      if (ack->cookie > hdr->cookie) {
        fprintf(stderr, "\n%s: error: Bad cookie\n", appname);
        return 0;
      }
    }

    if (ack->cmd == NETBOOT_COMMAND_ACK || ack->cmd == NETBOOT_COMMAND_FILE_RECEIVED) {
      return 0;
    }

    switch (ack->cmd) {
      case NETBOOT_COMMAND_ERROR_GENERIC:
        fprintf(stderr, "\n%s: error: Generic error\n", appname);
        break;
      case NETBOOT_COMMAND_ERROR_BAD_CMD:
        fprintf(stderr, "\n%s: error: Bad command\n", appname);
        break;
      case NETBOOT_COMMAND_ERROR_BAD_PARAM:
        fprintf(stderr, "\n%s: error: Bad parameter\n", appname);
        break;
      case NETBOOT_COMMAND_ERROR_TOO_LARGE:
        fprintf(stderr, "\n%s: error: File too large\n", appname);
        break;
      case NETBOOT_COMMAND_ERROR_BAD_FILE:
        fprintf(stderr, "\n%s: error: Bad file\n", appname);
        break;
      default:
        fprintf(stderr, "\n%s: error: Unknown command 0x%08X\n", appname, ack->cmd);
    }
    return -1;
  }
  fprintf(stderr, "\n%s: error: Unexpected code path\n", appname);
  return -1;
}

static int io_send(int s, netboot_message_header_t* msg, size_t len) {
  for (int i = 0; i < MAX_SEND_RETRIES; i++) {
#if defined(__APPLE__)
    bool retry_allowed = i + 1 < MAX_SEND_RETRIES;
#endif

    int r = write(s, msg, len);
    if (r < 0) {
#if defined(__APPLE__)
      if (retry_allowed && errno == ENOBUFS) {
        // On Darwin we manage to overflow the ethernet driver, so retry
        struct timespec reqtime;
        reqtime.tv_sec = 0;
        reqtime.tv_nsec = 50 * 1000;
        nanosleep(&reqtime, NULL);
        continue;
      }
#endif
      fprintf(stderr, "\n%s: error: Socket write error %d\n", appname, errno);
      return -1;
    }
    return 0;
  }
  fprintf(stderr, "\n%s: error: Unexpected code path\n", appname);
  return -1;
}

static int io(int s, netboot_message_header_t* msg, size_t len, netboot_message_header_t* ack,
              bool wait_reply) {
  int r, n;
  struct timeval tv;
  fd_set reads, writes;
  fd_set* ws = NULL;
  fd_set* rs = NULL;

  ack->cookie = 0;
  ack->cmd = 0;
  ack->arg = 0;

  FD_ZERO(&reads);
  if (!wait_reply) {
    FD_SET(s, &reads);
    rs = &reads;
  }

  FD_ZERO(&writes);
  if (msg && len > 0) {
    msg->magic = NETBOOT_MAGIC;
    msg->cookie = cookie++;

    FD_SET(s, &writes);
    ws = &writes;
  }

  if (rs || ws) {
    n = s + 1;
    tv.tv_sec = 10;
    tv.tv_usec = 500000;
    int rv = select(n, rs, ws, NULL, &tv);
    if (rv == -1) {
      fprintf(stderr, "\n%s: error: Select failed %d\n", appname, errno);
      return -1;
    } else if (rv == 0) {
      // Timed-out
      fprintf(stderr, "\n%s: error: Select timed out\n", appname);
      return -1;
    } else {
      r = 0;
      if (FD_ISSET(s, &reads)) {
        r = io_rcv(s, msg, ack);
      }

      // If we got an ack, don't bother sending anything - go handle the ack first
      if (!r && FD_ISSET(s, &writes) && (ack->cookie == 0 || ack->cmd != NETBOOT_COMMAND_ACK)) {
        r = io_send(s, msg, len);
      }

      if (r || !wait_reply) {
        return r;
      }
    }
  } else if (!wait_reply) {  // no-op
    return 0;
  }

  if (wait_reply) {
    return io_rcv(s, msg, ack);
  }
  fprintf(stderr, "\n%s: error: Select triggered without events\n", appname);
  return -1;
}

typedef struct {
  FILE* fp;
  const char* data;
  size_t datalen;
  const char* ptr;
  size_t avail;
} xferdata;

static ssize_t xread(xferdata* xd, void* data, size_t len) {
  if (xd->fp == NULL) {
    if (len > xd->avail) {
      len = xd->avail;
    }
    memcpy(data, xd->ptr, len);
    xd->avail -= len;
    xd->ptr += len;
    return len;
  } else {
    ssize_t r = fread(data, 1, len, xd->fp);
    if (r == 0) {
      return ferror(xd->fp) ? -1 : 0;
    }
    return r;
  }
}

static int xseek(xferdata* xd, size_t off) {
  if (xd->fp == NULL) {
    if (off > xd->datalen) {
      return -1;
    }
    xd->ptr = xd->data + off;
    xd->avail = xd->datalen - off;
    return 0;
  } else {
    return fseek(xd->fp, off, SEEK_SET);
  }
}

// UDP6_MAX_PAYLOAD (ETH_MTU - ETH_HDR_LEN - IP6_HDR_LEN - UDP_HDR_LEN)
//      1452           1514   -     14      -     40      -    8

// Some EFI network stacks have problems with larger packets
// 1280 is friendlier
#define PAYLOAD_SIZE 1280

int netboot_xfer(struct sockaddr_in6* addr, const char* fn, const char* name) {
  xferdata xd;
  char msgbuf[2048];
  char ackbuf[2048];
  char tmp[INET6_ADDRSTRLEN];
  struct timeval tv;
  netboot_message_header_t* msg = (void*)msgbuf;
  char* msg_data = (char*)(msg + 1);
  netboot_message_header_t* ack = (void*)ackbuf;
  int s;
  int status = -1;
  size_t current_pos = 0;
  size_t sz = 0;

  if (!strcmp(fn, "(cmdline)")) {
    xd.fp = NULL;
    xd.data = name;
    xd.datalen = strlen(name) + 1;
    xd.ptr = xd.data;
    xd.avail = xd.datalen;
    name = NETBOOT_CMDLINE_FILENAME;
    sz = xd.datalen;
  } else {
    if ((xd.fp = fopen(fn, "rb")) == NULL) {
      fprintf(stderr, "%s: error: Could not open file %s\n", appname, fn);
      return -1;
    }
    if (fseek(xd.fp, 0L, SEEK_END)) {
      fprintf(stderr, "%s: error: Could not determine size of %s\n", appname, fn);
    } else if ((sz = ftell(xd.fp)) < 0) {
      fprintf(stderr, "%s: error: Could not determine size of %s\n", appname, fn);
      sz = 0;
    } else if (fseek(xd.fp, 0L, SEEK_SET)) {
      fprintf(stderr, "%s: error: Failed to rewind %s\n", appname, fn);
      return -1;
    }
  }

  if (sz > 0) {
    initialize_status(xd.data, sz);
  }

  if ((s = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP)) < 0) {
    fprintf(stderr, "%s: error: Cannot create socket %d\n", appname, errno);
    goto done;
  }
  tv.tv_sec = 0;
  tv.tv_usec = 250 * 1000;
  setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
  if (connect(s, (void*)addr, sizeof(*addr)) < 0) {
    fprintf(stderr, "%s: error: Cannot connect to [%s]%d\n", appname,
            inet_ntop(AF_INET6, &addr->sin6_addr, tmp, sizeof(tmp)), ntohs(addr->sin6_port));
    goto done;
  }

  msg->cmd = NETBOOT_COMMAND_SEND_FILE;
  msg->arg = sz;
  strcpy((void*)msg_data, name);
  if (io(s, msg, sizeof(netboot_message_header_t) + strlen(name) + 1, ack, true)) {
    fprintf(stderr, "%s: error: Failed to start transfer\n", appname);
    goto done;
  }

  msg->cmd = NETBOOT_COMMAND_DATA;
  msg->arg = 0;

  bool completed = false;
  do {
    struct timeval packet_start_time;
    gettimeofday(&packet_start_time, NULL);

    ssize_t r = xread(&xd, msg_data, PAYLOAD_SIZE);
    if (r < 0) {
      fprintf(stderr, "\n%s: error: Reading '%s'\n", appname, fn);
      goto done;
    }

    update_status(msg->arg);

    if (r == 0) {
      fprintf(stderr, "\n%s: Reached end of file, waiting for confirmation.\n", appname);
      // Do not send anything, but keep waiting on incoming messages
      if (io(s, NULL, 0, ack, true)) {
        goto done;
      }
    } else {
      if (current_pos + (size_t)r >= sz) {
        msg->cmd = NETBOOT_COMMAND_LAST_DATA;
      } else {
        msg->cmd = NETBOOT_COMMAND_DATA;
      }

      if (io(s, msg, sizeof(netboot_message_header_t) + r, ack, false)) {
        goto done;
      }

      // Some UEFI netstacks can lose back-to-back packets at max speed
      // so throttle output.
      // At 1280 bytes per packet, we should at least have 10 microseconds
      // between packets, to be safe using 20 microseconds here.
      // 1280 bytes * (1,000,000/10) seconds = 128,000,000 bytes/seconds = 122MB/s = 976Mb/s
      // We wait as a busy wait as the context switching a sleep can cause
      // will often degrade performance significantly.
      int64_t us_since_last_packet;
      do {
        struct timeval now;
        gettimeofday(&now, NULL);
        us_since_last_packet = (int64_t)(now.tv_sec - packet_start_time.tv_sec) * 1000000 +
                               ((int64_t)now.tv_usec - (int64_t)packet_start_time.tv_usec);
      } while (us_since_last_packet < us_between_packets);
    }

    // ACKs really are NACKs
    if (ack->cookie > 0 && ack->cmd == NETBOOT_COMMAND_ACK) {
      // ACKs tend to be generated in groups, since a dropped packet will cause ACKs for all
      // outstanding packets. Therefore briefly sleep when we receive an ACK with a different
      // position, to let things settle and prevent ourselves from fighting subsequent acks.
      if (ack->arg != current_pos) {
        fprintf(stderr, "\n%s: need to reset to %d from %zu\n", appname, ack->arg, current_pos);
        current_pos = ack->arg;

        tv.tv_usec = 100000;
        select(0, NULL, NULL, NULL, &tv);
      }
      if (xseek(&xd, current_pos)) {
        fprintf(stderr, "\n%s: error: Failed to rewind '%s' to %zu\n", appname, fn, current_pos);
        goto done;
      }
    } else if (ack->cmd == NETBOOT_COMMAND_FILE_RECEIVED) {
      current_pos += r;
      completed = true;
    } else {
      current_pos += r;
    }

    msg->arg = current_pos;
  } while (!completed);

  status = 0;
  update_status(msg->arg);
done:
  if (s >= 0) {
    close(s);
  }
  if (xd.fp != NULL) {
    fclose(xd.fp);
  }
  return status;
}
