| // Copyright 2016 The Netstack Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // Package checker provides helper functions to check networking packets for |
| // validity. |
| package checker |
| |
| import ( |
| "testing" |
| |
| "github.com/google/netstack/tcpip" |
| "github.com/google/netstack/tcpip/header" |
| ) |
| |
| // NetworkChecker is a function to check a property of a network packet. |
| type NetworkChecker func(*testing.T, header.Network) |
| |
| // TransportChecker is a function to check a property of a transport packet. |
| type TransportChecker func(*testing.T, header.Transport) |
| |
| // IPv4 checks the validity and properties of the given ipv4 packet. It is |
| // expected to be used in conjunction with other network checkers for specific |
| // properties. For example, to check the source and destination address, one |
| // would call: |
| // |
| // checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y)) |
| func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { |
| ipv4 := header.IPv4(b) |
| |
| if !ipv4.IsValid(len(b)) { |
| t.Fatalf("Not a valid IPv4 packet") |
| } |
| |
| xsum := ipv4.CalculateChecksum() |
| if xsum != 0 && xsum != 0xffff { |
| t.Fatalf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum()) |
| } |
| |
| for _, f := range checkers { |
| f(t, ipv4) |
| } |
| } |
| |
| // IPv6 checks the validity and properties of the given ipv4 packet. The usage |
| // is similar to IPv4. |
| func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) { |
| ipv6 := header.IPv6(b) |
| if !ipv6.IsValid(len(b)) { |
| t.Fatalf("Not a valid IPv4 packet") |
| } |
| |
| for _, f := range checkers { |
| f(t, ipv6) |
| } |
| } |
| |
| // SrcAddr creates a checker that checks the source address. |
| func SrcAddr(addr tcpip.Address) NetworkChecker { |
| return func(t *testing.T, h header.Network) { |
| if a := h.SourceAddress(); a != addr { |
| t.Fatalf("Bad source address, got %v, want %v", a, addr) |
| } |
| } |
| } |
| |
| // DstAddr creates a checker that checks the destination address. |
| func DstAddr(addr tcpip.Address) NetworkChecker { |
| return func(t *testing.T, h header.Network) { |
| if a := h.DestinationAddress(); a != addr { |
| t.Fatalf("Bad destination address, got %v, want %v", a, addr) |
| } |
| } |
| } |
| |
| // PayloadLen creates a checker that checks the payload length. |
| func PayloadLen(plen int) NetworkChecker { |
| return func(t *testing.T, h header.Network) { |
| if l := len(h.Payload()); l != plen { |
| t.Fatalf("Bad payload length, got %v, want %v", l, plen) |
| } |
| } |
| } |
| |
| // FragmentOffset creates a checker that checks the FragmentOffset field. |
| func FragmentOffset(offset uint16) NetworkChecker { |
| return func(t *testing.T, h header.Network) { |
| // We only do this of IPv4 for now. |
| switch ip := h.(type) { |
| case header.IPv4: |
| if v := ip.FragmentOffset(); v != offset { |
| t.Fatalf("Bad fragment offset, got %v, want %v", v, offset) |
| } |
| } |
| } |
| } |
| |
| // FragmentFlags creates a checker that checks the fragment flags field. |
| func FragmentFlags(flags uint8) NetworkChecker { |
| return func(t *testing.T, h header.Network) { |
| // We only do this of IPv4 for now. |
| switch ip := h.(type) { |
| case header.IPv4: |
| if v := ip.Flags(); v != flags { |
| t.Fatalf("Bad fragment offset, got %v, want %v", v, flags) |
| } |
| } |
| } |
| } |
| |
| // TOS creates a checker that checks the TOS field. |
| func TOS(tos uint8, label uint32) NetworkChecker { |
| return func(t *testing.T, h header.Network) { |
| if v, l := h.TOS(); v != tos || l != label { |
| t.Fatalf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label) |
| } |
| } |
| } |
| |
| // TCP creates a checker that checks that the transport protocol is TCP and |
| // potentially additional transport header fields. |
| func TCP(checkers ...TransportChecker) NetworkChecker { |
| return func(t *testing.T, h header.Network) { |
| if p := h.TransportProtocol(); p != header.TCPProtocolNumber { |
| t.Fatalf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber) |
| } |
| |
| // Verify the checksum. |
| tcp := header.TCP(h.Payload()) |
| l := uint16(len(tcp)) |
| |
| xsum := header.Checksum([]byte(h.SourceAddress()), 0) |
| xsum = header.Checksum([]byte(h.DestinationAddress()), xsum) |
| xsum = header.Checksum([]byte{0, byte(h.TransportProtocol())}, xsum) |
| xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum) |
| xsum = header.Checksum(tcp, xsum) |
| |
| if xsum != 0 && xsum != 0xffff { |
| t.Fatalf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum()) |
| } |
| |
| // Run the transport checkers. |
| for _, f := range checkers { |
| f(t, tcp) |
| } |
| } |
| } |
| |
| // UDP creates a checker that checks that the transport protocol is UDP and |
| // potentially additional transport header fields. |
| func UDP(checkers ...TransportChecker) NetworkChecker { |
| return func(t *testing.T, h header.Network) { |
| if p := h.TransportProtocol(); p != header.UDPProtocolNumber { |
| t.Fatalf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) |
| } |
| |
| udp := header.UDP(h.Payload()) |
| for _, f := range checkers { |
| f(t, udp) |
| } |
| } |
| } |
| |
| // SrcPort creates a checker that checks the source port. |
| func SrcPort(port uint16) TransportChecker { |
| return func(t *testing.T, h header.Transport) { |
| if p := h.SourcePort(); p != port { |
| t.Fatalf("Bad source port, got %v, want %v", p, port) |
| } |
| } |
| } |
| |
| // DstPort creates a checker that checks the destination port. |
| func DstPort(port uint16) TransportChecker { |
| return func(t *testing.T, h header.Transport) { |
| if p := h.DestinationPort(); p != port { |
| t.Fatalf("Bad destination port, got %v, want %v", p, port) |
| } |
| } |
| } |
| |
| // SeqNum creates a checker that checks the sequence number. |
| func SeqNum(seq uint32) TransportChecker { |
| return func(t *testing.T, h header.Transport) { |
| tcp, ok := h.(header.TCP) |
| if !ok { |
| return |
| } |
| |
| if s := tcp.SequenceNumber(); s != seq { |
| t.Fatalf("Bad sequence number, got %v, want %v", s, seq) |
| } |
| } |
| } |
| |
| // AckNum creates a checker that checks the ack number. |
| func AckNum(seq uint32) TransportChecker { |
| return func(t *testing.T, h header.Transport) { |
| tcp, ok := h.(header.TCP) |
| if !ok { |
| return |
| } |
| |
| if s := tcp.AckNumber(); s != seq { |
| t.Fatalf("Bad ack number, got %v, want %v", s, seq) |
| } |
| } |
| } |
| |
| // Window creates a checker that checks the tcp window. |
| func Window(window uint16) TransportChecker { |
| return func(t *testing.T, h header.Transport) { |
| tcp, ok := h.(header.TCP) |
| if !ok { |
| return |
| } |
| |
| if w := tcp.WindowSize(); w != window { |
| t.Fatalf("Bad window, got 0x%x, want 0x%x", w, window) |
| } |
| } |
| } |
| |
| // TCPFlags creates a checker that checks the tcp flags. |
| func TCPFlags(flags uint8) TransportChecker { |
| return func(t *testing.T, h header.Transport) { |
| tcp, ok := h.(header.TCP) |
| if !ok { |
| return |
| } |
| |
| if f := tcp.Flags(); f != flags { |
| t.Fatalf("Bad flags, got 0x%x, want 0x%x", f, flags) |
| } |
| } |
| } |
| |
| // TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the |
| // given mask, match the supplied flags. |
| func TCPFlagsMatch(flags, mask uint8) TransportChecker { |
| return func(t *testing.T, h header.Transport) { |
| tcp, ok := h.(header.TCP) |
| if !ok { |
| return |
| } |
| |
| if f := tcp.Flags(); (f & mask) != (flags & mask) { |
| t.Fatalf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask) |
| } |
| } |
| } |
| |
| // TCPSynOptions creates a checker that checks the presence of TCP options in |
| // SYN segments. |
| // |
| // If wndscale is negative, the window scale option must not be present. |
| func TCPSynOptions(mss uint16, wndscale int) TransportChecker { |
| return func(t *testing.T, h header.Transport) { |
| tcp, ok := h.(header.TCP) |
| if !ok { |
| return |
| } |
| |
| offset := int(tcp.DataOffset()) |
| opts := []byte(tcp[header.TCPMinimumSize:offset]) |
| limit := len(opts) |
| foundMSS := false |
| foundWS := false |
| for i := 0; i < limit; { |
| switch opts[i] { |
| case header.TCPOptionEOL: |
| i = limit |
| case header.TCPOptionNOP: |
| i++ |
| case header.TCPOptionMSS: |
| v := uint16(opts[i+2])<<8 | uint16(opts[i+3]) |
| if mss != v { |
| t.Fatalf("Bad MSS: got %v, want %v", v, mss) |
| } |
| foundMSS = true |
| i += 4 |
| case header.TCPOptionWS: |
| if wndscale < 0 { |
| t.Fatalf("WS present when it shouldn't be") |
| } |
| v := int(opts[i+2]) |
| if v != wndscale { |
| t.Fatalf("Bad WS: got %v, want %v", v, wndscale) |
| } |
| foundWS = true |
| i += 3 |
| default: |
| i += int(opts[i+1]) |
| } |
| } |
| |
| if !foundMSS { |
| t.Fatalf("MSS option not found. Options: %x", opts) |
| } |
| |
| if !foundWS && wndscale >= 0 { |
| t.Fatalf("WS option not found. Options: %x", opts) |
| } |
| } |
| } |