Skip to content

Commit 213ef2b

Browse files
author
ferrol aderholdt
committed
REVIEW: address feedback
1 parent 92ce5a9 commit 213ef2b

File tree

5 files changed

+88
-117
lines changed

5 files changed

+88
-117
lines changed

src/components/tl/ucp/alltoall/alltoall_onesided.c

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,35 +24,21 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask)
2424
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
2525
ucc_rank_t start = (grank + 1) % gsize;
2626
long *pSync = TASK_ARGS(task).global_work_buffer;
27-
ucc_mem_map_mem_h *src_memh = TASK_ARGS(task).src_memh.global_memh;
27+
ucc_mem_map_mem_h src_memh = TASK_ARGS(task).src_memh.local_memh;
2828
ucc_mem_map_mem_h *dst_memh = TASK_ARGS(task).dst_memh.global_memh;
2929
ucc_rank_t peer;
3030

3131
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
3232
/* TODO: change when support for library-based work buffers is complete */
3333
nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype);
3434
dest = dest + grank * nelems;
35-
36-
if (ucc_likely(!(src_memh && dst_memh))) {
37-
for (peer = start; task->onesided.put_posted < gsize; peer = (peer + 1) % gsize) {
38-
UCPCHECK_GOTO(ucc_tl_ucp_put_nb(
39-
(void *)(src + peer * nelems), (void *)dest, nelems,
40-
peer, NULL, NULL, team, task),
41-
task, out);
42-
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, NULL,
43-
NULL, team),
44-
task, out);
45-
}
46-
} else {
47-
for (peer = start; task->onesided.put_posted < gsize; peer = (peer + 1) % gsize) {
48-
UCPCHECK_GOTO(ucc_tl_ucp_put_nb(
49-
(void *)(src + peer * nelems), (void *)dest, nelems,
50-
peer, *src_memh, dst_memh[peer], team, task),
51-
task, out);
52-
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, *src_memh,
53-
dst_memh[peer], team),
54-
task, out);
55-
}
35+
for (peer = start; task->onesided.put_posted < gsize; peer = (peer + 1) % gsize) {
36+
UCPCHECK_GOTO(ucc_tl_ucp_put_nb(
37+
(void *)(src + peer * nelems), (void *)dest, nelems,
38+
peer, src_memh, dst_memh, team, task),
39+
task, out);
40+
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, dst_memh, team),
41+
task, out);
5642
}
5743
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
5844
out:

src/components/tl/ucp/alltoallv/alltoallv_onesided.c

Lines changed: 21 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask)
2424
ucc_aint_t *d_disp = TASK_ARGS(task).dst.info_v.displacements;
2525
size_t sdt_size = ucc_dt_size(TASK_ARGS(task).src.info_v.datatype);
2626
size_t rdt_size = ucc_dt_size(TASK_ARGS(task).dst.info_v.datatype);
27-
ucc_mem_map_mem_h *src_memh = TASK_ARGS(task).src_memh.global_memh;
27+
ucc_mem_map_mem_h src_memh = TASK_ARGS(task).src_memh.local_memh;
2828
ucc_mem_map_mem_h *dst_memh = TASK_ARGS(task).dst_memh.global_memh;
2929
ucc_rank_t peer;
3030
size_t sd_disp, dd_disp, data_size;
@@ -33,51 +33,27 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask)
3333

3434
/* perform a put to each member peer using the peer's index in the
3535
* destination displacement. */
36-
if (ucc_likely(!(src_memh && dst_memh))) {
37-
for (peer = (grank + 1) % gsize; task->onesided.put_posted < gsize;
38-
peer = (peer + 1) % gsize) {
39-
sd_disp =
40-
ucc_coll_args_get_displacement(&TASK_ARGS(task), s_disp, peer) *
41-
sdt_size;
42-
dd_disp =
43-
ucc_coll_args_get_displacement(&TASK_ARGS(task), d_disp, peer) *
44-
rdt_size;
45-
data_size =
46-
ucc_coll_args_get_count(
47-
&TASK_ARGS(task), TASK_ARGS(task).src.info_v.counts, peer) *
48-
sdt_size;
36+
for (peer = (grank + 1) % gsize; task->onesided.put_posted < gsize;
37+
peer = (peer + 1) % gsize) {
38+
sd_disp =
39+
ucc_coll_args_get_displacement(&TASK_ARGS(task), s_disp, peer) *
40+
sdt_size;
41+
dd_disp =
42+
ucc_coll_args_get_displacement(&TASK_ARGS(task), d_disp, peer) *
43+
rdt_size;
44+
data_size =
45+
ucc_coll_args_get_count(
46+
&TASK_ARGS(task), TASK_ARGS(task).src.info_v.counts, peer) *
47+
sdt_size;
4948

50-
UCPCHECK_GOTO(ucc_tl_ucp_put_nb(PTR_OFFSET(src, sd_disp),
51-
PTR_OFFSET(dest, dd_disp),
52-
data_size, peer, NULL, NULL, team,
53-
task),
54-
task, out);
55-
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, NULL, NULL, team),
56-
task, out);
57-
}
58-
} else {
59-
for (peer = (grank + 1) % gsize; task->onesided.put_posted < gsize;
60-
peer = (peer + 1) % gsize) {
61-
sd_disp =
62-
ucc_coll_args_get_displacement(&TASK_ARGS(task), s_disp, peer) *
63-
sdt_size;
64-
dd_disp =
65-
ucc_coll_args_get_displacement(&TASK_ARGS(task), d_disp, peer) *
66-
rdt_size;
67-
data_size =
68-
ucc_coll_args_get_count(
69-
&TASK_ARGS(task), TASK_ARGS(task).src.info_v.counts, peer) *
70-
sdt_size;
71-
72-
UCPCHECK_GOTO(ucc_tl_ucp_put_nb(PTR_OFFSET(src, sd_disp),
73-
PTR_OFFSET(dest, dd_disp),
74-
data_size, peer, *src_memh,
75-
dst_memh[peer], team, task),
76-
task, out);
77-
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, *src_memh,
78-
dst_memh[peer], team),
79-
task, out);
80-
}
49+
UCPCHECK_GOTO(ucc_tl_ucp_put_nb(PTR_OFFSET(src, sd_disp),
50+
PTR_OFFSET(dest, dd_disp),
51+
data_size, peer, src_memh,
52+
dst_memh, team, task),
53+
task, out);
54+
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer,
55+
dst_memh, team),
56+
task, out);
8157
}
8258
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
8359
out:

src/components/tl/ucp/tl_ucp_context.c

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -651,20 +651,28 @@ ucc_status_t ucc_tl_ucp_mem_map_offload_import(ucc_tl_ucp_context_t *ctx,
651651
ucc_mem_map_memh_t *l_memh,
652652
void *tl_h)
653653
{
654-
size_t offset = 0;
655-
ucp_mem_h mh;
656-
ucc_status_t ucc_status;
657-
658-
for (int i = 0; i < l_memh->num_tls; i++) {
659-
size_t *p = (size_t *)PTR_OFFSET(l_memh->pack_buffer, offset);
660-
661-
if (tl_h == (void *)&l_memh->tl_h[i]) {
654+
int i = 0;
655+
ucc_mem_map_tl_t *tl = (ucc_mem_map_tl_t *)tl_h;
656+
size_t offset = 0;
657+
ucp_mem_h mh;
658+
ucc_status_t ucc_status;
659+
char *name;
660+
size_t *packed_size;
661+
662+
for (; i < l_memh->num_tls; i++) {
663+
name = (char *)PTR_OFFSET(l_memh->pack_buffer, offset);
664+
packed_size = (size_t *)PTR_OFFSET(l_memh->pack_buffer,
665+
offset + UCC_MEM_MAP_TL_NAME_LEN);
666+
667+
if (strncmp(tl->tl_name, name, UCC_MEM_MAP_TL_NAME_LEN - 1) == 0) {
662668
break;
663669
}
664670
/* this is not the index, skip this section of buffer if exists */
665-
if (p[0] == i) {
666-
offset += p[1];
667-
}
671+
offset += *packed_size;
672+
}
673+
if (i == l_memh->num_tls) {
674+
tl_error(ctx->super.super.lib, "unable to find TL UCP in memory handle");
675+
return UCC_ERR_NOT_FOUND;
668676
}
669677

670678
ucc_status = ucc_tl_ucp_mem_map_memhbuf(

src/components/tl/ucp/tl_ucp_sendrecv.h

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -260,42 +260,45 @@ static inline ucc_status_t ucc_tl_ucp_get_memh(ucc_tl_ucp_team_t *team,
260260
return UCC_OK;
261261
}
262262

263-
static inline ucc_status_t ucc_tl_ucp_check_memh(ucp_ep_h *ep, void *va,
264-
uint64_t *rva,
265-
ucp_rkey_h *rkey, int tl_index,
266-
ucc_mem_map_mem_h src_map_memh,
267-
ucc_mem_map_mem_h dst_map_memh)
263+
static inline ucc_status_t ucc_tl_ucp_check_memh(ucp_ep_h *ep, ucc_rank_t me,
264+
ucc_rank_t peer, void *va,
265+
uint64_t *rva, ucp_rkey_h *rkey,
266+
int tl_index,
267+
ucc_mem_map_mem_h *dst_map_memh)
268268
{
269-
// check if src_memh or dest_memh have segment
270-
ucc_mem_map_memh_t *src_memh = src_map_memh;
271-
ucc_mem_map_memh_t *dst_memh = dst_map_memh;
272-
ucc_tl_ucp_memh_data_t *dst_tl_data =
273-
(ucc_tl_ucp_memh_data_t *)dst_memh->tl_h[tl_index].tl_data;
274-
uint64_t base, end;
275-
ucs_status_t ucs_status;
276-
int i;
277-
size_t offset;
278-
279-
base = (uint64_t)src_memh->address;
280-
end = base + src_memh->len;
269+
ucc_mem_map_memh_t **dst_memh = (ucc_mem_map_memh_t **)dst_map_memh;
270+
ucc_tl_ucp_memh_data_t *dst_tl_data = NULL;
271+
uint64_t base;
272+
uint64_t end;
273+
ucs_status_t ucs_status;
274+
int i;
275+
size_t offset;
276+
277+
base = (uint64_t)dst_memh[me]->address;
278+
end = base + dst_memh[me]->len;
281279

282280
if (!((uint64_t)va >= base && (uint64_t)va < end)) {
283281
return UCC_ERR_NOT_FOUND;
284282
}
285-
*rva = (uint64_t)PTR_OFFSET(dst_memh->address,
286-
((uint64_t)va - (uint64_t)src_memh->address));
283+
*rva = (uint64_t)PTR_OFFSET(dst_memh[peer]->address,
284+
((uint64_t)va - (uint64_t)dst_memh[me]->address));
285+
dst_tl_data = (ucc_tl_ucp_memh_data_t *)dst_memh[peer]->tl_h[tl_index].tl_data;
287286
if (NULL == dst_tl_data->rkey) {
288287
offset = 0;
289288
/* find pack location for tl */
290289
for (i = 0; i < tl_index; i++) {
291-
size_t *p = PTR_OFFSET(dst_memh->pack_buffer, offset);
292-
if (p[0] == tl_index) {
290+
char *name = PTR_OFFSET(dst_memh[peer]->pack_buffer, offset);
291+
size_t *packed_size = PTR_OFFSET(dst_memh[peer]->pack_buffer, offset
292+
+ UCC_MEM_MAP_TL_NAME_LEN);
293+
if (strncmp(name, "ucp", 3) == 0) {
293294
break;
294295
}
295-
offset += p[1];
296+
offset += *packed_size;
296297
}
297298
ucs_status = ucp_ep_rkey_unpack(
298-
*ep, PTR_OFFSET(dst_memh->pack_buffer, offset + sizeof(size_t) * 4),
299+
*ep, PTR_OFFSET(dst_memh[peer]->pack_buffer, offset +
300+
sizeof(size_t) *
301+
(UCC_TL_UCP_MEMH_TL_HEADERS + UCC_TL_UCP_MEMH_TL_PACKED_HEADERS)),
299302
&dst_tl_data->rkey);
300303
if (UCS_OK != ucs_status) {
301304
return ucs_status_to_ucc_status(ucs_status);
@@ -309,9 +312,10 @@ static inline ucc_status_t ucc_tl_ucp_check_memh(ucp_ep_h *ep, void *va,
309312
static inline ucc_status_t
310313
ucc_tl_ucp_resolve_p2p_by_va(ucc_tl_ucp_team_t *team, void *va, ucp_ep_h *ep,
311314
ucc_rank_t peer, uint64_t *rva, ucp_rkey_h *rkey,
312-
int *segment, ucc_mem_map_mem_h src_memh, ucc_mem_map_mem_h dst_memh)
315+
int *segment, ucc_mem_map_mem_h *dst_memh)
313316
{
314317
ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team);
318+
ucc_rank_t grank = UCC_TL_TEAM_RANK(team);
315319
ptrdiff_t key_offset = 0;
316320
const size_t section_offset = sizeof(uint64_t) * ctx->n_rinfo_segs;
317321
int tl_index = 0;
@@ -347,16 +351,16 @@ ucc_tl_ucp_resolve_p2p_by_va(ucc_tl_ucp_team_t *team, void *va, ucp_ep_h *ep,
347351
key_offset += key_sizes[i];
348352
}
349353
if (ucc_unlikely(0 > *segment)) {
350-
if (src_memh && dst_memh) {
354+
if (dst_memh) {
351355
/* check if segment is in src/dst memh */
352-
status = find_tl_index(src_memh, &tl_index);
356+
status = find_tl_index(dst_memh[grank], &tl_index);
353357
if (status == UCC_ERR_NOT_FOUND) {
354358
tl_error(UCC_TL_TEAM_LIB(team),
355359
"attempt to perform one-sided operation with malformed mem map handle");
356360
return status;
357361
}
358362

359-
status = ucc_tl_ucp_check_memh(ep, va, rva, rkey, tl_index, src_memh, dst_memh);
363+
status = ucc_tl_ucp_check_memh(ep, grank, peer, va, rva, rkey, tl_index, dst_memh);
360364
if (status == UCC_OK) {
361365
return UCC_OK;
362366
}
@@ -421,7 +425,7 @@ static inline ucc_status_t ucc_tl_ucp_put_nb(void *buffer, void *target,
421425
size_t msglen,
422426
ucc_rank_t dest_group_rank,
423427
ucc_mem_map_mem_h src_memh,
424-
ucc_mem_map_mem_h dest_memh,
428+
ucc_mem_map_mem_h *dest_memh,
425429
ucc_tl_ucp_team_t *team,
426430
ucc_tl_ucp_task_t *task)
427431
{
@@ -447,7 +451,7 @@ static inline ucc_status_t ucc_tl_ucp_put_nb(void *buffer, void *target,
447451
}
448452

449453
status = ucc_tl_ucp_resolve_p2p_by_va(team, target, &ep, dest_group_rank,
450-
&rva, &rkey, &segment, src_memh, dest_memh);
454+
&rva, &rkey, &segment, dest_memh);
451455
if (ucc_unlikely(UCC_OK != status)) {
452456
return status;
453457
}
@@ -478,7 +482,7 @@ static inline ucc_status_t ucc_tl_ucp_get_nb(void *buffer, void *target,
478482
size_t msglen,
479483
ucc_rank_t dest_group_rank,
480484
ucc_mem_map_mem_h src_memh,
481-
ucc_mem_map_mem_h dest_memh,
485+
ucc_mem_map_mem_h *dest_memh,
482486
ucc_tl_ucp_team_t *team,
483487
ucc_tl_ucp_task_t *task)
484488
{
@@ -504,7 +508,7 @@ static inline ucc_status_t ucc_tl_ucp_get_nb(void *buffer, void *target,
504508
}
505509

506510
status = ucc_tl_ucp_resolve_p2p_by_va(team, target, &ep, dest_group_rank,
507-
&rva, &rkey, &segment, src_memh, dest_memh);
511+
&rva, &rkey, &segment, dest_memh);
508512
if (ucc_unlikely(UCC_OK != status)) {
509513
return status;
510514
}
@@ -534,8 +538,7 @@ static inline ucc_status_t ucc_tl_ucp_get_nb(void *buffer, void *target,
534538

535539
static inline ucc_status_t ucc_tl_ucp_atomic_inc(void * target,
536540
ucc_rank_t dest_group_rank,
537-
ucc_mem_map_mem_h src_memh,
538-
ucc_mem_map_mem_h dest_memh,
541+
ucc_mem_map_mem_h *dest_memh,
539542
ucc_tl_ucp_team_t *team)
540543
{
541544
ucp_request_param_t req_param = {0};
@@ -553,7 +556,7 @@ static inline ucc_status_t ucc_tl_ucp_atomic_inc(void * target,
553556
}
554557

555558
status = ucc_tl_ucp_resolve_p2p_by_va(team, target, &ep, dest_group_rank,
556-
&rva, &rkey, &segment, src_memh, dest_memh);
559+
&rva, &rkey, &segment, dest_memh);
557560
if (ucc_unlikely(UCC_OK != status)) {
558561
return status;
559562
}

src/core/ucc_context.c

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,6 +1163,7 @@ ucc_status_t ucc_mem_map_import(ucc_context_h context,
11631163
ctx->n_tl_ctx, sizeof(ucc_mem_map_tl_t), "tl memh");
11641164
for (i = 0; i < ctx->n_tl_ctx; i++) {
11651165
tl_lib = ucc_derived_of(ctx->tl_ctx[i]->super.lib, ucc_tl_lib_t);
1166+
strncpy(local_memh->tl_h[i].tl_name, tls->names[i], UCC_MEM_MAP_TL_NAME_LEN - 1);
11661167
status = tl_lib->iface->context.mem_map(
11671168
(const ucc_base_context_t *)ctx->tl_ctx[i], type,
11681169
params->segments[0].address, params->segments[0].len, local_memh,
@@ -1171,7 +1172,6 @@ ucc_status_t ucc_mem_map_import(ucc_context_h context,
11711172
ucc_error("failed to import mem map memh %d", status);
11721173
return status;
11731174
}
1174-
strncpy(local_memh->tl_h[i].tl_name, tls->names[i], 7);
11751175
}
11761176
local_memh->type = type;
11771177
/* fix context as it will be incorrect on a different system */
@@ -1272,13 +1272,11 @@ ucc_status_t ucc_mem_map_export(ucc_context_h context,
12721272
/* copying */
12731273
exported_memh->tl_h = local_memh->tl_h;
12741274
for (i = 0, offset = 0; i < ctx->n_tl_ctx; i++) {
1275-
uint64_t tl_index = i;
12761275
if (local_memh->tl_h[i].packed_size == 0) {
12771276
continue;
12781277
}
1279-
memcpy(PTR_OFFSET(exported_memh->pack_buffer, offset), &tl_index,
1280-
sizeof(size_t));
1281-
offset += sizeof(size_t);
1278+
strncpy(PTR_OFFSET(exported_memh->pack_buffer, offset), tls->names[i], UCC_MEM_MAP_TL_NAME_LEN - 1);
1279+
offset += UCC_MEM_MAP_TL_NAME_LEN;
12821280
memcpy(PTR_OFFSET(exported_memh->pack_buffer, offset),
12831281
&local_memh->tl_h[i].packed_size, sizeof(size_t));
12841282
offset += sizeof(size_t);

0 commit comments

Comments
 (0)