From 75aac9f47fe0246016e6133cd3cfa35b63c8904e Mon Sep 17 00:00:00 2001
From: Anthony Ramine <nox@nox.paris>
Date: Mon, 21 Mar 2022 15:22:07 +0100
Subject: [PATCH] fix(client): send an error back to client when dispatch
 misbehaves (fixes #2649)

---
 src/client/dispatch.rs | 56 +++++++++++++++++++++++++++++++-----------
 src/error.rs           | 11 +++++++++
 tests/client.rs        | 42 +++++++++++++++++++++++++++++++
 3 files changed, 94 insertions(+), 15 deletions(-)

diff --git a/src/client/dispatch.rs b/src/client/dispatch.rs
index 63bb0256d5..970fb243e4 100644
--- a/src/client/dispatch.rs
+++ b/src/client/dispatch.rs
@@ -90,7 +90,7 @@ impl<T, U> Sender<T, U> {
         }
         let (tx, rx) = oneshot::channel();
         self.inner
-            .send(Envelope(Some((val, Callback::Retry(tx)))))
+            .send(Envelope(Some((val, Callback::Retry(Some(tx))))))
             .map(move |_| rx)
             .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
     }
@@ -101,7 +101,7 @@ impl<T, U> Sender<T, U> {
         }
         let (tx, rx) = oneshot::channel();
         self.inner
-            .send(Envelope(Some((val, Callback::NoRetry(tx)))))
+            .send(Envelope(Some((val, Callback::NoRetry(Some(tx))))))
             .map(move |_| rx)
             .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
     }
@@ -131,7 +131,7 @@ impl<T, U> UnboundedSender<T, U> {
     pub(crate) fn try_send(&mut self, val: T) -> Result<RetryPromise<T, U>, T> {
         let (tx, rx) = oneshot::channel();
         self.inner
-            .send(Envelope(Some((val, Callback::Retry(tx)))))
+            .send(Envelope(Some((val, Callback::Retry(Some(tx))))))
             .map(move |_| rx)
             .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
     }
@@ -139,7 +139,7 @@ impl<T, U> UnboundedSender<T, U> {
     pub(crate) fn send(&mut self, val: T) -> Result<Promise<U>, T> {
         let (tx, rx) = oneshot::channel();
         self.inner
-            .send(Envelope(Some((val, Callback::NoRetry(tx)))))
+            .send(Envelope(Some((val, Callback::NoRetry(Some(tx))))))
             .map(move |_| rx)
             .map_err(|mut e| (e.0).0.take().expect("envelope not dropped").0)
     }
@@ -215,33 +215,59 @@ impl<T, U> Drop for Envelope<T, U> {
 
 pub(crate) enum Callback<T, U> {
     #[allow(unused)]
-    Retry(oneshot::Sender<Result<U, (crate::Error, Option<T>)>>),
-    NoRetry(oneshot::Sender<Result<U, crate::Error>>),
+    Retry(Option<oneshot::Sender<Result<U, (crate::Error, Option<T>)>>>),
+    NoRetry(Option<oneshot::Sender<Result<U, crate::Error>>>),
+}
+
+impl<T, U> Drop for Callback<T, U> {
+    fn drop(&mut self) {
+        // FIXME(nox): What errors do we want here?
+        let error = crate::Error::new_user_dispatch_gone().with(if std::thread::panicking() {
+            "user code panicked"
+        } else {
+            "runtime dropped the dispatch task"
+        });
+
+        match self {
+            Callback::Retry(tx) => {
+                if let Some(tx) = tx.take() {
+                    let _ = tx.send(Err((error, None)));
+                }
+            }
+            Callback::NoRetry(tx) => {
+                if let Some(tx) = tx.take() {
+                    let _ = tx.send(Err(error));
+                }
+            }
+        }
+    }
 }
 
 impl<T, U> Callback<T, U> {
     #[cfg(feature = "http2")]
     pub(crate) fn is_canceled(&self) -> bool {
         match *self {
-            Callback::Retry(ref tx) => tx.is_closed(),
-            Callback::NoRetry(ref tx) => tx.is_closed(),
+            Callback::Retry(Some(ref tx)) => tx.is_closed(),
+            Callback::NoRetry(Some(ref tx)) => tx.is_closed(),
+            _ => unreachable!(),
         }
     }
 
     pub(crate) fn poll_canceled(&mut self, cx: &mut task::Context<'_>) -> Poll<()> {
         match *self {
-            Callback::Retry(ref mut tx) => tx.poll_closed(cx),
-            Callback::NoRetry(ref mut tx) => tx.poll_closed(cx),
+            Callback::Retry(Some(ref mut tx)) => tx.poll_closed(cx),
+            Callback::NoRetry(Some(ref mut tx)) => tx.poll_closed(cx),
+            _ => unreachable!(),
         }
     }
 
-    pub(crate) fn send(self, val: Result<U, (crate::Error, Option<T>)>) {
+    pub(crate) fn send(mut self, val: Result<U, (crate::Error, Option<T>)>) {
         match self {
-            Callback::Retry(tx) => {
-                let _ = tx.send(val);
+            Callback::Retry(ref mut tx) => {
+                let _ = tx.take().unwrap().send(val);
             }
-            Callback::NoRetry(tx) => {
-                let _ = tx.send(val.map_err(|e| e.0));
+            Callback::NoRetry(ref mut tx) => {
+                let _ = tx.take().unwrap().send(val.map_err(|e| e.0));
             }
         }
     }
diff --git a/src/error.rs b/src/error.rs
index bc4414ab78..b07a22c409 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -110,6 +110,10 @@ pub(super) enum User {
     #[cfg(feature = "http1")]
     ManualUpgrade,
 
+    /// The dispatch task is gone.
+    #[cfg(feature = "client")]
+    DispatchGone,
+
     /// User aborted in an FFI callback.
     #[cfg(feature = "ffi")]
     AbortedByCallback,
@@ -314,6 +318,11 @@ impl Error {
         Error::new_user(User::AbortedByCallback)
     }
 
+    #[cfg(feature = "client")]
+    pub(super) fn new_user_dispatch_gone() -> Error {
+        Error::new(Kind::User(User::DispatchGone))
+    }
+
     #[cfg(feature = "http2")]
     pub(super) fn new_h2(cause: ::h2::Error) -> Error {
         if cause.is_io() {
@@ -390,6 +399,8 @@ impl Error {
             Kind::User(User::NoUpgrade) => "no upgrade available",
             #[cfg(feature = "http1")]
             Kind::User(User::ManualUpgrade) => "upgrade expected but low level API in use",
+            #[cfg(feature = "client")]
+            Kind::User(User::DispatchGone) => "dispatch task is gone",
             #[cfg(feature = "ffi")]
             Kind::User(User::AbortedByCallback) => "operation aborted by an application callback",
         }
diff --git a/tests/client.rs b/tests/client.rs
index 8968433885..4ade440439 100644
--- a/tests/client.rs
+++ b/tests/client.rs
@@ -2267,6 +2267,48 @@ mod conn {
         done_tx.send(()).unwrap();
     }
 
+    #[tokio::test]
+    async fn test_body_panics() {
+        let _ = pretty_env_logger::try_init();
+
+        let listener = TkTcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))
+            .await
+            .unwrap();
+        let addr = listener.local_addr().unwrap();
+
+        // spawn a server that reads but doesn't write
+        tokio::spawn(async move {
+            let sock = listener.accept().await.unwrap().0;
+            drain_til_eof(sock).await.expect("server read");
+        });
+
+        let io = tcp_connect(&addr).await.expect("tcp connect");
+
+        let (mut client, conn) = conn::http1::Builder::new()
+            .handshake(io)
+            .await
+            .expect("handshake");
+
+        tokio::spawn(async move {
+            conn.await.expect("client conn shouldn't error");
+        });
+
+        let req = Request::post("/a")
+            .body(http_body_util::BodyExt::map_frame::<_, bytes::Bytes>(
+                http_body_util::Full::<bytes::Bytes>::from("baguette"),
+                |_| panic!("oopsie"),
+            ))
+            .unwrap();
+
+        let error = client.send_request(req).await.unwrap_err();
+
+        assert!(error.is_user());
+        assert_eq!(
+            error.to_string(),
+            "dispatch task is gone: user code panicked"
+        );
+    }
+
     async fn drain_til_eof<T: AsyncRead + Unpin>(mut sock: T) -> io::Result<()> {
         let mut buf = [0u8; 1024];
         loop {