| // Copyright 2018 The gVisor Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package stack_test |
| |
| import ( |
| "math" |
| "math/rand" |
| "testing" |
| |
| "github.com/google/netstack/tcpip" |
| "github.com/google/netstack/tcpip/buffer" |
| "github.com/google/netstack/tcpip/header" |
| "github.com/google/netstack/tcpip/link/channel" |
| "github.com/google/netstack/tcpip/network/ipv4" |
| "github.com/google/netstack/tcpip/network/ipv6" |
| "github.com/google/netstack/tcpip/stack" |
| "github.com/google/netstack/tcpip/transport/udp" |
| "github.com/google/netstack/waiter" |
| ) |
| |
| const ( |
| stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" |
| testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" |
| |
| stackAddr = "\x0a\x00\x00\x01" |
| stackPort = 1234 |
| testPort = 4096 |
| ) |
| |
| type testContext struct { |
| t *testing.T |
| linkEPs map[string]*channel.Endpoint |
| s *stack.Stack |
| |
| ep tcpip.Endpoint |
| wq waiter.Queue |
| } |
| |
| func (c *testContext) cleanup() { |
| if c.ep != nil { |
| c.ep.Close() |
| } |
| } |
| |
| func (c *testContext) createV6Endpoint(v6only bool) { |
| var err *tcpip.Error |
| c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) |
| if err != nil { |
| c.t.Fatalf("NewEndpoint failed: %v", err) |
| } |
| |
| var v tcpip.V6OnlyOption |
| if v6only { |
| v = 1 |
| } |
| if err := c.ep.SetSockOpt(v); err != nil { |
| c.t.Fatalf("SetSockOpt failed: %v", err) |
| } |
| } |
| |
| // newDualTestContextMultiNic creates the testing context and also linkEpNames |
| // named NICs. |
| func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext { |
| s := stack.New(stack.Options{ |
| NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, |
| TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) |
| linkEPs := make(map[string]*channel.Endpoint) |
| for i, linkEpName := range linkEpNames { |
| channelEP := channel.New(256, mtu, "") |
| nicid := tcpip.NICID(i + 1) |
| if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil { |
| t.Fatalf("CreateNIC failed: %v", err) |
| } |
| linkEPs[linkEpName] = channelEP |
| |
| if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil { |
| t.Fatalf("AddAddress IPv4 failed: %v", err) |
| } |
| |
| if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil { |
| t.Fatalf("AddAddress IPv6 failed: %v", err) |
| } |
| } |
| |
| s.SetRouteTable([]tcpip.Route{ |
| { |
| Destination: header.IPv4EmptySubnet, |
| NIC: 1, |
| }, |
| { |
| Destination: header.IPv6EmptySubnet, |
| NIC: 1, |
| }, |
| }) |
| |
| return &testContext{ |
| t: t, |
| s: s, |
| linkEPs: linkEPs, |
| } |
| } |
| |
| type headers struct { |
| srcPort uint16 |
| dstPort uint16 |
| } |
| |
| func newPayload() []byte { |
| b := make([]byte, 30+rand.Intn(100)) |
| for i := range b { |
| b[i] = byte(rand.Intn(256)) |
| } |
| return b |
| } |
| |
| func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) { |
| // Allocate a buffer for data and headers. |
| buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) |
| copy(buf[len(buf)-len(payload):], payload) |
| |
| // Initialize the IP header. |
| ip := header.IPv6(buf) |
| ip.Encode(&header.IPv6Fields{ |
| PayloadLength: uint16(header.UDPMinimumSize + len(payload)), |
| NextHeader: uint8(udp.ProtocolNumber), |
| HopLimit: 65, |
| SrcAddr: testV6Addr, |
| DstAddr: stackV6Addr, |
| }) |
| |
| // Initialize the UDP header. |
| u := header.UDP(buf[header.IPv6MinimumSize:]) |
| u.Encode(&header.UDPFields{ |
| SrcPort: h.srcPort, |
| DstPort: h.dstPort, |
| Length: uint16(header.UDPMinimumSize + len(payload)), |
| }) |
| |
| // Calculate the UDP pseudo-header checksum. |
| xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) |
| |
| // Calculate the UDP checksum and set it. |
| xsum = header.Checksum(payload, xsum) |
| u.SetChecksum(^u.CalculateChecksum(xsum)) |
| |
| // Inject packet. |
| c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) |
| } |
| |
| func TestTransportDemuxerRegister(t *testing.T) { |
| for _, test := range []struct { |
| name string |
| proto tcpip.NetworkProtocolNumber |
| want *tcpip.Error |
| }{ |
| {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol}, |
| {"success", ipv4.ProtocolNumber, nil}, |
| } { |
| t.Run(test.name, func(t *testing.T) { |
| s := stack.New(stack.Options{ |
| NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, |
| TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) |
| if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want { |
| t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want) |
| } |
| }) |
| } |
| } |
| |
| // TestReuseBindToDevice injects varied packets on input devices and checks that |
| // the distribution of packets received matches expectations. |
| func TestDistribution(t *testing.T) { |
| type endpointSockopts struct { |
| reuse int |
| bindToDevice string |
| } |
| for _, test := range []struct { |
| name string |
| // endpoints will received the inject packets. |
| endpoints []endpointSockopts |
| // wantedDistribution is the wanted ratio of packets received on each |
| // endpoint for each NIC on which packets are injected. |
| wantedDistributions map[string][]float64 |
| }{ |
| { |
| "BindPortReuse", |
| // 5 endpoints that all have reuse set. |
| []endpointSockopts{ |
| endpointSockopts{1, ""}, |
| endpointSockopts{1, ""}, |
| endpointSockopts{1, ""}, |
| endpointSockopts{1, ""}, |
| endpointSockopts{1, ""}, |
| }, |
| map[string][]float64{ |
| // Injected packets on dev0 get distributed evenly. |
| "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2}, |
| }, |
| }, |
| { |
| "BindToDevice", |
| // 3 endpoints with various bindings. |
| []endpointSockopts{ |
| endpointSockopts{0, "dev0"}, |
| endpointSockopts{0, "dev1"}, |
| endpointSockopts{0, "dev2"}, |
| }, |
| map[string][]float64{ |
| // Injected packets on dev0 go only to the endpoint bound to dev0. |
| "dev0": []float64{1, 0, 0}, |
| // Injected packets on dev1 go only to the endpoint bound to dev1. |
| "dev1": []float64{0, 1, 0}, |
| // Injected packets on dev2 go only to the endpoint bound to dev2. |
| "dev2": []float64{0, 0, 1}, |
| }, |
| }, |
| { |
| "ReuseAndBindToDevice", |
| // 6 endpoints with various bindings. |
| []endpointSockopts{ |
| endpointSockopts{1, "dev0"}, |
| endpointSockopts{1, "dev0"}, |
| endpointSockopts{1, "dev1"}, |
| endpointSockopts{1, "dev1"}, |
| endpointSockopts{1, "dev1"}, |
| endpointSockopts{1, ""}, |
| }, |
| map[string][]float64{ |
| // Injected packets on dev0 get distributed among endpoints bound to |
| // dev0. |
| "dev0": []float64{0.5, 0.5, 0, 0, 0, 0}, |
| // Injected packets on dev1 get distributed among endpoints bound to |
| // dev1 or unbound. |
| "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, |
| // Injected packets on dev999 go only to the unbound. |
| "dev999": []float64{0, 0, 0, 0, 0, 1}, |
| }, |
| }, |
| } { |
| t.Run(test.name, func(t *testing.T) { |
| for device, wantedDistribution := range test.wantedDistributions { |
| t.Run(device, func(t *testing.T) { |
| var devices []string |
| for d := range test.wantedDistributions { |
| devices = append(devices, d) |
| } |
| c := newDualTestContextMultiNic(t, defaultMTU, devices) |
| defer c.cleanup() |
| |
| c.createV6Endpoint(false) |
| |
| eps := make(map[tcpip.Endpoint]int) |
| |
| pollChannel := make(chan tcpip.Endpoint) |
| for i, endpoint := range test.endpoints { |
| // Try to receive the data. |
| wq := waiter.Queue{} |
| we, ch := waiter.NewChannelEntry(nil) |
| wq.EventRegister(&we, waiter.EventIn) |
| defer wq.EventUnregister(&we) |
| defer close(ch) |
| |
| var err *tcpip.Error |
| ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) |
| if err != nil { |
| c.t.Fatalf("NewEndpoint failed: %v", err) |
| } |
| eps[ep] = i |
| |
| go func(ep tcpip.Endpoint) { |
| for range ch { |
| pollChannel <- ep |
| } |
| }(ep) |
| |
| defer ep.Close() |
| reusePortOption := tcpip.ReusePortOption(endpoint.reuse) |
| if err := ep.SetSockOpt(reusePortOption); err != nil { |
| c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err) |
| } |
| bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) |
| if err := ep.SetSockOpt(bindToDeviceOption); err != nil { |
| c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err) |
| } |
| if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { |
| t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err) |
| } |
| } |
| |
| npackets := 100000 |
| nports := 10000 |
| if got, want := len(test.endpoints), len(wantedDistribution); got != want { |
| t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) |
| } |
| ports := make(map[uint16]tcpip.Endpoint) |
| stats := make(map[tcpip.Endpoint]int) |
| for i := 0; i < npackets; i++ { |
| // Send a packet. |
| port := uint16(i % nports) |
| payload := newPayload() |
| c.sendV6Packet(payload, |
| &headers{ |
| srcPort: testPort + port, |
| dstPort: stackPort}, |
| device) |
| |
| var addr tcpip.FullAddress |
| ep := <-pollChannel |
| _, _, err := ep.Read(&addr) |
| if err != nil { |
| c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err) |
| } |
| stats[ep]++ |
| if i < nports { |
| ports[uint16(i)] = ep |
| } else { |
| // Check that all packets from one client are handled by the same |
| // socket. |
| if want, got := ports[port], ep; want != got { |
| t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) |
| } |
| } |
| } |
| |
| // Check that a packet distribution is as expected. |
| for ep, i := range eps { |
| wantedRatio := wantedDistribution[i] |
| wantedRecv := wantedRatio * float64(npackets) |
| actualRecv := stats[ep] |
| actualRatio := float64(stats[ep]) / float64(npackets) |
| // The deviation is less than 10%. |
| if math.Abs(actualRatio-wantedRatio) > 0.05 { |
| t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets) |
| } |
| } |
| }) |
| } |
| }) |
| } |
| } |