|
1 | 1 | @testset "add self-loops" begin |
2 | 2 | A = [1 1 0 0 |
3 | | - 0 0 1 0 |
4 | | - 0 0 0 1 |
5 | | - 1 0 0 0] |
| 3 | + 0 0 1 0 |
| 4 | + 0 0 0 1 |
| 5 | + 1 0 0 0] |
6 | 6 | A2 = [2 1 0 0 |
7 | | - 0 1 1 0 |
8 | | - 0 0 1 1 |
9 | | - 1 0 0 1] |
| 7 | + 0 1 1 0 |
| 8 | + 0 0 1 1 |
| 9 | + 1 0 0 1] |
10 | 10 |
|
11 | 11 | g = GNNGraph(A; graph_type = GRAPH_T) |
12 | 12 | fg2 = add_self_loops(g) |
|
18 | 18 |
|
19 | 19 | @testset "batch" begin |
20 | 20 | g1 = GNNGraph(random_regular_graph(10, 2), ndata = rand(16, 10), |
21 | | - graph_type = GRAPH_T) |
| 21 | + graph_type = GRAPH_T) |
22 | 22 | g2 = GNNGraph(random_regular_graph(4, 2), ndata = rand(16, 4), graph_type = GRAPH_T) |
23 | 23 | g3 = GNNGraph(random_regular_graph(7, 2), ndata = rand(16, 7), graph_type = GRAPH_T) |
24 | 24 |
|
|
44 | 44 | # Batch of batches |
45 | 45 | g123123 = Flux.batch([g123, g123]) |
46 | 46 | @test g123123.graph_indicator == |
47 | | - [fill(1, 10); fill(2, 4); fill(3, 7); fill(4, 10); fill(5, 4); fill(6, 7)] |
| 47 | + [fill(1, 10); fill(2, 4); fill(3, 7); fill(4, 10); fill(5, 4); fill(6, 7)] |
48 | 48 | @test g123123.num_graphs == 6 |
49 | 49 | end |
50 | 50 |
|
|
67 | 67 | c = 3 |
68 | 68 | ngraphs = 10 |
69 | 69 | gs = [rand_graph(n, c * n, ndata = rand(2, n), edata = rand(3, c * n), |
70 | | - graph_type = GRAPH_T) |
71 | | - for _ in 1:ngraphs] |
| 70 | + graph_type = GRAPH_T) |
| 71 | + for _ in 1:ngraphs] |
72 | 72 | gall = Flux.batch(gs) |
73 | 73 | gs2 = Flux.unbatch(gall) |
74 | 74 | @test gs2[1] == gs[1] |
|
77 | 77 |
|
78 | 78 | @testset "getgraph" begin |
79 | 79 | g1 = GNNGraph(random_regular_graph(10, 2), ndata = rand(16, 10), |
80 | | - graph_type = GRAPH_T) |
| 80 | + graph_type = GRAPH_T) |
81 | 81 | g2 = GNNGraph(random_regular_graph(4, 2), ndata = rand(16, 4), graph_type = GRAPH_T) |
82 | 82 | g3 = GNNGraph(random_regular_graph(7, 2), ndata = rand(16, 7), graph_type = GRAPH_T) |
83 | 83 | g = Flux.batch([g1, g2, g3]) |
@@ -268,3 +268,14 @@ end end |
268 | 268 | @test nv(DG) == g.num_nodes |
269 | 269 | @test ne(DG) == g.num_edges |
270 | 270 | end |
| 271 | + |
| 272 | +@testset "random_walk_pe" begin |
| 273 | + s = [1, 2, 2, 3] |
| 274 | + t = [2, 1, 3, 2] |
| 275 | + ndata = [-1, 0, 1] |
| 276 | + g = GNNGraph(s, t, graph_type = GRAPH_T, ndata = ndata) |
| 277 | + output = random_walk_pe(g, 3) |
| 278 | + @test output == [0.0 0.0 0.0 |
| 279 | + 0.5 1.0 0.5 |
| 280 | + 0.0 0.0 0.0] |
| 281 | +end |
0 commit comments