Skip to content

Commit

Permalink
take _some_ headers into account when computing the cache key
Browse files Browse the repository at this point in the history
  • Loading branch information
fasterthanlime committed Oct 13, 2024
1 parent 758627b commit 7fa7deb
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 22 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ hyper = { version = "1.4.1", default-features = false, features = [
"server",
] }
hyper-util = { version = "0.1.9", features = ["tokio"] }
md5 = "0.7.0"
postcard = { version = "1.0.10", features = ["use-std"] }
pretty-hex = "0.4.1"
rcgen = { version = "0.13.1" }
Expand Down
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ An proof-of-concept(TM) caching HTTP forward proxy

## Limitations

* Will only accept to negotiate http/2 over TLS (via CONNECT) right now
* Very naive rules to decide if something is cachable (see sources)
specifically, **fopro DOES NOT RESPECT `cache-control`, `vary`, ETC**.
* The cache is boundless (both in memory and on disk)
Expand Down
122 changes: 101 additions & 21 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use argh::FromArgs;
use color_eyre::eyre::{self, Context};
use futures_util::future::BoxFuture;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use http_serde::http::uri::Scheme;
use hyper::{
body::{Body, Bytes},
server::conn,
Expand Down Expand Up @@ -78,7 +79,7 @@ impl CertAuth {
}

// just the output of the 'date' on macOS Sequoia
static CACHE_VERSION: &str = "Sun Oct 13 22:09:06 CEST 2024";
static CACHE_VERSION: &str = "Sun Oct 13 23:40:34 CEST 2024";

#[derive(FromArgs)]
/// A caching HTTP forward proxy
Expand Down Expand Up @@ -198,6 +199,7 @@ async fn main() -> eyre::Result<()> {
.with_no_client_auth()
.with_cert_resolver(cert_cache);
server_conf.alpn_protocols.push(b"h2".to_vec());
server_conf.alpn_protocols.push(b"http/1.1".to_vec());
server_conf.max_early_data_size = 4 * 1024;
server_conf.send_half_rtt_data = true;
server_conf.session_storage = ServerSessionMemoryCache::new(16 * 1024);
Expand Down Expand Up @@ -289,8 +291,19 @@ where
}
};

let scheme = if req.uri().port_u16().unwrap_or_default() == 443 {
Scheme::HTTPS
} else {
Scheme::HTTP
};

if req.method() != Method::CONNECT {
let service = ProxyService { host, settings };
let service = ProxyService {
host,
settings,
scheme,
};

return match service.proxy_request(req).await {
Ok(resp) => Ok(resp),
Err(e) => {
Expand All @@ -305,7 +318,7 @@ where

let on_upgrade = hyper::upgrade::on(req);
tokio::spawn(async move {
if let Err(e) = handle_upgraded_conn(on_upgrade, host, settings).await {
if let Err(e) = handle_upgraded_conn(on_upgrade, host, scheme, settings).await {
tracing::error!("Error handling upgraded conn: {e:?}");
}
});
Expand All @@ -321,6 +334,7 @@ where
async fn handle_upgraded_conn(
on_upgrade: OnUpgrade,
host: String,
scheme: Scheme,
settings: ProxySettings,
) -> eyre::Result<()> {
let c = on_upgrade.await.unwrap();
Expand All @@ -330,19 +344,47 @@ async fn handle_upgraded_conn(
let acceptor = tokio_rustls::TlsAcceptor::from(settings.server_conf.clone());
let tls_stream = acceptor.accept(c).await?;

{
enum Mode {
H1,
H2,
}

let mode = {
let (_stream, server_conn) = tls_stream.get_ref();

tracing::trace!(
"Negotiated TLS session in {:?}, ALPN proto:\n{}",
before_accept.elapsed(),
pretty_hex::pretty_hex(&server_conn.alpn_protocol().unwrap_or_default())
);
}

let service = ProxyService { host, settings };
let conn = conn::http2::Builder::new(TokioExecutor::new())
.serve_connection(TokioIo::new(tls_stream), service);
if server_conn.alpn_protocol().unwrap_or_default() == b"h2" {
Mode::H2
} else {
Mode::H1
}
};

let service = ProxyService {
host,
settings,
scheme,
};

let conn = tokio::spawn(async move {
match mode {
Mode::H1 => {
conn::http1::Builder::new()
.serve_connection(TokioIo::new(tls_stream), service)
.await
}
Mode::H2 => {
conn::http2::Builder::new(TokioExecutor::new())
.serve_connection(TokioIo::new(tls_stream), service)
.await
}
}
});
match conn.await {
Ok(_) => (),
Err(e) => {
Expand Down Expand Up @@ -449,8 +491,8 @@ impl std::fmt::Debug for ProxySettings {
#[derive(Debug, Clone)]
struct ProxyService {
host: String,

settings: ProxySettings,
scheme: Scheme,
}

impl<ReqBody> Service<Request<ReqBody>> for ProxyService
Expand Down Expand Up @@ -494,9 +536,13 @@ impl ProxyService {
let uri = req.uri().clone();
tracing::trace!(settings = ?self.settings, %uri, "Should proxy request");

let method = req.method().clone();
let (part, body) = req.into_parts();

let uri_host = uri
.host()
.ok_or_else(|| eyre::eyre!("expected host in CONNECT request"))?;
.or_else(|| part.headers.get("host").and_then(|h| h.to_str().ok()))
.ok_or_else(|| eyre::eyre!("expected host in URI or host header"))?;

if uri_host != self.host {
return Ok(Response::builder()
Expand All @@ -507,28 +553,53 @@ impl ProxyService {

let before_req = Instant::now();

let method = req.method().clone();
let (part, body) = req.into_parts();

let mut cachable = true;

let cache_key = format!(
let mut cache_key = format!(
"k/{}{}",
uri.authority().map(|a| a.as_str()).unwrap_or_default(),
uri.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or_default()
);
let cache_key = cache_key.replace(':', "_COLON_");
let cache_key = cache_key.replace("//", "_SLASHSLASH_");
cache_key = cache_key.replace(':', "_COLON_");
cache_key = cache_key.replace("//", "_SLASHSLASH_");
if cache_key.contains("..") {
cachable = false;
}
let cache_key = if cache_key.ends_with('/') {
format!("{cache_key}_INDEX_")
} else {
cache_key.to_string()

if cache_key.ends_with('/') {
cache_key = format!("{cache_key}_INDEX_");
};

if let Some(authorization) = part.headers.get(hyper::header::AUTHORIZATION) {
let authorization = authorization.to_str().unwrap();
let hash = md5::compute(authorization);
let hash = format!("{:x}", hash);
cache_key = format!("{cache_key}_AUTH_{hash}");
}

if let Some(accept) = part.headers.get(hyper::header::ACCEPT) {
let accept = accept.to_str().unwrap();
let hash = md5::compute(accept);
let hash = format!("{:x}", hash);
cache_key = format!("{cache_key}_ACCEPT_{hash}");
}

if let Some(accept_encoding) = part.headers.get(hyper::header::ACCEPT_ENCODING) {
let accept_encoding = accept_encoding.to_str().unwrap();
let hash = md5::compute(accept_encoding);
let hash = format!("{:x}", hash);
cache_key = format!("{cache_key}_ACCEPT_ENCODING_{hash}");
}

if let Some(accept_language) = part.headers.get(hyper::header::ACCEPT_LANGUAGE) {
let accept_language = accept_language.to_str().unwrap();
let hash = md5::compute(accept_language);
let hash = format!("{:x}", hash);
cache_key = format!("{cache_key}_ACCEPT_LANGUAGE_{hash}");
}

tracing::debug!("Cache key: {}", cache_key);

if let Some(host) = uri.host() {
Expand Down Expand Up @@ -640,7 +711,16 @@ impl ProxyService {
}
}

tracing::debug!("Proxying {method} {uri}");
tracing::debug!("Proxying {method} {uri}: {part:#?}");

let uri = if uri.host().is_none() {
let mut parts = uri.into_parts();
parts.scheme = Some(self.scheme.clone());
parts.authority = Some(format!("{}", self.host).parse().unwrap());
hyper::Uri::from_parts(parts).unwrap()
} else {
uri
};

let upstream_res = match self
.settings
Expand Down

0 comments on commit 7fa7deb

Please sign in to comment.