Skip to content

Commit d37d7dc

Browse files
committed
feat: WITHIN GROUP expression support
1 parent 6a54d27 commit d37d7dc

File tree

4 files changed

+119
-47
lines changed

4 files changed

+119
-47
lines changed

.github/workflows/rust.yml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,7 @@ jobs:
5656
with:
5757
rust-version: ${{ matrix.rust }}
5858
- name: Install Tarpaulin
59-
uses: actions-rs/[email protected]
60-
with:
61-
crate: cargo-tarpaulin
62-
version: 0.14.2
63-
use-tool-cache: true
59+
run: cargo install --version 0.14.2 --features vendored-openssl cargo-tarpaulin
6460
- name: Checkout
6561
uses: actions/checkout@v2
6662
- name: Test

src/ast/mod.rs

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,8 @@ pub enum Expr {
348348
ListAgg(ListAgg),
349349
/// The `ARRAY_AGG` function `SELECT ARRAY_AGG(... ORDER BY ...)`
350350
ArrayAgg(ArrayAgg),
351+
/// The `WITHIN GROUP` expr `... WITHIN GROUP (ORDER BY ...)`
352+
WithinGroup(WithinGroup),
351353
/// The `GROUPING SETS` expr.
352354
GroupingSets(Vec<Vec<Expr>>),
353355
/// The `CUBE` expr.
@@ -549,6 +551,7 @@ impl fmt::Display for Expr {
549551
Expr::ArraySubquery(s) => write!(f, "ARRAY({})", s),
550552
Expr::ListAgg(listagg) => write!(f, "{}", listagg),
551553
Expr::ArrayAgg(arrayagg) => write!(f, "{}", arrayagg),
554+
Expr::WithinGroup(withingroup) => write!(f, "{}", withingroup),
552555
Expr::GroupingSets(sets) => {
553556
write!(f, "GROUPING SETS (")?;
554557
let mut sep = "";
@@ -2420,7 +2423,6 @@ pub struct ListAgg {
24202423
pub expr: Box<Expr>,
24212424
pub separator: Option<Box<Expr>>,
24222425
pub on_overflow: Option<ListAggOnOverflow>,
2423-
pub within_group: Vec<OrderByExpr>,
24242426
}
24252427

24262428
impl fmt::Display for ListAgg {
@@ -2438,13 +2440,6 @@ impl fmt::Display for ListAgg {
24382440
write!(f, "{}", on_overflow)?;
24392441
}
24402442
write!(f, ")")?;
2441-
if !self.within_group.is_empty() {
2442-
write!(
2443-
f,
2444-
" WITHIN GROUP (ORDER BY {})",
2445-
display_comma_separated(&self.within_group)
2446-
)?;
2447-
}
24482443
Ok(())
24492444
}
24502445
}
@@ -2494,7 +2489,6 @@ pub struct ArrayAgg {
24942489
pub expr: Box<Expr>,
24952490
pub order_by: Option<Box<OrderByExpr>>,
24962491
pub limit: Option<Box<Expr>>,
2497-
pub within_group: bool, // order by is used inside a within group or not
24982492
}
24992493

25002494
impl fmt::Display for ArrayAgg {
@@ -2505,20 +2499,33 @@ impl fmt::Display for ArrayAgg {
25052499
if self.distinct { "DISTINCT " } else { "" },
25062500
self.expr
25072501
)?;
2508-
if !self.within_group {
2509-
if let Some(order_by) = &self.order_by {
2510-
write!(f, " ORDER BY {}", order_by)?;
2511-
}
2512-
if let Some(limit) = &self.limit {
2513-
write!(f, " LIMIT {}", limit)?;
2514-
}
2502+
if let Some(order_by) = &self.order_by {
2503+
write!(f, " ORDER BY {}", order_by)?;
25152504
}
2516-
write!(f, ")")?;
2517-
if self.within_group {
2518-
if let Some(order_by) = &self.order_by {
2519-
write!(f, " WITHIN GROUP (ORDER BY {})", order_by)?;
2520-
}
2505+
if let Some(limit) = &self.limit {
2506+
write!(f, " LIMIT {}", limit)?;
25212507
}
2508+
write!(f, ")")?;
2509+
Ok(())
2510+
}
2511+
}
2512+
2513+
/// A `WITHIN GROUP` invocation `<expr> WITHIN GROUP (ORDER BY <sort_expr> )`
2514+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
2515+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2516+
pub struct WithinGroup {
2517+
pub expr: Box<Expr>,
2518+
pub order_by: Vec<OrderByExpr>,
2519+
}
2520+
2521+
impl fmt::Display for WithinGroup {
2522+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2523+
write!(
2524+
f,
2525+
"{} WITHIN GROUP (ORDER BY {})",
2526+
self.expr,
2527+
display_comma_separated(&self.order_by),
2528+
)?;
25222529
Ok(())
25232530
}
25242531
}

src/parser.rs

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -627,14 +627,33 @@ impl<'a> Parser<'a> {
627627
None
628628
};
629629

630-
Ok(Expr::Function(Function {
630+
let within_group = if self.parse_keywords(&[Keyword::WITHIN, Keyword::GROUP]) {
631+
self.expect_token(&Token::LParen)?;
632+
self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?;
633+
let order_by_expr = self.parse_comma_separated(Parser::parse_order_by_expr)?;
634+
self.expect_token(&Token::RParen)?;
635+
Some(order_by_expr)
636+
} else {
637+
None
638+
};
639+
640+
let function = Expr::Function(Function {
631641
name,
632642
args,
633643
over,
634644
distinct,
635645
special: false,
636646
approximate: false,
637-
}))
647+
});
648+
649+
Ok(if let Some(within_group) = within_group {
650+
Expr::WithinGroup(WithinGroup {
651+
expr: Box::new(function),
652+
order_by: within_group,
653+
})
654+
} else {
655+
function
656+
})
638657
}
639658

640659
pub fn parse_time_functions(&mut self, name: ObjectName) -> Result<Expr, ParserError> {
@@ -995,17 +1014,24 @@ impl<'a> Parser<'a> {
9951014
self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?;
9961015
let order_by_expr = self.parse_comma_separated(Parser::parse_order_by_expr)?;
9971016
self.expect_token(&Token::RParen)?;
998-
order_by_expr
1017+
Some(order_by_expr)
9991018
} else {
1000-
vec![]
1019+
None
10011020
};
1002-
Ok(Expr::ListAgg(ListAgg {
1021+
let list_agg = Expr::ListAgg(ListAgg {
10031022
distinct,
10041023
expr,
10051024
separator,
10061025
on_overflow,
1007-
within_group,
1008-
}))
1026+
});
1027+
Ok(if let Some(within_group) = within_group {
1028+
Expr::WithinGroup(WithinGroup {
1029+
expr: Box::new(list_agg),
1030+
order_by: within_group,
1031+
})
1032+
} else {
1033+
list_agg
1034+
})
10091035
}
10101036

10111037
pub fn parse_array_agg_expr(&mut self) -> Result<Expr, ParserError> {
@@ -1031,7 +1057,6 @@ impl<'a> Parser<'a> {
10311057
expr,
10321058
order_by,
10331059
limit,
1034-
within_group: false,
10351060
}));
10361061
}
10371062
// Snowflake defines ORDERY BY in within group instead of inside the function like
@@ -1042,18 +1067,25 @@ impl<'a> Parser<'a> {
10421067
self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?;
10431068
let order_by_expr = self.parse_order_by_expr()?;
10441069
self.expect_token(&Token::RParen)?;
1045-
Some(Box::new(order_by_expr))
1070+
Some(order_by_expr)
10461071
} else {
10471072
None
10481073
};
10491074

1050-
Ok(Expr::ArrayAgg(ArrayAgg {
1075+
let array_agg = Expr::ArrayAgg(ArrayAgg {
10511076
distinct,
10521077
expr,
1053-
order_by: within_group,
1078+
order_by: None,
10541079
limit: None,
1055-
within_group: true,
1056-
}))
1080+
});
1081+
Ok(if let Some(within_group) = within_group {
1082+
Expr::WithinGroup(WithinGroup {
1083+
expr: Box::new(array_agg),
1084+
order_by: vec![within_group],
1085+
})
1086+
} else {
1087+
array_agg
1088+
})
10571089
}
10581090

10591091
// This function parses date/time fields for both the EXTRACT function-like

tests/sqlparser_common.rs

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,14 +1702,16 @@ fn parse_listagg() {
17021702
},
17031703
];
17041704
assert_eq!(
1705-
&Expr::ListAgg(ListAgg {
1706-
distinct: true,
1707-
expr,
1708-
separator: Some(Box::new(Expr::Value(Value::SingleQuotedString(
1709-
", ".to_string()
1710-
)))),
1711-
on_overflow,
1712-
within_group
1705+
&Expr::WithinGroup(WithinGroup {
1706+
expr: Box::new(Expr::ListAgg(ListAgg {
1707+
distinct: true,
1708+
expr,
1709+
separator: Some(Box::new(Expr::Value(Value::SingleQuotedString(
1710+
", ".to_string()
1711+
)))),
1712+
on_overflow,
1713+
})),
1714+
order_by: within_group
17131715
}),
17141716
expr_from_projection(only(&select.projection))
17151717
);
@@ -1736,6 +1738,41 @@ fn parse_array_agg_func() {
17361738
}
17371739
}
17381740

1741+
#[test]
1742+
fn parse_within_group() {
1743+
let sql = "SELECT PERCENTILE_CONT(0.0) WITHIN GROUP (ORDER BY name ASC NULLS FIRST)";
1744+
let select = verified_only_select(sql);
1745+
1746+
#[cfg(feature = "bigdecimal")]
1747+
let value = bigdecimal::BigDecimal::from(0);
1748+
#[cfg(not(feature = "bigdecimal"))]
1749+
let value = "0.0".to_string();
1750+
let expr = Expr::Value(Value::Number(value, false));
1751+
let function = Expr::Function(Function {
1752+
name: ObjectName(vec![Ident::new("PERCENTILE_CONT")]),
1753+
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(expr))],
1754+
over: None,
1755+
distinct: false,
1756+
special: false,
1757+
approximate: false,
1758+
});
1759+
let within_group = vec![OrderByExpr {
1760+
expr: Expr::Identifier(Ident {
1761+
value: "name".to_string(),
1762+
quote_style: None,
1763+
}),
1764+
asc: Some(true),
1765+
nulls_first: Some(true),
1766+
}];
1767+
assert_eq!(
1768+
&Expr::WithinGroup(WithinGroup {
1769+
expr: Box::new(function),
1770+
order_by: within_group
1771+
}),
1772+
expr_from_projection(only(&select.projection))
1773+
);
1774+
}
1775+
17391776
#[test]
17401777
fn parse_create_table() {
17411778
let sql = "CREATE TABLE uk_cities (\

0 commit comments

Comments
 (0)