blob: 58743a16b5e188fd07677b7df515cb5760d6a298 [file] [log] [blame]
// Copyright 2019 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use {
crate::Event,
futures::{
channel::{mpsc, oneshot},
future::Future,
lock::Mutex,
stream::Stream,
task::{Context, Poll},
},
hyper::{
body::{Body, Bytes},
Response, StatusCode,
},
std::{mem::replace, ops::DerefMut, pin::Pin, sync::Arc},
};
pub struct SseResponseCreator {
buffer_size: usize,
clients: Arc<Mutex<Vec<Client>>>,
}
impl SseResponseCreator {
/// hyper `Response` `Body`s created by this `SseResponseCreator` will buffer
/// `buffer_size + 1` `Events` before the `Body` stream is closed for falling too far behind.
pub fn with_additional_buffer_size(buffer_size: usize) -> (Self, EventSender) {
let clients = Arc::new(Mutex::new(vec![]));
(Self { buffer_size, clients: Arc::clone(&clients) }, EventSender { clients })
}
/// Creates hyper `Response`s whose `Body`s receive `Event`s from the `EventSender` associated
/// with this `SseResponseCreator`.
pub async fn create(&self) -> Response<Body> {
let (abort_tx, abort_rx) = oneshot::channel();
let (chunk_tx, chunk_rx) = mpsc::channel(self.buffer_size);
self.clients.lock().await.push(Client { abort_tx, chunk_tx });
Response::builder()
.status(StatusCode::OK)
.header("content-type", "text/event-stream")
.body(Body::wrap_stream(BodyAbortStream { abort_rx, chunk_rx }))
.unwrap() // builder arguments are all statically determined, build will not fail
}
}
pub struct EventSender {
clients: Arc<Mutex<Vec<Client>>>,
}
impl EventSender {
/// Send an `Event` to each connected client. Clients that have fallen too far behind have
/// their connections closed.
pub async fn send(&self, event: &Event) {
let mut clients_guard = self.clients.lock().await;
let clients = replace(DerefMut::deref_mut(&mut clients_guard), vec![]);
let clients = clients
.into_iter()
.filter_map(|mut c| {
if c.try_send(event).is_ok() {
Some(c)
} else {
let _ = c.abort();
None
}
})
.collect();
*clients_guard = clients;
}
/// Number of clients that `send` will attempt to communicate with.
pub async fn client_count(&self) -> usize {
self.clients.lock().await.len()
}
/// Drops all connected clients. Already existing `Response<Body>`s created by the
/// `SseResponseCreator` should return error on subsequent `poll_next`.
pub async fn drop_all_clients(&self) {
self.clients.lock().await.clear();
}
}
// reimplementation of the body created by hyper::body::body::channel() b/c hyper doesn't allow
// specifying the buffer size and doesn't provide an abort channel.
struct BodyAbortStream {
abort_rx: oneshot::Receiver<()>,
chunk_rx: mpsc::Receiver<Bytes>,
}
impl Stream for BodyAbortStream {
type Item = Result<Bytes, &'static str>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if let Poll::Ready(_) = Pin::new(&mut self.abort_rx).poll(cx) {
return Poll::Ready(Some(Err("client dropped")));
}
match Pin::new(&mut self.chunk_rx).poll_next(cx) {
Poll::Ready(Some(chunk)) => Poll::Ready(Some(Ok(chunk))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
// Reimplementation of hyper::body::Sender b/c in-tree hyper doesn't allow specifying the buffer
// size and doesn't provide an abort channel.
struct Client {
abort_tx: oneshot::Sender<()>,
chunk_tx: mpsc::Sender<Bytes>,
}
impl Client {
fn try_send(&mut self, event: &Event) -> Result<(), ()> {
self.chunk_tx.try_send(event.to_vec().into()).map_err(|_| ())
}
fn abort(self) {
let _ = self.abort_tx.send(());
}
}
#[cfg(test)]
mod tests {
use {
super::*,
assert_matches::assert_matches,
fuchsia_async::{self as fasync},
futures::StreamExt,
};
#[fasync::run_singlethreaded(test)]
async fn response_headers() {
let (sse_response_creator, _) = SseResponseCreator::with_additional_buffer_size(0);
let resp = sse_response_creator.create().await;
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get("content-type").map(|h| h.as_bytes()),
Some(&b"text/event-stream"[..])
);
}
#[fasync::run_singlethreaded(test)]
async fn response_correct_body_single_event() {
let event = Event::from_type_and_data("event_type", "data_contents").unwrap();
let (sse_response_creator, event_sender) =
SseResponseCreator::with_additional_buffer_size(0);
let resp = sse_response_creator.create().await;
event_sender.send(&event).await;
let mut body_stream = resp.into_body();
let body_bytes = body_stream.next().await;
assert_eq!(body_bytes.unwrap().unwrap().to_vec(), event.to_vec());
}
#[fasync::run_singlethreaded(test)]
async fn full_client_dropped_other_clients_continue_to_receive_events() {
let event0 = Event::from_type_and_data("event_type0", "data_contents0").unwrap();
let (sse_response_creator, event_sender) =
SseResponseCreator::with_additional_buffer_size(0);
assert_eq!(event_sender.client_count().await, 0);
let mut body_stream0 = sse_response_creator.create().await.into_body();
let mut body_stream1 = sse_response_creator.create().await.into_body();
assert_eq!(event_sender.client_count().await, 2);
event_sender.send(&event0).await;
let body_bytes1 = body_stream1.next().await;
assert_matches!(body_bytes1, Some(Ok(chunk)) if chunk.to_vec() == event0.to_vec());
let event1 = Event::from_type_and_data("event_type1", "data_contents1").unwrap();
event_sender.send(&event1).await;
assert_eq!(event_sender.client_count().await, 1);
let body_bytes0 = body_stream0.next().await;
assert_matches!(body_bytes0, Some(Err(_)));
let body_bytes1 = body_stream1.next().await;
assert_eq!(body_bytes1.unwrap().unwrap().to_vec(), event1.to_vec());
}
#[fasync::run_singlethreaded(test)]
async fn drop_all_clients() {
let (sse_response_creator, event_sender) =
SseResponseCreator::with_additional_buffer_size(0);
let mut body_stream0 = sse_response_creator.create().await.into_body();
let mut body_stream1 = sse_response_creator.create().await.into_body();
assert_eq!(event_sender.client_count().await, 2);
event_sender.send(&Event::from_type_and_data("event_type", "data_contents").unwrap()).await;
event_sender.drop_all_clients().await;
assert_eq!(event_sender.client_count().await, 0);
assert_matches!(body_stream0.next().await, Some(Err(_)));
assert_matches!(body_stream1.next().await, Some(Err(_)));
}
}