@@ -20,13 +20,15 @@ void ucc_tl_ucp_allreduce_ring_progress(ucc_coll_task_t *coll_task)
20
20
size_t count = TASK_ARGS (task ).dst .info .count ;
21
21
ucc_datatype_t dt = TASK_ARGS (task ).dst .info .datatype ;
22
22
size_t data_size = count * ucc_dt_size (dt );
23
+ int num_chunks = tsize ; // Number of chunks equals number of ranks
23
24
size_t chunk_size , offset , remaining ;
24
25
ucc_rank_t sendto , recvfrom ;
25
26
void * recv_buf , * send_buf , * reduce_buf ;
26
27
ucc_status_t status ;
28
+ int step , chunk ;
27
29
28
- int num_chunks = tsize ; // Use the number of ranks as the number of chunks (this is dynamic)
29
- chunk_size = (data_size + num_chunks - 1 ) / num_chunks ; // Ensure chunks fit into data evenly
30
+ // Divide data into chunks, rounding up to ensure we cover all data
31
+ chunk_size = ucc_div_round_up (data_size , num_chunks );
30
32
31
33
if (UCC_IS_INPLACE (TASK_ARGS (task ))) {
32
34
sbuf = rbuf ;
@@ -39,16 +41,22 @@ void ucc_tl_ucp_allreduce_ring_progress(ucc_coll_task_t *coll_task)
39
41
sendto = ucc_ep_map_eval (task -> subset .map , (trank + 1 ) % tsize );
40
42
recvfrom = ucc_ep_map_eval (task -> subset .map , (trank - 1 + tsize ) % tsize );
41
43
42
- while (task -> tagged .send_posted < tsize - 1 ) {
43
- int step = task -> tagged .send_posted ;
44
-
45
- for (int chunk = 0 ; chunk < num_chunks ; chunk ++ ) {
44
+ /*
45
+ * In the ring algorithm, each process sends/receives tsize-1 times
46
+ * This is because after tsize-1 steps, each piece of data has traversed
47
+ * the entire ring and completed its reduction
48
+ */
49
+ while (task -> allreduce_ring .step < tsize - 1 ) {
50
+ step = task -> allreduce_ring .step ;
51
+
52
+ /* Resume from the last processed chunk */
53
+ for (chunk = task -> allreduce_ring .chunk ; chunk < num_chunks ; chunk ++ ) {
46
54
offset = chunk * chunk_size ;
47
55
remaining = (chunk == num_chunks - 1 ) ? data_size - offset : chunk_size ;
48
56
49
- send_buf = (step == 0 ) ? sbuf + offset : rbuf + offset ;
50
- recv_buf = task -> allreduce_ring .scratch + offset ;
51
- reduce_buf = rbuf + offset ;
57
+ send_buf = (step == 0 ) ? PTR_OFFSET ( sbuf , offset ) : PTR_OFFSET ( rbuf , offset ) ;
58
+ recv_buf = PTR_OFFSET ( task -> allreduce_ring .scratch , offset ) ;
59
+ reduce_buf = PTR_OFFSET ( rbuf , offset ) ;
52
60
53
61
UCPCHECK_GOTO (
54
62
ucc_tl_ucp_send_nb (send_buf , remaining , mem_type , sendto , team , task ),
@@ -57,7 +65,11 @@ void ucc_tl_ucp_allreduce_ring_progress(ucc_coll_task_t *coll_task)
57
65
ucc_tl_ucp_recv_nb (recv_buf , remaining , mem_type , recvfrom , team , task ),
58
66
task , out );
59
67
68
+ /* Save current chunk position before testing progress */
69
+ task -> allreduce_ring .chunk = chunk ;
70
+
60
71
if (UCC_INPROGRESS == ucc_tl_ucp_test (task )) {
72
+ /* Return and resume from this chunk next time */
61
73
return ;
62
74
}
63
75
@@ -73,7 +85,9 @@ void ucc_tl_ucp_allreduce_ring_progress(ucc_coll_task_t *coll_task)
73
85
}
74
86
}
75
87
76
- task -> tagged .send_posted ++ ;
88
+ task -> allreduce_ring .step ++ ;
89
+ /* Reset chunk counter for the next step */
90
+ task -> allreduce_ring .chunk = 0 ;
77
91
}
78
92
79
93
ucc_assert (UCC_TL_UCP_TASK_P2P_COMPLETE (task ));
@@ -84,48 +98,50 @@ void ucc_tl_ucp_allreduce_ring_progress(ucc_coll_task_t *coll_task)
84
98
85
99
ucc_status_t ucc_tl_ucp_allreduce_ring_start (ucc_coll_task_t * coll_task )
86
100
{
87
- ucc_tl_ucp_task_t * task = ucc_derived_of (coll_task , ucc_tl_ucp_task_t );
88
- ucc_tl_ucp_team_t * team = TASK_TEAM (task );
89
- size_t count = TASK_ARGS (task ).dst .info .count ;
90
- ucc_datatype_t dt = TASK_ARGS (task ).dst .info .datatype ;
91
- size_t data_size = count * ucc_dt_size (dt );
92
- ucc_status_t status ;
101
+ ucc_tl_ucp_task_t * task = ucc_derived_of (coll_task , ucc_tl_ucp_task_t );
102
+ ucc_tl_ucp_team_t * team = TASK_TEAM (task );
93
103
94
104
UCC_TL_UCP_PROFILE_REQUEST_EVENT (coll_task , "ucp_allreduce_ring_start" , 0 );
95
105
ucc_tl_ucp_task_reset (task , UCC_INPROGRESS );
96
106
97
- /* Allocate scratch space for the receive buffer */
98
- status = ucc_mc_alloc (& task -> allreduce_ring .scratch_mc_header ,
99
- data_size , TASK_ARGS (task ).dst .info .mem_type );
100
- task -> allreduce_ring .scratch = task -> allreduce_ring .scratch_mc_header -> addr ;
101
- if (ucc_unlikely (status != UCC_OK )) {
102
- tl_error (UCC_TASK_LIB (task ), "failed to allocate scratch buffer" );
103
- return status ;
104
- }
105
-
106
107
return ucc_progress_queue_enqueue (UCC_TL_CORE_CTX (team )-> pq , & task -> super );
107
108
}
108
109
109
110
ucc_status_t ucc_tl_ucp_allreduce_ring_init_common (ucc_tl_ucp_task_t * task )
110
111
{
111
112
ucc_tl_ucp_team_t * team = TASK_TEAM (task );
112
113
ucc_sbgp_t * sbgp ;
114
+ size_t count = TASK_ARGS (task ).dst .info .count ;
115
+ ucc_datatype_t dt = TASK_ARGS (task ).dst .info .datatype ;
116
+ size_t data_size = count * ucc_dt_size (dt );
117
+ ucc_status_t status ;
113
118
114
119
if (!ucc_coll_args_is_predefined_dt (& TASK_ARGS (task ), UCC_RANK_INVALID )) {
115
120
tl_error (UCC_TASK_LIB (task ), "user defined datatype is not supported" );
116
121
return UCC_ERR_NOT_SUPPORTED ;
117
122
}
118
123
119
- if (!(task -> flags & UCC_TL_UCP_TASK_FLAG_SUBSET )) {
120
- if (team -> cfg .use_reordering ) {
121
- sbgp = ucc_topo_get_sbgp (team -> topo , UCC_SBGP_FULL_HOST_ORDERED );
122
- task -> subset .myrank = sbgp -> group_rank ;
123
- task -> subset .map = sbgp -> map ;
124
- }
124
+ if (!(task -> flags & UCC_TL_UCP_TASK_FLAG_SUBSET ) && team -> cfg .use_reordering ) {
125
+ sbgp = ucc_topo_get_sbgp (team -> topo , UCC_SBGP_FULL_HOST_ORDERED );
126
+ task -> subset .myrank = sbgp -> group_rank ;
127
+ task -> subset .map = sbgp -> map ;
128
+ }
129
+
130
+ /* Allocate scratch space for the receive buffer */
131
+ status = ucc_mc_alloc (& task -> allreduce_ring .scratch_mc_header ,
132
+ data_size , TASK_ARGS (task ).dst .info .mem_type );
133
+ if (ucc_unlikely (status != UCC_OK )) {
134
+ tl_error (UCC_TASK_LIB (task ), "failed to allocate scratch buffer" );
135
+ return status ;
125
136
}
137
+ task -> allreduce_ring .scratch = task -> allreduce_ring .scratch_mc_header -> addr ;
126
138
139
+ task -> allreduce_ring .step = 0 ; /* Initialize step counter */
140
+ task -> allreduce_ring .chunk = 0 ; /* Initialize chunk counter */
141
+
127
142
task -> super .post = ucc_tl_ucp_allreduce_ring_start ;
128
143
task -> super .progress = ucc_tl_ucp_allreduce_ring_progress ;
144
+ task -> super .finalize = ucc_tl_ucp_allreduce_ring_finalize ;
129
145
130
146
return UCC_OK ;
131
147
}
0 commit comments