| // Copyright 2022 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| package udp_serde |
| |
| import ( |
| "errors" |
| "fmt" |
| "math" |
| "testing" |
| "time" |
| |
| "go.fuchsia.dev/fuchsia/src/connectivity/network/netstack/util" |
| "gvisor.dev/gvisor/pkg/tcpip" |
| "gvisor.dev/gvisor/pkg/tcpip/header" |
| "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" |
| ) |
| |
| const preludeOffset = 8 |
| |
| const ( |
| testPort uint16 = 42 |
| testIpTtl uint8 = 43 |
| testIpv6Hoplimit uint8 = 44 |
| testNICID tcpip.NICID = 45 |
| testIpTos uint8 = 46 |
| testIpv6Tclass uint32 = 47 |
| testTimestampNanos int64 = 48 |
| testPayloadSize int = 49 |
| invalidIpTtl uint8 = 0 |
| ) |
| |
| var ( |
| ipv4Loopback = util.Parse("127.0.0.1") |
| ipv6Loopback = util.Parse("::1") |
| ipv6LinkLocal = util.Parse("fe80::1") |
| ) |
| |
| func TestSerializeThenDeserializeSendMsgMeta(t *testing.T) { |
| for _, netProto := range []tcpip.NetworkProtocolNumber{ |
| header.IPv4ProtocolNumber, |
| header.IPv6ProtocolNumber, |
| } { |
| t.Run(fmt.Sprintf("%d", netProto), func(t *testing.T) { |
| buf := make([]byte, TxUdpPreludeSize()) |
| addr := tcpip.FullAddress{ |
| Port: testPort, |
| } |
| var cmsgSet tcpip.SendableControlMessages |
| |
| switch netProto { |
| case header.IPv4ProtocolNumber: |
| addr.Addr = ipv4Loopback |
| cmsgSet.HasTTL = true |
| cmsgSet.TTL = testIpTtl |
| case header.IPv6ProtocolNumber: |
| addr.Addr = ipv6Loopback |
| cmsgSet.HasHopLimit = true |
| cmsgSet.HopLimit = testIpv6Hoplimit |
| cmsgSet.HasIPv6PacketInfo = true |
| cmsgSet.IPv6PacketInfo = tcpip.IPv6PacketInfo{ |
| NIC: testNICID, |
| Addr: ipv6Loopback, |
| } |
| addr.NIC = testNICID |
| } |
| |
| if err := SerializeSendMsgMeta(tcpip.NetworkProtocolNumber(netProto), addr, cmsgSet, buf); err != nil { |
| t.Fatalf("got SerializeSendMsgMeta(%d, %#v, %#v, _) = (%#v), want (%#v)", netProto, addr, cmsgSet, err, nil) |
| } |
| |
| deserializedAddr, deserializedCmsgSet, err := DeserializeSendMsgMeta(buf) |
| |
| if err != nil { |
| t.Fatalf("expect DeserializeSendMsgMeta(_) succeeds, got: %s", err) |
| } |
| |
| wantAddr := tcpip.FullAddress{ |
| Port: addr.Port, |
| Addr: addr.Addr, |
| // Expect the NICID set in the IPv6 case above is not serialized because the address is non-link local. |
| NIC: 0, |
| } |
| |
| if got, want := *deserializedAddr, wantAddr; got != want { |
| t.Errorf("got address after serde = (%#v), want (%#v)", got, want) |
| } |
| |
| if got, want := deserializedCmsgSet, cmsgSet; got != want { |
| t.Errorf("got cmsg set after serde = (%#v), want (%#v)", got, want) |
| } |
| }) |
| } |
| } |
| |
| func TestSerializeThenDeserializeSendMsgMetaWithLinkLocalIPv6Addr(t *testing.T) { |
| buf := make([]byte, TxUdpPreludeSize()) |
| addr := tcpip.FullAddress{ |
| Port: testPort, |
| Addr: ipv6LinkLocal, |
| NIC: testNICID, |
| } |
| var cmsgSet tcpip.SendableControlMessages |
| |
| if err := SerializeSendMsgMeta(ipv6.ProtocolNumber, addr, cmsgSet, buf); err != nil { |
| t.Fatalf("got SerializeSendMsgMeta(%d, %#v, %#v, _) = (%#v), want (%#v)", ipv6.ProtocolNumber, addr, cmsgSet, err, nil) |
| } |
| |
| deserializedAddr, _, err := DeserializeSendMsgMeta(buf) |
| |
| if err != nil { |
| t.Fatalf("expect DeserializeSendMsgMeta(_) succeeds, got: %s", err) |
| } |
| |
| if got, want := *deserializedAddr, addr; got != want { |
| t.Errorf("got address after serde = (%#v), want (%#v)", got, want) |
| } |
| } |
| |
| func TestSerializeThenDeserializeSendMsgMetaWithUnspecifiedAddrs(t *testing.T) { |
| for _, netProto := range []tcpip.NetworkProtocolNumber{ |
| header.IPv4ProtocolNumber, |
| header.IPv6ProtocolNumber, |
| } { |
| t.Run(fmt.Sprintf("%d", netProto), func(t *testing.T) { |
| buf := make([]byte, TxUdpPreludeSize()) |
| addr := tcpip.FullAddress{ |
| Port: testPort, |
| } |
| cmsgSet := tcpip.SendableControlMessages{} |
| if netProto == header.IPv6ProtocolNumber { |
| cmsgSet.HasIPv6PacketInfo = true |
| cmsgSet.IPv6PacketInfo = tcpip.IPv6PacketInfo{ |
| NIC: testNICID, |
| } |
| } |
| |
| if err := SerializeSendMsgMeta(netProto, addr, cmsgSet, buf); err != nil { |
| t.Fatalf("got SerializeSendMsgMeta(%d, %#v, %#v, _) = (%#v), want (%#v)", netProto, addr, cmsgSet, err, nil) |
| } |
| |
| deserializedAddr, deserializedCmsg, err := DeserializeSendMsgMeta(buf) |
| |
| if err != nil { |
| t.Fatalf("expect DeserializeSendMsgMeta(_) succeeds, got: %s", err) |
| } |
| if got, want := *deserializedAddr, addr; got != want { |
| t.Errorf("got address after serde = (%#v), want (%#v)", got, want) |
| } |
| if got, want := deserializedCmsg, cmsgSet; got != want { |
| t.Errorf("got cmsg after serde = (%#v), want (%#v)", got, want) |
| } |
| }) |
| } |
| } |
| |
| func TestSerializeSendMsgMetaFailures(t *testing.T) { |
| for _, testCase := range []struct { |
| name string |
| getBuffer func([]byte) []byte |
| expectedErr error |
| }{ |
| {"nil buffer", func(buf []byte) []byte { return nil }, &InputBufferNullErr{}}, |
| {"buffer too small", func(buf []byte) []byte { return buf[:preludeOffset-1] }, &InputBufferTooSmallErr{}}, |
| } { |
| |
| t.Run(fmt.Sprintf("%s", testCase.name), func(t *testing.T) { |
| storage := make([]byte, TxUdpPreludeSize()) |
| buf := testCase.getBuffer(storage) |
| |
| addr := tcpip.FullAddress{ |
| Port: testPort, |
| Addr: ipv4Loopback, |
| } |
| cmsgSet := tcpip.SendableControlMessages{} |
| |
| err := SerializeSendMsgMeta(header.IPv4ProtocolNumber, addr, cmsgSet, buf) |
| |
| if got, want := err, testCase.expectedErr; !errors.Is(err, testCase.expectedErr) { |
| t.Errorf("got SerializeSendMsgMeta(%d, %#v, %#v, _) = (%#v), want (%#v)", header.IPv4ProtocolNumber, addr, cmsgSet, got, want) |
| } |
| }) |
| } |
| } |
| |
| func TestDeserializeSendMsgMetaFailures(t *testing.T) { |
| type DeserializeSendMsgMetaErrorCondition int |
| |
| const ( |
| DeserializeSendMsgMetaErrInputBufferNil DeserializeSendMsgMetaErrorCondition = iota |
| DeserializeSendMsgMetaErrInputBufferTooSmall |
| DeserializeSendMsgMetaErrFailedToDecode |
| ) |
| for _, testCase := range []struct { |
| name string |
| errCondition DeserializeSendMsgMetaErrorCondition |
| expectedErr error |
| }{ |
| {"nil buffer", DeserializeSendMsgMetaErrInputBufferNil, &InputBufferNullErr{}}, |
| {"buffer too small", DeserializeSendMsgMetaErrInputBufferTooSmall, &InputBufferTooSmallErr{}}, |
| {"failed to decode", DeserializeSendMsgMetaErrFailedToDecode, &FailedToDecodeErr{}}, |
| } { |
| |
| t.Run(fmt.Sprintf("%s", testCase.name), func(t *testing.T) { |
| buf := make([]byte, TxUdpPreludeSize()) |
| |
| switch DeserializeSendMsgMetaErrorCondition(testCase.errCondition) { |
| case DeserializeSendMsgMetaErrInputBufferNil: |
| buf = nil |
| case DeserializeSendMsgMetaErrInputBufferTooSmall: |
| buf = buf[:preludeOffset-1] |
| case DeserializeSendMsgMetaErrFailedToDecode: |
| } |
| |
| _, _, err := DeserializeSendMsgMeta(buf) |
| |
| if got, want := err, testCase.expectedErr; !errors.Is(err, testCase.expectedErr) { |
| t.Errorf("got DeserializeSendMsgMeta(_) = (_, _, %#v), want (_, _, %#v)", got, want) |
| } |
| }) |
| } |
| } |
| |
| func TestSerializeRecvMsgMetaFailures(t *testing.T) { |
| type SerializeRecvMsgMetaErrorCondition int |
| |
| const ( |
| SerializeRecvMsgMetaErrOutputBufferNil SerializeRecvMsgMetaErrorCondition = iota |
| SerializeRecvMsgMetaErrOutputBufferTooSmall |
| SerializeRecvMsgMetaErrPayloadTooLarge |
| ) |
| |
| const maxPayloadSize = int(math.MaxUint16) |
| const tooBigPayloadSize = maxPayloadSize + 1 |
| |
| for _, testCase := range []struct { |
| name string |
| errCondition SerializeRecvMsgMetaErrorCondition |
| expectedErr error |
| }{ |
| {"nil buffer", SerializeRecvMsgMetaErrOutputBufferNil, &InputBufferNullErr{}}, |
| {"buffer too small", SerializeRecvMsgMetaErrOutputBufferTooSmall, &InputBufferTooSmallErr{}}, |
| {"payload too large", SerializeRecvMsgMetaErrPayloadTooLarge, &PayloadSizeExceedsMaxAllowedErr{payloadSize: tooBigPayloadSize, maxAllowed: maxPayloadSize}}, |
| } { |
| for _, netProto := range []tcpip.NetworkProtocolNumber{ |
| header.IPv4ProtocolNumber, |
| header.IPv6ProtocolNumber, |
| } { |
| t.Run(fmt.Sprintf("%s %d", testCase.name, netProto), func(t *testing.T) { |
| res := tcpip.ReadResult{} |
| buf := make([]byte, TxUdpPreludeSize()) |
| |
| switch SerializeRecvMsgMetaErrorCondition(testCase.errCondition) { |
| case SerializeRecvMsgMetaErrOutputBufferNil: |
| buf = nil |
| case SerializeRecvMsgMetaErrOutputBufferTooSmall: |
| buf = buf[:preludeOffset-1] |
| case SerializeRecvMsgMetaErrPayloadTooLarge: |
| res.Count = tooBigPayloadSize |
| } |
| |
| err := SerializeRecvMsgMeta(tcpip.NetworkProtocolNumber(netProto), res, buf) |
| |
| if got, want := err, testCase.expectedErr; !errors.Is(err, testCase.expectedErr) { |
| t.Errorf("got SerializeRecvMsgMeta(%d, %#v, _) = (%#v), want (%#v)", netProto, res, got, want) |
| } |
| }) |
| } |
| } |
| } |
| |
| func TestSerializeThenDeserializeRecvMsgMeta(t *testing.T) { |
| for _, netProto := range []tcpip.NetworkProtocolNumber{ |
| header.IPv4ProtocolNumber, |
| header.IPv6ProtocolNumber, |
| } { |
| t.Run(fmt.Sprintf("%d", netProto), func(t *testing.T) { |
| addr := tcpip.FullAddress{ |
| Port: testPort, |
| } |
| cmsgSet := tcpip.ReceivableControlMessages{ |
| HasTimestamp: true, |
| Timestamp: time.Unix(0, testTimestampNanos), |
| } |
| |
| switch netProto { |
| case header.IPv4ProtocolNumber: |
| addr.Addr = ipv4Loopback |
| cmsgSet.HasTTL = true |
| cmsgSet.TTL = testIpTtl |
| cmsgSet.HasTOS = true |
| cmsgSet.TOS = testIpTos |
| case header.IPv6ProtocolNumber: |
| addr.Addr = ipv6Loopback |
| cmsgSet.HasHopLimit = true |
| cmsgSet.HopLimit = testIpv6Hoplimit |
| cmsgSet.HasIPv6PacketInfo = true |
| cmsgSet.IPv6PacketInfo = tcpip.IPv6PacketInfo{ |
| NIC: testNICID, |
| Addr: ipv6Loopback, |
| } |
| addr.NIC = testNICID |
| cmsgSet.HasTClass = true |
| cmsgSet.TClass = testIpv6Tclass |
| } |
| res := tcpip.ReadResult{ |
| ControlMessages: cmsgSet, |
| RemoteAddr: addr, |
| Count: testPayloadSize, |
| } |
| buf := make([]byte, RxUdpPreludeSize()) |
| |
| if err := SerializeRecvMsgMeta(tcpip.NetworkProtocolNumber(netProto), res, buf); err != nil { |
| t.Errorf("got SerializeRecvMsgMeta(%d, %#v, _) = (%#v), want (%#v)", netProto, res, err, nil) |
| } |
| |
| recvMeta, err := DeserializeRecvMsgMeta(buf) |
| |
| if err != nil { |
| t.Fatalf("expect DeserializeRecvMsgMeta(_) succeeds, got: %s", err) |
| } |
| |
| wantAddr := tcpip.FullAddress{ |
| Port: res.RemoteAddr.Port, |
| Addr: res.RemoteAddr.Addr, |
| // Expect the NICID set in the IPv6 case above is not serialized because the address is non-link local. |
| NIC: 0, |
| } |
| |
| if got, want := *recvMeta.addr, wantAddr; got != want { |
| t.Errorf("got address after serde = (%#v), want (%#v)", got, want) |
| } |
| |
| if got, want := recvMeta.control, res.ControlMessages; got != want { |
| t.Errorf("got cmsg set after serde = (%#v), want (%#v)", got, want) |
| } |
| |
| if got, want := recvMeta.payloadSize, uint16(res.Count); got != want { |
| t.Errorf("got payload size after serde = (%d), want (%d)", got, want) |
| } |
| }) |
| } |
| } |
| |
| func TestSerializeThenDeserializeRecvMsgMetaWithLinkLocalIPv6Addr(t *testing.T) { |
| buf := make([]byte, RxUdpPreludeSize()) |
| readResult := tcpip.ReadResult{ |
| RemoteAddr: tcpip.FullAddress{ |
| Port: testPort, |
| Addr: ipv6LinkLocal, |
| NIC: testNICID, |
| }, |
| ControlMessages: tcpip.ReceivableControlMessages{}, |
| } |
| |
| if err := SerializeRecvMsgMeta(ipv6.ProtocolNumber, readResult, buf); err != nil { |
| t.Fatalf("got SerializeRecvMsgMeta(%d, %#v, _) = (%#v), want (%#v)", ipv6.ProtocolNumber, readResult, err, nil) |
| } |
| |
| res, err := DeserializeRecvMsgMeta(buf) |
| |
| if err != nil { |
| t.Fatalf("expect DeserializeRecvMsgMeta(_) succeeds, got: %s", err) |
| } |
| |
| if got, want := *res.addr, readResult.RemoteAddr; got != want { |
| t.Errorf("got address after serde = (%#v), want (%#v)", got, want) |
| } |
| } |
| |
| func TestSerializeThenDeserializeRecvMsgMetaWithUnspecifiedAddrs(t *testing.T) { |
| for _, netProto := range []tcpip.NetworkProtocolNumber{ |
| header.IPv4ProtocolNumber, |
| header.IPv6ProtocolNumber, |
| } { |
| t.Run(fmt.Sprintf("%d", netProto), func(t *testing.T) { |
| buf := make([]byte, RxUdpPreludeSize()) |
| readResult := tcpip.ReadResult{ |
| RemoteAddr: tcpip.FullAddress{ |
| Port: testPort, |
| }, |
| ControlMessages: tcpip.ReceivableControlMessages{}, |
| } |
| |
| if netProto == header.IPv6ProtocolNumber { |
| readResult.ControlMessages.HasIPv6PacketInfo = true |
| readResult.ControlMessages.IPv6PacketInfo = tcpip.IPv6PacketInfo{ |
| NIC: testNICID, |
| } |
| } |
| |
| if err := SerializeRecvMsgMeta(netProto, readResult, buf); err != nil { |
| t.Fatalf("got SerializeRecvMsgMeta(%d, %#v, _) = (%#v), want (%#v)", netProto, readResult, err, nil) |
| } |
| |
| res, err := DeserializeRecvMsgMeta(buf) |
| |
| if err != nil { |
| t.Fatalf("expect DeserializeRecvMsgMeta(_) succeeds, got: %s", err) |
| } |
| if got, want := *res.addr, readResult.RemoteAddr; got != want { |
| t.Errorf("got address after serde = (%#v), want (%#v)", got, want) |
| } |
| if got, want := res.control, readResult.ControlMessages; got != want { |
| t.Errorf("got cmsg after serde = (%#v), want (%#v)", got, want) |
| } |
| }) |
| } |
| } |
| |
| func TestDeserializeRecvMsgMetaFailures(t *testing.T) { |
| type DeserializeRecvMsgMetaErrorCondition int |
| |
| const ( |
| DeserializeRecvMsgMetaErrInputBufferNil DeserializeRecvMsgMetaErrorCondition = iota |
| DeserializeRecvMsgMetaErrInputBufferTooSmall |
| DeserializeRecvMsgMetaErrFailedToDecode |
| ) |
| for _, testCase := range []struct { |
| name string |
| errCondition DeserializeRecvMsgMetaErrorCondition |
| expectedErr error |
| }{ |
| {"nil buffer", DeserializeRecvMsgMetaErrInputBufferNil, &InputBufferNullErr{}}, |
| {"buffer too small", DeserializeRecvMsgMetaErrInputBufferTooSmall, &UnspecifiedDecodingFailure{}}, |
| {"failed to decode", DeserializeRecvMsgMetaErrFailedToDecode, &UnspecifiedDecodingFailure{}}, |
| } { |
| |
| t.Run(fmt.Sprintf("%s", testCase.name), func(t *testing.T) { |
| buf := make([]byte, TxUdpPreludeSize()) |
| |
| switch DeserializeRecvMsgMetaErrorCondition(testCase.errCondition) { |
| case DeserializeRecvMsgMetaErrInputBufferNil: |
| buf = nil |
| case DeserializeRecvMsgMetaErrInputBufferTooSmall: |
| buf = buf[:preludeOffset-1] |
| case DeserializeRecvMsgMetaErrFailedToDecode: |
| } |
| |
| _, err := DeserializeRecvMsgMeta(buf) |
| |
| if got, want := err, testCase.expectedErr; !errors.Is(err, testCase.expectedErr) { |
| t.Errorf("got DeserializeRecvMsgMeta(_) = (_, _, %#v), want (_, _, %#v)", got, want) |
| } |
| }) |
| } |
| } |