| // Copyright 2021 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| use anyhow::Result; |
| use std::io::Write; |
| |
| use crate::test_code::{CodeGenerator, TestCodeBuilder}; |
| use std::collections::BTreeSet; |
| |
| const MOCK_FUNC_TEMPLATE: &'static str = include_str!("templates/template_rust_mock_function"); |
| const TEST_FUNC_TEMPLATE: &'static str = include_str!("templates/template_rust_test_function"); |
| |
| pub struct RustTestCodeGenerator<'a> { |
| pub code: &'a RustTestCode, |
| } |
| |
| impl CodeGenerator for RustTestCodeGenerator<'_> { |
| fn write_file<W: Write>(&self, writer: &mut W) -> Result<()> { |
| let create_realm_func_start = r#"pub async fn create_realm() -> Result<RealmInstance, Error> { |
| let mut builder = RealmBuilder::new().await?; |
| builder |
| "#; |
| let mut create_realm_impl = self.code.realm_builder_snippets.join("\n"); |
| create_realm_impl.push_str(";\n\n"); |
| let create_realm_func_end = r#" |
| let instance = builder.build().create().await?; |
| Ok(instance) |
| } |
| |
| "#; |
| // Add import statements |
| let all_imports = self.code.imports.clone().into_iter().collect::<Vec<_>>(); |
| let mut imports = all_imports.join("\n"); |
| imports.push_str("\n\n"); |
| writer.write_all(&imports.as_bytes())?; |
| |
| // Add constants, these are components urls |
| let mut constants = self.code.constants.join("\n"); |
| constants.push_str("\n\n"); |
| writer.write_all(&constants.as_bytes())?; |
| |
| // Generate create_realm() function |
| writer.write_all(&create_realm_func_start.as_bytes())?; |
| writer.write_all(&create_realm_impl.as_bytes())?; |
| writer.write_all(&create_realm_func_end.as_bytes())?; |
| |
| // Add mock implementation functions, one per component |
| if self.code.mock_functions.len() > 0 { |
| let mut mock_funcs = self.code.mock_functions.join("\n\n"); |
| mock_funcs.push_str("\n\n"); |
| writer.write_all(&mock_funcs.as_bytes())?; |
| } |
| |
| // Add testcases, one per protocol |
| let mut test_cases = self.code.test_case.join("\n\n"); |
| test_cases.push_str("\n"); |
| writer.write_all(&test_cases.as_bytes())?; |
| |
| Ok(()) |
| } |
| } |
| |
| pub struct RustTestCode { |
| /// library import strings |
| pub imports: BTreeSet<String>, |
| /// constant strings |
| constants: Vec<String>, |
| /// RealmBuilder compatibility routing code |
| pub realm_builder_snippets: Vec<String>, |
| /// testcase functions |
| test_case: Vec<String>, |
| // skeleton functions for implementing mocks |
| mock_functions: Vec<String>, |
| /// name used by RealmBuilder for the component-under-test |
| component_under_test: String, |
| } |
| |
| impl TestCodeBuilder for RustTestCode { |
| fn new(component_name: &str) -> Self { |
| RustTestCode { |
| realm_builder_snippets: Vec::new(), |
| constants: Vec::new(), |
| imports: BTreeSet::new(), |
| test_case: Vec::new(), |
| mock_functions: Vec::new(), |
| component_under_test: component_name.to_string(), |
| } |
| } |
| fn add_import<'a>(&'a mut self, import_library: &str) -> &'a dyn TestCodeBuilder { |
| self.imports.insert(format!(r#"use {};"#, import_library)); |
| self |
| } |
| |
| fn add_component<'a>( |
| &'a mut self, |
| component_name: &str, |
| url: &str, |
| const_var: &str, |
| mock: bool, |
| ) -> &'a dyn TestCodeBuilder { |
| if mock { |
| let mock_function_name = format!("{}_impl", component_name); |
| self.realm_builder_snippets.push(format!( |
| r#" .add_component("{}", |
| ComponentSource::Mock(Mock::new(move |mock_handles: MockHandles| {{ |
| Box::pin({}(mock_handles)) |
| }})), |
| ) |
| .await?"#, |
| component_name, &mock_function_name, |
| )); |
| } else { |
| self.realm_builder_snippets.push(format!( |
| r#" .add_component("{}", ComponentSource::url({})).await?"#, |
| component_name, const_var |
| )); |
| self.constants.push(format!(r#"const {}: &str = "{}";"#, const_var, url).to_string()); |
| } |
| self |
| } |
| |
| fn add_mock_impl<'a>( |
| &'a mut self, |
| component_name: &str, |
| _protocol: &str, |
| ) -> &'a dyn TestCodeBuilder { |
| // Note: this function name must match the one we added in 'add_component'. |
| let mock_function_name = format!("{}_impl", component_name); |
| self.mock_functions.push(MOCK_FUNC_TEMPLATE.replace("FUNCTION_NAME", &mock_function_name)); |
| self |
| } |
| |
| fn add_protocol<'a>( |
| &'a mut self, |
| protocol: &str, |
| source: &str, |
| targets: Vec<String>, |
| ) -> &'a dyn TestCodeBuilder { |
| let source_code = match source { |
| "root" => "RouteEndpoint::above_root()".to_string(), |
| "self" => format!("RouteEndpoint::component(\"{}\")", self.component_under_test), |
| _ => format!("RouteEndpoint::component(\"{}\")", source), |
| }; |
| |
| let mut targets_code: String = "".to_string(); |
| for i in 0..targets.len() { |
| let t = &targets[i]; |
| if t == "root" { |
| targets_code.push_str("RouteEndpoint::above_root(), "); |
| } else if t == "self" { |
| targets_code.push_str( |
| format!("RouteEndpoint::component(\"{}\"), ", self.component_under_test) |
| .as_str(), |
| ); |
| } else { |
| targets_code.push_str(format!("RouteEndpoint::component(\"{}\"), ", t).as_str()); |
| } |
| } |
| self.realm_builder_snippets.push(format!( |
| r#" .add_route(CapabilityRoute {{ |
| capability: Capability::protocol("{}"), |
| source: {}, |
| targets: vec![ |
| {} |
| ], |
| }})?"#, |
| protocol, |
| source_code, |
| targets_code.trim_end() |
| )); |
| self |
| } |
| |
| fn add_directory<'a>( |
| &'a mut self, |
| dir_name: &str, |
| dir_path: &str, |
| targets: Vec<String>, |
| ) -> &'a dyn TestCodeBuilder { |
| let mut targets_code: String = "".to_string(); |
| for i in 0..targets.len() { |
| let t = &targets[i]; |
| if t == "root" { |
| targets_code.push_str("RouteEndpoint::above_root(), "); |
| } else if t == "self" { |
| targets_code.push_str( |
| format!("RouteEndpoint::component(\"{}\"), ", self.component_under_test) |
| .as_str(), |
| ); |
| } else { |
| targets_code.push_str(format!("RouteEndpoint::component(\"{}\"), ", t).as_str()); |
| } |
| } |
| self.realm_builder_snippets.push(format!( |
| r#" .add_route(CapabilityRoute {{ |
| capability: Capability::directory( |
| "{}", |
| "{}", |
| fio2::RW_STAR_DIR), |
| source: RouteEndpoint::above_root(), |
| targets: vec![ |
| {} |
| ], |
| }})?"#, |
| dir_name, |
| dir_path, |
| targets_code.trim_end() |
| )); |
| self |
| } |
| |
| fn add_storage<'a>( |
| &'a mut self, |
| storage_name: &str, |
| storage_path: &str, |
| targets: Vec<String>, |
| ) -> &'a dyn TestCodeBuilder { |
| let mut targets_code: String = "".to_string(); |
| for i in 0..targets.len() { |
| let t = &targets[i]; |
| if t == "root" { |
| targets_code.push_str("RouteEndpoint::above_root(), "); |
| } else if t == "self" { |
| targets_code.push_str( |
| format!("RouteEndpoint::component(\"{}\"), ", self.component_under_test) |
| .as_str(), |
| ); |
| } else { |
| targets_code.push_str(format!("RouteEndpoint::component(\"{}\"), ", t).as_str()); |
| } |
| } |
| self.realm_builder_snippets.push(format!( |
| r#" .add_route(CapabilityRoute {{ |
| capability: Capability::storage( |
| "{}", |
| "{}", |
| ), |
| source: RouteEndpoint::above_root(), |
| targets: vec![ |
| {} |
| ], |
| }})?"#, |
| storage_name, |
| storage_path, |
| targets_code.trim_end() |
| )); |
| self |
| } |
| |
| fn add_test_case<'a>(&'a mut self, protocol: &str) -> &'a dyn TestCodeBuilder { |
| let protocol_marker = format!("{}Marker", &protocol); |
| self.test_case.push( |
| TEST_FUNC_TEMPLATE |
| .replace("MARKER_VAR_NAME", &protocol_marker.to_ascii_lowercase()) |
| .replace("MARKER", &protocol_marker) |
| .replace("PROTOCOL", &protocol), |
| ); |
| self |
| } |
| } |
| |
| #[cfg(test)] |
| mod test { |
| use super::*; |
| |
| #[test] |
| fn deduplicate_imports() { |
| let mut output: Vec<u8> = vec![]; |
| let mut code = RustTestCode::new("test-component"); |
| code.add_import("example::Value"); |
| code.add_import("example::Value"); |
| code.add_import("example::Value2"); |
| RustTestCodeGenerator { code: &code }.write_file(&mut output).expect("write output"); |
| |
| let lines = std::str::from_utf8(&output) |
| .expect("output must be UTF-8") |
| .split("\n") |
| .collect::<Vec<_>>(); |
| assert!(lines.len() > 3); |
| assert_eq!(lines[0], "use example::Value2;"); |
| assert_eq!(lines[1], "use example::Value;"); |
| assert_ne!(lines[2], "use example::Value;"); |
| } |
| } |