// Copyright 2022 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

use assert_matches::assert_matches;
use proc_macro::TokenStream;
use proc_macro2::{TokenStream as TokenStream2, TokenTree};
use quote::{quote, ToTokens as _};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{
    parse_quote, AngleBracketedGenericArguments, GenericArgument, GenericParam, Generics, Ident,
    Type, TypeParam, TypeParamBound, TypePath,
};

/// Implements a derive macro for [`net_types::ip::GenericOverIp`].
/// Requires that #[derive(GenericOverIp)] invocations explicitly specify
/// which type parameter is the generic-over-ip one with the
/// `#[generic_over_ip]` attribute, rather than inferring it from the bounds
/// on the struct generics.
///
/// Consider the following example:
///
/// ```
///  #[derive(GenericOverIp)]
///  #[generic_over_ip(<ARGUMENTS EXPLAINED BELOW>)]
///  struct Foo<T>(T);
/// ```
///
/// `#[generic_over_ip(T, Ip)]` specifies that the GenericOverIp impl
/// should be written treating `T` as an `Ip` implementor (either `Ipv4` or
/// `Ipv6`).
///
/// `#[generic_over_ip(T, IpAddress)]` specifies that `T` is an `IpAddress`
/// implementor (`Ipv4Addr` or `Ipv6Addr`).
///
/// `#[generic_over_ip(T, GenericOverIp)]` specifies that `T` implements
/// `GenericOverIp<I>` for all `I: Ip`. (Notably, we'd like to use this case
/// for the Ip and IpAddress cases above, but we cannot due to issues
/// with conflicting blanket impls.)
///
/// `#[generic_over_ip()]` specifies that `Foo` is IP-invariant.
#[proc_macro_derive(GenericOverIp, attributes(generic_over_ip))]
pub fn derive_generic_over_ip(input: TokenStream) -> TokenStream {
    let ast = syn::parse(input).unwrap();

    impl_derive_generic_over_ip(&ast).into()
}

fn impl_derive_generic_over_ip(ast: &syn::DeriveInput) -> TokenStream2 {
    let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
    if where_clause.is_some() {
        return quote! {
            compile_error!("deriving GenericOverIp for types with 'where' clauses is unsupported")
        }
        .into();
    }

    let name = &ast.ident;

    let specified_generic_over_ip = match find_generic_over_ip_attr(ast) {
        Ok(param) => param,
        Err(e) => return e.into_compile_error().into(),
    };

    let extra_bounds = match &specified_generic_over_ip {
        None => Vec::new(),
        Some((ident, _)) => match collect_bounds(ident, ast.generics.type_params()) {
            Some(bounds) => bounds,
            None => {
                return syn::Error::new(
                    ast.generics.span(),
                    format!(
                        "found no type parameter named {ident:?} as specified in \
                            the generic_over_ip attribute"
                    ),
                )
                .into_compile_error()
                .into();
            }
        },
    };

    // Drop the first and last tokens, which should be '<' and '>', and the
    // trailing comma if there is one.
    let mut impl_generics = impl_generics.into_token_stream().into_iter();

    let expect_trailing_angle_bracket = impl_generics.next().is_some_and(|first| {
        assert_matches!(first, TokenTree::Punct(p) if p.as_char() == '<');
        true
    });
    let mut impl_generics: Vec<_> = impl_generics.collect();
    if expect_trailing_angle_bracket {
        assert_matches!(impl_generics.pop(), Some(TokenTree::Punct(p)) if p.as_char() == '>');
    }

    // Add a trailing comma if `impl_generics` is non-empty and doesn't have one.
    match impl_generics.last() {
        Some(TokenTree::Punct(p)) if p.as_char() == ',' => {}
        None => {}
        Some(_) => {
            impl_generics.push(parse_quote! { , });
        }
    }

    let impl_generics = impl_generics.into_iter().collect::<TokenStream2>();

    match specified_generic_over_ip.clone() {
        Some((ident, param_type)) => {
            // Emit an impl that substitutes the identified type parameter
            // to produce the new GenericOverIp::Type.
            let generic_ip_name: Ident = parse_quote!(IpType);
            let extra_bounds_target: TypePath = match param_type {
                IpGenericParamType::IpVersion => parse_quote!(#generic_ip_name),
                IpGenericParamType::IpAddress => parse_quote!(#generic_ip_name::Addr),
                IpGenericParamType::GenericOverIp => parse_quote! {
                    <#ident as GenericOverIp<IpType>>::Type
                },
            };

            let bound_if_generic_over_ip = match param_type {
                IpGenericParamType::IpVersion | IpGenericParamType::IpAddress => None,
                IpGenericParamType::GenericOverIp => Some(quote! {
                    #ident: GenericOverIp<IpType>,
                }),
            };

            let generic_bounds = with_type_param_replaced(
                &ast.generics,
                &ident,
                parse_quote! {
                    #extra_bounds_target
                },
            );

            quote! {
                impl <#impl_generics #generic_ip_name: Ip>
                GenericOverIp<IpType> for #name #type_generics
                where #bound_if_generic_over_ip #extra_bounds_target: #(#extra_bounds)+*, {
                    type Type = #name #generic_bounds;
                }
            }
        }
        None => {
            // The type is IP-invariant so `GenericOverIp::Type` is always Self.`
            quote! {
                impl <#impl_generics IpType: Ip> GenericOverIp<IpType> for #name #type_generics {
                    type Type = Self;
                }
            }
        }
    }
}

#[derive(Debug, Clone)]
enum IpGenericParamType {
    IpVersion,
    IpAddress,
    GenericOverIp,
}

fn find_generic_over_ip_attr(
    ast: &syn::DeriveInput,
) -> Result<Option<(Ident, IpGenericParamType)>, syn::Error> {
    let mut attrs = ast
        .attrs
        .iter()
        .filter(|attr| attr.path.get_ident().map(|i| i == "generic_over_ip").unwrap_or(false));
    let attr = attrs.next().ok_or_else(|| {
        syn::Error::new(
            ast.ident.span(),
            "derive(GenericOverIp) cannot be used without the generic_over_ip attribute",
        )
    })?;
    match attrs.next() {
        None => {}
        Some(attr) => {
            return Err(syn::Error::new(
                attr.span(),
                "derive(GenericOverIp) cannot be used with multiple generic_over_ip attributes",
            ))
        }
    }

    let meta = attr.parse_meta().map_err(|e| {
        syn::Error::new(attr.span(), format!("generic_over_ip attr did not parse as Meta: {e:?}"))
    })?;

    let list = match meta {
        syn::Meta::Path(_) | syn::Meta::NameValue(_) => {
            return Err(syn::Error::new(
                meta.span(),
                "generic_over_ip must be passed at most one\
                type parameter identifier",
            ));
        }
        syn::Meta::List(list) => list,
    };

    if list.nested.is_empty() {
        return Ok(None);
    } else if list.nested.len() != 2 {
        return Err(syn::Error::new(
            list.span(),
            "generic_over_ip must be either be passed no \
                         arguments, or one type parameter identifier and its \
                         trait bound (Ip, IpAddress, or GenericOverIp)",
        ));
    }

    let mut iter = list.nested.into_iter();
    let (ident, bound) = (iter.next().unwrap(), iter.next().unwrap());

    let ident = match ident {
        syn::NestedMeta::Meta(meta) => match meta {
            syn::Meta::Path(path) => match path.get_ident() {
                None => Err(syn::Error::new(
                    path.span(),
                    "generic_over_ip must be passed a parameter identifier",
                )),
                Some(ident) => Ok(ident.clone()),
            },
            syn::Meta::List(_) | syn::Meta::NameValue(_) => Err(syn::Error::new(
                meta.span(),
                "generic_over_ip must be passed at most one \
                             type parameter identifier, not a list or name-value pair",
            )),
        },
        syn::NestedMeta::Lit(lit) => Err(syn::Error::new(
            lit.span(),
            "generic_over_ip must be passed at most one \
                        type parameter identifier, not a literal",
        )),
    }?;

    let bound_error_message = "the bound passed to generic_over_ip \
                                             must be Ip, IpAddress, or GenericOverIp";

    let bound = match bound {
        syn::NestedMeta::Meta(meta) => match meta {
            syn::Meta::Path(path) => match path.get_ident() {
                None => Err(syn::Error::new(path.span(), bound_error_message)),
                Some(ident) => Ok(ident.clone()),
            },
            syn::Meta::List(_) | syn::Meta::NameValue(_) => {
                Err(syn::Error::new(meta.span(), bound_error_message))
            }
        },
        syn::NestedMeta::Lit(lit) => Err(syn::Error::new(lit.span(), bound_error_message)),
    }?;

    let bound = match bound.to_string().as_str() {
        "Ip" => IpGenericParamType::IpVersion,
        "IpAddress" => IpGenericParamType::IpAddress,
        "GenericOverIp" => IpGenericParamType::GenericOverIp,
        _ => return Err(syn::Error::new(bound.span(), bound_error_message)),
    };

    Ok(Some((ident, bound)))
}

fn collect_bounds<'a>(
    ident: &'a Ident,
    mut generics: impl Iterator<Item = &'a TypeParam>,
) -> Option<Vec<&'a TypeParamBound>> {
    generics.find_map(|t| if &t.ident == ident { Some(t.bounds.iter().collect()) } else { None })
}

fn with_type_param_replaced(
    generics: &Generics,
    to_find: &Ident,
    replacement: TypePath,
) -> Option<AngleBracketedGenericArguments> {
    let args: Punctuated<_, _> = generics
        .params
        .iter()
        .map(|g| match g {
            GenericParam::Const(c) => GenericArgument::Const(parse_quote!(#c.ident)),
            GenericParam::Lifetime(l) => GenericArgument::Lifetime(l.lifetime.clone()),
            GenericParam::Type(t) => {
                if &t.ident == to_find {
                    GenericArgument::Type(Type::Path(replacement.clone()))
                } else {
                    GenericArgument::Type(Type::Path(TypePath {
                        path: t.ident.clone().into(),
                        qself: None,
                    }))
                }
            }
        })
        .collect();
    (args.len() != 0).then(|| AngleBracketedGenericArguments {
        args,
        colon2_token: None,
        lt_token: Default::default(),
        gt_token: Default::default(),
    })
}
