|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include <array> |
| 4 | +#include <tuple> |
| 5 | +#include <type_traits> |
| 6 | +#include <utility> |
| 7 | + |
| 8 | +#include "defines.h" |
| 9 | + |
| 10 | +namespace finufft { |
| 11 | +namespace common { |
| 12 | + |
| 13 | +FINUFFT_EXPORT void FINUFFT_CDECL gaussquad(int n, double *xgl, double *wgl); |
| 14 | +std::tuple<double, double> leg_eval(int n, double x); |
| 15 | + |
| 16 | +// helper to generate the integer sequence in range [Start, End] |
| 17 | +template<int Offset, typename Seq> struct offset_seq; |
| 18 | + |
| 19 | +template<int Offset, int... I> |
| 20 | +struct offset_seq<Offset, std::integer_sequence<int, I...>> { |
| 21 | + using type = std::integer_sequence<int, (Offset + I)...>; |
| 22 | +}; |
| 23 | + |
| 24 | +template<int Start, int End> |
| 25 | +using make_range = |
| 26 | + typename offset_seq<Start, std::make_integer_sequence<int, End - Start + 1>>::type; |
| 27 | + |
| 28 | +template<typename Seq> struct DispatchParam { |
| 29 | + int runtime_val; |
| 30 | + using seq_type = Seq; |
| 31 | +}; |
| 32 | + |
| 33 | +// Cartesian product over integer sequences. |
| 34 | +// Invokes f.template operator()<...>() for each combination of values. |
| 35 | +// The functor F must provide a templated call operator. |
| 36 | +namespace detail { |
| 37 | + |
| 38 | +template<typename F, typename... Seq> struct Product; |
| 39 | + |
| 40 | +// Recursive case: at least two sequences remaining |
| 41 | +template<typename F, int... I1, typename Seq2, typename... Rest> |
| 42 | +struct Product<F, std::integer_sequence<int, I1...>, Seq2, Rest...> { |
| 43 | + template<int... Prefix> static void apply(F &f) { |
| 44 | + (Product<F, Seq2, Rest...>::template apply<Prefix..., I1>(f), ...); |
| 45 | + } |
| 46 | +}; |
| 47 | + |
| 48 | +// Base case: single sequence left |
| 49 | +template<typename F, int... I1> struct Product<F, std::integer_sequence<int, I1...>> { |
| 50 | + template<int... Prefix> static void apply(F &f) { |
| 51 | + (f.template operator()<Prefix..., I1>(), ...); |
| 52 | + } |
| 53 | +}; |
| 54 | + |
| 55 | +template<typename F, typename... Seq> void product(F &f, Seq...) { |
| 56 | + Product<F, Seq...>::template apply<>(f); |
| 57 | +} |
| 58 | + |
| 59 | +// Helper functor invoked for each combination to check runtime values |
| 60 | +template<typename Func, std::size_t N, typename ArgTuple, typename ResultType> |
| 61 | +struct DispatcherCaller { |
| 62 | + Func &func; |
| 63 | + const std::array<int, N> &vals; |
| 64 | + ArgTuple &args; |
| 65 | + std::conditional_t<std::is_void_v<ResultType>, char, ResultType> result{}; |
| 66 | + template<int... Params> void operator()() { |
| 67 | + static constexpr std::array<int, sizeof...(Params)> p{Params...}; |
| 68 | + if (p == vals) { |
| 69 | + if constexpr (std::is_void_v<ResultType>) { |
| 70 | + std::apply( |
| 71 | + [&](auto &&...a) { |
| 72 | + func.template operator()<Params...>(std::forward<decltype(a)>(a)...); |
| 73 | + }, |
| 74 | + args); |
| 75 | + } else { |
| 76 | + result = std::apply( |
| 77 | + [&](auto &&...a) { |
| 78 | + return func.template operator()<Params...>(std::forward<decltype(a)>(a)...); |
| 79 | + }, |
| 80 | + args); |
| 81 | + } |
| 82 | + } |
| 83 | + } |
| 84 | +}; |
| 85 | + |
| 86 | +template<typename Seq> struct seq_first; |
| 87 | +template<int I0, int... I> |
| 88 | +struct seq_first<std::integer_sequence<int, I0, I...>> : std::integral_constant<int, I0> { |
| 89 | +}; |
| 90 | + |
| 91 | +template<typename Tuple, std::size_t... I> |
| 92 | +auto extract_vals_impl(const Tuple &t, std::index_sequence<I...>) { |
| 93 | + return std::array<int, sizeof...(I)>{std::get<I>(t).runtime_val...}; |
| 94 | +} |
| 95 | +template<typename Tuple> auto extract_vals(const Tuple &t) { |
| 96 | + using T = std::remove_reference_t<Tuple>; |
| 97 | + return extract_vals_impl(t, std::make_index_sequence<std::tuple_size_v<T>>{}); |
| 98 | +} |
| 99 | + |
| 100 | +template<typename Tuple, std::size_t... I> |
| 101 | +auto extract_seqs_impl(const Tuple &t, std::index_sequence<I...>) { |
| 102 | + using T = std::remove_reference_t<Tuple>; |
| 103 | + return std::make_tuple(typename std::tuple_element_t<I, T>::seq_type{}...); |
| 104 | +} |
| 105 | +template<typename Tuple> auto extract_seqs(const Tuple &t) { |
| 106 | + using T = std::remove_reference_t<Tuple>; |
| 107 | + return extract_seqs_impl(t, std::make_index_sequence<std::tuple_size_v<T>>{}); |
| 108 | +} |
| 109 | + |
| 110 | +template<typename Func, typename ArgTuple, typename... Seq> |
| 111 | +struct dispatch_result_helper { |
| 112 | + template<std::size_t... I> |
| 113 | + static auto test(std::index_sequence<I...>) |
| 114 | + -> decltype(std::declval<Func>().template operator()<seq_first<Seq>::value...>( |
| 115 | + std::get<I>(std::declval<ArgTuple>())...)); |
| 116 | + using type = decltype(test(std::make_index_sequence<std::tuple_size_v<ArgTuple>>{})); |
| 117 | +}; |
| 118 | +template<typename Func, typename ArgTuple, typename SeqTuple> struct dispatch_result; |
| 119 | +template<typename Func, typename ArgTuple, typename... Seq> |
| 120 | +struct dispatch_result<Func, ArgTuple, std::tuple<Seq...>> { |
| 121 | + using type = typename dispatch_result_helper<Func, ArgTuple, Seq...>::type; |
| 122 | +}; |
| 123 | +template<typename Func, typename ArgTuple, typename SeqTuple> |
| 124 | +using dispatch_result_t = typename dispatch_result<Func, ArgTuple, SeqTuple>::type; |
| 125 | + |
| 126 | +} // namespace detail |
| 127 | + |
| 128 | +// Generic dispatcher mapping runtime ints to template parameters. |
| 129 | +// params is a tuple of DispatchParam holding runtime values and sequences. |
| 130 | +// When a match is found, the functor is invoked with those template parameters |
| 131 | +// and its result returned. Otherwise, the default-constructed result is returned. |
| 132 | +template<typename Func, typename ParamTuple, typename... Args> |
| 133 | +decltype(auto) dispatch(Func &&func, ParamTuple &¶ms, Args &&...args) { |
| 134 | + using tuple_t = std::remove_reference_t<ParamTuple>; |
| 135 | + constexpr std::size_t N = std::tuple_size_v<tuple_t>; |
| 136 | + auto vals = detail::extract_vals(params); |
| 137 | + auto seqs = detail::extract_seqs(params); |
| 138 | + auto arg_tuple = std::forward_as_tuple(std::forward<Args>(args)...); |
| 139 | + using result_t = detail::dispatch_result_t<Func, decltype(arg_tuple), decltype(seqs)>; |
| 140 | + detail::DispatcherCaller<Func, N, decltype(arg_tuple), result_t> caller{func, vals, |
| 141 | + arg_tuple}; |
| 142 | + std::apply([&](auto &&...s) { detail::product(caller, s...); }, seqs); |
| 143 | + if constexpr (!std::is_void_v<result_t>) return caller.result; |
| 144 | +} |
| 145 | + |
| 146 | +} // namespace common |
| 147 | +} // namespace finufft |
0 commit comments