Skip to content

Commit

Permalink
TL/UCP: Allow self copy in allgather using network loopback
Browse files Browse the repository at this point in the history
  • Loading branch information
yaeliyac committed Feb 12, 2025
1 parent 73651ea commit b9331ee
Show file tree
Hide file tree
Showing 22 changed files with 410 additions and 228 deletions.
12 changes: 10 additions & 2 deletions src/components/tl/mlx5/alltoall/alltoall_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ static ucc_status_t ucc_tl_mlx5_fanout_start(ucc_coll_task_t *coll_task)
tl_debug(UCC_TASK_LIB(task), "fanout start");
/* start task if completion event received */
UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start", 0);
if (team->a2a->node.sbgp->group_rank == team->a2a->node.asr_rank) {
UCC_TL_MLX5_PROFILE_REQUEST_EVENT(
task, "mlx5_alltoall_wait-on-data_start", 0);
}
/* Start fanout */
ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, coll_task);
return UCC_OK;
Expand All @@ -265,6 +269,8 @@ static void ucc_tl_mlx5_fanout_progress(ucc_coll_task_t *coll_task)
coll_task->status = UCC_INPROGRESS;
return;
}
UCC_TL_MLX5_PROFILE_REQUEST_EVENT(
task, "mlx5_alltoall_wait-on-data_complete, fanout_start", 0);
}

if (UCC_OK == ucc_tl_mlx5_node_fanout(team, task)) {
Expand Down Expand Up @@ -342,12 +348,14 @@ static ucc_status_t ucc_tl_mlx5_asr_barrier_start(ucc_coll_task_t *coll_task)
status = send_done(team, i);
}
if (status != UCC_OK) {
tl_error(UCC_TASK_LIB(task), "failed sending barrier notice");
tl_error(UCC_TASK_LIB(task), "failed sending barrier notice");
return status;
}
UCC_TL_MLX5_PROFILE_REQUEST_EVENT(
task, "mlx5_alltoall_barrier_send_posted", 0);
}
coll_task->status = UCC_OK;
UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barreir_done",
UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_done",
0);
return ucc_task_complete(coll_task);
}
Expand Down
48 changes: 31 additions & 17 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,22 +199,25 @@ struct pp_packet {
uintptr_t buf; // buffer address, initialized once
};

struct mcast_group {
struct ibv_qp *qp;
struct ibv_ah *ah;
uint16_t lid;
union ibv_gid mgid;
struct sockaddr_in6 mcast_addr;
};

struct mcast_ctx {
struct ibv_qp *qp;
struct ibv_ah *ah;
struct ibv_send_wr swr;
struct ibv_sge ssg;

struct ibv_cq *scq;
struct ibv_cq *rcq;
struct ibv_srq *srq;
struct mcast_group groups[MAX_GROUP_COUNT];
// RC connection info for supporing one-sided based relibality
struct ibv_qp **rc_qp;
uint16_t *rc_lid;
union ibv_gid *rc_gid;

// multiple mcast group
struct ibv_qp **qp_list;
struct ibv_ah **ah_list;
struct ibv_send_wr *swr_list;
struct ibv_sge *ssg_list;
};

struct packet {
Expand Down Expand Up @@ -303,15 +306,10 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm {
ucc_tl_mlx5_mcast_coll_comm_init_spec_t params;
ucc_tl_mlx5_mcast_p2p_interface_t p2p;
int tx;
struct ibv_cq *scq;
struct ibv_cq *rcq;
struct ibv_srq *srq;
ucc_rank_t rank;
ucc_rank_t commsize;
char *grh_buf;
struct ibv_mr *grh_mr;
uint16_t mcast_lid;
union ibv_gid mgid;
unsigned max_inline;
size_t max_eager;
int max_per_packet;
Expand All @@ -334,7 +332,6 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm {
int comm_id;
void *p2p_ctx;
ucc_base_lib_t *lib;
struct sockaddr_in6 mcast_addr;
int cuda_mem_enabled;
ucc_tl_mlx5_mcast_join_info_t *group_setup_info;
ucc_service_coll_req_t *group_setup_info_req;
Expand Down Expand Up @@ -441,6 +438,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_req {
ucc_service_coll_req_t *allgather_rkeys_req;
ucc_service_coll_req_t *barrier_req;
void *recv_rreg;
ucc_ee_executor_task_t *exec_task;
ucc_coll_task_t *coll_task;
} ucc_tl_mlx5_mcast_coll_req_t;

typedef struct ucc_tl_mlx5_mcast_oob_p2p_context {
Expand Down Expand Up @@ -490,7 +489,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_recv_buffers(ucc_tl_mlx5_mcast
}
if (i != 0) {
rwr[i-1].next = NULL;
if (ibv_post_recv(comm->mcast.qp, &rwr[0], &bad_wr)) {
if (ibv_post_recv(comm->mcast.groups[0].qp, &rwr[0], &bad_wr)) {
tl_error(comm->lib, "failed to prepost recvs: errno %d", errno);
return UCC_ERR_NO_RESOURCE;
}
Expand Down Expand Up @@ -543,7 +542,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_user_recv_buffers(ucc_tl_mlx5_

if (i > 0) {
rwr[i-1].next = NULL;
if (ibv_post_recv(comm->mcast.qp_list[group_id], &rwr[0], &bad_wr)) {
if (ibv_post_recv(comm->mcast.groups[group_id].qp, &rwr[0], &bad_wr)) {
tl_error(comm->lib, "Failed to prepost recvs: errno %d buffer count %d",
errno, i);
return UCC_ERR_NO_RESOURCE;
Expand All @@ -555,6 +554,21 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_user_recv_buffers(ucc_tl_mlx5_
return UCC_OK;
}

#define EXEC_TASK_TEST(_errmsg, _etask, _lib) do { \
if (_etask != NULL) { \
status = ucc_ee_executor_task_test(_etask); \
if (status > 0) { \
return status; \
} \
ucc_ee_executor_task_finalize(_etask); \
_etask = NULL; \
if (ucc_unlikely(status < 0)) { \
tl_error(_lib, _errmsg); \
return status; \
} \
} \
} while(0)

ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context,
ucc_tl_mlx5_mcast_team_t **mcast_team,
ucc_tl_mlx5_mcast_context_t *ctx,
Expand Down
10 changes: 9 additions & 1 deletion src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_r_window_recycle(ucc_tl_mlx5_mcast_
return status;
}

while (req->exec_task != NULL) {
EXEC_TASK_TEST("failed to complete the nb memcpy", req->exec_task, comm->lib);
}

comm->bcast_comm.n_mcast_reliable++;

for (; comm->bcast_comm.last_acked < comm->psn; comm->bcast_comm.last_acked++) {
Expand Down Expand Up @@ -267,7 +271,10 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task)
return ucc_task_complete(coll_task);
}

coll_task->status = status;
ucc_assert(task->coll_mcast.req_handle != NULL);

coll_task->status = status;
task->coll_mcast.req_handle->coll_task = coll_task;

return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(mlx5_team)->pq, &task->super);
}
Expand Down Expand Up @@ -333,6 +340,7 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task)
{
task->super.post = ucc_tl_mlx5_mcast_bcast_start;
task->super.progress = ucc_tl_mlx5_mcast_collective_progress;
task->super.flags = UCC_COLL_TASK_FLAG_EXECUTOR;

return UCC_OK;
}
1 change: 1 addition & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req);

ucc_status_t ucc_tl_mlx5_mcast_check_support(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team);

#endif
Loading

0 comments on commit b9331ee

Please sign in to comment.