| // Copyright 2017 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| package tftp |
| |
| import ( |
| "bytes" |
| "crypto/rand" |
| "encoding/binary" |
| "errors" |
| "fmt" |
| "io" |
| "net" |
| "strconv" |
| "strings" |
| "time" |
| ) |
| |
| const ( |
| // ServerPort is the default server port. |
| ServerPort = 33340 |
| // Client port is the default client port. |
| ClientPort = 33341 |
| |
| // WriteTimeout is the duration to wait for the client to ack a packet. |
| WriteTimeout = 2 * time.Second |
| // WriteAttempts is the maximum number of times a packet will be resent. |
| WriteAttempts = 5 |
| // BlockSize is the maximum block size used for file transfers. |
| BlockSize = 1024 |
| // WindowSize is the window size used for file transfers. |
| WindowSize = 64 |
| ) |
| |
| type Client struct { |
| // WriteTimeout sets the duration to wait for the client to ack a packet. |
| WriteTimeout time.Duration |
| // WriteAttempts controls how many times a packet will be resent. |
| WriteAttempts int |
| // BlockSize sets the maximum block size used for file transfers. |
| BlockSize int64 |
| // WindowSize sets the window size used for file transfers. |
| WindowSize int64 |
| } |
| |
| func NewClient() *Client { |
| return &Client { |
| WriteTimeout: WriteTimeout, |
| WriteAttempts: WriteAttempts, |
| BlockSize: BlockSize, |
| WindowSize: WindowSize, |
| } |
| } |
| |
| type options struct { |
| Timeout time.Duration |
| BlockSize int64 |
| WindowSize int64 |
| TransferSize int64 |
| } |
| |
| func (c *Client) Send(addr *net.UDPAddr, filename string, file io.Reader, size int64) error { |
| conn, err := net.ListenUDP("udp", &net.UDPAddr{Port: ServerPort}) |
| if err != nil { |
| return fmt.Errorf("creating socket: %s", err) |
| } |
| defer conn.Close() |
| |
| var b bytes.Buffer |
| |
| if _, err := io.CopyN(&b, rand.Reader, 1); err != nil { |
| return fmt.Errorf("reading rand: %s", err) |
| } |
| b.WriteByte(2) |
| |
| b.WriteString(filename) |
| b.WriteByte(0) |
| // Only support octet mode, because in practice that's the |
| // only remaining sensible use of TFTP. |
| b.WriteString("octet") |
| b.WriteByte(0) |
| |
| o := &options{ |
| Timeout: c.WriteTimeout, |
| BlockSize: c.BlockSize, |
| WindowSize: c.WindowSize, |
| TransferSize: size, |
| } |
| |
| c.writeOPT(&b, o) |
| |
| options, addr, err := c.sendWRQ(conn, addr, b.Bytes()) |
| if err != nil { |
| return fmt.Errorf("sending WRQ: %s", err) |
| } |
| b.Reset() |
| // We're the sender, we send the size |
| o.TransferSize = size |
| |
| ring := NewWindowReader(file, int(o.WindowSize), int(o.BlockSize)) |
| |
| if err := c.sendData(conn, addr, options, ring); err != nil { |
| return fmt.Errorf("sending data: %s", err) |
| } |
| |
| return nil |
| } |
| |
| func (c *Client) writeOPT(b *bytes.Buffer, o *options) { |
| b.WriteString("tsize") |
| b.WriteByte(0) |
| b.WriteString(strconv.FormatInt(o.TransferSize, 10)) |
| b.WriteByte(0) |
| |
| if o.BlockSize != 0 { |
| b.WriteString("blksize") |
| b.WriteByte(0) |
| b.WriteString(strconv.FormatInt(o.BlockSize, 10)) |
| b.WriteByte(0) |
| } |
| |
| if o.Timeout != 0 { |
| b.WriteString("timeout") |
| b.WriteByte(0) |
| b.WriteString(strconv.FormatInt(int64(o.Timeout / time.Second), 10)) |
| b.WriteByte(0) |
| } |
| |
| if o.WindowSize != 0 { |
| b.WriteString("windowsize") |
| b.WriteByte(0) |
| b.WriteString(strconv.FormatInt(o.WindowSize, 10)) |
| b.WriteByte(0) |
| } |
| } |
| |
| func (c *Client) sendWRQ(conn *net.UDPConn, addr *net.UDPAddr, b []byte) (*options, *net.UDPAddr, error) { |
| Attempt: |
| for attempt := 0; attempt < c.WriteAttempts; attempt++ { |
| if _, err := conn.WriteToUDP(b, addr); err != nil { |
| return nil, nil, err |
| } |
| |
| conn.SetReadDeadline(time.Now().Add(c.WriteTimeout)) |
| |
| var recv [256]byte |
| for { |
| n, addr, err := conn.ReadFromUDP(recv[:]) |
| if err != nil { |
| if t, ok := err.(net.Error); ok && t.Timeout() { |
| continue Attempt |
| } |
| return nil, nil, err |
| } |
| |
| if n < 4 { // packet too small |
| continue |
| } |
| switch recv[1] { |
| case 5: |
| msg, _, _ := tftpStr(recv[4:]) |
| return nil, addr, fmt.Errorf("server aborted transfer: %s", msg) |
| case 6: |
| options, err := parseOACK(recv[:n]) |
| return options, addr, err |
| } |
| } |
| } |
| |
| return nil, nil, errors.New("timeout waiting for ACK") |
| } |
| |
| func (c *Client) sendData(conn *net.UDPConn, addr *net.UDPAddr, o *options, reader *WindowReader) error { |
| var b bytes.Buffer |
| seq := uint16(0) |
| b.Grow(int(o.BlockSize + 4)) |
| |
| Loop: |
| for { |
| var n int64 |
| Attempt: |
| for attempt := 0; attempt < c.WriteAttempts; attempt++ { |
| for i := uint16(0); i < uint16(o.WindowSize); i++ { |
| b.Reset() |
| if _, err := io.CopyN(&b, rand.Reader, 1); err != nil { |
| return fmt.Errorf("writing rand: %s", err) |
| } |
| b.WriteByte(3) |
| seq++ |
| if err := binary.Write(&b, binary.BigEndian, seq); err != nil { |
| return fmt.Errorf("writing seqnum: %s", err) |
| } |
| var err error |
| n, err = io.CopyN(&b, reader, int64(o.BlockSize)) |
| if err != nil && err != io.EOF { |
| return fmt.Errorf("reading bytes for block %d: %s", seq, err) |
| } |
| if _, err := conn.WriteToUDP(b.Bytes(), addr); err != nil { |
| return fmt.Errorf("sending block %d: %s", seq, err) |
| } |
| if n < int64(o.BlockSize) { |
| break |
| } |
| } |
| |
| conn.SetReadDeadline(time.Now().Add(o.Timeout)) |
| |
| for { |
| var recv [256]byte |
| m, _, err := conn.ReadFromUDP(recv[:]) |
| if err != nil { |
| if t, ok := err.(net.Error); ok && t.Timeout() { |
| reader.Unread(int(o.WindowSize)) |
| seq -= uint16(o.WindowSize) |
| continue Attempt |
| } |
| return err |
| } |
| if m < 4 { // packet too small |
| continue |
| } |
| switch recv[1] { |
| case 4: |
| if num := binary.BigEndian.Uint16(recv[2:4]); num != seq { |
| if num < seq - uint16(o.WindowSize) { |
| return fmt.Errorf("invalid ACK: %q", num) |
| } |
| if num > seq { |
| // out-of-order ack? |
| } |
| reader.Unread(int(seq - num)) |
| seq = num |
| } else if n < int64(o.BlockSize) { |
| return nil |
| } |
| continue Loop |
| case 5: |
| msg, _, _ := tftpStr(recv[4:]) |
| return fmt.Errorf("server aborted transfer: %s", msg) |
| } |
| } |
| } |
| return errors.New("timeout waiting for ACK") |
| } |
| } |
| |
| func parseOACK(bs []byte) (*options, error) { |
| // Smallest a useful TFTP packet can be is 6 bytes: 2b opcode, 1b |
| // filename, 1b null, 1b mode, 1b null. |
| if len(bs) < 6 || binary.BigEndian.Uint16(bs[:2]) != 6 { |
| return nil, errors.New("not an OACK packet") |
| } |
| bs = bs[2:] |
| |
| o := &options{} |
| |
| for len(bs) > 0 { |
| opt, rest, err := tftpStr(bs) |
| if err != nil { |
| return nil, fmt.Errorf("reading option name: %s", err) |
| } |
| bs = rest |
| val, rest, err := tftpStr(bs) |
| if err != nil { |
| return nil, fmt.Errorf("reading option %q value: %s", opt, err) |
| } |
| bs = rest |
| switch strings.ToLower(opt) { |
| case "blksize": |
| size, err := strconv.ParseUint(val, 10, 16) |
| if err != nil { |
| return nil, fmt.Errorf("unsupported block size %q", val) |
| } |
| o.BlockSize = int64(size) |
| case "timeout": |
| seconds, err := strconv.ParseUint(val, 10, 8) |
| if err != nil { |
| return nil, fmt.Errorf("unsupported timeout %q", val) |
| } |
| o.Timeout = time.Second * time.Duration(seconds) |
| case "tsize": |
| size, err := strconv.ParseInt(val, 10, 64) |
| if err != nil { |
| return nil, fmt.Errorf("unsupported transfer size %q", val) |
| } |
| o.TransferSize = size |
| case "windowsize": |
| size, err := strconv.ParseUint(val, 10, 16) |
| if err != nil { |
| return nil, fmt.Errorf("unsupported window size %q", val) |
| } |
| o.WindowSize = int64(size) |
| } |
| } |
| |
| return o, nil |
| } |
| |
| // tftpStr extracts a null-terminated string from the given bytes, and |
| // returns any remaining bytes. |
| // |
| // String content is checked to be a "read-useful" subset of |
| // "netascii", itself a subset of ASCII. Specifically, all byte values |
| // must fall in the range 0x20 to 0x7E inclusive. |
| func tftpStr(bs []byte) (str string, remaining []byte, err error) { |
| for i, b := range bs { |
| if b == 0 { |
| return string(bs[:i]), bs[i+1:], nil |
| } else if b < 0x20 || b > 0x7E { |
| return "", nil, fmt.Errorf("invalid netascii byte %q at offset %d", b, i) |
| } |
| } |
| return "", nil, errors.New("no null terminated string found") |
| } |
| |
| // WindowReader supports reading bytes from an underlying stream using in |
| // fixed sized blocks and rewinding back up to slots blocks. |
| type WindowReader struct { |
| buf []byte // buffer space |
| len []int // len of data written to each slot |
| rd io.Reader |
| current int // current to be read or written to |
| head int // head of buffer |
| slots int |
| size int |
| } |
| |
| func NewWindowReader(rd io.Reader, slots int, size int) *WindowReader { |
| return &WindowReader{ |
| buf: make([]byte, slots * size), |
| len: make([]int, slots), |
| rd: rd, |
| slots: slots, |
| size: size, |
| } |
| } |
| |
| func (r *WindowReader) Read(p []byte) (int, error) { |
| slot := r.current % r.slots |
| offset := slot * r.size |
| |
| if r.current != r.head { |
| len := offset + r.len[slot] |
| n := copy(p, r.buf[offset:len]) |
| r.current++ |
| return n, nil |
| } |
| |
| n, err := r.rd.Read(p) |
| n = copy(r.buf[offset:offset + n], p[:n]) |
| r.len[slot] = n |
| |
| r.current++ |
| r.head = r.current |
| return n, err |
| } |
| |
| func (r *WindowReader) Unread(n int) { |
| r.current -= n |
| } |