Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.

Commit e26d75b

Browse files
authored
3.4 louvain backward (#737)
* intermediate communities must be explicitly requested on streaming side * write side
1 parent d245f99 commit e26d75b

File tree

6 files changed

+149
-24
lines changed

6 files changed

+149
-24
lines changed

algo/src/main/java/org/neo4j/graphalgo/LouvainProc.java

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@
4343
*/
4444
public class LouvainProc {
4545

46-
public static final String CONFIG_CLUSTER_PROPERTY = "writeProperty";
47-
public static final String DEFAULT_CLUSTER_PROPERTY = "communities";
46+
public static final String INTERMEDIATE_COMMUNITIES_WRITE_PROPERTY = "intermediateCommunitiesWriteProperty";
47+
public static final String DEFAULT_CLUSTER_PROPERTY = "community";
48+
public static final String INCLUDE_INTERMEDIATE_COMMUNITIES = "includeIntermediateCommunities";
4849

4950
@Context
5051
public GraphDatabaseAPI api;
@@ -93,7 +94,7 @@ public Stream<LouvainResult> louvain(
9394
}
9495

9596
if (configuration.isWriteFlag()) {
96-
builder.timeWrite(() -> write(graph, louvain.getDendrogram(), configuration));
97+
builder.timeWrite(() -> write(graph, louvain.getDendrogram(), louvain.getCommunityIds(), configuration));
9798
}
9899

99100
return Stream.of(builder.build());
@@ -125,7 +126,7 @@ public Stream<Louvain.StreamingResult> louvainStream(
125126
return Stream.empty();
126127
}
127128

128-
return louvain.dendrogramStream();
129+
return louvain.dendrogramStream(configuration.get(INCLUDE_INTERMEDIATE_COMMUNITIES, false));
129130
}
130131

131132
public Graph graph(ProcedureConfiguration config) {
@@ -138,16 +139,18 @@ public Graph graph(ProcedureConfiguration config) {
138139
.load(config.getGraphImpl());
139140
}
140141

141-
private void write(Graph graph, int[][] communities, ProcedureConfiguration configuration) {
142+
private void write(Graph graph, int[][] allCommunities, int[] finalCommunities, ProcedureConfiguration configuration) {
142143
log.debug("Writing results");
144+
boolean includeIntermediateCommunities = configuration.get(INCLUDE_INTERMEDIATE_COMMUNITIES, false);
143145

144146
new LouvainCommunityExporter(
145147
api,
146148
Pools.DEFAULT,
147149
configuration.getConcurrency(),
148150
graph,
149-
communities[0].length,
150-
configuration.getWriteProperty(DEFAULT_CLUSTER_PROPERTY))
151-
.export(communities);
151+
allCommunities[0].length,
152+
configuration.getWriteProperty(DEFAULT_CLUSTER_PROPERTY),
153+
configuration.get(INTERMEDIATE_COMMUNITIES_WRITE_PROPERTY, "communities"))
154+
.export(allCommunities, finalCommunities, includeIntermediateCommunities);
152155
}
153156
}

algo/src/main/java/org/neo4j/graphalgo/impl/louvain/Louvain.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,15 +209,18 @@ public Stream<Result> resultStream() {
209209
.mapToObj(i -> new Result(i, communities[i]));
210210
}
211211

212-
public Stream<StreamingResult> dendrogramStream() {
212+
public Stream<StreamingResult> dendrogramStream(boolean includeIntermediateCommunities) {
213213
return IntStream.range(0, rootNodeCount)
214214
.mapToObj(i -> {
215-
List<Long> communities = new ArrayList<>(dendrogram.length);
216-
for (int[] community : dendrogram) {
217-
communities.add((long) community[i]);
215+
List<Long> communitiesList = null;
216+
if (includeIntermediateCommunities) {
217+
communitiesList = new ArrayList<>(dendrogram.length);
218+
for (int[] community : dendrogram) {
219+
communitiesList.add((long) community[i]);
220+
}
218221
}
219222

220-
return new StreamingResult(root.toOriginalNodeId(i), communities);
223+
return new StreamingResult(root.toOriginalNodeId(i), communitiesList, communities[i]);
221224
});
222225
}
223226

@@ -272,11 +275,13 @@ public Result(long id, long community) {
272275
public static final class StreamingResult {
273276
public final long nodeId;
274277
public final List<Long> communities;
278+
public final long community;
275279

276-
public StreamingResult(long nodeId, List<Long> communities) {
280+
public StreamingResult(long nodeId, List<Long> communities, long community) {
277281

278282
this.nodeId = nodeId;
279283
this.communities = communities;
284+
this.community = community;
280285
}
281286
}
282287
}

algo/src/main/java/org/neo4j/graphalgo/impl/louvain/LouvainCommunityExporter.java

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,62 +40,96 @@ public class LouvainCommunityExporter extends StatementApi {
4040
private final int concurrency;
4141
private final IdMapping mapping;
4242
private final int nodeCount;
43+
private final int intermediateCommunitiesPropertyId;
4344
private Integer propertyId;
4445

4546
public LouvainCommunityExporter(GraphDatabaseAPI api,
4647
ExecutorService pool,
4748
int concurrency,
4849
IdMapping mapping,
4950
int nodeCount,
50-
String propertyName) {
51+
String propertyName,
52+
String intermediateCommunitiesPropertyName) {
5153
super(api);
5254
this.pool = pool;
5355
this.concurrency = concurrency;
5456
this.mapping = mapping;
5557
this.nodeCount = nodeCount;
5658

57-
propertyId = applyInTransaction(statement -> statement.tokenWrite().propertyKeyGetOrCreateForName(propertyName));
58-
59+
propertyId = applyInTransaction(statement -> statement.tokenWrite().propertyKeyGetOrCreateForName(propertyName));
60+
intermediateCommunitiesPropertyId = applyInTransaction(statement -> statement.tokenWrite().propertyKeyGetOrCreateForName(intermediateCommunitiesPropertyName));
5961
}
6062

61-
public void export(int[][] communities) {
63+
public void export(int[][] communities, int[] finalCommunities, boolean includeIntermediateCommunities) {
6264
final Collection<PrimitiveIntIterable> batchIterables = ParallelUtil.batchIterables(concurrency, nodeCount);
6365
final ArrayList<Runnable> tasks = new ArrayList<>();
64-
batchIterables.forEach(it -> tasks.add(new NodeBatchExporter(it, communities)));
66+
batchIterables.forEach(it -> tasks.add(new NodeBatchExporter(it, communities, finalCommunities, includeIntermediateCommunities)));
6567
ParallelUtil.run(tasks, pool);
6668
}
6769

6870
private class NodeBatchExporter implements Runnable {
6971

7072
private final PrimitiveIntIterable iterable;
71-
private final int[][] communities;
73+
private final int[][] allCommunities;
74+
private final int[] finalCommunities;
75+
private final boolean includeIntermediateCommunities;
7276

73-
private NodeBatchExporter(PrimitiveIntIterable iterable, int[][] communities) {
77+
private NodeBatchExporter(PrimitiveIntIterable iterable, int[][] allCommunities, int[] finalCommunities, boolean includeIntermediateCommunities) {
7478
this.iterable = iterable;
75-
this.communities = communities;
79+
this.allCommunities = allCommunities;
80+
this.finalCommunities = finalCommunities;
81+
this.includeIntermediateCommunities = includeIntermediateCommunities;
7682
}
7783

7884
@Override
7985
public void run() {
86+
if (includeIntermediateCommunities) {
87+
writeEverything();
88+
} else {
89+
onlyWriteFinalCommunities();
90+
}
91+
}
92+
93+
private void writeEverything() {
8094
acceptInTransaction(statement -> {
8195
final Write dataWriteOperations = statement.dataWrite();
8296
for(PrimitiveIntIterator it = iterable.iterator(); it.hasNext(); ) {
8397
final int id = it.next();
8498
// build int array
85-
final int[] data = new int[communities.length];
99+
final int[] data = new int[allCommunities.length];
86100
for (int i = 0; i < data.length; i++) {
87101
try {
88-
data[i] = communities[i][id];
102+
data[i] = allCommunities[i][id];
89103
} catch (Exception e) {
90104
throw e; // TODO
91105
}
92106
}
107+
93108
dataWriteOperations.nodeSetProperty(
94109
mapping.toOriginalNodeId(id),
95110
propertyId,
111+
Values.intValue(finalCommunities[id]));
112+
113+
dataWriteOperations.nodeSetProperty(
114+
mapping.toOriginalNodeId(id),
115+
intermediateCommunitiesPropertyId,
96116
Values.intArray(data));
97117
}
98118
});
99119
}
120+
121+
private void onlyWriteFinalCommunities() {
122+
acceptInTransaction(statement -> {
123+
final Write dataWriteOperations = statement.dataWrite();
124+
for(PrimitiveIntIterator it = iterable.iterator(); it.hasNext(); ) {
125+
final int id = it.next();
126+
127+
dataWriteOperations.nodeSetProperty(
128+
mapping.toOriginalNodeId(id),
129+
propertyId,
130+
Values.intValue(finalCommunities[id]));
131+
}
132+
});
133+
}
100134
}
101135
}

doc/asciidoc/louvain.adoc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ YIELD nodes, communityCount, iterations, loadMillis, computeMillis, writeMillis
151151
| write | boolean | true | yes | Specifies if the result should be written back as a node property
152152
| writeProperty | string | 'community' | yes | The property name written back to the ID of the community that particular node belongs to
153153
| defaultValue | float | null | yes | The default value of the weight in case it is missing or invalid
154+
| includeIntermediateCommunities | boolean | false | yes | Specifies whether an array of intermediate communities should be returned
155+
| intermediateCommunitiesWriteProperty | string | 'communities' | yes | The property name written back to the ID of the intermediate communities that particular node belongs to
154156
| concurrency | int | available CPUs | yes | The number of concurrent threads
155157
| graph | string | 'heavy' | yes | Use 'heavy' when describing the subset of the graph with label and relationship-type parameter. Use 'cypher' for describing the subset with cypher node-statement and relationship-statement
156158
|===

tests/src/test/java/org/neo4j/graphalgo/algo/LouvainClusteringIntegrationTest.java

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@
2828

2929
import java.util.Arrays;
3030
import java.util.List;
31+
import java.util.concurrent.atomic.AtomicInteger;
32+
import java.util.concurrent.atomic.AtomicLong;
3133

34+
import static junit.framework.TestCase.assertNull;
3235
import static org.junit.Assert.assertEquals;
3336
import static org.junit.Assert.assertTrue;
3437

@@ -114,6 +117,32 @@ public void test() {
114117
@Test
115118
public void testStream() {
116119
final String cypher = "CALL algo.louvain.stream('', '', {concurrency:1}) " +
120+
"YIELD nodeId, community, communities";
121+
final IntIntScatterMap testMap = new IntIntScatterMap();
122+
DB.execute(cypher).accept(row -> {
123+
final long community = (long) row.get("community");
124+
System.out.println(community);
125+
testMap.addTo((int) community, 1);
126+
return false;
127+
});
128+
assertEquals(3, testMap.size());
129+
}
130+
131+
@Test
132+
public void testStreamNoIntermediateCommunitiesByDefault() {
133+
final String cypher = "CALL algo.louvain.stream('', '', {concurrency:1}) " +
134+
"YIELD nodeId, community, communities";
135+
final IntIntScatterMap testMap = new IntIntScatterMap();
136+
DB.execute(cypher).accept(row -> {
137+
Object communities = row.get("communities");
138+
assertNull(communities);
139+
return false;
140+
});
141+
}
142+
143+
@Test
144+
public void testStreamIncludingIntermediateCommunities() {
145+
final String cypher = "CALL algo.louvain.stream('', '', {concurrency:1, includeIntermediateCommunities: true}) " +
117146
"YIELD nodeId, communities";
118147
final IntIntScatterMap testMap = new IntIntScatterMap();
119148
DB.execute(cypher).accept(row -> {
@@ -125,6 +154,56 @@ public void testStream() {
125154
assertEquals(3, testMap.size());
126155
}
127156

157+
@Test
158+
public void testWrite() {
159+
final String cypher = "CALL algo.louvain('', '', {concurrency:1})";
160+
final IntIntScatterMap testMap = new IntIntScatterMap();
161+
DB.execute(cypher);
162+
163+
String readQuery = "MATCH (n) RETURN n.community AS community";
164+
165+
DB.execute(readQuery).accept(row -> {
166+
final int community = (int) row.get("community");
167+
testMap.addTo(community, 1);
168+
return true;
169+
});
170+
171+
assertEquals(3, testMap.size());
172+
}
173+
174+
@Test
175+
public void testWriteIncludingIntermediateCommunities() {
176+
final String cypher = "CALL algo.louvain('', '', {concurrency:1, includeIntermediateCommunities: true})";
177+
final IntIntScatterMap testMap = new IntIntScatterMap();
178+
DB.execute(cypher);
179+
180+
String readQuery = "MATCH (n) RETURN n.communities AS communities";
181+
182+
DB.execute(readQuery).accept(row -> {
183+
final long community = ((int[]) row.get("communities"))[0];
184+
testMap.addTo((int) community, 1);
185+
return true;
186+
});
187+
188+
assertEquals(3, testMap.size());
189+
}
190+
191+
@Test
192+
public void testWriteNoIntermediateCommunitiesByDefault() {
193+
final String cypher = "CALL algo.louvain('', '', {concurrency:1})";
194+
DB.execute(cypher);
195+
196+
final AtomicLong testInteger = new AtomicLong(0);
197+
String readQuery = "MATCH (n) WHERE not(exists(n.communities)) RETURN count(*) AS count";
198+
DB.execute(readQuery).accept(row -> {
199+
long count = (long) row.get("count");
200+
testInteger.set(count);
201+
return false;
202+
});
203+
204+
assertEquals(9, testInteger.get());
205+
}
206+
128207
@Test
129208
public void testWithLabelRel() {
130209
final String cypher = "CALL algo.louvain('Node', 'TYPE', {concurrency:1}) " +

tests/src/test/java/org/neo4j/graphalgo/impl/LouvainMultiLevelTest.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ public void testComplex() throws Exception {
137137
}
138138
System.out.println("level " + i + ": " + Arrays.toString(dendogram[i - 1]));
139139
}
140+
140141
assertArrayEquals(new int[]{0, 0, 0, 1, 1, 1, 2, 2, 2}, dendogram[0]);
142+
assertArrayEquals(new int[]{0, 0, 0, 1, 1, 1, 2, 2, 2}, algorithm.getCommunityIds());
141143
}
142144
}

0 commit comments

Comments
 (0)