Merge remote-tracking branch 'upstream/master' into HEAD

Change-Id: I76d0405304d5fbe420a4b924a2b86dfff1ad94f5
diff --git a/tcpip/network/arp/arp.go b/tcpip/network/arp/arp.go
index 005c0b2..d245583 100644
--- a/tcpip/network/arp/arp.go
+++ b/tcpip/network/arp/arp.go
@@ -79,7 +79,7 @@
 
 func (e *endpoint) Close() {}
 
-func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.TransportProtocolNumber, uint8, stack.PacketLooping) *tcpip.Error {
+func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, stack.NetworkHeaderParams, stack.PacketLooping) *tcpip.Error {
 	return tcpip.ErrNotSupported
 }
 
@@ -109,7 +109,11 @@
 		copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender())
 		copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender())
 		e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+		fallthrough // also fill the cache from requests
 	case header.ARPReply:
+		addr := tcpip.Address(h.ProtocolAddressSender())
+		linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+		e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr)
 	}
 }
 
diff --git a/tcpip/network/ip_test.go b/tcpip/network/ip_test.go
index e05e427..f1ebbf6 100644
--- a/tcpip/network/ip_test.go
+++ b/tcpip/network/ip_test.go
@@ -230,7 +230,7 @@
 	if err != nil {
 		t.Fatalf("could not find route: %v", err)
 	}
-	if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil {
+	if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil {
 		t.Fatalf("WritePacket failed: %v", err)
 	}
 }
@@ -460,7 +460,7 @@
 	if err != nil {
 		t.Fatalf("could not find route: %v", err)
 	}
-	if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil {
+	if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil {
 		t.Fatalf("WritePacket failed: %v", err)
 	}
 }
diff --git a/tcpip/network/ipv4/icmp.go b/tcpip/network/ipv4/icmp.go
index c73bf4e..ac58039 100644
--- a/tcpip/network/ipv4/icmp.go
+++ b/tcpip/network/ipv4/icmp.go
@@ -95,7 +95,7 @@
 		pkt.SetChecksum(0)
 		pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0)))
 		sent := stats.ICMP.V4PacketsSent
-		if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv4ProtocolNumber, 0, true /* useDefaultTTL */); err != nil {
+		if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil {
 			sent.Dropped.Increment()
 			return
 		}
diff --git a/tcpip/network/ipv4/ipv4.go b/tcpip/network/ipv4/ipv4.go
index 46d07e3..7fc5b36 100644
--- a/tcpip/network/ipv4/ipv4.go
+++ b/tcpip/network/ipv4/ipv4.go
@@ -199,21 +199,22 @@
 }
 
 // WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
 	ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
 	length := uint16(hdr.UsedLength() + payload.Size())
 	id := uint32(0)
 	if length > header.IPv4MaximumHeaderSize+8 {
 		// Packets of 68 bytes or less are required by RFC 791 to not be
 		// fragmented, so we only assign ids to larger packets.
-		id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, protocol, e.protocol.hashIV)%buckets], 1)
+		id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
 	}
 	ip.Encode(&header.IPv4Fields{
 		IHL:         header.IPv4MinimumSize,
 		TotalLength: length,
 		ID:          uint16(id),
-		TTL:         ttl,
-		Protocol:    uint8(protocol),
+		TTL:         params.TTL,
+		TOS:         params.TOS,
+		Protocol:    uint8(params.Protocol),
 		SrcAddr:     r.LocalAddress,
 		DstAddr:     r.RemoteAddress,
 	})
diff --git a/tcpip/network/ipv4/ipv4_test.go b/tcpip/network/ipv4/ipv4_test.go
index 0c5d19f..7982ef5 100644
--- a/tcpip/network/ipv4/ipv4_test.go
+++ b/tcpip/network/ipv4/ipv4_test.go
@@ -302,7 +302,7 @@
 				Payload: payload.Clone([]buffer.View{}),
 			}
 			c := buildContext(t, nil, ft.mtu)
-			err := c.Route.WritePacket(ft.gso, hdr, payload, tcp.ProtocolNumber, 42 /* ttl */, false /* useDefaultTTL */)
+			err := c.Route.WritePacket(ft.gso, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS})
 			if err != nil {
 				t.Errorf("err got %v, want %v", err, nil)
 			}
@@ -349,7 +349,7 @@
 		t.Run(ft.description, func(t *testing.T) {
 			hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
 			c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
-			err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42 /* ttl */, false /* useDefaultTTL */)
+			err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS})
 			for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
 				if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
 					t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)
diff --git a/tcpip/network/ipv6/icmp.go b/tcpip/network/ipv6/icmp.go
index ba2dc1d..14aa110 100644
--- a/tcpip/network/ipv6/icmp.go
+++ b/tcpip/network/ipv6/icmp.go
@@ -121,7 +121,6 @@
 
 	case header.ICMPv6NeighborSolicit:
 		received.NeighborSolicit.Increment()
-
 		if len(v) < header.ICMPv6NeighborSolicitMinimumSize {
 			received.Invalid.Increment()
 			return
@@ -131,7 +130,6 @@
 			// We don't have a useful answer; the best we can do is ignore the request.
 			return
 		}
-
 		hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertSize)
 		pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
 		pkt.SetType(header.ICMPv6NeighborAdvert)
@@ -154,7 +152,22 @@
 		r.LocalAddress = targetAddr
 		pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
 
-		if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, 0, true /* useDefaultTTL */); err != nil {
+		// TODO(tamird/ghanan): there exists an explicit NDP option that is
+		// used to update the neighbor table with link addresses for a
+		// neighbor from an NS (see the Source Link Layer option RFC
+		// 4861 section 4.6.1 and section 7.2.3).
+		//
+		// Furthermore, the entirety of NDP handling here seems to be
+		// contradicted by RFC 4861.
+		e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
+
+		// RFC 4861 Neighbor Discovery for IP version 6 (IPv6)
+		//
+		// 7.1.2. Validation of Neighbor Advertisements
+		//
+		// The IP Hop Limit field has a value of 255, i.e., the packet
+		// could not possibly have been forwarded by a router.
+		if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ndpHopLimit, TOS: stack.DefaultTOS}); err != nil {
 			sent.Dropped.Increment()
 			return
 		}
@@ -178,14 +191,13 @@
 			received.Invalid.Increment()
 			return
 		}
-
 		vv.TrimFront(header.ICMPv6EchoMinimumSize)
 		hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize)
 		pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
 		copy(pkt, h)
 		pkt.SetType(header.ICMPv6EchoReply)
 		pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv))
-		if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, 0, true /* useDefaultTTL */); err != nil {
+		if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil {
 			sent.Dropped.Increment()
 			return
 		}
diff --git a/tcpip/network/ipv6/icmp_test.go b/tcpip/network/ipv6/icmp_test.go
index dabced5..4448102 100644
--- a/tcpip/network/ipv6/icmp_test.go
+++ b/tcpip/network/ipv6/icmp_test.go
@@ -15,7 +15,6 @@
 package ipv6
 
 import (
-	"fmt"
 	"reflect"
 	"strings"
 	"testing"
@@ -179,13 +178,10 @@
 	t := v.Type()
 	for i := 0; i < v.NumField(); i++ {
 		v := v.Field(i)
-		switch v.Kind() {
-		case reflect.Ptr:
-			f(t.Field(i).Name, v.Interface().(*tcpip.StatCounter))
-		case reflect.Struct:
+		if s, ok := v.Interface().(*tcpip.StatCounter); ok {
+			f(t.Field(i).Name, s)
+		} else {
 			visitStats(v, f)
-		default:
-			panic(fmt.Sprintf("unexpected type %s", v.Type()))
 		}
 	}
 }
diff --git a/tcpip/network/ipv6/ipv6.go b/tcpip/network/ipv6/ipv6.go
index 8a5685a..10729c8 100644
--- a/tcpip/network/ipv6/ipv6.go
+++ b/tcpip/network/ipv6/ipv6.go
@@ -98,13 +98,14 @@
 }
 
 // WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
 	length := uint16(hdr.UsedLength() + payload.Size())
 	ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
 	ip.Encode(&header.IPv6Fields{
 		PayloadLength: length,
-		NextHeader:    uint8(protocol),
-		HopLimit:      ttl,
+		NextHeader:    uint8(params.Protocol),
+		HopLimit:      params.TTL,
+		TrafficClass:  params.TOS,
 		SrcAddr:       r.LocalAddress,
 		DstAddr:       r.RemoteAddress,
 	})
diff --git a/tcpip/network/ipv6/ipv6_test.go b/tcpip/network/ipv6/ipv6_test.go
index 7131ed1..ce46ee2 100644
--- a/tcpip/network/ipv6/ipv6_test.go
+++ b/tcpip/network/ipv6/ipv6_test.go
@@ -64,7 +64,7 @@
 	}
 }
 
-// testReceiveICMP tests receiving a UDP packet from src to dst. want is the
+// testReceiveUDP tests receiving a UDP packet from src to dst. want is the
 // expected UDP received count after receiving the packet.
 func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
 	t.Helper()
diff --git a/tcpip/stack/nic.go b/tcpip/stack/nic.go
index 5615feb..5838b1c 100644
--- a/tcpip/stack/nic.go
+++ b/tcpip/stack/nic.go
@@ -632,8 +632,6 @@
 
 	src, dst := netProto.ParseAddresses(vv.First())
 
-	n.stack.AddLinkAddress(n.id, src, remote)
-
 	if ref := n.getRef(protocol, dst); ref != nil {
 		handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
 		return
diff --git a/tcpip/stack/registration.go b/tcpip/stack/registration.go
index bd29ae1..d406eca 100644
--- a/tcpip/stack/registration.go
+++ b/tcpip/stack/registration.go
@@ -146,6 +146,19 @@
 	PacketLoop
 )
 
+// NetworkHeaderParams are the header parameters given as input by the
+// transport endpoint to the network.
+type NetworkHeaderParams struct {
+	// Protocol refers to the transport protocol number.
+	Protocol tcpip.TransportProtocolNumber
+
+	// TTL refers to Time To Live field of the IP-header.
+	TTL uint8
+
+	// TOS refers to TypeOfService or TrafficClass field of the IP-header.
+	TOS uint8
+}
+
 // NetworkEndpoint is the interface that needs to be implemented by endpoints
 // of network layer protocols (e.g., ipv4, ipv6).
 type NetworkEndpoint interface {
@@ -170,7 +183,7 @@
 
 	// WritePacket writes a packet to the given destination address and
 	// protocol.
-	WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error
+	WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) *tcpip.Error
 
 	// WriteHeaderIncludedPacket writes a packet that includes a network
 	// header to the given destination address.
diff --git a/tcpip/stack/route.go b/tcpip/stack/route.go
index f3f43f9..7a3dd68 100644
--- a/tcpip/stack/route.go
+++ b/tcpip/stack/route.go
@@ -154,16 +154,12 @@
 }
 
 // WritePacket writes the packet through the given route.
-func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, useDefaultTTL bool) *tcpip.Error {
+func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams) *tcpip.Error {
 	if !r.ref.isValidForOutgoing() {
 		return tcpip.ErrInvalidEndpointState
 	}
 
-	if useDefaultTTL {
-		ttl = r.DefaultTTL()
-	}
-
-	err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop)
+	err := r.ref.ep.WritePacket(r, gso, hdr, payload, params, r.loop)
 	if err != nil {
 		r.Stats().IP.OutgoingPacketErrors.Increment()
 	} else {
diff --git a/tcpip/stack/stack.go b/tcpip/stack/stack.go
index fe4575e..2153c0e 100644
--- a/tcpip/stack/stack.go
+++ b/tcpip/stack/stack.go
@@ -43,6 +43,9 @@
 	resolutionTimeout = 1 * time.Second
 	// resolutionAttempts is set to the same ARP retries used in Linux.
 	resolutionAttempts = 3
+
+	// DefaultTOS is the default type of service value for network endpoints.
+	DefaultTOS = 0
 )
 
 type transportProtocolState struct {
@@ -394,7 +397,7 @@
 	// portSeed is a one-time random value initialized at stack startup
 	// and is used to seed the TCP port picking on active connections
 	//
-	// TODO(gvisor.dev/issues/940): S/R this field.
+	// TODO(gvisor.dev/issue/940): S/R this field.
 	portSeed uint32
 }
 
diff --git a/tcpip/stack/stack_test.go b/tcpip/stack/stack_test.go
index 2e987ae..d9c307e 100644
--- a/tcpip/stack/stack_test.go
+++ b/tcpip/stack/stack_test.go
@@ -119,7 +119,7 @@
 	return f.ep.Capabilities()
 }
 
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
 	// Increment the sent packet count in the protocol descriptor.
 	f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
 
@@ -128,7 +128,7 @@
 	b := hdr.Prepend(fakeNetHeaderLen)
 	b[0] = r.RemoteAddress[0]
 	b[1] = f.id.LocalAddress[0]
-	b[2] = byte(protocol)
+	b[2] = byte(params.Protocol)
 
 	if loop&stack.PacketLoop != 0 {
 		views := make([]buffer.View, 1, 1+len(payload.Views()))
@@ -310,7 +310,7 @@
 
 func send(r stack.Route, payload buffer.View) *tcpip.Error {
 	hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
-	return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123 /* ttl */, false /* useDefaultTTL */)
+	return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS})
 }
 
 func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) {
diff --git a/tcpip/stack/transport_test.go b/tcpip/stack/transport_test.go
index b5f15d2..02bc791 100644
--- a/tcpip/stack/transport_test.go
+++ b/tcpip/stack/transport_test.go
@@ -82,7 +82,7 @@
 	if err != nil {
 		return 0, nil, err
 	}
-	if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), fakeTransNumber, 123 /* ttl */, false /* useDefaultTTL */); err != nil {
+	if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}); err != nil {
 		return 0, nil, err
 	}
 
diff --git a/tcpip/tcpip.go b/tcpip/tcpip.go
index 8b51442..7d47aa8 100644
--- a/tcpip/tcpip.go
+++ b/tcpip/tcpip.go
@@ -624,6 +624,14 @@
 // a default TTL.
 type DefaultTTLOption uint8
 
+// IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS
+// for all subsequent outgoing IPv4 packets from the endpoint.
+type IPv4TOSOption uint8
+
+// IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS
+// for all subsequent outgoing IPv6 packets from the endpoint.
+type IPv6TrafficClassOption uint8
+
 // Route is a row in the routing table. It specifies through which NIC (and
 // gateway) sets of packets should be routed. A row is considered viable if the
 // masked target address matches the destination address in the row.
@@ -1078,15 +1086,12 @@
 func fillIn(v reflect.Value) {
 	for i := 0; i < v.NumField(); i++ {
 		v := v.Field(i)
-		switch v.Kind() {
-		case reflect.Ptr:
-			if s := v.Addr().Interface().(**StatCounter); *s == nil {
-				*s = &StatCounter{}
+		if s, ok := v.Addr().Interface().(**StatCounter); ok {
+			if *s == nil {
+				*s = new(StatCounter)
 			}
-		case reflect.Struct:
+		} else {
 			fillIn(v)
-		default:
-			panic(fmt.Sprintf("unexpected type %s", v.Type()))
 		}
 	}
 }
diff --git a/tcpip/transport/icmp/endpoint.go b/tcpip/transport/icmp/endpoint.go
index 45cf827..c6f81dc 100644
--- a/tcpip/transport/icmp/endpoint.go
+++ b/tcpip/transport/icmp/endpoint.go
@@ -422,7 +422,10 @@
 	icmpv4.SetChecksum(0)
 	icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
 
-	return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, ttl, ttl == 0 /* useDefaultTTL */)
+	if ttl == 0 {
+		ttl = r.DefaultTTL()
+	}
+	return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS})
 }
 
 func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error {
@@ -445,7 +448,10 @@
 	icmpv6.SetChecksum(0)
 	icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0)))
 
-	return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, ttl, ttl == 0 /* useDefaultTTL */)
+	if ttl == 0 {
+		ttl = r.DefaultTTL()
+	}
+	return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS})
 }
 
 func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
diff --git a/tcpip/transport/raw/endpoint.go b/tcpip/transport/raw/endpoint.go
index 45617ac..0ff546b 100644
--- a/tcpip/transport/raw/endpoint.go
+++ b/tcpip/transport/raw/endpoint.go
@@ -350,7 +350,7 @@
 			break
 		}
 		hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
-		if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), e.TransProto, 0, true /* useDefaultTTL */); err != nil {
+		if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil {
 			return 0, nil, err
 		}
 
diff --git a/tcpip/transport/tcp/accept.go b/tcpip/transport/tcp/accept.go
index 76e02cf..c287325 100644
--- a/tcpip/transport/tcp/accept.go
+++ b/tcpip/transport/tcp/accept.go
@@ -441,7 +441,7 @@
 				TSEcr: opts.TSVal,
 				MSS:   uint16(mss),
 			}
-			e.sendSynTCP(&s.route, s.id, e.ttl, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
+			e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
 			e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
 		}
 
diff --git a/tcpip/transport/tcp/connect.go b/tcpip/transport/tcp/connect.go
index 7af0728..c3258b2 100644
--- a/tcpip/transport/tcp/connect.go
+++ b/tcpip/transport/tcp/connect.go
@@ -255,7 +255,7 @@
 	if ttl == 0 {
 		ttl = s.route.DefaultTTL()
 	}
-	h.ep.sendSynTCP(&s.route, h.ep.ID, ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+	h.ep.sendSynTCP(&s.route, h.ep.ID, ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
 	return nil
 }
 
@@ -299,7 +299,7 @@
 			SACKPermitted: h.ep.sackPermitted,
 			MSS:           h.ep.amss,
 		}
-		h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+		h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
 		return nil
 	}
 
@@ -468,7 +468,8 @@
 			synOpts.WS = -1
 		}
 	}
-	h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+	h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+
 	for h.state != handshakeCompleted {
 		switch index, _ := s.Fetch(true); index {
 		case wakerForResend:
@@ -477,7 +478,7 @@
 				return tcpip.ErrTimeout
 			}
 			rt.Reset(timeOut)
-			h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+			h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
 
 		case wakerForNotification:
 			n := h.ep.fetchNotifications()
@@ -587,17 +588,18 @@
 	return options[:offset]
 }
 
-func (e *endpoint) sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) {
+func (e *endpoint) sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
 	options := makeSynOptions(opts)
 	// We ignore SYN send errors and let the callers re-attempt send.
-	if err := e.sendTCP(r, id, buffer.VectorisedView{}, ttl, flags, seq, ack, rcvWnd, options, nil); err != nil {
+	if err := e.sendTCP(r, id, buffer.VectorisedView{}, ttl, tos, flags, seq, ack, rcvWnd, options, nil); err != nil {
 		e.stats.SendErrors.SynSendToNetworkFailed.Increment()
 	}
 	putOptions(options)
+	return nil
 }
 
-func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
-	if err := sendTCP(r, id, data, ttl, flags, seq, ack, rcvWnd, opts, gso); err != nil {
+func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+	if err := sendTCP(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso); err != nil {
 		e.stats.SendErrors.SegmentSendToNetworkFailed.Increment()
 		return err
 	}
@@ -607,7 +609,7 @@
 
 // sendTCP sends a TCP segment with the provided options via the provided
 // network endpoint and under the provided identity.
-func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
 	optLen := len(opts)
 	// Allocate a buffer for the TCP header.
 	hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen)
@@ -643,7 +645,10 @@
 		tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
 	}
 
-	if err := r.WritePacket(gso, hdr, data, ProtocolNumber, ttl, ttl == 0 /* useDefaultTTL */); err != nil {
+	if ttl == 0 {
+		ttl = r.DefaultTTL()
+	}
+	if err := r.WritePacket(gso, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil {
 		r.Stats().TCP.SegmentSendErrors.Increment()
 		return err
 	}
@@ -700,7 +705,7 @@
 		sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
 	}
 	options := e.makeOptions(sackBlocks)
-	err := e.sendTCP(&e.route, e.ID, data, e.ttl, flags, seq, ack, rcvWnd, options, e.gso)
+	err := e.sendTCP(&e.route, e.ID, data, e.ttl, e.sendTOS, flags, seq, ack, rcvWnd, options, e.gso)
 	putOptions(options)
 	return err
 }
diff --git a/tcpip/transport/tcp/dual_stack_test.go b/tcpip/transport/tcp/dual_stack_test.go
index ebd7b0b..78d3aed 100644
--- a/tcpip/transport/tcp/dual_stack_test.go
+++ b/tcpip/transport/tcp/dual_stack_test.go
@@ -42,7 +42,7 @@
 	}
 }
 
-func testV4Connect(t *testing.T, c *context.Context) {
+func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) {
 	// Start connection attempt.
 	we, ch := waiter.NewChannelEntry(nil)
 	c.WQ.EventRegister(&we, waiter.EventOut)
@@ -55,12 +55,11 @@
 
 	// Receive SYN packet.
 	b := c.GetPacket()
-	checker.IPv4(t, b,
-		checker.TCP(
-			checker.DstPort(context.TestPort),
-			checker.TCPFlags(header.TCPFlagSyn),
-		),
-	)
+	synCheckers := append(checkers, checker.TCP(
+		checker.DstPort(context.TestPort),
+		checker.TCPFlags(header.TCPFlagSyn),
+	))
+	checker.IPv4(t, b, synCheckers...)
 
 	tcp := header.TCP(header.IPv4(b).Payload())
 	c.IRS = seqnum.Value(tcp.SequenceNumber())
@@ -76,14 +75,13 @@
 	})
 
 	// Receive ACK packet.
-	checker.IPv4(t, c.GetPacket(),
-		checker.TCP(
-			checker.DstPort(context.TestPort),
-			checker.TCPFlags(header.TCPFlagAck),
-			checker.SeqNum(uint32(c.IRS)+1),
-			checker.AckNum(uint32(iss)+1),
-		),
-	)
+	ackCheckers := append(checkers, checker.TCP(
+		checker.DstPort(context.TestPort),
+		checker.TCPFlags(header.TCPFlagAck),
+		checker.SeqNum(uint32(c.IRS)+1),
+		checker.AckNum(uint32(iss)+1),
+	))
+	checker.IPv4(t, c.GetPacket(), ackCheckers...)
 
 	// Wait for connection to be established.
 	select {
@@ -152,7 +150,7 @@
 	testV4Connect(t, c)
 }
 
-func testV6Connect(t *testing.T, c *context.Context) {
+func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) {
 	// Start connection attempt to IPv6 address.
 	we, ch := waiter.NewChannelEntry(nil)
 	c.WQ.EventRegister(&we, waiter.EventOut)
@@ -165,12 +163,11 @@
 
 	// Receive SYN packet.
 	b := c.GetV6Packet()
-	checker.IPv6(t, b,
-		checker.TCP(
-			checker.DstPort(context.TestPort),
-			checker.TCPFlags(header.TCPFlagSyn),
-		),
-	)
+	synCheckers := append(checkers, checker.TCP(
+		checker.DstPort(context.TestPort),
+		checker.TCPFlags(header.TCPFlagSyn),
+	))
+	checker.IPv6(t, b, synCheckers...)
 
 	tcp := header.TCP(header.IPv6(b).Payload())
 	c.IRS = seqnum.Value(tcp.SequenceNumber())
@@ -186,14 +183,13 @@
 	})
 
 	// Receive ACK packet.
-	checker.IPv6(t, c.GetV6Packet(),
-		checker.TCP(
-			checker.DstPort(context.TestPort),
-			checker.TCPFlags(header.TCPFlagAck),
-			checker.SeqNum(uint32(c.IRS)+1),
-			checker.AckNum(uint32(iss)+1),
-		),
-	)
+	ackCheckers := append(checkers, checker.TCP(
+		checker.DstPort(context.TestPort),
+		checker.TCPFlags(header.TCPFlagAck),
+		checker.SeqNum(uint32(c.IRS)+1),
+		checker.AckNum(uint32(iss)+1),
+	))
+	checker.IPv6(t, c.GetV6Packet(), ackCheckers...)
 
 	// Wait for connection to be established.
 	select {
diff --git a/tcpip/transport/tcp/endpoint.go b/tcpip/transport/tcp/endpoint.go
index 55adf5b..ba897be 100644
--- a/tcpip/transport/tcp/endpoint.go
+++ b/tcpip/transport/tcp/endpoint.go
@@ -494,6 +494,10 @@
 	// amss is the advertised MSS to the peer by this endpoint.
 	amss uint16
 
+	// sendTOS represents IPv4 TOS or IPv6 TrafficClass,
+	// applied while sending packets. Defaults to 0 as on Linux.
+	sendTOS uint8
+
 	gso *stack.GSO
 
 	// TODO(b/142022063): Add ability to save and restore per endpoint stats.
@@ -1136,6 +1140,8 @@
 
 // SetSockOpt sets a socket option.
 func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+	// Lower 2 bits represents ECN bits. RFC 3168, section 23.1
+	const inetECNMask = 3
 	switch v := opt.(type) {
 	case tcpip.DelayOption:
 		if v == 0 {
@@ -1296,6 +1302,23 @@
 		// Linux returns ENOENT when an invalid congestion
 		// control algorithm is specified.
 		return tcpip.ErrNoSuchFile
+
+	case tcpip.IPv4TOSOption:
+		e.mu.Lock()
+		// TODO(gvisor.dev/issue/995): ECN is not currently supported,
+		// ignore the bits for now.
+		e.sendTOS = uint8(v) & ^uint8(inetECNMask)
+		e.mu.Unlock()
+		return nil
+
+	case tcpip.IPv6TrafficClassOption:
+		e.mu.Lock()
+		// TODO(gvisor.dev/issue/995): ECN is not currently supported,
+		// ignore the bits for now.
+		e.sendTOS = uint8(v) & ^uint8(inetECNMask)
+		e.mu.Unlock()
+		return nil
+
 	default:
 		return nil
 	}
@@ -1495,6 +1518,18 @@
 		e.mu.Unlock()
 		return nil
 
+	case *tcpip.IPv4TOSOption:
+		e.mu.RLock()
+		*o = tcpip.IPv4TOSOption(e.sendTOS)
+		e.mu.RUnlock()
+		return nil
+
+	case *tcpip.IPv6TrafficClassOption:
+		e.mu.RLock()
+		*o = tcpip.IPv6TrafficClassOption(e.sendTOS)
+		e.mu.RUnlock()
+		return nil
+
 	default:
 		return tcpip.ErrUnknownProtocolOption
 	}
diff --git a/tcpip/transport/tcp/protocol.go b/tcpip/transport/tcp/protocol.go
index aee8e1b..95da42d 100644
--- a/tcpip/transport/tcp/protocol.go
+++ b/tcpip/transport/tcp/protocol.go
@@ -153,7 +153,7 @@
 
 	ack := s.sequenceNumber.Add(s.logicalLen())
 
-	sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0, nil /* options */, nil /* gso */)
+	sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */)
 }
 
 // SetOption implements TransportProtocol.SetOption.
diff --git a/tcpip/transport/tcp/tcp_test.go b/tcpip/transport/tcp/tcp_test.go
index 05fa5e8..56716e2 100644
--- a/tcpip/transport/tcp/tcp_test.go
+++ b/tcpip/transport/tcp/tcp_test.go
@@ -474,6 +474,107 @@
 	)
 }
 
+func TestTOSV4(t *testing.T) {
+	c := context.New(t, defaultMTU)
+	defer c.Cleanup()
+
+	ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+	if err != nil {
+		t.Fatalf("NewEndpoint failed: %s", err)
+	}
+	c.EP = ep
+
+	const tos = 0xC0
+	if err := c.EP.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
+		t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err)
+	}
+
+	var v tcpip.IPv4TOSOption
+	if err := c.EP.GetSockOpt(&v); err != nil {
+		t.Errorf("GetSockopt failed: %s", err)
+	}
+
+	if want := tcpip.IPv4TOSOption(tos); v != want {
+		t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+	}
+
+	testV4Connect(t, c, checker.TOS(tos, 0))
+
+	data := []byte{1, 2, 3}
+	view := buffer.NewView(len(data))
+	copy(view, data)
+
+	if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+		t.Fatalf("Write failed: %s", err)
+	}
+
+	// Check that data is received.
+	b := c.GetPacket()
+	checker.IPv4(t, b,
+		checker.PayloadLen(len(data)+header.TCPMinimumSize),
+		checker.TCP(
+			checker.DstPort(context.TestPort),
+			checker.SeqNum(uint32(c.IRS)+1),
+			checker.AckNum(790), // Acknum is initial sequence number + 1
+			checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+		),
+		checker.TOS(tos, 0),
+	)
+
+	if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
+		t.Errorf("got data = %x, want = %x", p, data)
+	}
+}
+
+func TestTrafficClassV6(t *testing.T) {
+	c := context.New(t, defaultMTU)
+	defer c.Cleanup()
+
+	c.CreateV6Endpoint(false)
+
+	const tos = 0xC0
+	if err := c.EP.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil {
+		t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv6TrafficClassOption(tos), err)
+	}
+
+	var v tcpip.IPv6TrafficClassOption
+	if err := c.EP.GetSockOpt(&v); err != nil {
+		t.Fatalf("GetSockopt failed: %s", err)
+	}
+
+	if want := tcpip.IPv6TrafficClassOption(tos); v != want {
+		t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+	}
+
+	// Test the connection request.
+	testV6Connect(t, c, checker.TOS(tos, 0))
+
+	data := []byte{1, 2, 3}
+	view := buffer.NewView(len(data))
+	copy(view, data)
+
+	if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+		t.Fatalf("Write failed: %s", err)
+	}
+
+	// Check that data is received.
+	b := c.GetV6Packet()
+	checker.IPv6(t, b,
+		checker.PayloadLen(len(data)+header.TCPMinimumSize),
+		checker.TCP(
+			checker.DstPort(context.TestPort),
+			checker.SeqNum(uint32(c.IRS)+1),
+			checker.AckNum(790),
+			checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+		),
+		checker.TOS(tos, 0),
+	)
+
+	if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
+		t.Errorf("got data = %x, want = %x", p, data)
+	}
+}
+
 func TestConnectBindToDevice(t *testing.T) {
 	for _, test := range []struct {
 		name   string
diff --git a/tcpip/transport/udp/endpoint.go b/tcpip/transport/udp/endpoint.go
index 8adba9b..a4aed3e 100644
--- a/tcpip/transport/udp/endpoint.go
+++ b/tcpip/transport/udp/endpoint.go
@@ -106,6 +106,10 @@
 	bindToDevice   tcpip.NICID
 	broadcast      bool
 
+	// sendTOS represents IPv4 TOS or IPv6 TrafficClass,
+	// applied while sending packets. Defaults to 0 as on Linux.
+	sendTOS uint8
+
 	// shutdownFlags represent the current shutdown state of the endpoint.
 	shutdownFlags tcpip.ShutdownFlags
 
@@ -429,7 +433,7 @@
 		useDefaultTTL = false
 	}
 
-	if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL); err != nil {
+	if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS); err != nil {
 		return 0, nil, err
 	}
 	return int64(len(v)), nil, nil
@@ -628,6 +632,18 @@
 		e.mu.Unlock()
 
 		return nil
+
+	case tcpip.IPv4TOSOption:
+		e.mu.Lock()
+		e.sendTOS = uint8(v)
+		e.mu.Unlock()
+		return nil
+
+	case tcpip.IPv6TrafficClassOption:
+		e.mu.Lock()
+		e.sendTOS = uint8(v)
+		e.mu.Unlock()
+		return nil
 	}
 	return nil
 }
@@ -748,6 +764,18 @@
 		}
 		return nil
 
+	case *tcpip.IPv4TOSOption:
+		e.mu.RLock()
+		*o = tcpip.IPv4TOSOption(e.sendTOS)
+		e.mu.RUnlock()
+		return nil
+
+	case *tcpip.IPv6TrafficClassOption:
+		e.mu.RLock()
+		*o = tcpip.IPv6TrafficClassOption(e.sendTOS)
+		e.mu.RUnlock()
+		return nil
+
 	default:
 		return tcpip.ErrUnknownProtocolOption
 	}
@@ -755,7 +783,7 @@
 
 // sendUDP sends a UDP segment via the provided network endpoint and under the
 // provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool) *tcpip.Error {
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8) *tcpip.Error {
 	// Allocate a buffer for the UDP header.
 	hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
 
@@ -778,7 +806,10 @@
 		udp.SetChecksum(^udp.CalculateChecksum(xsum))
 	}
 
-	if err := r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl, useDefaultTTL); err != nil {
+	if useDefaultTTL {
+		ttl = r.DefaultTTL()
+	}
+	if err := r.WritePacket(nil /* gso */, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil {
 		r.Stats().UDP.PacketSendErrors.Increment()
 		return err
 	}
diff --git a/tcpip/transport/udp/protocol.go b/tcpip/transport/udp/protocol.go
index da02493..9969fc3 100644
--- a/tcpip/transport/udp/protocol.go
+++ b/tcpip/transport/udp/protocol.go
@@ -130,7 +130,7 @@
 		pkt.SetType(header.ICMPv4DstUnreachable)
 		pkt.SetCode(header.ICMPv4PortUnreachable)
 		pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload))
-		r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv4ProtocolNumber, 0, true /* useDefaultTTL */)
+		r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS})
 
 	case header.IPv6AddressSize:
 		if !r.Stack().AllowICMPMessage() {
@@ -164,7 +164,7 @@
 		pkt.SetType(header.ICMPv6DstUnreachable)
 		pkt.SetCode(header.ICMPv6PortUnreachable)
 		pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload))
-		r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv6ProtocolNumber, 0, true /* useDefaultTTL */)
+		r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS})
 	}
 	return true
 }
diff --git a/tcpip/transport/udp/udp_test.go b/tcpip/transport/udp/udp_test.go
index 4b5862c..50e52ba 100644
--- a/tcpip/transport/udp/udp_test.go
+++ b/tcpip/transport/udp/udp_test.go
@@ -1267,6 +1267,76 @@
 	}
 }
 
+func TestTOSV4(t *testing.T) {
+	for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
+		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+			c := newDualTestContext(t, defaultMTU)
+			defer c.cleanup()
+
+			c.createEndpointForFlow(flow)
+
+			const tos = 0xC0
+			var v tcpip.IPv4TOSOption
+			if err := c.ep.GetSockOpt(&v); err != nil {
+				c.t.Errorf("GetSockopt failed: %s", err)
+			}
+			// Test for expected default value.
+			if v != 0 {
+				c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
+			}
+
+			if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
+				c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err)
+			}
+
+			if err := c.ep.GetSockOpt(&v); err != nil {
+				c.t.Errorf("GetSockopt failed: %s", err)
+			}
+
+			if want := tcpip.IPv4TOSOption(tos); v != want {
+				c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+			}
+
+			testWrite(c, flow, checker.TOS(tos, 0))
+		})
+	}
+}
+
+func TestTOSV6(t *testing.T) {
+	for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} {
+		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+			c := newDualTestContext(t, defaultMTU)
+			defer c.cleanup()
+
+			c.createEndpointForFlow(flow)
+
+			const tos = 0xC0
+			var v tcpip.IPv6TrafficClassOption
+			if err := c.ep.GetSockOpt(&v); err != nil {
+				c.t.Errorf("GetSockopt failed: %s", err)
+			}
+			// Test for expected default value.
+			if v != 0 {
+				c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
+			}
+
+			if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil {
+				c.t.Errorf("SetSockOpt failed: %s", err)
+			}
+
+			if err := c.ep.GetSockOpt(&v); err != nil {
+				c.t.Errorf("GetSockopt failed: %s", err)
+			}
+
+			if want := tcpip.IPv6TrafficClassOption(tos); v != want {
+				c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+			}
+
+			testWrite(c, flow, checker.TOS(tos, 0))
+		})
+	}
+}
+
 func TestMulticastInterfaceOption(t *testing.T) {
 	for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
 		t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {