blob: 634f30fdc2c8a214aac4d99524b10c8684579e3e [file] [log] [blame]
// Copyright 2022 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//! Provides libs for connecting to and interacting with a FIDL protocol.
//!
//! If you have a fidl protocol like this:
//!
//! ```fidl
//! type Error = strict enum : int32 {
//! PERMANENT = 1;
//! TRANSIENT = 2;
//! };
//!
//! @discoverable
//! protocol ProtocolFactory {
//! CreateProtocol(resource struct {
//! protocol server_end:Protocol;
//! }) -> () error Error;
//! };
//!
//! protocol Protocol {
//! DoAction() -> () error Error;
//! };
//! ```
//!
//! Then you could implement ConnectedProtocol as follows:
//!
//! ```rust
//! struct ProtocolConnectedProtocol;
//! impl ConnectedProtocol for ProtocolConnectedProtocol {
//! type Protocol = ProtocolProxy;
//! type ConnectError = anyhow::Error;
//! type Message = ();
//! type SendError = anyhow::Error;
//!
//! fn get_protocol<'a>(
//! &'a mut self,
//! ) -> BoxFuture<'a, Result<Self::Protocol, Self::ConnectError>> {
//! async move {
//! let (protocol_proxy, server_end) =
//! fidl::endpoints::create_proxy().context("creating server endpoints failed")?;
//! let protocol_factory = connect_to_protocol::<ProtocolFactoryMarker>()
//! .context("Failed to connect to test.protocol.ProtocolFactory")?;
//!
//! protocol_factory
//! .create_protocol(server_end)
//! .await?
//! .map_err(|e| format_err!("Failed to create protocol: {:?}", e))?;
//!
//! Ok(protocol_proxy)
//! }
//! .boxed()
//! }
//!
//! fn send_message<'a>(
//! &'a mut self,
//! protocol: &'a Self::Protocol,
//! _msg: (),
//! ) -> BoxFuture<'a, Result<(), Self::SendError>> {
//! async move {
//! protocol.do_action().await?.map_err(|e| format_err!("Failed to do action: {:?}", e))?;
//! Ok(())
//! }
//! .boxed()
//! }
//! }
//! ```
//!
//! Then all you would have to do to connect to the service is:
//!
//! ```rust
//! let connector = ProtocolConnector::new(ProtocolConnectedProtocol);
//! let (sender, future) = connector.serve_and_log_errors();
//! let future = Task::spawn(future);
//! // Use sender to send messages to the protocol
//! ```
use fuchsia_async::{self as fasync, DurationExt};
use fuchsia_zircon as zx;
use futures::channel::mpsc;
use futures::future::BoxFuture;
use futures::{Future, StreamExt};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tracing::error;
/// A trait for implementing connecting to and sending messages to a FIDL protocol.
pub trait ConnectedProtocol {
/// The protocol that will be connected to.
type Protocol: fidl::endpoints::Proxy;
/// An error type returned for connection failures.
type ConnectError: std::fmt::Display;
/// The message type that will be forwarded to the `Protocol`.
type Message;
/// An error type returned for message send failures.
type SendError: std::fmt::Display;
/// Connects to the protocol represented by `Protocol`.
///
/// If this is a two-step process as in the case of the ServiceHub pattern,
/// both steps should be performed in this function.
fn get_protocol<'a>(&'a mut self) -> BoxFuture<'a, Result<Self::Protocol, Self::ConnectError>>;
/// Sends a message to the underlying `Protocol`.
///
/// The protocol object should be assumed to be connected.
fn send_message<'a>(
&'a mut self,
protocol: &'a Self::Protocol,
msg: Self::Message,
) -> BoxFuture<'a, Result<(), Self::SendError>>;
}
/// A ProtocolSender wraps around an `mpsc::Sender` object that is used to send
/// messages to a running ProtocolConnector instance.
#[derive(Clone, Debug)]
pub struct ProtocolSender<Msg> {
sender: mpsc::Sender<Msg>,
is_blocked: Arc<AtomicBool>,
}
/// Returned by ProtocolSender::send to notify the caller about the state of the underlying mpsc::channel.
/// None of these status codes should be considered an error state, they are purely informational.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum ProtocolSenderStatus {
/// channel is accepting new messages.
Healthy,
/// channel has rejected its first message.
BackoffStarts,
/// channel is not accepting new messages.
InBackoff,
/// channel has begun accepting messages again.
BackoffEnds,
}
impl<Msg> ProtocolSender<Msg> {
/// Create a new ProtocolSender which will use `sender` to send messages.
pub fn new(sender: mpsc::Sender<Msg>) -> Self {
Self { sender, is_blocked: Arc::new(AtomicBool::new(false)) }
}
/// Send a message to the underlying channel.
///
/// When the sender enters or exits a backoff state, it will log an error,
/// but no other feedback will be provided to the caller.
pub fn send(&mut self, message: Msg) -> ProtocolSenderStatus {
if self.sender.try_send(message).is_err() {
let was_blocked =
self.is_blocked.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst);
if let Ok(false) = was_blocked {
ProtocolSenderStatus::BackoffStarts
} else {
ProtocolSenderStatus::InBackoff
}
} else {
let was_blocked =
self.is_blocked.compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst);
if let Ok(true) = was_blocked {
ProtocolSenderStatus::BackoffEnds
} else {
ProtocolSenderStatus::Healthy
}
}
}
}
struct ExponentialBackoff {
initial: zx::Duration,
current: zx::Duration,
factor: f64,
}
impl ExponentialBackoff {
fn new(initial: zx::Duration, factor: f64) -> Self {
Self { initial, current: initial, factor }
}
fn next_timer(&mut self) -> fasync::Timer {
let timer = fasync::Timer::new(self.current.after_now());
self.current =
zx::Duration::from_nanos((self.current.into_nanos() as f64 * self.factor) as i64);
timer
}
fn reset(&mut self) {
self.current = self.initial;
}
}
/// Errors encountered while connecting to or sending messages to the ConnectedProtocol implementation.
#[derive(Debug, PartialEq, Eq)]
pub enum ProtocolConnectorError<ConnectError, ProtocolError> {
/// Connecting to the protocol failed for some reason.
ConnectFailed(ConnectError),
/// Connection to the protocol was dropped. A reconnect will be triggered.
ConnectionLost,
/// The protocol returned an error while sending a message.
ProtocolError(ProtocolError),
}
/// ProtocolConnector contains the logic to use a `ConnectedProtocol` to connect
/// to and forward messages to a protocol.
pub struct ProtocolConnector<CP: ConnectedProtocol> {
/// The size of the `mpsc::channel` to use when sending event objects from the main thread to the worker thread.
pub buffer_size: usize,
protocol: CP,
}
impl<CP: ConnectedProtocol> ProtocolConnector<CP> {
/// Construct a ProtocolConnector with the default `buffer_size` (10)
pub fn new(protocol: CP) -> Self {
Self::new_with_buffer_size(protocol, 10)
}
/// Construct a ProtocolConnector with a specified `buffer_size`
pub fn new_with_buffer_size(protocol: CP, buffer_size: usize) -> Self {
Self { buffer_size, protocol }
}
/// serve_and_log_errors creates both a ProtocolSender and a future that can
/// be used to send messages to the underlying protocol. All errors from the
/// underlying protocol will be logged.
pub fn serve_and_log_errors(self) -> (ProtocolSender<CP::Message>, impl Future<Output = ()>) {
let protocol = <<<CP as ConnectedProtocol>::Protocol as fidl::endpoints::Proxy>::Protocol as fidl::endpoints::ProtocolMarker>::DEBUG_NAME;
let mut log_error = log_first_n_factory(30, move |e| error!(%protocol, "{}", e));
self.serve(move |e| match e {
ProtocolConnectorError::ConnectFailed(e) => {
log_error(format!("Error obtaining a connection to the protocol: {}", e))
}
ProtocolConnectorError::ConnectionLost => {
log_error("Protocol disconnected, starting reconnect.".into())
}
ProtocolConnectorError::ProtocolError(e) => {
log_error(format!("Protocol returned an error: {}", e))
}
})
}
/// serve creates both a ProtocolSender and a future that can be used to send
/// messages to the underlying protocol.
pub fn serve<ErrHandler: FnMut(ProtocolConnectorError<CP::ConnectError, CP::SendError>)>(
self,
h: ErrHandler,
) -> (ProtocolSender<CP::Message>, impl Future<Output = ()>) {
let (sender, receiver) = mpsc::channel(self.buffer_size);
let sender = ProtocolSender::new(sender);
(sender, self.send_events(receiver, h))
}
async fn send_events<
ErrHandler: FnMut(ProtocolConnectorError<CP::ConnectError, CP::SendError>),
>(
mut self,
mut receiver: mpsc::Receiver<<CP as ConnectedProtocol>::Message>,
mut h: ErrHandler,
) {
let mut backoff = ExponentialBackoff::new(zx::Duration::from_millis(100), 2.0);
loop {
let protocol = match self.protocol.get_protocol().await {
Ok(protocol) => protocol,
Err(e) => {
h(ProtocolConnectorError::ConnectFailed(e));
backoff.next_timer().await;
continue;
}
};
'receiving: loop {
match receiver.next().await {
Some(message) => {
let resp = self.protocol.send_message(&protocol, message).await;
match resp {
Ok(_) => {
backoff.reset();
continue;
}
Err(e) => {
if fidl::endpoints::Proxy::is_closed(&protocol) {
h(ProtocolConnectorError::ConnectionLost);
break 'receiving;
} else {
h(ProtocolConnectorError::ProtocolError(e));
}
}
}
}
None => return,
}
}
backoff.next_timer().await;
}
}
}
fn log_first_n_factory(n: u64, mut log_fn: impl FnMut(String)) -> impl FnMut(String) {
let mut count = 0;
move |message| {
if count < n {
count += 1;
log_fn(message);
}
}
}
#[cfg(test)]
mod test {
use super::*;
use anyhow::{format_err, Context};
use fidl_test_protocol_connector::{
ProtocolFactoryMarker, ProtocolFactoryRequest, ProtocolFactoryRequestStream, ProtocolProxy,
ProtocolRequest, ProtocolRequestStream,
};
use fuchsia_async as fasync;
use fuchsia_component::server as fserver;
use fuchsia_component_test::{
Capability, ChildOptions, LocalComponentHandles, RealmBuilder, RealmInstance, Ref, Route,
};
use futures::channel::mpsc::Sender;
use futures::{FutureExt, TryStreamExt};
use std::sync::atomic::AtomicU8;
struct ProtocolConnectedProtocol(RealmInstance, Sender<()>);
impl ConnectedProtocol for ProtocolConnectedProtocol {
type Protocol = ProtocolProxy;
type ConnectError = anyhow::Error;
type Message = ();
type SendError = anyhow::Error;
fn get_protocol<'a>(
&'a mut self,
) -> BoxFuture<'a, Result<Self::Protocol, Self::ConnectError>> {
async move {
let (protocol_proxy, server_end) =
fidl::endpoints::create_proxy().context("creating server endpoints failed")?;
let protocol_factory = self
.0
.root
.connect_to_protocol_at_exposed_dir::<ProtocolFactoryMarker>()
.context("Connecting to test.protocol.ProtocolFactory failed")?;
protocol_factory
.create_protocol(server_end)
.await?
.map_err(|e| format_err!("Failed to create protocol: {:?}", e))?;
Ok(protocol_proxy)
}
.boxed()
}
fn send_message<'a>(
&'a mut self,
protocol: &'a Self::Protocol,
_msg: (),
) -> BoxFuture<'a, Result<(), Self::SendError>> {
async move {
protocol
.do_action()
.await?
.map_err(|e| format_err!("Failed to do action: {:?}", e))?;
self.1.try_send(())?;
Ok(())
}
.boxed()
}
}
async fn protocol_mock(
stream: ProtocolRequestStream,
calls_made: Arc<AtomicU8>,
close_after: Option<Arc<AtomicU8>>,
) -> Result<(), anyhow::Error> {
stream
.map(|result| result.context("failed request"))
.try_for_each(|request| async {
let calls_made = calls_made.clone();
let close_after = close_after.clone();
match request {
ProtocolRequest::DoAction { responder } => {
calls_made.fetch_add(1, Ordering::SeqCst);
responder.send(Ok(()))?;
}
}
if let Some(ca) = &close_after {
if ca.fetch_sub(1, Ordering::SeqCst) == 1 {
return Err(format_err!("close_after triggered"));
}
}
Ok(())
})
.await
}
async fn protocol_factory_mock(
handles: LocalComponentHandles,
calls_made: Arc<AtomicU8>,
close_after: Option<u8>,
) -> Result<(), anyhow::Error> {
let mut fs = fserver::ServiceFs::new();
let mut tasks = vec![];
fs.dir("svc").add_fidl_service(move |mut stream: ProtocolFactoryRequestStream| {
let calls_made = calls_made.clone();
tasks.push(fasync::Task::local(async move {
while let Some(ProtocolFactoryRequest::CreateProtocol { protocol, responder }) =
stream.try_next().await.expect("ProtocolFactoryRequestStream yielded an Err(_)")
{
let close_after = close_after.map(|ca| Arc::new(AtomicU8::new(ca)));
responder.send(Ok(())).expect("Replying to CreateProtocol caller failed");
let _ = protocol_mock(
protocol.into_stream().expect("Converting ServerEnd to stream failed"),
calls_made.clone(),
close_after,
)
.await;
}
}));
});
fs.serve_connection(handles.outgoing_dir)?;
fs.collect::<()>().await;
Ok(())
}
async fn setup_realm(
calls_made: Arc<AtomicU8>,
close_after: Option<u8>,
) -> Result<RealmInstance, anyhow::Error> {
let builder = RealmBuilder::new().await?;
let protocol_factory_server = builder
.add_local_child(
"protocol_factory",
move |handles: LocalComponentHandles| {
Box::pin(protocol_factory_mock(handles, calls_made.clone(), close_after))
},
ChildOptions::new(),
)
.await?;
builder
.add_route(
Route::new()
.capability(Capability::protocol_by_name(
"test.protocol.connector.ProtocolFactory",
))
.from(&protocol_factory_server)
.to(Ref::parent()),
)
.await?;
Ok(builder.build().await?)
}
#[fuchsia::test(logging_tags = ["test_protocol_connector"])]
async fn test_protocol_connector() -> Result<(), anyhow::Error> {
let calls_made = Arc::new(AtomicU8::new(0));
let realm = setup_realm(calls_made.clone(), None).await?;
let (log_received_sender, mut log_received_receiver) = mpsc::channel(1);
let connector = ProtocolConnectedProtocol(realm, log_received_sender);
let error_count = Arc::new(AtomicU8::new(0));
let svc = ProtocolConnector::new(connector);
let (mut sender, fut) = svc.serve({
let count = error_count.clone();
move |e| {
error!("Encountered unexpected error: {:?}", e);
count.fetch_add(1, Ordering::SeqCst);
}
});
let _server = fasync::Task::local(fut);
for _ in 0..10 {
assert_eq!(sender.send(()), ProtocolSenderStatus::Healthy);
log_received_receiver.next().await;
}
assert_eq!(calls_made.fetch_add(0, Ordering::SeqCst), 10);
assert_eq!(error_count.fetch_add(0, Ordering::SeqCst), 0);
Ok(())
}
#[fuchsia::test(logging_tags = ["test_protocol_reconnnect"])]
async fn test_protocol_reconnect() -> Result<(), anyhow::Error> {
let calls_made = Arc::new(AtomicU8::new(0));
// Simulate the protocol closing after each successful call.
let realm = setup_realm(calls_made.clone(), Some(1)).await?;
let (log_received_sender, mut log_received_receiver) = mpsc::channel(1);
let connector = ProtocolConnectedProtocol(realm, log_received_sender);
let svc = ProtocolConnector::new(connector);
let (mut err_send, mut err_rcv) = mpsc::channel(1);
let (mut sender, fut) = svc.serve(move |e| {
err_send.try_send(e).expect("Could not log error");
});
let _server = fasync::Task::local(fut);
for _ in 0..10 {
// This first send will successfully call the underlying protocol.
assert_eq!(sender.send(()), ProtocolSenderStatus::Healthy);
log_received_receiver.next().await;
// The second send will not, because the protocol has shut down.
assert_eq!(sender.send(()), ProtocolSenderStatus::Healthy);
match err_rcv.next().await.expect("Expected err") {
ProtocolConnectorError::ConnectionLost => {}
_ => {
assert!(false, "saw unexpected error type");
}
}
}
assert_eq!(calls_made.fetch_add(0, Ordering::SeqCst), 10);
Ok(())
}
}