Skip to content

Commit e42202b

Browse files
committed
[FLINK-26518][table] Support BridgingSqlFunction with SqlTableFunction for Scala implicits
This closes apache#19137.
1 parent f8cb19e commit e42202b

File tree

27 files changed

+246
-196
lines changed

27 files changed

+246
-196
lines changed

flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/utils/CalculatedTableFactory.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.apache.flink.table.expressions.ResolvedExpression;
2929
import org.apache.flink.table.expressions.utils.ResolvedExpressionDefaultVisitor;
3030
import org.apache.flink.table.functions.FunctionDefinition;
31+
import org.apache.flink.table.functions.FunctionKind;
3132
import org.apache.flink.table.operations.CalculatedQueryOperation;
3233
import org.apache.flink.table.operations.QueryOperation;
3334
import org.apache.flink.table.types.DataType;
@@ -38,6 +39,7 @@
3839
import java.util.List;
3940

4041
import static java.util.stream.Collectors.toList;
42+
import static org.apache.flink.table.expressions.ApiExpressionUtils.isFunctionOfKind;
4143
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.AS;
4244

4345
/** Utility class for creating a valid {@link CalculatedQueryOperation} operation. */
@@ -89,7 +91,7 @@ private CalculatedQueryOperation unwrapFromAlias(CallExpression call) {
8991
+ alias)))
9092
.collect(toList());
9193

92-
if (!(children.get(0) instanceof CallExpression)) {
94+
if (!isFunctionOfKind(children.get(0), FunctionKind.TABLE)) {
9395
throw fail();
9496
}
9597

@@ -156,7 +158,7 @@ protected CalculatedQueryOperation defaultMethod(ResolvedExpression expression)
156158

157159
private ValidationException fail() {
158160
return new ValidationException(
159-
"A lateral join only accepts a string expression which defines a table function "
161+
"A lateral join only accepts an expression which defines a table function "
160162
+ "call that might be followed by some alias.");
161163
}
162164
}

flink-table/flink-table-api-scala/src/main/scala/org/apache/flink/table/api/ImplicitExpressionConversions.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,17 +158,13 @@ trait ImplicitExpressionConversions {
158158
}
159159
}
160160

161-
implicit class TableFunctionCall[T: TypeInformation](val t: TableFunction[T]) {
161+
implicit class TableFunctionCall(val t: TableFunction[_]) {
162162

163163
/**
164164
* Calls a table function for the given parameters.
165165
*/
166166
def apply(params: Expression*): Expression = {
167-
val resultTypeInfo: TypeInformation[T] = UserDefinedFunctionHelper
168-
.getReturnTypeOfTableFunction(t, implicitly[TypeInformation[T]])
169-
unresolvedCall(
170-
new TableFunctionDefinition(t.getClass.getName, t, resultTypeInfo),
171-
params.map(ApiExpressionUtils.objectToExpression): _*)
167+
unresolvedCall(t, params.map(ApiExpressionUtils.objectToExpression): _*)
172168
}
173169
}
174170

flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/ProcedureNamespace.java

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,12 @@
2121
import org.apache.calcite.rel.type.RelDataType;
2222
import org.apache.calcite.sql.SqlCall;
2323
import org.apache.calcite.sql.SqlCallBinding;
24-
import org.apache.calcite.sql.SqlKind;
2524
import org.apache.calcite.sql.SqlNode;
2625
import org.apache.calcite.sql.SqlOperator;
2726
import org.apache.calcite.sql.SqlTableFunction;
28-
import org.apache.calcite.sql.SqlUtil;
2927
import org.apache.calcite.sql.type.SqlReturnTypeInference;
30-
import org.apache.calcite.sql.type.SqlTypeName;
28+
29+
import static java.util.Objects.requireNonNull;
3130

3231
/**
3332
* Namespace whose contents are defined by the result of a call to a user-defined procedure.
@@ -56,25 +55,23 @@ public final class ProcedureNamespace extends AbstractNamespace {
5655

5756
public RelDataType validateImpl(RelDataType targetRowType) {
5857
validator.inferUnknownTypes(validator.unknownType, scope, call);
59-
final RelDataType type = validator.deriveTypeImpl(scope, call);
58+
// The result is ignored but the type is derived to trigger the validation
59+
validator.deriveTypeImpl(scope, call);
6060
final SqlOperator operator = call.getOperator();
6161
final SqlCallBinding callBinding = new SqlCallBinding(validator, scope, call);
62-
if (operator instanceof SqlTableFunction) {
63-
final SqlTableFunction tableFunction = (SqlTableFunction) operator;
64-
if (type.getSqlTypeName() != SqlTypeName.CURSOR) {
65-
throw new IllegalArgumentException(
66-
"Table function should have CURSOR " + "type, not " + type);
67-
}
68-
final SqlReturnTypeInference rowTypeInference = tableFunction.getRowTypeInference();
69-
RelDataType retType = rowTypeInference.inferReturnType(callBinding);
70-
return validator.getTypeFactory().createTypeWithNullability(retType, false);
71-
}
72-
73-
// special handling of collection tables TABLE(function(...))
74-
if (SqlUtil.stripAs(enclosingNode).getKind() == SqlKind.COLLECTION_TABLE) {
75-
return toStruct(type, getNode());
62+
if (!(operator instanceof SqlTableFunction)) {
63+
throw new IllegalArgumentException(
64+
"Argument must be a table function: " + operator.getNameAsId());
7665
}
77-
return type;
66+
final SqlTableFunction tableFunction = (SqlTableFunction) operator;
67+
final SqlReturnTypeInference rowTypeInference = tableFunction.getRowTypeInference();
68+
final RelDataType rowRelDataType =
69+
requireNonNull(
70+
rowTypeInference.inferReturnType(callBinding),
71+
() -> "got null from inferReturnType for call " + callBinding.getCall());
72+
// For BridgingSqlFunction the type can still be atomic
73+
// and will be wrapped with a proper field alias
74+
return toStruct(rowRelDataType, getNode());
7875
}
7976

8077
/** Converts a type to a struct if it is not already. */

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/calcite/FlinkRelBuilder.java

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.flink.table.operations.QueryOperation;
2424
import org.apache.flink.table.planner.calcite.FlinkRelFactories.ExpandFactory;
2525
import org.apache.flink.table.planner.calcite.FlinkRelFactories.RankFactory;
26+
import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction;
2627
import org.apache.flink.table.planner.hint.FlinkHints;
2728
import org.apache.flink.table.planner.plan.QueryOperationConverter;
2829
import org.apache.flink.table.planner.plan.logical.LogicalWindow;
@@ -33,13 +34,15 @@
3334
import org.apache.flink.table.runtime.groupwindow.NamedWindowProperty;
3435
import org.apache.flink.table.runtime.operators.rank.RankRange;
3536
import org.apache.flink.table.runtime.operators.rank.RankType;
37+
import org.apache.flink.util.CollectionUtil;
3638
import org.apache.flink.util.Preconditions;
3739

3840
import org.apache.flink.shaded.guava30.com.google.common.collect.ImmutableList;
3941

4042
import org.apache.calcite.plan.Context;
4143
import org.apache.calcite.plan.Contexts;
4244
import org.apache.calcite.plan.RelOptCluster;
45+
import org.apache.calcite.plan.RelOptRule;
4346
import org.apache.calcite.plan.RelOptSchema;
4447
import org.apache.calcite.plan.RelOptTable.ToRelContext;
4548
import org.apache.calcite.plan.ViewExpanders;
@@ -48,18 +51,26 @@
4851
import org.apache.calcite.rel.core.AggregateCall;
4952
import org.apache.calcite.rel.hint.RelHint;
5053
import org.apache.calcite.rel.logical.LogicalAggregate;
54+
import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
55+
import org.apache.calcite.rel.type.RelDataType;
56+
import org.apache.calcite.rel.type.RelDataTypeFactory;
5157
import org.apache.calcite.rel.type.RelDataTypeField;
58+
import org.apache.calcite.rex.RexBuilder;
5259
import org.apache.calcite.rex.RexNode;
5360
import org.apache.calcite.sql.SqlKind;
61+
import org.apache.calcite.sql.SqlOperator;
5462
import org.apache.calcite.tools.RelBuilder;
5563
import org.apache.calcite.tools.RelBuilderFactory;
5664
import org.apache.calcite.util.ImmutableBitSet;
5765
import org.apache.calcite.util.Util;
5866

5967
import java.util.ArrayList;
68+
import java.util.Collections;
69+
import java.util.LinkedList;
6070
import java.util.List;
6171
import java.util.Map;
6272
import java.util.function.UnaryOperator;
73+
import java.util.stream.Collectors;
6374

6475
import static org.apache.flink.table.planner.plan.utils.AggregateUtil.isTableAggregate;
6576

@@ -105,6 +116,54 @@ public static RelBuilderFactory proto(Context context) {
105116
};
106117
}
107118

119+
/**
120+
* {@link RelBuilder#functionScan(SqlOperator, int, Iterable)} cannot work smoothly with aliases
121+
* which is why we implement a custom one. The method is static because some {@link RelOptRule}s
122+
* don't use {@link FlinkRelBuilder}.
123+
*/
124+
public static RelBuilder pushFunctionScan(
125+
RelBuilder relBuilder,
126+
SqlOperator operator,
127+
int inputCount,
128+
Iterable<RexNode> operands,
129+
List<String> aliases) {
130+
Preconditions.checkArgument(
131+
operator instanceof BridgingSqlFunction.WithTableFunction,
132+
"Table function expected.");
133+
final RexBuilder rexBuilder = relBuilder.getRexBuilder();
134+
final RelDataTypeFactory typeFactory = relBuilder.getTypeFactory();
135+
136+
final List<RelNode> inputs = new LinkedList<>();
137+
for (int i = 0; i < inputCount; i++) {
138+
inputs.add(0, relBuilder.build());
139+
}
140+
141+
final List<RexNode> operandList = CollectionUtil.iterableToList(operands);
142+
143+
final RelDataType functionRelDataType = rexBuilder.deriveReturnType(operator, operandList);
144+
final List<RelDataType> fieldRelDataTypes;
145+
if (functionRelDataType.isStruct()) {
146+
fieldRelDataTypes =
147+
functionRelDataType.getFieldList().stream()
148+
.map(RelDataTypeField::getType)
149+
.collect(Collectors.toList());
150+
} else {
151+
fieldRelDataTypes = Collections.singletonList(functionRelDataType);
152+
}
153+
final RelDataType rowRelDataType = typeFactory.createStructType(fieldRelDataTypes, aliases);
154+
155+
final RexNode call = rexBuilder.makeCall(rowRelDataType, operator, operandList);
156+
final RelNode functionScan =
157+
LogicalTableFunctionScan.create(
158+
relBuilder.getCluster(),
159+
inputs,
160+
call,
161+
null,
162+
rowRelDataType,
163+
Collections.emptySet());
164+
return relBuilder.push(functionScan);
165+
}
166+
108167
public RelBuilder expand(List<List<RexNode>> projects, int expandIdIndex) {
109168
final RelNode input = build();
110169
final RelNode expand = expandFactory.createExpand(input, projects, expandIdIndex);

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/bridging/BridgingSqlFunction.java

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.apache.flink.table.functions.FunctionIdentifier;
2727
import org.apache.flink.table.functions.FunctionKind;
2828
import org.apache.flink.table.planner.calcite.FlinkContext;
29+
import org.apache.flink.table.planner.calcite.FlinkRelBuilder;
2930
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
3031
import org.apache.flink.table.planner.utils.ShortcutUtils;
3132
import org.apache.flink.table.types.DataType;
@@ -35,6 +36,10 @@
3536
import org.apache.calcite.rel.type.RelDataType;
3637
import org.apache.calcite.sql.SqlFunction;
3738
import org.apache.calcite.sql.SqlKind;
39+
import org.apache.calcite.sql.SqlOperator;
40+
import org.apache.calcite.sql.SqlTableFunction;
41+
import org.apache.calcite.sql.type.SqlReturnTypeInference;
42+
import org.apache.calcite.tools.RelBuilder;
3843

3944
import java.util.List;
4045

@@ -52,7 +57,7 @@
5257
* (either a system or user-defined function).
5358
*/
5459
@Internal
55-
public final class BridgingSqlFunction extends SqlFunction {
60+
public class BridgingSqlFunction extends SqlFunction {
5661

5762
private final DataTypeFactory dataTypeFactory;
5863

@@ -108,6 +113,10 @@ public static BridgingSqlFunction of(
108113
functionKind == FunctionKind.SCALAR || functionKind == FunctionKind.TABLE,
109114
"Scalar or table function kind expected.");
110115

116+
if (functionKind == FunctionKind.TABLE) {
117+
return new BridgingSqlFunction.WithTableFunction(
118+
dataTypeFactory, typeFactory, kind, resolvedFunction, typeInference);
119+
}
111120
return new BridgingSqlFunction(
112121
dataTypeFactory, typeFactory, kind, resolvedFunction, typeInference);
113122
}
@@ -177,4 +186,32 @@ public List<String> getParamNames() {
177186
public boolean isDeterministic() {
178187
return resolvedFunction.getDefinition().isDeterministic();
179188
}
189+
190+
// --------------------------------------------------------------------------------------------
191+
// Table function extension
192+
// --------------------------------------------------------------------------------------------
193+
194+
/** Special flavor of {@link BridgingSqlFunction} to indicate a table function to Calcite. */
195+
public static class WithTableFunction extends BridgingSqlFunction implements SqlTableFunction {
196+
197+
private WithTableFunction(
198+
DataTypeFactory dataTypeFactory,
199+
FlinkTypeFactory typeFactory,
200+
SqlKind kind,
201+
ContextResolvedFunction resolvedFunction,
202+
TypeInference typeInference) {
203+
super(dataTypeFactory, typeFactory, kind, resolvedFunction, typeInference);
204+
}
205+
206+
/**
207+
* The conversion to a row type is handled on the caller side. This allows us to perform it
208+
* SQL/Table API-specific. This is in particular important to set the aliases of fields
209+
* correctly (see {@link FlinkRelBuilder#pushFunctionScan(RelBuilder, SqlOperator, int,
210+
* Iterable, List)}).
211+
*/
212+
@Override
213+
public SqlReturnTypeInference getRowTypeInference() {
214+
return getReturnTypeInference();
215+
}
216+
}
180217
}

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/QueryOperationConverter.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,14 @@ public RelNode visit(CalculatedQueryOperation calculatedTable) {
305305
final BridgingSqlFunction sqlFunction =
306306
BridgingSqlFunction.of(relBuilder.getCluster(), resolvedFunction);
307307

308-
return relBuilder
309-
.functionScan(sqlFunction, 0, parameters)
310-
.rename(calculatedTable.getResolvedSchema().getColumnNames())
311-
.build();
308+
FlinkRelBuilder.pushFunctionScan(
309+
relBuilder,
310+
sqlFunction,
311+
0,
312+
parameters,
313+
calculatedTable.getResolvedSchema().getColumnNames());
314+
315+
return relBuilder.build();
312316
}
313317

314318
private RelNode convertLegacyTableFunction(

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/SetOpRewriteUtil.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.flink.table.planner.plan.utils
2020

2121

2222
import org.apache.flink.table.functions.BuiltInFunctionDefinitions
23+
import org.apache.flink.table.planner.calcite.FlinkRelBuilder
2324
import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction
2425

2526
import org.apache.calcite.plan.RelOptUtil
@@ -73,12 +74,15 @@ object SetOpRewriteUtil {
7374
val cluster = relBuilder.getCluster
7475

7576
val sqlFunction = BridgingSqlFunction.of(
76-
relBuilder.getCluster,
77+
cluster,
7778
BuiltInFunctionDefinitions.INTERNAL_REPLICATE_ROWS)
7879

79-
relBuilder
80-
.functionScan(sqlFunction, 0, relBuilder.fields(Util.range(fields.size() + 1)))
81-
.rename(outputRelDataType.getFieldNames)
80+
FlinkRelBuilder.pushFunctionScan(
81+
relBuilder,
82+
sqlFunction,
83+
0,
84+
relBuilder.fields(Util.range(fields.size() + 1)),
85+
outputRelDataType.getFieldNames)
8286

8387
// correlated join
8488
val corSet = Collections.singleton(cluster.createCorrel())

flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -846,8 +846,7 @@ public void testInvalidUseOfSystemScalarFunction() {
846846
tEnv().explainSql(
847847
"INSERT INTO SinkTable "
848848
+ "SELECT * FROM TABLE(MD5('3'))"))
849-
.hasMessageContaining(
850-
"Currently, only table functions can be used in a correlate operation.");
849+
.hasMessageContaining("Argument must be a table function: MD5");
851850
}
852851

853852
@Test
@@ -864,8 +863,7 @@ public void testInvalidUseOfTableFunction() {
864863
tEnv().explainSql(
865864
"INSERT INTO SinkTable "
866865
+ "SELECT RowTableFunction('test')"))
867-
.hasMessageContaining(
868-
"Currently, only scalar functions can be used in a projection or filter operation.");
866+
.hasMessageContaining("Cannot call table function here: 'RowTableFunction'");
869867
}
870868

871869
@Test

0 commit comments

Comments
 (0)