Skip to content

Commit 1d301f7

Browse files
authored
Merge pull request #13177 from roiedanino/shmem/1.5-support-ucx
2 parents f6fe1d4 + d0a4e07 commit 1d301f7

File tree

3 files changed

+165
-26
lines changed

3 files changed

+165
-26
lines changed

oshmem/mca/spml/ucx/spml_ucx.c

+139-14
Original file line numberDiff line numberDiff line change
@@ -1755,53 +1755,178 @@ int mca_spml_ucx_team_sync(shmem_team_t team)
17551755
return OSHMEM_ERR_NOT_IMPLEMENTED;
17561756
}
17571757

1758-
/* This routine is not implemented */
17591758
int mca_spml_ucx_team_my_pe(shmem_team_t team)
17601759
{
1761-
return OSHMEM_ERR_NOT_IMPLEMENTED;
1760+
mca_spml_ucx_team_t *ucx_team = (mca_spml_ucx_team_t *)team;
1761+
1762+
if (team == SHMEM_TEAM_WORLD) {
1763+
return shmem_my_pe();
1764+
}
1765+
1766+
return ucx_team->my_pe;
17621767
}
17631768

1764-
/* This routine is not implemented */
17651769
int mca_spml_ucx_team_n_pes(shmem_team_t team)
17661770
{
1767-
return OSHMEM_ERR_NOT_IMPLEMENTED;
1771+
mca_spml_ucx_team_t *ucx_team = (mca_spml_ucx_team_t *)team;
1772+
1773+
if (team == SHMEM_TEAM_WORLD) {
1774+
return shmem_n_pes();
1775+
}
1776+
1777+
return ucx_team->n_pes;
17681778
}
17691779

1770-
/* This routine is not implemented */
17711780
int mca_spml_ucx_team_get_config(shmem_team_t team, long config_mask,
17721781
shmem_team_config_t *config)
17731782
{
1774-
return OSHMEM_ERR_NOT_IMPLEMENTED;
1783+
mca_spml_ucx_team_t *ucx_team = (mca_spml_ucx_team_t *)team;
1784+
SPML_UCX_VALIDATE_TEAM(team);
1785+
1786+
memcpy(config, &ucx_team->config, sizeof(shmem_team_config_t));
1787+
1788+
return SHMEM_SUCCESS;
1789+
}
1790+
1791+
static inline int mca_spml_ucx_is_pe_in_strided_team(int src_pe, int start,
1792+
int stride, int size)
1793+
{
1794+
return (src_pe >= start) && (src_pe < start + size * stride)
1795+
&& ((src_pe - start) % stride == 0);
17751796
}
17761797

1777-
/* This routine is not implemented */
17781798
int mca_spml_ucx_team_translate_pe(shmem_team_t src_team, int src_pe,
1779-
shmem_team_t dest_team)
1799+
shmem_team_t dest_team)
17801800
{
1781-
return OSHMEM_ERR_NOT_IMPLEMENTED;
1801+
mca_spml_ucx_team_t *ucx_src_team = (mca_spml_ucx_team_t*) src_team;
1802+
mca_spml_ucx_team_t *ucx_dest_team = (mca_spml_ucx_team_t*) dest_team;
1803+
int global_pe;
1804+
1805+
if ((src_pe == SPML_UCX_PE_NOT_IN_TEAM) || (src_team == dest_team)) {
1806+
return src_pe;
1807+
}
1808+
1809+
global_pe = ucx_src_team->start + src_pe * ucx_src_team->stride;
1810+
1811+
if (dest_team == SHMEM_TEAM_WORLD) {
1812+
return global_pe;
1813+
}
1814+
1815+
if (!mca_spml_ucx_is_pe_in_strided_team(global_pe, ucx_dest_team->start, ucx_dest_team->stride,
1816+
ucx_dest_team->n_pes)) {
1817+
return SPML_UCX_PE_NOT_IN_TEAM;
1818+
}
1819+
1820+
return (global_pe - ucx_dest_team->start) / ucx_dest_team->stride;
17821821
}
17831822

1784-
/* This routine is not implemented */
17851823
int mca_spml_ucx_team_split_strided(shmem_team_t parent_team, int start, int
17861824
stride, int size, const shmem_team_config_t *config, long config_mask,
17871825
shmem_team_t *new_team)
17881826
{
1789-
return OSHMEM_ERR_NOT_IMPLEMENTED;
1827+
mca_spml_ucx_team_t *ucx_parent_team;
1828+
mca_spml_ucx_team_t *ucx_new_team;
1829+
int parent_pe;
1830+
int parent_start;
1831+
int parent_stride;
1832+
int my_pe;
1833+
1834+
SPML_UCX_ASSERT(((start + size * stride) <= oshmem_num_procs()) &&
1835+
(stride > 0) && (size > 0));
1836+
1837+
if (parent_team == SHMEM_TEAM_WORLD) {
1838+
parent_pe = shmem_my_pe();
1839+
parent_start = 0;
1840+
parent_stride = 1;
1841+
} else {
1842+
ucx_parent_team = (mca_spml_ucx_team_t*) parent_team;
1843+
parent_pe = ucx_parent_team->my_pe;
1844+
parent_start = ucx_parent_team->start;
1845+
parent_stride = ucx_parent_team->stride;
1846+
}
1847+
1848+
if (mca_spml_ucx_is_pe_in_strided_team(parent_pe, start, stride, size)) {
1849+
my_pe = (parent_pe - start) / stride;
1850+
} else {
1851+
/* not in team, according to spec it should be SHMEM_TEAM_INVALID but its value is NULL which
1852+
can be also interpreted as 0 (first pe), therefore -1 is used */
1853+
my_pe = SPML_UCX_PE_NOT_IN_TEAM;
1854+
}
1855+
1856+
/* In order to simplify pe translations start and stride are calculated with respect to
1857+
* world_team */
1858+
ucx_new_team = (mca_spml_ucx_team_t *)malloc(sizeof(mca_spml_ucx_team_t));
1859+
ucx_new_team->start = parent_start + (start * parent_stride);
1860+
ucx_new_team->stride = parent_stride * stride;
1861+
1862+
ucx_new_team->n_pes = size;
1863+
ucx_new_team->my_pe = my_pe;
1864+
1865+
ucx_new_team->config = calloc(1, sizeof(mca_spml_ucx_team_config_t));
1866+
1867+
if (config != NULL) {
1868+
memcpy(&ucx_new_team->config->super, config, sizeof(shmem_team_config_t));
1869+
}
1870+
1871+
ucx_new_team->parent_team = (mca_spml_ucx_team_t*)parent_team;
1872+
1873+
*new_team = (shmem_team_t)ucx_new_team;
1874+
1875+
return OSHMEM_SUCCESS;
17901876
}
17911877

1792-
/* This routine is not implemented */
17931878
int mca_spml_ucx_team_split_2d(shmem_team_t parent_team, int xrange, const
17941879
shmem_team_config_t *xaxis_config, long xaxis_mask, shmem_team_t
17951880
*xaxis_team, const shmem_team_config_t *yaxis_config, long yaxis_mask,
17961881
shmem_team_t *yaxis_team)
17971882
{
1798-
return OSHMEM_ERR_NOT_IMPLEMENTED;
1883+
mca_spml_ucx_team_t *ucx_parent_team = (mca_spml_ucx_team_t*) parent_team;
1884+
int parent_n_pes = (parent_team == SHMEM_TEAM_WORLD) ?
1885+
oshmem_num_procs() :
1886+
ucx_parent_team->n_pes;
1887+
int parent_my_pe = (parent_team == SHMEM_TEAM_WORLD) ?
1888+
shmem_my_pe() :
1889+
ucx_parent_team->my_pe;
1890+
int yrange = parent_n_pes / xrange;
1891+
int pe_x = parent_my_pe % xrange;
1892+
int pe_y = parent_my_pe / xrange;
1893+
int rc;
1894+
1895+
/* Create x-team of my_pe */
1896+
rc = mca_spml_ucx_team_split_strided(parent_team, pe_y * xrange, 1, xrange,
1897+
xaxis_config, xaxis_mask, xaxis_team);
1898+
1899+
if (rc != OSHMEM_SUCCESS) {
1900+
SPML_UCX_ERROR("mca_spml_ucx_team_split_strided failed (x-axis team creation)");
1901+
return rc;
1902+
}
1903+
1904+
/* Create y-team of my_pe */
1905+
rc = mca_spml_ucx_team_split_strided(parent_team, pe_x, xrange, yrange,
1906+
yaxis_config, yaxis_mask, yaxis_team);
1907+
if (rc != OSHMEM_SUCCESS) {
1908+
SPML_UCX_ERROR("mca_spml_ucx_team_split_strided failed (y-axis team creation)");
1909+
goto out_free_xaxis;
1910+
}
1911+
1912+
return OSHMEM_SUCCESS;
1913+
1914+
out_free_xaxis:
1915+
mca_spml_ucx_team_destroy(*xaxis_team);
1916+
return rc;
17991917
}
18001918

18011919
/* This routine is not implemented */
18021920
int mca_spml_ucx_team_destroy(shmem_team_t team)
18031921
{
1804-
return OSHMEM_ERR_NOT_IMPLEMENTED;
1922+
mca_spml_ucx_team_t *ucx_team = (mca_spml_ucx_team_t *)team;
1923+
1924+
SPML_UCX_VALIDATE_TEAM(team);
1925+
1926+
free(ucx_team->config);
1927+
free(team);
1928+
1929+
return OSHMEM_SUCCESS;
18051930
}
18061931

18071932
/* This routine is not implemented */

oshmem/mca/spml/ucx/spml_ucx.h

+23
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ BEGIN_C_DECLS
4848
#define SPML_UCX_TRANSP_IDX 0
4949
#define SPML_UCX_TRANSP_CNT 1
5050
#define SPML_UCX_SERVICE_SEG 0
51+
#define SPML_UCX_PE_NOT_IN_TEAM -1
52+
53+
#define SPML_UCX_VALIDATE_TEAM(_team) \
54+
do { \
55+
if (OPAL_UNLIKELY((_team) == SHMEM_TEAM_INVALID)) { \
56+
SPML_UCX_ERROR("Invalid team at %s", __func__); \
57+
return OSHMEM_ERROR; \
58+
} \
59+
} while (0)
5160

5261
enum {
5362
SPML_UCX_STRONG_ORDERING_NONE = 0, /* don't use strong ordering */
@@ -115,6 +124,20 @@ typedef struct mca_spml_ucx_ctx_array {
115124
mca_spml_ucx_ctx_t **ctxs;
116125
} mca_spml_ucx_ctx_array_t;
117126

127+
typedef struct mca_spml_ucx_team_config {
128+
shmem_team_config_t super;
129+
130+
} mca_spml_ucx_team_config_t;
131+
132+
typedef struct mca_spml_ucx_team {
133+
int n_pes;
134+
int my_pe;
135+
int stride;
136+
int start;
137+
mca_spml_ucx_team_config_t *config;
138+
struct mca_spml_ucx_team *parent_team;
139+
} mca_spml_ucx_team_t;
140+
118141
struct mca_spml_ucx {
119142
mca_spml_base_module_t super;
120143
ucp_context_h ucp_context;

oshmem/shmem/c/shmem_team.c

+3-12
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,9 @@ void shmem_team_sync(shmem_team_t team)
5151

5252
int shmem_team_my_pe(shmem_team_t team)
5353
{
54-
int rc = 0;
55-
5654
RUNTIME_CHECK_INIT();
5755

58-
rc = MCA_SPML_CALL(team_my_pe(team));
59-
RUNTIME_CHECK_IMPL_RC(rc);
60-
61-
return rc;
56+
return MCA_SPML_CALL(team_my_pe(team));
6257
}
6358

6459
int shmem_team_n_pes(shmem_team_t team)
@@ -85,15 +80,11 @@ int shmem_team_get_config(shmem_team_t team, long config_mask, shmem_team_config
8580
}
8681
int shmem_team_translate_pe(shmem_team_t src_team, int src_pe, shmem_team_t dest_team)
8782
{
88-
int rc = 0;
89-
9083
RUNTIME_CHECK_INIT();
9184

92-
rc = MCA_SPML_CALL(team_translate_pe(src_team, src_pe, dest_team));
93-
RUNTIME_CHECK_IMPL_RC(rc);
94-
95-
return rc;
85+
return MCA_SPML_CALL(team_translate_pe(src_team, src_pe, dest_team));
9686
}
87+
9788
int shmem_team_split_strided (shmem_team_t parent_team, int start, int stride,
9889
int size, const shmem_team_config_t *config, long config_mask,
9990
shmem_team_t *new_team)

0 commit comments

Comments
 (0)