Skip to content

Commit 974a2c9

Browse files
committed
fix(pool): spawn task for before_acquire
1 parent cb4c975 commit 974a2c9

File tree

1 file changed

+63
-49
lines changed

1 file changed

+63
-49
lines changed

sqlx-core/src/pool/inner.rs

+63-49
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ use std::task::ready;
1313
use crate::logger::private_level_filter_to_trace_level;
1414
use crate::pool::connect::{ConnectPermit, ConnectionCounter, DynConnector};
1515
use crate::pool::idle::IdleQueue;
16-
use crate::private_tracing_dynamic_event;
16+
use crate::rt::JoinHandle;
17+
use crate::{private_tracing_dynamic_event, rt};
18+
use either::Either;
1719
use futures_util::future::{self, OptionFuture};
1820
use futures_util::FutureExt;
1921
use std::time::{Duration, Instant};
@@ -116,7 +118,7 @@ impl<DB: Database> PoolInner<DB> {
116118
let mut close_event = pin!(self.close_event());
117119
let mut deadline = pin!(crate::rt::sleep(self.options.acquire_timeout));
118120
let mut acquire_idle = pin!(self.idle.acquire(self).fuse());
119-
let mut check_idle = pin!(OptionFuture::from(None));
121+
let mut before_acquire = OptionFuture::from(None);
120122
let mut acquire_connect_permit = pin!(OptionFuture::from(Some(
121123
self.counter.acquire_permit(self).fuse()
122124
)));
@@ -144,21 +146,26 @@ impl<DB: Database> PoolInner<DB> {
144146

145147
// Attempt to acquire a connection from the idle queue.
146148
if let Ready(idle) = acquire_idle.poll_unpin(cx) {
147-
check_idle.set(Some(check_idle_conn(idle, &self.options)).into());
149+
// If we acquired an idle connection, run any checks that need to be done.
150+
//
151+
// Includes `test_on_acquire` and the `before_acquire` callback, if set.
152+
match finish_acquire(idle) {
153+
// There are checks needed to be done, so they're spawned as a task
154+
// to be cancellation-safe.
155+
Either::Left(check_task) => {
156+
before_acquire = Some(check_task).into();
157+
}
158+
// The connection is ready to go.
159+
Either::Right(conn) => {
160+
return Ready(Ok(conn));
161+
}
162+
}
148163
}
149164

150-
// If we acquired an idle connection, run any checks that need to be done.
151-
//
152-
// Includes `test_on_acquire` and the `before_acquire` callback, if set.
153-
//
154-
// We don't want to race this step if it's already running because canceling it
155-
// will result in the potentially unnecessary closure of a connection.
156-
//
157-
// Instead, we just wait and see what happens. If we already started connecting,
158-
// that'll happen concurrently.
159-
match ready!(check_idle.poll_unpin(cx)) {
165+
// Poll the task returned by `finish_acquire`
166+
match ready!(before_acquire.poll_unpin(cx)) {
160167
// The `.reattach()` call errors with "type annotations needed" if not qualified.
161-
Some(Ok(live)) => return Ready(Ok(Floating::reattach(live))),
168+
Some(Ok(conn)) => return Ready(Ok(conn)),
162169
Some(Err(permit)) => {
163170
// We don't strictly need to poll `connect` here; all we really want to do
164171
// is to check if it is `None`. But since currently there's no getter for that,
@@ -178,7 +185,7 @@ impl<DB: Database> PoolInner<DB> {
178185
// Attempt to acquire another idle connection concurrently to opening a new one.
179186
acquire_idle.set(self.idle.acquire(self).fuse());
180187
// Annoyingly, `OptionFuture` doesn't fuse to `None` on its own
181-
check_idle.set(None.into());
188+
before_acquire = None.into();
182189
}
183190
None => (),
184191
}
@@ -289,42 +296,49 @@ fn is_beyond_idle_timeout<DB: Database>(idle: &Idle<DB>, options: &PoolOptions<D
289296
.map_or(false, |timeout| idle.idle_since.elapsed() > timeout)
290297
}
291298

292-
async fn check_idle_conn<DB: Database>(
293-
mut conn: Floating<DB, Idle<DB>>,
294-
options: &PoolOptions<DB>,
295-
) -> Result<Floating<DB, Live<DB>>, ConnectPermit<DB>> {
296-
if options.test_before_acquire {
297-
// Check that the connection is still live
298-
if let Err(error) = conn.ping().await {
299-
// an error here means the other end has hung up or we lost connectivity
300-
// either way we're fine to just discard the connection
301-
// the error itself here isn't necessarily unexpected so WARN is too strong
302-
tracing::info!(%error, "ping on idle connection returned error");
303-
// connection is broken so don't try to close nicely
304-
return Err(conn.close_hard().await);
305-
}
306-
}
307-
308-
if let Some(test) = &options.before_acquire {
309-
let meta = conn.metadata();
310-
match test(&mut conn.live.raw, meta).await {
311-
Ok(false) => {
312-
// connection was rejected by user-defined hook, close nicely
313-
return Err(conn.close().await);
314-
}
315-
316-
Err(error) => {
317-
tracing::warn!(%error, "error from `before_acquire`");
299+
/// Execute `test_before_acquire` and/or
300+
fn finish_acquire<DB: Database>(
301+
mut conn: Floating<DB, Idle<DB>>
302+
) -> Either<JoinHandle<Result<PoolConnection<DB>, ConnectPermit<DB>>>, PoolConnection<DB>> {
303+
let pool = conn.permit.pool();
304+
305+
if pool.options.test_before_acquire || pool.options.before_acquire.is_some() {
306+
// Spawn a task so the call may complete even if `acquire()` is cancelled.
307+
return Either::Left(rt::spawn(async move {
308+
// Check that the connection is still live
309+
if let Err(error) = conn.ping().await {
310+
// an error here means the other end has hung up or we lost connectivity
311+
// either way we're fine to just discard the connection
312+
// the error itself here isn't necessarily unexpected so WARN is too strong
313+
tracing::info!(%error, "ping on idle connection returned error");
318314
// connection is broken so don't try to close nicely
319315
return Err(conn.close_hard().await);
320316
}
321317

322-
Ok(true) => {}
323-
}
324-
}
318+
if let Some(test) = &conn.permit.pool().options.before_acquire {
319+
let meta = conn.metadata();
320+
match test(&mut conn.inner.live.raw, meta).await {
321+
Ok(false) => {
322+
// connection was rejected by user-defined hook, close nicely
323+
return Err(conn.close().await);
324+
}
325325

326-
// No need to re-connect; connection is alive or we don't care
327-
Ok(conn.into_live())
326+
Err(error) => {
327+
tracing::warn!(%error, "error from `before_acquire`");
328+
// connection is broken so don't try to close nicely
329+
return Err(conn.close_hard().await);
330+
}
331+
332+
Ok(true) => {}
333+
}
334+
}
335+
336+
Ok(conn.into_live().reattach())
337+
}));
338+
}
339+
340+
// No checks are configured, return immediately.
341+
Either::Right(conn.into_live().reattach())
328342
}
329343

330344
fn spawn_maintenance_tasks<DB: Database>(pool: &Arc<PoolInner<DB>>) {
@@ -353,7 +367,7 @@ fn spawn_maintenance_tasks<DB: Database>(pool: &Arc<PoolInner<DB>>) {
353367
// Immediately cancel this task if the pool is closed.
354368
let mut close_event = pool.close_event();
355369

356-
crate::rt::spawn(async move {
370+
rt::spawn(async move {
357371
let _ = close_event
358372
.do_until(async {
359373
// If the last handle to the pool was dropped while we were sleeping
@@ -386,10 +400,10 @@ fn spawn_maintenance_tasks<DB: Database>(pool: &Arc<PoolInner<DB>>) {
386400

387401
if let Some(duration) = next_run.checked_duration_since(Instant::now()) {
388402
// `async-std` doesn't have a `sleep_until()`
389-
crate::rt::sleep(duration).await;
403+
rt::sleep(duration).await;
390404
} else {
391405
// `next_run` is in the past, just yield.
392-
crate::rt::yield_now().await;
406+
rt::yield_now().await;
393407
}
394408
}
395409
})

0 commit comments

Comments
 (0)