@@ -11,7 +11,7 @@ const ADJMAT_T = AbstractMatrix
1111const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
1212
1313"""
14- GNNGraph(data; [graph_type, nf, ef, gf, num_nodes, num_graphs, graph_indicator, dir])
14+ GNNGraph(data; [graph_type, nf, ef, gf, num_nodes, graph_indicator, dir])
1515 GNNGraph(g::GNNGraph; [nf, ef, gf])
1616
1717A type representing a graph structure and storing also arrays
@@ -23,6 +23,11 @@ is governed by `graph_type`.
2323When constructed from another graph `g`, the internal graph representation
2424is preserved and shared.
2525
26+ A `GNNGraph` can also represent multiple graphs batched togheter
27+ (see [`Flux.batch`](@ref) or [`SparseArrays.blockdiag`](@ref)).
28+ The field `g.graph_indicator` contains the graph membership
29+ of each node.
30+
2631A `GNNGraph` is a LightGraphs' `AbstractGraph`, therefore any functionality
2732from the LightGraphs' graph library can be used on it.
2833
@@ -45,7 +50,6 @@ from the LightGraphs' graph library can be used on it.
4550- `dir`. The assumed edge direction when given adjacency matrix or adjacency list input data `g`.
4651 Possible values are `:out` and `:in`. Default `:out`.
4752- `num_nodes`. The number of nodes. If not specified, inferred from `g`. Default `nothing`.
48- - `num_graphs`. The number of graphs. Larger than 1 in case of batched graphs. Default `1`.
4953- `graph_indicator`. For batched graphs, a vector containeing the graph assigment of each node. Default `nothing`.
5054- `nf`: Node features. Either nothing, or an array whose last dimension has size num_nodes. Default `nothing`.
5155- `ef`: Edge features. Either nothing, or an array whose last dimension has size num_edges. Default `nothing`.
@@ -118,17 +122,17 @@ function GNNGraph(data;
118122
119123 @assert graph_type ∈ [:coo , :dense , :sparse ] " Invalid graph_type $graph_type requested"
120124 @assert dir ∈ [:in , :out ]
125+
121126 if graph_type == :coo
122127 g, num_nodes, num_edges = to_coo (data; num_nodes, dir)
123128 elseif graph_type == :dense
124129 g, num_nodes, num_edges = to_dense (data; dir)
125130 elseif graph_type == :sparse
126131 g, num_nodes, num_edges = to_sparse (data; dir)
127132 end
128- if num_graphs > 1
129- @assert len (graph_indicator) = num_nodes " When batching multiple graphs `graph_indicator` should be filled with the nodes' memberships."
130- end
131-
133+
134+ num_graphs = ! isnothing (graph_indicator) ? maximum (graph_indicator) : 1
135+
132136 # # Possible future implementation of feature maps.
133137 # # Currently this doesn't play well with zygote due to
134138 # # https://github.com/FluxML/Zygote.jl/issues/717
@@ -149,8 +153,8 @@ GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...)
149153
150154function GNNGraph (g:: AbstractGraph ; kws... )
151155 s = LightGraphs. src .(LightGraphs. edges (g))
152- t = LightGraphs. dst .(LightGraphs. edges (g))
153- GNNGraph ((s, t); kws... )
156+ t = LightGraphs. dst .(LightGraphs. edges (g))
157+ GNNGraph ((s, t); num_nodes = nv (g), kws... )
154158end
155159
156160function GNNGraph (g:: GNNGraph ;
@@ -431,19 +435,76 @@ function _catgraphs(g1::GNNGraph{<:COO_T}, g2::GNNGraph{<:COO_T})
431435 )
432436end
433437
434- # Cat public interfaces
438+ # ## Cat public interfaces #############
439+
440+ """
441+ blockdiag(xs::GNNGraph...)
442+
443+ Batch togheter multiple `GNNGraph`s into a single one
444+ containing the total number of nodes and edges of the original graphs.
445+
446+ Equivalent to [`Flux.batch`](@ref).
447+ """
435448function SparseArrays. blockdiag (g1:: GNNGraph , gothers:: GNNGraph... )
436- @assert length (gothers) >= 1
437449 g = g1
438450 for go in gothers
439451 g = _catgraphs (g, go)
440452 end
441453 return g
442454end
443455
456+ """
457+ batch(xs::Vector{<:GNNGraph})
458+
459+ Batch togheter multiple `GNNGraph`s into a single one
460+ containing the total number of nodes and edges of the original graphs.
461+
462+ Equivalent to [`SparseArrays.blockdiag`](@ref).
463+ """
444464Flux. batch (xs:: Vector{<:GNNGraph} ) = blockdiag (xs... )
445465# ########################
446466
467+ """
468+ subgraph(g::GNNGraph, i)
469+
470+ Return the subgraph of `g` induced by those nodes `v`
471+ for which `g.graph_indicator[v] ∈ i`. In other words, it
472+ extract the component graphs from a batched graph.
473+
474+ It also returns a vector `nodes` mapping the new nodes to the old ones.
475+ The node `i` in the subgraph corresponds to the node `nodes[i]` in `g`.
476+ """
477+ subgraph (g:: GNNGraph , i:: Int ) = subgraph (g:: GNNGraph{<:COO_T} , [i])
478+
479+ function subgraph (g:: GNNGraph{<:COO_T} , i:: AbstractVector )
480+ node_mask = g. graph_indicator .∈ Ref (i)
481+
482+ nodes = (1 : g. num_nodes)[node_mask]
483+ nodemap = Dict (v => vnew for (vnew, v) in enumerate (nodes))
484+
485+ graphmap = Dict (i => inew for (inew, i) in enumerate (i))
486+ graph_indicator = [graphmap[i] for i in g. graph_indicator[node_mask]]
487+
488+ s, t, w = g. graph
489+ edge_mask = s .∈ Ref (nodes)
490+ s = [nodemap[i] for i in s[edge_mask]]
491+ t = [nodemap[i] for i in t[edge_mask]]
492+ w = isnothing (w) ? nothing : w[edge_mask]
493+ nf = isnothing (g. nf) ? nothing : g. nf[:,node_mask]
494+ ef = isnothing (g. ef) ? nothing : g. ef[:,edge_mask]
495+ gf = isnothing (g. gf) ? nothing : g. gf[:,i]
496+
497+ num_nodes = length (graph_indicator)
498+ num_edges = length (s)
499+ num_graphs = length (i)
500+
501+ gnew = GNNGraph ((s,t,w),
502+ num_nodes, num_edges, num_graphs,
503+ graph_indicator,
504+ nf, ef, gf)
505+ return gnew, nodes
506+ end
507+
447508@non_differentiable normalized_laplacian (x... )
448509@non_differentiable normalized_adjacency (x... )
449510@non_differentiable scaled_laplacian (x... )
0 commit comments