diff --git a/crates/duckdb/src/inner_connection.rs b/crates/duckdb/src/inner_connection.rs index f9d50c1c..4cf40bb9 100644 --- a/crates/duckdb/src/inner_connection.rs +++ b/crates/duckdb/src/inner_connection.rs @@ -3,6 +3,7 @@ use std::{ mem, os::raw::c_char, ptr, str, + sync::{Arc, Mutex}, }; use super::{ffi, Appender, Config, Connection, Result}; @@ -15,6 +16,7 @@ use crate::{ pub struct InnerConnection { pub db: ffi::duckdb_database, pub con: ffi::duckdb_connection, + interrupt: Arc, owned: bool, } @@ -30,7 +32,14 @@ impl InnerConnection { Some("connect error".to_owned()), )); } - Ok(InnerConnection { db, con, owned }) + let interrupt = Arc::new(InterruptHandle::new(con)); + + Ok(InnerConnection { + db, + con, + interrupt, + owned, + }) } pub fn open_with_flags(c_path: &CStr, config: Config) -> Result { @@ -57,6 +66,7 @@ impl InnerConnection { unsafe { ffi::duckdb_disconnect(&mut self.con); self.con = ptr::null_mut(); + self.interrupt.clear(); if self.owned { ffi::duckdb_close(&mut self.db); @@ -106,6 +116,10 @@ impl InnerConnection { Ok(Appender::new(conn, c_app)) } + pub fn get_interrupt_handle(&self) -> Arc { + self.interrupt.clone() + } + #[inline] pub fn is_autocommit(&self) -> bool { true @@ -126,3 +140,37 @@ impl Drop for InnerConnection { } } } + +/// A handle that allows interrupting long-running queries. +pub struct InterruptHandle { + conn: Mutex, +} + +unsafe impl Send for InterruptHandle {} +unsafe impl Sync for InterruptHandle {} + +impl InterruptHandle { + fn new(conn: ffi::duckdb_connection) -> Self { + Self { conn: Mutex::new(conn) } + } + + fn clear(&self) { + *(self.conn.lock().unwrap()) = ptr::null_mut(); + } + + /// Interrupt the query currently running on the connection this handle was + /// obtained from. The interrupt will cause that query to fail with + /// `Error::DuckDBFailure`. If the connection was dropped after obtaining + /// this interrupt handle, calling this method results in a noop. + /// + /// See [`crate::Connection::interrupt_handle`] for an example. + pub fn interrupt(&self) { + let db_handle = self.conn.lock().unwrap(); + + if !db_handle.is_null() { + unsafe { + ffi::duckdb_interrupt(*db_handle); + } + } + } +} diff --git a/crates/duckdb/src/lib.rs b/crates/duckdb/src/lib.rs index b7d796bd..46c5e770 100644 --- a/crates/duckdb/src/lib.rs +++ b/crates/duckdb/src/lib.rs @@ -79,6 +79,7 @@ pub use crate::{ config::{AccessMode, Config, DefaultNullOrder, DefaultOrder}, error::Error, ffi::ErrorCode, + inner_connection::InterruptHandle, params::{params_from_iter, Params, ParamsFromIter}, row::{AndThenRows, Map, MappedRows, Row, RowIndex, Rows}, statement::Statement, @@ -532,6 +533,30 @@ impl Connection { self.db.borrow_mut().appender(self, table, schema) } + /// Get a handle to interrupt long-running queries. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use duckdb::{Connection, Result}; + /// fn run_query(conn: Connection) -> Result<()> { + /// let interrupt_handle = conn.interrupt_handle(); + /// let join_handle = std::thread::spawn(move || { conn.execute("expensive query", []) }); + /// + /// // Arbitrary wait for query to start + /// std::thread::sleep(std::time::Duration::from_millis(100)); + /// + /// interrupt_handle.interrupt(); + /// + /// let query_result = join_handle.join().unwrap(); + /// assert!(query_result.is_err()); + /// + /// Ok(()) + /// } + pub fn interrupt_handle(&self) -> std::sync::Arc { + self.db.borrow().get_interrupt_handle() + } + /// Close the DuckDB connection. /// /// This is functionally equivalent to the `Drop` implementation for @@ -1337,6 +1362,36 @@ mod test { Ok(()) } + #[test] + fn test_interrupt() -> Result<()> { + let db = checked_memory_handle(); + let db_interrupt = db.interrupt_handle(); + + let (tx, rx) = std::sync::mpsc::channel(); + std::thread::spawn(move || { + let mut stmt = db + .prepare("select count(*) from range(10000000) t1, range(1000000) t2") + .unwrap(); + tx.send(stmt.execute([])).unwrap(); + }); + + std::thread::sleep(std::time::Duration::from_millis(100)); + db_interrupt.interrupt(); + + let result = rx.recv_timeout(std::time::Duration::from_secs(5)).unwrap(); + assert!(result.is_err_and(|err| err.to_string().contains("INTERRUPT"))); + Ok(()) + } + + #[test] + fn test_interrupt_on_dropped_db() { + let db = checked_memory_handle(); + let db_interrupt = db.interrupt_handle(); + + drop(db); + db_interrupt.interrupt(); + } + #[cfg(feature = "bundled")] #[test] fn test_version() -> Result<()> {