Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UCP Switch Based Multicast Support #65

Open
wants to merge 1 commit into
base: huawei
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/planc/ucx/bcast/bcast.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ static ucg_plan_attr_t ucg_planc_ucx_bcast_plan_attr[] = {
{ucg_planc_ucx_bcast_kntree_prepare,
10, "K-nomial tree", PLAN_DOMAIN},

{ucg_planc_ucx_bcast_multicast_prepare,
11, "switch-based multicast", PLAN_DOMAIN},

{NULL},
};
UCG_PLAN_ATTR_REGISTER_TABLE(ucg_planc_ucx, UCG_COLL_TYPE_BCAST,
Expand Down Expand Up @@ -73,6 +76,11 @@ static ucg_config_field_t bcast_config_table[] = {
ucg_offsetof(ucg_planc_ucx_bcast_config_t, root_adjust),
UCG_CONFIG_TYPE_BOOL},

{"BCAST_MCAST_ROOT_IP", "",
"enable switch-based multicast",
ucg_offsetof(ucg_planc_ucx_bcast_config_t, mcast_root_ip),
UCG_CONFIG_TYPE_STRING},

{NULL}
};
UCG_PLANC_UCX_BUILTIN_ALGO_REGISTER(UCG_COLL_TYPE_BCAST, bcast_config_table,
Expand Down Expand Up @@ -274,6 +282,10 @@ static ucg_plan_policy_t bcast_LG_LG[] = {
{10, {0, 16384}, UCG_PLAN_UCX_PLAN_SCORE_2ND},
UCG_PLAN_LAST_POLICY,
};
static ucg_plan_policy_t bcast_mcast[] = {
{11, {0, UCG_PLAN_RANGE_MAX}, UCG_PLAN_UCX_PLAN_SCORE_1ST},
UCG_PLAN_LAST_POLICY,
};

static ucg_plan_policy_t* bcast_plan_policy[] = {
bcast_4_4,
Expand Down Expand Up @@ -302,9 +314,17 @@ static ucg_plan_policy_t* bcast_plan_policy[] = {
bcast_LG_LG,
};

const ucg_plan_policy_t *ucg_planc_ucx_get_bcast_plan_policy(ucg_planc_ucx_node_level_t node_level,
const ucg_plan_policy_t *ucg_planc_ucx_get_bcast_plan_policy(ucg_planc_ucx_group_t *ucx_group,
ucg_planc_ucx_node_level_t node_level,
ucg_planc_ucx_ppn_level_t ppn_level)
{
ucg_planc_ucx_bcast_config_t *bcast_config;
bcast_config = UCG_PLANC_UCX_CONTEXT_BUILTIN_CONFIG_BUNDLE(ucx_group->context, bcast,
UCG_COLL_TYPE_BCAST);
if (strcmp("", bcast_config->mcast_root_ip)) {
return bcast_mcast;
}

int idx = node_level * PPN_LEVEL_NUMS + ppn_level;
ucg_assert(idx < NODE_LEVEL_NUMS * PPN_LEVEL_NUMS);
ucg_plan_policy_t *policy = bcast_plan_policy[idx];
Expand Down
15 changes: 14 additions & 1 deletion src/planc/ucx/bcast/bcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "util/algo/ucg_kntree.h"
#include "util/algo/ucg_ring.h"
#include "util/ucg_log.h"
#include <ucp/api/ucp.h>

typedef struct ucg_planc_ucx_bcast_config {
/* configuration of kntree bcast */
Expand All @@ -24,15 +25,23 @@ typedef struct ucg_planc_ucx_bcast_config {
/* configuration of node-aware kntree bcast */
int na_kntree_inter_degree;
int na_kntree_intra_degree;
char *mcast_root_ip;
} ucg_planc_ucx_bcast_config_t;

typedef struct ucg_algo_mcast_ctx {
int init_done;
ucp_coll_bcast_ctx_h ucp_mcast_ctx;
char *server_ip;
} ucg_algo_mcast_ctx_t;

/**
* @brief Bcast op auxiliary information
*/
typedef struct ucg_planc_ucx_bcast {
union {
ucg_algo_kntree_iter_t kntree_iter;
ucg_algo_ring_iter_t ring_iter;
ucg_algo_mcast_ctx_t mcast_ctx;
struct {
ucg_algo_kntree_iter_t kntree_iter;
ucg_algo_ring_iter_t ring_iter;
Expand All @@ -42,7 +51,8 @@ typedef struct ucg_planc_ucx_bcast {
};
} ucg_planc_ucx_bcast_t;

const ucg_plan_policy_t *ucg_planc_ucx_get_bcast_plan_policy(ucg_planc_ucx_node_level_t node_level,
const ucg_plan_policy_t *ucg_planc_ucx_get_bcast_plan_policy(ucg_planc_ucx_group_t *ucx_group,
ucg_planc_ucx_node_level_t node_level,
ucg_planc_ucx_ppn_level_t ppn_level);

/* xxx_op_new routines are provided for internal algorithm combination */
Expand Down Expand Up @@ -83,6 +93,9 @@ ucg_status_t ucg_planc_ucx_bcast_nta_kntree_prepare(ucg_vgroup_t *group,
ucg_status_t ucg_planc_ucx_bcast_van_de_geijn_prepare(ucg_vgroup_t *vgroup,
const ucg_coll_args_t *args,
ucg_plan_op_t **op);
ucg_status_t ucg_planc_ucx_bcast_multicast_prepare(ucg_vgroup_t *vgroup,
const ucg_coll_args_t *args,
ucg_plan_op_t **op);

/* helper for adding op to meta op. */
ucg_status_t ucg_planc_ucx_bcast_add_adjust_root_op(ucg_plan_meta_op_t *meta_op,
Expand Down
109 changes: 109 additions & 0 deletions src/planc/ucx/bcast/bcast_mcast.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
*/

#include "bcast.h"
#include "planc_ucx_plan.h"
#include "planc_ucx_p2p.h"
#include "core/ucg_dt.h"
#include "core/ucg_group.h"
#include "util/ucg_log.h"
#include "util/ucg_malloc.h"
#include <netdb.h>
#include <arpa/inet.h> /* inet_addr */


static int server_port = 13300;

static ucg_status_t ucg_planc_ucx_bcast_multicast_op_progress(ucg_plan_op_t *ucg_op)
{
ucg_planc_ucx_op_t *op = ucg_derived_of(ucg_op, ucg_planc_ucx_op_t);
if (UCS_OK == ucp_tag_send_bcast_progress(op->bcast.mcast_ctx.ucp_mcast_ctx)) {
ucg_op->super.status = UCG_OK;
return UCG_OK;
}
ucg_op->super.status = UCG_INPROGRESS;
return UCG_INPROGRESS;
}

static ucg_status_t ucg_planc_ucx_bcast_multicast_op_trigger(ucg_plan_op_t *ucg_op)
{
ucg_status_t status;
ucg_planc_ucx_op_t *op = ucg_derived_of(ucg_op, ucg_planc_ucx_op_t);
ucg_planc_ucx_op_reset(op);
ucg_coll_bcast_args_t *args = &ucg_op->super.args.bcast;
ucg_rank_t myrank = op->super.vgroup->myrank;
uint32_t group_size = op->super.vgroup->size;
int is_root = !!(args->root == myrank);
if (!op->bcast.mcast_ctx.init_done) {
struct sockaddr_in sock_addr = {
.sin_family = AF_INET,
.sin_port = htons(server_port),
.sin_addr = {
.s_addr = is_root ? INADDR_ANY : inet_addr(op->bcast.mcast_ctx.server_ip)
}
};
ucs_sock_addr_t server_address = {
.addr = (struct sockaddr *)&sock_addr,
.addrlen = sizeof(struct sockaddr)
};

ucg_info("ucp_bcast_init: myrank=%d is_root=%d group_size=%d", myrank, is_root, group_size);

status = ucp_tag_send_bcast_init(&op->bcast.mcast_ctx.ucp_mcast_ctx, &server_address, group_size - 1, is_root ? UCP_COLL_BCAST_FLAG_SERVER : 0, myrank, 1000);
if (status) {
ucg_info("ucp_bcast_init ERROR: myrank=%d is_root=%d group_size=%d", myrank, is_root, group_size);
return UCG_ERR_IO_ERROR;
}
op->bcast.mcast_ctx.init_done = 1;
server_port++;
}
return ucp_tag_send_bcast_start(op->bcast.mcast_ctx.ucp_mcast_ctx, args->buffer, args->count, is_root ? UCP_COLL_BCAST_FLAG_SERVER: 0);
}

static ucg_status_t ucg_planc_ucx_bcast_multicast_op_discard(ucg_plan_op_t *ucg_op)
{
ucg_planc_ucx_op_t *op = ucg_derived_of(ucg_op, ucg_planc_ucx_op_t);
ucp_tag_send_bcast_destroy(op->bcast.mcast_ctx.ucp_mcast_ctx);
UCG_CLASS_DESTRUCT(ucg_plan_op_t, ucg_op);
ucg_mpool_put(ucg_op);
return UCG_OK;
}

ucg_status_t ucg_planc_ucx_bcast_multicast_prepare(ucg_vgroup_t *vgroup,
const ucg_coll_args_t *args,
ucg_plan_op_t **op)
{
UCG_CHECK_NULL_INVALID(vgroup, args, op);
ucg_status_t status;
ucg_planc_ucx_group_t *ucx_group = ucg_derived_of(vgroup, ucg_planc_ucx_group_t);
ucg_planc_ucx_op_t *ucx_op = ucg_mpool_get(&ucx_group->context->op_mp);
ucg_planc_ucx_bcast_config_t *bcast_config;
if (ucx_op == NULL) {
return UCG_ERR_NO_MEMORY;
}

bcast_config = UCG_PLANC_UCX_CONTEXT_BUILTIN_CONFIG_BUNDLE(ucx_group->context, bcast,
UCG_COLL_TYPE_BCAST);

status = UCG_CLASS_CONSTRUCT(ucg_plan_op_t, &ucx_op->super, vgroup,
ucg_planc_ucx_bcast_multicast_op_trigger,
ucg_planc_ucx_bcast_multicast_op_progress,
ucg_planc_ucx_bcast_multicast_op_discard,
args);
if (status != UCG_OK) {
ucg_error("Failed to initialize super of ucx op");
goto err_free_op;
}
ucx_op->bcast.mcast_ctx.init_done = 0;
ucx_op->bcast.mcast_ctx.server_ip = bcast_config->mcast_root_ip;
ucg_planc_ucx_op_init(ucx_op, ucx_group);

*op = &ucx_op->super;
return UCG_OK;

err_free_op:
ucg_mpool_put(ucx_op);
return status;

}
7 changes: 4 additions & 3 deletions src/planc/ucx/planc_ucx_plan.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

UCG_PLAN_ATTR_TABLE_DEFINE(ucg_planc_ucx);

static const ucg_plan_policy_t* ucg_planc_ucx_get_plan_policy(ucg_coll_type_t coll_type,
static const ucg_plan_policy_t* ucg_planc_ucx_get_plan_policy(ucg_planc_ucx_group_t *ucx_group,
ucg_coll_type_t coll_type,
ucg_planc_ucx_node_level_t node_level,
ucg_planc_ucx_ppn_level_t ppn_level)
{
const ucg_plan_policy_t *policy = NULL;
switch (coll_type) {
case UCG_COLL_TYPE_BCAST:
policy = ucg_planc_ucx_get_bcast_plan_policy(node_level, ppn_level);
policy = ucg_planc_ucx_get_bcast_plan_policy(ucx_group, node_level, ppn_level);
break;
case UCG_COLL_TYPE_ALLREDUCE:
policy = ucg_planc_ucx_get_allreduce_plan_policy(node_level, ppn_level);
Expand Down Expand Up @@ -92,7 +93,7 @@ static ucg_status_t ucg_planc_ucx_add_default_plans(ucg_planc_ucx_group_t *ucx_g
ucg_coll_type_t coll_type = UCG_COLL_TYPE_BCAST;
for (; coll_type < UCG_COLL_TYPE_LAST; ++coll_type) {
/* get internal policy */
default_policy = ucg_planc_ucx_get_plan_policy(coll_type, node_level, ppn_level);
default_policy = ucg_planc_ucx_get_plan_policy(ucx_group, coll_type, node_level, ppn_level);
if (default_policy == NULL) {
continue;
}
Expand Down