diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRule.java index aa0b8b8552819..6dc56664643de 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRule.java @@ -90,6 +90,16 @@ public boolean isRemoteCall(RexNode node) { public boolean isNonRemoteCall(RexNode node) { return AsyncUtil.isNonAsyncCall(node); } + + @Override + public String getName() { + return "Async"; + } + + @Override + public boolean equals(Object obj) { + return obj != null && this.getClass() == obj.getClass(); + } } private static boolean hasNestedCalls(List projects) { diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRule.java new file mode 100644 index 0000000000000..18561e0db8910 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRule.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.logical; + +import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc; +import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate; +import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan; +import org.apache.flink.table.planner.plan.rules.logical.AsyncCalcSplitRule.AsyncRemoteCalcCallFinder; + +import org.apache.calcite.plan.RelOptRule; + +/** + * Rule will split the Async {@link FlinkLogicalTableFunctionScan} with Java calls or the Java + * {@link FlinkLogicalTableFunctionScan} with Async calls into a {@link FlinkLogicalCalc} which will + * be the left input of the new {@link FlinkLogicalCorrelate} and a new {@link + * FlinkLogicalTableFunctionScan}. + */ +public class AsyncCorrelateSplitRule { + + private static final RemoteCalcCallFinder ASYNC_CALL_FINDER = new AsyncRemoteCalcCallFinder(); + + public static final RelOptRule CORRELATE_SPLIT = + new RemoteCorrelateSplitRule(ASYNC_CALL_FINDER); +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java index 986f0fc538c68..968a599406177 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java @@ -21,33 +21,6 @@ import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc; import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate; import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan; -import org.apache.flink.table.planner.plan.rules.physical.stream.StreamPhysicalCorrelateRule; -import org.apache.flink.table.planner.plan.utils.PythonUtil; -import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor; - -import org.apache.calcite.plan.RelOptRule; -import org.apache.calcite.plan.RelOptRuleCall; -import org.apache.calcite.plan.hep.HepRelVertex; -import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rel.type.RelDataTypeField; -import org.apache.calcite.rex.RexBuilder; -import org.apache.calcite.rex.RexCall; -import org.apache.calcite.rex.RexCorrelVariable; -import org.apache.calcite.rex.RexFieldAccess; -import org.apache.calcite.rex.RexInputRef; -import org.apache.calcite.rex.RexNode; -import org.apache.calcite.rex.RexProgram; -import org.apache.calcite.rex.RexProgramBuilder; -import org.apache.calcite.rex.RexUtil; -import org.apache.calcite.sql.validate.SqlValidatorUtil; - -import java.util.LinkedList; -import java.util.List; -import java.util.stream.Collectors; - -import scala.collection.Iterator; -import scala.collection.mutable.ArrayBuffer; /** * Rule will split the Python {@link FlinkLogicalTableFunctionScan} with Java calls or the Java @@ -55,272 +28,8 @@ * will be the left input of the new {@link FlinkLogicalCorrelate} and a new {@link * FlinkLogicalTableFunctionScan}. */ -public class PythonCorrelateSplitRule extends RelOptRule { - public static final PythonCorrelateSplitRule INSTANCE = new PythonCorrelateSplitRule(); - - private PythonCorrelateSplitRule() { - super(operand(FlinkLogicalCorrelate.class, any()), "PythonCorrelateSplitRule"); - } - - private FlinkLogicalTableFunctionScan createNewScan( - FlinkLogicalTableFunctionScan scan, ScalarFunctionSplitter splitter) { - RexCall rightRexCall = (RexCall) scan.getCall(); - // extract Java funcs from Python TableFunction or Python funcs from Java TableFunction. - List rightCalcProjects = - rightRexCall.getOperands().stream() - .map(x -> x.accept(splitter)) - .collect(Collectors.toList()); - - RexCall newRightRexCall = rightRexCall.clone(rightRexCall.getType(), rightCalcProjects); - return new FlinkLogicalTableFunctionScan( - scan.getCluster(), - scan.getTraitSet(), - scan.getInputs(), - newRightRexCall, - scan.getElementType(), - scan.getRowType(), - scan.getColumnMappings()); - } - - @Override - public boolean matches(RelOptRuleCall call) { - FlinkLogicalCorrelate correlate = call.rel(0); - RelNode right = ((HepRelVertex) correlate.getRight()).getCurrentRel(); - FlinkLogicalTableFunctionScan tableFunctionScan; - if (right instanceof FlinkLogicalTableFunctionScan) { - tableFunctionScan = (FlinkLogicalTableFunctionScan) right; - } else if (right instanceof FlinkLogicalCalc) { - tableFunctionScan = StreamPhysicalCorrelateRule.getTableScan((FlinkLogicalCalc) right); - } else { - return false; - } - RexNode rexNode = tableFunctionScan.getCall(); - if (rexNode instanceof RexCall) { - return PythonUtil.isPythonCall(rexNode, null) - && PythonUtil.containsNonPythonCall(rexNode) - || PythonUtil.isNonPythonCall(rexNode) - && PythonUtil.containsPythonCall(rexNode, null) - || (PythonUtil.isPythonCall(rexNode, null) - && RexUtil.containsFieldAccess(rexNode)); - } - return false; - } - - private List createNewFieldNames( - RelDataType rowType, - RexBuilder rexBuilder, - int primitiveFieldCount, - ArrayBuffer extractedRexNodes, - List calcProjects) { - for (int i = 0; i < primitiveFieldCount; i++) { - calcProjects.add(RexInputRef.of(i, rowType)); - } - // change RexCorrelVariable to RexInputRef. - RexDefaultVisitor visitor = - new RexDefaultVisitor() { - @Override - public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { - RexNode expr = fieldAccess.getReferenceExpr(); - if (expr instanceof RexCorrelVariable) { - RelDataTypeField field = fieldAccess.getField(); - return new RexInputRef(field.getIndex(), field.getType()); - } else { - return rexBuilder.makeFieldAccess( - expr.accept(this), fieldAccess.getField().getIndex()); - } - } - - @Override - public RexNode visitNode(RexNode rexNode) { - return rexNode; - } - }; - // add the fields of the extracted rex calls. - Iterator iterator = extractedRexNodes.iterator(); - while (iterator.hasNext()) { - RexNode rexNode = iterator.next(); - if (rexNode instanceof RexCall) { - RexCall rexCall = (RexCall) rexNode; - List newProjects = - rexCall.getOperands().stream() - .map(x -> x.accept(visitor)) - .collect(Collectors.toList()); - RexCall newRexCall = rexCall.clone(rexCall.getType(), newProjects); - calcProjects.add(newRexCall); - } else { - calcProjects.add(rexNode); - } - } - - List nameList = new LinkedList<>(); - for (int i = 0; i < primitiveFieldCount; i++) { - nameList.add(rowType.getFieldNames().get(i)); - } - Iterator indicesIterator = extractedRexNodes.indices().iterator(); - while (indicesIterator.hasNext()) { - nameList.add("f" + indicesIterator.next()); - } - return SqlValidatorUtil.uniquify( - nameList, rexBuilder.getTypeFactory().getTypeSystem().isSchemaCaseSensitive()); - } - - private FlinkLogicalCalc createNewLeftCalc( - RelNode left, - RexBuilder rexBuilder, - ArrayBuffer extractedRexNodes, - FlinkLogicalCorrelate correlate) { - // add the fields of the primitive left input. - List leftCalcProjects = new LinkedList<>(); - RelDataType leftRowType = left.getRowType(); - List leftCalcCalcFieldNames = - createNewFieldNames( - leftRowType, - rexBuilder, - leftRowType.getFieldCount(), - extractedRexNodes, - leftCalcProjects); - - // create a new calc - return new FlinkLogicalCalc( - correlate.getCluster(), - correlate.getTraitSet(), - left, - RexProgram.create( - leftRowType, leftCalcProjects, null, leftCalcCalcFieldNames, rexBuilder)); - } - - private FlinkLogicalCalc createTopCalc( - int primitiveLeftFieldCount, - RexBuilder rexBuilder, - ArrayBuffer extractedRexNodes, - RelDataType calcRowType, - FlinkLogicalCorrelate newCorrelate) { - RexProgram rexProgram = - new RexProgramBuilder(newCorrelate.getRowType(), rexBuilder).getProgram(); - int offset = extractedRexNodes.size() + primitiveLeftFieldCount; - - // extract correlate output RexNode. - List newTopCalcProjects = - rexProgram.getExprList().stream() - .filter(x -> x instanceof RexInputRef) - .filter( - x -> { - int index = ((RexInputRef) x).getIndex(); - return index < primitiveLeftFieldCount || index >= offset; - }) - .collect(Collectors.toList()); - - return new FlinkLogicalCalc( - newCorrelate.getCluster(), - newCorrelate.getTraitSet(), - newCorrelate, - RexProgram.create( - newCorrelate.getRowType(), - newTopCalcProjects, - null, - calcRowType, - rexBuilder)); - } - - private ScalarFunctionSplitter createScalarFunctionSplitter( - RexProgram program, - RexBuilder rexBuilder, - int primitiveLeftFieldCount, - ArrayBuffer extractedRexNodes, - RexNode tableFunctionNode) { - return new ScalarFunctionSplitter( - program, - rexBuilder, - primitiveLeftFieldCount, - extractedRexNodes, - node -> { - if (PythonUtil.isNonPythonCall(tableFunctionNode)) { - // splits the RexCalls which contain Python functions into separate node - return PythonUtil.isPythonCall(node, null); - } else if (PythonUtil.containsNonPythonCall(node)) { - // splits the RexCalls which contain non-Python functions into separate node - return PythonUtil.isNonPythonCall(node); - } else { - // splits the RexFieldAccesses which contain non-Python functions into - // separate node - return node instanceof RexFieldAccess; - } - }, - new PythonRemoteCalcCallFinder()); - } - - @Override - public void onMatch(RelOptRuleCall call) { - FlinkLogicalCorrelate correlate = call.rel(0); - RexBuilder rexBuilder = call.builder().getRexBuilder(); - RelNode left = ((HepRelVertex) correlate.getLeft()).getCurrentRel(); - RelNode right = ((HepRelVertex) correlate.getRight()).getCurrentRel(); - int primitiveLeftFieldCount = left.getRowType().getFieldCount(); - ArrayBuffer extractedRexNodes = new ArrayBuffer<>(); - - RelNode rightNewInput; - if (right instanceof FlinkLogicalTableFunctionScan) { - FlinkLogicalTableFunctionScan scan = (FlinkLogicalTableFunctionScan) right; - rightNewInput = - createNewScan( - scan, - createScalarFunctionSplitter( - null, - rexBuilder, - primitiveLeftFieldCount, - extractedRexNodes, - scan.getCall())); - } else { - FlinkLogicalCalc calc = (FlinkLogicalCalc) right; - FlinkLogicalTableFunctionScan scan = StreamPhysicalCorrelateRule.getTableScan(calc); - FlinkLogicalCalc mergedCalc = StreamPhysicalCorrelateRule.getMergedCalc(calc); - FlinkLogicalTableFunctionScan newScan = - createNewScan( - scan, - createScalarFunctionSplitter( - null, - rexBuilder, - primitiveLeftFieldCount, - extractedRexNodes, - scan.getCall())); - rightNewInput = - mergedCalc.copy(mergedCalc.getTraitSet(), newScan, mergedCalc.getProgram()); - } - - FlinkLogicalCorrelate newCorrelate; - if (extractedRexNodes.size() > 0) { - FlinkLogicalCalc leftCalc = - createNewLeftCalc(left, rexBuilder, extractedRexNodes, correlate); - - newCorrelate = - new FlinkLogicalCorrelate( - correlate.getCluster(), - correlate.getTraitSet(), - leftCalc, - rightNewInput, - correlate.getCorrelationId(), - correlate.getRequiredColumns(), - correlate.getJoinType()); - } else { - newCorrelate = - new FlinkLogicalCorrelate( - correlate.getCluster(), - correlate.getTraitSet(), - left, - rightNewInput, - correlate.getCorrelationId(), - correlate.getRequiredColumns(), - correlate.getJoinType()); - } - - FlinkLogicalCalc newTopCalc = - createTopCalc( - primitiveLeftFieldCount, - rexBuilder, - extractedRexNodes, - correlate.getRowType(), - newCorrelate); +public class PythonCorrelateSplitRule { - call.transformTo(newTopCalc); - } + public static final RemoteCorrelateSplitRule INSTANCE = + new RemoteCorrelateSplitRule(new PythonRemoteCalcCallFinder()); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/RemoteCorrelateSplitRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/RemoteCorrelateSplitRule.java new file mode 100644 index 0000000000000..744edfda201bb --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/RemoteCorrelateSplitRule.java @@ -0,0 +1,345 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.logical; + +import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc; +import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate; +import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan; +import org.apache.flink.table.planner.plan.rules.physical.stream.StreamPhysicalCorrelateRule; +import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.hep.HepRelVertex; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexCorrelVariable; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexProgram; +import org.apache.calcite.rex.RexProgramBuilder; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.sql.validate.SqlValidatorUtil; + +import java.util.LinkedList; +import java.util.List; +import java.util.stream.Collectors; + +import scala.collection.Iterator; +import scala.collection.mutable.ArrayBuffer; + +/** + * Rule will split the Remote {@link FlinkLogicalTableFunctionScan} with Java calls or the Java + * {@link FlinkLogicalTableFunctionScan} with Remote calls into a {@link FlinkLogicalCalc} which + * will be the left input of the new {@link FlinkLogicalCorrelate} and a new {@link + * FlinkLogicalTableFunctionScan}. + */ +public class RemoteCorrelateSplitRule extends RelOptRule { + private final RemoteCalcCallFinder callFinder; + + RemoteCorrelateSplitRule(RemoteCalcCallFinder callFinder) { + super( + operand(FlinkLogicalCorrelate.class, any()), + "RemoteCorrelateSplitRule-" + callFinder.getName()); + this.callFinder = callFinder; + } + + private FlinkLogicalTableFunctionScan createNewScan( + FlinkLogicalTableFunctionScan scan, ScalarFunctionSplitter splitter) { + RexCall rightRexCall = (RexCall) scan.getCall(); + // extract Java funcs from Remote TableFunction or Remote funcs from Java TableFunction. + List rightCalcProjects = + rightRexCall.getOperands().stream() + .map(x -> x.accept(splitter)) + .collect(Collectors.toList()); + + RexCall newRightRexCall = rightRexCall.clone(rightRexCall.getType(), rightCalcProjects); + return new FlinkLogicalTableFunctionScan( + scan.getCluster(), + scan.getTraitSet(), + scan.getInputs(), + newRightRexCall, + scan.getElementType(), + scan.getRowType(), + scan.getColumnMappings()); + } + + @Override + public boolean matches(RelOptRuleCall call) { + FlinkLogicalCorrelate correlate = call.rel(0); + RelNode right = ((HepRelVertex) correlate.getRight()).getCurrentRel(); + FlinkLogicalTableFunctionScan tableFunctionScan; + if (right instanceof FlinkLogicalTableFunctionScan) { + tableFunctionScan = (FlinkLogicalTableFunctionScan) right; + } else if (right instanceof FlinkLogicalCalc) { + tableFunctionScan = StreamPhysicalCorrelateRule.getTableScan((FlinkLogicalCalc) right); + } else { + return false; + } + RexNode rexNode = tableFunctionScan.getCall(); + if (rexNode instanceof RexCall) { + return callFinder.isRemoteCall(rexNode) && callFinder.containsNonRemoteCall(rexNode) + || callFinder.isNonRemoteCall(rexNode) && callFinder.containsRemoteCall(rexNode) + || (callFinder.isRemoteCall(rexNode) && RexUtil.containsFieldAccess(rexNode)); + } + return false; + } + + private List createNewFieldNames( + RelDataType rowType, + RexBuilder rexBuilder, + int primitiveFieldCount, + ArrayBuffer extractedRexNodes, + List calcProjects) { + for (int i = 0; i < primitiveFieldCount; i++) { + calcProjects.add(RexInputRef.of(i, rowType)); + } + // change RexCorrelVariable to RexInputRef. + RexDefaultVisitor visitor = + new RexDefaultVisitor() { + @Override + public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { + RexNode expr = fieldAccess.getReferenceExpr(); + if (expr instanceof RexCorrelVariable) { + RelDataTypeField field = fieldAccess.getField(); + return new RexInputRef(field.getIndex(), field.getType()); + } else { + return rexBuilder.makeFieldAccess( + expr.accept(this), fieldAccess.getField().getIndex()); + } + } + + @Override + public RexNode visitCall(RexCall call) { + List newProjects = + call.getOperands().stream() + .map(x -> x.accept(this)) + .collect(Collectors.toList()); + return rexBuilder.makeCall(call.getOperator(), newProjects); + } + + @Override + public RexNode visitNode(RexNode rexNode) { + return rexNode; + } + }; + // add the fields of the extracted rex calls. + Iterator iterator = extractedRexNodes.iterator(); + while (iterator.hasNext()) { + RexNode rexNode = iterator.next(); + if (rexNode instanceof RexCall) { + RexCall rexCall = (RexCall) rexNode; + List newProjects = + rexCall.getOperands().stream() + .map(x -> x.accept(visitor)) + .collect(Collectors.toList()); + RexCall newRexCall = rexCall.clone(rexCall.getType(), newProjects); + calcProjects.add(newRexCall); + } else { + calcProjects.add(rexNode); + } + } + + List nameList = new LinkedList<>(); + for (int i = 0; i < primitiveFieldCount; i++) { + nameList.add(rowType.getFieldNames().get(i)); + } + Iterator indicesIterator = extractedRexNodes.indices().iterator(); + while (indicesIterator.hasNext()) { + nameList.add("f" + indicesIterator.next()); + } + return SqlValidatorUtil.uniquify( + nameList, rexBuilder.getTypeFactory().getTypeSystem().isSchemaCaseSensitive()); + } + + private FlinkLogicalCalc createNewLeftCalc( + RelNode left, + RexBuilder rexBuilder, + ArrayBuffer extractedRexNodes, + FlinkLogicalCorrelate correlate) { + // add the fields of the primitive left input. + List leftCalcProjects = new LinkedList<>(); + RelDataType leftRowType = left.getRowType(); + List leftCalcCalcFieldNames = + createNewFieldNames( + leftRowType, + rexBuilder, + leftRowType.getFieldCount(), + extractedRexNodes, + leftCalcProjects); + + // create a new calc + return new FlinkLogicalCalc( + correlate.getCluster(), + correlate.getTraitSet(), + left, + RexProgram.create( + leftRowType, leftCalcProjects, null, leftCalcCalcFieldNames, rexBuilder)); + } + + private FlinkLogicalCalc createTopCalc( + int primitiveLeftFieldCount, + RexBuilder rexBuilder, + ArrayBuffer extractedRexNodes, + RelDataType calcRowType, + FlinkLogicalCorrelate newCorrelate) { + RexProgram rexProgram = + new RexProgramBuilder(newCorrelate.getRowType(), rexBuilder).getProgram(); + int offset = extractedRexNodes.size() + primitiveLeftFieldCount; + + // extract correlate output RexNode. + List newTopCalcProjects = + rexProgram.getExprList().stream() + .filter(x -> x instanceof RexInputRef) + .filter( + x -> { + int index = ((RexInputRef) x).getIndex(); + return index < primitiveLeftFieldCount || index >= offset; + }) + .collect(Collectors.toList()); + + return new FlinkLogicalCalc( + newCorrelate.getCluster(), + newCorrelate.getTraitSet(), + newCorrelate, + RexProgram.create( + newCorrelate.getRowType(), + newTopCalcProjects, + null, + calcRowType, + rexBuilder)); + } + + private ScalarFunctionSplitter createScalarFunctionSplitter( + RexProgram program, + RexBuilder rexBuilder, + int primitiveLeftFieldCount, + ArrayBuffer extractedRexNodes, + RexNode tableFunctionNode) { + return new ScalarFunctionSplitter( + program, + rexBuilder, + primitiveLeftFieldCount, + extractedRexNodes, + node -> { + if (callFinder.isNonRemoteCall(tableFunctionNode)) { + // splits the RexCalls which contain Remote functions into separate node + return callFinder.isRemoteCall(node); + } else if (callFinder.containsNonRemoteCall(node)) { + // splits the RexCalls which contain non-Remote functions into separate node + return callFinder.isNonRemoteCall(node); + } else { + // splits the RexFieldAccesses which contain non-Remote functions into + // separate node + return node instanceof RexFieldAccess; + } + }, + callFinder); + } + + @Override + public void onMatch(RelOptRuleCall call) { + FlinkLogicalCorrelate correlate = call.rel(0); + RexBuilder rexBuilder = call.builder().getRexBuilder(); + RelNode left = ((HepRelVertex) correlate.getLeft()).getCurrentRel(); + RelNode right = ((HepRelVertex) correlate.getRight()).getCurrentRel(); + int primitiveLeftFieldCount = left.getRowType().getFieldCount(); + ArrayBuffer extractedRexNodes = new ArrayBuffer<>(); + + RelNode rightNewInput; + if (right instanceof FlinkLogicalTableFunctionScan) { + FlinkLogicalTableFunctionScan scan = (FlinkLogicalTableFunctionScan) right; + rightNewInput = + createNewScan( + scan, + createScalarFunctionSplitter( + null, + rexBuilder, + primitiveLeftFieldCount, + extractedRexNodes, + scan.getCall())); + } else { + FlinkLogicalCalc calc = (FlinkLogicalCalc) right; + FlinkLogicalTableFunctionScan scan = StreamPhysicalCorrelateRule.getTableScan(calc); + FlinkLogicalCalc mergedCalc = StreamPhysicalCorrelateRule.getMergedCalc(calc); + FlinkLogicalTableFunctionScan newScan = + createNewScan( + scan, + createScalarFunctionSplitter( + null, + rexBuilder, + primitiveLeftFieldCount, + extractedRexNodes, + scan.getCall())); + rightNewInput = + mergedCalc.copy(mergedCalc.getTraitSet(), newScan, mergedCalc.getProgram()); + } + + FlinkLogicalCorrelate newCorrelate; + if (extractedRexNodes.size() > 0) { + FlinkLogicalCalc leftCalc = + createNewLeftCalc(left, rexBuilder, extractedRexNodes, correlate); + + newCorrelate = + new FlinkLogicalCorrelate( + correlate.getCluster(), + correlate.getTraitSet(), + leftCalc, + rightNewInput, + correlate.getCorrelationId(), + correlate.getRequiredColumns(), + correlate.getJoinType()); + } else { + newCorrelate = + new FlinkLogicalCorrelate( + correlate.getCluster(), + correlate.getTraitSet(), + left, + rightNewInput, + correlate.getCorrelationId(), + correlate.getRequiredColumns(), + correlate.getJoinType()); + } + + FlinkLogicalCalc newTopCalc = + createTopCalc( + primitiveLeftFieldCount, + rexBuilder, + extractedRexNodes, + correlate.getRowType(), + newCorrelate); + + call.transformTo(newTopCalc); + } + + // Consider the rules to be equal if they are the same class and their call finders are the same + // class. + @Override + public boolean equals(Object object) { + if (object == null || !object.getClass().equals(RemoteCorrelateSplitRule.class)) { + return false; + } + RemoteCorrelateSplitRule rule = (RemoteCorrelateSplitRule) object; + return callFinder.equals(rule.callFinder); + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala index d029cf6e44c2a..f77c0126e2318 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala @@ -417,7 +417,9 @@ object FlinkStreamRuleSets { // Avoid async calls which call async calls. AsyncCalcSplitRule.NESTED_SPLIT, // Avoid having async calls in multiple projections in a single calc. - AsyncCalcSplitRule.ONE_PER_CALC_SPLIT + AsyncCalcSplitRule.ONE_PER_CALC_SPLIT, + // Split async calls from correlates + AsyncCorrelateSplitRule.CORRELATE_SPLIT ) /** RuleSet to do physical optimize for stream */ diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRule.scala index 1756b9eb5e827..169e088738abd 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRule.scala @@ -81,6 +81,12 @@ class PythonRemoteCalcCallFinder extends RemoteCalcCallFinder { override def isNonRemoteCall(node: RexNode): Boolean = { PythonUtil.isNonPythonCall(node) } + + override def equals(obj: Any): Boolean = { + obj != null && obj.isInstanceOf[PythonRemoteCalcCallFinder] + } + + override def getName: String = "Python" } object PythonCalcSplitRule { diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcCallFinder.java b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcCallFinder.java index 5209ab1b14be7..68c0e281b529a 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcCallFinder.java +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcCallFinder.java @@ -34,4 +34,7 @@ public interface RemoteCalcCallFinder { // If the node contains directly a non-remote call. boolean isNonRemoteCall(RexNode node); + + // A name that can be appended onto the rule + String getName(); } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.java index e97abdf1cfcee..1098c0585b2e5 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.java @@ -22,6 +22,7 @@ import org.apache.flink.table.api.TableConfig; import org.apache.flink.table.api.TableEnvironment; import org.apache.flink.table.functions.AsyncScalarFunction; +import org.apache.flink.table.functions.TableFunction; import org.apache.flink.table.planner.plan.optimize.program.BatchOptimizeContext; import org.apache.flink.table.planner.plan.optimize.program.FlinkChainedProgram; import org.apache.flink.table.planner.plan.optimize.program.FlinkHepRuleSetProgramBuilder; @@ -82,6 +83,7 @@ public void setup() { util.addTemporarySystemFunction("func4", new Func4()); util.addTemporarySystemFunction("func5", new Func5()); util.addTemporarySystemFunction("func6", new Func6()); + util.addTemporarySystemFunction("tableFunc", new RandomTableFunction()); } @Test @@ -370,4 +372,13 @@ public void eval(CompletableFuture future, Integer param, Integer param future.complete(param + param2); } } + + /** Test function. */ + public static class RandomTableFunction extends TableFunction { + + public void eval(Integer i) { + collect("blah " + i); + collect("foo " + i); + } + } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRuleTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRuleTest.java new file mode 100644 index 0000000000000..5bbcc57fd8d32 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRuleTest.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.logical; + +import org.apache.flink.table.api.TableConfig; +import org.apache.flink.table.api.TableEnvironment; +import org.apache.flink.table.planner.plan.optimize.program.BatchOptimizeContext; +import org.apache.flink.table.planner.plan.optimize.program.FlinkChainedProgram; +import org.apache.flink.table.planner.plan.optimize.program.FlinkHepRuleSetProgramBuilder; +import org.apache.flink.table.planner.plan.optimize.program.HEP_RULES_EXECUTION_TYPE; +import org.apache.flink.table.planner.plan.rules.FlinkStreamRuleSets; +import org.apache.flink.table.planner.plan.rules.logical.AsyncCalcSplitRuleTest.Func1; +import org.apache.flink.table.planner.plan.rules.logical.AsyncCalcSplitRuleTest.RandomTableFunction; +import org.apache.flink.table.planner.utils.TableTestBase; +import org.apache.flink.table.planner.utils.TableTestUtil; + +import org.apache.calcite.plan.hep.HepMatchOrder; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** Test for {@link AsyncCorrelateSplitRule}. */ +public class AsyncCorrelateSplitRuleTest extends TableTestBase { + + private TableTestUtil util = streamTestUtil(TableConfig.getDefault()); + + @BeforeEach + public void setup() { + FlinkChainedProgram programs = new FlinkChainedProgram(); + programs.addLast( + "logical_rewrite", + FlinkHepRuleSetProgramBuilder.newBuilder() + .setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE()) + .setHepMatchOrder(HepMatchOrder.BOTTOM_UP) + .add(FlinkStreamRuleSets.LOGICAL_REWRITE()) + .build()); + + TableEnvironment tEnv = util.getTableEnv(); + tEnv.executeSql( + "CREATE TABLE MyTable (\n" + + " a int,\n" + + " b bigint,\n" + + " c string,\n" + + " d ARRAY\n" + + ") ;"); + + util.addTemporarySystemFunction("func1", new Func1()); + util.addTemporarySystemFunction("tableFunc", new RandomTableFunction()); + } + + @Test + public void testCorrelateImmediate() { + String sqlQuery = "select * FROM MyTable, LATERAL TABLE(tableFunc(func1(a)))"; + util.verifyRelPlan(sqlQuery); + } + + @Test + public void testCorrelateIndirect() { + String sqlQuery = "select * FROM MyTable, LATERAL TABLE(tableFunc(ABS(func1(a))))"; + util.verifyRelPlan(sqlQuery); + } + + @Test + public void testCorrelateIndirectOtherWay() { + String sqlQuery = "select * FROM MyTable, LATERAL TABLE(tableFunc(func1(ABS(a))))"; + util.verifyRelPlan(sqlQuery); + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncCalcITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncCalcITCase.java index b06454f9006d7..4e4125122e71b 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncCalcITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncCalcITCase.java @@ -28,6 +28,7 @@ import org.apache.flink.table.api.config.ExecutionConfigOptions; import org.apache.flink.table.functions.AsyncScalarFunction; import org.apache.flink.table.functions.FunctionContext; +import org.apache.flink.table.functions.TableFunction; import org.apache.flink.table.planner.runtime.utils.StreamingTestBase; import org.apache.flink.types.Row; @@ -238,6 +239,22 @@ public void testFailures() { assertThat(results).containsSequence(expectedRows); } + @Test + public void testTableFuncWithAsyncCalc() { + Table t1 = tEnv.fromValues(1, 2).as("f1"); + tEnv.createTemporaryView("t1", t1); + tEnv.createTemporarySystemFunction("func", new RandomTableFunction()); + tEnv.createTemporarySystemFunction("addTen", new AsyncFuncAdd10()); + final List results = executeSql("select * FROM t1, LATERAL TABLE(func(addTen(f1)))"); + final List expectedRows = + Arrays.asList( + Row.of(1, "blah 11"), + Row.of(1, "foo 11"), + Row.of(2, "blah 12"), + Row.of(2, "foo 12")); + assertThat(results).containsSequence(expectedRows); + } + private List executeSql(String sql) { TableResult result = tEnv.executeSql(sql); final List rows = new ArrayList<>(); @@ -382,4 +399,13 @@ public void eval(CompletableFuture future, Integer param1, Integer para executor.schedule(() -> future.complete(param1 + param2), 10, TimeUnit.MILLISECONDS); } } + + /** A table function. */ + public static class RandomTableFunction extends TableFunction { + + public void eval(Integer i) { + collect("blah " + i); + collect("foo " + i); + } + } } diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRuleTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRuleTest.xml new file mode 100644 index 0000000000000..00b283f7b1cd9 --- /dev/null +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRuleTest.xml @@ -0,0 +1,84 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +