From 7a219ce3fa7c07420ec8f0eac639bc9334ce8e3f Mon Sep 17 00:00:00 2001 From: Justin Bradfield Date: Fri, 24 Oct 2025 22:44:23 -0500 Subject: [PATCH 1/2] enable more sni tests mzcompose --- Cargo.lock | 104 ++++++++++++++-- src/balancerd/Cargo.toml | 1 + src/balancerd/src/bin/balancerd.rs | 47 +++++-- src/balancerd/src/lib.rs | 191 +++++++++++++++++++++-------- src/balancerd/tests/server.rs | 31 +++-- test/balancerd/mzcompose.py | 11 +- test/dnsmasq/entrypoint.sh | 1 + 7 files changed, 304 insertions(+), 82 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6cf6e9c85dff5..4a5b787816123 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2115,9 +2115,9 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.3" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "core_affinity" @@ -2200,6 +2200,12 @@ dependencies = [ "itertools 0.10.5", ] +[[package]] +name = "critical-section" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" + [[package]] name = "crossbeam" version = "0.8.4" @@ -3735,6 +3741,52 @@ dependencies = [ "rayon", ] +[[package]] +name = "hickory-proto" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8a6fe56c0038198998a6f217ca4e7ef3a5e51f46163bd6dd60b5c71ca6c6502" +dependencies = [ + "async-trait", + "cfg-if", + "data-encoding", + "enum-as-inner", + "futures-channel", + "futures-io", + "futures-util", + "idna", + "ipnet", + "once_cell", + "rand 0.9.2", + "ring", + "thiserror 2.0.12", + "tinyvec", + "tokio", + "tracing", + "url", +] + +[[package]] +name = "hickory-resolver" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc62a9a99b0bfb44d2ab95a7208ac952d31060efc16241c87eaf36406fecf87a" +dependencies = [ + "cfg-if", + "futures-util", + "hickory-proto", + "ipconfig", + "moka", + "once_cell", + "parking_lot", + "rand 0.9.2", + "resolv-conf", + "smallvec", + "thiserror 2.0.12", + "tokio", + "tracing", +] + [[package]] name = "hmac" version = "0.12.1" @@ -4276,6 +4328,18 @@ dependencies = [ "libc", ] +[[package]] +name = "ipconfig" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b58db92f96b720de98181bbbe63c831e87005ab460c1bf306eb2622b4707997f" +dependencies = [ + "socket2 0.5.10", + "widestring", + "windows-sys 0.48.0", + "winreg", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -5480,6 +5544,7 @@ dependencies = [ "clap", "domain", "futures", + "hickory-resolver", "humantime", "hyper 1.6.0", "hyper-openssl", @@ -8598,6 +8663,10 @@ name = "once_cell" version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +dependencies = [ + "critical-section", + "portable-atomic", +] [[package]] name = "oorandom" @@ -10338,6 +10407,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "resolv-conf" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b3789b30bd25ba102de4beabd95d21ac45b69b1be7d14522bab988c526d6799" + [[package]] name = "retain_mut" version = "0.1.9" @@ -10565,9 +10640,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.23" +version = "0.23.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395" +checksum = "cd3c25631629d034ce7cd9940adc9d45762d46de2b0f57193c4443b92c6d4d40" dependencies = [ "once_cell", "rustls-pki-types", @@ -10587,15 +10662,18 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.10.1" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37" +checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" +dependencies = [ + "zeroize", +] [[package]] name = "rustls-webpki" -version = "0.102.8" +version = "0.103.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" +checksum = "8572f3c2cb9934231157b45499fc41e1f58c589fdfb81a844ba873265e80f8eb" dependencies = [ "ring", "rustls-pki-types", @@ -10739,9 +10817,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.11.0" +version = "2.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7" +checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" dependencies = [ "core-foundation-sys", "libc", @@ -13052,6 +13130,12 @@ dependencies = [ "web-sys", ] +[[package]] +name = "widestring" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd7cf3379ca1aac9eea11fba24fd7e315d621f8dfe35c8d7d2be8b793726e07d" + [[package]] name = "winapi" version = "0.3.9" diff --git a/src/balancerd/Cargo.toml b/src/balancerd/Cargo.toml index f45d9f1f71c71..d918f3406af74 100644 --- a/src/balancerd/Cargo.toml +++ b/src/balancerd/Cargo.toml @@ -19,6 +19,7 @@ chrono = { version = "0.4.39", default-features = false, features = ["std"] } clap = { version = "4.5.23", features = ["derive", "env"] } domain = { version = "0.11.1", default-features = false, features = ["resolv"] } futures = "0.3.31" +hickory-resolver = "0.25.2" humantime = "2.2.0" hyper = { version = "1.4.1", features = ["http1", "server"] } hyper-openssl = "0.10.2" diff --git a/src/balancerd/src/bin/balancerd.rs b/src/balancerd/src/bin/balancerd.rs index a2d173029fe92..ab7a8d579f454 100644 --- a/src/balancerd/src/bin/balancerd.rs +++ b/src/balancerd/src/bin/balancerd.rs @@ -18,11 +18,13 @@ use std::path::PathBuf; use std::time::Duration; use anyhow::Context; -use domain::resolv::StubResolver; +use hickory_resolver::{ + Resolver, config::*, name_server::TokioConnectionProvider, system_conf::read_system_conf, +}; use jsonwebtoken::DecodingKey; use mz_balancerd::{ - BUILD_INFO, BalancerConfig, BalancerService, CancellationResolver, FronteggResolver, Resolver, - SniResolver, + BUILD_INFO, BalancerConfig, BalancerResolver, BalancerService, CancellationResolver, + FronteggResolver, SniResolver, create_default_resolver, }; use mz_frontegg_auth::{ Authenticator, AuthenticatorConfig, DEFAULT_REFRESH_DROP_FACTOR, @@ -241,8 +243,10 @@ pub async fn run(args: ServiceArgs, tracing_handle: TracingHandle) -> Result<(), if !cancellation_resolver_dir.is_dir() { anyhow::bail!("{cancellation_resolver_dir:?} is not a directory"); } + ( - Resolver::MultiTenant( + BalancerResolver::MultiTenant( + create_default_resolver(), FronteggResolver { auth, addr_template, @@ -261,11 +265,7 @@ pub async fn run(args: ServiceArgs, tracing_handle: TracingHandle) -> Result<(), ) }) .expect("invalid port for pgwire_sni_resolver_template"); - Some(SniResolver { - resolver: StubResolver::new(), - template, - port, - }) + Some(SniResolver { template, port }) } }, ), @@ -284,8 +284,35 @@ pub async fn run(args: ServiceArgs, tracing_handle: TracingHandle) -> Result<(), }; drop(addrs); + // Create a resolver for static addresses with the same caching configuration + let mut resolver_opts = ResolverOpts::default(); + resolver_opts.cache_size = 10000; + resolver_opts.positive_max_ttl = Some(Duration::from_secs(10)); + resolver_opts.positive_min_ttl = Some(Duration::from_secs(9)); + resolver_opts.negative_min_ttl = Some(Duration::from_secs(1)); + + // Read system DNS configuration or fall back to defaults + let (config, opts) = read_system_conf() + .map(|(config, mut opts)| { + // Override specific options while keeping system DNS servers + opts.cache_size = resolver_opts.cache_size; + opts.positive_max_ttl = resolver_opts.positive_max_ttl; + opts.positive_min_ttl = resolver_opts.positive_min_ttl; + opts.negative_min_ttl = resolver_opts.negative_min_ttl; + (config, opts) + }) + .unwrap_or_else(|err| { + eprintln!("Failed to read system DNS configuration for static resolver, using defaults: {}", err); + (ResolverConfig::default(), resolver_opts) + }); + ( - Resolver::Static(addr.clone()), + BalancerResolver::Static( + Resolver::builder_with_config(config, TokioConnectionProvider::default()) + .with_options(opts) + .build(), + addr.clone(), + ), CancellationResolver::Static(addr), ) } diff --git a/src/balancerd/src/lib.rs b/src/balancerd/src/lib.rs index 491d8a628de52..529836ac22432 100644 --- a/src/balancerd/src/lib.rs +++ b/src/balancerd/src/lib.rs @@ -23,7 +23,6 @@ use std::collections::BTreeMap; use std::net::SocketAddr; use std::path::PathBuf; use std::pin::Pin; -use std::str::FromStr; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -31,11 +30,12 @@ use anyhow::Context; use axum::response::IntoResponse; use axum::{Router, routing}; use bytes::BytesMut; -use domain::base::{Name, Rtype}; -use domain::rdata::AllRecordData; -use domain::resolv::StubResolver; use futures::TryFutureExt; use futures::stream::BoxStream; +use hickory_resolver::{ + Resolver, TokioResolver, config::*, name_server::TokioConnectionProvider, + proto::rr::RecordType, system_conf::read_system_conf, +}; use hyper::StatusCode; use hyper_util::rt::TokioIo; use launchdarkly_server_sdk as ld; @@ -94,7 +94,7 @@ pub struct BalancerConfig { /// DNS resolver for pgwire cancellation requests cancellation_resolver: CancellationResolver, /// DNS resolver. - resolver: Resolver, + resolver: BalancerResolver, https_sni_addr_template: String, tls: Option, internal_tls: bool, @@ -117,7 +117,7 @@ impl BalancerConfig { pgwire_listen_addr: SocketAddr, https_listen_addr: SocketAddr, cancellation_resolver: CancellationResolver, - resolver: Resolver, + resolver: BalancerResolver, https_sni_addr_template: String, tls: Option, internal_tls: bool, @@ -354,9 +354,9 @@ impl BalancerService { panic!("expected port in https_addr_template"); }; let port: u16 = port.parse().expect("unexpected port"); - let resolver = StubResolver::new(); + let https = HttpsBalancer { - resolver: Arc::from(resolver), + resolver: Arc::from(create_default_resolver()), tls: https_tls, resolve_template: Arc::from(addr), port, @@ -613,7 +613,7 @@ struct PgwireBalancer { tls: Option, internal_tls: bool, cancellation_resolver: Arc, - resolver: Arc, + resolver: Arc, metrics: ServerMetrics, now: NowFn, } @@ -624,7 +624,7 @@ impl PgwireBalancer { conn: &'a mut FramedConn, version: i32, params: BTreeMap, - resolver: &Resolver, + resolver: &BalancerResolver, tls_mode: Option, internal_tls: bool, metrics: &ServerMetrics, @@ -1049,7 +1049,7 @@ async fn cancel_request( } struct HttpsBalancer { - resolver: Arc, + resolver: Arc, tls: Option, resolve_template: Arc, port: u16, @@ -1060,7 +1060,7 @@ struct HttpsBalancer { impl HttpsBalancer { async fn resolve( - resolver: &StubResolver, + resolver: &TokioResolver, resolve_template: &str, port: u16, servername: Option<&str>, @@ -1086,7 +1086,7 @@ impl HttpsBalancer { let tenant = resolver.tenant(&addr).await; // Now do the regular ip lookup, regardless of if there was a CNAME. - let envd_addr = lookup(&format!("{addr}:{port}")).await?; + let envd_addr = lookup_with_resolver(resolver, &format!("{addr}:{port}")).await?; Ok(ResolvedAddr { addr: envd_addr, @@ -1096,31 +1096,20 @@ impl HttpsBalancer { } } -trait StubResolverExt { +trait ResolverExt { async fn tenant(&self, addr: &str) -> Option; } -impl StubResolverExt for StubResolver { +impl ResolverExt for TokioResolver { /// Finds the tenant of a DNS address. Errors or lack of cname resolution here are ok, because /// this is only used for metrics. async fn tenant(&self, addr: &str) -> Option { - let Ok(dname) = Name::>::from_str(addr) else { - return None; - }; debug!("resolving tenant for {:?}", addr); // Lookup the CNAME. If there's a CNAME, find the tenant. - let lookup = self.query((dname, Rtype::CNAME)).await; + let lookup = self.lookup(addr, RecordType::CNAME).await; if let Ok(lookup) = lookup { - if let Ok(answer) = lookup.answer() { - let res = answer.limit_to::>(); - for record in res { - let Ok(record) = record else { - continue; - }; - if record.rtype() != Rtype::CNAME { - continue; - } - let cname = record.data(); + for record in lookup.iter() { + if let Some(cname) = record.as_cname() { let cname = cname.to_string(); debug!("cname: {cname}"); return extract_tenant_from_cname(&cname); @@ -1264,7 +1253,6 @@ impl mz_server_core::Server for HttpsBalancer { #[derive(Debug)] pub struct SniResolver { - pub resolver: StubResolver, pub template: String, pub port: u16, } @@ -1273,12 +1261,12 @@ trait ClientStream: AsyncRead + AsyncWrite + Unpin + Send {} impl ClientStream for T {} #[derive(Debug)] -pub enum Resolver { - Static(String), - MultiTenant(FronteggResolver, Option), +pub enum BalancerResolver { + Static(TokioResolver, String), + MultiTenant(TokioResolver, FronteggResolver, Option), } -impl Resolver { +impl BalancerResolver { async fn resolve( &self, conn: &mut FramedConn, @@ -1289,7 +1277,8 @@ impl Resolver { A: AsyncRead + AsyncWrite + Unpin, { match self { - Resolver::MultiTenant( + BalancerResolver::MultiTenant( + dns_resolver, FronteggResolver { auth, addr_template, @@ -1313,15 +1302,14 @@ impl Resolver { ( Some(servername), Some(SniResolver { - resolver: stub_resolver, template: sni_addr_template, port, }), ) => { let sni_addr = sni_addr_template.replace("{}", servername); - let tenant = stub_resolver.tenant(&sni_addr).await; + let tenant = dns_resolver.tenant(&sni_addr).await; let sni_addr = format!("{sni_addr}:{port}"); - let addr = lookup(&sni_addr).await?; + let addr = lookup_with_resolver(dns_resolver, &sni_addr).await?; if tenant.is_some() { debug!("SNI header found for tenant {:?}", tenant); } @@ -1352,7 +1340,7 @@ impl Resolver { let addr = addr_template.replace("{}", &auth_session.tenant_id().to_string()); - let addr = lookup(&addr).await?; + let addr = lookup_with_resolver(dns_resolver, &addr).await?; let tenant = Some(auth_session.tenant_id().to_string()); if tenant.is_some() { debug!("SNI header NOT found for tenant {:?}", tenant); @@ -1373,8 +1361,9 @@ impl Resolver { Ok(resolved_addr) } - Resolver::Static(addr) => { - let addr = lookup(addr).await?; + BalancerResolver::Static(resolver, addr) => { + // Use the shared resolver instance for caching benefits + let addr = lookup_with_resolver(resolver, addr).await?; Ok(ResolvedAddr { addr, password: None, @@ -1385,16 +1374,120 @@ impl Resolver { } } -/// Returns the first IP address resolved from the provided hostname. -async fn lookup(name: &str) -> Result { - let mut addrs = tokio::net::lookup_host(name).await?; - match addrs.next() { - Some(addr) => Ok(addr), - None => { - error!("{name} did not resolve to any addresses"); - anyhow::bail!("internal error") +/// Creates a default resolver with caching enabled for DNS lookups. +pub fn create_default_resolver() -> TokioResolver { + let mut resolver_opts = ResolverOpts::default(); + resolver_opts.cache_size = 10000; + resolver_opts.positive_max_ttl = Some(Duration::from_secs(10)); + resolver_opts.positive_min_ttl = Some(Duration::from_secs(9)); + resolver_opts.negative_min_ttl = Some(Duration::from_secs(1)); + resolver_opts.negative_max_ttl = Some(Duration::from_secs(1)); + resolver_opts.ip_strategy = LookupIpStrategy::Ipv4thenIpv6; + + // Read system DNS configuration or fall back to defaults + let (config, opts) = read_system_conf() + .map(|(config, mut opts)| { + // Override specific options while keeping system DNS servers + opts.cache_size = resolver_opts.cache_size; + opts.positive_max_ttl = resolver_opts.positive_max_ttl; + opts.positive_min_ttl = resolver_opts.positive_min_ttl; + opts.negative_min_ttl = resolver_opts.negative_min_ttl; + opts.negative_max_ttl = resolver_opts.negative_max_ttl; + (config, opts) + }) + .unwrap_or_else(|err| { + warn!( + "Failed to read system DNS configuration, using defaults: {}", + err + ); + (ResolverConfig::default(), resolver_opts) + }); + + Resolver::builder_with_config(config, TokioConnectionProvider::default()) + .with_options(opts) + .build() +} + +/// Follows CNAME chain to resolve a hostname to IP addresses, ensuring each step is cached. +/// +/// hickory's `lookup_ip` only caches the final step when CNAMEs are involved. This function +/// explicitly queries for CNAME records at each step, following the chain until we reach a +/// hostname that resolves to A records. Each query is cached individually by hickory. +/// +/// Returns the final hostname and its A record lookup response. +async fn follow_cname_chain( + resolver: &TokioResolver, + initial_host: &str, +) -> Result { + let mut current_host = initial_host.to_string(); + let max_hops = 10; // Prevent infinite loops in case of misconfigured DNS + + for hop in 0..max_hops { + // Explicitly query for CNAME records to ensure this step is cached + match resolver.lookup(¤t_host, RecordType::CNAME).await { + Ok(cname_response) => { + // Check if we got a CNAME record back + if let Some(cname_record) = cname_response.iter().next() { + if let Some(cname_data) = cname_record.as_cname() { + // Follow the CNAME to the next hostname + let next_host = cname_data.to_string(); + tracing::debug!( + hop = hop, + from = %current_host, + to = %next_host, + "Following CNAME" + ); + current_host = next_host; + continue; + } + } + // No CNAME found, fall through to A record lookup + break; + } + Err(_) => { + // No CNAME record exists, this is likely the final hostname + break; + } } } + + // Now lookup A records for the final hostname in the chain + // We don't want to cache this + let response = resolver + .lookup_ip(¤t_host) + .await + .with_context(|| format!("Failed to resolve A record for hostname: {}", current_host))?; + + Ok(response) + // Ok((current_host, a_response)) +} + +/// Returns the first IP address resolved from the provided hostname using the hickory resolver for caching. +async fn lookup_with_resolver( + resolver: &TokioResolver, + name: &str, +) -> Result { + // Parse out the hostname and port from name (format: "hostname:port") + let (host, port_str) = name + .rsplit_once(':') + .ok_or_else(|| anyhow::anyhow!("Invalid address format, expected 'hostname:port'"))?; + let port: u16 = port_str + .parse() + .with_context(|| format!("Invalid port in address: {}", name))?; + + // Follow the CNAME chain and get A records, ensuring each step is cached + let ip_response = follow_cname_chain(resolver, host).await?; + + // Extract IP address from A records + let maybe_socket_addr = ip_response + .iter() + .find_map(|r| Some(SocketAddr::new(r, port))); + + if let Some(addr) = maybe_socket_addr { + Ok(addr) + } else { + anyhow::bail!("No A records found in DNS response") + } } #[derive(Debug)] diff --git a/src/balancerd/tests/server.rs b/src/balancerd/tests/server.rs index 90acb6aeef064..86e7cb8403ac2 100644 --- a/src/balancerd/tests/server.rs +++ b/src/balancerd/tests/server.rs @@ -16,12 +16,12 @@ use std::sync::Arc; use std::time::Duration; use chrono::Utc; -use domain::resolv::StubResolver; use futures::StreamExt; +use hickory_resolver::{Resolver, config::*, name_server::TokioConnectionProvider, system_conf::read_system_conf}; use jsonwebtoken::{DecodingKey, EncodingKey}; use mz_balancerd::{ - BUILD_INFO, BalancerConfig, BalancerService, CancellationResolver, FronteggResolver, Resolver, - SniResolver, + BUILD_INFO, BalancerConfig, BalancerResolver, BalancerService, CancellationResolver, + FronteggResolver, SniResolver, }; use mz_environmentd::test_util::{self, Ca, make_pg_tls}; use mz_frontegg_auth::{ @@ -140,17 +140,34 @@ async fn test_balancer() { let resolvers = vec![ ( - Resolver::Static(envd_server.sql_local_addr().to_string()), + BalancerResolver::Static( + { + // Use system DNS configuration for static resolver in tests + let (config, opts) = read_system_conf() + .unwrap_or_else(|_| (ResolverConfig::default(), ResolverOpts::default())); + Resolver::builder_with_config(config, TokioConnectionProvider::default()) + .with_options(opts) + .build() + }, + envd_server.sql_local_addr().to_string(), + ), CancellationResolver::Static(envd_server.sql_local_addr().to_string()), ), ( - Resolver::MultiTenant( + BalancerResolver::MultiTenant( + { + // Use system DNS configuration in tests + let (config, opts) = read_system_conf() + .unwrap_or_else(|_| (ResolverConfig::default(), ResolverOpts::default())); + Resolver::builder_with_config(config, TokioConnectionProvider::default()) + .with_options(opts) + .build() + }, FronteggResolver { auth: frontegg_auth, addr_template: envd_server.sql_local_addr().to_string(), }, Some(SniResolver { - resolver: StubResolver::new(), template: envd_server.sql_local_addr().ip().to_string(), port: envd_server.sql_local_addr().port(), }), @@ -176,7 +193,7 @@ async fn test_balancer() { for (resolver, cancellation_resolver) in resolvers { let (mut reload_tx, reload_rx) = futures::channel::mpsc::channel(1); let ticker = Box::pin(reload_rx); - let is_multi_tenant_resolver = matches!(resolver, Resolver::MultiTenant(_, _)); + let is_multi_tenant_resolver = matches!(resolver, BalancerResolver::MultiTenant(_, _, _)); let balancer_cfg = BalancerConfig::new( &BUILD_INFO, SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), diff --git a/test/balancerd/mzcompose.py b/test/balancerd/mzcompose.py index dc4f4e4ea1e62..3595f699ce783 100644 --- a/test/balancerd/mzcompose.py +++ b/test/balancerd/mzcompose.py @@ -111,7 +111,7 @@ def app_password(email: str) -> str: # balancerd. DnsmasqEntry( type="cname", - key="materialized", + key="sni.test", value="environmentd.environment-58cd23ff-a4d7-4bd0-ad85-a6ff29cc86c3-0.svc.cluster.local", ), DnsmasqEntry( @@ -133,8 +133,8 @@ def app_password(email: str) -> str: "--frontegg-jwk-file=/secrets/frontegg-mock.crt", f"--frontegg-api-token-url={FRONTEGG_URL}/identity/resources/auth/v1/api-token", f"--frontegg-admin-role={ADMIN_ROLE}", - "--https-sni-resolver-template=materialized:6876", - "--pgwire-sni-resolver-template=materialized:6875", + "--https-sni-resolver-template=sni.test:6876", + "--pgwire-sni-resolver-template=sni.test:6875", "--tls-key=/secrets/balancerd.key", "--tls-cert=/secrets/balancerd.crt", "--internal-tls", @@ -659,9 +659,8 @@ def workflow_user(c: Composition) -> None: "--frontegg-jwk-file=/secrets/frontegg-mock.crt", f"--frontegg-api-token-url={FRONTEGG_URL}/identity/resources/auth/v1/api-token", f"--frontegg-admin-role={ADMIN_ROLE}", - "--https-sni-resolver-template=materialized:6876", - # We want to use the frontegg resolver in this - "--pgwire-sni-resolver-template=materialized:6875", + "--https-sni-resolver-template=sni.test:6876", + "--pgwire-sni-resolver-template=sni.test:6875", "--tls-key=/secrets/balancerd.key", "--tls-cert=/secrets/balancerd.crt", "--internal-tls", diff --git a/test/dnsmasq/entrypoint.sh b/test/dnsmasq/entrypoint.sh index 708fa34e9a293..90f2c1a3aebbb 100644 --- a/test/dnsmasq/entrypoint.sh +++ b/test/dnsmasq/entrypoint.sh @@ -57,5 +57,6 @@ done --no-resolv \ --keep-in-foreground \ --log-queries \ + --no-daemon \ --log-facility=- \ --conf-file="$OVR" From ce26213716881468ae1879ef9e6eff083a0c4a15 Mon Sep 17 00:00:00 2001 From: Justin Bradfield Date: Mon, 27 Oct 2025 10:30:00 -0500 Subject: [PATCH 2/2] introduce tenant dns caching for CNAMES only --- Cargo.lock | 6 +- src/balancerd/src/bin/balancerd.rs | 36 +---- src/balancerd/src/lib.rs | 210 +++++++++++++++++------------ src/balancerd/tests/server.rs | 22 +-- src/workspace-hack/Cargo.toml | 4 + 5 files changed, 134 insertions(+), 144 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4a5b787816123..9f357f5e9dfc6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3759,7 +3759,7 @@ dependencies = [ "once_cell", "rand 0.9.2", "ring", - "thiserror 2.0.12", + "thiserror 2.0.17", "tinyvec", "tokio", "tracing", @@ -3782,7 +3782,7 @@ dependencies = [ "rand 0.9.2", "resolv-conf", "smallvec", - "thiserror 2.0.12", + "thiserror 2.0.17", "tokio", "tracing", ] @@ -13476,6 +13476,7 @@ dependencies = [ "hyper-util", "idna", "insta", + "ipnet", "itertools 0.13.0", "libc", "libz-sys", @@ -13498,6 +13499,7 @@ dependencies = [ "num-bigint", "num-integer", "num-traits", + "once_cell", "openssl", "openssl-sys", "parking_lot", diff --git a/src/balancerd/src/bin/balancerd.rs b/src/balancerd/src/bin/balancerd.rs index ab7a8d579f454..04144cabad607 100644 --- a/src/balancerd/src/bin/balancerd.rs +++ b/src/balancerd/src/bin/balancerd.rs @@ -18,13 +18,10 @@ use std::path::PathBuf; use std::time::Duration; use anyhow::Context; -use hickory_resolver::{ - Resolver, config::*, name_server::TokioConnectionProvider, system_conf::read_system_conf, -}; use jsonwebtoken::DecodingKey; use mz_balancerd::{ BUILD_INFO, BalancerConfig, BalancerResolver, BalancerService, CancellationResolver, - FronteggResolver, SniResolver, create_default_resolver, + FronteggResolver, SniResolver, }; use mz_frontegg_auth::{ Authenticator, AuthenticatorConfig, DEFAULT_REFRESH_DROP_FACTOR, @@ -246,7 +243,7 @@ pub async fn run(args: ServiceArgs, tracing_handle: TracingHandle) -> Result<(), ( BalancerResolver::MultiTenant( - create_default_resolver(), + mz_balancerd::TenantDnsResolver::new(), FronteggResolver { auth, addr_template, @@ -284,35 +281,8 @@ pub async fn run(args: ServiceArgs, tracing_handle: TracingHandle) -> Result<(), }; drop(addrs); - // Create a resolver for static addresses with the same caching configuration - let mut resolver_opts = ResolverOpts::default(); - resolver_opts.cache_size = 10000; - resolver_opts.positive_max_ttl = Some(Duration::from_secs(10)); - resolver_opts.positive_min_ttl = Some(Duration::from_secs(9)); - resolver_opts.negative_min_ttl = Some(Duration::from_secs(1)); - - // Read system DNS configuration or fall back to defaults - let (config, opts) = read_system_conf() - .map(|(config, mut opts)| { - // Override specific options while keeping system DNS servers - opts.cache_size = resolver_opts.cache_size; - opts.positive_max_ttl = resolver_opts.positive_max_ttl; - opts.positive_min_ttl = resolver_opts.positive_min_ttl; - opts.negative_min_ttl = resolver_opts.negative_min_ttl; - (config, opts) - }) - .unwrap_or_else(|err| { - eprintln!("Failed to read system DNS configuration for static resolver, using defaults: {}", err); - (ResolverConfig::default(), resolver_opts) - }); - ( - BalancerResolver::Static( - Resolver::builder_with_config(config, TokioConnectionProvider::default()) - .with_options(opts) - .build(), - addr.clone(), - ), + BalancerResolver::Static(addr.clone()), CancellationResolver::Static(addr), ) } diff --git a/src/balancerd/src/lib.rs b/src/balancerd/src/lib.rs index 529836ac22432..26d436ed31b18 100644 --- a/src/balancerd/src/lib.rs +++ b/src/balancerd/src/lib.rs @@ -356,7 +356,7 @@ impl BalancerService { let port: u16 = port.parse().expect("unexpected port"); let https = HttpsBalancer { - resolver: Arc::from(create_default_resolver()), + resolver: Arc::from(TenantDnsResolver::new()), tls: https_tls, resolve_template: Arc::from(addr), port, @@ -1049,7 +1049,7 @@ async fn cancel_request( } struct HttpsBalancer { - resolver: Arc, + resolver: Arc, tls: Option, resolve_template: Arc, port: u16, @@ -1060,7 +1060,7 @@ struct HttpsBalancer { impl HttpsBalancer { async fn resolve( - resolver: &TokioResolver, + resolver: &TenantDnsResolver, resolve_template: &str, port: u16, servername: Option<&str>, @@ -1086,7 +1086,9 @@ impl HttpsBalancer { let tenant = resolver.tenant(&addr).await; // Now do the regular ip lookup, regardless of if there was a CNAME. - let envd_addr = lookup_with_resolver(resolver, &format!("{addr}:{port}")).await?; + let envd_addr = resolver + .resolve_tenant_from_sni_host(&format!("{addr}:{port}")) + .await?; Ok(ResolvedAddr { addr: envd_addr, @@ -1100,21 +1102,15 @@ trait ResolverExt { async fn tenant(&self, addr: &str) -> Option; } -impl ResolverExt for TokioResolver { +impl ResolverExt for TenantDnsResolver { /// Finds the tenant of a DNS address. Errors or lack of cname resolution here are ok, because /// this is only used for metrics. async fn tenant(&self, addr: &str) -> Option { debug!("resolving tenant for {:?}", addr); - // Lookup the CNAME. If there's a CNAME, find the tenant. - let lookup = self.lookup(addr, RecordType::CNAME).await; - if let Ok(lookup) = lookup { - for record in lookup.iter() { - if let Some(cname) = record.as_cname() { - let cname = cname.to_string(); - debug!("cname: {cname}"); - return extract_tenant_from_cname(&cname); - } - } + // Lookup the CNAME using the caching resolver + if let Ok(Some(cname)) = self.resolve_cname(addr).await { + debug!("cname: {cname}"); + return extract_tenant_from_cname(&cname); } None } @@ -1262,8 +1258,8 @@ impl ClientStream for T {} #[derive(Debug)] pub enum BalancerResolver { - Static(TokioResolver, String), - MultiTenant(TokioResolver, FronteggResolver, Option), + Static(String), + MultiTenant(TenantDnsResolver, FronteggResolver, Option), } impl BalancerResolver { @@ -1309,7 +1305,7 @@ impl BalancerResolver { let sni_addr = sni_addr_template.replace("{}", servername); let tenant = dns_resolver.tenant(&sni_addr).await; let sni_addr = format!("{sni_addr}:{port}"); - let addr = lookup_with_resolver(dns_resolver, &sni_addr).await?; + let addr = dns_resolver.resolve_tenant_from_sni_host(&sni_addr).await?; if tenant.is_some() { debug!("SNI header found for tenant {:?}", tenant); } @@ -1340,7 +1336,7 @@ impl BalancerResolver { let addr = addr_template.replace("{}", &auth_session.tenant_id().to_string()); - let addr = lookup_with_resolver(dns_resolver, &addr).await?; + let addr = dns_resolver.resolve_tenant_from_sni_host(&addr).await?; let tenant = Some(auth_session.tenant_id().to_string()); if tenant.is_some() { debug!("SNI header NOT found for tenant {:?}", tenant); @@ -1361,9 +1357,16 @@ impl BalancerResolver { Ok(resolved_addr) } - BalancerResolver::Static(resolver, addr) => { - // Use the shared resolver instance for caching benefits - let addr = lookup_with_resolver(resolver, addr).await?; + BalancerResolver::Static(addr) => { + // We don't want any caching here so we just use the standard + // tokio resolver. + let addr = if let Some(a) = tokio::net::lookup_host(addr).await?.next() { + a + } else { + error!("{addr} did not resolve to any addresses"); + anyhow::bail!("internal error"); + }; + Ok(ResolvedAddr { addr, password: None, @@ -1408,85 +1411,114 @@ pub fn create_default_resolver() -> TokioResolver { .build() } -/// Follows CNAME chain to resolve a hostname to IP addresses, ensuring each step is cached. -/// -/// hickory's `lookup_ip` only caches the final step when CNAMEs are involved. This function -/// explicitly queries for CNAME records at each step, following the chain until we reach a -/// hostname that resolves to A records. Each query is cached individually by hickory. -/// -/// Returns the final hostname and its A record lookup response. -async fn follow_cname_chain( - resolver: &TokioResolver, - initial_host: &str, -) -> Result { - let mut current_host = initial_host.to_string(); - let max_hops = 10; // Prevent infinite loops in case of misconfigured DNS - - for hop in 0..max_hops { - // Explicitly query for CNAME records to ensure this step is cached - match resolver.lookup(¤t_host, RecordType::CNAME).await { +/// Creates a resolver with caching disabled for DNS lookups. +pub fn create_non_caching_resolver() -> TokioResolver { + let mut resolver_opts = ResolverOpts::default(); + resolver_opts.cache_size = 0; // Disable caching + resolver_opts.ip_strategy = LookupIpStrategy::Ipv4thenIpv6; + + // Read system DNS configuration or fall back to defaults + let (config, opts) = read_system_conf() + .map(|(config, mut opts)| { + // Override specific options while keeping system DNS servers + opts.cache_size = resolver_opts.cache_size; + (config, opts) + }) + .unwrap_or_else(|err| { + warn!( + "Failed to read system DNS configuration, using defaults: {}", + err + ); + (ResolverConfig::default(), resolver_opts) + }); + + Resolver::builder_with_config(config, TokioConnectionProvider::default()) + .with_options(opts) + .build() +} + +/// A resolver that uses separate caching and non-caching resolvers for different record types. +#[derive(Debug)] +pub struct TenantDnsResolver { + caching_resolver: TokioResolver, + non_caching_resolver: TokioResolver, +} + +impl TenantDnsResolver { + /// Creates a new TenantDnsResolver with default caching and non-caching resolvers. + pub fn new() -> Self { + Self { + caching_resolver: create_default_resolver(), + non_caching_resolver: create_non_caching_resolver(), + } + } + + /// Resolves a CNAME record using the caching resolver. + pub async fn resolve_cname(&self, hostname: &str) -> Result, anyhow::Error> { + match self + .caching_resolver + .lookup(hostname, RecordType::CNAME) + .await + { Ok(cname_response) => { // Check if we got a CNAME record back if let Some(cname_record) = cname_response.iter().next() { if let Some(cname_data) = cname_record.as_cname() { - // Follow the CNAME to the next hostname - let next_host = cname_data.to_string(); - tracing::debug!( - hop = hop, - from = %current_host, - to = %next_host, - "Following CNAME" - ); - current_host = next_host; - continue; + return Ok(Some(cname_data.to_string())); } } - // No CNAME found, fall through to A record lookup - break; - } - Err(_) => { - // No CNAME record exists, this is likely the final hostname - break; + Ok(None) } + Err(_) => Ok(None), } } - // Now lookup A records for the final hostname in the chain - // We don't want to cache this - let response = resolver - .lookup_ip(¤t_host) - .await - .with_context(|| format!("Failed to resolve A record for hostname: {}", current_host))?; + /// Resolves A records without caching. + /// Returns a LookupIp containing all A records for the hostname. + pub async fn resolve_arec_no_cache( + &self, + hostname: &str, + ) -> Result { + self.non_caching_resolver + .lookup_ip(hostname) + .await + .with_context(|| format!("Failed to resolve A record for hostname: {}", hostname)) + } + + /// Returns the first IP address resolved from the provided hostname using the DualResolver. + /// + /// CNAMEs are resolved with caching, while A records are resolved without caching. + async fn resolve_tenant_from_sni_host( + self: &TenantDnsResolver, + name: &str, + ) -> Result { + // Parse out the hostname and port from name (format: "hostname:port") + let (host, port_str) = name + .rsplit_once(':') + .ok_or_else(|| anyhow::anyhow!("Invalid address format, expected 'hostname:port'"))?; + let port: u16 = port_str + .parse() + .with_context(|| format!("Invalid port in address: {}", name))?; + + // Resolve initial CNAME with caching if it exists... + // these are generally static and we can generally ignore TTLS + // + // Resolve arec from host without caching... these can change at anytime + // in generally we should ignore TTLs. + let ip = if let Some(resolved_host) = self.resolve_cname(host).await? { + self.resolve_arec_no_cache(&resolved_host).await? + } else { + self.resolve_arec_no_cache(host).await? + }; - Ok(response) - // Ok((current_host, a_response)) -} + // Extract IP address from A records + let maybe_socket_addr = ip.iter().find_map(|r| Some(SocketAddr::new(r, port))); -/// Returns the first IP address resolved from the provided hostname using the hickory resolver for caching. -async fn lookup_with_resolver( - resolver: &TokioResolver, - name: &str, -) -> Result { - // Parse out the hostname and port from name (format: "hostname:port") - let (host, port_str) = name - .rsplit_once(':') - .ok_or_else(|| anyhow::anyhow!("Invalid address format, expected 'hostname:port'"))?; - let port: u16 = port_str - .parse() - .with_context(|| format!("Invalid port in address: {}", name))?; - - // Follow the CNAME chain and get A records, ensuring each step is cached - let ip_response = follow_cname_chain(resolver, host).await?; - - // Extract IP address from A records - let maybe_socket_addr = ip_response - .iter() - .find_map(|r| Some(SocketAddr::new(r, port))); - - if let Some(addr) = maybe_socket_addr { - Ok(addr) - } else { - anyhow::bail!("No A records found in DNS response") + if let Some(addr) = maybe_socket_addr { + Ok(addr) + } else { + anyhow::bail!("No A records found in DNS response") + } } } diff --git a/src/balancerd/tests/server.rs b/src/balancerd/tests/server.rs index 86e7cb8403ac2..d1a99c41e3c90 100644 --- a/src/balancerd/tests/server.rs +++ b/src/balancerd/tests/server.rs @@ -17,7 +17,6 @@ use std::time::Duration; use chrono::Utc; use futures::StreamExt; -use hickory_resolver::{Resolver, config::*, name_server::TokioConnectionProvider, system_conf::read_system_conf}; use jsonwebtoken::{DecodingKey, EncodingKey}; use mz_balancerd::{ BUILD_INFO, BalancerConfig, BalancerResolver, BalancerService, CancellationResolver, @@ -140,29 +139,12 @@ async fn test_balancer() { let resolvers = vec![ ( - BalancerResolver::Static( - { - // Use system DNS configuration for static resolver in tests - let (config, opts) = read_system_conf() - .unwrap_or_else(|_| (ResolverConfig::default(), ResolverOpts::default())); - Resolver::builder_with_config(config, TokioConnectionProvider::default()) - .with_options(opts) - .build() - }, - envd_server.sql_local_addr().to_string(), - ), + BalancerResolver::Static(envd_server.sql_local_addr().to_string()), CancellationResolver::Static(envd_server.sql_local_addr().to_string()), ), ( BalancerResolver::MultiTenant( - { - // Use system DNS configuration in tests - let (config, opts) = read_system_conf() - .unwrap_or_else(|_| (ResolverConfig::default(), ResolverOpts::default())); - Resolver::builder_with_config(config, TokioConnectionProvider::default()) - .with_options(opts) - .build() - }, + mz_balancerd::TenantDnsResolver::new(), FronteggResolver { auth: frontegg_auth, addr_template: envd_server.sql_local_addr().to_string(), diff --git a/src/workspace-hack/Cargo.toml b/src/workspace-hack/Cargo.toml index 2db17242d042a..da82747738035 100644 --- a/src/workspace-hack/Cargo.toml +++ b/src/workspace-hack/Cargo.toml @@ -76,6 +76,7 @@ hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14.5", defaul hyper = { version = "0.14.27", features = ["client", "http1", "http2", "stream", "tcp"] } hyper-util = { version = "0.1.17", features = ["client-legacy", "server-auto", "service", "tracing"] } insta = { version = "1.43.2", features = ["json"] } +ipnet = { version = "2.11.0" } itertools = { version = "0.13.0" } libc = { version = "0.2.177", features = ["extra_traits", "use_std"] } libz-sys = { version = "1.1.22", features = ["static"] } @@ -95,6 +96,7 @@ num = { version = "0.4.3" } num-bigint = { version = "0.4.6", features = ["serde"] } num-integer = { version = "0.1.46", features = ["i128"] } num-traits = { version = "0.2.19", features = ["i128", "libm"] } +once_cell = { version = "1.21.3", features = ["critical-section"] } parking_lot = { version = "0.12.5", features = ["serde"] } parquet = { version = "55.2.0", features = ["async"] } percent-encoding = { version = "2.3.2" } @@ -223,6 +225,7 @@ hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14.5", defaul hyper = { version = "0.14.27", features = ["client", "http1", "http2", "stream", "tcp"] } hyper-util = { version = "0.1.17", features = ["client-legacy", "server-auto", "service", "tracing"] } insta = { version = "1.43.2", features = ["json"] } +ipnet = { version = "2.11.0" } itertools = { version = "0.13.0" } libc = { version = "0.2.177", features = ["extra_traits", "use_std"] } libz-sys = { version = "1.1.22", features = ["static"] } @@ -242,6 +245,7 @@ num = { version = "0.4.3" } num-bigint = { version = "0.4.6", features = ["serde"] } num-integer = { version = "0.1.46", features = ["i128"] } num-traits = { version = "0.2.19", features = ["i128", "libm"] } +once_cell = { version = "1.21.3", features = ["critical-section"] } parking_lot = { version = "0.12.5", features = ["serde"] } parquet = { version = "55.2.0", features = ["async"] } percent-encoding = { version = "2.3.2" }