blob: ad8fa88c691167059267132b0c809bbaa191d47e [file] [log] [blame]
// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//! Provides utilities to fold [`Stream`]s and [`TryStream`]s with a
//! short-circuited result.
use futures::{Future, FutureExt, Stream, StreamExt, TryStream, TryStreamExt};
/// Controls folding behavior.
#[derive(Debug)]
pub enum FoldWhile<C, D> {
/// Continue folding with state `C`.
Continue(C),
/// Short-circuit folding with result `D`.
Done(D),
}
/// The result of folding a stream.
#[derive(Debug, Eq, PartialEq, Clone, Copy)]
pub enum FoldResult<F, R> {
/// The stream ended with folded state `F`.
StreamEnded(F),
/// The stream was short-cirtuited with result `R`.
ShortCircuited(R),
}
impl<F, R> FoldResult<F, R> {
/// Transforms into [`Result`] mapping the [`FoldResult::StreamEnded`]
/// variant into `Ok`.
pub fn ended(self) -> Result<F, R> {
match self {
FoldResult::StreamEnded(r) => Ok(r),
FoldResult::ShortCircuited(r) => Err(r),
}
}
/// Transforms into [`Result`] mapping the [`FoldResult::ShortCircuited`]
/// variant into `Ok`.
pub fn short_circuited(self) -> Result<R, F> {
match self {
FoldResult::StreamEnded(r) => Err(r),
FoldResult::ShortCircuited(r) => Ok(r),
}
}
}
impl<F> FoldResult<F, F> {
/// Unwraps this [`FoldResult`] into its inner value, discarding the variant
/// information.
pub fn into_inner(self) -> F {
match self {
FoldResult::StreamEnded(r) | FoldResult::ShortCircuited(r) => r,
}
}
}
/// Similar to [`TryStreamExt::try_fold`], but the closure `f` can short-circuit
/// the operation by returning [`FoldWhile::Done`].
///
/// Returns [`FoldResult::StreamEnded`] with the current folded value when the
/// stream ends. Returns [`FoldResult::ShortCircuited`] with the value of
/// [`FoldWhile::Done`] if `f` short-circuits the operation.
/// Returns `Err` if either `s` or `f` returns an error.
pub fn try_fold_while<S, T, D, F, Fut>(
s: S,
init: T,
mut f: F,
) -> impl Future<Output = Result<FoldResult<T, D>, S::Error>>
where
S: TryStream,
F: FnMut(T, S::Ok) -> Fut,
Fut: Future<Output = Result<FoldWhile<T, D>, S::Error>>,
{
s.map_err(Err)
.try_fold(init, move |acc, n| {
f(acc, n).map(|r| match r {
Ok(FoldWhile::Continue(r)) => Ok(r),
Ok(FoldWhile::Done(d)) => Err(Ok(d)),
Err(e) => Err(Err(e)),
})
})
.map(|r| match r {
Ok(n) => Ok(FoldResult::StreamEnded(n)),
Err(Ok(n)) => Ok(FoldResult::ShortCircuited(n)),
Err(Err(e)) => Err(e),
})
}
/// Similar to [`StreamExt::fold`], but the closure `f` can short-circuit
/// the operation by returning [`FoldWhile::Done`].
///
/// Returns [`FoldResult::StreamEnded`] with the current folded value when the
/// stream ends. Returns [`FoldResult::ShortCircuited`] with the value of
/// [`FoldWhile::Done`] if `f` short-circuits the operation.
pub fn fold_while<S, T, D, F, Fut>(
s: S,
init: T,
mut f: F,
) -> impl Future<Output = FoldResult<T, D>>
where
S: Stream,
F: FnMut(T, S::Item) -> Fut,
Fut: Future<Output = FoldWhile<T, D>>,
{
s.map(Ok)
.try_fold(init, move |acc, n| {
f(acc, n).map(|r| match r {
FoldWhile::Continue(r) => Ok(r),
FoldWhile::Done(d) => Err(d),
})
})
.map(|r| match r {
Ok(n) => FoldResult::StreamEnded(n),
Err(n) => FoldResult::ShortCircuited(n),
})
}
#[cfg(test)]
mod tests {
use super::*;
use fuchsia_async as fasync;
use futures::channel::mpsc;
use futures::future;
#[fasync::run_singlethreaded(test)]
async fn test_try_fold_while_short_circuit() {
let (sender, stream) = mpsc::unbounded::<u32>();
const STOP_AT: u32 = 5;
let mut sum = 0;
for i in 0..10 {
if i < STOP_AT {
sum += i;
}
let () = sender.unbounded_send(i).expect("failed to send item");
}
let (acc, stop) = try_fold_while(stream.map(Result::<_, ()>::Ok), 0, |acc, next| {
future::ok(if next == STOP_AT {
FoldWhile::Done((acc, next))
} else {
FoldWhile::Continue(next + acc)
})
})
.await
.expect("try_fold_while failed")
.short_circuited()
.expect("try_fold_while should've short-circuited");
assert_eq!(stop, STOP_AT);
assert_eq!(acc, sum);
}
#[fasync::run_singlethreaded(test)]
async fn test_try_fold_while_stream_ended() {
let (sender, stream) = mpsc::unbounded::<u32>();
let mut sum = 0u32;
for i in 0..10 {
sum += i;
let () = sender.unbounded_send(i).expect("failed to send item");
}
std::mem::drop(sender);
let result =
try_fold_while::<_, _, (), _, _>(stream.map(Result::<_, ()>::Ok), 0, |acc, next| {
future::ok(FoldWhile::Continue(next + acc))
})
.await
.expect("try_fold_while failed")
.ended()
.expect("try_fold_while should have seen the stream end");
assert_eq!(result, sum);
}
#[fasync::run_singlethreaded(test)]
async fn test_try_fold_while_stream_error() {
#[derive(Debug)]
struct StreamErr;
let (sender, stream) = mpsc::unbounded::<Result<u32, StreamErr>>();
let () = sender.unbounded_send(Err(StreamErr {})).expect("failed to send item");
let StreamErr {} = try_fold_while::<_, _, (), _, _>(stream, (), |(), _: u32| async {
panic!("shouldn't receive error input");
})
.await
.expect_err("try_fold_while should return error");
}
#[fasync::run_singlethreaded(test)]
async fn test_try_fold_while_closure_error() {
#[derive(Debug)]
struct StreamErr {
item: u32,
}
const ERROR_ITEM: u32 = 1234;
let (sender, stream) = mpsc::unbounded::<Result<u32, StreamErr>>();
let () = sender.unbounded_send(Ok(ERROR_ITEM)).expect("failed to send item");
let StreamErr { item } = try_fold_while::<_, _, (), _, _>(stream, (), |(), item| {
future::err(StreamErr { item })
})
.await
.expect_err("try_fold_while should return error");
assert_eq!(item, ERROR_ITEM);
}
#[fasync::run_singlethreaded(test)]
async fn test_fold_while_short_circuit() {
let (sender, stream) = mpsc::unbounded::<u32>();
const STOP_AT: u32 = 5;
let mut sum = 0;
for i in 0..10 {
if i < STOP_AT {
sum += i;
}
let () = sender.unbounded_send(i).expect("failed to send item");
}
let (acc, stop) = fold_while(stream, 0, |acc, next| {
future::ready(if next == STOP_AT {
FoldWhile::Done((acc, next))
} else {
FoldWhile::Continue(next + acc)
})
})
.await
.short_circuited()
.expect("fold_while should've short-circuited");
assert_eq!(stop, STOP_AT);
assert_eq!(acc, sum);
}
#[fasync::run_singlethreaded(test)]
async fn test_fold_while_stream_ended() {
let (sender, stream) = mpsc::unbounded::<u32>();
let mut sum = 0u32;
for i in 0..10 {
sum += i;
let () = sender.unbounded_send(i).expect("failed to send item");
}
std::mem::drop(sender);
let result = fold_while::<_, _, (), _, _>(stream, 0, |acc, next| {
future::ready(FoldWhile::Continue(next + acc))
})
.await
.ended()
.expect("fold_while should have seen the stream end");
assert_eq!(result, sum);
}
#[test]
fn test_fold_result_into_inner() {
let x = FoldResult::<u32, u32>::StreamEnded(1);
let y = FoldResult::<u32, u32>::ShortCircuited(2);
assert_eq!(x.into_inner(), 1);
assert_eq!(y.into_inner(), 2);
}
#[test]
fn test_fold_result_mapping() {
type FoldResult = super::FoldResult<u32, bool>;
assert_eq!(FoldResult::StreamEnded(1).ended(), Ok(1));
assert_eq!(FoldResult::ShortCircuited(false).ended(), Err(false));
assert_eq!(FoldResult::StreamEnded(2).short_circuited(), Err(2));
assert_eq!(FoldResult::ShortCircuited(true).short_circuited(), Ok(true));
}
}