Skip to content

Commit 1bb7f47

Browse files
authored
Avoid large device allocation in UMAP with nndescent (#6292)
Currently `NNDescent` returns two arrays: - `graph.graph()`: (n x graph_degree) on host - `graph.distances()`: (n x graph_degree) on device Downstream, the rest of UMAP wants both of these to be device arrays of shape (n x n_neighbors). Currently we copy `graph.graph()` to a temporary device array, then slice and and copy it to the output array `out.knn_indices`. Ideally we'd force `graph_degree = n_neighbors` to avoid the slicing entirely (and reduce the size of the intermediate results). However, it seems like currently there's a bug in `NNDescent` where reducing `graph_degree` to `n_neighbors` causes a significant decrease in result quality. So for now we need to keep the slicing around. We can avoid allocating the temporary device array though, instead doing the slicing on host. Doing this avoids allocating a (n x graph_degree) device array entirely; for large `n` this can be a significant savings (47 GiB on one test problem I was trying). We still should fix the `graph_degree` issue, but for now this should help unblock running UMAP on very large datasets. Authors: - Jim Crist-Harif (https://github.com/jcrist) Approvers: - Divye Gala (https://github.com/divyegala) - William Hicks (https://github.com/wphicks) URL: #6292
1 parent def265e commit 1bb7f47

File tree

1 file changed

+28
-17
lines changed

1 file changed

+28
-17
lines changed

cpp/src/umap/knn_graph/algo.cuh

+28-17
Original file line numberDiff line numberDiff line change
@@ -118,31 +118,42 @@ inline void launcher(const raft::handle_t& handle,
118118
// TODO: use nndescent from cuvs
119119
RAFT_EXPECTS(static_cast<size_t>(n_neighbors) <= params->nn_descent_params.graph_degree,
120120
"n_neighbors should be smaller than the graph degree computed by nn descent");
121+
RAFT_EXPECTS(params->nn_descent_params.return_distances,
122+
"return_distances for nn descent should be set to true to be used for UMAP");
121123

122124
auto graph = get_graph_nnd(handle, inputsA, params);
123125

124-
auto indices_d = raft::make_device_matrix<int64_t, int64_t>(
125-
handle, inputsA.n, params->nn_descent_params.graph_degree);
126-
127-
raft::copy(indices_d.data_handle(),
128-
graph.graph().data_handle(),
129-
inputsA.n * params->nn_descent_params.graph_degree,
130-
stream);
131-
126+
// `graph.graph()` is a host array (n x graph_degree).
127+
// Slice and copy to a temporary host array (n x n_neighbors), then copy
128+
// that to the output device array `out.knn_indices` (n x n_neighbors).
129+
// TODO: force graph_degree = n_neighbors so the temporary host array and
130+
// slice isn't necessary.
131+
auto temp_indices_h = raft::make_host_matrix<int64_t, int64_t>(inputsA.n, n_neighbors);
132+
size_t graph_degree = params->nn_descent_params.graph_degree;
133+
#pragma omp parallel for
134+
for (size_t i = 0; i < static_cast<size_t>(inputsA.n); i++) {
135+
for (int j = 0; j < n_neighbors; j++) {
136+
auto target = temp_indices_h.data_handle();
137+
auto source = graph.graph().data_handle();
138+
target[i * n_neighbors + j] = source[i * graph_degree + j];
139+
}
140+
}
141+
raft::copy(handle,
142+
raft::make_device_matrix_view(out.knn_indices, inputsA.n, n_neighbors),
143+
temp_indices_h.view());
144+
145+
// `graph.distances()` is a device array (n x graph_degree).
146+
// Slice and copy to the output device array `out.knn_dists` (n x n_neighbors).
147+
// TODO: force graph_degree = n_neighbors so this slice isn't necessary.
132148
raft::matrix::slice_coordinates coords{static_cast<int64_t>(0),
133149
static_cast<int64_t>(0),
134150
static_cast<int64_t>(inputsA.n),
135151
static_cast<int64_t>(n_neighbors)};
136-
137-
RAFT_EXPECTS(graph.distances().has_value(),
138-
"return_distances for nn descent should be set to true to be used for UMAP");
139-
auto out_knn_dists_view = raft::make_device_matrix_view(out.knn_dists, inputsA.n, n_neighbors);
140152
raft::matrix::slice<float, int64_t, raft::row_major>(
141-
handle, raft::make_const_mdspan(graph.distances().value()), out_knn_dists_view, coords);
142-
auto out_knn_indices_view =
143-
raft::make_device_matrix_view(out.knn_indices, inputsA.n, n_neighbors);
144-
raft::matrix::slice<int64_t, int64_t, raft::row_major>(
145-
handle, raft::make_const_mdspan(indices_d.view()), out_knn_indices_view, coords);
153+
handle,
154+
raft::make_const_mdspan(graph.distances().value()),
155+
raft::make_device_matrix_view(out.knn_dists, inputsA.n, n_neighbors),
156+
coords);
146157
}
147158
}
148159

0 commit comments

Comments
 (0)