diff --git a/sqlx-postgres/src/options/mod.rs b/sqlx-postgres/src/options/mod.rs index a0b222606a..6ec872f569 100644 --- a/sqlx-postgres/src/options/mod.rs +++ b/sqlx-postgres/src/options/mod.rs @@ -1,6 +1,6 @@ use std::borrow::Cow; use std::env::var; -use std::fmt::{Display, Write}; +use std::fmt::{self, Display, Write}; use std::path::{Path, PathBuf}; pub use ssl_mode::PgSslMode; @@ -495,6 +495,9 @@ impl PgConnectOptions { /// Set additional startup options for the connection as a list of key-value pairs. /// + /// Escapes the options’ backslash and space characters as per + /// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-OPTIONS + /// /// # Example /// /// ```rust @@ -515,7 +518,8 @@ impl PgConnectOptions { options_str.push(' '); } - write!(options_str, "-c {k}={v}").expect("failed to write an option to the string"); + options_str.push_str("-c "); + write!(PgOptionsWriteEscaped(options_str), "{k}={v}").ok(); } self } @@ -669,6 +673,39 @@ fn default_host(port: u16) -> String { "localhost".to_owned() } +/// Writer that escapes passed-in PostgreSQL options. +/// +/// Escapes backslashes and spaces with an additional backslash according to +/// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-OPTIONS +#[derive(Debug)] +struct PgOptionsWriteEscaped<'a>(&'a mut String); + +impl Write for PgOptionsWriteEscaped<'_> { + fn write_str(&mut self, s: &str) -> fmt::Result { + let mut span_start = 0; + + for (span_end, matched) in s.match_indices([' ', '\\']) { + write!(self.0, r"{}\{matched}", &s[span_start..span_end])?; + span_start = span_end + matched.len(); + } + + // Write the rest of the string after the last match, or all of it if no matches + self.0.push_str(&s[span_start..]); + + Ok(()) + } + + fn write_char(&mut self, ch: char) -> fmt::Result { + if matches!(ch, ' ' | '\\') { + self.0.push('\\'); + } + + self.0.push(ch); + + Ok(()) + } +} + #[test] fn test_options_formatting() { let options = PgConnectOptions::new().options([("geqo", "off")]); @@ -683,6 +720,26 @@ fn test_options_formatting() { options.options, Some("-c geqo=off -c statement_timeout=5min".to_string()) ); + // https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-OPTIONS + let options = + PgConnectOptions::new().options([("application_name", r"/back\slash/ and\ spaces")]); + assert_eq!( + options.options, + Some(r"-c application_name=/back\\slash/\ and\\\ spaces".to_string()) + ); let options = PgConnectOptions::new(); assert_eq!(options.options, None); } + +#[test] +fn test_pg_write_escaped() { + let mut buf = String::new(); + let mut x = PgOptionsWriteEscaped(&mut buf); + x.write_str("x").unwrap(); + x.write_str("").unwrap(); + x.write_char('\\').unwrap(); + x.write_str("y \\").unwrap(); + x.write_char(' ').unwrap(); + x.write_char('z').unwrap(); + assert_eq!(buf, r"x\\y\ \\\ z"); +}