diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/AlignByDeviceOrderByLimitOffsetTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/AlignByDeviceOrderByLimitOffsetTest.java index cff70fc457ce4..8d3abde41aa2f 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/AlignByDeviceOrderByLimitOffsetTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/AlignByDeviceOrderByLimitOffsetTest.java @@ -24,6 +24,7 @@ import org.apache.iotdb.db.queryengine.common.QueryId; import org.apache.iotdb.db.queryengine.plan.analyze.Analysis; import org.apache.iotdb.db.queryengine.plan.planner.plan.DistributedQueryPlan; +import org.apache.iotdb.db.queryengine.plan.planner.plan.FragmentInstance; import org.apache.iotdb.db.queryengine.plan.planner.plan.LogicalQueryPlan; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.AggregationMergeSortNode; @@ -48,6 +49,7 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; public class AlignByDeviceOrderByLimitOffsetTest { @@ -880,6 +882,25 @@ public void orderByExpressionTest1() { * ├──SeriesScanNode-27:[SeriesPath: root.sg.d333.s1, DataRegion: TConsensusGroupId(type:DataRegion, id:4)] * └──SeriesScanNode-28:[SeriesPath: root.sg.d333.s2, DataRegion: TConsensusGroupId(type:DataRegion, id:4)] */ + private TopKNode findRootTopK(DistributedQueryPlan plan) { + for (FragmentInstance inst : plan.getInstances()) { + PlanNode tree = inst.getFragment().getPlanNodeTree(); + if (!tree.getChildren().isEmpty() && tree.getChildren().get(0) instanceof TopKNode) { + return (TopKNode) tree.getChildren().get(0); + } + } + return null; + } + + private TopKNode findInnerTopK(TopKNode rootTopK) { + for (PlanNode child : rootTopK.getChildren()) { + if (child instanceof TopKNode) { + return (TopKNode) child; + } + } + return null; + } + @Test public void orderByExpressionTest2() { // only order by expression, has LIMIT @@ -891,18 +912,29 @@ public void orderByExpressionTest2() { planner = new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode)); plan = planner.planFragments(); assertEquals(4, plan.getInstances().size()); - firstFiRoot = plan.getInstances().get(0).getFragment().getPlanNodeTree(); - firstFiTopNode = firstFiRoot.getChildren().get(0); - assertTrue(firstFiTopNode instanceof TopKNode); - for (PlanNode node : firstFiTopNode.getChildren().get(0).getChildren()) { - assertTrue(node instanceof DeviceViewNode); - assertTrue(node.getChildren().get(0) instanceof FullOuterTimeJoinNode); + TopKNode rootTopK = findRootTopK(plan); + assertNotNull(rootTopK); + TopKNode innerTopK = findInnerTopK(rootTopK); + assertNotNull(innerTopK); + for (PlanNode subTree : innerTopK.getChildren()) { + assertTrue( + containsNodeType(subTree, DeviceViewNode.class) + || containsNodeType(subTree, SingleDeviceViewNode.class)); } - assertTrue(firstFiTopNode.getChildren().get(1) instanceof ExchangeNode); - assertTrue(firstFiTopNode.getChildren().get(2) instanceof ExchangeNode); - assertTrue(firstFiTopNode.getChildren().get(3) instanceof ExchangeNode); - for (int i = 0; i < 4; i++) { - assertScanNodeLimitValue(plan.getInstances().get(i).getFragment().getPlanNodeTree(), 0); + boolean needJoin = + innerTopK.getChildren().stream() + .anyMatch( + st -> + st.getChildren().stream().filter(n -> n instanceof SeriesScanNode).count() > 1); + if (needJoin) { + assertTrue( + containsNodeType(innerTopK, FullOuterTimeJoinNode.class) + || containsNodeType(innerTopK, LeftOuterTimeJoinNode.class)); + } + long exCnt = rootTopK.getChildren().stream().filter(n -> n instanceof ExchangeNode).count(); + assertEquals(3, exCnt); + for (FragmentInstance inst : plan.getInstances()) { + assertScanNodeLimitValue(inst.getFragment().getPlanNodeTree(), 0); } } @@ -1121,4 +1153,16 @@ private void assertScanNodeLimitValue(PlanNode root, long limitValue) { } } } + + private static boolean containsNodeType(PlanNode root, Class clazz) { + if (clazz.isInstance(root)) { + return true; + } + for (PlanNode child : root.getChildren()) { + if (containsNodeType(child, clazz)) { + return true; + } + } + return false; + } }