Skip to content

Commit d137ee5

Browse files
committed
Reapply drop PG connection when client is dropped
1 parent ba55965 commit d137ee5

File tree

2 files changed

+47
-30
lines changed

2 files changed

+47
-30
lines changed

src/postgres-util/src/tunnel.rs

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@ use std::time::Duration;
1515
use mz_ore::future::{InTask, OreFutureExt};
1616
use mz_ore::netio::DUMMY_DNS_PORT;
1717
use mz_ore::option::OptionExt;
18-
use mz_ore::task;
18+
use mz_ore::task::{self, AbortOnDropHandle};
1919
use mz_repr::CatalogItemId;
2020
use mz_ssh_util::tunnel::{SshTimeoutConfig, SshTunnelConfig};
2121
use mz_ssh_util::tunnel_manager::SshTunnelManager;
22-
use tokio::io::{AsyncRead, AsyncWrite};
2322
use tokio::net::TcpStream as TokioTcpStream;
2423
use tokio_postgres::config::{Host, ReplicationMode};
2524
use tokio_postgres::tls::MakeTlsConnect;
@@ -66,26 +65,13 @@ pub const DEFAULT_SNAPSHOT_STATEMENT_TIMEOUT: Duration = Duration::ZERO;
6665
pub struct Client {
6766
inner: tokio_postgres::Client,
6867
server_version: Option<String>,
68+
// Holds a handle to the task with the connection to ensure that when
69+
// the client is dropped, the task can be aborted to close the connection.
70+
// This is also useful for maintaining the lifetimes of dependent object (e.g. ssh tunnel).
71+
_connection_handle: AbortOnDropHandle<()>,
6972
}
7073

7174
impl Client {
72-
fn new<S, T>(
73-
client: tokio_postgres::Client,
74-
connection: &tokio_postgres::Connection<S, T>,
75-
) -> Client
76-
where
77-
S: AsyncRead + AsyncWrite + Unpin,
78-
T: AsyncRead + AsyncWrite + Unpin,
79-
{
80-
let server_version = connection
81-
.parameter("server_version")
82-
.map(|v| v.to_string());
83-
Client {
84-
inner: client,
85-
server_version,
86-
}
87-
}
88-
8975
/// Reports the value of the `server_version` parameter reported by the
9076
/// server.
9177
pub fn server_version(&self) -> Option<&str> {
@@ -247,8 +233,19 @@ impl Config {
247233
let (client, connection) = async move { postgres_config.connect(tls).await }
248234
.run_in_task_if(self.in_task, || "pg_connect".to_string())
249235
.await?;
250-
let client = Client::new(client, &connection);
251-
task::spawn(|| task_name, connection);
236+
237+
let client = Client {
238+
inner: client,
239+
server_version: connection
240+
.parameter("server_version")
241+
.map(|v| v.to_string()),
242+
_connection_handle: task::spawn(|| task_name, async {
243+
if let Err(e) = connection.await {
244+
warn!("postgres direct connection failed: {e}");
245+
}
246+
})
247+
.abort_on_drop(),
248+
};
252249
Ok(client)
253250
}
254251
TunnelConfig::Ssh { config } => {
@@ -279,14 +276,20 @@ impl Config {
279276
async move { postgres_config.connect_raw(tcp_stream, tls).await }
280277
.run_in_task_if(self.in_task, || "pg_connect".to_string())
281278
.await?;
282-
let client = Client::new(client, &connection);
283-
task::spawn(|| task_name, async {
284-
let _tunnel = tunnel; // Keep SSH tunnel alive for duration of connection.
285279

286-
if let Err(e) = connection.await {
287-
warn!("postgres connection failed: {e}");
288-
}
289-
});
280+
let client = Client {
281+
inner: client,
282+
server_version: connection
283+
.parameter("server_version")
284+
.map(|v| v.to_string()),
285+
_connection_handle: task::spawn(|| task_name, async {
286+
let _tunnel = tunnel; // Keep SSH tunnel alive for duration of connection.
287+
if let Err(e) = connection.await {
288+
warn!("postgres via SSH tunnel connection failed: {e}");
289+
}
290+
})
291+
.abort_on_drop(),
292+
};
290293
Ok(client)
291294
}
292295
TunnelConfig::AwsPrivatelink { connection_id } => {
@@ -322,8 +325,19 @@ impl Config {
322325
let (client, connection) = async move { postgres_config.connect(tls).await }
323326
.run_in_task_if(self.in_task, || "pg_connect".to_string())
324327
.await?;
325-
let client = Client::new(client, &connection);
326-
task::spawn(|| task_name, connection);
328+
329+
let client = Client {
330+
inner: client,
331+
server_version: connection
332+
.parameter("server_version")
333+
.map(|v| v.to_string()),
334+
_connection_handle: task::spawn(|| task_name, async {
335+
if let Err(e) = connection.await {
336+
warn!("postgres AWS link connection failed: {e}");
337+
}
338+
})
339+
.abort_on_drop(),
340+
};
327341
Ok(client)
328342
}
329343
}

src/storage/src/source/postgres/replication.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,9 @@ async fn raw_stream<'a>(
777777
// Ensure we don't pre-drop the task
778778
let _max_lsn_task_handle = max_lsn_task_handle;
779779

780+
// ensure we don't drop the replication client!
781+
let _replication_client = replication_client;
782+
780783
let mut uppers = pin!(uppers);
781784
let mut last_committed_upper = resume_lsn;
782785

0 commit comments

Comments
 (0)