1
1
// Copyright 2024-2025 NVIDIA Corporation
2
2
// SPDX-License-Identifier: Apache-2.0
3
3
4
- #ifndef TSD_USE_CUDA
5
- #define TSD_USE_CUDA 1
6
- #endif
7
-
8
4
#include " tsd/algorithms/computeScalarRange.hpp"
9
5
#include " tsd/core/Context.hpp"
10
- // std
11
- #include < algorithm>
12
- #include < limits>
13
- #if TSD_USE_CUDA
14
- // thrust
15
- #include < cuda_runtime.h>
16
- #include < thrust/device_ptr.h>
17
- #include < thrust/extrema.h>
18
- #endif
19
-
20
- namespace tsd ::algorithm {
21
-
22
- namespace detail {
23
-
24
- // NOTE(jda): This is a reduced version of anari::anariTypeInvoke() to lower
25
- // Thrust/CUDA compile times
26
- template <typename R, template <int > class F , typename ... Args>
27
- inline R scalarTypeInvoke (ANARIDataType type, Args &&...args)
28
- {
29
- // clang-format off
30
- switch (type) {
31
- case ANARI_UFIXED8: return F<ANARI_UFIXED8>()(std::forward<Args>(args)...);
32
- case ANARI_UFIXED16: return F<ANARI_UFIXED16>()(std::forward<Args>(args)...);
33
- case ANARI_FIXED8: return F<ANARI_FIXED8>()(std::forward<Args>(args)...);
34
- case ANARI_FIXED16: return F<ANARI_FIXED16>()(std::forward<Args>(args)...);
35
- case ANARI_FLOAT32: return F<ANARI_FLOAT32>()(std::forward<Args>(args)...);
36
- case ANARI_FLOAT64: return F<ANARI_FLOAT64>()(std::forward<Args>(args)...);
37
- default :
38
- return F<ANARI_UNKNOWN>()(std::forward<Args>(args)...);
39
- }
40
- // clang-format off
41
- }
42
-
43
- template <int ANARI_ENUM_T>
44
- struct ComputeScalarRange
45
- {
46
- using properties_t = anari::ANARITypeProperties<ANARI_ENUM_T>;
47
- using base_t = typename properties_t ::base_type;
6
+ #include " tsd/core/Logging.hpp"
48
7
49
- tsd::float2 operator ()(const Array &a)
50
- {
51
- tsd::float4 min_out{0 .f , 0 .f , 0 .f , 0 .f };
52
- tsd::float4 max_out{0 .f , 0 .f , 0 .f , 0 .f };
8
+ #include " tsd/algorithms/detail/computeScalarRangeImpl.hpp"
53
9
54
- const auto *begin = a.dataAs <base_t >();
55
- const auto *end = begin + a.size ();
56
- #if TSD_USE_CUDA
57
- if (a.kind () == Array::MemoryKind::CUDA) {
58
- const auto minmax = thrust::minmax_element (
59
- thrust::device_pointer_cast (begin), thrust::device_pointer_cast (end));
60
- const base_t min_v = *minmax.first ;
61
- const base_t max_v = *minmax.second ;
62
- properties_t::toFloat4 (&min_out.x , &min_v);
63
- properties_t::toFloat4 (&max_out.x , &max_v);
64
- } else {
65
- #endif
66
- const auto minmax = std::minmax_element (begin, end);
67
- const auto min_v = *minmax.first ;
68
- const auto max_v = *minmax.second ;
69
- properties_t::toFloat4 (&min_out.x , &min_v);
70
- properties_t::toFloat4 (&max_out.x , &max_v);
71
- #if TSD_USE_CUDA
72
- }
73
- #endif
74
-
75
- return {min_out.x , max_out.x };
76
- }
77
- };
78
-
79
- } // namespace detail
80
-
81
- // /////////////////////////////////////////////////////////////////////////////
82
- // /////////////////////////////////////////////////////////////////////////////
83
- // /////////////////////////////////////////////////////////////////////////////
10
+ namespace tsd ::algorithm {
84
11
85
12
tsd::float2 computeScalarRange (const Array &a)
86
13
{
@@ -103,8 +30,32 @@ tsd::float2 computeScalarRange(const Array &a)
103
30
retval.y = std::max (retval.y , subRange.y );
104
31
});
105
32
} else if (elementsAreScalars) {
106
- retval = detail::scalarTypeInvoke<tsd::float2, detail::ComputeScalarRange>(
107
- type, a);
33
+ switch (type) {
34
+ case ANARI_UFIXED8:
35
+ retval = detail::computeScalarRange_ufixed8 (a);
36
+ break ;
37
+ case ANARI_UFIXED16:
38
+ retval = detail::computeScalarRange_ufixed16 (a);
39
+ break ;
40
+ case ANARI_FIXED8:
41
+ retval = detail::computeScalarRange_fixed8 (a);
42
+ break ;
43
+ case ANARI_FIXED16:
44
+ retval = detail::computeScalarRange_fixed16 (a);
45
+ break ;
46
+ case ANARI_FLOAT32:
47
+ retval = detail::computeScalarRange_float32 (a);
48
+ break ;
49
+ case ANARI_FLOAT64:
50
+ retval = detail::computeScalarRange_float64 (a);
51
+ break ;
52
+ default :
53
+ logWarning (
54
+ " computeScalarRange() called on an "
55
+ " array with incompatible element type '%s'" ,
56
+ anari::toString (type));
57
+ break ;
58
+ }
108
59
}
109
60
110
61
return retval;
0 commit comments