/*
 * Copyright 2013 Google Inc.
 *
 * See file CREDITS for list of people who contributed to this
 * project.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2 of
 * the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but without any warranty; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston,
 * MA 02111-1307 USA
 */

#include <assert.h>
#include <endian.h>
#include <stdint.h>
#include <stdio.h>

#include "base/time.h"
#include "base/xalloc.h"
#include "drivers/net/net.h"
#include "net/net.h"
#include "net/netboot/tftp.h"
#include "net/ipv4/uip/udp/packet.h"
#include "net/ipv4/uip/uip.h"

static const uint64_t TftpReceiveTimeoutUs = 100000;

typedef enum TftpStatus
{
	TftpPending = 0,
	TftpProgressing = 1,
	TftpSuccess = 2,
	TftpFailure = 3
} TftpStatus;

typedef struct {
	void *dest;
	size_t total_size;
	size_t max_size;

	int block_num;

	void *buf;
	size_t buf_size;
} TftpTransfer;

typedef struct TftpAckPacket
{
	uint16_t opcode;
	uint16_t block;
} TftpAckPacket;

static void tftp_print_error_pkt(void)
{
	if (uip_datalen() >= 4) {
		uint16_t error;
		memcpy(&error, (uint8_t *)uip_appdata + 2,
			sizeof(error));
		error = ntohw(error);
		printf("Error code %d", error);
		switch (error) {
		case TftpUndefined:
			printf(" (Undefined)\n");
			break;
		case TftpFileNotFound:
			printf(" (File not found)\n");
			break;
		case TftpAccessViolation:
			printf(" (Access violation)\n");
			break;
		case TftpNoSpace:
			printf(" (Not enough space)\n");
			break;
		case TftpIllegalOp:
			printf(" (Illegal operation)\n");
			break;
		case TftpUnknownId:
			printf(" (Unknown transfer ID)\n");
			break;
		case TftpFileExists:
			printf(" (File already exists)\n");
			break;
		case TftpNoSuchUser:
			printf(" (No such user)\n");
			break;
		default:
			printf("\n");
		}
	}
	if (uip_datalen() > 4) {
		// Copy out the error message so we can null terminate it.
		int message_len = uip_datalen() - 4 + 1;
		char *message = xmalloc(message_len);

		message[message_len - 1] = 0;
		memcpy(message, uip_appdata, message_len - 1);
		printf("Error message: %s\n", message);
		free(message);
	}
}

TftpStatus tftp_handle_response(NetConOps *con, TftpTransfer *transfer)
{
	size_t incoming;
	if (netcon_incoming(con, &incoming))
		return TftpFailure;

	// If the packet is too small, ignore it.
	if (incoming < 4)
		return TftpPending;

	// Expand the transfer's scratch buffer if necessary.
	if (transfer->buf_size < incoming) {
		free(transfer->buf);
		transfer->buf = xmalloc(incoming);
		transfer->buf_size = incoming;
	}
	if (netcon_receive(con, transfer->buf, &incoming, transfer->buf_size))
		return TftpFailure;

	// Extract the opcode.
	uint16_t opcode;
	memcpy(&opcode, transfer->buf, sizeof(opcode));
	opcode = ntohw(opcode);

	// If there was an error, report it and stop the transfer.
	if (opcode == TftpError) {
		printf(" error!\n");
		tftp_print_error_pkt();
		return TftpFailure;
	}

	// We should only get data packets.
	if (opcode != TftpData)
		return TftpPending;

	// Get the block number.
	uint16_t block_num;
	memcpy(&block_num, (uint8_t *)transfer->buf + 2, sizeof(block_num));
	block_num = ntohw(block_num);

	// Ignore blocks which are duplicated or out of order.
	if (block_num != transfer->block_num)
		return TftpPending;

	void *new_data = (uint8_t *)transfer->buf + 4;
	int new_data_len = incoming - 4;

	// If the block is too big, reject it.
	if (new_data_len > TftpMaxBlockSize)
		return TftpPending;

	// If we're out of space give up.
	if (new_data_len > transfer->max_size - transfer->total_size) {
		printf("TFTP transfer too large.\n");
		return TftpFailure;
	}

	// If there's any data, copy it in.
	if (new_data_len) {
		memcpy(transfer->dest, new_data, new_data_len);
		transfer->dest = (uint8_t *)transfer->dest + new_data_len;
		transfer->total_size += new_data_len;
	}

	// Prepare an ack.
	TftpAckPacket ack = {
		htonw(TftpAck),
		htonw(block_num)
	};
	if (netcon_send(con, &ack, sizeof(ack)))
		return TftpFailure;

	// If this block was less than the maximum size, the transfer is done.
	if (new_data_len < TftpMaxBlockSize)
		return TftpSuccess;

	// Move on to the next block.
	transfer->block_num++;

	if (!(transfer->block_num % 10)) {
		// Give some feedback that something is happening.
		printf("#");
	}

	return TftpProgressing;
}

int tftp_read(NetConOps *con, void *dest, const char *bootfile,
	      uint32_t *size, uint32_t max_size)
{
	// Build the read request packet.
	uint16_t opcode = htonw(TftpReadReq);
	int opcode_len = sizeof(opcode);

	int name_len = strlen(bootfile) + 1;

	const char mode[] = "Octet";
	int mode_len = sizeof(mode);

	int read_req_len = opcode_len + name_len + mode_len;
	uint8_t *read_req = xmalloc(read_req_len);

	memcpy(read_req, &opcode, opcode_len);
	memcpy(read_req + opcode_len, bootfile, name_len);
	memcpy(read_req + opcode_len + name_len, mode, mode_len);

	// Send the request.
	printf("Sending tftp read request... ");
	if (netcon_send(con, read_req, read_req_len))
		return -1;
	uint64_t last_sent = time_us(0);
	printf("done.\n");

	// Prepare for the transfer.
	printf("Waiting for the transfer... ");

	TftpTransfer transfer;
	transfer.dest = dest;
	transfer.total_size = 0;
	transfer.max_size = max_size;
	transfer.block_num = 1;
	transfer.buf = xmalloc(1);
	transfer.buf_size = 1;

	// Poll the network driver until the transaction is done.
	int ret = -1;
	while (1) {
		TftpStatus status = tftp_handle_response(con, &transfer);
		if (status == TftpFailure)
			break;
		if (status == TftpSuccess) {
			if (size)
				*size = transfer.total_size;
			printf(" done.\n");
			ret = 0;
			break;
		}
		if (status != TftpProgressing) {
			if (time_us(last_sent) > TftpReceiveTimeoutUs) {
				// Timed out. Resend our last packet.
				if (transfer.block_num == 1) {
					// Resend the read request.
					if (netcon_send(con, read_req,
							read_req_len)) {
						break;
					}
				} else {
					// Resend the last ack.
					TftpAckPacket ack = {
						htonw(TftpAck),
						htonw(transfer.block_num - 1)
					};
					if (netcon_send(con, &ack, sizeof(ack)))
						break;
				}
				last_sent = time_us(0);
			}
		}
	}
	free(transfer.buf);
	free(read_req);
	return ret;

}
