Skip to content

Escape PostgreSQL options #3800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions sqlx-postgres/src/options/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::borrow::Cow;
use std::env::var;
use std::fmt::{Display, Write};
use std::fmt::{self, Display, Write};
use std::path::{Path, PathBuf};

pub use ssl_mode::PgSslMode;
Expand Down Expand Up @@ -495,6 +495,9 @@ impl PgConnectOptions {

/// Set additional startup options for the connection as a list of key-value pairs.
///
/// Escapes the options’ backslash and space characters as per
/// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-OPTIONS
///
/// # Example
///
/// ```rust
Expand All @@ -515,7 +518,8 @@ impl PgConnectOptions {
options_str.push(' ');
}

write!(options_str, "-c {k}={v}").expect("failed to write an option to the string");
options_str.push_str("-c ");
write!(PgOptionsWriteEscaped(options_str), "{k}={v}").ok();
}
self
}
Expand Down Expand Up @@ -669,6 +673,39 @@ fn default_host(port: u16) -> String {
"localhost".to_owned()
}

/// Writer that escapes passed-in PostgreSQL options.
///
/// Escapes backslashes and spaces with an additional backslash according to
/// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-OPTIONS
#[derive(Debug)]
struct PgOptionsWriteEscaped<'a>(&'a mut String);

impl Write for PgOptionsWriteEscaped<'_> {
fn write_str(&mut self, s: &str) -> fmt::Result {
let mut span_start = 0;

for (span_end, matched) in s.match_indices([' ', '\\']) {
write!(self.0, r"{}\{matched}", &s[span_start..span_end])?;
span_start = span_end + matched.len();
}

// Write the rest of the string after the last match, or all of it if no matches
self.0.push_str(&s[span_start..]);

Ok(())
}

fn write_char(&mut self, ch: char) -> fmt::Result {
if matches!(ch, ' ' | '\\') {
self.0.push('\\');
}

self.0.push(ch);

Ok(())
}
}

#[test]
fn test_options_formatting() {
let options = PgConnectOptions::new().options([("geqo", "off")]);
Expand All @@ -683,6 +720,26 @@ fn test_options_formatting() {
options.options,
Some("-c geqo=off -c statement_timeout=5min".to_string())
);
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-OPTIONS
let options =
PgConnectOptions::new().options([("application_name", r"/back\slash/ and\ spaces")]);
assert_eq!(
options.options,
Some(r"-c application_name=/back\\slash/\ and\\\ spaces".to_string())
);
let options = PgConnectOptions::new();
assert_eq!(options.options, None);
}

#[test]
fn test_pg_write_escaped() {
let mut buf = String::new();
let mut x = PgOptionsWriteEscaped(&mut buf);
x.write_str("x").unwrap();
x.write_str("").unwrap();
x.write_char('\\').unwrap();
x.write_str("y \\").unwrap();
x.write_char(' ').unwrap();
x.write_char('z').unwrap();
assert_eq!(buf, r"x\\y\ \\\ z");
}