blob: cee59d201b677ae9b82d1eebfd5759399be08ebd [file] [log] [blame]
// Copyright 2017 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package tftp
import (
"bytes"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
"time"
)
const (
// ServerPort is the default server port.
ServerPort = 33340
// Client port is the default client port.
ClientPort = 33341
// WriteTimeout is the duration to wait for the client to ack a packet.
WriteTimeout = 2 * time.Second
// WriteAttempts is the maximum number of times a packet will be resent.
WriteAttempts = 5
// BlockSize is the maximum block size used for file transfers.
BlockSize = 1024
// WindowSize is the window size used for file transfers.
WindowSize = 64
)
type Client struct {
// WriteTimeout sets the duration to wait for the client to ack a packet.
WriteTimeout time.Duration
// WriteAttempts controls how many times a packet will be resent.
WriteAttempts int
// BlockSize sets the maximum block size used for file transfers.
BlockSize int64
// WindowSize sets the window size used for file transfers.
WindowSize int64
}
func NewClient() *Client {
return &Client {
WriteTimeout: WriteTimeout,
WriteAttempts: WriteAttempts,
BlockSize: BlockSize,
WindowSize: WindowSize,
}
}
type options struct {
Timeout time.Duration
BlockSize int64
WindowSize int64
TransferSize int64
}
func (c *Client) Send(addr *net.UDPAddr, filename string, file io.Reader, size int64) error {
conn, err := net.ListenUDP("udp", &net.UDPAddr{Port: ServerPort})
if err != nil {
return fmt.Errorf("creating socket: %s", err)
}
defer conn.Close()
var b bytes.Buffer
if _, err := io.CopyN(&b, rand.Reader, 1); err != nil {
return fmt.Errorf("reading rand: %s", err)
}
b.WriteByte(2)
b.WriteString(filename)
b.WriteByte(0)
// Only support octet mode, because in practice that's the
// only remaining sensible use of TFTP.
b.WriteString("octet")
b.WriteByte(0)
o := &options{
Timeout: c.WriteTimeout,
BlockSize: c.BlockSize,
WindowSize: c.WindowSize,
TransferSize: size,
}
c.writeOPT(&b, o)
options, addr, err := c.sendWRQ(conn, addr, b.Bytes())
if err != nil {
return fmt.Errorf("sending WRQ: %s", err)
}
b.Reset()
// We're the sender, we send the size
o.TransferSize = size
ring := NewWindowReader(file, int(o.WindowSize), int(o.BlockSize))
if err := c.sendData(conn, addr, options, ring); err != nil {
return fmt.Errorf("sending data: %s", err)
}
return nil
}
func (c *Client) writeOPT(b *bytes.Buffer, o *options) {
b.WriteString("tsize")
b.WriteByte(0)
b.WriteString(strconv.FormatInt(o.TransferSize, 10))
b.WriteByte(0)
if o.BlockSize != 0 {
b.WriteString("blksize")
b.WriteByte(0)
b.WriteString(strconv.FormatInt(o.BlockSize, 10))
b.WriteByte(0)
}
if o.Timeout != 0 {
b.WriteString("timeout")
b.WriteByte(0)
b.WriteString(strconv.FormatInt(int64(o.Timeout / time.Second), 10))
b.WriteByte(0)
}
if o.WindowSize != 0 {
b.WriteString("windowsize")
b.WriteByte(0)
b.WriteString(strconv.FormatInt(o.WindowSize, 10))
b.WriteByte(0)
}
}
func (c *Client) sendWRQ(conn *net.UDPConn, addr *net.UDPAddr, b []byte) (*options, *net.UDPAddr, error) {
Attempt:
for attempt := 0; attempt < c.WriteAttempts; attempt++ {
if _, err := conn.WriteToUDP(b, addr); err != nil {
return nil, nil, err
}
conn.SetReadDeadline(time.Now().Add(c.WriteTimeout))
var recv [256]byte
for {
n, addr, err := conn.ReadFromUDP(recv[:])
if err != nil {
if t, ok := err.(net.Error); ok && t.Timeout() {
continue Attempt
}
return nil, nil, err
}
if n < 4 { // packet too small
continue
}
switch recv[1] {
case 5:
msg, _, _ := tftpStr(recv[4:])
return nil, addr, fmt.Errorf("server aborted transfer: %s", msg)
case 6:
options, err := parseOACK(recv[:n])
return options, addr, err
}
}
}
return nil, nil, errors.New("timeout waiting for ACK")
}
func (c *Client) sendData(conn *net.UDPConn, addr *net.UDPAddr, o *options, reader *WindowReader) error {
var b bytes.Buffer
seq := uint16(0)
b.Grow(int(o.BlockSize + 4))
Loop:
for {
var n int64
Attempt:
for attempt := 0; attempt < c.WriteAttempts; attempt++ {
for i := uint16(0); i < uint16(o.WindowSize); i++ {
b.Reset()
if _, err := io.CopyN(&b, rand.Reader, 1); err != nil {
return fmt.Errorf("writing rand: %s", err)
}
b.WriteByte(3)
seq++
if err := binary.Write(&b, binary.BigEndian, seq); err != nil {
return fmt.Errorf("writing seqnum: %s", err)
}
var err error
n, err = io.CopyN(&b, reader, int64(o.BlockSize))
if err != nil && err != io.EOF {
return fmt.Errorf("reading bytes for block %d: %s", seq, err)
}
if _, err := conn.WriteToUDP(b.Bytes(), addr); err != nil {
return fmt.Errorf("sending block %d: %s", seq, err)
}
if n < int64(o.BlockSize) {
break
}
}
conn.SetReadDeadline(time.Now().Add(o.Timeout))
for {
var recv [256]byte
m, _, err := conn.ReadFromUDP(recv[:])
if err != nil {
if t, ok := err.(net.Error); ok && t.Timeout() {
reader.Unread(int(o.WindowSize))
seq -= uint16(o.WindowSize)
continue Attempt
}
return err
}
if m < 4 { // packet too small
continue
}
switch recv[1] {
case 4:
if num := binary.BigEndian.Uint16(recv[2:4]); num != seq {
if num < seq - uint16(o.WindowSize) {
return fmt.Errorf("invalid ACK: %q", num)
}
if num > seq {
// out-of-order ack?
}
reader.Unread(int(seq - num))
seq = num
} else if n < int64(o.BlockSize) {
return nil
}
continue Loop
case 5:
msg, _, _ := tftpStr(recv[4:])
return fmt.Errorf("server aborted transfer: %s", msg)
}
}
}
return errors.New("timeout waiting for ACK")
}
}
func parseOACK(bs []byte) (*options, error) {
// Smallest a useful TFTP packet can be is 6 bytes: 2b opcode, 1b
// filename, 1b null, 1b mode, 1b null.
if len(bs) < 6 || binary.BigEndian.Uint16(bs[:2]) != 6 {
return nil, errors.New("not an OACK packet")
}
bs = bs[2:]
o := &options{}
for len(bs) > 0 {
opt, rest, err := tftpStr(bs)
if err != nil {
return nil, fmt.Errorf("reading option name: %s", err)
}
bs = rest
val, rest, err := tftpStr(bs)
if err != nil {
return nil, fmt.Errorf("reading option %q value: %s", opt, err)
}
bs = rest
switch strings.ToLower(opt) {
case "blksize":
size, err := strconv.ParseUint(val, 10, 16)
if err != nil {
return nil, fmt.Errorf("unsupported block size %q", val)
}
o.BlockSize = int64(size)
case "timeout":
seconds, err := strconv.ParseUint(val, 10, 8)
if err != nil {
return nil, fmt.Errorf("unsupported timeout %q", val)
}
o.Timeout = time.Second * time.Duration(seconds)
case "tsize":
size, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("unsupported transfer size %q", val)
}
o.TransferSize = size
case "windowsize":
size, err := strconv.ParseUint(val, 10, 16)
if err != nil {
return nil, fmt.Errorf("unsupported window size %q", val)
}
o.WindowSize = int64(size)
}
}
return o, nil
}
// tftpStr extracts a null-terminated string from the given bytes, and
// returns any remaining bytes.
//
// String content is checked to be a "read-useful" subset of
// "netascii", itself a subset of ASCII. Specifically, all byte values
// must fall in the range 0x20 to 0x7E inclusive.
func tftpStr(bs []byte) (str string, remaining []byte, err error) {
for i, b := range bs {
if b == 0 {
return string(bs[:i]), bs[i+1:], nil
} else if b < 0x20 || b > 0x7E {
return "", nil, fmt.Errorf("invalid netascii byte %q at offset %d", b, i)
}
}
return "", nil, errors.New("no null terminated string found")
}
// WindowReader supports reading bytes from an underlying stream using in
// fixed sized blocks and rewinding back up to slots blocks.
type WindowReader struct {
buf []byte // buffer space
len []int // len of data written to each slot
rd io.Reader
current int // current to be read or written to
head int // head of buffer
slots int
size int
}
func NewWindowReader(rd io.Reader, slots int, size int) *WindowReader {
return &WindowReader{
buf: make([]byte, slots * size),
len: make([]int, slots),
rd: rd,
slots: slots,
size: size,
}
}
func (r *WindowReader) Read(p []byte) (int, error) {
slot := r.current % r.slots
offset := slot * r.size
if r.current != r.head {
len := offset + r.len[slot]
n := copy(p, r.buf[offset:len])
r.current++
return n, nil
}
n, err := r.rd.Read(p)
n = copy(r.buf[offset:offset + n], p[:n])
r.len[slot] = n
r.current++
r.head = r.current
return n, err
}
func (r *WindowReader) Unread(n int) {
r.current -= n
}