blob: a6e426bc403e6bed48f266dea7f1ef0807990e8e [file] [log] [blame]
// Copyright 2018 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use crate::Never;
use futures::{
future::{Future, FutureExt, FutureObj},
ready,
task::{LocalWaker, Poll},
};
use std::{marker::Unpin, pin::Pin};
pub struct State<E>(FutureObj<'static, Result<State<E>, E>>);
pub struct StateMachine<E> {
cur_state: State<E>,
}
impl<E> Unpin for StateMachine<E> {}
impl<E> Future for StateMachine<E> {
type Output = Result<Never, E>;
fn poll(mut self: Pin<&mut Self>, lw: &LocalWaker) -> Poll<Self::Output> {
loop {
match ready!(self.cur_state.0.poll_unpin(lw)) {
Ok(next) => self.cur_state = next,
Err(e) => return Poll::Ready(Err(e)),
}
}
}
}
pub trait IntoStateExt<E>: Future<Output = Result<State<E>, E>> {
fn into_state(self) -> State<E>
where
Self: Sized + Send + 'static,
{
State(FutureObj::new(Box::new(self)))
}
fn into_state_machine(self) -> StateMachine<E>
where
Self: Sized + Send + 'static,
{
StateMachine { cur_state: self.into_state() }
}
}
impl<F, E> IntoStateExt<E> for F where F: Future<Output = Result<State<E>, E>> {}
#[cfg(test)]
mod tests {
use super::*;
use fuchsia_async as fasync;
use futures::{channel::mpsc, stream::StreamExt};
use std::mem;
#[test]
fn state_machine() {
let mut exec = fasync::Executor::new().expect("Failed to create an executor");
let (sender, receiver) = mpsc::unbounded();
let mut state_machine = sum_state(0, receiver).into_state_machine();
let r = exec.run_until_stalled(&mut state_machine);
assert_eq!(Poll::Pending, r);
sender.unbounded_send(2).unwrap();
sender.unbounded_send(3).unwrap();
mem::drop(sender);
let r = exec.run_until_stalled(&mut state_machine);
assert_eq!(Poll::Ready(Err(5)), r);
}
async fn sum_state(
current: u32,
stream: mpsc::UnboundedReceiver<u32>,
) -> Result<State<u32>, u32> {
let (number, stream) = await!(stream.into_future());
match number {
Some(number) => Ok(make_sum_state(current + number, stream)),
None => Err(current),
}
}
// A workaround for the "recursive impl Trait" problem in the compiler
fn make_sum_state(current: u32, stream: mpsc::UnboundedReceiver<u32>) -> State<u32> {
sum_state(current, stream).into_state()
}
}