counts = Tuple3.of(0, 0, 0);
+ summaries.forEach(
+ (item, summary) -> {
+ counts.f0 += 1;
+ counts.f1 += summary.count;
+ Node p = summary.head;
+ while (p != null) {
+ counts.f2 += 1;
+ p = p.successor;
+ }
+ });
+ LOG.info("fptree_profile {}", counts);
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/fpgrowth/FPGrowth.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/fpgrowth/FPGrowth.java
new file mode 100644
index 000000000..ea39d2224
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/fpgrowth/FPGrowth.java
@@ -0,0 +1,907 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.recommendation.fpgrowth;
+
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichFilterFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.fpgrowth.AssociationRule;
+import org.apache.flink.ml.common.fpgrowth.FPTree;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.StringUtils;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.Set;
+
+/**
+ * An implementation of parallel FP-growth algorithm to mine frequent itemset.
+ *
+ * For detail descriptions, please refer to: Han et al., Mining frequent patterns without
+ * candidate generation. Li et al., PFP:
+ * Parallel FP-growth for query recommendation
+ */
+public class FPGrowth implements AlgoOperator, FPGrowthParams {
+ private static final Logger LOG = LoggerFactory.getLogger(FPGrowth.class);
+ private static final String ITEM_INDEX = "ITEM_INDEX";
+ private static final String[] FREQ_PATTERN_OUTPUT_COLS = {
+ "items", "support_count", "item_count"
+ };
+ private static final String[] RULES_OUTPUT_COLS = {
+ "rule", "item_count", "lift", "support_percent", "confidence_percent", "transaction_count"
+ };
+
+ private final Map, Object> paramMap = new HashMap<>();
+
+ public FPGrowth() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ Table data = inputs[0];
+ final String itemColName = getItemsCol();
+ final double minLift = getMinLift();
+ final double minConfidence = getMinConfidence();
+ final int maxPatternLen = getMaxPatternLength();
+ final String fieldDelimiter = getFieldDelimiter();
+ StreamTableEnvironment tenv =
+ (StreamTableEnvironment) ((TableImpl) data).getTableEnvironment();
+
+ DataStream itemTokens =
+ tenv.toDataStream(data)
+ .map(
+ new MapFunction() {
+
+ @Override
+ public String[] map(Row value) throws Exception {
+ Set itemset = new HashSet<>();
+ String itemsetStr = (String) value.getField(itemColName);
+ if (!StringUtils.isNullOrWhitespaceOnly(itemsetStr)) {
+ String[] splited = itemsetStr.split(fieldDelimiter);
+ itemset.addAll(Arrays.asList(splited));
+ }
+ return itemset.toArray(new String[0]);
+ }
+ })
+ .name("scan_transaction");
+
+ DataStream items =
+ itemTokens
+ .flatMap(
+ new FlatMapFunction() {
+ @Override
+ public void flatMap(
+ String[] strings, Collector collector)
+ throws Exception {
+ for (String s : strings) {
+ collector.collect(s);
+ }
+ }
+ })
+ .returns(Types.STRING);
+
+ // Count the total num of transactions.
+ DataStream> transactionCount = countRecords(itemTokens);
+ // Generate a Datastream of minSupport
+ final Double minSupport = getMinSupport();
+ final int minSupportThreshold = getMinSupportCount();
+ DataStream minSupportStream =
+ calculateMinSupport(tenv, transactionCount, minSupport, minSupportThreshold);
+ // Count the total number of each item.
+ DataStream> itemCounts = countItems(items);
+ // Drop items with support smaller than requirement.
+ final String minSuppoerCountDouble = "MIN_SUPPOER_COUNT_DOUBLE";
+ itemCounts =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(itemCounts),
+ Collections.singletonMap(minSuppoerCountDouble, minSupportStream),
+ inputList -> {
+ DataStream input = inputList.get(0);
+ return input.filter(
+ new RichFilterFunction>() {
+ Double minSuppport = null;
+
+ @Override
+ public boolean filter(Tuple2 o)
+ throws Exception {
+ if (null == minSuppport) {
+ minSuppport =
+ (double)
+ getRuntimeContext()
+ .getBroadcastVariable(
+ minSuppoerCountDouble)
+ .get(0);
+ }
+ if (o.f1 < minSupportThreshold) {
+ return false;
+ } else {
+ return true;
+ }
+ }
+ });
+ });
+
+ // Assign items with indices, ordered by their support from high to low.
+ DataStream> itemCountIndex =
+ assignItemIndex(tenv, itemCounts);
+ // Assign items with partition id.
+ DataStream> itemPid = partitionItems(itemCountIndex);
+
+ DataStream> itemIndex =
+ itemCountIndex
+ .map(t3 -> Tuple2.of(t3.f0, t3.f2))
+ .returns(Types.TUPLE(Types.STRING, Types.INT));
+ DataStream> itemCount =
+ itemCountIndex
+ .map(t3 -> Tuple2.of(t3.f2, t3.f1))
+ .returns(Types.TUPLE(Types.INT, Types.INT));
+
+ DataStream transactions = tokensToIndices(itemTokens, itemIndex);
+ DataStream> transactionGroups =
+ genCondTransactions(transactions, itemPid);
+
+ // Extract all frequent patterns.
+ DataStream> indexPatterns =
+ mineFreqPattern(transactionGroups, itemPid, maxPatternLen, minSupportStream);
+ DataStream tokenPatterns = patternIndicesToTokens(indexPatterns, itemIndex);
+
+ // Extract consequent rules from frequent patterns.
+ DataStream> rules =
+ AssociationRule.extractSingleConsequentRules(
+ indexPatterns,
+ itemCount,
+ transactionCount,
+ minConfidence,
+ minLift,
+ maxPatternLen);
+ DataStream rulesToken = ruleIndexToToken(rules, itemIndex);
+
+ Table patternTable = tenv.fromDataStream(tokenPatterns);
+ Table rulesTable = tenv.fromDataStream(rulesToken);
+
+ return new Table[] {patternTable, rulesTable};
+ }
+
+ /**
+ * Generate items partition. To achieve load balance, we assign to each item a score that
+ * represents its estimation of number of nodes in the Fp-tree. Then we greedily partition the
+ * items to balance the sum of scores in each partition.
+ *
+ * @param itemCountIndex A DataStream of tuples of item token, count and id.
+ * @return A DataStream of tuples of item id and partition id
+ */
+ private static DataStream> partitionItems(
+ DataStream> itemCountIndex) {
+ DataStream> partition =
+ itemCountIndex.transform(
+ "ComputingPartitionCost",
+ Types.TUPLE(Types.INT, Types.INT),
+ new ComputingPartitionCost());
+ return partition;
+ }
+
+ private static class ComputingPartitionCost
+ extends AbstractStreamOperator>
+ implements OneInputStreamOperator<
+ Tuple3, Tuple2>,
+ BoundedOneInput {
+ List> itemCounts;
+
+ private ComputingPartitionCost() {
+ itemCounts = new ArrayList<>();
+ }
+
+ @Override
+ public void endInput() throws Exception {
+ int numPartitions = getRuntimeContext().getNumberOfParallelSubtasks();
+
+ PriorityQueue> queue =
+ new PriorityQueue<>(numPartitions, Comparator.comparingDouble(o -> o.f1));
+
+ for (int i = 0; i < numPartitions; i++) {
+ queue.add(Tuple2.of(i, 0.0));
+ }
+
+ List scaledItemCount = new ArrayList<>(itemCounts.size());
+ for (int i = 0; i < itemCounts.size(); i++) {
+ Tuple2 item = itemCounts.get(i);
+ double pos = (double) (item.f0 + 1) / ((double) itemCounts.size());
+ double score = pos * item.f1.doubleValue();
+ scaledItemCount.add(score);
+ }
+
+ List order = new ArrayList<>(itemCounts.size());
+ for (int i = 0; i < itemCounts.size(); i++) {
+ order.add(i);
+ }
+
+ order.sort(
+ (o1, o2) -> {
+ double s1 = scaledItemCount.get(o1);
+ double s2 = scaledItemCount.get(o2);
+ return Double.compare(s2, s1);
+ });
+
+ // greedily assign partition number to each item
+ for (int i = 0; i < itemCounts.size(); i++) {
+ Tuple2 item = itemCounts.get(order.get(i));
+ double score = scaledItemCount.get(order.get(i));
+ Tuple2 target = queue.poll();
+ int targetPartition = target.f0;
+ target.f1 += score;
+ queue.add(target);
+ output.collect(new StreamRecord<>(Tuple2.of(item.f0, targetPartition)));
+ }
+ }
+
+ @Override
+ public void processElement(StreamRecord> streamRecord)
+ throws Exception {
+ Tuple3 t3 = streamRecord.getValue();
+ itemCounts.add(Tuple2.of(t3.f2, t3.f1));
+ }
+ }
+
+ /**
+ * Generate conditional transactions for each partitions.
+ *
+ * Scan from the longest substring in a transaction, partition the substring into the group
+ * where its last element belongs. If the partition already contains a longer substring, skip
+ * it.
+ *
+ * @param transactions A DataStream of transactions.
+ * @param targetPartition A DataStream of tuples of item and partition number.
+ * @return substring of transactions and partition number
+ */
+ private static DataStream> genCondTransactions(
+ DataStream transactions, DataStream> targetPartition) {
+ final String itemPartition = "ITEM_PARTITION";
+ Map> broadcastMap = new HashMap<>(1);
+ broadcastMap.put(itemPartition, targetPartition);
+ return BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(transactions),
+ broadcastMap,
+ inputLists -> {
+ DataStream transactionStream = inputLists.get(0);
+ return transactionStream.flatMap(
+ new RichFlatMapFunction>() {
+ transient Map partitioner;
+ // list of flags used to skip partition that is not empty
+ transient int[] flags;
+
+ @Override
+ public void flatMap(
+ int[] transaction, Collector> out)
+ throws Exception {
+ if (null == flags) {
+ int numPartition =
+ getRuntimeContext().getNumberOfParallelSubtasks();
+ this.flags = new int[numPartition];
+ }
+ if (null == partitioner) {
+ List> bc =
+ getRuntimeContext()
+ .getBroadcastVariable(itemPartition);
+ partitioner = new HashMap<>();
+ for (Tuple2 t2 : bc) {
+ partitioner.put(t2.f0, t2.f1);
+ }
+ }
+ Arrays.fill(flags, 0);
+ int cnt = transaction.length;
+ // starts from the longest substring
+ for (; cnt > 0; cnt--) {
+ int lastPos = cnt - 1;
+ int partition = this.partitioner.get(transaction[lastPos]);
+ if (flags[partition] == 0) {
+ List condTransaction = new ArrayList<>(cnt);
+ for (int j = 0; j < cnt; j++) {
+ condTransaction.add(transaction[j]);
+ }
+ int[] tr = new int[condTransaction.size()];
+ for (int j = 0; j < tr.length; j++) {
+ tr[j] = condTransaction.get(j);
+ }
+ out.collect(Tuple2.of(partition, tr));
+ flags[partition] = 1;
+ }
+ }
+ }
+ });
+ });
+ }
+
+ /**
+ * Mine frequent patterns locally in each partition.
+ *
+ * @param condTransactions The conditional transactions with partition id.
+ * @param partitioner A DataStream of tuples of item id and partition id.
+ * @param maxPatternLength Maximum pattern length.
+ * @return A DataStream of tuples of count and frequent patterns.
+ */
+ private static DataStream> mineFreqPattern(
+ DataStream> condTransactions,
+ DataStream> partitioner,
+ int maxPatternLength,
+ DataStream minSupport) {
+ condTransactions =
+ condTransactions.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f0);
+ partitioner = partitioner.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f1);
+
+ List> inputList = new ArrayList<>();
+ inputList.add(condTransactions);
+ inputList.add(partitioner);
+ return BroadcastUtils.withBroadcastStream(
+ inputList,
+ Collections.singletonMap("MIN_COUNT", minSupport),
+ inputLists -> {
+ DataStream condStream = inputLists.get(0);
+ DataStream partitionStream = inputLists.get(1);
+ return condStream
+ .connect(partitionStream)
+ .transform(
+ "mine-freq-pattern",
+ Types.TUPLE(
+ Types.INT,
+ PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO),
+ new FpTreeConstructor(maxPatternLength));
+ });
+ }
+
+ private static class MyRichFunction extends AbstractRichFunction {}
+
+ /* The Operator to construct fp-trees of each worker. */
+ // private static class FpTreeConstructor extends AbstractStreamOperator>
+ private static class FpTreeConstructor
+ extends AbstractUdfStreamOperator, RichFunction>
+ implements TwoInputStreamOperator<
+ Tuple2,
+ Tuple2,
+ Tuple2>,
+ BoundedMultiInput {
+
+ private final FPTree tree = new FPTree();
+ private boolean input1Ends;
+ private boolean input2Ends;
+ private int minSupportCnt;
+ private final int maxPatternLen;
+ private Set itemList = new HashSet<>();
+
+ FpTreeConstructor(int maxPatternLength) {
+ super(new MyRichFunction());
+ maxPatternLen = maxPatternLength;
+ }
+
+ @Override
+ public void open() throws Exception {
+ super.open();
+ tree.createTree();
+ }
+
+ @Override
+ public void endInput(int i) throws Exception {
+ if (1 == i) {
+ LOG.info("Finished adding transactions.");
+ input1Ends = true;
+ } else {
+ LOG.info("Finished adding items.");
+ input2Ends = true;
+ }
+
+ if (input1Ends && input2Ends) {
+ LOG.info("Start to extract fptrees.");
+ endInputs();
+ }
+ }
+
+ public void endInputs() {
+ tree.initialize();
+ tree.printProfile();
+ minSupportCnt =
+ (int)
+ Math.ceil(
+ ((Double)
+ userFunction
+ .getRuntimeContext()
+ .getBroadcastVariable("MIN_COUNT")
+ .get(0)));
+ int[] suffices = new int[itemList.size()];
+ int i = 0;
+ for (Integer item : itemList) {
+ suffices[i++] = item;
+ }
+ tree.extractAll(suffices, minSupportCnt, maxPatternLen, output);
+ tree.destroyTree();
+ LOG.info("itemList size {}.", itemList.size());
+ LOG.info("Finished extracting fptrees.");
+ }
+
+ @Override
+ public void processElement1(StreamRecord> streamRecord)
+ throws Exception {
+ tree.addTransaction(streamRecord.getValue().f1);
+ }
+
+ @Override
+ public void processElement2(StreamRecord> streamRecord)
+ throws Exception {
+ itemList.add(streamRecord.getValue().f0);
+ }
+ }
+
+ /**
+ * Map indices to string in pattern.
+ *
+ * @param patterns A DataStream of tuples of frequent patterns (represented as int array) and
+ * support count.
+ * @param itemIndex A DataStream of tuples of item token and id.
+ * @return A DataStream of frequent patterns, support count and length of the pattern.
+ */
+ private static DataStream patternIndicesToTokens(
+ DataStream> patterns,
+ DataStream> itemIndex) {
+
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ new TypeInformation[] {Types.STRING, Types.LONG, Types.LONG},
+ FREQ_PATTERN_OUTPUT_COLS);
+
+ Map> broadcastMap = new HashMap<>(1);
+ broadcastMap.put(ITEM_INDEX, itemIndex);
+ return BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(patterns),
+ broadcastMap,
+ inputList -> {
+ DataStream freqPatterns = inputList.get(0);
+ return freqPatterns
+ .map(
+ new RichMapFunction, Row>() {
+ Map tokenToId;
+
+ @Override
+ public Row map(Tuple2 pattern)
+ throws Exception {
+ if (null == tokenToId) {
+ tokenToId = new HashMap<>();
+ List> itemIndexList =
+ getRuntimeContext()
+ .getBroadcastVariable("ITEM_INDEX");
+ for (Tuple2 t2 : itemIndexList) {
+ tokenToId.put(t2.f1, t2.f0);
+ }
+ }
+ int len = pattern.f1.length;
+ if (len == 0) {
+ return null;
+ }
+ StringBuilder sbd = new StringBuilder();
+ sbd.append(tokenToId.get(pattern.f1[0]));
+ for (int i = 1; i < len; i++) {
+ sbd.append(",")
+ .append(tokenToId.get(pattern.f1[i]));
+ }
+ return Row.of(
+ sbd.toString(), (long) pattern.f0, (long) len);
+ }
+ })
+ .name("flatMap_id_to_token")
+ .returns(outputTypeInfo);
+ });
+ }
+
+ /**
+ * Count the total rows of input Datastream.
+ *
+ * @param itemTokens A DataStream of input transactions.
+ * @return A DataStream of one record, recording the number of input rows.
+ */
+ private static DataStream> countRecords(
+ DataStream itemTokens) {
+ return DataStreamUtils.aggregate(
+ itemTokens,
+ new AggregateFunction() {
+
+ @Override
+ public Integer createAccumulator() {
+ return 0;
+ }
+
+ @Override
+ public Integer add(String[] strings, Integer count) {
+ if (strings.length > 0) {
+ return count + 1;
+ }
+ return count;
+ }
+
+ @Override
+ public Integer getResult(Integer count) {
+ return count;
+ }
+
+ @Override
+ public Integer merge(Integer count, Integer acc1) {
+ return count + acc1;
+ }
+ })
+ .map(
+ new MapFunction>() {
+ @Override
+ public Tuple2 map(Integer count) throws Exception {
+ return Tuple2.of(-1, count);
+ }
+ })
+ .name("count_transaction")
+ .returns(Types.TUPLE(Types.INT, Types.INT));
+ }
+
+ /**
+ * Count the number of occurence of each item.
+ *
+ * @param items A DataStream of items.
+ * @return A DataStream of tuples of item string and count.
+ */
+ private static DataStream> countItems(DataStream items) {
+ return DataStreamUtils.keyedAggregate(
+ items.keyBy(s -> s),
+ new AggregateFunction, Tuple2>() {
+
+ @Override
+ public Tuple2 createAccumulator() {
+ return Tuple2.of(null, 0);
+ }
+
+ @Override
+ public Tuple2 add(String item, Tuple2 acc) {
+ if (null == acc.f0) {
+ acc.f0 = item;
+ }
+ acc.f1++;
+ return acc;
+ }
+
+ @Override
+ public Tuple2 getResult(Tuple2 t2) {
+ return t2;
+ }
+
+ @Override
+ public Tuple2 merge(
+ Tuple2 acc1, Tuple2 acc2) {
+ acc2.f1 += acc1.f1;
+ return acc2;
+ }
+ },
+ Types.TUPLE(Types.STRING, Types.INT),
+ Types.TUPLE(Types.STRING, Types.INT));
+ }
+
+ /**
+ * Calculate minimal support count of the frequent pattern. If minSupportRate is not null,
+ * return minSupportRate * transactionCount, else return minSupportCount
+ *
+ * @param tenv
+ * @param transactionCount
+ * @param minSupportRate
+ * @param minSupportCount
+ * @return A DataStream of one record, recording the minimal support.
+ */
+ private static DataStream calculateMinSupport(
+ StreamTableEnvironment tenv,
+ DataStream> transactionCount,
+ final Double minSupportRate,
+ final int minSupportCount) {
+ final String supportCount = "MIN_SUPPORT_COUNT";
+ final String supportRate = "MIN_SUPPORT_RATE";
+ Map> bc = new HashMap<>(2);
+ DataStream minSupportCountStream = tenv.toDataStream(tenv.fromValues(minSupportCount));
+ bc.put(supportCount, minSupportCountStream);
+ DataStream minSupportRateStream = tenv.toDataStream(tenv.fromValues(minSupportRate));
+ bc.put(supportRate, minSupportRateStream);
+ return BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(transactionCount),
+ bc,
+ inputLists -> {
+ DataStream transactionCountStream = inputLists.get(0);
+ return transactionCountStream.map(
+ new RichMapFunction, Double>() {
+ @Override
+ public Double map(Tuple2 tuple2)
+ throws Exception {
+ Double bcSupport =
+ ((Row)
+ getRuntimeContext()
+ .getBroadcastVariable(
+ supportRate)
+ .get(0))
+ .getFieldAs(0);
+ int bcCount =
+ ((Row)
+ getRuntimeContext()
+ .getBroadcastVariable(
+ supportCount)
+ .get(0))
+ .getFieldAs(0);
+ if (bcCount > 0) {
+ return (double) bcCount;
+ }
+ return tuple2.f1 * bcSupport;
+ }
+ });
+ });
+ }
+
+ /**
+ * Map item token to indice based on its descending order of count.
+ *
+ * @param tenv the StreamTableEnvironment of execution environment.
+ * @param itemCounts A DataStream of tuples of item and count.
+ * @return A DataStream of tuples of item and count and index.
+ */
+ private static DataStream> assignItemIndex(
+ StreamTableEnvironment tenv, DataStream> itemCounts) {
+ final String itemSupports = "ITEM_SUPPORTS";
+
+ return BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(
+ tenv.toDataStream(tenv.fromValues(Collections.singletonMap(-1, -1)))),
+ Collections.singletonMap(itemSupports, itemCounts),
+ inputList -> {
+ DataStream input = inputList.get(0);
+ return input.flatMap(
+ new RichFlatMapFunction<
+ Row, Tuple3>() {
+ List> supportCount;
+
+ @Override
+ public void flatMap(
+ Row o,
+ Collector>
+ collector)
+ throws Exception {
+ if (null == supportCount) {
+ supportCount =
+ getRuntimeContext()
+ .getBroadcastVariable(itemSupports);
+ }
+ Integer[] order = new Integer[supportCount.size()];
+ for (int i = 0; i < order.length; i++) {
+ order[i] = i;
+ }
+ Arrays.sort(
+ order,
+ new Comparator() {
+ @Override
+ public int compare(Integer o1, Integer o2) {
+ Integer cnt1 = supportCount.get(o1).f1;
+ Integer cnt2 = supportCount.get(o2).f1;
+ if (cnt1.equals(cnt2)) {
+ return supportCount
+ .get(o1)
+ .f0
+ .compareTo(
+ supportCount.get(o2)
+ .f0);
+ }
+ return Integer.compare(cnt2, cnt1);
+ }
+ });
+ for (int i = 0; i < order.length; i++) {
+ collector.collect(
+ Tuple3.of(
+ supportCount.get(order[i]).f0,
+ supportCount.get(order[i]).f1,
+ i));
+ }
+ }
+ })
+ .name("item_indexer");
+ });
+ }
+
+ /**
+ * Map string to indices in transactions.
+ *
+ * @param itemSets A DataStream of transactions.
+ * @param itemIndex A DataStream of tuples of item token and id.
+ * @return A DataStream of tuples of transactions represented as int array
+ */
+ private static DataStream tokensToIndices(
+ DataStream itemSets, DataStream> itemIndex) {
+ Map> broadcastMap = new HashMap<>(1);
+ broadcastMap.put(ITEM_INDEX, itemIndex);
+ return BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(itemSets),
+ broadcastMap,
+ inputList -> {
+ DataStream transactions = inputList.get(0);
+ return transactions
+ .map(
+ new RichMapFunction() {
+ Map tokenToId;
+
+ @Override
+ public int[] map(String[] transaction) throws Exception {
+ if (null == tokenToId) {
+ LOG.info("Trying to get ITEM_INDEX.");
+ tokenToId = new HashMap<>();
+ List> itemIndexList =
+ getRuntimeContext()
+ .getBroadcastVariable("ITEM_INDEX");
+ for (Tuple2 t2 : itemIndexList) {
+ tokenToId.put(t2.f0, t2.f1);
+ }
+ LOG.info(
+ "Size of tokenToId is {}.",
+ tokenToId.size());
+ }
+ int[] items = new int[transaction.length];
+ int len = 0;
+ for (String item : transaction) {
+ Integer id = tokenToId.get(item);
+ if (id != null) {
+ items[len++] = id;
+ }
+ }
+ if (len > 0) {
+ int[] qualified = Arrays.copyOfRange(items, 0, len);
+ Arrays.sort(qualified);
+ return qualified;
+ } else {
+ return new int[0];
+ }
+ }
+ })
+ .name("map_token_to_index")
+ .returns(PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO);
+ });
+ }
+
+ /**
+ * Map indices to string in association rules.
+ *
+ * @param rules A DataStream of tuples of antecedent, consequent, support count, [lift, support,
+ * confidence].
+ * @param itemIndex A DataStream of tuples of item token and id.
+ * @return A DataStream of tuples of row of rules, length of the rule, lift, support, confidence
+ * and support count.
+ */
+ private static DataStream ruleIndexToToken(
+ DataStream> rules,
+ DataStream> itemIndex) {
+
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ new TypeInformation[] {
+ Types.STRING,
+ Types.LONG,
+ Types.DOUBLE,
+ Types.DOUBLE,
+ Types.DOUBLE,
+ Types.LONG
+ },
+ RULES_OUTPUT_COLS);
+
+ Map> broadcastMap = new HashMap<>(1);
+ broadcastMap.put(ITEM_INDEX, itemIndex);
+ return BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(rules),
+ broadcastMap,
+ inputList -> {
+ DataStream freqPatterns = inputList.get(0);
+ return freqPatterns
+ .map(
+ new RichMapFunction<
+ Tuple4, Row>() {
+ Map tokenToId;
+
+ @Override
+ public Row map(Tuple4 rule)
+ throws Exception {
+ if (null == tokenToId) {
+ tokenToId = new HashMap<>();
+ List> itemIndexList =
+ getRuntimeContext()
+ .getBroadcastVariable("ITEM_INDEX");
+ for (Tuple2 t2 : itemIndexList) {
+ tokenToId.put(t2.f1, t2.f0);
+ }
+ }
+ StringBuilder sbd = new StringBuilder();
+ int[] ascent = rule.f0;
+ int[] consq = rule.f1;
+ sbd.append(tokenToId.get(ascent[0]));
+ for (int i = 1; i < ascent.length; i++) {
+ sbd.append(",").append(tokenToId.get(ascent[i]));
+ }
+ sbd.append("=>");
+ sbd.append(tokenToId.get(consq[0]));
+ for (int i = 1; i < consq.length; i++) {
+ sbd.append(",").append(tokenToId.get(consq[i]));
+ }
+ return Row.of(
+ sbd.toString(),
+ (long) (ascent.length + consq.length),
+ rule.f3[0],
+ rule.f3[1],
+ rule.f3[2],
+ (long) rule.f2);
+ }
+ })
+ .returns(outputTypeInfo);
+ });
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static FPGrowth load(StreamTableEnvironment tEnv, String path) throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ @Override
+ public Map, Object> getParamMap() {
+ return paramMap;
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/fpgrowth/FPGrowthParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/fpgrowth/FPGrowthParams.java
new file mode 100644
index 000000000..801a4cbae
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/fpgrowth/FPGrowthParams.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.recommendation.fpgrowth;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Params of {@link FPGrowth}.
+ *
+ * @param The class type of this instance.
+ */
+public interface FPGrowthParams extends WithParams {
+
+ Param ITEMS_COL =
+ new StringParam(
+ "itemsCol", "Item sequence column name.", "items", ParamValidators.notNull());
+ Param FIELD_DELIMITER =
+ new StringParam(
+ "fieldDelimiter",
+ "Field delimiter of item sequence, default delimiter is ','.",
+ ",",
+ ParamValidators.notNull());
+ Param MIN_LIFT =
+ new DoubleParam(
+ "minLift",
+ "Minimal lift level for association rules.",
+ 1.0,
+ ParamValidators.gtEq(0));
+ Param MIN_CONFIDENCE =
+ new DoubleParam(
+ "minConfidence",
+ "Minimal confidence level for association rules.",
+ 0.6,
+ ParamValidators.gtEq(0));
+ Param MIN_SUPPORT =
+ new DoubleParam(
+ "minSupport",
+ "Minimal support percent level. The default value of MIN_SUPPORT is 0.02.",
+ 0.02);
+ Param MIN_SUPPORT_COUNT =
+ new IntParam(
+ "minSupportCount",
+ "Minimal support count. MIN_ITEM_COUNT has no effect when less than or equal to 0, The default value is -1.",
+ -1);
+
+ Param MAX_PATTERN_LENGTH =
+ new IntParam(
+ "maxPatternLength", "Max frequent pattern length.", 10, ParamValidators.gt(0));
+
+ default String getItemsCol() {
+ return get(ITEMS_COL);
+ }
+
+ default T setItemsCol(String value) {
+ return set(ITEMS_COL, value);
+ }
+
+ default String getFieldDelimiter() {
+ return get(FIELD_DELIMITER);
+ }
+
+ default T setFieldDelimiter(String value) {
+ return set(FIELD_DELIMITER, value);
+ }
+
+ default double getMinLift() {
+ return get(MIN_LIFT);
+ }
+
+ default T setMinLift(Double value) {
+ return set(MIN_LIFT, value);
+ }
+
+ default Double getMinSupport() {
+ return get(MIN_SUPPORT);
+ }
+
+ default T setMinSupport(double value) {
+ return set(MIN_SUPPORT, value);
+ }
+
+ default double getMinConfidence() {
+ return get(MIN_CONFIDENCE);
+ }
+
+ default T setMinConfidence(Double value) {
+ return set(MIN_CONFIDENCE, value);
+ }
+
+ default int getMinSupportCount() {
+ return get(MIN_SUPPORT_COUNT);
+ }
+
+ default T setMinSupportCount(Integer value) {
+ return set(MIN_SUPPORT_COUNT, value);
+ }
+
+ default int getMaxPatternLength() {
+ return get(MAX_PATTERN_LENGTH);
+ }
+
+ default T setMaxPatternLength(Integer value) {
+ return set(MAX_PATTERN_LENGTH, value);
+ }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/FPGrowthTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/FPGrowthTest.java
new file mode 100644
index 000000000..610926c40
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/FPGrowthTest.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.recommendation;
+
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.recommendation.fpgrowth.FPGrowth;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link FPGrowth}. */
+public class FPGrowthTest {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private static final int defaultParallelism = 4;
+ private static StreamExecutionEnvironment env;
+ private static StreamTableEnvironment tEnv;
+ private Table inputTable;
+
+ List expectedPatterns =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of("B", 5L, 1L),
+ Row.of("C", 4L, 1L),
+ Row.of("B,C", 4L, 2L),
+ Row.of("A", 3L, 1L),
+ Row.of("B,A", 3L, 2L),
+ Row.of("C,A", 3L, 2L),
+ Row.of("B,C,A", 3L, 3L),
+ Row.of("E", 3L, 1L),
+ Row.of("B,E", 3L, 2L),
+ Row.of("D", 3L, 1L),
+ Row.of("B,D", 3L, 2L)));
+ List expectedRules =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of("B=>E", 2L, 1.0, 0.6, 0.6, 3L),
+ Row.of("B=>C", 2L, 1.0, 0.8, 0.8, 4L),
+ Row.of("A=>B", 2L, 1.0, 0.6, 1.0, 3L),
+ Row.of("B=>A", 2L, 1.0, 0.6, 0.6, 3L),
+ Row.of("A=>C", 2L, 1.25, 0.6, 1.0, 3L),
+ Row.of("D=>B", 2L, 1.0, 0.6, 1.0, 3L),
+ Row.of("B=>D", 2L, 1.0, 0.6, 0.6, 3L),
+ Row.of("C,A=>B", 3L, 1.0, 0.6, 1.0, 3L),
+ Row.of("B,A=>C", 3L, 1.25, 0.6, 1.0, 3L),
+ Row.of("E=>B", 2L, 1.0, 0.6, 1.0, 3L),
+ Row.of("C=>B", 2L, 1.0, 0.8, 1.0, 4L),
+ Row.of("C=>A", 2L, 1.25, 0.6, 0.75, 3L),
+ Row.of("B,C=>A", 3L, 1.25, 0.6, 0.75, 3L)));
+
+ public void checkResult(List expected, CloseableIterator result) {
+ List actual = new ArrayList<>();
+ while (result.hasNext()) {
+ Row row = result.next();
+ actual.add(row);
+ }
+
+ expected.sort(
+ (o1, o2) -> {
+ String s1 = o1.getFieldAs(0);
+ String s2 = o2.getFieldAs(0);
+ return s1.compareTo(s2);
+ });
+
+ actual.sort(
+ (o1, o2) -> {
+ String s1 = o1.getFieldAs(0);
+ String s2 = o2.getFieldAs(0);
+ return s1.compareTo(s2);
+ });
+
+ Assert.assertArrayEquals(expected.toArray(), actual.toArray());
+ }
+
+ @Before
+ public void before() {
+ env = TestUtils.getExecutionEnvironment();
+ env.getConfig().setParallelism(defaultParallelism);
+ tEnv = StreamTableEnvironment.create(env);
+ List inputRows =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(""),
+ Row.of("A,B,C,D"),
+ Row.of("B,C,E"),
+ Row.of("A,B,C,E"),
+ Row.of("B,D,E"),
+ Row.of("A,B,C,D,A")));
+
+ inputTable =
+ tEnv.fromDataStream(
+ env.fromCollection(
+ inputRows,
+ new RowTypeInfo(
+ new TypeInformation[] {BasicTypeInfo.STRING_TYPE_INFO},
+ new String[] {"transactions"})));
+ }
+
+ @Test
+ public void testParam() {
+ FPGrowth fpGrowth = new FPGrowth();
+ assertEquals("items", fpGrowth.getItemsCol());
+ assertEquals(",", fpGrowth.getFieldDelimiter());
+ assertEquals(10, fpGrowth.getMaxPatternLength());
+ assertEquals(0.6, fpGrowth.getMinConfidence(), 1e-9);
+ assertEquals(0.02, fpGrowth.getMinSupport(), 1e-9);
+ assertEquals(-1, fpGrowth.getMinSupportCount());
+ assertEquals(1.0, fpGrowth.getMinLift(), 1e-9);
+
+ fpGrowth.setItemsCol("transactions")
+ .setFieldDelimiter(" ")
+ .setMaxPatternLength(100)
+ .setMinLift(0.5)
+ .setMinConfidence(0.3)
+ .setMinSupport(0.3)
+ .setMinSupportCount(10);
+
+ assertEquals("transactions", fpGrowth.getItemsCol());
+ assertEquals(" ", fpGrowth.getFieldDelimiter());
+ assertEquals(100, fpGrowth.getMaxPatternLength());
+ assertEquals(0.3, fpGrowth.getMinConfidence(), 1e-9);
+ assertEquals(0.3, fpGrowth.getMinSupport(), 1e-9);
+ assertEquals(10, fpGrowth.getMinSupportCount());
+ assertEquals(0.5, fpGrowth.getMinLift(), 1e-9);
+ }
+
+ @Test
+ public void testTransform() {
+ FPGrowth fpGrowth = new FPGrowth().setItemsCol("transactions").setMinSupport(0.6);
+ Table[] results = fpGrowth.transform(inputTable);
+ CloseableIterator patterns = results[0].execute().collect();
+ checkResult(expectedPatterns, patterns);
+ CloseableIterator rules = results[1].execute().collect();
+ checkResult(expectedRules, rules);
+ }
+
+ @Test
+ public void testOutputSchema() {
+ FPGrowth fpGrowth = new FPGrowth().setItemsCol("transactions").setMinSupportCount(3);
+ Table[] results = fpGrowth.transform(inputTable);
+ assertEquals(
+ Arrays.asList("items", "support_count", "item_count"),
+ results[0].getResolvedSchema().getColumnNames());
+ assertEquals(
+ Arrays.asList(
+ "rule",
+ "item_count",
+ "lift",
+ "support_percent",
+ "confidence_percent",
+ "transaction_count"),
+ results[1].getResolvedSchema().getColumnNames());
+ }
+
+ @Test
+ public void testSaveLoadAndTransform() throws Exception {
+ FPGrowth fpGrowth = new FPGrowth().setItemsCol("transactions").setMinSupportCount(3);
+ FPGrowth loadedFPGrowth =
+ TestUtils.saveAndReload(
+ tEnv, fpGrowth, tempFolder.newFolder().getAbsolutePath(), FPGrowth::load);
+ Table[] results = loadedFPGrowth.transform(inputTable);
+ CloseableIterator patterns = results[0].execute().collect();
+ checkResult(expectedPatterns, patterns);
+ CloseableIterator rules = results[1].execute().collect();
+ checkResult(expectedRules, rules);
+ }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/recommendation/fpgrowth_example.py b/flink-ml-python/pyflink/examples/ml/recommendation/fpgrowth_example.py
new file mode 100644
index 000000000..5808ad4a5
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/recommendation/fpgrowth_example.py
@@ -0,0 +1,66 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that creates a fpgrowth instance and gives recommendations for items.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.table import StreamTableEnvironment
+
+from pyflink.ml.recommendation.fpgrowth import FPGrowth
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input data.
+input_table = t_env.from_data_stream(
+ env.from_collection([
+ ("A,B,C,D",),
+ ("B,C,E",),
+ ("A,B,C,E",),
+ ("B,D,E",),
+ ("A,B,C,D",)
+ ],
+ type_info=Types.ROW_NAMED(
+ ['items'],
+ [Types.STRING()])
+ ))
+
+# Creates a fpgrowth object and initialize its parameters.
+fpg = FPGrowth().set_min_support(0.6)
+
+# Transforms the data to fpgrowth algorithm result.
+output_table = fpg.transform(input_table)
+
+# Extracts and display the results.
+pattern_result_names = output_table[0].get_schema().get_field_names()
+rule_result_names = output_table[1].get_schema().get_field_names()
+
+patterns = t_env.to_data_stream(output_table[0]).execute_and_collect()
+rules = t_env.to_data_stream(output_table[1]).execute_and_collect()
+
+print("|\t"+"\t|\t".join(pattern_result_names)+"\t|")
+for result in patterns:
+ print(f'|\t{result[0]}\t|\t{result[1]}\t|\t{result[2]}\t|')
+print("|\t"+" | ".join(rule_result_names)+"\t|")
+for result in rules:
+ print(f'|\t{result[0]}\t|\t{result[1]}\t|\t{result[2]}\t|\t{result[3]}'
+ + f'\t|\t{result[4]}\t|\t{result[5]}\t|')
diff --git a/flink-ml-python/pyflink/ml/recommendation/fpgrowth.py b/flink-ml-python/pyflink/ml/recommendation/fpgrowth.py
new file mode 100644
index 000000000..7cae08ed5
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/recommendation/fpgrowth.py
@@ -0,0 +1,166 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import typing
+
+from pyflink.ml.param import Param, StringParam, IntParam, FloatParam, ParamValidators
+from pyflink.ml.recommendation.common import JavaRecommendationAlgoOperator
+from pyflink.ml.wrapper import JavaWithParams
+
+
+class _FPGrowthParams(
+ JavaWithParams
+):
+ """
+ Params for :class:`FPGrowth`.
+ """
+
+ ITEMS_COL: Param[str] = StringParam(
+ "items_col",
+ "Item sequence column name.",
+ "items",
+ ParamValidators.not_null())
+
+ FIELD_DELIMITER: Param[str] = StringParam(
+ "field_delimiter",
+ "Field delimiter of item sequence, default delimiter is ','.",
+ ",",
+ ParamValidators.not_null())
+
+ MIN_LIFT: Param[float] = FloatParam(
+ "min_lift",
+ "Minimal lift level for association rules.",
+ 1.0,
+ ParamValidators.gt_eq(0))
+
+ MIN_CONFIDENCE: Param[float] = FloatParam(
+ "min_confidence",
+ "Minimal confidence level for association rules.",
+ 0.6,
+ ParamValidators.gt_eq(0))
+
+ MIN_SUPPORT: Param[float] = FloatParam(
+ "min_support",
+ "Minimal support percent. The default value of MIN_SUPPORT is 0.02",
+ 0.02)
+
+ MIN_SUPPORT_COUNT: Param[int] = IntParam(
+ "min_support_count",
+ "Minimal support count. MIN_ITEM_COUNT has no "
+ + "effect when less than or equal to 0, The default value is -1.",
+ -1)
+
+ MAX_PATTERN_LENGTH: Param[int] = FloatParam(
+ "max_pattern_length",
+ "Max frequent pattern length.",
+ 10,
+ ParamValidators.gt(0))
+
+ def __init__(self, java_params):
+ super(_FPGrowthParams, self).__init__(java_params)
+
+ def set_items_col(self, value: str):
+ return typing.cast(_FPGrowthParams, self.set(self.ITEMS_COL, value))
+
+ def get_items_col(self) -> str:
+ return self.get(self.ITEMS_COL)
+
+ def set_field_delimiter(self, value: str):
+ return typing.cast(_FPGrowthParams, self.set(self.FIELD_DELIMITER, value))
+
+ def get_field_delimiter(self) -> str:
+ return self.get(self.FIELD_DELIMITER)
+
+ def set_min_lift(self, value: float):
+ return typing.cast(_FPGrowthParams, self.set(self.MIN_LIFT, value))
+
+ def get_min_lift(self) -> float:
+ return self.get(self.MIN_LIFT)
+
+ def set_min_confidence(self, value: float):
+ return typing.cast(_FPGrowthParams, self.set(self.MIN_CONFIDENCE, value))
+
+ def get_min_confidence(self) -> float:
+ return self.get(self.MIN_CONFIDENCE)
+
+ def set_min_support(self, value: float):
+ return typing.cast(_FPGrowthParams, self.set(self.MIN_SUPPORT, value))
+
+ def get_min_support(self) -> float:
+ return self.get(self.MIN_SUPPORT)
+
+ def set_min_support_count(self, value: int):
+ return typing.cast(_FPGrowthParams, self.set(self.MIN_SUPPORT_COUNT, value))
+
+ def get_min_support_count(self) -> int:
+ return self.get(self.MIN_SUPPORT_COUNT)
+
+ def set_max_pattern_length(self, value: int):
+ return typing.cast(_FPGrowthParams, self.set(self.MAX_PATTERN_LENGTH, value))
+
+ def get_max_pattern_length(self) -> int:
+ return self.get(self.MAX_PATTERN_LENGTH)
+
+ @property
+ def items_col(self) -> str:
+ return self.get_items_col()
+
+ @property
+ def field_delimiter(self) -> str:
+ return self.get_field_delimiter()
+
+ @property
+ def min_lift(self) -> float:
+ return self.get_min_lift()
+
+ @property
+ def min_confidence(self) -> float:
+ return self.get_min_confidence()
+
+ @property
+ def min_support(self) -> float:
+ return self.get_min_support()
+
+ @property
+ def min_support_count(self) -> int:
+ return self.get_min_support_count()
+
+ @property
+ def max_pattern_length(self) -> int:
+ return self.get_max_pattern_length()
+
+
+class FPGrowth(JavaRecommendationAlgoOperator, _FPGrowthParams):
+ """
+ An implementation of parallel FP-growth algorithm to mine frequent itemset.
+
+ For detail descriptions, please refer to: Han et al., Mining frequent patterns without
+ candidate generation. Li et al., PFP:
+ Parallel FP-growth for query recommendation
+ """
+
+ def __init__(self, java_algo_operator=None):
+ super(FPGrowth, self).__init__(java_algo_operator)
+
+ @classmethod
+ def _java_algo_operator_package_name(cls) -> str:
+ return "fpgrowth"
+
+ @classmethod
+ def _java_algo_operator_class_name(cls) -> str:
+ return "FPGrowth"
diff --git a/flink-ml-python/pyflink/ml/recommendation/tests/__init__.py b/flink-ml-python/pyflink/ml/recommendation/tests/__init__.py
index 66981919e..d91d4836e 100644
--- a/flink-ml-python/pyflink/ml/recommendation/tests/__init__.py
+++ b/flink-ml-python/pyflink/ml/recommendation/tests/__init__.py
@@ -22,7 +22,7 @@
# Because the project and the dependent `pyflink` project have the same directory structure,
# we need to manually add `flink-ml-python` path to `sys.path` in the test of this project to change
# the order of package search.
-flink_ml_python_dir = Path(__file__).parents[5]
+flink_ml_python_dir = Path(__file__).parents[4]
sys.path.append(str(flink_ml_python_dir))
import pyflink
diff --git a/flink-ml-python/pyflink/ml/recommendation/tests/test_fpgrowth.py b/flink-ml-python/pyflink/ml/recommendation/tests/test_fpgrowth.py
new file mode 100644
index 000000000..ea682d29d
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/recommendation/tests/test_fpgrowth.py
@@ -0,0 +1,133 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from pyflink.common import Types
+from pyflink.table import Table
+from typing import List
+
+from pyflink.ml.recommendation.fpgrowth import FPGrowth
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+# Tests Swing.
+class SwingTest(PyFlinkMLTestCase):
+ def setUp(self):
+ super(SwingTest, self).setUp()
+ self.input_table = self.t_env.from_data_stream(
+ self.env.from_collection([
+ ("A,B,C,D",),
+ ("B,C,E",),
+ ("A,B,C,E",),
+ ("B,D,E",),
+ ("A,B,C,D",)
+ ],
+ type_info=Types.ROW_NAMED(
+ ['items'],
+ [Types.STRING()])
+ ))
+
+ self.expected_patterns = [
+ ["A", 3, 1],
+ ["B", 5, 1],
+ ["B,A", 3, 2],
+ ["B,C", 4, 2],
+ ["B,C,A", 3, 3],
+ ["B,D", 3, 2],
+ ["B,E", 3, 2],
+ ["C", 4, 1],
+ ["C,A", 3, 2],
+ ["D", 3, 1],
+ ["E", 3, 1]
+ ]
+
+ self.expected_rules = [
+ ["A=>B", 2, 1.0, 0.6, 1.0, 3],
+ ["A=>C", 2, 1.25, 0.6, 1.0, 3],
+ ["B=>A", 2, 1.0, 0.6, 0.6, 3],
+ ["B=>C", 2, 1.0, 0.8, 0.8, 4],
+ ["B=>D", 2, 1.0, 0.6, 0.6, 3],
+ ["B=>E", 2, 1.0, 0.6, 0.6, 3],
+ ["B,A=>C", 3, 1.25, 0.6, 1.0, 3],
+ ["B,C=>A", 3, 1.25, 0.6, 0.75, 3],
+ ["C=>A", 2, 1.25, 0.6, 0.75, 3],
+ ["C=>B", 2, 1.0, 0.8, 1.0, 4],
+ ["C,A=>B", 3, 1.0, 0.6, 1.0, 3],
+ ["D=>B", 2, 1.0, 0.6, 1.0, 3],
+ ["E=>B", 2, 1.0, 0.6, 1.0, 3]
+ ]
+
+ def test_param(self):
+ fpg = FPGrowth()
+ self.assertEqual("items", fpg.items_col)
+ self.assertEqual(",", fpg.field_delimiter)
+ self.assertAlmostEqual(1.0, fpg.min_lift, delta=1e-9)
+ self.assertAlmostEqual(0.6, fpg.min_confidence, delta=1e-9)
+ self.assertAlmostEqual(0.02, fpg.min_support, delta=1e-9)
+ self.assertEqual(-1, fpg.min_support_count)
+ self.assertEqual(10, fpg.max_pattern_length)
+
+ fpg.set_items_col("values") \
+ .set_field_delimiter(" ") \
+ .set_min_lift(1.2) \
+ .set_min_confidence(0.7) \
+ .set_min_support(0.01) \
+ .set_min_support_count(50) \
+ .set_max_pattern_length(5)
+
+ self.assertEqual("values", fpg.items_col)
+ self.assertEqual(" ", fpg.field_delimiter)
+ self.assertAlmostEqual(1.2, fpg.min_lift, delta=1e-9)
+ self.assertAlmostEqual(0.7, fpg.min_confidence, delta=1e-9)
+ self.assertAlmostEqual(0.01, fpg.min_support, delta=1e-9)
+ self.assertEqual(50, fpg.min_support_count)
+ self.assertEqual(5, fpg.max_pattern_length)
+
+ def test_output_schema(self):
+ fpg = FPGrowth()
+ output_tables = fpg.transform(self.input_table)
+ self.assertEqual(
+ ["items", "support_count", "item_count"],
+ output_tables[0].get_schema().get_field_names())
+ self.assertEqual(
+ ["rule", "item_count", "lift", "support_percent",
+ "confidence_percent", "transaction_count"],
+ output_tables[1].get_schema().get_field_names())
+
+ def test_transform(self):
+ fpg = FPGrowth().set_min_support(0.6)
+ output_tables = fpg.transform(self.input_table)
+ self.verify_output_result(output_tables[0], self.expected_patterns)
+ self.verify_output_result(output_tables[1], self.expected_rules)
+
+ def test_save_load_and_transform(self):
+ fpg = FPGrowth().set_min_support_count(3)
+ reloaded_swing = self.save_and_reload(fpg)
+ output_tables = reloaded_swing.transform(self.input_table)
+ self.verify_output_result(output_tables[0], self.expected_patterns)
+ self.verify_output_result(output_tables[1], self.expected_rules)
+
+ def verify_output_result(
+ self, output: Table,
+ expected_result: List):
+ collected_results = [result for result in
+ self.t_env.to_data_stream(output).execute_and_collect()]
+ results = []
+ for result in collected_results:
+ results.append([item for item in result])
+ results.sort(key=lambda x: x[0])
+ expected_result.sort(key=lambda x: x[0])
+ self.assertEqual(expected_result, results)