Skip to content

Commit a5a4053

Browse files
committed
fix(pool): spawn task for before_acquire
1 parent cb4c975 commit a5a4053

File tree

1 file changed

+65
-49
lines changed

1 file changed

+65
-49
lines changed

sqlx-core/src/pool/inner.rs

+65-49
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ use std::task::ready;
1313
use crate::logger::private_level_filter_to_trace_level;
1414
use crate::pool::connect::{ConnectPermit, ConnectionCounter, DynConnector};
1515
use crate::pool::idle::IdleQueue;
16-
use crate::private_tracing_dynamic_event;
16+
use crate::rt::JoinHandle;
17+
use crate::{private_tracing_dynamic_event, rt};
18+
use either::Either;
1719
use futures_util::future::{self, OptionFuture};
1820
use futures_util::FutureExt;
1921
use std::time::{Duration, Instant};
@@ -116,7 +118,7 @@ impl<DB: Database> PoolInner<DB> {
116118
let mut close_event = pin!(self.close_event());
117119
let mut deadline = pin!(crate::rt::sleep(self.options.acquire_timeout));
118120
let mut acquire_idle = pin!(self.idle.acquire(self).fuse());
119-
let mut check_idle = pin!(OptionFuture::from(None));
121+
let mut before_acquire = OptionFuture::from(None);
120122
let mut acquire_connect_permit = pin!(OptionFuture::from(Some(
121123
self.counter.acquire_permit(self).fuse()
122124
)));
@@ -144,21 +146,26 @@ impl<DB: Database> PoolInner<DB> {
144146

145147
// Attempt to acquire a connection from the idle queue.
146148
if let Ready(idle) = acquire_idle.poll_unpin(cx) {
147-
check_idle.set(Some(check_idle_conn(idle, &self.options)).into());
149+
// If we acquired an idle connection, run any checks that need to be done.
150+
//
151+
// Includes `test_on_acquire` and the `before_acquire` callback, if set.
152+
match finish_acquire(idle) {
153+
// There are checks needed to be done, so they're spawned as a task
154+
// to be cancellation-safe.
155+
Either::Left(check_task) => {
156+
before_acquire = Some(check_task).into();
157+
}
158+
// The connection is ready to go.
159+
Either::Right(conn) => {
160+
return Ready(Ok(conn));
161+
}
162+
}
148163
}
149164

150-
// If we acquired an idle connection, run any checks that need to be done.
151-
//
152-
// Includes `test_on_acquire` and the `before_acquire` callback, if set.
153-
//
154-
// We don't want to race this step if it's already running because canceling it
155-
// will result in the potentially unnecessary closure of a connection.
156-
//
157-
// Instead, we just wait and see what happens. If we already started connecting,
158-
// that'll happen concurrently.
159-
match ready!(check_idle.poll_unpin(cx)) {
165+
// Poll the task returned by `finish_acquire`
166+
match ready!(before_acquire.poll_unpin(cx)) {
160167
// The `.reattach()` call errors with "type annotations needed" if not qualified.
161-
Some(Ok(live)) => return Ready(Ok(Floating::reattach(live))),
168+
Some(Ok(conn)) => return Ready(Ok(conn)),
162169
Some(Err(permit)) => {
163170
// We don't strictly need to poll `connect` here; all we really want to do
164171
// is to check if it is `None`. But since currently there's no getter for that,
@@ -178,7 +185,7 @@ impl<DB: Database> PoolInner<DB> {
178185
// Attempt to acquire another idle connection concurrently to opening a new one.
179186
acquire_idle.set(self.idle.acquire(self).fuse());
180187
// Annoyingly, `OptionFuture` doesn't fuse to `None` on its own
181-
check_idle.set(None.into());
188+
before_acquire = None.into();
182189
}
183190
None => (),
184191
}
@@ -289,42 +296,51 @@ fn is_beyond_idle_timeout<DB: Database>(idle: &Idle<DB>, options: &PoolOptions<D
289296
.map_or(false, |timeout| idle.idle_since.elapsed() > timeout)
290297
}
291298

292-
async fn check_idle_conn<DB: Database>(
293-
mut conn: Floating<DB, Idle<DB>>,
294-
options: &PoolOptions<DB>,
295-
) -> Result<Floating<DB, Live<DB>>, ConnectPermit<DB>> {
296-
if options.test_before_acquire {
297-
// Check that the connection is still live
298-
if let Err(error) = conn.ping().await {
299-
// an error here means the other end has hung up or we lost connectivity
300-
// either way we're fine to just discard the connection
301-
// the error itself here isn't necessarily unexpected so WARN is too strong
302-
tracing::info!(%error, "ping on idle connection returned error");
303-
// connection is broken so don't try to close nicely
304-
return Err(conn.close_hard().await);
305-
}
306-
}
307-
308-
if let Some(test) = &options.before_acquire {
309-
let meta = conn.metadata();
310-
match test(&mut conn.live.raw, meta).await {
311-
Ok(false) => {
312-
// connection was rejected by user-defined hook, close nicely
313-
return Err(conn.close().await);
314-
}
315-
316-
Err(error) => {
317-
tracing::warn!(%error, "error from `before_acquire`");
299+
/// Execute `test_before_acquire` and/or `before_acquire` in a background task, if applicable.
300+
///
301+
/// Otherwise, immediately returns the connection.
302+
fn finish_acquire<DB: Database>(
303+
mut conn: Floating<DB, Idle<DB>>
304+
) -> Either<JoinHandle<Result<PoolConnection<DB>, ConnectPermit<DB>>>, PoolConnection<DB>> {
305+
let pool = conn.permit.pool();
306+
307+
if pool.options.test_before_acquire || pool.options.before_acquire.is_some() {
308+
// Spawn a task so the call may complete even if `acquire()` is cancelled.
309+
return Either::Left(rt::spawn(async move {
310+
// Check that the connection is still live
311+
if let Err(error) = conn.ping().await {
312+
// an error here means the other end has hung up or we lost connectivity
313+
// either way we're fine to just discard the connection
314+
// the error itself here isn't necessarily unexpected so WARN is too strong
315+
tracing::info!(%error, "ping on idle connection returned error");
318316
// connection is broken so don't try to close nicely
319317
return Err(conn.close_hard().await);
320318
}
321319

322-
Ok(true) => {}
323-
}
324-
}
320+
if let Some(test) = &conn.permit.pool().options.before_acquire {
321+
let meta = conn.metadata();
322+
match test(&mut conn.inner.live.raw, meta).await {
323+
Ok(false) => {
324+
// connection was rejected by user-defined hook, close nicely
325+
return Err(conn.close().await);
326+
}
325327

326-
// No need to re-connect; connection is alive or we don't care
327-
Ok(conn.into_live())
328+
Err(error) => {
329+
tracing::warn!(%error, "error from `before_acquire`");
330+
// connection is broken so don't try to close nicely
331+
return Err(conn.close_hard().await);
332+
}
333+
334+
Ok(true) => {}
335+
}
336+
}
337+
338+
Ok(conn.into_live().reattach())
339+
}));
340+
}
341+
342+
// No checks are configured, return immediately.
343+
Either::Right(conn.into_live().reattach())
328344
}
329345

330346
fn spawn_maintenance_tasks<DB: Database>(pool: &Arc<PoolInner<DB>>) {
@@ -353,7 +369,7 @@ fn spawn_maintenance_tasks<DB: Database>(pool: &Arc<PoolInner<DB>>) {
353369
// Immediately cancel this task if the pool is closed.
354370
let mut close_event = pool.close_event();
355371

356-
crate::rt::spawn(async move {
372+
rt::spawn(async move {
357373
let _ = close_event
358374
.do_until(async {
359375
// If the last handle to the pool was dropped while we were sleeping
@@ -386,10 +402,10 @@ fn spawn_maintenance_tasks<DB: Database>(pool: &Arc<PoolInner<DB>>) {
386402

387403
if let Some(duration) = next_run.checked_duration_since(Instant::now()) {
388404
// `async-std` doesn't have a `sleep_until()`
389-
crate::rt::sleep(duration).await;
405+
rt::sleep(duration).await;
390406
} else {
391407
// `next_run` is in the past, just yield.
392-
crate::rt::yield_now().await;
408+
rt::yield_now().await;
393409
}
394410
}
395411
})

0 commit comments

Comments
 (0)