blob: 019bf384326b1aa6ffde7148a269b46722708124 [file] [log] [blame]
package bridge_test
import (
"bytes"
"fmt"
"testing"
"time"
"netstack/link/bridge"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/buffer"
"github.com/google/netstack/tcpip/header"
"github.com/google/netstack/tcpip/link/channel"
"github.com/google/netstack/tcpip/network/arp"
"github.com/google/netstack/tcpip/network/ipv4"
"github.com/google/netstack/tcpip/stack"
"github.com/google/netstack/tcpip/transport/tcp"
"github.com/google/netstack/waiter"
)
var (
timeoutReceiveReady error = fmt.Errorf("receiveready")
timeoutSendReady error = fmt.Errorf("sendready")
timeoutPayloadReceived error = fmt.Errorf("payloadreceived")
)
func TestEndpointAttributes(t *testing.T) {
resolutionRequired := stack.LinkEndpointCapabilities(stack.CapabilityResolutionRequired)
var resolutionNotRequired stack.LinkEndpointCapabilities
if bridge.CombineCapabilities(resolutionRequired, resolutionNotRequired) != resolutionRequired {
t.Errorf("got bridge.Combinecapabilities(%#v, %#v) == %#v, want == %#v", resolutionRequired, resolutionNotRequired, bridge.CombineCapabilities(resolutionRequired, resolutionNotRequired), resolutionRequired)
}
checksumAndLoopback := stack.LinkEndpointCapabilities(stack.CapabilityChecksumOffload | stack.CapabilityLoopback)
var noChecksumAndLoopback stack.LinkEndpointCapabilities
if bridge.CombineCapabilities(checksumAndLoopback, noChecksumAndLoopback) != noChecksumAndLoopback {
t.Errorf("got bridge.Combinecapabilities(%#v, %#v) == %#v, want == %#v", checksumAndLoopback, noChecksumAndLoopback, bridge.CombineCapabilities(checksumAndLoopback, noChecksumAndLoopback), noChecksumAndLoopback)
}
linkID1, _ := channel.New(1, 101, "")
linkID2, _ := channel.New(1, 100, "")
linkID1, ep1 := bridge.NewEndpoint(linkID1)
linkID2, ep2 := bridge.NewEndpoint(linkID2)
bridge := bridge.New([]*bridge.BridgeableEndpoint{ep1, ep2})
if bridge.MTU() != 100 {
t.Errorf("got bridge.MTU() == %d but want 100", bridge.MTU())
}
if linkAddr := bridge.LinkAddress(); linkAddr[0]&0x2 == 0 {
t.Errorf("bridge.LinkAddress() expected to be locally administered MAC address, got: %s", linkAddr)
}
}
func TestBridge(t *testing.T) {
// payload should be unique enough that it won't accidentally appear
// in TCP/IP packets.
const payload = "hello"
/* Connection diagram:
s1ep <----> ep1 ep2 <----> s2ep
^--bridge1--^
*/
ep1, ep2 := pipe(tcpip.LinkAddress(bytes.Repeat([]byte{1}, header.EthernetAddressSize)), tcpip.LinkAddress(bytes.Repeat([]byte{2}, header.EthernetAddressSize)))
ep3, ep4 := pipe(tcpip.LinkAddress(bytes.Repeat([]byte{3}, header.EthernetAddressSize)), tcpip.LinkAddress(bytes.Repeat([]byte{4}, header.EthernetAddressSize)))
s1addr := tcpip.Address([]byte{1, 1, 1, 1})
s1, err := makeStackWithEndpoint(ep1, s1addr)
if err != nil {
t.Fatal(err)
}
baddr := tcpip.Address([]byte{2, 2, 2, 2})
sb, b, err := makeStackWithBridgedEndpoints(ep2, ep3, baddr)
if err != nil {
t.Fatal(err)
}
if err := b.Up(); err != nil {
t.Fatal(err)
}
s2addr := tcpip.Address([]byte{3, 3, 3, 3})
s2, err := makeStackWithEndpoint(ep4, s2addr)
if err != nil {
t.Fatal(err)
}
// Add an address to one of the constituent links of the bridge (in addition
// to the address on the virtual NIC representing the bridge itself), to test
// that constituent links are still routable.
bcaddr := tcpip.Address([]byte{4, 4, 4, 4})
if err := sb.AddAddress(1, header.ARPProtocolNumber, arp.ProtocolAddress); err != nil {
t.Fatal(fmt.Errorf("AddAddress failed: %s", err))
}
if err := sb.AddAddress(1, header.IPv4ProtocolNumber, bcaddr); err != nil {
t.Fatal(fmt.Errorf("AddAddress failed: %s", err))
}
s1.SetRouteTable([]tcpip.Route{
{
Destination: s2addr,
Mask: "\xff\xff\xff\xff",
NIC: 1,
},
{
Destination: baddr,
Mask: "\xff\xff\xff\xff",
NIC: 1,
},
{
Destination: bcaddr,
Mask: "\xff\xff\xff\xff",
NIC: 1,
},
})
sb.SetRouteTable([]tcpip.Route{
{
Destination: s1addr,
Mask: "\xff\xff\xff\xff",
NIC: 1,
},
})
s2.SetRouteTable(
[]tcpip.Route{
{
Destination: s1addr,
Mask: "\xff\xff\xff\xff",
NIC: 2,
},
},
)
addrs := map[tcpip.Address]*stack.Stack{
s2addr: s2,
baddr: sb,
bcaddr: sb,
}
stacks := map[string]*stack.Stack{
"s1": s1, "s2": s2, "sb": sb,
}
ep2.onWritePacket = func(vv buffer.VectorisedView) {
if bytes.Contains(vv.ToView(), []byte(payload)) {
t.Errorf("did not expect payload %q to be sent back to ep1 in vv: %v", payload, vv)
}
}
for ipaddr, toStack := range addrs {
t.Run(fmt.Sprintf("ConnectAndWrite_%s", ipaddr), func(t *testing.T) {
recvd, err := connectAndWrite(s1, toStack, ipaddr, payload)
if err != nil {
t.Fatal(err)
}
if got := string(recvd); got != payload {
t.Errorf("got Read(...) = %v, want = %v", got, payload)
}
for name, s := range stacks {
stats := s.Stats()
if n := stats.UnknownProtocolRcvdPackets.Value(); n != 0 {
t.Errorf("stack %s received %d UnknownProtocolRcvdPackets", name, n)
}
if n := stats.MalformedRcvdPackets.Value(); n != 0 {
t.Errorf("stack %s received %d MalformedRcvdPackets", name, n)
}
if n := stats.DroppedPackets.Value(); n != 0 {
t.Errorf("stack %s received %d DroppedPackets", name, n)
}
// The invalid address counter counts packets that have been received
// by a stack correctly addressed at the link layer but incorrectly
// addressed at the network layer (e.g. no network interface has the
// address listed in the packet). This usually happens because
// the stack is being sent packets for an IP address that it used to
// have but doesn't have anymore. In this case, the bridge will
// forward a packet to all constituent links when the link address that
// the packet is addressed to isn't found on the bridge.
//
// TODO(NET-690): When we implement learning, we should be able to
// modify this test setup to get to zero invalid addresses received.
// With the current test setup, once learning is implemented, the
// bridge would indiscriminately forward the first packet addressed to
// a link address to all constituent links (causing #links - 1 invalid
// addresses received), observe which link the response packet came
// from, and then remember which link to forward to when the next
// packet addressed to that link address was received. We might be able
// to get to zero invalid addresses received by learning which links a
// given address is on via the broadcast packets sent during ARP.
// if n := stats.IP.InvalidAddressesReceived.Value(); n != 0 {
// t.Errorf("stack %s received %d InvalidAddressesReceived", name, n)
// }
if n := stats.IP.OutgoingPacketErrors.Value(); n != 0 {
t.Errorf("stack %s received %d OutgoingPacketErrors", name, n)
}
if n := stats.TCP.FailedConnectionAttempts.Value(); n != 0 {
t.Errorf("stack %s received %d FailedConnectionAttempts", name, n)
}
if n := stats.TCP.InvalidSegmentsReceived.Value(); n != 0 {
t.Errorf("stack %s received %d InvalidSegmentsReceived", name, n)
}
if n := stats.TCP.ResetsSent.Value(); n != 0 {
t.Errorf("stack %s received %d ResetsSent", name, n)
}
if n := stats.TCP.ResetsReceived.Value(); n != 0 {
t.Errorf("stack %s received %d ResetsReceived", name, n)
}
}
})
}
if err := b.Close(); err != nil {
t.Fatal(err)
}
// verify that the endpoint from the constituent link on sb is still accessible
// and the bridge endpoint and endpoint on s2 are no longer accessible from s1
noLongerConnectable := map[tcpip.Address]*stack.Stack{
s2addr: s2,
baddr: sb,
}
stillConnectable := map[tcpip.Address]*stack.Stack{
bcaddr: sb,
}
for ipaddr, toStack := range noLongerConnectable {
senderWaitQueue := new(waiter.Queue)
sender, err := s1.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, senderWaitQueue)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
}
defer sender.Close()
receiverWaitQueue := new(waiter.Queue)
receiver, err := toStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, receiverWaitQueue)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
}
defer receiver.Close()
fulladdr := tcpip.FullAddress{Addr: ipaddr, Port: 2}
if err := receiver.Bind(fulladdr, nil); err != nil {
t.Fatalf("bind failed: %s", err)
}
if err := receiver.Listen(1); err != nil {
t.Fatalf("listen failed: %s", err)
}
if err := connect(sender, fulladdr, senderWaitQueue, receiverWaitQueue); err != timeoutSendReady {
t.Errorf("expected timeout sendready, got nil error connecting to addr %s", ipaddr)
}
}
for ipaddr, toStack := range stillConnectable {
recvd, err := connectAndWrite(s1, toStack, ipaddr, payload)
if err != nil {
t.Fatal(err)
}
if got := string(recvd); got != payload {
t.Errorf("got Read(...) = %v, want = %v", got, payload)
}
}
}
// pipe mints two linked endpoints with the given link addresses.
func pipe(addr1, addr2 tcpip.LinkAddress) (*endpoint, *endpoint) {
ep1, ep2 := &endpoint{linkAddr: addr1}, &endpoint{linkAddr: addr2}
ep1.linked = ep2
ep2.linked = ep1
return ep1, ep2
}
// Use our own endpoint fake because we'd like to report
// CapabilityResolutionRequired and trigger link address resolution.
//
// `endpoint` cannot be copied.
//
// Make endpoints using `pipe()`, not using endpoint literals.
type endpoint struct {
linkAddr tcpip.LinkAddress
dispatcher stack.NetworkDispatcher
linked *endpoint
onWritePacket func(buffer.VectorisedView)
}
func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
if e.linked == nil {
panic(fmt.Sprintf("ep %+v has not been linked to another endpoint; create endpoints with `pipe()`", e))
}
if !e.linked.IsAttached() {
panic(fmt.Sprintf("ep: %+v linked endpoint: %+v has not been `Attach`ed; call stack.CreateNIC to attach it", e, e.linked))
}
vv := buffer.NewVectorisedView(hdr.UsedLength()+payload.Size(), append([]buffer.View{hdr.View()}, payload.Views()...))
// the "remote" address for `other` is our local address and vice versa
e.linked.dispatcher.DeliverNetworkPacket(e.linked, r.LocalLinkAddress, r.RemoteLinkAddress, protocol, vv)
if e.onWritePacket != nil {
e.onWritePacket(vv)
}
return nil
}
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
}
func (e *endpoint) IsAttached() bool {
return e.dispatcher != nil
}
func (*endpoint) MTU() uint32 {
return 0
}
func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityResolutionRequired
}
func (*endpoint) MaxHeaderLength() uint16 {
return 0
}
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
func makeStackWithEndpoint(ep *endpoint, addr tcpip.Address) (*stack.Stack, error) {
s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
id := stack.RegisterLinkEndpoint(ep)
id, _ = bridge.NewEndpoint(id)
if err := s.CreateNIC(1, id); err != nil {
return nil, fmt.Errorf("CreateNIC failed: %s", err)
}
if err := s.AddAddress(1, header.ARPProtocolNumber, arp.ProtocolAddress); err != nil {
return nil, fmt.Errorf("AddAddress failed: %s", err)
}
if err := s.AddAddress(1, header.IPv4ProtocolNumber, addr); err != nil {
return nil, fmt.Errorf("AddAddress failed: %s", err)
}
return s, nil
}
func makeStackWithBridgedEndpoints(ep1, ep2 *endpoint, baddr tcpip.Address) (*stack.Stack, *bridge.Endpoint, error) {
linkID1 := stack.RegisterLinkEndpoint(ep1)
linkID2 := stack.RegisterLinkEndpoint(ep2)
linkID1, bep1 := bridge.NewEndpoint(linkID1)
linkID2, bep2 := bridge.NewEndpoint(linkID2)
stk := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
if err := stk.CreateNIC(1, linkID1); err != nil {
return nil, nil, fmt.Errorf("CreateNIC failed: %s", err)
}
if err := stk.CreateNIC(2, linkID2); err != nil {
return nil, nil, fmt.Errorf("CreateNIC failed: %s", err)
}
bridge := bridge.New([]*bridge.BridgeableEndpoint{bep1, bep2})
bID := tcpip.NICID(3)
if err := stk.CreateNIC(bID, stack.RegisterLinkEndpoint(bridge)); err != nil {
return nil, nil, fmt.Errorf("CreateNIC failed: %s", err)
}
if err := stk.AddAddress(bID, header.IPv4ProtocolNumber, baddr); err != nil {
return nil, nil, fmt.Errorf("AddAddress failed: %s", err)
}
if err := stk.AddAddress(bID, header.ARPProtocolNumber, arp.ProtocolAddress); err != nil {
return nil, nil, fmt.Errorf("AddAddress failed: %s", err)
}
return stk, bridge, nil
}
func connectAndWrite(fromStack *stack.Stack, toStack *stack.Stack, addr tcpip.Address, payload string) ([]byte, error) {
senderWaitQueue := new(waiter.Queue)
sender, err := fromStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, senderWaitQueue)
if err != nil {
return nil, fmt.Errorf("NewEndpoint failed: %s", err)
}
defer sender.Close()
receiverWaitQueue := new(waiter.Queue)
receiver, err := toStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, receiverWaitQueue)
if err != nil {
return nil, fmt.Errorf("NewEndpoint failed: %s", err)
}
defer receiver.Close()
fulladdr := tcpip.FullAddress{Addr: addr, Port: 1}
if err := receiver.Bind(fulladdr, nil); err != nil {
return nil, fmt.Errorf("bind failed: %s", err)
}
if err := receiver.Listen(1); err != nil {
return nil, fmt.Errorf("listen failed: %s", err)
}
if err := connect(sender, fulladdr, senderWaitQueue, receiverWaitQueue); err != nil {
return nil, fmt.Errorf("connect failed: %s\n\n%+v\n\n%+v", err, fromStack.Stats(), toStack.Stats())
}
ep, wq, err := receiver.Accept()
if err != nil {
return nil, fmt.Errorf("accept failed: %s", err)
}
if err := write(sender, fulladdr, payload, wq); err != nil {
return nil, err
}
recvd, _, err := ep.Read(nil)
if err != nil {
return nil, fmt.Errorf("read failed: %s", err)
}
return recvd, nil
}
func write(sender tcpip.Endpoint, s2fulladdr tcpip.FullAddress, payload string, wq *waiter.Queue) error {
payloadReceivedWaitEntry, payloadReceivedNotifyCh := waiter.NewChannelEntry(nil)
wq.EventRegister(&payloadReceivedWaitEntry, waiter.EventIn)
defer wq.EventUnregister(&payloadReceivedWaitEntry)
if _, _, err := sender.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{To: &s2fulladdr}); err != nil {
return fmt.Errorf("write failed: %s", err)
}
select {
case <-payloadReceivedNotifyCh:
case <-time.After(1 * time.Second):
return timeoutPayloadReceived
}
return nil
}
func connect(sender tcpip.Endpoint, addr tcpip.FullAddress, senderWaitQueue, receiverWaitQueue *waiter.Queue) error {
sendReadyWaitEntry, sendReadyNotifyCh := waiter.NewChannelEntry(nil)
senderWaitQueue.EventRegister(&sendReadyWaitEntry, waiter.EventOut)
defer senderWaitQueue.EventUnregister(&sendReadyWaitEntry)
receiveReadyWaitEntry, receiveReadyNotifyCh := waiter.NewChannelEntry(nil)
receiverWaitQueue.EventRegister(&receiveReadyWaitEntry, waiter.EventIn)
defer receiverWaitQueue.EventUnregister(&receiveReadyWaitEntry)
if err := sender.Connect(addr); err != tcpip.ErrConnectStarted {
return fmt.Errorf("connect failed: %s", err)
}
select {
case <-sendReadyNotifyCh:
case <-time.After(1 * time.Second):
return timeoutSendReady
}
select {
case <-receiveReadyNotifyCh:
case <-time.After(1 * time.Second):
return timeoutReceiveReady
}
return nil
}