// Copyright 2016 The Netstack Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package ipv4_test

import (
	"context"
	"encoding/binary"
	"testing"
	"time"

	"github.com/google/netstack/tcpip"
	"github.com/google/netstack/tcpip/buffer"
	"github.com/google/netstack/tcpip/header"
	"github.com/google/netstack/tcpip/link/channel"
	"github.com/google/netstack/tcpip/link/sniffer"
	"github.com/google/netstack/tcpip/network/ipv4"
	"github.com/google/netstack/tcpip/stack"
	"github.com/google/netstack/waiter"
)

const stackAddr = "\x0a\x00\x00\x01"

type testContext struct {
	t      *testing.T
	linkEP *channel.Endpoint
	s      *stack.Stack
}

func newTestContext(t *testing.T) *testContext {
	s := stack.New([]string{ipv4.ProtocolName}, []string{ipv4.PingProtocolName})

	const defaultMTU = 65536
	id, linkEP := channel.New(256, defaultMTU, "")
	if testing.Verbose() {
		id = sniffer.New(id)
	}
	if err := s.CreateNIC(1, id); err != nil {
		t.Fatalf("CreateNIC failed: %v", err)
	}

	if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
		t.Fatalf("AddAddress failed: %v", err)
	}

	s.SetRouteTable([]tcpip.Route{{
		Destination: "\x00\x00\x00\x00",
		Mask:        "\x00\x00\x00\x00",
		Gateway:     "",
		NIC:         1,
	}})

	return &testContext{
		t:      t,
		s:      s,
		linkEP: linkEP,
	}
}

func (c *testContext) cleanup() {
	close(c.linkEP.C)
}

func (c *testContext) loopback() {
	go func() {
		for pkt := range c.linkEP.C {
			v := make(buffer.View, len(pkt.Header)+len(pkt.Payload))
			copy(v, pkt.Header)
			copy(v[len(pkt.Header):], pkt.Payload)
			vv := v.ToVectorisedView([1]buffer.View{})
			c.linkEP.Inject(pkt.Proto, &vv)
		}
	}()
}

func TestEcho(t *testing.T) {
	c := newTestContext(t)
	defer c.cleanup()
	c.loopback()

	ch := make(chan ipv4.PingReply, 1)
	p := ipv4.Pinger{
		Stack: c.s,
		NICID: 1,
		Addr:  stackAddr,
		Wait:  10 * time.Millisecond,
		Count: 1, // one ping only
	}
	if err := p.Ping(context.Background(), ch); err != nil {
		t.Fatalf("icmp.Ping failed: %v", err)
	}

	ping := <-ch
	if ping.Error != nil {
		t.Errorf("bad ping response: %v", ping.Error)
	}
}

func TestEchoSequence(t *testing.T) {
	c := newTestContext(t)
	defer c.cleanup()
	c.loopback()

	const numPings = 3
	ch := make(chan ipv4.PingReply, numPings)
	p := ipv4.Pinger{
		Stack: c.s,
		NICID: 1,
		Addr:  stackAddr,
		Wait:  10 * time.Millisecond,
		Count: numPings,
	}
	if err := p.Ping(context.Background(), ch); err != nil {
		t.Fatalf("icmp.Ping failed: %v", err)
	}

	for i := uint16(0); i < numPings; i++ {
		ping := <-ch
		if ping.Error != nil {
			t.Errorf("i=%d bad ping response: %v", i, ping.Error)
		}
		if ping.SeqNumber != i {
			t.Errorf("SeqNumber=%d, want %d", ping.SeqNumber, i)
		}
	}
}

const (
	stackAddr0 = "\x0a\x00\x00\x02"
	stackAddr1 = "\x0a\x00\x00\x03"
	linkAddr0  = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06")
	linkAddr1  = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
)

type testEndpointContext struct {
	t *testing.T
	s *stack.Stack

	linkEP0 *channel.Endpoint
	linkEP1 *channel.Endpoint

	icmpCh chan header.ICMPv4
}

func (c *testEndpointContext) cleanup() {
	close(c.linkEP0.C)
	close(c.linkEP1.C)
}

func newTestEndpointContext(t *testing.T) *testEndpointContext {
	c := &testEndpointContext{
		t:      t,
		s:      stack.New([]string{ipv4.ProtocolName}, []string{ipv4.PingProtocolName}),
		icmpCh: make(chan header.ICMPv4, 10),
	}

	const defaultMTU = 65536
	id0, linkEP := channel.New(256, defaultMTU, linkAddr0)
	c.linkEP0 = linkEP
	if testing.Verbose() {
		id0 = sniffer.New(id0)
	}
	if err := c.s.CreateNIC(1, id0); err != nil {
		t.Fatalf("CreateNIC s: %v", err)
	}
	id1, linkEP := channel.New(256, defaultMTU, linkAddr1)
	c.linkEP1 = linkEP
	if testing.Verbose() {
		id1 = sniffer.New(id1)
	}
	if err := c.s.CreateNIC(2, id1); err != nil {
		t.Fatalf("CreateNIC s: %v", err)
	}
	if err := c.s.AddAddress(2, ipv4.ProtocolNumber, stackAddr0); err != nil {
		t.Fatalf("AddAddress failed: %v", err)
	}
	if err := c.s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil {
		t.Fatalf("AddAddress failed: %v", err)
	}
	c.s.SetRouteTable([]tcpip.Route{
		{
			Destination: stackAddr0,
			Mask:        "\xFF\xFF\xFF\xFF",
			Gateway:     "",
			NIC:         1,
		},
		{
			Destination: stackAddr1,
			Mask:        "\xFF\xFF\xFF\xFF",
			Gateway:     "",
			NIC:         2,
		},
	})

	go c.routePackets(c.linkEP0.C, c.linkEP1)
	go c.routePackets(c.linkEP1.C, c.linkEP0)
	return c
}

func (c *testEndpointContext) countPacket(pkt channel.PacketInfo) {
	if pkt.Proto != header.IPv4ProtocolNumber {
		c.t.Fatalf("Received non IPV4 packet: 0x%x", pkt.Proto)
	}
	ipv4 := header.IPv4(pkt.Header)
	c.icmpCh <- header.ICMPv4(append(pkt.Header[ipv4.HeaderLength():], pkt.Payload...))
}

func (c *testEndpointContext) routePackets(ch <-chan channel.PacketInfo, ep *channel.Endpoint) {
	for pkt := range ch {
		c.countPacket(pkt)
		v := buffer.View(append(pkt.Header, pkt.Payload...))
		vs := []buffer.View{v}
		vv := buffer.NewVectorisedView(len(v), vs)
		ep.InjectLinkAddr(pkt.Proto, ep.LinkAddress(), &vv)
	}
}

type callbackStub struct {
	f func(e *waiter.Entry)
}

func (c *callbackStub) Callback(e *waiter.Entry) {
	c.f(e)
}

func TestEndpoints(t *testing.T) {
	c := newTestEndpointContext(t)
	defer c.cleanup()

	wq0 := &waiter.Queue{}
	ep0, err := c.s.NewEndpoint(ipv4.PingProtocolNumber, ipv4.ProtocolNumber, wq0)
	if err != nil {
		c.t.Fatalf("NewEndpoint failed: %v", err)
	}
	defer ep0.Close()
	wq1 := &waiter.Queue{}
	ep1, err := c.s.NewEndpoint(ipv4.PingProtocolNumber, ipv4.ProtocolNumber, wq1)
	if err != nil {
		c.t.Fatalf("NewEndpoint failed: %v", err)
	}
	defer ep1.Close()

	if err := ep0.Bind(tcpip.FullAddress{NIC: 1, Addr: stackAddr0}, nil); err != nil {
		c.t.Fatalf("Bind failed: %v", err)
	}
	if err := ep1.Bind(tcpip.FullAddress{NIC: 2, Addr: stackAddr1}, nil); err != nil {
		c.t.Fatalf("Bind failed: %v", err)
	}

	echos := 64

	ping := func(wq *waiter.Queue, ep tcpip.Endpoint, data []byte) {
		outPkt := make([]byte, header.ICMPv4MinimumSize+4+len(data))
		icmpv4 := header.ICMPv4(outPkt[:header.ICMPv4MinimumSize])
		icmpv4.SetType(header.ICMPv4Echo)
		copy(outPkt[header.ICMPv4MinimumSize+4:], data)
		binary.BigEndian.PutUint16(outPkt[header.ICMPv4MinimumSize:], 0)

		for seqno := uint16(1); seqno <= uint16(echos); seqno++ {
			binary.BigEndian.PutUint16(outPkt[header.ICMPv4MinimumSize+2:], seqno)

			// We need to register with the waiter queue before we try writing, since
			// the notification that the endpoint received a response may arrive immediately.
			ready := make(chan struct{})
			e := waiter.Entry{Callback: &callbackStub{func(*waiter.Entry) { close(ready) }}}
			wq.EventRegister(&e, waiter.EventIn)
			n, err := ep.Write(buffer.View(outPkt), nil)
			if err != nil {
				c.t.Fatalf("Write failed: %v\n", err)
			} else if n != uintptr(len(outPkt)) {
				c.t.Fatalf("Write was short: %v\n", n)
			}

			// Avoid reading until we have something to read
			select {
			case <-time.After(1 * time.Second):
				c.t.Fatalf("Timed out waiting for socket to be readable")
			case <-ready:
			}
			wq.EventUnregister(&e)
			inPkt, err := ep.Read(nil)
			if err != nil {
				c.t.Fatalf("Read failed: %v\n", err)
			}

			// Verify the contents of the packet we just read.
			var icmp header.ICMPv4 = []byte(inPkt)
			if icmp.Type() != header.ICMPv4EchoReply {
				c.t.Fatalf("Unexpected packet type: %d", icmp.Type())
			}
			inSeqno := binary.BigEndian.Uint16(inPkt[header.ICMPv4MinimumSize+2 : header.ICMPv4MinimumSize+4])
			if inSeqno != seqno {
				c.t.Fatalf("Unexpected sequence number: %d", inSeqno)
			}
			outData := outPkt[header.ICMPv4EchoMinimumSize:]
			inData := inPkt[header.ICMPv4EchoMinimumSize:]
			if len(outData) != len(inData) {
				c.t.Fatalf("Read packet of unexpected length: %d\n", len(inData))
			}
			for i := range inData {
				if inData[i] != outData[i] {
					c.t.Fatalf("Data mismatch")
				}
			}
		}
	}

	data := []byte{0xaa, 0xab, 0xac}
	go ping(wq0, ep0, data)
	data = []byte{0xad, 0xae, 0xaf}
	go ping(wq1, ep1, data)

	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
	defer cancel()

	stats := make(map[header.ICMPv4Type]int)
	for {
		select {
		case <-ctx.Done():
			t.Errorf("Timeout waiting for ICMP, got: %#+v", stats)
			return
		case icmp := <-c.icmpCh:
			if icmp.Type() != header.ICMPv4Echo && icmp.Type() != header.ICMPv4EchoReply {
				c.t.Fatalf("Unexpected type: %d", icmp.Type())
			}
			stats[icmp.Type()]++
			if stats[icmp.Type()] > echos*2 {
				c.t.Fatalf("Too many (%d) packets of type %d", stats[icmp.Type()], icmp.Type())
			}
			if len(stats) == 2 && stats[header.ICMPv4Echo] == echos*2 && stats[header.ICMPv4EchoReply] == echos*2 {
				return
			}
		}
	}
}
