@@ -15,11 +15,10 @@ use std::time::Duration;
1515use mz_ore:: future:: { InTask , OreFutureExt } ;
1616use mz_ore:: netio:: DUMMY_DNS_PORT ;
1717use mz_ore:: option:: OptionExt ;
18- use mz_ore:: task;
18+ use mz_ore:: task:: { self , AbortOnDropHandle } ;
1919use mz_repr:: CatalogItemId ;
2020use mz_ssh_util:: tunnel:: { SshTimeoutConfig , SshTunnelConfig } ;
2121use mz_ssh_util:: tunnel_manager:: SshTunnelManager ;
22- use tokio:: io:: { AsyncRead , AsyncWrite } ;
2322use tokio:: net:: TcpStream as TokioTcpStream ;
2423use tokio_postgres:: config:: { Host , ReplicationMode } ;
2524use tokio_postgres:: tls:: MakeTlsConnect ;
@@ -66,26 +65,13 @@ pub const DEFAULT_SNAPSHOT_STATEMENT_TIMEOUT: Duration = Duration::ZERO;
6665pub struct Client {
6766 inner : tokio_postgres:: Client ,
6867 server_version : Option < String > ,
68+ // Holds a handle to the task with the connection to ensure that when
69+ // the client is dropped, the task can be aborted to close the connection.
70+ // This is also useful for maintaining the lifetimes of dependent object (e.g. ssh tunnel).
71+ _connection_handle : AbortOnDropHandle < ( ) > ,
6972}
7073
7174impl Client {
72- fn new < S , T > (
73- client : tokio_postgres:: Client ,
74- connection : & tokio_postgres:: Connection < S , T > ,
75- ) -> Client
76- where
77- S : AsyncRead + AsyncWrite + Unpin ,
78- T : AsyncRead + AsyncWrite + Unpin ,
79- {
80- let server_version = connection
81- . parameter ( "server_version" )
82- . map ( |v| v. to_string ( ) ) ;
83- Client {
84- inner : client,
85- server_version,
86- }
87- }
88-
8975 /// Reports the value of the `server_version` parameter reported by the
9076 /// server.
9177 pub fn server_version ( & self ) -> Option < & str > {
@@ -247,8 +233,19 @@ impl Config {
247233 let ( client, connection) = async move { postgres_config. connect ( tls) . await }
248234 . run_in_task_if ( self . in_task , || "pg_connect" . to_string ( ) )
249235 . await ?;
250- let client = Client :: new ( client, & connection) ;
251- task:: spawn ( || task_name, connection) ;
236+
237+ let client = Client {
238+ inner : client,
239+ server_version : connection
240+ . parameter ( "server_version" )
241+ . map ( |v| v. to_string ( ) ) ,
242+ _connection_handle : task:: spawn ( || task_name, async {
243+ if let Err ( e) = connection. await {
244+ warn ! ( "postgres direct connection failed: {e}" ) ;
245+ }
246+ } )
247+ . abort_on_drop ( ) ,
248+ } ;
252249 Ok ( client)
253250 }
254251 TunnelConfig :: Ssh { config } => {
@@ -279,14 +276,20 @@ impl Config {
279276 async move { postgres_config. connect_raw ( tcp_stream, tls) . await }
280277 . run_in_task_if ( self . in_task , || "pg_connect" . to_string ( ) )
281278 . await ?;
282- let client = Client :: new ( client, & connection) ;
283- task:: spawn ( || task_name, async {
284- let _tunnel = tunnel; // Keep SSH tunnel alive for duration of connection.
285279
286- if let Err ( e) = connection. await {
287- warn ! ( "postgres connection failed: {e}" ) ;
288- }
289- } ) ;
280+ let client = Client {
281+ inner : client,
282+ server_version : connection
283+ . parameter ( "server_version" )
284+ . map ( |v| v. to_string ( ) ) ,
285+ _connection_handle : task:: spawn ( || task_name, async {
286+ let _tunnel = tunnel; // Keep SSH tunnel alive for duration of connection.
287+ if let Err ( e) = connection. await {
288+ warn ! ( "postgres via SSH tunnel connection failed: {e}" ) ;
289+ }
290+ } )
291+ . abort_on_drop ( ) ,
292+ } ;
290293 Ok ( client)
291294 }
292295 TunnelConfig :: AwsPrivatelink { connection_id } => {
@@ -322,8 +325,19 @@ impl Config {
322325 let ( client, connection) = async move { postgres_config. connect ( tls) . await }
323326 . run_in_task_if ( self . in_task , || "pg_connect" . to_string ( ) )
324327 . await ?;
325- let client = Client :: new ( client, & connection) ;
326- task:: spawn ( || task_name, connection) ;
328+
329+ let client = Client {
330+ inner : client,
331+ server_version : connection
332+ . parameter ( "server_version" )
333+ . map ( |v| v. to_string ( ) ) ,
334+ _connection_handle : task:: spawn ( || task_name, async {
335+ if let Err ( e) = connection. await {
336+ warn ! ( "postgres AWS link connection failed: {e}" ) ;
337+ }
338+ } )
339+ . abort_on_drop ( ) ,
340+ } ;
327341 Ok ( client)
328342 }
329343 }
0 commit comments