Skip to content

Commit 56dbd1f

Browse files
committed
progress
1 parent e1ba582 commit 56dbd1f

File tree

6 files changed

+220
-24
lines changed

6 files changed

+220
-24
lines changed

crates/pgt_lexer/src/lexer.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ impl<'a> Lexer<'a> {
111111
pgt_tokenizer::TokenKind::Tilde => SyntaxKind::TILDE,
112112
pgt_tokenizer::TokenKind::Question => SyntaxKind::QUESTION,
113113
pgt_tokenizer::TokenKind::Colon => SyntaxKind::COLON,
114+
pgt_tokenizer::TokenKind::DoubleColon => SyntaxKind::DOUBLE_COLON,
114115
pgt_tokenizer::TokenKind::Eq => SyntaxKind::EQ,
115116
pgt_tokenizer::TokenKind::Bang => SyntaxKind::BANG,
116117
pgt_tokenizer::TokenKind::Lt => SyntaxKind::L_ANGLE,

crates/pgt_lexer_codegen/src/syntax_kind.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ const PUNCT: &[(&str, &str)] = &[
3737
("_", "UNDERSCORE"),
3838
(".", "DOT"),
3939
(":", "COLON"),
40+
("::", "DOUBLE_COLON"),
4041
("=", "EQ"),
4142
("!", "BANG"),
4243
("-", "MINUS"),

crates/pgt_tokenizer/src/lib.rs

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -144,32 +144,37 @@ impl Cursor<'_> {
144144
}
145145
}
146146
':' => {
147-
// Named parameters in psql with different substitution styles.
148-
//
149-
// https://www.postgresql.org/docs/current/app-psql.html#APP-PSQL-INTERPOLATION
150-
match self.first() {
151-
'\'' => {
152-
// Named parameter with colon prefix and single quotes.
153-
self.bump();
154-
let terminated = self.single_quoted_string();
155-
let kind = NamedParamKind::ColonString { terminated };
156-
TokenKind::NamedParam { kind }
157-
}
158-
'"' => {
159-
// Named parameter with colon prefix and double quotes.
160-
self.bump();
161-
let terminated = self.double_quoted_string();
162-
let kind = NamedParamKind::ColonIdentifier { terminated };
163-
TokenKind::NamedParam { kind }
164-
}
165-
c if is_ident_start(c) => {
166-
// Named parameter with colon prefix.
167-
self.eat_while(is_ident_cont);
168-
TokenKind::NamedParam {
169-
kind: NamedParamKind::ColonRaw,
147+
if self.first() == ':' {
148+
self.bump();
149+
TokenKind::DoubleColon
150+
} else {
151+
// Named parameters in psql with different substitution styles.
152+
//
153+
// https://www.postgresql.org/docs/current/app-psql.html#APP-PSQL-INTERPOLATION
154+
match self.first() {
155+
'\'' => {
156+
// Named parameter with colon prefix and single quotes.
157+
self.bump();
158+
let terminated = self.single_quoted_string();
159+
let kind = NamedParamKind::ColonString { terminated };
160+
TokenKind::NamedParam { kind }
170161
}
162+
'"' => {
163+
// Named parameter with colon prefix and double quotes.
164+
self.bump();
165+
let terminated = self.double_quoted_string();
166+
let kind = NamedParamKind::ColonIdentifier { terminated };
167+
TokenKind::NamedParam { kind }
168+
}
169+
c if is_ident_start(c) => {
170+
// Named parameter with colon prefix.
171+
self.eat_while(is_ident_cont);
172+
TokenKind::NamedParam {
173+
kind: NamedParamKind::ColonRaw,
174+
}
175+
}
176+
_ => TokenKind::Colon,
171177
}
172-
_ => TokenKind::Colon,
173178
}
174179
}
175180
// One-symbol tokens.
@@ -664,6 +669,23 @@ mod tests {
664669
assert_debug_snapshot!(result);
665670
}
666671

672+
#[test]
673+
fn debug_simple_cast() {
674+
let result = lex("::test");
675+
assert_debug_snapshot!(result, @r###"
676+
[
677+
"::" @ DoubleColon,
678+
"test" @ Ident,
679+
]
680+
"###);
681+
}
682+
683+
#[test]
684+
fn named_param_colon_raw_vs_cast() {
685+
let result = lex("select 1 from c where id::test = :id;");
686+
assert_debug_snapshot!(result);
687+
}
688+
667689
#[test]
668690
fn named_param_colon_string() {
669691
let result = lex("select 1 from c where id = :'id';");
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
---
2+
source: crates/pgt_tokenizer/src/lib.rs
3+
expression: result
4+
snapshot_kind: text
5+
---
6+
[
7+
"select" @ Ident,
8+
" " @ Space,
9+
"1" @ Literal { kind: Int { base: Decimal, empty_int: false } },
10+
" " @ Space,
11+
"from" @ Ident,
12+
" " @ Space,
13+
"c" @ Ident,
14+
" " @ Space,
15+
"where" @ Ident,
16+
" " @ Space,
17+
"id" @ Ident,
18+
"::" @ DoubleColon,
19+
"test" @ Ident,
20+
" " @ Space,
21+
"=" @ Eq,
22+
" " @ Space,
23+
":id" @ NamedParam { kind: ColonRaw },
24+
";" @ Semi,
25+
]

crates/pgt_tokenizer/src/token.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ pub enum TokenKind {
4646
Minus,
4747
/// `:`
4848
Colon,
49+
/// `::`
50+
DoubleColon,
4951
/// `.`
5052
Dot,
5153
/// `=`

crates/pgt_workspace/src/workspace/server/pg_query.rs

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,151 @@ mod tests {
142142
);
143143
}
144144

145+
#[test]
146+
fn test_convert_to_positional_params_complex() {
147+
let input = r#"CREATE OR REPLACE FUNCTION private.create_contacts_query(
148+
v_and_filter jsonb DEFAULT NULL::jsonb,
149+
v_include_filter jsonb DEFAULT NULL::jsonb,
150+
v_exclude_filter jsonb DEFAULT NULL::jsonb,
151+
v_require_marketing_opt_in boolean DEFAULT false,
152+
v_require_transactional_opt_in boolean DEFAULT false,
153+
v_include_contacts jsonb DEFAULT NULL::jsonb,
154+
v_exclude_contacts jsonb DEFAULT NULL::jsonb,
155+
v_channel_types public.channel_type[] DEFAULT NULL::public.channel_type[],
156+
v_organisation_id uuid DEFAULT NULL::uuid,
157+
v_include_blocked boolean DEFAULT false,
158+
v_include_segment_ids uuid[] DEFAULT NULL::uuid[],
159+
v_exclude_segment_ids uuid[] DEFAULT NULL::uuid[],
160+
v_include_contact_list_ids uuid[] DEFAULT NULL::uuid[],
161+
v_exclude_contact_list_ids uuid[] DEFAULT NULL::uuid[],
162+
v_columns text[] DEFAULT NULL::text[],
163+
-- expects an array of objects with `{"name": "column_name", "type": "column_type"}` format
164+
-- used to include fields not present on the contact table such as `placeholder_values`
165+
v_extra_fields jsonb DEFAULT NULL,
166+
v_count_only boolean DEFAULT false,
167+
-- below are fields that are only used in the UI
168+
-- search is pushed down to the include subqueries
169+
v_search text DEFAULT NULL,
170+
-- order by is only allowed if v_limit is set
171+
"v_order_by" "public"."column_sort"[] DEFAULT NULL::"public"."column_sort"[],
172+
v_limit integer DEFAULT NULL,
173+
v_offset integer DEFAULT NULL
174+
)
175+
RETURNS text
176+
LANGUAGE plpgsql
177+
STABLE
178+
SET search_path TO ''
179+
AS $function$
180+
declare
181+
v_channel_types_specified public.channel_type[] := (
182+
select array_agg(channel_type)
183+
from unnest(coalesce(nullif(v_channel_types, '{}'), enum_range(null::public.channel_type))) channel_type
184+
);
185+
-- we need to include columns we filter on top-level
186+
v_include_subqueries text[];
187+
v_exclude_subqueries text[];
188+
v_where_clauses text[];
189+
begin
190+
if cardinality(v_columns) = 0 and v_count_only is not true then
191+
raise exception using
192+
message = 'No columns provided',
193+
hint = 'Please pass v_columns',
194+
errcode = 'INVIN';
195+
end if;
196+
197+
if v_order_by is not null and v_limit is null then
198+
raise exception using
199+
message = 'v_order_by is only allowed if v_limit is set',
200+
hint = 'Please pass v_limit',
201+
errcode = 'INVIN';
202+
end if;
203+
204+
if v_limit is not null and v_limit > 50 then
205+
raise exception using
206+
message = 'v_limit is too high',
207+
hint = 'Please pass a v_limit of 50 or lower',
208+
errcode = 'INVIN';
209+
end if;
210+
211+
v_where_clauses := array_remove(
212+
array[
213+
-- is_blocked filter
214+
(case when v_include_blocked is true then null else 'is_blocked is not true' end),
215+
-- opt in filter
216+
(
217+
case
218+
when v_require_marketing_opt_in is true then
219+
'(' || (select string_agg(format('(%s is not null and %s_marketing_opt_in is true)', ct.type, ct.type), ' or ') from unnest(coalesce(v_channel_types, enum_range(null::public.channel_type))) ct(type)) || ')'
220+
when v_require_transactional_opt_in is true then
221+
'(' || (select string_agg(format('(%s is not null and %s_transactional_opt_in is not null)', channel_type, channel_type), ' or ') from unnest(coalesce(v_channel_types, enum_range(null::public.channel_type))) channel_type) || ')'
222+
when v_channel_types is not null and cardinality(v_channel_types) > 0 then
223+
'(' || (select string_agg(format('(%s is not null)', channel_type), ' or ') from unnest(v_channel_types) channel_type) || ')'
224+
else null
225+
end
226+
),
227+
-- organisation_id filter
228+
format('"c"."organisation_id" = %L', v_organisation_id),
229+
-- search filter
230+
(case when nullif(v_search, '') is not null then
231+
format('"c"."fts" @@ to_tsquery(''simple'', %L)', v_search)
232+
end)
233+
234+
],
235+
null
236+
);
237+
238+
-- select cols from public.contact
239+
-- left join with contacts include / exclude contacts
240+
--
241+
-- where <where clauses>
242+
243+
return format(
244+
$sql$
245+
select %s
246+
from public.contact
247+
-- joins
248+
-- where clause
249+
%s
250+
-- order by clause
251+
%s
252+
-- limit + offset
253+
%s %s
254+
$sql$,
255+
(select string_agg(col, ', ') from unnest(v_columns) as col),
256+
-- joins
257+
-- where clause
258+
(case when cardinality(v_where_clauses) > 0 then
259+
'where ' || array_to_string(v_where_clauses, ' and ')
260+
else
261+
''
262+
end),
263+
-- order by
264+
case when v_order_by is not null and cardinality(v_order_by) > 0 then
265+
format('order by %s', (
266+
select string_agg(format('%I.%I %s', 'contact', (x).column_name, case when (x).descending then 'desc' else 'asc' end), ', ')
267+
from unnest(v_order_by) as x
268+
where (x).column_name in (select column_name from information_schema.columns where table_schema = 'public' and table_name = 'contact')
269+
)
270+
)
271+
when v_limit is not null then
272+
'order by c.id' -- we need to order by something if we have a limit and offset, otherwise the results are not deterministic
273+
else null
274+
end,
275+
-- limit
276+
(case when v_limit is not null then format('limit %L', v_limit) end),
277+
-- offset
278+
(case when v_offset is not null then format('offset %L', v_offset) end)
279+
);
280+
end
281+
$function$
282+
; "#;
283+
let result = convert_to_positional_params(input);
284+
assert_eq!(
285+
result,
286+
"select * from users where id = $1 and name = $2 and email = $3 ;"
287+
);
288+
}
289+
145290
#[test]
146291
fn test_plpgsql_syntax_error() {
147292
let input = "

0 commit comments

Comments
 (0)