diff --git a/docs/content/docs/operators/recommendation/fpgrowth.md b/docs/content/docs/operators/recommendation/fpgrowth.md new file mode 100644 index 000000000..005fa2407 --- /dev/null +++ b/docs/content/docs/operators/recommendation/fpgrowth.md @@ -0,0 +1,219 @@ +--- +title: "Swing" +type: docs +aliases: + +- /operators/recommendation/swing.html + +--- + + + +## FPGrowth + +An AlgoOperator which implements the FPGrowth algorithm. + +FPGrowth is an algorithm for frequent pattern mining. FP growth algorithm represents the database in the form of a +tree called a frequent pattern tree or FP tree. + +Ignore NULL values and empty sequence in the feature column during transform(). + +Use distinct elements from a sequence to mine frequent pattern. + +See + +Han et al., Mining frequent patterns without candidate generation, + +Li et al., PFP Parallel FP-growth for query recommendation and + +Borgelt C. An Implementation of the FP-growth Algorithm for more information. + +### Input Columns + +| Param name | Type | Default | Description | +|:-----------|:-------|:----------|:-------------------------------------------| +| itemsCol | String | `"items"` | Items sequence. (e.g. "item1,item2,item3") | + +### Structure of Output Table + +#### Frequent Pattern Table + +| Name | Type | Description | +|:--------------|:----------------|:---------------------------------------------------------| +| items | String | Frequent pattern. | +| support_count | Long | Number of occurrences of the frequent pattern. | +| item_count | Long | Number of elements in the frequent pattern. | + +#### Association Rule Table + +| Name | Type | Description | +|:----------|:-------|:-----------------------------------------------| +| rule | String | Association rule. (e.g. "item1,item2=>item3") | +| item_count | Double | Number of elements in the association rule. | +| lift | Double | Lift. | +| support_percent | Double | Support (frequency of the association rule). | +| confidence_percent | Double | Confidence. | +| transaction_count | Long | Number of occurrences of the association rule. | + +### Parameters + +Below are the parameters required by `FPGrowth`. + +| Key | Default | Type | Required | Description | +|:-----------------------|:----------|:--------|:---------|:--------------------------------------------------------------------------------------| +| itemsCol | `"items"` | String | no | Item sequence column name. | +| fieldDelimiter | `","` | String | no | Field delimiter of item sequence. | +| minLift | `1.0` | Double | no | Minimal lift level for association rules. | +| minConfidence | `0.6` | Double | no | Minimal confidence level for association rules. | +| minSupport | `0.02` | Double | no | Minimal support percent, | +| minSupportCount | `-1` | Double | no | Minimal support count. MIN_ITEM_COUNT has no effect when less than or equal to 0 | +| maxPatternLength | `10` | Integer | no | Max frequent pattern length. | + +### Examples + +{{< tabs examples >}} + +{{< tab "Java">}} + +```java +package org.apache.flink.ml.examples.recommendation; + +import org.apache.flink.ml.recommendation.fpgrowth.FPGrowth; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +/** + * Simple program that creates a Swing instance and uses it to generate recommendations for items. + */ +public class FPGrowthExample { + public static void main(String[] args) { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + // Generates input data. + DataStream inputStream = + env.fromElements( + Row.of(""), + Row.of("A,B,C,D"), + Row.of("B,C,E"), + Row.of("A,B,C,E"), + Row.of("B,D,E"), + Row.of("A,B,C,D,A")); + + Table inputTable = tEnv.fromDataStream(inputStream).as("items"); + + // Creates a FPGrowth object and initializes its parameters. + FPGrowth fpg = new FPGrowth().setMinSupportCount(3); + + // Transforms the data. + Table[] outputTable = fpg.transform(inputTable); + + // Extracts and displays the frequent patterns. + for (CloseableIterator it = outputTable[0].execute().collect(); it.hasNext(); ) { + Row row = it.next(); + + String pattern = row.getFieldAs(0); + Long support = row.getFieldAs(1); + Long itemCount = row.getFieldAs(2); + + System.out.printf("pattern: %d, support count: %d, item_count:%d\n",pattern, support, itemCount); + } + + // Extracts and displays the association rules. + for (CloseableIterator it = outputTable[1].execute().collect(); it.hasNext(); ) { + Row row = it.next(); + + String rule = row.getFieldAs(0); + Double lift = row.getFieldAs(2); + Double support = row.getFieldAs(3); + Double confidence_percent = row.getFieldAs(4); + + System.out.printf("rule: %d, list: %f, support:%f, confidence:%f\n",rule, lift, support, confidence_percent); + } + } +} + + +``` + +{{< /tab>}} + +{{< tab "Python">}} + +```python + +# Simple program that creates a FPGrowth instance and gives recommendations for items. + +from pyflink.common import Types +from pyflink.datastream import StreamExecutionEnvironment +from pyflink.table import StreamTableEnvironment + +from pyflink.ml.recommendation.fpgrowth import FPGrowth + +# Creates a new StreamExecutionEnvironment. +env = StreamExecutionEnvironment.get_execution_environment() + +# Creates a StreamTableEnvironment. +t_env = StreamTableEnvironment.create(env) + +# Generates input data. +input_table = t_env.from_data_stream( + env.from_collection([ + ("A,B,C,D",), + ("B,C,E",), + ("A,B,C,E",), + ("B,D,E",), + ("A,B,C,D",) + ], + type_info=Types.ROW_NAMED( + ['items'], + [Types.STRING()]) + )) + +# Creates a fpgrowth object and initialize its parameters. +fpg = FPGrowth().set_min_support(0.6) + +# Transforms the data to fpgrowth algorithm result. +output_table = fpg.transform(input_table) + +# Extracts and display the results. +pattern_result_names = output_table[0].get_schema().get_field_names() +rule_result_names = output_table[1].get_schema().get_field_names() + +patterns = t_env.to_data_stream(output_table[0]).execute_and_collect() +rules = t_env.to_data_stream(output_table[1]).execute_and_collect() + +print("|\t"+"\t|\t".join(pattern_result_names)+"\t|") +for result in patterns: + print(f'|\t{result[0]}\t|\t{result[1]}\t|\t{result[2]}\t|') +print("|\t"+" | ".join(rule_result_names)+"\t|") +for result in rules: + print(f'|\t{result[0]}\t|\t{result[1]}\t|\t{result[2]}\t|\t{result[3]}' + + f'\t|\t{result[4]}\t|\t{result[5]}\t|') + +``` + +{{< /tab>}} + +{{< /tabs>}} diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java index e4cbcd529..7b9b43557 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java @@ -244,6 +244,59 @@ public static DataStream aggregate( return aggregate(input, func, accType, outType); } + /** + * Applies a {@link AggregateFunction} on a bounded keyed data stream. The output stream + * contains one stream record for each key. + * + * @param input The input keyed data stream. + * @param func The user defined aggredate function. + * @param accType The type information of intermediate data. + * @param outType The type information of the output. + * @return The result data stream. + * @param The key type of input. + * @param The class type of input. + * @param The type of intermediate data. + * @param The class type of output. + */ + public static DataStream keyedAggregate( + KeyedStream input, + AggregateFunction func, + TypeInformation accType, + TypeInformation outType) { + return input.transform( + "Keyed GroupReduce", + outType, + new KeyedAggregateOperator<>( + func, accType.createSerializer(input.getExecutionConfig()))) + .setParallelism(input.getParallelism()); + } + + /** + * Applies a {@link AggregateFunction} on a bounded keyed data stream. The output stream + * contains one stream record for each key. + * + * @param input The input keyed data stream. + * @param func The user defined aggredate function. + * @param accTypeSerializer The type serializer of intermediate data. + * @param outType The type information of the output. + * @return The result data stream. + * @param The key type of input. + * @param The class type of input. + * @param The type of intermediate data. + * @param The class type of output. + */ + public static DataStream keyedAggregate( + KeyedStream input, + AggregateFunction func, + TypeSerializer accTypeSerializer, + TypeInformation outType) { + return input.transform( + "Keyed GroupReduce", + outType, + new KeyedAggregateOperator<>(func, accTypeSerializer)) + .setParallelism(input.getParallelism()); + } + /** * Aggregates the elements in each partition of the input bounded stream, and then merges the * partial results of all partitions. The output stream contains the aggregated result and its @@ -562,6 +615,64 @@ public void snapshotState(StateSnapshotContext context) throws Exception { } } + private static class KeyedAggregateOperator + extends AbstractUdfStreamOperator> + implements OneInputStreamOperator, Triggerable { + + AggregateFunction aggregator; + + private static final String STATE_NAME = "_op_state"; + + private transient ValueState values; + + private final TypeSerializer serializer; + + private InternalTimerService timerService; + + public KeyedAggregateOperator( + AggregateFunction aggregator, TypeSerializer serializer) { + super(aggregator); + this.serializer = serializer; + } + + @Override + public void open() throws Exception { + super.open(); + ValueStateDescriptor stateId = new ValueStateDescriptor<>(STATE_NAME, serializer); + values = getPartitionedState(stateId); + timerService = + getInternalTimerService("end-key-timers", new VoidNamespaceSerializer(), this); + } + + @Override + public void processElement(StreamRecord element) throws Exception { + IN value = element.getValue(); + ACC currentValue = values.value(); + + if (currentValue == null) { + // Registers a timer for emitting the result at the end when this is the + // first input for this key. + timerService.registerEventTimeTimer(VoidNamespace.INSTANCE, Long.MAX_VALUE); + currentValue = userFunction.createAccumulator(); + } + + currentValue = userFunction.add(value, currentValue); + values.update(currentValue); + } + + @Override + public void onEventTime(InternalTimer timer) throws Exception { + ACC currentValue = values.value(); + if (currentValue != null) { + output.collect( + new StreamRecord<>(userFunction.getResult(currentValue), Long.MAX_VALUE)); + } + } + + @Override + public void onProcessingTime(InternalTimer timer) throws Exception {} + } + /** * A stream operator to apply {@link ReduceFunction} on the input bounded keyed data stream. * diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/recommendation/FPGrowthExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/recommendation/FPGrowthExample.java new file mode 100644 index 000000000..6a287c45b --- /dev/null +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/recommendation/FPGrowthExample.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.examples.recommendation; + +import org.apache.flink.ml.recommendation.fpgrowth.FPGrowth; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +/** + * Simple program that creates a FPGrowth instance and uses it to generate frequent patterns and + * association rules. + */ +public class FPGrowthExample { + public static void main(String[] args) { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + // Generates input data. + DataStream inputStream = + env.fromElements( + Row.of(""), + Row.of("A,B,C,D"), + Row.of("B,C,E"), + Row.of("A,B,C,E"), + Row.of("B,D,E"), + Row.of("A,B,C,D,A")); + + Table inputTable = tEnv.fromDataStream(inputStream).as("items"); + + // Creates a FPGrowth object and initializes its parameters. + FPGrowth fpg = new FPGrowth().setMinSupportCount(3); + + // Transforms the data. + Table[] outputTable = fpg.transform(inputTable); + + // Extracts and displays the frequent patterns. + for (CloseableIterator it = outputTable[0].execute().collect(); it.hasNext(); ) { + Row row = it.next(); + + String pattern = row.getFieldAs(0); + Long support = row.getFieldAs(1); + Long itemCount = row.getFieldAs(2); + + System.out.printf( + "pattern: %s, support count: %d, item_count:%d\n", pattern, support, itemCount); + } + + // Extracts and displays the association rules. + for (CloseableIterator it = outputTable[1].execute().collect(); it.hasNext(); ) { + Row row = it.next(); + + String rule = row.getFieldAs(0); + Double lift = row.getFieldAs(2); + Double support = row.getFieldAs(3); + Double confidence_percent = row.getFieldAs(4); + + System.out.printf( + "rule: %s, list: %f, support:%f, confidence:%f\n", + rule, lift, support, confidence_percent); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/fpgrowth/AssociationRule.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/fpgrowth/AssociationRule.java new file mode 100644 index 000000000..4c08ec4b0 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/fpgrowth/AssociationRule.java @@ -0,0 +1,343 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.fpgrowth; + +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.BooleanSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.tuple.Tuple5; +import org.apache.flink.api.java.typeutils.runtime.TupleSerializer; +import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Collector; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.TreeMap; + +/** The class for generating association rules from frequent patterns. */ +public class AssociationRule { + private static final Logger LOG = LoggerFactory.getLogger(AssociationRule.class); + + public static DataStream> extractSingleConsequentRules( + DataStream> patterns, + DataStream> itemCounts, + DataStream> transactionCount, + final double minConfidence, + final double minLift, + final int maxPatternLen) { + + /* preprocess and group the patterns into a format suitable for extracting association rules. */ + DataStream> processedPatterns = + groupPatternsByConseq(patterns); + itemCounts = itemCounts.union(transactionCount); + + DataStream> rules = + processedPatterns + .connect(itemCounts.broadcast()) + .transform( + "ExtractRulesOperator", + Types.TUPLE( + PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO, + PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO, + Types.INT, + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO), + new ExtractRulesOperator(minLift, minConfidence, maxPatternLen)); + return rules; + } + + private static class ExtractRulesOperator + extends AbstractStreamOperator> + implements TwoInputStreamOperator< + Tuple5, + Tuple2, + Tuple4>, + BoundedMultiInput { + final double minLift; + final double minConfidence; + final int maxPatternLen; + private Map itemCounts; + private TreeMap supportMap; + private ListStateWithCache> patternsListState; + private ListStateWithCache> supportMapListState; + double transactionCount = -1; + private boolean flag1; + private boolean flag2; + + ExtractRulesOperator(double minLift, double minConfidence, int maxPatternLen) { + this.minLift = minLift; + this.minConfidence = minConfidence; + this.maxPatternLen = maxPatternLen; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + + TypeInformation> type = + Types.TUPLE(PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO, Types.INT); + + supportMapListState = + new ListStateWithCache<>( + new TupleSerializer<>( + (Class>) (Class) Tuple2.class, + new TypeSerializer[] { + new IntPrimitiveArraySerializer(), new IntSerializer() + }), + getContainingTask(), + getRuntimeContext(), + context, + config.getOperatorID()); + + patternsListState = + new ListStateWithCache<>( + new TupleSerializer<>( + (Class>) + (Class) Tuple3.class, + new TypeSerializer[] { + new IntPrimitiveArraySerializer(), + new IntSerializer(), + new BooleanSerializer() + }), + getContainingTask(), + getRuntimeContext(), + context, + config.getOperatorID()); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + supportMapListState.snapshotState(context); + patternsListState.snapshotState(context); + } + + @Override + public void endInput(int i) throws Exception { + if (1 == i) { + flag2 = true; + } else { + flag1 = true; + } + int idx = getRuntimeContext().getIndexOfThisSubtask(); + + if (flag1 && flag2) { + for (int epoch = 1; epoch <= maxPatternLen; epoch++) { + boolean finished = true; + for (Tuple3 pattern : patternsListState.get()) { + int patternLen = pattern.f0.length; + if (patternLen > epoch) { + finished = false; + } else if (patternLen == epoch) { + processPattern(pattern); + } + } + fillSupportMap(); + if (finished) { + break; + } + } + } + } + + @Override + public void open() throws Exception { + super.open(); + } + + @Override + public void close() throws Exception { + patternsListState.clear(); + super.close(); + } + + @Override + public void processElement1( + StreamRecord> streamRecord) + throws Exception { + Tuple5 pattern = streamRecord.getValue(); + patternsListState.add(Tuple3.of(pattern.f0, pattern.f1, pattern.f4)); + } + + @Override + public void processElement2(StreamRecord> streamRecord) + throws Exception { + Tuple2 itemCount = streamRecord.getValue(); + if (itemCounts == null) { + itemCounts = new HashMap<>(); + } + if (itemCount.f0 == -1) { + this.transactionCount = itemCount.f1; + return; + } + itemCounts.put(itemCount.f0, itemCount.f1); + } + + private void processPattern(Tuple3 pattern) throws Exception { + boolean rotated = pattern.f2; + int[] items = pattern.f0; + if (rotated) { + int[] ante = Arrays.copyOfRange(items, 1, items.length); + int conseq = items[0]; + exportRule(ante, conseq, pattern.f1, output); + } else { + for (int i = 0; i < items.length - 1; i++) { + int[] ante = new int[items.length - 1]; + System.arraycopy(items, 0, ante, 0, i); + System.arraycopy(items, i + 1, ante, i, items.length - i - 1); + int conseq = items[i]; + exportRule(ante, conseq, pattern.f1, output); + } + supportMapListState.add(Tuple2.of(items, pattern.f1)); + } + } + + private void exportRule( + int[] x, + int y, + int suppXY, + Output>> collector) { + Integer suppX = supportMap.get(x); + Integer suppY = itemCounts.get(y); + assert suppX != null && suppY != null; + assert suppX >= suppXY && suppY >= suppXY; + assert transactionCount > 0; + double lift = suppXY * transactionCount / (suppX.doubleValue() * suppY.doubleValue()); + double confidence = suppXY / suppX.doubleValue(); + double support = suppXY / transactionCount; + if (lift >= minLift && confidence >= minConfidence) { + collector.collect( + new StreamRecord<>( + Tuple4.of( + x, + new int[] {y}, + suppXY, + new double[] {lift, support, confidence}))); + } + } + + private void fillSupportMap() { + if (null == supportMap) { + supportMap = new TreeMap<>(ExtractRulesOperator::comparatorFunction); + } else { + supportMap.clear(); + } + try { + for (Tuple2 itemCount : supportMapListState.get()) { + supportMap.put(itemCount.f0, itemCount.f1); + } + } catch (Exception e) { + } + supportMapListState.clear(); + } + + private static int comparatorFunction(int[] o1, int[] o2) { + if (o1.length != o2.length) { + return Integer.compare(o1.length, o2.length); + } + for (int i = 0; i < o1.length; i++) { + if (o1[i] != o2[i]) { + return Integer.compare(o1[i], o2[i]); + } + } + return 0; + } + } + + private static DataStream> + groupPatternsByConseq(DataStream> patterns) { + + DataStream> processedPatterns = + patterns.flatMap( + new RichFlatMapFunction< + Tuple2, + Tuple5>() { + + @Override + public void flatMap( + Tuple2 value, + Collector< + Tuple5< + int[], + Integer, + Integer, + Integer, + Boolean>> + out) + throws Exception { + + int[] items = value.f1; + int itemsLen = items.length; + Tuple5 nonRotatedpattern = + Tuple5.of( + value.f1, + value.f0, + items[itemsLen - 1], + itemsLen, + false); + out.collect(nonRotatedpattern); + if (items.length > 1) { + int tail = items[itemsLen - 1]; + for (int i = itemsLen - 1; i >= 1; i--) { + items[i] = items[i - 1]; + } + items[0] = tail; + Tuple5 rotatedpattern = + Tuple5.of( + items, + value.f0, + items[itemsLen - 1], + itemsLen, + true); + out.collect(rotatedpattern); + } + } + }) + .name("process_pattern_for_extract") + .keyBy(t5 -> t5.f2) + .map(t5 -> t5) + .returns( + Types.TUPLE( + PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO, + Types.INT, + Types.INT, + Types.INT, + Types.BOOLEAN)); + return processedPatterns; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/fpgrowth/FPTree.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/fpgrowth/FPTree.java new file mode 100644 index 000000000..e2f682c8b --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/fpgrowth/FPTree.java @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.fpgrowth; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Implementation of local FPGrowth algorithm. Reference: Christian Borgelt, An Implementation of + * the FP-growth Algorithm. + */ +public class FPTree implements Serializable { + + private static final Logger LOG = LoggerFactory.getLogger(FPTree.class); + + /** The tree node. Notice that no reference to children are kept. */ + private static class Node implements Serializable { + private static final long serialVersionUID = -3963529487030357584L; + int itemId; + int support; + Node parent; + Node successor; + Node auxPtr; + + public Node(int itemId, int support, Node parent) { + this.itemId = itemId; + this.support = support; + this.parent = parent; + this.successor = null; + this.auxPtr = null; + } + } + + /** Summary of an item in the Fp-tree. */ + private static class Summary implements Serializable { + private static final long serialVersionUID = 7641916158660339302L; + /** Number of nodes in the tree. */ + int count; + + /** The head of the linked list of all nodes of an item. */ + Node head; + + public Summary(Node head) { + this.head = head; + } + + public void countAll() { + Node p = head; + count = 0; + while (p != null) { + count += p.support; + p = p.successor; + } + } + + @Override + public String toString() { + StringBuilder sbd = new StringBuilder(); + Node p = head; + while (p != null) { + sbd.append("->") + .append( + String.format( + "(%d,%d,%d)", + p.itemId, + p.support, + p.parent == null ? -1 : p.parent.itemId)); + p = p.successor; + } + return sbd.toString(); + } + } + + private Map summaries; // item -> summary of the item + + // transient data for building trees. + private Map roots; // item -> root node of the item + private Map> itemNodes; // item -> list of nodes of the item + + public FPTree() {} + + private FPTree(Map summaries) { + this.summaries = summaries; + this.summaries.forEach( + (itemId, summary) -> { + summary.countAll(); + }); + } + + public void createTree() { + this.summaries = new HashMap<>(); + this.roots = new HashMap<>(); + this.itemNodes = new HashMap<>(); + } + + public void destroyTree() { + if (summaries != null) { + this.summaries.clear(); + } + if (roots != null) { + this.roots.clear(); + } + if (itemNodes != null) { + this.itemNodes.clear(); + } + } + + public void addTransaction(int[] transaction) { + if (transaction.length == 0) { + return; + } + int firstItem = transaction[0]; + Node curr; + if (roots.containsKey(firstItem)) { + curr = roots.get(firstItem); + curr.support += 1; + } else { + curr = new Node(firstItem, 1, null); + List list = new ArrayList<>(); + list.add(curr); + itemNodes.merge( + firstItem, + list, + (old, delta) -> { + old.addAll(delta); + return old; + }); + roots.put(firstItem, curr); + } + + for (int i = 1; i < transaction.length; i++) { + int item = transaction[i]; + Node p = curr.auxPtr; // use auxPtr as head of siblings + while (p != null && p.itemId != item) { + p = p.successor; + } + if (p != null) { // found + p.support += 1; + curr = p; + } else { // not found + Node newNode = new Node(item, 1, curr); + // insert newNode at the beginning of siblings. + newNode.successor = curr.auxPtr; + curr.auxPtr = newNode; + curr = newNode; + List list = new ArrayList<>(); + list.add(newNode); + itemNodes.merge( + item, + list, + (old, delta) -> { + old.addAll(delta); + return old; + }); + } + } + } + + public void initialize() { + this.itemNodes.forEach( + (item, nodesList) -> { + int n = nodesList.size(); + for (int i = 0; i < n; i++) { + Node curr = nodesList.get(i); + curr.auxPtr = null; + curr.successor = (i + 1) >= n ? null : nodesList.get(i + 1); + } + this.summaries.put(item, new Summary(nodesList.get(0))); + }); + + // clear data buffer + this.roots.clear(); + this.itemNodes.clear(); + + this.summaries.forEach((item, summary) -> summary.countAll()); + } + + /** Project the tree on the given item. */ + private FPTree project(int itemId, int minSupportCnt) { + if (!this.summaries.containsKey(itemId)) { + throw new RuntimeException("not contain item " + itemId); + } + Summary summary = this.summaries.get(itemId); + Map projectedSummaries = new HashMap<>(); + + Node p = summary.head; + while (p != null) { + // trace upward + // auxiliary pointer is copied and linked from its original ancestor f. + Node lastShadow = null; + Node f = p.parent; + while (f != null) { + if (f.auxPtr == null) { + Node shadow = new Node(f.itemId, p.support, null); + if (projectedSummaries.containsKey(shadow.itemId)) { + Summary summary0 = projectedSummaries.get(shadow.itemId); + shadow.successor = summary0.head; + summary0.head = shadow; + } else { + Summary summary0 = new Summary(shadow); + projectedSummaries.put(shadow.itemId, summary0); + } + f.auxPtr = shadow; + } else { // aux ptr already created by another branch + f.auxPtr.support += p.support; + } + if (lastShadow != null) { + // to set parent ptr of auxPtr + lastShadow.parent = f.auxPtr; + } + lastShadow = f.auxPtr; + f = f.parent; + } + p = p.successor; + } + + // prune + Set toPrune = new HashSet<>(); + projectedSummaries.forEach( + (item, s) -> { + s.countAll(); + if (s.count < minSupportCnt) { + toPrune.add(item); + } + }); + toPrune.forEach(projectedSummaries::remove); + + p = summary.head; + while (p != null) { + Node f = p.parent; + if (f != null) { + Node leaf = f.auxPtr; + while (leaf != null && toPrune.contains(leaf.itemId)) { + leaf = leaf.parent; + } + while (leaf != null) { + Node leafParent = leaf.parent; + while (leafParent != null && toPrune.contains(leafParent.itemId)) { + leafParent = leafParent.parent; + } + leaf.parent = leafParent; + leaf = leafParent; + } + } + p = p.successor; + } + + // clear auxPtr + p = summary.head; + while (p != null) { + Node f = p.parent; + while (f != null) { + f.auxPtr = null; + f = f.parent; + } + p = p.successor; + } + + return new FPTree(projectedSummaries); + } + + private void extractImpl( + int minSupportCnt, + int item, + int maxLength, + int[] suffix, + Output>> collector) { + if (maxLength < 1) { + return; + } + Summary summary = summaries.get(item); + if (summary.count < minSupportCnt) { + return; + } + int[] newSuffix = new int[suffix.length + 1]; + newSuffix[0] = item; + System.arraycopy(suffix, 0, newSuffix, 1, suffix.length); + Arrays.sort(newSuffix); + collector.collect(new StreamRecord<>(Tuple2.of(summary.count, newSuffix.clone()))); + if (maxLength == 1) { + return; + } + FPTree projectedTree = this.project(item, minSupportCnt); + projectedTree.summaries.forEach( + (pItem, pSummary) -> { + projectedTree.extractImpl( + minSupportCnt, pItem, maxLength - 1, newSuffix, collector); + }); + } + + public void extractAll( + int[] suffices, + int minSupport, + int maxPatternLength, + Output>> collector) { + for (int item : suffices) { + extractImpl(minSupport, item, maxPatternLength, new int[0], collector); + } + } + + /** Print the tree profile for debugging purpose. */ + public void printProfile() { + // tuple: + // 1) num distinct items in the tree, + // 2) sum of support of each items, + // 3) num tree nodes in the tree + Tuple3 counts = Tuple3.of(0, 0, 0); + summaries.forEach( + (item, summary) -> { + counts.f0 += 1; + counts.f1 += summary.count; + Node p = summary.head; + while (p != null) { + counts.f2 += 1; + p = p.successor; + } + }); + LOG.info("fptree_profile {}", counts); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/fpgrowth/FPGrowth.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/fpgrowth/FPGrowth.java new file mode 100644 index 000000000..ea39d2224 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/fpgrowth/FPGrowth.java @@ -0,0 +1,907 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.fpgrowth; + +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.RichFilterFunction; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.api.common.functions.RichFunction; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.fpgrowth.AssociationRule; +import org.apache.flink.ml.common.fpgrowth.FPTree; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.StringUtils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.Set; + +/** + * An implementation of parallel FP-growth algorithm to mine frequent itemset. + * + *

For detail descriptions, please refer to: Han et al., Mining frequent patterns without + * candidate generation. Li et al., PFP: + * Parallel FP-growth for query recommendation + */ +public class FPGrowth implements AlgoOperator, FPGrowthParams { + private static final Logger LOG = LoggerFactory.getLogger(FPGrowth.class); + private static final String ITEM_INDEX = "ITEM_INDEX"; + private static final String[] FREQ_PATTERN_OUTPUT_COLS = { + "items", "support_count", "item_count" + }; + private static final String[] RULES_OUTPUT_COLS = { + "rule", "item_count", "lift", "support_percent", "confidence_percent", "transaction_count" + }; + + private final Map, Object> paramMap = new HashMap<>(); + + public FPGrowth() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Table data = inputs[0]; + final String itemColName = getItemsCol(); + final double minLift = getMinLift(); + final double minConfidence = getMinConfidence(); + final int maxPatternLen = getMaxPatternLength(); + final String fieldDelimiter = getFieldDelimiter(); + StreamTableEnvironment tenv = + (StreamTableEnvironment) ((TableImpl) data).getTableEnvironment(); + + DataStream itemTokens = + tenv.toDataStream(data) + .map( + new MapFunction() { + + @Override + public String[] map(Row value) throws Exception { + Set itemset = new HashSet<>(); + String itemsetStr = (String) value.getField(itemColName); + if (!StringUtils.isNullOrWhitespaceOnly(itemsetStr)) { + String[] splited = itemsetStr.split(fieldDelimiter); + itemset.addAll(Arrays.asList(splited)); + } + return itemset.toArray(new String[0]); + } + }) + .name("scan_transaction"); + + DataStream items = + itemTokens + .flatMap( + new FlatMapFunction() { + @Override + public void flatMap( + String[] strings, Collector collector) + throws Exception { + for (String s : strings) { + collector.collect(s); + } + } + }) + .returns(Types.STRING); + + // Count the total num of transactions. + DataStream> transactionCount = countRecords(itemTokens); + // Generate a Datastream of minSupport + final Double minSupport = getMinSupport(); + final int minSupportThreshold = getMinSupportCount(); + DataStream minSupportStream = + calculateMinSupport(tenv, transactionCount, minSupport, minSupportThreshold); + // Count the total number of each item. + DataStream> itemCounts = countItems(items); + // Drop items with support smaller than requirement. + final String minSuppoerCountDouble = "MIN_SUPPOER_COUNT_DOUBLE"; + itemCounts = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(itemCounts), + Collections.singletonMap(minSuppoerCountDouble, minSupportStream), + inputList -> { + DataStream input = inputList.get(0); + return input.filter( + new RichFilterFunction>() { + Double minSuppport = null; + + @Override + public boolean filter(Tuple2 o) + throws Exception { + if (null == minSuppport) { + minSuppport = + (double) + getRuntimeContext() + .getBroadcastVariable( + minSuppoerCountDouble) + .get(0); + } + if (o.f1 < minSupportThreshold) { + return false; + } else { + return true; + } + } + }); + }); + + // Assign items with indices, ordered by their support from high to low. + DataStream> itemCountIndex = + assignItemIndex(tenv, itemCounts); + // Assign items with partition id. + DataStream> itemPid = partitionItems(itemCountIndex); + + DataStream> itemIndex = + itemCountIndex + .map(t3 -> Tuple2.of(t3.f0, t3.f2)) + .returns(Types.TUPLE(Types.STRING, Types.INT)); + DataStream> itemCount = + itemCountIndex + .map(t3 -> Tuple2.of(t3.f2, t3.f1)) + .returns(Types.TUPLE(Types.INT, Types.INT)); + + DataStream transactions = tokensToIndices(itemTokens, itemIndex); + DataStream> transactionGroups = + genCondTransactions(transactions, itemPid); + + // Extract all frequent patterns. + DataStream> indexPatterns = + mineFreqPattern(transactionGroups, itemPid, maxPatternLen, minSupportStream); + DataStream tokenPatterns = patternIndicesToTokens(indexPatterns, itemIndex); + + // Extract consequent rules from frequent patterns. + DataStream> rules = + AssociationRule.extractSingleConsequentRules( + indexPatterns, + itemCount, + transactionCount, + minConfidence, + minLift, + maxPatternLen); + DataStream rulesToken = ruleIndexToToken(rules, itemIndex); + + Table patternTable = tenv.fromDataStream(tokenPatterns); + Table rulesTable = tenv.fromDataStream(rulesToken); + + return new Table[] {patternTable, rulesTable}; + } + + /** + * Generate items partition. To achieve load balance, we assign to each item a score that + * represents its estimation of number of nodes in the Fp-tree. Then we greedily partition the + * items to balance the sum of scores in each partition. + * + * @param itemCountIndex A DataStream of tuples of item token, count and id. + * @return A DataStream of tuples of item id and partition id + */ + private static DataStream> partitionItems( + DataStream> itemCountIndex) { + DataStream> partition = + itemCountIndex.transform( + "ComputingPartitionCost", + Types.TUPLE(Types.INT, Types.INT), + new ComputingPartitionCost()); + return partition; + } + + private static class ComputingPartitionCost + extends AbstractStreamOperator> + implements OneInputStreamOperator< + Tuple3, Tuple2>, + BoundedOneInput { + List> itemCounts; + + private ComputingPartitionCost() { + itemCounts = new ArrayList<>(); + } + + @Override + public void endInput() throws Exception { + int numPartitions = getRuntimeContext().getNumberOfParallelSubtasks(); + + PriorityQueue> queue = + new PriorityQueue<>(numPartitions, Comparator.comparingDouble(o -> o.f1)); + + for (int i = 0; i < numPartitions; i++) { + queue.add(Tuple2.of(i, 0.0)); + } + + List scaledItemCount = new ArrayList<>(itemCounts.size()); + for (int i = 0; i < itemCounts.size(); i++) { + Tuple2 item = itemCounts.get(i); + double pos = (double) (item.f0 + 1) / ((double) itemCounts.size()); + double score = pos * item.f1.doubleValue(); + scaledItemCount.add(score); + } + + List order = new ArrayList<>(itemCounts.size()); + for (int i = 0; i < itemCounts.size(); i++) { + order.add(i); + } + + order.sort( + (o1, o2) -> { + double s1 = scaledItemCount.get(o1); + double s2 = scaledItemCount.get(o2); + return Double.compare(s2, s1); + }); + + // greedily assign partition number to each item + for (int i = 0; i < itemCounts.size(); i++) { + Tuple2 item = itemCounts.get(order.get(i)); + double score = scaledItemCount.get(order.get(i)); + Tuple2 target = queue.poll(); + int targetPartition = target.f0; + target.f1 += score; + queue.add(target); + output.collect(new StreamRecord<>(Tuple2.of(item.f0, targetPartition))); + } + } + + @Override + public void processElement(StreamRecord> streamRecord) + throws Exception { + Tuple3 t3 = streamRecord.getValue(); + itemCounts.add(Tuple2.of(t3.f2, t3.f1)); + } + } + + /** + * Generate conditional transactions for each partitions. + * + *

Scan from the longest substring in a transaction, partition the substring into the group + * where its last element belongs. If the partition already contains a longer substring, skip + * it. + * + * @param transactions A DataStream of transactions. + * @param targetPartition A DataStream of tuples of item and partition number. + * @return substring of transactions and partition number + */ + private static DataStream> genCondTransactions( + DataStream transactions, DataStream> targetPartition) { + final String itemPartition = "ITEM_PARTITION"; + Map> broadcastMap = new HashMap<>(1); + broadcastMap.put(itemPartition, targetPartition); + return BroadcastUtils.withBroadcastStream( + Collections.singletonList(transactions), + broadcastMap, + inputLists -> { + DataStream transactionStream = inputLists.get(0); + return transactionStream.flatMap( + new RichFlatMapFunction>() { + transient Map partitioner; + // list of flags used to skip partition that is not empty + transient int[] flags; + + @Override + public void flatMap( + int[] transaction, Collector> out) + throws Exception { + if (null == flags) { + int numPartition = + getRuntimeContext().getNumberOfParallelSubtasks(); + this.flags = new int[numPartition]; + } + if (null == partitioner) { + List> bc = + getRuntimeContext() + .getBroadcastVariable(itemPartition); + partitioner = new HashMap<>(); + for (Tuple2 t2 : bc) { + partitioner.put(t2.f0, t2.f1); + } + } + Arrays.fill(flags, 0); + int cnt = transaction.length; + // starts from the longest substring + for (; cnt > 0; cnt--) { + int lastPos = cnt - 1; + int partition = this.partitioner.get(transaction[lastPos]); + if (flags[partition] == 0) { + List condTransaction = new ArrayList<>(cnt); + for (int j = 0; j < cnt; j++) { + condTransaction.add(transaction[j]); + } + int[] tr = new int[condTransaction.size()]; + for (int j = 0; j < tr.length; j++) { + tr[j] = condTransaction.get(j); + } + out.collect(Tuple2.of(partition, tr)); + flags[partition] = 1; + } + } + } + }); + }); + } + + /** + * Mine frequent patterns locally in each partition. + * + * @param condTransactions The conditional transactions with partition id. + * @param partitioner A DataStream of tuples of item id and partition id. + * @param maxPatternLength Maximum pattern length. + * @return A DataStream of tuples of count and frequent patterns. + */ + private static DataStream> mineFreqPattern( + DataStream> condTransactions, + DataStream> partitioner, + int maxPatternLength, + DataStream minSupport) { + condTransactions = + condTransactions.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f0); + partitioner = partitioner.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f1); + + List> inputList = new ArrayList<>(); + inputList.add(condTransactions); + inputList.add(partitioner); + return BroadcastUtils.withBroadcastStream( + inputList, + Collections.singletonMap("MIN_COUNT", minSupport), + inputLists -> { + DataStream condStream = inputLists.get(0); + DataStream partitionStream = inputLists.get(1); + return condStream + .connect(partitionStream) + .transform( + "mine-freq-pattern", + Types.TUPLE( + Types.INT, + PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO), + new FpTreeConstructor(maxPatternLength)); + }); + } + + private static class MyRichFunction extends AbstractRichFunction {} + + /* The Operator to construct fp-trees of each worker. */ + // private static class FpTreeConstructor extends AbstractStreamOperator> + private static class FpTreeConstructor + extends AbstractUdfStreamOperator, RichFunction> + implements TwoInputStreamOperator< + Tuple2, + Tuple2, + Tuple2>, + BoundedMultiInput { + + private final FPTree tree = new FPTree(); + private boolean input1Ends; + private boolean input2Ends; + private int minSupportCnt; + private final int maxPatternLen; + private Set itemList = new HashSet<>(); + + FpTreeConstructor(int maxPatternLength) { + super(new MyRichFunction()); + maxPatternLen = maxPatternLength; + } + + @Override + public void open() throws Exception { + super.open(); + tree.createTree(); + } + + @Override + public void endInput(int i) throws Exception { + if (1 == i) { + LOG.info("Finished adding transactions."); + input1Ends = true; + } else { + LOG.info("Finished adding items."); + input2Ends = true; + } + + if (input1Ends && input2Ends) { + LOG.info("Start to extract fptrees."); + endInputs(); + } + } + + public void endInputs() { + tree.initialize(); + tree.printProfile(); + minSupportCnt = + (int) + Math.ceil( + ((Double) + userFunction + .getRuntimeContext() + .getBroadcastVariable("MIN_COUNT") + .get(0))); + int[] suffices = new int[itemList.size()]; + int i = 0; + for (Integer item : itemList) { + suffices[i++] = item; + } + tree.extractAll(suffices, minSupportCnt, maxPatternLen, output); + tree.destroyTree(); + LOG.info("itemList size {}.", itemList.size()); + LOG.info("Finished extracting fptrees."); + } + + @Override + public void processElement1(StreamRecord> streamRecord) + throws Exception { + tree.addTransaction(streamRecord.getValue().f1); + } + + @Override + public void processElement2(StreamRecord> streamRecord) + throws Exception { + itemList.add(streamRecord.getValue().f0); + } + } + + /** + * Map indices to string in pattern. + * + * @param patterns A DataStream of tuples of frequent patterns (represented as int array) and + * support count. + * @param itemIndex A DataStream of tuples of item token and id. + * @return A DataStream of frequent patterns, support count and length of the pattern. + */ + private static DataStream patternIndicesToTokens( + DataStream> patterns, + DataStream> itemIndex) { + + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + new TypeInformation[] {Types.STRING, Types.LONG, Types.LONG}, + FREQ_PATTERN_OUTPUT_COLS); + + Map> broadcastMap = new HashMap<>(1); + broadcastMap.put(ITEM_INDEX, itemIndex); + return BroadcastUtils.withBroadcastStream( + Collections.singletonList(patterns), + broadcastMap, + inputList -> { + DataStream freqPatterns = inputList.get(0); + return freqPatterns + .map( + new RichMapFunction, Row>() { + Map tokenToId; + + @Override + public Row map(Tuple2 pattern) + throws Exception { + if (null == tokenToId) { + tokenToId = new HashMap<>(); + List> itemIndexList = + getRuntimeContext() + .getBroadcastVariable("ITEM_INDEX"); + for (Tuple2 t2 : itemIndexList) { + tokenToId.put(t2.f1, t2.f0); + } + } + int len = pattern.f1.length; + if (len == 0) { + return null; + } + StringBuilder sbd = new StringBuilder(); + sbd.append(tokenToId.get(pattern.f1[0])); + for (int i = 1; i < len; i++) { + sbd.append(",") + .append(tokenToId.get(pattern.f1[i])); + } + return Row.of( + sbd.toString(), (long) pattern.f0, (long) len); + } + }) + .name("flatMap_id_to_token") + .returns(outputTypeInfo); + }); + } + + /** + * Count the total rows of input Datastream. + * + * @param itemTokens A DataStream of input transactions. + * @return A DataStream of one record, recording the number of input rows. + */ + private static DataStream> countRecords( + DataStream itemTokens) { + return DataStreamUtils.aggregate( + itemTokens, + new AggregateFunction() { + + @Override + public Integer createAccumulator() { + return 0; + } + + @Override + public Integer add(String[] strings, Integer count) { + if (strings.length > 0) { + return count + 1; + } + return count; + } + + @Override + public Integer getResult(Integer count) { + return count; + } + + @Override + public Integer merge(Integer count, Integer acc1) { + return count + acc1; + } + }) + .map( + new MapFunction>() { + @Override + public Tuple2 map(Integer count) throws Exception { + return Tuple2.of(-1, count); + } + }) + .name("count_transaction") + .returns(Types.TUPLE(Types.INT, Types.INT)); + } + + /** + * Count the number of occurence of each item. + * + * @param items A DataStream of items. + * @return A DataStream of tuples of item string and count. + */ + private static DataStream> countItems(DataStream items) { + return DataStreamUtils.keyedAggregate( + items.keyBy(s -> s), + new AggregateFunction, Tuple2>() { + + @Override + public Tuple2 createAccumulator() { + return Tuple2.of(null, 0); + } + + @Override + public Tuple2 add(String item, Tuple2 acc) { + if (null == acc.f0) { + acc.f0 = item; + } + acc.f1++; + return acc; + } + + @Override + public Tuple2 getResult(Tuple2 t2) { + return t2; + } + + @Override + public Tuple2 merge( + Tuple2 acc1, Tuple2 acc2) { + acc2.f1 += acc1.f1; + return acc2; + } + }, + Types.TUPLE(Types.STRING, Types.INT), + Types.TUPLE(Types.STRING, Types.INT)); + } + + /** + * Calculate minimal support count of the frequent pattern. If minSupportRate is not null, + * return minSupportRate * transactionCount, else return minSupportCount + * + * @param tenv + * @param transactionCount + * @param minSupportRate + * @param minSupportCount + * @return A DataStream of one record, recording the minimal support. + */ + private static DataStream calculateMinSupport( + StreamTableEnvironment tenv, + DataStream> transactionCount, + final Double minSupportRate, + final int minSupportCount) { + final String supportCount = "MIN_SUPPORT_COUNT"; + final String supportRate = "MIN_SUPPORT_RATE"; + Map> bc = new HashMap<>(2); + DataStream minSupportCountStream = tenv.toDataStream(tenv.fromValues(minSupportCount)); + bc.put(supportCount, minSupportCountStream); + DataStream minSupportRateStream = tenv.toDataStream(tenv.fromValues(minSupportRate)); + bc.put(supportRate, minSupportRateStream); + return BroadcastUtils.withBroadcastStream( + Collections.singletonList(transactionCount), + bc, + inputLists -> { + DataStream transactionCountStream = inputLists.get(0); + return transactionCountStream.map( + new RichMapFunction, Double>() { + @Override + public Double map(Tuple2 tuple2) + throws Exception { + Double bcSupport = + ((Row) + getRuntimeContext() + .getBroadcastVariable( + supportRate) + .get(0)) + .getFieldAs(0); + int bcCount = + ((Row) + getRuntimeContext() + .getBroadcastVariable( + supportCount) + .get(0)) + .getFieldAs(0); + if (bcCount > 0) { + return (double) bcCount; + } + return tuple2.f1 * bcSupport; + } + }); + }); + } + + /** + * Map item token to indice based on its descending order of count. + * + * @param tenv the StreamTableEnvironment of execution environment. + * @param itemCounts A DataStream of tuples of item and count. + * @return A DataStream of tuples of item and count and index. + */ + private static DataStream> assignItemIndex( + StreamTableEnvironment tenv, DataStream> itemCounts) { + final String itemSupports = "ITEM_SUPPORTS"; + + return BroadcastUtils.withBroadcastStream( + Collections.singletonList( + tenv.toDataStream(tenv.fromValues(Collections.singletonMap(-1, -1)))), + Collections.singletonMap(itemSupports, itemCounts), + inputList -> { + DataStream input = inputList.get(0); + return input.flatMap( + new RichFlatMapFunction< + Row, Tuple3>() { + List> supportCount; + + @Override + public void flatMap( + Row o, + Collector> + collector) + throws Exception { + if (null == supportCount) { + supportCount = + getRuntimeContext() + .getBroadcastVariable(itemSupports); + } + Integer[] order = new Integer[supportCount.size()]; + for (int i = 0; i < order.length; i++) { + order[i] = i; + } + Arrays.sort( + order, + new Comparator() { + @Override + public int compare(Integer o1, Integer o2) { + Integer cnt1 = supportCount.get(o1).f1; + Integer cnt2 = supportCount.get(o2).f1; + if (cnt1.equals(cnt2)) { + return supportCount + .get(o1) + .f0 + .compareTo( + supportCount.get(o2) + .f0); + } + return Integer.compare(cnt2, cnt1); + } + }); + for (int i = 0; i < order.length; i++) { + collector.collect( + Tuple3.of( + supportCount.get(order[i]).f0, + supportCount.get(order[i]).f1, + i)); + } + } + }) + .name("item_indexer"); + }); + } + + /** + * Map string to indices in transactions. + * + * @param itemSets A DataStream of transactions. + * @param itemIndex A DataStream of tuples of item token and id. + * @return A DataStream of tuples of transactions represented as int array + */ + private static DataStream tokensToIndices( + DataStream itemSets, DataStream> itemIndex) { + Map> broadcastMap = new HashMap<>(1); + broadcastMap.put(ITEM_INDEX, itemIndex); + return BroadcastUtils.withBroadcastStream( + Collections.singletonList(itemSets), + broadcastMap, + inputList -> { + DataStream transactions = inputList.get(0); + return transactions + .map( + new RichMapFunction() { + Map tokenToId; + + @Override + public int[] map(String[] transaction) throws Exception { + if (null == tokenToId) { + LOG.info("Trying to get ITEM_INDEX."); + tokenToId = new HashMap<>(); + List> itemIndexList = + getRuntimeContext() + .getBroadcastVariable("ITEM_INDEX"); + for (Tuple2 t2 : itemIndexList) { + tokenToId.put(t2.f0, t2.f1); + } + LOG.info( + "Size of tokenToId is {}.", + tokenToId.size()); + } + int[] items = new int[transaction.length]; + int len = 0; + for (String item : transaction) { + Integer id = tokenToId.get(item); + if (id != null) { + items[len++] = id; + } + } + if (len > 0) { + int[] qualified = Arrays.copyOfRange(items, 0, len); + Arrays.sort(qualified); + return qualified; + } else { + return new int[0]; + } + } + }) + .name("map_token_to_index") + .returns(PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO); + }); + } + + /** + * Map indices to string in association rules. + * + * @param rules A DataStream of tuples of antecedent, consequent, support count, [lift, support, + * confidence]. + * @param itemIndex A DataStream of tuples of item token and id. + * @return A DataStream of tuples of row of rules, length of the rule, lift, support, confidence + * and support count. + */ + private static DataStream ruleIndexToToken( + DataStream> rules, + DataStream> itemIndex) { + + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + new TypeInformation[] { + Types.STRING, + Types.LONG, + Types.DOUBLE, + Types.DOUBLE, + Types.DOUBLE, + Types.LONG + }, + RULES_OUTPUT_COLS); + + Map> broadcastMap = new HashMap<>(1); + broadcastMap.put(ITEM_INDEX, itemIndex); + return BroadcastUtils.withBroadcastStream( + Collections.singletonList(rules), + broadcastMap, + inputList -> { + DataStream freqPatterns = inputList.get(0); + return freqPatterns + .map( + new RichMapFunction< + Tuple4, Row>() { + Map tokenToId; + + @Override + public Row map(Tuple4 rule) + throws Exception { + if (null == tokenToId) { + tokenToId = new HashMap<>(); + List> itemIndexList = + getRuntimeContext() + .getBroadcastVariable("ITEM_INDEX"); + for (Tuple2 t2 : itemIndexList) { + tokenToId.put(t2.f1, t2.f0); + } + } + StringBuilder sbd = new StringBuilder(); + int[] ascent = rule.f0; + int[] consq = rule.f1; + sbd.append(tokenToId.get(ascent[0])); + for (int i = 1; i < ascent.length; i++) { + sbd.append(",").append(tokenToId.get(ascent[i])); + } + sbd.append("=>"); + sbd.append(tokenToId.get(consq[0])); + for (int i = 1; i < consq.length; i++) { + sbd.append(",").append(tokenToId.get(consq[i])); + } + return Row.of( + sbd.toString(), + (long) (ascent.length + consq.length), + rule.f3[0], + rule.f3[1], + rule.f3[2], + (long) rule.f2); + } + }) + .returns(outputTypeInfo); + }); + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static FPGrowth load(StreamTableEnvironment tEnv, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map, Object> getParamMap() { + return paramMap; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/fpgrowth/FPGrowthParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/fpgrowth/FPGrowthParams.java new file mode 100644 index 000000000..801a4cbae --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/recommendation/fpgrowth/FPGrowthParams.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation.fpgrowth; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** + * Params of {@link FPGrowth}. + * + * @param The class type of this instance. + */ +public interface FPGrowthParams extends WithParams { + + Param ITEMS_COL = + new StringParam( + "itemsCol", "Item sequence column name.", "items", ParamValidators.notNull()); + Param FIELD_DELIMITER = + new StringParam( + "fieldDelimiter", + "Field delimiter of item sequence, default delimiter is ','.", + ",", + ParamValidators.notNull()); + Param MIN_LIFT = + new DoubleParam( + "minLift", + "Minimal lift level for association rules.", + 1.0, + ParamValidators.gtEq(0)); + Param MIN_CONFIDENCE = + new DoubleParam( + "minConfidence", + "Minimal confidence level for association rules.", + 0.6, + ParamValidators.gtEq(0)); + Param MIN_SUPPORT = + new DoubleParam( + "minSupport", + "Minimal support percent level. The default value of MIN_SUPPORT is 0.02.", + 0.02); + Param MIN_SUPPORT_COUNT = + new IntParam( + "minSupportCount", + "Minimal support count. MIN_ITEM_COUNT has no effect when less than or equal to 0, The default value is -1.", + -1); + + Param MAX_PATTERN_LENGTH = + new IntParam( + "maxPatternLength", "Max frequent pattern length.", 10, ParamValidators.gt(0)); + + default String getItemsCol() { + return get(ITEMS_COL); + } + + default T setItemsCol(String value) { + return set(ITEMS_COL, value); + } + + default String getFieldDelimiter() { + return get(FIELD_DELIMITER); + } + + default T setFieldDelimiter(String value) { + return set(FIELD_DELIMITER, value); + } + + default double getMinLift() { + return get(MIN_LIFT); + } + + default T setMinLift(Double value) { + return set(MIN_LIFT, value); + } + + default Double getMinSupport() { + return get(MIN_SUPPORT); + } + + default T setMinSupport(double value) { + return set(MIN_SUPPORT, value); + } + + default double getMinConfidence() { + return get(MIN_CONFIDENCE); + } + + default T setMinConfidence(Double value) { + return set(MIN_CONFIDENCE, value); + } + + default int getMinSupportCount() { + return get(MIN_SUPPORT_COUNT); + } + + default T setMinSupportCount(Integer value) { + return set(MIN_SUPPORT_COUNT, value); + } + + default int getMaxPatternLength() { + return get(MAX_PATTERN_LENGTH); + } + + default T setMaxPatternLength(Integer value) { + return set(MAX_PATTERN_LENGTH, value); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/FPGrowthTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/FPGrowthTest.java new file mode 100644 index 000000000..610926c40 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/recommendation/FPGrowthTest.java @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.recommendation; + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.recommendation.fpgrowth.FPGrowth; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** Tests {@link FPGrowth}. */ +public class FPGrowthTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private static final int defaultParallelism = 4; + private static StreamExecutionEnvironment env; + private static StreamTableEnvironment tEnv; + private Table inputTable; + + List expectedPatterns = + new ArrayList<>( + Arrays.asList( + Row.of("B", 5L, 1L), + Row.of("C", 4L, 1L), + Row.of("B,C", 4L, 2L), + Row.of("A", 3L, 1L), + Row.of("B,A", 3L, 2L), + Row.of("C,A", 3L, 2L), + Row.of("B,C,A", 3L, 3L), + Row.of("E", 3L, 1L), + Row.of("B,E", 3L, 2L), + Row.of("D", 3L, 1L), + Row.of("B,D", 3L, 2L))); + List expectedRules = + new ArrayList<>( + Arrays.asList( + Row.of("B=>E", 2L, 1.0, 0.6, 0.6, 3L), + Row.of("B=>C", 2L, 1.0, 0.8, 0.8, 4L), + Row.of("A=>B", 2L, 1.0, 0.6, 1.0, 3L), + Row.of("B=>A", 2L, 1.0, 0.6, 0.6, 3L), + Row.of("A=>C", 2L, 1.25, 0.6, 1.0, 3L), + Row.of("D=>B", 2L, 1.0, 0.6, 1.0, 3L), + Row.of("B=>D", 2L, 1.0, 0.6, 0.6, 3L), + Row.of("C,A=>B", 3L, 1.0, 0.6, 1.0, 3L), + Row.of("B,A=>C", 3L, 1.25, 0.6, 1.0, 3L), + Row.of("E=>B", 2L, 1.0, 0.6, 1.0, 3L), + Row.of("C=>B", 2L, 1.0, 0.8, 1.0, 4L), + Row.of("C=>A", 2L, 1.25, 0.6, 0.75, 3L), + Row.of("B,C=>A", 3L, 1.25, 0.6, 0.75, 3L))); + + public void checkResult(List expected, CloseableIterator result) { + List actual = new ArrayList<>(); + while (result.hasNext()) { + Row row = result.next(); + actual.add(row); + } + + expected.sort( + (o1, o2) -> { + String s1 = o1.getFieldAs(0); + String s2 = o2.getFieldAs(0); + return s1.compareTo(s2); + }); + + actual.sort( + (o1, o2) -> { + String s1 = o1.getFieldAs(0); + String s2 = o2.getFieldAs(0); + return s1.compareTo(s2); + }); + + Assert.assertArrayEquals(expected.toArray(), actual.toArray()); + } + + @Before + public void before() { + env = TestUtils.getExecutionEnvironment(); + env.getConfig().setParallelism(defaultParallelism); + tEnv = StreamTableEnvironment.create(env); + List inputRows = + new ArrayList<>( + Arrays.asList( + Row.of(""), + Row.of("A,B,C,D"), + Row.of("B,C,E"), + Row.of("A,B,C,E"), + Row.of("B,D,E"), + Row.of("A,B,C,D,A"))); + + inputTable = + tEnv.fromDataStream( + env.fromCollection( + inputRows, + new RowTypeInfo( + new TypeInformation[] {BasicTypeInfo.STRING_TYPE_INFO}, + new String[] {"transactions"}))); + } + + @Test + public void testParam() { + FPGrowth fpGrowth = new FPGrowth(); + assertEquals("items", fpGrowth.getItemsCol()); + assertEquals(",", fpGrowth.getFieldDelimiter()); + assertEquals(10, fpGrowth.getMaxPatternLength()); + assertEquals(0.6, fpGrowth.getMinConfidence(), 1e-9); + assertEquals(0.02, fpGrowth.getMinSupport(), 1e-9); + assertEquals(-1, fpGrowth.getMinSupportCount()); + assertEquals(1.0, fpGrowth.getMinLift(), 1e-9); + + fpGrowth.setItemsCol("transactions") + .setFieldDelimiter(" ") + .setMaxPatternLength(100) + .setMinLift(0.5) + .setMinConfidence(0.3) + .setMinSupport(0.3) + .setMinSupportCount(10); + + assertEquals("transactions", fpGrowth.getItemsCol()); + assertEquals(" ", fpGrowth.getFieldDelimiter()); + assertEquals(100, fpGrowth.getMaxPatternLength()); + assertEquals(0.3, fpGrowth.getMinConfidence(), 1e-9); + assertEquals(0.3, fpGrowth.getMinSupport(), 1e-9); + assertEquals(10, fpGrowth.getMinSupportCount()); + assertEquals(0.5, fpGrowth.getMinLift(), 1e-9); + } + + @Test + public void testTransform() { + FPGrowth fpGrowth = new FPGrowth().setItemsCol("transactions").setMinSupport(0.6); + Table[] results = fpGrowth.transform(inputTable); + CloseableIterator patterns = results[0].execute().collect(); + checkResult(expectedPatterns, patterns); + CloseableIterator rules = results[1].execute().collect(); + checkResult(expectedRules, rules); + } + + @Test + public void testOutputSchema() { + FPGrowth fpGrowth = new FPGrowth().setItemsCol("transactions").setMinSupportCount(3); + Table[] results = fpGrowth.transform(inputTable); + assertEquals( + Arrays.asList("items", "support_count", "item_count"), + results[0].getResolvedSchema().getColumnNames()); + assertEquals( + Arrays.asList( + "rule", + "item_count", + "lift", + "support_percent", + "confidence_percent", + "transaction_count"), + results[1].getResolvedSchema().getColumnNames()); + } + + @Test + public void testSaveLoadAndTransform() throws Exception { + FPGrowth fpGrowth = new FPGrowth().setItemsCol("transactions").setMinSupportCount(3); + FPGrowth loadedFPGrowth = + TestUtils.saveAndReload( + tEnv, fpGrowth, tempFolder.newFolder().getAbsolutePath(), FPGrowth::load); + Table[] results = loadedFPGrowth.transform(inputTable); + CloseableIterator patterns = results[0].execute().collect(); + checkResult(expectedPatterns, patterns); + CloseableIterator rules = results[1].execute().collect(); + checkResult(expectedRules, rules); + } +} diff --git a/flink-ml-python/pyflink/examples/ml/recommendation/fpgrowth_example.py b/flink-ml-python/pyflink/examples/ml/recommendation/fpgrowth_example.py new file mode 100644 index 000000000..5808ad4a5 --- /dev/null +++ b/flink-ml-python/pyflink/examples/ml/recommendation/fpgrowth_example.py @@ -0,0 +1,66 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Simple program that creates a fpgrowth instance and gives recommendations for items. + +from pyflink.common import Types +from pyflink.datastream import StreamExecutionEnvironment +from pyflink.table import StreamTableEnvironment + +from pyflink.ml.recommendation.fpgrowth import FPGrowth + +# Creates a new StreamExecutionEnvironment. +env = StreamExecutionEnvironment.get_execution_environment() + +# Creates a StreamTableEnvironment. +t_env = StreamTableEnvironment.create(env) + +# Generates input data. +input_table = t_env.from_data_stream( + env.from_collection([ + ("A,B,C,D",), + ("B,C,E",), + ("A,B,C,E",), + ("B,D,E",), + ("A,B,C,D",) + ], + type_info=Types.ROW_NAMED( + ['items'], + [Types.STRING()]) + )) + +# Creates a fpgrowth object and initialize its parameters. +fpg = FPGrowth().set_min_support(0.6) + +# Transforms the data to fpgrowth algorithm result. +output_table = fpg.transform(input_table) + +# Extracts and display the results. +pattern_result_names = output_table[0].get_schema().get_field_names() +rule_result_names = output_table[1].get_schema().get_field_names() + +patterns = t_env.to_data_stream(output_table[0]).execute_and_collect() +rules = t_env.to_data_stream(output_table[1]).execute_and_collect() + +print("|\t"+"\t|\t".join(pattern_result_names)+"\t|") +for result in patterns: + print(f'|\t{result[0]}\t|\t{result[1]}\t|\t{result[2]}\t|') +print("|\t"+" | ".join(rule_result_names)+"\t|") +for result in rules: + print(f'|\t{result[0]}\t|\t{result[1]}\t|\t{result[2]}\t|\t{result[3]}' + + f'\t|\t{result[4]}\t|\t{result[5]}\t|') diff --git a/flink-ml-python/pyflink/ml/recommendation/fpgrowth.py b/flink-ml-python/pyflink/ml/recommendation/fpgrowth.py new file mode 100644 index 000000000..7cae08ed5 --- /dev/null +++ b/flink-ml-python/pyflink/ml/recommendation/fpgrowth.py @@ -0,0 +1,166 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +import typing + +from pyflink.ml.param import Param, StringParam, IntParam, FloatParam, ParamValidators +from pyflink.ml.recommendation.common import JavaRecommendationAlgoOperator +from pyflink.ml.wrapper import JavaWithParams + + +class _FPGrowthParams( + JavaWithParams +): + """ + Params for :class:`FPGrowth`. + """ + + ITEMS_COL: Param[str] = StringParam( + "items_col", + "Item sequence column name.", + "items", + ParamValidators.not_null()) + + FIELD_DELIMITER: Param[str] = StringParam( + "field_delimiter", + "Field delimiter of item sequence, default delimiter is ','.", + ",", + ParamValidators.not_null()) + + MIN_LIFT: Param[float] = FloatParam( + "min_lift", + "Minimal lift level for association rules.", + 1.0, + ParamValidators.gt_eq(0)) + + MIN_CONFIDENCE: Param[float] = FloatParam( + "min_confidence", + "Minimal confidence level for association rules.", + 0.6, + ParamValidators.gt_eq(0)) + + MIN_SUPPORT: Param[float] = FloatParam( + "min_support", + "Minimal support percent. The default value of MIN_SUPPORT is 0.02", + 0.02) + + MIN_SUPPORT_COUNT: Param[int] = IntParam( + "min_support_count", + "Minimal support count. MIN_ITEM_COUNT has no " + + "effect when less than or equal to 0, The default value is -1.", + -1) + + MAX_PATTERN_LENGTH: Param[int] = FloatParam( + "max_pattern_length", + "Max frequent pattern length.", + 10, + ParamValidators.gt(0)) + + def __init__(self, java_params): + super(_FPGrowthParams, self).__init__(java_params) + + def set_items_col(self, value: str): + return typing.cast(_FPGrowthParams, self.set(self.ITEMS_COL, value)) + + def get_items_col(self) -> str: + return self.get(self.ITEMS_COL) + + def set_field_delimiter(self, value: str): + return typing.cast(_FPGrowthParams, self.set(self.FIELD_DELIMITER, value)) + + def get_field_delimiter(self) -> str: + return self.get(self.FIELD_DELIMITER) + + def set_min_lift(self, value: float): + return typing.cast(_FPGrowthParams, self.set(self.MIN_LIFT, value)) + + def get_min_lift(self) -> float: + return self.get(self.MIN_LIFT) + + def set_min_confidence(self, value: float): + return typing.cast(_FPGrowthParams, self.set(self.MIN_CONFIDENCE, value)) + + def get_min_confidence(self) -> float: + return self.get(self.MIN_CONFIDENCE) + + def set_min_support(self, value: float): + return typing.cast(_FPGrowthParams, self.set(self.MIN_SUPPORT, value)) + + def get_min_support(self) -> float: + return self.get(self.MIN_SUPPORT) + + def set_min_support_count(self, value: int): + return typing.cast(_FPGrowthParams, self.set(self.MIN_SUPPORT_COUNT, value)) + + def get_min_support_count(self) -> int: + return self.get(self.MIN_SUPPORT_COUNT) + + def set_max_pattern_length(self, value: int): + return typing.cast(_FPGrowthParams, self.set(self.MAX_PATTERN_LENGTH, value)) + + def get_max_pattern_length(self) -> int: + return self.get(self.MAX_PATTERN_LENGTH) + + @property + def items_col(self) -> str: + return self.get_items_col() + + @property + def field_delimiter(self) -> str: + return self.get_field_delimiter() + + @property + def min_lift(self) -> float: + return self.get_min_lift() + + @property + def min_confidence(self) -> float: + return self.get_min_confidence() + + @property + def min_support(self) -> float: + return self.get_min_support() + + @property + def min_support_count(self) -> int: + return self.get_min_support_count() + + @property + def max_pattern_length(self) -> int: + return self.get_max_pattern_length() + + +class FPGrowth(JavaRecommendationAlgoOperator, _FPGrowthParams): + """ + An implementation of parallel FP-growth algorithm to mine frequent itemset. + +

For detail descriptions, please refer to: Han et al., Mining frequent patterns without + candidate generation. Li et al., PFP: + Parallel FP-growth for query recommendation + """ + + def __init__(self, java_algo_operator=None): + super(FPGrowth, self).__init__(java_algo_operator) + + @classmethod + def _java_algo_operator_package_name(cls) -> str: + return "fpgrowth" + + @classmethod + def _java_algo_operator_class_name(cls) -> str: + return "FPGrowth" diff --git a/flink-ml-python/pyflink/ml/recommendation/tests/__init__.py b/flink-ml-python/pyflink/ml/recommendation/tests/__init__.py index 66981919e..d91d4836e 100644 --- a/flink-ml-python/pyflink/ml/recommendation/tests/__init__.py +++ b/flink-ml-python/pyflink/ml/recommendation/tests/__init__.py @@ -22,7 +22,7 @@ # Because the project and the dependent `pyflink` project have the same directory structure, # we need to manually add `flink-ml-python` path to `sys.path` in the test of this project to change # the order of package search. -flink_ml_python_dir = Path(__file__).parents[5] +flink_ml_python_dir = Path(__file__).parents[4] sys.path.append(str(flink_ml_python_dir)) import pyflink diff --git a/flink-ml-python/pyflink/ml/recommendation/tests/test_fpgrowth.py b/flink-ml-python/pyflink/ml/recommendation/tests/test_fpgrowth.py new file mode 100644 index 000000000..ea682d29d --- /dev/null +++ b/flink-ml-python/pyflink/ml/recommendation/tests/test_fpgrowth.py @@ -0,0 +1,133 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +from pyflink.common import Types +from pyflink.table import Table +from typing import List + +from pyflink.ml.recommendation.fpgrowth import FPGrowth +from pyflink.ml.tests.test_utils import PyFlinkMLTestCase + + +# Tests Swing. +class SwingTest(PyFlinkMLTestCase): + def setUp(self): + super(SwingTest, self).setUp() + self.input_table = self.t_env.from_data_stream( + self.env.from_collection([ + ("A,B,C,D",), + ("B,C,E",), + ("A,B,C,E",), + ("B,D,E",), + ("A,B,C,D",) + ], + type_info=Types.ROW_NAMED( + ['items'], + [Types.STRING()]) + )) + + self.expected_patterns = [ + ["A", 3, 1], + ["B", 5, 1], + ["B,A", 3, 2], + ["B,C", 4, 2], + ["B,C,A", 3, 3], + ["B,D", 3, 2], + ["B,E", 3, 2], + ["C", 4, 1], + ["C,A", 3, 2], + ["D", 3, 1], + ["E", 3, 1] + ] + + self.expected_rules = [ + ["A=>B", 2, 1.0, 0.6, 1.0, 3], + ["A=>C", 2, 1.25, 0.6, 1.0, 3], + ["B=>A", 2, 1.0, 0.6, 0.6, 3], + ["B=>C", 2, 1.0, 0.8, 0.8, 4], + ["B=>D", 2, 1.0, 0.6, 0.6, 3], + ["B=>E", 2, 1.0, 0.6, 0.6, 3], + ["B,A=>C", 3, 1.25, 0.6, 1.0, 3], + ["B,C=>A", 3, 1.25, 0.6, 0.75, 3], + ["C=>A", 2, 1.25, 0.6, 0.75, 3], + ["C=>B", 2, 1.0, 0.8, 1.0, 4], + ["C,A=>B", 3, 1.0, 0.6, 1.0, 3], + ["D=>B", 2, 1.0, 0.6, 1.0, 3], + ["E=>B", 2, 1.0, 0.6, 1.0, 3] + ] + + def test_param(self): + fpg = FPGrowth() + self.assertEqual("items", fpg.items_col) + self.assertEqual(",", fpg.field_delimiter) + self.assertAlmostEqual(1.0, fpg.min_lift, delta=1e-9) + self.assertAlmostEqual(0.6, fpg.min_confidence, delta=1e-9) + self.assertAlmostEqual(0.02, fpg.min_support, delta=1e-9) + self.assertEqual(-1, fpg.min_support_count) + self.assertEqual(10, fpg.max_pattern_length) + + fpg.set_items_col("values") \ + .set_field_delimiter(" ") \ + .set_min_lift(1.2) \ + .set_min_confidence(0.7) \ + .set_min_support(0.01) \ + .set_min_support_count(50) \ + .set_max_pattern_length(5) + + self.assertEqual("values", fpg.items_col) + self.assertEqual(" ", fpg.field_delimiter) + self.assertAlmostEqual(1.2, fpg.min_lift, delta=1e-9) + self.assertAlmostEqual(0.7, fpg.min_confidence, delta=1e-9) + self.assertAlmostEqual(0.01, fpg.min_support, delta=1e-9) + self.assertEqual(50, fpg.min_support_count) + self.assertEqual(5, fpg.max_pattern_length) + + def test_output_schema(self): + fpg = FPGrowth() + output_tables = fpg.transform(self.input_table) + self.assertEqual( + ["items", "support_count", "item_count"], + output_tables[0].get_schema().get_field_names()) + self.assertEqual( + ["rule", "item_count", "lift", "support_percent", + "confidence_percent", "transaction_count"], + output_tables[1].get_schema().get_field_names()) + + def test_transform(self): + fpg = FPGrowth().set_min_support(0.6) + output_tables = fpg.transform(self.input_table) + self.verify_output_result(output_tables[0], self.expected_patterns) + self.verify_output_result(output_tables[1], self.expected_rules) + + def test_save_load_and_transform(self): + fpg = FPGrowth().set_min_support_count(3) + reloaded_swing = self.save_and_reload(fpg) + output_tables = reloaded_swing.transform(self.input_table) + self.verify_output_result(output_tables[0], self.expected_patterns) + self.verify_output_result(output_tables[1], self.expected_rules) + + def verify_output_result( + self, output: Table, + expected_result: List): + collected_results = [result for result in + self.t_env.to_data_stream(output).execute_and_collect()] + results = [] + for result in collected_results: + results.append([item for item in result]) + results.sort(key=lambda x: x[0]) + expected_result.sort(key=lambda x: x[0]) + self.assertEqual(expected_result, results)