blob: 5a50e6ca732bbd826cc5c14867b971e6ca876ce7 [file] [log] [blame]
// Copyright 2021 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use anyhow::Result;
use std::io::Write;
use crate::test_code::{convert_to_camel, copyright, CodeGenerator, TestCodeBuilder};
use std::collections::BTreeSet;
const MOCK_FUNC_TEMPLATE: &'static str = include_str!("templates/template_rust_mock_function");
const TEST_FUNC_TEMPLATE: &'static str = include_str!("templates/template_rust_test_function");
const CONNECT_FUNC_TEMPLATE: &'static str =
include_str!("templates/template_rust_connect_function");
const DO_NOT_EDIT: &'static str = "// EDIT WITH CAUTION. This file is autogenerated using 'fx testgen', and will be overwritten with each run.\n";
/// Generates lib.rs, this code will get regenerated with incremental code.
pub struct RustLibGenerator<'a> {
pub code: &'a RustTestCode,
pub copyright: bool,
}
/// Generates <test_name>.rs, this is a one-time generated code.
pub struct RustTestCodeGenerator<'a> {
pub code: &'a RustTestCode,
pub copyright: bool,
}
pub trait RustCreateRealm: CodeGenerator {
fn write_create_realm_fn<W: Write>(&self, writer: &mut W) -> Result<()>;
}
impl RustCreateRealm for RustLibGenerator<'_> {
fn write_create_realm_fn<W: Write>(&self, writer: &mut W) -> Result<()> {
let create_realm_func_start = r#" pub async fn create_realm() -> Result<RealmInstance, Error> {
let builder = RealmBuilder::new().await?;
"#;
let mut create_realm_impl = self.code.realm_builder_snippets.join("\n");
create_realm_impl.push_str("\n");
let create_realm_func_end = r#"
let instance = builder.build().await?;
Ok(instance)
}
"#;
// Generate create_realm() function
writer.write_all(&create_realm_func_start.as_bytes())?;
writer.write_all(&create_realm_impl.as_bytes())?;
writer.write_all(&create_realm_func_end.as_bytes())?;
Ok(())
}
}
impl CodeGenerator for RustTestCodeGenerator<'_> {
fn write_file<W: Write>(&self, writer: &mut W) -> Result<()> {
// Add import statements
let mut imports;
let mut mock_impl = "\n".to_string();
if self.code.mock_functions.len() > 0 {
imports = format!("use crate::lib::{{{}, Mocks}};\n", self.code.test_class_name);
imports.push_str("use async_trait::async_trait;\n");
imports.push_str("use fuchsia_component::server::ServiceFs;\n");
imports.push_str("use fuchsia_component_test::LocalComponentHandles;\n");
mock_impl.push_str(
format!(
"#[async_trait]
impl Mocks for {} {{\n",
self.code.test_class_name
)
.as_str(),
);
mock_impl.push_str(&self.code.mock_functions.join("\n\n"));
mock_impl.push_str("\n}\n\n");
} else {
imports = format!("use crate::lib::{};\n", self.code.test_class_name);
}
imports.push_str("mod lib;\n");
// Add testcases, one per protocol
let mut test_cases = self.code.test_case.join("\n\n");
test_cases.push_str("\n");
if self.copyright {
writer.write_all(&copyright("//").as_bytes())?;
}
writer.write_all(&imports.as_bytes())?;
writer.write_all(&mock_impl.as_bytes())?;
writer.write_all(&test_cases.as_bytes())?;
Ok(())
}
}
impl CodeGenerator for RustLibGenerator<'_> {
fn write_file<W: Write>(&self, writer: &mut W) -> Result<()> {
if self.copyright {
writer.write_all(&copyright("//").as_bytes())?;
}
writer.write_all(&DO_NOT_EDIT.as_bytes())?;
// Add import statements
let all_imports = self.code.imports.clone().into_iter().collect::<Vec<_>>();
let mut imports = all_imports.join("\n");
imports.push_str("\n\n");
writer.write_all(&imports.as_bytes())?;
// Add constants, these are components urls
let mut constants = self.code.constants.join("\n");
constants.push_str("\n\n");
writer.write_all(&constants.as_bytes())?;
if self.code.mock_function_interfaces.len() > 0 {
let mut mock_trait = r#"#[async_trait]
pub trait Mocks {
"#
.to_string();
mock_trait.push_str(&self.code.mock_function_interfaces.join("\n"));
mock_trait.push_str("\n}\n\n");
writer.write_all(&mock_trait.as_bytes())?;
}
let test_class_struct = format!(
"pub struct {class_name};\n\n",
class_name = self.code.test_class_name.as_str()
);
writer.write_all(&test_class_struct.as_bytes())?;
let test_impl_start =
format!("impl {class_name} {{\n", class_name = self.code.test_class_name.as_str());
let test_impl_end = "\n}";
writer.write_all(&test_impl_start.as_bytes())?;
// Generate create_realm() function
self.write_create_realm_fn(writer)?;
if self.code.connect_functions.len() > 0 {
writer.write_all(&self.code.connect_functions.join("\n").as_bytes())?;
}
writer.write_all(&test_impl_end.as_bytes())?;
Ok(())
}
}
pub struct RustTestCode {
/// library import strings
pub imports: BTreeSet<String>,
// /// test import strings
// pub test_imports: BTreeSet<String>,
/// RealmBuilder compatibility routing code
pub realm_builder_snippets: Vec<String>,
/// constant strings
constants: Vec<String>,
/// testcase functions
test_case: Vec<String>,
// skeleton functions for implementing mocks
mock_functions: Vec<String>,
/// Contains interface signatures that's included in "trait Mocks"
mock_function_interfaces: Vec<String>,
/// var name used in generated RealmBuilder code that refers to the
/// component-under-test
component_under_test: String,
/// Generated class name, this is the {ComponentName}Test ex: EchoServerTest
test_class_name: String,
/// Contains generated functions that connects to a specified FIDL protocol
/// ex: connect_to_echomarker(...){...}
connect_functions: Vec<String>,
}
impl TestCodeBuilder for RustTestCode {
fn new(component_name: &str) -> Self {
RustTestCode {
realm_builder_snippets: Vec::new(),
constants: Vec::new(),
imports: BTreeSet::new(),
test_case: Vec::new(),
mock_functions: Vec::new(),
mock_function_interfaces: Vec::new(),
component_under_test: component_name.to_string(),
test_class_name: format!("{}Test", convert_to_camel(component_name)),
connect_functions: Vec::new(),
}
}
fn add_import<'a>(&'a mut self, import_library: &str) -> &'a dyn TestCodeBuilder {
self.imports.insert(format!(r#"use {};"#, import_library));
self
}
fn add_component<'a>(
&'a mut self,
component_name: &str,
url: &str,
const_var: &str,
mock: bool,
) -> &'a dyn TestCodeBuilder {
if mock {
let mock_function_name = format!("{}_impl", component_name);
self.realm_builder_snippets.push(format!(
r#" let {child_component} = builder.add_local_child(
"{child_component}",
move |handles: LocalComponentHandles| Box::pin({test_class_name}::{mock_function}(handles)),
ChildOptions::new()
)
.await?;"#,
test_class_name = self.test_class_name,
child_component = component_name,
mock_function = &mock_function_name
));
} else {
self.constants.push(format!(r#"const {}: &str = "{}";"#, const_var, url).to_string());
self.realm_builder_snippets.push(format!(
r#" let {child_component} = builder.add_child(
"{child_component}",
{url},
ChildOptions::new()
)
.await?;"#,
child_component = component_name,
url = const_var
));
}
self
}
fn add_mock_impl<'a>(
&'a mut self,
component_name: &str,
protocol: &str,
) -> &'a dyn TestCodeBuilder {
// Note: this function name must match the one we added in 'add_component'.
let mock_function_name = format!("{}_impl", component_name);
self.mock_functions.push(
MOCK_FUNC_TEMPLATE
.replace("FUNCTION_NAME", &mock_function_name)
.replace("PROTOCOL_REQUEST_STREAM", format!("{}RequestStream", protocol).as_str()),
);
self.mock_function_interfaces.push(format!(
" async fn {}(handles: LocalComponentHandles) -> Result<(), Error>;",
mock_function_name
));
self
}
fn add_protocol<'a>(
&'a mut self,
protocol: &str,
source: &str,
targets: Vec<String>,
) -> &'a dyn TestCodeBuilder {
let source_code = match source {
"root" => "Ref::parent()".to_string(),
"self" => format!("&{}", self.component_under_test),
_ => format!("&{}", source),
};
let mut targets_code: String = "".to_string();
for i in 0..targets.len() {
let t = &targets[i];
if t == "root" {
targets_code.push_str(format!("{:>20}.to(Ref::parent())\n", " ").as_str());
} else if t == "self" {
targets_code
.push_str(format!("{:>20}.to(&{})\n", " ", self.component_under_test).as_str());
} else {
targets_code.push_str(format!("{:>20}.to(&{})\n", " ", source).as_str());
}
}
self.realm_builder_snippets.push(format!(
r#" builder
.add_route(
Route::new()
.capability(Capability::protocol_by_name("{protocol}"))
.from({from})
{to},
)
.await?;"#,
protocol = protocol,
from = source_code,
to = targets_code.trim_end()
));
self
}
fn add_directory<'a>(
&'a mut self,
dir_name: &str,
dir_path: &str,
targets: Vec<String>,
) -> &'a dyn TestCodeBuilder {
let mut targets_code: String = "".to_string();
for i in 0..targets.len() {
let t = &targets[i];
if t == "root" {
targets_code.push_str(format!("{:>20}.to(Ref::parent())\n", " ").as_str());
} else if t == "self" {
targets_code
.push_str(format!("{:>20}.to(&{})\n", " ", self.component_under_test).as_str());
} else {
targets_code.push_str(format!("{:>20}.to(&{})\n", " ", t).as_str());
}
}
self.realm_builder_snippets.push(format!(
r#" builder
.add_route(
Route::new()
.capability(Capability::directory("{dir}").path("{path}").rights("fio::RW_STAR_DIR"))
.from(Ref::parent())
{to},
)
.await?;"#,
dir = dir_name,
path = dir_path,
to = targets_code.trim_end()
));
self
}
fn add_storage<'a>(
&'a mut self,
storage_name: &str,
storage_path: &str,
targets: Vec<String>,
) -> &'a dyn TestCodeBuilder {
let mut targets_code: String = "".to_string();
for i in 0..targets.len() {
let t = &targets[i];
if t == "root" {
targets_code.push_str(format!("{:>20}.to(Ref::parent())\n", " ").as_str());
} else if t == "self" {
targets_code
.push_str(format!("{:>20}.to(&{})\n", " ", self.component_under_test).as_str());
} else {
targets_code.push_str(format!("{:>20}.to(&{})\n", " ", t).as_str());
}
}
self.realm_builder_snippets.push(format!(
r#" builder
.add_route(
Route::new()
.capability(Capability::storage("{dir}").path("{path}"))
.from(Ref::parent())
{to},
)
.await?;"#,
dir = storage_name,
path = storage_path,
to = targets_code.trim_end()
));
self
}
fn add_test_case<'a>(&'a mut self, protocol: &str) -> &'a dyn TestCodeBuilder {
let protocol_marker = format!("{}Marker", &protocol);
self.test_case.push(
TEST_FUNC_TEMPLATE
.replace("MARKER_NAME", &protocol_marker.to_ascii_lowercase())
.replace("TEST_CLASS_NAME", &self.test_class_name)
.replace("PROTOCOL", &protocol),
);
self
}
// Generate functions that connects to a specific protocol exposed by component_under_test
fn add_fidl_connect<'a>(&'a mut self, protocol: &str) -> &'a dyn TestCodeBuilder {
let protocol_marker = format!("{}Marker", &protocol);
let protocol_proxy = format!("{}Proxy", &protocol);
self.connect_functions.push(
CONNECT_FUNC_TEMPLATE
.replace("MARKER_NAME", &protocol_marker.to_ascii_lowercase())
.replace("MARKER_CONNECTOR", &protocol_marker)
.replace("MARKER_PROXY", &protocol_proxy)
.replace("PROTOCOL", &protocol),
);
self
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn deduplicate_imports() {
let mut output: Vec<u8> = vec![];
let mut code = RustTestCode::new("test-component");
code.add_import("example::Value");
code.add_import("example::Value");
code.add_import("example::Value2");
RustLibGenerator { code: &code, copyright: false }
.write_file(&mut output)
.expect("write output");
let lines = std::str::from_utf8(&output)
.expect("output must be UTF-8")
.split("\n")
.collect::<Vec<_>>();
// Note lines[0] is used to generate DO_NOT_EDIT
assert_eq!(lines[1], "use example::Value2;");
assert_eq!(lines[2], "use example::Value;");
assert_ne!(lines[3], "use example::Value;");
}
#[test]
fn generate_test_lib() {
let mut output: Vec<u8> = vec![];
let mut code = RustTestCode::new("test-component");
code.add_test_case("Echo");
code.add_mock_impl("test-component", "Echo");
RustTestCodeGenerator { code: &code, copyright: false }
.write_file(&mut output)
.expect("write output");
let lines = std::str::from_utf8(&output).expect("output must be UTF-8");
let expect = r#"use crate::lib::{TestComponentTest, Mocks};
use async_trait::async_trait;
use fuchsia_component::server::ServiceFs;
use fuchsia_component_test::LocalComponentHandles;
mod lib;
#[async_trait]
impl Mocks for TestComponentTest {
async fn test-component_impl(handles: LocalComponentHandles) -> Result<(), Error> {
let mut fs = ServiceFs::new();
// Implement mocked component below, ex:
// fs.dir("svc")
// .add_fidl_service(move |mut stream: EchoRequestStream| {
// // mock the fidl service
// })
// .add_fidl_service(move |mut stream: some_other_request| {
// // mock the fidl service
// });
fs.serve_connection(handles.outgoing_dir.into_channel()).unwrap();
fs.collect::<()>().await;
Ok(())
}
}
#[fuchsia::test]
async fn test_echomarker() {
let instance = TestComponentTest::create_realm().await.expect("setting up test realm");
let proxy = TestComponentTest::connect_to_echomarker(instance);
// Add test for Echo
}
"#;
assert_eq!(lines, expect);
}
}