package stack

import (
	"testing"
	"time"

	"github.com/google/netstack/tcpip"
	"github.com/google/netstack/tcpip/buffer"
	"github.com/google/netstack/tcpip/header"
	"github.com/google/netstack/tcpip/link/bufwritingchannel"
	"github.com/google/netstack/tcpip/link/channel"
	"github.com/google/netstack/tcpip/network/ipv4"
	"github.com/google/netstack/tcpip/stack"
	"github.com/google/netstack/tcpip/transport/tcp"
	"github.com/google/netstack/tcpip/transport/udp"
	"github.com/google/netstack/waiter"
)

func TestOneWayBridgeSeparateStacks(t *testing.T) {
	s1, s1eps, s1nics, err := newStack([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, 1)
	if err != nil {
		t.Fatalf("newStack error: %s", err)
	}
	s1EP := s1eps[0]
	s1.EnableNIC(s1nics[0])
	s1.AddAddress(s1nics[0], header.IPv4ProtocolNumber, tcpip.Parse("192.168.42.10"))
	s1.SetRouteTable([]tcpip.Route{
		{
			Destination: tcpip.Parse("10.0.0.1"),
			Mask:        tcpip.Parse("255.255.255.255"),
			NIC:         s1nics[0],
		},
	})
	var wq1 waiter.Queue
	txEP1, err := s1.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq1)

	s2, s2eps, s2nics, err := newStack([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, 1)
	s2EP := s2eps[0]
	s2.EnableNIC(s2nics[0])
	s2.AddAddress(s2nics[0], header.IPv4ProtocolNumber, tcpip.Parse("10.0.0.1"))
	var wq2 waiter.Queue
	txEP2, err := s2.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq2)
	err = txEP2.Bind(tcpip.FullAddress{Addr: tcpip.Parse("10.0.0.1"), Port: 8080}, nil)
	if err != nil {
		t.Fatalf("error in bind: %s", err)
	}

	sb, eps, nicIDs, err := newStack([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, 2)
	bEP1 := eps[0]
	bEP2 := eps[1]

	b, tcpipErr := sb.Bridge(nicIDs)
	if tcpipErr != nil {
		t.Fatalf("failed during bridge setup: %s", err)
	}
	b.Enable()

	// bEP1 and bEP2 don't need to be linked, since they're bridged.
	go link(s1EP, bEP1)
	go link(bEP2, s2EP)

	addr := tcpip.FullAddress{Addr: tcpip.Parse("10.0.0.1"), Port: 8080}
	_, err = txEP1.Write(buffer.View("hello"), &addr)
	if err != nil {
		t.Fatalf("failed to write: %s \n%+v", err, s1.Stats())
	}

	// TODO(stijlist): use waitqueue from txEP2 instead of sleeping
	<-time.After(10 * time.Millisecond)

	recvd, err := txEP2.Read(&addr)
	if err != nil {
		t.Fatalf("failed to read: %s\n%+v", err, s2.Stats())
	}

	payload := string(recvd)
	if payload != "hello" {
		t.Errorf("want hello, got %s", payload)
	}
}

func TestTwoWayBridgeSeparateStacks(t *testing.T) {
	s1, s1eps, s1nics, err := newStack([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, 1)
	if err != nil {
		t.Fatalf("newStack error: %s", err)
	}
	s1EP := s1eps[0]
	s1.EnableNIC(s1nics[0])
	s1.AddAddress(s1nics[0], header.IPv4ProtocolNumber, tcpip.Parse("192.168.42.10"))
	s1.SetRouteTable([]tcpip.Route{
		{
			Destination: tcpip.Parse("10.0.0.1"),
			Mask:        tcpip.Parse("255.255.255.255"),
			NIC:         s1nics[0],
		},
	})

	var wq1 waiter.Queue
	txEP1, err := s1.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq1)
	if err != nil {
		t.Fatalf("could not create endpoint in s1: %s", err)
	}

	s2, s2eps, s2nics, err := newStack([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, 1)
	s2EP := s2eps[0]
	s2.EnableNIC(s2nics[0])
	s2.AddAddress(s2nics[0], header.IPv4ProtocolNumber, tcpip.Parse("10.0.0.1"))
	s2.SetRouteTable([]tcpip.Route{
		{
			Destination: tcpip.Parse("192.168.42.10"),
			Mask:        tcpip.Parse("255.255.255.255"),
			NIC:         s2nics[0],
		},
	})

	var wq2 waiter.Queue
	txEP2, err := s2.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq2)
	if err != nil {
		t.Fatalf("could not create endpoint in s2: %s", err)
	}

	err = txEP2.Bind(tcpip.FullAddress{Addr: tcpip.Parse("10.0.0.1"), Port: 8080}, nil)
	if err != nil {
		t.Fatalf("error in bind: %s", err)
	}

	err = txEP2.Listen(1)
	if err != nil {
		t.Fatalf("error in listen: %s", err)
	}

	sb, eps, nicIDs, err := newStack([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, 2)
	bEP1 := eps[0]
	bEP2 := eps[1]

	b, tcpipErr := sb.Bridge(nicIDs)
	if tcpipErr != nil {
		t.Fatalf("failed during bridge setup: %s", err)
	}
	b.Enable()

	go link(s1EP, bEP1)
	go link(bEP1, s1EP)

	go link(s2EP, bEP2)
	go link(bEP2, s2EP)

	addr := tcpip.FullAddress{Addr: tcpip.Parse("10.0.0.1"), Port: 8080}
	err = txEP1.Connect(addr)
	if err != tcpip.ErrConnectStarted {
		t.Fatalf("failed to connect: %s \n%+v", err, s1.Stats())
	}

	// TODO(stijlist): use waitqueue returned from accept instead of sleeping
	<-time.After(10 * time.Millisecond)

	readingEP, _, err := txEP2.Accept()

	_, err = txEP1.Write(buffer.View("hello"), &addr)
	if err != nil {
		t.Fatalf("failed to write: %s \n%+v", err, s1.Stats())
	}

	<-time.After(10 * time.Millisecond)

	recvd, err := readingEP.Read(&addr)
	if err != nil {
		t.Fatalf("failed to read: %s\ns1 stats: %+v\n\ns2 stats: %+v", err, s1.Stats(), s2.Stats())
	}

	payload := string(recvd)
	if payload != "hello" {
		t.Errorf("want hello, got %s", payload)
	}
}

// Loop forever, injecting `a`'s packets into `b`.
func link(a, b *bufwritingchannel.Endpoint) {
	for x := range a.C {
		b.Inject(unpacketInfo(x))
	}
}

var ni uint = 0

func newStack(netProtos []string, transProtos []string, numEndpoints int) (s *stack.Stack, eps []*bufwritingchannel.Endpoint, nicIDs []tcpip.NICID, err *tcpip.Error) {
	s = stack.New(netProtos, transProtos)
	for i := 0; i < numEndpoints; i++ {
		id, ep := bufwritingchannel.New(1, 100, newLinkAddress())
		nicid := tcpip.NICID(ni)
		ni++
		err = s.CreateDisabledNIC(nicid, id)
		if err != nil {
			return
		}
		eps = append(eps, ep)
		nicIDs = append(nicIDs, nicid)
	}
	return
}

var li byte = 0

func newLinkAddress() tcpip.LinkAddress {
	l := tcpip.LinkAddress([]byte{li, li, li, li, li, li, li})
	li++
	return l
}

func unpacketInfo(p channel.PacketInfo) (tcpip.NetworkProtocolNumber, *buffer.VectorisedView) {
	n := p.Proto
	var vv buffer.VectorisedView
	if p.Header != nil {
		vv = buffer.NewVectorisedView(len(p.Header)+len(p.Payload), []buffer.View{p.Header, p.Payload})
	} else {
		vv = buffer.NewVectorisedView(len(p.Payload), []buffer.View{p.Payload})
	}
	return n, &vv
}
