Skip to content

Commit faaeb1b

Browse files
committed
Create custom mpi types for serializable custom types
- Add Serializable concept
1 parent a22f9a8 commit faaeb1b

File tree

3 files changed

+107
-6
lines changed

3 files changed

+107
-6
lines changed

c++/mpi/datatypes.hpp

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,14 @@ namespace mpi {
7676
D(unsigned long long, MPI_UNSIGNED_LONG_LONG);
7777
#undef D
7878

79+
/**
80+
* @brief Specialization of mpi::mpi_type for enum types
81+
* @tparam T C++ enum type.
82+
*/
83+
template <typename T>
84+
requires(std::is_enum_v<T>)
85+
struct mpi_type<T> : mpi_type<std::underlying_type_t<T>> {};
86+
7987
/**
8088
* @brief Specialization of mpi::mpi_type for `const` types.
8189
* @tparam T C++ type.
@@ -136,7 +144,10 @@ namespace mpi {
136144
* @tparam Ts Tuple element types.
137145
*/
138146
template <typename... T> struct mpi_type<std::tuple<T...>> {
139-
[[nodiscard]] static MPI_Datatype get() noexcept { return get_mpi_type(std::tuple<T...>{}); }
147+
[[nodiscard]] static MPI_Datatype get() noexcept {
148+
static MPI_Datatype type = get_mpi_type(std::tuple<T...>{});
149+
return type;
150+
}
140151
};
141152

142153
/**
@@ -163,8 +174,75 @@ namespace mpi {
163174
*
164175
* @tparam T Type to be converted to an `MPI_Datatype`.
165176
*/
166-
template <typename T> struct mpi_type_from_tie {
167-
[[nodiscard]] static MPI_Datatype get() noexcept { return get_mpi_type(tie_data(T{})); }
177+
template <typename T>
178+
requires requires(T t) { tie_data(t); }
179+
struct mpi_type<T> {
180+
[[nodiscard]] static MPI_Datatype get() noexcept {
181+
static MPI_Datatype type = get_mpi_type(tie_data(T{}));
182+
return type;
183+
}
184+
};
185+
186+
namespace detail {
187+
// Archive helper class obtain MPI custom type info using reference to class members
188+
struct MpiArchive {
189+
std::vector<int> block_lengths{};
190+
std::vector<MPI_Aint> displacements{};
191+
std::vector<MPI_Datatype> types{};
192+
MPI_Aint base_address{};
193+
194+
public:
195+
explicit MpiArchive(const void *base) { MPI_Get_address(base, &base_address); }
196+
197+
// Overloaded operator& to process members
198+
template <typename T>
199+
requires(has_mpi_type<T>)
200+
MpiArchive &operator&(const T &member) {
201+
types.push_back(mpi_type<T>::get());
202+
MPI_Aint address{};
203+
MPI_Get_address(&member, &address);
204+
displacements.push_back(address - base_address);
205+
block_lengths.push_back(1);
206+
return *this;
207+
}
208+
};
209+
} // namespace detail
210+
211+
/**
212+
* @brief Create an `MPI_Datatype` from a serializable type.
213+
*
214+
* @details It is assumed that the type has a member function `serialize`
215+
* which feeds all its class members into an archive using the `operator&`.
216+
*
217+
* @code{.cpp}
218+
* // type to use for MPI communication
219+
* struct foo {
220+
* double x;
221+
* int y;
222+
* void serialize(auto& ar) const { ar & x & y; }
223+
* };
224+
* @endcode
225+
*
226+
* @tparam T Type to be converted to an `MPI_Datatype`.
227+
*/
228+
template <Serializable T> [[nodiscard]] MPI_Datatype get_mpi_type(const T &obj) {
229+
detail::MpiArchive ar(&obj);
230+
obj.serialize(ar);
231+
MPI_Datatype mpi_type{};
232+
MPI_Type_create_struct(ar.block_lengths.size(), ar.block_lengths.data(), ar.displacements.data(), ar.types.data(), &mpi_type);
233+
MPI_Type_commit(&mpi_type);
234+
return mpi_type;
235+
}
236+
237+
/**
238+
* @brief Specialization of mpi::mpi_type for serializable types.
239+
* @tparam T Serializable type.
240+
*/
241+
template <Serializable T> struct mpi_type<T> {
242+
[[nodiscard]] static MPI_Datatype get() noexcept {
243+
static MPI_Datatype type = get_mpi_type(T{});
244+
return type;
245+
}
168246
};
169247

170248
/** @} */

c++/mpi/utils.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
#include <stdexcept>
2727
#include <string>
28+
#include <type_traits>
2829

2930
namespace mpi {
3031

@@ -80,6 +81,31 @@ namespace mpi {
8081
template <typename R>
8182
concept contiguous_sized_range = std::ranges::contiguous_range<R> && std::ranges::sized_range<R>;
8283

84+
namespace detail {
85+
86+
// Helper struct to check if a types serialize function serializes only fundamental types or enums
87+
struct serialize_checker {
88+
89+
template <typename T>
90+
void operator&(T const &t)
91+
requires(std::is_fundamental_v<T> or std::is_enum_v<T> or requires { t.serialize(*this); })
92+
{}
93+
94+
template <typename T>
95+
void operator&(T &t)
96+
requires(std::is_fundamental_v<T> or std::is_enum_v<T> or requires { t.deserialize(*this); })
97+
{}
98+
};
99+
100+
} // namespace detail
101+
102+
/// Check if objects of the type can be serialized and deserialized
103+
template <typename T>
104+
concept Serializable = requires(T a, detail::serialize_checker ar) {
105+
a.serialize(ar);
106+
a.deserialize(ar);
107+
};
108+
83109
/** @} */
84110

85111
} // namespace mpi

test/c++/mpi_custom.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@ struct custom_cplx {
3939
// tie the data (used to construct the custom MPI type)
4040
inline auto tie_data(custom_cplx z) { return std::tie(z.real, z.imag); }
4141

42-
// specialize mpi_type for custom_cplx
43-
template <> struct mpi::mpi_type<custom_cplx> : mpi::mpi_type_from_tie<custom_cplx> {};
44-
4542
// stand-alone add function (the same as the operator+ above)
4643
custom_cplx add(custom_cplx const &x, custom_cplx const &y) { return x + y; }
4744

0 commit comments

Comments
 (0)