blob: deca247888bcd615bd0e123c84c1a1b5bffca940 [file] [log] [blame]
package filter
import (
"testing"
"time"
"netstack/util"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
"github.com/google/netstack/tcpip/link/channel"
"github.com/google/netstack/tcpip/network/ipv4"
"github.com/google/netstack/tcpip/stack"
"github.com/google/netstack/tcpip/transport/tcp"
"github.com/google/netstack/tcpip/transport/udp"
"github.com/google/netstack/waiter"
)
// TODO: make these tests table-driven.
var (
testLanNet = util.Parse("192.168.42.0")
testLanNetMask tcpip.AddressMask = "\xff\xff\xff\x00"
testWanNet = util.Parse("10.0.0.0")
testWanNetMask tcpip.AddressMask = "\xff\x00\x00\x00"
testLanNICAddr = util.Parse("192.168.42.10")
testRouterNICAddr1 = util.Parse("192.168.42.1")
testRouterNICAddr2 = util.Parse("10.0.0.1")
testWanNICAddr = util.Parse("10.0.0.2")
testLanPort = uint16(10000)
testRouterPort = uint16(8080)
testWanPort = uint16(80)
testLanNICID = 1
testRouterNICID1 = 2
testRouterNICID2 = 3
testWanNICID = 4
testLanLinkAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x01")
testRouterLinkAddress1 = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x02")
testRouterLinkAddress2 = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x03")
testWanLinkAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x04")
)
func createTestStackLan(t *testing.T) (*stack.Stack, *channel.Endpoint) {
s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{})
id, linkEP := channel.New(1, 100, testLanLinkAddress)
nic := tcpip.NICID(testLanNICID)
if err := s.CreateDisabledNIC(nic, id); err != nil {
t.Fatalf("CreateDisableNIC error: %s", err)
}
s.EnableNIC(nic)
s.AddAddress(nic, header.IPv4ProtocolNumber, testLanNICAddr)
s.SetRouteTable([]tcpip.Route{
{
Destination: testLanNet,
Mask: testLanNetMask,
NIC: nic,
},
{
Destination: testWanNet,
Mask: testWanNetMask,
NIC: nic,
},
})
return s, linkEP
}
func createTestStackWan(t *testing.T) (*stack.Stack, *channel.Endpoint) {
s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{})
id, linkEP := channel.New(1, 100, testWanLinkAddress)
nic := tcpip.NICID(testWanNICID)
if err := s.CreateDisabledNIC(nic, id); err != nil {
t.Fatalf("CreateDisableNIC error: %s", err)
}
s.EnableNIC(nic)
s.AddAddress(nic, header.IPv4ProtocolNumber, testWanNICAddr)
s.SetRouteTable([]tcpip.Route{
{
Destination: testWanNet,
Mask: testWanNetMask,
NIC: nic,
},
})
return s, linkEP
}
func createTestStackRouterNAT(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) {
s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{})
f := New(s.PortManager)
srcSubnet, terr := tcpip.NewSubnet(testLanNet, tcpip.AddressMask(testLanNetMask))
if terr != nil {
t.Fatalf("NewSubnet error: %s", terr)
}
f.rulesetNAT.Lock()
f.rulesetNAT.v = []NAT{
{
transProto: header.UDPProtocolNumber,
srcSubnet: &srcSubnet,
newSrcAddr: testRouterNICAddr2,
},
{
transProto: header.TCPProtocolNumber,
srcSubnet: &srcSubnet,
newSrcAddr: testRouterNICAddr2,
},
}
f.rulesetNAT.Unlock()
id1, linkEP1 := channel.New(1, 100, testRouterLinkAddress1)
nic1 := tcpip.NICID(testRouterNICID1)
if err := s.CreateDisabledNIC(nic1, NewEndpoint(f, id1)); err != nil {
t.Fatalf("CreateDisableNIC error: %s", err)
}
s.EnableNIC(nic1)
s.AddAddress(nic1, header.IPv4ProtocolNumber, testRouterNICAddr1)
id2, linkEP2 := channel.New(1, 100, testRouterLinkAddress2)
nic2 := tcpip.NICID(testRouterNICID2)
if err := s.CreateDisabledNIC(nic2, NewEndpoint(f, id2)); err != nil {
t.Fatalf("CreateDisableNIC error: %s", err)
}
s.EnableNIC(nic2)
s.AddAddress(nic2, header.IPv4ProtocolNumber, testRouterNICAddr2)
s.SetRouteTable([]tcpip.Route{
{
Destination: testLanNet,
Mask: testLanNetMask,
NIC: nic1,
},
{
Destination: testWanNet,
Mask: testWanNetMask,
NIC: nic2,
},
})
s.SetForwarding(true)
return s, linkEP1, linkEP2
}
func TestNATOneWayLanToWanUDP(t *testing.T) {
sLan, sLanLinkEP := createTestStackLan(t)
sWan, sWanLinkEP := createTestStackWan(t)
_, sRouterLinkEP1, sRouterLinkEP2 := createTestStackRouterNAT(t)
go link(sLanLinkEP, sRouterLinkEP1)
go link(sRouterLinkEP2, sWanLinkEP)
var wqLan waiter.Queue
epLanUDP, err := sLan.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wqLan)
if err != nil {
t.Fatalf("NewEndpoint error: %s", err)
}
var wqWan waiter.Queue
epWanUDP, err := sWan.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wqWan)
if err != nil {
t.Fatalf("NewEndpoint error: %s", err)
}
receiverWan := tcpip.FullAddress{Addr: testWanNICAddr, Port: testWanPort}
if err := epWanUDP.Bind(receiverWan, nil); err != nil {
t.Fatalf("Bind error: %s", err)
}
waitEntryWan, chWan := waiter.NewChannelEntry(nil)
wqWan.EventRegister(&waitEntryWan, waiter.EventIn)
if _, _, err := epLanUDP.Write(tcpip.SlicePayload("hello"), tcpip.WriteOptions{To: &receiverWan}); err != nil {
t.Fatalf("Write error: %s", err)
}
select {
case <-chWan:
case <-time.After(1 * time.Second):
t.Fatalf("Read timeout")
}
wqWan.EventUnregister(&waitEntryWan)
var sender tcpip.FullAddress
recvd, _, err := epWanUDP.Read(&sender)
if err != nil {
t.Fatalf("Read error: %s", err)
}
if got, want := sender.Addr, testRouterNICAddr2; got != want {
t.Errorf("sender.Addr %s, want %s", got, want)
}
// sender.Port is random.
if got, want := string(recvd), "hello"; got != want {
t.Errorf("got %s, want %s", got, want)
}
}
func TestNATRoundtripLanToWanUDP(t *testing.T) {
sLan, sLanLinkEP := createTestStackLan(t)
sWan, sWanLinkEP := createTestStackWan(t)
_, sRouterLinkEP1, sRouterLinkEP2 := createTestStackRouterNAT(t)
go link(sLanLinkEP, sRouterLinkEP1)
go link(sRouterLinkEP2, sWanLinkEP)
go link(sWanLinkEP, sRouterLinkEP2)
go link(sRouterLinkEP1, sLanLinkEP)
var wqLan waiter.Queue
epLanUDP, err := sLan.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wqLan)
if err != nil {
t.Fatalf("NewEndpoint error: %s", err)
}
var wqWan waiter.Queue
epWanUDP, err := sWan.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wqWan)
if err != nil {
t.Fatalf("NewEndpoint error: %s", err)
}
receiverLan := tcpip.FullAddress{Addr: testLanNICAddr, Port: testLanPort}
receiverWan := tcpip.FullAddress{Addr: testWanNICAddr, Port: testWanPort}
if err := epLanUDP.Bind(receiverLan, nil); err != nil {
t.Fatalf("Bind error: %s", err)
}
if err := epWanUDP.Bind(receiverWan, nil); err != nil {
t.Fatalf("Bind error: %s", err)
}
waitEntryWan, chWan := waiter.NewChannelEntry(nil)
wqWan.EventRegister(&waitEntryWan, waiter.EventIn)
if _, _, err := epLanUDP.Write(tcpip.SlicePayload("hello"), tcpip.WriteOptions{To: &receiverWan}); err != nil {
t.Fatalf("Write error: %s", err)
}
select {
case <-chWan:
case <-time.After(1 * time.Second):
t.Fatalf("Read timeout")
}
wqWan.EventUnregister(&waitEntryWan)
var sender tcpip.FullAddress
recvd, _, err := epWanUDP.Read(&sender)
if err != nil {
t.Fatalf("Read error: %s", err)
}
if got, want := sender.Addr, testRouterNICAddr2; got != want {
t.Errorf("sender.Addr %s, want %s", got, want)
}
// sender.Port is random.
if got, want := string(recvd), "hello"; got != want {
t.Errorf("got %s, want %s", got, want)
}
waitEntryLan, chLan := waiter.NewChannelEntry(nil)
wqLan.EventRegister(&waitEntryLan, waiter.EventIn)
if _, _, err := epWanUDP.Write(tcpip.SlicePayload("hi"), tcpip.WriteOptions{To: &sender}); err != nil {
t.Fatalf("Write error: %s", err)
}
select {
case <-chLan:
case <-time.After(1 * time.Second):
t.Fatalf("Read timeout")
}
wqLan.EventUnregister(&waitEntryLan)
var sender2 tcpip.FullAddress
recvd2, _, err := epLanUDP.Read(&sender2)
if err != nil {
t.Fatalf("Read error: %s", err)
}
if got, want := sender2.Addr, receiverWan.Addr; got != want {
t.Errorf("sender2.Addr %s, want %s", got, want)
}
if got, want := sender2.Port, receiverWan.Port; got != want {
t.Errorf("sender2.Addr %d, want %d", got, want)
}
if got, want := string(recvd2), "hi"; got != want {
t.Errorf("got %s, want %s", got, want)
}
}
func TestNATLanToWanTCP(t *testing.T) {
sLan, sLanLinkEP := createTestStackLan(t)
sWan, sWanLinkEP := createTestStackWan(t)
_, sRouterLinkEP1, sRouterLinkEP2 := createTestStackRouterNAT(t)
go link(sLanLinkEP, sRouterLinkEP1)
go link(sRouterLinkEP2, sWanLinkEP)
go link(sWanLinkEP, sRouterLinkEP2)
go link(sRouterLinkEP1, sLanLinkEP)
var wqLan waiter.Queue
epLanTCP, err := sLan.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wqLan)
if err != nil {
t.Fatalf("NewEndpoint error: %s", err)
}
var wqWanMaster waiter.Queue
epWanTCPMaster, err := sWan.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wqWanMaster)
if err != nil {
t.Fatalf("NewEndpoint error: %s", err)
}
receiverWan := tcpip.FullAddress{Addr: testWanNICAddr, Port: testWanPort}
if err := epWanTCPMaster.Bind(receiverWan, nil); err != nil {
t.Fatalf("Bind error: %s", err)
}
if err := epWanTCPMaster.Listen(10); err != nil {
t.Fatalf("Bind error: %s", err)
}
waitEntryLan, chLan := waiter.NewChannelEntry(nil)
wqLan.EventRegister(&waitEntryLan, waiter.EventOut)
waitEntryWanMaster, chWanMaster := waiter.NewChannelEntry(nil)
wqWanMaster.EventRegister(&waitEntryWanMaster, waiter.EventIn)
if err := epLanTCP.Connect(receiverWan); err != nil {
if err != tcpip.ErrConnectStarted {
t.Fatalf("Connect error: %s", err)
}
}
select {
case <-chWanMaster:
case <-time.After(1 * time.Second):
t.Fatalf("Accept timeout")
}
epWanTCP, wqWan, err := epWanTCPMaster.Accept()
if err != nil {
t.Fatalf("Accept error: %s", err)
}
wqWanMaster.EventUnregister(&waitEntryWanMaster)
select {
case <-chLan:
case <-time.After(1 * time.Second):
t.Fatalf("Connect timeout")
}
wqLan.EventUnregister(&waitEntryLan)
sender, err := epWanTCP.GetRemoteAddress()
if got, want := sender.Addr, testRouterNICAddr2; got != want {
t.Errorf("sender.Addr %s, want %s", got, want)
}
// sender.Port is random.
waitEntryWan, chWan := waiter.NewChannelEntry(nil)
wqWan.EventRegister(&waitEntryWan, waiter.EventIn)
if _, _, err := epLanTCP.Write(tcpip.SlicePayload("hello"), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write error: %s", err)
}
select {
case <-chWan:
case <-time.After(1 * time.Second):
t.Fatalf("Read timeout")
}
wqWan.EventUnregister(&waitEntryWan)
recvd, _, err := epWanTCP.Read(nil)
if err != nil {
t.Fatalf("Read error: %s", err)
}
if got, want := string(recvd), "hello"; got != want {
t.Errorf("got %s, want %s", got, want)
}
wqLan.EventRegister(&waitEntryLan, waiter.EventIn)
if _, _, err := epWanTCP.Write(tcpip.SlicePayload("hi"), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write error: %s", err)
}
select {
case <-chLan:
case <-time.After(1 * time.Second):
t.Fatalf("Read timeout")
}
wqLan.EventUnregister(&waitEntryLan)
recvd2, _, err := epLanTCP.Read(nil)
if err != nil {
t.Fatalf("Read error: %s", err)
}
if got, want := string(recvd2), "hi"; got != want {
t.Errorf("got %s, want %s", got, want)
}
}
func link(a, b *channel.Endpoint) {
for x := range a.C {
b.Inject(unpacketInfo(x))
}
}
func unpacketInfo(p channel.PacketInfo) (tcpip.NetworkProtocolNumber, buffer.VectorisedView) {
var size int
var views []buffer.View
if header := p.Header; header != nil {
size += len(header)
views = append(views, header)
}
size += len(p.Payload)
views = append(views, p.Payload)
return p.Proto, buffer.NewVectorisedView(size, views)
}