Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions cpu/include/neighbors.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,18 @@
#include <cstdint>
#include <set>

using namespace std;

template <typename scalar_t>
int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
vector<int64_t>& neighbors_indices, vector<float>& dists, float radius,
int nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>& supports,
std::vector<int64_t>& neighbors_indices, std::vector<float>& dists, float radius,
int max_num, int mode, bool sorted);

template <typename scalar_t>
int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
vector<int64_t>& q_batches, vector<int64_t>& s_batches,
vector<int64_t>& neighbors_indices, vector<float>& dists,
int batch_nanoflann_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>& supports,
std::vector<int64_t>& q_batches, std::vector<int64_t>& s_batches,
std::vector<int64_t>& neighbors_indices, std::vector<float>& dists,
float radius, int max_num, int mode, bool sorted);

template <typename scalar_t>
void nanoflann_knn_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
vector<int64_t>& neighbors_indices, vector<float>& dists, int k);
void nanoflann_knn_neighbors(std::vector<scalar_t>& queries, std::vector<scalar_t>& supports,
std::vector<int64_t>& neighbors_indices, std::vector<float>& dists, int k);
4 changes: 2 additions & 2 deletions cpu/src/knn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ std::pair<at::Tensor, at::Tensor> dense_knn(at::Tensor support, at::Tensor query
CHECK_CPU(support);

int b = query.size(0);
vector<at::Tensor> batch_idx;
vector<at::Tensor> batch_dist;
std::vector<at::Tensor> batch_idx;
std::vector<at::Tensor> batch_dist;
for (int i = 0; i < b; i++)
{
auto out_pair = _single_batch_knn(support[i], query[i], k);
Expand Down
4 changes: 4 additions & 0 deletions cpu/src/neighbors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
#include <chrono>
#include <random>

using std::vector;
using std::pair;
using std::max;

template <typename scalar_t>
int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
vector<int64_t>& neighbors_indices, vector<float>& dists, float radius,
Expand Down