@@ -93,6 +93,12 @@ const int num_colors = sizeof(colors) / sizeof(uint32_t);
93
93
94
94
// TODO: include NCCL headers
95
95
#include < nccl.h>
96
+ #ifdef NCCL_VERSION
97
+ #define NCCL_VERSION_UB NCCL_VERSION (2 ,19 ,1 )
98
+ #define NCCL_UB_SUPPORT NCCL_VERSION_CODE >= NCCL_VERSION_UB
99
+ #else
100
+ #define NCCL_UB_SUPPORT 0
101
+ #endif
96
102
97
103
#define NCCL_CALL (call ) \
98
104
{ \
@@ -168,7 +174,13 @@ int main(int argc, char* argv[]) {
168
174
const int nx = get_argval<int >(argv, argv + argc, " -nx" , 16384 );
169
175
const int ny = get_argval<int >(argv, argv + argc, " -ny" , 16384 );
170
176
const bool csv = get_arg (argv, argv + argc, " -csv" );
171
-
177
+ bool user_buffer_reg = get_arg (argv, argv + argc, " -user_buffer_reg" );
178
+ #if NCCL_UB_SUPPORT == 0
179
+ if (user_buffer_reg) {
180
+ fprintf (stderr," WARNING: Ignoring -user_buffer_reg, required NCCL APIs are provided by NCCL 2.19.1 or later.\n " );
181
+ user_buffer_reg = false ;
182
+ }
183
+ #endif // NCCL_UB_SUPPORT == 0
172
184
int local_rank = -1 ;
173
185
{
174
186
MPI_Comm local_comm;
@@ -220,10 +232,27 @@ int main(int argc, char* argv[]) {
220
232
chunk_size = chunk_size_high;
221
233
222
234
real* a;
223
- CUDA_RT_CALL (cudaMalloc (&a, nx * (chunk_size + 2 ) * sizeof (real)));
224
235
real* a_new;
225
- CUDA_RT_CALL (cudaMalloc (&a_new, nx * (chunk_size + 2 ) * sizeof (real)));
226
236
237
+ #if NCCL_UB_SUPPORT
238
+ void * a_reg_handle;
239
+ void * a_new_reg_handle;
240
+ if (user_buffer_reg) {
241
+ // TODO: Allocate the memory with ncclMemAlloc and register it for the commmunicatior
242
+ NCCL_CALL (ncclMemAlloc ( (void **) &a , nx * (chunk_size + 2 ) * sizeof (real)));
243
+ NCCL_CALL (ncclMemAlloc ( (void **) &a_new, nx * (chunk_size + 2 ) * sizeof (real)));
244
+ NCCL_CALL (ncclCommRegister (nccl_comm, a , nx * (chunk_size + 2 ) * sizeof (real), &a_reg_handle));
245
+ NCCL_CALL (ncclCommRegister (nccl_comm, a_new, nx * (chunk_size + 2 ) * sizeof (real), &a_new_reg_handle));
246
+ if ( nccl_version < 22304 ) {
247
+ fprintf (stderr," WARNING: -user_buffer_reg available, but Jacobi communication pattern needs NCCL 2.23.4 or later.\n " );
248
+ }
249
+ }
250
+ else
251
+ #endif // NCCL_UB_SUPPORT
252
+ {
253
+ CUDA_RT_CALL (cudaMalloc (&a, nx * (chunk_size + 2 ) * sizeof (real)));
254
+ CUDA_RT_CALL (cudaMalloc (&a_new, nx * (chunk_size + 2 ) * sizeof (real)));
255
+ }
227
256
CUDA_RT_CALL (cudaMemset (a, 0 , nx * (chunk_size + 2 ) * sizeof (real)));
228
257
CUDA_RT_CALL (cudaMemset (a_new, 0 , nx * (chunk_size + 2 ) * sizeof (real)));
229
258
@@ -403,10 +432,20 @@ int main(int argc, char* argv[]) {
403
432
404
433
CUDA_RT_CALL (cudaFreeHost (l2_norm_h));
405
434
CUDA_RT_CALL (cudaFree (l2_norm_d));
406
-
435
+ #if NCCL_UB_SUPPORT
436
+ if (user_buffer_reg) {
437
+ // TODO: Deregister and Free the Buffer
438
+ NCCL_CALL (ncclCommDeregister (nccl_comm, a_new_reg_handle));
439
+ NCCL_CALL (ncclCommDeregister (nccl_comm, a_reg_handle));
440
+ NCCL_CALL (ncclMemFree (a_new));
441
+ NCCL_CALL (ncclMemFree (a));
442
+ }
443
+ else
444
+ #endif // NCCL_UB_SUPPORT
445
+ {
407
446
CUDA_RT_CALL (cudaFree (a_new));
408
447
CUDA_RT_CALL (cudaFree (a));
409
-
448
+ }
410
449
CUDA_RT_CALL (cudaFreeHost (a_h));
411
450
CUDA_RT_CALL (cudaFreeHost (a_ref_h));
412
451
0 commit comments