Skip to content

Commit ac2d868

Browse files
feat(sqlite): add WAL hook support
1 parent 69bb595 commit ac2d868

File tree

3 files changed

+123
-1
lines changed

3 files changed

+123
-1
lines changed

sqlx-sqlite/src/connection/establish.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ impl EstablishParams {
188188
preupdate_hook_callback: None,
189189
commit_hook_callback: None,
190190
rollback_hook_callback: None,
191+
wal_hook_callback: None,
191192
})
192193
}
193194

sqlx-sqlite/src/connection/mod.rs

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use std::ptr::NonNull;
1111
use futures_intrusive::sync::MutexGuard;
1212
use libsqlite3_sys::{
1313
sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook,
14-
sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE,
14+
sqlite3_update_hook, sqlite3_wal_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_OK, SQLITE_UPDATE,
1515
};
1616
#[cfg(feature = "preupdate-hook")]
1717
pub use preupdate_hook::*;
@@ -96,6 +96,11 @@ pub struct UpdateHookResult<'a> {
9696
pub rowid: i64,
9797
}
9898

99+
pub struct WalHookResult<'a> {
100+
pub database: &'a str,
101+
pub page_count: i32,
102+
}
103+
99104
pub(crate) struct UpdateHookHandler(NonNull<dyn FnMut(UpdateHookResult) + Send + 'static>);
100105
unsafe impl Send for UpdateHookHandler {}
101106

@@ -105,6 +110,9 @@ unsafe impl Send for CommitHookHandler {}
105110
pub(crate) struct RollbackHookHandler(NonNull<dyn FnMut() + Send + 'static>);
106111
unsafe impl Send for RollbackHookHandler {}
107112

113+
pub(crate) struct WalHookHandler(NonNull<dyn FnMut(WalHookResult) + Send + 'static>);
114+
unsafe impl Send for WalHookHandler {}
115+
108116
pub(crate) struct ConnectionState {
109117
pub(crate) handle: ConnectionHandle,
110118

@@ -123,6 +131,8 @@ pub(crate) struct ConnectionState {
123131
commit_hook_callback: Option<CommitHookHandler>,
124132

125133
rollback_hook_callback: Option<RollbackHookHandler>,
134+
135+
wal_hook_callback: Option<WalHookHandler>,
126136
}
127137

128138
impl ConnectionState {
@@ -172,6 +182,15 @@ impl ConnectionState {
172182
}
173183
}
174184
}
185+
186+
pub(crate) fn remove_wal_hook(&mut self) {
187+
if let Some(mut handler) = self.wal_hook_callback.take() {
188+
unsafe {
189+
sqlite3_wal_hook(self.handle.as_ptr(), None, ptr::null_mut());
190+
let _ = { Box::from_raw(handler.0.as_mut()) };
191+
}
192+
}
193+
}
175194
}
176195

177196
pub(crate) struct Statements {
@@ -353,6 +372,28 @@ where
353372
}
354373
}
355374

375+
extern "C" fn wal_hook<F>(
376+
callback: *mut c_void,
377+
_db: *mut sqlite3,
378+
database: *const c_char,
379+
page_count: c_int,
380+
) -> c_int
381+
where
382+
F: FnMut(WalHookResult) + Send + 'static,
383+
{
384+
unsafe {
385+
let _ = catch_unwind(|| {
386+
let callback: *mut F = callback.cast::<F>();
387+
let database = CStr::from_ptr(database).to_str().unwrap_or_default();
388+
(*callback)(WalHookResult {
389+
database,
390+
page_count,
391+
})
392+
});
393+
}
394+
SQLITE_OK
395+
}
396+
356397
impl LockedSqliteHandle<'_> {
357398
/// Returns the underlying sqlite3* connection handle.
358399
///
@@ -520,6 +561,26 @@ impl LockedSqliteHandle<'_> {
520561
}
521562
}
522563

564+
/// Sets a WAL hook that is invoked whenever a commit occurs in WAL mode. Only a single WAL hook may be
565+
/// defined at one time per database connection; setting a new WAL hook overrides the old one.
566+
///
567+
/// Note that sqlite3_wal_autocheckpoint() and the wal_autocheckpoint pragma overwrite the WAL hook.
568+
pub fn set_wal_hook<F>(&mut self, callback: F)
569+
where
570+
F: FnMut(WalHookResult) + Send + 'static,
571+
{
572+
unsafe {
573+
let callback_boxed = Box::new(callback);
574+
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
575+
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
576+
let handler = callback.as_ptr() as *mut _;
577+
self.guard.remove_wal_hook();
578+
self.guard.wal_hook_callback = Some(WalHookHandler(callback));
579+
580+
sqlite3_wal_hook(self.as_raw_handle().as_mut(), Some(wal_hook::<F>), handler);
581+
}
582+
}
583+
523584
/// Removes the progress handler on a database connection. The method does nothing if no handler was set.
524585
pub fn remove_progress_handler(&mut self) {
525586
self.guard.remove_progress_handler();
@@ -542,6 +603,10 @@ impl LockedSqliteHandle<'_> {
542603
self.guard.remove_rollback_hook();
543604
}
544605

606+
pub fn remove_wal_hook(&mut self) {
607+
self.guard.remove_wal_hook();
608+
}
609+
545610
pub fn last_error(&mut self) -> Option<SqliteError> {
546611
self.guard.handle.last_error()
547612
}

tests/sqlite/sqlite.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,62 @@ async fn test_multiple_set_rollback_hook_calls_drop_old_handler() -> anyhow::Res
978978
Ok(())
979979
}
980980

981+
#[sqlx_macros::test]
982+
async fn test_query_with_wal_hook() -> anyhow::Result<()> {
983+
let mut conn = new::<Sqlite>().await?;
984+
conn.execute("PRAGMA journal_mode = WAL").await?;
985+
986+
// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
987+
let state = "test".to_string();
988+
static CALLED: AtomicBool = AtomicBool::new(false);
989+
conn.lock_handle().await?.set_wal_hook(move |_| {
990+
assert_eq!(state, "test");
991+
CALLED.store(true, Ordering::Relaxed);
992+
});
993+
994+
let mut tx = conn.begin().await?;
995+
sqlx::query("CREATE TABLE test (id INTEGER PRIMARY KEY)")
996+
.execute(&mut *tx)
997+
.await?;
998+
assert!(!CALLED.load(Ordering::Relaxed));
999+
tx.commit().await?;
1000+
assert!(CALLED.load(Ordering::Relaxed));
1001+
Ok(())
1002+
}
1003+
1004+
#[sqlx_macros::test]
1005+
async fn test_multiple_set_wal_hook_calls_drop_old_handler() -> anyhow::Result<()> {
1006+
let ref_counted_object = Arc::new(0);
1007+
assert_eq!(1, Arc::strong_count(&ref_counted_object));
1008+
1009+
{
1010+
let mut conn = new::<Sqlite>().await?;
1011+
1012+
let o = ref_counted_object.clone();
1013+
conn.lock_handle().await?.set_wal_hook(move |_| {
1014+
println!("{o:?}");
1015+
});
1016+
assert_eq!(2, Arc::strong_count(&ref_counted_object));
1017+
1018+
let o = ref_counted_object.clone();
1019+
conn.lock_handle().await?.set_wal_hook(move |_| {
1020+
println!("{o:?}");
1021+
});
1022+
assert_eq!(2, Arc::strong_count(&ref_counted_object));
1023+
1024+
let o = ref_counted_object.clone();
1025+
conn.lock_handle().await?.set_wal_hook(move |_| {
1026+
println!("{o:?}");
1027+
});
1028+
assert_eq!(2, Arc::strong_count(&ref_counted_object));
1029+
1030+
conn.lock_handle().await?.remove_wal_hook();
1031+
}
1032+
1033+
assert_eq!(1, Arc::strong_count(&ref_counted_object));
1034+
Ok(())
1035+
}
1036+
9811037
#[sqlx_macros::test]
9821038
async fn issue_3150() {
9831039
// Same bounds as `tokio::spawn()`

0 commit comments

Comments
 (0)