blob: 580218b64485d1bc7a15b966a9540f418968dd19 [file] [log] [blame]
// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//! General-purpose socket utilities common to device layer and IP layer
//! sockets.
use core::{
convert::Infallible as Never, fmt::Debug, hash::Hash, marker::PhantomData, num::NonZeroU16,
};
use derivative::Derivative;
use net_types::{
ip::{GenericOverIp, Ip, IpAddress, Ipv4, Ipv6},
AddrAndZone, ScopeableAddress, SpecifiedAddr, Witness, ZonedAddr,
};
use crate::{
data_structures::socketmap::{
Entry, IterShadows, OccupiedEntry as SocketMapOccupiedEntry, SocketMap, Tagged,
},
device,
error::{ExistsError, NotFoundError},
ip::{device::state::IpDeviceStateIpExt, socket::SocketIpExt, IpLayerIpExt},
socket::address::{
AddrVecIter, ConnAddr, ConnIpAddr, ListenerAddr, ListenerIpAddr, SocketIpAddr,
},
};
/// A dual stack IP extention trait that provides the `OtherVersion` associated
/// type.
pub trait DualStackIpExt: IpLayerIpExt + IpDeviceStateIpExt {
/// The "other" IP version, e.g. [`Ipv4`] for [`Ipv6`] and vice-versa.
type OtherVersion: IpLayerIpExt
+ IpDeviceStateIpExt
+ crate::socket::datagram::DualStackIpExt<OtherVersion = Self>
+ crate::transport::tcp::socket::DualStackIpExt<OtherVersion = Self>;
}
impl DualStackIpExt for Ipv4 {
type OtherVersion = Ipv6;
}
impl DualStackIpExt for Ipv6 {
type OtherVersion = Ipv4;
}
/// State belonging to either IP stack.
///
/// Like `[either::Either]`, but with more helpful variant names.
///
/// Note that this type is not optimally type-safe, because `T` and `O` are not
/// bound by `IP` and `IP::OtherVersion`, respectively. In many cases it may be
/// more appropriate to define a one-off enum parameterized over `I: Ip`.
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum EitherStack<T, O> {
/// In the current stack version.
ThisStack(T),
/// In the other version of the stack.
OtherStack(O),
}
/// Control flow type containing either a dual-stack or non-dual-stack context.
///
/// This type exists to provide nice names to the result of
/// [`BoundStateContext::dual_stack_context`], and to allow generic code to
/// match on when checking whether a socket protocol and IP version support
/// dual-stack operation. If dual-stack operation is supported, a
/// [`MaybeDualStack::DualStack`] value will be held, otherwise a `NonDualStack`
/// value.
///
/// Note that the templated types to not have trait bounds; those are provided
/// by the trait with the `dual_stack_context` function.
///
/// In monomorphized code, this type frequently has exactly one template
/// parameter that is uninstantiable (it contains an instance of
/// [`core::convert::Infallible`] or some other empty enum, or a reference to
/// the same)! That lets the compiler optimize it out completely, creating no
/// actual runtime overhead.
#[derive(Debug)]
pub enum MaybeDualStack<DS, NDS> {
DualStack(DS),
NotDualStack(NDS),
}
/// An error encountered while enabling or disabling dual-stack operation.
#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq)]
#[generic_over_ip()]
pub enum SetDualStackEnabledError {
/// A socket can only have dual stack enabled or disabled while unbound.
SocketIsBound,
/// Similar to [`NotDualStackCapableError`]; the socket's protocol is not
/// dual stack capable.
NotCapable,
}
/// An error encountered when attempting to perform dual stack operations on
/// socket with a non dual stack capable protocol.
#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq)]
#[generic_over_ip()]
pub struct NotDualStackCapableError;
/// Describes which direction(s) of the data path should be shut down.
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
pub(crate) struct Shutdown {
/// True if the send path is shut down for the owning socket.
///
/// If this is true, the socket should not be able to send packets.
pub send: bool,
/// True if the receive path is shut down for the owning socket.
///
/// If this is true, the socket should not be able to receive packets.
pub receive: bool,
}
/// Which direction(s) to shut down for a socket.
#[derive(Copy, Clone, Debug, Eq, GenericOverIp, PartialEq)]
#[generic_over_ip()]
pub enum ShutdownType {
/// Prevent sending packets on the socket.
Send,
/// Prevent receiving packets on the socket.
Receive,
/// Prevent sending and receiving packets on the socket.
SendAndReceive,
}
impl ShutdownType {
pub(crate) fn to_send_receive(&self) -> (bool, bool) {
match self {
Self::Send => (true, false),
Self::Receive => (false, true),
Self::SendAndReceive => (true, true),
}
}
/// Creates a [`ShutdownType`] from a pair of bools for send and receive.
pub fn from_send_receive(send: bool, receive: bool) -> Option<Self> {
match (send, receive) {
(true, false) => Some(Self::Send),
(false, true) => Some(Self::Receive),
(true, true) => Some(Self::SendAndReceive),
(false, false) => None,
}
}
}
/// Determines whether the provided address is underspecified by itself.
///
/// Some addresses are ambiguous and so must have a zone identifier in order
/// to be used in a socket address. This function returns true for IPv6
/// link-local addresses and false for all others.
pub(crate) fn must_have_zone<A: IpAddress>(addr: &SpecifiedAddr<A>) -> bool {
try_into_null_zoned(addr).is_some()
}
/// Determines where a change in device is allowed given the local and remote
/// addresses.
pub(crate) fn can_device_change<
A: IpAddress,
W: device::WeakId<Strong = S>,
S: device::StrongId,
>(
local_ip: Option<&SpecifiedAddr<A>>,
remote_ip: Option<&SpecifiedAddr<A>>,
old_device: Option<&W>,
new_device: Option<&S>,
) -> bool {
let must_have_zone =
local_ip.is_some_and(must_have_zone) || remote_ip.is_some_and(must_have_zone);
if !must_have_zone {
return true;
}
let old_device = old_device.as_ref().unwrap_or_else(|| {
panic!("local_ip={:?} or remote_ip={:?} must have zone", local_ip, remote_ip)
});
new_device.as_ref().is_some_and(|new_device| old_device == new_device)
}
/// Converts into a [`AddrAndZone<A, ()>`] if the address requires a zone.
///
/// Otherwise returns `None`.
pub(crate) fn try_into_null_zoned<A: IpAddress, W: Witness<A> + ScopeableAddress + Copy>(
addr: &W,
) -> Option<AddrAndZone<W, ()>> {
if addr.get().is_loopback() {
return None;
}
AddrAndZone::new(*addr, ())
}
/// Provides a specified IP address to use in-place of an unspecified remote.
///
/// Concretely, this method is called during `connect()` and `send_to()` socket
/// operations to transform an unspecified remote IP address to the loopback
/// address. This ensures conformance with Linux and BSD.
pub(crate) fn specify_unspecified_remote<I: SocketIpExt, A: From<SocketIpAddr<I::Addr>>, Z>(
addr: Option<ZonedAddr<A, Z>>,
) -> ZonedAddr<A, Z> {
addr.unwrap_or_else(|| ZonedAddr::Unzoned(I::LOOPBACK_ADDRESS_AS_SOCKET_IP_ADDR.into()))
}
/// Specification for the identifiers in an [`AddrVec`].
///
/// This is a convenience trait for bundling together the local and remote
/// identifiers for a protocol.
pub trait SocketMapAddrSpec {
/// The local identifier portion of a socket address.
type LocalIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq + Into<NonZeroU16>;
/// The remote identifier portion of a socket address.
type RemoteIdentifier: Copy + Clone + Debug + Send + Sync + Hash + Eq;
}
pub struct ListenerAddrInfo {
pub(crate) has_device: bool,
pub(crate) specified_addr: bool,
}
impl<A: IpAddress, D: device::Id, LI> ListenerAddr<ListenerIpAddr<A, LI>, D> {
pub(crate) fn info(&self) -> ListenerAddrInfo {
let Self { device, ip: ListenerIpAddr { addr, identifier: _ } } = self;
ListenerAddrInfo { has_device: device.is_some(), specified_addr: addr.is_some() }
}
}
/// Specifies the types parameters for [`BoundSocketMap`] state as a single bundle.
pub trait SocketMapStateSpec {
/// The tag value of a socket address vector entry.
///
/// These values are derived from [`Self::ListenerAddrState`] and
/// [`Self::ConnAddrState`].
type AddrVecTag: Eq + Copy + Debug + 'static;
fn listener_tag(info: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag;
fn connected_tag(has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag;
/// An identifier for a listening socket.
type ListenerId: Clone + Debug;
/// An identifier for a connected socket.
type ConnId: Clone + Debug;
/// The state stored for a listening socket that is used to determine
/// whether sockets can share an address.
type ListenerSharingState: Clone + Debug;
/// The state stored for a connected socket that is used to determine
/// whether sockets can share an address.
type ConnSharingState: Clone + Debug;
/// The state stored for a listener socket address.
type ListenerAddrState: SocketMapAddrStateSpec<Id = Self::ListenerId, SharingState = Self::ListenerSharingState>
+ Debug;
/// The state stored for a connected socket address.
type ConnAddrState: SocketMapAddrStateSpec<Id = Self::ConnId, SharingState = Self::ConnSharingState>
+ Debug;
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub(crate) struct IncompatibleError;
pub(crate) trait Inserter<T> {
/// Inserts the provided item and consumes `self`.
///
/// Inserts a single item and consumes the inserter (thus preventing
/// additional insertions).
fn insert(self, item: T);
}
impl<'a, T, E: Extend<T>> Inserter<T> for &'a mut E {
fn insert(self, item: T) {
self.extend([item])
}
}
impl<T> Inserter<T> for Never {
fn insert(self, _: T) {
match self {}
}
}
/// Describes an entry in a [`SocketMap`] for a listener or connection address.
pub trait SocketMapAddrStateSpec {
/// The type of ID that can be present at the address.
type Id;
/// The sharing state for the address.
///
/// This can be used to determine whether a socket can be inserted at the
/// address. Every socket has its own sharing state associated with it,
/// though the sharing state is not necessarily stored in the address
/// entry.
type SharingState;
/// The type of inserter returned by [`SocketMapAddrStateSpec::try_get_inserter`].
type Inserter<'a>: Inserter<Self::Id> + 'a
where
Self: 'a,
Self::Id: 'a;
/// Creates a new `Self` holding the provided socket with the given new
/// sharing state at the specified address.
fn new(new_sharing_state: &Self::SharingState, id: Self::Id) -> Self;
/// Looks up the ID in self, returning `true` if it is present.
fn contains_id(&self, id: &Self::Id) -> bool;
/// Enables insertion in `self` for a new socket with the provided sharing
/// state.
///
/// If the new state is incompatible with the existing socket(s),
/// implementations of this function should return `Err(IncompatibleError)`.
/// If `Ok(x)` is returned, calling `x.insert(y)` will insert `y` into
/// `self`.
fn try_get_inserter<'a, 'b>(
&'b mut self,
new_sharing_state: &'a Self::SharingState,
) -> Result<Self::Inserter<'b>, IncompatibleError>;
/// Returns `Ok` if an entry with the given sharing state could be added
/// to `self`.
///
/// If this returns `Ok`, `try_get_dest` should succeed.
fn could_insert(&self, new_sharing_state: &Self::SharingState)
-> Result<(), IncompatibleError>;
/// Removes the given socket from the existing state.
///
/// Implementations should assume that `id` is contained in `self`.
fn remove_by_id(&mut self, id: Self::Id) -> RemoveResult;
}
pub(crate) trait SocketMapAddrStateUpdateSharingSpec: SocketMapAddrStateSpec {
fn try_update_sharing(
&mut self,
id: Self::Id,
new_sharing_state: &Self::SharingState,
) -> Result<(), IncompatibleError>;
}
pub trait SocketMapConflictPolicy<Addr, SharingState, I: Ip, D: device::Id, A: SocketMapAddrSpec>:
SocketMapStateSpec
{
/// Checks whether a new socket with the provided state can be inserted at
/// the given address in the existing socket map, returning an error
/// otherwise.
///
/// Implementations of this function should check for any potential
/// conflicts that would arise when inserting a socket with state
/// `new_sharing_state` into a new or existing entry at `addr` in
/// `socketmap`.
fn check_insert_conflicts(
new_sharing_state: &SharingState,
addr: &Addr,
socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
) -> Result<(), InsertError>;
}
pub(crate) trait SocketMapUpdateSharingPolicy<Addr, SharingState, I: Ip, D: device::Id, A>:
SocketMapConflictPolicy<Addr, SharingState, I, D, A>
where
A: SocketMapAddrSpec,
{
fn allows_sharing_update(
socketmap: &SocketMap<AddrVec<I, D, A>, Bound<Self>>,
addr: &Addr,
old_sharing: &SharingState,
new_sharing_state: &SharingState,
) -> Result<(), UpdateSharingError>;
}
#[derive(Derivative)]
#[derivative(Debug(bound = "S::ListenerAddrState: Debug, S::ConnAddrState: Debug"))]
pub enum Bound<S: SocketMapStateSpec + ?Sized> {
Listen(S::ListenerAddrState),
Conn(S::ConnAddrState),
}
/// An "address vector" type that can hold any address in a [`SocketMap`].
///
/// This is a "vector" in the mathematical sense, in that it denotes an address
/// in a space. Here, the space is the possible addresses to which a socket
/// receiving IP packets can be bound.
///
/// `AddrVec`s are used as keys for the `SocketMap` type. Since an incoming
/// packet can match more than one address, for each incoming packet there is a
/// set of possible `AddrVec` keys whose entries (sockets) in a `SocketMap`
/// might receive the packet.
///
/// This set of keys can be ordered by precedence as described in the
/// documentation for [`AddrVecIter`]. Calling [`IterShadows::iter_shadows`] on
/// an instance will produce the sequence of addresses it has precedence over.
#[derive(Derivative)]
#[derivative(
Debug(bound = "D: Debug"),
Clone(bound = "D: Clone"),
Eq(bound = "D: Eq"),
PartialEq(bound = "D: PartialEq"),
Hash(bound = "D: Hash")
)]
pub enum AddrVec<I: Ip, D, A: SocketMapAddrSpec + ?Sized> {
Listen(ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>),
Conn(ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>),
}
impl<I: Ip, D: device::Id, A: SocketMapAddrSpec, S: SocketMapStateSpec + ?Sized>
Tagged<AddrVec<I, D, A>> for Bound<S>
{
type Tag = S::AddrVecTag;
fn tag(&self, address: &AddrVec<I, D, A>) -> Self::Tag {
match (self, address) {
(Bound::Listen(l), AddrVec::Listen(addr)) => S::listener_tag(addr.info(), l),
(Bound::Conn(c), AddrVec::Conn(ConnAddr { device, ip: _ })) => {
S::connected_tag(device.is_some(), c)
}
(Bound::Listen(_), AddrVec::Conn(_)) => {
unreachable!("found listen state for conn addr")
}
(Bound::Conn(_), AddrVec::Listen(_)) => {
unreachable!("found conn state for listen addr")
}
}
}
}
impl<I: Ip, D: device::Id, A: SocketMapAddrSpec> IterShadows for AddrVec<I, D, A> {
type IterShadows = AddrVecIter<I, D, A>;
fn iter_shadows(&self) -> Self::IterShadows {
let (socket_ip_addr, device) = match self.clone() {
AddrVec::Conn(ConnAddr { ip, device }) => (ip.into(), device),
AddrVec::Listen(ListenerAddr { ip, device }) => (ip.into(), device),
};
let mut iter = match device {
Some(device) => AddrVecIter::with_device(socket_ip_addr, device),
None => AddrVecIter::without_device(socket_ip_addr),
};
// Skip the first element, which is always `*self`.
assert_eq!(iter.next().as_ref(), Some(self));
iter
}
}
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
pub(crate) enum SocketAddrType {
AnyListener,
SpecificListener,
Connected,
}
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
pub(crate) struct SocketAddrTypeTag<S> {
pub(crate) has_device: bool,
pub(crate) addr_type: SocketAddrType,
pub(crate) sharing: S,
}
impl<'a, A: IpAddress, D, LI, S> From<(&'a ListenerAddr<ListenerIpAddr<A, LI>, D>, S)>
for SocketAddrTypeTag<S>
{
fn from((addr, sharing): (&'a ListenerAddr<ListenerIpAddr<A, LI>, D>, S)) -> Self {
let ListenerAddr { ip: ListenerIpAddr { addr, identifier: _ }, device } = addr;
SocketAddrTypeTag {
has_device: device.is_some(),
addr_type: if addr.is_some() {
SocketAddrType::SpecificListener
} else {
SocketAddrType::AnyListener
},
sharing,
}
}
}
impl<'a, A, D, S> From<(&'a ConnAddr<A, D>, S)> for SocketAddrTypeTag<S> {
fn from((addr, sharing): (&'a ConnAddr<A, D>, S)) -> Self {
let ConnAddr { ip: _, device } = addr;
SocketAddrTypeTag {
has_device: device.is_some(),
addr_type: SocketAddrType::Connected,
sharing,
}
}
}
/// The result of attempting to remove a socket from a collection of sockets.
pub(crate) enum RemoveResult {
/// The value was removed successfully.
Success,
/// The value is the last value in the collection so the entire collection
/// should be removed.
IsLast,
}
#[derive(Derivative)]
#[derivative(Clone(bound = "S::ListenerId: Clone, S::ConnId: Clone"), Debug(bound = ""))]
pub(crate) enum SocketId<S: SocketMapStateSpec> {
Listener(S::ListenerId),
Connection(S::ConnId),
}
/// A map from socket addresses to sockets.
///
/// The types of keys and IDs is determined by the [`SocketMapStateSpec`]
/// parameter. Each listener and connected socket stores additional state.
/// Listener and connected sockets are keyed independently, but share the same
/// address vector space. Conflicts are detected on attempted insertion of new
/// sockets.
///
/// Listener addresses map to listener-address-specific state, and likewise
/// with connected addresses. Depending on protocol (determined by the
/// `SocketMapStateSpec` protocol), these address states can hold one or more
/// socket identifiers (e.g. UDP sockets with `SO_REUSEPORT` set can share an
/// address).
#[derive(Derivative)]
#[derivative(Default(bound = ""))]
pub struct BoundSocketMap<I: Ip, D: device::Id, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
addr_to_state: SocketMap<AddrVec<I, D, A>, Bound<S>>,
}
impl<I: Ip, D: device::Id, A: SocketMapAddrSpec, S: SocketMapStateSpec> BoundSocketMap<I, D, A, S> {
#[cfg(test)]
pub(crate) fn len(&self) -> usize {
self.addr_to_state.len()
}
}
/// Uninstantiable tag type for denoting listening sockets.
pub(crate) enum Listener {}
/// Uninstantiable tag type for denoting connected sockets.
pub(crate) enum Connection {}
/// View struct over one type of sockets in a [`BoundSocketMap`].
pub(crate) struct Sockets<AddrToStateMap, SocketType>(AddrToStateMap, PhantomData<SocketType>);
impl<
'a,
I: Ip,
D: device::Id,
SocketType: ConvertSocketMapState<I, D, A, S>,
A: SocketMapAddrSpec,
S: SocketMapStateSpec,
> Sockets<&'a SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
where
S: SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
{
/// Returns the state at an address, if there is any.
pub(crate) fn get_by_addr(self, addr: &SocketType::Addr) -> Option<&'a SocketType::AddrState> {
let Self(addr_to_state, _marker) = self;
addr_to_state.get(&SocketType::to_addr_vec(addr)).map(|state| {
SocketType::from_bound_ref(state)
.unwrap_or_else(|| unreachable!("found {:?} for address {:?}", state, addr))
})
}
/// Returns `Ok(())` if a socket could be inserted, otherwise an error.
///
/// Goes through a dry run of inserting a socket at the given address and
/// with the given sharing state, returning `Ok(())` if the insertion would
/// succeed, otherwise the error that would be returned.
pub(crate) fn could_insert(
self,
addr: &SocketType::Addr,
sharing: &SocketType::SharingState,
) -> Result<(), InsertError> {
let Self(addr_to_state, _) = self;
match self.get_by_addr(addr) {
Some(state) => {
state.could_insert(sharing).map_err(|IncompatibleError| InsertError::Exists)
}
None => S::check_insert_conflicts(&sharing, &addr, &addr_to_state),
}
}
}
#[derive(Derivative)]
#[derivative(Debug(bound = ""))]
pub(crate) struct SocketStateEntry<
'a,
I: Ip,
D: device::Id,
A: SocketMapAddrSpec,
S: SocketMapStateSpec,
SocketType,
> {
id: SocketId<S>,
addr_entry: SocketMapOccupiedEntry<'a, AddrVec<I, D, A>, Bound<S>>,
_marker: PhantomData<SocketType>,
}
impl<
'a,
I: Ip,
D: device::Id,
SocketType: ConvertSocketMapState<I, D, A, S>,
A: SocketMapAddrSpec,
S: SocketMapStateSpec
+ SocketMapConflictPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
> Sockets<&'a mut SocketMap<AddrVec<I, D, A>, Bound<S>>, SocketType>
where
SocketType::SharingState: Clone,
SocketType::Id: Clone,
{
pub(crate) fn try_insert(
self,
socket_addr: SocketType::Addr,
tag_state: SocketType::SharingState,
id: SocketType::Id,
) -> Result<SocketStateEntry<'a, I, D, A, S, SocketType>, (InsertError, SocketType::SharingState)>
{
self.try_insert_with(socket_addr, tag_state, |_addr, _sharing| (id, ()))
.map(|(entry, ())| entry)
}
pub(crate) fn try_insert_with<R>(
self,
socket_addr: SocketType::Addr,
tag_state: SocketType::SharingState,
make_id: impl FnOnce(SocketType::Addr, SocketType::SharingState) -> (SocketType::Id, R),
) -> Result<
(SocketStateEntry<'a, I, D, A, S, SocketType>, R),
(InsertError, SocketType::SharingState),
> {
let Self(addr_to_state, _) = self;
match S::check_insert_conflicts(&tag_state, &socket_addr, &addr_to_state) {
Err(e) => return Err((e, tag_state)),
Ok(()) => (),
};
let addr = SocketType::to_addr_vec(&socket_addr);
match addr_to_state.entry(addr) {
Entry::Occupied(mut o) => {
let (id, ret) = o.map_mut(|bound| {
let bound = match SocketType::from_bound_mut(bound) {
Some(bound) => bound,
None => unreachable!("found {:?} for address {:?}", bound, socket_addr),
};
match <SocketType::AddrState as SocketMapAddrStateSpec>::try_get_inserter(
bound, &tag_state,
) {
Ok(v) => {
let (id, ret) = make_id(socket_addr, tag_state);
v.insert(id.clone());
Ok((SocketType::to_socket_id(id), ret))
}
Err(IncompatibleError) => Err((InsertError::Exists, tag_state)),
}
})?;
Ok((SocketStateEntry { id, addr_entry: o, _marker: Default::default() }, ret))
}
Entry::Vacant(v) => {
let (id, ret) = make_id(socket_addr, tag_state.clone());
let addr_entry = v.insert(SocketType::to_bound(SocketType::AddrState::new(
&tag_state,
id.clone(),
)));
let id = SocketType::to_socket_id(id);
Ok((SocketStateEntry { id, addr_entry, _marker: Default::default() }, ret))
}
}
}
pub(crate) fn entry(
self,
id: &SocketType::Id,
addr: &SocketType::Addr,
) -> Option<SocketStateEntry<'a, I, D, A, S, SocketType>> {
let Self(addr_to_state, _) = self;
let addr_entry = match addr_to_state.entry(SocketType::to_addr_vec(addr)) {
Entry::Vacant(_) => return None,
Entry::Occupied(o) => o,
};
let state = SocketType::from_bound_ref(addr_entry.get())?;
state.contains_id(id).then_some(SocketStateEntry {
id: SocketType::to_socket_id(id.clone()),
addr_entry,
_marker: PhantomData::default(),
})
}
pub(crate) fn remove(
self,
id: &SocketType::Id,
addr: &SocketType::Addr,
) -> Result<(), NotFoundError> {
self.entry(id, addr)
.map(|entry| {
entry.remove();
})
.ok_or(NotFoundError)
}
}
#[derive(Debug)]
pub(crate) struct UpdateSharingError;
impl<
'a,
I: Ip,
D: device::Id,
SocketType: ConvertSocketMapState<I, D, A, S>,
A: SocketMapAddrSpec,
S: SocketMapStateSpec,
> SocketStateEntry<'a, I, D, A, S, SocketType>
where
SocketType::Id: Clone,
{
pub(crate) fn get_addr(&self) -> &SocketType::Addr {
let Self { id: _, addr_entry, _marker } = self;
SocketType::from_addr_vec_ref(addr_entry.key())
}
#[cfg(test)]
pub(crate) fn id(&self) -> &SocketType::Id {
let Self { id, addr_entry: _, _marker } = self;
SocketType::from_socket_id_ref(id)
}
pub(crate) fn try_update_addr(
self,
new_addr: SocketType::Addr,
) -> Result<Self, (ExistsError, Self)> {
let Self { id, addr_entry, _marker } = self;
let new_addrvec = SocketType::to_addr_vec(&new_addr);
let old_addr = addr_entry.key().clone();
let (addr_state, addr_to_state) = addr_entry.remove_from_map();
let addr_to_state = match addr_to_state.entry(new_addrvec) {
Entry::Occupied(o) => o.into_map(),
Entry::Vacant(v) => {
if v.descendant_counts().len() != 0 {
v.into_map()
} else {
let new_addr_entry = v.insert(addr_state);
return Ok(SocketStateEntry { id, addr_entry: new_addr_entry, _marker });
}
}
};
let to_restore = addr_state;
// Restore the old state before returning an error.
let addr_entry = match addr_to_state.entry(old_addr) {
Entry::Occupied(_) => unreachable!("just-removed-from entry is occupied"),
Entry::Vacant(v) => v.insert(to_restore),
};
return Err((ExistsError, SocketStateEntry { id, addr_entry, _marker }));
}
pub(crate) fn remove(self) {
let Self { id, mut addr_entry, _marker } = self;
let addr = addr_entry.key().clone();
match addr_entry.map_mut(|value| {
let value = match SocketType::from_bound_mut(value) {
Some(value) => value,
None => unreachable!("found {:?} for address {:?}", value, addr),
};
value.remove_by_id(SocketType::from_socket_id_ref(&id).clone())
}) {
RemoveResult::Success => (),
RemoveResult::IsLast => {
let _: Bound<S> = addr_entry.remove();
}
}
}
pub(crate) fn try_update_sharing(
&mut self,
old_sharing_state: &SocketType::SharingState,
new_sharing_state: SocketType::SharingState,
) -> Result<(), UpdateSharingError>
where
SocketType::AddrState: SocketMapAddrStateUpdateSharingSpec,
S: SocketMapUpdateSharingPolicy<SocketType::Addr, SocketType::SharingState, I, D, A>,
{
let Self { id, addr_entry, _marker } = self;
let addr = SocketType::from_addr_vec_ref(addr_entry.key());
S::allows_sharing_update(
addr_entry.get_map(),
addr,
old_sharing_state,
&new_sharing_state,
)?;
addr_entry
.map_mut(|value| {
let value = match SocketType::from_bound_mut(value) {
Some(value) => value,
// We shouldn't ever be storing listener state in a bound
// address, or bound state in a listener address. Doing so means
// we've got a serious bug.
None => unreachable!("found invalid state {:?}", value),
};
value.try_update_sharing(
SocketType::from_socket_id_ref(id).clone(),
&new_sharing_state,
)
})
.map_err(|IncompatibleError| UpdateSharingError)
}
}
impl<I: Ip, D: device::Id, A: SocketMapAddrSpec, S> BoundSocketMap<I, D, A, S>
where
AddrVec<I, D, A>: IterShadows,
S: SocketMapStateSpec,
{
pub(crate) fn listeners(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
where
S: SocketMapConflictPolicy<
ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
<S as SocketMapStateSpec>::ListenerSharingState,
I,
D,
A,
>,
S::ListenerAddrState:
SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
{
let Self { addr_to_state } = self;
Sockets(addr_to_state, Default::default())
}
pub(crate) fn listeners_mut(
&mut self,
) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Listener>
where
S: SocketMapConflictPolicy<
ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>,
<S as SocketMapStateSpec>::ListenerSharingState,
I,
D,
A,
>,
S::ListenerAddrState:
SocketMapAddrStateSpec<Id = S::ListenerId, SharingState = S::ListenerSharingState>,
{
let Self { addr_to_state } = self;
Sockets(addr_to_state, Default::default())
}
pub(crate) fn conns(&self) -> Sockets<&SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
where
S: SocketMapConflictPolicy<
ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
<S as SocketMapStateSpec>::ConnSharingState,
I,
D,
A,
>,
S::ConnAddrState:
SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
{
let Self { addr_to_state } = self;
Sockets(addr_to_state, Default::default())
}
pub(crate) fn conns_mut(
&mut self,
) -> Sockets<&mut SocketMap<AddrVec<I, D, A>, Bound<S>>, Connection>
where
S: SocketMapConflictPolicy<
ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>,
<S as SocketMapStateSpec>::ConnSharingState,
I,
D,
A,
>,
S::ConnAddrState:
SocketMapAddrStateSpec<Id = S::ConnId, SharingState = S::ConnSharingState>,
{
let Self { addr_to_state } = self;
Sockets(addr_to_state, Default::default())
}
#[cfg(test)]
pub(crate) fn iter_addrs(&self) -> impl Iterator<Item = &AddrVec<I, D, A>> {
let Self { addr_to_state } = self;
addr_to_state.iter().map(|(a, _v): (_, &Bound<S>)| a)
}
pub(crate) fn get_shadower_counts(&self, addr: &AddrVec<I, D, A>) -> usize {
let Self { addr_to_state } = self;
addr_to_state.descendant_counts(&addr).map(|(_sharing, size)| size.get()).sum()
}
}
#[derive(Debug, Eq, PartialEq)]
pub enum InsertError {
ShadowAddrExists,
Exists,
ShadowerExists,
IndirectConflict,
}
/// Helper trait for converting between [`AddrVec`] and [`Bound`] and their
/// variants.
pub(crate) trait ConvertSocketMapState<I: Ip, D, A: SocketMapAddrSpec, S: SocketMapStateSpec> {
type Id;
type SharingState;
type Addr: Debug;
type AddrState: SocketMapAddrStateSpec<Id = Self::Id, SharingState = Self::SharingState>;
fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A>;
fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr;
fn from_bound_ref(bound: &Bound<S>) -> Option<&Self::AddrState>;
fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut Self::AddrState>;
fn to_bound(state: Self::AddrState) -> Bound<S>;
fn to_socket_id(id: Self::Id) -> SocketId<S>;
fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id;
}
impl<I: Ip, D: device::Id, A: SocketMapAddrSpec, S: SocketMapStateSpec>
ConvertSocketMapState<I, D, A, S> for Listener
{
type Id = S::ListenerId;
type SharingState = S::ListenerSharingState;
type Addr = ListenerAddr<ListenerIpAddr<I::Addr, A::LocalIdentifier>, D>;
type AddrState = S::ListenerAddrState;
fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
AddrVec::Listen(addr.clone())
}
fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
match addr {
AddrVec::Listen(l) => l,
AddrVec::Conn(c) => unreachable!("conn addr for listener: {c:?}"),
}
}
fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ListenerAddrState> {
match bound {
Bound::Listen(l) => Some(l),
Bound::Conn(_c) => None,
}
}
fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ListenerAddrState> {
match bound {
Bound::Listen(l) => Some(l),
Bound::Conn(_c) => None,
}
}
fn to_bound(state: S::ListenerAddrState) -> Bound<S> {
Bound::Listen(state)
}
fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
match id {
SocketId::Listener(id) => id,
SocketId::Connection(_) => unreachable!("connection ID for listener"),
}
}
fn to_socket_id(id: Self::Id) -> SocketId<S> {
SocketId::Listener(id)
}
}
impl<I: Ip, D: device::Id, A: SocketMapAddrSpec, S: SocketMapStateSpec>
ConvertSocketMapState<I, D, A, S> for Connection
{
type Id = S::ConnId;
type SharingState = S::ConnSharingState;
type Addr = ConnAddr<ConnIpAddr<I::Addr, A::LocalIdentifier, A::RemoteIdentifier>, D>;
type AddrState = S::ConnAddrState;
fn to_addr_vec(addr: &Self::Addr) -> AddrVec<I, D, A> {
AddrVec::Conn(addr.clone())
}
fn from_addr_vec_ref(addr: &AddrVec<I, D, A>) -> &Self::Addr {
match addr {
AddrVec::Conn(c) => c,
AddrVec::Listen(l) => unreachable!("listener addr for conn: {l:?}"),
}
}
fn from_bound_ref(bound: &Bound<S>) -> Option<&S::ConnAddrState> {
match bound {
Bound::Listen(_l) => None,
Bound::Conn(c) => Some(c),
}
}
fn from_bound_mut(bound: &mut Bound<S>) -> Option<&mut S::ConnAddrState> {
match bound {
Bound::Listen(_l) => None,
Bound::Conn(c) => Some(c),
}
}
fn to_bound(state: S::ConnAddrState) -> Bound<S> {
Bound::Conn(state)
}
fn from_socket_id_ref(id: &SocketId<S>) -> &Self::Id {
match id {
SocketId::Connection(id) => id,
SocketId::Listener(_) => unreachable!("listener ID for connection"),
}
}
fn to_socket_id(id: Self::Id) -> SocketId<S> {
SocketId::Connection(id)
}
}
#[cfg(test)]
mod tests {
use alloc::{collections::HashSet, vec, vec::Vec};
use assert_matches::assert_matches;
use const_unwrap::const_unwrap_option;
use net_declare::{net_ip_v4, net_ip_v6};
use net_types::ip::{Ipv4Addr, Ipv6, Ipv6Addr};
use test_case::test_case;
use crate::{
device::testutil::{FakeDeviceId, FakeWeakDeviceId},
testutil::set_logger_for_test,
};
use super::*;
#[test_case(net_ip_v4!("8.8.8.8"))]
#[test_case(net_ip_v4!("127.0.0.1"))]
#[test_case(net_ip_v4!("127.0.8.9"))]
#[test_case(net_ip_v4!("224.1.2.3"))]
fn must_never_have_zone_ipv4(addr: Ipv4Addr) {
// No IPv4 addresses are allowed to have a zone.
let addr = SpecifiedAddr::new(addr).unwrap();
assert_eq!(must_have_zone(&addr), false);
}
#[test_case(net_ip_v6!("1::2:3"), false)]
#[test_case(net_ip_v6!("::1"), false; "localhost")]
#[test_case(net_ip_v6!("1::"), false)]
#[test_case(net_ip_v6!("ff03:1:2:3::1"), false)]
#[test_case(net_ip_v6!("ff02:1:2:3::1"), true)]
#[test_case(Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.get(), true)]
#[test_case(net_ip_v6!("fe80::1"), true)]
fn must_have_zone_ipv6(addr: Ipv6Addr, must_have: bool) {
// Only link-local unicast and multicast addresses are allowed to have
// zones.
let addr = SpecifiedAddr::new(addr).unwrap();
assert_eq!(must_have_zone(&addr), must_have);
}
#[test]
fn try_into_null_zoned_ipv6() {
assert_eq!(try_into_null_zoned(&Ipv6::LOOPBACK_ADDRESS), None);
let zoned = Ipv6::ALL_NODES_LINK_LOCAL_MULTICAST_ADDRESS.into_specified();
const ZONE: u32 = 5;
assert_eq!(
try_into_null_zoned(&zoned).map(|a| a.map_zone(|()| ZONE)),
Some(AddrAndZone::new(zoned, ZONE).unwrap())
);
}
enum FakeSpec {}
#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
struct Listener(usize);
#[derive(PartialEq, Eq, Debug)]
struct Multiple<T>(char, Vec<T>);
impl<T> Multiple<T> {
fn tag(&self) -> char {
let Multiple(c, _) = self;
*c
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
struct Conn(usize);
enum FakeAddrSpec {}
impl SocketMapAddrSpec for FakeAddrSpec {
type LocalIdentifier = NonZeroU16;
type RemoteIdentifier = ();
}
impl SocketMapStateSpec for FakeSpec {
type AddrVecTag = char;
type ListenerId = Listener;
type ConnId = Conn;
type ListenerSharingState = char;
type ConnSharingState = char;
type ListenerAddrState = Multiple<Listener>;
type ConnAddrState = Multiple<Conn>;
fn listener_tag(_: ListenerAddrInfo, state: &Self::ListenerAddrState) -> Self::AddrVecTag {
state.tag()
}
fn connected_tag(_has_device: bool, state: &Self::ConnAddrState) -> Self::AddrVecTag {
state.tag()
}
}
type FakeBoundSocketMap =
BoundSocketMap<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec, FakeSpec>;
/// Generator for unique socket IDs that don't have any state.
///
/// Calling [`FakeSocketIdGen::next`] returns a unique ID.
#[derive(Default)]
struct FakeSocketIdGen {
next_id: usize,
}
impl FakeSocketIdGen {
fn next(&mut self) -> usize {
let next_next_id = self.next_id + 1;
core::mem::replace(&mut self.next_id, next_next_id)
}
}
impl<I: Eq> SocketMapAddrStateSpec for Multiple<I> {
type Id = I;
type SharingState = char;
type Inserter<'a> = &'a mut Vec<I> where I: 'a;
fn new(new_sharing_state: &char, id: I) -> Self {
Self(*new_sharing_state, vec![id])
}
fn contains_id(&self, id: &Self::Id) -> bool {
self.1.contains(id)
}
fn try_get_inserter<'a, 'b>(
&'b mut self,
new_state: &'a char,
) -> Result<Self::Inserter<'b>, IncompatibleError> {
let Self(c, v) = self;
(new_state == c).then_some(v).ok_or(IncompatibleError)
}
fn could_insert(
&self,
new_sharing_state: &Self::SharingState,
) -> Result<(), IncompatibleError> {
let Self(c, _) = self;
(new_sharing_state == c).then_some(()).ok_or(IncompatibleError)
}
fn remove_by_id(&mut self, id: I) -> RemoveResult {
let Self(_, v) = self;
let index = v.iter().position(|i| i == &id).expect("did not find id");
let _: I = v.swap_remove(index);
if v.is_empty() {
RemoveResult::IsLast
} else {
RemoveResult::Success
}
}
}
impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
SocketMapConflictPolicy<A, char, Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>
for FakeSpec
{
fn check_insert_conflicts(
new_state: &char,
addr: &A,
socketmap: &SocketMap<
AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
Bound<FakeSpec>,
>,
) -> Result<(), InsertError> {
let dest = addr.clone().into();
if dest.iter_shadows().any(|a| socketmap.get(&a).is_some()) {
return Err(InsertError::ShadowAddrExists);
}
match socketmap.get(&dest) {
Some(Bound::Listen(Multiple(c, _))) | Some(Bound::Conn(Multiple(c, _))) => {
// Require that all sockets inserted in a `Multiple` entry
// have the same sharing state.
if c != new_state {
return Err(InsertError::Exists);
}
}
None => (),
}
if socketmap.descendant_counts(&dest).len() != 0 {
Err(InsertError::ShadowerExists)
} else {
Ok(())
}
}
}
impl<I: Eq> SocketMapAddrStateUpdateSharingSpec for Multiple<I> {
fn try_update_sharing(
&mut self,
id: Self::Id,
new_sharing_state: &Self::SharingState,
) -> Result<(), IncompatibleError> {
let Self(sharing, v) = self;
if new_sharing_state == sharing {
return Ok(());
}
// Preserve the invariant that all sockets inserted in a `Multiple`
// entry have the same sharing state. That means we can't change
// the sharing state of all the sockets at the address unless there
// is exactly one!
if v.len() != 1 {
return Err(IncompatibleError);
}
assert!(v.contains(&id));
*sharing = *new_sharing_state;
Ok(())
}
}
impl<A: Into<AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>> + Clone>
SocketMapUpdateSharingPolicy<A, char, Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>
for FakeSpec
{
fn allows_sharing_update(
_socketmap: &SocketMap<
AddrVec<Ipv4, FakeWeakDeviceId<FakeDeviceId>, FakeAddrSpec>,
Bound<Self>,
>,
_addr: &A,
_old_sharing: &char,
_new_sharing_state: &char,
) -> Result<(), UpdateSharingError> {
Ok(())
}
}
const LISTENER_ADDR: ListenerAddr<
ListenerIpAddr<Ipv4Addr, NonZeroU16>,
FakeWeakDeviceId<FakeDeviceId>,
> = ListenerAddr {
ip: ListenerIpAddr {
addr: Some(unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("1.2.3.4")) }),
identifier: const_unwrap_option(NonZeroU16::new(1)),
},
device: None,
};
const CONN_ADDR: ConnAddr<
ConnIpAddr<Ipv4Addr, NonZeroU16, ()>,
FakeWeakDeviceId<FakeDeviceId>,
> = ConnAddr {
ip: ConnIpAddr {
local: (
unsafe { SocketIpAddr::new_unchecked(net_ip_v4!("5.6.7.8")) },
const_unwrap_option(NonZeroU16::new(1)),
),
remote: unsafe { (SocketIpAddr::new_unchecked(net_ip_v4!("8.7.6.5")), ()) },
},
device: None,
};
#[test]
fn bound_insert_get_remove_listener() {
set_logger_for_test();
let mut bound = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let addr = LISTENER_ADDR;
let id = {
let entry =
bound.listeners_mut().try_insert(addr, 'v', Listener(fake_id_gen.next())).unwrap();
assert_eq!(entry.get_addr(), &addr);
entry.id().clone()
};
assert_eq!(bound.listeners().get_by_addr(&addr), Some(&Multiple('v', vec![id])));
assert_eq!(bound.listeners_mut().remove(&id, &addr), Ok(()));
assert_eq!(bound.listeners().get_by_addr(&addr), None);
}
#[test]
fn bound_insert_get_remove_conn() {
set_logger_for_test();
let mut bound = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let addr = CONN_ADDR;
let id = {
let entry = bound.conns_mut().try_insert(addr, 'v', Conn(fake_id_gen.next())).unwrap();
assert_eq!(entry.get_addr(), &addr);
entry.id().clone()
};
assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple('v', vec![id])));
assert_eq!(bound.conns_mut().remove(&id, &addr), Ok(()));
assert_eq!(bound.conns().get_by_addr(&addr), None);
}
#[test]
fn bound_iter_addrs() {
set_logger_for_test();
let mut bound = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let listener_addrs = [
(Some(net_ip_v4!("1.1.1.1")), 1),
(Some(net_ip_v4!("2.2.2.2")), 2),
(Some(net_ip_v4!("1.1.1.1")), 3),
(None, 4),
]
.map(|(ip, identifier)| ListenerAddr {
device: None,
ip: ListenerIpAddr {
addr: ip.map(|x| SocketIpAddr::new(x).unwrap()),
identifier: NonZeroU16::new(identifier).unwrap(),
},
});
let conn_addrs = [
(net_ip_v4!("3.3.3.3"), 3, net_ip_v4!("4.4.4.4")),
(net_ip_v4!("4.4.4.4"), 3, net_ip_v4!("3.3.3.3")),
]
.map(|(local_ip, local_identifier, remote_ip)| ConnAddr {
ip: ConnIpAddr {
local: (
SocketIpAddr::new(local_ip).unwrap(),
NonZeroU16::new(local_identifier).unwrap(),
),
remote: (SocketIpAddr::new(remote_ip).unwrap(), ()),
},
device: None,
});
for addr in listener_addrs.iter().cloned() {
let _entry =
bound.listeners_mut().try_insert(addr, 'a', Listener(fake_id_gen.next())).unwrap();
}
for addr in conn_addrs.iter().cloned() {
let _entry = bound.conns_mut().try_insert(addr, 'a', Conn(fake_id_gen.next())).unwrap();
}
let expected_addrs = listener_addrs
.into_iter()
.map(Into::into)
.chain(conn_addrs.into_iter().map(Into::into))
.collect::<HashSet<_>>();
assert_eq!(expected_addrs, bound.iter_addrs().cloned().collect());
}
#[test]
fn try_insert_with_callback_not_called_on_error() {
// TODO(https://fxbug.dev/42076891): remove this test along with
// try_insert_with.
set_logger_for_test();
let mut bound = FakeBoundSocketMap::default();
let addr = LISTENER_ADDR;
// Insert a listener so that future calls can conflict.
let _: &Listener = bound.listeners_mut().try_insert(addr, 'a', Listener(0)).unwrap().id();
// All of the below try_insert_with calls should fail, but more
// importantly, they should not call the `make_id` callback (because it
// is only called once success is certain).
fn is_never_called<A, B, T>(_: A, _: B) -> (T, ()) {
panic!("should never be called");
}
assert_matches!(
bound.listeners_mut().try_insert_with(addr, 'b', is_never_called),
Err((InsertError::Exists, _))
);
assert_matches!(
bound.listeners_mut().try_insert_with(
ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr },
'b',
is_never_called
),
Err((InsertError::ShadowAddrExists, _))
);
assert_matches!(
bound.conns_mut().try_insert_with(
ConnAddr {
device: None,
ip: ConnIpAddr {
local: (addr.ip.addr.unwrap(), addr.ip.identifier),
remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
},
},
'b',
is_never_called,
),
Err((InsertError::ShadowAddrExists, _))
);
}
#[test]
fn insert_listener_conflict_with_listener() {
set_logger_for_test();
let mut bound = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let addr = LISTENER_ADDR;
let _: &Listener =
bound.listeners_mut().try_insert(addr, 'a', Listener(fake_id_gen.next())).unwrap().id();
assert_matches!(
bound.listeners_mut().try_insert(addr, 'b', Listener(fake_id_gen.next())),
Err((InsertError::Exists, 'b'))
);
}
#[test]
fn insert_listener_conflict_with_shadower() {
set_logger_for_test();
let mut bound = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let addr = LISTENER_ADDR;
let shadows_addr = {
assert_eq!(addr.device, None);
ListenerAddr { device: Some(FakeWeakDeviceId(FakeDeviceId)), ..addr }
};
let _: &Listener =
bound.listeners_mut().try_insert(addr, 'a', Listener(fake_id_gen.next())).unwrap().id();
assert_matches!(
bound.listeners_mut().try_insert(shadows_addr, 'b', Listener(fake_id_gen.next())),
Err((InsertError::ShadowAddrExists, 'b'))
);
}
#[test]
fn insert_conn_conflict_with_listener() {
set_logger_for_test();
let mut bound = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let addr = LISTENER_ADDR;
let shadows_addr = ConnAddr {
device: None,
ip: ConnIpAddr {
local: (addr.ip.addr.unwrap(), addr.ip.identifier),
remote: (SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap(), ()),
},
};
let _: &Listener =
bound.listeners_mut().try_insert(addr, 'a', Listener(fake_id_gen.next())).unwrap().id();
assert_matches!(
bound.conns_mut().try_insert(shadows_addr, 'b', Conn(fake_id_gen.next())),
Err((InsertError::ShadowAddrExists, 'b'))
);
}
#[test]
fn insert_and_remove_listener() {
set_logger_for_test();
let mut bound = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let addr = LISTENER_ADDR;
let a = bound
.listeners_mut()
.try_insert(addr, 'x', Listener(fake_id_gen.next()))
.unwrap()
.id()
.clone();
let b = bound
.listeners_mut()
.try_insert(addr, 'x', Listener(fake_id_gen.next()))
.unwrap()
.id()
.clone();
assert_ne!(a, b);
assert_eq!(bound.listeners_mut().remove(&a, &addr), Ok(()));
assert_eq!(bound.listeners().get_by_addr(&addr), Some(&Multiple('x', vec![b])));
}
#[test]
fn insert_and_remove_conn() {
set_logger_for_test();
let mut bound = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let addr = CONN_ADDR;
let a =
bound.conns_mut().try_insert(addr, 'x', Conn(fake_id_gen.next())).unwrap().id().clone();
let b =
bound.conns_mut().try_insert(addr, 'x', Conn(fake_id_gen.next())).unwrap().id().clone();
assert_ne!(a, b);
assert_eq!(bound.conns_mut().remove(&a, &addr), Ok(()));
assert_eq!(bound.conns().get_by_addr(&addr), Some(&Multiple('x', vec![b])));
}
#[test]
fn update_listener_to_shadowed_addr_fails() {
let mut bound = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let first_addr = LISTENER_ADDR;
let second_addr = ListenerAddr {
ip: ListenerIpAddr {
addr: Some(SocketIpAddr::new(net_ip_v4!("1.1.1.1")).unwrap()),
..LISTENER_ADDR.ip
},
..LISTENER_ADDR
};
let both_shadow = ListenerAddr {
ip: ListenerIpAddr { addr: None, identifier: first_addr.ip.identifier },
device: None,
};
let first = bound
.listeners_mut()
.try_insert(first_addr, 'a', Listener(fake_id_gen.next()))
.unwrap()
.id()
.clone();
let second = bound
.listeners_mut()
.try_insert(second_addr, 'b', Listener(fake_id_gen.next()))
.unwrap()
.id()
.clone();
// Moving from (1, "aaa") to (1, None) should fail since it is shadowed
// by (1, "yyy"), and vise versa.
let (ExistsError, entry) = bound
.listeners_mut()
.entry(&second, &second_addr)
.unwrap()
.try_update_addr(both_shadow)
.expect_err("update should fail");
// The entry should correspond to `second`.
assert_eq!(entry.id(), &second);
drop(entry);
let (ExistsError, entry) = bound
.listeners_mut()
.entry(&first, &first_addr)
.unwrap()
.try_update_addr(both_shadow)
.expect_err("update should fail");
assert_eq!(entry.get_addr(), &first_addr);
}
#[test]
fn nonexistent_conn_entry() {
let mut map = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let addr = CONN_ADDR;
let conn_id = map
.conns_mut()
.try_insert(addr.clone(), 'a', Conn(fake_id_gen.next()))
.expect("failed to insert")
.id()
.clone();
assert_matches!(map.conns_mut().remove(&conn_id, &addr), Ok(()));
assert!(map.conns_mut().entry(&conn_id, &addr).is_none());
}
#[test]
fn update_conn_sharing() {
let mut map = FakeBoundSocketMap::default();
let mut fake_id_gen = FakeSocketIdGen::default();
let addr = CONN_ADDR;
let mut entry = map
.conns_mut()
.try_insert(addr.clone(), 'a', Conn(fake_id_gen.next()))
.expect("failed to insert");
entry.try_update_sharing(&'a', 'd').expect("worked");
// Updating sharing is only allowed if there are no other occupants at
// the address.
let mut second_conn = map
.conns_mut()
.try_insert(addr.clone(), 'd', Conn(fake_id_gen.next()))
.expect("can insert");
assert_matches!(second_conn.try_update_sharing(&'d', 'e'), Err(UpdateSharingError));
}
}