// Copyright 2016 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <endian.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>

#include "base/xalloc.h"
#include "net/ipv6/inet6.h"
#include "net/ipv6/ipv6.h"

enum {
	Ipv6_HdrLen = 40,
	Udp_HdrLen = 8,
};

// Convert MAC Address to IPv6 Link Local Address.
// aa:bb:cc:dd:ee:ff => FF80::aabb:ccFF:FEdd:eeff
// Bit 2 (U/L) of the mac is inverted.
static void inet6_lladdr_from_mac(Ipv6Address *_ip, const MacAddress *_mac)
{
	uint8_t *ip = _ip->x;
	const uint8_t *mac = _mac->octet;

	memset(ip, 0, Ipv6_AddrLen);

	ip[0] = 0xFE;
	ip[1] = 0x80;
	memset(ip + 2, 0, 6);
	ip[8] = mac[0] ^ 2;
	ip[9] = mac[1];
	ip[10] = mac[2];
	ip[11] = 0xFF;
	ip[12] = 0xFE;
	ip[13] = mac[3];
	ip[14] = mac[4];
	ip[15] = mac[5];
}

// Convert MAC Address to IPv6 Solicit Neighbor Multicast Address.
// aa:bb:cc:dd:ee:ff -> FF02::1:FFdd:eeff
static void inet6_snmaddr_from_mac(Ipv6Address *_ip, const MacAddress *_mac)
{
	uint8_t *ip = _ip->x;
	const uint8_t *mac = _mac->octet;

	ip[0] = 0xFF;
	ip[1] = 0x02;
	memset(ip + 2, 0, 9);
	ip[11] = 0x01;
	ip[12] = 0xFF;
	ip[13] = mac[3];
	ip[14] = mac[4];
	ip[15] = mac[5];
}

// Convert IPv6 Multicast Address to Ethernet Multicast Address.
static void inet6_multicast_from_ipv6(MacAddress *_mac,
				      const Ipv6Address *_ipv6)
{
	const uint8_t *ip = _ipv6->x;
	uint8_t *mac = _mac->octet;
	mac[0] = 0x33;
	mac[1] = 0x33;
	mac[2] = ip[12];
	mac[3] = ip[13];
	mac[4] = ip[14];
	mac[5] = ip[15];
}

// Cache for the last source addresses we've seen.
static MacAddress inet6_rx_mac_addr;
static Ipv6Address inet6_rx_ip_addr;

static void ipv6_print_addr(void *ipv6addr)
{
	const uint8_t *x = ipv6addr;

	// Find the longest stretch of zeroes in the address.
	int zero_start = 0;
	int zero_len = 0;
	int max_zero_start = -1;
	int max_zero_len = 2;
	for (int i = 0; i < Ipv6_AddrLen; i += 2) {
		if (!x[i] && !x[i + 1]) {
			if (!zero_len)
				zero_start = i;
			zero_len += 2;
		} else {
			if (zero_len > max_zero_len) {
				max_zero_len = zero_len;
				max_zero_start = zero_start;
			}
			zero_len = 0;
		}
	}
	if (zero_len > max_zero_len) {
		max_zero_len = zero_len;
		max_zero_start = zero_start;
	}

	// Print the address, substituting "::" for the stretch of zeroes
	// identified above.
	int colon = 0;
	for (int i = 0; i < Ipv6_AddrLen; i += 2) {
		if (i == max_zero_start) {
			printf("::");
			i += (max_zero_len - 2);
			colon = 0;
		} else {
			printf("%s%x", colon ? ":" : "",
			       (x[i] << 8) | x[i + 1]);
			colon = 1;
		}
	}
}

void ipv6_init(NetDevice *dev)
{
	const MacAddress *mac = dev->get_mac(dev);

	// Save our ethernet MAC and synthesize link layer addresses.
	Ipv6Address ip;
	inet6_lladdr_from_mac(&ip, mac);
	Ipv6Address snm_ip;
	inet6_snmaddr_from_mac(&snm_ip, mac);
	MacAddress snm_mac;
	inet6_multicast_from_ipv6(&snm_mac, &snm_ip);

	eth_add_mcast_filter(dev, &snm_mac);

	MacAddress all;
	inet6_multicast_from_ipv6(&all, &Ipv6LlAllNodes);
	eth_add_mcast_filter(dev, &all);

	printf("macaddr: %02x:%02x:%02x:%02x:%02x:%02x\n",
	       mac->octet[0], mac->octet[1], mac->octet[2], mac->octet[3],
	       mac->octet[4], mac->octet[5]);
	printf("ipv6addr: ");
	ipv6_print_addr(&ip);
	printf("\nsnmaddr: ");
	ipv6_print_addr(&snm_ip);
	printf("\n");
}

static int resolve_ipv6(MacAddress *_mac, const Ipv6Address *_ip)
{
	const uint8_t *ip = _ip->x;

	// Multicast addresses are a simple transform.
	if (ip[0] == 0xFF) {
		inet6_multicast_from_ipv6(_mac, _ip);
		return 0;
	}

	// Trying to send to the IP that we last received a packet from?
	// Assume their mac address has not changed.
	if (memcmp(_ip, &inet6_rx_ip_addr, sizeof(inet6_rx_ip_addr)) == 0) {
		memcpy(_mac, &inet6_rx_mac_addr, sizeof(inet6_rx_mac_addr));
		return 0;
	}

	// We don't know how to find peers or routers yet, so give up...
	return -1;
}

static uint16_t checksum(const void *_data, size_t len, uint16_t _sum)
{
	uint32_t sum = _sum;
	const uint16_t *data = _data;

	while (len > 1) {
		sum += *data++;
		len -= 2;
	}
	if (len)
		sum += (*data & 0xFF);
	while (sum > 0xFFFF)
		sum = (sum & 0xFFFF) + (sum >> 16);
	return sum;
}

typedef struct __attribute__((packed)) {
	uint32_t ver_tc_flow;
	uint16_t length;
	uint8_t next_header;
	uint8_t hop_limit;
	uint8_t src[Ipv6_AddrLen];
	uint8_t dst[Ipv6_AddrLen];
} Ipv6Hdr;

typedef struct {
	uint8_t eth[16];
	Ipv6Hdr ipv6;
	uint8_t data[0];
} Ipv6Pkt;

typedef struct __attribute__((packed)) {
	uint16_t src_port;
	uint16_t dst_port;
	uint16_t length;
	uint16_t checksum;
} UdpHdr;

typedef struct {
	uint8_t eth[16];
	Ipv6Hdr ipv6;
	UdpHdr udp;
	uint8_t data[0];
} UdpPkt;

static unsigned ipv6_checksum(Ipv6Hdr *ip, unsigned type, size_t length)
{
	uint16_t sum;

	// Length and protocol field for pseudo-header.
	sum = checksum(&ip->length, 2, htonw(type));
	// src/dst for pseudo-header + payload.
	sum = checksum(ip->src, 32 + length, sum);

	// 0 is illegal, so 0xffff remains 0xffff
	return sum == 0xffff ? sum : ~sum;
}

static int ipv6_setup(Ipv6Pkt *p, NetDevice *dev, const Ipv6Address *daddr,
		     size_t length, uint8_t type)
{
	MacAddress dmac;

	if (resolve_ipv6(&dmac, daddr))
		return -1;

	// Ethernet header.
	const MacAddress *ll_mac = dev->get_mac(dev);
	memcpy(p->eth + 2, &dmac, sizeof(dmac));
	memcpy(p->eth + 8, ll_mac, sizeof(*ll_mac));
	p->eth[14] = (EthType_Ipv6 >> 8) & 0xFF;
	p->eth[15] = EthType_Ipv6 & 0xFF;

	// Ipv6 header.
	p->ipv6.ver_tc_flow = 0x60; // v=6, tc=0, flow=0
	p->ipv6.length = htonw(length);
	p->ipv6.next_header = type;
	p->ipv6.hop_limit = 255;
	Ipv6Address saddr;
	inet6_lladdr_from_mac(&saddr, ll_mac);
	memcpy(p->ipv6.src, &saddr, sizeof(saddr));
	memcpy(p->ipv6.dst, daddr, sizeof(*daddr));

	return 0;
}

enum {
	Udpv6_MaxPayload =
		CONFIG_NET_LINK_MTU - sizeof(EtherHdr) -
		Ipv6_HdrLen - Udp_HdrLen,
};

enum {
	HdrHnhOpt = 0,
	HdrTcp = 6,
	HdrUdp = 17,
	HdrRouting = 43,
	HdrFragment = 44,
	HdrIcmpv6 = 58,
	HdrNone = 59,
	HdrDstOpt = 60,
};

int udpv6_send(Ipv6UdpCon *con, const void *data, size_t dlen)
{
	if (dlen > Udpv6_MaxPayload) {
		printf("Data is too big.\n");
		return -1;
	}

	UdpPkt *p = xmalloc(CONFIG_NET_LINK_MTU + 2);

	size_t length = dlen + Udp_HdrLen;
	if (ipv6_setup((void *)p, con->dev, &con->dest_ip, length, HdrUdp)) {
		printf("ipv6_setup failed.\n");
		free(p);
		return -1;
	}

	// udp header
	p->udp.src_port = htonw(con->source_port);
	p->udp.dst_port = htonw(con->dest_port);
	p->udp.length = htonw(length);
	p->udp.checksum = 0;

	memcpy(p->data, data, dlen);
	p->udp.checksum = ipv6_checksum(&p->ipv6, HdrUdp, length);

	int ret = con->dev->send(con->dev, p->eth + 2,
				 sizeof(EtherHdr) + Ipv6_HdrLen + length);
	free(p);
	return ret;
}

enum {
	Icmpv6_DestUnreachable = 1,
	Icmpv6_PacketTooBig = 2,
	Icmpv6_TimeExceeded = 3,
	Icmpv6_ParameterProblem = 4,

	Icmpv6_EchoRequest = 128,
	Icmpv6_EchoReply = 129,

	Icmpv6_NdpNSolicit = 135,
	Icmpv6_NdpNAdvertise = 136,
};

enum {
	Icmpv6_MaxPayload =
		CONFIG_NET_LINK_MTU - sizeof(EtherHdr) - Ipv6_HdrLen,
};

typedef struct __attribute__((packed)) {
	uint8_t type;
	uint8_t code;
	uint16_t checksum;
} Icmp6Hdr;

static int icmpv6_send(NetDevice *dev, const Ipv6Address *dest_ip,
		       const void *data, size_t length)
{
	if (length > Icmpv6_MaxPayload)
		return -1;

	Ipv6Pkt *p = xmalloc(CONFIG_NET_LINK_MTU + 2);

	if (ipv6_setup((void *)p, dev, dest_ip, length, HdrIcmpv6)) {
		free(p);
		return -1;
	}

	Icmp6Hdr *icmp = (void *)p->data;
	memcpy(icmp, data, length);
	icmp->checksum = ipv6_checksum(&p->ipv6, HdrIcmpv6, length);

	int ret = dev->send(dev, p->eth + 2,
			    sizeof(EtherHdr) + Ipv6_HdrLen + length);
	free(p);
	return ret;
}

static void _udpv6_recv(Ipv6Hdr *ip, void *_data, size_t len)
{
	UdpHdr *udp = _data;
	uint16_t sum, n;

	if (len < Udp_HdrLen)
		printf("error: Bogus Header Len\n");
	if (udp->checksum == 0)
		printf("error: Checksum Invalid\n");
	if (udp->checksum == 0xFFFF)
		udp->checksum = 0;

	sum = checksum(&ip->length, 2, htonw(HdrUdp));
	sum = checksum(ip->src, 32 + len, sum);
	if (sum != 0xFFFF)
		printf("error: Checksum Incorrect\n");

	n = ntohw(udp->length);
	if (n < Udp_HdrLen)
		printf("error: Bogus Header Len\n");
	if (n > len)
		printf("error: Packet Too Short\n");
	len = n - Udp_HdrLen;

	udpv6_recv((uint8_t *)_data + Udp_HdrLen, len,
		  (void *)ip->dst, ntohw(udp->dst_port),
		  (void *)ip->src, ntohw(udp->src_port));
}

enum {
	NdpNSrcLlAddr = 1,
	NdpNTgtLlAddr = 2,
	NdpNPrefixInfo = 3,
	NdpNRedirectedHdr = 4,
	NdpNMtu = 5,
};

typedef struct __attribute__((packed)) {
	uint8_t type;
	uint8_t code;
	uint16_t checksum;
	uint32_t flags;
	uint8_t target[Ipv6_AddrLen];
	uint8_t options[0];
} NdpNHdr;

static void icmpv6_recv(NetDevice *dev, Ipv6Hdr *ip, void *_data, size_t len)
{
	Icmp6Hdr *icmp = _data;
	if (icmp->checksum == 0)
		printf("error: Checksum Invalid\n");
	if (icmp->checksum == 0xFFFF)
		icmp->checksum = 0;

	uint16_t sum = checksum(&ip->length, 2, htonw(HdrIcmpv6));
	sum = checksum(ip->src, 32 + len, sum);
	if (sum != 0xFFFF)
		printf("error: Checksum Incorrect\n");

	if (icmp->type == Icmpv6_NdpNSolicit) {
		NdpNHdr *ndp = _data;
		struct {
			NdpNHdr hdr;
			uint8_t opt[8];
		} msg;

		if (len < sizeof(NdpNHdr))
			printf("error: Bogus NDP Message\n");
		if (ndp->code != 0)
			printf("error: Bogus NDP Code\n");

		const MacAddress *mac = dev->get_mac(dev);

		Ipv6Address ll_ip_addr;
		inet6_lladdr_from_mac(&ll_ip_addr, mac);
		if (memcmp(ndp->target, &ll_ip_addr, sizeof(ll_ip_addr)))
			printf("error: NDP Not For Me\n");

		msg.hdr.type = Icmpv6_NdpNAdvertise;
		msg.hdr.code = 0;
		msg.hdr.checksum = 0;
		msg.hdr.flags = 0x60; // (S)olicited and (O)verride flags.
		memcpy(msg.hdr.target, &ll_ip_addr, sizeof(ll_ip_addr));
		msg.opt[0] = NdpNTgtLlAddr;
		msg.opt[1] = 1;
		memcpy(msg.opt + 2, mac, sizeof(MacAddress));

		icmpv6_send(dev, (void *)ip->src, &msg, sizeof(msg));
		return;
	}

	if (icmp->type == Icmpv6_EchoRequest) {
		icmp->checksum = 0;
		icmp->type = Icmpv6_EchoReply;
		icmpv6_send(dev, (void *)ip->src, _data, len);
		return;
	}

	printf("error: ICMP6 Unhandled\n");
}

int eth_recv(NetDevice *dev, void *_data, size_t len)
{
	uint8_t *data = _data;

	if (len < sizeof(EtherHdr) + Ipv6_HdrLen)
		printf("error: Bogus Header Len\n");

	Ipv6Hdr *ip = (void *)(data + sizeof(EtherHdr));
	data += (sizeof(EtherHdr) + Ipv6_HdrLen);
	len -= (sizeof(EtherHdr) + Ipv6_HdrLen);

	// Require v6.
	if ((ip->ver_tc_flow & 0xF0) != 0x60)
		printf("error: Unknown ipv6 version.\n");

	// Ensure length is sane.
	uint32_t n = ntohw(ip->length);
	if (n > len)
		printf("error: IPv6 Length Mismatch\n");

	// Ignore any trailing data in the ethernet frame.
	len = n;

	// Require that we are the destination.
	const MacAddress *mac = dev->get_mac(dev);
	Ipv6Address ll_ip_addr;
	inet6_lladdr_from_mac(&ll_ip_addr, mac);
	Ipv6Address snm_ip_addr;
	inet6_snmaddr_from_mac(&snm_ip_addr, mac);

	if (memcmp(&ll_ip_addr, ip->dst, sizeof(ll_ip_addr)) &&
	    memcmp(&snm_ip_addr, ip->dst, sizeof(snm_ip_addr))) {
		return 1;
	}

	// Stash the sender's info to simplify replies.
	memcpy(&inet6_rx_mac_addr, (uint8_t *)_data + 6,
	       sizeof(inet6_rx_mac_addr));
	memcpy(&inet6_rx_ip_addr, ip->src, Ipv6_AddrLen);

	if (ip->next_header == HdrIcmpv6) {
		icmpv6_recv(dev, ip, data, len);
		return 0;
	}

	if (ip->next_header == HdrUdp) {
		_udpv6_recv(ip, data, len);
		return 0;
	}

	printf("error: Unhandled IPv6\n");
	return 1;
}
