Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api-test/src/alpn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ where
assert_eq!(&buf, b"hello");

t!(socket.write_all(b"world").await);
t!(socket.shutdown().await);
};
block_on(f);
});
Expand Down
1 change: 1 addition & 0 deletions api-test/src/client_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ where
assert_eq!(&buf, b"hello");

t!(socket.write_all(b"world").await);
t!(socket.shutdown().await);
};
block_on(future);
})
Expand Down
1 change: 1 addition & 0 deletions api-test/src/client_server_dyn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ async fn test_client_server_dyn_impl(
assert_eq!(&buf, b"hello");

t!(socket.write_all(b"world").await);
t!(socket.shutdown().await);
};
block_on(future);
})
Expand Down
13 changes: 12 additions & 1 deletion api-test/src/google.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,18 @@ async fn test_google_impl<C: TlsConnector>() {

t!(tls_stream.write_all(b"GET / HTTP/1.0\r\n\r\n").await);
let mut result = vec![];
t!(tls_stream.read_to_end(&mut result).await);
let res = tls_stream.read_to_end(&mut result).await;

// Google will not send close_notify and just close the connection.
// This means that they are not confirming to TLS exactly, that connections to google.com
// are vulnerable to truncation attacks and that we need to suppress error about this here.
match res {
Ok(_) => {}
Err(e)
if e.to_string()
.contains("peer closed connection without sending TLS close_notify") => {}
Err(e) => panic!("{}", e),
}

println!("{}", String::from_utf8_lossy(&result));
assert!(
Expand Down
45 changes: 43 additions & 2 deletions api/src/async_as_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ where
/// API-implementation of wrapper stream.
///
/// Wrapped object is always [`AsyncIoAsSyncIo`].
type SyncWrapper: Read + Write + Unpin + Send + 'static;
type SyncWrapper: Read + Write + WriteShutdown + Unpin + Send + 'static;

/// Which crates imlpements this?
fn impl_info() -> ImplInfo;
Expand All @@ -137,6 +137,47 @@ where
fn get_alpn_protocol(w: &Self::SyncWrapper) -> anyhow::Result<Option<Vec<u8>>>;
}

/// Notify the writer that there will be no more data written.
/// In context of TLS providers, this is great time to send notify_close message.
pub trait WriteShutdown: Write {
/// Initiates or attempts to shut down this writer, returning when
/// the I/O connection has completely shut down.
///
/// For example this is suitable for implementing shutdown of a
/// TLS connection or calling `TcpStream::shutdown` on a proxied connection.
/// Protocols sometimes need to flush out final pieces of data or otherwise
/// perform a graceful shutdown handshake, reading/writing more data as
/// appropriate. This method is the hook for such protocols to implement the
/// graceful shutdown logic.
///
/// This `shutdown` method is required by implementers of the
/// `AsyncWrite` trait. Wrappers typically just want to proxy this call
/// through to the wrapped type, and base types will typically implement
/// shutdown logic here or just return `Ok(().into())`. Note that if you're
/// wrapping an underlying `AsyncWrite` a call to `shutdown` implies that
/// transitively the entire stream has been shut down. After your wrapper's
/// shutdown logic has been executed you should shut down the underlying
/// stream.
///
/// Invocation of a `shutdown` implies an invocation of `flush`. Once this
/// method returns it implies that a flush successfully happened
/// before the shutdown happened. That is, callers don't need to call
/// `flush` before calling `shutdown`. They can rely that by calling
/// `shutdown` any pending buffered data will be written out.
///
/// # Errors
///
/// This function can return normal I/O errors through `Err`, described
/// above. Additionally this method may also render the underlying
/// `Write::write` method no longer usable (e.g. will return errors in the
/// future). It's recommended that once `shutdown` is called the
/// `write` method is no longer called.
fn shutdown(&mut self) -> Result<(), io::Error> {
self.flush()?;
Ok(())
}
}

/// Implementation of `TlsStreamImpl` for APIs using synchronous I/O.
pub struct TlsStreamOverSyncIo<A, O>
where
Expand Down Expand Up @@ -270,7 +311,7 @@ where
#[cfg(feature = "runtime-tokio")]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_mut()
.with_context_sync_to_async(cx, |stream| stream.stream.flush())
.with_context_sync_to_async(cx, |stream| stream.stream.shutdown())
}

#[cfg(feature = "runtime-async-std")]
Expand Down
6 changes: 4 additions & 2 deletions impl-native-tls/src/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use tls_api::async_as_sync::AsyncIoAsSyncIo;
use tls_api::spi::save_context;
use tls_api::AsyncSocket;

use crate::stream::NativeTlsStream;

pub(crate) enum HandshakeFuture<F, S: Unpin> {
Initial(F, AsyncIoAsSyncIo<S>),
MidHandshake(native_tls::MidHandshakeTlsStream<AsyncIoAsSyncIo<S>>),
Expand All @@ -36,7 +38,7 @@ where
match mem::replace(self_mut, HandshakeFuture::Done) {
HandshakeFuture::Initial(f, stream) => match f(stream) {
Ok(stream) => {
return Poll::Ready(Ok(crate::TlsStream::new(stream)));
return Poll::Ready(Ok(crate::TlsStream::new(NativeTlsStream(stream))));
}
Err(native_tls::HandshakeError::WouldBlock(mid)) => {
*self_mut = HandshakeFuture::MidHandshake(mid);
Expand All @@ -48,7 +50,7 @@ where
},
HandshakeFuture::MidHandshake(stream) => match stream.handshake() {
Ok(stream) => {
return Poll::Ready(Ok(crate::TlsStream::new(stream)));
return Poll::Ready(Ok(crate::TlsStream::new(NativeTlsStream(stream))));
}
Err(native_tls::HandshakeError::WouldBlock(mid)) => {
*self_mut = HandshakeFuture::MidHandshake(mid);
Expand Down
71 changes: 64 additions & 7 deletions impl-native-tls/src/stream.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
use native_tls::TlsStream as native_tls_TlsStream;
use std::fmt;
use std::io;
use std::io::Read;
use std::io::Write;
use std::marker::PhantomData;
use tls_api::async_as_sync::AsyncIoAsSyncIo;
use tls_api::async_as_sync::AsyncWrapperOps;
use tls_api::async_as_sync::TlsStreamOverSyncIo;
use tls_api::async_as_sync::WriteShutdown;
use tls_api::spi_async_socket_impl_delegate;
use tls_api::spi_tls_stream_over_sync_io_wrapper;
use tls_api::AsyncSocket;
use tls_api::ImplInfo;

spi_tls_stream_over_sync_io_wrapper!(TlsStream, native_tls_TlsStream);
spi_tls_stream_over_sync_io_wrapper!(TlsStream, NativeTlsStream);

#[derive(Debug)]
pub(crate) struct AsyncWrapperOpsImpl<S, A>(PhantomData<(S, A)>)
Expand All @@ -22,25 +25,79 @@ where
S: fmt::Debug + Unpin + Send + 'static,
A: AsyncSocket,
{
type SyncWrapper = native_tls::TlsStream<AsyncIoAsSyncIo<A>>;
type SyncWrapper = NativeTlsStream<AsyncIoAsSyncIo<A>>;

fn impl_info() -> ImplInfo {
crate::info()
}

fn debug(w: &Self::SyncWrapper) -> &dyn fmt::Debug {
w
&w.0
}

fn get_mut(w: &mut Self::SyncWrapper) -> &mut AsyncIoAsSyncIo<A> {
w.get_mut()
w.0.get_mut()
}

fn get_ref(w: &Self::SyncWrapper) -> &AsyncIoAsSyncIo<A> {
w.get_ref()
w.0.get_ref()
}

fn get_alpn_protocol(w: &Self::SyncWrapper) -> anyhow::Result<Option<Vec<u8>>> {
w.negotiated_alpn().map_err(anyhow::Error::new)
w.0.negotiated_alpn().map_err(anyhow::Error::new)
}
}

pub(crate) struct NativeTlsStream<A: Read + Write>(pub(crate) native_tls::TlsStream<A>);

impl<A: Read + Write> Write for NativeTlsStream<A> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}

fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}

fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
self.0.write_vectored(bufs)
}

fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
self.0.write_all(buf)
}

fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
self.0.write_fmt(fmt)
}
}

impl<A: Read + Write> WriteShutdown for NativeTlsStream<A> {
fn shutdown(&mut self) -> Result<(), io::Error> {
self.flush()?;
self.0.shutdown()?;
Ok(())
}
}

impl<A: Read + Write> Read for NativeTlsStream<A> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}

fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
self.0.read_vectored(bufs)
}

fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
self.0.read_to_end(buf)
}

fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
self.0.read_to_string(buf)
}

fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
self.0.read_exact(buf)
}
}
6 changes: 4 additions & 2 deletions impl-openssl/src/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use tls_api::async_as_sync::AsyncIoAsSyncIo;
use tls_api::spi::save_context;
use tls_api::AsyncSocket;

use crate::stream::OpenSSLStream;

pub(crate) enum HandshakeFuture<F, S: Unpin> {
Initial(F, AsyncIoAsSyncIo<S>),
MidHandshake(openssl::ssl::MidHandshakeSslStream<AsyncIoAsSyncIo<S>>),
Expand All @@ -36,7 +38,7 @@ where
match mem::replace(self_mut, HandshakeFuture::Done) {
HandshakeFuture::Initial(f, stream) => match f(stream) {
Ok(stream) => {
return Poll::Ready(Ok(crate::TlsStream::new(stream)));
return Poll::Ready(Ok(crate::TlsStream::new(OpenSSLStream(stream))));
}
Err(openssl::ssl::HandshakeError::WouldBlock(mid)) => {
*self_mut = HandshakeFuture::MidHandshake(mid);
Expand All @@ -51,7 +53,7 @@ where
},
HandshakeFuture::MidHandshake(stream) => match stream.handshake() {
Ok(stream) => {
return Poll::Ready(Ok(crate::TlsStream::new(stream)));
return Poll::Ready(Ok(crate::TlsStream::new(OpenSSLStream(stream))));
}
Err(openssl::ssl::HandshakeError::WouldBlock(mid)) => {
*self_mut = HandshakeFuture::MidHandshake(mid);
Expand Down
75 changes: 68 additions & 7 deletions impl-openssl/src/stream.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
use std::fmt;
use std::io;
use std::io::Read;
use std::io::Write;
use std::marker::PhantomData;

use openssl::ssl::SslRef;
use openssl::ssl::SslStream;
use tls_api::async_as_sync::AsyncIoAsSyncIo;
use tls_api::async_as_sync::AsyncWrapperOps;
use tls_api::async_as_sync::TlsStreamOverSyncIo;
use tls_api::async_as_sync::WriteShutdown;
use tls_api::spi_async_socket_impl_delegate;
use tls_api::spi_tls_stream_over_sync_io_wrapper;
use tls_api::AsyncSocket;
use tls_api::ImplInfo;

spi_tls_stream_over_sync_io_wrapper!(TlsStream, SslStream);
spi_tls_stream_over_sync_io_wrapper!(TlsStream, OpenSSLStream);

impl<A: AsyncSocket> TlsStream<A> {
/// Get the [`SslRef`] object for the stream.
pub fn get_ssl_ref(&self) -> &SslRef {
self.0.stream.ssl()
self.0.stream.0.ssl()
}
}

Expand All @@ -31,25 +35,82 @@ where
S: fmt::Debug + Unpin + Send + 'static,
A: AsyncSocket,
{
type SyncWrapper = openssl::ssl::SslStream<AsyncIoAsSyncIo<A>>;
type SyncWrapper = OpenSSLStream<AsyncIoAsSyncIo<A>>;

fn debug(w: &Self::SyncWrapper) -> &dyn fmt::Debug {
w
&w.0
}

fn get_mut(w: &mut Self::SyncWrapper) -> &mut AsyncIoAsSyncIo<A> {
w.get_mut()
w.0.get_mut()
}

fn get_ref(w: &Self::SyncWrapper) -> &AsyncIoAsSyncIo<A> {
w.get_ref()
w.0.get_ref()
}

fn get_alpn_protocol(w: &Self::SyncWrapper) -> anyhow::Result<Option<Vec<u8>>> {
Ok(w.ssl().selected_alpn_protocol().map(Vec::from))
Ok(w.0.ssl().selected_alpn_protocol().map(Vec::from))
}

fn impl_info() -> ImplInfo {
crate::into()
}
}

pub(crate) struct OpenSSLStream<A: Read + Write>(pub(crate) SslStream<A>);

impl<A: Read + Write> Write for OpenSSLStream<A> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}

fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}

fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
self.0.write_vectored(bufs)
}

fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
self.0.write_all(buf)
}

fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
self.0.write_fmt(fmt)
}
}

impl<A: Read + Write> WriteShutdown for OpenSSLStream<A> {
fn shutdown(&mut self) -> Result<(), io::Error> {
self.flush()?;
self.0.shutdown().map_err(|e| {
e.into_io_error()
.unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))
})?;
Ok(())
}
}

impl<A: Read + Write> Read for OpenSSLStream<A> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}

fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
self.0.read_vectored(bufs)
}

fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
self.0.read_to_end(buf)
}

fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
self.0.read_to_string(buf)
}

fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
self.0.read_exact(buf)
}
}
Loading