blob: e75d2cd56ff65363b2c36ecd5ce797031bbf9238 [file] [log] [blame]
/*
* Copyright 2013 Google Inc.
*
* See file CREDITS for list of people who contributed to this
* project.
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License as
* published by the Free Software Foundation; either version 2 of
* the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but without any warranty; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*/
#include <assert.h>
#include <endian.h>
#include <stdint.h>
#include <stdio.h>
#include "base/time.h"
#include "base/xalloc.h"
#include "drivers/net/net.h"
#include "net/net.h"
#include "net/netboot/tftp.h"
#include "net/ipv4/uip/udp/packet.h"
#include "net/ipv4/uip/uip.h"
static const uint64_t TftpReceiveTimeoutUs = 100000;
typedef enum TftpStatus
{
TftpPending = 0,
TftpProgressing = 1,
TftpSuccess = 2,
TftpFailure = 3
} TftpStatus;
typedef struct {
void *dest;
size_t total_size;
size_t max_size;
int block_num;
void *buf;
size_t buf_size;
} TftpTransfer;
typedef struct TftpAckPacket
{
uint16_t opcode;
uint16_t block;
} TftpAckPacket;
static void tftp_print_error_pkt(void)
{
if (uip_datalen() >= 4) {
uint16_t error;
memcpy(&error, (uint8_t *)uip_appdata + 2,
sizeof(error));
error = ntohw(error);
printf("Error code %d", error);
switch (error) {
case TftpUndefined:
printf(" (Undefined)\n");
break;
case TftpFileNotFound:
printf(" (File not found)\n");
break;
case TftpAccessViolation:
printf(" (Access violation)\n");
break;
case TftpNoSpace:
printf(" (Not enough space)\n");
break;
case TftpIllegalOp:
printf(" (Illegal operation)\n");
break;
case TftpUnknownId:
printf(" (Unknown transfer ID)\n");
break;
case TftpFileExists:
printf(" (File already exists)\n");
break;
case TftpNoSuchUser:
printf(" (No such user)\n");
break;
default:
printf("\n");
}
}
if (uip_datalen() > 4) {
// Copy out the error message so we can null terminate it.
int message_len = uip_datalen() - 4 + 1;
char *message = xmalloc(message_len);
message[message_len - 1] = 0;
memcpy(message, uip_appdata, message_len - 1);
printf("Error message: %s\n", message);
free(message);
}
}
TftpStatus tftp_handle_response(NetConOps *con, TftpTransfer *transfer)
{
size_t incoming;
if (netcon_incoming(con, &incoming))
return TftpFailure;
// If the packet is too small, ignore it.
if (incoming < 4)
return TftpPending;
// Expand the transfer's scratch buffer if necessary.
if (transfer->buf_size < incoming) {
free(transfer->buf);
transfer->buf = xmalloc(incoming);
transfer->buf_size = incoming;
}
if (netcon_receive(con, transfer->buf, &incoming, transfer->buf_size))
return TftpFailure;
// Extract the opcode.
uint16_t opcode;
memcpy(&opcode, transfer->buf, sizeof(opcode));
opcode = ntohw(opcode);
// If there was an error, report it and stop the transfer.
if (opcode == TftpError) {
printf(" error!\n");
tftp_print_error_pkt();
return TftpFailure;
}
// We should only get data packets.
if (opcode != TftpData)
return TftpPending;
// Get the block number.
uint16_t block_num;
memcpy(&block_num, (uint8_t *)transfer->buf + 2, sizeof(block_num));
block_num = ntohw(block_num);
// Ignore blocks which are duplicated or out of order.
if (block_num != transfer->block_num)
return TftpPending;
void *new_data = (uint8_t *)transfer->buf + 4;
int new_data_len = incoming - 4;
// If the block is too big, reject it.
if (new_data_len > TftpMaxBlockSize)
return TftpPending;
// If we're out of space give up.
if (new_data_len > transfer->max_size - transfer->total_size) {
printf("TFTP transfer too large.\n");
return TftpFailure;
}
// If there's any data, copy it in.
if (new_data_len) {
memcpy(transfer->dest, new_data, new_data_len);
transfer->dest = (uint8_t *)transfer->dest + new_data_len;
transfer->total_size += new_data_len;
}
// Prepare an ack.
TftpAckPacket ack = {
htonw(TftpAck),
htonw(block_num)
};
if (netcon_send(con, &ack, sizeof(ack)))
return TftpFailure;
// If this block was less than the maximum size, the transfer is done.
if (new_data_len < TftpMaxBlockSize)
return TftpSuccess;
// Move on to the next block.
transfer->block_num++;
if (!(transfer->block_num % 10)) {
// Give some feedback that something is happening.
printf("#");
}
return TftpProgressing;
}
int tftp_read(NetConOps *con, void *dest, const char *bootfile,
uint32_t *size, uint32_t max_size)
{
// Build the read request packet.
uint16_t opcode = htonw(TftpReadReq);
int opcode_len = sizeof(opcode);
int name_len = strlen(bootfile) + 1;
const char mode[] = "Octet";
int mode_len = sizeof(mode);
int read_req_len = opcode_len + name_len + mode_len;
uint8_t *read_req = xmalloc(read_req_len);
memcpy(read_req, &opcode, opcode_len);
memcpy(read_req + opcode_len, bootfile, name_len);
memcpy(read_req + opcode_len + name_len, mode, mode_len);
// Send the request.
printf("Sending tftp read request... ");
if (netcon_send(con, read_req, read_req_len))
return -1;
uint64_t last_sent = time_us(0);
printf("done.\n");
// Prepare for the transfer.
printf("Waiting for the transfer... ");
TftpTransfer transfer;
transfer.dest = dest;
transfer.total_size = 0;
transfer.max_size = max_size;
transfer.block_num = 1;
transfer.buf = xmalloc(1);
transfer.buf_size = 1;
// Poll the network driver until the transaction is done.
int ret = -1;
while (1) {
TftpStatus status = tftp_handle_response(con, &transfer);
if (status == TftpFailure)
break;
if (status == TftpSuccess) {
if (size)
*size = transfer.total_size;
printf(" done.\n");
ret = 0;
break;
}
if (status != TftpProgressing) {
if (time_us(last_sent) > TftpReceiveTimeoutUs) {
// Timed out. Resend our last packet.
if (transfer.block_num == 1) {
// Resend the read request.
if (netcon_send(con, read_req,
read_req_len)) {
break;
}
} else {
// Resend the last ack.
TftpAckPacket ack = {
htonw(TftpAck),
htonw(transfer.block_num - 1)
};
if (netcon_send(con, &ack, sizeof(ack)))
break;
}
last_sent = time_us(0);
}
}
}
free(transfer.buf);
free(read_req);
return ret;
}