Skip to content

Commit

Permalink
export and use only double correlationFunction
Browse files Browse the repository at this point in the history
  • Loading branch information
DomFijan committed Feb 6, 2025
1 parent c5897c3 commit 0b8be1c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 21 deletions.
2 changes: 1 addition & 1 deletion freud/density.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class CorrelationFunction(_SpatialHistogram1D):

def __init__(self, bins, r_max):
self._bins = int(bins)
self._cpp_obj = freud._density.CorrelationFunctionComplex(self._bins, r_max)
self._cpp_obj = freud._density.CorrelationFunction(self._bins, r_max)
self.r_max = r_max
self.is_complex = False

Expand Down
27 changes: 7 additions & 20 deletions freud/density/export-CorrelationFunction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,16 @@ using nb_array = nanobind::ndarray<T, shape, nanobind::device::cpu, nanobind::c_
namespace wrap {

// Wrapper function for accumulate
template<typename T>
void accumulateCF(const std::shared_ptr<CorrelationFunction<T>>& self,
void accumulateCF(const std::shared_ptr<CorrelationFunction<std::complex<double>>>& self,
const std::shared_ptr<locality::NeighborQuery> neighbor_query,
const nb_array<T, nanobind::shape<-1>>& values,
const nb_array<std::complex<double>, nanobind::shape<-1>>& values,
const nb_array<float, nanobind::shape<-1, 3>>& query_points,
const nb_array<T, nanobind::shape<-1>>& query_values,
const nb_array<std::complex<double>, nanobind::shape<-1>>& query_values,
std::shared_ptr<locality::NeighborList> nlist, const locality::QueryArgs& qargs)
{
auto* values_data = reinterpret_cast<T*>(values.data());
auto* values_data = reinterpret_cast<std::complex<double>*>(values.data());
auto* query_points_data = reinterpret_cast<vec3<float>*>(query_points.data());
auto* query_values_data = reinterpret_cast<T*>(query_values.data());
auto* query_values_data = reinterpret_cast<std::complex<double>*>(query_values.data());

const unsigned int num_query_points = query_points.shape(0);

Expand All @@ -44,22 +43,10 @@ namespace detail {

void export_CorrelationFunction(nanobind::module_& m)
{
nanobind::class_<CorrelationFunction<double>>(m, "CorrelationFunctionDouble")
.def(nanobind::init<unsigned int, float>(), nanobind::arg("bins"), nanobind::arg("r_max"))
.def("reset", &CorrelationFunction<double>::reset)
.def("accumulate", &wrap::accumulateCF<double>, nanobind::arg("neighbor_query"),
nanobind::arg("values"), nanobind::arg("query_points"), nanobind::arg("query_values"),
nanobind::arg("nlist").none(), nanobind::arg("qargs"))
.def("getBinCenters", &CorrelationFunction<double>::getBinCenters)
.def("getBinCounts", &CorrelationFunction<double>::getBinCounts)
.def("getAxisSizes", &CorrelationFunction<double>::getAxisSizes)
.def("getBinEdges", &CorrelationFunction<double>::getBinEdges)
.def("getBox", &CorrelationFunction<double>::getBox)
.def("getCorrelation", &CorrelationFunction<double>::getCorrelation);
nanobind::class_<CorrelationFunction<std::complex<double>>>(m, "CorrelationFunctionComplex")
nanobind::class_<CorrelationFunction<std::complex<double>>>(m, "CorrelationFunction")
.def(nanobind::init<unsigned int, float>(), nanobind::arg("bins"), nanobind::arg("r_max"))
.def("reset", &CorrelationFunction<std::complex<double>>::reset)
.def("accumulate", &wrap::accumulateCF<std::complex<double>>, nanobind::arg("neighbor_query"),
.def("accumulate", &wrap::accumulateCF, nanobind::arg("neighbor_query"),
nanobind::arg("values"), nanobind::arg("query_points"), nanobind::arg("query_values"),
nanobind::arg("nlist").none(), nanobind::arg("qargs"))
.def("getBinCenters", &CorrelationFunction<std::complex<double>>::getBinCenters)
Expand Down

0 comments on commit 0b8be1c

Please sign in to comment.