@@ -95,6 +95,13 @@ const int num_colors = sizeof(colors) / sizeof(uint32_t);
95
95
#ifdef SOLUTION
96
96
#include < nccl.h>
97
97
#endif
98
+ #ifdef NCCL_VERSION
99
+ #define NCCL_VERSION_UB NCCL_VERSION (2 ,19 ,1 )
100
+ #define NCCL_UB_SUPPORT NCCL_VERSION_CODE >= NCCL_VERSION_UB
101
+ #else
102
+ #define NCCL_UB_SUPPORT 0
103
+ #endif
104
+
98
105
99
106
#define NCCL_CALL (call ) \
100
107
{ \
@@ -172,6 +179,13 @@ int main(int argc, char* argv[]) {
172
179
const int nx = get_argval<int >(argv, argv + argc, " -nx" , 16384 );
173
180
const int ny = get_argval<int >(argv, argv + argc, " -ny" , 16384 );
174
181
const bool csv = get_arg (argv, argv + argc, " -csv" );
182
+ bool user_buffer_reg = get_arg (argv, argv + argc, " -user_buffer_reg" );
183
+ #if NCCL_UB_SUPPORT == 0
184
+ if (user_buffer_reg) {
185
+ fprintf (stderr," WARNING: Ignoring -user_buffer_reg, required NCCL APIs are provided by NCCL 2.19.1 or later.\n " );
186
+ user_buffer_reg = false ;
187
+ }
188
+ #endif // NCCL_UB_SUPPORT == 0
175
189
176
190
int local_rank = -1 ;
177
191
{
@@ -226,10 +240,30 @@ int main(int argc, char* argv[]) {
226
240
chunk_size = chunk_size_high;
227
241
228
242
real* a;
229
- CUDA_RT_CALL (cudaMalloc (&a, nx * (chunk_size + 2 ) * sizeof (real)));
230
243
real* a_new;
231
- CUDA_RT_CALL (cudaMalloc (&a_new, nx * (chunk_size + 2 ) * sizeof (real)));
244
+ #if NCCL_UB_SUPPORT
245
+ void * a_reg_handle;
246
+ void * a_new_reg_handle;
247
+ if (user_buffer_reg) {
248
+ // TODO: Allocate the memory with ncclMemAlloc and register it for the commmunicatior
249
+ #ifdef SOLUTION
250
+
251
+ NCCL_CALL (ncclMemAlloc ( (void **) &a , nx * (chunk_size + 2 ) * sizeof (real)));
252
+ NCCL_CALL (ncclMemAlloc ( (void **) &a_new, nx * (chunk_size + 2 ) * sizeof (real)));
253
+ NCCL_CALL (ncclCommRegister (nccl_comm, a , nx * (chunk_size + 2 ) * sizeof (real), &a_reg_handle));
254
+ NCCL_CALL (ncclCommRegister (nccl_comm, a_new, nx * (chunk_size + 2 ) * sizeof (real), &a_new_reg_handle));
255
+ #endif
256
+ if ( nccl_version < 22304 ) {
257
+ fprintf (stderr," WARNING: -user_buffer_reg available, but Jacobi communication pattern needs NCCL 2.23.4 or later.\n " );
258
+ }
259
+ }
260
+ else
261
+ #endif // NCCL_UB_SUPPORT
232
262
263
+ {
264
+ CUDA_RT_CALL (cudaMalloc (&a, nx * (chunk_size + 2 ) * sizeof (real)));
265
+ CUDA_RT_CALL (cudaMalloc (&a_new, nx * (chunk_size + 2 ) * sizeof (real)));
266
+ }
233
267
CUDA_RT_CALL (cudaMemset (a, 0 , nx * (chunk_size + 2 ) * sizeof (real)));
234
268
CUDA_RT_CALL (cudaMemset (a_new, 0 , nx * (chunk_size + 2 ) * sizeof (real)));
235
269
@@ -434,10 +468,22 @@ int main(int argc, char* argv[]) {
434
468
435
469
CUDA_RT_CALL (cudaFreeHost (l2_norm_h));
436
470
CUDA_RT_CALL (cudaFree (l2_norm_d));
437
-
471
+ #if NCCL_UB_SUPPORT
472
+ if (user_buffer_reg) {
473
+ // TODO: Deregister and Free the Buffer
474
+ #ifdef SOLUTION
475
+ NCCL_CALL (ncclCommDeregister (nccl_comm, a_new_reg_handle));
476
+ NCCL_CALL (ncclCommDeregister (nccl_comm, a_reg_handle));
477
+ NCCL_CALL (ncclMemFree (a_new));
478
+ NCCL_CALL (ncclMemFree (a));
479
+ #endif
480
+ }
481
+ else
482
+ #endif // NCCL_UB_SUPPORT
483
+ {
438
484
CUDA_RT_CALL (cudaFree (a_new));
439
485
CUDA_RT_CALL (cudaFree (a));
440
-
486
+ }
441
487
CUDA_RT_CALL (cudaFreeHost (a_h));
442
488
CUDA_RT_CALL (cudaFreeHost (a_ref_h));
443
489
0 commit comments