Skip to content

Commit 48f0a03

Browse files
committed
Armadillo 15.2.2
1 parent 0828e75 commit 48f0a03

File tree

2 files changed

+74
-13
lines changed

2 files changed

+74
-13
lines changed

inst/include/current/armadillo_bits/arma_rng.hpp

Lines changed: 73 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
#undef ARMA_USE_THREAD_LOCAL
2727
#define ARMA_USE_THREAD_LOCAL
2828

29+
#undef ARMA_USE_THREAD_UNIQUE_RNG_SEED
30+
#define ARMA_USE_THREAD_UNIQUE_RNG_SEED
31+
2932
#if (defined(ARMA_RNG_ALT) || defined(ARMA_DONT_USE_CXX11_RNG))
3033
#undef ARMA_USE_CXX11_RNG
3134
#endif
@@ -34,6 +37,10 @@
3437
#undef ARMA_USE_THREAD_LOCAL
3538
#endif
3639

40+
#if defined(ARMA_DONT_USE_THREAD_UNIQUE_RNG_SEED)
41+
#undef ARMA_USE_THREAD_UNIQUE_RNG_SEED
42+
#endif
43+
3744

3845
// NOTE: ARMA_WARMUP_PRODUCER enables a workaround
3946
// NOTE: for thread_local issue on macOS 11 and/or AppleClang 12.0
@@ -129,24 +136,42 @@ arma_rng::get_producer()
129136
{
130137
#if defined(ARMA_USE_THREAD_LOCAL)
131138

132-
// use a thread-safe RNG, with each thread having its own unique starting seed
133-
134-
static std::atomic<std::size_t> mt19937_64_producer_counter(0);
135-
136-
static thread_local std::mt19937_64 mt19937_64_producer( std::mt19937_64::default_seed + mt19937_64_producer_counter++ );
139+
// thread-safe RNG
137140

138-
arma_rng::warmup_producer(mt19937_64_producer);
141+
#if defined(ARMA_USE_THREAD_UNIQUE_RNG_SEED)
142+
143+
// each thread has unique starting seed
144+
145+
#if defined(ARMA_USE_OPENMP)
146+
147+
static thread_local std::mt19937_64 mt19937_64_producer( std::mt19937_64::default_seed + arma_rng::seed_type(omp_get_thread_num()) );
148+
149+
#else
150+
151+
static std::atomic<std::size_t> mt19937_64_producer_counter(0);
152+
153+
static thread_local std::mt19937_64 mt19937_64_producer( std::mt19937_64::default_seed + mt19937_64_producer_counter++ );
154+
155+
#endif
156+
157+
#else
158+
159+
// each thread has the same starting seed
160+
161+
static thread_local std::mt19937_64 mt19937_64_producer( std::mt19937_64::default_seed );
162+
163+
#endif
139164

140165
#else
141166

142-
// use a plain RNG in case we don't have thread_local
167+
// plain RNG in case we don't have thread_local
143168

144169
static std::mt19937_64 mt19937_64_producer( std::mt19937_64::default_seed );
145170

146-
arma_rng::warmup_producer(mt19937_64_producer);
147-
148171
#endif
149172

173+
arma_rng::warmup_producer(mt19937_64_producer);
174+
150175
return mt19937_64_producer;
151176
}
152177

@@ -226,9 +251,45 @@ arma_rng::set_seed(const arma_rng::seed_type val)
226251
}
227252
#elif defined(ARMA_USE_CXX11_RNG)
228253
{
229-
arma_rng::lock_producer();
230-
arma_rng::get_producer().seed(val);
231-
arma_rng::unlock_producer();
254+
#if defined(ARMA_USE_OPENMP) && defined(ARMA_USE_THREAD_LOCAL)
255+
{
256+
arma_rng::lock_producer();
257+
258+
#if defined(ARMA_USE_THREAD_UNIQUE_RNG_SEED)
259+
constexpr bool thread_unique_rng_seed = true;
260+
#else
261+
constexpr bool thread_unique_rng_seed = false;
262+
#endif
263+
264+
// if we're already in a parallel region, assume the user is setting the seed for each thread
265+
266+
if( (thread_unique_rng_seed == false) || bool(omp_in_parallel()) )
267+
{
268+
arma_rng::get_producer().seed(val);
269+
}
270+
else
271+
{
272+
const int n_threads = int( (std::max)( int(1), int(omp_get_max_threads()) ) );
273+
274+
#pragma omp parallel for ordered schedule(static) num_threads(n_threads)
275+
for(int t=0; t < n_threads; ++t)
276+
{
277+
#pragma omp ordered
278+
{
279+
arma_rng::get_producer().seed(val + arma_rng::seed_type(omp_get_thread_num()));
280+
}
281+
}
282+
}
283+
284+
arma_rng::unlock_producer();
285+
}
286+
#else
287+
{
288+
arma_rng::lock_producer();
289+
arma_rng::get_producer().seed(val);
290+
arma_rng::unlock_producer();
291+
}
292+
#endif
232293
}
233294
#else
234295
{

inst/include/current/armadillo_bits/arma_version.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
#define ARMA_VERSION_MAJOR 15
2525
#define ARMA_VERSION_MINOR 2
26-
#define ARMA_VERSION_PATCH 1
26+
#define ARMA_VERSION_PATCH 2
2727
#define ARMA_VERSION_NAME "Medium Roast Deluxe"
2828

2929

0 commit comments

Comments
 (0)