@@ -76,6 +76,14 @@ namespace mpi {
76
76
D (unsigned long long , MPI_UNSIGNED_LONG_LONG);
77
77
#undef D
78
78
79
+ /* *
80
+ * @brief Specialization of mpi::mpi_type for enum types
81
+ * @tparam T C++ enum type.
82
+ */
83
+ template <typename T>
84
+ requires (std::is_enum_v<T>)
85
+ struct mpi_type<T> : mpi_type<std::underlying_type_t<T>> {};
86
+
79
87
/* *
80
88
* @brief Specialization of mpi::mpi_type for `const` types.
81
89
* @tparam T C++ type.
@@ -136,7 +144,10 @@ namespace mpi {
136
144
* @tparam Ts Tuple element types.
137
145
*/
138
146
template <typename ... T> struct mpi_type <std::tuple<T...>> {
139
- [[nodiscard]] static MPI_Datatype get () noexcept { return get_mpi_type (std::tuple<T...>{}); }
147
+ [[nodiscard]] static MPI_Datatype get () noexcept {
148
+ static MPI_Datatype type = get_mpi_type (std::tuple<T...>{});
149
+ return type;
150
+ }
140
151
};
141
152
142
153
/* *
@@ -163,8 +174,75 @@ namespace mpi {
163
174
*
164
175
* @tparam T Type to be converted to an `MPI_Datatype`.
165
176
*/
166
- template <typename T> struct mpi_type_from_tie {
167
- [[nodiscard]] static MPI_Datatype get () noexcept { return get_mpi_type (tie_data (T{})); }
177
+ template <typename T>
178
+ requires requires (T t) { tie_data (t); }
179
+ struct mpi_type <T> {
180
+ [[nodiscard]] static MPI_Datatype get () noexcept {
181
+ static MPI_Datatype type = get_mpi_type (tie_data (T{}));
182
+ return type;
183
+ }
184
+ };
185
+
186
+ namespace detail {
187
+ // Archive helper class obtain MPI custom type info using reference to class members
188
+ struct MpiArchive {
189
+ std::vector<int > block_lengths{};
190
+ std::vector<MPI_Aint> displacements{};
191
+ std::vector<MPI_Datatype> types{};
192
+ MPI_Aint base_address{};
193
+
194
+ public:
195
+ explicit MpiArchive (const void *base) { MPI_Get_address (base, &base_address); }
196
+
197
+ // Overloaded operator& to process members
198
+ template <typename T>
199
+ requires (has_mpi_type<T>)
200
+ MpiArchive &operator &(const T &member) {
201
+ types.push_back (mpi_type<T>::get ());
202
+ MPI_Aint address{};
203
+ MPI_Get_address (&member, &address);
204
+ displacements.push_back (address - base_address);
205
+ block_lengths.push_back (1 );
206
+ return *this ;
207
+ }
208
+ };
209
+ } // namespace detail
210
+
211
+ /* *
212
+ * @brief Create an `MPI_Datatype` from a serializable type.
213
+ *
214
+ * @details It is assumed that the type has a member function `serialize`
215
+ * which feeds all its class members into an archive using the `operator&`.
216
+ *
217
+ * @code{.cpp}
218
+ * // type to use for MPI communication
219
+ * struct foo {
220
+ * double x;
221
+ * int y;
222
+ * void serialize(auto& ar) const { ar & x & y; }
223
+ * };
224
+ * @endcode
225
+ *
226
+ * @tparam T Type to be converted to an `MPI_Datatype`.
227
+ */
228
+ template <Serializable T> [[nodiscard]] MPI_Datatype get_mpi_type (const T &obj) {
229
+ detail::MpiArchive ar (&obj);
230
+ obj.serialize (ar);
231
+ MPI_Datatype mpi_type{};
232
+ MPI_Type_create_struct (ar.block_lengths .size (), ar.block_lengths .data (), ar.displacements .data (), ar.types .data (), &mpi_type);
233
+ MPI_Type_commit (&mpi_type);
234
+ return mpi_type;
235
+ }
236
+
237
+ /* *
238
+ * @brief Specialization of mpi::mpi_type for serializable types.
239
+ * @tparam T Serializable type.
240
+ */
241
+ template <Serializable T> struct mpi_type <T> {
242
+ [[nodiscard]] static MPI_Datatype get () noexcept {
243
+ static MPI_Datatype type = get_mpi_type (T{});
244
+ return type;
245
+ }
168
246
};
169
247
170
248
/* * @} */
0 commit comments