| // Copyright 2018 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| package filter |
| |
| import ( |
| "testing" |
| |
| "github.com/google/netstack/tcpip" |
| "github.com/google/netstack/tcpip/buffer" |
| "github.com/google/netstack/tcpip/header" |
| ) |
| |
| func TestRewritePacketICMPv4(t *testing.T) { |
| var tests = []struct { |
| packet func() (buffer.Prependable, buffer.VectorisedView) |
| newAddr tcpip.Address |
| isSource bool |
| }{ |
| { |
| func() (buffer.Prependable, buffer.VectorisedView) { |
| return icmpV4Packet([]byte("payload."), &icmpV4Params{ |
| srcAddr: "\x0a\x00\x00\x00", |
| dstAddr: "\x0a\x00\x00\x02", |
| icmpV4Type: header.ICMPv4EchoReply, |
| code: 0, |
| }) |
| }, |
| "\x0b\x00\x00\x00", |
| true, |
| }, |
| } |
| |
| for _, test := range tests { |
| hdr, payload := test.packet() |
| ipv4 := header.IPv4(hdr.View()) |
| transportHeader := ipv4[ipv4.HeaderLength():] |
| icmpv4 := header.ICMPv4(transportHeader) |
| |
| // Make sure the checksum in the original packet is correct. |
| iCksum := ipv4.CalculateChecksum() |
| if got, want := iCksum, uint16(0xffff); got != want { |
| t.Errorf("ipv4 checksum=%x, want=%x", got, want) |
| } |
| |
| if got, want := header.Checksum(icmpv4, header.Checksum(payload.ToView(), 0)), uint16(0xffff); got != want { |
| t.Errorf("icmpv4 checksum=%x, want=%x", got, want) |
| } |
| |
| rewritePacketICMPv4(test.newAddr, test.isSource, hdr, transportHeader) |
| |
| if test.isSource { |
| if got, want := ipv4.SourceAddress(), test.newAddr; got != want { |
| t.Errorf("ipv4.SourceAddress()=%v, want=%v", got, want) |
| } |
| } else { |
| if got, want := ipv4.DestinationAddress(), test.newAddr; got != want { |
| t.Errorf("ipv4.DestinationAddress()=%v, want=%v", got, want) |
| } |
| } |
| |
| // Check if the checksum in the rewritten packet is correct. |
| iCksum = ipv4.CalculateChecksum() |
| if got, want := iCksum, uint16(0xffff); got != want { |
| t.Errorf("ipv4 checksum=%x, want=%x", got, want) |
| } |
| |
| if got, want := header.Checksum(icmpv4, header.Checksum(payload.ToView(), 0)), uint16(0xffff); got != want { |
| t.Errorf("icmpv4 checksum=%x, want=%x", got, want) |
| } |
| } |
| } |
| |
| func TestRewritePacketUDPv4(t *testing.T) { |
| var tests = []struct { |
| packet func() (buffer.Prependable, buffer.VectorisedView) |
| newAddr tcpip.Address |
| newPort uint16 |
| isSource bool |
| }{ |
| { |
| func() (buffer.Prependable, buffer.VectorisedView) { |
| return udpV4Packet([]byte("payload"), &udpParams{ |
| srcAddr: "\x0a\x00\x00\x00", |
| srcPort: 100, |
| dstAddr: "\x0a\x00\x00\x02", |
| dstPort: 200, |
| }) |
| }, |
| "\x0b\x00\x00\x00", |
| 101, |
| true, |
| }, |
| { |
| func() (buffer.Prependable, buffer.VectorisedView) { |
| return udpV4Packet([]byte("payload"), &udpParams{ |
| srcAddr: "\x0a\x00\x00\x00", |
| srcPort: 100, |
| dstAddr: "\x0a\x00\x00\x02", |
| dstPort: 200, |
| noUDPChecksum: true, |
| }) |
| }, |
| "\x0b\x00\x00\x00", |
| 101, |
| false, |
| }, |
| } |
| |
| for _, test := range tests { |
| hdr, payload := test.packet() |
| ipv4 := header.IPv4(hdr.View()) |
| transportHeader := ipv4[ipv4.HeaderLength():] |
| udp := header.UDP(transportHeader) |
| |
| // Make sure the checksum in the original packet is correct. |
| iCksum := ipv4.CalculateChecksum() |
| if got, want := iCksum, uint16(0xffff); got != want { |
| t.Errorf("ipv4 checksum=%x, want=%x", got, want) |
| } |
| |
| noUDPChecksum := false |
| if udp.Checksum() == 0 { |
| noUDPChecksum = true |
| } else { |
| tCksum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, |
| ipv4.SourceAddress(), ipv4.DestinationAddress()) |
| tCksum = header.Checksum(payload.ToView(), tCksum) |
| tCksum = udp.CalculateChecksum(tCksum, uint16(len(transportHeader)+payload.Size())) |
| if got, want := tCksum, uint16(0xffff); got != want { |
| t.Errorf("udp checksum=%x, want=%x", got, want) |
| } |
| } |
| |
| rewritePacketUDPv4(test.newAddr, test.newPort, test.isSource, hdr, transportHeader) |
| |
| if test.isSource { |
| if got, want := ipv4.SourceAddress(), test.newAddr; got != want { |
| t.Errorf("ipv4.SourceAddress()=%v, want=%v", got, want) |
| } |
| if got, want := udp.SourcePort(), test.newPort; got != want { |
| t.Errorf("ipv4.SourcePort()=%v, want=%v", got, want) |
| } |
| } else { |
| if got, want := ipv4.DestinationAddress(), test.newAddr; got != want { |
| t.Errorf("ipv4.DestinationAddress()=%v, want=%v", got, want) |
| } |
| if got, want := udp.DestinationPort(), test.newPort; got != want { |
| t.Errorf("ipv4.DestinationPort()=%v, want=%v", got, want) |
| } |
| } |
| |
| // Check if the checksum in the rewritten packet is correct. |
| iCksum = ipv4.CalculateChecksum() |
| if got, want := iCksum, uint16(0xffff); got != want { |
| t.Errorf("ipv4 checksum=%x, want=%x", got, want) |
| } |
| |
| if noUDPChecksum { |
| if got, want := udp.Checksum(), uint16(0); got != want { |
| t.Errorf("udp checksum=%x, want=%x", got, want) |
| } |
| } else { |
| tCksum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, |
| ipv4.SourceAddress(), ipv4.DestinationAddress()) |
| tCksum = header.Checksum(payload.ToView(), tCksum) |
| tCksum = udp.CalculateChecksum(tCksum, uint16(len(transportHeader)+payload.Size())) |
| if got, want := tCksum, uint16(0xffff); got != want { |
| t.Errorf("udp checksum=%x, want=%x", got, want) |
| } |
| } |
| } |
| } |
| |
| func TestRewritePacketTCPv4(t *testing.T) { |
| var tests = []struct { |
| packet func() (buffer.Prependable, buffer.VectorisedView) |
| newAddr tcpip.Address |
| newPort uint16 |
| isSource bool |
| }{ |
| { |
| func() (buffer.Prependable, buffer.VectorisedView) { |
| return tcpV4Packet([]byte("payload"), &tcpParams{ |
| srcAddr: "\x0a\x00\x00\x00", |
| srcPort: 100, |
| dstAddr: "\x0a\x00\x00\x02", |
| dstPort: 200, |
| }) |
| }, |
| "\x0b\x00\x00\x00", |
| 101, |
| true, |
| }, |
| { |
| func() (buffer.Prependable, buffer.VectorisedView) { |
| return tcpV4Packet([]byte("payload"), &tcpParams{ |
| srcAddr: "\x0a\x00\x00\x00", |
| srcPort: 100, |
| dstAddr: "\x0a\x00\x00\x02", |
| dstPort: 200, |
| }) |
| }, |
| "\x0b\x00\x00\x00", |
| 101, |
| false, |
| }, |
| } |
| |
| for _, test := range tests { |
| hdr, payload := test.packet() |
| ipv4 := header.IPv4(hdr.View()) |
| transportHeader := ipv4[ipv4.HeaderLength():] |
| tcp := header.TCP(transportHeader) |
| |
| // Make sure the checksum in the original packet is correct. |
| iCksum := ipv4.CalculateChecksum() |
| if got, want := iCksum, uint16(0xffff); got != want { |
| t.Errorf("ipv4 checksum=%x, want=%x", got, want) |
| } |
| |
| tCksum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, |
| ipv4.SourceAddress(), ipv4.DestinationAddress()) |
| tCksum = header.Checksum(payload.ToView(), tCksum) |
| tCksum = tcp.CalculateChecksum(tCksum, uint16(len(transportHeader)+payload.Size())) |
| if got, want := tCksum, uint16(0xffff); got != want { |
| t.Errorf("tcp checksum=%x, want=%x", got, want) |
| } |
| |
| rewritePacketTCPv4(test.newAddr, test.newPort, test.isSource, hdr, transportHeader) |
| |
| if test.isSource { |
| if got, want := ipv4.SourceAddress(), test.newAddr; got != want { |
| t.Errorf("ipv4.SourceAddress()=%v, want=%v", got, want) |
| } |
| if got, want := tcp.SourcePort(), test.newPort; got != want { |
| t.Errorf("ipv4.SourcePort()=%v, want=%v", got, want) |
| } |
| } else { |
| if got, want := ipv4.DestinationAddress(), test.newAddr; got != want { |
| t.Errorf("ipv4.DestinationAddress()=%v, want=%v", got, want) |
| } |
| if got, want := tcp.DestinationPort(), test.newPort; got != want { |
| t.Errorf("ipv4.DestinationPort()=%v, want=%v", got, want) |
| } |
| } |
| |
| // Check if the checksum in the rewritten packet is correct. |
| iCksum = ipv4.CalculateChecksum() |
| if got, want := iCksum, uint16(0xffff); got != want { |
| t.Errorf("ipv4 checksum=%x, want=%x", got, want) |
| } |
| |
| tCksum = header.PseudoHeaderChecksum(header.TCPProtocolNumber, |
| ipv4.SourceAddress(), ipv4.DestinationAddress()) |
| tCksum = header.Checksum(payload.ToView(), tCksum) |
| tCksum = tcp.CalculateChecksum(tCksum, uint16(len(transportHeader)+payload.Size())) |
| if got, want := tCksum, uint16(0xffff); got != want { |
| t.Errorf("tcp checksum=%x, want=%x", got, want) |
| } |
| } |
| } |
| |
| func TestRewritePacketUDPv6(t *testing.T) { |
| var tests = []struct { |
| packet func() (buffer.Prependable, buffer.VectorisedView) |
| newAddr tcpip.Address |
| newPort uint16 |
| isSource bool |
| }{ |
| { |
| func() (buffer.Prependable, buffer.VectorisedView) { |
| return udpV6Packet([]byte("payload"), &udpParams{ |
| srcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", |
| srcPort: 100, |
| dstAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02", |
| dstPort: 200, |
| }) |
| }, |
| "\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", |
| 101, |
| true, |
| }, |
| } |
| |
| for _, test := range tests { |
| hdr, payload := test.packet() |
| ipv6 := header.IPv6(hdr.View()) |
| transportHeader := ipv6[header.IPv6MinimumSize:] |
| udp := header.UDP(transportHeader) |
| |
| // Make sure the checksum in the original packet is correct. |
| noUDPChecksum := false |
| if udp.Checksum() == 0 { |
| noUDPChecksum = true |
| } else { |
| tCksum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, |
| ipv6.SourceAddress(), ipv6.DestinationAddress()) |
| tCksum = header.Checksum(payload.ToView(), tCksum) |
| tCksum = udp.CalculateChecksum(tCksum, uint16(len(transportHeader)+payload.Size())) |
| if got, want := tCksum, uint16(0xffff); got != want { |
| t.Errorf("udp checksum=%x, want=%x", got, want) |
| } |
| } |
| |
| rewritePacketUDPv6(test.newAddr, test.newPort, test.isSource, hdr, transportHeader) |
| |
| if test.isSource { |
| if got, want := ipv6.SourceAddress(), test.newAddr; got != want { |
| t.Errorf("ipv6.SourceAddress()=%v, want=%v", got, want) |
| } |
| if got, want := udp.SourcePort(), test.newPort; got != want { |
| t.Errorf("ipv6.SourcePort()=%v, want=%v", got, want) |
| } |
| } else { |
| if got, want := ipv6.DestinationAddress(), test.newAddr; got != want { |
| t.Errorf("ipv6.DestinationAddress()=%v, want=%v", got, want) |
| } |
| if got, want := udp.DestinationPort(), test.newPort; got != want { |
| t.Errorf("ipv6.DestinationPort()=%v, want=%v", got, want) |
| } |
| } |
| |
| // Check if the checksum in the rewritten packet is correct. |
| if noUDPChecksum { |
| if got, want := udp.Checksum(), uint16(0); got != want { |
| t.Errorf("udp checksum=%x, want=%x", got, want) |
| } |
| } else { |
| tCksum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, |
| ipv6.SourceAddress(), ipv6.DestinationAddress()) |
| tCksum = header.Checksum(payload.ToView(), tCksum) |
| tCksum = udp.CalculateChecksum(tCksum, uint16(len(transportHeader)+payload.Size())) |
| if got, want := tCksum, uint16(0xffff); got != want { |
| t.Errorf("udp checksum=%x, want=%x", got, want) |
| } |
| } |
| } |
| } |
| |
| func TestRewritePacketTCPv6(t *testing.T) { |
| var tests = []struct { |
| packet func() (buffer.Prependable, buffer.VectorisedView) |
| newAddr tcpip.Address |
| newPort uint16 |
| isSource bool |
| }{ |
| { |
| func() (buffer.Prependable, buffer.VectorisedView) { |
| return tcpV6Packet([]byte("payload"), &tcpParams{ |
| srcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", |
| srcPort: 100, |
| dstAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02", |
| dstPort: 200, |
| }) |
| }, |
| "\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", |
| 101, |
| true, |
| }, |
| } |
| |
| for _, test := range tests { |
| hdr, payload := test.packet() |
| ipv6 := header.IPv6(hdr.View()) |
| transportHeader := ipv6[header.IPv6MinimumSize:] |
| tcp := header.TCP(transportHeader) |
| |
| // Make sure the checksum in the original packet is correct. |
| tCksum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, |
| ipv6.SourceAddress(), ipv6.DestinationAddress()) |
| tCksum = header.Checksum(payload.ToView(), tCksum) |
| tCksum = tcp.CalculateChecksum(tCksum, uint16(len(transportHeader)+payload.Size())) |
| if got, want := tCksum, uint16(0xffff); got != want { |
| t.Errorf("tcp checksum=%x, want=%x", got, want) |
| } |
| |
| rewritePacketTCPv6(test.newAddr, test.newPort, test.isSource, hdr, transportHeader) |
| |
| if test.isSource { |
| if got, want := ipv6.SourceAddress(), test.newAddr; got != want { |
| t.Errorf("ipv6.SourceAddress()=%v, want=%v", got, want) |
| } |
| if got, want := tcp.SourcePort(), test.newPort; got != want { |
| t.Errorf("ipv6.SourcePort()=%v, want=%v", got, want) |
| } |
| } else { |
| if got, want := ipv6.DestinationAddress(), test.newAddr; got != want { |
| t.Errorf("ipv6.DestinationAddress()=%v, want=%v", got, want) |
| } |
| if got, want := tcp.DestinationPort(), test.newPort; got != want { |
| t.Errorf("ipv6.DestinationPort()=%v, want=%v", got, want) |
| } |
| } |
| |
| // Check if the checksum in the rewritten packet is correct. |
| tCksum = header.PseudoHeaderChecksum(header.TCPProtocolNumber, |
| ipv6.SourceAddress(), ipv6.DestinationAddress()) |
| tCksum = header.Checksum(payload.ToView(), tCksum) |
| tCksum = tcp.CalculateChecksum(tCksum, uint16(len(transportHeader)+payload.Size())) |
| if got, want := tCksum, uint16(0xffff); got != want { |
| t.Errorf("tcp checksum=%x, want=%x", got, want) |
| } |
| } |
| } |