Skip to content

Commit e49de40

Browse files
feat(sqlite): add WAL hook support
1 parent 69bb595 commit e49de40

File tree

3 files changed

+122
-1
lines changed

3 files changed

+122
-1
lines changed

sqlx-sqlite/src/connection/establish.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ impl EstablishParams {
188188
preupdate_hook_callback: None,
189189
commit_hook_callback: None,
190190
rollback_hook_callback: None,
191+
wal_hook_callback: None,
191192
})
192193
}
193194

sqlx-sqlite/src/connection/mod.rs

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use std::ptr::NonNull;
1111
use futures_intrusive::sync::MutexGuard;
1212
use libsqlite3_sys::{
1313
sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook,
14-
sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE,
14+
sqlite3_update_hook, sqlite3_wal_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_OK, SQLITE_UPDATE,
1515
};
1616
#[cfg(feature = "preupdate-hook")]
1717
pub use preupdate_hook::*;
@@ -96,6 +96,11 @@ pub struct UpdateHookResult<'a> {
9696
pub rowid: i64,
9797
}
9898

99+
pub struct WalHookResult<'a> {
100+
pub database: &'a str,
101+
pub page_count: i32,
102+
}
103+
99104
pub(crate) struct UpdateHookHandler(NonNull<dyn FnMut(UpdateHookResult) + Send + 'static>);
100105
unsafe impl Send for UpdateHookHandler {}
101106

@@ -105,6 +110,9 @@ unsafe impl Send for CommitHookHandler {}
105110
pub(crate) struct RollbackHookHandler(NonNull<dyn FnMut() + Send + 'static>);
106111
unsafe impl Send for RollbackHookHandler {}
107112

113+
pub(crate) struct WalHookHandler(NonNull<dyn FnMut(WalHookResult) + Send + 'static>);
114+
unsafe impl Send for WalHookHandler {}
115+
108116
pub(crate) struct ConnectionState {
109117
pub(crate) handle: ConnectionHandle,
110118

@@ -123,6 +131,8 @@ pub(crate) struct ConnectionState {
123131
commit_hook_callback: Option<CommitHookHandler>,
124132

125133
rollback_hook_callback: Option<RollbackHookHandler>,
134+
135+
wal_hook_callback: Option<WalHookHandler>,
126136
}
127137

128138
impl ConnectionState {
@@ -172,6 +182,15 @@ impl ConnectionState {
172182
}
173183
}
174184
}
185+
186+
pub(crate) fn remove_wal_hook(&mut self) {
187+
if let Some(mut handler) = self.wal_hook_callback.take() {
188+
unsafe {
189+
sqlite3_wal_hook(self.handle.as_ptr(), None, ptr::null_mut());
190+
let _ = { Box::from_raw(handler.0.as_mut()) };
191+
}
192+
}
193+
}
175194
}
176195

177196
pub(crate) struct Statements {
@@ -353,6 +372,28 @@ where
353372
}
354373
}
355374

375+
extern "C" fn wal_hook<F>(
376+
callback: *mut c_void,
377+
_db: *mut sqlite3,
378+
database: *const c_char,
379+
page_count: c_int,
380+
) -> c_int
381+
where
382+
F: FnMut(WalHookResult) + Send + 'static,
383+
{
384+
unsafe {
385+
let _ = catch_unwind(|| {
386+
let callback: *mut F = callback.cast::<F>();
387+
let database = CStr::from_ptr(database).to_str().unwrap_or_default();
388+
(*callback)(WalHookResult {
389+
database,
390+
page_count,
391+
})
392+
});
393+
}
394+
SQLITE_OK
395+
}
396+
356397
impl LockedSqliteHandle<'_> {
357398
/// Returns the underlying sqlite3* connection handle.
358399
///
@@ -520,6 +561,26 @@ impl LockedSqliteHandle<'_> {
520561
}
521562
}
522563

564+
/// Sets a WAL hook that is invoked whenever a commit occurs in WAL mode. Only a single WAL hook may be
565+
/// defined at one time per database connection; setting a new WAL hook overrides the old one.
566+
///
567+
/// Note that sqlite3_wal_autocheckpoint() and the wal_autocheckpoint pragma overwrite the WAL hook.
568+
pub fn set_wal_hook<F>(&mut self, callback: F)
569+
where
570+
F: FnMut(WalHookResult) + Send + 'static,
571+
{
572+
unsafe {
573+
let callback_boxed = Box::new(callback);
574+
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
575+
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
576+
let handler = callback.as_ptr() as *mut _;
577+
self.guard.remove_wal_hook();
578+
self.guard.wal_hook_callback = Some(WalHookHandler(callback));
579+
580+
sqlite3_wal_hook(self.as_raw_handle().as_mut(), Some(wal_hook::<F>), handler);
581+
}
582+
}
583+
523584
/// Removes the progress handler on a database connection. The method does nothing if no handler was set.
524585
pub fn remove_progress_handler(&mut self) {
525586
self.guard.remove_progress_handler();
@@ -542,6 +603,10 @@ impl LockedSqliteHandle<'_> {
542603
self.guard.remove_rollback_hook();
543604
}
544605

606+
pub fn remove_wal_hook(&mut self) {
607+
self.guard.remove_wal_hook();
608+
}
609+
545610
pub fn last_error(&mut self) -> Option<SqliteError> {
546611
self.guard.handle.last_error()
547612
}

tests/sqlite/sqlite.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,61 @@ async fn test_multiple_set_rollback_hook_calls_drop_old_handler() -> anyhow::Res
978978
Ok(())
979979
}
980980

981+
#[sqlx_macros::test]
982+
async fn test_query_with_wal_hook() -> anyhow::Result<()> {
983+
let mut conn = new::<Sqlite>().await?;
984+
985+
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
986+
let state = "test".to_string();
987+
static CALLED: AtomicBool = AtomicBool::new(false);
988+
conn.lock_handle().await?.set_wal_hook(move |_| {
989+
assert_eq!(state, "test");
990+
CALLED.store(true, Ordering::Relaxed);
991+
});
992+
993+
let mut tx = conn.begin().await?;
994+
sqlx::query("CREATE TABLE test (id INTEGER PRIMARY KEY)")
995+
.execute(&mut *tx)
996+
.await?;
997+
assert!(!CALLED.load(Ordering::Relaxed));
998+
tx.commit().await?;
999+
assert!(CALLED.load(Ordering::Relaxed));
1000+
Ok(())
1001+
}
1002+
1003+
#[sqlx_macros::test]
1004+
async fn test_multiple_set_wal_hook_calls_drop_old_handler() -> anyhow::Result<()> {
1005+
let ref_counted_object = Arc::new(0);
1006+
assert_eq!(1, Arc::strong_count(&ref_counted_object));
1007+
1008+
{
1009+
let mut conn = new::<Sqlite>().await?;
1010+
1011+
let o = ref_counted_object.clone();
1012+
conn.lock_handle().await?.set_wal_hook(move |_| {
1013+
println!("{o:?}");
1014+
});
1015+
assert_eq!(2, Arc::strong_count(&ref_counted_object));
1016+
1017+
let o = ref_counted_object.clone();
1018+
conn.lock_handle().await?.set_wal_hook(move |_| {
1019+
println!("{o:?}");
1020+
});
1021+
assert_eq!(2, Arc::strong_count(&ref_counted_object));
1022+
1023+
let o = ref_counted_object.clone();
1024+
conn.lock_handle().await?.set_wal_hook(move |_| {
1025+
println!("{o:?}");
1026+
});
1027+
assert_eq!(2, Arc::strong_count(&ref_counted_object));
1028+
1029+
conn.lock_handle().await?.remove_wal_hook();
1030+
}
1031+
1032+
assert_eq!(1, Arc::strong_count(&ref_counted_object));
1033+
Ok(())
1034+
}
1035+
9811036
#[sqlx_macros::test]
9821037
async fn issue_3150() {
9831038
// Same bounds as `tokio::spawn()`

0 commit comments

Comments
 (0)