blob: 49c15de04ddd56f925d151aceb90f30313928544 [file] [log] [blame]
// Copyright 2022 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use assert_matches::assert_matches;
use proc_macro::TokenStream;
use proc_macro2::{TokenStream as TokenStream2, TokenTree};
use quote::{quote, ToTokens as _};
use syn::{
parse_quote, punctuated::Punctuated, spanned::Spanned, AngleBracketedGenericArguments,
GenericArgument, GenericParam, Generics, Ident, Type, TypeParam, TypeParamBound, TypePath,
};
/// Implements a derive macro for [`net_types::ip::GenericOverIp`].
/// Requires that #[derive(GenericOverIp)] invocations explicitly specify
/// which type parameter is the generic-over-ip one with the
/// `#[generic_over_ip]` attribute, rather than inferring it from the bounds
/// on the struct generics.
///
/// Consider the following example:
///
/// ```
/// #[derive(GenericOverIp)]
/// #[generic_over_ip(<ARGUMENTS EXPLAINED BELOW>)]
/// struct Foo<T>(T);
/// ```
///
/// `#[generic_over_ip(T, Ip)]` specifies that the GenericOverIp impl
/// should be written treating `T` as an `Ip` implementor (either `Ipv4` or
/// `Ipv6`).
///
/// `#[generic_over_ip(T, IpAddress)]` specifies that `T` is an `IpAddress`
/// implementor (`Ipv4Addr` or `Ipv6Addr`).
///
/// `#[generic_over_ip(T, GenericOverIp)]` specifies that `T` implements
/// `GenericOverIp<I>` for all `I: Ip`. (Notably, we'd like to use this case
/// for the Ip and IpAddress cases above, but we cannot due to issues
/// with conflicting blanket impls.)
///
/// `#[generic_over_ip()]` specifies that `Foo` is IP-invariant.
#[proc_macro_derive(GenericOverIp, attributes(generic_over_ip))]
pub fn derive_generic_over_ip(input: TokenStream) -> TokenStream {
let ast = syn::parse(input).unwrap();
impl_derive_generic_over_ip(&ast).into()
}
fn impl_derive_generic_over_ip(ast: &syn::DeriveInput) -> TokenStream2 {
let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
if where_clause.is_some() {
return quote! {
compile_error!("deriving GenericOverIp for types with 'where' clauses is unsupported")
}
.into();
}
let name = &ast.ident;
let specified_generic_over_ip = match find_generic_over_ip_attr(ast) {
Ok(param) => param,
Err(e) => return e.into_compile_error().into(),
};
let extra_bounds = match &specified_generic_over_ip {
None => Vec::new(),
Some((ident, _)) => match collect_bounds(ident, ast.generics.type_params()) {
Some(bounds) => bounds,
None => {
return syn::Error::new(
ast.generics.span(),
format!(
"found no type parameter named {ident:?} as specified in \
the generic_over_ip attribute"
),
)
.into_compile_error()
.into();
}
},
};
// Drop the first and last tokens, which should be '<' and '>', and the
// trailing comma if there is one.
let mut impl_generics = impl_generics.into_token_stream().into_iter();
let expect_trailing_angle_bracket = impl_generics.next().is_some_and(|first| {
assert_matches!(first, TokenTree::Punct(p) if p.as_char() == '<');
true
});
let mut impl_generics: Vec<_> = impl_generics.collect();
if expect_trailing_angle_bracket {
assert_matches!(impl_generics.pop(), Some(TokenTree::Punct(p)) if p.as_char() == '>');
}
// Add a trailing comma if `impl_generics` is non-empty and doesn't have one.
match impl_generics.last() {
Some(TokenTree::Punct(p)) if p.as_char() == ',' => {}
None => {}
Some(_) => {
impl_generics.push(parse_quote! { , });
}
}
let impl_generics = impl_generics.into_iter().collect::<TokenStream2>();
match specified_generic_over_ip.clone() {
Some((ident, param_type)) => {
// Emit an impl that substitutes the identified type parameter
// to produce the new GenericOverIp::Type.
let generic_ip_name: Ident = parse_quote!(IpType);
let extra_bounds_target: TypePath = match param_type {
IpGenericParamType::IpVersion => parse_quote!(#generic_ip_name),
IpGenericParamType::IpAddress => parse_quote!(#generic_ip_name::Addr),
IpGenericParamType::GenericOverIp => parse_quote! {
<#ident as GenericOverIp<IpType>>::Type
},
};
let bound_if_generic_over_ip = match param_type {
IpGenericParamType::IpVersion | IpGenericParamType::IpAddress => None,
IpGenericParamType::GenericOverIp => Some(quote! {
#ident: GenericOverIp<IpType>,
}),
};
let generic_bounds = with_type_param_replaced(
&ast.generics,
&ident,
parse_quote! {
#extra_bounds_target
},
);
quote! {
impl <#impl_generics #generic_ip_name: Ip>
GenericOverIp<IpType> for #name #type_generics
where #bound_if_generic_over_ip #extra_bounds_target: #(#extra_bounds)+*, {
type Type = #name #generic_bounds;
}
}
}
None => {
// The type is IP-invariant so `GenericOverIp::Type` is always Self.`
quote! {
impl <#impl_generics IpType: Ip> GenericOverIp<IpType> for #name #type_generics {
type Type = Self;
}
}
}
}
}
#[derive(Debug, Clone)]
enum IpGenericParamType {
IpVersion,
IpAddress,
GenericOverIp,
}
fn find_generic_over_ip_attr(
ast: &syn::DeriveInput,
) -> Result<Option<(Ident, IpGenericParamType)>, syn::Error> {
let mut attrs = ast
.attrs
.iter()
.filter(|attr| attr.path.get_ident().map(|i| i == "generic_over_ip").unwrap_or(false));
let attr = attrs.next().ok_or_else(|| {
syn::Error::new(
ast.ident.span(),
"derive(GenericOverIp) cannot be used without the generic_over_ip attribute",
)
})?;
match attrs.next() {
None => {}
Some(attr) => {
return Err(syn::Error::new(
attr.span(),
"derive(GenericOverIp) cannot be used with multiple generic_over_ip attributes",
))
}
}
let meta = attr.parse_meta().map_err(|e| {
syn::Error::new(attr.span(), format!("generic_over_ip attr did not parse as Meta: {e:?}"))
})?;
let list = match meta {
syn::Meta::Path(_) | syn::Meta::NameValue(_) => {
return Err(syn::Error::new(
meta.span(),
"generic_over_ip must be passed at most one\
type parameter identifier",
));
}
syn::Meta::List(list) => list,
};
if list.nested.is_empty() {
return Ok(None);
} else if list.nested.len() != 2 {
return Err(syn::Error::new(
list.span(),
"generic_over_ip must be either be passed no \
arguments, or one type parameter identifier and its \
trait bound (Ip, IpAddress, or GenericOverIp)",
));
}
let mut iter = list.nested.into_iter();
let (ident, bound) = (iter.next().unwrap(), iter.next().unwrap());
let ident = match ident {
syn::NestedMeta::Meta(meta) => match meta {
syn::Meta::Path(path) => match path.get_ident() {
None => Err(syn::Error::new(
path.span(),
"generic_over_ip must be passed a parameter identifier",
)),
Some(ident) => Ok(ident.clone()),
},
syn::Meta::List(_) | syn::Meta::NameValue(_) => Err(syn::Error::new(
meta.span(),
"generic_over_ip must be passed at most one \
type parameter identifier, not a list or name-value pair",
)),
},
syn::NestedMeta::Lit(lit) => Err(syn::Error::new(
lit.span(),
"generic_over_ip must be passed at most one \
type parameter identifier, not a literal",
)),
}?;
let bound_error_message = "the bound passed to generic_over_ip \
must be Ip, IpAddress, or GenericOverIp";
let bound = match bound {
syn::NestedMeta::Meta(meta) => match meta {
syn::Meta::Path(path) => match path.get_ident() {
None => Err(syn::Error::new(path.span(), bound_error_message)),
Some(ident) => Ok(ident.clone()),
},
syn::Meta::List(_) | syn::Meta::NameValue(_) => {
Err(syn::Error::new(meta.span(), bound_error_message))
}
},
syn::NestedMeta::Lit(lit) => Err(syn::Error::new(lit.span(), bound_error_message)),
}?;
let bound = match bound.to_string().as_str() {
"Ip" => IpGenericParamType::IpVersion,
"IpAddress" => IpGenericParamType::IpAddress,
"GenericOverIp" => IpGenericParamType::GenericOverIp,
_ => return Err(syn::Error::new(bound.span(), bound_error_message)),
};
Ok(Some((ident, bound)))
}
fn collect_bounds<'a>(
ident: &'a Ident,
mut generics: impl Iterator<Item = &'a TypeParam>,
) -> Option<Vec<&'a TypeParamBound>> {
generics.find_map(|t| if &t.ident == ident { Some(t.bounds.iter().collect()) } else { None })
}
fn with_type_param_replaced(
generics: &Generics,
to_find: &Ident,
replacement: TypePath,
) -> Option<AngleBracketedGenericArguments> {
let args: Punctuated<_, _> = generics
.params
.iter()
.map(|g| match g {
GenericParam::Const(c) => GenericArgument::Const(parse_quote!(#c.ident)),
GenericParam::Lifetime(l) => GenericArgument::Lifetime(l.lifetime.clone()),
GenericParam::Type(t) => {
if &t.ident == to_find {
GenericArgument::Type(Type::Path(replacement.clone()))
} else {
GenericArgument::Type(Type::Path(TypePath {
path: t.ident.clone().into(),
qself: None,
}))
}
}
})
.collect();
(args.len() != 0).then(|| AngleBracketedGenericArguments {
args,
colon2_token: None,
lt_token: Default::default(),
gt_token: Default::default(),
})
}