diff --git a/dpnp/backend/extensions/common/ext/common.hpp b/dpnp/backend/extensions/common/ext/common.hpp index 080df62b25e3..d56ac80d5d78 100644 --- a/dpnp/backend/extensions/common/ext/common.hpp +++ b/dpnp/backend/extensions/common/ext/common.hpp @@ -70,20 +70,6 @@ struct AtomicOp } }; -template -struct Less -{ - bool operator()(const T &lhs, const T &rhs) const - { - if constexpr (type_utils::is_complex_v) { - return dpctl::tensor::math_utils::less_complex(lhs, rhs); - } - else { - return std::less{}(lhs, rhs); - } - } -}; - template struct IsNan { @@ -106,6 +92,21 @@ struct IsNan } }; +template +struct Less +{ + bool operator()(const T &lhs, const T &rhs) const + { + if constexpr (type_utils::is_complex_v) { + return IsNan::isnan(rhs) || + dpctl::tensor::math_utils::less_complex(lhs, rhs); + } + else { + return IsNan::isnan(rhs) || std::less{}(lhs, rhs); + } + } +}; + template struct value_type_of_impl; diff --git a/dpnp/backend/extensions/statistics/CMakeLists.txt b/dpnp/backend/extensions/statistics/CMakeLists.txt index 1c9027870f92..9fa56d52dcd0 100644 --- a/dpnp/backend/extensions/statistics/CMakeLists.txt +++ b/dpnp/backend/extensions/statistics/CMakeLists.txt @@ -30,6 +30,8 @@ set(_module_src ${CMAKE_CURRENT_SOURCE_DIR}/histogram.cpp ${CMAKE_CURRENT_SOURCE_DIR}/histogramdd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/histogram_common.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/kth_element1d.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/partitioning.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sliding_dot_product1d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sliding_window1d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/statistics_py.cpp diff --git a/dpnp/backend/extensions/statistics/kth_element1d.cpp b/dpnp/backend/extensions/statistics/kth_element1d.cpp new file mode 100644 index 000000000000..c94fed4f8b7d --- /dev/null +++ b/dpnp/backend/extensions/statistics/kth_element1d.cpp @@ -0,0 +1,517 @@ +//***************************************************************************** +// Copyright (c) 2024-2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include +#include +#include +#include + +#include +#include + +// dpctl tensor headers +#include "dpctl4pybind11.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/type_dispatch.hpp" + +#include "ext/common.hpp" +#include "kth_element1d.hpp" +#include "partitioning.hpp" + +#include +#include + +namespace sycl_exp = sycl::ext::oneapi::experimental; +namespace dpctl_td_ns = dpctl::tensor::type_dispatch; +namespace dpctl_utils = dpctl::tensor::alloc_utils; + +using dpctl::tensor::usm_ndarray; + +using namespace statistics::partitioning; +using namespace ext::common; + +namespace +{ + +template +T NextAfter(T x) +{ + if constexpr (std::is_floating_point::value) { + return sycl::nextafter(x, std::numeric_limits::infinity()); + } + else if constexpr (std::is_integral::value) { + if (x < std::numeric_limits::max()) + return x + 1; + else + return x; + } + else if constexpr (type_utils::is_complex_v) { + if (x.imag() != std::numeric_limits::infinity()) { + return T{x.real(), NextAfter(x.imag())}; + } + else if (x.real() != std::numeric_limits::infinity()) { + return T{NextAfter(x.real()), -x.imag()}; + } + else { + return x; + } + } +} + +template +struct pick_pivot_kernel; + +template +struct kth_sorter_kernel; + +template +struct KthElementF +{ + static std::tuple + run_kth_sort(sycl::queue &exec_q, + const T *in, + const size_t k, + State &state, + const std::vector &depends) + { + auto device = exec_q.get_device(); + size_t local_mem_size = get_local_mem_size_in_bytes(device); + size_t temp_memory_size = + sycl_exp::default_sorters::joint_sorter<>::memory_required( + sycl::memory_scope::work_group, state.n); + size_t loc_items_mem = sizeof(T) * state.n; + + if ((temp_memory_size + loc_items_mem) > local_mem_size) + return {false, sycl::event{}}; + + auto e = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + const uint32_t local_size = get_max_local_size(exec_q); + const uint32_t WorkPI = CeilDiv(state.n, local_size); + auto work_sz = make_ndrange(state.n, local_size, WorkPI); + auto loc_items = + sycl::local_accessor(sycl::range<1>(state.n), cgh); + auto scratch = sycl::local_accessor( + sycl::range<1>(temp_memory_size), cgh); + + cgh.parallel_for>( + work_sz, [=](sycl::nd_item<1> item) { + auto group = item.get_group(); + auto sbg = item.get_sub_group(); + + if (state.stop[0]) + return; + + auto llid = item.get_local_linear_id(); + uint32_t sbg_size = sbg.get_max_local_range()[0]; + uint32_t sbg_llid = sbg.get_local_linear_id(); + auto local_size = item.get_group_range(0); + uint32_t nan_count = 0; + + uint32_t i_base = + sbg.get_group_id() * WorkPI * sbg_size + sbg_llid; + for (uint32_t i = 0; i < WorkPI; i++) { + uint32_t idx = i_base + i * sbg_size; + if (idx < state.n) { + loc_items[idx] = in[idx]; + if (IsNan::isnan(in[idx])) { + nan_count++; + } + } + } + + nan_count = sycl::reduce_over_group(group, nan_count, + sycl::plus<>()); + sycl::group_barrier(group); + + auto gh = sycl_exp::group_with_scratchpad( + group, sycl::span{&scratch[0], temp_memory_size}); + sycl_exp::joint_sort(gh, &loc_items[0], + &loc_items[0] + state.n, Less{}); + + sycl::group_barrier(group); + + if (group.leader()) { + state.values[0] = loc_items[k]; + state.values[1] = loc_items[k + 1]; + state.target_found[0] = true; + state.counters.nan_count[0] = nan_count; + } + }); + }); + + return {true, e}; + } + + static sycl::event run_pick_pivot(sycl::queue &queue, + T *in, + T *out, + uint64_t target, + State &state, + uint64_t items_to_sort, + uint64_t limit, + const std::vector &deps) + { + auto e = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + constexpr uint64_t group_size = 128; + + auto work_sz = make_ndrange(group_size, group_size, 1); + + size_t temp_memory_size = + sycl_exp::default_sorters::joint_sorter<>::memory_required( + sycl::memory_scope::work_group, limit); + + auto loc_items = + sycl::local_accessor(sycl::range<1>(items_to_sort), cgh); + auto scratch = sycl::local_accessor( + sycl::range<1>(temp_memory_size), cgh); + + cgh.parallel_for>(work_sz, [=](sycl::nd_item<1> + item) { + auto group = item.get_group(); + + if (state.stop[0]) + return; + + auto llid = item.get_local_linear_id(); + auto local_size = item.get_group_range(0); + + uint64_t num_elems = 0; + bool target_found = false; + + T *_in = nullptr; + if (group.leader()) { + state.update_counters(); + auto less_count = state.counters.less_count[0]; + bool left = target < less_count; + state.left[0] = left; + + if (left) { + _in = in; + num_elems = state.iteration_counters.less_count[0]; + if (target + 1 == less_count) { + _in[num_elems] = state.pivot[0]; + state.counters.less_count[0] += 1; + num_elems += 1; + } + } + else { + num_elems = + state.iteration_counters.greater_equal_count[0]; + _in = in + state.n - num_elems; + + if (target + 1 < + less_count + + state.iteration_counters.equal_count[0]) { + state.values[0] = state.pivot[0]; + state.values[1] = state.pivot[0]; + + state.stop[0] = true; + state.target_found[0] = true; + target_found = true; + } + } + state.reset_iteration_counters(); + } + + target_found = sycl::group_broadcast(group, target_found, 0); + _in = sycl::group_broadcast(group, _in, 0); + num_elems = sycl::group_broadcast(group, num_elems, 0); + + if (target_found) { + return; + } + + if (num_elems <= limit) { + auto gh = sycl_exp::group_with_scratchpad( + group, sycl::span{&scratch[0], temp_memory_size}); + if (num_elems > 0) + sycl_exp::joint_sort(gh, &_in[0], &_in[num_elems], + Less{}); + + if (group.leader()) { + uint64_t offset = state.counters.less_count[0]; + if (state.left[0]) { + offset = state.counters.less_count[0] - num_elems; + } + + int64_t idx = target - offset; + + state.values[0] = _in[idx]; + state.values[1] = _in[idx + 1]; + + state.stop[0] = true; + state.target_found[0] = true; + } + + return; + } + + uint64_t step = num_elems / items_to_sort; + for (uint32_t i = llid; i < items_to_sort; i += local_size) { + loc_items[i] = std::numeric_limits::max(); + uint32_t idx = i * step; + if (idx < num_elems) { + loc_items[i] = _in[idx]; + } + } + + sycl::group_barrier(group); + + auto gh = sycl_exp::group_with_scratchpad( + group, sycl::span{&scratch[0], temp_memory_size}); + sycl_exp::joint_sort(gh, &loc_items[0], + &loc_items[0] + items_to_sort, Less{}); + + state.num_elems[0] = num_elems; + + T new_pivot = loc_items[items_to_sort / 2]; + if (new_pivot != state.pivot[0] && !IsNan::isnan(new_pivot)) + { + if (group.leader()) { + state.pivot[0] = new_pivot; + } + return; + } + + auto start = llid + items_to_sort / 2 + 1; + uint32_t index = start; + for (uint32_t i = start; i < items_to_sort; i += local_size) { + if (loc_items[i] != new_pivot && + !IsNan::isnan(loc_items[i])) { + index = i; + break; + } + } + + index = + sycl::reduce_over_group(group, index, sycl::minimum<>()); + if (group.leader()) { + if (loc_items[index] != new_pivot || + !IsNan::isnan(loc_items[index])) { + // if all values are Nan just use it as pivot + // to filter out all the Nans + state.pivot[0] = loc_items[index]; + } + else { + // we are going to filter out new_pivot + // but we need to keep at least one since it + // could be our target (but not target + 1) + out[state.n - 1] = new_pivot; + state.iteration_counters.greater_equal_count[0] += 1; + state.counters.less_count[0] -= 1; + new_pivot = NextAfter(new_pivot); + state.pivot[0] = new_pivot; + } + } + }); + }); + + return e; + } + + static sycl::event run_partition(sycl::queue &exec_q, + T *in, + T *out, + PartitionState &state, + const std::vector &deps) + { + + uint32_t group_size = 128; + constexpr uint32_t WorkPI = 4; + return run_partition_one_pivot_cpu(exec_q, in, out, state, + deps, group_size); + } + + static sycl::event run_kth_element(sycl::queue &exec_q, + const T *in, + T *partitioned, + T *temp_buff, + const size_t k, + State &state, + PartitionState &pstate, + const std::vector &depends) + { + auto [success, evt] = run_kth_sort(exec_q, in, k, state, depends); + if (success) { + return evt; + } + + uint32_t items_to_sort = 127; + uint32_t limit = 4 * (items_to_sort + 1); + + uint32_t iterations = 1; + + if (state.n > limit) { + iterations = std::ceil(-std::log(double(state.n) / limit) / + std::log(0.536)) + + 1; + + // Ensure iterations are odd so the final result is always stored in + // 'partitioned' + iterations += 1 - iterations % 2; + } + + auto prev = run_pick_pivot(exec_q, const_cast(in), partitioned, k, + state, items_to_sort, limit, depends); + prev = run_partition(exec_q, const_cast(in), partitioned, pstate, + {prev}); + + T *_in = partitioned; + T *_out = temp_buff; + for (uint32_t i = 0; i < iterations - 1; ++i) { + prev = run_pick_pivot(exec_q, _in, _out, k, state, items_to_sort, + limit, {prev}); + prev = run_partition(exec_q, _in, _out, pstate, {prev}); + std::swap(_in, _out); + } + prev = run_pick_pivot(exec_q, _in, _out, k, state, items_to_sort, limit, + {prev}); + + return prev; + } + + static KthElement1d::RetT impl(sycl::queue &exec_queue, + const void *v_ain, + void *v_partitioned, + const size_t a_size, + const size_t k, + const std::vector &depends) + { + const T *ain = static_cast(v_ain); + T *partitioned = static_cast(v_partitioned); + + State state(exec_queue, a_size, partitioned); + PartitionState pstate(state); + + exec_queue.wait(); + auto init_e = state.init(exec_queue, depends); + init_e = pstate.init(exec_queue, {init_e}); + + auto temp_buff = dpctl_utils::smart_malloc(state.n, exec_queue, + sycl::usm::alloc::device); + auto evt = run_kth_element(exec_queue, ain, partitioned, + temp_buff.get(), k, state, pstate, {init_e}); + + bool found = false; + bool left = false; + uint64_t less_count = 0; + uint64_t greater_equal_count = 0; + uint64_t num_elems = 0; + uint64_t nan_count = 0; + auto copy_evt = exec_queue.copy(state.target_found, &found, 1, evt); + copy_evt = exec_queue.copy(state.left, &left, 1, copy_evt); + copy_evt = exec_queue.copy(state.counters.less_count, &less_count, 1, + copy_evt); + copy_evt = exec_queue.copy(state.counters.greater_equal_count, + &greater_equal_count, 1, copy_evt); + copy_evt = exec_queue.copy(state.num_elems, &num_elems, 1, copy_evt); + copy_evt = + exec_queue.copy(state.counters.nan_count, &nan_count, 1, copy_evt); + + copy_evt.wait(); + + uint64_t buff_offset = 0; + uint64_t elems_offset = less_count; + + if (!found) { + if (left) { + elems_offset = less_count - num_elems; + } + else { + buff_offset = a_size - num_elems; + } + } + else { + num_elems = 2; + elems_offset = k; + } + + state.cleanup(exec_queue); + + return {found, buff_offset, elems_offset, num_elems, nan_count}; + } +}; + +using SupportedTypes = std::tuple, + std::complex>; +} // namespace + +KthElement1d::KthElement1d() : dispatch_table("a") +{ + dispatch_table.populate_dispatch_table(); +} + +KthElement1d::RetT KthElement1d::call(const dpctl::tensor::usm_ndarray &a, + dpctl::tensor::usm_ndarray &partitioned, + const size_t k, + const std::vector &depends) +{ + validate(a, partitioned, k); + + const int a_typenum = a.get_typenum(); + auto kth_elem_func = dispatch_table.get(a_typenum); + + auto exec_q = a.get_queue(); + auto result = kth_elem_func(exec_q, a.get_data(), partitioned.get_data(), + a.get_shape(0), k, depends); + + return result; +} + +std::unique_ptr kth; + +void statistics::partitioning::populate_kth_element1d(py::module_ m) +{ + using namespace std::placeholders; + + kth.reset(new KthElement1d()); + + auto kth_func = [kthp = kth.get()]( + const dpctl::tensor::usm_ndarray &a, + dpctl::tensor::usm_ndarray &partitioned, const size_t k, + const std::vector &depends) { + return kthp->call(a, partitioned, k, depends); + }; + + m.def("kth_element", kth_func, "finding k and k+1 elements.", py::arg("a"), + py::arg("partitioned"), py::arg("k"), + py::arg("depends") = py::list()); + + auto kth_dtypes = [kthp = kth.get()]() { + return kthp->dispatch_table.get_all_supported_types(); + }; + + m.def("kth_element_dtypes", kth_dtypes, + "Get the supported data types for kth_element."); +} diff --git a/dpnp/backend/extensions/statistics/kth_element1d.hpp b/dpnp/backend/extensions/statistics/kth_element1d.hpp new file mode 100644 index 000000000000..b181028bb1e9 --- /dev/null +++ b/dpnp/backend/extensions/statistics/kth_element1d.hpp @@ -0,0 +1,55 @@ +//***************************************************************************** +// Copyright (c) 2024-2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include "ext/dispatch_table.hpp" +#include +#include + +namespace statistics::partitioning +{ +struct KthElement1d +{ + using RetT = std::tuple; + using FnT = RetT (*)(sycl::queue &, + const void *, + void *, + const size_t, + const size_t, + const std::vector &); + + ext::common::DispatchTable dispatch_table; + + KthElement1d(); + + RetT call(const dpctl::tensor::usm_ndarray &a, + dpctl::tensor::usm_ndarray &partitioned, + uint64_t k, + const std::vector &depends); +}; + +void populate_kth_element1d(py::module_ m); +} // namespace statistics::partitioning diff --git a/dpnp/backend/extensions/statistics/partitioning.cpp b/dpnp/backend/extensions/statistics/partitioning.cpp new file mode 100644 index 000000000000..abd8bd69cefe --- /dev/null +++ b/dpnp/backend/extensions/statistics/partitioning.cpp @@ -0,0 +1,67 @@ +//***************************************************************************** +// Copyright (c) 2024-2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include +#include + +#include "dpctl4pybind11.hpp" +#include "utils/type_dispatch.hpp" +#include + +#include "ext/common.hpp" +#include "ext/validation_utils.hpp" +#include "sliding_window1d.hpp" + +namespace dpctl_td_ns = dpctl::tensor::type_dispatch; +using namespace ext::common; +using namespace ext::validation; + +using dpctl::tensor::usm_ndarray; +using dpctl_td_ns::typenum_t; + +namespace statistics::partitioning +{ + +void validate(const usm_ndarray &a, + const usm_ndarray &partitioned, + const size_t k) +{ + array_names names = {{&a, "a"}, {&partitioned, "partitioned"}}; + + common_checks({&a}, {&partitioned}, names); + check_same_size(&a, &partitioned, names); + check_num_dims(&a, 1, names); + check_num_dims(&partitioned, 1, names); + check_same_dtype(&a, &partitioned, names); + + if (k > a.get_size() - 2) { + throw py::value_error("'k' must be from 0 to a.size() - 2, " + "but got k = " + + std::to_string(k) + " and a.size() = " + + std::to_string(a.get_size())); + } +} + +} // namespace statistics::partitioning diff --git a/dpnp/backend/extensions/statistics/partitioning.hpp b/dpnp/backend/extensions/statistics/partitioning.hpp new file mode 100644 index 000000000000..9c58ae7308c3 --- /dev/null +++ b/dpnp/backend/extensions/statistics/partitioning.hpp @@ -0,0 +1,231 @@ +//***************************************************************************** +// Copyright (c) 2024-2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include "utils/math_utils.hpp" +#include +#include + +#include + +#include "ext/common.hpp" + +using dpctl::tensor::usm_ndarray; + +using ext::common::AtomicOp; +using ext::common::IsNan; +using ext::common::Less; +using ext::common::make_ndrange; + +namespace statistics::partitioning +{ + +struct Counters +{ + uint64_t *less_count; + uint64_t *equal_count; + uint64_t *greater_equal_count; + uint64_t *nan_count; + + Counters(sycl::queue &queue) + { + less_count = sycl::malloc_device(1, queue); + equal_count = sycl::malloc_device(1, queue); + greater_equal_count = sycl::malloc_device(1, queue); + nan_count = sycl::malloc_device(1, queue); + }; + + void cleanup(sycl::queue &queue) + { + sycl::free(less_count, queue); + sycl::free(equal_count, queue); + sycl::free(greater_equal_count, queue); + sycl::free(nan_count, queue); + } +}; + +template +struct State +{ + Counters counters; + Counters iteration_counters; + + bool *stop; + bool *target_found; + bool *left; + + T *pivot; + T *values; + + size_t *num_elems; + + size_t n; + + State(sycl::queue &queue, size_t _n, T *values_buff) + : counters(queue), iteration_counters(queue) + { + stop = sycl::malloc_device(1, queue); + target_found = sycl::malloc_device(1, queue); + left = sycl::malloc_device(1, queue); + + pivot = sycl::malloc_device(1, queue); + values = values_buff; + + num_elems = sycl::malloc_device(1, queue); + + n = _n; + } + + sycl::event init(sycl::queue &queue, const std::vector &deps) + { + sycl::event fill_e = + queue.fill(counters.less_count, 0, 1, deps); + fill_e = queue.fill(counters.equal_count, 0, 1, {fill_e}); + fill_e = + queue.fill(counters.greater_equal_count, n, 1, {fill_e}); + fill_e = queue.fill(counters.nan_count, 0, 1, {fill_e}); + fill_e = queue.fill(num_elems, 0, 1, {fill_e}); + fill_e = queue.fill(stop, false, 1, {fill_e}); + fill_e = queue.fill(target_found, false, 1, {fill_e}); + fill_e = queue.fill(left, false, 1, {fill_e}); + fill_e = queue.fill(pivot, 0, 1, {fill_e}); + + return fill_e; + } + + void update_counters() const + { + if (*left) { + counters.less_count[0] -= iteration_counters.greater_equal_count[0]; + counters.greater_equal_count[0] += + iteration_counters.greater_equal_count[0]; + } + else { + counters.less_count[0] += iteration_counters.less_count[0]; + counters.greater_equal_count[0] -= iteration_counters.less_count[0]; + } + counters.equal_count[0] = iteration_counters.equal_count[0]; + counters.nan_count[0] += iteration_counters.nan_count[0]; + } + + void reset_iteration_counters() const + { + iteration_counters.less_count[0] = 0; + iteration_counters.equal_count[0] = 0; + iteration_counters.greater_equal_count[0] = 0; + iteration_counters.nan_count[0] = 0; + } + + void cleanup(sycl::queue &queue) + { + counters.cleanup(queue); + iteration_counters.cleanup(queue); + + sycl::free(stop, queue); + sycl::free(target_found, queue); + sycl::free(left, queue); + + sycl::free(num_elems, queue); + sycl::free(pivot, queue); + } +}; + +template +struct PartitionState +{ + Counters iteration_counters; + + bool *stop; + bool *left; + + T *pivot; + + size_t n; + size_t *num_elems; + + PartitionState(State &state) + : iteration_counters(state.iteration_counters) + { + stop = state.stop; + left = state.left; + + num_elems = state.num_elems; + pivot = state.pivot; + + n = state.n; + } + + sycl::event init(sycl::queue &queue, const std::vector &deps) + { + sycl::event fill_e = + queue.fill(iteration_counters.less_count, n, 1, deps); + fill_e = queue.fill(iteration_counters.equal_count, 0, 1, + {fill_e}); + fill_e = queue.fill(iteration_counters.greater_equal_count, 0, + 1, {fill_e}); + fill_e = + queue.fill(iteration_counters.nan_count, 0, 1, {fill_e}); + + return fill_e; + } +}; + +} // namespace statistics::partitioning + +#include "partitioning_one_pivot_kernel_cpu.hpp" +#include "partitioning_one_pivot_kernel_gpu.hpp" + +namespace statistics::partitioning +{ +template +sycl::event run_partition_one_pivot(sycl::queue &exec_q, + T *in, + T *out, + PartitionState &state, + const std::vector &deps) +{ + auto device = exec_q.get_device(); + + if (device.is_gpu()) { + constexpr uint32_t WorkPI = 8; + constexpr uint32_t group_size = 128; + + return run_partition_one_pivot_gpu(exec_q, in, out, state, deps, + group_size, WorkPI); + } + else { + constexpr uint32_t WorkPI = 4; + constexpr uint32_t group_size = 128; + + return run_partition_one_pivot_cpu(exec_q, in, out, state, + deps, group_size); + } +} + +void validate(const usm_ndarray &a, + const usm_ndarray &partitioned, + const size_t k); +} // namespace statistics::partitioning diff --git a/dpnp/backend/extensions/statistics/partitioning_one_pivot_kernel_cpu.hpp b/dpnp/backend/extensions/statistics/partitioning_one_pivot_kernel_cpu.hpp new file mode 100644 index 000000000000..f9ed9039c340 --- /dev/null +++ b/dpnp/backend/extensions/statistics/partitioning_one_pivot_kernel_cpu.hpp @@ -0,0 +1,226 @@ +//***************************************************************************** +// Copyright (c) 2024-2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include "utils/math_utils.hpp" +#include +#include + +#include + +#include "ext/common.hpp" + +#include "partitioning.hpp" + +using dpctl::tensor::usm_ndarray; + +using ext::common::AtomicOp; +using ext::common::IsNan; +using ext::common::Less; +using ext::common::make_ndrange; + +namespace statistics::partitioning +{ + +template +struct partition_one_pivot_kernel_cpu; + +template +auto partition_one_pivot_func_cpu(sycl::handler &cgh, + T *in, + T *out, + PartitionState &state) +{ + auto loc_counters = + sycl::local_accessor(sycl::range<1>(4), cgh); + + return [=](sycl::nd_item<1> item) { + if (state.stop[0]) + return; + + auto group = item.get_group(); + uint64_t items_per_group = group.get_local_range(0) * WorkPI; + uint64_t num_elems = state.num_elems[0]; + + if (group.get_group_id(0) * items_per_group >= num_elems) + return; + + T *_in = nullptr; + if (state.left[0]) { + _in = in; + } + else { + _in = in + state.n - num_elems; + } + + auto value = state.pivot[0]; + + auto sbg = item.get_sub_group(); + uint32_t sbg_size = sbg.get_max_local_range()[0]; + + uint64_t i_base = + (item.get_global_linear_id() - sbg.get_local_linear_id()) * WorkPI; + + if (group.leader()) { + loc_counters[0] = 0; + loc_counters[1] = 0; + loc_counters[2] = 0; + } + + sycl::group_barrier(group); + + uint32_t less_count = 0; + uint32_t equal_count = 0; + uint32_t greater_equal_count = 0; + uint32_t nan_count = 0; + + T values[WorkPI]; + uint32_t actual_count = 0; + uint64_t local_i_base = i_base + sbg.get_local_linear_id(); + + for (uint32_t _i = 0; _i < WorkPI; ++_i) { + auto i = local_i_base + _i * sbg_size; + if (i < num_elems) { + values[_i] = _in[i]; + auto is_nan = IsNan::isnan(values[_i]); + less_count += (Less{}(values[_i], value) && !is_nan); + equal_count += (values[_i] == value && !is_nan); + nan_count += is_nan; + actual_count++; + } + } + + greater_equal_count = actual_count - less_count - nan_count; + + auto sbg_less_equal = + sycl::reduce_over_group(sbg, less_count, sycl::plus<>()); + auto sbg_equal = + sycl::reduce_over_group(sbg, equal_count, sycl::plus<>()); + auto sbg_greater = + sycl::reduce_over_group(sbg, greater_equal_count, sycl::plus<>()); + + uint32_t local_less_offset = 0; + uint32_t local_gr_offset = 0; + if (sbg.leader()) { + sycl::atomic_ref + gr_less_eq(loc_counters[0]); + local_less_offset = gr_less_eq.fetch_add(sbg_less_equal); + + sycl::atomic_ref + gr_eq(loc_counters[1]); + gr_eq += sbg_equal; + + sycl::atomic_ref + gr_greater(loc_counters[2]); + local_gr_offset = gr_greater.fetch_add(sbg_greater); + } + + local_less_offset = sycl::select_from_group(sbg, local_less_offset, 0); + local_gr_offset = sycl::select_from_group(sbg, local_gr_offset, 0); + + sycl::group_barrier(group); + + if (group.leader()) { + sycl::atomic_ref + glbl_less_eq(state.iteration_counters.less_count[0]); + auto global_less_eq_offset = + glbl_less_eq.fetch_add(loc_counters[0]); + + sycl::atomic_ref + glbl_eq(state.iteration_counters.equal_count[0]); + glbl_eq += loc_counters[1]; + + sycl::atomic_ref + glbl_greater(state.iteration_counters.greater_equal_count[0]); + auto global_gr_offset = glbl_greater.fetch_add(loc_counters[2]); + + loc_counters[0] = global_less_eq_offset; + loc_counters[2] = global_gr_offset; + } + + sycl::group_barrier(group); + + auto sbg_less_offset = loc_counters[0] + local_less_offset; + auto sbg_gr_offset = + state.n - (loc_counters[2] + local_gr_offset + sbg_greater); + + uint32_t le_item_offset = 0; + uint32_t gr_item_offset = 0; + + for (uint32_t _i = 0; _i < WorkPI; ++_i) { + uint32_t is_nan = IsNan::isnan(values[_i]); + uint32_t less = (!is_nan && Less{}(values[_i], value)); + auto le_pos = + sycl::exclusive_scan_over_group(sbg, less, sycl::plus<>()); + auto ge_pos = sbg.get_local_linear_id() - le_pos; + + auto total_le = sycl::reduce_over_group(sbg, less, sycl::plus<>()); + auto total_nan = + sycl::reduce_over_group(sbg, is_nan, sycl::plus<>()); + auto total_gr = sbg_size - total_le - total_nan; + + if (_i < actual_count) { + if (less) { + out[sbg_less_offset + le_item_offset + le_pos] = values[_i]; + } + else if (!is_nan) { + out[sbg_gr_offset + gr_item_offset + ge_pos] = values[_i]; + } + le_item_offset += total_le; + gr_item_offset += total_gr; + } + } + }; +} + +template +sycl::event run_partition_one_pivot_cpu(sycl::queue &exec_q, + T *in, + T *out, + PartitionState &state, + const std::vector &deps, + uint32_t group_size) +{ + auto e = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + + auto work_range = make_ndrange(state.n, group_size, WorkPI); + + cgh.parallel_for>( + work_range, + partition_one_pivot_func_cpu(cgh, in, out, state)); + }); + + return e; +} + +} // namespace statistics::partitioning diff --git a/dpnp/backend/extensions/statistics/partitioning_one_pivot_kernel_gpu.hpp b/dpnp/backend/extensions/statistics/partitioning_one_pivot_kernel_gpu.hpp new file mode 100644 index 000000000000..cbe0ed46e4d0 --- /dev/null +++ b/dpnp/backend/extensions/statistics/partitioning_one_pivot_kernel_gpu.hpp @@ -0,0 +1,234 @@ +//***************************************************************************** +// Copyright (c) 2024-2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include "utils/math_utils.hpp" +#include +#include + +#include + +#include "ext/common.hpp" + +#include "partitioning.hpp" + +using dpctl::tensor::usm_ndarray; + +using ext::common::AtomicOp; +using ext::common::IsNan; +using ext::common::Less; +using ext::common::make_ndrange; + +namespace statistics::partitioning +{ + +template +struct partition_one_pivot_kernel_gpu; + +template +auto partition_one_pivot_func_gpu(sycl::handler &cgh, + T *in, + T *out, + PartitionState &state, + uint32_t group_size, + uint32_t WorkPI) +{ + auto loc_counters = + sycl::local_accessor(sycl::range<1>(4), cgh); + auto loc_global_counters = + sycl::local_accessor(sycl::range<1>(2), cgh); + auto loc_items = + sycl::local_accessor(sycl::range<1>(WorkPI * group_size), cgh); + + return [=](sycl::nd_item<1> item) { + if (state.stop[0]) + return; + + auto group = item.get_group(); + auto group_range = group.get_local_range(0); + auto llid = item.get_local_linear_id(); + uint64_t items_per_group = group.get_local_range(0) * WorkPI; + uint64_t num_elems = state.num_elems[0]; + + if (group.get_group_id(0) * items_per_group >= num_elems) + return; + + T *_in = nullptr; + if (state.left[0]) { + _in = in; + } + else { + _in = in + state.n - num_elems; + } + + auto value = state.pivot[0]; + + auto sbg = item.get_sub_group(); + + uint32_t sbg_size = sbg.get_max_local_range()[0]; + uint32_t sbg_work_size = sbg_size * WorkPI; + uint32_t sbg_llid = sbg.get_local_linear_id(); + uint64_t i_base = (item.get_global_linear_id() - sbg_llid) * WorkPI; + + if (group.leader()) { + loc_counters[0] = 0; + loc_counters[1] = 0; + loc_counters[2] = 0; + } + + sycl::group_barrier(group); + + for (uint32_t _i = 0; _i < WorkPI; ++_i) { + uint32_t less_count = 0; + uint32_t equal_count = 0; + uint32_t greater_equal_count = 0; + + uint32_t actual_count = 0; + auto i = i_base + _i * sbg_size + sbg_llid; + uint32_t valid = i < num_elems; + auto val = valid ? _in[i] : 0; + uint32_t less = (val < value) && valid; + uint32_t equal = (val == value) && valid; + + auto le_pos = + sycl::exclusive_scan_over_group(sbg, less, sycl::plus<>()); + auto ge_pos = sbg.get_local_linear_id() - le_pos; + auto sbg_less_equal = + sycl::reduce_over_group(sbg, less, sycl::plus<>()); + auto sbg_equal = + sycl::reduce_over_group(sbg, equal, sycl::plus<>()); + auto tot_valid = + sycl::reduce_over_group(sbg, valid, sycl::plus<>()); + auto sbg_greater = tot_valid - sbg_less_equal; + + uint32_t local_less_offset = 0; + uint32_t local_gr_offset = 0; + + if (sbg.leader()) { + sycl::atomic_ref + gr_less_eq(loc_counters[0]); + local_less_offset = gr_less_eq.fetch_add(sbg_less_equal); + + sycl::atomic_ref + gr_eq(loc_counters[1]); + gr_eq += sbg_equal; + + sycl::atomic_ref + gr_greater(loc_counters[2]); + local_gr_offset = gr_greater.fetch_add(sbg_greater); + } + + uint32_t local_less_offset_ = + sycl::select_from_group(sbg, local_less_offset, 0); + uint32_t local_gr_offset_ = + sycl::select_from_group(sbg, local_gr_offset, 0); + + if (valid) { + if (less) { + uint32_t ll_offset = local_less_offset_ + le_pos; + loc_items[ll_offset] = val; + } + else { + auto loc_gr_offset = group_range * WorkPI - + local_gr_offset_ - sbg_greater + + ge_pos; + loc_items[loc_gr_offset] = val; + } + } + } + + sycl::group_barrier(group); + + if (group.leader()) { + sycl::atomic_ref + glbl_less_eq(state.iteration_counters.less_count[0]); + auto global_less_eq_offset = + glbl_less_eq.fetch_add(loc_counters[0]); + + sycl::atomic_ref + glbl_eq(state.iteration_counters.equal_count[0]); + glbl_eq += loc_counters[1]; + + sycl::atomic_ref + glbl_greater(state.iteration_counters.greater_equal_count[0]); + auto global_gr_offset = glbl_greater.fetch_add(loc_counters[2]); + + loc_global_counters[0] = global_less_eq_offset; + loc_global_counters[1] = global_gr_offset + loc_counters[2]; + } + + sycl::group_barrier(group); + + auto global_less_eq_offset = loc_global_counters[0]; + auto global_gr_offset = state.n - loc_global_counters[1]; + + uint32_t sbg_id = sbg.get_group_id(); + for (uint32_t _i = 0; _i < WorkPI; ++_i) { + uint32_t i = sbg_id * sbg_size * WorkPI + _i * sbg_size + sbg_llid; + if (i < loc_counters[0]) { + out[global_less_eq_offset + i] = loc_items[i]; + } + else if (i < loc_counters[0] + loc_counters[2]) { + auto global_gr_offset_ = global_gr_offset + i - loc_counters[0]; + uint32_t local_buff_offset = WorkPI * group_range - + loc_counters[2] + i - + loc_counters[0]; + + out[global_gr_offset_] = loc_items[local_buff_offset]; + } + } + }; +} + +template +sycl::event run_partition_one_pivot_gpu(sycl::queue &exec_q, + T *in, + T *out, + PartitionState &state, + const std::vector &deps, + uint32_t group_size, + uint32_t WorkPI) +{ + auto e = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + + auto work_range = make_ndrange(state.n, group_size, WorkPI); + + cgh.parallel_for>( + work_range, partition_one_pivot_func_gpu(cgh, in, out, state, + group_size, WorkPI)); + }); + + return e; +} + +} // namespace statistics::partitioning diff --git a/dpnp/backend/extensions/statistics/statistics_py.cpp b/dpnp/backend/extensions/statistics/statistics_py.cpp index 6636d3f7d531..757ec85c6222 100644 --- a/dpnp/backend/extensions/statistics/statistics_py.cpp +++ b/dpnp/backend/extensions/statistics/statistics_py.cpp @@ -32,12 +32,14 @@ #include "bincount.hpp" #include "histogram.hpp" #include "histogramdd.hpp" +#include "kth_element1d.hpp" #include "sliding_dot_product1d.hpp" PYBIND11_MODULE(_statistics_impl, m) { statistics::histogram::populate_bincount(m); statistics::histogram::populate_histogram(m); + statistics::partitioning::populate_kth_element1d(m); statistics::sliding_window1d::populate_sliding_dot_product1d(m); statistics::histogram::populate_histogramdd(m); } diff --git a/dpnp/dpnp_utils/dpnp_utils_statistics.py b/dpnp/dpnp_utils/dpnp_utils_statistics.py index 108fda7286fc..98cafefd5421 100644 --- a/dpnp/dpnp_utils/dpnp_utils_statistics.py +++ b/dpnp/dpnp_utils/dpnp_utils_statistics.py @@ -27,11 +27,14 @@ import dpctl import dpctl.tensor as dpt +import dpctl.utils as dpu from dpctl.tensor._numpy_helper import normalize_axis_tuple from dpctl.utils import ExecutionPlacementError import dpnp +import dpnp.backend.extensions.statistics._statistics_impl as statistics_ext from dpnp.dpnp_array import dpnp_array +from dpnp.dpnp_utils.dpnp_utils_common import to_supported_dtypes __all__ = ["dpnp_cov", "dpnp_median"] @@ -191,6 +194,77 @@ def dpnp_cov( return c.squeeze() +def native_median(a, ignore_nan): + a = dpnp.reshape(a, a.size) + device = a.sycl_device + + result_dtype = dpnp.default_float_type() + if dpnp.issubdtype(a.dtype, dpnp.complexfloating): + result_dtype = a.dtype + + if a.size == 0: + return dpnp.array(dpnp.nan, ndmin=1, dtype=result_dtype) + elif a.size == 1: + return dpnp.array(a[0], ndmin=1, dtype=result_dtype) + + supported_types = statistics_ext.kth_element_dtypes() + supported_dtype = to_supported_dtypes(a.dtype, supported_types, device) + + if supported_dtype is None: # pragma: no cover + raise ValueError( + f"function does not support input type " + f"{a.dtype.name}, " + "and the input could not be coerced to any " + f"supported type. List of supported types: " + f"{[st.name for st in supported_types]}" + ) + + a_casted = dpnp.asarray(a, dtype=supported_dtype, order="C") + + partitioned = dpnp.empty_like(a_casted) + + a_usm = dpnp.get_usm_ndarray(a_casted) + partitioned_usm = dpnp.get_usm_ndarray(partitioned) + + _manager = dpu.SequentialOrderManager[a.sycl_queue] + + result = dpnp.empty_like(a, dtype=result_dtype, shape=1) + + nans = 0 + if ignore_nan: + nans = dpnp.isnan(a_usm).sum() + k = (a.shape[0] - 1 - nans) // 2 + + found, buff_offset, elems_offset, num_elems, nan_count = ( + statistics_ext.kth_element( + a_usm, + partitioned_usm, + k, + depends=_manager.submitted_events, + ) + ) + + if not ignore_nan and nan_count > 0: + return dpnp.array(dpnp.nan, ndmin=1, dtype=result_dtype) + + if found: + if a.shape[0] % 2 == 0: + # even number of elements + result[0] = (partitioned[0] + partitioned[1]) / 2 + else: + result[0] = partitioned[0] + else: + partitioned[buff_offset : buff_offset + num_elems].sort() + kth_idx = buff_offset + k - elems_offset + if a.shape[0] % 2 == 0: + # even number of elements + result[0] = (partitioned[kth_idx] + partitioned[kth_idx + 1]) / 2 + else: + result[0] = partitioned[kth_idx] + + return result + + def dpnp_median( a, axis=None, @@ -201,6 +275,13 @@ def dpnp_median( ): """Compute the median of an array along a specified axis.""" + if axis is None or a.ndim == 1: + result = native_median(a, ignore_nan) + if not keepdims: + return result[0] + + return result.reshape((1,) * a.ndim) + a_ndim = a.ndim a_shape = a.shape _axis = range(a_ndim) if axis is None else axis diff --git a/dpnp/tests/test_statistics.py b/dpnp/tests/test_statistics.py index cf436087b607..f41fd6e850b8 100644 --- a/dpnp/tests/test_statistics.py +++ b/dpnp/tests/test_statistics.py @@ -915,6 +915,7 @@ def test_basic(self, dtype, size): a = generate_random_numpy_array(size, dtype) ia = dpnp.array(a) + # import pdb; pdb.set_trace() expected = numpy.median(a) result = dpnp.median(ia) assert_dtype_allclose(result, expected) @@ -979,25 +980,6 @@ def test_nan(self, axis, keepdims): assert_dtype_allclose(result, expected) - @pytest.mark.parametrize("axis", [None, 0, -1, (0, -2, -1)]) - @pytest.mark.parametrize("keepdims", [True, False]) - def test_overwrite_input(self, axis, keepdims): - a = generate_random_numpy_array((2, 3, 4)) - ia = dpnp.array(a) - - b = a.copy() - ib = ia.copy() - expected = numpy.median( - b, axis=axis, keepdims=keepdims, overwrite_input=True - ) - result = dpnp.median( - ib, axis=axis, keepdims=keepdims, overwrite_input=True - ) - assert not numpy.all(a == b) - assert not dpnp.all(ia == ib) - - assert_dtype_allclose(result, expected) - @pytest.mark.parametrize("axis", [None, 0, (-1,), [0, 1]]) @pytest.mark.parametrize("overwrite_input", [True, False]) def test_usm_ndarray(self, axis, overwrite_input): @@ -1008,6 +990,9 @@ def test_usm_ndarray(self, axis, overwrite_input): result = dpnp.median(ia, axis=axis, overwrite_input=overwrite_input) assert_dtype_allclose(result, expected) + if not overwrite_input: + assert_dtype_allclose(ia, a) + class TestPtp: @pytest.mark.parametrize("axis", [None, 0, 1])