blob: e9656b6f5e6130161741b506950f4a853dead39d [file] [log] [blame]
// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use {
crate::common::{extract_expected, gen_visit_str, ident_from_path},
maplit::hashset,
proc_macro2::{Ident, TokenStream as TokenStream2},
quote::{quote, ToTokens, TokenStreamExt},
std::collections::HashSet,
syn,
};
pub fn impl_derive_ref(ast: syn::DeriveInput) -> Result<TokenStream2, syn::Error> {
let attrs = parse_reference_attributes(&ast)?;
struct Display<'a> {
name: &'a Ident,
variants: &'a HashSet<String>,
}
impl ToTokens for Display<'_> {
fn to_tokens(&self, tokens: &mut TokenStream2) {
let name = self.name;
let inner = {
let mut tokens = quote!();
if self.variants.contains("Named") {
tokens.append_all(quote! {
Self::Named(name) => write!(f, "#{}", name),
});
}
if self.variants.contains("Parent") {
tokens.append_all(quote! {
Self::Parent => write!(f, "parent"),
});
}
if self.variants.contains("Framework") {
tokens.append_all(quote! {
Self::Framework => write!(f, "framework"),
});
}
if self.variants.contains("Debug") {
tokens.append_all(quote! {
Self::Debug => write!(f, "debug"),
});
}
if self.variants.contains("Self_") {
tokens.append_all(quote! {
Self::Self_ => write!(f, "self"),
});
}
if self.variants.contains("Void") {
tokens.append_all(quote! {
Self::Void => write!(f, "void"),
});
}
tokens
};
tokens.append_all(quote! {
impl<'a> std::fmt::Display for #name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
#inner
}
}
}
});
}
}
struct FromStr<'a> {
name: &'a Ident,
variants: &'a HashSet<String>,
}
impl ToTokens for FromStr<'_> {
fn to_tokens(&self, tokens: &mut TokenStream2) {
let name = self.name;
let inner = {
let mut tokens = quote!();
tokens.append_all(quote! {
if value.len() == 0 || value.len() > 101 {
return Err(ParseError::InvalidLength);
}
});
if self.variants.contains("Named") {
tokens.append_all(quote! {
if value.starts_with("#") {
return value[1..]
.parse::<Name>()
.map(Self::Named)
.map_err(|_| ParseError::InvalidValue);
}
});
}
let inner = {
let mut tokens = quote!();
if self.variants.contains("Parent") {
tokens.append_all(quote! {
"parent" => Ok(Self::Parent),
});
}
if self.variants.contains("Framework") {
tokens.append_all(quote! {
"framework" => Ok(Self::Framework),
});
}
if self.variants.contains("Debug") {
tokens.append_all(quote! {
"debug" => Ok(Self::Debug),
});
}
if self.variants.contains("Self_") {
tokens.append_all(quote! {
"self" => Ok(Self::Self_),
});
}
if self.variants.contains("Void") {
tokens.append_all(quote! {
"void" => Ok(Self::Void),
});
}
tokens.append_all(quote! {_ => Err(ParseError::InvalidValue) });
tokens
};
tokens.append_all(quote! {
match value {
#inner
}
});
tokens
};
tokens.append_all(quote! {
impl std::str::FromStr for #name {
type Err = ParseError;
fn from_str(value: &str) -> Result<Self, Self::Err> {
#inner
}
}
});
}
}
struct AnyRefFrom<'a> {
name: &'a Ident,
variants: &'a HashSet<String>,
}
impl ToTokens for AnyRefFrom<'_> {
fn to_tokens(&self, tokens: &mut TokenStream2) {
let name = self.name;
let inner = {
let mut tokens = quote!();
if self.variants.contains("Parent") {
tokens.append_all(quote! {
#name::Parent => Self::Parent,
});
}
if self.variants.contains("Framework") {
tokens.append_all(quote! {
#name::Framework => Self::Framework,
});
}
if self.variants.contains("Debug") {
tokens.append_all(quote! {
#name::Debug => Self::Debug,
});
}
if self.variants.contains("Self_") {
tokens.append_all(quote! {
#name::Self_ => Self::Self_,
});
}
if self.variants.contains("Void") {
tokens.append_all(quote! {
#name::Void => Self::Void,
});
}
if self.variants.contains("Named") {
tokens.append_all(quote! {
#name::Named(ref n) => Self::Named(n),
});
}
tokens
};
tokens.append_all(quote! {
impl<'a> From<&'a #name> for AnyRef<'a> {
fn from(r: &'a #name) -> Self {
match r {
#inner
}
}
}
});
}
}
struct DeserializeAndSerialize<'a> {
name: &'a Ident,
expected: &'a syn::LitStr,
}
impl ToTokens for DeserializeAndSerialize<'_> {
fn to_tokens(&self, tokens: &mut TokenStream2) {
let name = self.name;
let expected = self.expected;
let visit_str = gen_visit_str(None, self.expected);
tokens.append_all(quote! {
impl<'de> serde::de::Deserialize<'de> for #name {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
struct RefVisitor;
impl<'de> serde::de::Visitor<'de> for RefVisitor {
type Value = #name;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str(#expected)
}
#visit_str
}
deserializer.deserialize_str(RefVisitor)
}
}
impl serde::ser::Serialize for #name {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::ser::Serializer
{
self.to_string().serialize(serializer)
}
}
});
}
}
let display = Display { name: &attrs.name, variants: &attrs.variants };
let from_str = FromStr { name: &attrs.name, variants: &attrs.variants };
let anyref_from = AnyRefFrom { name: &attrs.name, variants: &attrs.variants };
let deserialize_and_serialize =
DeserializeAndSerialize { name: &attrs.name, expected: &attrs.expected };
let tokens = quote! {
#display
#from_str
#anyref_from
#deserialize_and_serialize
};
Ok(tokens)
}
/// Attributes extracted from the `derive(Reference)` macro.
struct ReferenceAttributes {
/// Name of the reference enum.
name: Ident,
/// `expecting` string to return from the deserializer.
expected: syn::LitStr,
/// The variants that the reference can be. Valid choices: "named", "self", "parent",
/// "framework".
variants: HashSet<String>,
}
fn parse_reference_attributes(ast: &syn::DeriveInput) -> Result<ReferenceAttributes, syn::Error> {
let variants: HashSet<_> = if let syn::Data::Enum(enum_) = &ast.data {
enum_.variants.iter().map(|v| v.ident.to_string()).collect()
} else {
return Err(syn::Error::new_spanned(ast, "Ref must be derived on an enum."));
};
let allowed_variants = hashset! {
"Named",
"Parent",
"Framework",
"Debug",
"Self_",
"Void",
};
for ty in variants.iter() {
if !allowed_variants.contains(ty as &str) {
return Err(syn::Error::new_spanned(
ast,
&format!("enum variant not supported: {}", ty),
));
}
}
let mut expected = None;
for attr in &ast.attrs {
if !attr.path.is_ident("reference") {
continue;
}
match attr
.parse_meta()
.map_err(|_| syn::Error::new_spanned(ast, "`reference` attribute is not valid"))?
{
syn::Meta::List(l) => {
for attr in l.nested {
match attr {
syn::NestedMeta::Meta(syn::Meta::NameValue(attr)) => {
let ident = ident_from_path(&attr.path);
match &ident as &str {
"expected" => extract_expected(ast, attr, &mut expected)?,
_ => {
return Err(syn::Error::new_spanned(
ast,
"`reference` attribute is not valid",
));
}
}
}
_ => {
return Err(syn::Error::new_spanned(
ast,
"`reference` attribute must contain name-value pairs",
))?
}
}
}
}
_ => {
return Err(syn::Error::new_spanned(
ast,
"`reference` attribute value must be a list",
));
}
}
}
let expected =
expected.ok_or_else(|| syn::Error::new_spanned(ast, "`expected` attribute is missing"))?;
Ok(ReferenceAttributes { name: ast.ident.clone(), expected, variants })
}