Fix a bug with TTL values.
Previously, sending a multicast packet would set the default TTL value
for all packets sent via the same link endpoint. This fixes the issue.
NET-314 #done
Change-Id: I117b7a7569804c1887dfbf08d4788eea54b46d74
diff --git a/tcpip/checker/checker.go b/tcpip/checker/checker.go
index 5571846..921cbfe 100644
--- a/tcpip/checker/checker.go
+++ b/tcpip/checker/checker.go
@@ -75,6 +75,22 @@
}
}
+// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6).
+func TTL(ttl uint8) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ var v uint8
+ switch ip := h[0].(type) {
+ case header.IPv4:
+ v = ip.TTL()
+ case header.IPv6:
+ v = ip.HopLimit()
+ }
+ if v != ttl {
+ t.Fatalf("Bad TTL, got %v, want %v", v, ttl)
+ }
+ }
+}
+
// PayloadLen creates a checker that checks the payload length.
func PayloadLen(plen int) NetworkChecker {
return func(t *testing.T, h []header.Network) {
diff --git a/tcpip/header/ipv4.go b/tcpip/header/ipv4.go
index e87c044..63e0a42 100644
--- a/tcpip/header/ipv4.go
+++ b/tcpip/header/ipv4.go
@@ -84,6 +84,9 @@
// IPv4Version is the version of the ipv4 procotol.
IPv4Version = 4
+ // IPv4DefaultTTL is the default time-to-live value for sent packets.
+ IPv4DefaultTTL = 65
+
// IPv4Loopback is the loopback address of the IPv4 procotol.
IPv4Loopback tcpip.Address = "\x7f\x00\x00\x01"
diff --git a/tcpip/header/ipv6.go b/tcpip/header/ipv6.go
index 872fc3b..22a45e9 100644
--- a/tcpip/header/ipv6.go
+++ b/tcpip/header/ipv6.go
@@ -63,6 +63,10 @@
// IPv6Version is the version of the ipv6 procotol.
IPv6Version = 6
+ // IPv6DefaultHopLimit is the default hop limit (or TTL) value for
+ // sent packets.
+ IPv6DefaultHopLimit = 255
+
// IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460,
// section 5.
IPv6MinimumMTU = 1280
diff --git a/tcpip/network/arp/arp.go b/tcpip/network/arp/arp.go
index 20e5dce..a41e475 100644
--- a/tcpip/network/arp/arp.go
+++ b/tcpip/network/arp/arp.go
@@ -41,6 +41,10 @@
linkAddrCache stack.LinkAddressCache
}
+func (e *endpoint) DefaultTTL() uint8 {
+ return 0 // unused for ARP
+}
+
func (e *endpoint) MTU() uint32 {
lmtu := e.linkEP.MTU()
return lmtu - uint32(e.MaxHeaderLength())
@@ -58,12 +62,9 @@
return e.linkEP.MaxHeaderLength() + header.ARPSize
}
-func (e *endpoint) SetTTL(_ uint8) {
-}
-
func (e *endpoint) Close() {}
-func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
return tcpip.ErrNotSupported
}
diff --git a/tcpip/network/ipv4/icmp.go b/tcpip/network/ipv4/icmp.go
index 7e3d4dc..29e7d89 100644
--- a/tcpip/network/ipv4/icmp.go
+++ b/tcpip/network/ipv4/icmp.go
@@ -71,7 +71,7 @@
icmpv4.SetCode(code)
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
- return r.WritePacket(&hdr, data, header.ICMPv4ProtocolNumber)
+ return r.WritePacket(&hdr, data, header.ICMPv4ProtocolNumber, r.DefaultTTL())
}
// A Pinger can send echo requests to an address.
@@ -283,7 +283,7 @@
chksum := header.ICMPv4(data).CalculateChecksum(icmpv4.CalculateChecksum(0))
icmpv4.SetChecksum(^chksum)
- if err := e.route.WritePacket(&hdr, data, header.ICMPv4ProtocolNumber); err != nil {
+ if err := e.route.WritePacket(&hdr, data, header.ICMPv4ProtocolNumber, e.route.DefaultTTL()); err != nil {
return 0, err
}
return uintptr(len(v)), nil
diff --git a/tcpip/network/ipv4/ipv4.go b/tcpip/network/ipv4/ipv4.go
index 29337ba..25ff0bc 100644
--- a/tcpip/network/ipv4/ipv4.go
+++ b/tcpip/network/ipv4/ipv4.go
@@ -46,7 +46,6 @@
dispatcher stack.TransportDispatcher
echoRequests chan echoRequest
fragmentation fragmentation.Fragmentation
- ttl uint8
}
func newEndpoint(nicid tcpip.NICID, addr tcpip.Address, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) *endpoint {
@@ -56,7 +55,6 @@
dispatcher: dispatcher,
echoRequests: make(chan echoRequest, 10),
fragmentation: fragmentation.NewFragmentation(fragmentation.MemoryLimit, fragmentation.DefaultReassembleTimeout),
- ttl: 65,
}
copy(e.address[:], addr)
e.id = stack.NetworkEndpointID{tcpip.Address(e.address[:])}
@@ -66,6 +64,10 @@
return e
}
+func (e *endpoint) DefaultTTL() uint8 {
+ return header.IPv4DefaultTTL
+}
+
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
@@ -92,14 +94,8 @@
return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize
}
-// SetTTL sets the default time-to-live value for packets sent through
-// this endpoint.
-func (e *endpoint) SetTTL(ttl uint8) {
- e.ttl = ttl
-}
-
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
length := uint16(hdr.UsedLength() + len(payload))
id := uint32(0)
@@ -112,7 +108,7 @@
IHL: header.IPv4MinimumSize,
TotalLength: length,
ID: uint16(id),
- TTL: e.ttl,
+ TTL: ttl,
Protocol: uint8(protocol),
SrcAddr: tcpip.Address(e.address[:]),
DstAddr: r.RemoteAddress,
diff --git a/tcpip/network/ipv6/icmp.go b/tcpip/network/ipv6/icmp.go
index 0c7e5e6..dd3a348 100644
--- a/tcpip/network/ipv6/icmp.go
+++ b/tcpip/network/ipv6/icmp.go
@@ -48,7 +48,7 @@
copy(pkt[26:], r.LocalLinkAddress[:])
r.LocalAddress = targetAddr
pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, nil))
- r.WritePacket(&hdr, nil, header.ICMPv6ProtocolNumber)
+ r.WritePacket(&hdr, nil, header.ICMPv6ProtocolNumber, r.DefaultTTL())
e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
@@ -68,7 +68,7 @@
copy(pkt, h)
pkt.SetType(header.ICMPv6EchoReply)
pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, data))
- r.WritePacket(&hdr, data, header.ICMPv6ProtocolNumber)
+ r.WritePacket(&hdr, data, header.ICMPv6ProtocolNumber, r.DefaultTTL())
default:
log.Printf("got ICMPv6: type=%v, code=%v, len(v)=%d", h.Type(), h.Code(), len(v))
}
diff --git a/tcpip/network/ipv6/ipv6.go b/tcpip/network/ipv6/ipv6.go
index 80aedd4..3858c27 100644
--- a/tcpip/network/ipv6/ipv6.go
+++ b/tcpip/network/ipv6/ipv6.go
@@ -38,7 +38,10 @@
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
dispatcher stack.TransportDispatcher
- ttl uint8
+}
+
+func (e *endpoint) DefaultTTL() uint8 {
+ return header.IPv6DefaultHopLimit
}
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
@@ -67,14 +70,8 @@
return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
}
-// SetTTL sets the default time-to-live value for packets sent through
-// this endpoint.
-func (e *endpoint) SetTTL(ttl uint8) {
- e.ttl = ttl
-}
-
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
length := uint16(hdr.UsedLength())
if payload != nil {
length += uint16(len(payload))
@@ -83,7 +80,7 @@
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
NextHeader: uint8(protocol),
- HopLimit: e.ttl,
+ HopLimit: ttl,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
@@ -145,7 +142,6 @@
linkEP: linkEP,
linkAddrCache: linkAddrCache,
dispatcher: dispatcher,
- ttl: 255,
}
copy(e.address[:], addr)
e.id = stack.NetworkEndpointID{tcpip.Address(e.address[:])}
diff --git a/tcpip/stack/registration.go b/tcpip/stack/registration.go
index a2c58b3..da7b6d9 100644
--- a/tcpip/stack/registration.go
+++ b/tcpip/stack/registration.go
@@ -89,6 +89,10 @@
// NetworkEndpoint is the interface that needs to be implemented by endpoints
// of network layer protocols (e.g., ipv4, ipv6).
type NetworkEndpoint interface {
+ // Default TTL is the default time-to-live value (or hop limit, in ipv6)
+ // for this endpoint.
+ DefaultTTL() uint8
+
// MTU is the maximum transmission unit for this endpoint. This is
// generally calculated as the MTU of the underlying data link endpoint
// minus the network endpoint max header length.
@@ -102,7 +106,7 @@
// WritePacket writes a packet to the given destination address and
// protocol.
- WritePacket(r *Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error
+ WritePacket(r *Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
@@ -110,10 +114,6 @@
// NICID returns the id of the NIC this endpoint belongs to.
NICID() tcpip.NICID
- // SetTTL sets the default time-to-live value for packets sent through
- // this endpoint.
- SetTTL(ttl uint8)
-
// HandlePacket is called by the link layer when new packets arrive to
// this network endpoint.
HandlePacket(r *Route, vv *buffer.VectorisedView)
diff --git a/tcpip/stack/route.go b/tcpip/stack/route.go
index c3eca94..fef7daf 100644
--- a/tcpip/stack/route.go
+++ b/tcpip/stack/route.go
@@ -66,17 +66,12 @@
return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress)
}
-// SetTTL forwards the call to the network endpoint's implementation.
-func (r *Route) SetTTL(ttl uint8) {
- r.ref.ep.SetTTL(ttl)
-}
-
func isLoopback(addr tcpip.Address) bool {
return (len(addr) == 4 && addr[0] == 127) || addr == header.IPv6Loopback
}
// WritePacket writes the packet through the given route.
-func (r *Route) WritePacket(hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+func (r *Route) WritePacket(hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
if r.RemoteLinkAddress == "" && r.ref.linkRes != nil && !isLoopback(r.RemoteAddress) {
nextAddr := r.NextHop
if nextAddr == "" {
@@ -93,7 +88,12 @@
return tcpip.ErrNoLinkAddress
}
}
- return r.ref.ep.WritePacket(r, hdr, payload, protocol)
+ return r.ref.ep.WritePacket(r, hdr, payload, protocol, ttl)
+}
+
+// DefaultTTL returns the default TTL of the underlying network endpoint.
+func (r *Route) DefaultTTL() uint8 {
+ return r.ref.ep.DefaultTTL()
}
// MTU returns the MTU of the underlying network endpoint.
diff --git a/tcpip/transport/tcp/connect.go b/tcpip/transport/tcp/connect.go
index 9582576..e47fe32 100644
--- a/tcpip/transport/tcp/connect.go
+++ b/tcpip/transport/tcp/connect.go
@@ -413,12 +413,12 @@
header.TCPOptionWS, 3, uint8(opts.WS), header.TCPOptionNOP)
}
- return sendTCPWithOptions(r, id, nil, flags, seq, ack, rcvWnd, options)
+ return sendTCPWithOptions(r, id, nil, r.DefaultTTL(), flags, seq, ack, rcvWnd, options)
}
// sendTCPWithOptions sends a TCP segment with the provided options via the
// provided network endpoint and under the provided identity.
-func sendTCPWithOptions(r *stack.Route, id stack.TransportEndpointID, data buffer.View, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error {
+func sendTCPWithOptions(r *stack.Route, id stack.TransportEndpointID, data buffer.View, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error {
optLen := len(opts)
// Allocate a buffer for the TCP header.
hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen)
@@ -449,12 +449,12 @@
tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length))
- return r.WritePacket(&hdr, data, ProtocolNumber)
+ return r.WritePacket(&hdr, data, ProtocolNumber, ttl)
}
// sendTCP sends a TCP segment via the provided network endpoint and under the
// provided identity.
-func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.View, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
+func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.View, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
// Allocate a buffer for the TCP header.
hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()))
@@ -483,7 +483,7 @@
tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length))
- return r.WritePacket(&hdr, data, ProtocolNumber)
+ return r.WritePacket(&hdr, data, ProtocolNumber, ttl)
}
// sendRaw sends a TCP segment to the endpoint's peer.
@@ -507,9 +507,9 @@
//
// Ref: https://tools.ietf.org/html/rfc7323#section-5.4.
options := header.EncodeTSOption(e.timestamp(), uint32(e.recentTS))
- return sendTCPWithOptions(&e.route, e.id, data, flags, seq, ack, rcvWnd, options[:])
+ return sendTCPWithOptions(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options[:])
}
- return sendTCP(&e.route, e.id, data, flags, seq, ack, rcvWnd)
+ return sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd)
}
func (e *endpoint) handleWrite() bool {
diff --git a/tcpip/transport/tcp/protocol.go b/tcpip/transport/tcp/protocol.go
index 6961db5..d291958 100644
--- a/tcpip/transport/tcp/protocol.go
+++ b/tcpip/transport/tcp/protocol.go
@@ -85,7 +85,7 @@
ack := s.sequenceNumber.Add(s.logicalLen())
- sendTCP(&s.route, s.id, nil, flagRst|flagAck, seq, ack, 0)
+ sendTCP(&s.route, s.id, nil, s.route.DefaultTTL(), flagRst|flagAck, seq, ack, 0)
}
// SetOption implements TransportProtocol.SetOption.
diff --git a/tcpip/transport/udp/endpoint.go b/tcpip/transport/udp/endpoint.go
index 0ad5b66..d315a52 100644
--- a/tcpip/transport/udp/endpoint.go
+++ b/tcpip/transport/udp/endpoint.go
@@ -269,11 +269,12 @@
dstPort = to.Port
}
+ ttl := route.DefaultTTL()
if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
- route.SetTTL(e.multicastTTL)
+ ttl = e.multicastTTL
}
- sendUDP(route, v, e.id.LocalPort, dstPort)
+ sendUDP(route, v, e.id.LocalPort, dstPort, ttl)
return uintptr(len(v)), nil
}
@@ -392,7 +393,7 @@
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.View, localPort, remotePort uint16) *tcpip.Error {
+func sendUDP(r *stack.Route, data buffer.View, localPort, remotePort uint16, ttl uint8) *tcpip.Error {
// Allocate a buffer for the UDP header.
hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
@@ -414,7 +415,7 @@
udp.SetChecksum(^udp.CalculateChecksum(xsum, length))
- return r.WritePacket(&hdr, data, ProtocolNumber)
+ return r.WritePacket(&hdr, data, ProtocolNumber, ttl)
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
diff --git a/tcpip/transport/udp/udp_test.go b/tcpip/transport/udp/udp_test.go
index 9638232..e95c6b4 100644
--- a/tcpip/transport/udp/udp_test.go
+++ b/tcpip/transport/udp/udp_test.go
@@ -24,16 +24,19 @@
)
const (
- stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr
- testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr
- V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
+ stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr
+ testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr
+ multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr
+ V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
- stackAddr = "\x0a\x00\x00\x01"
- stackPort = 1234
- testAddr = "\x0a\x00\x00\x02"
- testPort = 4096
+ stackAddr = "\x0a\x00\x00\x01"
+ stackPort = 1234
+ testAddr = "\x0a\x00\x00\x02"
+ testPort = 4096
+ multicastAddr = "\xe8\x2b\xd3\xea"
+ multicastPort = 1234
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
@@ -158,6 +161,26 @@
return nil
}
+func (c *testContext) getMCPacket() []byte {
+ select {
+ case p := <-c.linkEP.C:
+ if p.Proto != ipv4.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ }
+ b := make([]byte, len(p.Header)+len(p.Payload))
+ copy(b, p.Header)
+ copy(b[len(p.Header):], p.Payload)
+
+ checker.IPv4(c.t, b, checker.SrcAddr(stackAddr), checker.DstAddr(multicastAddr))
+ return b
+
+ case <-time.After(2 * time.Second):
+ c.t.Fatalf("Packet wasn't written out")
+ }
+
+ return nil
+}
+
func (c *testContext) sendV6Packet(payload []byte, h *headers) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
@@ -615,3 +638,51 @@
c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
}
}
+
+func TestMulticastTTL(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+ c.ep.SetSockOpt(tcpip.MulticastTTLOption(42))
+
+ payload := buffer.View(newPayload())
+ // Write a multicast packet. Its TTL value should be the above multicast value.
+ {
+ n, err := c.ep.Write(payload, &tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: multicastPort})
+ if err != nil {
+ c.t.Fatalf("Write failed: %v", err)
+ }
+ if n != uintptr(len(payload)) {
+ c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
+ }
+
+ // Check that we received the packet and that it has the multicastTTL value.
+ b := c.getMCPacket()
+ checker.IPv4(c.t, b,
+ checker.TTL(42),
+ checker.UDP(
+ checker.DstPort(multicastPort),
+ ),
+ )
+ }
+
+ // Write a regular packet. Its TTL value should be the default.
+ {
+ n, err := c.ep.Write(payload, &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort})
+ if err != nil {
+ c.t.Fatalf("Write failed: %v", err)
+ }
+ if n != uintptr(len(payload)) {
+ c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
+ }
+
+ b := c.getPacket()
+ checker.IPv4(c.t, b,
+ checker.TTL(header.IPv4DefaultTTL),
+ checker.UDP(
+ checker.DstPort(testPort),
+ ),
+ )
+ }
+}