blob: 70b1eae397f5700d61433824ad8397208964eb2a [file] [log] [blame]
// Copyright 2023 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use crate::Task;
use futures::{channel::mpsc, Future, StreamExt};
/// Errors that can be returned by this crate.
#[derive(Debug, thiserror::Error)]
enum Error {
/// Return when a task cannot be added to a [`TaskGroup`] or [`TaskSink`].
#[error("Failed to add Task: {0}")]
GroupDropped(#[from] mpsc::TrySendError<Task<()>>),
}
/// Allows the user to spawn multiple Tasks and await them as a unit.
///
/// Tasks can be added to this group using [`TaskGroup::add`].
/// All pending tasks in the group can be awaited using [`TaskGroup::join`].
pub struct TaskGroup {
sink: TaskSink,
// A Task that waits for all tasks sent on `sink` to complete.
// `sink` writes tasks to a channel and this Task drains tasks from the channel
// using an unbounded loop. Therefore, `sink` must be dropped to close the channel and
// allow this future to complete.
done: Task<()>,
}
impl TaskGroup {
/// Creates a new TaskGroup.
///
/// The TaskGroup can be used to await an arbitrary number of Tasks and may
/// consume an arbitrary amount of memory.
pub fn new() -> Self {
let (tx, rx) = mpsc::unbounded::<Task<()>>();
let sink = TaskSink::new(tx);
let done = Task::spawn(async move {
rx.for_each_concurrent(None, |task| task).await;
});
Self { sink, done }
}
/// Adds a Task to this TaskGroup.
pub fn add(&mut self, task: Task<()>) {
// This only panics if the receiver is closed. TaskSink is private, so the only
// way to close the receiver is with TaskGroup::join() which consumes this TaskGroup.
// Therefore this method can't be called after the receiver is closed and should
// never panic.
self.sink.try_add(task).unwrap();
}
/// Spawns a new task in this TaskGroup.
///
/// To add a future that is not [`Send`] to this TaskGroup, use [`TaskGroup::add`].
///
/// # Panics
///
/// `spawn` may panic if not called in the context of an executor (e.g.
/// within a call to `run` or `run_singlethreaded`).
pub fn spawn(&mut self, future: impl Future<Output = ()> + Send + 'static) {
self.add(Task::spawn(future));
}
/// Waits for all Tasks in this TaskGroup to finish.
///
/// Call this only after all Tasks have been added.
pub async fn join(self) {
// Close the sink to ensure the receiving end of the channel terminates.
drop(self.sink);
self.done.await;
}
}
/// Adds tasks to a remote [`TaskGroup`].
///
/// This sink must be dropped before the corresponding done future on the original
/// [`TaskGroup`] can complete.
struct TaskSink {
sender: mpsc::UnboundedSender<Task<()>>,
}
impl TaskSink {
fn new(sender: mpsc::UnboundedSender<Task<()>>) -> Self {
Self { sender }
}
/// Adds a task to this sink.
pub fn try_add(&mut self, task: Task<()>) -> Result<(), Error> {
self.sender.unbounded_send(task).map_err(Error::GroupDropped)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::SendExecutor;
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
};
// Notifies a channel when dropped, signifying completion of some operation.
#[derive(Clone)]
struct DoneSignaler {
done: mpsc::UnboundedSender<()>,
}
impl Drop for DoneSignaler {
fn drop(&mut self) {
self.done.unbounded_send(()).unwrap();
self.done.disconnect();
}
}
// Waits for a group of `impl Drop` to signal completion.
// Create as many `impl Drop` objects as needed with `WaitGroup::add_one` and
// call `wait` to wait for all of them to be dropped.
struct WaitGroup {
tx: mpsc::UnboundedSender<()>,
rx: mpsc::UnboundedReceiver<()>,
}
impl WaitGroup {
fn new() -> Self {
let (tx, rx) = mpsc::unbounded();
Self { tx, rx }
}
fn add_one(&self) -> impl Drop {
DoneSignaler { done: self.tx.clone() }
}
async fn wait(self) {
drop(self.tx);
self.rx.collect::<()>().await;
}
}
#[test]
fn test_task_group_join_waits_for_tasks() {
let task_count = 20;
SendExecutor::new(task_count).run(async move {
let mut task_group = TaskGroup::new();
let value = Arc::new(AtomicU64::new(0));
for _ in 0..task_count {
let value = value.clone();
let task = Task::spawn(async move {
value.fetch_add(1, Ordering::Relaxed);
});
task_group.add(task);
}
task_group.join().await;
assert_eq!(value.load(Ordering::Relaxed), task_count as u64);
});
}
#[test]
fn test_task_group_empty_join_completes() {
SendExecutor::new(1).run(async move {
TaskGroup::new().join().await;
});
}
#[test]
fn test_task_group_added_tasks_are_cancelled_on_drop() {
let wait_group = WaitGroup::new();
let task_count = 10;
SendExecutor::new(task_count).run(async move {
let mut task_group = TaskGroup::new();
for _ in 0..task_count {
let done_signaler = wait_group.add_one();
// Never completes but drops `done_signaler` when cancelled.
let task = Task::spawn(async move {
// Take ownership of done_signaler.
let _done_signaler = done_signaler;
std::future::pending::<()>().await;
});
task_group.add(task);
}
drop(task_group);
wait_group.wait().await;
// If we get here, all tasks were cancelled.
});
}
#[test]
fn test_task_group_spawn() {
let task_count = 3;
SendExecutor::new(task_count).run(async move {
let mut task_group = TaskGroup::new();
// We can spawn tasks from any Future<()> implementation, including...
// ... naked futures.
task_group.spawn(std::future::ready(()));
// ... futures returned from async blocks.
task_group.spawn(async move {
std::future::ready(()).await;
});
// ... and other tasks.
task_group.spawn(Task::spawn(std::future::ready(())));
task_group.join().await;
});
}
}