Skip to content

Commit 22d4725

Browse files
committed
Add interrupt handle
Exposes the `duckdb_interrupt` function from the C API to allow consumers to interrupt long-running queries from another thread. Inspired by rusqlite: https://docs.rs/rusqlite/latest/rusqlite/struct.InterruptHandle.html
1 parent ef1432f commit 22d4725

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed

crates/duckdb/src/inner_connection.rs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::{
33
mem,
44
os::raw::c_char,
55
ptr, str,
6+
sync::{Arc, Mutex}
67
};
78

89
use super::{ffi, Appender, Config, Connection, Result};
@@ -15,6 +16,7 @@ use crate::{
1516
pub struct InnerConnection {
1617
pub db: ffi::duckdb_database,
1718
pub con: ffi::duckdb_connection,
19+
interrupt: Arc<InterruptHandle>,
1820
owned: bool,
1921
}
2022

@@ -30,7 +32,9 @@ impl InnerConnection {
3032
Some("connect error".to_owned()),
3133
));
3234
}
33-
Ok(InnerConnection { db, con, owned })
35+
let interrupt = Arc::new(InterruptHandle::new(con));
36+
37+
Ok(InnerConnection { db, con, interrupt, owned })
3438
}
3539

3640
pub fn open_with_flags(c_path: &CStr, config: Config) -> Result<InnerConnection> {
@@ -57,6 +61,7 @@ impl InnerConnection {
5761
unsafe {
5862
ffi::duckdb_disconnect(&mut self.con);
5963
self.con = ptr::null_mut();
64+
self.interrupt.clear();
6065

6166
if self.owned {
6267
ffi::duckdb_close(&mut self.db);
@@ -106,6 +111,11 @@ impl InnerConnection {
106111
Ok(Appender::new(conn, c_app))
107112
}
108113

114+
#[inline]
115+
pub fn get_interrupt_handle(&self) -> Arc<InterruptHandle> {
116+
self.interrupt.clone()
117+
}
118+
109119
#[inline]
110120
pub fn is_autocommit(&self) -> bool {
111121
true
@@ -126,3 +136,34 @@ impl Drop for InnerConnection {
126136
}
127137
}
128138
}
139+
140+
pub struct InterruptHandle {
141+
conn: Mutex<ffi::duckdb_connection>,
142+
}
143+
144+
unsafe impl Send for InterruptHandle {}
145+
unsafe impl Sync for InterruptHandle {}
146+
147+
impl InterruptHandle {
148+
pub fn new(conn: ffi::duckdb_connection) -> Self {
149+
Self {
150+
conn: Mutex::new(conn),
151+
}
152+
}
153+
154+
pub fn clear(&self) {
155+
*(self.conn.lock().unwrap()) = ptr::null_mut();
156+
}
157+
158+
/// Interrupt any query currently executing on another thread. This will
159+
/// cause that query to fail with an `Error::DuckDBFailure` error.
160+
pub fn interrupt(&self) {
161+
let db_handle = self.conn.lock().unwrap();
162+
163+
if !db_handle.is_null() {
164+
unsafe {
165+
ffi::duckdb_interrupt(*db_handle);
166+
}
167+
}
168+
}
169+
}

crates/duckdb/src/lib.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,30 @@ impl Connection {
532532
self.db.borrow_mut().appender(self, table, schema)
533533
}
534534

535+
/// Get a handle to interrupt long-running queries.
536+
///
537+
/// ## Example
538+
///
539+
/// ```rust,no_run
540+
/// # use duckdb::{Connection, Result};
541+
/// fn run_query(conn: Connection) -> Result<()> {
542+
/// let interrupt = conn.interrupt_handle();
543+
/// let handle = std::thread::spawn(move || { conn.execute("expensive query", []) });
544+
///
545+
/// // Arbitrary wait for query to start
546+
/// std::thread::sleep(std::time::Duration::from_millis(100));
547+
///
548+
/// interrupt.interrupt();
549+
///
550+
/// let query_result = handle.join().unwrap();
551+
/// assert!(query_result.is_err());
552+
///
553+
/// Ok(())
554+
/// }
555+
pub fn interrupt_handle(&self) -> std::sync::Arc<inner_connection::InterruptHandle> {
556+
self.db.borrow().get_interrupt_handle()
557+
}
558+
535559
/// Close the DuckDB connection.
536560
///
537561
/// This is functionally equivalent to the `Drop` implementation for
@@ -1337,6 +1361,34 @@ mod test {
13371361
Ok(())
13381362
}
13391363

1364+
#[test]
1365+
fn test_interrupt() -> Result<()> {
1366+
let db = checked_memory_handle();
1367+
let db_interrupt = db.interrupt_handle();
1368+
1369+
let (tx, rx) = std::sync::mpsc::channel();
1370+
std::thread::spawn(move || {
1371+
let mut stmt = db.prepare("select count(*) from range(10000000) t1, range(1000000) t2").unwrap();
1372+
tx.send(stmt.execute([])).unwrap();
1373+
});
1374+
1375+
std::thread::sleep(std::time::Duration::from_millis(100));
1376+
db_interrupt.interrupt();
1377+
1378+
let result = rx.recv_timeout(std::time::Duration::from_secs(5)).unwrap();
1379+
assert!(result.is_err_and(|err| err.to_string().contains("INTERRUPT")));
1380+
Ok(())
1381+
}
1382+
1383+
#[test]
1384+
fn test_interrupt_on_dropped_db() {
1385+
let db = checked_memory_handle();
1386+
let db_interrupt = db.interrupt_handle();
1387+
1388+
drop(db);
1389+
db_interrupt.interrupt();
1390+
}
1391+
13401392
#[cfg(feature = "bundled")]
13411393
#[test]
13421394
fn test_version() -> Result<()> {

0 commit comments

Comments
 (0)