| // Copyright 2020 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| use { |
| crate::rest::{error::RestError, visualizer::*}, |
| anyhow::{Error, Result}, |
| log::{info, warn}, |
| rouille::{Request, Response, ResponseBody}, |
| scrutiny::{ |
| engine::dispatcher::{ControllerDispatcher, DispatcherError}, |
| model::controller::ConnectionMode, |
| }, |
| serde_json::json, |
| std::collections::HashMap, |
| std::io::{self, ErrorKind, Read}, |
| std::net::TcpStream, |
| std::str, |
| std::sync::{Arc, RwLock}, |
| std::thread, |
| }; |
| |
| /// Holds ownership of the thread that the REST service is running on. |
| pub struct RestService {} |
| |
| impl RestService { |
| /// Spawns the RestService on a new thread. |
| pub fn spawn( |
| dispatcher: Arc<RwLock<ControllerDispatcher>>, |
| visualizer: Arc<RwLock<Visualizer>>, |
| port: u16, |
| ) -> Result<()> { |
| let addr = format!("127.0.0.1:{}", port); |
| println!("• Server: http://{}\n", addr); |
| |
| if TcpStream::connect(("127.0.0.1", port)).is_ok() { |
| Err(Error::new(RestError::port_in_use(addr))) |
| } else { |
| thread::spawn(move || RestService::run(dispatcher, visualizer, addr)); |
| Ok(()) |
| } |
| } |
| |
| /// Runs the core REST service loop, parsing URLs and queries to their |
| /// respective controllers via the ControllerDispatcher. This function does |
| /// not exit. |
| fn run( |
| dispatcher: Arc<RwLock<ControllerDispatcher>>, |
| visualizer: Arc<RwLock<Visualizer>>, |
| addr: String, |
| ) { |
| info!("Server starting: http://{}", addr); |
| rouille::start_server(addr, move |request| { |
| info!("Request: {} {}", request.method(), request.url()); |
| // TODO: Change to allow each plugin to define its own visualizers. |
| if request.url().starts_with("/api") { |
| RestService::handle_controller_request(dispatcher.clone(), request) |
| } else { |
| visualizer.read().unwrap().serve_path_or_index(request) |
| } |
| }); |
| } |
| |
| fn parse_get_params(query_str: &str) -> HashMap<String, String> { |
| // TODO: Sanitize these values. |
| query_str |
| .split('&') |
| .map(|kv| { |
| if let Some(ind) = kv.find('=') { |
| if ind == 0 || ind == kv.len() - 1 { |
| // If the = is at the end of the split, ignore. |
| None |
| } else { |
| // TODO(arkay) Make this work for non string-string kv pairs. |
| let (key, value) = kv.split_at(ind); |
| Some((key.to_string(), value[1..].to_string())) |
| } |
| } else { |
| None |
| } |
| }) |
| .flatten() |
| .collect() |
| } |
| |
| /// Converts a rust error into a JSON error response. |
| fn error_response(error: Error, status_code: Option<u16>) -> Response { |
| let result = json!({ |
| "status": "error", |
| "description": error.to_string(), |
| }); |
| |
| Response { |
| status_code: { |
| if let Some(code) = status_code { |
| code |
| } else { |
| 500 |
| } |
| }, |
| headers: vec![("Content-Type".into(), "application/json".into())], |
| data: ResponseBody::from_string(result.to_string()), |
| upgrade: None, |
| } |
| } |
| |
| fn handle_controller_request( |
| dispatcher: Arc<RwLock<ControllerDispatcher>>, |
| request: &Request, |
| ) -> Response { |
| let method = request.method(); |
| let mut body = request.data().expect("RequestBody already retrieved"); |
| |
| let query_val = match method { |
| "GET" => { |
| // TODO: Looking at the source for `get_param(&self, param_name: &str)` seems like it's not great to |
| // rely on that function since it doesn't match against the entire parameter name... |
| let query = request.raw_query_string(); |
| let params = RestService::parse_get_params(query); |
| Ok(json!(params)) |
| } |
| "POST" => { |
| let mut query = String::new(); |
| if let Err(e) = body.read_to_string(&mut query) { |
| warn!("Failed to read request body."); |
| return RestService::error_response(Error::new(e), Some(400)); |
| } |
| if query.is_empty() { |
| // If there is no body, return a null value, since from_str will error. |
| Ok(json!(null)) |
| } else { |
| serde_json::from_str(&query) |
| } |
| } |
| _ => { |
| // TODO: Should always serve HEAD requests. |
| warn!("Expected GET or POST method, received {}.", method); |
| return RestService::error_response( |
| Error::new(io::Error::new(ErrorKind::ConnectionRefused, "Unsupported method.")), |
| Some(405), |
| ); |
| } |
| }; |
| |
| let dispatch = dispatcher.read().unwrap(); |
| if let Ok(json_val) = query_val { |
| match dispatch.query(ConnectionMode::Remote, request.url(), json_val) { |
| Ok(result) => Response { |
| status_code: 200, |
| headers: vec![("Content-Type".into(), "application/json".into())], |
| data: ResponseBody::from_string(serde_json::to_string_pretty(&result).unwrap()), |
| upgrade: None, |
| }, |
| Err(e) => { |
| if let Some(dispatch_error) = e.downcast_ref::<DispatcherError>() { |
| if let DispatcherError::NamespaceDoesNotExist(_) = dispatch_error { |
| warn!("Address not found."); |
| return Response::empty_404(); |
| } |
| } |
| RestService::error_response(e, None) |
| } |
| } |
| } else { |
| return Response::empty_400(); |
| } |
| } |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use { |
| super::*, |
| anyhow::Result, |
| scrutiny::{model::controller::DataController, model::model::DataModel}, |
| scrutiny_testing::fake::*, |
| serde_json::value::Value, |
| std::io, |
| uuid::Uuid, |
| }; |
| |
| #[derive(Default)] |
| struct EchoController {} |
| |
| impl DataController for EchoController { |
| fn query(&self, _: Arc<DataModel>, query: Value) -> Result<Value> { |
| Ok(query) |
| } |
| } |
| |
| #[derive(Default)] |
| struct ErrorController {} |
| |
| impl DataController for ErrorController { |
| fn query(&self, _: Arc<DataModel>, _: Value) -> Result<Value> { |
| Err(Error::new(io::Error::new(io::ErrorKind::Other, "It's always an error!"))) |
| } |
| } |
| |
| fn setup_dispatcher() -> Arc<RwLock<ControllerDispatcher>> { |
| let data_model = fake_data_model(); |
| let mut dispatcher = ControllerDispatcher::new(data_model); |
| let echo = Arc::new(EchoController::default()); |
| let error = Arc::new(ErrorController::default()); |
| dispatcher.add(Uuid::new_v4(), "/api/foo/bar".to_string(), echo).unwrap(); |
| dispatcher.add(Uuid::new_v4(), "/api/foo/baz".to_string(), error).unwrap(); |
| Arc::new(RwLock::new(dispatcher)) |
| } |
| |
| #[test] |
| fn handle_controller_request_fails_non_get_or_post_request() { |
| let dispatcher = setup_dispatcher(); |
| let request = &Request::fake_http("HEAD", "/api/foo/bar", vec![], vec![]); |
| let response = RestService::handle_controller_request(dispatcher.clone(), request); |
| assert_eq!(response.status_code, 405); |
| } |
| |
| #[test] |
| fn handle_controller_request_returns_500_on_controller_error() { |
| let dispatcher = setup_dispatcher(); |
| let request = &Request::fake_http("GET", "/api/foo/baz", vec![], vec![]); |
| let response = RestService::handle_controller_request(dispatcher.clone(), request); |
| assert_eq!(response.status_code, 500); |
| } |
| |
| #[test] |
| fn handle_controller_request_returns_404_on_non_matching_dispatcher() { |
| let dispatcher = setup_dispatcher(); |
| let request = &Request::fake_http("GET", "/api/foo/bin", vec![], vec![]); |
| let response = RestService::handle_controller_request(dispatcher.clone(), request); |
| assert_eq!(response.status_code, 404); |
| } |
| |
| #[test] |
| fn handle_controller_request_serves_get_request() { |
| let dispatcher = setup_dispatcher(); |
| let request = &Request::fake_http("GET", "/api/foo/bar?hello=world", vec![], vec![]); |
| let response = RestService::handle_controller_request(dispatcher.clone(), request); |
| assert_eq!(response.status_code, 200); |
| let mut buffer = Vec::new(); |
| let (mut reader, _) = response.data.into_reader_and_size(); |
| reader.read_to_end(&mut buffer).unwrap(); |
| let response_str = std::str::from_utf8(&buffer).unwrap(); |
| assert_eq!(response_str.contains("hello"), true); |
| assert_eq!(response_str.contains("world"), true); |
| } |
| |
| #[test] |
| fn handle_controller_request_serves_post_request() { |
| let dispatcher = setup_dispatcher(); |
| let bytes = b"{\"hello\":\"world\"}"; |
| let request = &Request::fake_http("POST", "/api/foo/bar", vec![], bytes.to_vec()); |
| let response = RestService::handle_controller_request(dispatcher.clone(), request); |
| assert_eq!(response.status_code, 200); |
| let mut buffer = Vec::new(); |
| let (mut reader, _) = response.data.into_reader_and_size(); |
| reader.read_to_end(&mut buffer).unwrap(); |
| let response_str = std::str::from_utf8(&buffer).unwrap(); |
| assert_eq!(response_str.contains("hello"), true); |
| assert_eq!(response_str.contains("world"), true); |
| } |
| |
| #[test] |
| fn parse_get_params_returns_empty_vec_on_empty_query() { |
| let params = RestService::parse_get_params(""); |
| assert!(params.is_empty()); |
| } |
| |
| #[test] |
| fn parse_get_params_skips_invalid_key_value_pairs() { |
| let params = RestService::parse_get_params( |
| "foo=bar&\ |
| &\ |
| baz=&\ |
| =aries&\ |
| hello=world", |
| ); |
| assert_eq!(params.len(), 2); |
| assert_eq!(params.get("foo").unwrap(), "bar"); |
| assert_eq!(params.get("hello").unwrap(), "world"); |
| } |
| } |