blob: edddea41532fcca7032e62f1dae20b5cddaab0da [file] [log] [blame]
// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package bridge_test
import (
"bytes"
"errors"
"fmt"
"testing"
"time"
"go.fuchsia.dev/fuchsia/src/connectivity/network/netstack/link/bridge"
"go.fuchsia.dev/fuchsia/src/connectivity/network/netstack/packetbuffer"
"go.fuchsia.dev/fuchsia/src/connectivity/network/netstack/util"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
linkAddr1 = tcpip.LinkAddress("\x02\x03\x04\x05\x06\x07")
linkAddr2 = tcpip.LinkAddress("\x02\x03\x04\x05\x06\x08")
linkAddr3 = tcpip.LinkAddress("\x02\x03\x04\x05\x06\x09")
linkAddr4 = tcpip.LinkAddress("\x02\x03\x04\x05\x06\x0a")
linkAddr5 = tcpip.LinkAddress("\x02\x03\x04\x05\x06\x0b")
linkAddr6 = tcpip.LinkAddress("\x02\x03\x04\x05\x06\x0c")
)
var (
timeoutReceiveReady = errors.New("receiveready")
timeoutSendReady = errors.New("sendready")
timeoutPayloadReceived = errors.New("payloadreceived")
)
type endpointWithAttributes struct {
stack.LinkEndpoint
capabilities stack.LinkEndpointCapabilities
maxHeaderLength uint16
}
func (ep *endpointWithAttributes) Capabilities() stack.LinkEndpointCapabilities {
return ep.LinkEndpoint.Capabilities() | ep.capabilities
}
func (ep *endpointWithAttributes) MaxHeaderLength() uint16 {
return ep.LinkEndpoint.MaxHeaderLength() + ep.maxHeaderLength
}
func TestEndpointAttributes(t *testing.T) {
ep1 := bridge.NewEndpoint(&endpointWithAttributes{
LinkEndpoint: loopback.New(),
capabilities: stack.CapabilityLoopback,
maxHeaderLength: 5,
})
ep2 := bridge.NewEndpoint(&endpointWithAttributes{
LinkEndpoint: loopback.New(),
capabilities: stack.CapabilityLoopback | stack.CapabilityResolutionRequired,
maxHeaderLength: 10,
})
bridgeEP := bridge.New([]*bridge.BridgeableEndpoint{ep1, ep2})
if got, want := bridgeEP.Capabilities(), stack.CapabilityResolutionRequired; got != want {
t.Errorf("got Capabilities = %b, want = %b", got, want)
}
if got, want := bridgeEP.MaxHeaderLength(), ep2.MaxHeaderLength(); got != want {
t.Errorf("got MaxHeaderLength = %d, want = %d", got, want)
}
if got, want := bridgeEP.MTU(), ep2.MTU(); got != want {
t.Errorf("got MTU = %d, want = %d", got, want)
}
if linkAddr := bridgeEP.LinkAddress(); linkAddr[0]&0x2 == 0 {
t.Errorf("bridge.LinkAddress() expected to be locally administered MAC address, got: %s", linkAddr)
}
}
type waitingEndpoint struct {
stack.LinkEndpoint
ch chan struct{}
}
func (we *waitingEndpoint) Wait() {
<-we.ch
}
func TestEndpoint_Wait(t *testing.T) {
ep := loopback.New()
ep1 := waitingEndpoint{
LinkEndpoint: ep,
ch: make(chan struct{}),
}
ep2 := waitingEndpoint{
LinkEndpoint: ep,
ch: make(chan struct{}),
}
bridgeEP := bridge.New([]*bridge.BridgeableEndpoint{
bridge.NewEndpoint(&ep1),
bridge.NewEndpoint(&ep2),
})
ch := make(chan struct{})
go func() {
bridgeEP.Wait()
close(ch)
}()
for _, ep := range []waitingEndpoint{ep1, ep2} {
select {
case <-ch:
t.Fatal("bridge wait completed before constituent links")
case <-time.After(100 * time.Millisecond):
}
close(ep.ch)
}
select {
case <-ch:
case <-time.After(100 * time.Millisecond):
t.Fatal("bridge wait pending after constituent links completed")
}
}
var _ stack.NetworkDispatcher = (*testNetworkDispatcher)(nil)
type testNetworkDispatcher struct {
pkt *stack.PacketBuffer
count int
}
func (t *testNetworkDispatcher) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
t.count++
t.pkt = pkt
}
func (*testNetworkDispatcher) DeliverOutboundPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
}
const channelEndpointHeaderLen = 1
var _ stack.LinkEndpoint = (*channelEndpoint)(nil)
type channelEndpoint struct {
stack.LinkEndpoint
linkAddr tcpip.LinkAddress
c chan *stack.PacketBuffer
}
func (*channelEndpoint) MaxHeaderLength() uint16 {
return channelEndpointHeaderLen
}
func (e *channelEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
func (e *channelEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
_ = pkt.LinkHeader().Push(channelEndpointHeaderLen)
select {
case e.c <- pkt:
default:
return tcpip.ErrWouldBlock
}
return nil
}
func (e *channelEndpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
i := 0
for pkt := pkts.Front(); pkt != nil; i, pkt = i+1, pkt.Next() {
_ = pkt.LinkHeader().Push(channelEndpointHeaderLen)
select {
case e.c <- pkt:
default:
return i, tcpip.ErrWouldBlock
}
}
return i, nil
}
func (e *channelEndpoint) getPacket() *stack.PacketBuffer {
select {
case pkt := <-e.c:
return pkt
default:
return nil
}
}
func makeChannelEndpoint(linkAddr tcpip.LinkAddress, size int) channelEndpoint {
return channelEndpoint{
LinkEndpoint: loopback.New(),
linkAddr: linkAddr,
c: make(chan *stack.PacketBuffer, size),
}
}
// TestBridgeWritePackets tests that writing to a bridge writes the packets to
// all bridged endpoints.
func TestBridgeWritePackets(t *testing.T) {
data := [][]byte{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}
ep1 := makeChannelEndpoint(linkAddr1, len(data))
ep2 := makeChannelEndpoint(linkAddr2, len(data))
ep3 := makeChannelEndpoint(linkAddr3, len(data))
bep1 := bridge.NewEndpoint(&ep1)
bep2 := bridge.NewEndpoint(&ep2)
bep3 := bridge.NewEndpoint(&ep3)
bridgeEP := bridge.New([]*bridge.BridgeableEndpoint{bep1, bep2, bep3})
t.Run("DeliverNetworkPacketToBridge", func(t *testing.T) {
// The bridge and channel endpoints do not care about the route, GSO
// or network protocol number when writing packets.
bridgeEP.DeliverNetworkPacketToBridge(nil /* rxEP */, linkAddr4, linkAddr5, 0 /* protocol */, stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(bridgeEP.MaxHeaderLength()),
Data: buffer.View(data[0]).ToVectorisedView(),
}))
// The first byte in the data from the endpoints is expected to be the link header
// byte which we ignore.
if pkt := ep1.getPacket(); pkt == nil {
t.Error("expected a packet on ep1")
} else if got := pkt.Data.ToView(); !bytes.Equal(got, data[0]) {
t.Errorf("got ep1 data = %x, want = %x", got, data[0])
}
if pkt := ep2.getPacket(); pkt == nil {
t.Error("expected a packet on ep2")
} else if got := pkt.Data.ToView(); !bytes.Equal(got, data[0]) {
t.Errorf("got ep2 data = %x, want = %x", got, data[0])
}
if pkt := ep3.getPacket(); pkt == nil {
t.Error("expected a packet on ep3")
} else if got := pkt.Data.ToView(); !bytes.Equal(got, data[0]) {
t.Errorf("got ep3 data = %x, want = %x", got, data[0])
}
})
t.Run("WritePacket", func(t *testing.T) {
// The bridge and channel endpoints do not care about the route, GSO
// or network protocol number when writing packets.
err := bridgeEP.WritePacket(nil /* route */, nil /* gso */, 0 /* protocol */, stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(bridgeEP.MaxHeaderLength()),
Data: buffer.View(data[0]).ToVectorisedView(),
}))
if err != nil {
t.Errorf("bridgeEP.WritePacket(nil, nil, 0, _): %s", err)
}
if pkt := ep1.getPacket(); pkt == nil {
t.Error("expected a packet on ep1")
} else if got := pkt.Data.ToView(); !bytes.Equal(got, data[0]) {
t.Errorf("got ep1 data = %x, want = %x", got, data[0])
}
if pkt := ep2.getPacket(); pkt == nil {
t.Error("expected a packet on ep2")
} else if got := pkt.Data.ToView(); !bytes.Equal(got, data[0]) {
t.Errorf("got ep2 data = %x, want = %x", got, data[0])
}
if pkt := ep3.getPacket(); pkt == nil {
t.Error("expected a packet on ep3")
} else if got := pkt.Data.ToView(); !bytes.Equal(got, data[0]) {
t.Errorf("got ep3 data = %x, want = %x", got, data[0])
}
})
for i := 1; i <= len(data); i++ {
t.Run(fmt.Sprintf("WritePackets(N=%d)", i), func(t *testing.T) {
var pkts stack.PacketBufferList
for j := 0; j < i; j++ {
pkts.PushBack(stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(bridgeEP.MaxHeaderLength()),
Data: buffer.View(data[j]).ToVectorisedView(),
}))
}
// The bridge and channel endpoints do not care about the route, GSO
// or network protocol number when writing packets.
n, err := bridgeEP.WritePackets(nil /* route */, nil /* gso */, pkts, 0 /* protocol */)
if err != nil {
t.Errorf("bridgeEP.WritePackets(nil, nil, _, 0): %s", err)
}
if n != i {
t.Errorf("got bridgeEP.WritePackets(nil, nil, _, 0) = %d, want = %d", n, i)
}
for j := 0; j < i; j++ {
if pkt := ep1.getPacket(); pkt == nil {
t.Errorf("(j=%d) expected a packet on ep1", j)
} else if got := pkt.Data.ToView(); !bytes.Equal(got, data[j]) {
t.Errorf("(j=%d) got ep1 data = %x, want = %x", j, got, data[j])
}
if pkt := ep2.getPacket(); pkt == nil {
t.Errorf("(j=%d) expected a packet on ep2", j)
} else if got := pkt.Data.ToView(); !bytes.Equal(got, data[j]) {
t.Errorf("(j=%d) got ep2 data = %x, want = %x", j, got, data[j])
}
if pkt := ep3.getPacket(); pkt == nil {
t.Errorf("(j=%d) expected a packet on ep3", j)
} else if got := pkt.Data.ToView(); !bytes.Equal(got, data[j]) {
t.Errorf("(j=%d) got ep3 data = %x, want = %x", j, got, data[j])
}
}
})
}
}
// TestBridgeRouting makes sure that frames are directed to the right unicast
// endpoint or floods all endpoints for multicast and broadcast frames.
func TestBridgeRouting(t *testing.T) {
type rxEPKind int
const (
rxEPNil rxEPKind = iota
rxEP1
rxEP2
)
data := []byte{1, 2, 3, 4}
pkt := stack.PacketBuffer{
Data: buffer.View(data).ToVectorisedView(),
}
tests := []struct {
name string
dstAddr tcpip.LinkAddress
ep1ShouldGetPacket bool
nd1ShouldGetPacket bool
ep2ShouldGetPacket bool
nd2ShouldGetPacket bool
ndbShouldGetPacket bool
}{
{
name: "ToMulticast",
dstAddr: "\x01\x03\x04\x05\x06\x07",
ep1ShouldGetPacket: true,
nd1ShouldGetPacket: true,
ep2ShouldGetPacket: true,
nd2ShouldGetPacket: true,
ndbShouldGetPacket: true,
},
{
name: "ToBroadcast",
dstAddr: "\xff\xff\xff\xff\xff\xff",
ep1ShouldGetPacket: true,
nd1ShouldGetPacket: true,
ep2ShouldGetPacket: true,
nd2ShouldGetPacket: true,
ndbShouldGetPacket: true,
},
{
name: "ToEP1",
dstAddr: linkAddr1,
nd1ShouldGetPacket: true,
},
{
name: "ToEP2",
dstAddr: linkAddr2,
nd2ShouldGetPacket: true,
},
{
name: "ToOther",
dstAddr: linkAddr4,
ep1ShouldGetPacket: true,
ep2ShouldGetPacket: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
subtests := []struct {
name string
rxEP rxEPKind
ep1ShouldGetPacket bool
ep2ShouldGetPacket bool
}{
{
name: "Delivered from nil EP",
rxEP: rxEPNil,
ep1ShouldGetPacket: test.ep1ShouldGetPacket,
ep2ShouldGetPacket: test.ep2ShouldGetPacket,
},
{
name: "Delivered from EP1",
rxEP: rxEP1,
ep1ShouldGetPacket: false,
ep2ShouldGetPacket: test.ep2ShouldGetPacket,
},
{
name: "Delivered from EP2",
rxEP: rxEP2,
ep1ShouldGetPacket: test.ep1ShouldGetPacket,
ep2ShouldGetPacket: false,
},
}
for _, subtest := range subtests {
t.Run(test.name, func(t *testing.T) {
ep1 := makeChannelEndpoint(linkAddr1, 1)
ep2 := makeChannelEndpoint(linkAddr2, 1)
bep1 := bridge.NewEndpoint(&ep1)
bep2 := bridge.NewEndpoint(&ep2)
var nd1, nd2, ndb testNetworkDispatcher
bridgeEP := bridge.New([]*bridge.BridgeableEndpoint{bep1, bep2})
bep1.Attach(&nd1)
bep2.Attach(&nd2)
bridgeEP.Attach(&ndb)
var rxEP *bridge.BridgeableEndpoint
switch subtest.rxEP {
case rxEPNil:
case rxEP1:
rxEP = bep1
case rxEP2:
rxEP = bep2
default:
t.Fatalf("unrecognized rxEPKind = %d", subtest.rxEP)
}
bridgeEP.DeliverNetworkPacketToBridge(rxEP, linkAddr3, test.dstAddr, 0, &pkt)
if pkt := ep1.getPacket(); subtest.ep1ShouldGetPacket {
if pkt == nil {
t.Error("expected a packet on ep1")
} else if got := pkt.Data.ToView(); !bytes.Equal(got, data) {
t.Errorf("got ep1 data = %x, want = %x", got, data)
}
} else if pkt != nil {
t.Errorf("ep1 unexpectedly got a packet = %+v", pkt)
}
if test.nd1ShouldGetPacket {
if nd1.count != 1 {
t.Errorf("got nd1.count = %d, want = 1", nd1.count)
}
if got := nd1.pkt.Data.ToView(); !bytes.Equal(got, data) {
t.Errorf("got nd1 data = %x, want = %x", got, data)
}
} else if nd1.count != 0 {
t.Errorf("got nd1.count = %d, want = 0", nd1.count)
}
if pkt := ep2.getPacket(); subtest.ep2ShouldGetPacket {
if pkt == nil {
t.Error("expected a packet on ep2")
} else if got := pkt.Data.ToView(); !bytes.Equal(got, data) {
t.Errorf("got ep2 data = %x, want = %x", got, data)
}
} else if pkt != nil {
t.Errorf("ep2 unexpectedly got a packet = %+v", pkt)
}
if test.nd2ShouldGetPacket {
if nd2.count != 1 {
t.Errorf("got nd2.count = %d, want = 1", nd2.count)
}
if got := nd2.pkt.Data.ToView(); !bytes.Equal(got, data) {
t.Errorf("got nd2 data = %x, want = %x", got, data)
}
} else if nd2.count != 0 {
t.Errorf("got nd2.count = %d, want = 0", nd2.count)
}
if test.ndbShouldGetPacket {
if ndb.count != 1 {
t.Errorf("got ndb.count = %d, want = 1", ndb.count)
}
if got := ndb.pkt.Data.ToView(); !bytes.Equal(got, data) {
t.Errorf("got ndb data = %x, want = %x", got, data)
}
} else if ndb.count != 0 {
t.Errorf("got ndb.count = %d, want = 0", ndb.count)
}
})
}
})
}
}
func TestBridge(t *testing.T) {
const (
s1NICID = 1
s2NICID = 10
sbEP2NICID = 2
sbOtherNICID = 9000
)
for _, testCase := range []struct {
name string
protocolFactory stack.NetworkProtocolFactory
protocolNumber tcpip.NetworkProtocolNumber
addressSize int
}{
{name: "ipv4", protocolFactory: ipv4.NewProtocol, protocolNumber: ipv4.ProtocolNumber, addressSize: header.IPv4AddressSize},
{name: "ipv6", protocolFactory: ipv6.NewProtocol, protocolNumber: ipv6.ProtocolNumber, addressSize: header.IPv6AddressSize},
} {
t.Run(testCase.name, func(t *testing.T) {
// payload should be unique enough that it won't accidentally appear
// in TCP/IP packets.
const payload = "hello"
// Connection diagram:
//
// <---> ep1 <--pipe--> ep2 <--bridge--> ep3 <--pipe--> ep4
//
// Included are several additional endpoints to ensure bridging N > 2
// endpoints works.
ep1, ep2 := pipe(linkAddr1, linkAddr2)
ep3, ep4 := pipe(linkAddr3, linkAddr4)
ep5, ep6 := pipe(linkAddr5, linkAddr6)
s1addr := tcpip.Address(bytes.Repeat([]byte{1}, testCase.addressSize))
s1subnet := util.PointSubnet(s1addr)
s1, err := makeStackWithEndpoint(s1NICID, ep1, testCase.protocolFactory, testCase.protocolNumber, s1addr)
if err != nil {
t.Fatal(err)
}
baddr := tcpip.Address(bytes.Repeat([]byte{2}, testCase.addressSize))
bsubnet := util.PointSubnet(baddr)
sb, b, bridgeNICID := makeStackWithBridgedEndpoints(t, testCase.protocolFactory, testCase.protocolNumber, baddr, ep5, ep2, ep3)
if err := sb.CreateNIC(sbOtherNICID, ep6); err != nil {
t.Fatal(err)
}
if err := b.Up(); err != nil {
t.Fatal(err)
}
s2addr := tcpip.Address(bytes.Repeat([]byte{3}, testCase.addressSize))
s2subnet := util.PointSubnet(s2addr)
s2, err := makeStackWithEndpoint(s2NICID, ep4, testCase.protocolFactory, testCase.protocolNumber, s2addr)
if err != nil {
t.Fatal(err)
}
// Add an address to one of the constituent links of the bridge (in addition
// to the address on the virtual NIC representing the bridge itself), to test
// that constituent links are still routable.
bcaddr := tcpip.Address(bytes.Repeat([]byte{4}, testCase.addressSize))
bcsubnet := util.PointSubnet(bcaddr)
if err := sb.AddAddress(sbEP2NICID, testCase.protocolNumber, bcaddr); err != nil {
t.Fatal(fmt.Errorf("AddAddress failed: %s", err))
}
// Make sure s1 can communicate with all the addresses we configured
// above.
s1.SetRouteTable([]tcpip.Route{
{
Destination: s2subnet,
NIC: s1NICID,
},
{
Destination: bsubnet,
NIC: s1NICID,
},
{
Destination: bcsubnet,
NIC: s1NICID,
},
})
sb.SetRouteTable([]tcpip.Route{
{
Destination: s1subnet,
NIC: sbEP2NICID,
},
{
Destination: s1subnet,
NIC: bridgeNICID,
},
})
s2.SetRouteTable(
[]tcpip.Route{
{
Destination: s1subnet,
NIC: s2NICID,
},
},
)
addrs := map[tcpip.Address]*stack.Stack{
s2addr: s2,
baddr: sb,
bcaddr: sb,
}
stacks := map[string]*stack.Stack{
"s1": s1, "s2": s2, "sb": sb,
}
ep2.onWritePacket = func(pkt *stack.PacketBuffer) {
for i, view := range pkt.Data.Views() {
if bytes.Contains(view, []byte(payload)) {
t.Errorf("did not expect payload %x to be sent back to ep1 in view %d: %x", payload, i, view)
}
}
}
for addr, toStack := range addrs {
t.Run(fmt.Sprintf("ConnectAndWrite_%s", addr), func(t *testing.T) {
recvd, err := connectAndWrite(s1, toStack, testCase.protocolFactory, testCase.protocolNumber, addr, payload)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(recvd, []byte(payload)) {
t.Errorf("got Read(...) = %x, want = %x", recvd, payload)
}
for name, s := range stacks {
stats := s.Stats()
if n := stats.UnknownProtocolRcvdPackets.Value(); n != 0 {
t.Errorf("stack %s received %d UnknownProtocolRcvdPackets", name, n)
}
if n := stats.MalformedRcvdPackets.Value(); n != 0 {
t.Errorf("stack %s received %d MalformedRcvdPackets", name, n)
}
if n := stats.DroppedPackets.Value(); n != 0 {
t.Errorf("stack %s received %d DroppedPackets", name, n)
}
// The invalid address counter counts packets that have been received
// by a stack correctly addressed at the link layer but incorrectly
// addressed at the network layer (e.g. no network interface has the
// address listed in the packet). This usually happens because
// the stack is being sent packets for an IP address that it used to
// have but doesn't have anymore. In this case, the bridge will
// forward a packet to all constituent links when the link address that
// the packet is addressed to isn't found on the bridge.
//
// TODO(fxbug.dev/20778): When we implement learning, we should be able to
// modify this test setup to get to zero invalid addresses received.
// With the current test setup, once learning is implemented, the
// bridge would indiscriminately forward the first packet addressed to
// a link address to all constituent links (causing #links - 1 invalid
// addresses received), observe which link the response packet came
// from, and then remember which link to forward to when the next
// packet addressed to that link address was received. We might be able
// to get to zero invalid addresses received by learning which links a
// given address is on via the broadcast packets sent during ARP.
// if n := stats.IP.InvalidAddressesReceived.Value(); n != 0 {
// t.Errorf("stack %s received %d InvalidAddressesReceived", name, n)
// }
if n := stats.IP.OutgoingPacketErrors.Value(); n != 0 {
t.Errorf("stack %s received %d OutgoingPacketErrors", name, n)
}
if n := stats.TCP.FailedConnectionAttempts.Value(); n != 0 {
t.Errorf("stack %s received %d FailedConnectionAttempts", name, n)
}
if n := stats.TCP.InvalidSegmentsReceived.Value(); n != 0 {
t.Errorf("stack %s received %d InvalidSegmentsReceived", name, n)
}
if n := stats.TCP.ResetsSent.Value(); n != 0 {
t.Errorf("stack %s received %d ResetsSent", name, n)
}
if n := stats.TCP.ResetsReceived.Value(); n != 0 {
t.Errorf("stack %s received %d ResetsReceived", name, n)
}
}
})
}
b.Attach(nil)
// verify that the endpoint from the constituent link on sb is still accessible
// and the bridge endpoint and endpoint on s2 are no longer accessible from s1
noLongerConnectable := map[tcpip.Address]*stack.Stack{
s2addr: s2,
baddr: sb,
}
stillConnectable := map[tcpip.Address]*stack.Stack{
bcaddr: sb,
}
for addr, toStack := range noLongerConnectable {
t.Run(addr.String(), func(t *testing.T) {
senderWaitQueue := new(waiter.Queue)
sender, err := s1.NewEndpoint(tcp.ProtocolNumber, testCase.protocolNumber, senderWaitQueue)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
}
defer sender.Close()
receiverWaitQueue := new(waiter.Queue)
receiver, err := toStack.NewEndpoint(tcp.ProtocolNumber, testCase.protocolNumber, receiverWaitQueue)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
}
defer receiver.Close()
if err := receiver.Bind(tcpip.FullAddress{Addr: addr}); err != nil {
t.Fatalf("bind failed: %s", err)
}
if err := receiver.Listen(1); err != nil {
t.Fatalf("listen failed: %s", err)
}
addr, err := receiver.GetLocalAddress()
if err != nil {
t.Fatalf("getlocaladdress failed: %s", err)
}
addr.NIC = 0
if err := connect(sender, addr, senderWaitQueue, receiverWaitQueue); err != timeoutSendReady {
t.Errorf("expected timeout sendready, got %v connecting to addr %+v", err, addr)
}
})
}
for addr, toStack := range stillConnectable {
recvd, err := connectAndWrite(s1, toStack, testCase.protocolFactory, testCase.protocolNumber, addr, payload)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(recvd, []byte(payload)) {
t.Errorf("got Read(...) = %x, want = %x", recvd, payload)
}
}
})
}
}
// TestBridgeableEndpointDetach tests that bridgeable endpoints don't cause
// panics after attaching to a nil dispatcher.
func TestBridgeableEndpointDetach(t *testing.T) {
ep1 := makeChannelEndpoint(linkAddr1, 1)
bep1 := bridge.NewEndpoint(&ep1)
var disp testNetworkDispatcher
if ep1.IsAttached() {
t.Fatal("ep1.IsAttached() = true, want = false")
}
if bep1.IsAttached() {
t.Fatal("bep1.IsAttached() = true, want = false")
}
bep1.Attach(&disp)
if disp.count != 0 {
t.Fatalf("got disp.count = %d, want = 0", disp.count)
}
if !ep1.IsAttached() {
t.Fatal("ep1.IsAttached() = false, want = true")
}
if !bep1.IsAttached() {
t.Fatal("bep1.IsAttached() = false, want = true")
}
bep1.DeliverNetworkPacket(linkAddr1, linkAddr2, header.IPv4ProtocolNumber, &stack.PacketBuffer{})
if disp.count != 1 {
t.Fatalf("got disp.count = %d, want = 1", disp.count)
}
bep1.Attach(nil)
if ep1.IsAttached() {
t.Fatal("ep1.IsAttached() = true, want = false")
}
if bep1.IsAttached() {
t.Fatal("bep1.IsAttached() = true, want = false")
}
bep1.DeliverNetworkPacket(linkAddr1, linkAddr2, header.IPv4ProtocolNumber, &stack.PacketBuffer{})
if disp.count != 1 {
t.Fatalf("got disp.count = %d, want = 1", disp.count)
}
}
// pipe mints two linked endpoints with the given link addresses.
func pipe(addr1, addr2 tcpip.LinkAddress) (*endpoint, *endpoint) {
ep1, ep2 := &endpoint{linkAddr: addr1}, &endpoint{linkAddr: addr2}
ep1.linked = ep2
ep2.linked = ep1
return ep1, ep2
}
var _ stack.LinkEndpoint = (*endpoint)(nil)
// Use our own endpoint fake because we'd like to report
// CapabilityResolutionRequired and trigger link address resolution.
//
// `endpoint` cannot be copied.
//
// Make endpoints using `pipe()`, not using endpoint literals.
type endpoint struct {
linkAddr tcpip.LinkAddress
dispatcher stack.NetworkDispatcher
linked *endpoint
onWritePacket func(*stack.PacketBuffer)
}
func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
if e.linked == nil {
panic(fmt.Sprintf("ep %+v has not been linked to another endpoint; create endpoints with `pipe()`", e))
}
if !e.linked.IsAttached() {
panic(fmt.Sprintf("ep: %+v linked endpoint: %+v has not been `Attach`ed; call stack.CreateNIC to attach it", e, e.linked))
}
if fn := e.onWritePacket; fn != nil {
fn(pkt)
}
// the "remote" address for `other` is our local address and vice versa.
e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress, r.RemoteLinkAddress(), protocol, packetbuffer.OutboundToInbound(pkt))
return nil
}
func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
panic("not implemented")
}
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
}
func (e *endpoint) IsAttached() bool {
return e.dispatcher != nil
}
func (*endpoint) Wait() {}
func (*endpoint) MTU() uint32 {
// As per RFC 8200 section 5:
//
// IPv6 requires that every link in the Internet have an MTU of 1280
// octets or greater. This is known as the IPv6 minimum link MTU. On
// any link that cannot convey a 1280-octet packet in one piece, link-
// specific fragmentation and reassembly must be provided at a layer
// below IPv6.
//
// RFC 791 section 3.2 also has a minimum MTU requirement for IPv4:
//
// Every internet module must be able to forward a datagram of 68
// octets without further fragmentation. This is because an internet
// header may be up to 60 octets, and the minimum fragment is 8 octets.
//
// Since the IPv6 minimum MTU is the greater value, we use that.
return header.IPv6MinimumMTU
}
func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityResolutionRequired
}
func (*endpoint) MaxHeaderLength() uint16 {
return 0
}
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
func (*endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareEther
}
func (*endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
}
func makeStackWithEndpoint(nicID tcpip.NICID, ep stack.LinkEndpoint, protocolFactory stack.NetworkProtocolFactory, protocolNumber tcpip.NetworkProtocolNumber, addr tcpip.Address) (*stack.Stack, error) {
if testing.Verbose() {
ep = sniffer.New(ep)
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
arp.NewProtocol,
protocolFactory,
},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
},
})
if err := s.CreateNIC(nicID, ep); err != nil {
return nil, fmt.Errorf("CreateNIC failed: %s", err)
}
if err := s.AddAddress(nicID, protocolNumber, addr); err != nil {
return nil, fmt.Errorf("AddAddress failed: %s", err)
}
return s, nil
}
func makeStackWithBridgedEndpoints(t *testing.T, protocolFactory stack.NetworkProtocolFactory, protocolNumber tcpip.NetworkProtocolNumber, baddr tcpip.Address, eps ...stack.LinkEndpoint) (*stack.Stack, *bridge.Endpoint, tcpip.NICID) {
t.Helper()
if testing.Verbose() {
for i := range eps {
eps[i] = sniffer.New(eps[i])
}
}
stk := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
arp.NewProtocol,
protocolFactory,
},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
},
})
beps := make([]*bridge.BridgeableEndpoint, len(eps))
for i, ep := range eps {
bep := bridge.NewEndpoint(ep)
if err := stk.CreateNIC(tcpip.NICID(i+1), bep); err != nil {
t.Fatalf("CreateNIC failed: %s", err)
}
beps[i] = bep
}
bridgeEP := bridge.New(beps)
var bridgeLinkEP stack.LinkEndpoint = bridgeEP
if testing.Verbose() {
bridgeLinkEP = sniffer.New(bridgeLinkEP)
}
bID := tcpip.NICID(len(beps) + 1)
if err := stk.CreateNIC(bID, bridgeLinkEP); err != nil {
t.Fatalf("CreateNIC failed: %s", err)
}
if err := stk.AddAddress(bID, protocolNumber, baddr); err != nil {
t.Fatalf("AddAddress failed: %s", err)
}
return stk, bridgeEP, bID
}
func connectAndWrite(fromStack *stack.Stack, toStack *stack.Stack, protocolFactory stack.NetworkProtocolFactory, protocolNumber tcpip.NetworkProtocolNumber, addr tcpip.Address, payload string) ([]byte, error) {
senderWaitQueue := new(waiter.Queue)
sender, err := fromStack.NewEndpoint(tcp.ProtocolNumber, protocolNumber, senderWaitQueue)
if err != nil {
return nil, fmt.Errorf("NewEndpoint failed: %s", err)
}
defer sender.Close()
receiverWaitQueue := new(waiter.Queue)
receiver, err := toStack.NewEndpoint(tcp.ProtocolNumber, protocolNumber, receiverWaitQueue)
if err != nil {
return nil, fmt.Errorf("NewEndpoint failed: %s", err)
}
defer receiver.Close()
if err := receiver.Bind(tcpip.FullAddress{Addr: addr}); err != nil {
return nil, fmt.Errorf("bind failed: %s", err)
}
if err := receiver.Listen(1); err != nil {
return nil, fmt.Errorf("listen failed: %s", err)
}
{
addr, err := receiver.GetLocalAddress()
if err != nil {
return nil, fmt.Errorf("getlocaladdress failed: %s", err)
}
addr.NIC = 0
if err := connect(sender, addr, senderWaitQueue, receiverWaitQueue); err != nil {
return nil, fmt.Errorf("connect failed: %s\n\n%+v\n\n%+v", err, fromStack.Stats(), toStack.Stats())
}
ep, wq, err := receiver.Accept(nil)
if err != nil {
return nil, fmt.Errorf("accept failed: %s", err)
}
if err := write(sender, addr, payload, wq); err != nil {
return nil, err
}
recvd, _, err := ep.Read(nil)
if err != nil {
return nil, fmt.Errorf("read failed: %s", err)
}
return recvd, nil
}
}
func write(sender tcpip.Endpoint, s2fulladdr tcpip.FullAddress, payload string, wq *waiter.Queue) error {
payloadReceivedWaitEntry, payloadReceivedNotifyCh := waiter.NewChannelEntry(nil)
wq.EventRegister(&payloadReceivedWaitEntry, waiter.EventIn)
defer wq.EventUnregister(&payloadReceivedWaitEntry)
if _, _, err := sender.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{To: &s2fulladdr}); err != nil {
return fmt.Errorf("write failed: %s", err)
}
select {
case <-payloadReceivedNotifyCh:
case <-time.After(1 * time.Second):
return timeoutPayloadReceived
}
return nil
}
func connect(sender tcpip.Endpoint, addr tcpip.FullAddress, senderWaitQueue, receiverWaitQueue *waiter.Queue) error {
sendReadyWaitEntry, sendReadyNotifyCh := waiter.NewChannelEntry(nil)
senderWaitQueue.EventRegister(&sendReadyWaitEntry, waiter.EventOut)
defer senderWaitQueue.EventUnregister(&sendReadyWaitEntry)
receiveReadyWaitEntry, receiveReadyNotifyCh := waiter.NewChannelEntry(nil)
receiverWaitQueue.EventRegister(&receiveReadyWaitEntry, waiter.EventIn)
defer receiverWaitQueue.EventUnregister(&receiveReadyWaitEntry)
if err := sender.Connect(addr); err != tcpip.ErrConnectStarted {
return fmt.Errorf("connect failed: %s", err)
}
select {
case <-sendReadyNotifyCh:
case <-time.After(1 * time.Second):
return timeoutSendReady
}
select {
case <-receiveReadyNotifyCh:
case <-time.After(1 * time.Second):
return timeoutReceiveReady
}
return nil
}