Skip to content

Commit 96b79eb

Browse files
authored
[FLINK-36615] Support LEAD/LAG functions in Table API (apache#25582)
1 parent 96bb45d commit 96b79eb

File tree

22 files changed

+509
-64
lines changed

22 files changed

+509
-64
lines changed

flink-python/pyflink/table/expressions.py

+22
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,28 @@ def json_array_agg(on_null: JsonOnNull, item_expr) -> Expression:
867867
return _binary_op("jsonArrayAgg", on_null._to_j_json_on_null(), item_expr)
868868

869869

870+
def lag(expr, offset=1, default=None) -> Expression:
871+
"""
872+
A window function that provides access to a row at a specified physical offset which comes
873+
before the current row.
874+
"""
875+
if default is None:
876+
return _binary_op("lag", expr, offset)
877+
else:
878+
return _ternary_op("lag", expr, offset, default)
879+
880+
881+
def lead(expr, offset=1, default=None) -> Expression:
882+
"""
883+
A window function that provides access to a row at a specified physical offset which comes
884+
after the current row.
885+
"""
886+
if default is None:
887+
return _binary_op("lead", expr, offset)
888+
else:
889+
return _ternary_op("lead", expr, offset, default)
890+
891+
870892
def call(f: Union[str, UserDefinedFunctionWrapper], *args) -> Expression:
871893
"""
872894
The first parameter `f` could be a str or a Python user-defined function.

flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/Expressions.java

+120
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,126 @@ public static ApiExpression jsonArrayAgg(JsonOnNull onNull, Object itemExpr) {
947947
return apiCall(functionDefinition, itemExpr);
948948
}
949949

950+
/**
951+
* A window function that provides access to a row that comes directly after the current row.
952+
*
953+
* <p>Example:
954+
*
955+
* <pre>{@code
956+
* table.window(Over.orderBy($("ts")).partitionBy("organisation").as("w"))
957+
* .select(
958+
* $("organisation"),
959+
* $("revenue"),
960+
* lag($("revenue")).over($("w").as("next_revenue")
961+
* )
962+
* }</pre>
963+
*/
964+
public static ApiExpression lead(Object value) {
965+
return apiCall(BuiltInFunctionDefinitions.LEAD, value);
966+
}
967+
968+
/**
969+
* A window function that provides access to a row at a specified physical offset which comes
970+
* after the current row.
971+
*
972+
* <p>Example:
973+
*
974+
* <pre>{@code
975+
* table.window(Over.orderBy($("ts")).partitionBy("organisation").as("w"))
976+
* .select(
977+
* $("organisation"),
978+
* $("revenue"),
979+
* lag($("revenue"), 1).over($("w").as("next_revenue")
980+
* )
981+
* }</pre>
982+
*/
983+
public static ApiExpression lead(Object value, Object offset) {
984+
return apiCall(BuiltInFunctionDefinitions.LEAD, value, offset);
985+
}
986+
987+
/**
988+
* A window function that provides access to a row at a specified physical offset which comes
989+
* after the current row.
990+
*
991+
* <p>The value to return when offset is beyond the scope of the partition. If a default value
992+
* is not specified, NULL is returned. {@code default} must be type-compatible with {@code
993+
* value}.
994+
*
995+
* <p>Example:
996+
*
997+
* <pre>{@code
998+
* table.window(Over.orderBy($("ts")).partitionBy("organisation").as("w"))
999+
* .select(
1000+
* $("organisation"),
1001+
* $("revenue"),
1002+
* lag($("revenue"), 1, lit(0)).over($("w").as("next_revenue")
1003+
* )
1004+
* }</pre>
1005+
*/
1006+
public static ApiExpression lead(Object value, Object offset, Object defaultValue) {
1007+
return apiCall(BuiltInFunctionDefinitions.LEAD, value, offset, defaultValue);
1008+
}
1009+
1010+
/**
1011+
* A window function that provides access to a row that comes directly before the current row.
1012+
*
1013+
* <p>Example:
1014+
*
1015+
* <pre>{@code
1016+
* table.window(Over.orderBy($("ts")).partitionBy("organisation").as("w"))
1017+
* .select(
1018+
* $("organisation"),
1019+
* $("revenue"),
1020+
* lag($("revenue")).over($("w").as("prev_revenue")
1021+
* )
1022+
* }</pre>
1023+
*/
1024+
public static ApiExpression lag(Object value) {
1025+
return apiCall(BuiltInFunctionDefinitions.LAG, value);
1026+
}
1027+
1028+
/**
1029+
* A window function that provides access to a row at a specified physical offset which comes
1030+
* before the current row.
1031+
*
1032+
* <p>Example:
1033+
*
1034+
* <pre>{@code
1035+
* table.window(Over.orderBy($("ts")).partitionBy("organisation").as("w"))
1036+
* .select(
1037+
* $("organisation"),
1038+
* $("revenue"),
1039+
* lag($("revenue"), 1).over($("w").as("prev_revenue")
1040+
* )
1041+
* }</pre>
1042+
*/
1043+
public static ApiExpression lag(Object value, Object offset) {
1044+
return apiCall(BuiltInFunctionDefinitions.LAG, value, offset);
1045+
}
1046+
1047+
/**
1048+
* A window function that provides access to a row at a specified physical offset which comes
1049+
* before the current row.
1050+
*
1051+
* <p>The value to return when offset is beyond the scope of the partition. If a default value
1052+
* is not specified, NULL is returned. {@code default} must be type-compatible with {@code
1053+
* value}.
1054+
*
1055+
* <p>Example:
1056+
*
1057+
* <pre>{@code
1058+
* org.window(Over.orderBy($("ts")).partitionBy("organisation").as("w"))
1059+
* .select(
1060+
* $("organisation"),
1061+
* $("revenue"),
1062+
* lag($("revenue"), 1, lit(0)).over($("w").as("prev_revenue")
1063+
* )
1064+
* }</pre>
1065+
*/
1066+
public static ApiExpression lag(Object value, Object offset, Object defaultValue) {
1067+
return apiCall(BuiltInFunctionDefinitions.LAG, value, offset, defaultValue);
1068+
}
1069+
9501070
/**
9511071
* A call to a function that will be looked up in a catalog. There are two kinds of functions:
9521072
*

flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/OverWindow.java

+9-7
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import org.apache.flink.annotation.PublicEvolving;
2222
import org.apache.flink.table.expressions.Expression;
2323

24+
import javax.annotation.Nullable;
25+
2426
import java.util.List;
2527
import java.util.Optional;
2628

@@ -36,15 +38,15 @@ public final class OverWindow {
3638
private final Expression alias;
3739
private final List<Expression> partitioning;
3840
private final Expression order;
39-
private final Expression preceding;
40-
private final Optional<Expression> following;
41+
private final @Nullable Expression preceding;
42+
private final @Nullable Expression following;
4143

4244
OverWindow(
4345
Expression alias,
4446
List<Expression> partitionBy,
4547
Expression orderBy,
46-
Expression preceding,
47-
Optional<Expression> following) {
48+
@Nullable Expression preceding,
49+
@Nullable Expression following) {
4850
this.alias = alias;
4951
this.partitioning = partitionBy;
5052
this.order = orderBy;
@@ -64,11 +66,11 @@ public Expression getOrder() {
6466
return order;
6567
}
6668

67-
public Expression getPreceding() {
68-
return preceding;
69+
public Optional<Expression> getPreceding() {
70+
return Optional.ofNullable(preceding);
6971
}
7072

7173
public Optional<Expression> getFollowing() {
72-
return following;
74+
return Optional.ofNullable(following);
7375
}
7476
}

flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/OverWindowPartitionedOrdered.java

+1-8
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@
2222
import org.apache.flink.table.expressions.Expression;
2323

2424
import java.util.List;
25-
import java.util.Optional;
2625

2726
import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
28-
import static org.apache.flink.table.expressions.ApiExpressionUtils.valueLiteral;
2927

3028
/** Partially defined over window with (optional) partitioning and order. */
3129
@PublicEvolving
@@ -66,11 +64,6 @@ public OverWindow as(String alias) {
6664
* @return the fully defined over window
6765
*/
6866
public OverWindow as(Expression alias) {
69-
return new OverWindow(
70-
alias,
71-
partitionBy,
72-
orderBy,
73-
valueLiteral(OverWindowRange.UNBOUNDED_RANGE),
74-
Optional.empty());
67+
return new OverWindow(alias, partitionBy, orderBy, null, null);
7568
}
7669
}

flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/OverWindowPartitionedOrderedPreceding.java

+5-4
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
import org.apache.flink.annotation.PublicEvolving;
2222
import org.apache.flink.table.expressions.Expression;
2323

24+
import javax.annotation.Nullable;
25+
2426
import java.util.List;
25-
import java.util.Optional;
2627

2728
import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
2829

@@ -32,8 +33,8 @@ public final class OverWindowPartitionedOrderedPreceding {
3233

3334
private final List<Expression> partitionBy;
3435
private final Expression orderBy;
35-
private final Expression preceding;
36-
private Optional<Expression> optionalFollowing = Optional.empty();
36+
private final @Nullable Expression preceding;
37+
private @Nullable Expression optionalFollowing = null;
3738

3839
OverWindowPartitionedOrderedPreceding(
3940
List<Expression> partitionBy, Expression orderBy, Expression preceding) {
@@ -69,7 +70,7 @@ public OverWindow as(Expression alias) {
6970
* @return an over window with defined following
7071
*/
7172
public OverWindowPartitionedOrderedPreceding following(Expression following) {
72-
optionalFollowing = Optional.of(following);
73+
optionalFollowing = following;
7374
return this;
7475
}
7576
}

flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/ExpressionResolver.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ private LocalOverWindow resolveOverWindow(OverWindow overWindow) {
367367
overWindow.getAlias(),
368368
prepareExpressions(overWindow.getPartitioning()),
369369
resolveFieldsInSingleExpression(overWindow.getOrder()),
370-
resolveFieldsInSingleExpression(overWindow.getPreceding()),
370+
overWindow.getPreceding().map(this::resolveFieldsInSingleExpression).orElse(null),
371371
overWindow.getFollowing().map(this::resolveFieldsInSingleExpression).orElse(null));
372372
}
373373

flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/LocalOverWindow.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ public final class LocalOverWindow {
3636

3737
private Expression orderBy;
3838

39-
private Expression preceding;
39+
private @Nullable Expression preceding;
4040

4141
private @Nullable Expression following;
4242

4343
LocalOverWindow(
4444
Expression alias,
4545
List<Expression> partitionBy,
4646
Expression orderBy,
47-
Expression preceding,
47+
@Nullable Expression preceding,
4848
@Nullable Expression following) {
4949
this.alias = alias;
5050
this.partitionBy = partitionBy;
@@ -65,8 +65,8 @@ public Expression getOrderBy() {
6565
return orderBy;
6666
}
6767

68-
public Expression getPreceding() {
69-
return preceding;
68+
public Optional<Expression> getPreceding() {
69+
return Optional.ofNullable(preceding);
7070
}
7171

7272
public Optional<Expression> getFollowing() {

flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/OverWindowResolverRule.java

+27-14
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package org.apache.flink.table.expressions.resolver.rules;
2020

2121
import org.apache.flink.annotation.Internal;
22+
import org.apache.flink.table.api.DataTypes;
2223
import org.apache.flink.table.api.OverWindowRange;
2324
import org.apache.flink.table.api.ValidationException;
2425
import org.apache.flink.table.expressions.Expression;
@@ -33,7 +34,6 @@
3334
import java.util.List;
3435
import java.util.stream.Collectors;
3536

36-
import static java.util.Arrays.asList;
3737
import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedCall;
3838
import static org.apache.flink.table.expressions.ApiExpressionUtils.valueLiteral;
3939
import static org.apache.flink.table.types.logical.LogicalTypeRoot.BIGINT;
@@ -76,14 +76,29 @@ public Expression visit(UnresolvedCallExpression unresolvedCall) {
7676
new ValidationException(
7777
"Could not resolve over call."));
7878

79-
Expression following = calculateOverWindowFollowing(referenceWindow);
80-
List<Expression> newArgs =
81-
new ArrayList<>(
82-
asList(
83-
children.get(0),
84-
referenceWindow.getOrderBy(),
85-
referenceWindow.getPreceding(),
86-
following));
79+
UnresolvedCallExpression agg = (UnresolvedCallExpression) children.get(0);
80+
81+
final List<Expression> newArgs = new ArrayList<>();
82+
newArgs.add(agg);
83+
newArgs.add(referenceWindow.getOrderBy());
84+
if (agg.getFunctionDefinition() == BuiltInFunctionDefinitions.LAG
85+
|| agg.getFunctionDefinition() == BuiltInFunctionDefinitions.LEAD) {
86+
if (referenceWindow.getPreceding().isPresent()
87+
|| referenceWindow.getFollowing().isPresent()) {
88+
throw new ValidationException(
89+
"LEAD/LAG functions do not support "
90+
+ "providing RANGE/ROW bounds.");
91+
}
92+
newArgs.add(valueLiteral(null, DataTypes.NULL()));
93+
newArgs.add(valueLiteral(null, DataTypes.NULL()));
94+
} else {
95+
Expression preceding =
96+
referenceWindow
97+
.getPreceding()
98+
.orElse(valueLiteral(OverWindowRange.UNBOUNDED_RANGE));
99+
newArgs.add(preceding);
100+
newArgs.add(calculateOverWindowFollowing(referenceWindow, preceding));
101+
}
87102

88103
newArgs.addAll(referenceWindow.getPartitionBy());
89104

@@ -96,15 +111,13 @@ public Expression visit(UnresolvedCallExpression unresolvedCall) {
96111
}
97112
}
98113

99-
private Expression calculateOverWindowFollowing(LocalOverWindow referenceWindow) {
114+
private Expression calculateOverWindowFollowing(
115+
LocalOverWindow referenceWindow, Expression preceding) {
100116
return referenceWindow
101117
.getFollowing()
102118
.orElseGet(
103119
() -> {
104-
WindowKind kind =
105-
referenceWindow
106-
.getPreceding()
107-
.accept(OVER_WINDOW_KIND_EXTRACTOR);
120+
WindowKind kind = preceding.accept(OVER_WINDOW_KIND_EXTRACTOR);
108121
if (kind == WindowKind.ROW) {
109122
return valueLiteral(OverWindowRange.CURRENT_ROW);
110123
} else {

flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java

+18
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,24 @@ ANY, and(logical(LogicalTypeRoot.BOOLEAN), LITERAL)
782782
.outputTypeStrategy(TypeStrategies.aggArg0(t -> t, true))
783783
.build();
784784

785+
public static final BuiltInFunctionDefinition LEAD =
786+
BuiltInFunctionDefinition.newBuilder()
787+
.name("lead")
788+
.kind(AGGREGATE)
789+
.inputTypeStrategy(SpecificInputTypeStrategies.LEAD_LAG)
790+
.outputTypeStrategy(SpecificTypeStrategies.LEAD_LAG)
791+
.runtimeDeferred()
792+
.build();
793+
794+
public static final BuiltInFunctionDefinition LAG =
795+
BuiltInFunctionDefinition.newBuilder()
796+
.name("lag")
797+
.kind(AGGREGATE)
798+
.inputTypeStrategy(SpecificInputTypeStrategies.LEAD_LAG)
799+
.outputTypeStrategy(SpecificTypeStrategies.LEAD_LAG)
800+
.runtimeDeferred()
801+
.build();
802+
785803
public static final BuiltInFunctionDefinition LISTAGG =
786804
BuiltInFunctionDefinition.newBuilder()
787805
.name("listAgg")

0 commit comments

Comments
 (0)