blob: 587d5e7ca0874ebfa5c40848eea42c04a0736292 [file] [log] [blame]
// Copyright 2019 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//! Group Management Protocols (GMPs).
//!
//! This module provides implementations of the Internet Group Management Protocol
//! (IGMP) and the Multicast Listener Discovery (MLD) protocol. These allow
//! hosts to join IPv4 and IPv6 multicast groups respectively.
//!
//! The term "Group Management Protocol" is defined in [RFC 4606]:
//!
//! > Due to the commonality of function, the term "Group Management Protocol",
//! > or "GMP", will be used to refer to both IGMP and MLD.
//!
//! [RFC 4606]: https://tools.ietf.org/html/rfc4604
// This macro is used by tests in both the `igmp` and `mld` modules.
#[allow(unknown_lints)] // TODO(fxbug.dev/99424): remove this after toolchain roll when lint is known
#[allow(unused_macro_rules)] // TODO(fxbug.dev/100318): remove unused macro rules and re-enable
/// Assert that the GMP state machine for `$group` is in the given state.
///
/// `$ctx` is a `context::testutil::DummyCtx` whose state contains a `groups:
/// MulticastGroupSet` field.
#[cfg(test)]
macro_rules! assert_gmp_state {
($ctx:expr, $group:expr, NonMember) => {
assert_gmp_state!(@inner $ctx, $group, MemberState::NonMember(_));
};
($ctx:expr, $group:expr, Delaying) => {
assert_gmp_state!(@inner $ctx, $group, MemberState::Delaying(_));
};
($ctx:expr, $group:expr, Idle) => {
assert_gmp_state!(@inner $ctx, $group, MemberState::Idle(_));
};
(@inner $ctx:expr, $group:expr, $pattern:pat) => {
assert!(matches!($ctx.get_ref().groups.get($group).unwrap().0.inner.as_ref().unwrap(), $pattern))
};
}
pub(crate) mod igmp;
pub(crate) mod mld;
use alloc::vec::Vec;
use core::{convert::TryFrom, fmt::Debug, time::Duration};
use crate::{
context::{RngContext, TimerContext},
data_structures::ref_counted_hash_map::{InsertResult, RefCountedHashMap, RemoveResult},
ip::IpDeviceIdContext,
Instant, InstantContext,
};
use assert_matches::assert_matches;
use net_types::{
ip::{Ip, IpAddress},
MulticastAddr,
};
use packet_formats::utils::NonZeroDuration;
use rand::Rng;
/// The result of joining a multicast group.
///
/// `GroupJoinResult` is the result of joining a multicast group in a
/// [`MulticastGroupSet`].
#[cfg_attr(test, derive(Debug, Eq, PartialEq))]
pub(crate) enum GroupJoinResult<O = ()> {
/// We were not previously a member of the group, so we joined the
/// group.
Joined(O),
/// We were already a member of the group, so we incremented the group's
/// reference count.
AlreadyMember,
}
impl<O> GroupJoinResult<O> {
/// Maps a [`GroupJoinResult::Joined`] variant to another type.
///
/// If `self` is [`GroupJoinResult::AlreadyMember`], it is left as-is.
pub(crate) fn map<P, F: FnOnce(O) -> P>(self, f: F) -> GroupJoinResult<P> {
match self {
GroupJoinResult::Joined(output) => GroupJoinResult::Joined(f(output)),
GroupJoinResult::AlreadyMember => GroupJoinResult::AlreadyMember,
}
}
}
impl<O> From<InsertResult<O>> for GroupJoinResult<O> {
fn from(result: InsertResult<O>) -> Self {
match result {
InsertResult::Inserted(output) => GroupJoinResult::Joined(output),
InsertResult::AlreadyPresent => GroupJoinResult::AlreadyMember,
}
}
}
/// The result of leaving a multicast group.
///
/// `GroupLeaveResult` is the result of leaving a multicast group in
/// [`MulticastGroupSet`].
#[cfg_attr(test, derive(Debug, Eq, PartialEq))]
pub(crate) enum GroupLeaveResult<T = ()> {
/// The reference count reached 0, so we left the group.
Left(T),
/// The reference count did not reach 0, so we are still a member of the
/// group.
StillMember,
/// We were not a member of the group.
NotMember,
}
impl<T> GroupLeaveResult<T> {
/// Maps a [`GroupLeaveResult::Left`] variant to another type.
///
/// If `self` is [`GroupLeaveResult::StillMember`] or
/// [`GroupLeaveResult::NotMember`], it is left as-is.
pub(crate) fn map<U, F: FnOnce(T) -> U>(self, f: F) -> GroupLeaveResult<U> {
match self {
GroupLeaveResult::Left(value) => GroupLeaveResult::Left(f(value)),
GroupLeaveResult::StillMember => GroupLeaveResult::StillMember,
GroupLeaveResult::NotMember => GroupLeaveResult::NotMember,
}
}
}
impl<T> From<RemoveResult<T>> for GroupLeaveResult<T> {
fn from(result: RemoveResult<T>) -> Self {
match result {
RemoveResult::Removed(value) => GroupLeaveResult::Left(value),
RemoveResult::StillPresent => GroupLeaveResult::StillMember,
RemoveResult::NotPresent => GroupLeaveResult::NotMember,
}
}
}
/// A set of reference-counted multicast groups and associated data.
///
/// `MulticastGroupSet` is a set of multicast groups, each with associated data
/// `T`. Each group is reference-counted, only being removed once its reference
/// count reaches zero.
#[cfg_attr(test, derive(Debug))]
pub(crate) struct MulticastGroupSet<A: IpAddress, T> {
inner: RefCountedHashMap<MulticastAddr<A>, T>,
}
impl<A: IpAddress, T> Default for MulticastGroupSet<A, T> {
fn default() -> MulticastGroupSet<A, T> {
MulticastGroupSet { inner: RefCountedHashMap::default() }
}
}
/// Actions to take as a consequence of joining a group.
#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
struct JoinGroupActions<P> {
send_report_and_schedule_timer: Option<(P, Duration)>,
}
impl<P> JoinGroupActions<P> {
const NOOP: Self = Self { send_report_and_schedule_timer: None };
}
/// Actions to take as a consequence of leaving a group.
#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
struct LeaveGroupActions {
send_leave: bool,
stop_timer: bool,
}
impl LeaveGroupActions {
const NOOP: Self = Self { send_leave: false, stop_timer: false };
}
/// Actions to take as a consequence of handling a received report message.
#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
struct ReportReceivedActions {
stop_timer: bool,
}
impl ReportReceivedActions {
const NOOP: Self = Self { stop_timer: false };
}
#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
enum QueryReceivedGenericAction<P> {
ScheduleTimer(Duration),
StopTimerAndSendReport(P),
}
/// Actions to take as a consequence of receiving a query message.
#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
struct QueryReceivedActions<P: ProtocolSpecific> {
generic: Option<QueryReceivedGenericAction<P>>,
protocol_specific: Option<P::Actions>,
}
impl<P: ProtocolSpecific> QueryReceivedActions<P> {
const NOOP: Self = Self { generic: None, protocol_specific: None };
}
/// Actions to take as a consequence of a report timer expiring.
#[cfg_attr(test, derive(Debug, PartialEq, Eq))]
struct ReportTimerExpiredActions<P> {
send_report: P,
}
impl<A: IpAddress, T> MulticastGroupSet<A, T> {
fn groups<'a>(&'a self) -> impl Iterator<Item = &MulticastAddr<A>> + 'a {
self.inner.iter().map(|(g, _state)| g)
}
fn join_group_with<O, F: FnOnce() -> (T, O)>(
&mut self,
group: MulticastAddr<A>,
f: F,
) -> GroupJoinResult<O> {
self.inner.insert_with(group, f).into()
}
/// Joins a multicast group and initializes it with a GMP state machine.
///
/// `join_group_gmp` joins the multicast group `group`. If the group was not
/// already joined, then a new instance of [`GmpStateMachine`] is generated
/// using [`GmpStateMachine::join_group`], it is inserted with a reference
/// count of 1, and the list of actions returned by `join_group` is
/// returned. Otherwise, if the group was already joined, its reference
/// count is incremented.
fn join_group_gmp<I: Instant, P: ProtocolSpecific + Default, R: Rng>(
&mut self,
gmp_disabled: bool,
group: MulticastAddr<A>,
rng: &mut R,
now: I,
) -> GroupJoinResult<JoinGroupActions<P>>
where
T: From<GmpStateMachine<I, P>>,
P::Config: Default,
{
self.join_group_with(group, || {
let (state, actions) = GmpStateMachine::join_group(rng, now, gmp_disabled);
(T::from(state), actions)
})
.into()
}
fn leave_group(&mut self, group: MulticastAddr<A>) -> GroupLeaveResult<T> {
self.inner.remove(group).into()
}
/// Leaves a multicast group.
///
/// `leave_group_gmp` leaves the multicast group `group` by decrementing the
/// reference count on the group. If the reference count reaches 0, the
/// group is left using [`GmpStateMachine::leave_group`] and the list of
/// actions returned by `leave_group` is returned.
fn leave_group_gmp<I: Instant, P: ProtocolSpecific>(
&mut self,
group: MulticastAddr<A>,
) -> GroupLeaveResult<LeaveGroupActions>
where
T: Into<GmpStateMachine<I, P>>,
{
self.leave_group(group).map(|state| state.into().leave_group()).into()
}
/// Does the set contain the given group?
pub(crate) fn contains(&self, group: &MulticastAddr<A>) -> bool {
self.inner.contains_key(group)
}
#[cfg(test)]
fn get(&self, group: &MulticastAddr<A>) -> Option<&T> {
self.inner.get(group)
}
fn get_mut(&mut self, group: &MulticastAddr<A>) -> Option<&mut T> {
self.inner.get_mut(group)
}
fn iter_mut<'a>(&'a mut self) -> impl 'a + Iterator<Item = (&'a MulticastAddr<A>, &'a mut T)> {
self.inner.iter_mut()
}
}
/// An implementation of a Group Management Protocol (GMP) such as the Internet
/// Group Management Protocol, Version 2 (IGMPv2) for IPv4 or the Multicast
/// Listener Discovery (MLD) protocol for IPv6.
pub(crate) trait GmpHandler<I: Ip>: IpDeviceIdContext<I> {
/// Handles GMP potentially being enabled.
///
/// Attempts to transition memberships in the non-member state to a member
/// state. Should be called anytime a configuration change occurs which
/// results in GMP potentially being enabled. E.g. when IP or GMP
/// transitions to being enabled.
fn gmp_handle_maybe_enabled(&mut self, device: Self::DeviceId);
/// Handles GMP being disabled.
///
/// All joined groups will transition to the non-member state but still
/// remain locally joined.
fn gmp_handle_disabled(&mut self, device: Self::DeviceId);
/// Joins the given multicast group.
fn gmp_join_group(
&mut self,
device: Self::DeviceId,
group_addr: MulticastAddr<I::Addr>,
) -> GroupJoinResult;
/// Leaves the given multicast group.
fn gmp_leave_group(
&mut self,
device: Self::DeviceId,
group_addr: MulticastAddr<I::Addr>,
) -> GroupLeaveResult;
}
/// This trait is used to model the different parts of the two protocols.
///
/// Though MLD and IGMPv2 share the most part of their state machines there are
/// some subtle differences between each other.
trait ProtocolSpecific: Copy {
/// The type for protocol-specific actions.
type Actions;
/// The type for protocol-specific configs.
type Config: Debug;
/// The maximum delay to wait to send an unsolicited report.
fn cfg_unsolicited_report_interval(cfg: &Self::Config) -> Duration;
/// Whether the host should send a leave message even if it is not the last
/// host in the group.
fn cfg_send_leave_anyway(cfg: &Self::Config) -> bool;
/// Get the _real_ `MAX_RESP_TIME`
///
/// `None` indicates that the maximum response time is zero and thus a
/// response should be sent immediately.
fn get_max_resp_time(resp_time: Duration) -> Option<NonZeroDuration>;
/// Respond to a query in a protocol-specific way.
///
/// When receiving a query, IGMPv2 needs to check whether the query is an
/// IGMPv1 message and, if so, set a local "IGMPv1 Router Present" flag and
/// set a timer. For MLD, this function is a no-op.
fn do_query_received_specific(
cfg: &Self::Config,
max_resp_time: Duration,
old: Self,
) -> (Self, Option<Self::Actions>);
}
/// This is used to represent the states that are common in both MLD and IGMPv2.
/// The state machine should behave as described on [RFC 2236 page 10] and [RFC
/// 2710 page 10].
///
/// [RFC 2236 page 10]: https://tools.ietf.org/html/rfc2236#page-10
/// [RFC 2710 page 10]: https://tools.ietf.org/html/rfc2710#page-10
#[cfg_attr(test, derive(Debug))]
struct GmpHostState<State, P: ProtocolSpecific> {
state: State,
/// `protocol_specific` are the value(s) you don't want the users to have a
/// chance to modify. It is supposed to be only modified by the protocol
/// itself.
protocol_specific: P,
/// `cfg` is used to store value(s) that is supposed to be modified by
/// users.
cfg: P::Config,
}
// Used to write tests in the `igmp` and `mld` modules.
#[cfg(test)]
impl<S, P: ProtocolSpecific> GmpHostState<S, P> {
fn get_protocol_specific(&self) -> P {
self.protocol_specific
}
fn get_state(&self) -> &S {
&self.state
}
}
/// The state for a multicast group membership.
///
/// The terms used here are biased towards [IGMPv2]. In [MLD], their names are
/// {Non, Delaying, Idle}-Listener instead.
///
/// [IGMPv2]: https://tools.ietf.org/html/rfc2236
/// [MLD]: https://tools.ietf.org/html/rfc2710
#[cfg_attr(test, derive(Debug))]
enum MemberState<I: Instant, P: ProtocolSpecific> {
NonMember(GmpHostState<NonMember, P>),
Delaying(GmpHostState<DelayingMember<I>, P>),
Idle(GmpHostState<IdleMember, P>),
}
/// The transition between one state and the next.
///
/// A `Transition` includes the next state to enter and any actions to take
/// while executing the transition.
struct Transition<S, P: ProtocolSpecific, Actions>(GmpHostState<S, P>, Actions);
/// Represents Non Member-specific state variables.
///
/// Memberships may be a non-member when joined locally but are not performing
/// GMP.
#[cfg_attr(test, derive(Debug))]
struct NonMember;
/// Represents Delaying Member-specific state variables.
#[cfg_attr(test, derive(Debug))]
struct DelayingMember<I: Instant> {
/// The expiration time for the current timer. Useful to check if the timer
/// needs to be reset when a query arrives.
timer_expiration: I,
/// Used to indicate whether we need to send out a Leave message when we are
/// leaving the group. This flag will become false once we heard about
/// another reporter.
last_reporter: bool,
}
/// Represents Idle Member-specific state variables.
#[cfg_attr(test, derive(Debug))]
struct IdleMember {
/// Used to indicate whether we need to send out a Leave message when we are
/// leaving the group.
last_reporter: bool,
}
impl<S, P: ProtocolSpecific> GmpHostState<S, P> {
/// Construct a `Transition` from this state into the new state `T` with the
/// given actions.
fn transition<T, A>(self, t: T, actions: A) -> Transition<T, P, A> {
Transition(
GmpHostState { state: t, protocol_specific: self.protocol_specific, cfg: self.cfg },
actions,
)
}
}
/// Compute the next state and actions to take for a member state (Delaying or
/// Idle member) that has received a query message.
///
/// # Arguments
/// * `last_reporter` indicates if the last report was sent by this node.
/// * `timer_expiration` is `None` if there are currently no timers, otherwise
/// `Some(t)` where `t` is the old instant when the currently installed timer
/// should fire. That is, `None` if an Idle member and `Some` if a Delaying
/// member.
/// * `max_resp_time` is the maximum response time required by Query message.
fn member_query_received<P: ProtocolSpecific, R: Rng, I: Instant>(
rng: &mut R,
last_reporter: bool,
timer_expiration: Option<I>,
max_resp_time: Duration,
now: I,
cfg: P::Config,
ps: P,
) -> (MemberState<I, P>, QueryReceivedActions<P>) {
let (protocol_specific, ps_actions) = P::do_query_received_specific(&cfg, max_resp_time, ps);
let (transition, generic_actions) = match P::get_max_resp_time(max_resp_time) {
None => (
GmpHostState { state: IdleMember { last_reporter }, protocol_specific, cfg }.into(),
Some(QueryReceivedGenericAction::StopTimerAndSendReport(protocol_specific)),
),
Some(max_resp_time) => {
let max_resp_time = max_resp_time.get();
let new_deadline = now.checked_add(max_resp_time).unwrap();
let (timer_expiration, action) = match timer_expiration {
Some(old) if new_deadline >= old => (old, None),
None | Some(_) => {
let delay = random_report_timeout(rng, max_resp_time);
(
now.checked_add(delay).unwrap(),
Some(QueryReceivedGenericAction::ScheduleTimer(delay)),
)
}
};
(
GmpHostState {
state: DelayingMember { last_reporter, timer_expiration },
protocol_specific,
cfg,
}
.into(),
action,
)
}
};
(transition, QueryReceivedActions { generic: generic_actions, protocol_specific: ps_actions })
}
/// Randomly generates a timeout in (0, period].
///
/// # Panics
///
/// `random_report_timeout` may panic if `period.as_micros()` overflows `u64`.
fn random_report_timeout<R: Rng>(rng: &mut R, period: Duration) -> Duration {
let micros = rng.gen_range(0..u64::try_from(period.as_micros()).unwrap()) + 1;
// u64 will be enough here because the only input of the function is from
// the `MaxRespTime` field of the GMP query packets. The representable
// number of microseconds is bounded by 2^33.
Duration::from_micros(micros)
}
impl<P: ProtocolSpecific> GmpHostState<NonMember, P> {
fn join_group<I: Instant, R: Rng>(
self,
rng: &mut R,
now: I,
) -> Transition<DelayingMember<I>, P, JoinGroupActions<P>> {
let duration = P::cfg_unsolicited_report_interval(&self.cfg);
let delay = random_report_timeout(rng, duration);
let actions = JoinGroupActions {
send_report_and_schedule_timer: Some((self.protocol_specific, delay)),
};
self.transition(
DelayingMember {
last_reporter: true,
timer_expiration: now.checked_add(delay).expect("timer expiration overflowed"),
},
actions,
)
}
fn leave_group(self) -> Transition<NonMember, P, LeaveGroupActions> {
self.transition(NonMember, LeaveGroupActions::NOOP)
}
}
impl<I: Instant, P: ProtocolSpecific> GmpHostState<DelayingMember<I>, P> {
fn query_received<R: Rng>(
self,
rng: &mut R,
max_resp_time: Duration,
now: I,
) -> (MemberState<I, P>, QueryReceivedActions<P>) {
let GmpHostState {
state: DelayingMember { last_reporter, timer_expiration },
protocol_specific,
cfg,
} = self;
member_query_received(
rng,
last_reporter,
Some(timer_expiration),
max_resp_time,
now,
cfg,
protocol_specific,
)
}
fn leave_group(self) -> Transition<NonMember, P, LeaveGroupActions> {
let actions = LeaveGroupActions {
send_leave: self.state.last_reporter || P::cfg_send_leave_anyway(&self.cfg),
stop_timer: true,
};
self.transition(NonMember, actions)
}
fn report_received(self) -> Transition<IdleMember, P, ReportReceivedActions> {
self.transition(
IdleMember { last_reporter: false },
ReportReceivedActions { stop_timer: true },
)
}
fn report_timer_expired(self) -> Transition<IdleMember, P, ReportTimerExpiredActions<P>> {
let actions = ReportTimerExpiredActions { send_report: self.protocol_specific };
self.transition(IdleMember { last_reporter: true }, actions)
}
}
impl<P: ProtocolSpecific> GmpHostState<IdleMember, P> {
fn query_received<I: Instant, R: Rng>(
self,
rng: &mut R,
max_resp_time: Duration,
now: I,
) -> (MemberState<I, P>, QueryReceivedActions<P>) {
let GmpHostState { state: IdleMember { last_reporter }, protocol_specific, cfg } = self;
member_query_received(rng, last_reporter, None, max_resp_time, now, cfg, protocol_specific)
}
fn leave_group(self) -> Transition<NonMember, P, LeaveGroupActions> {
let actions = LeaveGroupActions {
send_leave: self.state.last_reporter || P::cfg_send_leave_anyway(&self.cfg),
stop_timer: false,
};
self.transition(NonMember, actions)
}
}
impl<I: Instant, P: ProtocolSpecific> From<GmpHostState<NonMember, P>> for MemberState<I, P> {
fn from(s: GmpHostState<NonMember, P>) -> Self {
MemberState::NonMember(s)
}
}
impl<I: Instant, P: ProtocolSpecific> From<GmpHostState<DelayingMember<I>, P>>
for MemberState<I, P>
{
fn from(s: GmpHostState<DelayingMember<I>, P>) -> Self {
MemberState::Delaying(s)
}
}
impl<I: Instant, P: ProtocolSpecific> From<GmpHostState<IdleMember, P>> for MemberState<I, P> {
fn from(s: GmpHostState<IdleMember, P>) -> Self {
MemberState::Idle(s)
}
}
impl<S, P: ProtocolSpecific, A> Transition<S, P, A> {
fn into_state_actions<I: Instant>(self) -> (MemberState<I, P>, A)
where
MemberState<I, P>: From<GmpHostState<S, P>>,
{
(self.0.into(), self.1)
}
}
impl<I: Instant, P: ProtocolSpecific> MemberState<I, P> {
/// Performs the "join group" transition, producing a new `MemberState` and
/// set of actions to execute.
fn join_group<R: Rng>(
protocol_specific: P,
cfg: P::Config,
rng: &mut R,
now: I,
gmp_disabled: bool,
) -> (MemberState<I, P>, JoinGroupActions<P>) {
let non_member = GmpHostState { protocol_specific, cfg, state: NonMember };
if gmp_disabled {
(non_member.into(), JoinGroupActions::NOOP)
} else {
non_member.join_group(rng, now).into_state_actions()
}
}
/// Performs the "leave group" transition, consuming the state by value, and
/// returning the next state and a set of actions to execute.
///
/// In the [IGMPv2] and [MLD] RFCs, the "leave group" transition moves from
/// any state to the Non-Member state. However, we don't allow `MemberState`
/// to be in the Non-Member state, so we instead implement `leave_group` by
/// consuming the state by value. This ensures that once a group has been
/// left, we don't spuriously store state for it.
///
/// [IGMPv2]: https://tools.ietf.org/html/rfc2236
/// [MLD]: https://tools.ietf.org/html/rfc2710
fn leave_group(self) -> (MemberState<I, P>, LeaveGroupActions) {
// Rust can infer these types, but since we're just discarding `_state`,
// we explicitly make sure it's the state we expect in case we introduce
// a bug.
match self {
MemberState::NonMember(state) => state.leave_group(),
MemberState::Delaying(state) => state.leave_group(),
MemberState::Idle(state) => state.leave_group(),
}
.into_state_actions()
}
fn query_received<R: Rng>(
self,
rng: &mut R,
max_resp_time: Duration,
now: I,
) -> (MemberState<I, P>, QueryReceivedActions<P>) {
match self {
state @ MemberState::NonMember(_) => (state, QueryReceivedActions::NOOP),
MemberState::Delaying(state) => state.query_received(rng, max_resp_time, now),
MemberState::Idle(state) => state.query_received(rng, max_resp_time, now),
}
}
fn report_received(self) -> (MemberState<I, P>, ReportReceivedActions) {
match self {
state @ MemberState::Idle(_) | state @ MemberState::NonMember(_) => {
(state, ReportReceivedActions::NOOP)
}
MemberState::Delaying(state) => state.report_received().into_state_actions(),
}
}
fn report_timer_expired(self) -> (MemberState<I, P>, ReportTimerExpiredActions<P>) {
match self {
MemberState::Idle(_) | MemberState::NonMember(_) => {
unreachable!("got report timer in non-delaying state")
}
MemberState::Delaying(state) => state.report_timer_expired().into_state_actions(),
}
}
}
#[cfg_attr(test, derive(Debug))]
struct GmpStateMachine<I: Instant, P: ProtocolSpecific> {
// Invariant: `inner` is always `Some`. It is stored as an `Option` so that
// methods can `.take()` the `MemberState` in order to perform transitions
// that consume `MemberState` by value. However, a new `MemberState` is
// always put back in its place so that `inner` is `Some` by the time the
// methods return.
inner: Option<MemberState<I, P>>,
}
impl<I: Instant, P: ProtocolSpecific + Default> GmpStateMachine<I, P>
where
P::Config: Default,
{
/// When a "join group" command is received.
///
/// `join_group` initializes a new state machine in the Non-Member state and
/// then immediately executes the "join group" transition. The new state
/// machine is returned along with any actions to take.
fn join_group<R: Rng>(
rng: &mut R,
now: I,
gmp_disabled: bool,
) -> (GmpStateMachine<I, P>, JoinGroupActions<P>) {
let (state, actions) =
MemberState::join_group(P::default(), P::Config::default(), rng, now, gmp_disabled);
(GmpStateMachine { inner: Some(state) }, actions)
}
}
impl<I: Instant, P: ProtocolSpecific> GmpStateMachine<I, P> {
/// Attempts to join the group if the group is currently in the non-member
/// state.
///
/// If the group is in a member state (delaying/idle), this method does
/// nothing.
fn join_if_non_member<R: Rng>(&mut self, rng: &mut R, now: I) -> JoinGroupActions<P> {
self.update(|s| match s {
MemberState::NonMember(s) => s.join_group(rng, now).into_state_actions(),
state @ MemberState::Delaying(_) | state @ MemberState::Idle(_) => {
(state, JoinGroupActions::NOOP)
}
})
}
/// Leaves the group if the group is in a member state.
///
/// Does nothing if the group is in a non-member state.
fn leave_if_member(&mut self) -> LeaveGroupActions {
self.update(|s| s.leave_group())
}
/// When a "leave group" command is received.
///
/// `leave_group` consumes the state machine by value since we don't allow
/// storing a state machine in the Non-Member state.
fn leave_group(self) -> LeaveGroupActions {
// This `unwrap` is safe because we maintain the invariant that `inner`
// is always `Some`.
let (_state, actions) = self.inner.unwrap().leave_group();
actions
}
/// When a query is received, and we have to respond within max_resp_time.
fn query_received<R: Rng>(
&mut self,
rng: &mut R,
max_resp_time: Duration,
now: I,
) -> QueryReceivedActions<P> {
self.update(|s| s.query_received(rng, max_resp_time, now))
}
/// We have received a report from another host on our local network.
fn report_received(&mut self) -> ReportReceivedActions {
self.update(MemberState::report_received)
}
/// The timer installed has expired.
fn report_timer_expired(&mut self) -> ReportTimerExpiredActions<P> {
self.update(MemberState::report_timer_expired)
}
/// Update the state with no argument.
fn update<A, F: FnOnce(MemberState<I, P>) -> (MemberState<I, P>, A)>(&mut self, f: F) -> A {
let (s, a) = f(self.inner.take().unwrap());
self.inner = Some(s);
a
}
/// Update the state with a new protocol-specific value.
fn update_with_protocol_specific(&mut self, ps: P) {
self.update(|s| {
(
match s {
MemberState::NonMember(GmpHostState { state, cfg, protocol_specific: _ }) => {
MemberState::NonMember(GmpHostState { state, cfg, protocol_specific: ps })
}
MemberState::Delaying(GmpHostState { state, cfg, protocol_specific: _ }) => {
MemberState::Delaying(GmpHostState { state, cfg, protocol_specific: ps })
}
MemberState::Idle(GmpHostState { state, cfg, protocol_specific: _ }) => {
MemberState::Idle(GmpHostState { state, cfg, protocol_specific: ps })
}
},
(),
)
})
}
#[cfg(test)]
fn get_inner(&self) -> &MemberState<I, P> {
self.inner.as_ref().unwrap()
}
}
/// A timer ID for GMP to send a report.
#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
pub(crate) struct GmpDelayedReportTimerId<A, D> {
pub(crate) device: D,
pub(crate) group_addr: MulticastAddr<A>,
}
/// A type of GMP message.
#[derive(Debug)]
enum GmpMessageType<P> {
Report(P),
Leave,
}
/// Provides common functionality for GMP context implementations.
///
/// This trait implements portions of a group management protocol.
trait GmpContext<I: Ip, PS: ProtocolSpecific>:
IpDeviceIdContext<I> + RngContext + TimerContext<GmpDelayedReportTimerId<I::Addr, Self::DeviceId>>
{
type Err;
type GroupState: From<GmpStateMachine<Self::Instant, PS>>
+ Into<GmpStateMachine<Self::Instant, PS>>
+ AsMut<GmpStateMachine<Self::Instant, PS>>;
/// Returns true iff the group management protocol is currently disabled for
/// a multicast address on a device.
fn gmp_disabled(&self, device: Self::DeviceId, addr: MulticastAddr<I::Addr>) -> bool;
/// Sends a GMP message.
fn send_message(
&mut self,
device: Self::DeviceId,
group_addr: MulticastAddr<I::Addr>,
msg_type: GmpMessageType<PS>,
);
/// Runs protocol-specific actions.
fn run_actions(&mut self, device: Self::DeviceId, actions: PS::Actions);
fn not_a_member_err(addr: I::Addr) -> Self::Err;
fn get_state_mut_and_rng(
&mut self,
device: Self::DeviceId,
) -> (&mut MulticastGroupSet<I::Addr, Self::GroupState>, &mut Self::Rng);
fn get_state_mut(
&mut self,
device: Self::DeviceId,
) -> &mut MulticastGroupSet<I::Addr, Self::GroupState> {
let (state, _rng) = self.get_state_mut_and_rng(device);
state
}
fn get_state(&self, device: Self::DeviceId) -> &MulticastGroupSet<I::Addr, Self::GroupState>;
}
fn gmp_handle_timer<I, PS, C>(
sync_ctx: &mut C,
GmpDelayedReportTimerId { device, group_addr }: GmpDelayedReportTimerId<I::Addr, C::DeviceId>,
) where
C: GmpContext<I, PS>,
I: Ip,
PS: ProtocolSpecific,
{
let ReportTimerExpiredActions { send_report } = sync_ctx
.get_state_mut(device)
.get_mut(&group_addr)
.expect("get state for group with expired report timer")
.as_mut()
.report_timer_expired();
sync_ctx.send_message(device, group_addr, GmpMessageType::Report(send_report));
}
trait GmpMessage<I: Ip> {
fn group_addr(&self) -> I::Addr;
}
fn handle_report_message<I, PS, C>(
sync_ctx: &mut C,
device: C::DeviceId,
group_addr: MulticastAddr<I::Addr>,
) -> Result<(), C::Err>
where
C: GmpContext<I, PS>,
I: Ip,
PS: ProtocolSpecific,
{
let ReportReceivedActions { stop_timer } = sync_ctx
.get_state_mut(device)
.get_mut(&group_addr)
.ok_or(C::not_a_member_err(*group_addr))?
.as_mut()
.report_received();
if stop_timer {
assert_matches!(
sync_ctx.cancel_timer(GmpDelayedReportTimerId { device, group_addr }),
Some(_)
);
}
Ok(())
}
/// The group targeted in a query message.
enum QueryTarget<A> {
Unspecified,
Specified(MulticastAddr<A>),
}
fn handle_query_message<I, PS, C>(
sync_ctx: &mut C,
device: C::DeviceId,
target: QueryTarget<I::Addr>,
max_response_time: Duration,
) -> Result<(), C::Err>
where
C: GmpContext<I, PS>,
I: Ip,
PS: ProtocolSpecific,
{
let now = sync_ctx.now();
let (state, rng) = sync_ctx.get_state_mut_and_rng(device);
let addr_and_actions = match target {
QueryTarget::Unspecified => either::Either::Left(
state
.iter_mut()
.map(|(addr, state)| {
(addr.clone(), state.as_mut().query_received(rng, max_response_time, now))
})
.collect::<Vec<_>>(),
),
QueryTarget::Specified(group_addr) => either::Either::Right([(
group_addr,
state
.get_mut(&group_addr)
.ok_or(C::not_a_member_err(*group_addr))?
.as_mut()
.query_received(rng, max_response_time, now),
)]),
};
for (
group_addr,
QueryReceivedActions { generic: generic_actions, protocol_specific: ps_actions },
) in addr_and_actions.into_iter()
{
if let Some(generic_actions) = generic_actions {
let _: Option<C::Instant> = match generic_actions {
QueryReceivedGenericAction::ScheduleTimer(delay) => {
sync_ctx.schedule_timer(delay, GmpDelayedReportTimerId { device, group_addr })
}
QueryReceivedGenericAction::StopTimerAndSendReport(protocol_specific) => {
sync_ctx.send_message(
device,
group_addr,
GmpMessageType::Report(protocol_specific),
);
sync_ctx.cancel_timer(GmpDelayedReportTimerId { device, group_addr })
}
};
}
if let Some(ps_actions) = ps_actions {
sync_ctx.run_actions(device, ps_actions);
}
}
Ok(())
}
fn gmp_handle_maybe_enabled<C, I, PS>(sync_ctx: &mut C, device: C::DeviceId)
where
C: GmpContext<I, PS> + InstantContext,
PS: ProtocolSpecific + Default,
PS::Config: Default,
I: Ip,
{
let groups = sync_ctx
.get_state(device)
.groups()
.filter_map(|g| {
let g = g.clone();
(!sync_ctx.gmp_disabled(device, g)).then(|| g)
})
.collect::<Vec<_>>();
let now = sync_ctx.now();
let (state, rng) = sync_ctx.get_state_mut_and_rng(device);
let actions = groups
.into_iter()
.map(|group| {
let gs = state.get_mut(&group);
(group, gs.map_or(JoinGroupActions::NOOP, |s| s.as_mut().join_if_non_member(rng, now)))
})
.collect::<Vec<_>>();
for (group_addr, JoinGroupActions { send_report_and_schedule_timer }) in actions {
if let Some((protocol_specific, delay)) = send_report_and_schedule_timer {
sync_ctx.send_message(device, group_addr, GmpMessageType::Report(protocol_specific));
assert_matches!(
sync_ctx.schedule_timer(delay, GmpDelayedReportTimerId { device, group_addr }),
None
);
}
}
}
fn gmp_handle_disabled<C, I, PS>(sync_ctx: &mut C, device: C::DeviceId)
where
C: GmpContext<I, PS> + InstantContext,
PS: ProtocolSpecific,
I: Ip,
{
let (state, _rng) = sync_ctx.get_state_mut_and_rng(device);
let groups = state.groups().cloned().collect::<Vec<_>>();
let actions = groups
.into_iter()
.map(|group| {
let gs = state.get_mut(&group);
(group, gs.map_or(LeaveGroupActions::NOOP, |s| s.as_mut().leave_if_member()))
})
.collect::<Vec<_>>();
for (group_addr, LeaveGroupActions { send_leave, stop_timer }) in actions {
if stop_timer {
assert_matches!(
sync_ctx.cancel_timer(GmpDelayedReportTimerId { device, group_addr }),
Some(_)
);
}
if send_leave {
sync_ctx.send_message(device, group_addr, GmpMessageType::Leave);
}
}
}
fn gmp_join_group<C, I, PS>(
sync_ctx: &mut C,
device: C::DeviceId,
group_addr: MulticastAddr<I::Addr>,
) -> GroupJoinResult
where
C: GmpContext<I, PS> + InstantContext,
PS: ProtocolSpecific + Default,
PS::Config: Default,
I: Ip,
{
let now = sync_ctx.now();
let gmp_disabled = sync_ctx.gmp_disabled(device, group_addr);
let (state, rng) = sync_ctx.get_state_mut_and_rng(device);
state.join_group_gmp(gmp_disabled, group_addr, rng, now).map(
|JoinGroupActions { send_report_and_schedule_timer }| {
if let Some((protocol_specific, delay)) = send_report_and_schedule_timer {
sync_ctx.send_message(
device,
group_addr,
GmpMessageType::Report(protocol_specific),
);
assert_matches!(
sync_ctx.schedule_timer(delay, GmpDelayedReportTimerId { device, group_addr }),
None
);
}
},
)
}
fn gmp_leave_group<C, I, PS>(
sync_ctx: &mut C,
device: C::DeviceId,
group_addr: MulticastAddr<I::Addr>,
) -> GroupLeaveResult
where
C: GmpContext<I, PS> + InstantContext,
PS: ProtocolSpecific,
I: Ip,
{
sync_ctx.get_state_mut(device).leave_group_gmp(group_addr).map(
|LeaveGroupActions { send_leave, stop_timer }| {
if stop_timer {
assert_matches!(
sync_ctx.cancel_timer(GmpDelayedReportTimerId { device, group_addr }),
Some(_)
);
}
if send_leave {
sync_ctx.send_message(device, group_addr, GmpMessageType::Leave);
}
},
)
}
#[cfg(test)]
mod test {
use core::convert::Infallible as Never;
use assert_matches::assert_matches;
use super::*;
use crate::{context::testutil::DummyInstant, testutil::new_rng};
const DEFAULT_UNSOLICITED_REPORT_INTERVAL: Duration = Duration::from_secs(10);
/// Dummy `ProtocolSpecific` for test purposes.
#[derive(PartialEq, Eq, Copy, Clone, Debug, Default)]
struct DummyProtocolSpecific;
impl ProtocolSpecific for DummyProtocolSpecific {
/// Tests for generic state machine should not know anything about
/// protocol specific actions.
type Actions = Never;
/// Whether to send leave group message if our flag is not set.
type Config = bool;
fn cfg_unsolicited_report_interval(_cfg: &Self::Config) -> Duration {
DEFAULT_UNSOLICITED_REPORT_INTERVAL
}
fn cfg_send_leave_anyway(cfg: &Self::Config) -> bool {
*cfg
}
fn get_max_resp_time(resp_time: Duration) -> Option<NonZeroDuration> {
NonZeroDuration::new(resp_time)
}
fn do_query_received_specific(
_cfg: &Self::Config,
_max_resp_time: Duration,
old: Self,
) -> (Self, Option<Never>) {
(old, None)
}
}
impl<P: ProtocolSpecific> GmpStateMachine<DummyInstant, P> {
pub(crate) fn get_config_mut(&mut self) -> &mut P::Config {
match self.inner.as_mut().unwrap() {
MemberState::NonMember(s) => &mut s.cfg,
MemberState::Delaying(s) => &mut s.cfg,
MemberState::Idle(s) => &mut s.cfg,
}
}
}
type DummyGmpStateMachine = GmpStateMachine<DummyInstant, DummyProtocolSpecific>;
#[test]
fn test_gmp_state_non_member_to_delay_should_set_flag() {
let (s, _actions) =
DummyGmpStateMachine::join_group(&mut new_rng(0), DummyInstant::default(), false);
match s.get_inner() {
MemberState::Delaying(s) => assert!(s.get_state().last_reporter),
_ => panic!("Wrong State!"),
}
}
#[test]
fn test_gmp_state_non_member_to_delay_actions() {
let (_state, actions) =
DummyGmpStateMachine::join_group(&mut new_rng(0), DummyInstant::default(), false);
assert_matches!(
actions,
JoinGroupActions { send_report_and_schedule_timer: Some((DummyProtocolSpecific, d)) } if d <= DEFAULT_UNSOLICITED_REPORT_INTERVAL
);
}
#[test]
fn test_gmp_state_delay_no_reset_timer() {
let mut rng = new_rng(0);
let (mut s, _actions) =
DummyGmpStateMachine::join_group(&mut rng, DummyInstant::default(), false);
assert_eq!(
s.query_received(
&mut rng,
DEFAULT_UNSOLICITED_REPORT_INTERVAL + Duration::from_secs(1),
DummyInstant::default(),
),
QueryReceivedActions { generic: None, protocol_specific: None }
);
}
#[test]
fn test_gmp_state_delay_reset_timer() {
let mut rng = new_rng(0);
let (mut s, _actions) =
DummyGmpStateMachine::join_group(&mut rng, DummyInstant::default(), false);
assert_eq!(
s.query_received(&mut rng, Duration::from_millis(1), DummyInstant::default()),
QueryReceivedActions {
generic: Some(QueryReceivedGenericAction::ScheduleTimer(Duration::from_micros(1))),
protocol_specific: None
}
);
}
#[test]
fn test_gmp_state_delay_to_idle_with_report_no_flag() {
let (mut s, _actions) =
DummyGmpStateMachine::join_group(&mut new_rng(0), DummyInstant::default(), false);
assert_eq!(s.report_received(), ReportReceivedActions { stop_timer: true });
match s.get_inner() {
MemberState::Idle(s) => {
assert!(!s.get_state().last_reporter);
}
_ => panic!("Wrong State!"),
}
}
#[test]
fn test_gmp_state_delay_to_idle_without_report_set_flag() {
let (mut s, _actions) =
DummyGmpStateMachine::join_group(&mut new_rng(0), DummyInstant::default(), false);
assert_eq!(
s.report_timer_expired(),
ReportTimerExpiredActions { send_report: DummyProtocolSpecific }
);
match s.get_inner() {
MemberState::Idle(s) => {
assert!(s.get_state().last_reporter);
}
_ => panic!("Wrong State!"),
}
}
#[test]
fn test_gmp_state_leave_should_send_leave() {
let mut rng = new_rng(0);
let (s, _actions) =
DummyGmpStateMachine::join_group(&mut rng, DummyInstant::default(), false);
assert_eq!(s.leave_group(), LeaveGroupActions { send_leave: true, stop_timer: true },);
let (mut s, _actions) =
DummyGmpStateMachine::join_group(&mut rng, DummyInstant::default(), false);
assert_eq!(
s.report_timer_expired(),
ReportTimerExpiredActions { send_report: DummyProtocolSpecific }
);
assert_eq!(s.leave_group(), LeaveGroupActions { send_leave: true, stop_timer: false });
}
#[test]
fn test_gmp_state_delay_to_other_states_should_stop_timer() {
let mut rng = new_rng(0);
let (s, _actions) =
DummyGmpStateMachine::join_group(&mut rng, DummyInstant::default(), false);
assert_eq!(s.leave_group(), LeaveGroupActions { send_leave: true, stop_timer: true },);
let (mut s, _actions) =
DummyGmpStateMachine::join_group(&mut rng, DummyInstant::default(), false);
assert_eq!(s.report_received(), ReportReceivedActions { stop_timer: true });
}
#[test]
fn test_gmp_state_other_states_to_delay_should_schedule_timer() {
let mut rng = new_rng(0);
let (mut s, actions) =
DummyGmpStateMachine::join_group(&mut rng, DummyInstant::default(), false);
assert_matches!(
actions,
JoinGroupActions { send_report_and_schedule_timer: Some((DummyProtocolSpecific, d)) } if d <= DEFAULT_UNSOLICITED_REPORT_INTERVAL
);
assert_eq!(s.report_received(), ReportReceivedActions { stop_timer: true });
assert_eq!(
s.query_received(&mut rng, Duration::from_secs(1), DummyInstant::default()),
QueryReceivedActions {
generic: Some(QueryReceivedGenericAction::ScheduleTimer(Duration::from_micros(1))),
protocol_specific: None
}
);
}
#[test]
fn test_gmp_state_leave_send_anyway_do_send() {
let (mut s, _actions) =
DummyGmpStateMachine::join_group(&mut new_rng(0), DummyInstant::default(), false);
*s.get_config_mut() = true;
assert_eq!(s.report_received(), ReportReceivedActions { stop_timer: true });
match s.get_inner() {
MemberState::Idle(s) => assert!(!s.get_state().last_reporter),
_ => panic!("Wrong State!"),
}
assert_eq!(s.leave_group(), LeaveGroupActions { send_leave: true, stop_timer: false });
}
#[test]
fn test_gmp_state_leave_not_the_last_do_nothing() {
let (mut s, _actions) =
DummyGmpStateMachine::join_group(&mut new_rng(0), DummyInstant::default(), false);
assert_eq!(s.report_received(), ReportReceivedActions { stop_timer: true });
assert_eq!(s.leave_group(), LeaveGroupActions { send_leave: false, stop_timer: false })
}
}