[netstack] correctly handle link-local addresses
This is copied from gvisor:
- https://github.com/google/gvisor/blob/5f08f8f/pkg/sentry/socket/epsocket/epsocket.go#L181-L251
- https://github.com/google/gvisor/blob/5f08f8f/pkg/sentry/socket/epsocket/epsocket.go#L1300-L1345
Test: None
Change-Id: Ia802af8f7fe530161bad93afc7bc03bd50481de4
diff --git a/go/src/netstack/socket_conv.go b/go/src/netstack/socket_conv.go
index 3a7f87d..a494367 100644
--- a/go/src/netstack/socket_conv.go
+++ b/go/src/netstack/socket_conv.go
@@ -6,13 +6,11 @@
import (
"encoding/binary"
- "fmt"
"log"
"math"
"time"
"github.com/google/netstack/tcpip"
- "github.com/google/netstack/tcpip/header"
"github.com/google/netstack/tcpip/transport/tcp"
"github.com/google/netstack/tcpip/transport/udp"
)
@@ -24,53 +22,6 @@
// #include <netinet/udp.h>
import "C"
-func isZeros(b []byte) bool {
- for _, b := range b {
- if b != 0 {
- return false
- }
- }
- return true
-}
-
-func (v *C.struct_sockaddr_in) Decode() tcpip.FullAddress {
- out := tcpip.FullAddress{
- Port: binary.BigEndian.Uint16(v.sin_port.Bytes()),
- }
- if b := v.sin_addr.Bytes(); !isZeros(b) {
- out.Addr = tcpip.Address(b)
- }
- return out
-}
-
-func (v *C.struct_sockaddr_in) Encode(addr tcpip.FullAddress) error {
- v.sin_family = C.AF_INET
- if n := copy(v.sin_addr.Bytes(), addr.Addr); n < header.IPv4AddressSize {
- return fmt.Errorf("short %T: %d/%d", v, n, header.IPv4AddressSize)
- }
- binary.BigEndian.PutUint16(v.sin_port.Bytes(), addr.Port)
- return nil
-}
-
-func (v *C.struct_sockaddr_in6) Decode() tcpip.FullAddress {
- out := tcpip.FullAddress{
- Port: binary.BigEndian.Uint16(v.sin6_port.Bytes()),
- }
- if b := v.sin6_addr.Bytes(); !isZeros(b) {
- out.Addr = tcpip.Address(b)
- }
- return out
-}
-
-func (v *C.struct_sockaddr_in6) Encode(addr tcpip.FullAddress) error {
- v.sin6_family = C.AF_INET6
- if n := copy(v.sin6_addr.Bytes(), addr.Addr); n < header.IPv6AddressSize {
- return fmt.Errorf("short %T: %d/%d", v, n, header.IPv6AddressSize)
- }
- binary.BigEndian.PutUint16(v.sin6_port.Bytes(), addr.Port)
- return nil
-}
-
// Functions below are adapted from
// github.com/google/gvisor/pkg/sentry/socket/epsocket/epsocket.go.
//
@@ -626,3 +577,10 @@
}
return ep.SetSockOpt(struct{}{})
}
+
+// isLinkLocal determines if the given IPv6 address is link-local. This is the
+// case when it has the fe80::/10 prefix. This check is used to determine when
+// the NICID is relevant for a given IPv6 address.
+func isLinkLocal(addr tcpip.Address) bool {
+ return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80
+}
diff --git a/go/src/netstack/socket_encode.go b/go/src/netstack/socket_encode.go
index c1cf72e..34b01da 100644
--- a/go/src/netstack/socket_encode.go
+++ b/go/src/netstack/socket_encode.go
@@ -8,11 +8,14 @@
package netstack
import (
+ "encoding/binary"
"fmt"
"unsafe"
"github.com/google/netstack/tcpip"
"github.com/google/netstack/tcpip/header"
+ "github.com/google/netstack/tcpip/network/ipv4"
+ "github.com/google/netstack/tcpip/network/ipv6"
)
// #cgo CFLAGS: -D_GNU_SOURCE
@@ -136,24 +139,68 @@
return nil
}
+func isZeros(b []byte) bool {
+ for _, b := range b {
+ if b != 0 {
+ return false
+ }
+ }
+ return true
+}
+
func (v *C.struct_sockaddr_storage) Decode() (tcpip.FullAddress, error) {
switch v.ss_family {
case C.AF_INET:
- return (*C.struct_sockaddr_in)(unsafe.Pointer(v)).Decode(), nil
+ v := (*C.struct_sockaddr_in)(unsafe.Pointer(v))
+ out := tcpip.FullAddress{
+ Port: binary.BigEndian.Uint16(v.sin_port.Bytes()),
+ }
+ if b := v.sin_addr.Bytes(); !isZeros(b) {
+ out.Addr = tcpip.Address(b)
+ }
+ return out, nil
case C.AF_INET6:
- return (*C.struct_sockaddr_in6)(unsafe.Pointer(v)).Decode(), nil
+ v := (*C.struct_sockaddr_in6)(unsafe.Pointer(v))
+ out := tcpip.FullAddress{
+ Port: binary.BigEndian.Uint16(v.sin6_port.Bytes()),
+ }
+ if b := v.sin6_addr.Bytes(); !isZeros(b) {
+ out.Addr = tcpip.Address(b)
+ }
+ if isLinkLocal(out.Addr) {
+ out.NIC = tcpip.NICID(v.sin6_scope_id)
+ }
+ return out, nil
default:
return tcpip.FullAddress{}, fmt.Errorf("unknown sockaddr_storage.ss_family: %d", v.ss_family)
}
}
-func (v *C.struct_sockaddr_storage) Encode(addr tcpip.FullAddress) (int, error) {
- switch len(addr.Addr) {
- case header.IPv4AddressSize:
- return C.sizeof_struct_sockaddr_in, (*C.struct_sockaddr_in)(unsafe.Pointer(v)).Encode(addr)
- case header.IPv6AddressSize:
- return C.sizeof_struct_sockaddr_in6, (*C.struct_sockaddr_in6)(unsafe.Pointer(v)).Encode(addr)
+func (v *C.struct_sockaddr_storage) Encode(netProto tcpip.NetworkProtocolNumber, addr tcpip.FullAddress) int {
+ switch netProto {
+ case ipv4.ProtocolNumber:
+ v := (*C.struct_sockaddr_in)(unsafe.Pointer(v))
+ copy(v.sin_addr.Bytes(), addr.Addr)
+ v.sin_family = C.AF_INET
+ binary.BigEndian.PutUint16(v.sin_port.Bytes(), addr.Port)
+ return C.sizeof_struct_sockaddr_in
+ case ipv6.ProtocolNumber:
+ v := (*C.struct_sockaddr_in6)(unsafe.Pointer(v))
+ if len(addr.Addr) == header.IPv4AddressSize {
+ // Copy address in v4-mapped format.
+ copy(v.sin6_addr.Bytes()[header.IPv6AddressSize-header.IPv4AddressSize:], addr.Addr)
+ v.sin6_addr.Bytes()[header.IPv6AddressSize-header.IPv4AddressSize-1] = 0xff
+ v.sin6_addr.Bytes()[header.IPv6AddressSize-header.IPv4AddressSize-2] = 0xff
+ } else {
+ copy(v.sin6_addr.Bytes(), addr.Addr)
+ }
+ v.sin6_family = C.AF_INET6
+ binary.BigEndian.PutUint16(v.sin6_port.Bytes(), addr.Port)
+ if isLinkLocal(addr.Addr) {
+ v.sin6_scope_id = C.uint32_t(addr.NIC)
+ }
+ return C.sizeof_struct_sockaddr_in6
default:
- return 0, fmt.Errorf("unknown address family %+v", addr)
+ panic(fmt.Sprintf("unknown network protocol number: %v", netProto))
}
}
diff --git a/go/src/netstack/socket_server.go b/go/src/netstack/socket_server.go
index 3edc1e9..6af8edf 100644
--- a/go/src/netstack/socket_server.go
+++ b/go/src/netstack/socket_server.go
@@ -255,11 +255,7 @@
out := make([]byte, C.FDIO_SOCKET_MSG_HEADER_SIZE+len(v))
if err := func() error {
var fdioSocketMsg C.struct_fdio_socket_msg
- n, err := fdioSocketMsg.addr.Encode(sender)
- if err != nil {
- return err
- }
- fdioSocketMsg.addrlen = C.socklen_t(n)
+ fdioSocketMsg.addrlen = C.socklen_t(fdioSocketMsg.addr.Encode(ios.netProto, sender))
if _, err := fdioSocketMsg.MarshalTo(out[:C.FDIO_SOCKET_MSG_HEADER_SIZE]); err != nil {
return err
}
@@ -610,21 +606,15 @@
}
rep.info[index].index = C.ushort(index + 1)
rep.info[index].flags |= C.NETC_IFF_UP
- if _, err := rep.info[index].addr.Encode(tcpip.FullAddress{NIC: nicid, Addr: ifs.nic.Addr}); err != nil {
- log.Printf("encoding addr failed: %v", err)
- }
- if _, err := rep.info[index].netmask.Encode(tcpip.FullAddress{NIC: nicid, Addr: tcpip.Address(ifs.nic.Netmask)}); err != nil {
- log.Printf("encoding netmask failed: %v", err)
- }
+ rep.info[index].addr.Encode(ipv4.ProtocolNumber, tcpip.FullAddress{NIC: nicid, Addr: ifs.nic.Addr})
+ rep.info[index].netmask.Encode(ipv4.ProtocolNumber, tcpip.FullAddress{NIC: nicid, Addr: tcpip.Address(ifs.nic.Netmask)})
// Long-hand for: broadaddr = ifs.nic.Addr | ^ifs.nic.Netmask
broadaddr := []byte(ifs.nic.Addr)
for i := range broadaddr {
broadaddr[i] |= ^ifs.nic.Netmask[i]
}
- if _, err := rep.info[index].broadaddr.Encode(tcpip.FullAddress{NIC: nicid, Addr: tcpip.Address(broadaddr)}); err != nil {
- log.Printf("encoding broadaddr failed: %v", err)
- }
+ rep.info[index].broadaddr.Encode(ipv4.ProtocolNumber, tcpip.FullAddress{NIC: nicid, Addr: tcpip.Address(broadaddr)})
index++
}
rep.n_info = index
@@ -694,22 +684,14 @@
return zx.ErrInvalidArgs
}
-func fdioSockAddrReply(addr tcpip.FullAddress, msg *zxsocket.Msg) zx.Status {
+func fdioSockAddrReply(netProto tcpip.NetworkProtocolNumber, addr tcpip.FullAddress, msg *zxsocket.Msg) zx.Status {
var rep C.struct_zxrio_sockaddr_reply
- {
- n, err := rep.addr.Encode(addr)
- if err != nil {
- return errStatus(err)
- }
- rep.len = C.socklen_t(n)
+ rep.len = C.socklen_t(rep.addr.Encode(netProto, addr))
+ n, err := rep.MarshalTo(msg.Data[:])
+ if err != nil {
+ return errStatus(err)
}
- {
- n, err := rep.MarshalTo(msg.Data[:])
- if err != nil {
- return errStatus(err)
- }
- msg.Datalen = uint32(n)
- }
+ msg.Datalen = uint32(n)
msg.SetOff(0)
return zx.ErrOk
}
@@ -731,7 +713,7 @@
if debug {
log.Printf("getsockname(): %+v", a)
}
- return fdioSockAddrReply(a, msg)
+ return fdioSockAddrReply(ios.netProto, a, msg)
}
func (ios *iostate) opGetPeerName(msg *zxsocket.Msg) (status zx.Status) {
@@ -739,7 +721,7 @@
if err != nil {
return zxNetError(err)
}
- return fdioSockAddrReply(a, msg)
+ return fdioSockAddrReply(ios.netProto, a, msg)
}
func (ios *iostate) loopListen(inCh chan struct{}) error {