Skip to content

Commit 7f21584

Browse files
Improve support for Particles with PolymorphicArenaAllocator (#4603)
## Summary This PR adds a `SetArena(Arena*)` function to ParticleContainerBase that allows setting a memory arena that is used for all the particle vectors if the allocator is PolymorphicArenaAllocator. The function has to be called before particle tiles are defined. This functionality is used to fix a bunch of places where a polymorphic vector would previously not have its arena set properly. Additionally the `RunOnGpu` logic in `AMReX_WriteBinaryParticleData.H` is extended to work with a polymorphic allocator. Uses changes extracted from #4404. ## Additional background Previously all components of all particle tiles needed their arena set individually by the user if PolymorphicArenaAllocator was used. In case this was done this PR is a braking change due to the `AMREX_ALWAYS_ASSERT_WITH_MESSAGE(a_arena != nullptr` assert in `ParticleTile::define()`. ## Checklist The proposed changes: - [ ] fix a bug or incorrect behavior in AMReX - [x] add new capabilities to AMReX - [ ] changes answers in the test suite to more than roundoff level - [ ] are likely to significantly affect the results of downstream AMReX users - [ ] include documentation in the code and/or rst files, if appropriate
1 parent c254bcc commit 7f21584

8 files changed

+240
-74
lines changed

Src/Base/AMReX_GpuAllocators.H

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ namespace amrex {
171171
template <typename T>
172172
struct IsPolymorphicArenaAllocator : std::false_type {};
173173

174+
template <typename T>
175+
struct IsPolymorphicArenaAllocator<PolymorphicArenaAllocator<T> > : std::true_type {};
176+
174177
#ifdef AMREX_USE_GPU
175178
template <typename T>
176179
struct RunOnGpu<ArenaAllocator<T> > : std::true_type {};
@@ -183,10 +186,6 @@ namespace amrex {
183186

184187
template <typename T>
185188
struct RunOnGpu<AsyncArenaAllocator<T> > : std::true_type {};
186-
187-
template <typename T>
188-
struct IsPolymorphicArenaAllocator<PolymorphicArenaAllocator<T> > : std::true_type {};
189-
190189
#endif // AMREX_USE_GPU
191190

192191
#ifdef AMREX_USE_GPU

Src/Particle/AMReX_NeighborParticlesI.H

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,8 @@ NeighborParticleContainer<NStructReal, NStructInt, NArrayReal, NArrayInt>
454454
const Vector<NeighborIndexMap>& map = local_map[lev][dst_index];
455455
const int num_ghosts = map.size();
456456
neighbors[lev][dst_index].define(this->NumRuntimeRealComps(),
457-
this->NumRuntimeIntComps());
457+
this->NumRuntimeIntComps(),
458+
nullptr, nullptr, this->arena());
458459
neighbors[lev][dst_index].resize(num_ghosts);
459460
local_neighbor_sizes[lev][dst_index] = neighbors[lev][dst_index].size();
460461
}

Src/Particle/AMReX_ParticleContainer.H

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,9 @@ public:
197197
using ParIterType = ParIter_impl<ParticleType, NArrayReal, NArrayInt, Allocator, CellAssignor>;
198198
using ParConstIterType = ParConstIter_impl<ParticleType, NArrayReal, NArrayInt, Allocator, CellAssignor>;
199199

200+
static constexpr bool has_polymorphic_allocator =
201+
IsPolymorphicArenaAllocator<Allocator<RealType>>::value;
202+
200203
//! \brief Default constructor - construct an empty particle container that has no concept
201204
//! of a level hierarchy. Must be properly initialized later.
202205
ParticleContainer_impl ()
@@ -1156,8 +1159,11 @@ public:
11561159
*/
11571160
ParticleTileType& DefineAndReturnParticleTile (int lev, int grid, int tile)
11581161
{
1159-
m_particles[lev][std::make_pair(grid, tile)].define(NumRuntimeRealComps(), NumRuntimeIntComps(), &m_soa_rdata_names, &m_soa_idata_names);
1160-
1162+
m_particles[lev][std::make_pair(grid, tile)].define(
1163+
NumRuntimeRealComps(), NumRuntimeIntComps(),
1164+
&m_soa_rdata_names, &m_soa_idata_names,
1165+
arena()
1166+
);
11611167
return ParticlesAt(lev, grid, tile);
11621168
}
11631169

@@ -1186,7 +1192,11 @@ public:
11861192
ParticleTileType& DefineAndReturnParticleTile (int lev, const Iterator& iter)
11871193
{
11881194
auto index = std::make_pair(iter.index(), iter.LocalTileIndex());
1189-
m_particles[lev][index].define(NumRuntimeRealComps(), NumRuntimeIntComps(), &m_soa_rdata_names, &m_soa_idata_names);
1195+
m_particles[lev][index].define(
1196+
NumRuntimeRealComps(), NumRuntimeIntComps(),
1197+
&m_soa_rdata_names, &m_soa_idata_names,
1198+
arena()
1199+
);
11901200
return ParticlesAt(lev, iter);
11911201
}
11921202

Src/Particle/AMReX_ParticleContainerBase.H

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,14 @@ public:
245245
template <class MF>
246246
bool OnSameGrids (int level, const MF& mf) const { return m_gdb->OnSameGrids(level, mf); }
247247

248+
[[nodiscard]] Arena* arena () const {
249+
return m_arena;
250+
}
251+
252+
void SetArena (Arena* a) {
253+
m_arena = a;
254+
}
255+
248256
static const std::string& CheckpointVersion ();
249257
static const std::string& PlotfileVersion ();
250258
static const std::string& DataPrefix ();
@@ -269,6 +277,7 @@ protected:
269277
std::unique_ptr<ParGDB> m_gdb_object = std::make_unique<ParGDB>();
270278
ParGDBBase* m_gdb{nullptr};
271279
Vector<std::unique_ptr<MultiFab> > m_dummy_mf;
280+
Arena* m_arena = nullptr;
272281

273282
mutable std::unique_ptr<iMultiFab> redistribute_mask_ptr;
274283
mutable int redistribute_mask_nghost = std::numeric_limits<int>::min();

Src/Particle/AMReX_ParticleContainerI.H

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,8 +1255,11 @@ ParticleContainer_impl<ParticleType, NArrayReal, NArrayInt, Allocator, CellAssig
12551255
}
12561256
Gpu::streamSynchronize();
12571257
} else {
1258-
typename SoA::IdCPU tmp_idcpu(np_total);
1259-
1258+
typename SoA::IdCPU tmp_idcpu;
1259+
if constexpr (has_polymorphic_allocator) {
1260+
tmp_idcpu.setArena(arena());
1261+
}
1262+
tmp_idcpu.resize(np_total);
12601263
auto src = ptile.GetStructOfArrays().GetIdCPUData().data();
12611264
uint64_t* dst = tmp_idcpu.data();
12621265
AMREX_HOST_DEVICE_FOR_1D( np_total, i,
@@ -1270,7 +1273,11 @@ ParticleContainer_impl<ParticleType, NArrayReal, NArrayInt, Allocator, CellAssig
12701273
}
12711274

12721275
{ // Create a scope for the temporary vector below
1273-
RealVector tmp_real(np_total);
1276+
RealVector tmp_real;
1277+
if constexpr (has_polymorphic_allocator) {
1278+
tmp_real.setArena(arena());
1279+
}
1280+
tmp_real.resize(np_total);
12741281
for (int comp = 0; comp < NArrayReal + m_num_runtime_real; ++comp) {
12751282
auto src = ptile.GetStructOfArrays().GetRealData(comp).data();
12761283
ParticleReal* dst = tmp_real.data();
@@ -1285,7 +1292,11 @@ ParticleContainer_impl<ParticleType, NArrayReal, NArrayInt, Allocator, CellAssig
12851292
}
12861293
}
12871294

1288-
IntVector tmp_int(np_total);
1295+
IntVector tmp_int;
1296+
if constexpr (has_polymorphic_allocator) {
1297+
tmp_int.setArena(arena());
1298+
}
1299+
tmp_int.resize(np_total);
12891300
for (int comp = 0; comp < NArrayInt + m_num_runtime_int; ++comp) {
12901301
auto src = ptile.GetStructOfArrays().GetIntData(comp).data();
12911302
int* dst = tmp_int.data();
@@ -1300,7 +1311,8 @@ ParticleContainer_impl<ParticleType, NArrayReal, NArrayInt, Allocator, CellAssig
13001311
}
13011312
} else {
13021313
ParticleTileType ptile_tmp;
1303-
ptile_tmp.define(m_num_runtime_real, m_num_runtime_int, &m_soa_rdata_names, &m_soa_idata_names);
1314+
ptile_tmp.define(m_num_runtime_real, m_num_runtime_int,
1315+
&m_soa_rdata_names, &m_soa_idata_names, arena());
13041316
ptile_tmp.resize(np_total);
13051317
// copy re-ordered particles
13061318
gatherParticles(ptile_tmp, ptile, np, permutations);
@@ -1643,7 +1655,21 @@ ParticleContainer_impl<ParticleType, NArrayReal, NArrayInt, Allocator, CellAssig
16431655
tmp_local[lev][index].resize(num_threads);
16441656
soa_local[lev][index].resize(num_threads);
16451657
for (int t = 0; t < num_threads; ++t) {
1646-
soa_local[lev][index][t].define(m_num_runtime_real, m_num_runtime_int, &m_soa_rdata_names, &m_soa_idata_names);
1658+
soa_local[lev][index][t].define(m_num_runtime_real, m_num_runtime_int,
1659+
&m_soa_rdata_names, &m_soa_idata_names);
1660+
if constexpr (has_polymorphic_allocator) {
1661+
if constexpr (ParticleType::is_soa_particle) {
1662+
soa_local[lev][index][t].GetIdCPUData().setArena(arena());
1663+
} else {
1664+
tmp_local[lev][index][t].setArena(arena());
1665+
}
1666+
for (int j = 0; j < soa_local[lev][index][t].NumRealComps(); ++j) {
1667+
soa_local[lev][index][t].GetRealData(j).setArena(arena());
1668+
}
1669+
for (int j = 0; j < soa_local[lev][index][t].NumIntComps(); ++j) {
1670+
soa_local[lev][index][t].GetIntData(j).setArena(arena());
1671+
}
1672+
}
16471673
}
16481674
}
16491675
}

Src/Particle/AMReX_ParticleIO.H

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,8 +1215,9 @@ ParticleContainer_impl<ParticleType, NArrayReal, NArrayInt, Allocator, CellAssig
12151215
auto& pmap = m_particles[lev];
12161216
for (const auto& kv : pmap) {
12171217
ParticleTile<ParticleType, NArrayReal, NArrayInt,
1218-
amrex::PinnedArenaAllocator> pinned_ptile;
1219-
pinned_ptile.define(NumRuntimeRealComps(), NumRuntimeIntComps());
1218+
amrex::PolymorphicArenaAllocator> pinned_ptile;
1219+
pinned_ptile.define(NumRuntimeRealComps(), NumRuntimeIntComps(),
1220+
nullptr, nullptr, The_Pinned_Arena());
12201221
pinned_ptile.resize(kv.second.numParticles());
12211222
amrex::copyParticles(pinned_ptile, kv.second);
12221223
const auto& host_aos = pinned_ptile.GetArrayOfStructs();

Src/Particle/AMReX_ParticleTile.H

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,9 @@ struct ParticleTile
748748
using ParticleTileDataType = ParticleTileData<StorageParticleType, NArrayReal, NArrayInt>;
749749
using ConstParticleTileDataType = ConstParticleTileData<StorageParticleType, NArrayReal, NArrayInt>;
750750

751+
static constexpr bool has_polymorphic_allocator =
752+
IsPolymorphicArenaAllocator<Allocator<RealType>>::value;
753+
751754
ParticleTile () = default;
752755

753756
#ifndef _WIN32 // workaround windows compiler bug
@@ -764,15 +767,51 @@ struct ParticleTile
764767
int a_num_runtime_real,
765768
int a_num_runtime_int,
766769
std::vector<std::string>* soa_rdata_names=nullptr,
767-
std::vector<std::string>* soa_idata_names=nullptr
770+
std::vector<std::string>* soa_idata_names=nullptr,
771+
Arena* a_arena=nullptr
768772
)
769773
{
770-
m_defined = true;
771774
GetStructOfArrays().define(a_num_runtime_real, a_num_runtime_int, soa_rdata_names, soa_idata_names);
772775
m_runtime_r_ptrs.resize(a_num_runtime_real);
773776
m_runtime_i_ptrs.resize(a_num_runtime_int);
774777
m_runtime_r_cptrs.resize(a_num_runtime_real);
775778
m_runtime_i_cptrs.resize(a_num_runtime_int);
779+
780+
if constexpr (has_polymorphic_allocator) {
781+
if (m_defined) {
782+
// it is not allowed to change the arena after the tile has been defined
783+
if constexpr (ParticleType::is_soa_particle) {
784+
AMREX_ALWAYS_ASSERT_WITH_MESSAGE(
785+
a_arena == GetStructOfArrays().GetIdCPUData().arena(),
786+
"ParticleTile with PolymorphicArenaAllocator redefined with "
787+
"different memory arena");
788+
} else {
789+
AMREX_ALWAYS_ASSERT_WITH_MESSAGE(
790+
a_arena == m_aos_tile().arena(),
791+
"ParticleTile with PolymorphicArenaAllocator redefined with "
792+
"different memory arena");
793+
}
794+
}
795+
796+
AMREX_ALWAYS_ASSERT_WITH_MESSAGE(a_arena != nullptr,
797+
"ParticleTile with PolymorphicArenaAllocator defined with no memory arena! "
798+
"Make sure to call setArena() on the ParticleContainer before initialization or "
799+
"to pass an Arena to ParticleTile::define()");
800+
801+
if constexpr (ParticleType::is_soa_particle) {
802+
GetStructOfArrays().GetIdCPUData().setArena(a_arena);
803+
} else {
804+
m_aos_tile().setArena(a_arena);
805+
}
806+
for (int j = 0; j < NumRealComps(); ++j) {
807+
GetStructOfArrays().GetRealData(j).setArena(a_arena);
808+
}
809+
for (int j = 0; j < NumIntComps(); ++j) {
810+
GetStructOfArrays().GetIntData(j).setArena(a_arena);
811+
}
812+
}
813+
814+
m_defined = true;
776815
}
777816

778817
// Get id data

0 commit comments

Comments
 (0)