diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/CosineProc.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/CosineProc.java index a5c049f15..05bd9b6c6 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/similarity/CosineProc.java +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/CosineProc.java @@ -79,7 +79,9 @@ public Stream cosine( Stream stream = generateWeightedStream(configuration, inputs, similarityCutoff, topN, topK, computer); boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0; - return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty); + boolean writeParallel = configuration.get("writeParallel", false); + + return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty, writeParallel); } private SimilarityComputer similarityComputer(Double skipValue) { diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/EuclideanProc.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/EuclideanProc.java index 215e3c97a..26b3b6095 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/similarity/EuclideanProc.java +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/EuclideanProc.java @@ -80,7 +80,8 @@ public Stream euclidean( Stream stream = generateWeightedStream(configuration, inputs, similarityCutoff, topN, topK, computer); boolean write = configuration.isWriteFlag(false); // && similarityCutoff != 0.0; - return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty ); + boolean writeParallel = configuration.get("writeParallel", false); + return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty, writeParallel ); } Stream generateWeightedStream(ProcedureConfiguration configuration, WeightedInput[] inputs, diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/JaccardProc.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/JaccardProc.java index 2658c708e..d3a4b5480 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/similarity/JaccardProc.java +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/JaccardProc.java @@ -67,7 +67,9 @@ public Stream jaccard( similarityCutoff, getTopK(configuration)), getTopN(configuration)); boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0; - return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty ); + boolean writeParallel = configuration.get("writeParallel", false); + + return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty, writeParallel); } private SimilarityComputer similarityComputer() { diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/OverlapProc.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/OverlapProc.java index 427988627..c9c148a6b 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/similarity/OverlapProc.java +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/OverlapProc.java @@ -73,7 +73,9 @@ public Stream overlap( Stream stream = topN(similarityStream(inputs, computer, configuration, () -> null, similarityCutoff, getTopK(configuration)), getTopN(configuration)); boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0; - return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty); + boolean writeParallel = configuration.get("writeParallel", false); + + return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty, writeParallel); } diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/ParallelSimilarityExporter.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/ParallelSimilarityExporter.java new file mode 100644 index 000000000..b38fd9880 --- /dev/null +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/ParallelSimilarityExporter.java @@ -0,0 +1,157 @@ +/** + * Copyright (c) 2017 "Neo4j, Inc." + *

+ * This file is part of Neo4j Graph Algorithms . + *

+ * Neo4j Graph Algorithms is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + *

+ * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + *

+ * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.graphalgo.similarity; + +import com.carrotsearch.hppc.IntHashSet; +import com.carrotsearch.hppc.IntSet; +import org.neo4j.graphalgo.core.IdMap; +import org.neo4j.graphalgo.core.WeightMap; +import org.neo4j.graphalgo.core.heavyweight.AdjacencyMatrix; +import org.neo4j.graphalgo.core.heavyweight.HeavyGraph; +import org.neo4j.graphalgo.core.utils.*; +import org.neo4j.graphalgo.core.utils.dss.DisjointSetStruct; +import org.neo4j.graphalgo.core.utils.paged.AllocationTracker; +import org.neo4j.graphalgo.impl.DSSResult; +import org.neo4j.graphalgo.impl.GraphUnionFind; +import org.neo4j.graphdb.Direction; +import org.neo4j.internal.kernel.api.exceptions.EntityNotFoundException; +import org.neo4j.internal.kernel.api.exceptions.InvalidTransactionTypeKernelException; +import org.neo4j.internal.kernel.api.exceptions.KernelException; +import org.neo4j.internal.kernel.api.exceptions.explicitindex.AutoIndexingKernelException; +import org.neo4j.kernel.api.KernelTransaction; +import org.neo4j.kernel.internal.GraphDatabaseAPI; +import org.neo4j.logging.Log; +import org.neo4j.values.storable.Values; + +import java.util.*; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class ParallelSimilarityExporter extends SimilarityExporter { + + private final int nodeCount; + + public ParallelSimilarityExporter(GraphDatabaseAPI api, + Log log, + String relationshipType, + String propertyName, int nodeCount) { + super(api, log, propertyName, relationshipType); + this.nodeCount = nodeCount; + } + + public int export(Stream similarityPairs, long batchSize) { + IdMap idMap = new IdMap(this.nodeCount); + AdjacencyMatrix adjacencyMatrix = new AdjacencyMatrix(this.nodeCount, false, AllocationTracker.EMPTY); + WeightMap weightMap = new WeightMap(nodeCount, 0, propertyId); + + int[] numberOfRelationships = {0}; + + similarityPairs.forEach(pair -> { + int id1 = idMap.mapOrGet(pair.item1); + int id2 = idMap.mapOrGet(pair.item2); + adjacencyMatrix.addOutgoing(id1, id2); + weightMap.put(RawValues.combineIntInt(id1, id2), pair.similarity); + numberOfRelationships[0]++; + }); + + idMap.buildMappedIds(); + HeavyGraph graph = new HeavyGraph(idMap, adjacencyMatrix, weightMap, Collections.emptyMap()); + + DisjointSetStruct dssResult = computePartitions(graph); + + Stream> partitions = groupPartitions(graph, dssResult); + + int numberOfPartitions = dssResult.getSetCount(); + if(numberOfPartitions == 0) { + return 0; + } + + log.info("ParallelSimilarityExporter: Relationships to be created: %d, Partitions found: %d", numberOfRelationships[0], numberOfPartitions); + ArrayBlockingQueue> outQueue = new ArrayBlockingQueue<>(numberOfPartitions); + + AtomicInteger inQueueBatchCount = new AtomicInteger(0); + partitions.parallel().forEach(partition -> { + IntSet nodesInPartition = new IntHashSet(); + for (DisjointSetStruct.InternalResult internalResult : partition) { + nodesInPartition.add(internalResult.internalNodeId); + } + + List inPartition = new ArrayList<>(); + List outPartition = new ArrayList<>(); + + for (DisjointSetStruct.InternalResult result : partition) { + int nodeId = result.internalNodeId; + graph.forEachRelationship(nodeId, Direction.OUTGOING, (sourceNodeId, targetNodeId, relationId, weight) -> { + SimilarityResult similarityRelationship = new SimilarityResult(idMap.toOriginalNodeId(sourceNodeId), + idMap.toOriginalNodeId(targetNodeId), -1, -1, -1, weight); + + if (nodesInPartition.contains(targetNodeId)) { + inPartition.add(similarityRelationship); + } else { + outPartition.add(similarityRelationship); + } + + return false; + }); + } + + if (!inPartition.isEmpty()) { + int inQueueBatches = writeSequential(inPartition.stream(), batchSize); + inQueueBatchCount.addAndGet(inQueueBatches); + } + + if (!outPartition.isEmpty()) { + put(outQueue, outPartition); + } + }); + + + int inQueueBatches = inQueueBatchCount.get(); + int outQueueBatches = writeSequential(outQueue.stream().flatMap(Collection::stream), batchSize); + + log.info("ParallelSimilarityExporter: Batch Size: %d, Batches written - in parallel: %d, sequentially: %d", batchSize, inQueueBatches, outQueueBatches); + + return inQueueBatches + outQueueBatches; + } + + private Stream> groupPartitions(HeavyGraph graph, DisjointSetStruct dssResult) { + return dssResult.internalResultStream(graph) + .collect(Collectors.groupingBy(item -> item.setId)) + .values() + .stream(); + } + + private static void put(BlockingQueue queue, T items) { + try { + queue.put(items); + } catch (InterruptedException e) { + // ignore + } + } + + private DisjointSetStruct computePartitions(HeavyGraph graph) { + GraphUnionFind algo = new GraphUnionFind(graph); + DisjointSetStruct struct = algo.compute(); + algo.release(); + return struct; + } + +} diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/PearsonProc.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/PearsonProc.java index 91a27784c..35a62d44e 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/similarity/PearsonProc.java +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/PearsonProc.java @@ -80,7 +80,9 @@ public Stream pearson( Stream stream = generateWeightedStream(configuration, inputs, similarityCutoff, topN, topK, computer); boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0; - return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty); + boolean writeParallel = configuration.get("writeParallel", false); + + return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty, writeParallel); } private SimilarityComputer similarityComputer(Double skipValue) { diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/SequentialSimilarityExporter.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/SequentialSimilarityExporter.java new file mode 100644 index 000000000..c50eaff2a --- /dev/null +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/SequentialSimilarityExporter.java @@ -0,0 +1,40 @@ +/** + * Copyright (c) 2017 "Neo4j, Inc." + *

+ * This file is part of Neo4j Graph Algorithms . + *

+ * Neo4j Graph Algorithms is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + *

+ * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + *

+ * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.graphalgo.similarity; + +import org.neo4j.kernel.internal.GraphDatabaseAPI; +import org.neo4j.logging.Log; + +import java.util.stream.Stream; + +public class SequentialSimilarityExporter extends SimilarityExporter { + + public SequentialSimilarityExporter(GraphDatabaseAPI api, + Log log, String relationshipType, + String propertyName, int nodeCount) { + super(api, log, propertyName, relationshipType); + + } + + public int export(Stream similarityPairs, long batchSize) { + int batches = writeSequential(similarityPairs, batchSize); + log.info("SequentialSimilarityExporter: Batch Size: %d, Batches written - sequentially: %d", batchSize, batches); + return batches; + } +} diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityExporter.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityExporter.java index c8b9deee7..aa74f6ca8 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityExporter.java +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityExporter.java @@ -1,3 +1,4 @@ +package org.neo4j.graphalgo.similarity; /** * Copyright (c) 2017 "Neo4j, Inc." * @@ -16,7 +17,6 @@ * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ -package org.neo4j.graphalgo.similarity; import org.neo4j.graphalgo.core.utils.ExceptionUtil; import org.neo4j.graphalgo.core.utils.StatementApi; @@ -26,6 +26,7 @@ import org.neo4j.internal.kernel.api.exceptions.explicitindex.AutoIndexingKernelException; import org.neo4j.kernel.api.KernelTransaction; import org.neo4j.kernel.internal.GraphDatabaseAPI; +import org.neo4j.logging.Log; import org.neo4j.values.storable.Values; import java.util.ArrayList; @@ -33,21 +34,37 @@ import java.util.List; import java.util.stream.Stream; -public class SimilarityExporter extends StatementApi { - - private final int propertyId; +public abstract class SimilarityExporter extends StatementApi { + final Log log; + final int propertyId; private final int relationshipTypeId; - public SimilarityExporter(GraphDatabaseAPI api, - String relationshipType, - String propertyName) { + SimilarityExporter(GraphDatabaseAPI api, Log log, String propertyName, String relationshipType) { super(api); + this.log = log; propertyId = getOrCreatePropertyId(propertyName); - relationshipTypeId = getOrCreateRelationshipId(relationshipType); + relationshipTypeId = getOrCreateRelationshipTypeId(relationshipType); + } + + private int getOrCreateRelationshipTypeId(String relationshipType) { + return applyInTransaction(stmt -> stmt + .tokenWrite() + .relationshipTypeGetOrCreateForName(relationshipType)); + } + + private int getOrCreatePropertyId(String propertyName) { + return applyInTransaction(stmt -> stmt + .tokenWrite() + .propertyKeyGetOrCreateForName(propertyName)); } - public void export(Stream similarityPairs, long batchSize) { - writeSequential(similarityPairs, batchSize); + private void createRelationship(SimilarityResult similarityResult, KernelTransaction statement) throws EntityNotFoundException, InvalidTransactionTypeKernelException, AutoIndexingKernelException { + long node1 = similarityResult.item1; + long node2 = similarityResult.item2; + long relationshipId = statement.dataWrite().relationshipCreate(node1, relationshipTypeId, node2); + + statement.dataWrite().relationshipSetProperty( + relationshipId, propertyId, Values.doubleValue(similarityResult.similarity)); } private void export(SimilarityResult similarityResult) { @@ -76,36 +93,25 @@ private void export(List similarityResults) { } - private void createRelationship(SimilarityResult similarityResult, KernelTransaction statement) throws EntityNotFoundException, InvalidTransactionTypeKernelException, AutoIndexingKernelException { - long node1 = similarityResult.item1; - long node2 = similarityResult.item2; - long relationshipId = statement.dataWrite().relationshipCreate(node1, relationshipTypeId, node2); - - statement.dataWrite().relationshipSetProperty( - relationshipId, propertyId, Values.doubleValue(similarityResult.similarity)); - } - - private int getOrCreateRelationshipId(String relationshipType) { - return applyInTransaction(stmt -> stmt - .tokenWrite() - .relationshipTypeGetOrCreateForName(relationshipType)); - } - - private int getOrCreatePropertyId(String propertyName) { - return applyInTransaction(stmt -> stmt - .tokenWrite() - .propertyKeyGetOrCreateForName(propertyName)); - } - - private void writeSequential(Stream similarityPairs, long batchSize) { + int writeSequential(Stream similarityPairs, long batchSize) { + int[] counter = {0}; if (batchSize == 1) { - similarityPairs.forEach(this::export); + similarityPairs.forEach(similarityResult -> { + export(similarityResult); + counter[0]++; + }); } else { Iterator iterator = similarityPairs.iterator(); do { - export(take(iterator, Math.toIntExact(batchSize))); + List batch = take(iterator, Math.toIntExact(batchSize)); + export(batch); + if (batch.size() > 0) { + counter[0]++; + } } while (iterator.hasNext()); } + + return counter[0]; } private static List take(Iterator iterator, int batchSize) { @@ -116,5 +122,5 @@ private static List take(Iterator iterator, return result; } - + abstract int export(Stream similarityPairs, long batchSize); } diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java index 4420419e8..6ac6eac0f 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java @@ -25,16 +25,13 @@ import org.HdrHistogram.DoubleHistogram; import org.neo4j.graphalgo.core.ProcedureConfiguration; import org.neo4j.graphalgo.core.ProcedureConstants; -import org.neo4j.graphalgo.core.utils.ParallelUtil; -import org.neo4j.graphalgo.core.utils.Pools; -import org.neo4j.graphalgo.core.utils.QueueBasedSpliterator; -import org.neo4j.graphalgo.core.utils.TerminationFlag; +import org.neo4j.graphalgo.core.utils.*; import org.neo4j.graphalgo.impl.util.TopKConsumer; import org.neo4j.graphdb.Result; import org.neo4j.kernel.api.KernelTransaction; import org.neo4j.kernel.internal.GraphDatabaseAPI; import org.neo4j.logging.Log; -import org.neo4j.procedure.*; +import org.neo4j.procedure.Context; import java.util.*; import java.util.concurrent.ArrayBlockingQueue; @@ -48,7 +45,6 @@ import static org.neo4j.graphalgo.impl.util.TopKConsumer.topK; import static org.neo4j.graphalgo.similarity.Weights.REPEAT_CUTOFF; -import static org.neo4j.helpers.collection.MapUtil.map; public class SimilarityProc { @Context @@ -96,7 +92,76 @@ Long getWriteBatchSize(ProcedureConfiguration configuration) { return configuration.get("writeBatchSize", 10000L); } - Stream writeAndAggregateResults(Stream stream, int length, ProcedureConfiguration configuration, boolean write, String writeRelationshipType, String writeProperty) { + + public class SimilarityResultBuilder { + protected long writeDuration = -1; + protected boolean write = false; + private int nodes; + private String writeRelationshipType; + private String writeProperty; + private AtomicLong similarityPairs; + private DoubleHistogram histogram; + + + public SimilarityResultBuilder withWriteDuration(long writeDuration) { + this.writeDuration = writeDuration; + return this; + } + + public SimilarityResultBuilder withWrite(boolean write) { + this.write = write; + return this; + } + + /** + * returns an AutoClosable which measures the time + * until it gets closed. Saves the duration as writeMillis + * + * @return + */ + public ProgressTimer timeWrite() { + return ProgressTimer.start(this::withWriteDuration); + } + + public SimilaritySummaryResult build() { + return SimilaritySummaryResult.from(nodes, similarityPairs, writeRelationshipType, writeProperty, write, histogram, writeDuration); + } + + public SimilarityResultBuilder nodes(int nodes) { + this.nodes = nodes; + return this; + } + + public SimilarityResultBuilder write(boolean write) { + this.write = write; + return this; + } + + public SimilarityResultBuilder writeRelationshipType(String writeRelationshipType) { + this.writeRelationshipType = writeRelationshipType; + return this; + } + + public SimilarityResultBuilder writeProperty(String writeProperty) { + this.writeProperty = writeProperty; + return this; + } + + public SimilarityResultBuilder similarityPairs(AtomicLong similarityPairs) { + this.similarityPairs = similarityPairs; + return this; + } + + public SimilarityResultBuilder histogram(DoubleHistogram histogram) { + this.histogram = histogram; + return this; + } + } + + Stream writeAndAggregateResults(Stream stream, int length, ProcedureConfiguration configuration, boolean write, String writeRelationshipType, String writeProperty, boolean writeParallel) { + SimilarityResultBuilder builder = new SimilarityResultBuilder(); + builder.nodes(length).write(write).writeRelationshipType(writeRelationshipType).writeProperty(writeProperty); + long writeBatchSize = getWriteBatchSize(configuration); AtomicLong similarityPairs = new AtomicLong(); DoubleHistogram histogram = new DoubleHistogram(5); @@ -106,18 +171,32 @@ Stream writeAndAggregateResults(Stream emptyStream(String writeRelationshipType, String writeProperty) { return Stream.of(SimilaritySummaryResult.from(0, new AtomicLong(0), writeRelationshipType, - writeProperty, false, new DoubleHistogram(5))); + writeProperty, false, new DoubleHistogram(5), -1)); } Double getSimilarityCutoff(ProcedureConfiguration configuration) { @@ -246,7 +325,7 @@ CategoricalInput[] prepareCategories(List> data, long degree WeightedInput[] prepareWeights(Object rawData, ProcedureConfiguration configuration, Double skipValue) throws Exception { if (ProcedureConstants.CYPHER_QUERY.equals(configuration.getGraphName("dense"))) { - return prepareSparseWeights(api, (String) rawData, skipValue, configuration); + return prepareSparseWeights(api, (String) rawData, skipValue, configuration); } else { List> data = (List>) rawData; return preparseDenseWeights(data, getDegreeCutoff(configuration), skipValue); @@ -354,7 +433,7 @@ int getTopN(ProcedureConfiguration configuration) { } private Supplier createDecoderFactory(String graphType, int size) { - if(ProcedureConstants.CYPHER_QUERY.equals(graphType)) { + if (ProcedureConstants.CYPHER_QUERY.equals(graphType)) { return () -> new RleDecoder(size); } diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilaritySummaryResult.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilaritySummaryResult.java index 560673d8e..276f8142a 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilaritySummaryResult.java +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/SimilaritySummaryResult.java @@ -24,6 +24,7 @@ public class SimilaritySummaryResult { + public final long writeMillis; public final long nodes; public final long similarityPairs; public final boolean write; @@ -46,7 +47,7 @@ public SimilaritySummaryResult(long nodes, long similarityPairs, boolean write, String writeRelationshipType, String writeProperty, double min, double max, double mean, double stdDev, double p25, double p50, double p75, double p90, double p95, - double p99, double p999, double p100) { + double p99, double p999, double p100, long writeMillis) { this.nodes = nodes; this.similarityPairs = similarityPairs; this.write = write; @@ -64,9 +65,10 @@ public SimilaritySummaryResult(long nodes, long similarityPairs, this.p99 = p99; this.p999 = p999; this.p100 = p100; + this.writeMillis = writeMillis; } - static SimilaritySummaryResult from(long length, AtomicLong similarityPairs, String writeRelationshipType, String writeProperty, boolean write, DoubleHistogram histogram) { + static SimilaritySummaryResult from(long length, AtomicLong similarityPairs, String writeRelationshipType, String writeProperty, boolean write, DoubleHistogram histogram, long writeMillis) { return new SimilaritySummaryResult( length, similarityPairs.get(), @@ -84,7 +86,8 @@ static SimilaritySummaryResult from(long length, AtomicLong similarityPairs, Str histogram.getValueAtPercentile(95D), histogram.getValueAtPercentile(99D), histogram.getValueAtPercentile(99.9D), - histogram.getValueAtPercentile(100D) + histogram.getValueAtPercentile(100D), + writeMillis ); } } diff --git a/core/src/main/java/org/neo4j/graphalgo/core/utils/dss/DisjointSetStruct.java b/core/src/main/java/org/neo4j/graphalgo/core/utils/dss/DisjointSetStruct.java index 8bca215b4..c6758b599 100644 --- a/core/src/main/java/org/neo4j/graphalgo/core/utils/dss/DisjointSetStruct.java +++ b/core/src/main/java/org/neo4j/graphalgo/core/utils/dss/DisjointSetStruct.java @@ -121,6 +121,12 @@ public Stream resultStream(IdMapping idMapping) { find(mappedId))); } + + public Stream internalResultStream(IdMapping idMapping) { + + return IntStream.range(IdMapping.START_NODE_ID, (int) idMapping.nodeCount()) + .mapToObj(mappedId -> new InternalResult(mappedId, find(mappedId))); + } /** * element (node) count * @@ -353,6 +359,29 @@ public Result(long nodeId, long setId) { } } + public static class InternalResult { + + /** + * the mapped node id + */ + public final int internalNodeId; + + /** + * set id + */ + public final long setId; + + public InternalResult(int internalNodeId, int setId) { + this.internalNodeId = internalNodeId; + this.setId = (long) setId; + } + + public InternalResult(int internalNodeId, long setId) { + this.internalNodeId = internalNodeId; + this.setId = setId; + } + } + public final static class Translator implements PropertyTranslator.OfInt { public static final PropertyTranslator INSTANCE = new Translator(); diff --git a/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/JaccardTest.java b/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/JaccardTest.java index 000d15ac0..e85c4a70e 100644 --- a/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/JaccardTest.java +++ b/tests/src/test/java/org/neo4j/graphalgo/algo/similarity/JaccardTest.java @@ -269,7 +269,7 @@ public void simpleJaccardFromEmbeddingTest() { @Test public void simpleJaccardWriteTest() { - Map params = map("config", map( "write",true, "similarityCutoff", 0.1)); + Map params = map("config", map( "write",true, "similarityCutoff", 0.1, "writeParallel", true)); db.execute(STATEMENT,params).close(); diff --git a/tests/src/test/java/org/neo4j/graphalgo/similarity/SimilarityExporterTest.java b/tests/src/test/java/org/neo4j/graphalgo/similarity/SimilarityExporterTest.java new file mode 100644 index 000000000..1102cc0c5 --- /dev/null +++ b/tests/src/test/java/org/neo4j/graphalgo/similarity/SimilarityExporterTest.java @@ -0,0 +1,233 @@ +package org.neo4j.graphalgo.similarity; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.neo4j.graphdb.Transaction; +import org.neo4j.kernel.internal.GraphDatabaseAPI; +import org.neo4j.logging.Log; +import org.neo4j.logging.NullLog; +import org.neo4j.test.rule.ImpermanentDatabaseRule; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.collection.IsCollectionWithSize.hasSize; +import static org.hamcrest.core.IsCollectionContaining.hasItems; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +@RunWith(Parameterized.class) +public class SimilarityExporterTest { + @Rule + public final ImpermanentDatabaseRule DB = new ImpermanentDatabaseRule(); + + private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup(); + private static final MethodType CTOR_METHOD = MethodType.methodType( + void.class, + GraphDatabaseAPI.class, + Log.class, + String.class, + String.class, + int.class); + + private static final String RELATIONSHIP_TYPE = "SIMILAR"; + private static final String PROPERTY_NAME = "score"; + private SimilarityExporter exporter; + private GraphDatabaseAPI api; + private Class similarityExporterFactory; + + @Parameterized.Parameters(name = "{1}") + public static Collection data() { + return Arrays.asList( + new Object[]{SequentialSimilarityExporter.class, "Sequential"}, + new Object[]{ParallelSimilarityExporter.class, "Parallel"} + ); + } + + @Before + public void setup() { + api = DB.getGraphDatabaseAPI(); + } + + public SimilarityExporterTest(Class similarityExporterFactory, + String ignoreParamOnlyForTestNaming) throws Throwable { + + this.similarityExporterFactory = similarityExporterFactory; + } + + public SimilarityExporter load(Class factoryType, int nodeCount) throws Throwable { + final MethodHandle constructor = findConstructor(factoryType); + return (SimilarityExporter) constructor.invoke(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME, nodeCount); + } + + private MethodHandle findConstructor(Class factoryType) { + try { + return LOOKUP.findConstructor(factoryType, CTOR_METHOD); + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + @Test + public void createNothing() throws Throwable { + int nodeCount = 2; + createNodes(api, nodeCount); + exporter = load(similarityExporterFactory, nodeCount); + + Stream similarityPairs = Stream.empty(); + + int batches = exporter.export(similarityPairs, 1); + assertEquals(0, batches); + + try (Transaction tx = api.beginTx()) { + List allRelationships = getSimilarityRelationships(api); + assertThat(allRelationships, hasSize(0)); + } + } + + @Test + public void createOneRelationship() throws Throwable { + int nodeCount = 2; + createNodes(api, nodeCount); + exporter = load(similarityExporterFactory, nodeCount); + + Stream similarityPairs = Stream.of(new SimilarityResult(0, 1, -1, -1, -1, 0.5)); + + int batches = exporter.export(similarityPairs, 1); + assertEquals(1, batches); + + try (Transaction tx = api.beginTx()) { + List allRelationships = getSimilarityRelationships(api); + assertThat(allRelationships, hasSize(1)); + assertThat(allRelationships, hasItems(new SimilarityRelationship(0, 1, 0.5))); + } + } + + @Test + public void multipleBatches() throws Throwable { + int nodeCount = 4; + createNodes(api, nodeCount); + exporter = load(similarityExporterFactory, nodeCount); + + SimilarityExporter exporter = new SequentialSimilarityExporter(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME, 4); + + Stream similarityPairs = Stream.of( + new SimilarityResult(0, 1, -1, -1, -1, 0.5), + new SimilarityResult(2, 3, -1, -1, -1, 0.7) + ); + + int batches = exporter.export(similarityPairs, 1); + assertEquals(2, batches); + + try (Transaction tx = api.beginTx()) { + List allRelationships = getSimilarityRelationships(api); + + assertThat(allRelationships, hasSize(2)); + assertThat(allRelationships, hasItems(new SimilarityRelationship(0, 1, 0.5))); + assertThat(allRelationships, hasItems(new SimilarityRelationship(2, 3, 0.7))); + } + } + + @Test + public void smallerThanBatchSize() throws Throwable { + int nodeCount = 5; + createNodes(api, nodeCount); + exporter = load(similarityExporterFactory, nodeCount); + + Stream similarityPairs = Stream.of( + new SimilarityResult(0, 1, -1, -1, -1, 0.5), + new SimilarityResult(1, 2, -1, -1, -1, 0.6), + new SimilarityResult(2, 3, -1, -1, -1, 0.7), + new SimilarityResult(3, 4, -1, -1, -1, 0.8) + ); + + int batches = exporter.export(similarityPairs, 10); + assertEquals(1, batches); + + try (Transaction tx = api.beginTx()) { + List allRelationships = getSimilarityRelationships(api); + + assertThat(allRelationships, hasSize(4)); + assertThat(allRelationships, hasItems(new SimilarityRelationship(0, 1, 0.5))); + assertThat(allRelationships, hasItems(new SimilarityRelationship(1, 2, 0.6))); + assertThat(allRelationships, hasItems(new SimilarityRelationship(2, 3, 0.7))); + assertThat(allRelationships, hasItems(new SimilarityRelationship(3, 4, 0.8))); + } + } + + @Test + public void disconnectedUpdates() throws Throwable { + int nodeCount = 6; + createNodes(api, nodeCount); + exporter = load(similarityExporterFactory, nodeCount); + + Stream similarityPairs = Stream.of( + new SimilarityResult(0, 1, -1, -1, -1, 0.5), + new SimilarityResult(2, 3, -1, -1, -1, 0.7), + new SimilarityResult(4, 5, -1, -1, -1, 0.9) + ); + + exporter.export(similarityPairs, 10); + + try (Transaction tx = api.beginTx()) { + List allRelationships = getSimilarityRelationships(api); + + assertThat(allRelationships, hasSize(3)); + assertThat(allRelationships, hasItems(new SimilarityRelationship(0, 1, 0.5))); + assertThat(allRelationships, hasItems(new SimilarityRelationship(2, 3, 0.7))); + assertThat(allRelationships, hasItems(new SimilarityRelationship(4, 5, 0.9))); + } + } + + private List getSimilarityRelationships(GraphDatabaseAPI api) { + return api.getAllRelationships().stream() + .map(relationship -> new SimilarityRelationship(relationship.getStartNodeId(), relationship.getEndNodeId(), (double)relationship.getProperty(PROPERTY_NAME))) + .collect(Collectors.toList()); + } + + static class SimilarityRelationship { + private final long startNodeId; + private final long endNodeId; + private final double property; + + SimilarityRelationship(long startNodeId, long endNodeId, double property) { + this.startNodeId = startNodeId; + this.endNodeId = endNodeId; + this.property = property; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SimilarityRelationship that = (SimilarityRelationship) o; + return startNodeId == that.startNodeId && + endNodeId == that.endNodeId && + Double.compare(that.property, property) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(startNodeId, endNodeId, property); + } + } + + private void createNodes(GraphDatabaseAPI api, int nodeCount) { + try (Transaction tx = api.beginTx()) { + for(int i = 0; i < nodeCount; i++) { + api.createNode(); + } + tx.success(); + } + } +} \ No newline at end of file