Skip to content

Commit 0df9188

Browse files
committed
progress
1 parent 79f8dfa commit 0df9188

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

crates/pgt_workspace/src/workspace/server/pg_query.rs

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::collections::HashMap;
12
use std::num::NonZeroUsize;
23
use std::sync::{Arc, Mutex};
34

@@ -97,7 +98,8 @@ impl PgQueryStore {
9798
///
9899
/// Useful for preparing SQL queries for parsing or execution where named paramters are not supported.
99100
pub fn convert_to_positional_params(text: &str) -> String {
100-
let mut result = String::new();
101+
let mut result = String::with_capacity(text.len());
102+
let mut param_mapping: HashMap<&str, usize> = HashMap::new();
101103
let mut param_index = 1;
102104
let mut position = 0;
103105

@@ -106,7 +108,17 @@ pub fn convert_to_positional_params(text: &str) -> String {
106108
let token_text = &text[position..position + token_len];
107109

108110
if matches!(token.kind, pgt_tokenizer::TokenKind::NamedParam { .. }) {
109-
let replacement = format!("${}", param_index);
111+
let idx = match param_mapping.get(token_text) {
112+
Some(&index) => index,
113+
None => {
114+
let index = param_index;
115+
param_mapping.insert(token_text, index);
116+
param_index += 1;
117+
index
118+
}
119+
};
120+
121+
let replacement = format!("${}", idx);
110122
let original_len = token_text.len();
111123
let replacement_len = replacement.len();
112124

@@ -116,8 +128,6 @@ pub fn convert_to_positional_params(text: &str) -> String {
116128
if replacement_len < original_len {
117129
result.push_str(&" ".repeat(original_len - replacement_len));
118130
}
119-
120-
param_index += 1;
121131
} else {
122132
result.push_str(token_text);
123133
}
@@ -142,6 +152,16 @@ mod tests {
142152
);
143153
}
144154

155+
#[test]
156+
fn test_convert_to_positional_params_with_duplicates() {
157+
let input = "select * from users where first_name = @one and starts_with(email, @one) and created_at > @two;";
158+
let result = convert_to_positional_params(input);
159+
assert_eq!(
160+
result,
161+
"select * from users where first_name = $1 and starts_with(email, $1 ) and created_at > $2 ;"
162+
);
163+
}
164+
145165
#[test]
146166
fn test_plpgsql_syntax_error() {
147167
let input = "

0 commit comments

Comments
 (0)