| // Copyright 2016 The Netstack Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package ipv4_test |
| |
| import ( |
| "context" |
| "encoding/binary" |
| "testing" |
| "time" |
| |
| "github.com/google/netstack/tcpip" |
| "github.com/google/netstack/tcpip/buffer" |
| "github.com/google/netstack/tcpip/header" |
| "github.com/google/netstack/tcpip/link/channel" |
| "github.com/google/netstack/tcpip/link/sniffer" |
| "github.com/google/netstack/tcpip/network/ipv4" |
| "github.com/google/netstack/tcpip/stack" |
| "github.com/google/netstack/waiter" |
| ) |
| |
| const stackAddr = "\x0a\x00\x00\x01" |
| |
| type testContext struct { |
| t *testing.T |
| linkEP *channel.Endpoint |
| s *stack.Stack |
| } |
| |
| func newTestContext(t *testing.T) *testContext { |
| s := stack.New([]string{ipv4.ProtocolName}, []string{ipv4.PingProtocolName}) |
| |
| const defaultMTU = 65536 |
| id, linkEP := channel.New(256, defaultMTU, "") |
| if testing.Verbose() { |
| id = sniffer.New(id) |
| } |
| if err := s.CreateNIC(1, id); err != nil { |
| t.Fatalf("CreateNIC failed: %v", err) |
| } |
| |
| if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { |
| t.Fatalf("AddAddress failed: %v", err) |
| } |
| |
| s.SetRouteTable([]tcpip.Route{{ |
| Destination: "\x00\x00\x00\x00", |
| Mask: "\x00\x00\x00\x00", |
| Gateway: "", |
| NIC: 1, |
| }}) |
| |
| return &testContext{ |
| t: t, |
| s: s, |
| linkEP: linkEP, |
| } |
| } |
| |
| func (c *testContext) cleanup() { |
| close(c.linkEP.C) |
| } |
| |
| func (c *testContext) loopback() { |
| go func() { |
| for pkt := range c.linkEP.C { |
| v := make(buffer.View, len(pkt.Header)+len(pkt.Payload)) |
| copy(v, pkt.Header) |
| copy(v[len(pkt.Header):], pkt.Payload) |
| vv := v.ToVectorisedView([1]buffer.View{}) |
| c.linkEP.Inject(pkt.Proto, &vv) |
| } |
| }() |
| } |
| |
| func TestEcho(t *testing.T) { |
| c := newTestContext(t) |
| defer c.cleanup() |
| c.loopback() |
| |
| ch := make(chan ipv4.PingReply, 1) |
| p := ipv4.Pinger{ |
| Stack: c.s, |
| NICID: 1, |
| Addr: stackAddr, |
| Wait: 10 * time.Millisecond, |
| Count: 1, // one ping only |
| } |
| if err := p.Ping(context.Background(), ch); err != nil { |
| t.Fatalf("icmp.Ping failed: %v", err) |
| } |
| |
| ping := <-ch |
| if ping.Error != nil { |
| t.Errorf("bad ping response: %v", ping.Error) |
| } |
| } |
| |
| func TestEchoSequence(t *testing.T) { |
| c := newTestContext(t) |
| defer c.cleanup() |
| c.loopback() |
| |
| const numPings = 3 |
| ch := make(chan ipv4.PingReply, numPings) |
| p := ipv4.Pinger{ |
| Stack: c.s, |
| NICID: 1, |
| Addr: stackAddr, |
| Wait: 10 * time.Millisecond, |
| Count: numPings, |
| } |
| if err := p.Ping(context.Background(), ch); err != nil { |
| t.Fatalf("icmp.Ping failed: %v", err) |
| } |
| |
| for i := uint16(0); i < numPings; i++ { |
| ping := <-ch |
| if ping.Error != nil { |
| t.Errorf("i=%d bad ping response: %v", i, ping.Error) |
| } |
| if ping.SeqNumber != i { |
| t.Errorf("SeqNumber=%d, want %d", ping.SeqNumber, i) |
| } |
| } |
| } |
| |
| const ( |
| stackAddr0 = "\x0a\x00\x00\x02" |
| stackAddr1 = "\x0a\x00\x00\x03" |
| linkAddr0 = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") |
| linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") |
| ) |
| |
| type testEndpointContext struct { |
| t *testing.T |
| s *stack.Stack |
| |
| linkEP0 *channel.Endpoint |
| linkEP1 *channel.Endpoint |
| |
| icmpCh chan header.ICMPv4 |
| } |
| |
| func (c *testEndpointContext) cleanup() { |
| close(c.linkEP0.C) |
| close(c.linkEP1.C) |
| } |
| |
| func newTestEndpointContext(t *testing.T) *testEndpointContext { |
| c := &testEndpointContext{ |
| t: t, |
| s: stack.New([]string{ipv4.ProtocolName}, []string{ipv4.PingProtocolName}), |
| icmpCh: make(chan header.ICMPv4, 10), |
| } |
| |
| const defaultMTU = 65536 |
| id0, linkEP := channel.New(256, defaultMTU, linkAddr0) |
| c.linkEP0 = linkEP |
| if testing.Verbose() { |
| id0 = sniffer.New(id0) |
| } |
| if err := c.s.CreateNIC(1, id0); err != nil { |
| t.Fatalf("CreateNIC s: %v", err) |
| } |
| id1, linkEP := channel.New(256, defaultMTU, linkAddr1) |
| c.linkEP1 = linkEP |
| if testing.Verbose() { |
| id1 = sniffer.New(id1) |
| } |
| if err := c.s.CreateNIC(2, id1); err != nil { |
| t.Fatalf("CreateNIC s: %v", err) |
| } |
| if err := c.s.AddAddress(2, ipv4.ProtocolNumber, stackAddr0); err != nil { |
| t.Fatalf("AddAddress failed: %v", err) |
| } |
| if err := c.s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil { |
| t.Fatalf("AddAddress failed: %v", err) |
| } |
| c.s.SetRouteTable([]tcpip.Route{ |
| { |
| Destination: stackAddr0, |
| Mask: "\xFF\xFF\xFF\xFF", |
| Gateway: "", |
| NIC: 1, |
| }, |
| { |
| Destination: stackAddr1, |
| Mask: "\xFF\xFF\xFF\xFF", |
| Gateway: "", |
| NIC: 2, |
| }, |
| }) |
| |
| go c.routePackets(c.linkEP0.C, c.linkEP1) |
| go c.routePackets(c.linkEP1.C, c.linkEP0) |
| return c |
| } |
| |
| func (c *testEndpointContext) countPacket(pkt channel.PacketInfo) { |
| if pkt.Proto != header.IPv4ProtocolNumber { |
| c.t.Fatalf("Received non IPV4 packet: 0x%x", pkt.Proto) |
| } |
| ipv4 := header.IPv4(pkt.Header) |
| c.icmpCh <- header.ICMPv4(append(pkt.Header[ipv4.HeaderLength():], pkt.Payload...)) |
| } |
| |
| func (c *testEndpointContext) routePackets(ch <-chan channel.PacketInfo, ep *channel.Endpoint) { |
| for pkt := range ch { |
| c.countPacket(pkt) |
| v := buffer.View(append(pkt.Header, pkt.Payload...)) |
| vs := []buffer.View{v} |
| vv := buffer.NewVectorisedView(len(v), vs) |
| ep.InjectLinkAddr(pkt.Proto, ep.LinkAddress(), &vv) |
| } |
| } |
| |
| type callbackStub struct { |
| f func(e *waiter.Entry) |
| } |
| |
| func (c *callbackStub) Callback(e *waiter.Entry) { |
| c.f(e) |
| } |
| |
| func TestEndpoints(t *testing.T) { |
| c := newTestEndpointContext(t) |
| defer c.cleanup() |
| |
| wq0 := &waiter.Queue{} |
| ep0, err := c.s.NewEndpoint(ipv4.PingProtocolNumber, ipv4.ProtocolNumber, wq0) |
| if err != nil { |
| c.t.Fatalf("NewEndpoint failed: %v", err) |
| } |
| defer ep0.Close() |
| wq1 := &waiter.Queue{} |
| ep1, err := c.s.NewEndpoint(ipv4.PingProtocolNumber, ipv4.ProtocolNumber, wq1) |
| if err != nil { |
| c.t.Fatalf("NewEndpoint failed: %v", err) |
| } |
| defer ep1.Close() |
| |
| if err := ep0.Bind(tcpip.FullAddress{NIC: 1, Addr: stackAddr0}, nil); err != nil { |
| c.t.Fatalf("Bind failed: %v", err) |
| } |
| if err := ep1.Bind(tcpip.FullAddress{NIC: 2, Addr: stackAddr1}, nil); err != nil { |
| c.t.Fatalf("Bind failed: %v", err) |
| } |
| |
| echos := 64 |
| |
| ping := func(wq *waiter.Queue, ep tcpip.Endpoint, data []byte) { |
| outPkt := make([]byte, header.ICMPv4MinimumSize+4+len(data)) |
| icmpv4 := header.ICMPv4(outPkt[:header.ICMPv4MinimumSize]) |
| icmpv4.SetType(header.ICMPv4Echo) |
| copy(outPkt[header.ICMPv4MinimumSize+4:], data) |
| binary.BigEndian.PutUint16(outPkt[header.ICMPv4MinimumSize:], 0) |
| |
| for seqno := uint16(1); seqno <= uint16(echos); seqno++ { |
| binary.BigEndian.PutUint16(outPkt[header.ICMPv4MinimumSize+2:], seqno) |
| |
| // We need to register with the waiter queue before we try writing, since |
| // the notification that the endpoint received a response may arrive immediately. |
| ready := make(chan struct{}) |
| e := waiter.Entry{Callback: &callbackStub{func(*waiter.Entry) { close(ready) }}} |
| wq.EventRegister(&e, waiter.EventIn) |
| n, err := ep.Write(buffer.View(outPkt), nil) |
| if err != nil { |
| c.t.Fatalf("Write failed: %v\n", err) |
| } else if n != uintptr(len(outPkt)) { |
| c.t.Fatalf("Write was short: %v\n", n) |
| } |
| |
| // Avoid reading until we have something to read |
| select { |
| case <-time.After(1 * time.Second): |
| c.t.Fatalf("Timed out waiting for socket to be readable") |
| case <-ready: |
| } |
| wq.EventUnregister(&e) |
| inPkt, err := ep.Read(nil) |
| if err != nil { |
| c.t.Fatalf("Read failed: %v\n", err) |
| } |
| |
| // Verify the contents of the packet we just read. |
| var icmp header.ICMPv4 = []byte(inPkt) |
| if icmp.Type() != header.ICMPv4EchoReply { |
| c.t.Fatalf("Unexpected packet type: %d", icmp.Type()) |
| } |
| inSeqno := binary.BigEndian.Uint16(inPkt[header.ICMPv4MinimumSize+2 : header.ICMPv4MinimumSize+4]) |
| if inSeqno != seqno { |
| c.t.Fatalf("Unexpected sequence number: %d", inSeqno) |
| } |
| outData := outPkt[header.ICMPv4EchoMinimumSize:] |
| inData := inPkt[header.ICMPv4EchoMinimumSize:] |
| if len(outData) != len(inData) { |
| c.t.Fatalf("Read packet of unexpected length: %d\n", len(inData)) |
| } |
| for i := range inData { |
| if inData[i] != outData[i] { |
| c.t.Fatalf("Data mismatch") |
| } |
| } |
| } |
| } |
| |
| data := []byte{0xaa, 0xab, 0xac} |
| go ping(wq0, ep0, data) |
| data = []byte{0xad, 0xae, 0xaf} |
| go ping(wq1, ep1, data) |
| |
| ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) |
| defer cancel() |
| |
| stats := make(map[header.ICMPv4Type]int) |
| for { |
| select { |
| case <-ctx.Done(): |
| t.Errorf("Timeout waiting for ICMP, got: %#+v", stats) |
| return |
| case icmp := <-c.icmpCh: |
| if icmp.Type() != header.ICMPv4Echo && icmp.Type() != header.ICMPv4EchoReply { |
| c.t.Fatalf("Unexpected type: %d", icmp.Type()) |
| } |
| stats[icmp.Type()]++ |
| if stats[icmp.Type()] > echos*2 { |
| c.t.Fatalf("Too many (%d) packets of type %d", stats[icmp.Type()], icmp.Type()) |
| } |
| if len(stats) == 2 && stats[header.ICMPv4Echo] == echos*2 && stats[header.ICMPv4EchoReply] == echos*2 { |
| return |
| } |
| } |
| } |
| } |