Skip to content

Commit

Permalink
[Fix](nereids) set all nullable aggregate function to alwaysnullable …
Browse files Browse the repository at this point in the history
…in window expression (#40693)
  • Loading branch information
feiniaofeiafei authored Sep 13, 2024
1 parent 18a374f commit 55d6d64
Show file tree
Hide file tree
Showing 3 changed files with 369 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
Expand Down Expand Up @@ -64,13 +60,11 @@ private Plan normalize(LogicalProject<Plan> project) {
if (output instanceof WindowExpression) {
WindowExpression windowExpression = (WindowExpression) output;
Expression expression = ((WindowExpression) output).getFunction();
if (expression instanceof Sum || expression instanceof Max
|| expression instanceof Min || expression instanceof Avg) {
// sum, max, min and avg in window function should be always nullable
windowExpression = ((WindowExpression) output)
.withFunction(
((NullableAggregateFunction) expression).withAlwaysNullable(true)
);
if (expression instanceof NullableAggregateFunction) {
// NullableAggregateFunction in window function should be always nullable
// Because there may be no data in the window frame, null values will be generated.
windowExpression = ((WindowExpression) output).withFunction(
((NullableAggregateFunction) expression).withAlwaysNullable(true));
}

ImmutableList.Builder<Expression> nonLiteralPartitionKeys =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !max --
\N
2
3
3
3
5
5
5
5

-- !min --
\N
2
2
3
2
2
3
3
3

-- !sum --
\N
2
5
6
5
7
8
8
8

-- !avg --
\N
2.0
2.5
3.0
2.5
3.5
4.0
4.0
4.0

-- !topn --
\N
{"1":1}
{"2":1,"1":1}
{"3":1,"2":1}
{"4":1,"3":1}
{"5":1,"4":1}
{"6":1,"5":1}
{"7":1,"6":1}
{"8":1,"7":1}

-- !topn_array --
\N
[1]
[2, 1]
[3, 2]
[4, 3]
[5, 4]
[6, 5]
[7, 6]
[8, 7]

-- !topn_weighted --
\N
[1]
[2, 1]
[3, 2]
[4, 3]
[5, 4]
[6, 5]
[7, 6]
[8, 7]

-- !max_by --
\N
2
3
3
2
5
3
5
3

-- !min_by --
\N
2
2
3
3
2
5
3
5

-- !avg_weighted --
\N
2.0
2.5
3.0
2.5
3.5
4.0
4.0
5.0

-- !variance --
\N
0.0
0.25
0.0
0.25
2.25
1.0
1.0
1.0

-- !variance_samp --
\N
0.0
0.5
0.0
0.5
4.5
2.0
2.0
2.0

-- !percentile --
\N
2.0
2.5
3.0
2.5
3.5
4.0
4.0
4.0

-- !percentile_approx --
\N
2.0
3.0
3.0
3.0
5.0
5.0
5.0
5.0

-- !stddev --
\N
0.0
0.5
0.0
0.5
1.5
1.0
1.0
1.0

-- !stddev_samp --
\N
0.0
0.7071067811865476
0.0
0.7071067811865476
2.1213203435596424
1.4142135623730951
1.4142135623730951
1.4142135623730951

-- !corr --
\N
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0

-- !covar --
\N
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0

-- !covar_samp --
\N
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0

-- !group_concat --
\N
1
1,1
1,2
2,2
2,2
2,2
2,2
2

-- !retention --
\N
[1, 0]
[1, 1]
[1, 1]
[0, 0]
[0, 0]
[0, 0]
[0, 0]
[0, 0]

-- !group_bit_and --
\N
1
1
0
2
2
2
2
2

-- !group_bit_or --
\N
1
1
3
2
2
2
2
2

-- !group_bit_xor --
\N
1
0
3
0
0
0
0
2

-- !group_bitmap_xor --
\N
\N
\N
\N
\N
\N
\N
\N
\N

-- !sum_foreach --
\N
[1, 2]
[4, 4]
[4, 7]
[4, 7]
[8, 4]
[6, 4]
[2, 4]
[2, 25]

-- !sequence_match --
\N
false
false
false

Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("normalize_window_nullable_agg") {
sql "drop table if exists normalize_window_nullable_agg"
sql """create table normalize_window_nullable_agg (a int, b int,c int,d array<int>) distributed by hash(a)
properties("replication_num"="1");
"""
sql """insert into normalize_window_nullable_agg values(1,2,1,[1,2]),(1,3,2,[3,2]),(2,3,3,[1,5]),(2,2,4,[3,2]),(2,5,5,[5,2])
,(2,3,6,[1,2]),(2,5,7,[1,2]),(null,3,8,[1,23]),(null,6,9,[3,2]);"""
qt_max "select max(b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_min "select min(b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_sum "select sum(b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_avg "select avg(b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_topn "select topn(c,3) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_topn_array "select topn_array(c,3) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_topn_weighted "select topn_weighted(c,c,3) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_max_by "select max_by(b,c) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_min_by "select min_by(b,c) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_avg_weighted "select avg_weighted(b,a) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_variance "select variance(b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_variance_samp "select variance_samp(b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_percentile "select PERCENTILE(b,0.5) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_percentile_approx "select PERCENTILE_approx(b,0.99) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_stddev "select stddev(b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_stddev_samp "select stddev_samp(b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_corr "select corr(a,b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_covar "select covar(a,b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_covar_samp "select covar_samp(a,b) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_group_concat "select group_concat(cast(a as varchar(10)),',') over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_retention "select retention(a=1,b>2) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_group_bit_and "select group_bit_and(a) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_group_bit_or "select group_bit_or(a) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_group_bit_xor "select group_bit_xor(a) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_group_bitmap_xor "select group_bitmap_xor(to_bitmap(a)) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"
qt_sum_foreach "select sum_foreach(d) over(order by c rows between 2 preceding and 1 preceding) from normalize_window_nullable_agg"

sql "drop table if exists windowfunnel_test_normalize_window"
sql """CREATE TABLE windowfunnel_test_normalize_window (
`xwho` varchar(50) NULL COMMENT 'xwho',
`xwhen` datetime COMMENT 'xwhen',
`xwhat` int NULL COMMENT 'xwhat'
)
DUPLICATE KEY(xwho)
DISTRIBUTED BY HASH(xwho) BUCKETS 3
PROPERTIES (
"replication_num" = "1"
);"""

sql """INSERT into windowfunnel_test_normalize_window (xwho, xwhen, xwhat) values ('1', '2022-03-12 10:41:00', 1),
('1', '2022-03-12 13:28:02', 2),
('1', '2022-03-12 16:15:01', 3),
('1', '2022-03-12 19:05:04', 4);"""
//这个目前会core
// qt_window_funnel """select window_funnel(3600 * 3, 'default', t.xwhen, t.xwhat = 1, t.xwhat = 2 ) over (order by xwhat rows
// between 2 preceding and 1 preceding) AS level from windowfunnel_test_normalize_window t;"""
qt_sequence_match "SELECT sequence_match('(?1)(?2)', xwhen, xwhat = 1, xwhat = 3) over (order by xwhat rows between 2 preceding and 1 preceding) FROM windowfunnel_test_normalize_window;"
}

0 comments on commit 55d6d64

Please sign in to comment.