diff --git a/.gitignore b/.gitignore index 3e631c8af..0abf1a8bf 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ Cargo.lock .idea/ warp.iml + +*.swp diff --git a/Cargo.toml b/Cargo.toml index 5d426c0d5..ada6c2fb9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,7 @@ percent-encoding = "2.1" pin-project = "1.0" tokio-rustls = { version = "0.25", optional = true } rustls-pemfile = { version = "2.0", optional = true } +rustls-pki-types = { version = "1.9.0", optional = true } [dev-dependencies] pretty_env_logger = "0.5" @@ -56,7 +57,7 @@ listenfd = "1.0" default = ["multipart", "websocket"] multipart = ["multer"] websocket = ["tokio-tungstenite"] -tls = ["tokio-rustls", "rustls-pemfile"] +tls = ["tokio-rustls", "rustls-pemfile", "rustls-pki-types"] # Enable compression-related filters compression = ["compression-brotli", "compression-gzip"] diff --git a/src/filter/service.rs b/src/filter/service.rs index 3de12a02e..f9897ea1f 100644 --- a/src/filter/service.rs +++ b/src/filter/service.rs @@ -1,6 +1,5 @@ use std::convert::Infallible; use std::future::Future; -use std::net::SocketAddr; use std::pin::Pin; use std::task::{Context, Poll}; @@ -11,6 +10,7 @@ use pin_project::pin_project; use crate::reject::IsReject; use crate::reply::{Reply, Response}; use crate::route::{self, Route}; +use crate::transport::PeerInfo; use crate::{Filter, Request}; /// Convert a `Filter` into a `Service`. @@ -70,14 +70,14 @@ where ::Error: IsReject, { #[inline] - pub(crate) fn call_with_addr( + pub(crate) fn call_with_peer_info( &self, req: Request, - remote_addr: Option, + peer_info: PeerInfo, ) -> FilteredFuture { debug_assert!(!route::is_set(), "nested route::set calls"); - let route = Route::new(req, remote_addr); + let route = Route::new(req, peer_info); let fut = route::set(&route, || self.filter.filter(super::Internal)); FilteredFuture { future: fut, route } } @@ -99,7 +99,7 @@ where #[inline] fn call(&mut self, req: Request) -> Self::Future { - self.call_with_addr(req, None) + self.call_with_peer_info(req, Default::default()) } } diff --git a/src/filters/mod.rs b/src/filters/mod.rs index bd1c48718..64808d8c4 100644 --- a/src/filters/mod.rs +++ b/src/filters/mod.rs @@ -22,6 +22,8 @@ pub mod path; pub mod query; pub mod reply; pub mod sse; +#[cfg(feature = "tls")] +pub mod mtls; pub mod trace; #[cfg(feature = "websocket")] pub mod ws; diff --git a/src/filters/mtls.rs b/src/filters/mtls.rs new file mode 100644 index 000000000..7299c332b --- /dev/null +++ b/src/filters/mtls.rs @@ -0,0 +1,52 @@ +//! Mutual (client) TLS filters. + +use std::convert::Infallible; + +use rustls_pki_types::CertificateDer; + +use crate::{ + filter::{filter_fn_one, Filter}, + route::Route, +}; + +/// Certificates is a iterable container of Certificates. +pub type Certificates = Vec>; + +/// Creates a `Filter` to get the peer certificates for the TLS connection. +/// +/// If the underlying transport doesn't have peer certificates, this will yield +/// `None`. +/// +/// # Example +/// +/// ``` +/// use warp::mtls::Certificates; +/// use warp::Filter; +/// +/// let route = warp::mtls::peer_certificates() +/// .map(|certs: Option| { +/// println!("peer certificates = {:?}", certs.as_ref()); +/// }); +/// ``` +pub fn peer_certificates( +) -> impl Filter,), Error = Infallible> + Copy { + filter_fn_one(|route| futures_util::future::ok(from_route(route))) +} + +/// Testing +pub fn peer_certs_into_owned(certs: &Vec>) -> Vec> { + certs + .to_vec() + .iter() + .map(|cert| cert.clone().into_owned()) + .collect() +} + +fn from_route(route: &Route) -> Option { + route + .peer_certificates() + .read() + .unwrap() + .as_ref() + .map(peer_certs_into_owned) +} diff --git a/src/lib.rs b/src/lib.rs index f8d345662..dd2f3a8a4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -112,6 +112,9 @@ pub use self::filters::compression; #[cfg(feature = "multipart")] #[doc(hidden)] pub use self::filters::multipart; +#[cfg(feature = "tls")] +#[doc(hidden)] +pub use self::filters::mtls; #[cfg(feature = "websocket")] #[doc(hidden)] pub use self::filters::ws; diff --git a/src/route.rs b/src/route.rs index afbac4d8b..d63da838c 100644 --- a/src/route.rs +++ b/src/route.rs @@ -6,6 +6,7 @@ use std::net::SocketAddr; use hyper::Body; use crate::Request; +use crate::transport::PeerInfo; scoped_thread_local!(static ROUTE: RefCell); @@ -30,7 +31,7 @@ where #[derive(Debug)] pub(crate) struct Route { body: BodyState, - remote_addr: Option, + peer_info: PeerInfo, req: Request, segments_index: usize, } @@ -42,7 +43,7 @@ enum BodyState { } impl Route { - pub(crate) fn new(req: Request, remote_addr: Option) -> RefCell { + pub(crate) fn new(req: Request, peer_info: PeerInfo) -> RefCell { let segments_index = if req.uri().path().starts_with('/') { // Skip the beginning slash. 1 @@ -52,7 +53,7 @@ impl Route { RefCell::new(Route { body: BodyState::Ready, - remote_addr, + peer_info, req, segments_index, }) @@ -124,7 +125,12 @@ impl Route { } pub(crate) fn remote_addr(&self) -> Option { - self.remote_addr + self.peer_info.remote_addr + } + + #[cfg(feature = "tls")] + pub(crate) fn peer_certificates(&self) -> crate::transport::PeerCertificates { + self.peer_info.peer_certificates.clone() } pub(crate) fn take_body(&mut self) -> Option { diff --git a/src/server.rs b/src/server.rs index 929d96eb3..ed484758f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -55,9 +55,14 @@ macro_rules! into_service { let inner = crate::service($into); make_service_fn(move |transport| { let inner = inner.clone(); - let remote_addr = Transport::remote_addr(transport); + + let peer_info = crate::transport::PeerInfo { + remote_addr: Transport::remote_addr(transport), + peer_certificates: Transport::peer_certificates(transport), + }; + future::ok::<_, Infallible>(service_fn(move |req| { - inner.call_with_addr(req, remote_addr) + inner.call_with_peer_info(req, peer_info.clone()) })) }) }}; diff --git a/src/test.rs b/src/test.rs index 947c09025..c8e09c4e7 100644 --- a/src/test.rs +++ b/src/test.rs @@ -113,18 +113,19 @@ use crate::filters::ws::Message; use crate::reject::IsReject; use crate::reply::Reply; use crate::route::{self, Route}; +use crate::transport::PeerInfo; use crate::Request; #[cfg(feature = "websocket")] use crate::{Sink, Stream}; +#[cfg(feature = "tls")] +use crate::filters::mtls::Certificates; + use self::inner::OneOrTuple; /// Starts a new test `RequestBuilder`. pub fn request() -> RequestBuilder { - RequestBuilder { - remote_addr: None, - req: Request::default(), - } + Default::default() } /// Starts a new test `WsBuilder`. @@ -137,9 +138,9 @@ pub fn ws() -> WsBuilder { /// /// See [module documentation](crate::test) for an overview. #[must_use = "RequestBuilder does nothing on its own"] -#[derive(Debug)] +#[derive(Debug, Default)] pub struct RequestBuilder { - remote_addr: Option, + peer_info: PeerInfo, req: Request, } @@ -248,7 +249,21 @@ impl RequestBuilder { /// .remote_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)); /// ``` pub fn remote_addr(mut self, addr: SocketAddr) -> Self { - self.remote_addr = Some(addr); + self.peer_info.remote_addr = Some(addr); + self + } + + /// Set the peer certificates of this request. + /// Default is no peer certificates. + /// + /// # Example + /// ``` + /// let req = warp::test::request() + /// .peer_certificates([rustls_pki_types::CertificateDer::from_slice(b"FAKE CERT")]); + /// ``` + #[cfg(feature = "tls")] + pub fn peer_certificates(self, certs: impl Into) -> Self { + *self.peer_info.peer_certificates.write().unwrap() = Some(certs.into()); self } @@ -375,7 +390,7 @@ impl RequestBuilder { // TODO: de-duplicate this and apply_filter() assert!(!route::is_set(), "nested test filter calls"); - let route = Route::new(self.req, self.remote_addr); + let route = Route::new(self.req, self.peer_info); let mut fut = Box::pin( route::set(&route, move || f.filter(crate::filter::Internal)).then(|result| { let res = match result { @@ -404,7 +419,7 @@ impl RequestBuilder { { assert!(!route::is_set(), "nested test filter calls"); - let route = Route::new(self.req, self.remote_addr); + let route = Route::new(self.req, self.peer_info); let mut fut = Box::pin(route::set(&route, move || { f.filter(crate::filter::Internal) })); diff --git a/src/tls.rs b/src/tls.rs index aa7438752..df9cd87f5 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -15,7 +15,8 @@ use hyper::server::conn::{AddrIncoming, AddrStream}; use tokio_rustls::rustls::server::WebPkiClientVerifier; use tokio_rustls::rustls::{Error as TlsError, RootCertStore, ServerConfig}; -use crate::transport::Transport; +use crate::filters::mtls::peer_certs_into_owned; +use crate::transport::{PeerCertificates, Transport}; /// Represents errors that can occur building the TlsConfig #[derive(Debug)] @@ -284,6 +285,10 @@ impl Transport for TlsStream { fn remote_addr(&self) -> Option { Some(self.remote_addr) } + + fn peer_certificates(&self) -> PeerCertificates { + self.peer_certs.clone() + } } enum State { @@ -297,6 +302,7 @@ enum State { pub(crate) struct TlsStream { state: State, remote_addr: SocketAddr, + peer_certs: PeerCertificates, } impl TlsStream { @@ -306,6 +312,7 @@ impl TlsStream { TlsStream { state: State::Handshaking(accept), remote_addr, + peer_certs: Default::default(), } } } @@ -320,6 +327,11 @@ impl AsyncRead for TlsStream { match pin.state { State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { Ok(mut stream) => { + let (_, conn) = stream.get_ref(); + *pin.peer_certs.write().unwrap() = conn + .peer_certificates() + .map(|certs| peer_certs_into_owned(&certs.to_vec())); + let result = Pin::new(&mut stream).poll_read(cx, buf); pin.state = State::Streaming(stream); result diff --git a/src/transport.rs b/src/transport.rs index be553e706..22c785122 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -6,8 +6,27 @@ use std::task::{Context, Poll}; use hyper::server::conn::AddrStream; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +#[cfg(feature = "tls")] +use crate::filters::mtls::Certificates; + +#[cfg(feature = "tls")] +pub(crate) type PeerCertificates = std::sync::Arc>>; +#[cfg(not(feature = "tls"))] +pub(crate) type PeerCertificates = (); + pub trait Transport: AsyncRead + AsyncWrite { fn remote_addr(&self) -> Option; + + fn peer_certificates(&self) -> PeerCertificates { + Default::default() + } +} + +#[derive(Clone, Debug, Default)] +pub(crate) struct PeerInfo { + pub remote_addr: Option, + #[allow(dead_code)] + pub peer_certificates: PeerCertificates, } impl Transport for AddrStream { diff --git a/tests/mtls.rs b/tests/mtls.rs new file mode 100644 index 000000000..d8a51a2fe --- /dev/null +++ b/tests/mtls.rs @@ -0,0 +1,24 @@ +#![deny(warnings)] +#![cfg(feature = "tls")] + +use rustls_pki_types::CertificateDer; + +#[tokio::test] +async fn peer_certificates_missing() { + let extract_peer_certs = warp::mtls::peer_certificates(); + + let req = warp::test::request(); + let resp = req.filter(&extract_peer_certs).await.unwrap(); + assert!(resp.is_none()) +} + +#[tokio::test] +async fn peer_certificates_present() { + let extract_peer_certs = warp::mtls::peer_certificates(); + + let cert = CertificateDer::<'_>::from_slice(b"TEST CERT"); + + let req = warp::test::request().peer_certificates([cert.clone()]); + let resp = req.filter(&extract_peer_certs).await.unwrap(); + assert_eq!(resp.unwrap(), &[cert],) +}