From d51b9ca86835cb91af720ae080e5ddca5b884441 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Tue, 12 Aug 2025 08:28:07 +0200 Subject: [PATCH 1/5] fix: positional params --- Cargo.lock | 1 + crates/pgt_workspace/Cargo.toml | 1 + crates/pgt_workspace/src/workspace/server.rs | 5 +- .../src/workspace/server.tests.rs | 54 +++++++++++++++++++ .../src/workspace/server/pg_query.rs | 53 +++++++++++++++++- 5 files changed, 110 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 49143908b..94b591f3e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3189,6 +3189,7 @@ dependencies = [ "pgt_suppressions", "pgt_test_utils", "pgt_text_size", + "pgt_tokenizer", "pgt_typecheck", "pgt_workspace_macros", "rustc-hash 2.1.0", diff --git a/crates/pgt_workspace/Cargo.toml b/crates/pgt_workspace/Cargo.toml index efded47c7..860b51331 100644 --- a/crates/pgt_workspace/Cargo.toml +++ b/crates/pgt_workspace/Cargo.toml @@ -33,6 +33,7 @@ pgt_schema_cache = { workspace = true } pgt_statement_splitter = { workspace = true } pgt_suppressions = { workspace = true } pgt_text_size.workspace = true +pgt_tokenizer = { workspace = true } pgt_typecheck = { workspace = true } pgt_workspace_macros = { workspace = true } rustc-hash = { workspace = true } diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index f4a9561f2..49c306f2b 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -14,6 +14,7 @@ use document::{ TypecheckDiagnosticsMapper, }; use futures::{StreamExt, stream}; +use pg_query::convert_to_positional_params; use pgt_analyse::{AnalyserOptions, AnalysisFilter}; use pgt_analyser::{Analyser, AnalyserConfig, AnalyserParams}; use pgt_diagnostics::{ @@ -468,7 +469,7 @@ impl Workspace for WorkspaceServer { // Type checking let typecheck_result = pgt_typecheck::check_sql(TypecheckParams { conn: &pool, - sql: id.content(), + sql: convert_to_positional_params(id.content()).as_str(), ast: &ast, tree: &cst, schema_cache: schema_cache.as_ref(), @@ -511,8 +512,6 @@ impl Workspace for WorkspaceServer { .await .unwrap_or_else(|_| vec![]); - println!("{:#?}", plpgsql_check_results); - for d in plpgsql_check_results { let r = d.span.map(|span| span + range.start()); diagnostics.push( diff --git a/crates/pgt_workspace/src/workspace/server.tests.rs b/crates/pgt_workspace/src/workspace/server.tests.rs index ef5ba2677..894d10426 100644 --- a/crates/pgt_workspace/src/workspace/server.tests.rs +++ b/crates/pgt_workspace/src/workspace/server.tests.rs @@ -277,3 +277,57 @@ async fn test_dedupe_diagnostics(test_db: PgPool) { Some(TextRange::new(115.into(), 210.into())) ); } + +#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] +async fn test_positional_params(test_db: PgPool) { + let mut conf = PartialConfiguration::init(); + conf.merge_with(PartialConfiguration { + db: Some(PartialDatabaseConfiguration { + database: Some( + test_db + .connect_options() + .get_database() + .unwrap() + .to_string(), + ), + ..Default::default() + }), + ..Default::default() + }); + + let workspace = get_test_workspace(Some(conf)).expect("Unable to create test workspace"); + + let path = PgTPath::new("test.sql"); + + let setup_sql = r" + create table users ( + id serial primary key, + name text not null, + email text not null + ); + "; + test_db.execute(setup_sql).await.expect("setup sql failed"); + + let content = r#"select * from users where id = @one and name = :two and email = :'three';"#; + + workspace + .open_file(OpenFileParams { + path: path.clone(), + content: content.into(), + version: 1, + }) + .expect("Unable to open test file"); + + let diagnostics = workspace + .pull_diagnostics(crate::workspace::PullDiagnosticsParams { + path: path.clone(), + categories: RuleCategories::all(), + max_diagnostics: 100, + only: vec![], + skip: vec![], + }) + .expect("Unable to pull diagnostics") + .diagnostics; + + assert_eq!(diagnostics.len(), 0, "Expected no diagnostic"); +} diff --git a/crates/pgt_workspace/src/workspace/server/pg_query.rs b/crates/pgt_workspace/src/workspace/server/pg_query.rs index 05f1425dc..81f3bf2d1 100644 --- a/crates/pgt_workspace/src/workspace/server/pg_query.rs +++ b/crates/pgt_workspace/src/workspace/server/pg_query.rs @@ -4,6 +4,7 @@ use std::sync::{Arc, Mutex}; use lru::LruCache; use pgt_query_ext::diagnostics::*; use pgt_text_size::TextRange; +use pgt_tokenizer::tokenize; use super::statement_identifier::StatementId; @@ -37,7 +38,7 @@ impl PgQueryStore { } let r = Arc::new( - pgt_query::parse(statement.content()) + pgt_query::parse(&convert_to_positional_params(statement.content())) .map_err(SyntaxDiagnostic::from) .and_then(|ast| { ast.into_root().ok_or_else(|| { @@ -87,10 +88,60 @@ impl PgQueryStore { } } +/// Converts named parameters in a SQL query string to positional parameters. +/// +/// This function scans the input SQL string for named parameters (e.g., `@param`, `:param`, `:'param'`) +/// and replaces them with positional parameters (e.g., `$1`, `$2`, etc.). +/// +/// It maintains the original spacing of the named parameters in the output string. +/// +/// Useful for preparing SQL queries for parsing or execution where named paramters are not supported. +pub fn convert_to_positional_params(text: &str) -> String { + let mut result = String::new(); + let mut param_index = 1; + let mut position = 0; + + for token in tokenize(text) { + let token_len = token.len as usize; + let token_text = &text[position..position + token_len]; + + if matches!(token.kind, pgt_tokenizer::TokenKind::NamedParam { .. }) { + let replacement = format!("${}", param_index); + let original_len = token_text.len(); + let replacement_len = replacement.len(); + + result.push_str(&replacement); + + // maintain original spacing + if replacement_len < original_len { + result.push_str(&" ".repeat(original_len - replacement_len)); + } + + param_index += 1; + } else { + result.push_str(token_text); + } + + position += token_len; + } + + result +} + #[cfg(test)] mod tests { use super::*; + #[test] + fn test_convert_to_positional_params() { + let input = "select * from users where id = @one and name = :two and email = :'three';"; + let result = convert_to_positional_params(input); + assert_eq!( + result, + "select * from users where id = $1 and name = $2 and email = $3;" + ); + } + #[test] fn test_plpgsql_syntax_error() { let input = " From e1ba58299578c2a488c9d0ca7f9d7cda2e561b41 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Tue, 12 Aug 2025 08:42:00 +0200 Subject: [PATCH 2/5] progress --- crates/pgt_workspace/src/workspace/server/pg_query.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/pgt_workspace/src/workspace/server/pg_query.rs b/crates/pgt_workspace/src/workspace/server/pg_query.rs index 81f3bf2d1..872987255 100644 --- a/crates/pgt_workspace/src/workspace/server/pg_query.rs +++ b/crates/pgt_workspace/src/workspace/server/pg_query.rs @@ -138,7 +138,7 @@ mod tests { let result = convert_to_positional_params(input); assert_eq!( result, - "select * from users where id = $1 and name = $2 and email = $3;" + "select * from users where id = $1 and name = $2 and email = $3 ;" ); } From 56dbd1f22c7fee7e32b63553413be759daac7725 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Tue, 12 Aug 2025 16:17:39 +0200 Subject: [PATCH 3/5] progress --- crates/pgt_lexer/src/lexer.rs | 1 + crates/pgt_lexer_codegen/src/syntax_kind.rs | 1 + crates/pgt_tokenizer/src/lib.rs | 70 ++++++--- ..._tests__named_param_colon_raw_vs_cast.snap | 25 +++ crates/pgt_tokenizer/src/token.rs | 2 + .../src/workspace/server/pg_query.rs | 145 ++++++++++++++++++ 6 files changed, 220 insertions(+), 24 deletions(-) create mode 100644 crates/pgt_tokenizer/src/snapshots/pgt_tokenizer__tests__named_param_colon_raw_vs_cast.snap diff --git a/crates/pgt_lexer/src/lexer.rs b/crates/pgt_lexer/src/lexer.rs index ad6db297c..3e6912295 100644 --- a/crates/pgt_lexer/src/lexer.rs +++ b/crates/pgt_lexer/src/lexer.rs @@ -111,6 +111,7 @@ impl<'a> Lexer<'a> { pgt_tokenizer::TokenKind::Tilde => SyntaxKind::TILDE, pgt_tokenizer::TokenKind::Question => SyntaxKind::QUESTION, pgt_tokenizer::TokenKind::Colon => SyntaxKind::COLON, + pgt_tokenizer::TokenKind::DoubleColon => SyntaxKind::DOUBLE_COLON, pgt_tokenizer::TokenKind::Eq => SyntaxKind::EQ, pgt_tokenizer::TokenKind::Bang => SyntaxKind::BANG, pgt_tokenizer::TokenKind::Lt => SyntaxKind::L_ANGLE, diff --git a/crates/pgt_lexer_codegen/src/syntax_kind.rs b/crates/pgt_lexer_codegen/src/syntax_kind.rs index c671e4510..3a0054374 100644 --- a/crates/pgt_lexer_codegen/src/syntax_kind.rs +++ b/crates/pgt_lexer_codegen/src/syntax_kind.rs @@ -37,6 +37,7 @@ const PUNCT: &[(&str, &str)] = &[ ("_", "UNDERSCORE"), (".", "DOT"), (":", "COLON"), + ("::", "DOUBLE_COLON"), ("=", "EQ"), ("!", "BANG"), ("-", "MINUS"), diff --git a/crates/pgt_tokenizer/src/lib.rs b/crates/pgt_tokenizer/src/lib.rs index 80b663630..4b83d435d 100644 --- a/crates/pgt_tokenizer/src/lib.rs +++ b/crates/pgt_tokenizer/src/lib.rs @@ -144,32 +144,37 @@ impl Cursor<'_> { } } ':' => { - // Named parameters in psql with different substitution styles. - // - // https://www.postgresql.org/docs/current/app-psql.html#APP-PSQL-INTERPOLATION - match self.first() { - '\'' => { - // Named parameter with colon prefix and single quotes. - self.bump(); - let terminated = self.single_quoted_string(); - let kind = NamedParamKind::ColonString { terminated }; - TokenKind::NamedParam { kind } - } - '"' => { - // Named parameter with colon prefix and double quotes. - self.bump(); - let terminated = self.double_quoted_string(); - let kind = NamedParamKind::ColonIdentifier { terminated }; - TokenKind::NamedParam { kind } - } - c if is_ident_start(c) => { - // Named parameter with colon prefix. - self.eat_while(is_ident_cont); - TokenKind::NamedParam { - kind: NamedParamKind::ColonRaw, + if self.first() == ':' { + self.bump(); + TokenKind::DoubleColon + } else { + // Named parameters in psql with different substitution styles. + // + // https://www.postgresql.org/docs/current/app-psql.html#APP-PSQL-INTERPOLATION + match self.first() { + '\'' => { + // Named parameter with colon prefix and single quotes. + self.bump(); + let terminated = self.single_quoted_string(); + let kind = NamedParamKind::ColonString { terminated }; + TokenKind::NamedParam { kind } } + '"' => { + // Named parameter with colon prefix and double quotes. + self.bump(); + let terminated = self.double_quoted_string(); + let kind = NamedParamKind::ColonIdentifier { terminated }; + TokenKind::NamedParam { kind } + } + c if is_ident_start(c) => { + // Named parameter with colon prefix. + self.eat_while(is_ident_cont); + TokenKind::NamedParam { + kind: NamedParamKind::ColonRaw, + } + } + _ => TokenKind::Colon, } - _ => TokenKind::Colon, } } // One-symbol tokens. @@ -664,6 +669,23 @@ mod tests { assert_debug_snapshot!(result); } + #[test] + fn debug_simple_cast() { + let result = lex("::test"); + assert_debug_snapshot!(result, @r###" + [ + "::" @ DoubleColon, + "test" @ Ident, + ] + "###); + } + + #[test] + fn named_param_colon_raw_vs_cast() { + let result = lex("select 1 from c where id::test = :id;"); + assert_debug_snapshot!(result); + } + #[test] fn named_param_colon_string() { let result = lex("select 1 from c where id = :'id';"); diff --git a/crates/pgt_tokenizer/src/snapshots/pgt_tokenizer__tests__named_param_colon_raw_vs_cast.snap b/crates/pgt_tokenizer/src/snapshots/pgt_tokenizer__tests__named_param_colon_raw_vs_cast.snap new file mode 100644 index 000000000..ecfd48212 --- /dev/null +++ b/crates/pgt_tokenizer/src/snapshots/pgt_tokenizer__tests__named_param_colon_raw_vs_cast.snap @@ -0,0 +1,25 @@ +--- +source: crates/pgt_tokenizer/src/lib.rs +expression: result +snapshot_kind: text +--- +[ + "select" @ Ident, + " " @ Space, + "1" @ Literal { kind: Int { base: Decimal, empty_int: false } }, + " " @ Space, + "from" @ Ident, + " " @ Space, + "c" @ Ident, + " " @ Space, + "where" @ Ident, + " " @ Space, + "id" @ Ident, + "::" @ DoubleColon, + "test" @ Ident, + " " @ Space, + "=" @ Eq, + " " @ Space, + ":id" @ NamedParam { kind: ColonRaw }, + ";" @ Semi, +] diff --git a/crates/pgt_tokenizer/src/token.rs b/crates/pgt_tokenizer/src/token.rs index e3dbaee28..4d970dc5f 100644 --- a/crates/pgt_tokenizer/src/token.rs +++ b/crates/pgt_tokenizer/src/token.rs @@ -46,6 +46,8 @@ pub enum TokenKind { Minus, /// `:` Colon, + /// `::` + DoubleColon, /// `.` Dot, /// `=` diff --git a/crates/pgt_workspace/src/workspace/server/pg_query.rs b/crates/pgt_workspace/src/workspace/server/pg_query.rs index 872987255..6827864d8 100644 --- a/crates/pgt_workspace/src/workspace/server/pg_query.rs +++ b/crates/pgt_workspace/src/workspace/server/pg_query.rs @@ -142,6 +142,151 @@ mod tests { ); } + #[test] + fn test_convert_to_positional_params_complex() { + let input = r#"CREATE OR REPLACE FUNCTION private.create_contacts_query( + v_and_filter jsonb DEFAULT NULL::jsonb, + v_include_filter jsonb DEFAULT NULL::jsonb, + v_exclude_filter jsonb DEFAULT NULL::jsonb, + v_require_marketing_opt_in boolean DEFAULT false, + v_require_transactional_opt_in boolean DEFAULT false, + v_include_contacts jsonb DEFAULT NULL::jsonb, + v_exclude_contacts jsonb DEFAULT NULL::jsonb, + v_channel_types public.channel_type[] DEFAULT NULL::public.channel_type[], + v_organisation_id uuid DEFAULT NULL::uuid, + v_include_blocked boolean DEFAULT false, + v_include_segment_ids uuid[] DEFAULT NULL::uuid[], + v_exclude_segment_ids uuid[] DEFAULT NULL::uuid[], + v_include_contact_list_ids uuid[] DEFAULT NULL::uuid[], + v_exclude_contact_list_ids uuid[] DEFAULT NULL::uuid[], + v_columns text[] DEFAULT NULL::text[], + -- expects an array of objects with `{"name": "column_name", "type": "column_type"}` format + -- used to include fields not present on the contact table such as `placeholder_values` + v_extra_fields jsonb DEFAULT NULL, + v_count_only boolean DEFAULT false, + -- below are fields that are only used in the UI + -- search is pushed down to the include subqueries + v_search text DEFAULT NULL, + -- order by is only allowed if v_limit is set + "v_order_by" "public"."column_sort"[] DEFAULT NULL::"public"."column_sort"[], + v_limit integer DEFAULT NULL, + v_offset integer DEFAULT NULL +) + RETURNS text + LANGUAGE plpgsql + STABLE + SET search_path TO '' +AS $function$ +declare + v_channel_types_specified public.channel_type[] := ( + select array_agg(channel_type) + from unnest(coalesce(nullif(v_channel_types, '{}'), enum_range(null::public.channel_type))) channel_type + ); + -- we need to include columns we filter on top-level + v_include_subqueries text[]; + v_exclude_subqueries text[]; + v_where_clauses text[]; +begin + if cardinality(v_columns) = 0 and v_count_only is not true then + raise exception using + message = 'No columns provided', + hint = 'Please pass v_columns', + errcode = 'INVIN'; + end if; + + if v_order_by is not null and v_limit is null then + raise exception using + message = 'v_order_by is only allowed if v_limit is set', + hint = 'Please pass v_limit', + errcode = 'INVIN'; + end if; + + if v_limit is not null and v_limit > 50 then + raise exception using + message = 'v_limit is too high', + hint = 'Please pass a v_limit of 50 or lower', + errcode = 'INVIN'; + end if; + + v_where_clauses := array_remove( + array[ + -- is_blocked filter + (case when v_include_blocked is true then null else 'is_blocked is not true' end), + -- opt in filter + ( + case + when v_require_marketing_opt_in is true then + '(' || (select string_agg(format('(%s is not null and %s_marketing_opt_in is true)', ct.type, ct.type), ' or ') from unnest(coalesce(v_channel_types, enum_range(null::public.channel_type))) ct(type)) || ')' + when v_require_transactional_opt_in is true then + '(' || (select string_agg(format('(%s is not null and %s_transactional_opt_in is not null)', channel_type, channel_type), ' or ') from unnest(coalesce(v_channel_types, enum_range(null::public.channel_type))) channel_type) || ')' + when v_channel_types is not null and cardinality(v_channel_types) > 0 then + '(' || (select string_agg(format('(%s is not null)', channel_type), ' or ') from unnest(v_channel_types) channel_type) || ')' + else null + end + ), + -- organisation_id filter + format('"c"."organisation_id" = %L', v_organisation_id), + -- search filter + (case when nullif(v_search, '') is not null then + format('"c"."fts" @@ to_tsquery(''simple'', %L)', v_search) + end) + + ], + null + ); + + -- select cols from public.contact + -- left join with contacts include / exclude contacts + -- + -- where + + return format( + $sql$ + select %s + from public.contact + -- joins + -- where clause + %s + -- order by clause + %s + -- limit + offset + %s %s + $sql$, + (select string_agg(col, ', ') from unnest(v_columns) as col), + -- joins + -- where clause + (case when cardinality(v_where_clauses) > 0 then + 'where ' || array_to_string(v_where_clauses, ' and ') + else + '' + end), + -- order by + case when v_order_by is not null and cardinality(v_order_by) > 0 then + format('order by %s', ( + select string_agg(format('%I.%I %s', 'contact', (x).column_name, case when (x).descending then 'desc' else 'asc' end), ', ') + from unnest(v_order_by) as x + where (x).column_name in (select column_name from information_schema.columns where table_schema = 'public' and table_name = 'contact') + ) + ) + when v_limit is not null then + 'order by c.id' -- we need to order by something if we have a limit and offset, otherwise the results are not deterministic + else null + end, + -- limit + (case when v_limit is not null then format('limit %L', v_limit) end), + -- offset + (case when v_offset is not null then format('offset %L', v_offset) end) + ); +end +$function$ +; "#; + let result = convert_to_positional_params(input); + assert_eq!( + result, + "select * from users where id = $1 and name = $2 and email = $3 ;" + ); + } + #[test] fn test_plpgsql_syntax_error() { let input = " From 79f8dfa69233b76b3da6fcdfbaf5e6ff2f82dfe7 Mon Sep 17 00:00:00 2001 From: psteinroe Date: Tue, 12 Aug 2025 23:45:06 +0200 Subject: [PATCH 4/5] progress --- .../src/workspace/server/pg_query.rs | 145 ------------------ 1 file changed, 145 deletions(-) diff --git a/crates/pgt_workspace/src/workspace/server/pg_query.rs b/crates/pgt_workspace/src/workspace/server/pg_query.rs index 6827864d8..872987255 100644 --- a/crates/pgt_workspace/src/workspace/server/pg_query.rs +++ b/crates/pgt_workspace/src/workspace/server/pg_query.rs @@ -142,151 +142,6 @@ mod tests { ); } - #[test] - fn test_convert_to_positional_params_complex() { - let input = r#"CREATE OR REPLACE FUNCTION private.create_contacts_query( - v_and_filter jsonb DEFAULT NULL::jsonb, - v_include_filter jsonb DEFAULT NULL::jsonb, - v_exclude_filter jsonb DEFAULT NULL::jsonb, - v_require_marketing_opt_in boolean DEFAULT false, - v_require_transactional_opt_in boolean DEFAULT false, - v_include_contacts jsonb DEFAULT NULL::jsonb, - v_exclude_contacts jsonb DEFAULT NULL::jsonb, - v_channel_types public.channel_type[] DEFAULT NULL::public.channel_type[], - v_organisation_id uuid DEFAULT NULL::uuid, - v_include_blocked boolean DEFAULT false, - v_include_segment_ids uuid[] DEFAULT NULL::uuid[], - v_exclude_segment_ids uuid[] DEFAULT NULL::uuid[], - v_include_contact_list_ids uuid[] DEFAULT NULL::uuid[], - v_exclude_contact_list_ids uuid[] DEFAULT NULL::uuid[], - v_columns text[] DEFAULT NULL::text[], - -- expects an array of objects with `{"name": "column_name", "type": "column_type"}` format - -- used to include fields not present on the contact table such as `placeholder_values` - v_extra_fields jsonb DEFAULT NULL, - v_count_only boolean DEFAULT false, - -- below are fields that are only used in the UI - -- search is pushed down to the include subqueries - v_search text DEFAULT NULL, - -- order by is only allowed if v_limit is set - "v_order_by" "public"."column_sort"[] DEFAULT NULL::"public"."column_sort"[], - v_limit integer DEFAULT NULL, - v_offset integer DEFAULT NULL -) - RETURNS text - LANGUAGE plpgsql - STABLE - SET search_path TO '' -AS $function$ -declare - v_channel_types_specified public.channel_type[] := ( - select array_agg(channel_type) - from unnest(coalesce(nullif(v_channel_types, '{}'), enum_range(null::public.channel_type))) channel_type - ); - -- we need to include columns we filter on top-level - v_include_subqueries text[]; - v_exclude_subqueries text[]; - v_where_clauses text[]; -begin - if cardinality(v_columns) = 0 and v_count_only is not true then - raise exception using - message = 'No columns provided', - hint = 'Please pass v_columns', - errcode = 'INVIN'; - end if; - - if v_order_by is not null and v_limit is null then - raise exception using - message = 'v_order_by is only allowed if v_limit is set', - hint = 'Please pass v_limit', - errcode = 'INVIN'; - end if; - - if v_limit is not null and v_limit > 50 then - raise exception using - message = 'v_limit is too high', - hint = 'Please pass a v_limit of 50 or lower', - errcode = 'INVIN'; - end if; - - v_where_clauses := array_remove( - array[ - -- is_blocked filter - (case when v_include_blocked is true then null else 'is_blocked is not true' end), - -- opt in filter - ( - case - when v_require_marketing_opt_in is true then - '(' || (select string_agg(format('(%s is not null and %s_marketing_opt_in is true)', ct.type, ct.type), ' or ') from unnest(coalesce(v_channel_types, enum_range(null::public.channel_type))) ct(type)) || ')' - when v_require_transactional_opt_in is true then - '(' || (select string_agg(format('(%s is not null and %s_transactional_opt_in is not null)', channel_type, channel_type), ' or ') from unnest(coalesce(v_channel_types, enum_range(null::public.channel_type))) channel_type) || ')' - when v_channel_types is not null and cardinality(v_channel_types) > 0 then - '(' || (select string_agg(format('(%s is not null)', channel_type), ' or ') from unnest(v_channel_types) channel_type) || ')' - else null - end - ), - -- organisation_id filter - format('"c"."organisation_id" = %L', v_organisation_id), - -- search filter - (case when nullif(v_search, '') is not null then - format('"c"."fts" @@ to_tsquery(''simple'', %L)', v_search) - end) - - ], - null - ); - - -- select cols from public.contact - -- left join with contacts include / exclude contacts - -- - -- where - - return format( - $sql$ - select %s - from public.contact - -- joins - -- where clause - %s - -- order by clause - %s - -- limit + offset - %s %s - $sql$, - (select string_agg(col, ', ') from unnest(v_columns) as col), - -- joins - -- where clause - (case when cardinality(v_where_clauses) > 0 then - 'where ' || array_to_string(v_where_clauses, ' and ') - else - '' - end), - -- order by - case when v_order_by is not null and cardinality(v_order_by) > 0 then - format('order by %s', ( - select string_agg(format('%I.%I %s', 'contact', (x).column_name, case when (x).descending then 'desc' else 'asc' end), ', ') - from unnest(v_order_by) as x - where (x).column_name in (select column_name from information_schema.columns where table_schema = 'public' and table_name = 'contact') - ) - ) - when v_limit is not null then - 'order by c.id' -- we need to order by something if we have a limit and offset, otherwise the results are not deterministic - else null - end, - -- limit - (case when v_limit is not null then format('limit %L', v_limit) end), - -- offset - (case when v_offset is not null then format('offset %L', v_offset) end) - ); -end -$function$ -; "#; - let result = convert_to_positional_params(input); - assert_eq!( - result, - "select * from users where id = $1 and name = $2 and email = $3 ;" - ); - } - #[test] fn test_plpgsql_syntax_error() { let input = " From 0df918846230bb49107de2a8d2663824b672359a Mon Sep 17 00:00:00 2001 From: psteinroe Date: Thu, 14 Aug 2025 14:19:36 +0200 Subject: [PATCH 5/5] progress --- .../src/workspace/server/pg_query.rs | 28 ++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/crates/pgt_workspace/src/workspace/server/pg_query.rs b/crates/pgt_workspace/src/workspace/server/pg_query.rs index 872987255..bd9ffdfce 100644 --- a/crates/pgt_workspace/src/workspace/server/pg_query.rs +++ b/crates/pgt_workspace/src/workspace/server/pg_query.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::num::NonZeroUsize; use std::sync::{Arc, Mutex}; @@ -97,7 +98,8 @@ impl PgQueryStore { /// /// Useful for preparing SQL queries for parsing or execution where named paramters are not supported. pub fn convert_to_positional_params(text: &str) -> String { - let mut result = String::new(); + let mut result = String::with_capacity(text.len()); + let mut param_mapping: HashMap<&str, usize> = HashMap::new(); let mut param_index = 1; let mut position = 0; @@ -106,7 +108,17 @@ pub fn convert_to_positional_params(text: &str) -> String { let token_text = &text[position..position + token_len]; if matches!(token.kind, pgt_tokenizer::TokenKind::NamedParam { .. }) { - let replacement = format!("${}", param_index); + let idx = match param_mapping.get(token_text) { + Some(&index) => index, + None => { + let index = param_index; + param_mapping.insert(token_text, index); + param_index += 1; + index + } + }; + + let replacement = format!("${}", idx); let original_len = token_text.len(); let replacement_len = replacement.len(); @@ -116,8 +128,6 @@ pub fn convert_to_positional_params(text: &str) -> String { if replacement_len < original_len { result.push_str(&" ".repeat(original_len - replacement_len)); } - - param_index += 1; } else { result.push_str(token_text); } @@ -142,6 +152,16 @@ mod tests { ); } + #[test] + fn test_convert_to_positional_params_with_duplicates() { + let input = "select * from users where first_name = @one and starts_with(email, @one) and created_at > @two;"; + let result = convert_to_positional_params(input); + assert_eq!( + result, + "select * from users where first_name = $1 and starts_with(email, $1 ) and created_at > $2 ;" + ); + } + #[test] fn test_plpgsql_syntax_error() { let input = "