| // Copyright 2020 The gVisor Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package ipv6_test |
| |
| import ( |
| "bytes" |
| "testing" |
| "time" |
| |
| "gvisor.dev/gvisor/pkg/tcpip" |
| "gvisor.dev/gvisor/pkg/tcpip/buffer" |
| "gvisor.dev/gvisor/pkg/tcpip/checker" |
| "gvisor.dev/gvisor/pkg/tcpip/faketime" |
| "gvisor.dev/gvisor/pkg/tcpip/header" |
| "gvisor.dev/gvisor/pkg/tcpip/link/channel" |
| "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" |
| "gvisor.dev/gvisor/pkg/tcpip/stack" |
| ) |
| |
| const ( |
| linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" |
| globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" |
| globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" |
| ) |
| |
| var ( |
| linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr) |
| globalAddrSNMC = header.SolicitedNodeAddr(globalAddr) |
| ) |
| |
| func validateMLDPacket(t *testing.T, p buffer.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) { |
| t.Helper() |
| |
| checker.IPv6WithExtHdr(t, p, |
| checker.IPv6ExtHdr( |
| checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), |
| ), |
| checker.SrcAddr(localAddress), |
| checker.DstAddr(remoteAddress), |
| checker.TTL(header.MLDHopLimit), |
| checker.MLD(mldType, header.MLDMinimumSize, |
| checker.MLDMaxRespDelay(0), |
| checker.MLDMulticastAddress(groupAddress), |
| ), |
| ) |
| } |
| |
| func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { |
| const nicID = 1 |
| |
| s := stack.New(stack.Options{ |
| NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ |
| MLD: ipv6.MLDOptions{ |
| Enabled: true, |
| }, |
| })}, |
| }) |
| e := channel.New(1, header.IPv6MinimumMTU, "") |
| if err := s.CreateNIC(nicID, e); err != nil { |
| t.Fatalf("CreateNIC(%d, _): %s", nicID, err) |
| } |
| |
| // The stack will join an address's solicited node multicast address when |
| // an address is added. An MLD report message should be sent for the |
| // solicited-node group. |
| if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { |
| t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) |
| } |
| if p, ok := e.Read(); !ok { |
| t.Fatal("expected a report message to be sent") |
| } else { |
| validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) |
| } |
| |
| // The stack will leave an address's solicited node multicast address when |
| // an address is removed. An MLD done message should be sent for the |
| // solicited-node group. |
| if err := s.RemoveAddress(nicID, linkLocalAddr); err != nil { |
| t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, linkLocalAddr, err) |
| } |
| if p, ok := e.Read(); !ok { |
| t.Fatal("expected a done message to be sent") |
| } else { |
| validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) |
| } |
| } |
| |
| func TestSendQueuedMLDReports(t *testing.T) { |
| const ( |
| nicID = 1 |
| maxReports = 2 |
| ) |
| |
| tests := []struct { |
| name string |
| dadTransmits uint8 |
| retransmitTimer time.Duration |
| }{ |
| { |
| name: "DAD Disabled", |
| dadTransmits: 0, |
| retransmitTimer: 0, |
| }, |
| { |
| name: "DAD Enabled", |
| dadTransmits: 1, |
| retransmitTimer: time.Second, |
| }, |
| } |
| |
| nonce := [...]byte{ |
| 1, 2, 3, 4, 5, 6, |
| } |
| |
| const maxNSMessages = 2 |
| secureRNGBytes := make([]byte, len(nonce)*maxNSMessages) |
| for b := secureRNGBytes[:]; len(b) > 0; b = b[len(nonce):] { |
| if n := copy(b, nonce[:]); n != len(nonce) { |
| t.Fatalf("got copy(...) = %d, want = %d", n, len(nonce)) |
| } |
| } |
| |
| for _, test := range tests { |
| t.Run(test.name, func(t *testing.T) { |
| dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits) |
| clock := faketime.NewManualClock() |
| var secureRNG bytes.Reader |
| secureRNG.Reset(secureRNGBytes[:]) |
| s := stack.New(stack.Options{ |
| SecureRNG: &secureRNG, |
| NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ |
| DADConfigs: stack.DADConfigurations{ |
| DupAddrDetectTransmits: test.dadTransmits, |
| RetransmitTimer: test.retransmitTimer, |
| }, |
| MLD: ipv6.MLDOptions{ |
| Enabled: true, |
| }, |
| })}, |
| Clock: clock, |
| }) |
| |
| // Allow space for an extra packet so we can observe packets that were |
| // unexpectedly sent. |
| e := channel.New(maxReports+int(test.dadTransmits)+1 /* extra */, header.IPv6MinimumMTU, "") |
| if err := s.CreateNIC(nicID, e); err != nil { |
| t.Fatalf("CreateNIC(%d, _): %s", nicID, err) |
| } |
| |
| resolveDAD := func(addr, snmc tcpip.Address) { |
| clock.Advance(dadResolutionTime) |
| if p, ok := e.Read(); !ok { |
| t.Fatal("expected DAD packet") |
| } else { |
| checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), |
| checker.SrcAddr(header.IPv6Any), |
| checker.DstAddr(snmc), |
| checker.TTL(header.NDPHopLimit), |
| checker.NDPNS( |
| checker.NDPNSTargetAddress(addr), |
| checker.NDPNSOptions([]header.NDPOption{header.NDPNonceOption(nonce[:])}), |
| )) |
| } |
| } |
| |
| var reportCounter uint64 |
| reportStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport |
| if got := reportStat.Value(); got != reportCounter { |
| t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) |
| } |
| var doneCounter uint64 |
| doneStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone |
| if got := doneStat.Value(); got != doneCounter { |
| t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) |
| } |
| |
| // Joining a group without an assigned address should send an MLD report |
| // with the unspecified address. |
| if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil { |
| t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err) |
| } |
| reportCounter++ |
| if got := reportStat.Value(); got != reportCounter { |
| t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) |
| } |
| if p, ok := e.Read(); !ok { |
| t.Errorf("expected MLD report for %s", globalMulticastAddr) |
| } else { |
| validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr) |
| } |
| clock.Advance(time.Hour) |
| if p, ok := e.Read(); ok { |
| t.Errorf("got unexpected packet = %#v", p) |
| } |
| if t.Failed() { |
| t.FailNow() |
| } |
| |
| // Adding a global address should not send reports for the already joined |
| // group since we should only send queued reports when a link-local |
| // address is assigned. |
| // |
| // Note, we will still expect to send a report for the global address's |
| // solicited node address from the unspecified address as per RFC 3590 |
| // section 4. |
| if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil { |
| t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err) |
| } |
| reportCounter++ |
| if got := reportStat.Value(); got != reportCounter { |
| t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) |
| } |
| if p, ok := e.Read(); !ok { |
| t.Errorf("expected MLD report for %s", globalAddrSNMC) |
| } else { |
| validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalAddrSNMC, header.ICMPv6MulticastListenerReport, globalAddrSNMC) |
| } |
| if dadResolutionTime != 0 { |
| // Reports should not be sent when the address resolves. |
| resolveDAD(globalAddr, globalAddrSNMC) |
| if got := reportStat.Value(); got != reportCounter { |
| t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) |
| } |
| } |
| // Leave the group since we don't care about the global address's |
| // solicited node multicast group membership. |
| if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalAddrSNMC); err != nil { |
| t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalAddrSNMC, err) |
| } |
| if got := doneStat.Value(); got != doneCounter { |
| t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) |
| } |
| if p, ok := e.Read(); ok { |
| t.Errorf("got unexpected packet = %#v", p) |
| } |
| if t.Failed() { |
| t.FailNow() |
| } |
| |
| // Adding a link-local address should send a report for its solicited node |
| // address and globalMulticastAddr. |
| if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil { |
| t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err) |
| } |
| if dadResolutionTime != 0 { |
| reportCounter++ |
| if got := reportStat.Value(); got != reportCounter { |
| t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) |
| } |
| if p, ok := e.Read(); !ok { |
| t.Errorf("expected MLD report for %s", linkLocalAddrSNMC) |
| } else { |
| validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) |
| } |
| resolveDAD(linkLocalAddr, linkLocalAddrSNMC) |
| } |
| |
| // We expect two batches of reports to be sent (1 batch when the |
| // link-local address is assigned, and another after the maximum |
| // unsolicited report interval. |
| for i := 0; i < 2; i++ { |
| // We expect reports to be sent (one for globalMulticastAddr and another |
| // for linkLocalAddrSNMC). |
| reportCounter += maxReports |
| if got := reportStat.Value(); got != reportCounter { |
| t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) |
| } |
| |
| addrs := map[tcpip.Address]bool{ |
| globalMulticastAddr: false, |
| linkLocalAddrSNMC: false, |
| } |
| for range addrs { |
| p, ok := e.Read() |
| if !ok { |
| t.Fatalf("expected MLD report for %s and %s; addrs = %#v", globalMulticastAddr, linkLocalAddrSNMC, addrs) |
| } |
| |
| addr := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())).DestinationAddress() |
| if seen, ok := addrs[addr]; !ok { |
| t.Fatalf("got unexpected packet destined to %s", addr) |
| } else if seen { |
| t.Fatalf("got another packet destined to %s", addr) |
| } |
| |
| addrs[addr] = true |
| validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, addr, header.ICMPv6MulticastListenerReport, addr) |
| |
| clock.Advance(ipv6.UnsolicitedReportIntervalMax) |
| } |
| } |
| |
| // Should not send any more reports. |
| clock.Advance(time.Hour) |
| if p, ok := e.Read(); ok { |
| t.Errorf("got unexpected packet = %#v", p) |
| } |
| }) |
| } |
| } |
| |
| // createAndInjectMLDPacket creates and injects an MLD packet with the |
| // specified fields. |
| func createAndInjectMLDPacket(e *channel.Endpoint, mldType header.ICMPv6Type, hopLimit uint8, srcAddress tcpip.Address, withRouterAlertOption bool, routerAlertValue header.IPv6RouterAlertValue) { |
| var extensionHeaders header.IPv6ExtHdrSerializer |
| if withRouterAlertOption { |
| extensionHeaders = header.IPv6ExtHdrSerializer{ |
| header.IPv6SerializableHopByHopExtHdr{ |
| &header.IPv6RouterAlertOption{Value: routerAlertValue}, |
| }, |
| } |
| } |
| |
| extensionHeadersLength := extensionHeaders.Length() |
| payloadLength := extensionHeadersLength + header.ICMPv6HeaderSize + header.MLDMinimumSize |
| buf := buffer.NewView(header.IPv6MinimumSize + payloadLength) |
| |
| ip := header.IPv6(buf) |
| ip.Encode(&header.IPv6Fields{ |
| PayloadLength: uint16(payloadLength), |
| HopLimit: hopLimit, |
| TransportProtocol: header.ICMPv6ProtocolNumber, |
| SrcAddr: srcAddress, |
| DstAddr: header.IPv6AllNodesMulticastAddress, |
| ExtensionHeaders: extensionHeaders, |
| }) |
| |
| icmp := header.ICMPv6(ip.Payload()[extensionHeadersLength:]) |
| icmp.SetType(mldType) |
| mld := header.MLD(icmp.MessageBody()) |
| mld.SetMaximumResponseDelay(0) |
| mld.SetMulticastAddress(header.IPv6Any) |
| icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ |
| Header: icmp, |
| Src: srcAddress, |
| Dst: header.IPv6AllNodesMulticastAddress, |
| })) |
| |
| e.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ |
| Data: buf.ToVectorisedView(), |
| })) |
| } |
| |
| func TestMLDPacketValidation(t *testing.T) { |
| const ( |
| nicID = 1 |
| linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") |
| ) |
| |
| tests := []struct { |
| name string |
| messageType header.ICMPv6Type |
| srcAddr tcpip.Address |
| includeRouterAlertOption bool |
| routerAlertValue header.IPv6RouterAlertValue |
| hopLimit uint8 |
| expectValidMLD bool |
| getMessageTypeStatValue func(tcpip.Stats) uint64 |
| }{ |
| { |
| name: "valid", |
| messageType: header.ICMPv6MulticastListenerQuery, |
| includeRouterAlertOption: true, |
| routerAlertValue: header.IPv6RouterAlertMLD, |
| srcAddr: linkLocalAddr2, |
| hopLimit: header.MLDHopLimit, |
| expectValidMLD: true, |
| getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerQuery.Value() }, |
| }, |
| { |
| name: "bad hop limit", |
| messageType: header.ICMPv6MulticastListenerReport, |
| includeRouterAlertOption: true, |
| routerAlertValue: header.IPv6RouterAlertMLD, |
| srcAddr: linkLocalAddr2, |
| hopLimit: header.MLDHopLimit + 1, |
| expectValidMLD: false, |
| getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerReport.Value() }, |
| }, |
| { |
| name: "src ip not link local", |
| messageType: header.ICMPv6MulticastListenerReport, |
| includeRouterAlertOption: true, |
| routerAlertValue: header.IPv6RouterAlertMLD, |
| srcAddr: globalAddr, |
| hopLimit: header.MLDHopLimit, |
| expectValidMLD: false, |
| getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerReport.Value() }, |
| }, |
| { |
| name: "missing router alert ip option", |
| messageType: header.ICMPv6MulticastListenerDone, |
| includeRouterAlertOption: false, |
| srcAddr: linkLocalAddr2, |
| hopLimit: header.MLDHopLimit, |
| expectValidMLD: false, |
| getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerDone.Value() }, |
| }, |
| { |
| name: "incorrect router alert value", |
| messageType: header.ICMPv6MulticastListenerDone, |
| includeRouterAlertOption: true, |
| routerAlertValue: header.IPv6RouterAlertRSVP, |
| srcAddr: linkLocalAddr2, |
| hopLimit: header.MLDHopLimit, |
| expectValidMLD: false, |
| getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.ICMP.V6.PacketsReceived.MulticastListenerDone.Value() }, |
| }, |
| } |
| |
| for _, test := range tests { |
| t.Run(test.name, func(t *testing.T) { |
| s := stack.New(stack.Options{ |
| NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ |
| MLD: ipv6.MLDOptions{ |
| Enabled: true, |
| }, |
| })}, |
| }) |
| e := channel.New(nicID, header.IPv6MinimumMTU, "") |
| if err := s.CreateNIC(nicID, e); err != nil { |
| t.Fatalf("CreateNIC(%d, _): %s", nicID, err) |
| } |
| stats := s.Stats() |
| // Verify that every relevant stats is zero'd before we send a packet. |
| if got := test.getMessageTypeStatValue(s.Stats()); got != 0 { |
| t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 0", got) |
| } |
| if got := stats.ICMP.V6.PacketsReceived.Invalid.Value(); got != 0 { |
| t.Errorf("got stats.ICMP.V6.PacketsReceived.Invalid.Value() = %d, want = 0", got) |
| } |
| if got := stats.IP.PacketsDelivered.Value(); got != 0 { |
| t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 0", got) |
| } |
| createAndInjectMLDPacket(e, test.messageType, test.hopLimit, test.srcAddr, test.includeRouterAlertOption, test.routerAlertValue) |
| // We always expect the packet to pass IP validation. |
| if got := stats.IP.PacketsDelivered.Value(); got != 1 { |
| t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 1", got) |
| } |
| // Even when the MLD-specific validation checks fail, we expect the |
| // corresponding MLD counter to be incremented. |
| if got := test.getMessageTypeStatValue(s.Stats()); got != 1 { |
| t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 1", got) |
| } |
| var expectedInvalidCount uint64 |
| if !test.expectValidMLD { |
| expectedInvalidCount = 1 |
| } |
| if got := stats.ICMP.V6.PacketsReceived.Invalid.Value(); got != expectedInvalidCount { |
| t.Errorf("got stats.ICMP.V6.PacketsReceived.Invalid.Value() = %d, want = %d", got, expectedInvalidCount) |
| } |
| }) |
| } |
| } |