|
26 | 26 | #undef ARMA_USE_THREAD_LOCAL |
27 | 27 | #define ARMA_USE_THREAD_LOCAL |
28 | 28 |
|
| 29 | +#undef ARMA_USE_THREAD_UNIQUE_RNG_SEED |
| 30 | +#define ARMA_USE_THREAD_UNIQUE_RNG_SEED |
| 31 | + |
29 | 32 | #if (defined(ARMA_RNG_ALT) || defined(ARMA_DONT_USE_CXX11_RNG)) |
30 | 33 | #undef ARMA_USE_CXX11_RNG |
31 | 34 | #endif |
|
34 | 37 | #undef ARMA_USE_THREAD_LOCAL |
35 | 38 | #endif |
36 | 39 |
|
| 40 | +#if defined(ARMA_DONT_USE_THREAD_UNIQUE_RNG_SEED) |
| 41 | + #undef ARMA_USE_THREAD_UNIQUE_RNG_SEED |
| 42 | +#endif |
| 43 | + |
37 | 44 |
|
38 | 45 | // NOTE: ARMA_WARMUP_PRODUCER enables a workaround |
39 | 46 | // NOTE: for thread_local issue on macOS 11 and/or AppleClang 12.0 |
@@ -129,24 +136,42 @@ arma_rng::get_producer() |
129 | 136 | { |
130 | 137 | #if defined(ARMA_USE_THREAD_LOCAL) |
131 | 138 |
|
132 | | - // use a thread-safe RNG, with each thread having its own unique starting seed |
133 | | - |
134 | | - static std::atomic<std::size_t> mt19937_64_producer_counter(0); |
135 | | - |
136 | | - static thread_local std::mt19937_64 mt19937_64_producer( std::mt19937_64::default_seed + mt19937_64_producer_counter++ ); |
| 139 | + // thread-safe RNG |
137 | 140 |
|
138 | | - arma_rng::warmup_producer(mt19937_64_producer); |
| 141 | + #if defined(ARMA_USE_THREAD_UNIQUE_RNG_SEED) |
| 142 | + |
| 143 | + // each thread has unique starting seed |
| 144 | + |
| 145 | + #if defined(ARMA_USE_OPENMP) |
| 146 | + |
| 147 | + static thread_local std::mt19937_64 mt19937_64_producer( std::mt19937_64::default_seed + arma_rng::seed_type(omp_get_thread_num()) ); |
| 148 | + |
| 149 | + #else |
| 150 | + |
| 151 | + static std::atomic<std::size_t> mt19937_64_producer_counter(0); |
| 152 | + |
| 153 | + static thread_local std::mt19937_64 mt19937_64_producer( std::mt19937_64::default_seed + mt19937_64_producer_counter++ ); |
| 154 | + |
| 155 | + #endif |
| 156 | + |
| 157 | + #else |
| 158 | + |
| 159 | + // each thread has the same starting seed |
| 160 | + |
| 161 | + static thread_local std::mt19937_64 mt19937_64_producer( std::mt19937_64::default_seed ); |
| 162 | + |
| 163 | + #endif |
139 | 164 |
|
140 | 165 | #else |
141 | 166 |
|
142 | | - // use a plain RNG in case we don't have thread_local |
| 167 | + // plain RNG in case we don't have thread_local |
143 | 168 |
|
144 | 169 | static std::mt19937_64 mt19937_64_producer( std::mt19937_64::default_seed ); |
145 | 170 |
|
146 | | - arma_rng::warmup_producer(mt19937_64_producer); |
147 | | - |
148 | 171 | #endif |
149 | 172 |
|
| 173 | + arma_rng::warmup_producer(mt19937_64_producer); |
| 174 | + |
150 | 175 | return mt19937_64_producer; |
151 | 176 | } |
152 | 177 |
|
@@ -226,9 +251,45 @@ arma_rng::set_seed(const arma_rng::seed_type val) |
226 | 251 | } |
227 | 252 | #elif defined(ARMA_USE_CXX11_RNG) |
228 | 253 | { |
229 | | - arma_rng::lock_producer(); |
230 | | - arma_rng::get_producer().seed(val); |
231 | | - arma_rng::unlock_producer(); |
| 254 | + #if defined(ARMA_USE_OPENMP) && defined(ARMA_USE_THREAD_LOCAL) |
| 255 | + { |
| 256 | + arma_rng::lock_producer(); |
| 257 | + |
| 258 | + #if defined(ARMA_USE_THREAD_UNIQUE_RNG_SEED) |
| 259 | + constexpr bool thread_unique_rng_seed = true; |
| 260 | + #else |
| 261 | + constexpr bool thread_unique_rng_seed = false; |
| 262 | + #endif |
| 263 | + |
| 264 | + // if we're already in a parallel region, assume the user is setting the seed for each thread |
| 265 | + |
| 266 | + if( (thread_unique_rng_seed == false) || bool(omp_in_parallel()) ) |
| 267 | + { |
| 268 | + arma_rng::get_producer().seed(val); |
| 269 | + } |
| 270 | + else |
| 271 | + { |
| 272 | + const int n_threads = int( (std::max)( int(1), int(omp_get_max_threads()) ) ); |
| 273 | + |
| 274 | + #pragma omp parallel for ordered schedule(static) num_threads(n_threads) |
| 275 | + for(int t=0; t < n_threads; ++t) |
| 276 | + { |
| 277 | + #pragma omp ordered |
| 278 | + { |
| 279 | + arma_rng::get_producer().seed(val + arma_rng::seed_type(omp_get_thread_num())); |
| 280 | + } |
| 281 | + } |
| 282 | + } |
| 283 | + |
| 284 | + arma_rng::unlock_producer(); |
| 285 | + } |
| 286 | + #else |
| 287 | + { |
| 288 | + arma_rng::lock_producer(); |
| 289 | + arma_rng::get_producer().seed(val); |
| 290 | + arma_rng::unlock_producer(); |
| 291 | + } |
| 292 | + #endif |
232 | 293 | } |
233 | 294 | #else |
234 | 295 | { |
|
0 commit comments