Skip to content

Commit a067839

Browse files
Extracted Segmented Reduce kernels into NVRTC compilable header (#3727)
* Moved definition of DeviceSegmentedReduceKernel out to NVRTC compilable header file Moved the definition from cub/device/dispatch/dispatch_reduce.cuh to new file cub/device/dispatch/kernels/segmented_reduce.cuh Added compilation of the new file to cub/test/catch2_test_nvrtc.cu Needed to remove friended operator<< overload for std::ostream from arg_index_input_iterator.cuh, since this header is needed by DeviceSegmentedReduceKernel implementation, and use of std::ostream breaks compilation by nvrtc * Include cub/util_type.cuh for value_t and non_void_value_t * Include cub/detail/type_traits.cuh * Restore friend operator ostream conditionally for nvcc/host compiler
1 parent 61bda15 commit a067839

File tree

4 files changed

+188
-134
lines changed

4 files changed

+188
-134
lines changed

cub/cub/device/dispatch/dispatch_reduce.cuh

Lines changed: 3 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,18 @@
4444
# pragma system_header
4545
#endif // no system header
4646

47-
#include <cub/agent/agent_reduce.cuh>
4847
#include <cub/detail/launcher/cuda_runtime.cuh>
48+
#include <cub/detail/type_traits.cuh> // for cub::detail::invoke_result_t
4949
#include <cub/device/dispatch/kernels/reduce.cuh>
50+
#include <cub/device/dispatch/kernels/segmented_reduce.cuh>
5051
#include <cub/device/dispatch/tuning/tuning_reduce.cuh>
5152
#include <cub/grid/grid_even_share.cuh>
52-
#include <cub/iterator/arg_index_input_iterator.cuh>
5353
#include <cub/thread/thread_operators.cuh>
5454
#include <cub/thread/thread_store.cuh>
5555
#include <cub/util_debug.cuh>
5656
#include <cub/util_device.cuh>
5757
#include <cub/util_temporary_storage.cuh>
58-
59-
#include <iterator>
58+
#include <cub/util_type.cuh> // for cub::detail::non_void_value_t, cub::detail::value_t
6059

6160
_CCCL_SUPPRESS_DEPRECATED_PUSH
6261
#include <cuda/std/functional>
@@ -69,131 +68,6 @@ CUB_NAMESPACE_BEGIN
6968
namespace detail::reduce
7069
{
7170

72-
/// Normalize input iterator to segment offset
73-
template <typename T, typename OffsetT, typename IteratorT>
74-
_CCCL_DEVICE _CCCL_FORCEINLINE void NormalizeReductionOutput(T& /*val*/, OffsetT /*base_offset*/, IteratorT /*itr*/)
75-
{}
76-
77-
/// Normalize input iterator to segment offset (specialized for arg-index)
78-
template <typename KeyValuePairT, typename OffsetT, typename WrappedIteratorT, typename OutputValueT>
79-
_CCCL_DEVICE _CCCL_FORCEINLINE void NormalizeReductionOutput(
80-
KeyValuePairT& val, OffsetT base_offset, ArgIndexInputIterator<WrappedIteratorT, OffsetT, OutputValueT> /*itr*/)
81-
{
82-
val.key -= base_offset;
83-
}
84-
85-
/**
86-
* Segmented reduction (one block per segment)
87-
* @tparam ChainedPolicyT
88-
* Chained tuning policy
89-
*
90-
* @tparam InputIteratorT
91-
* Random-access input iterator type for reading input items @iterator
92-
*
93-
* @tparam OutputIteratorT
94-
* Output iterator type for recording the reduced aggregate @iterator
95-
*
96-
* @tparam BeginOffsetIteratorT
97-
* Random-access input iterator type for reading segment beginning offsets
98-
* @iterator
99-
*
100-
* @tparam EndOffsetIteratorT
101-
* Random-access input iterator type for reading segment ending offsets
102-
* @iterator
103-
*
104-
* @tparam OffsetT
105-
* Signed integer type for global offsets
106-
*
107-
* @tparam ReductionOpT
108-
* Binary reduction functor type having member
109-
* `T operator()(const T &a, const U &b)`
110-
*
111-
* @tparam InitT
112-
* Initial value type
113-
*
114-
* @param[in] d_in
115-
* Pointer to the input sequence of data items
116-
*
117-
* @param[out] d_out
118-
* Pointer to the output aggregate
119-
*
120-
* @param[in] d_begin_offsets
121-
* Random-access input iterator to the sequence of beginning offsets of
122-
* length `num_segments`, such that `d_begin_offsets[i]` is the first element
123-
* of the *i*<sup>th</sup> data segment in `d_keys_*` and `d_values_*`
124-
*
125-
* @param[in] d_end_offsets
126-
* Random-access input iterator to the sequence of ending offsets of length
127-
* `num_segments`, such that `d_end_offsets[i] - 1` is the last element of
128-
* the *i*<sup>th</sup> data segment in `d_keys_*` and `d_values_*`.
129-
* If `d_end_offsets[i] - 1 <= d_begin_offsets[i]`, the *i*<sup>th</sup> is
130-
* considered empty.
131-
*
132-
* @param[in] num_segments
133-
* The number of segments that comprise the sorting data
134-
*
135-
* @param[in] reduction_op
136-
* Binary reduction functor
137-
*
138-
* @param[in] init
139-
* The initial value of the reduction
140-
*/
141-
template <typename ChainedPolicyT,
142-
typename InputIteratorT,
143-
typename OutputIteratorT,
144-
typename BeginOffsetIteratorT,
145-
typename EndOffsetIteratorT,
146-
typename OffsetT,
147-
typename ReductionOpT,
148-
typename InitT,
149-
typename AccumT>
150-
CUB_DETAIL_KERNEL_ATTRIBUTES
151-
__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ReducePolicy::BLOCK_THREADS)) void DeviceSegmentedReduceKernel(
152-
InputIteratorT d_in,
153-
OutputIteratorT d_out,
154-
BeginOffsetIteratorT d_begin_offsets,
155-
EndOffsetIteratorT d_end_offsets,
156-
int /*num_segments*/,
157-
ReductionOpT reduction_op,
158-
InitT init)
159-
{
160-
// Thread block type for reducing input tiles
161-
using AgentReduceT =
162-
AgentReduce<typename ChainedPolicyT::ActivePolicy::ReducePolicy,
163-
InputIteratorT,
164-
OutputIteratorT,
165-
OffsetT,
166-
ReductionOpT,
167-
AccumT>;
168-
169-
// Shared memory storage
170-
__shared__ typename AgentReduceT::TempStorage temp_storage;
171-
172-
OffsetT segment_begin = d_begin_offsets[blockIdx.x];
173-
OffsetT segment_end = d_end_offsets[blockIdx.x];
174-
175-
// Check if empty problem
176-
if (segment_begin == segment_end)
177-
{
178-
if (threadIdx.x == 0)
179-
{
180-
*(d_out + blockIdx.x) = init;
181-
}
182-
return;
183-
}
184-
185-
// Consume input tiles
186-
AccumT block_aggregate = AgentReduceT(temp_storage, d_in, reduction_op).ConsumeRange(segment_begin, segment_end);
187-
188-
// Normalize as needed
189-
NormalizeReductionOutput(block_aggregate, segment_begin, d_in);
190-
191-
if (threadIdx.x == 0)
192-
{
193-
finalize_and_store_aggregate(d_out + blockIdx.x, reduction_op, init, block_aggregate);
194-
}
195-
}
196-
19771
template <typename MaxPolicyT,
19872
typename InputIteratorT,
19973
typename OutputIteratorT,
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
/******************************************************************************
2+
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Redistribution and use in source and binary forms, with or without
5+
* modification, are permitted provided that the following conditions are met:
6+
* * Redistributions of source code must retain the above copyright
7+
* notice, this list of conditions and the following disclaimer.
8+
* * Redistributions in binary form must reproduce the above copyright
9+
* notice, this list of conditions and the following disclaimer in the
10+
* documentation and/or other materials provided with the distribution.
11+
* * Neither the name of the NVIDIA CORPORATION nor the
12+
* names of its contributors may be used to endorse or promote products
13+
* derived from this software without specific prior written permission.
14+
*
15+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18+
* ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19+
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20+
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21+
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22+
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23+
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24+
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
*
26+
******************************************************************************/
27+
28+
#pragma once
29+
30+
#include <cub/config.cuh>
31+
32+
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
33+
# pragma GCC system_header
34+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
35+
# pragma clang system_header
36+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
37+
# pragma system_header
38+
#endif // no system header
39+
40+
#include <cub/agent/agent_reduce.cuh>
41+
#include <cub/iterator/arg_index_input_iterator.cuh>
42+
43+
CUB_NAMESPACE_BEGIN
44+
45+
namespace detail
46+
{
47+
namespace reduce
48+
{
49+
50+
/// Normalize input iterator to segment offset
51+
template <typename T, typename OffsetT, typename IteratorT>
52+
_CCCL_DEVICE _CCCL_FORCEINLINE void NormalizeReductionOutput(T& /*val*/, OffsetT /*base_offset*/, IteratorT /*itr*/)
53+
{}
54+
55+
/// Normalize input iterator to segment offset (specialized for arg-index)
56+
template <typename KeyValuePairT, typename OffsetT, typename WrappedIteratorT, typename OutputValueT>
57+
_CCCL_DEVICE _CCCL_FORCEINLINE void NormalizeReductionOutput(
58+
KeyValuePairT& val, OffsetT base_offset, ArgIndexInputIterator<WrappedIteratorT, OffsetT, OutputValueT> /*itr*/)
59+
{
60+
val.key -= base_offset;
61+
}
62+
63+
/**
64+
* Segmented reduction (one block per segment)
65+
* @tparam ChainedPolicyT
66+
* Chained tuning policy
67+
*
68+
* @tparam InputIteratorT
69+
* Random-access input iterator type for reading input items @iterator
70+
*
71+
* @tparam OutputIteratorT
72+
* Output iterator type for recording the reduced aggregate @iterator
73+
*
74+
* @tparam BeginOffsetIteratorT
75+
* Random-access input iterator type for reading segment beginning offsets
76+
* @iterator
77+
*
78+
* @tparam EndOffsetIteratorT
79+
* Random-access input iterator type for reading segment ending offsets
80+
* @iterator
81+
*
82+
* @tparam OffsetT
83+
* Signed integer type for global offsets
84+
*
85+
* @tparam ReductionOpT
86+
* Binary reduction functor type having member
87+
* `T operator()(const T &a, const U &b)`
88+
*
89+
* @tparam InitT
90+
* Initial value type
91+
*
92+
* @param[in] d_in
93+
* Pointer to the input sequence of data items
94+
*
95+
* @param[out] d_out
96+
* Pointer to the output aggregate
97+
*
98+
* @param[in] d_begin_offsets
99+
* Random-access input iterator to the sequence of beginning offsets of
100+
* length `num_segments`, such that `d_begin_offsets[i]` is the first element
101+
* of the *i*<sup>th</sup> data segment in `d_keys_*` and `d_values_*`
102+
*
103+
* @param[in] d_end_offsets
104+
* Random-access input iterator to the sequence of ending offsets of length
105+
* `num_segments`, such that `d_end_offsets[i] - 1` is the last element of
106+
* the *i*<sup>th</sup> data segment in `d_keys_*` and `d_values_*`.
107+
* If `d_end_offsets[i] - 1 <= d_begin_offsets[i]`, the *i*<sup>th</sup> is
108+
* considered empty.
109+
*
110+
* @param[in] num_segments
111+
* The number of segments that comprise the sorting data
112+
*
113+
* @param[in] reduction_op
114+
* Binary reduction functor
115+
*
116+
* @param[in] init
117+
* The initial value of the reduction
118+
*/
119+
template <typename ChainedPolicyT,
120+
typename InputIteratorT,
121+
typename OutputIteratorT,
122+
typename BeginOffsetIteratorT,
123+
typename EndOffsetIteratorT,
124+
typename OffsetT,
125+
typename ReductionOpT,
126+
typename InitT,
127+
typename AccumT>
128+
CUB_DETAIL_KERNEL_ATTRIBUTES
129+
__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ReducePolicy::BLOCK_THREADS)) void DeviceSegmentedReduceKernel(
130+
InputIteratorT d_in,
131+
OutputIteratorT d_out,
132+
BeginOffsetIteratorT d_begin_offsets,
133+
EndOffsetIteratorT d_end_offsets,
134+
int /*num_segments*/,
135+
ReductionOpT reduction_op,
136+
InitT init)
137+
{
138+
// Thread block type for reducing input tiles
139+
using AgentReduceT =
140+
AgentReduce<typename ChainedPolicyT::ActivePolicy::ReducePolicy,
141+
InputIteratorT,
142+
OutputIteratorT,
143+
OffsetT,
144+
ReductionOpT,
145+
AccumT>;
146+
147+
// Shared memory storage
148+
__shared__ typename AgentReduceT::TempStorage temp_storage;
149+
150+
OffsetT segment_begin = d_begin_offsets[blockIdx.x];
151+
OffsetT segment_end = d_end_offsets[blockIdx.x];
152+
153+
// Check if empty problem
154+
if (segment_begin == segment_end)
155+
{
156+
if (threadIdx.x == 0)
157+
{
158+
*(d_out + blockIdx.x) = init;
159+
}
160+
return;
161+
}
162+
163+
// Consume input tiles
164+
AccumT block_aggregate = AgentReduceT(temp_storage, d_in, reduction_op).ConsumeRange(segment_begin, segment_end);
165+
166+
// Normalize as needed
167+
NormalizeReductionOutput(block_aggregate, segment_begin, d_in);
168+
169+
if (threadIdx.x == 0)
170+
{
171+
finalize_and_store_aggregate(d_out + blockIdx.x, reduction_op, init, block_aggregate);
172+
}
173+
}
174+
175+
} // namespace reduce
176+
} // namespace detail
177+
178+
CUB_NAMESPACE_END

cub/cub/iterator/arg_index_input_iterator.cuh

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@
4343
# pragma system_header
4444
#endif // no system header
4545

46-
#include <cub/thread/thread_load.cuh>
47-
#include <cub/thread/thread_store.cuh>
46+
#include <cub/util_type.cuh>
4847

4948
#include <thrust/iterator/iterator_facade.h>
5049
#include <thrust/iterator/iterator_traits.h>
51-
#include <thrust/version.h>
5250

53-
#include <ostream>
51+
#if !_CCCL_COMPILER(NVRTC)
52+
# include <ostream>
53+
#endif // !_CCCL_COMPILER(NVRTC)
5454

5555
CUB_NAMESPACE_BEGIN
5656

@@ -246,11 +246,12 @@ public:
246246
offset = 0;
247247
}
248248

249-
/// ostream operator
249+
#if !_CCCL_COMPILER(NVRTC)
250250
friend std::ostream& operator<<(std::ostream& os, const self_type& /*itr*/)
251251
{
252252
return os;
253253
}
254+
#endif // !_CCCL_COMPILER(NVRTC)
254255
};
255256

256257
CUB_NAMESPACE_END

cub/test/catch2_test_nvrtc.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ TEST_CASE("Test nvrtc", "[test][nvrtc]")
5959
#include <cub/device/dispatch/kernels/for_each.cuh>
6060
#include <cub/device/dispatch/kernels/scan.cuh>
6161
#include <cub/device/dispatch/kernels/merge_sort.cuh>
62+
#include <cub/device/dispatch/kernels/segmented_reduce.cuh>
6263
6364
#include <thrust/iterator/constant_iterator.h>
6465
#include <thrust/iterator/counting_iterator.h>

0 commit comments

Comments
 (0)