Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.

Commit 824304e

Browse files
committed
Neighbors pref attachment (#817)
* added function for preferential attachment * add total neighbours
1 parent 06b4aee commit 824304e

File tree

5 files changed

+442
-32
lines changed

5 files changed

+442
-32
lines changed

algo/src/main/java/org/neo4j/graphalgo/linkprediction/LinkPrediction.java

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@
2121
import org.neo4j.graphalgo.core.ProcedureConfiguration;
2222
import org.neo4j.graphdb.Direction;
2323
import org.neo4j.graphdb.Node;
24-
import org.neo4j.graphdb.Relationship;
2524
import org.neo4j.graphdb.RelationshipType;
2625
import org.neo4j.kernel.internal.GraphDatabaseAPI;
2726
import org.neo4j.procedure.Context;
2827
import org.neo4j.procedure.Description;
2928
import org.neo4j.procedure.Name;
3029
import org.neo4j.procedure.UserFunction;
3130

32-
import java.util.*;
31+
import java.util.Map;
32+
import java.util.Set;
3333

3434
public class LinkPrediction {
3535
@Context
@@ -50,7 +50,7 @@ public double adamicAdarSimilarity(@Name("node1") Node node1, @Name("node2") Nod
5050
RelationshipType relationshipType = configuration.getRelationship();
5151
Direction direction = configuration.getDirection(Direction.BOTH);
5252

53-
Set<Node> neighbors = new CommonNeighborsFinder(api).findCommonNeighbors(node1, node2, relationshipType, direction);
53+
Set<Node> neighbors = new NeighborsFinder(api).findCommonNeighbors(node1, node2, relationshipType, direction);
5454
return neighbors.stream().mapToDouble(nb -> 1.0 / Math.log(degree(relationshipType, direction, nb))).sum();
5555
}
5656

@@ -69,7 +69,7 @@ public double resourceAllocationSimilarity(@Name("node1") Node node1, @Name("nod
6969
RelationshipType relationshipType = configuration.getRelationship();
7070
Direction direction = configuration.getDirection(Direction.BOTH);
7171

72-
Set<Node> neighbors = new CommonNeighborsFinder(api).findCommonNeighbors(node1, node2, relationshipType, direction);
72+
Set<Node> neighbors = new NeighborsFinder(api).findCommonNeighbors(node1, node2, relationshipType, direction);
7373
return neighbors.stream().mapToDouble(nb -> 1.0 / degree(relationshipType, direction, nb)).sum();
7474
}
7575

@@ -86,10 +86,47 @@ public double commonNeighbors(@Name("node1") Node node1, @Name("node2") Node nod
8686
RelationshipType relationshipType = configuration.getRelationship();
8787
Direction direction = configuration.getDirection(Direction.BOTH);
8888

89-
Set<Node> neighbors = new CommonNeighborsFinder(api).findCommonNeighbors(node1, node2, relationshipType, direction);
89+
Set<Node> neighbors = new NeighborsFinder(api).findCommonNeighbors(node1, node2, relationshipType, direction);
9090
return neighbors.size();
9191
}
9292

93+
@UserFunction("algo.linkprediction.preferentialAttachment")
94+
@Description("algo.linkprediction.preferentialAttachment(node1:Node, node2:Node, {relationshipQuery:'relationshipName', direction:'BOTH'}) " +
95+
"given two nodes, calculate Preferential Attachment")
96+
public double preferentialAttachment(@Name("node1") Node node1, @Name("node2") Node node2,
97+
@Name(value = "config", defaultValue = "{}") Map<String, Object> config) {
98+
if (node1 == null || node2 == null) {
99+
throw new RuntimeException("Nodes must not be null");
100+
}
101+
102+
ProcedureConfiguration configuration = ProcedureConfiguration.create(config);
103+
RelationshipType relationshipType = configuration.getRelationship();
104+
Direction direction = configuration.getDirection(Direction.BOTH);
105+
106+
return getDegree(node1, relationshipType, direction) * getDegree(node2, relationshipType, direction);
107+
}
108+
109+
@UserFunction("algo.linkprediction.totalNeighbors")
110+
@Description("algo.linkprediction.totalNeighbors(node1:Node, node2:Node, {relationshipQuery:'relationshipName', direction:'BOTH'}) " +
111+
"given two nodes, calculate Total Neighbors")
112+
public double totalNeighbors(@Name("node1") Node node1, @Name("node2") Node node2,
113+
@Name(value = "config", defaultValue = "{}") Map<String, Object> config) {
114+
ProcedureConfiguration configuration = ProcedureConfiguration.create(config);
115+
RelationshipType relationshipType = configuration.getRelationship();
116+
Direction direction = configuration.getDirection(Direction.BOTH);
117+
118+
NeighborsFinder neighborsFinder = new NeighborsFinder(api);
119+
120+
Set<Node> neighbors = neighborsFinder.findNeighbors(node1, relationshipType, direction);
121+
neighbors.addAll(neighborsFinder.findNeighbors(node2, relationshipType, direction));
122+
123+
return neighbors.size();
124+
}
125+
126+
private int getDegree(Node node, RelationshipType relationshipType, Direction direction) {
127+
return relationshipType == null ? node.getDegree(direction) : node.getDegree(relationshipType, direction);
128+
}
129+
93130

94131
private int degree(RelationshipType relationshipType, Direction direction, Node node) {
95132
return relationshipType == null ? node.getDegree(direction) : node.getDegree(relationshipType, direction);

algo/src/main/java/org/neo4j/graphalgo/linkprediction/CommonNeighborsFinder.java renamed to algo/src/main/java/org/neo4j/graphalgo/linkprediction/NeighborsFinder.java

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212

1313
import static org.neo4j.graphdb.Direction.*;
1414

15-
public class CommonNeighborsFinder {
15+
public class NeighborsFinder {
1616

1717
private GraphDatabaseAPI api;
1818

19-
public CommonNeighborsFinder(GraphDatabaseAPI api) {
19+
public NeighborsFinder(GraphDatabaseAPI api) {
2020
this.api = api;
2121
}
2222

@@ -25,23 +25,19 @@ public Set<Node> findCommonNeighbors(Node node1, Node node2, RelationshipType re
2525
return Collections.emptySet();
2626
}
2727

28-
Set<Node> neighbors = findPotentialNeighbors(node1, relationshipType, direction);
28+
Set<Node> neighbors = findNeighbors(node1, relationshipType, direction);
2929
neighbors.removeIf(node -> noCommonNeighbors(node, relationshipType, flipDirection(direction), node2));
3030
return neighbors;
3131
}
3232

33-
private Direction flipDirection(Direction direction) {
34-
switch(direction) {
35-
case OUTGOING:
36-
return INCOMING;
37-
case INCOMING:
38-
return OUTGOING;
39-
default:
40-
return BOTH;
41-
}
33+
public Set<Node> findNeighbors(Node node1, Node node2, RelationshipType relationshipType, Direction direction) {
34+
Set<Node> node1Neighbors = findNeighbors(node1, relationshipType, direction);
35+
Set<Node> node2Neighbors = findNeighbors(node2, relationshipType, direction);
36+
node1Neighbors.addAll(node2Neighbors);
37+
return node1Neighbors;
4238
}
4339

44-
private Set<Node> findPotentialNeighbors(Node node, RelationshipType relationshipType, Direction direction) {
40+
public Set<Node> findNeighbors(Node node, RelationshipType relationshipType, Direction direction) {
4541
Set<Node> neighbors = new HashSet<>();
4642

4743
for (Relationship rel : loadRelationships(node, relationshipType, direction)) {
@@ -54,6 +50,17 @@ private Set<Node> findPotentialNeighbors(Node node, RelationshipType relationshi
5450
return neighbors;
5551
}
5652

53+
private Direction flipDirection(Direction direction) {
54+
switch(direction) {
55+
case OUTGOING:
56+
return INCOMING;
57+
case INCOMING:
58+
return OUTGOING;
59+
default:
60+
return BOTH;
61+
}
62+
}
63+
5764
private boolean noCommonNeighbors(Node node, RelationshipType relationshipType, Direction direction, Node node2) {
5865
for (Relationship rel : loadRelationships(node, relationshipType, direction)) {
5966
if (rel.getOtherNode(node).equals(node2)) {
Lines changed: 91 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import org.junit.Before;
44
import org.junit.Rule;
55
import org.junit.Test;
6-
import org.neo4j.graphalgo.linkprediction.CommonNeighborsFinder;
6+
import org.neo4j.graphalgo.linkprediction.NeighborsFinder;
77
import org.neo4j.graphdb.Direction;
88
import org.neo4j.graphdb.Node;
99
import org.neo4j.graphdb.RelationshipType;
@@ -13,9 +13,12 @@
1313

1414
import java.util.Set;
1515

16+
import static org.hamcrest.MatcherAssert.assertThat;
17+
import static org.hamcrest.core.IsCollectionContaining.hasItem;
18+
import static org.hamcrest.core.IsCollectionContaining.hasItems;
1619
import static org.junit.Assert.assertEquals;
1720

18-
public class CommonNeighborsFinderTest {
21+
public class NeighborsFinderTest {
1922

2023
@Rule
2124
public final ImpermanentDatabaseRule DB = new ImpermanentDatabaseRule();
@@ -39,12 +42,12 @@ public void excludeDirectRelationships() throws Throwable {
3942
tx.success();
4043
}
4144

42-
CommonNeighborsFinder commonNeighborsFinder = new CommonNeighborsFinder(api);
45+
NeighborsFinder neighborsFinder = new NeighborsFinder(api);
4346

4447
try (Transaction tx = api.beginTx()) {
4548
Node node1 = api.getNodeById(0);
4649
Node node2 = api.getNodeById(1);
47-
Set<Node> neighbors = commonNeighborsFinder.findCommonNeighbors(node1, node2, null, Direction.BOTH);
50+
Set<Node> neighbors = neighborsFinder.findCommonNeighbors(node1, node2, null, Direction.BOTH);
4851

4952
assertEquals(0, neighbors.size());
5053
}
@@ -59,11 +62,11 @@ public void sameNodeHasNoCommonNeighbors() throws Throwable {
5962
tx.success();
6063
}
6164

62-
CommonNeighborsFinder commonNeighborsFinder = new CommonNeighborsFinder(api);
65+
NeighborsFinder neighborsFinder = new NeighborsFinder(api);
6366

6467
try (Transaction tx = api.beginTx()) {
6568
Node node1 = api.getNodeById(0);
66-
Set<Node> neighbors = commonNeighborsFinder.findCommonNeighbors(node1, node1, null, Direction.BOTH);
69+
Set<Node> neighbors = neighborsFinder.findCommonNeighbors(node1, node1, null, Direction.BOTH);
6770

6871
assertEquals(0, neighbors.size());
6972
}
@@ -86,12 +89,12 @@ public void findNeighborsExcludingDirection() throws Throwable {
8689
tx.success();
8790
}
8891

89-
CommonNeighborsFinder commonNeighborsFinder = new CommonNeighborsFinder(api);
92+
NeighborsFinder neighborsFinder = new NeighborsFinder(api);
9093

9194
try (Transaction tx = api.beginTx()) {
9295
Node node1 = api.getNodeById(0);
9396
Node node2 = api.getNodeById(1);
94-
Set<Node> neighbors = commonNeighborsFinder.findCommonNeighbors(node1, node2, null, Direction.BOTH);
97+
Set<Node> neighbors = neighborsFinder.findCommonNeighbors(node1, node2, null, Direction.BOTH);
9598

9699
assertEquals(2, neighbors.size());
97100
}
@@ -111,12 +114,12 @@ public void findOutgoingNeighbors() throws Throwable {
111114
tx.success();
112115
}
113116

114-
CommonNeighborsFinder commonNeighborsFinder = new CommonNeighborsFinder(api);
117+
NeighborsFinder neighborsFinder = new NeighborsFinder(api);
115118

116119
try (Transaction tx = api.beginTx()) {
117120
Node node1 = api.getNodeById(0);
118121
Node node2 = api.getNodeById(1);
119-
Set<Node> neighbors = commonNeighborsFinder.findCommonNeighbors(node1, node2, FOLLOWS, Direction.OUTGOING);
122+
Set<Node> neighbors = neighborsFinder.findCommonNeighbors(node1, node2, FOLLOWS, Direction.OUTGOING);
120123

121124
assertEquals(1, neighbors.size());
122125
}
@@ -136,12 +139,12 @@ public void findIncomingNeighbors() throws Throwable {
136139
tx.success();
137140
}
138141

139-
CommonNeighborsFinder commonNeighborsFinder = new CommonNeighborsFinder(api);
142+
NeighborsFinder neighborsFinder = new NeighborsFinder(api);
140143

141144
try (Transaction tx = api.beginTx()) {
142145
Node node1 = api.getNodeById(0);
143146
Node node2 = api.getNodeById(1);
144-
Set<Node> neighbors = commonNeighborsFinder.findCommonNeighbors(node1, node2, FOLLOWS, Direction.INCOMING);
147+
Set<Node> neighbors = neighborsFinder.findCommonNeighbors(node1, node2, FOLLOWS, Direction.INCOMING);
145148

146149
assertEquals(1, neighbors.size());
147150
}
@@ -164,18 +167,92 @@ public void findNeighborsOfSpecificRelationshipType() throws Throwable {
164167
tx.success();
165168
}
166169

167-
CommonNeighborsFinder commonNeighborsFinder = new CommonNeighborsFinder(api);
170+
NeighborsFinder neighborsFinder = new NeighborsFinder(api);
168171

169172
try (Transaction tx = api.beginTx()) {
170173
Node node1 = api.getNodeById(0);
171174
Node node2 = api.getNodeById(1);
172-
Set<Node> neighbors = commonNeighborsFinder.findCommonNeighbors(node1, node2, COLLEAGUE, Direction.BOTH);
175+
Set<Node> neighbors = neighborsFinder.findCommonNeighbors(node1, node2, COLLEAGUE, Direction.BOTH);
173176

174177
assertEquals(1, neighbors.size());
175178
}
176179
}
177180

181+
@Test
182+
public void dontCountDuplicates() throws Throwable {
183+
184+
Node node1;
185+
Node node2;
186+
Node node3;
187+
Node node4;
188+
try (Transaction tx = api.beginTx()) {
189+
node1 = api.createNode();
190+
node2 = api.createNode();
191+
node3 = api.createNode();
192+
node4 = api.createNode();
193+
194+
node1.createRelationshipTo(node3, FRIEND);
195+
node2.createRelationshipTo(node3, FRIEND);
196+
node1.createRelationshipTo(node4, COLLEAGUE);
197+
node2.createRelationshipTo(node4, COLLEAGUE);
198+
199+
tx.success();
200+
}
201+
202+
NeighborsFinder neighborsFinder = new NeighborsFinder(api);
203+
204+
try (Transaction tx = api.beginTx()) {
205+
Set<Node> neighbors = neighborsFinder.findNeighbors(node1, node2, null, Direction.BOTH);
206+
207+
assertEquals(2, neighbors.size());
208+
assertThat(neighbors, hasItems(node3, node4));
209+
}
210+
}
211+
212+
@Test
213+
public void otherNodeCountsAsNeighbor() throws Throwable {
214+
215+
Node node1;
216+
Node node2;
217+
try (Transaction tx = api.beginTx()) {
218+
node1 = api.createNode();
219+
node2 = api.createNode();
220+
node1.createRelationshipTo(node2, FRIEND);
178221

222+
tx.success();
223+
}
224+
225+
NeighborsFinder neighborsFinder = new NeighborsFinder(api);
226+
227+
try (Transaction tx = api.beginTx()) {
228+
Set<Node> neighbors = neighborsFinder.findNeighbors(node1, node2, null, Direction.BOTH);
229+
230+
assertEquals(2, neighbors.size());
231+
assertThat(neighbors, hasItems(node1, node2));
232+
}
233+
}
234+
235+
@Test
236+
public void otherNodeCountsAsOutgoingNeighbor() throws Throwable {
237+
Node node1;
238+
Node node2;
239+
try (Transaction tx = api.beginTx()) {
240+
node1 = api.createNode();
241+
node2 = api.createNode();
242+
node1.createRelationshipTo(node2, FRIEND);
243+
244+
tx.success();
245+
}
246+
247+
NeighborsFinder neighborsFinder = new NeighborsFinder(api);
248+
249+
try (Transaction tx = api.beginTx()) {
250+
Set<Node> neighbors = neighborsFinder.findNeighbors(node1, node2, null, Direction.OUTGOING);
251+
252+
assertEquals(1, neighbors.size());
253+
assertThat(neighbors, hasItems(node2));
254+
}
255+
}
179256

180257
}
181258

0 commit comments

Comments
 (0)