Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.

Commit 7306ccf

Browse files
mknblchmneedham
authored andcommitted
Huge (parallel, undirected, unweighted) louvain (#550)
* adapt trianglecount to work with >2bn nodes * adapt testcases * adapt writeback methods * set writemillis * issue 542 * mv specialized translators into their datatype classes * WIP * WIP * refactoring & jdoc * keep it object * fix test * debugging * fix test * fix test
1 parent 528e9da commit 7306ccf

File tree

12 files changed

+429
-288
lines changed

12 files changed

+429
-288
lines changed

algo/src/main/java/org/neo4j/graphalgo/LouvainProc.java

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.neo4j.graphalgo.core.utils.ProgressLogger;
2828
import org.neo4j.graphalgo.core.utils.ProgressTimer;
2929
import org.neo4j.graphalgo.core.utils.TerminationFlag;
30+
import org.neo4j.graphalgo.core.utils.paged.LongArray;
3031
import org.neo4j.graphalgo.core.write.Exporter;
3132
import org.neo4j.graphalgo.core.write.Translators;
3233
import org.neo4j.graphalgo.impl.louvain.*;
@@ -49,8 +50,6 @@ public class LouvainProc {
4950
public static final String CONFIG_CLUSTER_PROPERTY = "writeProperty";
5051
public static final String DEFAULT_CLUSTER_PROPERTY = "community";
5152

52-
public static final int DEFAULT_ITERATIONS = 5;
53-
5453
@Context
5554
public GraphDatabaseAPI api;
5655

@@ -82,7 +81,9 @@ public Stream<LouvainResult> louvain(
8281

8382
builder.withNodeCount(graph.nodeCount());
8483

85-
final LouvainAlgorithm louvain = louvain(graph, configuration);
84+
final LouvainAlgorithm louvain = LouvainAlgorithm.instance(graph, configuration)
85+
.withProgressLogger(ProgressLogger.wrap(log, "Louvain"))
86+
.withTerminationFlag(TerminationFlag.wrap(transaction));
8687

8788
// evaluation
8889
try (ProgressTimer timer = builder.timeEval()) {
@@ -114,7 +115,9 @@ public Stream<WeightedLouvain.Result> louvainStream(
114115
.overrideRelationshipTypeOrQuery(relationship);
115116

116117
// evaluation
117-
return louvain(graph(configuration), configuration)
118+
return LouvainAlgorithm.instance(graph(configuration), configuration)
119+
.withProgressLogger(ProgressLogger.wrap(log, "Louvain"))
120+
.withTerminationFlag(TerminationFlag.wrap(transaction))
118121
.compute()
119122
.resultStream();
120123

@@ -145,39 +148,23 @@ public Graph graph(ProcedureConfiguration config) {
145148
.load(graphImpl);
146149
}
147150

148-
public LouvainAlgorithm louvain(Graph graph, ProcedureConfiguration config) {
149-
150-
if (graph instanceof HugeGraph) {
151-
if (config.hasWeightProperty()) {
152-
return new WeightedLouvain(graph, Pools.DEFAULT, config.getConcurrency(), config.getIterations(DEFAULT_ITERATIONS))
153-
.withProgressLogger(ProgressLogger.wrap(log, "ModularityCommunityDetection"))
154-
.withTerminationFlag(TerminationFlag.wrap(transaction));
155-
}
156-
return new Louvain(graph, Pools.DEFAULT, config.getConcurrency(), config.getIterations(DEFAULT_ITERATIONS))
157-
.withProgressLogger(ProgressLogger.wrap(log, "Louvain"))
158-
.withTerminationFlag(TerminationFlag.wrap(transaction));
159-
}
160-
161-
return new ParallelLouvain(graph,
162-
graph,
163-
graph,
164-
Pools.DEFAULT,
165-
config.getConcurrency(),
166-
config.getIterations(DEFAULT_ITERATIONS))
167-
.withProgressLogger(ProgressLogger.wrap(log, "Louvain(deprecated)"))
168-
.withTerminationFlag(TerminationFlag.wrap(transaction));
169-
}
170-
171-
private void write(Graph graph, int[] communities, ProcedureConfiguration configuration) {
151+
private void write(Graph graph, Object communities, ProcedureConfiguration configuration) {
172152
log.debug("Writing results");
173-
Exporter.of(api, graph)
153+
final Exporter exporter = Exporter.of(api, graph)
174154
.withLog(log)
175155
.parallel(Pools.DEFAULT, configuration.getConcurrency(), TerminationFlag.wrap(transaction))
176-
.build()
177-
.write(
178-
configuration.get(CONFIG_CLUSTER_PROPERTY, DEFAULT_CLUSTER_PROPERTY),
179-
communities,
180-
Translators.INT_ARRAY_TRANSLATOR
181-
);
156+
.build();
157+
158+
if (communities instanceof int[]) {
159+
exporter.write(
160+
configuration.get(CONFIG_CLUSTER_PROPERTY, DEFAULT_CLUSTER_PROPERTY),
161+
(int[]) communities,
162+
Translators.INT_ARRAY_TRANSLATOR);
163+
} else if (communities instanceof LongArray) {
164+
exporter.write(
165+
configuration.get(CONFIG_CLUSTER_PROPERTY, DEFAULT_CLUSTER_PROPERTY),
166+
(LongArray) communities,
167+
LongArray.Translator.INSTANCE);
168+
}
182169
}
183170
}
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
/**
2+
* Copyright (c) 2017 "Neo4j, Inc." <http://neo4j.com>
3+
*
4+
* This file is part of Neo4j Graph Algorithms <http://github.com/neo4j-contrib/neo4j-graph-algorithms>.
5+
*
6+
* Neo4j Graph Algorithms is free software: you can redistribute it and/or modify
7+
* it under the terms of the GNU General Public License as published by
8+
* the Free Software Foundation, either version 3 of the License, or
9+
* (at your option) any later version.
10+
*
11+
* This program is distributed in the hope that it will be useful,
12+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14+
* GNU General Public License for more details.
15+
*
16+
* You should have received a copy of the GNU General Public License
17+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
18+
*/
19+
package org.neo4j.graphalgo.impl.louvain;
20+
21+
22+
import org.neo4j.graphalgo.api.HugeGraph;
23+
import org.neo4j.graphalgo.core.utils.ParallelUtil;
24+
import org.neo4j.graphalgo.core.utils.ProgressLogger;
25+
import org.neo4j.graphalgo.core.utils.TerminationFlag;
26+
import org.neo4j.graphalgo.core.utils.paged.AllocationTracker;
27+
import org.neo4j.graphalgo.core.utils.paged.DoubleArray;
28+
import org.neo4j.graphalgo.core.utils.paged.LongArray;
29+
import org.neo4j.graphalgo.core.utils.paged.PagedSimpleBitSet;
30+
import org.neo4j.graphalgo.impl.Algorithm;
31+
import org.neo4j.graphdb.Direction;
32+
33+
import java.util.ArrayList;
34+
import java.util.concurrent.ExecutorService;
35+
import java.util.concurrent.atomic.AtomicLong;
36+
import java.util.concurrent.atomic.LongAdder;
37+
import java.util.concurrent.locks.ReentrantReadWriteLock;
38+
import java.util.stream.LongStream;
39+
import java.util.stream.Stream;
40+
41+
/**
42+
* Parallel modularity based community detection algo
43+
*
44+
* @author mknblch
45+
*/
46+
public class HugeParallelLouvain extends Algorithm<HugeParallelLouvain> implements LouvainAlgorithm {
47+
48+
/**
49+
* pool
50+
*/
51+
private ExecutorService executorService;
52+
/**
53+
* number of threads to use
54+
*/
55+
private final int concurrency;
56+
/**
57+
* node cound
58+
*/
59+
private final long nodeCount;
60+
/**
61+
* graph
62+
*/
63+
private HugeGraph graph;
64+
/**
65+
* incrementing node counter
66+
*/
67+
private final AtomicLong queue;
68+
/**
69+
* R&W Locks
70+
*/
71+
private final ReentrantReadWriteLock readWriteLock = new ReentrantReadWriteLock();
72+
private final ReentrantReadWriteLock.WriteLock writeLock = readWriteLock.writeLock();
73+
/**
74+
* task array for parallel execution
75+
*/
76+
private final ArrayList<Task> tasks = new ArrayList<>();
77+
/**
78+
* memory tracker
79+
*/
80+
private final AllocationTracker tracker;
81+
/**
82+
* community weight. Sum of degrees of nodes
83+
* within a cluster
84+
*/
85+
private DoubleArray communityWeights;
86+
/**
87+
* pre calculated values
88+
*/
89+
private double m2, mq2;
90+
/**
91+
* node to community id mapping
92+
*/
93+
private LongArray communityIds;
94+
/**
95+
* number of iterations so far
96+
*/
97+
private int iterations;
98+
/**
99+
* maximum number of iterations
100+
*/
101+
private final int maxIterations;
102+
103+
public HugeParallelLouvain(HugeGraph graph,
104+
ExecutorService executorService,
105+
AllocationTracker tracker,
106+
int concurrency,
107+
int maxIterations) {
108+
this.graph = graph;
109+
nodeCount = graph.nodeCount();
110+
this.executorService = executorService;
111+
this.concurrency = concurrency;
112+
this.maxIterations = maxIterations;
113+
communityIds = LongArray.newArray(nodeCount, tracker);
114+
communityWeights = DoubleArray.newArray(nodeCount, tracker);
115+
this.queue = new AtomicLong(0);
116+
this.tracker = tracker;
117+
118+
}
119+
120+
/**
121+
* cluster id's until either max iterations is reached or no further
122+
* changes could improve modularity
123+
*/
124+
public LouvainAlgorithm compute() {
125+
reset();
126+
for (this.iterations = 0; this.iterations < maxIterations; this.iterations++) {
127+
queue.set(0);
128+
ParallelUtil.runWithConcurrency(concurrency, tasks, getTerminationFlag(), executorService);
129+
boolean changes = false;
130+
for (Task task : tasks) {
131+
changes |= task.changes;
132+
}
133+
if (!changes) {
134+
return this;
135+
}
136+
}
137+
return this;
138+
}
139+
140+
@Override
141+
public HugeParallelLouvain release() {
142+
graph = null;
143+
executorService = null;
144+
communityIds = null;
145+
communityWeights = null;
146+
return this;
147+
}
148+
149+
private void reset() {
150+
151+
tasks.clear();
152+
for (int i = 0; i < concurrency; i++) {
153+
tasks.add(new Task());
154+
}
155+
156+
communityIds.setAll(i -> i);
157+
final LongAdder adder = new LongAdder();
158+
ParallelUtil.iterateParallelHuge(executorService, nodeCount, concurrency, node -> {
159+
final int d = graph.degree(node, Direction.OUTGOING);
160+
communityWeights.set(node, d);
161+
adder.add(d);
162+
});
163+
/**
164+
* we can iterate over outgoing rels only because
165+
* the graph should be treated as undirected and
166+
* therefore have to multiply m by 4 instead of 2
167+
*/
168+
m2 = adder.intValue() * 4.0; // 2m
169+
mq2 = 2.0 * Math.pow(adder.intValue(), 2.0); // 2m^2
170+
}
171+
172+
/**
173+
* assign node to community
174+
* @param node nodeId
175+
* @param targetCommunity communityId
176+
*/
177+
private void assign(long node, long targetCommunity) {
178+
final int d = graph.degree(node, Direction.OUTGOING);
179+
writeLock.lock();
180+
try {
181+
final long index = communityIds.get(node);
182+
communityWeights.add(index, -d);
183+
communityWeights.add(targetCommunity, d);
184+
// update communityIds
185+
communityIds.set(node, targetCommunity);
186+
} finally {
187+
writeLock.unlock();
188+
}
189+
}
190+
191+
/**
192+
* @return kiIn
193+
*/
194+
private int kIIn(long node, long targetCommunity) {
195+
int[] sum = {0}; // {ki, ki_in}
196+
graph.forEachRelationship(node, Direction.OUTGOING, (sourceNodeId, targetNodeId) -> {
197+
if (targetCommunity == communityIds.get(targetNodeId)) {
198+
sum[0]++;
199+
}
200+
return true;
201+
});
202+
203+
return sum[0];
204+
}
205+
206+
public Stream<Result> resultStream() {
207+
return LongStream.range(0, nodeCount)
208+
.mapToObj(i ->
209+
new Result(graph.toOriginalNodeId(i), communityIds.get(i)));
210+
}
211+
212+
public LongArray getCommunityIds() {
213+
return communityIds;
214+
}
215+
216+
public int getIterations() {
217+
return iterations;
218+
}
219+
220+
public long getCommunityCount() {
221+
final PagedSimpleBitSet bitSet = PagedSimpleBitSet.newBitSet(nodeCount, tracker);
222+
for (long i = 0; i < communityIds.size(); i++) {
223+
bitSet.put(communityIds.get(i));
224+
}
225+
return bitSet._size();
226+
}
227+
228+
@Override
229+
public HugeParallelLouvain me() {
230+
return this;
231+
}
232+
233+
private class Task implements Runnable {
234+
235+
private boolean changes = false;
236+
private double bestGain;
237+
private long bestCommunity;
238+
private final ReentrantReadWriteLock.ReadLock readLock = readWriteLock.readLock();
239+
private final TerminationFlag flag = getTerminationFlag();
240+
private final ProgressLogger logger = getProgressLogger();
241+
242+
@Override
243+
public void run() {
244+
changes = false;
245+
for (long node; (node = queue.getAndIncrement()) < nodeCount && flag.running(); ) {
246+
bestGain = 0.0;
247+
readLock.lock();
248+
final long sourceCommunity = bestCommunity = communityIds.get(node);
249+
final double mSource = (communityWeights.get(sourceCommunity) * graph.degree(node, Direction.OUTGOING)) / mq2;
250+
readLock.unlock();
251+
252+
graph.forEachRelationship(node, Direction.OUTGOING, (sourceNodeId, targetNodeId) -> {
253+
readLock.lock();
254+
final long targetCommunity = communityIds.get(targetNodeId);
255+
final double gain = kIIn(sourceNodeId, targetCommunity) / m2 - mSource;
256+
readLock.unlock();
257+
if (gain > bestGain) {
258+
bestCommunity = targetCommunity;
259+
bestGain = gain;
260+
}
261+
return flag.running();
262+
});
263+
if (bestCommunity != sourceCommunity) {
264+
assign(node, bestCommunity);
265+
changes = true;
266+
}
267+
logger.logProgress(node, nodeCount - 1);
268+
}
269+
}
270+
}
271+
272+
}

0 commit comments

Comments
 (0)