blob: ee619f9bdca7fcb33d59f9e999cebd1b12b0dd4d [file] [log] [blame]
// Copyright 2019 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package tftp
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
"math"
"net"
"sync"
)
// Client is used to Read or Write files to/from a TFTP remote.
type Client interface {
Read(ctx context.Context, filename string) (*bytes.Reader, error)
Write(ctx context.Context, filename string, reader io.ReaderAt, size int64) error
RemoteAddr() *net.UDPAddr
}
// ClientImpl implements the Client interface; it is exported for testing.
type ClientImpl struct {
addr *net.UDPAddr
conn *net.UDPConn
mu *sync.Mutex
blockSize uint16
windowSize uint16
}
// NewClient returns a Client which can be used to Read or Write
// files to/from a TFTP remote. A blockSize and windowSize of 0 will use the
// default values defined in tftp.go.
func NewClient(addr *net.UDPAddr, blockSize, windowSize uint16) (*ClientImpl, error) {
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv6zero})
if err != nil {
return nil, err
}
if blockSize == 0 {
blockSize = defaultBlockSize
}
if windowSize == 0 {
windowSize = defaultWindowSize
}
return &ClientImpl{
addr: addr,
conn: conn,
mu: &sync.Mutex{},
blockSize: blockSize,
windowSize: windowSize,
}, nil
}
func (c *ClientImpl) newTransfer(opCode uint8, filename string) *transfer {
t := &transfer{
addr: c.addr,
buffer: bytes.NewBuffer([]byte{}),
client: c,
filename: filename,
opCode: opCode,
opts: &options{
timeout: timeout,
blockSize: c.blockSize,
windowSize: c.windowSize,
},
}
return t
}
// Read requests to read a file from the TFTP remote, if the request is successful,
// the contents of the remote file are returned as a bytes.Reader, if the request
// is unsuccessful an error is returned, if the error is ErrShouldWait, the request
// can be retried at some point in the future.
func (c *ClientImpl) Read(ctx context.Context, filename string) (*bytes.Reader, error) {
c.mu.Lock()
defer c.mu.Unlock()
t := c.newTransfer(opRrq, filename)
attempts := 0
transferStarted := false
for {
if !transferStarted {
// Send the request to read the file.
err := t.request()
if err != nil {
return nil, err
}
}
recv, err := t.wait(ctx, expectAny)
transferStarted = transferStarted || recv == opOack || recv == opData
if err != nil {
if err, ok := err.(net.Error); ok && err.Timeout() {
attempts++
if attempts < retries {
if !transferStarted {
t.cancel(errorUndefined, err)
}
continue
}
return nil, err
}
if err == io.EOF {
return bytes.NewReader(t.buffer.Bytes()), nil
}
t.cancel(errorUndefined, err)
return nil, err
}
if ctx.Err() != nil {
t.cancel(errorUndefined, ctx.Err())
return nil, ctx.Err()
}
}
}
func (c *ClientImpl) RemoteAddr() *net.UDPAddr {
return c.addr
}
// Write requests to send a file to the TFTP remote, if the operation is unsuccesful
// error is returned, if the error is ErrShouldWait, the request can be retried at
// some point in the future.
func (c *ClientImpl) Write(ctx context.Context, filename string, r io.ReaderAt, size int64) error {
c.mu.Lock()
defer c.mu.Unlock()
t := c.newTransfer(opWrq, filename)
attempts := 0
t.opts.transferSize = uint64(size)
// Send the request to write the file.
for {
if ctx.Err() != nil {
return ctx.Err()
}
err := t.request()
if err != nil {
return err
}
// Wait for the receiving side to ACK options.
if _, err := t.wait(ctx, expectOack); err != nil {
t.cancel(errorUndefined, err)
if err, ok := err.(net.Error); ok && err.Timeout() {
attempts++
if attempts < retries {
continue
}
return err
}
return err
}
break
}
// Send the file.
for {
for i := uint32(0); i < uint32(t.opts.windowSize); i++ {
b := make([]byte, t.opts.blockSize+dataOffset)
block := t.seq + i + 1
b[1] = opData
binary.BigEndian.PutUint16(b[2:], uint16(math.MaxUint16&block))
off := int64(block-1) * int64(t.opts.blockSize)
n, err := r.ReadAt(b[dataOffset:], off)
if err != nil && err != io.EOF {
t.cancel(errorUndefined, err)
return fmt.Errorf("reading bytes for block %d: %s", block, err)
}
isEOF := err == io.EOF
if err := t.send(b[:n+dataOffset]); err != nil {
t.cancel(errorUndefined, err)
return fmt.Errorf("sending block %d: %s", block, err)
}
// We transfered all data. Break & wait for ACK.
if isEOF {
break
}
}
// Wait for the receiving side to ACK or possibly error out.
if _, err := t.wait(ctx, expectAck); err != nil {
if err, ok := err.(net.Error); ok && err.Timeout() {
continue
}
return err
}
// The full transfer has been ACK'd. Finished.
if uint64(t.seq)*uint64(t.opts.blockSize) > t.opts.transferSize {
return nil
}
if ctx.Err() != nil {
t.cancel(errorUndefined, ctx.Err())
return ctx.Err()
}
}
}