Skip to content

Commit dcc9178

Browse files
authored
Merge pull request #701 from haahh/pr_sort_by_key_custom_struct2
radix sort by key with custom value type, fixes #162
2 parents e794833 + c21f705 commit dcc9178

File tree

3 files changed

+86
-5
lines changed

3 files changed

+86
-5
lines changed

include/boost/compute/algorithm/detail/radix_sort.hpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
#include <boost/type_traits/is_signed.hpp>
1818
#include <boost/type_traits/is_floating_point.hpp>
1919

20+
#include <boost/mpl/and.hpp>
21+
#include <boost/mpl/not.hpp>
22+
2023
#include <boost/compute/kernel.hpp>
2124
#include <boost/compute/program.hpp>
2225
#include <boost/compute/command_queue.hpp>
@@ -305,9 +308,12 @@ inline void radix_sort_impl(const buffer_iterator<T> first,
305308
options << " -DASC";
306309
}
307310

311+
// get type definition if it is a custom struct
312+
std::string custom_type_def = boost::compute::type_definition<T2>() + "\n";
313+
308314
// load radix sort program
309315
program radix_sort_program = cache->get_or_build(
310-
cache_key, options.str(), radix_sort_source, context
316+
cache_key, options.str(), custom_type_def + radix_sort_source, context
311317
);
312318

313319
kernel count_kernel(radix_sort_program, "count");

include/boost/compute/type_traits/type_definition.hpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ namespace compute {
1818
namespace detail {
1919

2020
template<class T>
21-
struct type_definition_trait;
21+
struct type_definition_trait
22+
{
23+
static std::string value() { return std::string(); }
24+
};
2225

2326
} // end detail namespace
2427

test/test_sort_by_key.cpp

+75-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@
1515
#include <boost/compute/algorithm/sort_by_key.hpp>
1616
#include <boost/compute/algorithm/is_sorted.hpp>
1717
#include <boost/compute/container/vector.hpp>
18+
#include <boost/compute/types/struct.hpp>
19+
20+
struct custom_struct
21+
{
22+
boost::compute::int_ x;
23+
boost::compute::int_ y;
24+
boost::compute::float2_ zw;
25+
};
26+
27+
BOOST_COMPUTE_ADAPT_STRUCT(custom_struct, custom_struct, (x, y, zw))
1828

1929
#include "check_macros.hpp"
2030
#include "context_setup.hpp"
@@ -69,15 +79,15 @@ BOOST_AUTO_TEST_CASE(sort_int_2)
6979
BOOST_AUTO_TEST_CASE(sort_char_by_int)
7080
{
7181
int keys_data[] = { 6, 2, 1, 3, 4, 7, 5, 0 };
72-
char values_data[] = { 'g', 'c', 'b', 'd', 'e', 'h', 'f', 'a' };
82+
compute::char_ values_data[] = { 'g', 'c', 'b', 'd', 'e', 'h', 'f', 'a' };
7383

7484
compute::vector<int> keys(keys_data, keys_data + 8, queue);
75-
compute::vector<char> values(values_data, values_data + 8, queue);
85+
compute::vector<compute::char_> values(values_data, values_data + 8, queue);
7686

7787
compute::sort_by_key(keys.begin(), keys.end(), values.begin(), queue);
7888

7989
CHECK_RANGE_EQUAL(int, 8, keys, (0, 1, 2, 3, 4, 5, 6, 7));
80-
CHECK_RANGE_EQUAL(char, 8, values, ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'));
90+
CHECK_RANGE_EQUAL(compute::char_, 8, values, ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'));
8191
}
8292

8393
BOOST_AUTO_TEST_CASE(sort_int_and_float)
@@ -132,4 +142,66 @@ BOOST_AUTO_TEST_CASE(sort_int_and_float_custom_comparison_func)
132142
BOOST_CHECK(compute::is_sorted(values.begin(), values.end(), queue) == true);
133143
}
134144

145+
BOOST_AUTO_TEST_CASE(sort_int_and_float2)
146+
{
147+
using boost::compute::int_;
148+
using boost::compute::float2_;
149+
150+
int n = 1024;
151+
std::vector<int_> host_keys(n);
152+
std::vector<float2_> host_values(n);
153+
for(int i = 0; i < n; i++){
154+
host_keys[i] = n - i;
155+
host_values[i] = float2_((n - i) / 2.f);
156+
}
157+
158+
BOOST_COMPUTE_FUNCTION(bool, sort_float2, (float2_ a, float2_ b),
159+
{
160+
return a.x < b.x;
161+
});
162+
163+
compute::vector<int_> keys(host_keys.begin(), host_keys.end(), queue);
164+
compute::vector<float2_> values(host_values.begin(), host_values.end(), queue);
165+
166+
BOOST_CHECK(compute::is_sorted(keys.begin(), keys.end(), queue) == false);
167+
BOOST_CHECK(compute::is_sorted(values.begin(), values.end(), sort_float2, queue) == false);
168+
169+
compute::sort_by_key(keys.begin(), keys.end(), values.begin(), queue);
170+
171+
BOOST_CHECK(compute::is_sorted(keys.begin(), keys.end(), queue) == true);
172+
BOOST_CHECK(compute::is_sorted(values.begin(), values.end(), sort_float2, queue) == true);
173+
}
174+
175+
BOOST_AUTO_TEST_CASE(sort_custom_struct_by_int)
176+
{
177+
using boost::compute::int_;
178+
using boost::compute::float2_;
179+
180+
int_ n = 1024;
181+
std::vector<int_> host_keys(n);
182+
std::vector<custom_struct> host_values(n);
183+
for(int_ i = 0; i < n; i++){
184+
host_keys[i] = n - i;
185+
host_values[i].x = n - i;
186+
host_values[i].y = n - i;
187+
host_values[i].zw = float2_((n - i) / 0.5f);
188+
}
189+
190+
BOOST_COMPUTE_FUNCTION(bool, sort_custom_struct, (custom_struct a, custom_struct b),
191+
{
192+
return a.x < b.x;
193+
});
194+
195+
compute::vector<int_> keys(host_keys.begin(), host_keys.end(), queue);
196+
compute::vector<custom_struct> values(host_values.begin(), host_values.end(), queue);
197+
198+
BOOST_CHECK(compute::is_sorted(keys.begin(), keys.end(), queue) == false);
199+
BOOST_CHECK(compute::is_sorted(values.begin(), values.end(), sort_custom_struct, queue) == false);
200+
201+
compute::sort_by_key(keys.begin(), keys.end(), values.begin(), queue);
202+
203+
BOOST_CHECK(compute::is_sorted(keys.begin(), keys.end(), queue) == true);
204+
BOOST_CHECK(compute::is_sorted(values.begin(), values.end(), sort_custom_struct, queue) == true);
205+
}
206+
135207
BOOST_AUTO_TEST_SUITE_END()

0 commit comments

Comments
 (0)