Skip to content

Commit 6b0a9f8

Browse files
authored
fix: positional params (#473)
pre-processes the query string to replace named params like `@name` with positional params like `$1`, because only the latter is supported by the postgres parser. not super happy with the way it is used for the type checker but also did not want to put that function in another crate. also removed a `println!` statement leftover from a previous pr of me. also fixed a bug in the lexer that tokenized a cast as a named parameter. fixes #405, #454
1 parent e2bbe42 commit 6b0a9f8

File tree

10 files changed

+205
-26
lines changed

10 files changed

+205
-26
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/pgt_lexer/src/lexer.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ impl<'a> Lexer<'a> {
111111
pgt_tokenizer::TokenKind::Tilde => SyntaxKind::TILDE,
112112
pgt_tokenizer::TokenKind::Question => SyntaxKind::QUESTION,
113113
pgt_tokenizer::TokenKind::Colon => SyntaxKind::COLON,
114+
pgt_tokenizer::TokenKind::DoubleColon => SyntaxKind::DOUBLE_COLON,
114115
pgt_tokenizer::TokenKind::Eq => SyntaxKind::EQ,
115116
pgt_tokenizer::TokenKind::Bang => SyntaxKind::BANG,
116117
pgt_tokenizer::TokenKind::Lt => SyntaxKind::L_ANGLE,

crates/pgt_lexer_codegen/src/syntax_kind.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ const PUNCT: &[(&str, &str)] = &[
3737
("_", "UNDERSCORE"),
3838
(".", "DOT"),
3939
(":", "COLON"),
40+
("::", "DOUBLE_COLON"),
4041
("=", "EQ"),
4142
("!", "BANG"),
4243
("-", "MINUS"),

crates/pgt_tokenizer/src/lib.rs

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -144,32 +144,37 @@ impl Cursor<'_> {
144144
}
145145
}
146146
':' => {
147-
// Named parameters in psql with different substitution styles.
148-
//
149-
// https://www.postgresql.org/docs/current/app-psql.html#APP-PSQL-INTERPOLATION
150-
match self.first() {
151-
'\'' => {
152-
// Named parameter with colon prefix and single quotes.
153-
self.bump();
154-
let terminated = self.single_quoted_string();
155-
let kind = NamedParamKind::ColonString { terminated };
156-
TokenKind::NamedParam { kind }
157-
}
158-
'"' => {
159-
// Named parameter with colon prefix and double quotes.
160-
self.bump();
161-
let terminated = self.double_quoted_string();
162-
let kind = NamedParamKind::ColonIdentifier { terminated };
163-
TokenKind::NamedParam { kind }
164-
}
165-
c if is_ident_start(c) => {
166-
// Named parameter with colon prefix.
167-
self.eat_while(is_ident_cont);
168-
TokenKind::NamedParam {
169-
kind: NamedParamKind::ColonRaw,
147+
if self.first() == ':' {
148+
self.bump();
149+
TokenKind::DoubleColon
150+
} else {
151+
// Named parameters in psql with different substitution styles.
152+
//
153+
// https://www.postgresql.org/docs/current/app-psql.html#APP-PSQL-INTERPOLATION
154+
match self.first() {
155+
'\'' => {
156+
// Named parameter with colon prefix and single quotes.
157+
self.bump();
158+
let terminated = self.single_quoted_string();
159+
let kind = NamedParamKind::ColonString { terminated };
160+
TokenKind::NamedParam { kind }
161+
}
162+
'"' => {
163+
// Named parameter with colon prefix and double quotes.
164+
self.bump();
165+
let terminated = self.double_quoted_string();
166+
let kind = NamedParamKind::ColonIdentifier { terminated };
167+
TokenKind::NamedParam { kind }
168+
}
169+
c if is_ident_start(c) => {
170+
// Named parameter with colon prefix.
171+
self.eat_while(is_ident_cont);
172+
TokenKind::NamedParam {
173+
kind: NamedParamKind::ColonRaw,
174+
}
170175
}
176+
_ => TokenKind::Colon,
171177
}
172-
_ => TokenKind::Colon,
173178
}
174179
}
175180
// One-symbol tokens.
@@ -675,6 +680,23 @@ mod tests {
675680
assert_debug_snapshot!(result);
676681
}
677682

683+
#[test]
684+
fn debug_simple_cast() {
685+
let result = lex("::test");
686+
assert_debug_snapshot!(result, @r###"
687+
[
688+
"::" @ DoubleColon,
689+
"test" @ Ident,
690+
]
691+
"###);
692+
}
693+
694+
#[test]
695+
fn named_param_colon_raw_vs_cast() {
696+
let result = lex("select 1 from c where id::test = :id;");
697+
assert_debug_snapshot!(result);
698+
}
699+
678700
#[test]
679701
fn named_param_colon_string() {
680702
let result = lex("select 1 from c where id = :'id';");
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
---
2+
source: crates/pgt_tokenizer/src/lib.rs
3+
expression: result
4+
snapshot_kind: text
5+
---
6+
[
7+
"select" @ Ident,
8+
" " @ Space,
9+
"1" @ Literal { kind: Int { base: Decimal, empty_int: false } },
10+
" " @ Space,
11+
"from" @ Ident,
12+
" " @ Space,
13+
"c" @ Ident,
14+
" " @ Space,
15+
"where" @ Ident,
16+
" " @ Space,
17+
"id" @ Ident,
18+
"::" @ DoubleColon,
19+
"test" @ Ident,
20+
" " @ Space,
21+
"=" @ Eq,
22+
" " @ Space,
23+
":id" @ NamedParam { kind: ColonRaw },
24+
";" @ Semi,
25+
]

crates/pgt_tokenizer/src/token.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ pub enum TokenKind {
4646
Minus,
4747
/// `:`
4848
Colon,
49+
/// `::`
50+
DoubleColon,
4951
/// `.`
5052
Dot,
5153
/// `=`

crates/pgt_workspace/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ pgt_schema_cache = { workspace = true }
3333
pgt_statement_splitter = { workspace = true }
3434
pgt_suppressions = { workspace = true }
3535
pgt_text_size.workspace = true
36+
pgt_tokenizer = { workspace = true }
3637
pgt_typecheck = { workspace = true }
3738
pgt_workspace_macros = { workspace = true }
3839
rustc-hash = { workspace = true }

crates/pgt_workspace/src/workspace/server.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use document::{
1414
TypecheckDiagnosticsMapper,
1515
};
1616
use futures::{StreamExt, stream};
17+
use pg_query::convert_to_positional_params;
1718
use pgt_analyse::{AnalyserOptions, AnalysisFilter};
1819
use pgt_analyser::{Analyser, AnalyserConfig, AnalyserParams};
1920
use pgt_diagnostics::{
@@ -468,7 +469,7 @@ impl Workspace for WorkspaceServer {
468469
// Type checking
469470
let typecheck_result = pgt_typecheck::check_sql(TypecheckParams {
470471
conn: &pool,
471-
sql: id.content(),
472+
sql: convert_to_positional_params(id.content()).as_str(),
472473
ast: &ast,
473474
tree: &cst,
474475
schema_cache: schema_cache.as_ref(),

crates/pgt_workspace/src/workspace/server.tests.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,57 @@ async fn test_dedupe_diagnostics(test_db: PgPool) {
277277
Some(TextRange::new(115.into(), 210.into()))
278278
);
279279
}
280+
281+
#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")]
282+
async fn test_positional_params(test_db: PgPool) {
283+
let mut conf = PartialConfiguration::init();
284+
conf.merge_with(PartialConfiguration {
285+
db: Some(PartialDatabaseConfiguration {
286+
database: Some(
287+
test_db
288+
.connect_options()
289+
.get_database()
290+
.unwrap()
291+
.to_string(),
292+
),
293+
..Default::default()
294+
}),
295+
..Default::default()
296+
});
297+
298+
let workspace = get_test_workspace(Some(conf)).expect("Unable to create test workspace");
299+
300+
let path = PgTPath::new("test.sql");
301+
302+
let setup_sql = r"
303+
create table users (
304+
id serial primary key,
305+
name text not null,
306+
email text not null
307+
);
308+
";
309+
test_db.execute(setup_sql).await.expect("setup sql failed");
310+
311+
let content = r#"select * from users where id = @one and name = :two and email = :'three';"#;
312+
313+
workspace
314+
.open_file(OpenFileParams {
315+
path: path.clone(),
316+
content: content.into(),
317+
version: 1,
318+
})
319+
.expect("Unable to open test file");
320+
321+
let diagnostics = workspace
322+
.pull_diagnostics(crate::workspace::PullDiagnosticsParams {
323+
path: path.clone(),
324+
categories: RuleCategories::all(),
325+
max_diagnostics: 100,
326+
only: vec![],
327+
skip: vec![],
328+
})
329+
.expect("Unable to pull diagnostics")
330+
.diagnostics;
331+
332+
assert_eq!(diagnostics.len(), 0, "Expected no diagnostic");
333+
}

crates/pgt_workspace/src/workspace/server/pg_query.rs

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
use std::collections::HashMap;
12
use std::num::NonZeroUsize;
23
use std::sync::{Arc, Mutex};
34

45
use lru::LruCache;
56
use pgt_query_ext::diagnostics::*;
67
use pgt_text_size::TextRange;
8+
use pgt_tokenizer::tokenize;
79

810
use super::statement_identifier::StatementId;
911

@@ -37,7 +39,7 @@ impl PgQueryStore {
3739
}
3840

3941
let r = Arc::new(
40-
pgt_query::parse(statement.content())
42+
pgt_query::parse(&convert_to_positional_params(statement.content()))
4143
.map_err(SyntaxDiagnostic::from)
4244
.and_then(|ast| {
4345
ast.into_root().ok_or_else(|| {
@@ -87,10 +89,79 @@ impl PgQueryStore {
8789
}
8890
}
8991

92+
/// Converts named parameters in a SQL query string to positional parameters.
93+
///
94+
/// This function scans the input SQL string for named parameters (e.g., `@param`, `:param`, `:'param'`)
95+
/// and replaces them with positional parameters (e.g., `$1`, `$2`, etc.).
96+
///
97+
/// It maintains the original spacing of the named parameters in the output string.
98+
///
99+
/// Useful for preparing SQL queries for parsing or execution where named paramters are not supported.
100+
pub fn convert_to_positional_params(text: &str) -> String {
101+
let mut result = String::with_capacity(text.len());
102+
let mut param_mapping: HashMap<&str, usize> = HashMap::new();
103+
let mut param_index = 1;
104+
let mut position = 0;
105+
106+
for token in tokenize(text) {
107+
let token_len = token.len as usize;
108+
let token_text = &text[position..position + token_len];
109+
110+
if matches!(token.kind, pgt_tokenizer::TokenKind::NamedParam { .. }) {
111+
let idx = match param_mapping.get(token_text) {
112+
Some(&index) => index,
113+
None => {
114+
let index = param_index;
115+
param_mapping.insert(token_text, index);
116+
param_index += 1;
117+
index
118+
}
119+
};
120+
121+
let replacement = format!("${}", idx);
122+
let original_len = token_text.len();
123+
let replacement_len = replacement.len();
124+
125+
result.push_str(&replacement);
126+
127+
// maintain original spacing
128+
if replacement_len < original_len {
129+
result.push_str(&" ".repeat(original_len - replacement_len));
130+
}
131+
} else {
132+
result.push_str(token_text);
133+
}
134+
135+
position += token_len;
136+
}
137+
138+
result
139+
}
140+
90141
#[cfg(test)]
91142
mod tests {
92143
use super::*;
93144

145+
#[test]
146+
fn test_convert_to_positional_params() {
147+
let input = "select * from users where id = @one and name = :two and email = :'three';";
148+
let result = convert_to_positional_params(input);
149+
assert_eq!(
150+
result,
151+
"select * from users where id = $1 and name = $2 and email = $3 ;"
152+
);
153+
}
154+
155+
#[test]
156+
fn test_convert_to_positional_params_with_duplicates() {
157+
let input = "select * from users where first_name = @one and starts_with(email, @one) and created_at > @two;";
158+
let result = convert_to_positional_params(input);
159+
assert_eq!(
160+
result,
161+
"select * from users where first_name = $1 and starts_with(email, $1 ) and created_at > $2 ;"
162+
);
163+
}
164+
94165
#[test]
95166
fn test_plpgsql_syntax_error() {
96167
let input = "

0 commit comments

Comments
 (0)