Skip to content

Commit 5784733

Browse files
committed
fix(pool): spawn task for before_acquire
1 parent cb4c975 commit 5784733

File tree

1 file changed

+66
-51
lines changed

1 file changed

+66
-51
lines changed

sqlx-core/src/pool/inner.rs

+66-51
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ use std::task::ready;
1313
use crate::logger::private_level_filter_to_trace_level;
1414
use crate::pool::connect::{ConnectPermit, ConnectionCounter, DynConnector};
1515
use crate::pool::idle::IdleQueue;
16-
use crate::private_tracing_dynamic_event;
16+
use crate::rt::JoinHandle;
17+
use crate::{private_tracing_dynamic_event, rt};
18+
use either::Either;
1719
use futures_util::future::{self, OptionFuture};
1820
use futures_util::FutureExt;
1921
use std::time::{Duration, Instant};
@@ -116,7 +118,7 @@ impl<DB: Database> PoolInner<DB> {
116118
let mut close_event = pin!(self.close_event());
117119
let mut deadline = pin!(crate::rt::sleep(self.options.acquire_timeout));
118120
let mut acquire_idle = pin!(self.idle.acquire(self).fuse());
119-
let mut check_idle = pin!(OptionFuture::from(None));
121+
let mut before_acquire = OptionFuture::from(None);
120122
let mut acquire_connect_permit = pin!(OptionFuture::from(Some(
121123
self.counter.acquire_permit(self).fuse()
122124
)));
@@ -144,21 +146,25 @@ impl<DB: Database> PoolInner<DB> {
144146

145147
// Attempt to acquire a connection from the idle queue.
146148
if let Ready(idle) = acquire_idle.poll_unpin(cx) {
147-
check_idle.set(Some(check_idle_conn(idle, &self.options)).into());
149+
// If we acquired an idle connection, run any checks that need to be done.
150+
//
151+
// Includes `test_on_acquire` and the `before_acquire` callback, if set.
152+
match finish_acquire(idle) {
153+
// There are checks needed to be done, so they're spawned as a task
154+
// to be cancellation-safe.
155+
Either::Left(check_task) => {
156+
before_acquire = Some(check_task).into();
157+
}
158+
// The connection is ready to go.
159+
Either::Right(conn) => {
160+
return Ready(Ok(conn));
161+
}
162+
}
148163
}
149164

150-
// If we acquired an idle connection, run any checks that need to be done.
151-
//
152-
// Includes `test_on_acquire` and the `before_acquire` callback, if set.
153-
//
154-
// We don't want to race this step if it's already running because canceling it
155-
// will result in the potentially unnecessary closure of a connection.
156-
//
157-
// Instead, we just wait and see what happens. If we already started connecting,
158-
// that'll happen concurrently.
159-
match ready!(check_idle.poll_unpin(cx)) {
160-
// The `.reattach()` call errors with "type annotations needed" if not qualified.
161-
Some(Ok(live)) => return Ready(Ok(Floating::reattach(live))),
165+
// Poll the task returned by `finish_acquire`
166+
match ready!(before_acquire.poll_unpin(cx)) {
167+
Some(Ok(conn)) => return Ready(Ok(conn)),
162168
Some(Err(permit)) => {
163169
// We don't strictly need to poll `connect` here; all we really want to do
164170
// is to check if it is `None`. But since currently there's no getter for that,
@@ -178,7 +184,7 @@ impl<DB: Database> PoolInner<DB> {
178184
// Attempt to acquire another idle connection concurrently to opening a new one.
179185
acquire_idle.set(self.idle.acquire(self).fuse());
180186
// Annoyingly, `OptionFuture` doesn't fuse to `None` on its own
181-
check_idle.set(None.into());
187+
before_acquire = None.into();
182188
}
183189
None => (),
184190
}
@@ -289,42 +295,51 @@ fn is_beyond_idle_timeout<DB: Database>(idle: &Idle<DB>, options: &PoolOptions<D
289295
.map_or(false, |timeout| idle.idle_since.elapsed() > timeout)
290296
}
291297

292-
async fn check_idle_conn<DB: Database>(
293-
mut conn: Floating<DB, Idle<DB>>,
294-
options: &PoolOptions<DB>,
295-
) -> Result<Floating<DB, Live<DB>>, ConnectPermit<DB>> {
296-
if options.test_before_acquire {
297-
// Check that the connection is still live
298-
if let Err(error) = conn.ping().await {
299-
// an error here means the other end has hung up or we lost connectivity
300-
// either way we're fine to just discard the connection
301-
// the error itself here isn't necessarily unexpected so WARN is too strong
302-
tracing::info!(%error, "ping on idle connection returned error");
303-
// connection is broken so don't try to close nicely
304-
return Err(conn.close_hard().await);
305-
}
306-
}
307-
308-
if let Some(test) = &options.before_acquire {
309-
let meta = conn.metadata();
310-
match test(&mut conn.live.raw, meta).await {
311-
Ok(false) => {
312-
// connection was rejected by user-defined hook, close nicely
313-
return Err(conn.close().await);
314-
}
315-
316-
Err(error) => {
317-
tracing::warn!(%error, "error from `before_acquire`");
298+
/// Execute `test_before_acquire` and/or `before_acquire` in a background task, if applicable.
299+
///
300+
/// Otherwise, immediately returns the connection.
301+
fn finish_acquire<DB: Database>(
302+
mut conn: Floating<DB, Idle<DB>>
303+
) -> Either<JoinHandle<Result<PoolConnection<DB>, ConnectPermit<DB>>>, PoolConnection<DB>> {
304+
let pool = conn.permit.pool();
305+
306+
if pool.options.test_before_acquire || pool.options.before_acquire.is_some() {
307+
// Spawn a task so the call may complete even if `acquire()` is cancelled.
308+
return Either::Left(rt::spawn(async move {
309+
// Check that the connection is still live
310+
if let Err(error) = conn.ping().await {
311+
// an error here means the other end has hung up or we lost connectivity
312+
// either way we're fine to just discard the connection
313+
// the error itself here isn't necessarily unexpected so WARN is too strong
314+
tracing::info!(%error, "ping on idle connection returned error");
318315
// connection is broken so don't try to close nicely
319316
return Err(conn.close_hard().await);
320317
}
321318

322-
Ok(true) => {}
323-
}
324-
}
319+
if let Some(test) = &conn.permit.pool().options.before_acquire {
320+
let meta = conn.metadata();
321+
match test(&mut conn.inner.live.raw, meta).await {
322+
Ok(false) => {
323+
// connection was rejected by user-defined hook, close nicely
324+
return Err(conn.close().await);
325+
}
325326

326-
// No need to re-connect; connection is alive or we don't care
327-
Ok(conn.into_live())
327+
Err(error) => {
328+
tracing::warn!(%error, "error from `before_acquire`");
329+
// connection is broken so don't try to close nicely
330+
return Err(conn.close_hard().await);
331+
}
332+
333+
Ok(true) => {}
334+
}
335+
}
336+
337+
Ok(conn.into_live().reattach())
338+
}));
339+
}
340+
341+
// No checks are configured, return immediately.
342+
Either::Right(conn.into_live().reattach())
328343
}
329344

330345
fn spawn_maintenance_tasks<DB: Database>(pool: &Arc<PoolInner<DB>>) {
@@ -339,7 +354,7 @@ fn spawn_maintenance_tasks<DB: Database>(pool: &Arc<PoolInner<DB>>) {
339354

340355
(None, None) => {
341356
if pool.options.min_connections > 0 {
342-
crate::rt::spawn(async move {
357+
rt::spawn(async move {
343358
if let Some(pool) = pool_weak.upgrade() {
344359
pool.min_connections_maintenance(None).await;
345360
}
@@ -353,7 +368,7 @@ fn spawn_maintenance_tasks<DB: Database>(pool: &Arc<PoolInner<DB>>) {
353368
// Immediately cancel this task if the pool is closed.
354369
let mut close_event = pool.close_event();
355370

356-
crate::rt::spawn(async move {
371+
rt::spawn(async move {
357372
let _ = close_event
358373
.do_until(async {
359374
// If the last handle to the pool was dropped while we were sleeping
@@ -386,10 +401,10 @@ fn spawn_maintenance_tasks<DB: Database>(pool: &Arc<PoolInner<DB>>) {
386401

387402
if let Some(duration) = next_run.checked_duration_since(Instant::now()) {
388403
// `async-std` doesn't have a `sleep_until()`
389-
crate::rt::sleep(duration).await;
404+
rt::sleep(duration).await;
390405
} else {
391406
// `next_run` is in the past, just yield.
392-
crate::rt::yield_now().await;
407+
rt::yield_now().await;
393408
}
394409
}
395410
})

0 commit comments

Comments
 (0)