// Copyright 2016 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package tcpip provides the interfaces and related types that users of the
// tcpip stack will use in order to create endpoints used to send and receive
// data over the network stack.
//
// The starting point is the creation and configuration of a stack. A stack can
// be created by calling the New() function of the tcpip/stack/stack package;
// configuring a stack involves creating NICs (via calls to Stack.CreateNIC()),
// adding network addresses (via calls to Stack.AddAddress()), and
// setting a route table (via a call to Stack.SetRouteTable()).
//
// Once a stack is configured, endpoints can be created by calling
// Stack.NewEndpoint(). Such endpoints can be used to send/receive data, connect
// to peers, listen for connections, accept connections, etc., depending on the
// transport protocol selected.
package tcpip

import (
	"errors"
	"fmt"
	"strconv"
	"strings"
	"time"

	"github.com/google/netstack/tcpip/buffer"
	"github.com/google/netstack/waiter"
)

// Error represents an error in the netstack error space. Using a special type
// ensures that errors outside of this space are not accidentally introduced.
type Error struct {
	msg string

	// IgnoreStats determines whether this error type should be included in
	// failure counts in tcpip.Stats structs.
	IgnoreStats bool
}

// String implements fmt.Stringer.String.
func (e *Error) String() string {
	if e == nil {
		return "<nil>"
	}
	return e.msg
}

// Errors that can be returned by the network stack.
var (
	ErrUnknownProtocol       = &Error{msg: "unknown protocol"}
	ErrUnknownNICID          = &Error{msg: "unknown nic id"}
	ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"}
	ErrDuplicateNICID        = &Error{msg: "duplicate nic id"}
	ErrDuplicateAddress      = &Error{msg: "duplicate address"}
	ErrNoRoute               = &Error{msg: "no route"}
	ErrBadLinkEndpoint       = &Error{msg: "bad link layer endpoint"}
	ErrAlreadyBound          = &Error{msg: "endpoint already bound", IgnoreStats: true}
	ErrInvalidEndpointState  = &Error{msg: "endpoint is in invalid state"}
	ErrAlreadyConnecting     = &Error{msg: "endpoint is already connecting", IgnoreStats: true}
	ErrAlreadyConnected      = &Error{msg: "endpoint is already connected", IgnoreStats: true}
	ErrNoPortAvailable       = &Error{msg: "no ports are available"}
	ErrPortInUse             = &Error{msg: "port is in use"}
	ErrBadLocalAddress       = &Error{msg: "bad local address"}
	ErrClosedForSend         = &Error{msg: "endpoint is closed for send"}
	ErrClosedForReceive      = &Error{msg: "endpoint is closed for receive"}
	ErrWouldBlock            = &Error{msg: "operation would block", IgnoreStats: true}
	ErrConnectionRefused     = &Error{msg: "connection was refused"}
	ErrTimeout               = &Error{msg: "operation timed out"}
	ErrAborted               = &Error{msg: "operation aborted"}
	ErrConnectStarted        = &Error{msg: "connection attempt started", IgnoreStats: true}
	ErrDestinationRequired   = &Error{msg: "destination address is required"}
	ErrNotSupported          = &Error{msg: "operation not supported"}
	ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"}
	ErrNotConnected          = &Error{msg: "endpoint not connected"}
	ErrConnectionReset       = &Error{msg: "connection reset by peer"}
	ErrConnectionAborted     = &Error{msg: "connection aborted"}
	ErrNoLinkAddress         = &Error{msg: "no remote link address"}
	ErrNoSuchFile            = &Error{msg: "no such file"}
	ErrInvalidOptionValue    = &Error{msg: "invalid option value specified"}
)

// Errors related to Subnet
var (
	errSubnetLengthMismatch = errors.New("subnet length of address and mask differ")
	errSubnetAddressMasked  = errors.New("subnet address has bits set outside the mask")
	errInvalidCIDRNotation  = errors.New("CIDR notation invalid")
)

// Address is a byte slice cast as a string that represents the address of a
// network node. Or, in the case of unix endpoints, it may represent a path.
type Address string

// AddressMask is a bitmask for an address.
type AddressMask string

// Subnet is a subnet defined by its address and mask.
type Subnet struct {
	address Address
	mask    AddressMask
}

func (a Address) Mask(m AddressMask) Address {
	out := []byte(a)
	for i, _ := range a {
		out[i] = a[i] & m[i]
	}
	return Address(out)
}

// NewSubnet creates a new Subnet, checking that the address and mask are the same length.
func NewSubnet(a Address, m AddressMask) (Subnet, error) {
	if len(a) != len(m) {
		return Subnet{}, errSubnetLengthMismatch
	}
	for i := 0; i < len(a); i++ {
		if a[i]&^m[i] != 0 {
			return Subnet{}, errSubnetAddressMasked
		}
	}
	return Subnet{a, m}, nil
}

// Contains returns true iff the address is of the same length and matches the
// subnet address and mask.
func (s *Subnet) Contains(a Address) bool {
	if len(a) != len(s.address) {
		return false
	}
	for i := 0; i < len(a); i++ {
		if a[i]&s.mask[i] != s.address[i] {
			return false
		}
	}
	return true
}

// ID returns the subnet ID.
func (s *Subnet) ID() Address {
	return s.address
}

// Bits returns the number of ones (network bits) and zeros (host bits) in the
// subnet mask.
func (s *Subnet) Bits() (ones int, zeros int) {
	for _, b := range []byte(s.mask) {
		for i := uint(0); i < 8; i++ {
			if b&(1<<i) == 0 {
				zeros++
			} else {
				ones++
			}
		}
	}
	return
}

// Prefix returns the number of bits before the first host bit.
func (s *Subnet) Prefix() int {
	for i, b := range []byte(s.mask) {
		for j := 7; j >= 0; j-- {
			if b&(1<<uint(j)) == 0 {
				return i*8 + 7 - j
			}
		}
	}
	return len(s.mask) * 8
}

// Mask returns the subnet mask.
// Getter instead of exported field to avoid rename before gVisor merge.
func (s *Subnet) Mask() AddressMask {
	return s.mask
}

// String implements fmt.Stringer.String.
func (s Subnet) String() string {
	return fmt.Sprintf("{ address=%s, mask=%s }", s.address, Address(s.mask))
}

// NICID is a number that uniquely identifies a NIC.
type NICID int32

// ShutdownFlags represents flags that can be passed to the Shutdown() method
// of the Endpoint interface.
type ShutdownFlags int

// Values of the flags that can be passed to the Shutdown() method. They can
// be OR'ed together.
const (
	ShutdownRead ShutdownFlags = 1 << iota
	ShutdownWrite
)

// FullAddress represents a full transport node address, as required by the
// Connect() and Bind() methods.
type FullAddress struct {
	// NIC is the ID of the NIC this address refers to.
	//
	// This may not be used by all endpoint types.
	NIC NICID

	// Addr is the network address.
	Addr Address

	// Port is the transport port.
	//
	// This may not be used by all endpoint types.
	Port uint16
}

// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
// that exposes functionality like read, write, connect, etc. to users of the
// networking stack.
type Endpoint interface {
	// Close puts the endpoint in a closed state and frees all resources
	// associated with it.
	Close()

	// Read reads data from the endpoint and optionally returns the sender.
	// This method does not block if there is no data pending.
	// It will also either return an error or data, never both.
	Read(*FullAddress) (buffer.View, *Error)

	// Write writes data to the endpoint's peer, or the provided address if
	// one is specified. This method does not block if the data cannot be
	// written.
	//
	// Note that unlike io.Writer.Write, it is not an error for Write to
	// perform a partial write.
	Write(buffer.View, *FullAddress) (uintptr, *Error)

	// Peek reads data without consuming it from the endpoint.
	//
	// This method does not block if there is no data pending.
	Peek([][]byte) (uintptr, *Error)

	// Connect connects the endpoint to its peer. Specifying a NIC is
	// optional.
	//
	// There are three classes of return values:
	//	nil -- the attempt to connect succeeded.
	//	ErrConnectStarted -- the connect attempt started but hasn't
	//		completed yet. In this case, the actual result will
	//		become available via GetSockOpt(ErrorOption) when
	//		the endpoint becomes writable. (This mimics the
	//		connect(2) syscall behavior.)
	//	Anything else -- the attempt to connect failed.
	Connect(address FullAddress) *Error

	// Shutdown closes the read and/or write end of the endpoint connection
	// to its peer.
	Shutdown(flags ShutdownFlags) *Error

	// Listen puts the endpoint in "listen" mode, which allows it to accept
	// new connections.
	Listen(backlog int) *Error

	// Accept returns a new endpoint if a peer has established a connection
	// to an endpoint previously set to listen mode. This method does not
	// block if no new connections are available.
	//
	// The returned Queue is the wait queue for the newly created endpoint.
	Accept() (Endpoint, *waiter.Queue, *Error)

	// Bind binds the endpoint to a specific local address and port.
	// Specifying a NIC is optional.
	//
	// An optional commit function will be executed atomically with respect
	// to binding the endpoint. If this returns an error, the bind will not
	// occur and the error will be propagated back to the caller.
	Bind(address FullAddress, commit func() *Error) *Error

	// GetLocalAddress returns the address to which the endpoint is bound.
	GetLocalAddress() (FullAddress, *Error)

	// GetRemoteAddress returns the address to which the endpoint is
	// connected.
	GetRemoteAddress() (FullAddress, *Error)

	// Readiness returns the current readiness of the endpoint. For example,
	// if waiter.EventIn is set, the endpoint is immediately readable.
	Readiness(mask waiter.EventMask) waiter.EventMask

	// SetSockOpt sets a socket option. opt should be one of the *Option types.
	SetSockOpt(opt interface{}) *Error

	// GetSockOpt gets a socket option. opt should be a pointer to one of the
	// *Option types.
	GetSockOpt(opt interface{}) *Error
}

// ErrorOption is used in GetSockOpt to specify that the last error reported by
// the endpoint should be cleared and returned.
type ErrorOption struct{}

// SendBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the send
// buffer size option.
type SendBufferSizeOption int

// ReceiveBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the
// receive buffer size option.
type ReceiveBufferSizeOption int

// SendQueueSizeOption is used in GetSockOpt to specify that the number of
// unread bytes in the output buffer should be returned.
type SendQueueSizeOption int

// ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of
// unread bytes in the input buffer should be returned.
type ReceiveQueueSizeOption int

// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6
// socket is to be restricted to sending and receiving IPv6 packets only.
type V6OnlyOption int

// NoDelayOption is used by SetSockOpt/GetSockOpt to specify if data should be
// sent out immediately by the transport protocol. For TCP, it determines if the
// Nagle algorithm is on or off.
type NoDelayOption int

// ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind()
// should allow reuse of local address.
type ReuseAddressOption int

// PasscredOption is used by SetSockOpt/GetSockOpt to specify whether
// SCM_CREDENTIALS socket control messages are enabled.
//
// Only supported on Unix sockets.
type PasscredOption int

// KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether
// TCP keepalive is enabled for this socket.
type KeepaliveEnabledOption int

// KeepaliveIdleOption is used by SetSockOpt/GetSockOpt to specify the time
// a connection must remain idle before the first TCP keepalive packet is sent.
// Once this time is reached, KeepaliveIntervalOption is used instead.
type KeepaliveIdleOption time.Duration

// KeepaliveIntervalOption is used by SetSockOpt/GetSockOpt to specify the
// interval between sending TCP keepalive packets.
type KeepaliveIntervalOption time.Duration

// KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the
// number of un-ACKed TCP keepalives that will be sent before the connection
// is closed.
type KeepaliveCountOption int

// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
// TTL value for multicast messages. The default is 1.
type MulticastTTLOption uint8

// MembershipOption is used by SetSockOpt/GetSockOpt as an argument to
// AddMembershipOption and RemoveMembershipOption
type MembershipOption struct {
	NIC           NICID
	InterfaceAddr Address
	MulticastAddr Address
}

// AddMembershipOption is used by SetSockOpt/GetSockOpt to join a multicast
// group identified by the given multicast address, on the interface matching
// the given interface address.
type AddMembershipOption MembershipOption

// RemoveMembershipOption is used by SetSockOpt/GetSockOpt to leave a multicast
// group identified by the given multicast address, on the interface matching
// the given interface address.
type RemoveMembershipOption MembershipOption

// InfoOption is used by GetSockOpt to query various metrics about the socket.
type InfoOption struct {
	Rtt    time.Duration
	Rttvar time.Duration
}

// Route is a row in the routing table. It specifies through which NIC (and
// gateway) sets of packets should be routed. A row is considered viable if the
// masked target address matches the destination adddress in the row.
type Route struct {
	// Destination is the address that must be matched against the masked
	// target address to check if this row is viable.
	Destination Address

	// Mask specifies which bits of the Destination and the target address
	// must match for this row to be viable.
	Mask Address

	// Gateway is the gateway to be used if this row is viable.
	Gateway Address

	// NIC is the id of the nic to be used if this row is viable.
	NIC NICID
}

// Match determines if r is viable for the given destination address.
func (r *Route) Match(addr Address) bool {
	if len(addr) != len(r.Destination) {
		return false
	}

	for i := 0; i < len(r.Destination); i++ {
		if (addr[i] & r.Mask[i]) != r.Destination[i] {
			return false
		}
	}

	return true
}

// LinkEndpointID represents a data link layer endpoint.
type LinkEndpointID uint64

// TransportProtocolNumber is the number of a transport protocol.
type TransportProtocolNumber uint32

// NetworkProtocolNumber is the number of a network protocol.
type NetworkProtocolNumber uint32

// IPStats collects IP-specific stats (both v4 and v6).
type IPStats struct {
	// PacketsReceived is the total number of IP packets received from the link
	// layer in nic.DeliverNetworkPacket.
	PacketsReceived uint64
	// InvalidAddressesReceived is the total number of IP packets received
	// with an unknown or invalid destination address.
	InvalidAddressesReceived uint64
	// PacketsDiscarded is the total number of IP packets received from the link
	// layer but not delivered to the transport layer.
	PacketsDiscarded uint64
	// PacketsDelivered is the total number of incoming IP packets that
	// are successfully delivered to the transport layer via HandlePacket.
	PacketsDelivered uint64
	// PacketsSent is the total number of IP packets sent via WritePacket.
	PacketsSent uint64
	// OutgoingPacketErrors is the total number of IP packets which failed
	// to write to a link-layer endpoint.
	OutgoingPacketErrors uint64
}

// TCPStats collects TCP-specific stats.
type TCPStats struct {
	// ActiveConnectionOpenings is the number of connections opened successfully
	// via Connect.
	ActiveConnectionOpenings uint64
	// PassiveConnectionOpenings is the number of connections opened
	// successfully via Listen.
	PassiveConnectionOpenings uint64
	// FailedConnectionAttempts is the number of calls to Connect or Listen
	// (active and passive openings, respectively) that end in an error.
	FailedConnectionAttempts uint64
	// ValidSegmentsReceived is the number of TCP segments received that the
	// transport layer successfully parsed.
	ValidSegmentsReceived uint64
	// InvalidSegmentsReceived is the number of TCP segments received that
	// the transport layer could not parse.
	InvalidSegmentsReceived uint64
	// SegmentsSent is the number of TCP segments sent.
	SegmentsSent uint64
	// ResetsSent is the number of TCP resets sent.
	ResetsSent uint64
	// ResetsReceived is the number of TCP resets received.
	ResetsReceived uint64
}

// UDPStats collects UDP-specific stats.
type UDPStats struct {
	// PacketsReceived is the number of UDP datagrams received via HandlePacket.
	PacketsReceived uint64
	// UnknownPortErrors is the number of incoming UDP datagrams dropped because
	// they did not have a known destination port.
	UnknownPortErrors uint64
	// ReceiveBufferErrors is the number of incoming UDP datagrams dropped due to the
	// receiving buffer being in an invalid state.
	ReceiveBufferErrors uint64
	// MalformedPacketsReceived is the number of incoming UDP datagrams dropped due to
	// the UDP header being in a malformed state.
	MalformedPacketsReceived uint64
	// PacketsSent is the number of UDP datagrams sent via sendUDP.
	PacketsSent uint64
}

// Stats holds statistics about the networking stack.
type Stats struct {
	// UnkownProtocolRcvdPackets is the number of packets received by the
	// stack that were for an unknown or unsupported protocol.
	UnknownProtocolRcvdPackets uint64

	// MalformedRcvPackets is the number of packets received by the stack
	// that were deemed malformed.
	MalformedRcvdPackets uint64

	// DroppedPackets is the number of packets dropped due to full queues.
	DroppedPackets uint64

	// IP breaks out IP-specific stats (both v4 and v6).
	IP IPStats

	// TCP breaks out TCP-specific stats.
	TCP TCPStats

	// UDP breaks out UDP-specific stats.
	UDP UDPStats
}

// String implements the fmt.Stringer interface.
func (a Address) String() string {
	switch len(a) {
	case 4:
		return fmt.Sprintf("%d.%d.%d.%d", int(a[0]), int(a[1]), int(a[2]), int(a[3]))
	case 16:
		// Find the longest subsequence of hexadecimal zeros.
		start, end := -1, -1
		for i := 0; i < len(a); i += 2 {
			j := i
			for j < len(a) && a[j] == 0 && a[j+1] == 0 {
				j += 2
			}
			if j > i+2 && j-i > end-start {
				start, end = i, j
			}
		}

		var b []byte
		for i := 0; i < len(a); i += 2 {
			if i == start {
				b = append(b, "::"...)
				i = end
				if end >= len(a) {
					break
				}
			} else if i > 0 {
				b = append(b, ':')
			}
			v := uint16(a[i+0])<<8 | uint16(a[i+1])
			b = appendHex(b, v)
		}
		return string(b)
	default:
		return fmt.Sprintf("%x", []byte(a))
	}
}

func appendHex(b []byte, v uint16) []byte {
	if v == 0 {
		return append(b, '0')
	}
	const digits = "0123456789abcdef"
	for i := uint(3); i < 4; i-- {
		if v := v >> (i * 4); v != 0 {
			b = append(b, digits[v&0xf])
		}
	}
	return b
}

// To4 converts the IPv4 address to a 4-byte representation.
// If the address is not an IPv4 address, To4 returns "".
func (a Address) To4() Address {
	const (
		ipv4len = 4
		ipv6len = 16
	)
	if len(a) == ipv4len {
		return a
	}
	if len(a) == ipv6len &&
		isZeros(a[0:10]) &&
		a[10] == 0xff &&
		a[11] == 0xff {
		return a[12:16]
	}
	return ""
}

// isZeros reports whether a is all zeros.
func isZeros(a Address) bool {
	for i := 0; i < len(a); i++ {
		if a[i] != 0 {
			return false
		}
	}
	return true
}

// Copied from pkg Net to avoid taking a dependency.
func CIDRMask(ones, bits int) AddressMask {
	// header.IPv4AddressSize, header.IPv6AddressSize
	if bits != 8*4 && bits != 8*16 {
		return AddressMask("")
	}
	if ones < 0 || ones > bits {
		return AddressMask("")
	}
	l := bits / 8
	m := make([]byte, l)
	n := uint(ones)
	for i := 0; i < l; i++ {
		if n >= 8 {
			m[i] = 0xff
			n -= 8
			continue
		}
		m[i] = ^byte(0xff >> n)
		n = 0
	}
	return AddressMask(m)
}

func ParseCIDR(subnet string) (Address, Subnet, error) {
	split := strings.Split(subnet, "/")
	if len(split) != 2 {
		return Address(""), Subnet{}, errInvalidCIDRNotation
	}
	addr := Parse(split[0])
	ones, err := strconv.ParseInt(split[1], 10, 8)

	if err != nil {
		return Address(""), Subnet{}, err
	}

	mask := CIDRMask(int(ones), 8*len(addr))
	sn, err := NewSubnet(addr.Mask(mask), mask)
	return addr, sn, err
}

// Parse parses the string representation of an IPv4 or IPv6 address.
func Parse(src string) Address {
	for i := 0; i < len(src); i++ {
		switch src[i] {
		case '.':
			return parseIP4(src)
		case ':':
			return parseIP6(src)
		}
	}
	return ""
}

func parseIP4(src string) Address {
	var addr [4]byte
	_, err := fmt.Sscanf(src, "%d.%d.%d.%d", &addr[0], &addr[1], &addr[2], &addr[3])
	if err != nil {
		return ""
	}
	return Address(addr[:])
}

func parseIP6(src string) (res Address) {
	a := make([]byte, 0, 16) // cap(a) is constant throughout
	expansion := -1          // index of '::' expansion in a

	if len(src) >= 2 && src[:2] == "::" {
		if len(src) == 2 {
			return Address(a[:cap(a)])
		}
		expansion = 0
		src = src[2:]
	}

	for len(a) < cap(a) && len(src) > 0 {
		var x uint16
		var ok bool
		x, src, ok = parseHex(src)
		if !ok {
			return ""
		}
		a = append(a, uint8(x>>8), uint8(x))

		if len(src) == 0 {
			break
		}

		// Next is either ":..." or "::[...]".
		if src[0] != ':' || len(src) == 1 {
			return ""
		}
		src = src[1:]
		if src[0] == ':' {
			if expansion >= 0 {
				return "" // only one expansion allowed
			}
			expansion = len(a)
			src = src[1:]
		}
	}
	if len(src) != 0 {
		return ""
	}

	if missing := cap(a) - len(a); missing > 0 {
		if expansion < 0 {
			return ""
		}
		a = a[:cap(a)]
		copy(a[expansion+missing:], a[expansion:])
		for i := 0; i < missing; i++ {
			a[i+expansion] = 0
		}
	}

	return Address(a)
}

func parseHex(src string) (x uint16, remaining string, ok bool) {
	if len(src) == 0 {
		return 0, src, false
	}
loop:
	for len(src) > 0 {
		v := src[0]
		switch {
		case '0' <= v && v <= '9':
			v = v - '0'
		case 'a' <= v && v <= 'f':
			v = v - 'a' + 10
		case 'A' <= v && v <= 'F':
			v = v - 'A' + 10
		case v == ':':
			break loop
		default:
			return 0, src, false
		}
		src = src[1:]
		x = (x << 4) | uint16(v)
	}
	return x, src, true
}

// LinkAddress is a byte slice cast as a string that represents a link address.
// It is typically a 6-byte MAC address.
type LinkAddress string

// String implements the fmt.Stringer interface.
func (a LinkAddress) String() string {
	switch len(a) {
	case 6:
		return fmt.Sprintf("%02x:%02x:%02x:%02x:%02x:%02x", a[0], a[1], a[2], a[3], a[4], a[5])
	default:
		return fmt.Sprintf("%x", []byte(a))
	}
}
