Skip to content

Commit 9a0e6a1

Browse files
authored
Merge pull request #1456 from corey-derochie-amd/CP_6.3_SWDEV-479834
Cherry-pick (Adjustment for UT Sendrecv (#1400))
2 parents eb33738 + 871e134 commit 9a0e6a1

File tree

2 files changed

+83
-56
lines changed

2 files changed

+83
-56
lines changed

test/SendRecvTests.cpp

Lines changed: 77 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ namespace RcclUnitTesting
1616
std::vector<int> const numElements = {1048576, 53327, 1024, 0};
1717
bool const inPlace = false;
1818
bool const useManagedMem = false;
19-
int const groupCallId = 0;
2019

2120
OptionalColArgs options;
2221
bool isCorrect = true;
@@ -28,7 +27,10 @@ namespace RcclUnitTesting
2827
int ranksPerGpu = rpg == 0 ? 1 : testBed.ev.maxRanksPerGpu;
2928
int totalRanks = numGpus * ranksPerGpu;
3029
int const numProcesses = isMultiProcess ? numGpus : 1;
31-
testBed.InitComms(TestBed::GetDeviceIdsList(numProcesses, numGpus, ranksPerGpu), 1);
30+
testBed.InitComms(TestBed::GetDeviceIdsList(numProcesses, numGpus, ranksPerGpu),
31+
{1,2}, //two group, second group sendrecv to self, has 2 coll
32+
testBed.GetNumStreamsPerGroup(1,2),
33+
2);
3234

3335
for (int dataIdx = 0; dataIdx < dataTypes.size() && isCorrect; ++dataIdx)
3436
for (int numIdx = 0; numIdx < numElements.size() && isCorrect; ++numIdx)
@@ -37,6 +39,8 @@ namespace RcclUnitTesting
3739
for (int recvRank = 0; recvRank < totalRanks; ++recvRank)
3840
{
3941
options.root = recvRank;
42+
int groupCallId = sendRank == recvRank; //self sendrecv group has two coll
43+
int recvId = sendRank == recvRank; //where recv will be second coll
4044
testBed.SetCollectiveArgs(ncclCollSend,
4145
dataTypes[dataIdx],
4246
numElements[numIdx],
@@ -47,36 +51,46 @@ namespace RcclUnitTesting
4751
sendRank);
4852
if (recvRank == 0)
4953
{
50-
testBed.AllocateMem(inPlace, useManagedMem, groupCallId, 0, sendRank);
51-
testBed.PrepareData(groupCallId, 0, sendRank);
52-
}
53-
if (recvRank != sendRank)
54-
{
55-
if (testBed.ev.showNames) // Show test names
56-
INFO("%s Datatype: %s SendReceive test Rank %d -> Rank %d for %d Elements\n",
57-
isMultiProcess ? "MP" : "SP",
58-
ncclDataTypeNames[dataTypes[dataIdx]],
59-
sendRank,
60-
recvRank,
61-
numElements[numIdx]);
62-
63-
options.root = sendRank;
64-
testBed.SetCollectiveArgs(ncclCollRecv,
54+
//set up the collArg slot to make sure AllocateMem is called once and correctly
55+
testBed.SetCollectiveArgs(ncclCollSend,
6556
dataTypes[dataIdx],
6657
numElements[numIdx],
6758
numElements[numIdx],
6859
options,
6960
0,
70-
groupCallId,
71-
recvRank);
72-
testBed.AllocateMem(inPlace, useManagedMem, groupCallId, 0, recvRank);
73-
testBed.PrepareData(groupCallId, 0, recvRank);
74-
testBed.ExecuteCollectives({sendRank, recvRank});
75-
testBed.ValidateResults(isCorrect, groupCallId, 0, recvRank);
76-
testBed.DeallocateMem(groupCallId, 0, recvRank);
61+
!groupCallId,
62+
sendRank);
63+
testBed.AllocateMem(inPlace, useManagedMem, 0, 0, sendRank);
64+
testBed.PrepareData(0, 0, sendRank);
65+
testBed.AllocateMem(inPlace, useManagedMem, 1, 0, sendRank);
66+
testBed.PrepareData(1, 0, sendRank);
7767
}
68+
69+
if (testBed.ev.showNames) // Show test names
70+
INFO("%s Datatype: %s SendReceive test Rank %d -> Rank %d for %d Elements\n",
71+
isMultiProcess ? "MP" : "SP",
72+
ncclDataTypeNames[dataTypes[dataIdx]],
73+
sendRank,
74+
recvRank,
75+
numElements[numIdx]);
76+
options.root = sendRank;
77+
78+
testBed.SetCollectiveArgs(ncclCollRecv,
79+
dataTypes[dataIdx],
80+
numElements[numIdx],
81+
numElements[numIdx],
82+
options,
83+
recvId,
84+
groupCallId,
85+
recvRank);
86+
testBed.AllocateMem(inPlace, useManagedMem, groupCallId, recvId, recvRank);
87+
testBed.PrepareData(groupCallId, recvId, recvRank);
88+
testBed.ExecuteCollectives({sendRank, recvRank}, groupCallId);
89+
testBed.ValidateResults(isCorrect, groupCallId, recvId, recvRank);
90+
testBed.DeallocateMem(groupCallId, recvId, recvRank);
7891
}
79-
testBed.DeallocateMem(groupCallId, 0, sendRank);
92+
testBed.DeallocateMem(0, 0, sendRank);
93+
testBed.DeallocateMem(1, 0, sendRank);
8094
}
8195
testBed.DestroyComms();
8296
}
@@ -94,7 +108,6 @@ namespace RcclUnitTesting
94108
bool const inPlace = false;
95109
bool const useManagedMem = false;
96110
bool const userRegistered = true;
97-
int const groupCallId = 0;
98111

99112
OptionalColArgs options;
100113
bool isCorrect = true;
@@ -106,7 +119,10 @@ namespace RcclUnitTesting
106119
int ranksPerGpu = rpg == 0 ? 1 : testBed.ev.maxRanksPerGpu;
107120
int totalRanks = numGpus * ranksPerGpu;
108121
int const numProcesses = isMultiProcess ? numGpus : 1;
109-
testBed.InitComms(TestBed::GetDeviceIdsList(numProcesses, numGpus, ranksPerGpu), 1);
122+
testBed.InitComms(TestBed::GetDeviceIdsList(numProcesses, numGpus, ranksPerGpu),
123+
{1,2}, //two group, second group sendrecv to self, has 2 coll
124+
testBed.GetNumStreamsPerGroup(1,2),
125+
2);
110126

111127
for (int dataIdx = 0; dataIdx < dataTypes.size() && isCorrect; ++dataIdx)
112128
for (int numIdx = 0; numIdx < numElements.size() && isCorrect; ++numIdx)
@@ -115,6 +131,8 @@ namespace RcclUnitTesting
115131
for (int recvRank = 0; recvRank < totalRanks; ++recvRank)
116132
{
117133
options.root = recvRank;
134+
int groupCallId = sendRank == recvRank;
135+
int recvId = sendRank == recvRank;
118136
testBed.SetCollectiveArgs(ncclCollSend,
119137
dataTypes[dataIdx],
120138
numElements[numIdx],
@@ -125,36 +143,45 @@ namespace RcclUnitTesting
125143
sendRank);
126144
if (recvRank == 0)
127145
{
128-
testBed.AllocateMem(inPlace, useManagedMem, groupCallId, 0, sendRank, userRegistered);
129-
testBed.PrepareData(groupCallId, 0, sendRank);
130-
}
131-
if (recvRank != sendRank)
132-
{
133-
if (testBed.ev.showNames) // Show test names
134-
INFO("%s Datatype: %s SendReceive test Rank %d -> Rank %d for %d Elements\n",
135-
isMultiProcess ? "MP" : "SP",
136-
ncclDataTypeNames[dataTypes[dataIdx]],
137-
sendRank,
138-
recvRank,
139-
numElements[numIdx]);
140-
141-
options.root = sendRank;
142-
testBed.SetCollectiveArgs(ncclCollRecv,
146+
testBed.SetCollectiveArgs(ncclCollSend,
143147
dataTypes[dataIdx],
144148
numElements[numIdx],
145149
numElements[numIdx],
146150
options,
147151
0,
148-
groupCallId,
149-
recvRank);
150-
testBed.AllocateMem(inPlace, useManagedMem, groupCallId, 0, recvRank, userRegistered);
151-
testBed.PrepareData(groupCallId, 0, recvRank);
152-
testBed.ExecuteCollectives({sendRank, recvRank});
153-
testBed.ValidateResults(isCorrect, groupCallId, 0, recvRank);
154-
testBed.DeallocateMem(groupCallId, 0, recvRank);
152+
!groupCallId,
153+
sendRank);
154+
testBed.AllocateMem(inPlace, useManagedMem, 0, 0, sendRank, userRegistered);
155+
testBed.PrepareData(0, 0, sendRank);
156+
testBed.AllocateMem(inPlace, useManagedMem, 1, 0, sendRank, userRegistered);
157+
testBed.PrepareData(1, 0, sendRank);
155158
}
159+
160+
if (testBed.ev.showNames) // Show test names
161+
INFO("%s Datatype: %s SendReceive test Rank %d -> Rank %d for %d Elements\n",
162+
isMultiProcess ? "MP" : "SP",
163+
ncclDataTypeNames[dataTypes[dataIdx]],
164+
sendRank,
165+
recvRank,
166+
numElements[numIdx]);
167+
168+
options.root = sendRank;
169+
testBed.SetCollectiveArgs(ncclCollRecv,
170+
dataTypes[dataIdx],
171+
numElements[numIdx],
172+
numElements[numIdx],
173+
options,
174+
recvId,
175+
groupCallId,
176+
recvRank);
177+
testBed.AllocateMem(inPlace, useManagedMem, groupCallId, recvId, recvRank, userRegistered);
178+
testBed.PrepareData(groupCallId, recvId, recvRank);
179+
testBed.ExecuteCollectives({sendRank, recvRank}, groupCallId);
180+
testBed.ValidateResults(isCorrect, groupCallId, recvId, recvRank);
181+
testBed.DeallocateMem(groupCallId, recvId, recvRank);
156182
}
157-
testBed.DeallocateMem(groupCallId, 0, sendRank);
183+
testBed.DeallocateMem(0, 0, sendRank);
184+
testBed.DeallocateMem(1, 0, sendRank);
158185
}
159186
testBed.DestroyComms();
160187
}

test/common/TestBedChild.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,8 @@ namespace RcclUnitTesting
395395
{
396396
CollectiveArgs& collArg = this->collArgs[groupId][localRank][collIdx];
397397
CHECK_CALL(collArg.AllocateMem(inPlace, useManagedMem, userRegistered));
398+
if (collArg.userRegistered && (collArg.funcType == ncclCollSend || collArg.funcType == ncclCollRecv))
399+
CHILD_NCCL_CALL(ncclCommRegister(this->comms[localRank], collArg.inputGpu.ptr, collArg.numInputBytesAllocated, &(collArg.commRegHandle)),"ncclCommRegister");
398400
if (this->verbose) INFO("Rank %d on child %d allocates memory for collective %d in group %d on device %d (%s,%s,%s) Input: %p Output %p\n",
399401
globalRank, this->childId, collIdx, groupId, this->deviceIds[localRank],
400402
inPlace ? "in-place" : "out-of-place",
@@ -646,8 +648,6 @@ namespace RcclUnitTesting
646648
"ncclAllToAllv");
647649
break;
648650
case ncclCollSend:
649-
if (collArg.userRegistered)
650-
CHILD_NCCL_CALL_RANK(errCode, ncclCommRegister(this->comms[localRank], collArg.inputGpu.ptr, collArg.numInputBytesAllocated, &(collArg.commRegHandle)),"ncclCommRegister");
651651
CHILD_NCCL_CALL_RANK(errCode, ncclSend(
652652
collArg.inputGpu.ptr,
653653
collArg.numInputElements,
@@ -658,8 +658,6 @@ namespace RcclUnitTesting
658658
"ncclSend");
659659
break;
660660
case ncclCollRecv:
661-
if (collArg.userRegistered)
662-
CHILD_NCCL_CALL_RANK(errCode, ncclCommRegister(this->comms[localRank], collArg.outputGpu.ptr, collArg.numOutputBytesAllocated, &(collArg.commRegHandle)), "ncclCommRegister");
663661
CHILD_NCCL_CALL_RANK(errCode, ncclRecv(
664662
collArg.outputGpu.ptr,
665663
collArg.numOutputElements,
@@ -891,15 +889,17 @@ namespace RcclUnitTesting
891889
for (int collIdx = 0; collIdx < collArgs[groupId][localRank].size(); ++collIdx)
892890
{
893891
CollectiveArgs& collArg = this->collArgs[groupId][localRank][collIdx];
894-
if (collArg.userRegistered && (collArg.funcType == ncclCollSend || collArg.funcType == ncclCollRecv))
895-
CHILD_NCCL_CALL(ncclCommDeregister(this->comms[localRank], collArg.commRegHandle), "ncclCommDeregister");
896892
if (collId == -1 || collId == collIdx)
897893
{
898894
if (this->verbose)
899895
{
900896
INFO("Child %d release memory for collective %d in group %d (Input: %p Output %p\n",
901897
this->childId, collIdx, groupId, collArg.inputGpu.ptr, collArg.outputGpu.ptr);
902898
}
899+
if (collArg.userRegistered && (collArg.funcType == ncclCollSend || collArg.funcType == ncclCollRecv))
900+
{
901+
CHILD_NCCL_CALL(ncclCommDeregister(this->comms[localRank], collArg.commRegHandle), "ncclCommDeregister");
902+
}
903903

904904
CHECK_CALL(collArg.DeallocateMem());
905905
}

0 commit comments

Comments
 (0)