44// octave (mkoctfile) needs this otherwise it doesn't know what int64_t is!
55#include < complex>
66
7- #include < cuComplex.h>
87#include < cufinufft/types.h>
98
109#include < cuda_runtime.h>
1110#include < thrust/extrema.h>
11+ #include < tuple>
1212#include < type_traits>
1313#include < utility> // for std::forward
1414
15- #include < finufft_errors .h>
15+ #include < common/common .h>
1616
1717#ifndef _USE_MATH_DEFINES
1818#define _USE_MATH_DEFINES
1919#endif
2020#include < cmath>
2121
22- #ifndef M_PI
23- #define M_PI 3.14159265358979323846
24- #endif
25-
2622#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__)
2723#else
2824__inline__ __device__ double atomicAdd (double *address, double val) {
@@ -72,6 +68,8 @@ template<typename T> __forceinline__ __device__ auto interval(const int ns, cons
7268namespace cufinufft {
7369namespace utils {
7470
71+ using namespace finufft ::common;
72+
7573class WithCudaDevice {
7674public:
7775 explicit WithCudaDevice (const int device) : orig_device_{get_orig_device ()} {
@@ -90,10 +88,8 @@ class WithCudaDevice {
9088 }
9189};
9290
93- // math helpers whose source is in src/cuda/utils.cpp
94- CUFINUFFT_BIGINT next235beven (CUFINUFFT_BIGINT n, CUFINUFFT_BIGINT b);
95- void gaussquad (int n, double *xgl, double *wgl);
96- std::tuple<double , double > leg_eval (int n, double x);
91+ // math helpers whose source is in src/utils.cpp
92+ long next235beven (long n, long b);
9793
9894template <typename T> T infnorm (int n, std::complex <T> *a) {
9995 T nrm = 0.0 ;
@@ -124,8 +120,8 @@ static __forceinline__ __device__ void atomicAddComplexShared(
124120 * on shared memory are supported so we leverage them
125121 */
126122template <typename T>
127- static __forceinline__ __device__ void atomicAddComplexGlobal (
128- cuda_complex<T> *address, cuda_complex<T> res) {
123+ static __forceinline__ __device__ void atomicAddComplexGlobal (cuda_complex<T> *address,
124+ cuda_complex<T> res) {
129125 if constexpr (
130126 std::is_same_v<cuda_complex<T>, float2> && COMPUTE_CAPABILITY_90_OR_HIGHER) {
131127 atomicAdd (address, res);
@@ -150,7 +146,7 @@ template<typename T> auto arrayrange(int n, T *a, cudaStream_t stream) {
150146
151147// Writes out w = half-width and c = center of an interval enclosing all a[n]'s
152148// Only chooses a nonzero center if this increases w by less than fraction
153- // ARRAYWIDCEN_GROWFRAC defined in defs .h.
149+ // ARRAYWIDCEN_GROWFRAC defined in common/constants .h.
154150// This prevents rephasings which don't grow nf by much. 6/8/17
155151// If n==0, w and c are not finite.
156152template <typename T> auto arraywidcen (int n, T *a, cudaStream_t stream) {
@@ -180,41 +176,27 @@ auto set_nhg_type3(T S, T X, const cufinufft_opts &opts,
180176 else
181177 Ssafe = std::max (Ssafe, T (1 ) / X);
182178 // use the safe X and S...
183- T nfd = 2.0 * opts.upsampfac * Ssafe * Xsafe / M_PI + nss;
179+ T nfd = 2.0 * opts.upsampfac * Ssafe * Xsafe / PI + nss;
184180 if (!std::isfinite (nfd)) nfd = 0.0 ; // use FLT to catch inf
185181 auto nf = (int )nfd;
186182 // printf("initial nf=%lld, ns=%d\n",*nf,spopts.nspread);
187183 // catch too small nf, and nan or +-inf, otherwise spread fails...
188184 if (nf < 2 * spopts.nspread ) nf = 2 * spopts.nspread ;
189- if (nf < MAX_NF) // otherwise will fail anyway
190- nf = utils:: next235beven (nf, 1 ); // expensive at huge nf
185+ if (nf < MAX_NF) // otherwise will fail anyway
186+ nf = next235beven (nf, 1 ); // expensive at huge nf
191187 // Note: b is 1 because type 3 uses a type 2 plan, so it should not need the extra
192188 // condition that seems to be used by Block Gather as type 2 are only GM-sort
193- auto h = 2 * T (M_PI ) / nf; // upsampled grid spacing
189+ auto h = 2 * T (PI ) / nf; // upsampled grid spacing
194190 auto gam = T (nf) / (2.0 * opts.upsampfac * Ssafe); // x scale fac to x'
195191 return std::make_tuple (nf, h, gam);
196192}
197193
198- // Generalized dispatcher for any function requiring ns-based dispatch
199- template <typename Func, typename T, int ns, typename ... Args>
200- int dispatch_ns (Func &&func, int target_ns, Args &&...args) {
201- if constexpr (ns > MAX_NSPREAD) {
202- return FINUFFT_ERR_METHOD_NOTVALID; // Stop recursion
203- } else {
204- if (target_ns == ns) {
205- return std::forward<Func>(func).template operator ()<ns>(
206- std::forward<Args>(args)...);
207- }
208- return dispatch_ns<Func, T, ns + 1 >(std::forward<Func>(func), target_ns,
209- std::forward<Args>(args)...);
210- }
211- }
212-
213- // Wrapper function that starts the dispatch recursion
194+ // Wrapper around the generic dispatcher for nspread-based dispatch
214195template <typename Func, typename T, typename ... Args>
215- int launch_dispatch_ns (Func &&func, int target_ns, Args &&...args) {
216- return dispatch_ns<Func, T, MIN_NSPREAD>(std::forward<Func>(func), target_ns,
217- std::forward<Args>(args)...);
196+ auto launch_dispatch_ns (Func &&func, int target_ns, Args &&...args) {
197+ using NsSeq = make_range<MIN_NSPREAD, MAX_NSPREAD>;
198+ auto params = std::make_tuple (DispatchParam<NsSeq>{target_ns});
199+ return dispatch (std::forward<Func>(func), params, std::forward<Args>(args)...);
218200}
219201
220202/* *
0 commit comments