Skip to content

Commit 70b9643

Browse files
authored
Push the aggregation node down the union node
1 parent a527b92 commit 70b9643

File tree

4 files changed

+214
-2
lines changed

4 files changed

+214
-2
lines changed

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/MappingCollectOperator.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,13 @@ public class MappingCollectOperator extends CollectOperator {
3737
// record mapping for each child
3838
private final List<List<Integer>> mappings;
3939

40+
private final int outputColumnsCount;
41+
4042
public MappingCollectOperator(
4143
OperatorContext operatorContext, List<Operator> children, List<List<Integer>> mappings) {
4244
super(operatorContext, children);
4345
this.mappings = mappings;
46+
outputColumnsCount = mappings.get(0).size();
4447
}
4548

4649
@Override
@@ -50,7 +53,7 @@ public TsBlock next() throws Exception {
5053
if (tsBlock == null) {
5154
return null;
5255
} else {
53-
Column[] columns = new Column[tsBlock.getValueColumnCount()];
56+
Column[] columns = new Column[outputColumnsCount];
5457
List<Integer> mapping = mappings.get(currentIndex);
5558
for (int i = 0; i < columns.length; i++) {
5659
columns[i] = tsBlock.getColumn(mapping.get(i));

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableDistributedPlanGenerator.java

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import org.apache.iotdb.db.queryengine.plan.relational.metadata.DeviceEntry;
4949
import org.apache.iotdb.db.queryengine.plan.relational.metadata.QualifiedObjectName;
5050
import org.apache.iotdb.db.queryengine.plan.relational.metadata.ResolvedFunction;
51+
import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments;
5152
import org.apache.iotdb.db.queryengine.plan.relational.planner.OrderingScheme;
5253
import org.apache.iotdb.db.queryengine.plan.relational.planner.SortOrder;
5354
import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
@@ -1057,8 +1058,67 @@ public List<PlanNode> visitAggregation(AggregationNode node, PlanContext context
10571058
// child
10581059
}
10591060

1061+
// push down aggregation if the child of aggregation node only has the union Node
10601062
if (childrenNodes.size() == 1) {
10611063
node.setChild(childrenNodes.get(0));
1064+
1065+
if (childrenNodes.get(0) instanceof UnionNode
1066+
&& node.getAggregations().values().stream()
1067+
.noneMatch(aggregation -> aggregation.isDistinct() || aggregation.hasMask())) {
1068+
UnionNode unionNode = (UnionNode) childrenNodes.get(0);
1069+
List<PlanNode> children = unionNode.getChildren();
1070+
1071+
// 1. add the project Node above the children of the union node
1072+
List<PlanNode> newProjectNodes = new ArrayList<>();
1073+
1074+
Map<Symbol, Collection<Symbol>> symbolMapping = unionNode.getSymbolMapping().asMap();
1075+
for (int i = 0; i < children.size(); i++) {
1076+
Assignments.Builder assignmentsBuilder = Assignments.builder();
1077+
for (Map.Entry<Symbol, Collection<Symbol>> symbolEntry : symbolMapping.entrySet()) {
1078+
List<Symbol> symbolList = (ImmutableList<Symbol>) symbolEntry.getValue();
1079+
assignmentsBuilder.put(symbolEntry.getKey(), symbolList.get(i).toSymbolReference());
1080+
}
1081+
newProjectNodes.add(
1082+
new ProjectNode(
1083+
queryId.genPlanNodeId(), children.get(i), assignmentsBuilder.build()));
1084+
}
1085+
1086+
// 2. split the aggregation into partial and final
1087+
Pair<AggregationNode, AggregationNode> splitResult = split(node, symbolAllocator, queryId);
1088+
AggregationNode intermediate = splitResult.right;
1089+
1090+
// 3. add the aggregation node above the project node
1091+
List<PlanNode> aggregationNodes =
1092+
newProjectNodes.stream()
1093+
.map(
1094+
child -> {
1095+
PlanNodeId planNodeId = queryId.genPlanNodeId();
1096+
AggregationNode aggregationNode =
1097+
new AggregationNode(
1098+
planNodeId,
1099+
child,
1100+
intermediate.getAggregations(),
1101+
intermediate.getGroupingSets(),
1102+
intermediate.getPreGroupedSymbols(),
1103+
intermediate.getStep(),
1104+
intermediate.getHashSymbol(),
1105+
intermediate.getGroupIdSymbol());
1106+
if (node.isStreamable() && childOrdering != null) {
1107+
nodeOrderingMap.put(planNodeId, expectedOrderingSchema);
1108+
}
1109+
return aggregationNode;
1110+
})
1111+
.collect(Collectors.toList());
1112+
1113+
// 4. Add a Collect Node under the final Aggregation Node, and add the partial Aggregation
1114+
// nodes as its children
1115+
CollectNode collectNode =
1116+
new CollectNode(queryId.genPlanNodeId(), aggregationNodes.get(0).getOutputSymbols());
1117+
collectNode.setChildren(aggregationNodes);
1118+
splitResult.left.setChild(collectNode);
1119+
return Collections.singletonList(splitResult.left);
1120+
}
1121+
10621122
return Collections.singletonList(node);
10631123
}
10641124

@@ -1072,7 +1132,6 @@ public List<PlanNode> visitAggregation(AggregationNode node, PlanContext context
10721132
nodeOrderingMap.get(childrenNodes.get(0).getPlanNodeId()), childrenNodes));
10731133
return Collections.singletonList(node);
10741134
}
1075-
10761135
Pair<AggregationNode, AggregationNode> splitResult = split(node, symbolAllocator, queryId);
10771136
AggregationNode intermediate = splitResult.right;
10781137

@@ -1782,6 +1841,7 @@ public List<PlanNode> visitWindowFunction(WindowNode node, PlanContext context)
17821841
@Override
17831842
public List<PlanNode> visitUnion(UnionNode node, PlanContext context) {
17841843
context.clearExpectedOrderingScheme();
1844+
17851845
List<List<PlanNode>> children =
17861846
node.getChildren().stream()
17871847
.map(child -> child.accept(this, context))

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/type/CompatibleResolver.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ public class CompatibleResolver {
9999
addCondition(UNKNOWN, TEXT, TEXT);
100100
addCondition(UNKNOWN, STRING, STRING);
101101
addCondition(UNKNOWN, BLOB, BLOB);
102+
addCondition(UNKNOWN, UNKNOWN, UNKNOWN);
102103
}
103104

104105
private static void addCondition(Type condition1, Type condition2, Type result) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.iotdb.db.queryengine.plan.relational.analyzer;
21+
22+
import org.apache.iotdb.db.queryengine.plan.planner.plan.DistributedQueryPlan;
23+
import org.apache.iotdb.db.queryengine.plan.planner.plan.LogicalQueryPlan;
24+
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
25+
import org.apache.iotdb.db.queryengine.plan.relational.planner.SymbolAllocator;
26+
import org.apache.iotdb.db.queryengine.plan.relational.planner.TableLogicalPlanner;
27+
import org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern;
28+
import org.apache.iotdb.db.queryengine.plan.relational.planner.distribute.TableDistributedPlanner;
29+
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode;
30+
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.OutputNode;
31+
32+
import com.google.common.collect.ImmutableList;
33+
import com.google.common.collect.ImmutableMap;
34+
import org.junit.Test;
35+
36+
import java.util.Optional;
37+
38+
import static org.apache.iotdb.db.queryengine.plan.relational.analyzer.AnalyzerTest.analyzeSQL;
39+
import static org.apache.iotdb.db.queryengine.plan.relational.analyzer.TestUtils.DEFAULT_WARNING;
40+
import static org.apache.iotdb.db.queryengine.plan.relational.analyzer.TestUtils.QUERY_CONTEXT;
41+
import static org.apache.iotdb.db.queryengine.plan.relational.analyzer.TestUtils.SESSION_INFO;
42+
import static org.apache.iotdb.db.queryengine.plan.relational.analyzer.TestUtils.TEST_MATADATA;
43+
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanAssert.assertPlan;
44+
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.aggregation;
45+
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.aggregationFunction;
46+
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.collect;
47+
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.exchange;
48+
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.output;
49+
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.project;
50+
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.singleGroupingSet;
51+
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.tableScan;
52+
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.union;
53+
import static org.junit.Assert.assertEquals;
54+
55+
public class PushAggregationThroughUnionTest {
56+
57+
@Test
58+
public void UnionAggregationTest() {
59+
60+
String sql = "select * from t1 union select * from t2";
61+
Analysis analysis = analyzeSQL(sql, TEST_MATADATA, QUERY_CONTEXT);
62+
SymbolAllocator symbolAllocator = new SymbolAllocator();
63+
LogicalQueryPlan actualLogicalQueryPlan =
64+
new TableLogicalPlanner(
65+
QUERY_CONTEXT, TEST_MATADATA, SESSION_INFO, symbolAllocator, DEFAULT_WARNING)
66+
.plan(analysis);
67+
68+
// verify the Logical plan `Output - Aggregation - union - 2*tableScan`
69+
assertPlan(
70+
actualLogicalQueryPlan,
71+
output(aggregation(union(tableScan("testdb.t1"), tableScan("testdb.t2")))));
72+
73+
// verify the Distributed plan `Output - Aggregation - collect - 6*exchange` +
74+
// 6*(aggregation-project-tableScan)`
75+
TableDistributedPlanner TableDistributedPlanner =
76+
new TableDistributedPlanner(
77+
analysis, symbolAllocator, actualLogicalQueryPlan, TEST_MATADATA, null);
78+
DistributedQueryPlan actualDistributedQueryPlan = TableDistributedPlanner.plan();
79+
assertEquals(7, actualDistributedQueryPlan.getFragments().size());
80+
81+
PlanMatchPattern expectedPattern =
82+
output(
83+
aggregation(
84+
collect(exchange(), exchange(), exchange(), exchange(), exchange(), exchange())));
85+
86+
OutputNode outputNode =
87+
(OutputNode)
88+
actualDistributedQueryPlan.getFragments().get(0).getPlanNodeTree().getChildren().get(0);
89+
assertPlan(outputNode, expectedPattern);
90+
91+
for (int i = 1; i < actualDistributedQueryPlan.getFragments().size(); i++) {
92+
PlanNode planNode =
93+
actualDistributedQueryPlan.getFragments().get(i).getPlanNodeTree().getChildren().get(0);
94+
assertPlan(planNode, aggregation(project(tableScan(i <= 3 ? "testdb.t1" : "testdb.t2"))));
95+
}
96+
}
97+
98+
@Test
99+
public void unionAllWithGroupByAggregationTest() {
100+
101+
String sql =
102+
"SELECT tag1, COUNT(s1), sum(s1) FROM (SELECT tag1, s1 FROM t1 UNION ALL SELECT tag1, s1 FROM t2) GROUP BY tag1";
103+
Analysis analysis = analyzeSQL(sql, TEST_MATADATA, QUERY_CONTEXT);
104+
SymbolAllocator symbolAllocator = new SymbolAllocator();
105+
106+
LogicalQueryPlan actualLogicalQueryPlan =
107+
new TableLogicalPlanner(
108+
QUERY_CONTEXT, TEST_MATADATA, SESSION_INFO, symbolAllocator, DEFAULT_WARNING)
109+
.plan(analysis);
110+
111+
// verify the Logical plan
112+
assertPlan(
113+
actualLogicalQueryPlan,
114+
output(aggregation(union(tableScan("testdb.t1"), tableScan("testdb.t2")))));
115+
116+
// verify the Distributed plan
117+
TableDistributedPlanner tableDistributedPlanner =
118+
new TableDistributedPlanner(
119+
analysis, symbolAllocator, actualLogicalQueryPlan, TEST_MATADATA, null);
120+
DistributedQueryPlan actualDistributedQueryPlan = tableDistributedPlanner.plan();
121+
122+
assertEquals(7, actualDistributedQueryPlan.getFragments().size());
123+
PlanMatchPattern expectedRootPattern =
124+
output(
125+
aggregation(
126+
singleGroupingSet("tag1"),
127+
ImmutableMap.of(
128+
Optional.of("count"),
129+
aggregationFunction("count", ImmutableList.of("count_12")),
130+
Optional.of("sum"),
131+
aggregationFunction("sum", ImmutableList.of("sum_13"))),
132+
Optional.empty(),
133+
AggregationNode.Step.FINAL,
134+
collect(exchange(), exchange(), exchange(), exchange(), exchange(), exchange())));
135+
OutputNode outputNode =
136+
(OutputNode)
137+
actualDistributedQueryPlan.getFragments().get(0).getPlanNodeTree().getChildren().get(0);
138+
assertPlan(outputNode, expectedRootPattern);
139+
140+
for (int i = 1; i < actualDistributedQueryPlan.getFragments().size(); i++) {
141+
PlanNode planNode =
142+
actualDistributedQueryPlan.getFragments().get(i).getPlanNodeTree().getChildren().get(0);
143+
PlanMatchPattern expectedLeafPattern =
144+
aggregation(project(tableScan(i <= 3 ? "testdb.t1" : "testdb.t2")));
145+
assertPlan(planNode, expectedLeafPattern);
146+
}
147+
}
148+
}

0 commit comments

Comments
 (0)