Skip to content

Commit 180beb9

Browse files
committed
Using c++17 template dispatching for kernels
Co-authored-by: Nils Wentzell <[email protected]>
1 parent c9ba3ec commit 180beb9

27 files changed

+514
-470
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ if(FINUFFT_USE_CPU)
262262
src/finufft_core.cpp
263263
src/c_interface.cpp
264264
src/finufft_utils.cpp
265+
src/utils.cpp
265266
)
266267

267268
if(FINUFFT_BUILD_FORTRAN)

include/common/common.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#pragma once
2+
3+
#include <common/constants.h>
4+
#include <common/defines.h>
5+
#include <common/utils.h>

include/common/constants.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#pragma once
2+
3+
namespace finufft {
4+
namespace common {
5+
6+
// constants needed within common
7+
// upper bound on w, ie nspread, even when padded (see evaluate_kernel_vector);
8+
// also for common
9+
inline constexpr int MIN_NSPREAD = 2;
10+
inline constexpr int MAX_NSPREAD = 16;
11+
// max number of positive quadr nodes
12+
inline constexpr int MAX_NQUAD = 100;
13+
// Fraction growth cut-off in utils:arraywidcen, sets when translate in type-3
14+
inline constexpr double ARRAYWIDCEN_GROWFRAC = 0.1;
15+
// How many waays there are to evaluate the kernel it should match the avbailable options
16+
// in finufft_opts
17+
inline constexpr int KEREVAL_METHODS = 2;
18+
inline constexpr double PI = 3.141592653589793238462643383279502884;
19+
// 1 / (2 * PI)
20+
inline constexpr double INV_2PI = 0.159154943091895335768883763372514362;
21+
} // namespace common
22+
} // namespace finufft

include/common/defines.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#pragma once
2+
3+
/* IMPORTANT: for Windows compilers, you should add a line
4+
#define FINUFFT_DLL
5+
here if you are compiling/using FINUFFT as a DLL,
6+
in order to do the proper importing/exporting, or
7+
alternatively compile with -DFINUFFT_DLL or the equivalent
8+
command-line flag. This is not necessary under MinGW/Cygwin, where
9+
libtool does the imports/exports automatically.
10+
Alternatively use include(GenerateExportHeader) and
11+
generate_export_header(finufft) to auto generate an header containing
12+
these defines.The main reason is that if msvc changes the way it deals
13+
with it in the future we just need to update cmake for it to work
14+
instead of having a check on the msvc version. */
15+
#if defined(FINUFFT_DLL) && (defined(_WIN32) || defined(__WIN32__))
16+
#if defined(dll_EXPORTS)
17+
#define FINUFFT_EXPORT __declspec(dllexport)
18+
#else
19+
#define FINUFFT_EXPORT __declspec(dllimport)
20+
#endif
21+
#else
22+
#define FINUFFT_EXPORT
23+
#endif
24+
25+
/* specify calling convention (Windows only)
26+
The cdecl calling convention is actually not the default in all but a very
27+
few C/C++ compilers.
28+
If the user code changes the default compiler calling convention, may need
29+
this when generating DLL. */
30+
#if defined(_WIN32) || defined(__WIN32__)
31+
#define FINUFFT_CDECL __cdecl
32+
#else
33+
#define FINUFFT_CDECL
34+
#endif
35+
36+
// common function attributes
37+
#if defined(_MSC_VER)
38+
#define FINUFFT_ALWAYS_INLINE __forceinline
39+
#define FINUFFT_NEVER_INLINE __declspec(noinline)
40+
#define FINUFFT_RESTRICT __restrict
41+
#define FINUFFT_UNREACHABLE __assume(0)
42+
#define FINUFFT_UNLIKELY(x) (x)
43+
#define FINUFFT_LIKELY(x) (x)
44+
#elif defined(__GNUC__) || defined(__clang__)
45+
#define FINUFFT_ALWAYS_INLINE __attribute__((always_inline)) inline
46+
#define FINUFFT_NEVER_INLINE __attribute__((noinline))
47+
#define FINUFFT_RESTRICT __restrict__
48+
#define FINUFFT_UNREACHABLE __builtin_unreachable()
49+
#define FINUFFT_UNLIKELY(x) __builtin_expect(!!(x), 0)
50+
#define FINUFFT_LIKELY(x) __builtin_expect(!!(x), 1)
51+
#else
52+
#define FINUFFT_ALWAYS_INLINE inline
53+
#define FINUFFT_NEVER_INLINE
54+
#define FINUFFT_RESTRICT
55+
#define FINUFFT_UNREACHABLE
56+
#define FINUFFT_UNLIKELY(x) (x)
57+
#define FINUFFT_LIKELY(x) (x)
58+
#endif

include/common/utils.h

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
#pragma once
2+
3+
#include <array>
4+
#include <tuple>
5+
#include <type_traits>
6+
#include <utility>
7+
8+
#include "defines.h"
9+
10+
namespace finufft {
11+
namespace common {
12+
13+
FINUFFT_EXPORT void FINUFFT_CDECL gaussquad(int n, double *xgl, double *wgl);
14+
std::tuple<double, double> leg_eval(int n, double x);
15+
16+
// helper to generate the integer sequence in range [Start, End]
17+
template<int Offset, typename Seq> struct offset_seq;
18+
19+
template<int Offset, int... I>
20+
struct offset_seq<Offset, std::integer_sequence<int, I...>> {
21+
using type = std::integer_sequence<int, (Offset + I)...>;
22+
};
23+
24+
template<int Start, int End>
25+
using make_range =
26+
typename offset_seq<Start, std::make_integer_sequence<int, End - Start + 1>>::type;
27+
28+
template<typename Seq> struct DispatchParam {
29+
int runtime_val;
30+
using seq_type = Seq;
31+
};
32+
33+
// Cartesian product over integer sequences.
34+
// Invokes f.template operator()<...>() for each combination of values.
35+
// The functor F must provide a templated call operator.
36+
// Adapted upon suggestion from Nils Wentzell: godbolt.org/z/GM94xb1j4
37+
//
38+
namespace detail {
39+
40+
template<typename F, typename... Seq> struct Product;
41+
42+
// Recursive case: at least two sequences remaining
43+
template<typename F, int... I1, typename Seq2, typename... Rest>
44+
struct Product<F, std::integer_sequence<int, I1...>, Seq2, Rest...> {
45+
template<int... Prefix> static void apply(F &f) {
46+
(Product<F, Seq2, Rest...>::template apply<Prefix..., I1>(f), ...);
47+
}
48+
};
49+
50+
// Base case: single sequence left
51+
template<typename F, int... I1> struct Product<F, std::integer_sequence<int, I1...>> {
52+
template<int... Prefix> static void apply(F &f) {
53+
(f.template operator()<Prefix..., I1>(), ...);
54+
}
55+
};
56+
57+
template<typename F, typename... Seq> void product(F &f, Seq...) {
58+
Product<F, Seq...>::template apply<>(f);
59+
}
60+
61+
// Helper functor invoked for each combination to check runtime values
62+
template<typename Func, std::size_t N, typename ArgTuple, typename ResultType>
63+
struct DispatcherCaller {
64+
Func &func;
65+
const std::array<int, N> &vals;
66+
ArgTuple &args;
67+
std::conditional_t<std::is_void_v<ResultType>, char, ResultType> result{};
68+
template<int... Params> void operator()() {
69+
static constexpr std::array<int, sizeof...(Params)> p{Params...};
70+
if (p == vals) {
71+
if constexpr (std::is_void_v<ResultType>) {
72+
std::apply(
73+
[&](auto &&...a) {
74+
func.template operator()<Params...>(std::forward<decltype(a)>(a)...);
75+
},
76+
args);
77+
} else {
78+
result = std::apply(
79+
[&](auto &&...a) {
80+
return func.template operator()<Params...>(std::forward<decltype(a)>(a)...);
81+
},
82+
args);
83+
}
84+
}
85+
}
86+
};
87+
88+
template<typename Seq> struct seq_first;
89+
template<int I0, int... I>
90+
struct seq_first<std::integer_sequence<int, I0, I...>> : std::integral_constant<int, I0> {
91+
};
92+
93+
template<typename Tuple, std::size_t... I>
94+
auto extract_vals_impl(const Tuple &t, std::index_sequence<I...>) {
95+
return std::array<int, sizeof...(I)>{std::get<I>(t).runtime_val...};
96+
}
97+
template<typename Tuple> auto extract_vals(const Tuple &t) {
98+
using T = std::remove_reference_t<Tuple>;
99+
return extract_vals_impl(t, std::make_index_sequence<std::tuple_size_v<T>>{});
100+
}
101+
102+
template<typename Tuple, std::size_t... I>
103+
auto extract_seqs_impl(const Tuple &t, std::index_sequence<I...>) {
104+
using T = std::remove_reference_t<Tuple>;
105+
return std::make_tuple(typename std::tuple_element_t<I, T>::seq_type{}...);
106+
}
107+
template<typename Tuple> auto extract_seqs(const Tuple &t) {
108+
using T = std::remove_reference_t<Tuple>;
109+
return extract_seqs_impl(t, std::make_index_sequence<std::tuple_size_v<T>>{});
110+
}
111+
112+
template<typename Func, typename ArgTuple, typename... Seq>
113+
struct dispatch_result_helper {
114+
template<std::size_t... I>
115+
static auto test(std::index_sequence<I...>)
116+
-> decltype(std::declval<Func>().template operator()<seq_first<Seq>::value...>(
117+
std::get<I>(std::declval<ArgTuple>())...));
118+
using type = decltype(test(std::make_index_sequence<std::tuple_size_v<ArgTuple>>{}));
119+
};
120+
template<typename Func, typename ArgTuple, typename SeqTuple> struct dispatch_result;
121+
template<typename Func, typename ArgTuple, typename... Seq>
122+
struct dispatch_result<Func, ArgTuple, std::tuple<Seq...>> {
123+
using type = typename dispatch_result_helper<Func, ArgTuple, Seq...>::type;
124+
};
125+
template<typename Func, typename ArgTuple, typename SeqTuple>
126+
using dispatch_result_t = typename dispatch_result<Func, ArgTuple, SeqTuple>::type;
127+
128+
} // namespace detail
129+
130+
// Generic dispatcher mapping runtime ints to template parameters.
131+
// params is a tuple of DispatchParam holding runtime values and sequences.
132+
// When a match is found, the functor is invoked with those template parameters
133+
// and its result returned. Otherwise, the default-constructed result is returned.
134+
template<typename Func, typename ParamTuple, typename... Args>
135+
decltype(auto) dispatch(Func &&func, ParamTuple &&params, Args &&...args) {
136+
using tuple_t = std::remove_reference_t<ParamTuple>;
137+
constexpr std::size_t N = std::tuple_size_v<tuple_t>;
138+
auto vals = detail::extract_vals(params);
139+
auto seqs = detail::extract_seqs(params);
140+
auto arg_tuple = std::forward_as_tuple(std::forward<Args>(args)...);
141+
using result_t = detail::dispatch_result_t<Func, decltype(arg_tuple), decltype(seqs)>;
142+
detail::DispatcherCaller<Func, N, decltype(arg_tuple), result_t> caller{func, vals,
143+
arg_tuple};
144+
std::apply([&](auto &&...s) { detail::product(caller, s...); }, seqs);
145+
if constexpr (!std::is_void_v<result_t>) return caller.result;
146+
}
147+
148+
} // namespace common
149+
} // namespace finufft

include/cufinufft/defs.h

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,8 @@
11
#ifndef CUFINUFFT_DEFS_H
22
#define CUFINUFFT_DEFS_H
33

4+
#include <common/common.h>
45
#include <limits>
5-
// constants needed within common
6-
// upper bound on w, ie nspread, even when padded (see evaluate_kernel_vector); also for
7-
// common
8-
#define MAX_NSPREAD 16
9-
#define MIN_NSPREAD 2
10-
11-
// max number of positive quadr nodes
12-
#define MAX_NQUAD 100
13-
14-
// Fraction growth cut-off in utils:arraywidcen, sets when translate in type-3
15-
#define ARRAYWIDCEN_GROWFRAC 0.1
166

177
// FIXME: If cufft ever takes N > INT_MAX...
188
constexpr int32_t MAX_NF = std::numeric_limits<int32_t>::max();

include/cufinufft/impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ int cufinufft_makeplan_impl(int type, int dim, int *nmodes, int iflag, int ntran
7272
Marco Barbone 07/26/24. Using SM when shared memory available is enough.
7373
*/
7474
using namespace cufinufft::common;
75+
using namespace finufft::common;
7576
int ier;
7677
if (type < 1 || type > 3) {
7778
fprintf(stderr, "[%s] Invalid type (%d): should be 1, 2, or 3.\n", __func__, type);

include/cufinufft/utils.h

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,21 @@
44
// octave (mkoctfile) needs this otherwise it doesn't know what int64_t is!
55
#include <complex>
66

7-
#include <cuComplex.h>
87
#include <cufinufft/types.h>
98

109
#include <cuda_runtime.h>
1110
#include <thrust/extrema.h>
11+
#include <tuple>
1212
#include <type_traits>
1313
#include <utility> // for std::forward
1414

15-
#include <finufft_errors.h>
15+
#include <common/common.h>
1616

1717
#ifndef _USE_MATH_DEFINES
1818
#define _USE_MATH_DEFINES
1919
#endif
2020
#include <cmath>
2121

22-
#ifndef M_PI
23-
#define M_PI 3.14159265358979323846
24-
#endif
25-
2622
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__)
2723
#else
2824
__inline__ __device__ double atomicAdd(double *address, double val) {
@@ -72,6 +68,8 @@ template<typename T> __forceinline__ __device__ auto interval(const int ns, cons
7268
namespace cufinufft {
7369
namespace utils {
7470

71+
using namespace finufft::common;
72+
7573
class WithCudaDevice {
7674
public:
7775
explicit WithCudaDevice(const int device) : orig_device_{get_orig_device()} {
@@ -90,10 +88,8 @@ class WithCudaDevice {
9088
}
9189
};
9290

93-
// math helpers whose source is in src/cuda/utils.cpp
94-
CUFINUFFT_BIGINT next235beven(CUFINUFFT_BIGINT n, CUFINUFFT_BIGINT b);
95-
void gaussquad(int n, double *xgl, double *wgl);
96-
std::tuple<double, double> leg_eval(int n, double x);
91+
// math helpers whose source is in src/utils.cpp
92+
long next235beven(long n, long b);
9793

9894
template<typename T> T infnorm(int n, std::complex<T> *a) {
9995
T nrm = 0.0;
@@ -124,8 +120,8 @@ static __forceinline__ __device__ void atomicAddComplexShared(
124120
* on shared memory are supported so we leverage them
125121
*/
126122
template<typename T>
127-
static __forceinline__ __device__ void atomicAddComplexGlobal(
128-
cuda_complex<T> *address, cuda_complex<T> res) {
123+
static __forceinline__ __device__ void atomicAddComplexGlobal(cuda_complex<T> *address,
124+
cuda_complex<T> res) {
129125
if constexpr (
130126
std::is_same_v<cuda_complex<T>, float2> && COMPUTE_CAPABILITY_90_OR_HIGHER) {
131127
atomicAdd(address, res);
@@ -150,7 +146,7 @@ template<typename T> auto arrayrange(int n, T *a, cudaStream_t stream) {
150146

151147
// Writes out w = half-width and c = center of an interval enclosing all a[n]'s
152148
// Only chooses a nonzero center if this increases w by less than fraction
153-
// ARRAYWIDCEN_GROWFRAC defined in defs.h.
149+
// ARRAYWIDCEN_GROWFRAC defined in common/constants.h.
154150
// This prevents rephasings which don't grow nf by much. 6/8/17
155151
// If n==0, w and c are not finite.
156152
template<typename T> auto arraywidcen(int n, T *a, cudaStream_t stream) {
@@ -180,41 +176,27 @@ auto set_nhg_type3(T S, T X, const cufinufft_opts &opts,
180176
else
181177
Ssafe = std::max(Ssafe, T(1) / X);
182178
// use the safe X and S...
183-
T nfd = 2.0 * opts.upsampfac * Ssafe * Xsafe / M_PI + nss;
179+
T nfd = 2.0 * opts.upsampfac * Ssafe * Xsafe / PI + nss;
184180
if (!std::isfinite(nfd)) nfd = 0.0; // use FLT to catch inf
185181
auto nf = (int)nfd;
186182
// printf("initial nf=%lld, ns=%d\n",*nf,spopts.nspread);
187183
// catch too small nf, and nan or +-inf, otherwise spread fails...
188184
if (nf < 2 * spopts.nspread) nf = 2 * spopts.nspread;
189-
if (nf < MAX_NF) // otherwise will fail anyway
190-
nf = utils::next235beven(nf, 1); // expensive at huge nf
185+
if (nf < MAX_NF) // otherwise will fail anyway
186+
nf = next235beven(nf, 1); // expensive at huge nf
191187
// Note: b is 1 because type 3 uses a type 2 plan, so it should not need the extra
192188
// condition that seems to be used by Block Gather as type 2 are only GM-sort
193-
auto h = 2 * T(M_PI) / nf; // upsampled grid spacing
189+
auto h = 2 * T(PI) / nf; // upsampled grid spacing
194190
auto gam = T(nf) / (2.0 * opts.upsampfac * Ssafe); // x scale fac to x'
195191
return std::make_tuple(nf, h, gam);
196192
}
197193

198-
// Generalized dispatcher for any function requiring ns-based dispatch
199-
template<typename Func, typename T, int ns, typename... Args>
200-
int dispatch_ns(Func &&func, int target_ns, Args &&...args) {
201-
if constexpr (ns > MAX_NSPREAD) {
202-
return FINUFFT_ERR_METHOD_NOTVALID; // Stop recursion
203-
} else {
204-
if (target_ns == ns) {
205-
return std::forward<Func>(func).template operator()<ns>(
206-
std::forward<Args>(args)...);
207-
}
208-
return dispatch_ns<Func, T, ns + 1>(std::forward<Func>(func), target_ns,
209-
std::forward<Args>(args)...);
210-
}
211-
}
212-
213-
// Wrapper function that starts the dispatch recursion
194+
// Wrapper around the generic dispatcher for nspread-based dispatch
214195
template<typename Func, typename T, typename... Args>
215-
int launch_dispatch_ns(Func &&func, int target_ns, Args &&...args) {
216-
return dispatch_ns<Func, T, MIN_NSPREAD>(std::forward<Func>(func), target_ns,
217-
std::forward<Args>(args)...);
196+
auto launch_dispatch_ns(Func &&func, int target_ns, Args &&...args) {
197+
using NsSeq = make_range<MIN_NSPREAD, MAX_NSPREAD>;
198+
auto params = std::make_tuple(DispatchParam<NsSeq>{target_ns});
199+
return dispatch(std::forward<Func>(func), params, std::forward<Args>(args)...);
218200
}
219201

220202
/**

0 commit comments

Comments
 (0)