Merge pull request #2325 from thekuom/feature/2191-sqlite-custom-aggregate
Add ability to create custom aggregate functions in sqlite
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5a4c50a..b89b0f5 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -28,6 +28,9 @@
* Added support for SQLite's `UPSERT`.
You can use this feature above SQLite version 3.24.0.
+
+* Added ability to create custom aggregate functions in SQLite.
+
* Multiple aggregate expressions can now appear together in the same select
clause. See [the upgrade notes](#2-0-0-upgrade-non-aggregate) for details.
@@ -36,6 +39,7 @@
functionality of `NonAggregate`. See [the upgrade
notes](#2-0-0-upgrade-non-aggregate) for details.
+
### Removed
* All previously deprecated items have been removed.
diff --git a/diesel/src/sqlite/connection/functions.rs b/diesel/src/sqlite/connection/functions.rs
index 3e6a526..e340b64 100644
--- a/diesel/src/sqlite/connection/functions.rs
+++ b/diesel/src/sqlite/connection/functions.rs
@@ -2,7 +2,7 @@
use super::raw::RawConnection;
use super::serialized_value::SerializedValue;
-use super::{Sqlite, SqliteValue};
+use super::{Sqlite, SqliteAggregateFunction, SqliteValue};
use crate::deserialize::{FromSqlRow, Queryable};
use crate::result::{DatabaseErrorKind, Error, QueryResult};
use crate::row::Row;
@@ -30,29 +30,75 @@
}
conn.register_sql_function(fn_name, fields_needed, deterministic, move |conn, args| {
- let mut row = FunctionRow { args };
- let args_row = Args::Row::build_from_row(&mut row).map_err(Error::DeserializationError)?;
- let args = Args::build(args_row);
+ let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
let result = f(conn, args);
- let mut buf = Output::new(Vec::new(), &());
- let is_null = result.to_sql(&mut buf).map_err(Error::SerializationError)?;
-
- let bytes = if let IsNull::Yes = is_null {
- None
- } else {
- Some(buf.into_inner())
- };
-
- Ok(SerializedValue {
- ty: Sqlite::metadata(&()),
- data: bytes,
- })
+ process_sql_function_result::<RetSqlType, Ret>(result)
})?;
Ok(())
}
+pub fn register_aggregate<ArgsSqlType, RetSqlType, Args, Ret, A>(
+ conn: &RawConnection,
+ fn_name: &str,
+) -> QueryResult<()>
+where
+ A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
+ Args: Queryable<ArgsSqlType, Sqlite>,
+ Ret: ToSql<RetSqlType, Sqlite>,
+ Sqlite: HasSqlType<RetSqlType>,
+{
+ let fields_needed = Args::Row::FIELDS_NEEDED;
+ if fields_needed > 127 {
+ return Err(Error::DatabaseError(
+ DatabaseErrorKind::UnableToSendCommand,
+ Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
+ ));
+ }
+
+ conn.register_aggregate_function::<ArgsSqlType, RetSqlType, Args, Ret, A>(
+ fn_name,
+ fields_needed,
+ )?;
+
+ Ok(())
+}
+
+pub(crate) fn build_sql_function_args<ArgsSqlType, Args>(
+ args: &[*mut ffi::sqlite3_value],
+) -> Result<Args, Error>
+where
+ Args: Queryable<ArgsSqlType, Sqlite>,
+{
+ let mut row = FunctionRow { args };
+ let args_row = Args::Row::build_from_row(&mut row).map_err(Error::DeserializationError)?;
+
+ Ok(Args::build(args_row))
+}
+
+pub(crate) fn process_sql_function_result<RetSqlType, Ret>(
+ result: Ret,
+) -> QueryResult<SerializedValue>
+where
+ Ret: ToSql<RetSqlType, Sqlite>,
+ Sqlite: HasSqlType<RetSqlType>,
+{
+ let mut buf = Output::new(Vec::new(), &());
+ let is_null = result.to_sql(&mut buf).map_err(Error::SerializationError)?;
+
+ let bytes = if let IsNull::Yes = is_null {
+ None
+ } else {
+ Some(buf.into_inner())
+ };
+
+ Ok(SerializedValue {
+ ty: Sqlite::metadata(&()),
+ data: bytes,
+ })
+}
+
struct FunctionRow<'a> {
args: &'a [*mut ffi::sqlite3_value],
}
diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs
index 67bbbe4..b200d15 100644
--- a/diesel/src/sqlite/connection/mod.rs
+++ b/diesel/src/sqlite/connection/mod.rs
@@ -15,6 +15,7 @@
use self::raw::RawConnection;
use self::statement_iterator::*;
use self::stmt::{Statement, StatementUse};
+use super::SqliteAggregateFunction;
use crate::connection::*;
use crate::deserialize::{Queryable, QueryableByName};
use crate::query_builder::bind_collector::RawBytesBindCollector;
@@ -238,6 +239,20 @@
)
}
+ #[doc(hidden)]
+ pub fn register_aggregate_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
+ &self,
+ fn_name: &str,
+ ) -> QueryResult<()>
+ where
+ A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
+ Args: Queryable<ArgsSqlType, Sqlite>,
+ Ret: ToSql<RetSqlType, Sqlite>,
+ Sqlite: HasSqlType<RetSqlType>,
+ {
+ functions::register_aggregate::<_, _, _, _, A>(&self.raw_connection, fn_name)
+ }
+
fn register_diesel_sql_functions(&self) -> QueryResult<()> {
use crate::sql_types::{Integer, Text};
@@ -370,4 +385,144 @@
.get_result::<(i32, i32, i32)>(&connection);
assert_eq!(Ok((2, 3, 4)), added);
}
+
+ use crate::sqlite::SqliteAggregateFunction;
+
+ sql_function! {
+ #[aggregate]
+ fn my_sum(expr: Integer) -> Integer;
+ }
+
+ #[derive(Default)]
+ struct MySum {
+ sum: i32,
+ }
+
+ impl SqliteAggregateFunction<i32> for MySum {
+ type Output = i32;
+
+ fn step(&mut self, expr: i32) {
+ self.sum += expr;
+ }
+
+ fn finalize(aggregator: Option<Self>) -> Self::Output {
+ aggregator.map(|a| a.sum).unwrap_or_default()
+ }
+ }
+
+ table! {
+ my_sum_example {
+ id -> Integer,
+ value -> Integer,
+ }
+ }
+
+ #[test]
+ fn register_aggregate_function() {
+ use self::my_sum_example::dsl::*;
+
+ let connection = SqliteConnection::establish(":memory:").unwrap();
+ connection
+ .execute(
+ "CREATE TABLE my_sum_example (id integer primary key autoincrement, value integer)",
+ )
+ .unwrap();
+ connection
+ .execute("INSERT INTO my_sum_example (value) VALUES (1), (2), (3)")
+ .unwrap();
+
+ my_sum::register_impl::<MySum, _>(&connection).unwrap();
+
+ let result = my_sum_example
+ .select(my_sum(value))
+ .get_result::<i32>(&connection);
+ assert_eq!(Ok(6), result);
+ }
+
+ #[test]
+ fn register_aggregate_function_returns_finalize_default_on_empty_set() {
+ use self::my_sum_example::dsl::*;
+
+ let connection = SqliteConnection::establish(":memory:").unwrap();
+ connection
+ .execute(
+ "CREATE TABLE my_sum_example (id integer primary key autoincrement, value integer)",
+ )
+ .unwrap();
+
+ my_sum::register_impl::<MySum, _>(&connection).unwrap();
+
+ let result = my_sum_example
+ .select(my_sum(value))
+ .get_result::<i32>(&connection);
+ assert_eq!(Ok(0), result);
+ }
+
+ sql_function! {
+ #[aggregate]
+ fn range_max(expr1: Integer, expr2: Integer, expr3: Integer) -> Nullable<Integer>;
+ }
+
+ #[derive(Default)]
+ struct RangeMax<T> {
+ max_value: Option<T>,
+ }
+
+ impl<T: Default + Ord + Copy + Clone> SqliteAggregateFunction<(T, T, T)> for RangeMax<T> {
+ type Output = Option<T>;
+
+ fn step(&mut self, (x0, x1, x2): (T, T, T)) {
+ let max = if x0 >= x1 && x0 >= x2 {
+ x0
+ } else if x1 >= x0 && x1 >= x2 {
+ x1
+ } else {
+ x2
+ };
+
+ self.max_value = match self.max_value {
+ Some(current_max_value) if max > current_max_value => Some(max),
+ None => Some(max),
+ _ => self.max_value,
+ };
+ }
+
+ fn finalize(aggregator: Option<Self>) -> Self::Output {
+ aggregator?.max_value
+ }
+ }
+
+ table! {
+ range_max_example {
+ id -> Integer,
+ value1 -> Integer,
+ value2 -> Integer,
+ value3 -> Integer,
+ }
+ }
+
+ #[test]
+ fn register_aggregate_multiarg_function() {
+ use self::range_max_example::dsl::*;
+
+ let connection = SqliteConnection::establish(":memory:").unwrap();
+ connection
+ .execute(
+ r#"CREATE TABLE range_max_example (
+ id integer primary key autoincrement,
+ value1 integer,
+ value2 integer,
+ value3 integer
+ )"#,
+ )
+ .unwrap();
+ connection.execute("INSERT INTO range_max_example (value1, value2, value3) VALUES (3, 2, 1), (2, 2, 2)").unwrap();
+
+ range_max::register_impl::<RangeMax<i32>, _, _, _>(&connection).unwrap();
+ let result = range_max_example
+ .select(range_max(value1, value2, value3))
+ .get_result::<Option<i32>>(&connection)
+ .unwrap();
+ assert_eq!(Some(3), result);
+ }
}
diff --git a/diesel/src/sqlite/connection/raw.rs b/diesel/src/sqlite/connection/raw.rs
index e46e4f1..6c7aee0 100644
--- a/diesel/src/sqlite/connection/raw.rs
+++ b/diesel/src/sqlite/connection/raw.rs
@@ -1,14 +1,19 @@
extern crate libsqlite3_sys as ffi;
-use std::ffi::{CStr, CString};
+use std::ffi::{CStr, CString, NulError};
use std::io::{stderr, Write};
use std::os::raw as libc;
use std::ptr::NonNull;
use std::{mem, ptr, slice, str};
+use super::functions::{build_sql_function_args, process_sql_function_result};
use super::serialized_value::SerializedValue;
+use super::{Sqlite, SqliteAggregateFunction};
+use crate::deserialize::Queryable;
use crate::result::Error::DatabaseError;
use crate::result::*;
+use crate::serialize::ToSql;
+use crate::sql_types::HasSqlType;
#[allow(missing_debug_implementations, missing_copy_implementations)]
pub struct RawConnection {
@@ -76,11 +81,8 @@
+ Send
+ 'static,
{
- let fn_name = CString::new(fn_name)?;
- let mut flags = ffi::SQLITE_UTF8;
- if deterministic {
- flags |= ffi::SQLITE_DETERMINISTIC;
- }
+ let fn_name = Self::get_fn_name(fn_name)?;
+ let flags = Self::get_flags(deterministic);
let callback_fn = Box::into_raw(Box::new(f));
let result = unsafe {
@@ -97,6 +99,53 @@
)
};
+ Self::process_sql_function_result(result)
+ }
+
+ pub fn register_aggregate_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
+ &self,
+ fn_name: &str,
+ num_args: usize,
+ ) -> QueryResult<()>
+ where
+ A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
+ Args: Queryable<ArgsSqlType, Sqlite>,
+ Ret: ToSql<RetSqlType, Sqlite>,
+ Sqlite: HasSqlType<RetSqlType>,
+ {
+ let fn_name = Self::get_fn_name(fn_name)?;
+ let flags = Self::get_flags(false);
+
+ let result = unsafe {
+ ffi::sqlite3_create_function_v2(
+ self.internal_connection.as_ptr(),
+ fn_name.as_ptr(),
+ num_args as _,
+ flags,
+ ptr::null_mut(),
+ None,
+ Some(run_aggregator_step_function::<_, _, _, _, A>),
+ Some(run_aggregator_final_function::<_, _, _, _, A>),
+ None,
+ )
+ };
+
+ Self::process_sql_function_result(result)
+ }
+
+ fn get_fn_name(fn_name: &str) -> Result<CString, NulError> {
+ Ok(CString::new(fn_name)?)
+ }
+
+ fn get_flags(deterministic: bool) -> i32 {
+ let mut flags = ffi::SQLITE_UTF8;
+ if deterministic {
+ flags |= ffi::SQLITE_DETERMINISTIC;
+ }
+ flags
+ }
+
+ fn process_sql_function_result(result: i32) -> Result<(), Error> {
if result == ffi::SQLITE_OK {
Ok(())
} else {
@@ -194,6 +243,135 @@
}
}
+// Need a custom option type here, because the std lib one does not have guarantees about the discriminate values
+// See: https://github.com/rust-lang/rfcs/blob/master/text/2195-really-tagged-unions.md#opaque-tags
+#[repr(u8)]
+enum OptionalAggregator<A> {
+ // Discriminant is 0
+ None,
+ Some(A),
+}
+
+#[allow(warnings)]
+extern "C" fn run_aggregator_step_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
+ ctx: *mut ffi::sqlite3_context,
+ num_args: libc::c_int,
+ value_ptr: *mut *mut ffi::sqlite3_value,
+) where
+ A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
+ Args: Queryable<ArgsSqlType, Sqlite>,
+ Ret: ToSql<RetSqlType, Sqlite>,
+ Sqlite: HasSqlType<RetSqlType>,
+{
+ unsafe {
+ // This block of unsafe code makes the following assumptions:
+ //
+ // * sqlite3_aggregate_context allocates sizeof::<OptionalAggregator<A>>
+ // bytes of zeroed memory as documented here:
+ // https://www.sqlite.org/c3ref/aggregate_context.html
+ // A null pointer is returned for negative or zero sized types,
+ // which should be impossible in theory. We check that nevertheless
+ //
+ // * OptionalAggregator::None has a discriminant of 0 as specified by
+ // #[repr(u8)] + RFC 2195
+ //
+ // * If all bytes are zero, the discriminant is also zero, so we can
+ // assume that we get OptionalAggregator::None in this case. This is
+ // not UB as we only access the discriminant here, so we do not try
+ // to read any other zeroed memory. After that we initialize our enum
+ // by writing a correct value at this location via ptr::write_unaligned
+ //
+ // * We use ptr::write_unaligned as we did not found any guarantees that
+ // the memory will have a correct alignment.
+ // (Note I(weiznich): would assume that it is aligned correctly, but we
+ // we cannot guarantee it, so better be safe than sorry)
+ let aggregate_context = ffi::sqlite3_aggregate_context(
+ ctx,
+ std::mem::size_of::<OptionalAggregator<A>>() as i32,
+ );
+ let mut aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
+ let aggregator = match aggregate_context.map(|a| &mut *a.as_ptr()) {
+ Some(&mut OptionalAggregator::Some(ref mut agg)) => agg,
+ Some(mut a_ptr @ &mut OptionalAggregator::None) => {
+ ptr::write_unaligned(a_ptr as *mut _, OptionalAggregator::Some(A::default()));
+ if let &mut OptionalAggregator::Some(ref mut agg) = a_ptr {
+ agg
+ } else {
+ unreachable!(
+ "We've written the aggregator above to that location, it must be there"
+ )
+ }
+ }
+ None => {
+ null_aggregate_context_error(ctx);
+ return;
+ }
+ };
+
+ let mut f = |args: &[*mut ffi::sqlite3_value]| -> Result<(), Error> {
+ let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
+
+ Ok(aggregator.step(args))
+ };
+
+ let args = slice::from_raw_parts(value_ptr, num_args as _);
+ match f(args) {
+ Err(e) => {
+ let msg = e.to_string();
+ ffi::sqlite3_result_error(ctx, msg.as_ptr() as *const _, msg.len() as _);
+ }
+ _ => (),
+ };
+ }
+}
+
+extern "C" fn run_aggregator_final_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
+ ctx: *mut ffi::sqlite3_context,
+) where
+ A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
+ Args: Queryable<ArgsSqlType, Sqlite>,
+ Ret: ToSql<RetSqlType, Sqlite>,
+ Sqlite: HasSqlType<RetSqlType>,
+{
+ unsafe {
+ // Within the xFinal callback, it is customary to set nBytes to 0 so no pointless memory
+ // allocations occur, a null pointer is returned in this case
+ // See: https://www.sqlite.org/c3ref/aggregate_context.html
+ //
+ // For the reasoning about the safety of the OptionalAggregator handling
+ // see the comment in run_aggregator_step_function.
+ let aggregate_context = ffi::sqlite3_aggregate_context(ctx, 0);
+ let mut aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
+ let aggregator = match aggregate_context {
+ Some(ref mut a) => match std::mem::replace(a.as_mut(), OptionalAggregator::None) {
+ OptionalAggregator::Some(agg) => Some(agg),
+ OptionalAggregator::None => unreachable!("We've written to the aggregator in the xStep callback. If xStep was never called, then ffi::sqlite_aggregate_context() would have returned a NULL pointer")
+ },
+ None => None,
+ };
+
+ let result = A::finalize(aggregator);
+
+ match process_sql_function_result::<RetSqlType, Ret>(result) {
+ Ok(value) => value.result_of(ctx),
+ Err(e) => {
+ let msg = e.to_string();
+ ffi::sqlite3_result_error(ctx, msg.as_ptr() as *const _, msg.len() as _);
+ }
+ }
+ }
+}
+
+unsafe fn null_aggregate_context_error(ctx: *mut ffi::sqlite3_context) {
+ static NULL_AG_CTX_ERR: &str = "An unknown error occurred. sqlite3_aggregate_context returned a null pointer. This should never happen.";
+
+ ffi::sqlite3_result_error(
+ ctx,
+ NULL_AG_CTX_ERR.as_ptr() as *const _ as *const _,
+ NULL_AG_CTX_ERR.len() as _,
+ );
+}
+
extern "C" fn destroy_boxed_fn<F>(data: *mut libc::c_void)
where
F: FnMut(&RawConnection, &[*mut ffi::sqlite3_value]) -> QueryResult<SerializedValue>
diff --git a/diesel/src/sqlite/mod.rs b/diesel/src/sqlite/mod.rs
index 80b5299..cbd0723 100644
--- a/diesel/src/sqlite/mod.rs
+++ b/diesel/src/sqlite/mod.rs
@@ -13,3 +13,22 @@
pub use self::backend::{Sqlite, SqliteType};
pub use self::connection::SqliteConnection;
pub use self::query_builder::SqliteQueryBuilder;
+
+/// Trait for the implementation of a SQLite aggregate function
+///
+/// This trait is to be used in conjunction with the `sql_function!`
+/// macro for defining a custom SQLite aggregate function. See
+/// the documentation [there](../prelude/macro.sql_function.html) for details.
+pub trait SqliteAggregateFunction<Args>: Default {
+ /// The result type of the SQLite aggregate function
+ type Output;
+
+ /// The `step()` method is called once for every record of the query
+ fn step(&mut self, args: Args);
+
+ /// After the last row has been processed, the `finalize()` method is
+ /// called to compute the result of the aggregate function. If no rows
+ /// were processed `aggregator` will be `None` and `finalize()` can be
+ /// used to specify a default result
+ fn finalize(aggregator: Option<Self>) -> Self::Output;
+}
diff --git a/diesel_derives/src/lib.rs b/diesel_derives/src/lib.rs
index 1f6d0f1..6d2cdea 100644
--- a/diesel_derives/src/lib.rs
+++ b/diesel_derives/src/lib.rs
@@ -642,6 +642,154 @@
/// # Ok(())
/// # }
/// ```
+///
+/// ## Custom Aggregate Functions
+///
+/// Custom aggregate functions can be created in SQLite by adding an `#[aggregate]`
+/// attribute inside of `sql_function`. `register_impl` needs to be called on
+/// the generated function with a type implementing the
+/// [SqliteAggregateFunction](../diesel/sqlite/trait.SqliteAggregateFunction.html)
+/// trait as a type parameter as shown in the examples below.
+///
+/// ```rust
+/// # extern crate diesel;
+/// # use diesel::*;
+/// #
+/// # #[cfg(feature = "sqlite")]
+/// # fn main() {
+/// # run().unwrap();
+/// # }
+/// #
+/// # #[cfg(not(feature = "sqlite"))]
+/// # fn main() {
+/// # }
+/// use diesel::sql_types::Integer;
+/// # #[cfg(feature = "sqlite")]
+/// use diesel::sqlite::SqliteAggregateFunction;
+///
+/// sql_function! {
+/// #[aggregate]
+/// fn my_sum(x: Integer) -> Integer;
+/// }
+///
+/// #[derive(Default)]
+/// struct MySum { sum: i32 }
+///
+/// # #[cfg(feature = "sqlite")]
+/// impl SqliteAggregateFunction<i32> for MySum {
+/// type Output = i32;
+///
+/// fn step(&mut self, expr: i32) {
+/// self.sum += expr;
+/// }
+///
+/// fn finalize(aggregator: Option<Self>) -> Self::Output {
+/// aggregator.map(|a| a.sum).unwrap_or_default()
+/// }
+/// }
+/// # table! {
+/// # players {
+/// # id -> Integer,
+/// # score -> Integer,
+/// # }
+/// # }
+///
+/// # #[cfg(feature = "sqlite")]
+/// fn run() -> Result<(), Box<dyn (::std::error::Error)>> {
+/// # use self::players::dsl::*;
+/// let connection = SqliteConnection::establish(":memory:")?;
+/// # connection.execute("create table players (id integer primary key autoincrement, score integer)").unwrap();
+/// # connection.execute("insert into players (score) values (10), (20), (30)").unwrap();
+///
+/// my_sum::register_impl::<MySum, _>(&connection)?;
+///
+/// let total_score = players.select(my_sum(score))
+/// .get_result::<i32>(&connection)?;
+///
+/// println!("The total score of all the players is: {}", total_score);
+///
+/// # assert_eq!(60, total_score);
+/// Ok(())
+/// }
+/// ```
+///
+/// With multiple function arguments the arguments are passed as a tuple to `SqliteAggregateFunction`
+///
+/// ```rust
+/// # extern crate diesel;
+/// # use diesel::*;
+/// #
+/// # #[cfg(feature = "sqlite")]
+/// # fn main() {
+/// # run().unwrap();
+/// # }
+/// #
+/// # #[cfg(not(feature = "sqlite"))]
+/// # fn main() {
+/// # }
+/// use diesel::sql_types::{Float, Nullable};
+/// # #[cfg(feature = "sqlite")]
+/// use diesel::sqlite::SqliteAggregateFunction;
+///
+/// sql_function! {
+/// #[aggregate]
+/// fn range_max(x0: Float, x1: Float) -> Nullable<Float>;
+/// }
+///
+/// #[derive(Default)]
+/// struct RangeMax<T> { max_value: Option<T> }
+///
+/// # #[cfg(feature = "sqlite")]
+/// impl<T: Default + PartialOrd + Copy + Clone> SqliteAggregateFunction<(T, T)> for RangeMax<T> {
+/// type Output = Option<T>;
+///
+/// fn step(&mut self, (x0, x1): (T, T)) {
+/// # let max = if x0 >= x1 {
+/// # x0
+/// # } else {
+/// # x1
+/// # };
+/// #
+/// # self.max_value = match self.max_value {
+/// # Some(current_max_value) if max > current_max_value => Some(max),
+/// # None => Some(max),
+/// # _ => self.max_value,
+/// # };
+/// // Compare self.max_value to x0 and x1
+/// }
+///
+/// fn finalize(aggregator: Option<Self>) -> Self::Output {
+/// aggregator?.max_value
+/// }
+/// }
+/// # table! {
+/// # student_avgs {
+/// # id -> Integer,
+/// # s1_avg -> Float,
+/// # s2_avg -> Float,
+/// # }
+/// # }
+///
+/// # #[cfg(feature = "sqlite")]
+/// fn run() -> Result<(), Box<dyn (::std::error::Error)>> {
+/// # use self::student_avgs::dsl::*;
+/// let connection = SqliteConnection::establish(":memory:")?;
+/// # connection.execute("create table student_avgs (id integer primary key autoincrement, s1_avg float, s2_avg float)").unwrap();
+/// # connection.execute("insert into student_avgs (s1_avg, s2_avg) values (85.5, 90), (79.8, 80.1)").unwrap();
+///
+/// range_max::register_impl::<RangeMax<f32>, _, _>(&connection)?;
+///
+/// let result = student_avgs.select(range_max(s1_avg, s2_avg))
+/// .get_result::<Option<f32>>(&connection)?;
+///
+/// if let Some(max_semeseter_avg) = result {
+/// println!("The largest semester average is: {}", max_semeseter_avg);
+/// }
+///
+/// # assert_eq!(Some(90f32), result);
+/// Ok(())
+/// }
+/// ```
#[proc_macro]
pub fn sql_function_proc(input: TokenStream) -> TokenStream {
expand_proc_macro(input, sql_function::expand)
diff --git a/diesel_derives/src/sql_function.rs b/diesel_derives/src/sql_function.rs
index 3838d35..a41af4f 100644
--- a/diesel_derives/src/sql_function.rs
+++ b/diesel_derives/src/sql_function.rs
@@ -47,9 +47,11 @@
.type_params()
.map(|type_param| type_param.ident.clone())
.collect::<Vec<_>>();
+
for StrictFnArg { name, .. } in args {
generics.params.push(parse_quote!(#name));
}
+
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
// Even if we force an empty where clause, it still won't print the where
// token with no bounds.
@@ -132,6 +134,68 @@
{
type IsAggregate = diesel::expression::is_aggregate::Yes;
}
+ };
+ if cfg!(feature = "sqlite") && type_args.is_empty() {
+ tokens = quote! {
+ #tokens
+
+ use diesel::sqlite::{Sqlite, SqliteConnection};
+ use diesel::serialize::ToSql;
+ use diesel::deserialize::Queryable;
+ use diesel::sqlite::SqliteAggregateFunction;
+ use diesel::sql_types::IntoNullable;
+ };
+
+ match arg_name.len() {
+ x if x > 1 => {
+ tokens = quote! {
+ #tokens
+
+ #[allow(dead_code)]
+ /// Registers an implementation for this aggregate function on the given connection
+ ///
+ /// This function must be called for every `SqliteConnection` before
+ /// this SQL function can be used on SQLite. The implementation must be
+ /// deterministic (returns the same result given the same arguments).
+ pub fn register_impl<A, #(#arg_name,)*>(
+ conn: &SqliteConnection
+ ) -> QueryResult<()>
+ where
+ A: SqliteAggregateFunction<(#(#arg_name,)*)> + Send + 'static,
+ A::Output: ToSql<#return_type, Sqlite>,
+ (#(#arg_name,)*): Queryable<(#(#arg_type,)*), Sqlite>,
+ {
+ conn.register_aggregate_function::<(#(#arg_type,)*), #return_type, _, _, A>(#sql_name)
+ }
+ };
+ }
+ x if x == 1 => {
+ let arg_name = arg_name[0];
+ let arg_type = arg_type[0];
+
+ tokens = quote! {
+ #tokens
+
+ #[allow(dead_code)]
+ /// Registers an implementation for this aggregate function on the given connection
+ ///
+ /// This function must be called for every `SqliteConnection` before
+ /// this SQL function can be used on SQLite. The implementation must be
+ /// deterministic (returns the same result given the same arguments).
+ pub fn register_impl<A, #arg_name>(
+ conn: &SqliteConnection
+ ) -> QueryResult<()>
+ where
+ A: SqliteAggregateFunction<#arg_name> + Send + 'static,
+ A::Output: ToSql<#return_type, Sqlite>,
+ #arg_name: Queryable<#arg_type, Sqlite>,
+ {
+ conn.register_aggregate_function::<#arg_type, #return_type, _, _, A>(#sql_name)
+ }
+ };
+ }
+ _ => (),
+ }
}
} else {
tokens = quote! {