@@ -105,43 +105,34 @@ function _sort_col(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1)
105105 return matrix[:, index]
106106end
107107
108- function _sort_matrix (matrix:: AbstractArray , k:: Int ; rev:: Bool = true , sortby = nothing )
108+ function _topk_matrix (matrix:: AbstractArray , k:: Int ; rev:: Bool = true , sortby = nothing )
109109 if sortby === nothing
110110 return sort (matrix, dims = 2 ; rev)[:, 1 : k]
111111 else
112112 return _sort_col (matrix; rev, sortby)[:, 1 : k]
113113 end
114114end
115115
116- function _sort_batch (matrices, k:: Int ; rev:: Bool = true , sortby = nothing )
117- return map (x -> _sort_matrix (x, k; rev, sortby), matrices)
118- end
119-
120- function _topk_batch (matrix:: AbstractArray , number_graphs:: Int , k:: Int ; rev:: Bool = true ,
116+ function _topk_batch (matrices:: AbstractArray , k:: Int ; rev:: Bool = true ,
121117 sortby = nothing )
122- tensor_matrix = reshape (matrix, size (matrix, 1 ), size (matrix, 2 ) ÷ number_graphs,
123- number_graphs)
124- sorted_matrix = _sort_batch (eachslice (tensor_matrix, dims = 3 ), k; rev, sortby)
118+ sorted_matrix = map (x -> _topk_matrix (x, k; rev, sortby), matrices)
125119 return reduce (hcat, sorted_matrix)
126120end
127121
128- function _topk (matrix:: AbstractArray , number_graphs:: Int , k:: Int ; rev:: Bool = true ,
129- sortby = nothing )
130- if number_graphs == 1
131- return _sort_matrix (matrix, k; rev, sortby)
132- else
133- return _topk_batch (matrix, number_graphs, k; rev, sortby)
134- end
135- end
136-
137122"""
138123 topk_nodes(g, feat, k; rev = true, sortby = nothing)
139124
140125Graph-wise top-k on node features `feat` according to the `sortby` feature index.
141126"""
142127function topk_nodes (g:: GNNGraph , feat:: Symbol , k:: Int ; rev = true , sortby = nothing )
143- matrix = getproperty (g. ndata, feat)
144- return _topk (matrix, g. num_graphs, k; rev, sortby)
128+ if g. num_graphs == 1
129+ matrix = getproperty (g. ndata, feat)
130+ return _topk_matrix (matrix, k; rev, sortby)
131+ else
132+ graphs = [getgraph (g, i) for i in 1 : (g. num_graphs)]
133+ matrices = map (graph -> getproperty (graph. ndata, feat), graphs)
134+ return _topk_batch (matrices, k; rev, sortby)
135+ end
145136end
146137
147138"""
150141Graph-wise top-k on edge features `feat` according to the `sortby` feature index.
151142"""
152143function topk_edges (g:: GNNGraph , feat:: Symbol , k:: Int ; rev = true , sortby = nothing )
153- matrix = getproperty (g. edata, feat)
154- return _topk (matrix, g. num_graphs, k; rev, sortby)
144+ if g. num_graphs == 1
145+ matrix = getproperty (g. edata, feat)
146+ return _topk_matrix (matrix, k; rev, sortby)
147+ else
148+ graphs = [getgraph (g, i) for i in 1 : (g. num_graphs)]
149+ matrices = map (graph -> getproperty (graph. edata, feat), graphs)
150+ return _topk_batch (matrices, k; rev, sortby)
151+ end
155152end
0 commit comments