Skip to content

Commit 1c630c6

Browse files
add stddev udf
1 parent d87ba84 commit 1c630c6

File tree

3 files changed

+199
-0
lines changed

3 files changed

+199
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package com.airwallex.airskiff.flink.udx;
2+
3+
import org.apache.flink.table.api.dataview.ListView;
4+
5+
public class StdDevAccumulator {
6+
public ListView<Double> nums = new ListView<>();
7+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package com.airwallex.airskiff.flink.udx;
2+
3+
import org.apache.flink.table.api.dataview.ListView;
4+
import org.apache.flink.table.functions.AggregateFunction;
5+
6+
import java.util.List;
7+
8+
public class StdDevFunction extends AggregateFunction<Double, StdDevAccumulator> {
9+
@Override
10+
public Double getValue(StdDevAccumulator acc) {
11+
List<Double> list = acc.nums.getList();
12+
if(list.size() == 0 || list.size() == 1) {
13+
return null;
14+
} else {
15+
Double avg = list.stream().mapToDouble(a -> a).average().orElse(0.0);
16+
Double sum = 0.0;
17+
for(Double num : list) {
18+
sum += (num-avg)*(num-avg);
19+
}
20+
return Math.sqrt(sum/(list.size()-1));
21+
}
22+
}
23+
24+
@Override
25+
public StdDevAccumulator createAccumulator() {
26+
return new StdDevAccumulator();
27+
}
28+
29+
public void accumulate(StdDevAccumulator acc, Double value) throws Exception {
30+
acc.nums.add(value);
31+
}
32+
33+
public void retract(StdDevAccumulator acc, Double value) throws Exception {
34+
acc.nums.remove(value);
35+
}
36+
37+
public void resetAccumulator(StdDevAccumulator acc) {
38+
acc.nums.clear();
39+
}
40+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
package com.airwallex.airskiff.flink.udx;
2+
3+
import org.apache.flink.api.common.RuntimeExecutionMode;
4+
import org.apache.flink.api.common.eventtime.*;
5+
import org.apache.flink.api.java.tuple.Tuple3;
6+
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
7+
import org.apache.flink.streaming.api.datastream.DataStream;
8+
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
9+
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
10+
import org.apache.flink.table.api.Table;
11+
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
12+
import org.apache.flink.test.util.MiniClusterWithClientResource;
13+
import org.apache.flink.types.Row;
14+
import org.junit.After;
15+
import org.junit.Before;
16+
import org.junit.Test;
17+
18+
import java.util.ArrayList;
19+
import java.util.Collections;
20+
import java.util.List;
21+
22+
import static org.apache.flink.table.api.Expressions.$;
23+
import static org.junit.Assert.assertEquals;
24+
import static org.junit.Assert.assertNotEquals;
25+
26+
public class StdDevFunctionTest {
27+
public static MiniClusterWithClientResource flinkCluster;
28+
public StreamExecutionEnvironment env;
29+
public StreamTableEnvironment tableEnv;
30+
31+
@Before
32+
public void setup() throws Exception {
33+
flinkCluster = new MiniClusterWithClientResource(
34+
new MiniClusterResourceConfiguration.Builder()
35+
.setNumberSlotsPerTaskManager(1)
36+
.setNumberTaskManagers(1)
37+
.build());
38+
flinkCluster.before();
39+
env = StreamExecutionEnvironment.getExecutionEnvironment();
40+
env.setRuntimeMode(RuntimeExecutionMode.STREAMING);
41+
tableEnv = StreamTableEnvironment.create(env);
42+
tableEnv.createTemporarySystemFunction("ASStddev", StdDevFunction.class);
43+
}
44+
45+
@Test
46+
public void testSQL() throws Exception {
47+
Sink.values.clear();
48+
DataStream<Tuple3<Long, String, Double>> source = env.fromElements(
49+
new Tuple3(1708403000000L, "a", 3.0), // null
50+
new Tuple3(1708403300000L, "a", 6.0), // 2.1213...
51+
new Tuple3(1708403600000L, "a", 9.0), // stddev([3.0, 6.0, 9.0]) = 3.0
52+
new Tuple3(1708403900000L, "a", 12.0) // stddev([6.0, 9.0, 12.0]) = 3.0
53+
);
54+
DataStream<Tuple3<Long, String, Double>> ds = source.assignTimestampsAndWatermarks(
55+
WatermarkStrategy.<Tuple3<Long, String, Double>>forMonotonousTimestamps().withTimestampAssigner(
56+
(t, x) -> t.f0)
57+
);
58+
tableEnv.createTemporaryView("tmp", ds, $("f0"), $("f1"), $("f2"), $("f0").rowtime().as("row_time"));
59+
60+
String sql =
61+
"SELECT f0, f1 " +
62+
" ,ASStddev(f2) OVER (PARTITION BY f1 ORDER BY row_time RANGE BETWEEN INTERVAL '10' MINUTE PRECEDING AND CURRENT ROW) " +
63+
" FROM tmp";
64+
Table t = tableEnv.sqlQuery(sql);
65+
tableEnv.toDataStream(t).addSink(new Sink());
66+
env.execute();
67+
assertEquals(4, Sink.values.size());
68+
assertEquals(1708403000000L, Sink.values.get(0).getField(0));
69+
assertEquals("a", Sink.values.get(0).getField(1));
70+
assertEquals(null, Sink.values.get(0).getField(2));
71+
assertEquals(1708403600000L, Sink.values.get(2).getField(0));
72+
assertEquals("a", Sink.values.get(2).getField(1));
73+
assertEquals(3.0, Sink.values.get(2).getField(2));
74+
assertEquals(1708403900000L, Sink.values.get(3).getField(0));
75+
assertEquals("a", Sink.values.get(3).getField(1));
76+
assertEquals(3.0, Sink.values.get(3).getField(2));
77+
}
78+
79+
@Test
80+
public void compareWithLibStddev() throws Exception {
81+
Sink.values.clear();
82+
List<Tuple3<Long, String, Double>> list = new ArrayList<>();
83+
for(int i = 0; i < 1000; i++) {
84+
list.add(new Tuple3(1708403000000L + i*60000, "a", i*1.0));
85+
}
86+
DataStream<Tuple3<Long, String, Double>> source = env.fromCollection(list);
87+
DataStream<Tuple3<Long, String, Double>> ds = source.assignTimestampsAndWatermarks(
88+
WatermarkStrategy.<Tuple3<Long, String, Double>>forMonotonousTimestamps().withTimestampAssigner(
89+
(t, x) -> t.f0)
90+
);
91+
tableEnv.createTemporaryView("tmp", ds, $("f0"), $("f1"), $("f2"), $("f0").rowtime().as("row_time"));
92+
93+
String sql =
94+
"SELECT f0, f1 " +
95+
" ,ASStddev(f2) OVER (PARTITION BY f1 ORDER BY row_time RANGE BETWEEN INTERVAL '10' MINUTE PRECEDING AND CURRENT ROW) " +
96+
" ,stddev(f2) OVER (PARTITION BY f1 ORDER BY row_time RANGE BETWEEN INTERVAL '10' MINUTE PRECEDING AND CURRENT ROW) " +
97+
" FROM tmp";
98+
Table t = tableEnv.sqlQuery(sql);
99+
tableEnv.toDataStream(t).addSink(new Sink());
100+
env.execute();
101+
for(int i = 0; i < 1000; i++) {
102+
assertEquals(Sink.values.get(i).getField(2), Sink.values.get(i).getField(3));
103+
}
104+
}
105+
106+
@Test
107+
public void diffWithLibStddev() throws Exception {
108+
Sink.values.clear();
109+
List<Tuple3<Long, String, Double>> list = new ArrayList<>();
110+
for(int i = 0; i < 1000; i++) {
111+
list.add(new Tuple3(1708403000000L + i*60000, "a", 0.001));
112+
}
113+
DataStream<Tuple3<Long, String, Double>> source = env.fromCollection(list);
114+
DataStream<Tuple3<Long, String, Double>> ds = source.assignTimestampsAndWatermarks(
115+
WatermarkStrategy.<Tuple3<Long, String, Double>>forMonotonousTimestamps().withTimestampAssigner(
116+
(t, x) -> t.f0)
117+
);
118+
tableEnv.createTemporaryView("tmp", ds, $("f0"), $("f1"), $("f2"), $("f0").rowtime().as("row_time"));
119+
120+
String sql =
121+
"SELECT f0, f1 " +
122+
" ,ASStddev(f2) OVER (PARTITION BY f1 ORDER BY row_time RANGE BETWEEN INTERVAL '10' MINUTE PRECEDING AND CURRENT ROW) " +
123+
" ,stddev(f2) OVER (PARTITION BY f1 ORDER BY row_time RANGE BETWEEN INTERVAL '10' MINUTE PRECEDING AND CURRENT ROW) " +
124+
" FROM tmp";
125+
Table t = tableEnv.sqlQuery(sql);
126+
tableEnv.toDataStream(t).addSink(new Sink());
127+
env.execute();
128+
for(int i = 1; i < 1000; i++) {
129+
assertEquals(0.0, Sink.values.get(i).getField(2));
130+
}
131+
int NaNCnt = 0;
132+
for(int i = 1; i < 1000; i++) {
133+
if(Double.isNaN((Double)Sink.values.get(i).getField(3))) {
134+
NaNCnt++;
135+
}
136+
}
137+
assertNotEquals(0, NaNCnt);
138+
}
139+
140+
@After
141+
public void tearDown() throws Exception {
142+
flinkCluster.after();
143+
}
144+
145+
private static class Sink implements SinkFunction<Row> {
146+
public static final List<Row> values = Collections.synchronizedList(new ArrayList<>());
147+
@Override
148+
public void invoke(Row value, SinkFunction.Context context) throws Exception {
149+
values.add(value);
150+
}
151+
}
152+
}

0 commit comments

Comments
 (0)