blob: 9a9ed90fe59b31e22065951067d15bd6e4e1f5b1 [file] [log] [blame]
use proc_macro2::TokenStream;
use quote::quote;
use quote::ToTokens;
use syn::parse::{Parse, ParseStream, Result};
use syn::punctuated::Punctuated;
use syn::{
parenthesized, parse_quote, Attribute, GenericArgument, Generics, Ident, Meta, MetaNameValue,
PathArguments, Token, Type,
};
pub(crate) fn expand(input: SqlFunctionDecl) -> TokenStream {
let SqlFunctionDecl {
mut attributes,
fn_token,
fn_name,
mut generics,
args,
return_type,
} = input;
let sql_name = attributes
.iter()
.find(|attr| attr.meta.path().is_ident("sql_name"))
.and_then(|attr| {
if let Meta::NameValue(MetaNameValue {
value:
syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(ref lit),
..
}),
..
}) = attr.meta
{
Some(lit.value())
} else {
None
}
})
.unwrap_or_else(|| fn_name.to_string());
let is_aggregate = attributes
.iter()
.any(|attr| attr.meta.path().is_ident("aggregate"));
attributes.retain(|attr| {
!attr.meta.path().is_ident("sql_name") && !attr.meta.path().is_ident("aggregate")
});
let args = &args;
let (ref arg_name, ref arg_type): (Vec<_>, Vec<_>) = args
.iter()
.map(|StrictFnArg { name, ty, .. }| (name, ty))
.unzip();
let arg_struct_assign = args.iter().map(
|StrictFnArg {
name, colon_token, ..
}| {
let name2 = name.clone();
quote!(#name #colon_token #name2.as_expression())
},
);
let type_args = &generics
.type_params()
.map(|type_param| type_param.ident.clone())
.collect::<Vec<_>>();
for StrictFnArg { name, .. } in args {
generics.params.push(parse_quote!(#name));
}
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
// Even if we force an empty where clause, it still won't print the where
// token with no bounds.
let where_clause = where_clause
.map(|w| quote!(#w))
.unwrap_or_else(|| quote!(where));
let mut generics_with_internal = generics.clone();
generics_with_internal
.params
.push(parse_quote!(__DieselInternal));
let (impl_generics_internal, _, _) = generics_with_internal.split_for_impl();
let sql_type;
let numeric_derive;
if arg_name.is_empty() {
sql_type = None;
// FIXME: We can always derive once trivial bounds are stable
numeric_derive = None;
} else {
sql_type = Some(quote!((#(#arg_name),*): Expression,));
numeric_derive = Some(quote!(#[derive(diesel::sql_types::DieselNumericOps)]));
}
let args_iter = args.iter();
let mut tokens = quote! {
use diesel::{self, QueryResult};
use diesel::expression::{AsExpression, Expression, SelectableExpression, AppearsOnTable, ValidGrouping};
use diesel::query_builder::{QueryFragment, AstPass};
use diesel::sql_types::*;
use super::*;
#[derive(Debug, Clone, Copy, diesel::query_builder::QueryId)]
#numeric_derive
pub struct #fn_name #ty_generics {
#(pub(in super) #args_iter,)*
#(pub(in super) #type_args: ::std::marker::PhantomData<#type_args>,)*
}
pub type HelperType #ty_generics = #fn_name <
#(#type_args,)*
#(<#arg_name as AsExpression<#arg_type>>::Expression,)*
>;
impl #impl_generics Expression for #fn_name #ty_generics
#where_clause
#sql_type
{
type SqlType = #return_type;
}
// __DieselInternal is what we call QS normally
impl #impl_generics_internal SelectableExpression<__DieselInternal>
for #fn_name #ty_generics
#where_clause
#(#arg_name: SelectableExpression<__DieselInternal>,)*
Self: AppearsOnTable<__DieselInternal>,
{
}
// __DieselInternal is what we call QS normally
impl #impl_generics_internal AppearsOnTable<__DieselInternal>
for #fn_name #ty_generics
#where_clause
#(#arg_name: AppearsOnTable<__DieselInternal>,)*
Self: Expression,
{
}
// __DieselInternal is what we call DB normally
impl #impl_generics_internal QueryFragment<__DieselInternal>
for #fn_name #ty_generics
where
__DieselInternal: diesel::backend::Backend,
#(#arg_name: QueryFragment<__DieselInternal>,)*
{
#[allow(unused_assignments)]
fn walk_ast<'__b>(&'__b self, mut out: AstPass<'_, '__b, __DieselInternal>) -> QueryResult<()>{
out.push_sql(concat!(#sql_name, "("));
// we unroll the arguments manually here, to prevent borrow check issues
let mut needs_comma = false;
#(
if !self.#arg_name.is_noop(out.backend())? {
if needs_comma {
out.push_sql(", ");
}
self.#arg_name.walk_ast(out.reborrow())?;
needs_comma = true;
}
)*
out.push_sql(")");
Ok(())
}
}
};
let is_supported_on_sqlite = cfg!(feature = "sqlite")
&& type_args.is_empty()
&& is_sqlite_type(&return_type)
&& arg_type.iter().all(|a| is_sqlite_type(a));
if is_aggregate {
tokens = quote! {
#tokens
impl #impl_generics_internal ValidGrouping<__DieselInternal>
for #fn_name #ty_generics
{
type IsAggregate = diesel::expression::is_aggregate::Yes;
}
};
if is_supported_on_sqlite {
tokens = quote! {
#tokens
use diesel::sqlite::{Sqlite, SqliteConnection};
use diesel::serialize::ToSql;
use diesel::deserialize::{FromSqlRow, StaticallySizedRow};
use diesel::sqlite::SqliteAggregateFunction;
use diesel::sql_types::IntoNullable;
};
match arg_name.len() {
x if x > 1 => {
tokens = quote! {
#tokens
#[allow(dead_code)]
/// Registers an implementation for this aggregate function on the given connection
///
/// This function must be called for every `SqliteConnection` before
/// this SQL function can be used on SQLite. The implementation must be
/// deterministic (returns the same result given the same arguments).
pub fn register_impl<A, #(#arg_name,)*>(
conn: &mut SqliteConnection
) -> QueryResult<()>
where
A: SqliteAggregateFunction<(#(#arg_name,)*)>
+ Send
+ 'static
+ ::std::panic::UnwindSafe
+ ::std::panic::RefUnwindSafe,
A::Output: ToSql<#return_type, Sqlite>,
(#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
StaticallySizedRow<(#(#arg_type,)*), Sqlite> +
::std::panic::UnwindSafe,
{
conn.register_aggregate_function::<(#(#arg_type,)*), #return_type, _, _, A>(#sql_name)
}
};
}
1 => {
let arg_name = arg_name[0];
let arg_type = arg_type[0];
tokens = quote! {
#tokens
#[allow(dead_code)]
/// Registers an implementation for this aggregate function on the given connection
///
/// This function must be called for every `SqliteConnection` before
/// this SQL function can be used on SQLite. The implementation must be
/// deterministic (returns the same result given the same arguments).
pub fn register_impl<A, #arg_name>(
conn: &mut SqliteConnection
) -> QueryResult<()>
where
A: SqliteAggregateFunction<#arg_name>
+ Send
+ 'static
+ std::panic::UnwindSafe
+ std::panic::RefUnwindSafe,
A::Output: ToSql<#return_type, Sqlite>,
#arg_name: FromSqlRow<#arg_type, Sqlite> +
StaticallySizedRow<#arg_type, Sqlite> +
::std::panic::UnwindSafe,
{
conn.register_aggregate_function::<#arg_type, #return_type, _, _, A>(#sql_name)
}
};
}
_ => (),
}
}
} else {
tokens = quote! {
#tokens
#[derive(ValidGrouping)]
pub struct __Derived<#(#arg_name,)*>(#(#arg_name,)*);
impl #impl_generics_internal ValidGrouping<__DieselInternal>
for #fn_name #ty_generics
where
__Derived<#(#arg_name,)*>: ValidGrouping<__DieselInternal>,
{
type IsAggregate = <__Derived<#(#arg_name,)*> as ValidGrouping<__DieselInternal>>::IsAggregate;
}
};
if is_supported_on_sqlite && !arg_name.is_empty() {
tokens = quote! {
#tokens
use diesel::sqlite::{Sqlite, SqliteConnection};
use diesel::serialize::ToSql;
use diesel::deserialize::{FromSqlRow, StaticallySizedRow};
#[allow(dead_code)]
/// Registers an implementation for this function on the given connection
///
/// This function must be called for every `SqliteConnection` before
/// this SQL function can be used on SQLite. The implementation must be
/// deterministic (returns the same result given the same arguments). If
/// the function is nondeterministic, call
/// `register_nondeterministic_impl` instead.
pub fn register_impl<F, Ret, #(#arg_name,)*>(
conn: &mut SqliteConnection,
f: F,
) -> QueryResult<()>
where
F: Fn(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static,
(#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
Ret: ToSql<#return_type, Sqlite>,
{
conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
#sql_name,
true,
move |(#(#arg_name,)*)| f(#(#arg_name,)*),
)
}
#[allow(dead_code)]
/// Registers an implementation for this function on the given connection
///
/// This function must be called for every `SqliteConnection` before
/// this SQL function can be used on SQLite.
/// `register_nondeterministic_impl` should only be used if your
/// function can return different results with the same arguments (e.g.
/// `random`). If your function is deterministic, you should call
/// `register_impl` instead.
pub fn register_nondeterministic_impl<F, Ret, #(#arg_name,)*>(
conn: &mut SqliteConnection,
mut f: F,
) -> QueryResult<()>
where
F: FnMut(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static,
(#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
Ret: ToSql<#return_type, Sqlite>,
{
conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
#sql_name,
false,
move |(#(#arg_name,)*)| f(#(#arg_name,)*),
)
}
};
}
if is_supported_on_sqlite && arg_name.is_empty() {
tokens = quote! {
#tokens
use diesel::sqlite::{Sqlite, SqliteConnection};
use diesel::serialize::ToSql;
#[allow(dead_code)]
/// Registers an implementation for this function on the given connection
///
/// This function must be called for every `SqliteConnection` before
/// this SQL function can be used on SQLite. The implementation must be
/// deterministic (returns the same result given the same arguments). If
/// the function is nondeterministic, call
/// `register_nondeterministic_impl` instead.
pub fn register_impl<F, Ret>(
conn: &SqliteConnection,
f: F,
) -> QueryResult<()>
where
F: Fn() -> Ret + std::panic::UnwindSafe + Send + 'static,
Ret: ToSql<#return_type, Sqlite>,
{
conn.register_noarg_sql_function::<#return_type, _, _>(
#sql_name,
true,
f,
)
}
#[allow(dead_code)]
/// Registers an implementation for this function on the given connection
///
/// This function must be called for every `SqliteConnection` before
/// this SQL function can be used on SQLite.
/// `register_nondeterministic_impl` should only be used if your
/// function can return different results with the same arguments (e.g.
/// `random`). If your function is deterministic, you should call
/// `register_impl` instead.
pub fn register_nondeterministic_impl<F, Ret>(
conn: &SqliteConnection,
mut f: F,
) -> QueryResult<()>
where
F: FnMut() -> Ret + std::panic::UnwindSafe + Send + 'static,
Ret: ToSql<#return_type, Sqlite>,
{
conn.register_noarg_sql_function::<#return_type, _, _>(
#sql_name,
false,
f,
)
}
};
}
}
let args_iter = args.iter();
quote! {
#(#attributes)*
#[allow(non_camel_case_types)]
pub #fn_token #fn_name #impl_generics (#(#args_iter,)*)
-> #fn_name::HelperType #ty_generics
#where_clause
#(#arg_name: ::diesel::expression::AsExpression<#arg_type>,)*
{
#fn_name::#fn_name {
#(#arg_struct_assign,)*
#(#type_args: ::std::marker::PhantomData,)*
}
}
#[doc(hidden)]
#[allow(non_camel_case_types, non_snake_case, unused_imports)]
pub(crate) mod #fn_name {
#tokens
}
}
}
pub(crate) struct SqlFunctionDecl {
attributes: Vec<Attribute>,
fn_token: Token![fn],
fn_name: Ident,
generics: Generics,
args: Punctuated<StrictFnArg, Token![,]>,
return_type: Type,
}
impl Parse for SqlFunctionDecl {
fn parse(input: ParseStream) -> Result<Self> {
let attributes = Attribute::parse_outer(input)?;
let fn_token: Token![fn] = input.parse()?;
let fn_name = Ident::parse(input)?;
let generics = Generics::parse(input)?;
let args;
let _paren = parenthesized!(args in input);
let args = args.parse_terminated(StrictFnArg::parse, Token![,])?;
let return_type = if Option::<Token![->]>::parse(input)?.is_some() {
Type::parse(input)?
} else {
parse_quote!(diesel::expression::expression_types::NotSelectable)
};
let _semi = Option::<Token![;]>::parse(input)?;
Ok(Self {
attributes,
fn_token,
fn_name,
generics,
args,
return_type,
})
}
}
/// Essentially the same as ArgCaptured, but only allowing ident patterns
struct StrictFnArg {
name: Ident,
colon_token: Token![:],
ty: Type,
}
impl Parse for StrictFnArg {
fn parse(input: ParseStream) -> Result<Self> {
let name = input.parse()?;
let colon_token = input.parse()?;
let ty = input.parse()?;
Ok(Self {
name,
colon_token,
ty,
})
}
}
impl ToTokens for StrictFnArg {
fn to_tokens(&self, tokens: &mut TokenStream) {
self.name.to_tokens(tokens);
self.colon_token.to_tokens(tokens);
self.name.to_tokens(tokens);
}
}
fn is_sqlite_type(ty: &Type) -> bool {
let last_segment = if let Type::Path(tp) = ty {
if let Some(segment) = tp.path.segments.last() {
segment
} else {
return false;
}
} else {
return false;
};
let ident = last_segment.ident.to_string();
if ident == "Nullable" {
if let PathArguments::AngleBracketed(ref ab) = last_segment.arguments {
if let Some(GenericArgument::Type(ty)) = ab.args.first() {
return is_sqlite_type(ty);
}
}
return false;
}
[
"BigInt",
"Binary",
"Bool",
"Date",
"Double",
"Float",
"Integer",
"Numeric",
"SmallInt",
"Text",
"Time",
"Timestamp",
]
.contains(&ident.as_str())
}