Skip to content

Commit 35eeb64

Browse files
WentzellThoemi09
andcommitted
Create custom mpi types for serializable custom types
- Add Serializable concept - Add tests for serializable MPI datatypes - Add doc strings Co-authored-by: Thomas Hahn <[email protected]>
1 parent a22f9a8 commit 35eeb64

File tree

5 files changed

+197
-18
lines changed

5 files changed

+197
-18
lines changed

c++/mpi/datatypes.hpp

+105-6
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,14 @@ namespace mpi {
7676
D(unsigned long long, MPI_UNSIGNED_LONG_LONG);
7777
#undef D
7878

79+
/**
80+
* @brief Specialization of mpi::mpi_type for enum types.
81+
* @tparam E C++ enum type.
82+
*/
83+
template <typename E>
84+
requires(std::is_enum_v<E>)
85+
struct mpi_type<E> : mpi_type<std::underlying_type_t<E>> {};
86+
7987
/**
8088
* @brief Specialization of mpi::mpi_type for `const` types.
8189
* @tparam T C++ type.
@@ -94,6 +102,28 @@ namespace mpi {
94102
*/
95103
template <typename T> constexpr bool has_mpi_type<T, std::void_t<decltype(mpi_type<T>::get())>> = true;
96104

105+
namespace detail {
106+
107+
// Helper struct to check if member types are mpi-serializable, i.e. have an associated mpi_type
108+
struct serialize_checker {
109+
template <typename T>
110+
void operator&(T &)
111+
requires(has_mpi_type<T>)
112+
{}
113+
};
114+
115+
} // namespace detail
116+
117+
/**
118+
* @brief A concept that checks if objects of a type can be serialized and deserialized.
119+
* @tparam T Type to check.
120+
*/
121+
template <typename T>
122+
concept Serializable = requires(const T ac, T a, detail::serialize_checker ar) {
123+
{ ac.serialize(ar) } -> std::same_as<void>;
124+
{ a.deserialize(ar) } -> std::same_as<void>;
125+
};
126+
97127
/**
98128
* @brief Create a new `MPI_Datatype` from a tuple.
99129
*
@@ -135,8 +165,11 @@ namespace mpi {
135165
* @brief Specialization of mpi::mpi_type for std::tuple.
136166
* @tparam Ts Tuple element types.
137167
*/
138-
template <typename... T> struct mpi_type<std::tuple<T...>> {
139-
[[nodiscard]] static MPI_Datatype get() noexcept { return get_mpi_type(std::tuple<T...>{}); }
168+
template <typename... Ts> struct mpi_type<std::tuple<Ts...>> {
169+
[[nodiscard]] static MPI_Datatype get() noexcept {
170+
static MPI_Datatype type = get_mpi_type(std::tuple<Ts...>{});
171+
return type;
172+
}
140173
};
141174

142175
/**
@@ -156,15 +189,81 @@ namespace mpi {
156189
* auto tie_data(foo f) {
157190
* return std::tie(f.x, f.y);
158191
* }
192+
* @endcode
193+
*
194+
* @tparam U Type to be converted to an `MPI_Datatype`.
195+
*/
196+
template <typename U>
197+
requires(not Serializable<U>) and requires(U u) { tie_data(u); }
198+
struct mpi_type<U> {
199+
[[nodiscard]] static MPI_Datatype get() noexcept {
200+
static MPI_Datatype type = get_mpi_type(tie_data(U{}));
201+
return type;
202+
}
203+
};
204+
205+
namespace detail {
206+
207+
// Archive helper class to obtain MPI custom type info using references to class members.
208+
struct mpi_archive {
209+
std::vector<int> block_lengths{};
210+
std::vector<MPI_Aint> displacements{};
211+
std::vector<MPI_Datatype> types{};
212+
MPI_Aint base_address{};
213+
214+
// Constructor sets the base address of the object.
215+
explicit mpi_archive(const void *base) { MPI_Get_address(base, &base_address); }
216+
217+
// Overloaded operator& to process members to set the block lengths, displacements and MPI types.
218+
template <typename T>
219+
requires(has_mpi_type<T>)
220+
mpi_archive &operator&(const T &member) {
221+
types.push_back(mpi_type<T>::get());
222+
MPI_Aint address{};
223+
MPI_Get_address(&member, &address);
224+
displacements.push_back(MPI_Aint_diff(address, base_address));
225+
block_lengths.push_back(1);
226+
return *this;
227+
}
228+
};
229+
230+
} // namespace detail
231+
232+
/**
233+
* @brief Create an `MPI_Datatype` from a serializable type.
234+
*
235+
* @details It is assumed that the type has a member function `serialize`
236+
* which feeds all its class members into an archive using the `operator&`.
159237
*
160-
* // provide a specialization of mpi_type
161-
* template <> struct mpi::mpi_type<foo> : mpi::mpi_type_from_tie<foo> {};
238+
* @code{.cpp}
239+
* // type to use for MPI communication
240+
* struct foo {
241+
* double x;
242+
* int y;
243+
* void serialize(auto& ar) const { ar & x & y; }
244+
* };
162245
* @endcode
163246
*
164247
* @tparam T Type to be converted to an `MPI_Datatype`.
165248
*/
166-
template <typename T> struct mpi_type_from_tie {
167-
[[nodiscard]] static MPI_Datatype get() noexcept { return get_mpi_type(tie_data(T{})); }
249+
template <Serializable T> [[nodiscard]] MPI_Datatype get_mpi_type(const T &obj) {
250+
detail::mpi_archive ar(&obj);
251+
obj.serialize(ar);
252+
MPI_Datatype mpi_type{};
253+
MPI_Type_create_struct(static_cast<int>(ar.block_lengths.size()), ar.block_lengths.data(), ar.displacements.data(), ar.types.data(), &mpi_type);
254+
MPI_Type_commit(&mpi_type);
255+
return mpi_type;
256+
}
257+
258+
/**
259+
* @brief Specialization of mpi::mpi_type for serializable types.
260+
* @tparam S Serializable type.
261+
*/
262+
template <Serializable S> struct mpi_type<S> {
263+
[[nodiscard]] static MPI_Datatype get() noexcept {
264+
static MPI_Datatype type = get_mpi_type(S{});
265+
return type;
266+
}
168267
};
169268

170269
/** @} */

c++/mpi/utils.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
#include <stdexcept>
2727
#include <string>
28+
#include <type_traits>
2829

2930
namespace mpi {
3031

doc/DoxygenLayout.xml

+8-3
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,24 @@
2929
<tab type="user" url="@ref mpi::environment" title="environment"/>
3030
</tab>
3131
<tab type="usergroup" url="@ref mpi_types_ops" title="MPI datatypes and operations">
32+
<tab type="user" url="@ref mpi::Serializable" title="Serializable"/>
3233
<tab type="usergroup" url="@ref mpi::mpi_type" title="mpi_type">
3334
<tab type="user" url="@ref mpi::mpi_type< bool >" title="mpi_type<bool>"/>
3435
<tab type="user" url="@ref mpi::mpi_type< char >" title="mpi_type<char>"/>
36+
<tab type="user" url="@ref mpi::mpi_type< const T >" title="mpi_type<const T>"/>
37+
<tab type="user" url="@ref mpi::mpi_type< double >" title="mpi_type<double>"/>
38+
<tab type="user" url="@ref mpi::mpi_type< E >" title="mpi_type<E>"/>
39+
<tab type="user" url="@ref mpi::mpi_type< float >" title="mpi_type<float>"/>
3540
<tab type="user" url="@ref mpi::mpi_type< int >" title="mpi_type<int>"/>
3641
<tab type="user" url="@ref mpi::mpi_type< long >" title="mpi_type<long>"/>
3742
<tab type="user" url="@ref mpi::mpi_type< long long >" title="mpi_type<long long>"/>
38-
<tab type="user" url="@ref mpi::mpi_type< double >" title="mpi_type<double>"/>
39-
<tab type="user" url="@ref mpi::mpi_type< float >" title="mpi_type<float>"/>
43+
<tab type="user" url="@ref mpi::mpi_type< S >" title="mpi_type<S>"/>
4044
<tab type="user" url="@ref mpi::mpi_type< std::complex< double > >" title="mpi_type<std::complex<double>>"/>
45+
<tab type="user" url="@ref mpi::mpi_type< std::tuple< Ts... > >" title="mpi_type<std::tuple>"/>
46+
<tab type="user" url="@ref mpi::mpi_type< U >" title="mpi_type<U>"/>
4147
<tab type="user" url="@ref mpi::mpi_type< unsigned int >" title="mpi_type<unsigned int>"/>
4248
<tab type="user" url="@ref mpi::mpi_type< unsigned long >" title="mpi_type<unsigned long>"/>
4349
<tab type="user" url="@ref mpi::mpi_type< unsigned long long >" title="mpi_type<unsigned long long>"/>
44-
<tab type="user" url="@ref mpi::mpi_type< std::tuple< Ts... > >" title="mpi_type<std::tuple>"/>
4550
</tab>
4651
</tab>
4752
<tab type="user" url="@ref coll_comm" title="Collective MPI communication"/>

doc/ex3.md

+3-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
[TOC]
44

5-
In this example, we show how to use mpi::mpi_type_from_tie, mpi::map_C_function and mpi::map_add to register a new MPI datatype and to define MPI operations for it.
5+
In this example, we show how to register a new MPI datatype and how to use mpi::map_C_function and mpi::map_add to
6+
define MPI operations for it.
67

78
```cpp
89
#include <mpi/mpi.hpp>
@@ -19,14 +20,11 @@ inline my_complex operator+(const my_complex& z1, const my_complex& z2) {
1920
return { z1.real + z2.real, z1.imag + z2.imag };
2021
}
2122

22-
// define a tie_data function for mpi_type_from_tie
23+
// define a tie_data function for my_complex to make it MPI compatible
2324
inline auto tie_data(const my_complex& z) {
2425
return std::tie(z.real, z.imag);
2526
}
2627

27-
// register my_complex as an MPI type
28-
template <> struct mpi::mpi_type<my_complex> : mpi::mpi_type_from_tie<my_complex> {};
29-
3028
int main(int argc, char *argv[]) {
3129
// initialize MPI environment
3230
mpi::environment env(argc, argv);

test/c++/mpi_custom.cpp

+80-4
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@ struct custom_cplx {
3939
// tie the data (used to construct the custom MPI type)
4040
inline auto tie_data(custom_cplx z) { return std::tie(z.real, z.imag); }
4141

42-
// specialize mpi_type for custom_cplx
43-
template <> struct mpi::mpi_type<custom_cplx> : mpi::mpi_type_from_tie<custom_cplx> {};
44-
4542
// stand-alone add function (the same as the operator+ above)
4643
custom_cplx add(custom_cplx const &x, custom_cplx const &y) { return x + y; }
4744

@@ -131,9 +128,88 @@ TEST(MPI, TupleMPIDatatypes) {
131128

132129
using type5 = std::tuple<int, double, char, custom_cplx, bool>;
133130
type5 tup5;
134-
if (rank == root) { tup5 = std::make_tuple(100, 3.1314, 'r', custom_cplx{1.0, 2.0}, false); }
131+
if (rank == root) { tup5 = std::make_tuple(100, 3.1314, 'r', custom_cplx{.real = 1.0, .imag = 2.0}, false); }
135132
mpi::broadcast(tup5, world, root);
136133
EXPECT_EQ(tup5, std::make_tuple(100, 3.1314, 'r', custom_cplx{1.0, 2.0}, false));
137134
}
138135

136+
// a simple struct representing a complex number that is serializable
137+
struct serializable_cplx {
138+
double real{}, imag{};
139+
140+
// add two serializable_cplx objects
141+
serializable_cplx operator+(serializable_cplx z) const {
142+
z.real += real;
143+
z.imag += imag;
144+
return z;
145+
}
146+
147+
// default equal-to operator
148+
bool operator==(const serializable_cplx &) const = default;
149+
150+
// serialize the object
151+
void serialize(auto &ar) const { ar & real & imag; }
152+
void deserialize(auto &ar) { ar & real & imag; }
153+
};
154+
155+
// a simple struct that contains a serializable type and is serializable itself
156+
struct serializable_container {
157+
serializable_cplx z1;
158+
custom_cplx z2;
159+
160+
// add two serializable_container objects
161+
serializable_container operator+(serializable_container z) const {
162+
z.z1 = z.z1 + z1;
163+
z.z2 = z.z2 + z2;
164+
return z;
165+
}
166+
167+
// default equal-to operator
168+
bool operator==(const serializable_container &) const = default;
169+
170+
// serialize the object
171+
void serialize(auto &ar) const { ar & z1 & z2; }
172+
void deserialize(auto &ar) { ar & z1 & z2; }
173+
};
174+
175+
// check Serializable concept
176+
static_assert(mpi::Serializable<serializable_cplx>);
177+
static_assert(mpi::Serializable<serializable_container>);
178+
179+
TEST(MPI, SerializableMPIDatatypes) {
180+
mpi::communicator world;
181+
int rank = world.rank();
182+
int root = 0;
183+
184+
// check broadcast
185+
auto z_exp = serializable_cplx{.real = 1.0, .imag = 2.0};
186+
auto z = (rank == root ? z_exp : serializable_cplx{});
187+
mpi::broadcast(z, world, root);
188+
EXPECT_EQ(z, z_exp);
189+
190+
// check all_reduce
191+
auto z_red = mpi::all_reduce(z, world, mpi::map_add<serializable_cplx>());
192+
EXPECT_DOUBLE_EQ(z_exp.real * world.size(), z_red.real);
193+
EXPECT_DOUBLE_EQ(z.imag * world.size(), z_red.imag);
194+
}
195+
196+
TEST(MPI, SerializableOfSerializableMPIDatatypes) {
197+
mpi::communicator world;
198+
int rank = world.rank();
199+
int root = 0;
200+
201+
// check broadcast
202+
auto c_exp = serializable_container{.z1 = {.real = 1.0, .imag = 2.0}, .z2 = {.real = 3.0, .imag = 4.0}};
203+
auto c = (rank == root ? c_exp : serializable_container{});
204+
mpi::broadcast(c, world, root);
205+
EXPECT_EQ(c, c_exp);
206+
207+
// check all_reduce
208+
auto c_red = mpi::all_reduce(c, world, mpi::map_add<serializable_container>());
209+
EXPECT_DOUBLE_EQ(c_exp.z1.real * world.size(), c_red.z1.real);
210+
EXPECT_DOUBLE_EQ(c_exp.z1.imag * world.size(), c_red.z1.imag);
211+
EXPECT_DOUBLE_EQ(c_exp.z2.real * world.size(), c_red.z2.real);
212+
EXPECT_DOUBLE_EQ(c_exp.z2.imag * world.size(), c_red.z2.imag);
213+
}
214+
139215
MPI_TEST_MAIN;

0 commit comments

Comments
 (0)