Merge pull request #2325 from thekuom/feature/2191-sqlite-custom-aggregate

Add ability to create custom aggregate functions in sqlite
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5a4c50a..b89b0f5 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -28,6 +28,9 @@
 * Added support for SQLite's `UPSERT`.
   You can use this feature above SQLite version 3.24.0.
 
+
+* Added ability to create custom aggregate functions in SQLite.
+
 * Multiple aggregate expressions can now appear together in the same select
   clause. See [the upgrade notes](#2-0-0-upgrade-non-aggregate) for details.
 
@@ -36,6 +39,7 @@
   functionality of `NonAggregate`. See [the upgrade
   notes](#2-0-0-upgrade-non-aggregate) for details.
 
+
 ### Removed
 
 * All previously deprecated items have been removed.
diff --git a/diesel/src/sqlite/connection/functions.rs b/diesel/src/sqlite/connection/functions.rs
index 3e6a526..e340b64 100644
--- a/diesel/src/sqlite/connection/functions.rs
+++ b/diesel/src/sqlite/connection/functions.rs
@@ -2,7 +2,7 @@
 
 use super::raw::RawConnection;
 use super::serialized_value::SerializedValue;
-use super::{Sqlite, SqliteValue};
+use super::{Sqlite, SqliteAggregateFunction, SqliteValue};
 use crate::deserialize::{FromSqlRow, Queryable};
 use crate::result::{DatabaseErrorKind, Error, QueryResult};
 use crate::row::Row;
@@ -30,29 +30,75 @@
     }
 
     conn.register_sql_function(fn_name, fields_needed, deterministic, move |conn, args| {
-        let mut row = FunctionRow { args };
-        let args_row = Args::Row::build_from_row(&mut row).map_err(Error::DeserializationError)?;
-        let args = Args::build(args_row);
+        let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
 
         let result = f(conn, args);
 
-        let mut buf = Output::new(Vec::new(), &());
-        let is_null = result.to_sql(&mut buf).map_err(Error::SerializationError)?;
-
-        let bytes = if let IsNull::Yes = is_null {
-            None
-        } else {
-            Some(buf.into_inner())
-        };
-
-        Ok(SerializedValue {
-            ty: Sqlite::metadata(&()),
-            data: bytes,
-        })
+        process_sql_function_result::<RetSqlType, Ret>(result)
     })?;
     Ok(())
 }
 
+pub fn register_aggregate<ArgsSqlType, RetSqlType, Args, Ret, A>(
+    conn: &RawConnection,
+    fn_name: &str,
+) -> QueryResult<()>
+where
+    A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
+    Args: Queryable<ArgsSqlType, Sqlite>,
+    Ret: ToSql<RetSqlType, Sqlite>,
+    Sqlite: HasSqlType<RetSqlType>,
+{
+    let fields_needed = Args::Row::FIELDS_NEEDED;
+    if fields_needed > 127 {
+        return Err(Error::DatabaseError(
+            DatabaseErrorKind::UnableToSendCommand,
+            Box::new("SQLite functions cannot take more than 127 parameters".to_string()),
+        ));
+    }
+
+    conn.register_aggregate_function::<ArgsSqlType, RetSqlType, Args, Ret, A>(
+        fn_name,
+        fields_needed,
+    )?;
+
+    Ok(())
+}
+
+pub(crate) fn build_sql_function_args<ArgsSqlType, Args>(
+    args: &[*mut ffi::sqlite3_value],
+) -> Result<Args, Error>
+where
+    Args: Queryable<ArgsSqlType, Sqlite>,
+{
+    let mut row = FunctionRow { args };
+    let args_row = Args::Row::build_from_row(&mut row).map_err(Error::DeserializationError)?;
+
+    Ok(Args::build(args_row))
+}
+
+pub(crate) fn process_sql_function_result<RetSqlType, Ret>(
+    result: Ret,
+) -> QueryResult<SerializedValue>
+where
+    Ret: ToSql<RetSqlType, Sqlite>,
+    Sqlite: HasSqlType<RetSqlType>,
+{
+    let mut buf = Output::new(Vec::new(), &());
+    let is_null = result.to_sql(&mut buf).map_err(Error::SerializationError)?;
+
+    let bytes = if let IsNull::Yes = is_null {
+        None
+    } else {
+        Some(buf.into_inner())
+    };
+
+    Ok(SerializedValue {
+        ty: Sqlite::metadata(&()),
+        data: bytes,
+    })
+}
+
 struct FunctionRow<'a> {
     args: &'a [*mut ffi::sqlite3_value],
 }
diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs
index 67bbbe4..b200d15 100644
--- a/diesel/src/sqlite/connection/mod.rs
+++ b/diesel/src/sqlite/connection/mod.rs
@@ -15,6 +15,7 @@
 use self::raw::RawConnection;
 use self::statement_iterator::*;
 use self::stmt::{Statement, StatementUse};
+use super::SqliteAggregateFunction;
 use crate::connection::*;
 use crate::deserialize::{Queryable, QueryableByName};
 use crate::query_builder::bind_collector::RawBytesBindCollector;
@@ -238,6 +239,20 @@
         )
     }
 
+    #[doc(hidden)]
+    pub fn register_aggregate_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
+        &self,
+        fn_name: &str,
+    ) -> QueryResult<()>
+    where
+        A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
+        Args: Queryable<ArgsSqlType, Sqlite>,
+        Ret: ToSql<RetSqlType, Sqlite>,
+        Sqlite: HasSqlType<RetSqlType>,
+    {
+        functions::register_aggregate::<_, _, _, _, A>(&self.raw_connection, fn_name)
+    }
+
     fn register_diesel_sql_functions(&self) -> QueryResult<()> {
         use crate::sql_types::{Integer, Text};
 
@@ -370,4 +385,144 @@
             .get_result::<(i32, i32, i32)>(&connection);
         assert_eq!(Ok((2, 3, 4)), added);
     }
+
+    use crate::sqlite::SqliteAggregateFunction;
+
+    sql_function! {
+        #[aggregate]
+        fn my_sum(expr: Integer) -> Integer;
+    }
+
+    #[derive(Default)]
+    struct MySum {
+        sum: i32,
+    }
+
+    impl SqliteAggregateFunction<i32> for MySum {
+        type Output = i32;
+
+        fn step(&mut self, expr: i32) {
+            self.sum += expr;
+        }
+
+        fn finalize(aggregator: Option<Self>) -> Self::Output {
+            aggregator.map(|a| a.sum).unwrap_or_default()
+        }
+    }
+
+    table! {
+        my_sum_example {
+            id -> Integer,
+            value -> Integer,
+        }
+    }
+
+    #[test]
+    fn register_aggregate_function() {
+        use self::my_sum_example::dsl::*;
+
+        let connection = SqliteConnection::establish(":memory:").unwrap();
+        connection
+            .execute(
+                "CREATE TABLE my_sum_example (id integer primary key autoincrement, value integer)",
+            )
+            .unwrap();
+        connection
+            .execute("INSERT INTO my_sum_example (value) VALUES (1), (2), (3)")
+            .unwrap();
+
+        my_sum::register_impl::<MySum, _>(&connection).unwrap();
+
+        let result = my_sum_example
+            .select(my_sum(value))
+            .get_result::<i32>(&connection);
+        assert_eq!(Ok(6), result);
+    }
+
+    #[test]
+    fn register_aggregate_function_returns_finalize_default_on_empty_set() {
+        use self::my_sum_example::dsl::*;
+
+        let connection = SqliteConnection::establish(":memory:").unwrap();
+        connection
+            .execute(
+                "CREATE TABLE my_sum_example (id integer primary key autoincrement, value integer)",
+            )
+            .unwrap();
+
+        my_sum::register_impl::<MySum, _>(&connection).unwrap();
+
+        let result = my_sum_example
+            .select(my_sum(value))
+            .get_result::<i32>(&connection);
+        assert_eq!(Ok(0), result);
+    }
+
+    sql_function! {
+        #[aggregate]
+        fn range_max(expr1: Integer, expr2: Integer, expr3: Integer) -> Nullable<Integer>;
+    }
+
+    #[derive(Default)]
+    struct RangeMax<T> {
+        max_value: Option<T>,
+    }
+
+    impl<T: Default + Ord + Copy + Clone> SqliteAggregateFunction<(T, T, T)> for RangeMax<T> {
+        type Output = Option<T>;
+
+        fn step(&mut self, (x0, x1, x2): (T, T, T)) {
+            let max = if x0 >= x1 && x0 >= x2 {
+                x0
+            } else if x1 >= x0 && x1 >= x2 {
+                x1
+            } else {
+                x2
+            };
+
+            self.max_value = match self.max_value {
+                Some(current_max_value) if max > current_max_value => Some(max),
+                None => Some(max),
+                _ => self.max_value,
+            };
+        }
+
+        fn finalize(aggregator: Option<Self>) -> Self::Output {
+            aggregator?.max_value
+        }
+    }
+
+    table! {
+        range_max_example {
+            id -> Integer,
+            value1 -> Integer,
+            value2 -> Integer,
+            value3 -> Integer,
+        }
+    }
+
+    #[test]
+    fn register_aggregate_multiarg_function() {
+        use self::range_max_example::dsl::*;
+
+        let connection = SqliteConnection::establish(":memory:").unwrap();
+        connection
+            .execute(
+                r#"CREATE TABLE range_max_example (
+                id integer primary key autoincrement,
+                value1 integer,
+                value2 integer,
+                value3 integer
+            )"#,
+            )
+            .unwrap();
+        connection.execute("INSERT INTO range_max_example (value1, value2, value3) VALUES (3, 2, 1), (2, 2, 2)").unwrap();
+
+        range_max::register_impl::<RangeMax<i32>, _, _, _>(&connection).unwrap();
+        let result = range_max_example
+            .select(range_max(value1, value2, value3))
+            .get_result::<Option<i32>>(&connection)
+            .unwrap();
+        assert_eq!(Some(3), result);
+    }
 }
diff --git a/diesel/src/sqlite/connection/raw.rs b/diesel/src/sqlite/connection/raw.rs
index e46e4f1..6c7aee0 100644
--- a/diesel/src/sqlite/connection/raw.rs
+++ b/diesel/src/sqlite/connection/raw.rs
@@ -1,14 +1,19 @@
 extern crate libsqlite3_sys as ffi;
 
-use std::ffi::{CStr, CString};
+use std::ffi::{CStr, CString, NulError};
 use std::io::{stderr, Write};
 use std::os::raw as libc;
 use std::ptr::NonNull;
 use std::{mem, ptr, slice, str};
 
+use super::functions::{build_sql_function_args, process_sql_function_result};
 use super::serialized_value::SerializedValue;
+use super::{Sqlite, SqliteAggregateFunction};
+use crate::deserialize::Queryable;
 use crate::result::Error::DatabaseError;
 use crate::result::*;
+use crate::serialize::ToSql;
+use crate::sql_types::HasSqlType;
 
 #[allow(missing_debug_implementations, missing_copy_implementations)]
 pub struct RawConnection {
@@ -76,11 +81,8 @@
             + Send
             + 'static,
     {
-        let fn_name = CString::new(fn_name)?;
-        let mut flags = ffi::SQLITE_UTF8;
-        if deterministic {
-            flags |= ffi::SQLITE_DETERMINISTIC;
-        }
+        let fn_name = Self::get_fn_name(fn_name)?;
+        let flags = Self::get_flags(deterministic);
         let callback_fn = Box::into_raw(Box::new(f));
 
         let result = unsafe {
@@ -97,6 +99,53 @@
             )
         };
 
+        Self::process_sql_function_result(result)
+    }
+
+    pub fn register_aggregate_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
+        &self,
+        fn_name: &str,
+        num_args: usize,
+    ) -> QueryResult<()>
+    where
+        A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
+        Args: Queryable<ArgsSqlType, Sqlite>,
+        Ret: ToSql<RetSqlType, Sqlite>,
+        Sqlite: HasSqlType<RetSqlType>,
+    {
+        let fn_name = Self::get_fn_name(fn_name)?;
+        let flags = Self::get_flags(false);
+
+        let result = unsafe {
+            ffi::sqlite3_create_function_v2(
+                self.internal_connection.as_ptr(),
+                fn_name.as_ptr(),
+                num_args as _,
+                flags,
+                ptr::null_mut(),
+                None,
+                Some(run_aggregator_step_function::<_, _, _, _, A>),
+                Some(run_aggregator_final_function::<_, _, _, _, A>),
+                None,
+            )
+        };
+
+        Self::process_sql_function_result(result)
+    }
+
+    fn get_fn_name(fn_name: &str) -> Result<CString, NulError> {
+        Ok(CString::new(fn_name)?)
+    }
+
+    fn get_flags(deterministic: bool) -> i32 {
+        let mut flags = ffi::SQLITE_UTF8;
+        if deterministic {
+            flags |= ffi::SQLITE_DETERMINISTIC;
+        }
+        flags
+    }
+
+    fn process_sql_function_result(result: i32) -> Result<(), Error> {
         if result == ffi::SQLITE_OK {
             Ok(())
         } else {
@@ -194,6 +243,135 @@
     }
 }
 
+// Need a custom option type here, because the std lib one does not have guarantees about the discriminate values
+// See: https://github.com/rust-lang/rfcs/blob/master/text/2195-really-tagged-unions.md#opaque-tags
+#[repr(u8)]
+enum OptionalAggregator<A> {
+    // Discriminant is 0
+    None,
+    Some(A),
+}
+
+#[allow(warnings)]
+extern "C" fn run_aggregator_step_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
+    ctx: *mut ffi::sqlite3_context,
+    num_args: libc::c_int,
+    value_ptr: *mut *mut ffi::sqlite3_value,
+) where
+    A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
+    Args: Queryable<ArgsSqlType, Sqlite>,
+    Ret: ToSql<RetSqlType, Sqlite>,
+    Sqlite: HasSqlType<RetSqlType>,
+{
+    unsafe {
+        // This block of unsafe code makes the following assumptions:
+        //
+        // * sqlite3_aggregate_context allocates sizeof::<OptionalAggregator<A>>
+        //   bytes of zeroed memory as documented here:
+        //   https://www.sqlite.org/c3ref/aggregate_context.html
+        //   A null pointer is returned for negative or zero sized types,
+        //   which should be impossible in theory. We check that nevertheless
+        //
+        // * OptionalAggregator::None has a discriminant of 0 as specified by
+        //   #[repr(u8)] + RFC 2195
+        //
+        // * If all bytes are zero, the discriminant is also zero, so we can
+        //   assume that we get OptionalAggregator::None in this case. This is
+        //   not UB as we only access the discriminant here, so we do not try
+        //   to read any other zeroed memory. After that we initialize our enum
+        //   by writing a correct value at this location via ptr::write_unaligned
+        //
+        // * We use ptr::write_unaligned as we did not found any guarantees that
+        //   the memory will have a correct alignment.
+        //   (Note I(weiznich): would assume that it is aligned correctly, but we
+        //    we cannot guarantee it, so better be safe than sorry)
+        let aggregate_context = ffi::sqlite3_aggregate_context(
+            ctx,
+            std::mem::size_of::<OptionalAggregator<A>>() as i32,
+        );
+        let mut aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
+        let aggregator = match aggregate_context.map(|a| &mut *a.as_ptr()) {
+            Some(&mut OptionalAggregator::Some(ref mut agg)) => agg,
+            Some(mut a_ptr @ &mut OptionalAggregator::None) => {
+                ptr::write_unaligned(a_ptr as *mut _, OptionalAggregator::Some(A::default()));
+                if let &mut OptionalAggregator::Some(ref mut agg) = a_ptr {
+                    agg
+                } else {
+                    unreachable!(
+                        "We've written the aggregator above to that location, it must be there"
+                    )
+                }
+            }
+            None => {
+                null_aggregate_context_error(ctx);
+                return;
+            }
+        };
+
+        let mut f = |args: &[*mut ffi::sqlite3_value]| -> Result<(), Error> {
+            let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
+
+            Ok(aggregator.step(args))
+        };
+
+        let args = slice::from_raw_parts(value_ptr, num_args as _);
+        match f(args) {
+            Err(e) => {
+                let msg = e.to_string();
+                ffi::sqlite3_result_error(ctx, msg.as_ptr() as *const _, msg.len() as _);
+            }
+            _ => (),
+        };
+    }
+}
+
+extern "C" fn run_aggregator_final_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
+    ctx: *mut ffi::sqlite3_context,
+) where
+    A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
+    Args: Queryable<ArgsSqlType, Sqlite>,
+    Ret: ToSql<RetSqlType, Sqlite>,
+    Sqlite: HasSqlType<RetSqlType>,
+{
+    unsafe {
+        // Within the xFinal callback, it is customary to set nBytes to 0 so no pointless memory
+        // allocations occur, a null pointer is returned in this case
+        // See: https://www.sqlite.org/c3ref/aggregate_context.html
+        //
+        // For the reasoning about the safety of the OptionalAggregator handling
+        // see the comment in run_aggregator_step_function.
+        let aggregate_context = ffi::sqlite3_aggregate_context(ctx, 0);
+        let mut aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
+        let aggregator = match aggregate_context {
+            Some(ref mut a) => match std::mem::replace(a.as_mut(), OptionalAggregator::None) {
+                OptionalAggregator::Some(agg) => Some(agg),
+                OptionalAggregator::None => unreachable!("We've written to the aggregator in the xStep callback. If xStep was never called, then ffi::sqlite_aggregate_context() would have returned a NULL pointer")
+            },
+            None => None,
+        };
+
+        let result = A::finalize(aggregator);
+
+        match process_sql_function_result::<RetSqlType, Ret>(result) {
+            Ok(value) => value.result_of(ctx),
+            Err(e) => {
+                let msg = e.to_string();
+                ffi::sqlite3_result_error(ctx, msg.as_ptr() as *const _, msg.len() as _);
+            }
+        }
+    }
+}
+
+unsafe fn null_aggregate_context_error(ctx: *mut ffi::sqlite3_context) {
+    static NULL_AG_CTX_ERR: &str = "An unknown error occurred. sqlite3_aggregate_context returned a null pointer. This should never happen.";
+
+    ffi::sqlite3_result_error(
+        ctx,
+        NULL_AG_CTX_ERR.as_ptr() as *const _ as *const _,
+        NULL_AG_CTX_ERR.len() as _,
+    );
+}
+
 extern "C" fn destroy_boxed_fn<F>(data: *mut libc::c_void)
 where
     F: FnMut(&RawConnection, &[*mut ffi::sqlite3_value]) -> QueryResult<SerializedValue>
diff --git a/diesel/src/sqlite/mod.rs b/diesel/src/sqlite/mod.rs
index 80b5299..cbd0723 100644
--- a/diesel/src/sqlite/mod.rs
+++ b/diesel/src/sqlite/mod.rs
@@ -13,3 +13,22 @@
 pub use self::backend::{Sqlite, SqliteType};
 pub use self::connection::SqliteConnection;
 pub use self::query_builder::SqliteQueryBuilder;
+
+/// Trait for the implementation of a SQLite aggregate function
+///
+/// This trait is to be used in conjunction with the `sql_function!`
+/// macro for defining a custom SQLite aggregate function. See
+/// the documentation [there](../prelude/macro.sql_function.html) for details.
+pub trait SqliteAggregateFunction<Args>: Default {
+    /// The result type of the SQLite aggregate function
+    type Output;
+
+    /// The `step()` method is called once for every record of the query
+    fn step(&mut self, args: Args);
+
+    /// After the last row has been processed, the `finalize()` method is
+    /// called to compute the result of the aggregate function. If no rows
+    /// were processed `aggregator` will be `None` and `finalize()` can be
+    /// used to specify a default result
+    fn finalize(aggregator: Option<Self>) -> Self::Output;
+}
diff --git a/diesel_derives/src/lib.rs b/diesel_derives/src/lib.rs
index 1f6d0f1..6d2cdea 100644
--- a/diesel_derives/src/lib.rs
+++ b/diesel_derives/src/lib.rs
@@ -642,6 +642,154 @@
 /// #     Ok(())
 /// # }
 /// ```
+///
+/// ## Custom Aggregate Functions
+///
+/// Custom aggregate functions can be created in SQLite by adding an `#[aggregate]`
+/// attribute inside of `sql_function`. `register_impl` needs to be called on
+/// the generated function with a type implementing the
+/// [SqliteAggregateFunction](../diesel/sqlite/trait.SqliteAggregateFunction.html)
+/// trait as a type parameter as shown in the examples below.
+///
+/// ```rust
+/// # extern crate diesel;
+/// # use diesel::*;
+/// #
+/// # #[cfg(feature = "sqlite")]
+/// # fn main() {
+/// #   run().unwrap();
+/// # }
+/// #
+/// # #[cfg(not(feature = "sqlite"))]
+/// # fn main() {
+/// # }
+/// use diesel::sql_types::Integer;
+/// # #[cfg(feature = "sqlite")]
+/// use diesel::sqlite::SqliteAggregateFunction;
+///
+/// sql_function! {
+///     #[aggregate]
+///     fn my_sum(x: Integer) -> Integer;
+/// }
+///
+/// #[derive(Default)]
+/// struct MySum { sum: i32 }
+///
+/// # #[cfg(feature = "sqlite")]
+/// impl SqliteAggregateFunction<i32> for MySum {
+///     type Output = i32;
+///
+///     fn step(&mut self, expr: i32) {
+///         self.sum += expr;
+///     }
+///
+///     fn finalize(aggregator: Option<Self>) -> Self::Output {
+///         aggregator.map(|a| a.sum).unwrap_or_default()
+///     }
+/// }
+/// # table! {
+/// #     players {
+/// #         id -> Integer,
+/// #         score -> Integer,
+/// #     }
+/// # }
+///
+/// # #[cfg(feature = "sqlite")]
+/// fn run() -> Result<(), Box<dyn (::std::error::Error)>> {
+/// #    use self::players::dsl::*;
+///     let connection = SqliteConnection::establish(":memory:")?;
+/// #    connection.execute("create table players (id integer primary key autoincrement, score integer)").unwrap();
+/// #    connection.execute("insert into players (score) values (10), (20), (30)").unwrap();
+///
+///     my_sum::register_impl::<MySum, _>(&connection)?;
+///
+///     let total_score = players.select(my_sum(score))
+///         .get_result::<i32>(&connection)?;
+///
+///     println!("The total score of all the players is: {}", total_score);
+///
+/// #    assert_eq!(60, total_score);
+///     Ok(())
+/// }
+/// ```
+///
+/// With multiple function arguments the arguments are passed as a tuple to `SqliteAggregateFunction`
+///
+/// ```rust
+/// # extern crate diesel;
+/// # use diesel::*;
+/// #
+/// # #[cfg(feature = "sqlite")]
+/// # fn main() {
+/// #   run().unwrap();
+/// # }
+/// #
+/// # #[cfg(not(feature = "sqlite"))]
+/// # fn main() {
+/// # }
+/// use diesel::sql_types::{Float, Nullable};
+/// # #[cfg(feature = "sqlite")]
+/// use diesel::sqlite::SqliteAggregateFunction;
+///
+/// sql_function! {
+///     #[aggregate]
+///     fn range_max(x0: Float, x1: Float) -> Nullable<Float>;
+/// }
+///
+/// #[derive(Default)]
+/// struct RangeMax<T> { max_value: Option<T> }
+///
+/// # #[cfg(feature = "sqlite")]
+/// impl<T: Default + PartialOrd + Copy + Clone> SqliteAggregateFunction<(T, T)> for RangeMax<T> {
+///     type Output = Option<T>;
+///
+///     fn step(&mut self, (x0, x1): (T, T)) {
+/// #        let max = if x0 >= x1 {
+/// #            x0
+/// #        } else {
+/// #            x1
+/// #        };
+/// #
+/// #        self.max_value = match self.max_value {
+/// #            Some(current_max_value) if max > current_max_value => Some(max),
+/// #            None => Some(max),
+/// #            _ => self.max_value,
+/// #        };
+///         // Compare self.max_value to x0 and x1
+///     }
+///
+///     fn finalize(aggregator: Option<Self>) -> Self::Output {
+///         aggregator?.max_value
+///     }
+/// }
+/// # table! {
+/// #     student_avgs {
+/// #         id -> Integer,
+/// #         s1_avg -> Float,
+/// #         s2_avg -> Float,
+/// #     }
+/// # }
+///
+/// # #[cfg(feature = "sqlite")]
+/// fn run() -> Result<(), Box<dyn (::std::error::Error)>> {
+/// #    use self::student_avgs::dsl::*;
+///     let connection = SqliteConnection::establish(":memory:")?;
+/// #    connection.execute("create table student_avgs (id integer primary key autoincrement, s1_avg float, s2_avg float)").unwrap();
+/// #    connection.execute("insert into student_avgs (s1_avg, s2_avg) values (85.5, 90), (79.8, 80.1)").unwrap();
+///
+///     range_max::register_impl::<RangeMax<f32>, _, _>(&connection)?;
+///
+///     let result = student_avgs.select(range_max(s1_avg, s2_avg))
+///         .get_result::<Option<f32>>(&connection)?;
+///
+///     if let Some(max_semeseter_avg) = result {
+///         println!("The largest semester average is: {}", max_semeseter_avg);
+///     }
+///
+/// #    assert_eq!(Some(90f32), result);
+///     Ok(())
+/// }
+/// ```
 #[proc_macro]
 pub fn sql_function_proc(input: TokenStream) -> TokenStream {
     expand_proc_macro(input, sql_function::expand)
diff --git a/diesel_derives/src/sql_function.rs b/diesel_derives/src/sql_function.rs
index 3838d35..a41af4f 100644
--- a/diesel_derives/src/sql_function.rs
+++ b/diesel_derives/src/sql_function.rs
@@ -47,9 +47,11 @@
         .type_params()
         .map(|type_param| type_param.ident.clone())
         .collect::<Vec<_>>();
+
     for StrictFnArg { name, .. } in args {
         generics.params.push(parse_quote!(#name));
     }
+
     let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
     // Even if we force an empty where clause, it still won't print the where
     // token with no bounds.
@@ -132,6 +134,68 @@
             {
                 type IsAggregate = diesel::expression::is_aggregate::Yes;
             }
+        };
+        if cfg!(feature = "sqlite") && type_args.is_empty() {
+            tokens = quote! {
+                #tokens
+
+                use diesel::sqlite::{Sqlite, SqliteConnection};
+                use diesel::serialize::ToSql;
+                use diesel::deserialize::Queryable;
+                use diesel::sqlite::SqliteAggregateFunction;
+                use diesel::sql_types::IntoNullable;
+            };
+
+            match arg_name.len() {
+                x if x > 1 => {
+                    tokens = quote! {
+                        #tokens
+
+                        #[allow(dead_code)]
+                        /// Registers an implementation for this aggregate function on the given connection
+                        ///
+                        /// This function must be called for every `SqliteConnection` before
+                        /// this SQL function can be used on SQLite. The implementation must be
+                        /// deterministic (returns the same result given the same arguments).
+                        pub fn register_impl<A, #(#arg_name,)*>(
+                            conn: &SqliteConnection
+                        ) -> QueryResult<()>
+                            where
+                            A: SqliteAggregateFunction<(#(#arg_name,)*)> + Send + 'static,
+                            A::Output: ToSql<#return_type, Sqlite>,
+                            (#(#arg_name,)*): Queryable<(#(#arg_type,)*), Sqlite>,
+                        {
+                            conn.register_aggregate_function::<(#(#arg_type,)*), #return_type, _, _, A>(#sql_name)
+                        }
+                    };
+                }
+                x if x == 1 => {
+                    let arg_name = arg_name[0];
+                    let arg_type = arg_type[0];
+
+                    tokens = quote! {
+                        #tokens
+
+                        #[allow(dead_code)]
+                        /// Registers an implementation for this aggregate function on the given connection
+                        ///
+                        /// This function must be called for every `SqliteConnection` before
+                        /// this SQL function can be used on SQLite. The implementation must be
+                        /// deterministic (returns the same result given the same arguments).
+                        pub fn register_impl<A, #arg_name>(
+                            conn: &SqliteConnection
+                        ) -> QueryResult<()>
+                            where
+                            A: SqliteAggregateFunction<#arg_name> + Send + 'static,
+                            A::Output: ToSql<#return_type, Sqlite>,
+                            #arg_name: Queryable<#arg_type, Sqlite>,
+                            {
+                                conn.register_aggregate_function::<#arg_type, #return_type, _, _, A>(#sql_name)
+                            }
+                    };
+                }
+                _ => (),
+            }
         }
     } else {
         tokens = quote! {