use proc_macro2::TokenStream;
use quote::quote;
use quote::ToTokens;
use syn::parse::{Parse, ParseStream, Result};
use syn::punctuated::Punctuated;
use syn::{
parenthesized, parse_quote, Attribute, GenericArgument, Generics, Ident, Meta, MetaNameValue,
PathArguments, Token, Type,
pub(crate) fn expand(input: SqlFunctionDecl) -> TokenStream {
let SqlFunctionDecl {
mut attributes,
mut generics,
} = input;
let sql_name = attributes
.find(|attr| attr.meta.path().is_ident("sql_name"))
.and_then(|attr| {
if let Meta::NameValue(MetaNameValue {
syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(ref lit),
}) = attr.meta
} else {
.unwrap_or_else(|| fn_name.to_string());
let is_aggregate = attributes
.any(|attr| attr.meta.path().is_ident("aggregate"));
attributes.retain(|attr| {
!attr.meta.path().is_ident("sql_name") && !attr.meta.path().is_ident("aggregate")
let args = &args;
let (ref arg_name, ref arg_type): (Vec<_>, Vec<_>) = args
.map(|StrictFnArg { name, ty, .. }| (name, ty))
let arg_struct_assign = args.iter().map(
|StrictFnArg {
name, colon_token, ..
}| {
let name2 = name.clone();
quote!(#name #colon_token #name2.as_expression())
let type_args = &generics
.map(|type_param| type_param.ident.clone())
for StrictFnArg { name, .. } in args {
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
// Even if we force an empty where clause, it still won't print the where
// token with no bounds.
let where_clause = where_clause
.map(|w| quote!(#w))
.unwrap_or_else(|| quote!(where));
let mut generics_with_internal = generics.clone();
let (impl_generics_internal, _, _) = generics_with_internal.split_for_impl();
let sql_type;
let numeric_derive;
if arg_name.is_empty() {
sql_type = None;
// FIXME: We can always derive once trivial bounds are stable
numeric_derive = None;
} else {
sql_type = Some(quote!((#(#arg_name),*): Expression,));
numeric_derive = Some(quote!(#[derive(diesel::sql_types::DieselNumericOps)]));
let args_iter = args.iter();
let mut tokens = quote! {
use diesel::{self, QueryResult};
use diesel::expression::{AsExpression, Expression, SelectableExpression, AppearsOnTable, ValidGrouping};
use diesel::query_builder::{QueryFragment, AstPass};
use diesel::sql_types::*;
use super::*;
#[derive(Debug, Clone, Copy, diesel::query_builder::QueryId)]
pub struct #fn_name #ty_generics {
#(pub(in super) #args_iter,)*
#(pub(in super) #type_args: ::std::marker::PhantomData<#type_args>,)*
pub type HelperType #ty_generics = #fn_name <
#(<#arg_name as AsExpression<#arg_type>>::Expression,)*
impl #impl_generics Expression for #fn_name #ty_generics
type SqlType = #return_type;
// __DieselInternal is what we call QS normally
impl #impl_generics_internal SelectableExpression<__DieselInternal>
for #fn_name #ty_generics
#(#arg_name: SelectableExpression<__DieselInternal>,)*
Self: AppearsOnTable<__DieselInternal>,
// __DieselInternal is what we call QS normally
impl #impl_generics_internal AppearsOnTable<__DieselInternal>
for #fn_name #ty_generics
#(#arg_name: AppearsOnTable<__DieselInternal>,)*
Self: Expression,
// __DieselInternal is what we call DB normally
impl #impl_generics_internal QueryFragment<__DieselInternal>
for #fn_name #ty_generics
__DieselInternal: diesel::backend::Backend,
#(#arg_name: QueryFragment<__DieselInternal>,)*
fn walk_ast<'__b>(&'__b self, mut out: AstPass<'_, '__b, __DieselInternal>) -> QueryResult<()>{
out.push_sql(concat!(#sql_name, "("));
// we unroll the arguments manually here, to prevent borrow check issues
let mut needs_comma = false;
if !self.#arg_name.is_noop(out.backend())? {
if needs_comma {
out.push_sql(", ");
needs_comma = true;
let is_supported_on_sqlite = cfg!(feature = "sqlite")
&& type_args.is_empty()
&& is_sqlite_type(&return_type)
&& arg_type.iter().all(|a| is_sqlite_type(a));
if is_aggregate {
tokens = quote! {
impl #impl_generics_internal ValidGrouping<__DieselInternal>
for #fn_name #ty_generics
type IsAggregate = diesel::expression::is_aggregate::Yes;
if is_supported_on_sqlite {
tokens = quote! {
use diesel::sqlite::{Sqlite, SqliteConnection};
use diesel::serialize::ToSql;
use diesel::deserialize::{FromSqlRow, StaticallySizedRow};
use diesel::sqlite::SqliteAggregateFunction;
use diesel::sql_types::IntoNullable;
match arg_name.len() {
x if x > 1 => {
tokens = quote! {
/// Registers an implementation for this aggregate function on the given connection
/// This function must be called for every `SqliteConnection` before
/// this SQL function can be used on SQLite. The implementation must be
/// deterministic (returns the same result given the same arguments).
pub fn register_impl<A, #(#arg_name,)*>(
conn: &mut SqliteConnection
) -> QueryResult<()>
A: SqliteAggregateFunction<(#(#arg_name,)*)>
+ Send
+ 'static
+ ::std::panic::UnwindSafe
+ ::std::panic::RefUnwindSafe,
A::Output: ToSql<#return_type, Sqlite>,
(#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
StaticallySizedRow<(#(#arg_type,)*), Sqlite> +
conn.register_aggregate_function::<(#(#arg_type,)*), #return_type, _, _, A>(#sql_name)
1 => {
let arg_name = arg_name[0];
let arg_type = arg_type[0];
tokens = quote! {
/// Registers an implementation for this aggregate function on the given connection
/// This function must be called for every `SqliteConnection` before
/// this SQL function can be used on SQLite. The implementation must be
/// deterministic (returns the same result given the same arguments).
pub fn register_impl<A, #arg_name>(
conn: &mut SqliteConnection
) -> QueryResult<()>
A: SqliteAggregateFunction<#arg_name>
+ Send
+ 'static
+ std::panic::UnwindSafe
+ std::panic::RefUnwindSafe,
A::Output: ToSql<#return_type, Sqlite>,
#arg_name: FromSqlRow<#arg_type, Sqlite> +
StaticallySizedRow<#arg_type, Sqlite> +
conn.register_aggregate_function::<#arg_type, #return_type, _, _, A>(#sql_name)
_ => (),
} else {
tokens = quote! {
pub struct __Derived<#(#arg_name,)*>(#(#arg_name,)*);
impl #impl_generics_internal ValidGrouping<__DieselInternal>
for #fn_name #ty_generics
__Derived<#(#arg_name,)*>: ValidGrouping<__DieselInternal>,
type IsAggregate = <__Derived<#(#arg_name,)*> as ValidGrouping<__DieselInternal>>::IsAggregate;
if is_supported_on_sqlite && !arg_name.is_empty() {
tokens = quote! {
use diesel::sqlite::{Sqlite, SqliteConnection};
use diesel::serialize::ToSql;
use diesel::deserialize::{FromSqlRow, StaticallySizedRow};
/// Registers an implementation for this function on the given connection
/// This function must be called for every `SqliteConnection` before
/// this SQL function can be used on SQLite. The implementation must be
/// deterministic (returns the same result given the same arguments). If
/// the function is nondeterministic, call
/// `register_nondeterministic_impl` instead.
pub fn register_impl<F, Ret, #(#arg_name,)*>(
conn: &mut SqliteConnection,
f: F,
) -> QueryResult<()>
F: Fn(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static,
(#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
Ret: ToSql<#return_type, Sqlite>,
conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
move |(#(#arg_name,)*)| f(#(#arg_name,)*),
/// Registers an implementation for this function on the given connection
/// This function must be called for every `SqliteConnection` before
/// this SQL function can be used on SQLite.
/// `register_nondeterministic_impl` should only be used if your
/// function can return different results with the same arguments (e.g.
/// `random`). If your function is deterministic, you should call
/// `register_impl` instead.
pub fn register_nondeterministic_impl<F, Ret, #(#arg_name,)*>(
conn: &mut SqliteConnection,
mut f: F,
) -> QueryResult<()>
F: FnMut(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static,
(#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
Ret: ToSql<#return_type, Sqlite>,
conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
move |(#(#arg_name,)*)| f(#(#arg_name,)*),
if is_supported_on_sqlite && arg_name.is_empty() {
tokens = quote! {
use diesel::sqlite::{Sqlite, SqliteConnection};
use diesel::serialize::ToSql;
/// Registers an implementation for this function on the given connection
/// This function must be called for every `SqliteConnection` before
/// this SQL function can be used on SQLite. The implementation must be
/// deterministic (returns the same result given the same arguments). If
/// the function is nondeterministic, call
/// `register_nondeterministic_impl` instead.
pub fn register_impl<F, Ret>(
conn: &SqliteConnection,
f: F,
) -> QueryResult<()>
F: Fn() -> Ret + std::panic::UnwindSafe + Send + 'static,
Ret: ToSql<#return_type, Sqlite>,
conn.register_noarg_sql_function::<#return_type, _, _>(
/// Registers an implementation for this function on the given connection
/// This function must be called for every `SqliteConnection` before
/// this SQL function can be used on SQLite.
/// `register_nondeterministic_impl` should only be used if your
/// function can return different results with the same arguments (e.g.
/// `random`). If your function is deterministic, you should call
/// `register_impl` instead.
pub fn register_nondeterministic_impl<F, Ret>(
conn: &SqliteConnection,
mut f: F,
) -> QueryResult<()>
F: FnMut() -> Ret + std::panic::UnwindSafe + Send + 'static,
Ret: ToSql<#return_type, Sqlite>,
conn.register_noarg_sql_function::<#return_type, _, _>(
let args_iter = args.iter();
quote! {
pub #fn_token #fn_name #impl_generics (#(#args_iter,)*)
-> #fn_name::HelperType #ty_generics
#(#arg_name: ::diesel::expression::AsExpression<#arg_type>,)*
#fn_name::#fn_name {
#(#type_args: ::std::marker::PhantomData,)*
#[allow(non_camel_case_types, non_snake_case, unused_imports)]
pub(crate) mod #fn_name {
pub(crate) struct SqlFunctionDecl {
attributes: Vec<Attribute>,
fn_token: Token![fn],
fn_name: Ident,
generics: Generics,
args: Punctuated<StrictFnArg, Token![,]>,
return_type: Type,
impl Parse for SqlFunctionDecl {
fn parse(input: ParseStream) -> Result<Self> {
let attributes = Attribute::parse_outer(input)?;
let fn_token: Token![fn] = input.parse()?;
let fn_name = Ident::parse(input)?;
let generics = Generics::parse(input)?;
let args;
let _paren = parenthesized!(args in input);
let args = args.parse_terminated(StrictFnArg::parse, Token![,])?;
let return_type = if Option::<Token![->]>::parse(input)?.is_some() {
} else {
let _semi = Option::<Token![;]>::parse(input)?;
Ok(Self {
/// Essentially the same as ArgCaptured, but only allowing ident patterns
struct StrictFnArg {
name: Ident,
colon_token: Token![:],
ty: Type,
impl Parse for StrictFnArg {
fn parse(input: ParseStream) -> Result<Self> {
let name = input.parse()?;
let colon_token = input.parse()?;
let ty = input.parse()?;
Ok(Self {
impl ToTokens for StrictFnArg {
fn to_tokens(&self, tokens: &mut TokenStream) {;
fn is_sqlite_type(ty: &Type) -> bool {
let last_segment = if let Type::Path(tp) = ty {
if let Some(segment) = tp.path.segments.last() {
} else {
return false;
} else {
return false;
let ident = last_segment.ident.to_string();
if ident == "Nullable" {
if let PathArguments::AngleBracketed(ref ab) = last_segment.arguments {
if let Some(GenericArgument::Type(ty)) = ab.args.first() {
return is_sqlite_type(ty);
return false;