Skip to content

Commit 57f969f

Browse files
arttianezhumeta-codesync[bot]
authored andcommitted
UT for Broadcast normal execution on both FT enabled or not
Summary: followup to D86437886, closing the gap for Broadcast to test both FT enabled & disabled paths. Reviewed By: dboyda Differential Revision: D86437887 fbshipit-source-id: 2c91a600d7da71bf1a1ad399ecd73ae05908fe8f
1 parent 7dfc547 commit 57f969f

File tree

1 file changed

+173
-0
lines changed

1 file changed

+173
-0
lines changed
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
#include <memory>
4+
#include <vector>
5+
6+
#include <gmock/gmock.h>
7+
#include <gtest/gtest.h>
8+
9+
#include "comms/ctran/Ctran.h"
10+
#include "comms/ctran/tests/CtranStandaloneUTUtils.h"
11+
#include "comms/utils/cvars/nccl_cvars.h"
12+
13+
namespace ctran::testing {
14+
15+
struct TestParam {
16+
std::string name;
17+
enum NCCL_BROADCAST_ALGO algo;
18+
};
19+
20+
class CtranBroadcastTest : public CtranStandaloneMultiRankBaseTest,
21+
public ::testing::WithParamInterface<TestParam> {
22+
protected:
23+
static constexpr int kNRanks = 4;
24+
static_assert(kNRanks % 2 == 0);
25+
static constexpr commDataType_t kDataType = commFloat32;
26+
static constexpr size_t kTypeSize = sizeof(float);
27+
static constexpr size_t kBufferNElem = kBufferSize / kTypeSize;
28+
29+
void SetUp() override {
30+
CtranStandaloneMultiRankBaseTest::SetUp();
31+
}
32+
33+
void overrideEnvConfig(const TestParam& param) {
34+
NCCL_BROADCAST_ALGO = param.algo;
35+
}
36+
37+
void startWorkers(
38+
const TestParam& param,
39+
std::optional<std::vector<std::shared_ptr<::ctran::utils::Abort>>>
40+
aborts = std::nullopt) {
41+
overrideEnvConfig(param);
42+
CtranStandaloneMultiRankBaseTest::startWorkers(
43+
kNRanks,
44+
/*aborts=*/
45+
aborts.value_or(std::vector<std::shared_ptr<::ctran::utils::Abort>>{}));
46+
}
47+
48+
void runTest(const TestParam& param) {
49+
for (int rank = 0; rank < kNRanks; ++rank) {
50+
run(rank,
51+
[this](PerRankState& state) { runBroadcast(kBufferNElem, state); });
52+
}
53+
}
54+
55+
void validateConfigs(size_t nElem) {
56+
ASSERT_TRUE(nElem <= kBufferNElem);
57+
}
58+
59+
void initBufferValues(size_t nElem, PerRankState& state) {
60+
std::vector<float> hostSrc(nElem, 1.0f);
61+
std::vector<float> hostDst(nElem, 0.0f);
62+
63+
ASSERT_EQ(
64+
cudaSuccess,
65+
cudaMemcpy(
66+
state.srcBuffer,
67+
hostSrc.data(),
68+
nElem * kTypeSize,
69+
cudaMemcpyHostToDevice));
70+
71+
ASSERT_EQ(
72+
cudaSuccess,
73+
cudaMemcpy(
74+
state.dstBuffer,
75+
hostDst.data(),
76+
nElem * kTypeSize,
77+
cudaMemcpyHostToDevice));
78+
}
79+
80+
void runBroadcast(size_t nElem, PerRankState& state, int root = 0) {
81+
validateConfigs(nElem);
82+
83+
CLOGF(INFO, "rank {} broadcast with {} elems", state.rank, nElem);
84+
85+
initBufferValues(nElem, state);
86+
87+
void* srcHandle;
88+
void* dstHandle;
89+
ASSERT_EQ(
90+
commSuccess,
91+
state.ctranComm->ctran_->commRegister(
92+
state.srcBuffer, kBufferSize, &srcHandle));
93+
ASSERT_EQ(
94+
commSuccess,
95+
state.ctranComm->ctran_->commRegister(
96+
state.dstBuffer, kBufferSize, &dstHandle));
97+
SCOPE_EXIT {
98+
state.ctranComm->ctran_->commDeregister(dstHandle);
99+
state.ctranComm->ctran_->commDeregister(srcHandle);
100+
};
101+
102+
CLOGF(INFO, "rank {} broadcast completed registration", state.rank);
103+
104+
auto result = ctranBroadcast(
105+
state.srcBuffer,
106+
state.dstBuffer,
107+
nElem,
108+
kDataType,
109+
root,
110+
state.ctranComm.get(),
111+
state.stream);
112+
EXPECT_EQ(commSuccess, result);
113+
114+
CLOGF(INFO, "rank {} broadcast scheduled", state.rank);
115+
116+
EXPECT_EQ(cudaSuccess, cudaStreamSynchronize(state.stream));
117+
EXPECT_EQ(commSuccess, state.ctranComm->getAsyncResult());
118+
119+
validateBroadcastData(nElem, state, root);
120+
121+
CLOGF(INFO, "rank {} broadcast task completed", state.rank);
122+
}
123+
124+
void validateBroadcastData(size_t nElem, PerRankState& state, int root) {
125+
std::vector<float> hostDst(nElem);
126+
ASSERT_EQ(
127+
cudaSuccess,
128+
cudaMemcpy(
129+
hostDst.data(),
130+
state.dstBuffer,
131+
nElem * kTypeSize,
132+
cudaMemcpyDeviceToHost));
133+
134+
for (size_t i = 0; i < nElem; ++i) {
135+
float expected = 1.0f;
136+
EXPECT_FLOAT_EQ(hostDst[i], expected)
137+
<< "Mismatch at index " << i << " on rank " << state.rank;
138+
}
139+
}
140+
};
141+
142+
TEST_P(CtranBroadcastTest, AbortDisabled) {
143+
auto param = GetParam();
144+
145+
startWorkers(param);
146+
147+
runTest(param);
148+
}
149+
150+
TEST_P(CtranBroadcastTest, AbortEnabled) {
151+
auto param = GetParam();
152+
153+
std::vector<std::shared_ptr<::ctran::utils::Abort>> aborts;
154+
aborts.reserve(kNRanks);
155+
for (int i = 0; i < kNRanks; ++i) {
156+
aborts.push_back(ctran::utils::createAbort(/*enabled=*/true));
157+
}
158+
startWorkers(param, aborts);
159+
160+
runTest(param);
161+
}
162+
163+
INSTANTIATE_TEST_SUITE_P(
164+
AllCombinations,
165+
CtranBroadcastTest,
166+
::testing::Values(
167+
TestParam{"broadcast_ctdirect", NCCL_BROADCAST_ALGO::ctdirect},
168+
TestParam{"broadcast_ctbtree", NCCL_BROADCAST_ALGO::ctbtree}),
169+
[](const ::testing::TestParamInfo<TestParam>& info) {
170+
return info.param.name;
171+
});
172+
173+
} // namespace ctran::testing

0 commit comments

Comments
 (0)