blob: 9af6a6366884219fb3342b77ea914f059b9942b1 [file] [edit]
// Copyright 2021 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use crate::{Context, StreamHandler};
use anyhow::Result;
use fidl::server::ServeInner;
use fuchsia_async::Task;
use std::cell::RefCell;
use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::rc::Rc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Weak};
use thiserror::Error;
pub type NameToStreamHandlerMap = HashMap<String, Box<dyn StreamHandler>>;
type ProtocolHandleMap = HashMap<usize, ProtocolHandle>;
type NameToProtocolHandleMap = HashMap<String, ProtocolHandleMap>;
#[derive(Default)]
struct ProtocolRegisterInner {
protocol_map: NameToStreamHandlerMap,
handles: RefCell<NameToProtocolHandleMap>,
next_task_id: AtomicUsize,
stopping: AtomicBool,
}
impl ProtocolRegisterInner {
fn drain_handles(&self) -> Vec<LabeledProtocolHandle> {
let mut res = Vec::new();
for (name, mut map) in self.handles.borrow_mut().drain() {
for (id, t) in map.drain() {
let name = name.clone();
res.push(LabeledProtocolHandle { name, id, inner: t });
}
}
res
}
fn remove_handle(&self, s: &String, id: &usize) -> Option<ProtocolHandle> {
self.handles.borrow_mut().get_mut(s).and_then(|e| e.remove(id))
}
}
struct ProtocolHandle {
// This is just the handle to the stream that is being handled inside
// `self.task`. Ideally the task will complete in a reasonable amount of
// time after this is called.
handle: Weak<ServeInner>,
task: Task<Result<()>>,
}
/// Protocol handle with more debug information.
struct LabeledProtocolHandle {
name: String,
id: usize,
inner: ProtocolHandle,
}
impl LabeledProtocolHandle {
pub(crate) async fn shutdown(self) -> Result<()> {
self.inner.shutdown().await
}
}
impl ProtocolHandle {
pub(crate) async fn shutdown(self) -> Result<()> {
if let Some(handle) = self.handle.upgrade() {
handle.shutdown();
self.task.await
} else {
Ok(())
}
}
}
#[derive(Error, Debug)]
pub enum ProtocolError {
#[error("protocol error: {0:?}")]
StreamOpenError(#[from] anyhow::Error),
#[error("bad protocol register state: {0:?}")]
BadRegisterState(String),
#[error("could not find protocol under the name: {0}")]
NoProtocolFound(String),
#[error("duplicate task id found under protocol {0}: {1}")]
DuplicateTaskId(String, usize),
}
#[derive(Default, Clone)]
pub struct ProtocolRegister {
inner: Rc<ProtocolRegisterInner>,
}
impl ProtocolRegister {
pub fn new(map: NameToStreamHandlerMap) -> Self {
// TODO(awdavies): Start the static protocols. Probably need the daemon
// to just do this on its own.
Self { inner: Rc::new(ProtocolRegisterInner { protocol_map: map, ..Default::default() }) }
}
/// Returns an error if `self.stopping` has been set to true, otherwise
/// returns `Ok(())`.
fn invariant_check(&self) -> Result<(), ProtocolError> {
if self.inner.stopping.load(Ordering::SeqCst) {
return Err(ProtocolError::BadRegisterState(
"Cannot start any protocols. Shutting down".to_string(),
));
}
Ok(())
}
pub async fn start(&self, name: String, cx: Context) -> Result<(), ProtocolError> {
self.invariant_check()?;
let svc = self
.inner
.protocol_map
.get(&name)
.ok_or(ProtocolError::NoProtocolFound(name.clone()))?;
svc.start(cx).await.map_err(Into::into)
}
pub async fn open(
&self,
name: String,
cx: Context,
server_channel: fidl::AsyncChannel,
) -> Result<(), ProtocolError> {
self.invariant_check()?;
let task_id = self.inner.next_task_id.fetch_add(1, Ordering::SeqCst);
let svc = self
.inner
.protocol_map
.get(&name)
.ok_or(ProtocolError::NoProtocolFound(name.clone()))?;
let weak_inner = Rc::downgrade(&self.inner);
let server = Arc::new(ServeInner::new(server_channel));
let weak_server = Arc::downgrade(&server);
let name_copy = name.clone();
let fut = svc.open(cx, server).await?;
let new_task = async move {
fut.await.unwrap_or_else(|e| log::warn!("running protocol stream handler: {:#?}", e));
if let Some(inner) = weak_inner.upgrade() {
if let Some(handle) = inner.remove_handle(&name_copy, &task_id) {
// Closes the stream's handle to make sure the task
// completes cleanly.
let r = handle.shutdown().await;
log::debug!(
"protocol stream for {}-{} finished with result: {:?}",
name_copy,
task_id,
r
);
}
}
Ok(())
};
match self.inner.handles.borrow_mut().entry(name.clone()) {
Entry::Occupied(mut e) => {
if let Some(_s) = e.get_mut().insert(
task_id,
ProtocolHandle { task: Task::local(new_task), handle: weak_server },
) {
return Err(ProtocolError::DuplicateTaskId(name, task_id));
}
}
Entry::Vacant(e) => {
let mut new_map = HashMap::new();
new_map.insert(
task_id,
ProtocolHandle { task: Task::local(new_task), handle: weak_server },
);
e.insert(new_map);
}
}
Ok(())
}
pub async fn shutdown(&self, cx: Context) -> Result<(), ProtocolError> {
if self
.inner
.stopping
.compare_exchange(false, true, Ordering::SeqCst, Ordering::Acquire)
.is_err()
{
return Err(ProtocolError::BadRegisterState(
"already shutting down ProtocolRegister".to_string(),
));
}
let handler_futs = self
.inner
.drain_handles()
.drain(..)
.map(|h| async move {
let name = h.name.clone();
let id = h.id;
log::debug!("shutting down handle {}-{}", name, id);
h.shutdown()
.await
.unwrap_or_else(|e| log::warn!("shutdown for handle {}-{}: {:?}", name, id, e));
})
.collect::<Vec<_>>();
futures::future::join_all(handler_futs).await;
let mut protocol_futs = Vec::new();
for (name, svc) in self.inner.protocol_map.iter() {
let name = name.clone();
let cx = &cx;
let fut = async move {
log::debug!("shutting down stream handler for {}", name);
svc.shutdown(cx)
.await
.unwrap_or_else(|e| log::warn!("closing stream handler for {}: {:?}", name, e));
};
protocol_futs.push(fut);
}
futures::future::join_all(protocol_futs).await;
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{DaemonProtocolProvider, FidlProtocol, FidlStreamHandler};
use async_trait::async_trait;
use ffx::DaemonError;
use ffx_config::EnvironmentContext;
use fidl::endpoints::DiscoverableProtocolMarker;
use {fidl_fuchsia_developer_ffx as ffx, fidl_fuchsia_ffx_test as ffx_test};
#[derive(Default, Clone)]
struct TestDaemon;
#[async_trait(?Send)]
impl DaemonProtocolProvider for TestDaemon {
async fn open_protocol(&self, _name: String) -> Result<fidl::Channel> {
unimplemented!()
}
async fn open_target_proxy(
&self,
_target_identifier: Option<String>,
_moniker: &str,
_capability_name: &str,
) -> Result<fidl::Channel> {
unimplemented!()
}
async fn open_target_proxy_with_info(
&self,
_target_identifier: Option<String>,
_moniker: &str,
_capability_name: &str,
) -> Result<(ffx::TargetInfo, fidl::Channel)> {
unimplemented!()
}
async fn get_target_info(
&self,
_target_identifier: Option<String>,
) -> Result<ffx::TargetInfo, DaemonError> {
unimplemented!()
}
}
#[derive(Default, Clone)]
struct NoopProtocol;
#[async_trait(?Send)]
impl FidlProtocol for NoopProtocol {
type Protocol = ffx_test::NoopMarker;
type StreamHandler = FidlStreamHandler<Self>;
async fn handle(&self, _cx: &Context, req: ffx_test::NoopRequest) -> Result<()> {
match req {
ffx_test::NoopRequest::DoNoop { responder } => responder.send().map_err(Into::into),
}
}
}
fn create_noop_register() -> ProtocolRegister {
let protocol_string =
<<NoopProtocol as FidlProtocol>::Protocol as DiscoverableProtocolMarker>::PROTOCOL_NAME
.to_owned();
let mut map = NameToStreamHandlerMap::new();
map.insert(protocol_string.clone(), Box::new(FidlStreamHandler::<NoopProtocol>::default()));
ProtocolRegister::new(map)
}
async fn create_noop_proxy(
context: &EnvironmentContext,
) -> Result<(ffx_test::NoopProxy, ProtocolRegister)> {
let register = create_noop_register();
let (noop_proxy, server) = fidl::endpoints::create_endpoints::<ffx_test::NoopMarker>();
register
.open(
ffx_test::NoopMarker::PROTOCOL_NAME.to_owned(),
Context::new(TestDaemon::default(), context.clone()),
fidl::AsyncChannel::from_channel(server.into_channel()),
)
.await?;
Ok((noop_proxy.into_proxy(), register))
}
#[fuchsia::test]
async fn test_start_stop() -> Result<()> {
let env = ffx_config::test_init().unwrap();
let (noop_proxy, register) = create_noop_proxy(&env.context).await?;
noop_proxy.do_noop().await?;
register.shutdown(Context::new(TestDaemon::default(), env.context.clone())).await?;
assert!(noop_proxy.do_noop().await.is_err());
Ok(())
}
#[fuchsia::test]
async fn test_err_on_open_after_shutdown() -> Result<()> {
let env = ffx_config::test_init().unwrap();
let register = create_noop_register();
let (noop_proxy, server) = fidl::endpoints::create_endpoints::<ffx_test::NoopMarker>();
register.shutdown(Context::new(TestDaemon::default(), env.context.clone())).await?;
let res = register
.open(
ffx_test::NoopMarker::PROTOCOL_NAME.to_owned(),
Context::new(TestDaemon::default(), env.context.clone()),
fidl::AsyncChannel::from_channel(server.into_channel()),
)
.await;
let noop_proxy = noop_proxy.into_proxy();
assert!(res.is_err());
assert!(noop_proxy.do_noop().await.is_err());
Ok(())
}
#[fuchsia::test]
async fn test_err_double_shutdown() -> Result<()> {
let env = ffx_config::test_init().unwrap();
let register = create_noop_register();
register.shutdown(Context::new(TestDaemon::default(), env.context.clone())).await?;
assert!(
register
.shutdown(Context::new(TestDaemon::default(), env.context.clone()))
.await
.is_err()
);
Ok(())
}
}