1 //! Create or redefine SQL functions.
2 //!
3 //! # Example
4 //!
5 //! Adding a `regexp` function to a connection in which compiled regular
6 //! expressions are cached in a `HashMap`. For an alternative implementation
7 //! that uses SQLite's [Function Auxiliary Data](https://www.sqlite.org/c3ref/get_auxdata.html) interface
8 //! to avoid recompiling regular expressions, see the unit tests for this
9 //! module.
10 //!
11 //! ```rust
12 //! use regex::Regex;
13 //! use rusqlite::functions::FunctionFlags;
14 //! use rusqlite::{Connection, Error, Result};
15 //! use std::sync::Arc;
16 //! type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
17 //!
18 //! fn add_regexp_function(db: &Connection) -> Result<()> {
19 //! db.create_scalar_function(
20 //! "regexp",
21 //! 2,
22 //! FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
23 //! move |ctx| {
24 //! assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
25 //! let regexp: Arc<Regex> = ctx.get_or_create_aux(0, |vr| -> Result<_, BoxError> {
26 //! Ok(Regex::new(vr.as_str()?)?)
27 //! })?;
28 //! let is_match = {
29 //! let text = ctx
30 //! .get_raw(1)
31 //! .as_str()
32 //! .map_err(|e| Error::UserFunctionError(e.into()))?;
33 //!
34 //! regexp.is_match(text)
35 //! };
36 //!
37 //! Ok(is_match)
38 //! },
39 //! )
40 //! }
41 //!
42 //! fn main() -> Result<()> {
43 //! let db = Connection::open_in_memory()?;
44 //! add_regexp_function(&db)?;
45 //!
46 //! let is_match: bool =
47 //! db.query_row("SELECT regexp('[aeiou]*', 'aaaaeeeiii')", [], |row| {
48 //! row.get(0)
49 //! })?;
50 //!
51 //! assert!(is_match);
52 //! Ok(())
53 //! }
54 //! ```
55 use std::any::Any;
56 use std::marker::PhantomData;
57 use std::ops::Deref;
58 use std::os::raw::{c_int, c_void};
59 use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
60 use std::ptr;
61 use std::slice;
62 use std::sync::Arc;
63
64 use crate::ffi;
65 use crate::ffi::sqlite3_context;
66 use crate::ffi::sqlite3_value;
67
68 use crate::context::set_result;
69 use crate::types::{FromSql, FromSqlError, ToSql, ValueRef};
70
71 use crate::{str_to_cstring, Connection, Error, InnerConnection, Result};
72
report_error(ctx: *mut sqlite3_context, err: &Error)73 unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) {
74 // Extended constraint error codes were added in SQLite 3.7.16. We don't have
75 // an explicit feature check for that, and this doesn't really warrant one.
76 // We'll use the extended code if we're on the bundled version (since it's
77 // at least 3.17.0) and the normal constraint error code if not.
78 fn constraint_error_code() -> i32 {
79 ffi::SQLITE_CONSTRAINT_FUNCTION
80 }
81
82 if let Error::SqliteFailure(ref err, ref s) = *err {
83 ffi::sqlite3_result_error_code(ctx, err.extended_code);
84 if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) {
85 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
86 }
87 } else {
88 ffi::sqlite3_result_error_code(ctx, constraint_error_code());
89 if let Ok(cstr) = str_to_cstring(&err.to_string()) {
90 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
91 }
92 }
93 }
94
free_boxed_value<T>(p: *mut c_void)95 unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
96 drop(Box::from_raw(p.cast::<T>()));
97 }
98
99 /// Context is a wrapper for the SQLite function
100 /// evaluation context.
101 pub struct Context<'a> {
102 ctx: *mut sqlite3_context,
103 args: &'a [*mut sqlite3_value],
104 }
105
106 impl Context<'_> {
107 /// Returns the number of arguments to the function.
108 #[inline]
109 #[must_use]
len(&self) -> usize110 pub fn len(&self) -> usize {
111 self.args.len()
112 }
113
114 /// Returns `true` when there is no argument.
115 #[inline]
116 #[must_use]
is_empty(&self) -> bool117 pub fn is_empty(&self) -> bool {
118 self.args.is_empty()
119 }
120
121 /// Returns the `idx`th argument as a `T`.
122 ///
123 /// # Failure
124 ///
125 /// Will panic if `idx` is greater than or equal to
126 /// [`self.len()`](Context::len).
127 ///
128 /// Will return Err if the underlying SQLite type cannot be converted to a
129 /// `T`.
get<T: FromSql>(&self, idx: usize) -> Result<T>130 pub fn get<T: FromSql>(&self, idx: usize) -> Result<T> {
131 let arg = self.args[idx];
132 let value = unsafe { ValueRef::from_value(arg) };
133 FromSql::column_result(value).map_err(|err| match err {
134 FromSqlError::InvalidType => {
135 Error::InvalidFunctionParameterType(idx, value.data_type())
136 }
137 FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i),
138 FromSqlError::Other(err) => {
139 Error::FromSqlConversionFailure(idx, value.data_type(), err)
140 }
141 FromSqlError::InvalidBlobSize { .. } => {
142 Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
143 }
144 })
145 }
146
147 /// Returns the `idx`th argument as a `ValueRef`.
148 ///
149 /// # Failure
150 ///
151 /// Will panic if `idx` is greater than or equal to
152 /// [`self.len()`](Context::len).
153 #[inline]
154 #[must_use]
get_raw(&self, idx: usize) -> ValueRef<'_>155 pub fn get_raw(&self, idx: usize) -> ValueRef<'_> {
156 let arg = self.args[idx];
157 unsafe { ValueRef::from_value(arg) }
158 }
159
160 /// Returns the subtype of `idx`th argument.
161 ///
162 /// # Failure
163 ///
164 /// Will panic if `idx` is greater than or equal to
165 /// [`self.len()`](Context::len).
get_subtype(&self, idx: usize) -> std::os::raw::c_uint166 pub fn get_subtype(&self, idx: usize) -> std::os::raw::c_uint {
167 let arg = self.args[idx];
168 unsafe { ffi::sqlite3_value_subtype(arg) }
169 }
170
171 /// Fetch or insert the auxiliary data associated with a particular
172 /// parameter. This is intended to be an easier-to-use way of fetching it
173 /// compared to calling [`get_aux`](Context::get_aux) and
174 /// [`set_aux`](Context::set_aux) separately.
175 ///
176 /// See `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of
177 /// this feature, or the unit tests of this module for an example.
get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>> where T: Send + Sync + 'static, E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>, F: FnOnce(ValueRef<'_>) -> Result<T, E>,178 pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>>
179 where
180 T: Send + Sync + 'static,
181 E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
182 F: FnOnce(ValueRef<'_>) -> Result<T, E>,
183 {
184 if let Some(v) = self.get_aux(arg)? {
185 Ok(v)
186 } else {
187 let vr = self.get_raw(arg as usize);
188 self.set_aux(
189 arg,
190 func(vr).map_err(|e| Error::UserFunctionError(e.into()))?,
191 )
192 }
193 }
194
195 /// Sets the auxiliary data associated with a particular parameter. See
196 /// `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of
197 /// this feature, or the unit tests of this module for an example.
set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>>198 pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> {
199 let orig: Arc<T> = Arc::new(value);
200 let inner: AuxInner = orig.clone();
201 let outer = Box::new(inner);
202 let raw: *mut AuxInner = Box::into_raw(outer);
203 unsafe {
204 ffi::sqlite3_set_auxdata(
205 self.ctx,
206 arg,
207 raw.cast(),
208 Some(free_boxed_value::<AuxInner>),
209 );
210 };
211 Ok(orig)
212 }
213
214 /// Gets the auxiliary data that was associated with a given parameter via
215 /// [`set_aux`](Context::set_aux). Returns `Ok(None)` if no data has been
216 /// associated, and Ok(Some(v)) if it has. Returns an error if the
217 /// requested type does not match.
get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>>218 pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> {
219 let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner };
220 if p.is_null() {
221 Ok(None)
222 } else {
223 let v: AuxInner = AuxInner::clone(unsafe { &*p });
224 v.downcast::<T>()
225 .map(Some)
226 .map_err(|_| Error::GetAuxWrongType)
227 }
228 }
229
230 /// Get the db connection handle via [sqlite3_context_db_handle](https://www.sqlite.org/c3ref/context_db_handle.html)
231 ///
232 /// # Safety
233 ///
234 /// This function is marked unsafe because there is a potential for other
235 /// references to the connection to be sent across threads, [see this comment](https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213).
get_connection(&self) -> Result<ConnectionRef<'_>>236 pub unsafe fn get_connection(&self) -> Result<ConnectionRef<'_>> {
237 let handle = ffi::sqlite3_context_db_handle(self.ctx);
238 Ok(ConnectionRef {
239 conn: Connection::from_handle(handle)?,
240 phantom: PhantomData,
241 })
242 }
243
244 /// Set the Subtype of an SQL function
set_result_subtype(&self, sub_type: std::os::raw::c_uint)245 pub fn set_result_subtype(&self, sub_type: std::os::raw::c_uint) {
246 unsafe { ffi::sqlite3_result_subtype(self.ctx, sub_type) };
247 }
248 }
249
250 /// A reference to a connection handle with a lifetime bound to something.
251 pub struct ConnectionRef<'ctx> {
252 // comes from Connection::from_handle(sqlite3_context_db_handle(...))
253 // and is non-owning
254 conn: Connection,
255 phantom: PhantomData<&'ctx Context<'ctx>>,
256 }
257
258 impl Deref for ConnectionRef<'_> {
259 type Target = Connection;
260
261 #[inline]
deref(&self) -> &Connection262 fn deref(&self) -> &Connection {
263 &self.conn
264 }
265 }
266
267 type AuxInner = Arc<dyn Any + Send + Sync + 'static>;
268
269 /// Aggregate is the callback interface for user-defined
270 /// aggregate function.
271 ///
272 /// `A` is the type of the aggregation context and `T` is the type of the final
273 /// result. Implementations should be stateless.
274 pub trait Aggregate<A, T>
275 where
276 A: RefUnwindSafe + UnwindSafe,
277 T: ToSql,
278 {
279 /// Initializes the aggregation context. Will be called prior to the first
280 /// call to [`step()`](Aggregate::step) to set up the context for an
281 /// invocation of the function. (Note: `init()` will not be called if
282 /// there are no rows.)
init(&self, _: &mut Context<'_>) -> Result<A>283 fn init(&self, _: &mut Context<'_>) -> Result<A>;
284
285 /// "step" function called once for each row in an aggregate group. May be
286 /// called 0 times if there are no rows.
step(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>287 fn step(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
288
289 /// Computes and returns the final result. Will be called exactly once for
290 /// each invocation of the function. If [`step()`](Aggregate::step) was
291 /// called at least once, will be given `Some(A)` (the same `A` as was
292 /// created by [`init`](Aggregate::init) and given to
293 /// [`step`](Aggregate::step)); if [`step()`](Aggregate::step) was not
294 /// called (because the function is running against 0 rows), will be
295 /// given `None`.
296 ///
297 /// The passed context will have no arguments.
finalize(&self, _: &mut Context<'_>, _: Option<A>) -> Result<T>298 fn finalize(&self, _: &mut Context<'_>, _: Option<A>) -> Result<T>;
299 }
300
301 /// `WindowAggregate` is the callback interface for
302 /// user-defined aggregate window function.
303 #[cfg(feature = "window")]
304 #[cfg_attr(docsrs, doc(cfg(feature = "window")))]
305 pub trait WindowAggregate<A, T>: Aggregate<A, T>
306 where
307 A: RefUnwindSafe + UnwindSafe,
308 T: ToSql,
309 {
310 /// Returns the current value of the aggregate. Unlike xFinal, the
311 /// implementation should not delete any context.
value(&self, _: Option<&A>) -> Result<T>312 fn value(&self, _: Option<&A>) -> Result<T>;
313
314 /// Removes a row from the current window.
inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>315 fn inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
316 }
317
318 bitflags::bitflags! {
319 /// Function Flags.
320 /// See [sqlite3_create_function](https://sqlite.org/c3ref/create_function.html)
321 /// and [Function Flags](https://sqlite.org/c3ref/c_deterministic.html) for details.
322 #[repr(C)]
323 pub struct FunctionFlags: ::std::os::raw::c_int {
324 /// Specifies UTF-8 as the text encoding this SQL function prefers for its parameters.
325 const SQLITE_UTF8 = ffi::SQLITE_UTF8;
326 /// Specifies UTF-16 using little-endian byte order as the text encoding this SQL function prefers for its parameters.
327 const SQLITE_UTF16LE = ffi::SQLITE_UTF16LE;
328 /// Specifies UTF-16 using big-endian byte order as the text encoding this SQL function prefers for its parameters.
329 const SQLITE_UTF16BE = ffi::SQLITE_UTF16BE;
330 /// Specifies UTF-16 using native byte order as the text encoding this SQL function prefers for its parameters.
331 const SQLITE_UTF16 = ffi::SQLITE_UTF16;
332 /// Means that the function always gives the same output when the input parameters are the same.
333 const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC; // 3.8.3
334 /// Means that the function may only be invoked from top-level SQL.
335 const SQLITE_DIRECTONLY = 0x0000_0008_0000; // 3.30.0
336 /// Indicates to SQLite that a function may call `sqlite3_value_subtype()` to inspect the sub-types of its arguments.
337 const SQLITE_SUBTYPE = 0x0000_0010_0000; // 3.30.0
338 /// Means that the function is unlikely to cause problems even if misused.
339 const SQLITE_INNOCUOUS = 0x0000_0020_0000; // 3.31.0
340 }
341 }
342
343 impl Default for FunctionFlags {
344 #[inline]
default() -> FunctionFlags345 fn default() -> FunctionFlags {
346 FunctionFlags::SQLITE_UTF8
347 }
348 }
349
350 impl Connection {
351 /// Attach a user-defined scalar function to
352 /// this database connection.
353 ///
354 /// `fn_name` is the name the function will be accessible from SQL.
355 /// `n_arg` is the number of arguments to the function. Use `-1` for a
356 /// variable number. If the function always returns the same value
357 /// given the same input, `deterministic` should be `true`.
358 ///
359 /// The function will remain available until the connection is closed or
360 /// until it is explicitly removed via
361 /// [`remove_function`](Connection::remove_function).
362 ///
363 /// # Example
364 ///
365 /// ```rust
366 /// # use rusqlite::{Connection, Result};
367 /// # use rusqlite::functions::FunctionFlags;
368 /// fn scalar_function_example(db: Connection) -> Result<()> {
369 /// db.create_scalar_function(
370 /// "halve",
371 /// 1,
372 /// FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
373 /// |ctx| {
374 /// let value = ctx.get::<f64>(0)?;
375 /// Ok(value / 2f64)
376 /// },
377 /// )?;
378 ///
379 /// let six_halved: f64 = db.query_row("SELECT halve(6)", [], |r| r.get(0))?;
380 /// assert_eq!(six_halved, 3f64);
381 /// Ok(())
382 /// }
383 /// ```
384 ///
385 /// # Failure
386 ///
387 /// Will return Err if the function could not be attached to the connection.
388 #[inline]
create_scalar_function<F, T>( &self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, x_func: F, ) -> Result<()> where F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, T: ToSql,389 pub fn create_scalar_function<F, T>(
390 &self,
391 fn_name: &str,
392 n_arg: c_int,
393 flags: FunctionFlags,
394 x_func: F,
395 ) -> Result<()>
396 where
397 F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
398 T: ToSql,
399 {
400 self.db
401 .borrow_mut()
402 .create_scalar_function(fn_name, n_arg, flags, x_func)
403 }
404
405 /// Attach a user-defined aggregate function to this
406 /// database connection.
407 ///
408 /// # Failure
409 ///
410 /// Will return Err if the function could not be attached to the connection.
411 #[inline]
create_aggregate_function<A, D, T>( &self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: D, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T> + 'static, T: ToSql,412 pub fn create_aggregate_function<A, D, T>(
413 &self,
414 fn_name: &str,
415 n_arg: c_int,
416 flags: FunctionFlags,
417 aggr: D,
418 ) -> Result<()>
419 where
420 A: RefUnwindSafe + UnwindSafe,
421 D: Aggregate<A, T> + 'static,
422 T: ToSql,
423 {
424 self.db
425 .borrow_mut()
426 .create_aggregate_function(fn_name, n_arg, flags, aggr)
427 }
428
429 /// Attach a user-defined aggregate window function to
430 /// this database connection.
431 ///
432 /// See `https://sqlite.org/windowfunctions.html#udfwinfunc` for more
433 /// information.
434 #[cfg(feature = "window")]
435 #[cfg_attr(docsrs, doc(cfg(feature = "window")))]
436 #[inline]
create_window_function<A, W, T>( &self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: W, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T> + 'static, T: ToSql,437 pub fn create_window_function<A, W, T>(
438 &self,
439 fn_name: &str,
440 n_arg: c_int,
441 flags: FunctionFlags,
442 aggr: W,
443 ) -> Result<()>
444 where
445 A: RefUnwindSafe + UnwindSafe,
446 W: WindowAggregate<A, T> + 'static,
447 T: ToSql,
448 {
449 self.db
450 .borrow_mut()
451 .create_window_function(fn_name, n_arg, flags, aggr)
452 }
453
454 /// Removes a user-defined function from this
455 /// database connection.
456 ///
457 /// `fn_name` and `n_arg` should match the name and number of arguments
458 /// given to [`create_scalar_function`](Connection::create_scalar_function)
459 /// or [`create_aggregate_function`](Connection::create_aggregate_function).
460 ///
461 /// # Failure
462 ///
463 /// Will return Err if the function could not be removed.
464 #[inline]
remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()>465 pub fn remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()> {
466 self.db.borrow_mut().remove_function(fn_name, n_arg)
467 }
468 }
469
470 impl InnerConnection {
create_scalar_function<F, T>( &mut self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, x_func: F, ) -> Result<()> where F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, T: ToSql,471 fn create_scalar_function<F, T>(
472 &mut self,
473 fn_name: &str,
474 n_arg: c_int,
475 flags: FunctionFlags,
476 x_func: F,
477 ) -> Result<()>
478 where
479 F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
480 T: ToSql,
481 {
482 unsafe extern "C" fn call_boxed_closure<F, T>(
483 ctx: *mut sqlite3_context,
484 argc: c_int,
485 argv: *mut *mut sqlite3_value,
486 ) where
487 F: FnMut(&Context<'_>) -> Result<T>,
488 T: ToSql,
489 {
490 let r = catch_unwind(|| {
491 let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::<F>();
492 assert!(!boxed_f.is_null(), "Internal error - null function pointer");
493 let ctx = Context {
494 ctx,
495 args: slice::from_raw_parts(argv, argc as usize),
496 };
497 (*boxed_f)(&ctx)
498 });
499 let t = match r {
500 Err(_) => {
501 report_error(ctx, &Error::UnwindingPanic);
502 return;
503 }
504 Ok(r) => r,
505 };
506 let t = t.as_ref().map(|t| ToSql::to_sql(t));
507
508 match t {
509 Ok(Ok(ref value)) => set_result(ctx, value),
510 Ok(Err(err)) => report_error(ctx, &err),
511 Err(err) => report_error(ctx, err),
512 }
513 }
514
515 let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
516 let c_name = str_to_cstring(fn_name)?;
517 let r = unsafe {
518 ffi::sqlite3_create_function_v2(
519 self.db(),
520 c_name.as_ptr(),
521 n_arg,
522 flags.bits(),
523 boxed_f.cast::<c_void>(),
524 Some(call_boxed_closure::<F, T>),
525 None,
526 None,
527 Some(free_boxed_value::<F>),
528 )
529 };
530 self.decode_result(r)
531 }
532
create_aggregate_function<A, D, T>( &mut self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: D, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T> + 'static, T: ToSql,533 fn create_aggregate_function<A, D, T>(
534 &mut self,
535 fn_name: &str,
536 n_arg: c_int,
537 flags: FunctionFlags,
538 aggr: D,
539 ) -> Result<()>
540 where
541 A: RefUnwindSafe + UnwindSafe,
542 D: Aggregate<A, T> + 'static,
543 T: ToSql,
544 {
545 let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
546 let c_name = str_to_cstring(fn_name)?;
547 let r = unsafe {
548 ffi::sqlite3_create_function_v2(
549 self.db(),
550 c_name.as_ptr(),
551 n_arg,
552 flags.bits(),
553 boxed_aggr.cast::<c_void>(),
554 None,
555 Some(call_boxed_step::<A, D, T>),
556 Some(call_boxed_final::<A, D, T>),
557 Some(free_boxed_value::<D>),
558 )
559 };
560 self.decode_result(r)
561 }
562
563 #[cfg(feature = "window")]
create_window_function<A, W, T>( &mut self, fn_name: &str, n_arg: c_int, flags: FunctionFlags, aggr: W, ) -> Result<()> where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T> + 'static, T: ToSql,564 fn create_window_function<A, W, T>(
565 &mut self,
566 fn_name: &str,
567 n_arg: c_int,
568 flags: FunctionFlags,
569 aggr: W,
570 ) -> Result<()>
571 where
572 A: RefUnwindSafe + UnwindSafe,
573 W: WindowAggregate<A, T> + 'static,
574 T: ToSql,
575 {
576 let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
577 let c_name = str_to_cstring(fn_name)?;
578 let r = unsafe {
579 ffi::sqlite3_create_window_function(
580 self.db(),
581 c_name.as_ptr(),
582 n_arg,
583 flags.bits(),
584 boxed_aggr.cast::<c_void>(),
585 Some(call_boxed_step::<A, W, T>),
586 Some(call_boxed_final::<A, W, T>),
587 Some(call_boxed_value::<A, W, T>),
588 Some(call_boxed_inverse::<A, W, T>),
589 Some(free_boxed_value::<W>),
590 )
591 };
592 self.decode_result(r)
593 }
594
remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()>595 fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> {
596 let c_name = str_to_cstring(fn_name)?;
597 let r = unsafe {
598 ffi::sqlite3_create_function_v2(
599 self.db(),
600 c_name.as_ptr(),
601 n_arg,
602 ffi::SQLITE_UTF8,
603 ptr::null_mut(),
604 None,
605 None,
606 None,
607 None,
608 )
609 };
610 self.decode_result(r)
611 }
612 }
613
aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A>614 unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> {
615 let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
616 if pac.is_null() {
617 return None;
618 }
619 Some(pac)
620 }
621
call_boxed_step<A, D, T>( ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value, ) where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T>, T: ToSql,622 unsafe extern "C" fn call_boxed_step<A, D, T>(
623 ctx: *mut sqlite3_context,
624 argc: c_int,
625 argv: *mut *mut sqlite3_value,
626 ) where
627 A: RefUnwindSafe + UnwindSafe,
628 D: Aggregate<A, T>,
629 T: ToSql,
630 {
631 let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) {
632 pac
633 } else {
634 ffi::sqlite3_result_error_nomem(ctx);
635 return;
636 };
637
638 let r = catch_unwind(|| {
639 let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
640 assert!(
641 !boxed_aggr.is_null(),
642 "Internal error - null aggregate pointer"
643 );
644 let mut ctx = Context {
645 ctx,
646 args: slice::from_raw_parts(argv, argc as usize),
647 };
648
649 if (*pac as *mut A).is_null() {
650 *pac = Box::into_raw(Box::new((*boxed_aggr).init(&mut ctx)?));
651 }
652
653 (*boxed_aggr).step(&mut ctx, &mut **pac)
654 });
655 let r = match r {
656 Err(_) => {
657 report_error(ctx, &Error::UnwindingPanic);
658 return;
659 }
660 Ok(r) => r,
661 };
662 match r {
663 Ok(_) => {}
664 Err(err) => report_error(ctx, &err),
665 };
666 }
667
668 #[cfg(feature = "window")]
call_boxed_inverse<A, W, T>( ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value, ) where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T>, T: ToSql,669 unsafe extern "C" fn call_boxed_inverse<A, W, T>(
670 ctx: *mut sqlite3_context,
671 argc: c_int,
672 argv: *mut *mut sqlite3_value,
673 ) where
674 A: RefUnwindSafe + UnwindSafe,
675 W: WindowAggregate<A, T>,
676 T: ToSql,
677 {
678 let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) {
679 pac
680 } else {
681 ffi::sqlite3_result_error_nomem(ctx);
682 return;
683 };
684
685 let r = catch_unwind(|| {
686 let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
687 assert!(
688 !boxed_aggr.is_null(),
689 "Internal error - null aggregate pointer"
690 );
691 let mut ctx = Context {
692 ctx,
693 args: slice::from_raw_parts(argv, argc as usize),
694 };
695 (*boxed_aggr).inverse(&mut ctx, &mut **pac)
696 });
697 let r = match r {
698 Err(_) => {
699 report_error(ctx, &Error::UnwindingPanic);
700 return;
701 }
702 Ok(r) => r,
703 };
704 match r {
705 Ok(_) => {}
706 Err(err) => report_error(ctx, &err),
707 };
708 }
709
call_boxed_final<A, D, T>(ctx: *mut sqlite3_context) where A: RefUnwindSafe + UnwindSafe, D: Aggregate<A, T>, T: ToSql,710 unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
711 where
712 A: RefUnwindSafe + UnwindSafe,
713 D: Aggregate<A, T>,
714 T: ToSql,
715 {
716 // Within the xFinal callback, it is customary to set N=0 in calls to
717 // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
718 let a: Option<A> = match aggregate_context(ctx, 0) {
719 Some(pac) => {
720 if (*pac as *mut A).is_null() {
721 None
722 } else {
723 let a = Box::from_raw(*pac);
724 Some(*a)
725 }
726 }
727 None => None,
728 };
729
730 let r = catch_unwind(|| {
731 let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
732 assert!(
733 !boxed_aggr.is_null(),
734 "Internal error - null aggregate pointer"
735 );
736 let mut ctx = Context { ctx, args: &mut [] };
737 (*boxed_aggr).finalize(&mut ctx, a)
738 });
739 let t = match r {
740 Err(_) => {
741 report_error(ctx, &Error::UnwindingPanic);
742 return;
743 }
744 Ok(r) => r,
745 };
746 let t = t.as_ref().map(|t| ToSql::to_sql(t));
747 match t {
748 Ok(Ok(ref value)) => set_result(ctx, value),
749 Ok(Err(err)) => report_error(ctx, &err),
750 Err(err) => report_error(ctx, err),
751 }
752 }
753
754 #[cfg(feature = "window")]
call_boxed_value<A, W, T>(ctx: *mut sqlite3_context) where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate<A, T>, T: ToSql,755 unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
756 where
757 A: RefUnwindSafe + UnwindSafe,
758 W: WindowAggregate<A, T>,
759 T: ToSql,
760 {
761 // Within the xValue callback, it is customary to set N=0 in calls to
762 // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
763 let a: Option<&A> = match aggregate_context(ctx, 0) {
764 Some(pac) => {
765 if (*pac as *mut A).is_null() {
766 None
767 } else {
768 let a = &**pac;
769 Some(a)
770 }
771 }
772 None => None,
773 };
774
775 let r = catch_unwind(|| {
776 let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
777 assert!(
778 !boxed_aggr.is_null(),
779 "Internal error - null aggregate pointer"
780 );
781 (*boxed_aggr).value(a)
782 });
783 let t = match r {
784 Err(_) => {
785 report_error(ctx, &Error::UnwindingPanic);
786 return;
787 }
788 Ok(r) => r,
789 };
790 let t = t.as_ref().map(|t| ToSql::to_sql(t));
791 match t {
792 Ok(Ok(ref value)) => set_result(ctx, value),
793 Ok(Err(err)) => report_error(ctx, &err),
794 Err(err) => report_error(ctx, err),
795 }
796 }
797
798 #[cfg(test)]
799 mod test {
800 use regex::Regex;
801 use std::os::raw::c_double;
802
803 #[cfg(feature = "window")]
804 use crate::functions::WindowAggregate;
805 use crate::functions::{Aggregate, Context, FunctionFlags};
806 use crate::{Connection, Error, Result};
807
half(ctx: &Context<'_>) -> Result<c_double>808 fn half(ctx: &Context<'_>) -> Result<c_double> {
809 assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
810 let value = ctx.get::<c_double>(0)?;
811 Ok(value / 2f64)
812 }
813
814 #[test]
test_function_half() -> Result<()>815 fn test_function_half() -> Result<()> {
816 let db = Connection::open_in_memory()?;
817 db.create_scalar_function(
818 "half",
819 1,
820 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
821 half,
822 )?;
823 let result: f64 = db.one_column("SELECT half(6)")?;
824
825 assert!((3f64 - result).abs() < f64::EPSILON);
826 Ok(())
827 }
828
829 #[test]
test_remove_function() -> Result<()>830 fn test_remove_function() -> Result<()> {
831 let db = Connection::open_in_memory()?;
832 db.create_scalar_function(
833 "half",
834 1,
835 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
836 half,
837 )?;
838 let result: f64 = db.one_column("SELECT half(6)")?;
839 assert!((3f64 - result).abs() < f64::EPSILON);
840
841 db.remove_function("half", 1)?;
842 let result: Result<f64> = db.one_column("SELECT half(6)");
843 result.unwrap_err();
844 Ok(())
845 }
846
847 // This implementation of a regexp scalar function uses SQLite's auxiliary data
848 // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular
849 // expression multiple times within one query.
regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool>850 fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool> {
851 assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
852 type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
853 let regexp: std::sync::Arc<Regex> = ctx
854 .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
855 Ok(Regex::new(vr.as_str()?)?)
856 })?;
857
858 let is_match = {
859 let text = ctx
860 .get_raw(1)
861 .as_str()
862 .map_err(|e| Error::UserFunctionError(e.into()))?;
863
864 regexp.is_match(text)
865 };
866
867 Ok(is_match)
868 }
869
870 #[test]
test_function_regexp_with_auxilliary() -> Result<()>871 fn test_function_regexp_with_auxilliary() -> Result<()> {
872 let db = Connection::open_in_memory()?;
873 db.execute_batch(
874 "BEGIN;
875 CREATE TABLE foo (x string);
876 INSERT INTO foo VALUES ('lisa');
877 INSERT INTO foo VALUES ('lXsi');
878 INSERT INTO foo VALUES ('lisX');
879 END;",
880 )?;
881 db.create_scalar_function(
882 "regexp",
883 2,
884 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
885 regexp_with_auxilliary,
886 )?;
887
888 let result: bool = db.one_column("SELECT regexp('l.s[aeiouy]', 'lisa')")?;
889
890 assert!(result);
891
892 let result: i64 =
893 db.one_column("SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1")?;
894
895 assert_eq!(2, result);
896 Ok(())
897 }
898
899 #[test]
test_varargs_function() -> Result<()>900 fn test_varargs_function() -> Result<()> {
901 let db = Connection::open_in_memory()?;
902 db.create_scalar_function(
903 "my_concat",
904 -1,
905 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
906 |ctx| {
907 let mut ret = String::new();
908
909 for idx in 0..ctx.len() {
910 let s = ctx.get::<String>(idx)?;
911 ret.push_str(&s);
912 }
913
914 Ok(ret)
915 },
916 )?;
917
918 for &(expected, query) in &[
919 ("", "SELECT my_concat()"),
920 ("onetwo", "SELECT my_concat('one', 'two')"),
921 ("abc", "SELECT my_concat('a', 'b', 'c')"),
922 ] {
923 let result: String = db.one_column(query)?;
924 assert_eq!(expected, result);
925 }
926 Ok(())
927 }
928
929 #[test]
test_get_aux_type_checking() -> Result<()>930 fn test_get_aux_type_checking() -> Result<()> {
931 let db = Connection::open_in_memory()?;
932 db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| {
933 if !ctx.get::<bool>(1)? {
934 ctx.set_aux::<i64>(0, 100)?;
935 } else {
936 assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
937 assert_eq!(*ctx.get_aux::<i64>(0)?.unwrap(), 100);
938 }
939 Ok(true)
940 })?;
941
942 let res: bool =
943 db.one_column("SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)")?;
944 // Doesn't actually matter, we'll assert in the function if there's a problem.
945 assert!(res);
946 Ok(())
947 }
948
949 struct Sum;
950 struct Count;
951
952 impl Aggregate<i64, Option<i64>> for Sum {
init(&self, _: &mut Context<'_>) -> Result<i64>953 fn init(&self, _: &mut Context<'_>) -> Result<i64> {
954 Ok(0)
955 }
956
step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()>957 fn step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
958 *sum += ctx.get::<i64>(0)?;
959 Ok(())
960 }
961
finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<Option<i64>>962 fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<Option<i64>> {
963 Ok(sum)
964 }
965 }
966
967 impl Aggregate<i64, i64> for Count {
init(&self, _: &mut Context<'_>) -> Result<i64>968 fn init(&self, _: &mut Context<'_>) -> Result<i64> {
969 Ok(0)
970 }
971
step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()>972 fn step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
973 *sum += 1;
974 Ok(())
975 }
976
finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<i64>977 fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<i64> {
978 Ok(sum.unwrap_or(0))
979 }
980 }
981
982 #[test]
test_sum() -> Result<()>983 fn test_sum() -> Result<()> {
984 let db = Connection::open_in_memory()?;
985 db.create_aggregate_function(
986 "my_sum",
987 1,
988 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
989 Sum,
990 )?;
991
992 // sum should return NULL when given no columns (contrast with count below)
993 let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
994 let result: Option<i64> = db.one_column(no_result)?;
995 assert!(result.is_none());
996
997 let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
998 let result: i64 = db.one_column(single_sum)?;
999 assert_eq!(4, result);
1000
1001 let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \
1002 2, 1)";
1003 let result: (i64, i64) = db.query_row(dual_sum, [], |r| Ok((r.get(0)?, r.get(1)?)))?;
1004 assert_eq!((4, 2), result);
1005 Ok(())
1006 }
1007
1008 #[test]
test_count() -> Result<()>1009 fn test_count() -> Result<()> {
1010 let db = Connection::open_in_memory()?;
1011 db.create_aggregate_function(
1012 "my_count",
1013 -1,
1014 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1015 Count,
1016 )?;
1017
1018 // count should return 0 when given no columns (contrast with sum above)
1019 let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
1020 let result: i64 = db.one_column(no_result)?;
1021 assert_eq!(result, 0);
1022
1023 let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
1024 let result: i64 = db.one_column(single_sum)?;
1025 assert_eq!(2, result);
1026 Ok(())
1027 }
1028
1029 #[cfg(feature = "window")]
1030 impl WindowAggregate<i64, Option<i64>> for Sum {
inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()>1031 fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1032 *sum -= ctx.get::<i64>(0)?;
1033 Ok(())
1034 }
1035
value(&self, sum: Option<&i64>) -> Result<Option<i64>>1036 fn value(&self, sum: Option<&i64>) -> Result<Option<i64>> {
1037 Ok(sum.copied())
1038 }
1039 }
1040
1041 #[test]
1042 #[cfg(feature = "window")]
test_window() -> Result<()>1043 fn test_window() -> Result<()> {
1044 use fallible_iterator::FallibleIterator;
1045
1046 let db = Connection::open_in_memory()?;
1047 db.create_window_function(
1048 "sumint",
1049 1,
1050 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1051 Sum,
1052 )?;
1053 db.execute_batch(
1054 "CREATE TABLE t3(x, y);
1055 INSERT INTO t3 VALUES('a', 4),
1056 ('b', 5),
1057 ('c', 3),
1058 ('d', 8),
1059 ('e', 1);",
1060 )?;
1061
1062 let mut stmt = db.prepare(
1063 "SELECT x, sumint(y) OVER (
1064 ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
1065 ) AS sum_y
1066 FROM t3 ORDER BY x;",
1067 )?;
1068
1069 let results: Vec<(String, i64)> = stmt
1070 .query([])?
1071 .map(|row| Ok((row.get("x")?, row.get("sum_y")?)))
1072 .collect()?;
1073 let expected = vec![
1074 ("a".to_owned(), 9),
1075 ("b".to_owned(), 12),
1076 ("c".to_owned(), 16),
1077 ("d".to_owned(), 12),
1078 ("e".to_owned(), 9),
1079 ];
1080 assert_eq!(expected, results);
1081 Ok(())
1082 }
1083 }
1084