Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions quinn/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,24 @@ impl Connection {
.clone()
}

/// Wait for the connection to be closed without keeping a strong reference to the connection
///
/// Returns a future that resolves, once the connection is closed, to a tuple of
/// ([`ConnectionError`], [`ConnectionStats`]).
///
/// Calling [`Self::closed`] keeps the connection alive until it is either closed locally via [`Connection::close`]
/// or closed by the remote peer. This function instead does not keep the connection itself alive,
/// so if all *other* clones of the connection are dropped, the connection will be closed implicitly even
/// if there are futures returned from this function still being awaited.
pub fn on_closed(&self) -> OnClosed {
let (tx, rx) = oneshot::channel();
self.0.state.lock("on_closed").on_closed.push(tx);
OnClosed {
conn: self.weak_handle(),
rx,
}
}

/// If the connection is closed, the reason why.
///
/// Returns `None` if the connection is still open.
Expand Down Expand Up @@ -1037,6 +1055,43 @@ impl Future for SendDatagram<'_> {
}
}

/// Future returned by [`Connection::on_closed`]
///
/// Resolves to a tuple of ([`ConnectionError`], [`ConnectionStats`]).
pub struct OnClosed {
rx: oneshot::Receiver<(ConnectionError, ConnectionStats)>,
conn: WeakConnectionHandle,
}

impl Drop for OnClosed {
fn drop(&mut self) {
if self.rx.is_terminated() {
return;
};
if let Some(conn) = self.conn.upgrade() {
self.rx.close();
conn.0
.state
.lock("OnClosed::drop")
.on_closed
.retain(|tx| !tx.is_closed());
}
}
}

impl Future for OnClosed {
type Output = (ConnectionError, ConnectionStats);

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
// The `expect` is safe because `State::drop` ensures that all senders are triggered
// before being dropped.
Pin::new(&mut this.rx)
.poll(cx)
.map(|x| x.expect("on_close sender is never dropped before sending"))
}
}

#[derive(Debug)]
pub(crate) struct ConnectionRef(Arc<ConnectionInner>);

Expand Down Expand Up @@ -1077,6 +1132,7 @@ impl ConnectionRef {
send_buffer: Vec::new(),
buffered_transmit: None,
observed_external_addr: watch::Sender::new(None),
on_closed: Vec::new(),
}),
shared: Shared::default(),
}))
Expand Down Expand Up @@ -1215,6 +1271,7 @@ pub(crate) struct State {
/// Our last external address reported by the peer. When multipath is enabled, this will be the
/// last report across all paths.
pub(crate) observed_external_addr: watch::Sender<Option<SocketAddr>>,
on_closed: Vec<oneshot::Sender<(ConnectionError, ConnectionStats)>>,
}

impl State {
Expand Down Expand Up @@ -1475,6 +1532,12 @@ impl State {
}
wake_all_notify(&mut self.stopped);
shared.closed.notify_waiters();

// Send to the registered on_closed futures.
let stats = self.inner.stats();
for tx in self.on_closed.drain(..) {
tx.send((reason.clone(), stats.clone())).ok();
}
}

fn close(&mut self, error_code: VarInt, reason: Bytes, shared: &Shared) {
Expand Down Expand Up @@ -1508,6 +1571,15 @@ impl Drop for State {
.endpoint_events
.send((self.handle, proto::EndpointEvent::drained()));
}

if !self.on_closed.is_empty() {
// Ensure that all on_closed oneshot senders are triggered before dropping.
let reason = self.error.as_ref().expect("closed without error reason");
let stats = self.inner.stats();
for tx in self.on_closed.drain(..) {
tx.send((reason.clone(), stats.clone())).ok();
}
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions quinn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ pub use rustls;
pub use udp;

pub use crate::connection::{
AcceptBi, AcceptUni, Connecting, Connection, OpenBi, OpenUni, ReadDatagram, SendDatagram,
SendDatagramError, WeakConnectionHandle, ZeroRttAccepted,
AcceptBi, AcceptUni, Connecting, Connection, OnClosed, OpenBi, OpenUni, ReadDatagram,
SendDatagram, SendDatagramError, WeakConnectionHandle, ZeroRttAccepted,
};
pub use crate::endpoint::{Accept, Endpoint, EndpointStats};
pub use crate::incoming::{Incoming, IncomingFuture, RetryError};
Expand Down
97 changes: 96 additions & 1 deletion quinn/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::{
use crate::runtime::TokioRuntime;
use crate::{Duration, Instant};
use bytes::Bytes;
use proto::{RandomConnectionIdGenerator, crypto::rustls::QuicClientConfig};
use proto::{ConnectionError, RandomConnectionIdGenerator, crypto::rustls::QuicClientConfig};
use rand::{RngCore, SeedableRng, rngs::StdRng};
use rustls::{
RootCertStore,
Expand Down Expand Up @@ -1023,3 +1023,98 @@ async fn test_multipath_observed_address() {

tokio::join!(server_task, client_task);
}

#[tokio::test]
async fn on_closed() {
let _guard = subscribe();
let endpoint = endpoint();
let endpoint2 = endpoint.clone();
let server_task = tokio::spawn(async move {
let conn = endpoint2
.accept()
.await
.expect("endpoint")
.await
.expect("connection");
let on_closed = conn.on_closed();
let cause = conn.closed().await;
let (cause1, _stats) = on_closed.await;
assert!(matches!(cause, ConnectionError::ApplicationClosed(_)));
assert!(matches!(cause1, ConnectionError::ApplicationClosed(_)));
});
let client_task = tokio::spawn(async move {
let conn = endpoint
.connect(endpoint.local_addr().unwrap(), "localhost")
.unwrap()
.await
.expect("connect");
let on_closed1 = conn.on_closed();
let on_closed2 = conn.on_closed();
drop(conn);

let (cause, _stats) = on_closed1.await;
assert_eq!(cause, ConnectionError::LocallyClosed);
let (cause, _stats) = on_closed2.await;
assert_eq!(cause, ConnectionError::LocallyClosed);
});
let (server_res, client_res) = tokio::join!(server_task, client_task);
server_res.expect("server task panicked");
client_res.expect("client task panicked");
}

#[tokio::test]
async fn on_closed_endpoint_drop() {
let _guard = subscribe();
let factory = EndpointFactory::new();
let client = factory.endpoint("client");
let server = factory.endpoint("server");
let server_addr = server.local_addr().unwrap();
let server_task = tokio::time::timeout(
Duration::from_millis(500),
tokio::spawn(async move {
let conn = server
.accept()
.await
.expect("endpoint")
.await
.expect("accept");
println!("accepted");
let on_closed = conn.on_closed();
drop(conn);
drop(server);
let (cause, _stats) = on_closed.await;
// Depending on timing we might have received a close frame or not.
assert!(matches!(
cause,
ConnectionError::ApplicationClosed(_) | ConnectionError::LocallyClosed
));
}),
);
let client_task = tokio::time::timeout(
Duration::from_millis(500),
tokio::spawn(async move {
let conn = client
.connect(server_addr, "localhost")
.unwrap()
.await
.expect("connect");
println!("connected");
let on_closed = conn.on_closed();
drop(conn);
drop(client);
let (cause, _stats) = on_closed.await;
// Depending on timing we might have received a close frame or not.
assert!(matches!(
cause,
ConnectionError::ApplicationClosed(_) | ConnectionError::LocallyClosed
));
}),
);
let (server_res, client_res) = tokio::join!(server_task, client_task);
server_res
.expect("server timeout")
.expect("server task panicked");
client_res
.expect("client timeout")
.expect("client task panicked");
}
Loading