Skip to content

Commit 65bc710

Browse files
committed
TL/UCP: Add all-reduce ring alogrithm
Signed-off-by: Armen Ratner <[email protected]>
1 parent 874d705 commit 65bc710

File tree

4 files changed

+51
-35
lines changed

4 files changed

+51
-35
lines changed

src/components/tl/ucp/allreduce/allreduce.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ ucc_base_coll_alg_info_t
3232
[UCC_TL_UCP_ALLREDUCE_ALG_RING] =
3333
{.id = UCC_TL_UCP_ALLREDUCE_ALG_RING,
3434
.name = "ring",
35-
.desc = "ring-based allreduce (optimized for large messages and simple topologies)"},
35+
.desc = "ring-based allreduce (optimized for BW and simple topologies)"},
3636
[UCC_TL_UCP_ALLREDUCE_ALG_LAST] = {
3737
.id = 0, .name = NULL, .desc = NULL}};
3838

src/components/tl/ucp/allreduce/allreduce.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@ void ucc_tl_ucp_allreduce_ring_progress(ucc_coll_task_t *coll_task);
8282

8383
ucc_status_t ucc_tl_ucp_allreduce_ring_start(ucc_coll_task_t *coll_task);
8484

85-
ucc_status_t ucc_tl_ucp_allreduce_ring_init_common(ucc_tl_ucp_task_t *task);
86-
8785
ucc_status_t ucc_tl_ucp_allreduce_ring_init(ucc_base_coll_args_t *coll_args,
8886
ucc_base_team_t *team,
8987
ucc_coll_task_t **task_h);

src/components/tl/ucp/allreduce/allreduce_ring.c

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ void ucc_tl_ucp_allreduce_ring_progress(ucc_coll_task_t *coll_task)
2020
size_t count = TASK_ARGS(task).dst.info.count;
2121
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
2222
size_t data_size = count * ucc_dt_size(dt);
23+
int num_chunks = tsize; // Number of chunks equals number of ranks
2324
size_t chunk_size, offset, remaining;
2425
ucc_rank_t sendto, recvfrom;
2526
void *recv_buf, *send_buf, *reduce_buf;
2627
ucc_status_t status;
28+
int step, chunk;
2729

28-
int num_chunks = tsize; // Use the number of ranks as the number of chunks (this is dynamic)
29-
chunk_size = (data_size + num_chunks - 1) / num_chunks; // Ensure chunks fit into data evenly
30+
// Divide data into chunks, rounding up to ensure we cover all data
31+
chunk_size = ucc_div_round_up(data_size, num_chunks);
3032

3133
if (UCC_IS_INPLACE(TASK_ARGS(task))) {
3234
sbuf = rbuf;
@@ -39,16 +41,22 @@ void ucc_tl_ucp_allreduce_ring_progress(ucc_coll_task_t *coll_task)
3941
sendto = ucc_ep_map_eval(task->subset.map, (trank + 1) % tsize);
4042
recvfrom = ucc_ep_map_eval(task->subset.map, (trank - 1 + tsize) % tsize);
4143

42-
while (task->tagged.send_posted < tsize - 1) {
43-
int step = task->tagged.send_posted;
44-
45-
for (int chunk = 0; chunk < num_chunks; chunk++) {
44+
/*
45+
* In the ring algorithm, each process sends/receives tsize-1 times
46+
* This is because after tsize-1 steps, each piece of data has traversed
47+
* the entire ring and completed its reduction
48+
*/
49+
while (task->allreduce_ring.step < tsize - 1) {
50+
step = task->allreduce_ring.step;
51+
52+
/* Resume from the last processed chunk */
53+
for (chunk = task->allreduce_ring.chunk; chunk < num_chunks; chunk++) {
4654
offset = chunk * chunk_size;
4755
remaining = (chunk == num_chunks - 1) ? data_size - offset : chunk_size;
4856

49-
send_buf = (step == 0) ? sbuf + offset : rbuf + offset;
50-
recv_buf = task->allreduce_ring.scratch + offset;
51-
reduce_buf = rbuf + offset;
57+
send_buf = (step == 0) ? PTR_OFFSET(sbuf, offset) : PTR_OFFSET(rbuf, offset);
58+
recv_buf = PTR_OFFSET(task->allreduce_ring.scratch, offset);
59+
reduce_buf = PTR_OFFSET(rbuf, offset);
5260

5361
UCPCHECK_GOTO(
5462
ucc_tl_ucp_send_nb(send_buf, remaining, mem_type, sendto, team, task),
@@ -57,7 +65,11 @@ void ucc_tl_ucp_allreduce_ring_progress(ucc_coll_task_t *coll_task)
5765
ucc_tl_ucp_recv_nb(recv_buf, remaining, mem_type, recvfrom, team, task),
5866
task, out);
5967

68+
/* Save current chunk position before testing progress */
69+
task->allreduce_ring.chunk = chunk;
70+
6071
if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
72+
/* Return and resume from this chunk next time */
6173
return;
6274
}
6375

@@ -73,7 +85,9 @@ void ucc_tl_ucp_allreduce_ring_progress(ucc_coll_task_t *coll_task)
7385
}
7486
}
7587

76-
task->tagged.send_posted++;
88+
task->allreduce_ring.step++;
89+
/* Reset chunk counter for the next step */
90+
task->allreduce_ring.chunk = 0;
7791
}
7892

7993
ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task));
@@ -84,48 +98,50 @@ void ucc_tl_ucp_allreduce_ring_progress(ucc_coll_task_t *coll_task)
8498

8599
ucc_status_t ucc_tl_ucp_allreduce_ring_start(ucc_coll_task_t *coll_task)
86100
{
87-
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
88-
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
89-
size_t count = TASK_ARGS(task).dst.info.count;
90-
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
91-
size_t data_size = count * ucc_dt_size(dt);
92-
ucc_status_t status;
101+
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
102+
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
93103

94104
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allreduce_ring_start", 0);
95105
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
96106

97-
/* Allocate scratch space for the receive buffer */
98-
status = ucc_mc_alloc(&task->allreduce_ring.scratch_mc_header,
99-
data_size, TASK_ARGS(task).dst.info.mem_type);
100-
task->allreduce_ring.scratch = task->allreduce_ring.scratch_mc_header->addr;
101-
if (ucc_unlikely(status != UCC_OK)) {
102-
tl_error(UCC_TASK_LIB(task), "failed to allocate scratch buffer");
103-
return status;
104-
}
105-
106107
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
107108
}
108109

109110
ucc_status_t ucc_tl_ucp_allreduce_ring_init_common(ucc_tl_ucp_task_t *task)
110111
{
111112
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
112113
ucc_sbgp_t *sbgp;
114+
size_t count = TASK_ARGS(task).dst.info.count;
115+
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
116+
size_t data_size = count * ucc_dt_size(dt);
117+
ucc_status_t status;
113118

114119
if (!ucc_coll_args_is_predefined_dt(&TASK_ARGS(task), UCC_RANK_INVALID)) {
115120
tl_error(UCC_TASK_LIB(task), "user defined datatype is not supported");
116121
return UCC_ERR_NOT_SUPPORTED;
117122
}
118123

119-
if (!(task->flags & UCC_TL_UCP_TASK_FLAG_SUBSET)) {
120-
if (team->cfg.use_reordering) {
121-
sbgp = ucc_topo_get_sbgp(team->topo, UCC_SBGP_FULL_HOST_ORDERED);
122-
task->subset.myrank = sbgp->group_rank;
123-
task->subset.map = sbgp->map;
124-
}
124+
if (!(task->flags & UCC_TL_UCP_TASK_FLAG_SUBSET) && team->cfg.use_reordering) {
125+
sbgp = ucc_topo_get_sbgp(team->topo, UCC_SBGP_FULL_HOST_ORDERED);
126+
task->subset.myrank = sbgp->group_rank;
127+
task->subset.map = sbgp->map;
128+
}
129+
130+
/* Allocate scratch space for the receive buffer */
131+
status = ucc_mc_alloc(&task->allreduce_ring.scratch_mc_header,
132+
data_size, TASK_ARGS(task).dst.info.mem_type);
133+
if (ucc_unlikely(status != UCC_OK)) {
134+
tl_error(UCC_TASK_LIB(task), "failed to allocate scratch buffer");
135+
return status;
125136
}
137+
task->allreduce_ring.scratch = task->allreduce_ring.scratch_mc_header->addr;
126138

139+
task->allreduce_ring.step = 0; /* Initialize step counter */
140+
task->allreduce_ring.chunk = 0; /* Initialize chunk counter */
141+
127142
task->super.post = ucc_tl_ucp_allreduce_ring_start;
128143
task->super.progress = ucc_tl_ucp_allreduce_ring_progress;
144+
task->super.finalize = ucc_tl_ucp_allreduce_ring_finalize;
129145

130146
return UCC_OK;
131147
}

src/components/tl/ucp/tl_ucp_coll.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,13 +268,15 @@ typedef struct ucc_tl_ucp_task {
268268
ucc_rank_t iteration;
269269
int phase;
270270
} alltoall_bruck;
271-
char plugin_data[UCC_TL_UCP_TASK_PLUGIN_MAX_DATA];
272271
struct {
273272
void *scratch;
274273
ucc_mc_buffer_header_t *scratch_mc_header;
275274
ucc_ee_executor_task_t *etask;
276275
ucc_ee_executor_t *executor;
276+
int step; /* Track algorithm steps separately */
277+
int chunk; /* Track current chunk being processed */
277278
} allreduce_ring;
279+
char plugin_data[UCC_TL_UCP_TASK_PLUGIN_MAX_DATA];
278280
};
279281
} ucc_tl_ucp_task_t;
280282

0 commit comments

Comments
 (0)