diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fecdbb0..cfb56db 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,21 +10,66 @@ on: env: CARGO_TERM_COLOR: always + RUST_BACKTRACE: 1 jobs: build: - runs-on: ubuntu-latest - services: clickhouse: - image: yandex/clickhouse-server + image: clickhouse/clickhouse-server ports: - 9000:9000 - + env: + CLICKHOUSE_SKIP_USER_SETUP: 1 steps: - uses: actions/checkout@v3 - name: Build run: cargo build --verbose - name: Run tests run: cargo test --verbose + + build-tls: + strategy: + fail-fast: false + matrix: + feature: + - tls-native-tls + - tls-rustls + database_url: + # for TLS we need skip_verify for self-signed certificate + - tcp://localhost:9440?skip_verify=true + # we don't need skip_verify when we pass CA cert + - tcp://localhost:9440?ca_certificate=tls/ca.pem + # mTLS + - tcp://tls@localhost:9440?ca_certificate=tls/ca.pem&client_certificate=tls/client.crt&client_private_key=tls/client.key + runs-on: ubuntu-latest + env: + # NOTE: not all tests "secure" aware, so let's define DATABASE_URL explicitly + # NOTE: sometimes for native-tls default connection_timeout (500ms) is not enough, interestingly that for rustls it is OK. + DATABASE_URL: ${{ matrix.database_url }}&compression=lz4&ping_timeout=2s&retry_timeout=3s&secure=true&connection_timeout=5s + steps: + - uses: actions/checkout@v3 + - name: Generate TLS certificates + run: | + extras/ci/generate_certs.sh tls + # NOTE: + # - we cannot use "services" because they are executed before the steps, i.e. repository checkout. + # - "job.container.network" is empty, hence "host" + # - github actions does not support YAML anchors (sigh) + - name: Run clickhouse-server + run: docker run + -v ./extras/ci/overrides.xml:/etc/clickhouse-server/config.d/overrides.xml + -v ./extras/ci/users-overrides.yaml:/etc/clickhouse-server/users.d/overrides.yaml + -v ./tls:/etc/clickhouse-server/tls + -e CLICKHOUSE_SKIP_USER_SETUP=1 + --network host + --name clickhouse + --rm + --detach + --publish 9440:9440 + clickhouse/clickhouse-server + - name: Build + run: cargo build --features ${{ matrix.feature }} --verbose + - name: Run tests + run: cargo test --features ${{ matrix.feature }} --verbose diff --git a/Cargo.toml b/Cargo.toml index cfac622..9d7de0f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,10 @@ exclude = ["tests/*", "examples/*"] [features] default = ["tokio_io"] -tls = ["tokio-native-tls", "native-tls"] +_tls = [] # meta feature for the clickhouse-rs generic TLS code +tls = ["tls-native-tls"] # backward compatibility +tls-native-tls = ["tokio-native-tls", "native-tls", "_tls"] +tls-rustls = ["tokio-rustls", "rustls", "rustls-pemfile", "webpki-roots", "_tls"] async_std = ["async-std"] tokio_io = ["tokio"] @@ -67,6 +70,22 @@ optional = true version = "^0.3" optional = true +[dependencies.rustls] +version = "0.22.1" +optional = true + +[dependencies.rustls-pemfile] +version = "2.0" +optional = true + +[dependencies.tokio-rustls] +version = "0.25.0" +optional = true + +[dependencies.webpki-roots] +version = "*" +optional = true + [dependencies.chrono] version = "^0.4" default-features = false @@ -76,6 +95,7 @@ features = ["std"] env_logger = "^0.10" pretty_assertions = "1.3.0" rand = "^0.8" +uuid = { version = "^1.4", features = [ "v4" ] } [dev-dependencies.tokio] version = "^1.32" diff --git a/README.md b/README.md index eedb4e4..badbdfb 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,13 @@ for the most common use cases. The following features are available. - `tokio_io` *(enabled by default)* — I/O based on [Tokio](https://tokio.rs/). - `async_std` — I/O based on [async-std](https://async.rs/) (doesn't work together with `tokio_io`). -- `tls` — TLS support (allowed only with `tokio_io`). +- `tls` — TLS support (allowed only with `tokio_io` and one of TLS libraries, under `tls-rustls` or `tls-native-tls` features). + +### TLS + +- `skip_verify` - do not verify the server certificate (**insecure**) +- `ca_certificate` - instead of `skip_verify` it is better to pass CA certificate explicitly (in case of self-signed certificates). +- `client_certificate`/`client_private_key` - authentication using TLS certificates (mTLS) (see [ClickHouse documentation](https://clickhouse.com/docs/operations/external-authenticators/ssl-x509) for more info) ## Example diff --git a/examples/simple.rs b/examples/simple.rs index f472995..bbfa46b 100644 --- a/examples/simple.rs +++ b/examples/simple.rs @@ -38,7 +38,7 @@ async fn execute(database_url: String) -> Result<(), Box> { Ok(()) } -#[cfg(all(feature = "tokio_io", not(feature = "tls")))] +#[cfg(all(feature = "tokio_io", not(feature = "_tls")))] #[tokio::main] async fn main() -> Result<(), Box> { let database_url = @@ -46,7 +46,7 @@ async fn main() -> Result<(), Box> { execute(database_url).await } -#[cfg(all(feature = "tokio_io", feature = "tls"))] +#[cfg(all(feature = "tokio_io", feature = "_tls"))] #[tokio::main] async fn main() -> Result<(), Box> { let database_url = env::var("DATABASE_URL") diff --git a/extras/ci/generate_certs.sh b/extras/ci/generate_certs.sh new file mode 100755 index 0000000..ed41da3 --- /dev/null +++ b/extras/ci/generate_certs.sh @@ -0,0 +1,52 @@ +#!/usr/bin/env bash + +out=$1 && shift +mkdir -p "$out" +cd "$out" + +# +# CA +# +openssl genrsa -out ca.key 4096 +openssl req -x509 -new -nodes -key ca.key -sha256 -days 3650 -out ca.pem -subj "/C=US/ST=DevState/O=DevOrg/CN=MyDevCA" + +# +# server +# +openssl genrsa -out server.key 2048 +openssl req -new -key server.key -out server.csr -subj "/C=US/ST=DevState/O=DevOrg/CN=localhost" + +cat > server.ext < client.ext < + + + /etc/clickhouse-server/tls/server.crt + /etc/clickhouse-server/tls/server.key + /etc/clickhouse-server/tls/ca.pem + relaxed + true + true + sslv2,sslv3 + true + + + 9440 + + + 1 + + diff --git a/extras/ci/users-overrides.yaml b/extras/ci/users-overrides.yaml new file mode 100644 index 0000000..16f3e4d --- /dev/null +++ b/extras/ci/users-overrides.yaml @@ -0,0 +1,6 @@ +--- +users: + tls: + ssl_certificates: + subject_alt_name: + - DNS:localhost diff --git a/src/client_info.rs b/src/client_info.rs index 899904f..88c47db 100644 --- a/src/client_info.rs +++ b/src/client_info.rs @@ -1,28 +1,26 @@ use crate::binary::Encoder; -pub static CLIENT_NAME: &str = "Rust SQLDriver"; - pub const CLICK_HOUSE_REVISION: u64 = 54429; // DBMS_MIN_REVISION_WITH_SETTINGS_SERIALIZED_AS_STRINGS pub const CLICK_HOUSE_DBMSVERSION_MAJOR: u64 = 1; pub const CLICK_HOUSE_DBMSVERSION_MINOR: u64 = 1; -pub fn write(encoder: &mut Encoder) { - encoder.string(CLIENT_NAME); +pub fn write(encoder: &mut Encoder, client_name: &str) { + encoder.string(client_name); encoder.uvarint(CLICK_HOUSE_DBMSVERSION_MAJOR); encoder.uvarint(CLICK_HOUSE_DBMSVERSION_MINOR); encoder.uvarint(CLICK_HOUSE_REVISION); } -pub fn description() -> String { +pub fn description(client_name: &str) -> String { format!( - "{CLIENT_NAME} {CLICK_HOUSE_DBMSVERSION_MAJOR}.{CLICK_HOUSE_DBMSVERSION_MINOR}.{CLICK_HOUSE_REVISION}", + "{client_name} {CLICK_HOUSE_DBMSVERSION_MAJOR}.{CLICK_HOUSE_DBMSVERSION_MINOR}.{CLICK_HOUSE_REVISION}", ) } #[test] fn test_description() { assert_eq!( - description(), + description("Rust SQLDriver"), format!( "Rust SQLDriver {}.{}.{}", CLICK_HOUSE_DBMSVERSION_MAJOR, CLICK_HOUSE_DBMSVERSION_MINOR, CLICK_HOUSE_REVISION diff --git a/src/connecting_stream.rs b/src/connecting_stream.rs index cab0132..de06b0b 100644 --- a/src/connecting_stream.rs +++ b/src/connecting_stream.rs @@ -6,22 +6,37 @@ use std::{ }; use futures_util::future::{select_ok, BoxFuture, SelectOk, TryFutureExt}; -#[cfg(feature = "tls")] +#[cfg(feature = "_tls")] use futures_util::FutureExt; #[cfg(feature = "async_std")] use async_std::net::TcpStream; -#[cfg(feature = "tls")] +#[cfg(feature = "tls-native-tls")] use native_tls::TlsConnector; #[cfg(feature = "tokio_io")] use tokio::net::TcpStream; +#[cfg(feature = "tls-rustls")] +use { + rustls::{ + client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, + crypto::{verify_tls12_signature, verify_tls13_signature}, + pki_types::{CertificateDer, ServerName, UnixTime}, + ClientConfig, DigitallySignedStruct, Error as TlsError, RootCertStore, + }, + std::sync::Arc, + tokio_rustls::TlsConnector, +}; use pin_project::pin_project; use url::Url; use crate::{errors::ConnectionError, io::Stream as InnerStream, Options}; -#[cfg(feature = "tls")] +#[cfg(feature = "tls-native-tls")] use tokio_native_tls::TlsStream; +#[cfg(feature = "tls-rustls")] +use tokio_rustls::client::TlsStream; +#[cfg(feature = "_tls")] +use crate::types::ClientTlsIdentity; type Result = std::result::Result; @@ -33,7 +48,7 @@ enum TcpState { Fail(Option), } -#[cfg(feature = "tls")] +#[cfg(feature = "_tls")] #[pin_project(project = TlsStateProj)] enum TlsState { Wait(#[pin] ConnectingFuture>), @@ -43,7 +58,7 @@ enum TlsState { #[pin_project(project = StateProj)] enum State { Tcp(#[pin] TcpState), - #[cfg(feature = "tls")] + #[cfg(feature = "_tls")] Tls(#[pin] TlsState), } @@ -60,7 +75,7 @@ impl TcpState { } } -#[cfg(feature = "tls")] +#[cfg(feature = "_tls")] impl TlsState { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { @@ -81,7 +96,7 @@ impl State { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { StateProj::Tcp(inner) => inner.poll(cx), - #[cfg(feature = "tls")] + #[cfg(feature = "_tls")] StateProj::Tls(inner) => inner.poll(cx), } } @@ -91,7 +106,12 @@ impl State { State::Tcp(TcpState::Fail(Some(conn_error))) } - #[cfg(feature = "tls")] + #[cfg(feature = "tls-rustls")] + fn tls_err(e: TlsError) -> Self { + State::Tls(TlsState::Fail(Some(ConnectionError::TlsError(e)))) + } + + #[cfg(feature = "_tls")] fn tls_host_err() -> Self { State::Tls(TlsState::Fail(Some(ConnectionError::TlsHostNotProvided))) } @@ -100,7 +120,7 @@ impl State { State::Tcp(TcpState::Wait(socket)) } - #[cfg(feature = "tls")] + #[cfg(feature = "_tls")] fn tls_wait(s: ConnectingFuture>) -> Self { State::Tls(TlsState::Wait(s)) } @@ -112,6 +132,58 @@ pub(crate) struct ConnectingStream { state: State, } +#[cfg(feature = "tls-rustls")] +#[derive(Debug)] +struct DummyTlsVerifier; + +#[cfg(feature = "tls-rustls")] +impl ServerCertVerifier for DummyTlsVerifier { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp_response: &[u8], + _now: UnixTime, + ) -> std::result::Result { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> std::result::Result { + verify_tls12_signature( + message, + cert, + dss, + &rustls::crypto::ring::default_provider().signature_verification_algorithms, + ) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> std::result::Result { + verify_tls13_signature( + message, + cert, + dss, + &rustls::crypto::ring::default_provider().signature_verification_algorithms, + ) + } + + fn supported_verify_schemes(&self) -> Vec { + rustls::crypto::ring::default_provider() + .signature_verification_algorithms + .supported_schemes() + } +} + impl ConnectingStream { #[allow(unused_variables)] pub(crate) fn new(addr: &Url, options: &Options) -> Self { @@ -137,7 +209,7 @@ impl ConnectingStream { let socket = select_ok(streams); - #[cfg(feature = "tls")] + #[cfg(feature = "_tls")] { if options.secure { return ConnectingStream::new_tls_connection(addr, socket, options); @@ -154,7 +226,7 @@ impl ConnectingStream { } } - #[cfg(feature = "tls")] + #[cfg(feature = "tls-native-tls")] fn new_tls_connection( addr: &Url, socket: SelectOk>, @@ -167,10 +239,14 @@ impl ConnectingStream { Some(host) => { let mut builder = TlsConnector::builder(); builder.danger_accept_invalid_certs(options.skip_verify); - if let Some(certificate) = options.certificate.clone() { + if let Some(certificate) = options.ca_certificate.clone() { let native_cert = native_tls::Certificate::from(certificate); builder.add_root_certificate(native_cert); } + if let Some(identity) = &options.client_tls_identity { + let ClientTlsIdentity::Pkcs(pkcs) = identity; + builder.identity(pkcs.clone()); + } Self { state: State::tls_wait(Box::pin(async move { @@ -185,6 +261,77 @@ impl ConnectingStream { } } } + + #[cfg(feature = "tls-rustls")] + fn new_tls_connection( + addr: &Url, + socket: SelectOk>, + options: &Options, + ) -> Self { + match addr.host_str().map(|host| host.to_owned()) { + None => Self { + state: State::tls_host_err(), + }, + Some(host) => { + let builder = if options.skip_verify { + ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(DummyTlsVerifier)) + } else { + let mut cert_store = RootCertStore::empty(); + cert_store.extend( + webpki_roots::TLS_SERVER_ROOTS + .iter() + .cloned() + ); + if let Some(certificates) = options.ca_certificate.clone() { + for certificate in + Into::>>::into( + certificates, + ) + { + match cert_store.add(certificate) { + Ok(_) => {}, + Err(err) => { + let err = io::Error::new( + io::ErrorKind::InvalidInput, + format!("Could not load certificate: {}.", err), + ); + return Self { state: State::tcp_err(err) }; + }, + } + } + } + ClientConfig::builder() + .with_root_certificates(cert_store) + }; + let config = if let Some(identity) = &options.client_tls_identity { + let ClientTlsIdentity::Pem { key, certs } = identity; + builder.with_client_auth_cert(certs.clone().into(), key.clone_key()) + } else { + Ok(builder.with_no_client_auth()) + }; + let config = match config { + Ok(config) => config, + Err(err) => { + return Self { state: State::tls_err(err) }; + }, + }; + Self { + state: State::tls_wait(Box::pin(async move { + let (s, _) = socket.await?; + let cx = TlsConnector::from(Arc::new(config)); + let host = ServerName::try_from(host) + .map_err(|_| ConnectionError::TlsHostNotProvided)?; + Ok(cx + .connect(host, s) + .await + .map_err(|e| ConnectionError::IoError(e))?) + })), + } + } + } + } } impl Future for ConnectingStream { diff --git a/src/errors/mod.rs b/src/errors/mod.rs index 5d59fb0..4af62b5 100644 --- a/src/errors/mod.rs +++ b/src/errors/mod.rs @@ -5,6 +5,11 @@ use thiserror::Error; use tokio::time::error::Elapsed; use url::ParseError; +#[cfg(feature = "tls-native-tls")] +use native_tls::Error as TlsError; +#[cfg(feature = "tls-rustls")] +use rustls::Error as TlsError; + /// Clickhouse error codes pub mod codes; @@ -55,9 +60,9 @@ pub enum ConnectionError { #[error("Input/output error: `{}`", _0)] IoError(#[source] io::Error), - #[cfg(feature = "tls")] + #[cfg(feature = "_tls")] #[error("TLS connection error: `{}`", _0)] - TlsError(#[source] native_tls::Error), + TlsError(#[source] TlsError), #[error("Connection broken")] Broken, @@ -137,9 +142,9 @@ impl From for Error { } } -#[cfg(feature = "tls")] -impl From for ConnectionError { - fn from(error: native_tls::Error) -> Self { +#[cfg(feature = "_tls")] +impl From for ConnectionError { + fn from(error: TlsError) -> Self { ConnectionError::TlsError(error) } } diff --git a/src/io/stream.rs b/src/io/stream.rs index 14a4d45..0ef4ec4 100644 --- a/src/io/stream.rs +++ b/src/io/stream.rs @@ -7,8 +7,10 @@ use std::{ #[cfg(feature = "tokio_io")] use tokio::{io::ReadBuf, net::TcpStream}; -#[cfg(feature = "tls")] +#[cfg(feature = "tls-native-tls")] use tokio_native_tls::TlsStream; +#[cfg(feature = "tls-rustls")] +use tokio_rustls::client::TlsStream; #[cfg(feature = "async_std")] use async_std::io::prelude::*; @@ -18,13 +20,13 @@ use pin_project::pin_project; #[cfg(feature = "tokio_io")] use tokio::io::{AsyncRead, AsyncWrite}; -#[cfg(all(feature = "tls", feature = "tokio_io"))] +#[cfg(all(feature = "_tls", feature = "tokio_io"))] type SecureTcpStream = TlsStream; #[pin_project(project = StreamProj)] pub(crate) enum Stream { Plain(#[pin] TcpStream), - #[cfg(feature = "tls")] + #[cfg(feature = "_tls")] Secure(#[pin] SecureTcpStream), } @@ -34,7 +36,7 @@ impl From for Stream { } } -#[cfg(feature = "tls")] +#[cfg(feature = "_tls")] impl From for Stream { fn from(stream: SecureTcpStream) -> Stream { Self::Secure(stream) @@ -55,7 +57,7 @@ impl Stream { pub(crate) fn set_keepalive(&mut self, keepalive: Option) -> io::Result<()> { // match *self { // Self::Plain(ref mut stream) => stream.set_keepalive(keepalive), - // #[cfg(feature = "tls")] + // #[cfg(feature = "_tls")] // Self::Secure(ref mut stream) => stream.get_mut().set_keepalive(keepalive), // }.map_err(|err| io::Error::new(err.kind(), format!("set_keepalive error: {}", err))) if keepalive.is_some() { @@ -68,10 +70,12 @@ impl Stream { pub(crate) fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { match *self { Self::Plain(ref mut stream) => stream.set_nodelay(nodelay), - #[cfg(feature = "tls")] + #[cfg(feature = "tls-native-tls")] Self::Secure(ref mut stream) => { stream.get_mut().get_mut().get_mut().set_nodelay(nodelay) } + #[cfg(feature = "tls-rustls")] + Self::Secure(ref mut stream) => stream.get_mut().0.set_nodelay(nodelay), } .map_err(|err| io::Error::new(err.kind(), format!("set_nodelay error: {err}"))) } @@ -84,7 +88,7 @@ impl Stream { ) -> Poll> { match self.project() { StreamProj::Plain(stream) => stream.poll_read(cx, buf), - #[cfg(feature = "tls")] + #[cfg(feature = "_tls")] StreamProj::Secure(stream) => stream.poll_read(cx, buf), } } @@ -99,7 +103,7 @@ impl Stream { let result = match self.project() { StreamProj::Plain(stream) => stream.poll_read(cx, &mut read_buf), - #[cfg(feature = "tls")] + #[cfg(feature = "_tls")] StreamProj::Secure(stream) => stream.poll_read(cx, &mut read_buf), }; @@ -117,7 +121,7 @@ impl Stream { ) -> Poll> { match self.project() { StreamProj::Plain(stream) => stream.poll_write(cx, buf), - #[cfg(feature = "tls")] + #[cfg(feature = "_tls")] StreamProj::Secure(stream) => stream.poll_write(cx, buf), } } diff --git a/src/io/transport.rs b/src/io/transport.rs index 75f8f0e..e4eff04 100644 --- a/src/io/transport.rs +++ b/src/io/transport.rs @@ -119,7 +119,7 @@ impl ClickhouseTransport { } } - let mut transport = h.unwrap(); + let mut transport = h.ok_or(Error::Driver(DriverError::UnexpectedPacket))?; transport.inconsistent = false; Ok(transport) } @@ -287,6 +287,15 @@ impl Stream for ClickhouseTransport { } if *this.done { + // We still have may have something in buffer, since the client may read something + // first, and only after the server will close the connection, likely in case of + // exception, let's try to parse it here + if !this.rd.is_empty() { + if let Poll::Ready(ret) = this.try_parse_msg()? { + return Poll::Ready(ret.map(Ok)); + } + } + return Poll::Ready(None); } diff --git a/src/lib.rs b/src/lib.rs index c475868..385fea5 100755 --- a/src/lib.rs +++ b/src/lib.rs @@ -66,7 +66,8 @@ //! //! - `tokio_io` *(enabled by default)* — I/O based on [Tokio](https://tokio.rs/). //! - `async_std` — I/O based on [async-std](https://async.rs/) (doesn't work together with `tokio_io`). -//! - `tls` — TLS support (allowed only with `tokio_io`). +//! - `tls-native-tls` — TLS support with native-tls (allowed only with `tokio_io`). +//! - `tls-rustls` — TLS support with rustls (allowed only with `tokio_io`). //! //! ### Example //! @@ -108,6 +109,9 @@ #![recursion_limit = "1024"] +#[cfg(all(feature = "tls-native-tls", feature = "tls-rustls"))] +compile_error!("tls-native-tls and tls-rustls are mutually exclusive and cannot be enabled together"); + use std::{fmt, future::Future, time::Duration}; use futures_util::{ @@ -324,7 +328,7 @@ impl ClientHandle { } self.inner = h; - self.context.server_info = info.unwrap(); + self.context.server_info = info.ok_or(Error::Other("Missing Hello/Exception packet".into()))?; Ok(()) } diff --git a/src/pool/mod.rs b/src/pool/mod.rs index 5d94fa0..a4e2f1f 100644 --- a/src/pool/mod.rs +++ b/src/pool/mod.rs @@ -377,6 +377,9 @@ mod test { let spent = start.elapsed(); assert!(spent >= Duration::from_millis(2000)); + #[cfg(feature = "_tls")] + assert!(spent < Duration::from_millis(5000)); // slow connect + #[cfg(not(feature = "_tls"))] assert!(spent < Duration::from_millis(2500)); assert_eq!(pool.info().idle_len, 6); diff --git a/src/types/cmd.rs b/src/types/cmd.rs index 74045c0..0b02ea8 100644 --- a/src/types/cmd.rs +++ b/src/types/cmd.rs @@ -44,13 +44,13 @@ fn encode_command(cmd: &Cmd) -> Result> { } fn encode_hello(context: &Context) -> Result> { - trace!("[hello] -> {}", client_info::description()); + let options = context.options.get()?; + + trace!("[hello] -> {}", client_info::description(&options.client_name)); let mut encoder = Encoder::new(); encoder.uvarint(protocol::CLIENT_HELLO); - client_info::write(&mut encoder); - - let options = context.options.get()?; + client_info::write(&mut encoder, &options.client_name); encoder.string(&options.database); encoder.string(&options.username); @@ -83,6 +83,8 @@ fn encode_query(query: &Query, context: &Context) -> Result> { encoder.uvarint(protocol::CLIENT_QUERY); encoder.string(""); + let options = context.options.get()?; + { let hostname = &context.hostname; encoder.uvarint(1); @@ -93,7 +95,7 @@ fn encode_query(query: &Query, context: &Context) -> Result> { encoder.string(hostname); encoder.string(hostname); } - client_info::write(&mut encoder); + client_info::write(&mut encoder, &options.client_name); if context.server_info.revision >= protocol::DBMS_MIN_REVISION_WITH_QUOTA_KEY_IN_CLIENT_INFO { encoder.string(""); @@ -103,8 +105,6 @@ fn encode_query(query: &Query, context: &Context) -> Result> { encoder.uvarint(0); } - let options = context.options.get()?; - let settings_format = if context.server_info.revision >= protocol::DBMS_MIN_REVISION_WITH_SETTINGS_SERIALIZED_AS_STRINGS { diff --git a/src/types/mod.rs b/src/types/mod.rs index 83ff796..aaba44c 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -21,6 +21,8 @@ pub use self::{ value::Value, value_ref::ValueRef, }; +#[cfg(feature = "_tls")] +pub use self::options::ClientTlsIdentity; pub(crate) use self::{ cmd::Cmd, diff --git a/src/types/options.rs b/src/types/options.rs index 5f02fb2..b1cac4d 100644 --- a/src/types/options.rs +++ b/src/types/options.rs @@ -6,13 +6,16 @@ use std::{ sync::{Arc, Mutex}, time::Duration, }; +#[cfg(feature = "_tls")] +use std::fs; use crate::errors::{Error, Result, UrlError}; -#[cfg(feature = "tls")] -use native_tls; use percent_encoding::percent_decode; use url::Url; +#[cfg(feature = "tls-rustls")] +use rustls::pki_types::pem::PemObject; + const DEFAULT_MIN_CONNS: usize = 10; const DEFAULT_MAX_CONNS: usize = 20; @@ -93,12 +96,11 @@ impl IntoOptions for String { } } -/// An X509 certificate. -#[cfg(feature = "tls")] +/// An X509 certificate for native-tls. +#[cfg(feature = "tls-native-tls")] #[derive(Clone)] pub struct Certificate(Arc); - -#[cfg(feature = "tls")] +#[cfg(feature = "tls-native-tls")] impl Certificate { /// Parses a DER-formatted X509 certificate. pub fn from_der(der: &[u8]) -> Result { @@ -110,33 +112,118 @@ impl Certificate { } /// Parses a PEM-formatted X509 certificate. - pub fn from_pem(der: &[u8]) -> Result { - let inner = match native_tls::Certificate::from_pem(der) { + pub fn from_pem(pem: &[u8]) -> Result { + let inner = match native_tls::Certificate::from_pem(pem) { Ok(certificate) => certificate, Err(err) => return Err(Error::Other(err.to_string().into())), }; Ok(Certificate(Arc::new(inner))) } } +#[cfg(feature = "tls-native-tls")] +impl From for native_tls::Certificate { + fn from(value: Certificate) -> Self { + value.0.as_ref().clone() + } +} + +/// An X509 certificate for rustls. +#[cfg(feature = "tls-rustls")] +#[derive(Clone)] +pub struct Certificate(Arc>>); +#[cfg(feature = "tls-rustls")] +impl Certificate { + /// Parses a DER-formatted X509 certificate. + pub fn from_der(der: &[u8]) -> Result { + let der = der.to_vec(); + let inner = match rustls::pki_types::CertificateDer::try_from(der) { + Ok(certificate) => certificate, + Err(err) => return Err(Error::Other(err.to_string().into())), + }; + Ok(Certificate(Arc::new(vec![inner]))) + } + + /// Parses a PEM-formatted X509 certificate. + pub fn from_pem(pem: &[u8]) -> Result { + let certs = rustls_pemfile::certs(&mut pem.as_ref()) + .map(|result| result.unwrap()) + .collect(); + Ok(Certificate(Arc::new(certs))) + } +} +#[cfg(feature = "tls-rustls")] +impl From for Vec> { + fn from(value: Certificate) -> Self { + value.0.as_ref().clone() + } +} -#[cfg(feature = "tls")] +#[cfg(feature = "_tls")] impl fmt::Debug for Certificate { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "[Certificate]") } } - -#[cfg(feature = "tls")] +#[cfg(feature = "_tls")] impl PartialEq for Certificate { fn eq(&self, _other: &Self) -> bool { true } } -#[cfg(feature = "tls")] -impl From for native_tls::Certificate { - fn from(value: Certificate) -> Self { - value.0.as_ref().clone() +#[cfg(feature = "_tls")] +pub fn load_certificate(file: &str) -> Result { + let data = fs::read(file)?; + if file.ends_with(".der") || file.ends_with(".cer") { + Certificate::from_der(&data) + } else { + Certificate::from_pem(&data) + } +} + +#[cfg(feature = "_tls")] +#[derive(Clone)] +pub enum ClientTlsIdentity { + #[cfg(feature = "tls-rustls")] + Pem { + key: Arc>, + certs: Certificate, + }, + #[cfg(feature = "tls-native-tls")] + Pkcs(native_tls::Identity), +} + +#[cfg(feature = "_tls")] +impl ClientTlsIdentity { + #[cfg(feature = "tls-rustls")] + pub fn load(cert_path: &str, key_path: &str) -> Result { + let key = rustls::pki_types::PrivateKeyDer::from_pem_slice(fs::read(key_path)?.as_ref()) + .map_err(|e| format!("Cannot read private key from {}: {}", key_path, e))?; + let key = Arc::new(key); + let certs = load_certificate(cert_path)?; + return Ok(Self::Pem{ key, certs }); + } + + #[cfg(feature = "tls-native-tls")] + pub fn load(cert_path: &str, key_path: &str) -> Result { + let identity = native_tls::Identity::from_pkcs8( + fs::read(cert_path)?.as_ref(), + fs::read(key_path)?.as_ref(), + ).map_err(|e| format!("Cannot load identity from {} and {}: {}", cert_path, key_path, e))?; + return Ok(Self::Pkcs(identity)); + } +} + +#[cfg(feature = "_tls")] +impl fmt::Debug for ClientTlsIdentity { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "[Client Certificate]") + } +} +#[cfg(feature = "_tls")] +impl PartialEq for ClientTlsIdentity { + fn eq(&self, _other: &Self) -> bool { + true } } @@ -254,27 +341,35 @@ pub struct Options { pub(crate) execute_timeout: Option, /// Enable TLS encryption (defaults to `false`) - #[cfg(feature = "tls")] + #[cfg(feature = "_tls")] pub(crate) secure: bool, /// Skip certificate verification (default is `false`). - #[cfg(feature = "tls")] + #[cfg(feature = "_tls")] pub(crate) skip_verify: bool, - /// An X509 certificate. - #[cfg(feature = "tls")] - pub(crate) certificate: Option, + /// CA certificate. + #[cfg(feature = "_tls")] + pub(crate) ca_certificate: Option, + + /// Authorization with certificate (mTLS). + #[cfg(feature = "_tls")] + pub(crate) client_tls_identity: Option, /// Query settings pub(crate) settings: HashMap, /// Comma separated list of single address host for load-balancing. pub(crate) alt_hosts: Vec, + + /// Client name (defaults to `Rust SQLDriver`). + pub(crate) client_name: String } impl fmt::Debug for Options { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Options") + let mut debug = f.debug_struct("Options"); + let res = debug .field("addr", &self.addr) .field("database", &self.database) .field("compression", &self.compression) @@ -289,7 +384,15 @@ impl fmt::Debug for Options { .field("connection_timeout", &self.connection_timeout) .field("settings", &self.settings) .field("alt_hosts", &self.alt_hosts) - .finish() + .field("client_name", &self.client_name); + + #[cfg(feature = "_tls")] + res + .field("secure", &self.secure) + .field("ca_certificate", &self.ca_certificate) + .field("client_tls_identity", &self.client_tls_identity); + + res.finish() } } @@ -313,14 +416,17 @@ impl Default for Options { query_timeout: Duration::from_secs(180), insert_timeout: Some(Duration::from_secs(180)), execute_timeout: Some(Duration::from_secs(180)), - #[cfg(feature = "tls")] + #[cfg(feature = "_tls")] secure: false, - #[cfg(feature = "tls")] + #[cfg(feature = "_tls")] skip_verify: false, - #[cfg(feature = "tls")] - certificate: None, + #[cfg(feature = "_tls")] + ca_certificate: None, + #[cfg(feature = "_tls")] + client_tls_identity: None, settings: HashMap::new(), alt_hosts: Vec::new(), + client_name: "Rust SQLDriver".into(), } } } @@ -455,22 +561,28 @@ impl Options { => execute_timeout: Option } - #[cfg(feature = "tls")] + #[cfg(feature = "_tls")] property! { /// Establish secure connection (default is `false`). => secure: bool } - #[cfg(feature = "tls")] + #[cfg(feature = "_tls")] property! { /// Skip certificate verification (default is `false`). => skip_verify: bool } - #[cfg(feature = "tls")] + #[cfg(feature = "_tls")] + property! { + /// CA certificate. + => ca_certificate: Option + } + + #[cfg(feature = "_tls")] property! { - /// An X509 certificate. - => certificate: Option + /// Authorization with certificate (mTLS). + => client_tls_identity: Option } property! { @@ -482,6 +594,11 @@ impl Options { /// Comma separated list of single address host for load-balancing. => alt_hosts: Vec } + + property! { + /// Client name (defaults to `Rust SQLDriver`). + => client_name: &str + } } impl FromStr for Options { @@ -537,6 +654,11 @@ fn set_params<'a, I>(options: &mut Options, iter: I) -> std::result::Result<(), where I: Iterator, Cow<'a, str>)>, { + #[cfg(feature = "_tls")] + let mut client_certificate = None; + #[cfg(feature = "_tls")] + let mut client_private_key = None; + for (key, value) in iter { match key.as_ref() { "pool_min" => options.pool_min = parse_param(key, value, usize::from_str)?, @@ -560,11 +682,18 @@ where options.execute_timeout = parse_param(key, value, parse_opt_duration)? } "compression" => options.compression = parse_param(key, value, parse_compression)?, - #[cfg(feature = "tls")] + #[cfg(feature = "_tls")] "secure" => options.secure = parse_param(key, value, bool::from_str)?, - #[cfg(feature = "tls")] + #[cfg(feature = "_tls")] "skip_verify" => options.skip_verify = parse_param(key, value, bool::from_str)?, + #[cfg(feature = "_tls")] + "ca_certificate" => options.ca_certificate = Some(parse_param(key, value, load_certificate)?), + #[cfg(feature = "_tls")] + "client_certificate" => client_certificate = Some(value), + #[cfg(feature = "_tls")] + "client_private_key" => client_private_key = Some(value), "alt_hosts" => options.alt_hosts = parse_param(key, value, parse_hosts)?, + "client_name" => options.client_name = parse_param(key, value, String::from_str)?, _ => { let value = SettingType::String(value.to_string()); options.settings.insert( @@ -578,6 +707,17 @@ where }; } + #[cfg(feature = "_tls")] + match (client_certificate, client_private_key) { + (Some(cert), Some(key)) => { + options.client_tls_identity = Some(ClientTlsIdentity::load(&cert, &key).map_err(|_| UrlError::Invalid)?); + } + (None, None) => {} + _ => { + return Err(UrlError::Invalid); + }, + } + Ok(()) } @@ -701,7 +841,7 @@ mod test { } #[test] - #[cfg(feature = "tls")] + #[cfg(feature = "_tls")] fn test_parse_secure_options() { let url = "tcp://username:password@host1:9001/database?ping_timeout=42ms&keepalive=99s&compression=lz4&connection_timeout=10s&secure=true&skip_verify=true"; assert_eq!( diff --git a/src/types/value.rs b/src/types/value.rs index 9b38b41..e86ef9c 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -806,7 +806,7 @@ mod test { #[test] fn test_size_of() { use std::mem; - assert_eq!(56, mem::size_of::<[Value; 1]>()); + assert_eq!(64, mem::size_of::<[Value; 1]>()); } #[test] diff --git a/tests/clickhouse.rs b/tests/clickhouse.rs index e72a1ad..5951458 100644 --- a/tests/clickhouse.rs +++ b/tests/clickhouse.rs @@ -31,14 +31,14 @@ use std::{ use uuid::Uuid; use Tz::{Asia__Istanbul as IST, UTC}; -#[cfg(not(feature = "tls"))] +#[cfg(not(feature = "_tls"))] fn database_url() -> String { env::var("DATABASE_URL").unwrap_or_else(|_| { "tcp://localhost:9000?compression=lz4&ping_timeout=2s&retry_timeout=3s".into() }) } -#[cfg(feature = "tls")] +#[cfg(feature = "_tls")] fn database_url() -> String { env::var("DATABASE_URL").unwrap_or_else(|_| { "tcp://localhost:9440?compression=lz4&ping_timeout=2s&retry_timeout=3s&secure=true&skip_verify=true".into() @@ -2574,3 +2574,24 @@ async fn test_insert_big_block() -> Result<(), Error> { assert_eq!(format!("{:?}", expected.as_ref()), format!("{:?}", &actual)); Ok(()) } + +#[cfg(feature = "tokio_io")] +#[tokio::test] +async fn test_client_name() -> Result<(), Error> { + let uuid = Uuid::new_v4().to_string(); + let client_name = format!("clickhouse-rs-tests-{}", uuid); + let log_comment = format!("tests-{}", uuid); + let options = Options::from_str(&format!("{}&client_name={}", database_url(), client_name))? + .with_setting("log_comment", &*log_comment, true); + let pool = Pool::new(options); + let mut c = pool.get_handle().await?; + + c.execute("SELECT 1").await?; + c.execute("SYSTEM FLUSH LOGS").await?; + let rows = c + .query(format!("SELECT client_name FROM system.query_log WHERE Settings['log_comment'] = '{log_comment}' AND query = 'SELECT 1' LIMIT 1")) + .fetch_all() + .await?; + assert_eq!(client_name, rows.get::(0, "client_name")?); + Ok(()) +}