Skip to content

[Backends] implement kernel call for more kernel with handle #897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions src/shambackends/include/shambackends/kernel_call.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
*/

#include "shambase/optional.hpp"
#include "shambase/type_traits.hpp"
#include "shambackends/DeviceBuffer.hpp"
#include <shambackends/sycl.hpp>
#include <functional>
#include <optional>

namespace sham {
Expand Down Expand Up @@ -295,6 +298,64 @@ namespace sham {
in.complete_event_state(e);
in_out.complete_event_state(e);
}

/// internal implementation of typed_index_kernel_call
template<class index_t, class RefIn, class RefOut, class... Targs, class Functor>
void typed_index_kernel_call_handle(
sham::DeviceQueue &q,
RefIn in,
RefOut in_out,
index_t group_size,
index_t nthreads,
Functor &&func,
Targs... args) {

if (nthreads == 0) {
shambase::throw_with_loc<std::runtime_error>("kernel call with : n == 0");
}

if (group_size == 0) {
shambase::throw_with_loc<std::runtime_error>("kernel call with : group_size == 0");
}

if (nthreads % group_size != 0) {
shambase::throw_with_loc<std::runtime_error>(
"kernel call with : nthreads % group_size != 0");
}

StackEntry stack_loc{};
sham::EventList depends_list;

auto acc_in = in.get_read_access(depends_list);
auto acc_in_out = in_out.get_write_access(depends_list);

auto e = q.submit(depends_list, [&](sycl::handler &cgh) {
auto kernel = std::apply(
[&](auto &...__acc_in) {
return std::apply(
[&](auto &...__acc_in_out) {
shambase::check_functor_signature_deduce_noreturn_add_t<
sycl::handler &>(func, __acc_in..., __acc_in_out..., args...);

return func(cgh, __acc_in..., __acc_in_out..., args...);
},
acc_in_out);
},
acc_in);

cgh.parallel_for(sycl::nd_range<1>{nthreads, group_size}, [=](sycl::nd_item<1> id) {
index_t local_id = id.get_local_id(0);
index_t group_tile_id = id.get_group_linear_id();

shambase::check_functor_signature_deduce<void>(kernel, group_tile_id, local_id);

kernel(group_tile_id, local_id);
});
});

in.complete_event_state(e);
in_out.complete_event_state(e);
}
} // namespace details

/**
Expand Down Expand Up @@ -476,4 +537,31 @@ namespace sham {
q, in, in_out, n, std::forward<Functor>(func), args...);
}

template<class RefIn, class RefOut, class... Targs, class Functor>
void kernel_call_handle(
sham::DeviceQueue &q,
RefIn in,
RefOut in_out,
u32 group_size,
u32 nthreads,
Functor &&func,
Targs... args) {
details::typed_index_kernel_call_handle<u32, RefIn, RefOut>(
q, in, in_out, group_size, nthreads, std::forward<Functor>(func), args...);
}

/// u64 indexed variant of kernel_call
template<class RefIn, class RefOut, class... Targs, class Functor>
void kernel_call_handle_u64(
sham::DeviceQueue &q,
RefIn in,
RefOut in_out,
u64 group_size,
u64 nthreads,
Functor &&func,
Targs... args) {
details::typed_index_kernel_call_handle<u64, RefIn, RefOut>(
q, in, in_out, group_size, nthreads, std::forward<Functor>(func), args...);
}

} // namespace sham
41 changes: 41 additions & 0 deletions src/shambase/include/shambase/type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,19 @@ namespace shambase {
"the signature of typed_false_v<...>.");
}

/// variant of check_functor_signature that does not check the return type
template<class... Targ, class Func>
constexpr void check_functor_signature_noreturn(Func &&func) {

using signature = void(std::remove_reference_t<Targ>...);

constexpr bool result_call = std::is_invocable_v<decltype(func), Targ...>;
static_assert(
typed_false_v<result_call, signature>,
"The lambda signature is incorrect, the correct function signature is indicated in the "
"signature of typed_false_v<...>, aside for the return type.");
}

/**
* @brief Check if a callable object has the correct deduced signature.
*
Expand All @@ -343,4 +356,32 @@ namespace shambase {
check_functor_signature<RetType, Targ...>(func);
}

/// variant of check_functor_signature_deduce that does not check the return type
template<class... Targ, class Func>
constexpr void check_functor_signature_deduce_noreturn(Func &&func, Targ...) {
check_functor_signature_noreturn<Targ...>(func);
}

/**
* @brief variant of check_functor_signature_deduce that does not check the return type and
* where some types can be specified manually
*
* For example when using a type with a reference, this is not properly deduced.
* This functions allows to specify the type manually with the reference.
*
* @code {.cpp}
* shambase::check_functor_signature_deduce_noreturn_add_t<sycl::handler&>(
* func,
* __acc_in...,
* __acc_in_out...,
* args...);
* @endcode
*
* Here the function signature will be auto(sycl::handler&, <other types>)
*/
template<class... Targ2, class... Targ, class Func>
constexpr void check_functor_signature_deduce_noreturn_add_t(Func &&func, Targ...) {
check_functor_signature_noreturn<Targ2..., Targ...>(func);
}

} // namespace shambase
73 changes: 73 additions & 0 deletions src/shammodels/sph/src/modules/ComputeEos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,72 @@
#include "shamrock/scheduler/SchedulerUtility.hpp"
#include "shamsys/legacy/log.hpp"

template<class T>
struct PatchDataFieldSpan{

using value_type = T;

PatchDataField<T> &field;
u32 nvar;

u32 begin;
u32 len;

PatchDataFieldSpan(PatchDataField<T> &field, u32 begin, u32 len) : field(field), begin(begin), len(len), nvar(field.get_nvar()) {

if(len+begin > field.get_obj_cnt()){
shambase::throw_with_loc<std::runtime_error>("span out of bounds");
}

}

struct accessed_read_only {
const T *ptr;

T operator()(u32 i, u32 loc_val,u32 nvar) const {
return ptr[i*nvar + loc_val];
}
};

struct accessed_read_write {
T *ptr;

T operator()(u32 i, u32 loc_val,u32 nvar) const {
return ptr[i*nvar + loc_val];
}
};

accessed_read_only get_read_access(sham::EventList &depends_list) {
auto ptr = field.get_buf().get_read_access(depends_list);
return accessed_read_only{ptr + begin*nvar};
}

accessed_read_write get_write_access(sham::EventList &depends_list) {
auto ptr = field.get_buf().get_read_write_access(depends_list);
return accessed_read_write{ptr + begin*nvar};
}

void complete_event_state(sycl::event e) { field.get_buf().complete_event_state(e);}

};

template<class T>
struct DistributedFieldRef{

shambase::DistributedData<std::reference_wrapper<PatchDataFieldSpan<T>>> fields;

template<class T2>
static DistributedFieldRef<T> from(std::function<std::reference_wrapper<PatchDataFieldSpan<T>>(u64 i, T2 &)> mapper, shambase::DistributedData<T2>& other){

DistributedFieldRef<T> ret;
ret.fields = other.template map<std::reference_wrapper<PatchDataFieldSpan<T>>>(mapper);
return ret;
}

};



template<class Tvec, template<class> class SPHKernel>
void shammodels::sph::modules::ComputeEos<Tvec, SPHKernel>::compute_eos() {

Expand Down Expand Up @@ -60,6 +126,13 @@ void shammodels::sph::modules::ComputeEos<Tvec, SPHKernel>::compute_eos() {

using EOS = shamphys::EOS_Isothermal<Tscal>;

DistributedFieldRef<Tscal> pressure_field = DistributedFieldRef<Tscal>::from(
[&](u64 id, EOS &eos) {
return std::cref(storage.pressure.get().get_field(id));
},
storage.pressure.get()
);

storage.merged_patchdata_ghost.get().for_each([&](u64 id, MergedPatchData &mpdat) {
sham::DeviceBuffer<Tscal> &buf_P = storage.pressure.get().get_buf_check(id);
sham::DeviceBuffer<Tscal> &buf_cs = storage.soundspeed.get().get_buf_check(id);
Expand Down
74 changes: 74 additions & 0 deletions src/tests/shambackends/kernel_call.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// -------------------------------------------------------//
//
// SHAMROCK code for hydrodynamics
// Copyright (c) 2021-2024 Timothée David--Cléris <[email protected]>
// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
//
// -------------------------------------------------------//

#include "shambackends/kernel_call.hpp"
#include "shambackends/DeviceBuffer.hpp"
#include "shamsys/NodeInstance.hpp"
#include "shamtest/shamtest.hpp"
#include <vector>

TestStart(Unittest, "sham::kernel_call", testkernel_call, 1) {

auto sched_ptr = shamsys::instance::get_compute_scheduler_ptr();
auto queue = sched_ptr->get_queue();

using T = f64;
sham::DeviceBuffer<T> buf{100, sched_ptr};

buf.fill(1.0);

sham::DeviceBuffer<T> buf2{100, sched_ptr};

sham::kernel_call(
queue, sham::MultiRef{buf}, sham::MultiRef{buf2}, 100, [](u32 i, const T *buf, T *buf2) {
buf2[i] = buf[i];
});

std::vector res = buf2.copy_to_stdvec();

for (u32 i = 0; i < 100; i++) {
shamtest::asserts().assert_equal("check", res[i], 1.0);
}
}

TestStart(Unittest, "sham:kernel_call_handle", testkernel_call_handle, 1) {

auto sched_ptr = shamsys::instance::get_compute_scheduler_ptr();
auto queue = sched_ptr->get_queue();

using T = f64;
sham::DeviceBuffer<T> buf{128, sched_ptr};

buf.fill(1.0);

sham::DeviceBuffer<T> buf2{128, sched_ptr};

u32 group_size = 16;
sham::kernel_call_handle(
queue,
sham::MultiRef{buf},
sham::MultiRef{buf2},
group_size,
128,
[group_size](sycl::handler &cgh, const T *buf, T *buf2) {
sycl::local_accessor<T> local_buf{16, cgh};

return [=](u32 group_id, u32 local_id) {
u32 i = group_id * group_size + local_id;
local_buf[local_id] = buf[i];
buf2[i] = local_buf[local_id];
};
});

std::vector res = buf2.copy_to_stdvec();

for (u32 i = 0; i < 128; i++) {
shamtest::asserts().assert_equal("check", res[i], 1.0);
}
}
Loading