From dcd060d397b67187bb989b9901c874460d2fe5c8 Mon Sep 17 00:00:00 2001 From: Dov Alperin Date: Thu, 28 Aug 2025 22:55:19 -0400 Subject: [PATCH] sasl/scram --- Cargo.lock | 4 + src/adapter/Cargo.toml | 1 + src/adapter/src/catalog/open.rs | 8 + src/adapter/src/catalog/state.rs | 6 + src/adapter/src/client.rs | 37 +++ src/adapter/src/command.rs | 34 +++ src/adapter/src/coord.rs | 2 + src/adapter/src/coord/command_handler.rs | 180 +++++++++++- src/adapter/src/error.rs | 30 +- src/auth/Cargo.toml | 1 + src/auth/src/hash.rs | 188 ++++++++++++- src/authenticator/src/lib.rs | 1 + src/catalog/Cargo.toml | 2 + src/catalog/src/durable.rs | 1 + src/catalog/src/durable/initialize.rs | 21 +- src/catalog/src/durable/transaction.rs | 19 +- src/catalog/src/memory/error.rs | 2 + src/catalog/tests/debug.rs | 21 +- src/catalog/tests/open.rs | 11 +- .../tests/snapshots/debug__opened_trace.snap | 1 - src/environmentd/src/environmentd/main.rs | 12 + src/environmentd/src/http.rs | 8 + src/environmentd/src/lib.rs | 2 + src/environmentd/src/test_util.rs | 48 ++++ src/environmentd/tests/auth.rs | 101 +++++++ .../ci/listener_configs/password.json | 6 + src/pgwire-common/src/lib.rs | 5 +- src/pgwire-common/src/message.rs | 53 ++++ src/pgwire/Cargo.toml | 1 + src/pgwire/src/codec.rs | 256 +++++++++++++++++- src/pgwire/src/message.rs | 23 ++ src/pgwire/src/protocol.rs | 229 +++++++++++++++- src/server-core/src/listeners.rs | 15 + 33 files changed, 1277 insertions(+), 52 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bae290459b1c4..d76b98f49baa0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5378,6 +5378,7 @@ dependencies = [ "arrow", "async-stream", "async-trait", + "base64 0.22.1", "bytes", "bytesize", "chrono", @@ -5536,6 +5537,7 @@ name = "mz-auth" version = "0.0.0" dependencies = [ "base64 0.22.1", + "itertools 0.14.0", "mz-ore", "openssl", "proptest", @@ -5743,6 +5745,7 @@ dependencies = [ "mz-storage-client", "mz-storage-types", "mz-transform", + "openssl", "paste", "prometheus", "proptest", @@ -7293,6 +7296,7 @@ version = "0.0.0" dependencies = [ "anyhow", "async-trait", + "base64 0.22.1", "byteorder", "bytes", "bytesize", diff --git a/src/adapter/Cargo.toml b/src/adapter/Cargo.toml index bb1e301285dcf..9a13a0bee7b3f 100644 --- a/src/adapter/Cargo.toml +++ b/src/adapter/Cargo.toml @@ -14,6 +14,7 @@ anyhow = "1.0.98" arrow = { version = "55.2.0", default-features = false } async-stream = "0.3.6" async-trait = "0.1.88" +base64 = "0.22.1" bytes = "1.10.1" bytesize = "1.3.0" chrono = { version = "0.4.39", default-features = false, features = ["std"] } diff --git a/src/adapter/src/catalog/open.rs b/src/adapter/src/catalog/open.rs index 5af12a57cd2bf..045f6813c2bb7 100644 --- a/src/adapter/src/catalog/open.rs +++ b/src/adapter/src/catalog/open.rs @@ -151,6 +151,7 @@ impl Catalog { source_references: BTreeMap::new(), storage_metadata: Default::default(), temporary_schemas: BTreeMap::new(), + mock_authentication_nonce: Default::default(), config: mz_sql::catalog::CatalogConfig { start_time: to_datetime((config.now)()), start_instant: Instant::now(), @@ -401,6 +402,13 @@ impl Catalog { .unwrap_or("new") .to_string(); + let mz_authentication_mock_nonce = + txn.get_authentication_mock_nonce().ok_or_else(|| { + Error::new(ErrorKind::SettingError("authentication nonce".to_string())) + })?; + + state.mock_authentication_nonce = Some(mz_authentication_mock_nonce); + // Migrate item ASTs. let builtin_table_update = if !config.skip_migrations { let migrate_result = migrate::migrate( diff --git a/src/adapter/src/catalog/state.rs b/src/adapter/src/catalog/state.rs index e5da2ad938eb3..ec31a5c99f37c 100644 --- a/src/adapter/src/catalog/state.rs +++ b/src/adapter/src/catalog/state.rs @@ -142,6 +142,7 @@ pub struct CatalogState { #[serde(serialize_with = "mz_ore::serde::map_key_to_string")] pub(super) source_references: BTreeMap, pub(super) storage_metadata: StorageMetadata, + pub(super) mock_authentication_nonce: Option, // Mutable state not derived from the durable catalog. #[serde(skip)] @@ -316,6 +317,7 @@ impl CatalogState { source_references: Default::default(), storage_metadata: Default::default(), license_key: ValidatedLicenseKey::for_tests(), + mock_authentication_nonce: Default::default(), } } @@ -2635,6 +2637,10 @@ impl CatalogState { CommentObjectId::NetworkPolicy(id) => self.get_network_policy(&id).name.clone(), } } + + pub fn mock_authentication_nonce(&self) -> String { + self.mock_authentication_nonce.clone().unwrap_or_default() + } } impl ConnectionResolver for CatalogState { diff --git a/src/adapter/src/client.rs b/src/adapter/src/client.rs index 4441e2fd5e291..6208553ed53e0 100644 --- a/src/adapter/src/client.rs +++ b/src/adapter/src/client.rs @@ -50,6 +50,7 @@ use uuid::Uuid; use crate::catalog::Catalog; use crate::command::{ AuthResponse, CatalogDump, CatalogSnapshot, Command, ExecuteResponse, Response, + SASLChallengeResponse, SASLVerifyProofResponse, }; use crate::coord::{Coordinator, ExecuteContextExtra}; use crate::error::AdapterError; @@ -168,6 +169,40 @@ impl Client { Ok(response) } + pub async fn generate_sasl_challenge( + &self, + user: &String, + client_nonce: &String, + ) -> Result { + let (tx, rx) = oneshot::channel(); + self.send(Command::AuthenticateGetSASLChallenge { + role_name: user.to_string(), + nonce: client_nonce.to_string(), + tx, + }); + let response = rx.await.expect("sender dropped")?; + Ok(response) + } + + pub async fn verify_sasl_proof( + &self, + user: &String, + proof: &String, + nonce: &String, + mock_hash: &String, + ) -> Result { + let (tx, rx) = oneshot::channel(); + self.send(Command::AuthenticateVerifySASLProof { + role_name: user.to_string(), + proof: proof.to_string(), + auth_message: nonce.to_string(), + mock_hash: mock_hash.to_string(), + tx, + }); + let response = rx.await.expect("sender dropped")?; + Ok(response) + } + /// Upgrades this client to a session client. /// /// A session is a connection that has successfully negotiated parameters, @@ -927,6 +962,8 @@ impl SessionClient { Command::GetWebhook { .. } => typ = Some("webhook"), Command::Startup { .. } | Command::AuthenticatePassword { .. } + | Command::AuthenticateGetSASLChallenge { .. } + | Command::AuthenticateVerifySASLProof { .. } | Command::CatalogSnapshot { .. } | Command::Commit { .. } | Command::CancelRequest { .. } diff --git a/src/adapter/src/command.rs b/src/adapter/src/command.rs index d27209f6b8d1c..32c62dc7742ed 100644 --- a/src/adapter/src/command.rs +++ b/src/adapter/src/command.rs @@ -73,6 +73,20 @@ pub enum Command { password: Option, }, + AuthenticateGetSASLChallenge { + tx: oneshot::Sender>, + role_name: String, + nonce: String, + }, + + AuthenticateVerifySASLProof { + tx: oneshot::Sender>, + role_name: String, + proof: String, + auth_message: String, + mock_hash: String, + }, + Execute { portal_name: String, session: Session, @@ -148,6 +162,8 @@ impl Command { Command::CancelRequest { .. } | Command::Startup { .. } | Command::AuthenticatePassword { .. } + | Command::AuthenticateGetSASLChallenge { .. } + | Command::AuthenticateVerifySASLProof { .. } | Command::CatalogSnapshot { .. } | Command::PrivilegedCancelRequest { .. } | Command::GetWebhook { .. } @@ -166,6 +182,8 @@ impl Command { Command::CancelRequest { .. } | Command::Startup { .. } | Command::AuthenticatePassword { .. } + | Command::AuthenticateGetSASLChallenge { .. } + | Command::AuthenticateVerifySASLProof { .. } | Command::CatalogSnapshot { .. } | Command::PrivilegedCancelRequest { .. } | Command::GetWebhook { .. } @@ -210,6 +228,22 @@ pub struct AuthResponse { pub superuser: bool, } +#[derive(Derivative)] +#[derivative(Debug)] +pub struct SASLChallengeResponse { + pub iteration_count: usize, + /// Base64-encoded salt for the SASL challenge. + pub salt: String, + pub nonce: String, +} + +#[derive(Derivative)] +#[derivative(Debug)] +pub struct SASLVerifyProofResponse { + pub verifier: String, + pub auth_resp: AuthResponse, +} + // Facile implementation for `StartupResponse`, which does not use the `allowed` // feature of `ClientTransmitter`. impl Transmittable for StartupResponse { diff --git a/src/adapter/src/coord.rs b/src/adapter/src/coord.rs index 84c807435dd63..c05b2f88cd92b 100644 --- a/src/adapter/src/coord.rs +++ b/src/adapter/src/coord.rs @@ -356,6 +356,8 @@ impl Message { Command::CheckConsistency { .. } => "command-check_consistency", Command::Dump { .. } => "command-dump", Command::AuthenticatePassword { .. } => "command-auth_check", + Command::AuthenticateGetSASLChallenge { .. } => "command-auth_get_sasl_challenge", + Command::AuthenticateVerifySASLProof { .. } => "command-auth_verify_sasl_proof", }, Message::ControllerReady { controller: ControllerReadiness::Compute, diff --git a/src/adapter/src/coord/command_handler.rs b/src/adapter/src/coord/command_handler.rs index 3fac72718e41d..9ed9186e4fae6 100644 --- a/src/adapter/src/coord/command_handler.rs +++ b/src/adapter/src/coord/command_handler.rs @@ -10,6 +10,7 @@ //! Logic for processing client [`Command`]s. Each [`Command`] is initiated by a //! client via some external Materialize API (ex: HTTP and psql). +use base64::prelude::*; use differential_dataflow::lattice::Lattice; use mz_adapter_types::dyncfgs::ALLOW_USER_SESSIONS; use mz_auth::password::Password; @@ -59,13 +60,16 @@ use tokio::sync::{mpsc, oneshot}; use tracing::{Instrument, debug_span, info, warn}; use tracing_opentelemetry::OpenTelemetrySpanExt; -use crate::command::{AuthResponse, CatalogSnapshot, Command, ExecuteResponse, StartupResponse}; +use crate::command::{ + AuthResponse, CatalogSnapshot, Command, ExecuteResponse, SASLChallengeResponse, + SASLVerifyProofResponse, StartupResponse, +}; use crate::coord::appends::PendingWriteTxn; use crate::coord::{ ConnMeta, Coordinator, DeferredPlanStatement, Message, PendingTxn, PlanStatement, PlanValidity, PurifiedStatementReady, validate_ip_with_policy_rules, }; -use crate::error::AdapterError; +use crate::error::{AdapterError, AuthenticationError}; use crate::notice::AdapterNotice; use crate::session::{Session, TransactionOps, TransactionStatus}; use crate::util::{ClientTransmitter, ResultExt}; @@ -120,6 +124,31 @@ impl Coordinator { .await; } + Command::AuthenticateGetSASLChallenge { + tx, + role_name, + nonce, + } => { + self.handle_generate_sasl_challenge(tx, role_name, nonce) + .await; + } + + Command::AuthenticateVerifySASLProof { + tx, + role_name, + proof, + mock_hash, + auth_message, + } => { + self.handle_authenticate_verify_sasl_proof( + tx, + role_name, + proof, + auth_message, + mock_hash, + ); + } + Command::Execute { portal_name, session, @@ -246,6 +275,133 @@ impl Coordinator { .boxed_local() } + fn handle_authenticate_verify_sasl_proof( + &self, + tx: oneshot::Sender>, + role_name: String, + proof: String, + auth_message: String, + mock_hash: String, + ) { + let role = self.catalog().try_get_role_by_name(role_name.as_str()); + let role_auth = role.and_then(|r| self.catalog().try_get_role_auth_by_id(&r.id)); + + let login = role + .as_ref() + .map(|r| r.attributes.login.unwrap_or(false)) + .unwrap_or(false); + + let real_hash = role_auth + .as_ref() + .and_then(|auth| auth.password_hash.as_ref()); + let hash_ref = real_hash.map(|s| s.as_str()).unwrap_or(&mock_hash); + + let role_present = role.is_some(); + let make_auth_err = |role_present: bool, login: bool| { + AdapterError::AuthenticationError(if role_present && !login { + AuthenticationError::NonLogin + } else { + AuthenticationError::InvalidCredentials + }) + }; + + match mz_auth::hash::sasl_verify(hash_ref, &proof, &auth_message) { + Ok(verifier) => { + // Success only if role exists, allows login, and a real password hash was used. + if login && real_hash.is_some() { + let role = role.expect("login implies role exists"); + let _ = tx.send(Ok(SASLVerifyProofResponse { + verifier, + auth_resp: AuthResponse { + role_id: role.id, + superuser: role.attributes.superuser.unwrap_or(false), + }, + })); + } else { + let _ = tx.send(Err(make_auth_err(role_present, login))); + } + } + Err(_) => { + let _ = tx.send(Err(AdapterError::AuthenticationError( + AuthenticationError::InvalidCredentials, + ))); + } + } + } + + #[mz_ore::instrument(level = "debug")] + async fn handle_generate_sasl_challenge( + &mut self, + tx: oneshot::Sender>, + role_name: String, + client_nonce: String, + ) { + let role_auth = self + .catalog() + .try_get_role_by_name(&role_name) + .and_then(|role| self.catalog().try_get_role_auth_by_id(&role.id)); + + let nonce = match mz_auth::hash::generate_nonce(&client_nonce) { + Ok(n) => n, + Err(e) => { + let msg = format!( + "failed to generate nonce for client nonce {}: {}", + client_nonce, e + ); + let _ = tx.send(Err(AdapterError::Internal(msg.clone()))); + soft_panic_or_log!("{msg}"); + return; + } + }; + + // It's important that the mock_nonce is deterministic per role, otherwise the purpose of + // doing mock authentication is defeated. We use a catalog-wide nonce, and combine that + // with the role name to get a per-role mock nonce. + let send_mock_challenge = + |role_name: String, + mock_nonce: String, + nonce: String, + tx: oneshot::Sender>| { + let opts = mz_auth::hash::mock_sasl_challenge(&role_name, &mock_nonce); + let _ = tx.send(Ok(SASLChallengeResponse { + iteration_count: mz_ore::cast::u32_to_usize(opts.iterations.get()), + salt: BASE64_STANDARD.encode(opts.salt), + nonce, + })); + }; + + match role_auth { + Some(auth) if auth.password_hash.is_some() => { + let hash = auth.password_hash.as_ref().expect("checked above"); + match mz_auth::hash::scram256_parse_opts(hash) { + Ok(opts) => { + let _ = tx.send(Ok(SASLChallengeResponse { + iteration_count: mz_ore::cast::u32_to_usize(opts.iterations.get()), + salt: BASE64_STANDARD.encode(opts.salt), + nonce, + })); + } + Err(_) => { + send_mock_challenge( + role_name, + self.catalog().state().mock_authentication_nonce(), + nonce, + tx, + ); + } + } + } + _ => { + send_mock_challenge( + role_name, + self.catalog().state().mock_authentication_nonce(), + nonce, + tx, + ); + } + } + } + #[mz_ore::instrument(level = "debug")] async fn handle_authenticate_password( &mut self, @@ -255,14 +411,18 @@ impl Coordinator { ) { let Some(password) = password else { // The user did not provide a password. - let _ = tx.send(Err(AdapterError::AuthenticationError)); + let _ = tx.send(Err(AdapterError::AuthenticationError( + AuthenticationError::PasswordRequired, + ))); return; }; if let Some(role) = self.catalog().try_get_role_by_name(role_name.as_str()) { if !role.attributes.login.unwrap_or(false) { // The user is not allowed to login. - let _ = tx.send(Err(AdapterError::AuthenticationError)); + let _ = tx.send(Err(AdapterError::AuthenticationError( + AuthenticationError::NonLogin, + ))); return; } if let Some(auth) = self.catalog().try_get_role_auth_by_id(&role.id) { @@ -272,16 +432,22 @@ impl Coordinator { role_id: role.id, superuser: role.attributes.superuser.unwrap_or(false), })), - Err(_) => tx.send(Err(AdapterError::AuthenticationError)), + Err(_) => tx.send(Err(AdapterError::AuthenticationError( + AuthenticationError::InvalidCredentials, + ))), }; return; } } // Authentication failed due to incorrect password or missing password hash. - let _ = tx.send(Err(AdapterError::AuthenticationError)); + let _ = tx.send(Err(AdapterError::AuthenticationError( + AuthenticationError::InvalidCredentials, + ))); } else { // The user does not exist. - let _ = tx.send(Err(AdapterError::AuthenticationError)); + let _ = tx.send(Err(AdapterError::AuthenticationError( + AuthenticationError::RoleNotFound, + ))); } } diff --git a/src/adapter/src/error.rs b/src/adapter/src/error.rs index 6b5ba1123df51..30f5eef280ffd 100644 --- a/src/adapter/src/error.rs +++ b/src/adapter/src/error.rs @@ -238,15 +238,20 @@ pub enum AdapterError { /// read-only mode. ReadOnly, AlterClusterTimeout, - /// Authentication error. This is specifically for self-managed auth - /// and can generally encompass things like "incorrect password" or - /// what have you. We intentionally limit the fidelity of the error - /// we return to avoid allowing an attacker to, for example, - /// enumerate users by spraying login attempts and differentiating - /// between a "no such user" and "incorrect password" error. - AuthenticationError, - /// An ALTER CLUSTER was attempted while a graceful cluster reconfiguration was in progress. AlterClusterWhilePendingReplicas, + AuthenticationError(AuthenticationError), +} + +#[derive(Debug, thiserror::Error)] +pub enum AuthenticationError { + #[error("invalid credentials")] + InvalidCredentials, + #[error("role is not allowed to login")] + NonLogin, + #[error("role does not exist")] + RoleNotFound, + #[error("password is required")] + PasswordRequired, } impl AdapterError { @@ -588,8 +593,11 @@ impl AdapterError { // transactions. AdapterError::ReadOnly => SqlState::READ_ONLY_SQL_TRANSACTION, AdapterError::AlterClusterTimeout => SqlState::QUERY_CANCELED, - AdapterError::AuthenticationError => SqlState::INVALID_AUTHORIZATION_SPECIFICATION, AdapterError::AlterClusterWhilePendingReplicas => SqlState::OBJECT_IN_USE, + AdapterError::AuthenticationError(AuthenticationError::InvalidCredentials) => { + SqlState::INVALID_PASSWORD + } + AdapterError::AuthenticationError(_) => SqlState::INVALID_AUTHORIZATION_SPECIFICATION, } } @@ -824,8 +832,8 @@ impl fmt::Display for AdapterError { AdapterError::AlterClusterTimeout => { write!(f, "canceling statement, provided timeout lapsed") } - AdapterError::AuthenticationError => { - write!(f, "authentication error") + AdapterError::AuthenticationError(e) => { + write!(f, "authentication error {e}") } AdapterError::UnavailableFeature { feature, docs } => { write!(f, "{} is not supported in this environment.", feature)?; diff --git a/src/auth/Cargo.toml b/src/auth/Cargo.toml index 4bd43563fdd62..21a37a4dd15b2 100644 --- a/src/auth/Cargo.toml +++ b/src/auth/Cargo.toml @@ -18,6 +18,7 @@ proptest-derive = "0.5.1" proptest = "1.7.0" static_assertions = "1.1" openssl = { version = "0.10.73", features = ["vendored"] } +itertools = "0.14.0" [features] default = ["workspace-hack"] diff --git a/src/auth/src/hash.rs b/src/auth/src/hash.rs index f25baa4c35c05..368032293f0d8 100644 --- a/src/auth/src/hash.rs +++ b/src/auth/src/hash.rs @@ -14,6 +14,7 @@ use std::fmt::Display; use std::num::NonZeroU32; use base64::prelude::*; +use itertools::Itertools; use crate::password::Password; @@ -27,6 +28,7 @@ const DEFAULT_SALT_SIZE: usize = 32; const SHA256_OUTPUT_LEN: usize = 32; /// The options for hashing a password +#[derive(Debug, PartialEq)] pub struct HashOpts { /// The number of iterations to use for PBKDF2 pub iterations: NonZeroU32, @@ -58,6 +60,14 @@ pub enum HashError { Openssl(openssl::error::ErrorStack), } +impl Display for HashError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + HashError::Openssl(e) => write!(f, "OpenSSL error: {}", e), + } + } +} + /// Hashes a password using PBKDF2 with SHA256 /// and a random salt. pub fn hash_password(password: &Password) -> Result { @@ -79,6 +89,14 @@ pub fn hash_password(password: &Password) -> Result { }) } +pub fn generate_nonce(client_nonce: &str) -> Result { + let mut nonce = [0u8; 24]; + openssl::rand::rand_bytes(&mut nonce).map_err(HashError::Openssl)?; + let nonce = BASE64_STANDARD.encode(&nonce); + let new_nonce = format!("{}{}", client_nonce, nonce); + Ok(new_nonce) +} + /// Hashes a password using PBKDF2 with SHA256 /// and the given options. pub fn hash_password_with_opts( @@ -102,20 +120,111 @@ pub fn scram256_hash(password: &Password) -> Result { Ok(scram256_hash_inner(hashed_password).to_string()) } +fn constant_time_compare(a: &[u8], b: &[u8]) -> bool { + if a.len() != b.len() { + return false; + } + openssl::memcmp::eq(a, b) +} + /// Verifies a password against a SCRAM-SHA-256 hash. pub fn scram256_verify(password: &Password, hashed_password: &str) -> Result<(), VerifyError> { let opts = scram256_parse_opts(hashed_password)?; let hashed = hash_password_with_opts(&opts, password).map_err(VerifyError::Hash)?; let scram = scram256_hash_inner(hashed); - if *hashed_password == scram.to_string() { + if constant_time_compare(hashed_password.as_bytes(), scram.to_string().as_bytes()) { Ok(()) } else { Err(VerifyError::InvalidPassword) } } +pub fn sasl_verify( + hashed_password: &str, + proof: &str, + auth_message: &str, +) -> Result { + // Parse SCRAM hash: SCRAM-SHA-256$:$: + let parts: Vec<&str> = hashed_password.split('$').collect(); + if parts.len() != 3 { + return Err(VerifyError::MalformedHash); + } + let auth_info = parts[1].split(':').collect::>(); + if auth_info.len() != 2 { + return Err(VerifyError::MalformedHash); + } + let auth_value = parts[2].split(':').collect::>(); + if auth_value.len() != 2 { + return Err(VerifyError::MalformedHash); + } + + let client_key = BASE64_STANDARD + .decode(auth_value[0]) + .map_err(|_| VerifyError::MalformedHash)?; + let server_key = BASE64_STANDARD + .decode(auth_value[1]) + .map_err(|_| VerifyError::MalformedHash)?; + + // Compute stored key + let stored_key = openssl::sha::sha256(&client_key); + + // Compute client signature: HMAC(stored_key, auth_message) + let client_signature = generate_signature(&stored_key, auth_message)?; + + // Compute expected client proof: client_key XOR client_signature + let expected_client_proof: Vec = client_key + .iter() + .zip_eq(client_signature.iter()) + .map(|(a, b)| a ^ b) + .collect(); + + // Decode provided proof + let provided_client_proof = BASE64_STANDARD + .decode(proof) + .map_err(|_| VerifyError::InvalidPassword)?; + + if constant_time_compare(&expected_client_proof, &provided_client_proof) { + // Compute server verifier: HMAC(server_key, auth_message) + let verifier = generate_signature(&server_key, auth_message)?; + let verifier = BASE64_STANDARD.encode(&verifier); + Ok(verifier) + } else { + Err(VerifyError::InvalidPassword) + } +} + +fn generate_signature(key: &[u8], message: &str) -> Result, VerifyError> { + let signing_key = + openssl::pkey::PKey::hmac(key).map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?; + let mut signer = + openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &signing_key) + .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?; + signer + .update(message.as_bytes()) + .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?; + let signature = signer + .sign_to_vec() + .map_err(|e| VerifyError::Hash(HashError::Openssl(e)))?; + Ok(signature) +} + +// Generate a mock challenge based on the username and client nonce +// We do this so that we can present a deterministic challenge even for +// nonexistent users, to avoid user enumeration attacks. +pub fn mock_sasl_challenge(username: &str, mock_nonce: &str) -> HashOpts { + let mut buf = Vec::with_capacity(username.len() + mock_nonce.len()); + buf.extend_from_slice(username.as_bytes()); + buf.extend_from_slice(mock_nonce.as_bytes()); + let digest = openssl::sha::sha256(&buf); + + HashOpts { + iterations: DEFAULT_ITERATIONS, + salt: digest, + } +} + /// Parses a SCRAM-SHA-256 hash and returns the options used to create it. -fn scram256_parse_opts(hashed_password: &str) -> Result { +pub fn scram256_parse_opts(hashed_password: &str) -> Result { let parts: Vec<&str> = hashed_password.split('$').collect(); if parts.len() != 3 { return Err(VerifyError::MalformedHash); @@ -211,6 +320,8 @@ fn hash_password_inner( #[cfg(test)] mod tests { + use itertools::Itertools; + use super::*; #[mz_ore::test] @@ -248,4 +359,77 @@ mod tests { let decoded_salt = BASE64_STANDARD.decode(salt).expect("Failed to decode salt"); assert_eq!(opts.salt, decoded_salt.as_ref()); } + + #[mz_ore::test] + #[cfg_attr(miri, ignore)] + fn test_mock_sasl_challenge() { + let username = "alice"; + let mock = "cnonce"; + let opts1 = mock_sasl_challenge(username, mock); + let opts2 = mock_sasl_challenge(username, mock); + assert_eq!(opts1, opts2); + } + + #[mz_ore::test] + #[cfg_attr(miri, ignore)] + fn test_sasl_verify_success() { + let password: Password = "password".into(); + let hashed_password = scram256_hash(&password).expect("hash password"); + let auth_message = "n=user,r=clientnonce,s=somesalt"; // arbitrary auth message + + // Parse client_key and server_key from the SCRAM hash + // Format: SCRAM-SHA-256$:$: + let parts: Vec<&str> = hashed_password.split('$').collect(); + assert_eq!(parts.len(), 3); + let key_parts: Vec<&str> = parts[2].split(':').collect(); + assert_eq!(key_parts.len(), 2); + let client_key = BASE64_STANDARD + .decode(key_parts[0]) + .expect("decode client key"); + let server_key = BASE64_STANDARD + .decode(key_parts[1]) + .expect("decode server key"); + + // stored_key = SHA256(client_key) + let stored_key = openssl::sha::sha256(&client_key); + // client_signature = HMAC(stored_key, auth_message) + let client_signature = + generate_signature(&stored_key, auth_message).expect("client signature"); + // client_proof = client_key XOR client_signature + let client_proof: Vec = client_key + .iter() + .zip_eq(client_signature.iter()) + .map(|(a, b)| a ^ b) + .collect(); + let client_proof_b64 = BASE64_STANDARD.encode(&client_proof); + + let verifier = sasl_verify(&hashed_password, &client_proof_b64, auth_message) + .expect("sasl_verify should succeed"); + + // Expected verifier: HMAC(server_key, auth_message) + let expected_verifier = BASE64_STANDARD + .encode(&generate_signature(&server_key, auth_message).expect("server verifier")); + assert_eq!(verifier, expected_verifier); + } + + #[mz_ore::test] + #[cfg_attr(miri, ignore)] + fn test_sasl_verify_invalid_proof() { + let password: Password = "password".into(); + let hashed_password = scram256_hash(&password).expect("hash password"); + let auth_message = "n=user,r=clientnonce,s=somesalt"; + // Provide an obviously invalid base64 proof (different size / random) + let bad_proof = BASE64_STANDARD.encode([0u8; 32]); + let res = sasl_verify(&hashed_password, &bad_proof, auth_message); + assert!(matches!(res, Err(VerifyError::InvalidPassword))); + } + + #[mz_ore::test] + fn test_sasl_verify_malformed_hash() { + let malformed_hash = "NOT-SCRAM$bad"; // clearly malformed (wrong parts count) + let auth_message = "n=user,r=clientnonce,s=somesalt"; + let bad_proof = BASE64_STANDARD.encode([0u8; 32]); + let res = sasl_verify(malformed_hash, &bad_proof, auth_message); + assert!(matches!(res, Err(VerifyError::MalformedHash))); + } } diff --git a/src/authenticator/src/lib.rs b/src/authenticator/src/lib.rs index 4bacc55b906e8..6e1527d13b742 100644 --- a/src/authenticator/src/lib.rs +++ b/src/authenticator/src/lib.rs @@ -14,5 +14,6 @@ use mz_frontegg_auth::Authenticator as FronteggAuthenticator; pub enum Authenticator { Frontegg(FronteggAuthenticator), Password(AdapterClient), + Sasl(AdapterClient), None, } diff --git a/src/catalog/Cargo.toml b/src/catalog/Cargo.toml index 4bd2886f2903a..6db7c9645228c 100644 --- a/src/catalog/Cargo.toml +++ b/src/catalog/Cargo.toml @@ -12,6 +12,7 @@ workspace = true [dependencies] anyhow = "1.0.98" async-trait = "0.1.88" +base64 = "0.22.1" bincode = { version = "1.3.3" } bytes = { version = "1.10.1", features = ["serde"] } bytesize = "1.3.0" @@ -49,6 +50,7 @@ mz-sql-parser = { path = "../sql-parser" } mz-storage-client = { path = "../storage-client" } mz-storage-types = { path = "../storage-types" } mz-transform = { path = "../transform" } +openssl = { version = "0.10.73", features = ["vendored"] } paste = "1.0.11" prometheus = { version = "0.13.4", default-features = false } proptest = { version = "1.7.0", default-features = false, features = ["std"] } diff --git a/src/catalog/src/durable.rs b/src/catalog/src/durable.rs index 1689ccea9116a..5676ba5e3d3e6 100644 --- a/src/catalog/src/durable.rs +++ b/src/catalog/src/durable.rs @@ -72,6 +72,7 @@ pub const OID_ALLOC_KEY: &str = "oid"; pub(crate) const CATALOG_CONTENT_VERSION_KEY: &str = "catalog_content_version"; pub const BUILTIN_MIGRATION_SHARD_KEY: &str = "builtin_migration_shard"; pub const EXPRESSION_CACHE_SHARD_KEY: &str = "expression_cache_shard"; +pub const MOCK_AUTHENTICATION_NONCE_KEY: &str = "mock_authentication_nonce"; #[derive(Clone, Debug)] pub struct BootstrapArgs { diff --git a/src/catalog/src/durable/initialize.rs b/src/catalog/src/durable/initialize.rs index dbfd4184c88b1..4256bc99a9cf5 100644 --- a/src/catalog/src/durable/initialize.rs +++ b/src/catalog/src/durable/initialize.rs @@ -13,6 +13,7 @@ use std::str::FromStr; use std::sync::LazyLock; use std::time::Duration; +use base64::prelude::*; use ipnet::IpNet; use itertools::max; use mz_audit_log::{CreateOrDropClusterReplicaReasonV1, EventV1, VersionedEvent}; @@ -47,10 +48,10 @@ use crate::durable::{ AUDIT_LOG_ID_ALLOC_KEY, BUILTIN_MIGRATION_SHARD_KEY, BootstrapArgs, CATALOG_CONTENT_VERSION_KEY, CatalogError, ClusterConfig, ClusterVariant, ClusterVariantManaged, DATABASE_ID_ALLOC_KEY, DefaultPrivilege, EXPRESSION_CACHE_SHARD_KEY, - OID_ALLOC_KEY, ReplicaConfig, ReplicaLocation, Role, SCHEMA_ID_ALLOC_KEY, - STORAGE_USAGE_ID_ALLOC_KEY, SYSTEM_CLUSTER_ID_ALLOC_KEY, SYSTEM_REPLICA_ID_ALLOC_KEY, Schema, - Transaction, USER_CLUSTER_ID_ALLOC_KEY, USER_NETWORK_POLICY_ID_ALLOC_KEY, - USER_REPLICA_ID_ALLOC_KEY, USER_ROLE_ID_ALLOC_KEY, + MOCK_AUTHENTICATION_NONCE_KEY, OID_ALLOC_KEY, ReplicaConfig, ReplicaLocation, Role, + SCHEMA_ID_ALLOC_KEY, STORAGE_USAGE_ID_ALLOC_KEY, SYSTEM_CLUSTER_ID_ALLOC_KEY, + SYSTEM_REPLICA_ID_ALLOC_KEY, Schema, Transaction, USER_CLUSTER_ID_ALLOC_KEY, + USER_NETWORK_POLICY_ID_ALLOC_KEY, USER_REPLICA_ID_ALLOC_KEY, USER_ROLE_ID_ALLOC_KEY, }; /// The key within the "config" Collection that stores the version of the catalog. @@ -758,6 +759,18 @@ pub(crate) async fn initialize( tx.set_setting(name, Some(value))?; } + if tx + .get_setting(MOCK_AUTHENTICATION_NONCE_KEY.to_string()) + .is_none() + { + let mut nonce = [0u8; 24]; + openssl::rand::rand_bytes(&mut nonce).expect("random number generation failed"); + tx.set_setting( + MOCK_AUTHENTICATION_NONCE_KEY.to_string(), + Some(BASE64_STANDARD.encode(nonce)), + )?; + } + Ok(()) } diff --git a/src/catalog/src/durable/transaction.rs b/src/catalog/src/durable/transaction.rs index 18fcfef437c50..26a3d131fe7d7 100644 --- a/src/catalog/src/durable/transaction.rs +++ b/src/catalog/src/durable/transaction.rs @@ -62,10 +62,11 @@ use crate::durable::objects::{ use crate::durable::{ AUDIT_LOG_ID_ALLOC_KEY, BUILTIN_MIGRATION_SHARD_KEY, CATALOG_CONTENT_VERSION_KEY, CatalogError, DATABASE_ID_ALLOC_KEY, DefaultPrivilege, DurableCatalogError, DurableCatalogState, - EXPRESSION_CACHE_SHARD_KEY, NetworkPolicy, OID_ALLOC_KEY, SCHEMA_ID_ALLOC_KEY, - STORAGE_USAGE_ID_ALLOC_KEY, SYSTEM_CLUSTER_ID_ALLOC_KEY, SYSTEM_ITEM_ALLOC_KEY, - SYSTEM_REPLICA_ID_ALLOC_KEY, Snapshot, SystemConfiguration, USER_ITEM_ALLOC_KEY, - USER_NETWORK_POLICY_ID_ALLOC_KEY, USER_REPLICA_ID_ALLOC_KEY, USER_ROLE_ID_ALLOC_KEY, + EXPRESSION_CACHE_SHARD_KEY, MOCK_AUTHENTICATION_NONCE_KEY, NetworkPolicy, OID_ALLOC_KEY, + SCHEMA_ID_ALLOC_KEY, STORAGE_USAGE_ID_ALLOC_KEY, SYSTEM_CLUSTER_ID_ALLOC_KEY, + SYSTEM_ITEM_ALLOC_KEY, SYSTEM_REPLICA_ID_ALLOC_KEY, Snapshot, SystemConfiguration, + USER_ITEM_ALLOC_KEY, USER_NETWORK_POLICY_ID_ALLOC_KEY, USER_REPLICA_ID_ALLOC_KEY, + USER_ROLE_ID_ALLOC_KEY, }; use crate::memory::objects::{StateDiff, StateUpdate, StateUpdateKind}; @@ -1920,7 +1921,7 @@ impl<'a> Transaction<'a> { } /// Get the value of a persisted setting. - fn get_setting(&self, name: String) -> Option<&str> { + pub fn get_setting(&self, name: String) -> Option<&str> { self.settings .get(&SettingKey { name }) .map(|entry| &*entry.value) @@ -2196,6 +2197,14 @@ impl<'a> Transaction<'a> { .map(|value| &*value.value) } + pub fn get_authentication_mock_nonce(&self) -> Option { + self.settings + .get(&SettingKey { + name: MOCK_AUTHENTICATION_NONCE_KEY.to_string(), + }) + .map(|value| value.value.clone()) + } + /// Commit the current operation within the transaction. This does not cause anything to be /// written durably, but signals to the current transaction that we are moving on to the next /// operation. diff --git a/src/catalog/src/memory/error.rs b/src/catalog/src/memory/error.rs index e5248a39558fe..9aa982bd84e59 100644 --- a/src/catalog/src/memory/error.rs +++ b/src/catalog/src/memory/error.rs @@ -102,6 +102,8 @@ pub enum ErrorKind { VarError(#[from] VarError), #[error("unknown cluster replica size {size}")] InvalidClusterReplicaSize { size: String, expected: Vec }, + #[error("failed to get catalog setting: {0}")] + SettingError(String), #[error("internal error: {0}")] Internal(String), } diff --git a/src/catalog/tests/debug.rs b/src/catalog/tests/debug.rs index 893dbae57e74c..742116ec654cd 100644 --- a/src/catalog/tests/debug.rs +++ b/src/catalog/tests/debug.rs @@ -15,8 +15,8 @@ use mz_catalog::durable::initialize::USER_VERSION_KEY; use mz_catalog::durable::objects::serialization::proto; use mz_catalog::durable::{ BUILTIN_MIGRATION_SHARD_KEY, CATALOG_VERSION, CatalogError, DurableCatalogError, - DurableCatalogState, EXPRESSION_CACHE_SHARD_KEY, Epoch, FenceError, TestCatalogStateBuilder, - test_bootstrap_args, + DurableCatalogState, EXPRESSION_CACHE_SHARD_KEY, Epoch, FenceError, + MOCK_AUTHENTICATION_NONCE_KEY, TestCatalogStateBuilder, test_bootstrap_args, }; use mz_ore::now::{NOW_ZERO, SYSTEM_TIME}; use mz_ore::{assert_none, assert_ok}; @@ -75,6 +75,12 @@ impl StableTrace<'_> { ) -> bool { key.name == EXPRESSION_CACHE_SHARD_KEY } + + fn is_mock_authentication_nonce( + ((key, _), _, _): &((proto::SettingKey, proto::SettingValue), Timestamp, Diff), + ) -> bool { + key.name == MOCK_AUTHENTICATION_NONCE_KEY + } } impl Debug for StableTrace<'_> { @@ -118,6 +124,7 @@ impl Debug for StableTrace<'_> { .filter(|value| { !Self::is_builtin_migration_shard(value) && !Self::is_expression_cache_shard(value) + && !Self::is_mock_authentication_nonce(value) }) .cloned() .collect(), @@ -238,7 +245,7 @@ async fn test_debug(state_builder: TestCatalogStateBuilder) { // Check adding a new value via `edit`. let settings = unconsolidated_trace.settings.values; - assert_eq!(settings.len(), 3); + assert_eq!(settings.len(), 4); let prev = debug_state .edit::( @@ -256,7 +263,7 @@ async fn test_debug(state_builder: TestCatalogStateBuilder) { let unconsolidated_trace = openable_state_reader.trace_unconsolidated().await.unwrap(); let mut settings = unconsolidated_trace.settings.values; differential_dataflow::consolidation::consolidate_updates(&mut settings); - assert_eq!(settings.len(), 4); + assert_eq!(settings.len(), 5); let ((key, value), _ts, diff) = settings .into_iter() .find(|((key, _), _, _)| key.name == "debug-key") @@ -297,7 +304,7 @@ async fn test_debug(state_builder: TestCatalogStateBuilder) { let unconsolidated_trace = openable_state_reader.trace_unconsolidated().await.unwrap(); let mut settings = unconsolidated_trace.settings.values; differential_dataflow::consolidation::consolidate_updates(&mut settings); - assert_eq!(settings.len(), 4); + assert_eq!(settings.len(), 5); let ((key, value), _ts, diff) = settings .into_iter() .find(|((key, _), _, _)| key.name == "debug-key") @@ -327,11 +334,11 @@ async fn test_debug(state_builder: TestCatalogStateBuilder) { let unconsolidated_trace = openable_state_reader.trace_unconsolidated().await.unwrap(); let mut settings = unconsolidated_trace.settings.values; differential_dataflow::consolidation::consolidate_updates(&mut settings); - assert_eq!(settings.len(), 3); + assert_eq!(settings.len(), 4); let consolidated_trace = openable_state_reader.trace_consolidated().await.unwrap(); let settings = consolidated_trace.settings.values; - assert_eq!(settings.len(), 3); + assert_eq!(settings.len(), 4); } #[mz_ore::test(tokio::test)] diff --git a/src/catalog/tests/open.rs b/src/catalog/tests/open.rs index 81a545de7121d..4d1c5eb7ce0da 100644 --- a/src/catalog/tests/open.rs +++ b/src/catalog/tests/open.rs @@ -16,8 +16,8 @@ use mz_catalog::durable::objects::serialization::proto; use mz_catalog::durable::objects::{DurableType, Snapshot}; use mz_catalog::durable::{ BUILTIN_MIGRATION_SHARD_KEY, CATALOG_VERSION, CatalogError, Database, DurableCatalogError, - DurableCatalogState, EXPRESSION_CACHE_SHARD_KEY, Epoch, FenceError, Schema, - TestCatalogStateBuilder, test_bootstrap_args, + DurableCatalogState, EXPRESSION_CACHE_SHARD_KEY, Epoch, FenceError, + MOCK_AUTHENTICATION_NONCE_KEY, Schema, TestCatalogStateBuilder, test_bootstrap_args, }; use mz_ore::cast::usize_to_u64; use mz_ore::collections::HashSet; @@ -65,6 +65,12 @@ impl StableSnapshot<'_> { name: EXPRESSION_CACHE_SHARD_KEY.to_string(), } } + + fn mock_authentication_nonce_key() -> proto::SettingKey { + proto::SettingKey { + name: MOCK_AUTHENTICATION_NONCE_KEY.to_string(), + } + } } impl Debug for StableSnapshot<'_> { @@ -97,6 +103,7 @@ impl Debug for StableSnapshot<'_> { let mut settings: BTreeMap = settings.clone(); settings.remove(&Self::builtin_migration_shard_key()); settings.remove(&Self::expression_cache_shard_key()); + settings.remove(&Self::mock_authentication_nonce_key()); f.debug_struct("Snapshot") .field("databases", databases) .field("schemas", schemas) diff --git a/src/catalog/tests/snapshots/debug__opened_trace.snap b/src/catalog/tests/snapshots/debug__opened_trace.snap index 526473e234369..d4a351212007b 100644 --- a/src/catalog/tests/snapshots/debug__opened_trace.snap +++ b/src/catalog/tests/snapshots/debug__opened_trace.snap @@ -1,6 +1,5 @@ --- source: src/catalog/tests/debug.rs -assertion_line: 226 expression: test_trace --- Trace { diff --git a/src/environmentd/src/environmentd/main.rs b/src/environmentd/src/environmentd/main.rs index 2e498bde71524..ec97b493406d1 100644 --- a/src/environmentd/src/environmentd/main.rs +++ b/src/environmentd/src/environmentd/main.rs @@ -766,6 +766,18 @@ fn run(mut args: Args) -> Result<(), anyhow::Error> { serde_json::from_reader(f)? }; + for (_, listener) in &listeners_config.sql { + listener + .validate() + .map_err(|e| anyhow::anyhow!("invalid SQL listener: {}", e))?; + } + + for (_, listener) in &listeners_config.http { + listener + .validate() + .map_err(|e| anyhow::anyhow!("invalid HTTP listener: {}", e))?; + } + // Configure CORS. let allowed_origins = if !args.cors_allowed_origin.is_empty() { args.cors_allowed_origin diff --git a/src/environmentd/src/http.rs b/src/environmentd/src/http.rs index 21c51741d2912..6b74cfffd113a 100644 --- a/src/environmentd/src/http.rs +++ b/src/environmentd/src/http.rs @@ -1002,6 +1002,14 @@ async fn auth( }); } }, + Authenticator::Sasl(_) => { + // We shouldn't ever end up here as the configuration is validated at startup. + // If we do, it's a server misconfiguration. + // Just in case, we return a 401 rather than panic. + return Err(AuthError::MissingHttpAuthentication { + include_www_authenticate_header, + }); + } Authenticator::None => { // If no authentication, use whatever is in the HTTP auth // header (without checking the password), or fall back to the diff --git a/src/environmentd/src/lib.rs b/src/environmentd/src/lib.rs index 6fcd4bb553d1a..5bcc909b40ee4 100644 --- a/src/environmentd/src/lib.rs +++ b/src/environmentd/src/lib.rs @@ -276,6 +276,7 @@ impl Listener { frontegg.expect("Frontegg args are required with AuthenticatorKind::Frontegg"), ), AuthenticatorKind::Password => Authenticator::Password(adapter_client.clone()), + AuthenticatorKind::Sasl => Authenticator::Sasl(adapter_client.clone()), AuthenticatorKind::None => Authenticator::None, }; @@ -401,6 +402,7 @@ impl Listeners { let authenticator_rx = match authenticator_kind { AuthenticatorKind::Frontegg => authenticator_frontegg_rx.clone(), AuthenticatorKind::Password => authenticator_password_rx.clone(), + AuthenticatorKind::Sasl => authenticator_password_rx.clone(), AuthenticatorKind::None => authenticator_none_rx.clone(), }; let source: &'static str = Box::leak(name.clone().into_boxed_str()); diff --git a/src/environmentd/src/test_util.rs b/src/environmentd/src/test_util.rs index 544d5eb6ad692..d9a2fe26dd8e7 100644 --- a/src/environmentd/src/test_util.rs +++ b/src/environmentd/src/test_util.rs @@ -409,6 +409,54 @@ impl TestHarness { self } + pub fn with_sasl_scram_auth(mut self, mz_system_password: Password) -> Self { + self.external_login_password_mz_system = Some(mz_system_password); + let enable_tls = self.tls.is_some(); + self.listeners_config = ListenersConfig { + sql: btreemap! { + "external".to_owned() => SqlListenerConfig { + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + authenticator_kind: AuthenticatorKind::Sasl, + allowed_roles: AllowedRoles::NormalAndInternal, + enable_tls, + }, + }, + http: btreemap! { + "external".to_owned() => HttpListenerConfig { + base: BaseListenerConfig { + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + authenticator_kind: AuthenticatorKind::Password, + allowed_roles: AllowedRoles::NormalAndInternal, + enable_tls, + }, + routes: HttpRoutesEnabled{ + base: true, + webhook: true, + internal: true, + metrics: false, + profiling: true, + }, + }, + "metrics".to_owned() => HttpListenerConfig { + base: BaseListenerConfig { + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + authenticator_kind: AuthenticatorKind::None, + allowed_roles: AllowedRoles::NormalAndInternal, + enable_tls: false, + }, + routes: HttpRoutesEnabled{ + base: false, + webhook: false, + internal: false, + metrics: true, + profiling: false, + }, + }, + }, + }; + self + } + pub fn with_now(mut self, now: NowFn) -> Self { self.now = now; self diff --git a/src/environmentd/tests/auth.rs b/src/environmentd/tests/auth.rs index b3649e681ff7c..75b81fa88d605 100644 --- a/src/environmentd/tests/auth.rs +++ b/src/environmentd/tests/auth.rs @@ -3159,6 +3159,107 @@ async fn test_password_auth() { ); } +#[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +async fn test_sasl_auth() { + let metrics_registry = MetricsRegistry::new(); + + let server = test_util::TestHarness::default() + .with_system_parameter_default( + "log_filter".to_string(), + "mz_frontegg_auth=debug,info".to_string(), + ) + .with_system_parameter_default("enable_password_auth".to_string(), "true".to_string()) + .with_sasl_scram_auth(Password("mz_system_password".to_owned())) + .with_metrics_registry(metrics_registry) + .start() + .await; + + let mz_system_client = server + .connect() + .no_tls() + .user("mz_system") + .password("mz_system_password") + .await + .unwrap(); + mz_system_client + .execute("CREATE ROLE foo WITH LOGIN PASSWORD 'bar'", &[]) + .await + .unwrap(); + + let external_client = server + .connect() + .no_tls() + .user("foo") + .password("bar") + .await + .unwrap(); + + assert_eq!( + external_client + .query_one("SELECT current_user", &[]) + .await + .unwrap() + .get::<_, String>(0), + "foo" + ); + + assert_eq!( + external_client + .query_one("SELECT mz_is_superuser()", &[]) + .await + .unwrap() + .get::<_, bool>(0), + false + ); +} + +#[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +async fn test_sasl_auth_failure() { + let metrics_registry = MetricsRegistry::new(); + + let server = test_util::TestHarness::default() + .with_system_parameter_default( + "log_filter".to_string(), + "mz_frontegg_auth=debug,info".to_string(), + ) + .with_system_parameter_default("enable_password_auth".to_string(), "true".to_string()) + .with_sasl_scram_auth(Password("mz_system_password".to_owned())) + .with_metrics_registry(metrics_registry) + .start() + .await; + + let mz_system_client = server + .connect() + .no_tls() + .user("mz_system") + .password("mz_system_password") + .await + .unwrap(); + mz_system_client + .execute("CREATE ROLE foo WITH LOGIN PASSWORD 'bar'", &[]) + .await + .unwrap(); + + let external_client = server + .connect() + .no_tls() + .user("foo") + .password("wrong_password") + .await; + assert_err!(external_client); + + let external_client = server + .connect() + .no_tls() + .user("no_user") + .password("wrong_password") + .await; + + assert_err!(external_client); +} + #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` async fn test_password_auth_superuser() { diff --git a/src/materialized/ci/listener_configs/password.json b/src/materialized/ci/listener_configs/password.json index f1aa75c171db3..200654ec7549b 100644 --- a/src/materialized/ci/listener_configs/password.json +++ b/src/materialized/ci/listener_configs/password.json @@ -5,6 +5,12 @@ "authenticator_kind": "Password", "allowed_roles": "NormalAndInternal", "enable_tls": false + }, + "externalsasl": { + "addr": "0.0.0.0:6877", + "authenticator_kind": "Sasl", + "allowed_roles": "NormalAndInternal", + "enable_tls": false } }, "http": { diff --git a/src/pgwire-common/src/lib.rs b/src/pgwire-common/src/lib.rs index 5efb07ac56f5d..3b9b73f9ddb9b 100644 --- a/src/pgwire-common/src/lib.rs +++ b/src/pgwire-common/src/lib.rs @@ -28,7 +28,8 @@ pub use conn::{ }; pub use format::Format; pub use message::{ - ErrorResponse, FrontendMessage, FrontendStartupMessage, VERSION_3, VERSION_CANCEL, - VERSION_GSSENC, VERSION_SSL, VERSIONS, + ChannelBinding, ErrorResponse, FrontendMessage, FrontendStartupMessage, GS2Header, + SASLClientFinalResponse, SASLInitialResponse, VERSION_3, VERSION_CANCEL, VERSION_GSSENC, + VERSION_SSL, VERSIONS, }; pub use severity::Severity; diff --git a/src/pgwire-common/src/message.rs b/src/pgwire-common/src/message.rs index a10444f4c7730..9158a989d984a 100644 --- a/src/pgwire-common/src/message.rs +++ b/src/pgwire-common/src/message.rs @@ -179,9 +179,59 @@ pub enum FrontendMessage { CopyFail(String), + RawAuthentication(Vec), + Password { password: String, }, + + SASLInitialResponse { + gs2_header: GS2Header, + mechanism: String, + initial_response: SASLInitialResponse, + }, + + SASLResponse(SASLClientFinalResponse), +} + +#[derive(Debug, Clone)] +pub enum ChannelBinding { + /// Client doesn't support channel binding. + None, + /// Client supports channel binding but thinks server does not. + ClientSupported, + /// Client requires channel binding. + Required(String), +} + +#[derive(Debug, Clone)] +pub struct GS2Header { + pub cbind_flag: ChannelBinding, + pub authzid: Option, +} + +impl GS2Header { + pub fn channel_binding_enabled(&self) -> bool { + matches!(self.cbind_flag, ChannelBinding::Required(_)) + } +} + +#[derive(Debug)] +pub struct SASLInitialResponse { + pub gs2_header: GS2Header, + pub nonce: String, + pub extensions: Vec, + pub reserved_mext: Option, + pub client_first_message_bare_raw: String, +} + +#[derive(Debug)] +pub struct SASLClientFinalResponse { + pub channel_binding: String, + pub nonce: String, + pub extensions: Vec, + pub proof: String, + pub client_final_message_bare_raw: String, } impl FrontendMessage { @@ -201,7 +251,10 @@ impl FrontendMessage { FrontendMessage::CopyData(_) => "copy_data", FrontendMessage::CopyDone => "copy_done", FrontendMessage::CopyFail(_) => "copy_fail", + FrontendMessage::RawAuthentication(_) => "raw_authentication", FrontendMessage::Password { .. } => "password", + FrontendMessage::SASLInitialResponse { .. } => "sasl_initial_response", + FrontendMessage::SASLResponse(_) => "sasl_response", } } } diff --git a/src/pgwire/Cargo.toml b/src/pgwire/Cargo.toml index 51443b9fd14b0..a19fbbed56805 100644 --- a/src/pgwire/Cargo.toml +++ b/src/pgwire/Cargo.toml @@ -12,6 +12,7 @@ workspace = true [dependencies] anyhow = "1.0.98" async-trait = "0.1.88" +base64 = "0.22.1" byteorder = "1.4.3" bytes = "1.10.1" bytesize = "1.3.0" diff --git a/src/pgwire/src/codec.rs b/src/pgwire/src/codec.rs index 260a4ef9cd973..abe6373b8c537 100644 --- a/src/pgwire/src/codec.rs +++ b/src/pgwire/src/codec.rs @@ -26,7 +26,8 @@ use mz_ore::cast::CastFrom; use mz_ore::future::OreSinkExt; use mz_ore::netio::AsyncReady; use mz_pgwire_common::{ - Conn, Cursor, DecodeState, ErrorResponse, FrontendMessage, MAX_REQUEST_SIZE, Pgbuf, input_err, + ChannelBinding, Conn, Cursor, DecodeState, ErrorResponse, FrontendMessage, GS2Header, + MAX_REQUEST_SIZE, Pgbuf, SASLClientFinalResponse, SASLInitialResponse, input_err, parse_frame_len, }; use tokio::io::{self, AsyncRead, AsyncWrite, Interest, Ready}; @@ -34,7 +35,7 @@ use tokio::time::{self, Duration}; use tokio_util::codec::{Decoder, Encoder, Framed}; use tracing::trace; -use crate::message::{BackendMessage, BackendMessageKind}; +use crate::message::{BackendMessage, BackendMessageKind, SASLServerFinalMessageKinds}; /// A connection that manages the encoding and decoding of pgwire frames. pub struct FramedConn { @@ -232,7 +233,10 @@ impl Encoder for Codec { // Write type byte. let byte = match &msg { BackendMessage::AuthenticationOk => b'R', - BackendMessage::AuthenticationCleartextPassword => b'R', + BackendMessage::AuthenticationCleartextPassword + | BackendMessage::AuthenticationSASL + | BackendMessage::AuthenticationSASLContinue(_) + | BackendMessage::AuthenticationSASLFinal(_) => b'R', BackendMessage::RowDescription(_) => b'T', BackendMessage::DataRow(_) => b'D', BackendMessage::CommandComplete { .. } => b'C', @@ -290,6 +294,32 @@ impl Encoder for Codec { BackendMessage::AuthenticationCleartextPassword => { dst.put_u32(3); } + BackendMessage::AuthenticationSASL => { + dst.put_u32(10); + dst.put_string("SCRAM-SHA-256"); + dst.put_u8(b'\0'); + } + BackendMessage::AuthenticationSASLContinue(data) => { + dst.put_u32(11); + let data = format!( + "r={},s={},i={}", + data.nonce, data.salt, data.iteration_count + ); + dst.put_slice(data.as_bytes()); + } + BackendMessage::AuthenticationSASLFinal(data) => { + dst.put_u32(12); + let res = match data.kind { + SASLServerFinalMessageKinds::Verifier(verifier) => { + format!("v={}", verifier) + } + }; + dst.put_slice(res.as_bytes()); + if !data.extensions.is_empty() { + dst.put_slice(b","); + dst.put_slice(data.extensions.join(",").as_bytes()); + } + } BackendMessage::RowDescription(fields) => { dst.put_length_i16(fields.len())?; for f in &fields { @@ -447,7 +477,7 @@ impl Decoder for Codec { b'X' => decode_terminate(buf)?, // Authentication. - b'p' => decode_password(buf)?, + b'p' => decode_auth(buf)?, // Copy from flow. b'f' => decode_copy_fail(buf)?, @@ -476,7 +506,223 @@ fn decode_terminate(mut _buf: Cursor) -> Result { Ok(FrontendMessage::Terminate) } -fn decode_password(mut buf: Cursor) -> Result { +fn decode_auth(mut buf: Cursor) -> Result { + let mut value = Vec::new(); + while let Ok(b) = buf.read_byte() { + value.push(b); + } + Ok(FrontendMessage::RawAuthentication(value)) +} + +fn expect(buf: &mut Cursor, expected: &[u8]) -> Result<(), io::Error> { + for i in 0..expected.len() { + if buf.read_byte()? != expected[i] { + return Err(input_err(format!( + "Invalid SASL initial response: expected '{}'", + std::str::from_utf8(expected).unwrap_or("invalid UTF-8") + ))); + } + } + Ok(()) +} + +fn read_until_comma(buf: &mut Cursor) -> Result, io::Error> { + let mut v = Vec::new(); + while let Ok(b) = buf.peek_byte() { + if b == b',' { + break; + } + v.push(buf.read_byte()?); + } + Ok(v) +} + +// All SASL parsing is based on RFC 5802, [section 7](https://datatracker.ietf.org/doc/html/rfc5802#section-7) + +// extensions = attr-val *("," attr-val) +// ;; All extensions are optional, +// ;; i.e., unrecognized attributes +// ;; not defined in this document +// ;; MUST be ignored. +// reserved-mext = "m=" 1*(value-char) +// ;; Reserved for signaling mandatory extensions. +// ;; The exact syntax will be defined in +// ;; the future. +// gs2-cbind-flag = ("p=" cb-name) / "n" / "y" +// ;; "n" -> client doesn't support channel binding. +// ;; "y" -> client does support channel binding +// ;; but thinks the server does not. +// ;; "p" -> client requires channel binding. +// ;; The selected channel binding follows "p=". +// +// gs2-header = gs2-cbind-flag "," [ authzid ] "," +// ;; GS2 header for SCRAM +// ;; (the actual GS2 header includes an optional +// ;; flag to indicate that the GSS mechanism is not +// ;; "standard", but since SCRAM is "standard", we +// ;; don't include that flag). +// client-first-message-bare = +// [reserved-mext ","] +// username "," nonce ["," extensions] +// +// client-first-message = +// gs2-header client-first-message-bare +pub fn decode_sasl_client_first_message(mut buf: Cursor) -> Result { + // 1) GS2 cbind flag + let cbind_flag = match buf.read_byte()? { + b'n' => ChannelBinding::None, + b'y' => ChannelBinding::ClientSupported, + b'p' => { + // must be "p=" then cbname up to next comma + expect(&mut buf, b"=")?; + let cbname = String::from_utf8(read_until_comma(&mut buf)?) + .map_err(|_| input_err("invalid cbname utf8"))?; + ChannelBinding::Required(cbname) + } + other => { + return Err(input_err(format!( + "Invalid channel binding flag: {}", + other + ))); + } + }; + expect(&mut buf, b",")?; + + // 2) Optional authzid: either empty, or "a=" up to next comma + let mut authzid = None; + if buf.peek_byte()? == b'a' { + expect(&mut buf, b"a=")?; + let a = String::from_utf8(read_until_comma(&mut buf)?) + .map_err(|_| input_err("invalid authzid utf8"))?; + authzid = Some(a); + } + expect(&mut buf, b",")?; + + let mut client_first_message_bare_raw = String::new(); + + // 3) Optional reserved "m=" extension before n= + let mut reserved_mext = None; + if buf.peek_byte()? == b'm' { + expect(&mut buf, b"m=")?; + let mext_val = String::from_utf8(read_until_comma(&mut buf)?) + .map_err(|_| input_err("invalid m ext utf8"))?; + client_first_message_bare_raw.push_str(&format!("m={},", mext_val)); + reserved_mext = Some(mext_val); + expect(&mut buf, b",")?; + } + + // 4) Username: must be "n=" then saslname + expect(&mut buf, b"n=")?; + // Postgres doesn't use the username here, so we just consume + let username = String::from_utf8(read_until_comma(&mut buf)?) + .map_err(|_| input_err("invalid username utf8"))?; + expect(&mut buf, b",")?; + client_first_message_bare_raw.push_str(&format!("n={},", username)); + + // 5) Nonce: must be "r=" then value up to next comma or end + expect(&mut buf, b"r=")?; + let nonce = String::from_utf8(read_until_comma(&mut buf)?) + .map_err(|_| input_err("invalid nonce utf8"))?; + client_first_message_bare_raw.push_str(&format!("r={}", nonce)); + + // 6) Optional extensions: "," key=value chunks + let mut extensions = Vec::new(); + while let Ok(b',') = buf.peek_byte().map(|b| b) { + expect(&mut buf, b",")?; + let ext = String::from_utf8(read_until_comma(&mut buf)?) + .map_err(|_| input_err("invalid ext utf8"))?; + if !ext.is_empty() { + client_first_message_bare_raw.push_str(&format!(",{}", ext)); + extensions.push(ext); + } + } + + Ok(SASLInitialResponse { + gs2_header: GS2Header { + cbind_flag, + authzid, + }, + nonce, + extensions, + reserved_mext, + client_first_message_bare_raw, + }) +} + +pub fn decode_sasl_initial_response(mut buf: Cursor) -> Result { + let mechanism = buf.read_cstr()?; + let initial_resp_len = buf.read_i32()?; + if initial_resp_len < 0 { + // -1 means no response? We bail here + return Err(input_err("No initial response")); + } + + let initial_response = decode_sasl_client_first_message(buf)?; + Ok(FrontendMessage::SASLInitialResponse { + gs2_header: initial_response.gs2_header.clone(), + mechanism: mechanism.to_owned(), + initial_response, + }) +} + +// proof = "p=" base64 +// +// channel-binding = "c=" base64 +// ;; base64 encoding of cbind-input. +// client-final-message-without-proof = +// channel-binding "," nonce ["," +// extensions] +// +// client-final-message = +// client-final-message-without-proof "," proof +pub fn decode_sasl_response(mut buf: Cursor) -> Result { + // --- client-final-message-without-proof --- + let mut client_final_message_bare_raw = String::new(); + // channel-binding: "c=" , up to the next comma + expect(&mut buf, b"c=")?; + let channel_binding = String::from_utf8(read_until_comma(&mut buf)?) + .map_err(|_| input_err("invalid channel-binding utf8"))?; + expect(&mut buf, b",")?; + client_final_message_bare_raw.push_str(&format!("c={},", channel_binding)); + + // nonce: "r=" , up to the next comma + expect(&mut buf, b"r=")?; + let nonce = String::from_utf8(read_until_comma(&mut buf)?) + .map_err(|_| input_err("invalid nonce utf8"))?; + client_final_message_bare_raw.push_str(&format!("r={}", nonce)); + + // after reading channel-binding and nonce + let mut extensions = Vec::new(); + + // Keep reading "," until we see ",p=" + while buf.peek_byte()? == b',' { + expect(&mut buf, b",")?; + if buf.peek_byte()? == b'p' { + break; + } + let ext = String::from_utf8(read_until_comma(&mut buf)?) + .map_err(|_| input_err("invalid extension utf8"))?; + if !ext.is_empty() { + client_final_message_bare_raw.push_str(&format!(",{}", ext)); + extensions.push(ext); + } + } + + // Proof is mandatory and last + expect(&mut buf, b"p=")?; + let proof = String::from_utf8(read_until_comma(&mut buf)?) + .map_err(|_| input_err("invalid proof utf8"))?; + + Ok(FrontendMessage::SASLResponse(SASLClientFinalResponse { + channel_binding, + nonce, + extensions, + proof, + client_final_message_bare_raw, + })) +} + +pub fn decode_password(mut buf: Cursor) -> Result { Ok(FrontendMessage::Password { password: buf.read_cstr()?.to_owned(), }) diff --git a/src/pgwire/src/message.rs b/src/pgwire/src/message.rs index a50319975009b..d242b1ec7c9cb 100644 --- a/src/pgwire/src/message.rs +++ b/src/pgwire/src/message.rs @@ -21,6 +21,9 @@ use mz_repr::{ColumnName, RelationDesc}; pub enum BackendMessage { AuthenticationOk, AuthenticationCleartextPassword, + AuthenticationSASL, + AuthenticationSASLContinue(SASLServerFirstMessage), + AuthenticationSASLFinal(SASLServerFinalMessage), CommandComplete { tag: String, }, @@ -58,6 +61,13 @@ impl From for BackendMessage { } } +#[derive(Debug)] +pub struct SASLServerFirstMessage { + pub iteration_count: usize, + pub nonce: String, + pub salt: String, +} + #[derive(Debug)] pub struct FieldDescription { pub name: ColumnName, @@ -69,6 +79,19 @@ pub struct FieldDescription { pub format: mz_pgwire_common::Format, } +#[derive(Debug)] +pub enum SASLServerFinalMessageKinds { + Verifier(String), + // The spec specifies an Error kind here but PG just uses + // its own error handling +} + +#[derive(Debug)] +pub struct SASLServerFinalMessage { + pub kind: SASLServerFinalMessageKinds, + pub extensions: Vec, +} + pub fn encode_row_description( desc: &RelationDesc, formats: &[mz_pgwire_common::Format], diff --git a/src/pgwire/src/protocol.rs b/src/pgwire/src/protocol.rs index efa57107e0780..6902ed4646fda 100644 --- a/src/pgwire/src/protocol.rs +++ b/src/pgwire/src/protocol.rs @@ -15,6 +15,7 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use std::{iter, mem}; +use base64::prelude::*; use byteorder::{ByteOrder, NetworkEndian}; use futures::future::{BoxFuture, FutureExt, pending}; use itertools::Itertools; @@ -37,7 +38,8 @@ use mz_ore::str::StrExt; use mz_ore::{assert_none, assert_ok, instrument, soft_assert_eq_or_log}; use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams}; use mz_pgwire_common::{ - ConnectionCounter, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3, VERSIONS, + ConnectionCounter, Cursor, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3, + VERSIONS, }; use mz_repr::user::InternalUserMetadata; use mz_repr::{ @@ -62,8 +64,13 @@ use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{Instrument, debug, debug_span, warn}; use uuid::Uuid; -use crate::codec::FramedConn; -use crate::message::{self, BackendMessage}; +use crate::codec::{ + FramedConn, decode_password, decode_sasl_initial_response, decode_sasl_response, +}; +use crate::message::{ + self, BackendMessage, SASLServerFinalMessage, SASLServerFinalMessageKinds, + SASLServerFirstMessage, +}; /// Reports whether the given stream begins with a pgwire handshake. /// @@ -180,7 +187,19 @@ where .await?; conn.flush().await?; let password = match conn.recv().await? { - Some(FrontendMessage::Password { password }) => password, + Some(FrontendMessage::RawAuthentication(data)) => { + match decode_password(Cursor::new(&data)).ok() { + Some(FrontendMessage::Password { password }) => password, + _ => { + return conn + .send(ErrorResponse::fatal( + SqlState::INVALID_AUTHORIZATION_SPECIFICATION, + "expected Password message", + )) + .await; + } + } + } _ => { return conn .send(ErrorResponse::fatal( @@ -229,7 +248,19 @@ where .await?; conn.flush().await?; let password = match conn.recv().await? { - Some(FrontendMessage::Password { password }) => Password(password), + Some(FrontendMessage::RawAuthentication(data)) => { + match decode_password(Cursor::new(&data)).ok() { + Some(FrontendMessage::Password { password }) => Password(password), + _ => { + return conn + .send(ErrorResponse::fatal( + SqlState::INVALID_AUTHORIZATION_SPECIFICATION, + "expected Password message", + )) + .await; + } + } + } _ => { return conn .send(ErrorResponse::fatal( @@ -266,6 +297,189 @@ where let auth_session = pending().right_future(); (session, auth_session) } + Authenticator::Sasl(adapter_client) => { + // Start the handshake + conn.send(BackendMessage::AuthenticationSASL).await?; + conn.flush().await?; + // Get the initial response indicating chosen mechanism + let (mechanism, initial_response) = match conn.recv().await? { + Some(FrontendMessage::RawAuthentication(data)) => { + match decode_sasl_initial_response(Cursor::new(&data)).ok() { + Some(FrontendMessage::SASLInitialResponse { + gs2_header, + mechanism, + initial_response, + }) => { + // We do not support channel binding + if gs2_header.channel_binding_enabled() { + return conn + .send(ErrorResponse::fatal( + SqlState::PROTOCOL_VIOLATION, + "channel binding not supported", + )) + .await; + } + (mechanism, initial_response) + } + _ => { + return conn + .send(ErrorResponse::fatal( + SqlState::INVALID_AUTHORIZATION_SPECIFICATION, + "expected SASLInitialResponse message", + )) + .await; + } + } + } + _ => { + return conn + .send(ErrorResponse::fatal( + SqlState::INVALID_AUTHORIZATION_SPECIFICATION, + "expected SASLInitialResponse message", + )) + .await; + } + }; + + if mechanism != "SCRAM-SHA-256" { + return conn + .send(ErrorResponse::fatal( + SqlState::INVALID_AUTHORIZATION_SPECIFICATION, + "unsupported SASL mechanism", + )) + .await; + } + + if initial_response.nonce.len() > 256 { + return conn + .send(ErrorResponse::fatal( + SqlState::INVALID_AUTHORIZATION_SPECIFICATION, + "nonce too long", + )) + .await; + } + + let (server_first_message_raw, mock_hash) = match adapter_client + .generate_sasl_challenge(&user, &initial_response.nonce) + .await + { + Ok(response) => { + let server_first_message_raw = format!( + "r={},s={},i={}", + response.nonce, response.salt, response.iteration_count + ); + + let client_key = [0u8; 32]; + let server_key = [1u8; 32]; + let mock_hash = format!( + "SCRAM-SHA-256${}:{}${}:{}", + response.iteration_count, + response.salt, + BASE64_STANDARD.encode(client_key), + BASE64_STANDARD.encode(server_key) + ); + + conn.send(BackendMessage::AuthenticationSASLContinue( + SASLServerFirstMessage { + iteration_count: response.iteration_count, + nonce: response.nonce, + salt: response.salt, + }, + )) + .await?; + conn.flush().await?; + (server_first_message_raw, mock_hash) + } + Err(e) => { + return conn.send(e.into_response(Severity::Fatal)).await; + } + }; + + let auth_resp = match conn.recv().await? { + Some(FrontendMessage::RawAuthentication(data)) => { + match decode_sasl_response(Cursor::new(&data)).ok() { + Some(FrontendMessage::SASLResponse(response)) => { + let auth_message = format!( + "{},{},{}", + initial_response.client_first_message_bare_raw, + server_first_message_raw, + response.client_final_message_bare_raw + ); + if response.proof.len() > 1024 { + return conn + .send(ErrorResponse::fatal( + SqlState::INVALID_AUTHORIZATION_SPECIFICATION, + "proof too long", + )) + .await; + } + match adapter_client + .verify_sasl_proof( + &user, + &response.proof, + &auth_message, + &mock_hash, + ) + .await + { + Ok(resp) => { + conn.send(BackendMessage::AuthenticationSASLFinal( + SASLServerFinalMessage { + kind: SASLServerFinalMessageKinds::Verifier( + resp.verifier, + ), + extensions: vec![], + }, + )) + .await?; + conn.flush().await?; + resp.auth_resp + } + Err(_) => { + return conn + .send(ErrorResponse::fatal( + SqlState::INVALID_PASSWORD, + "invalid password", + )) + .await; + } + } + } + _ => { + return conn + .send(ErrorResponse::fatal( + SqlState::INVALID_AUTHORIZATION_SPECIFICATION, + "expected SASLResponse message", + )) + .await; + } + } + } + _ => { + return conn + .send(ErrorResponse::fatal( + SqlState::INVALID_AUTHORIZATION_SPECIFICATION, + "expected SASLResponse message", + )) + .await; + } + }; + + let session = adapter_client.new_session(SessionConfig { + conn_id: conn.conn_id().clone(), + uuid: conn_uuid, + user, + client_ip: conn.peer_addr().clone(), + external_metadata_rx: None, + internal_user_metadata: Some(InternalUserMetadata { + superuser: auth_resp.superuser, + }), + helm_chart_version, + }); + // No frontegg check, so auth session lasts indefinitely. + let auth_session = pending().right_future(); + (session, auth_session) + } Authenticator::None => { let session = adapter_client.new_session(SessionConfig { conn_id: conn.conn_id().clone(), @@ -664,7 +878,10 @@ where Some(FrontendMessage::CopyData(_)) | Some(FrontendMessage::CopyDone) | Some(FrontendMessage::CopyFail(_)) - | Some(FrontendMessage::Password { .. }) => State::Drain, + | Some(FrontendMessage::Password { .. }) + | Some(FrontendMessage::RawAuthentication(_)) + | Some(FrontendMessage::SASLInitialResponse { .. }) + | Some(FrontendMessage::SASLResponse(_)) => State::Drain, None => State::Done, }; diff --git a/src/server-core/src/listeners.rs b/src/server-core/src/listeners.rs index 77f0914101163..2242643ec6a56 100644 --- a/src/server-core/src/listeners.rs +++ b/src/server-core/src/listeners.rs @@ -19,6 +19,8 @@ pub enum AuthenticatorKind { Frontegg, /// Authenticate users using internally stored password hashes. Password, + /// Authenticate users using SASL. + Sasl, /// Do not authenticate users. Trust they are who they say they are without verification. #[default] None, @@ -80,6 +82,7 @@ pub trait ListenerConfig { fn authenticator_kind(&self) -> AuthenticatorKind; fn allowed_roles(&self) -> AllowedRoles; fn enable_tls(&self) -> bool; + fn validate(&self) -> Result<(), String>; } impl ListenerConfig for SqlListenerConfig { fn addr(&self) -> SocketAddr { @@ -97,6 +100,10 @@ impl ListenerConfig for SqlListenerConfig { fn enable_tls(&self) -> bool { self.enable_tls } + + fn validate(&self) -> Result<(), String> { + Ok(()) + } } impl ListenerConfig for HttpListenerConfig { fn addr(&self) -> SocketAddr { @@ -114,4 +121,12 @@ impl ListenerConfig for HttpListenerConfig { fn enable_tls(&self) -> bool { self.base.enable_tls } + + fn validate(&self) -> Result<(), String> { + if self.base.authenticator_kind == AuthenticatorKind::Sasl { + Err("SASL authentication is not supported for HTTP listeners".to_string()) + } else { + Ok(()) + } + } }