@@ -260,42 +260,45 @@ static inline ucc_status_t ucc_tl_ucp_get_memh(ucc_tl_ucp_team_t *team,
260
260
return UCC_OK ;
261
261
}
262
262
263
- static inline ucc_status_t ucc_tl_ucp_check_memh (ucp_ep_h * ep , void * va ,
264
- uint64_t * rva ,
265
- ucp_rkey_h * rkey , int tl_index ,
266
- ucc_mem_map_mem_h src_map_memh ,
267
- ucc_mem_map_mem_h dst_map_memh )
263
+ static inline ucc_status_t ucc_tl_ucp_check_memh (ucp_ep_h * ep , ucc_rank_t me ,
264
+ ucc_rank_t peer , void * va ,
265
+ uint64_t * rva , ucp_rkey_h * rkey ,
266
+ int tl_index ,
267
+ ucc_mem_map_mem_h * dst_map_memh )
268
268
{
269
- // check if src_memh or dest_memh have segment
270
- ucc_mem_map_memh_t * src_memh = src_map_memh ;
271
- ucc_mem_map_memh_t * dst_memh = dst_map_memh ;
272
- ucc_tl_ucp_memh_data_t * dst_tl_data =
273
- (ucc_tl_ucp_memh_data_t * )dst_memh -> tl_h [tl_index ].tl_data ;
274
- uint64_t base , end ;
275
- ucs_status_t ucs_status ;
276
- int i ;
277
- size_t offset ;
278
-
279
- base = (uint64_t )src_memh -> address ;
280
- end = base + src_memh -> len ;
269
+ ucc_mem_map_memh_t * * dst_memh = (ucc_mem_map_memh_t * * )dst_map_memh ;
270
+ ucc_tl_ucp_memh_data_t * dst_tl_data = NULL ;
271
+ uint64_t base ;
272
+ uint64_t end ;
273
+ ucs_status_t ucs_status ;
274
+ int i ;
275
+ size_t offset ;
276
+
277
+ base = (uint64_t )dst_memh [me ]-> address ;
278
+ end = base + dst_memh [me ]-> len ;
281
279
282
280
if (!((uint64_t )va >= base && (uint64_t )va < end )) {
283
281
return UCC_ERR_NOT_FOUND ;
284
282
}
285
- * rva = (uint64_t )PTR_OFFSET (dst_memh -> address ,
286
- ((uint64_t )va - (uint64_t )src_memh -> address ));
283
+ * rva = (uint64_t )PTR_OFFSET (dst_memh [peer ]-> address ,
284
+ ((uint64_t )va - (uint64_t )dst_memh [me ]-> address ));
285
+ dst_tl_data = (ucc_tl_ucp_memh_data_t * )dst_memh [peer ]-> tl_h [tl_index ].tl_data ;
287
286
if (NULL == dst_tl_data -> rkey ) {
288
287
offset = 0 ;
289
288
/* find pack location for tl */
290
289
for (i = 0 ; i < tl_index ; i ++ ) {
291
- size_t * p = PTR_OFFSET (dst_memh -> pack_buffer , offset );
292
- if (p [0 ] == tl_index ) {
290
+ char * name = PTR_OFFSET (dst_memh [peer ]-> pack_buffer , offset );
291
+ size_t * packed_size = PTR_OFFSET (dst_memh [peer ]-> pack_buffer , offset
292
+ + UCC_MEM_MAP_TL_NAME_LEN );
293
+ if (strncmp (name , "ucp" , 3 ) == 0 ) {
293
294
break ;
294
295
}
295
- offset += p [ 1 ] ;
296
+ offset += * packed_size ;
296
297
}
297
298
ucs_status = ucp_ep_rkey_unpack (
298
- * ep , PTR_OFFSET (dst_memh -> pack_buffer , offset + sizeof (size_t ) * 4 ),
299
+ * ep , PTR_OFFSET (dst_memh [peer ]-> pack_buffer , offset +
300
+ sizeof (size_t ) *
301
+ (UCC_TL_UCP_MEMH_TL_HEADERS + UCC_TL_UCP_MEMH_TL_PACKED_HEADERS )),
299
302
& dst_tl_data -> rkey );
300
303
if (UCS_OK != ucs_status ) {
301
304
return ucs_status_to_ucc_status (ucs_status );
@@ -309,9 +312,10 @@ static inline ucc_status_t ucc_tl_ucp_check_memh(ucp_ep_h *ep, void *va,
309
312
static inline ucc_status_t
310
313
ucc_tl_ucp_resolve_p2p_by_va (ucc_tl_ucp_team_t * team , void * va , ucp_ep_h * ep ,
311
314
ucc_rank_t peer , uint64_t * rva , ucp_rkey_h * rkey ,
312
- int * segment , ucc_mem_map_mem_h src_memh , ucc_mem_map_mem_h dst_memh )
315
+ int * segment , ucc_mem_map_mem_h * dst_memh )
313
316
{
314
317
ucc_tl_ucp_context_t * ctx = UCC_TL_UCP_TEAM_CTX (team );
318
+ ucc_rank_t grank = UCC_TL_TEAM_RANK (team );
315
319
ptrdiff_t key_offset = 0 ;
316
320
const size_t section_offset = sizeof (uint64_t ) * ctx -> n_rinfo_segs ;
317
321
int tl_index = 0 ;
@@ -347,16 +351,16 @@ ucc_tl_ucp_resolve_p2p_by_va(ucc_tl_ucp_team_t *team, void *va, ucp_ep_h *ep,
347
351
key_offset += key_sizes [i ];
348
352
}
349
353
if (ucc_unlikely (0 > * segment )) {
350
- if (src_memh && dst_memh ) {
354
+ if (dst_memh ) {
351
355
/* check if segment is in src/dst memh */
352
- status = find_tl_index (src_memh , & tl_index );
356
+ status = find_tl_index (dst_memh [ grank ] , & tl_index );
353
357
if (status == UCC_ERR_NOT_FOUND ) {
354
358
tl_error (UCC_TL_TEAM_LIB (team ),
355
359
"attempt to perform one-sided operation with malformed mem map handle" );
356
360
return status ;
357
361
}
358
362
359
- status = ucc_tl_ucp_check_memh (ep , va , rva , rkey , tl_index , src_memh , dst_memh );
363
+ status = ucc_tl_ucp_check_memh (ep , grank , peer , va , rva , rkey , tl_index , dst_memh );
360
364
if (status == UCC_OK ) {
361
365
return UCC_OK ;
362
366
}
@@ -421,7 +425,7 @@ static inline ucc_status_t ucc_tl_ucp_put_nb(void *buffer, void *target,
421
425
size_t msglen ,
422
426
ucc_rank_t dest_group_rank ,
423
427
ucc_mem_map_mem_h src_memh ,
424
- ucc_mem_map_mem_h dest_memh ,
428
+ ucc_mem_map_mem_h * dest_memh ,
425
429
ucc_tl_ucp_team_t * team ,
426
430
ucc_tl_ucp_task_t * task )
427
431
{
@@ -447,7 +451,7 @@ static inline ucc_status_t ucc_tl_ucp_put_nb(void *buffer, void *target,
447
451
}
448
452
449
453
status = ucc_tl_ucp_resolve_p2p_by_va (team , target , & ep , dest_group_rank ,
450
- & rva , & rkey , & segment , src_memh , dest_memh );
454
+ & rva , & rkey , & segment , dest_memh );
451
455
if (ucc_unlikely (UCC_OK != status )) {
452
456
return status ;
453
457
}
@@ -478,7 +482,7 @@ static inline ucc_status_t ucc_tl_ucp_get_nb(void *buffer, void *target,
478
482
size_t msglen ,
479
483
ucc_rank_t dest_group_rank ,
480
484
ucc_mem_map_mem_h src_memh ,
481
- ucc_mem_map_mem_h dest_memh ,
485
+ ucc_mem_map_mem_h * dest_memh ,
482
486
ucc_tl_ucp_team_t * team ,
483
487
ucc_tl_ucp_task_t * task )
484
488
{
@@ -504,7 +508,7 @@ static inline ucc_status_t ucc_tl_ucp_get_nb(void *buffer, void *target,
504
508
}
505
509
506
510
status = ucc_tl_ucp_resolve_p2p_by_va (team , target , & ep , dest_group_rank ,
507
- & rva , & rkey , & segment , src_memh , dest_memh );
511
+ & rva , & rkey , & segment , dest_memh );
508
512
if (ucc_unlikely (UCC_OK != status )) {
509
513
return status ;
510
514
}
@@ -534,8 +538,7 @@ static inline ucc_status_t ucc_tl_ucp_get_nb(void *buffer, void *target,
534
538
535
539
static inline ucc_status_t ucc_tl_ucp_atomic_inc (void * target ,
536
540
ucc_rank_t dest_group_rank ,
537
- ucc_mem_map_mem_h src_memh ,
538
- ucc_mem_map_mem_h dest_memh ,
541
+ ucc_mem_map_mem_h * dest_memh ,
539
542
ucc_tl_ucp_team_t * team )
540
543
{
541
544
ucp_request_param_t req_param = {0 };
@@ -553,7 +556,7 @@ static inline ucc_status_t ucc_tl_ucp_atomic_inc(void * target,
553
556
}
554
557
555
558
status = ucc_tl_ucp_resolve_p2p_by_va (team , target , & ep , dest_group_rank ,
556
- & rva , & rkey , & segment , src_memh , dest_memh );
559
+ & rva , & rkey , & segment , dest_memh );
557
560
if (ucc_unlikely (UCC_OK != status )) {
558
561
return status ;
559
562
}
0 commit comments