diff --git a/api-test/src/alpn.rs b/api-test/src/alpn.rs index 286bad4..8178a7d 100644 --- a/api-test/src/alpn.rs +++ b/api-test/src/alpn.rs @@ -73,6 +73,7 @@ where assert_eq!(&buf, b"hello"); t!(socket.write_all(b"world").await); + t!(socket.shutdown().await); }; block_on(f); }); diff --git a/api-test/src/client_server.rs b/api-test/src/client_server.rs index 6f7fa4e..b8b7b95 100644 --- a/api-test/src/client_server.rs +++ b/api-test/src/client_server.rs @@ -58,6 +58,7 @@ where assert_eq!(&buf, b"hello"); t!(socket.write_all(b"world").await); + t!(socket.shutdown().await); }; block_on(future); }) diff --git a/api-test/src/client_server_dyn.rs b/api-test/src/client_server_dyn.rs index 3f9c961..9c9142a 100644 --- a/api-test/src/client_server_dyn.rs +++ b/api-test/src/client_server_dyn.rs @@ -51,6 +51,7 @@ async fn test_client_server_dyn_impl( assert_eq!(&buf, b"hello"); t!(socket.write_all(b"world").await); + t!(socket.shutdown().await); }; block_on(future); }) diff --git a/api-test/src/google.rs b/api-test/src/google.rs index e3e6201..5f8fed5 100644 --- a/api-test/src/google.rs +++ b/api-test/src/google.rs @@ -31,7 +31,18 @@ async fn test_google_impl() { t!(tls_stream.write_all(b"GET / HTTP/1.0\r\n\r\n").await); let mut result = vec![]; - t!(tls_stream.read_to_end(&mut result).await); + let res = tls_stream.read_to_end(&mut result).await; + + // Google will not send close_notify and just close the connection. + // This means that they are not confirming to TLS exactly, that connections to google.com + // are vulnerable to truncation attacks and that we need to suppress error about this here. + match res { + Ok(_) => {} + Err(e) + if e.to_string() + .contains("peer closed connection without sending TLS close_notify") => {} + Err(e) => panic!("{}", e), + } println!("{}", String::from_utf8_lossy(&result)); assert!( diff --git a/api/src/async_as_sync.rs b/api/src/async_as_sync.rs index 7c8b9fa..66b681b 100644 --- a/api/src/async_as_sync.rs +++ b/api/src/async_as_sync.rs @@ -119,7 +119,7 @@ where /// API-implementation of wrapper stream. /// /// Wrapped object is always [`AsyncIoAsSyncIo`]. - type SyncWrapper: Read + Write + Unpin + Send + 'static; + type SyncWrapper: Read + Write + WriteShutdown + Unpin + Send + 'static; /// Which crates imlpements this? fn impl_info() -> ImplInfo; @@ -137,6 +137,47 @@ where fn get_alpn_protocol(w: &Self::SyncWrapper) -> anyhow::Result>>; } +/// Notify the writer that there will be no more data written. +/// In context of TLS providers, this is great time to send notify_close message. +pub trait WriteShutdown: Write { + /// Initiates or attempts to shut down this writer, returning when + /// the I/O connection has completely shut down. + /// + /// For example this is suitable for implementing shutdown of a + /// TLS connection or calling `TcpStream::shutdown` on a proxied connection. + /// Protocols sometimes need to flush out final pieces of data or otherwise + /// perform a graceful shutdown handshake, reading/writing more data as + /// appropriate. This method is the hook for such protocols to implement the + /// graceful shutdown logic. + /// + /// This `shutdown` method is required by implementers of the + /// `AsyncWrite` trait. Wrappers typically just want to proxy this call + /// through to the wrapped type, and base types will typically implement + /// shutdown logic here or just return `Ok(().into())`. Note that if you're + /// wrapping an underlying `AsyncWrite` a call to `shutdown` implies that + /// transitively the entire stream has been shut down. After your wrapper's + /// shutdown logic has been executed you should shut down the underlying + /// stream. + /// + /// Invocation of a `shutdown` implies an invocation of `flush`. Once this + /// method returns it implies that a flush successfully happened + /// before the shutdown happened. That is, callers don't need to call + /// `flush` before calling `shutdown`. They can rely that by calling + /// `shutdown` any pending buffered data will be written out. + /// + /// # Errors + /// + /// This function can return normal I/O errors through `Err`, described + /// above. Additionally this method may also render the underlying + /// `Write::write` method no longer usable (e.g. will return errors in the + /// future). It's recommended that once `shutdown` is called the + /// `write` method is no longer called. + fn shutdown(&mut self) -> Result<(), io::Error> { + self.flush()?; + Ok(()) + } +} + /// Implementation of `TlsStreamImpl` for APIs using synchronous I/O. pub struct TlsStreamOverSyncIo where @@ -270,7 +311,7 @@ where #[cfg(feature = "runtime-tokio")] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.get_mut() - .with_context_sync_to_async(cx, |stream| stream.stream.flush()) + .with_context_sync_to_async(cx, |stream| stream.stream.shutdown()) } #[cfg(feature = "runtime-async-std")] diff --git a/impl-native-tls/src/handshake.rs b/impl-native-tls/src/handshake.rs index e585a73..c4dde43 100644 --- a/impl-native-tls/src/handshake.rs +++ b/impl-native-tls/src/handshake.rs @@ -11,6 +11,8 @@ use tls_api::async_as_sync::AsyncIoAsSyncIo; use tls_api::spi::save_context; use tls_api::AsyncSocket; +use crate::stream::NativeTlsStream; + pub(crate) enum HandshakeFuture { Initial(F, AsyncIoAsSyncIo), MidHandshake(native_tls::MidHandshakeTlsStream>), @@ -36,7 +38,7 @@ where match mem::replace(self_mut, HandshakeFuture::Done) { HandshakeFuture::Initial(f, stream) => match f(stream) { Ok(stream) => { - return Poll::Ready(Ok(crate::TlsStream::new(stream))); + return Poll::Ready(Ok(crate::TlsStream::new(NativeTlsStream(stream)))); } Err(native_tls::HandshakeError::WouldBlock(mid)) => { *self_mut = HandshakeFuture::MidHandshake(mid); @@ -48,7 +50,7 @@ where }, HandshakeFuture::MidHandshake(stream) => match stream.handshake() { Ok(stream) => { - return Poll::Ready(Ok(crate::TlsStream::new(stream))); + return Poll::Ready(Ok(crate::TlsStream::new(NativeTlsStream(stream)))); } Err(native_tls::HandshakeError::WouldBlock(mid)) => { *self_mut = HandshakeFuture::MidHandshake(mid); diff --git a/impl-native-tls/src/stream.rs b/impl-native-tls/src/stream.rs index d89faf0..eb6ad3e 100644 --- a/impl-native-tls/src/stream.rs +++ b/impl-native-tls/src/stream.rs @@ -1,15 +1,18 @@ -use native_tls::TlsStream as native_tls_TlsStream; use std::fmt; +use std::io; +use std::io::Read; +use std::io::Write; use std::marker::PhantomData; use tls_api::async_as_sync::AsyncIoAsSyncIo; use tls_api::async_as_sync::AsyncWrapperOps; use tls_api::async_as_sync::TlsStreamOverSyncIo; +use tls_api::async_as_sync::WriteShutdown; use tls_api::spi_async_socket_impl_delegate; use tls_api::spi_tls_stream_over_sync_io_wrapper; use tls_api::AsyncSocket; use tls_api::ImplInfo; -spi_tls_stream_over_sync_io_wrapper!(TlsStream, native_tls_TlsStream); +spi_tls_stream_over_sync_io_wrapper!(TlsStream, NativeTlsStream); #[derive(Debug)] pub(crate) struct AsyncWrapperOpsImpl(PhantomData<(S, A)>) @@ -22,25 +25,79 @@ where S: fmt::Debug + Unpin + Send + 'static, A: AsyncSocket, { - type SyncWrapper = native_tls::TlsStream>; + type SyncWrapper = NativeTlsStream>; fn impl_info() -> ImplInfo { crate::info() } fn debug(w: &Self::SyncWrapper) -> &dyn fmt::Debug { - w + &w.0 } fn get_mut(w: &mut Self::SyncWrapper) -> &mut AsyncIoAsSyncIo { - w.get_mut() + w.0.get_mut() } fn get_ref(w: &Self::SyncWrapper) -> &AsyncIoAsSyncIo { - w.get_ref() + w.0.get_ref() } fn get_alpn_protocol(w: &Self::SyncWrapper) -> anyhow::Result>> { - w.negotiated_alpn().map_err(anyhow::Error::new) + w.0.negotiated_alpn().map_err(anyhow::Error::new) + } +} + +pub(crate) struct NativeTlsStream(pub(crate) native_tls::TlsStream); + +impl Write for NativeTlsStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } + + fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result { + self.0.write_vectored(bufs) + } + + fn write_all(&mut self, buf: &[u8]) -> io::Result<()> { + self.0.write_all(buf) + } + + fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> { + self.0.write_fmt(fmt) + } +} + +impl WriteShutdown for NativeTlsStream { + fn shutdown(&mut self) -> Result<(), io::Error> { + self.flush()?; + self.0.shutdown()?; + Ok(()) + } +} + +impl Read for NativeTlsStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } + + fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result { + self.0.read_vectored(bufs) + } + + fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { + self.0.read_to_end(buf) + } + + fn read_to_string(&mut self, buf: &mut String) -> io::Result { + self.0.read_to_string(buf) + } + + fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + self.0.read_exact(buf) } } diff --git a/impl-openssl/src/handshake.rs b/impl-openssl/src/handshake.rs index fa70224..f91a9ef 100644 --- a/impl-openssl/src/handshake.rs +++ b/impl-openssl/src/handshake.rs @@ -11,6 +11,8 @@ use tls_api::async_as_sync::AsyncIoAsSyncIo; use tls_api::spi::save_context; use tls_api::AsyncSocket; +use crate::stream::OpenSSLStream; + pub(crate) enum HandshakeFuture { Initial(F, AsyncIoAsSyncIo), MidHandshake(openssl::ssl::MidHandshakeSslStream>), @@ -36,7 +38,7 @@ where match mem::replace(self_mut, HandshakeFuture::Done) { HandshakeFuture::Initial(f, stream) => match f(stream) { Ok(stream) => { - return Poll::Ready(Ok(crate::TlsStream::new(stream))); + return Poll::Ready(Ok(crate::TlsStream::new(OpenSSLStream(stream)))); } Err(openssl::ssl::HandshakeError::WouldBlock(mid)) => { *self_mut = HandshakeFuture::MidHandshake(mid); @@ -51,7 +53,7 @@ where }, HandshakeFuture::MidHandshake(stream) => match stream.handshake() { Ok(stream) => { - return Poll::Ready(Ok(crate::TlsStream::new(stream))); + return Poll::Ready(Ok(crate::TlsStream::new(OpenSSLStream(stream)))); } Err(openssl::ssl::HandshakeError::WouldBlock(mid)) => { *self_mut = HandshakeFuture::MidHandshake(mid); diff --git a/impl-openssl/src/stream.rs b/impl-openssl/src/stream.rs index c22390a..8e4d9fe 100644 --- a/impl-openssl/src/stream.rs +++ b/impl-openssl/src/stream.rs @@ -1,4 +1,7 @@ use std::fmt; +use std::io; +use std::io::Read; +use std::io::Write; use std::marker::PhantomData; use openssl::ssl::SslRef; @@ -6,17 +9,18 @@ use openssl::ssl::SslStream; use tls_api::async_as_sync::AsyncIoAsSyncIo; use tls_api::async_as_sync::AsyncWrapperOps; use tls_api::async_as_sync::TlsStreamOverSyncIo; +use tls_api::async_as_sync::WriteShutdown; use tls_api::spi_async_socket_impl_delegate; use tls_api::spi_tls_stream_over_sync_io_wrapper; use tls_api::AsyncSocket; use tls_api::ImplInfo; -spi_tls_stream_over_sync_io_wrapper!(TlsStream, SslStream); +spi_tls_stream_over_sync_io_wrapper!(TlsStream, OpenSSLStream); impl TlsStream { /// Get the [`SslRef`] object for the stream. pub fn get_ssl_ref(&self) -> &SslRef { - self.0.stream.ssl() + self.0.stream.0.ssl() } } @@ -31,25 +35,82 @@ where S: fmt::Debug + Unpin + Send + 'static, A: AsyncSocket, { - type SyncWrapper = openssl::ssl::SslStream>; + type SyncWrapper = OpenSSLStream>; fn debug(w: &Self::SyncWrapper) -> &dyn fmt::Debug { - w + &w.0 } fn get_mut(w: &mut Self::SyncWrapper) -> &mut AsyncIoAsSyncIo { - w.get_mut() + w.0.get_mut() } fn get_ref(w: &Self::SyncWrapper) -> &AsyncIoAsSyncIo { - w.get_ref() + w.0.get_ref() } fn get_alpn_protocol(w: &Self::SyncWrapper) -> anyhow::Result>> { - Ok(w.ssl().selected_alpn_protocol().map(Vec::from)) + Ok(w.0.ssl().selected_alpn_protocol().map(Vec::from)) } fn impl_info() -> ImplInfo { crate::into() } } + +pub(crate) struct OpenSSLStream(pub(crate) SslStream); + +impl Write for OpenSSLStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } + + fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result { + self.0.write_vectored(bufs) + } + + fn write_all(&mut self, buf: &[u8]) -> io::Result<()> { + self.0.write_all(buf) + } + + fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> { + self.0.write_fmt(fmt) + } +} + +impl WriteShutdown for OpenSSLStream { + fn shutdown(&mut self) -> Result<(), io::Error> { + self.flush()?; + self.0.shutdown().map_err(|e| { + e.into_io_error() + .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e)) + })?; + Ok(()) + } +} + +impl Read for OpenSSLStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } + + fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result { + self.0.read_vectored(bufs) + } + + fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { + self.0.read_to_end(buf) + } + + fn read_to_string(&mut self, buf: &mut String) -> io::Result { + self.0.read_to_string(buf) + } + + fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + self.0.read_exact(buf) + } +} diff --git a/impl-rustls/src/rustls_utils.rs b/impl-rustls/src/rustls_utils.rs index 305a0e6..a935a41 100644 --- a/impl-rustls/src/rustls_utils.rs +++ b/impl-rustls/src/rustls_utils.rs @@ -7,6 +7,7 @@ use std::io::IoSlice; use std::io::IoSliceMut; use std::io::Read; use std::io::Write; +use tls_api::async_as_sync::WriteShutdown; pub enum RustlsSessionRef<'a> { Client(&'a ClientConnection), @@ -102,6 +103,17 @@ impl Write for RustlsStream { } } +impl WriteShutdown for RustlsStream { + fn shutdown(&mut self) -> Result<(), io::Error> { + match self { + RustlsStream::Server(s) => s.conn.send_close_notify(), + RustlsStream::Client(s) => s.conn.send_close_notify(), + } + self.flush()?; + Ok(()) + } +} + impl Read for RustlsStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { match self {