// Copyright 2017 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package tftp

import (
	"bytes"
	"crypto/rand"
	"encoding/binary"
	"errors"
	"fmt"
	"io"
	"net"
	"strconv"
	"strings"
	"time"
)

const (
	// ServerPort is the default server port.
	ServerPort = 33340
	// Client port is the default client port.
	ClientPort = 33341

	// WriteTimeout is the duration to wait for the client to ack a packet.
	WriteTimeout = 2 * time.Second
	// WriteAttempts is the maximum number of times a packet will be resent.
	WriteAttempts = 5
	// BlockSize is the maximum block size used for file transfers.
	BlockSize = 1024
	// WindowSize is the window size used for file transfers.
	WindowSize = 64
)

type Client struct {
	// WriteTimeout sets the duration to wait for the client to ack a packet.
	WriteTimeout time.Duration
	// WriteAttempts controls how many times a packet will be resent.
	WriteAttempts int
	// BlockSize sets the maximum block size used for file transfers.
	BlockSize int64
	// WindowSize sets the window size used for file transfers.
	WindowSize int64
}

func NewClient() *Client {
	return &Client {
		WriteTimeout: WriteTimeout,
		WriteAttempts: WriteAttempts,
		BlockSize: BlockSize,
		WindowSize: WindowSize,
	}
}

type options struct {
	Timeout			time.Duration
	BlockSize		int64
	WindowSize		int64
	TransferSize	int64
}

func (c *Client) Send(addr *net.UDPAddr, filename string, file io.Reader, size int64) error {
	conn, err := net.ListenUDP("udp", &net.UDPAddr{Port: ServerPort})
	if err != nil {
		return fmt.Errorf("creating socket: %s", err)
	}
	defer conn.Close()

	var b bytes.Buffer

	if _, err := io.CopyN(&b, rand.Reader, 1); err != nil {
		return fmt.Errorf("reading rand: %s", err)
	}
	b.WriteByte(2)

	b.WriteString(filename)
	b.WriteByte(0)
	// Only support octet mode, because in practice that's the
	// only remaining sensible use of TFTP.
	b.WriteString("octet")
	b.WriteByte(0)

	o := &options{
		Timeout: c.WriteTimeout,
		BlockSize: c.BlockSize,
		WindowSize: c.WindowSize,
		TransferSize: size,
	}

	c.writeOPT(&b, o)

	options, addr, err := c.sendWRQ(conn, addr, b.Bytes())
	if err != nil {
		return fmt.Errorf("sending WRQ: %s", err)
	}
	b.Reset()
	// We're the sender, we send the size
	o.TransferSize = size

	ring := NewWindowReader(file, int(o.WindowSize), int(o.BlockSize))

	if err := c.sendData(conn, addr, options, ring); err != nil {
		return fmt.Errorf("sending data: %s", err)
	}

	return nil
}

func (c *Client) writeOPT(b *bytes.Buffer, o *options) {
	b.WriteString("tsize")
	b.WriteByte(0)
	b.WriteString(strconv.FormatInt(o.TransferSize, 10))
	b.WriteByte(0)

	if o.BlockSize != 0 {
		b.WriteString("blksize")
		b.WriteByte(0)
		b.WriteString(strconv.FormatInt(o.BlockSize, 10))
		b.WriteByte(0)
	}

	if o.Timeout != 0 {
		b.WriteString("timeout")
		b.WriteByte(0)
		b.WriteString(strconv.FormatInt(int64(o.Timeout / time.Second), 10))
		b.WriteByte(0)
	}

	if o.WindowSize != 0 {
		b.WriteString("windowsize")
		b.WriteByte(0)
		b.WriteString(strconv.FormatInt(o.WindowSize, 10))
		b.WriteByte(0)
	}
}

func (c *Client) sendWRQ(conn *net.UDPConn, addr *net.UDPAddr, b []byte) (*options, *net.UDPAddr, error) {
Attempt:
	for attempt := 0; attempt < c.WriteAttempts; attempt++ {
		if _, err := conn.WriteToUDP(b, addr); err != nil {
			return nil, nil, err
		}

		conn.SetReadDeadline(time.Now().Add(c.WriteTimeout))

		var recv [256]byte
		for {
			n, addr, err := conn.ReadFromUDP(recv[:])
			if err != nil {
				if t, ok := err.(net.Error); ok && t.Timeout() {
					continue Attempt
				}
				return nil, nil, err
			}

			if n < 4 { // packet too small
				continue
			}
			switch recv[1] {
			case 5:
				msg, _, _ := tftpStr(recv[4:])
				return nil, addr, fmt.Errorf("server aborted transfer: %s", msg)
			case 6:
				options, err := parseOACK(recv[:n])
				return options, addr, err
			}
		}
	}

	return nil, nil, errors.New("timeout waiting for ACK")
}

func (c *Client) sendData(conn *net.UDPConn, addr *net.UDPAddr, o *options, reader *WindowReader) error {
	var b bytes.Buffer
	seq := uint16(0)
	b.Grow(int(o.BlockSize + 4))

Loop:
	for {
		var n int64
Attempt:
		for attempt := 0; attempt < c.WriteAttempts; attempt++ {
			for i := uint16(0); i < uint16(o.WindowSize); i++ {
				b.Reset()
				if _, err := io.CopyN(&b, rand.Reader, 1); err != nil {
					return fmt.Errorf("writing rand: %s", err)
				}
				b.WriteByte(3)
				seq++
				if err := binary.Write(&b, binary.BigEndian, seq); err != nil {
					return fmt.Errorf("writing seqnum: %s", err)
				}
				var err error
				n, err = io.CopyN(&b, reader, int64(o.BlockSize))
				if err != nil && err != io.EOF {
					return fmt.Errorf("reading bytes for block %d: %s", seq, err)
				}
				if _, err := conn.WriteToUDP(b.Bytes(), addr); err != nil {
					return fmt.Errorf("sending block %d: %s", seq, err)
				}
				if n < int64(o.BlockSize) {
					break
				}
			}

			conn.SetReadDeadline(time.Now().Add(o.Timeout))

			for {
				var recv [256]byte
				m, _, err := conn.ReadFromUDP(recv[:])
				if err != nil {
					if t, ok := err.(net.Error); ok && t.Timeout() {
						reader.Unread(int(o.WindowSize))
						seq -= uint16(o.WindowSize)
						continue Attempt
					}
					return err
				}
				if m < 4 { // packet too small
					continue
				}
				switch recv[1] {
				case 4:
					if num := binary.BigEndian.Uint16(recv[2:4]); num != seq {
						if num < seq - uint16(o.WindowSize) {
							return fmt.Errorf("invalid ACK: %q", num)
						}
						if num > seq {
							// out-of-order ack?
						}
						reader.Unread(int(seq - num))
						seq = num
					} else if n < int64(o.BlockSize) {
						return nil
					}
					continue Loop
				case 5:
					msg, _, _ := tftpStr(recv[4:])
					return fmt.Errorf("server aborted transfer: %s", msg)
				}
			}
		}
		return errors.New("timeout waiting for ACK")
	}
}

func parseOACK(bs []byte) (*options, error) {
	// Smallest a useful TFTP packet can be is 6 bytes: 2b opcode, 1b
	// filename, 1b null, 1b mode, 1b null.
	if len(bs) < 6 || binary.BigEndian.Uint16(bs[:2]) != 6 {
		return nil, errors.New("not an OACK packet")
	}
	bs = bs[2:]

	o := &options{}

	for len(bs) > 0 {
		opt, rest, err := tftpStr(bs)
		if err != nil {
			return nil, fmt.Errorf("reading option name: %s", err)
		}
		bs = rest
		val, rest, err := tftpStr(bs)
		if err != nil {
			return nil, fmt.Errorf("reading option %q value: %s", opt, err)
		}
		bs = rest
		switch strings.ToLower(opt) {
		case "blksize":
			size, err := strconv.ParseUint(val, 10, 16)
			if err != nil {
				return nil, fmt.Errorf("unsupported block size %q", val)
			}
			o.BlockSize = int64(size)
		case "timeout":
			seconds, err := strconv.ParseUint(val, 10, 8)
			if err != nil {
				return nil, fmt.Errorf("unsupported timeout %q", val)
			}
			o.Timeout = time.Second * time.Duration(seconds)
		case "tsize":
			size, err := strconv.ParseInt(val, 10, 64)
			if err != nil {
				return nil, fmt.Errorf("unsupported transfer size %q", val)
			}
			o.TransferSize = size
		case "windowsize":
			size, err := strconv.ParseUint(val, 10, 16)
			if err != nil {
				return nil, fmt.Errorf("unsupported window size %q", val)
			}
			o.WindowSize = int64(size)
		}
	}

	return o, nil
}

// tftpStr extracts a null-terminated string from the given bytes, and
// returns any remaining bytes.
//
// String content is checked to be a "read-useful" subset of
// "netascii", itself a subset of ASCII. Specifically, all byte values
// must fall in the range 0x20 to 0x7E inclusive.
func tftpStr(bs []byte) (str string, remaining []byte, err error) {
	for i, b := range bs {
		if b == 0 {
			return string(bs[:i]), bs[i+1:], nil
		} else if b < 0x20 || b > 0x7E {
			return "", nil, fmt.Errorf("invalid netascii byte %q at offset %d", b, i)
		}
	}
	return "", nil, errors.New("no null terminated string found")
}

// WindowReader supports reading bytes from an underlying stream using in
// fixed sized blocks and rewinding back up to slots blocks.
type WindowReader struct {
	buf     []byte // buffer space
	len		[]int  // len of data written to each slot
	rd		io.Reader
	current int    // current to be read or written to
	head    int    // head of buffer
	slots   int
	size    int
}

func NewWindowReader(rd io.Reader, slots int, size int) *WindowReader {
	return &WindowReader{
		buf: make([]byte, slots * size),
		len: make([]int, slots),
		rd: rd,
		slots: slots,
		size: size,
	}
}

func (r *WindowReader) Read(p []byte) (int, error) {
	slot := r.current % r.slots
	offset := slot * r.size

	if r.current != r.head {
		len := offset + r.len[slot]
		n := copy(p, r.buf[offset:len])
		r.current++
		return n, nil
	}

	n, err := r.rd.Read(p)
	n = copy(r.buf[offset:offset + n], p[:n])
	r.len[slot] = n

	r.current++
	r.head = r.current
	return n, err
}

func (r *WindowReader) Unread(n int) {
	r.current -= n
}
