Skip to content

Commit ddd36d5

Browse files
authored
Have hvd.join() return the last rank that joined (horovod#3097)
* Update flatbuffers to v2.0.0 flatc built from the previously included flatbuffers source (ca. 2019) would generate weird code, which was fixed by this PR in the mean time: google/flatbuffers#5258 Signed-off-by: Max H. Gerlach <[email protected]> * Return last rank that joined from hvd.join() in PyTorch and TensorFlow (as determined on the coordinator) Signed-off-by: Max H. Gerlach <[email protected]> * Restore copyright/license header Signed-off-by: Max H. Gerlach <[email protected]> * Add TF API doc string Signed-off-by: Max H. Gerlach <[email protected]>
1 parent f4d519c commit ddd36d5

17 files changed

+175
-55
lines changed

horovod/common/controller.cc

+6-1
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ ResponseList Controller::ComputeResponseList(bool this_process_requested_shutdow
261261

262262
if (message.request_type() == Request::JOIN) {
263263
process_set.joined_size++;
264+
process_set.last_joined_rank = global_ranks_[rank_];
264265
continue;
265266
}
266267

@@ -285,6 +286,7 @@ ResponseList Controller::ComputeResponseList(bool this_process_requested_shutdow
285286

286287
if (received_message.request_type() == Request::JOIN) {
287288
process_set.joined_size++;
289+
process_set.last_joined_rank = global_ranks_[i];
288290
continue;
289291
}
290292

@@ -401,12 +403,15 @@ ResponseList Controller::ComputeResponseList(bool this_process_requested_shutdow
401403
responses.push_back(std::move(response));
402404
}
403405
if (process_set.joined_size == size_) {
404-
// All ranks did Join(). Send the response, reset joined size.
406+
// All ranks did Join(). Send the response, reset joined_size and
407+
// last_joined_rank.
405408
Response join_response;
406409
join_response.set_response_type(Response::JOIN);
407410
join_response.add_tensor_name(JOIN_TENSOR_NAME);
411+
join_response.set_last_joined_rank(process_set.last_joined_rank);
408412
responses.push_back(std::move(join_response));
409413
process_set.joined_size = 0;
414+
process_set.last_joined_rank = -1;
410415
}
411416
FuseResponses(responses, state, response_list);
412417
response_list.set_shutdown(should_shut_down);

horovod/common/message.cc

+16-2
Original file line numberDiff line numberDiff line change
@@ -391,9 +391,21 @@ double Response::prescale_factor() const { return prescale_factor_; };
391391

392392
double Response::postscale_factor() const { return postscale_factor_; };
393393

394-
void Response::set_prescale_factor(const double prescale_factor) { prescale_factor_ = prescale_factor; };
394+
void Response::set_prescale_factor(const double prescale_factor) {
395+
prescale_factor_ = prescale_factor;
396+
};
395397

396-
void Response::set_postscale_factor(const double postscale_factor) { postscale_factor_ = postscale_factor; };
398+
void Response::set_postscale_factor(const double postscale_factor) {
399+
postscale_factor_ = postscale_factor;
400+
};
401+
402+
int Response::last_joined_rank() const {
403+
return last_joined_rank_;
404+
}
405+
406+
void Response::set_last_joined_rank(int value) {
407+
last_joined_rank_ = value;
408+
}
397409

398410
void Response_ParseFromWire(Response& response,
399411
const wire::Response* obj) {
@@ -409,6 +421,7 @@ void Response_ParseFromWire(Response& response,
409421
obj->tensor_sizes()->end()));
410422
response.set_prescale_factor(obj->prescale_factor());
411423
response.set_postscale_factor(obj->postscale_factor());
424+
response.set_last_joined_rank(obj->last_joined_rank());
412425
}
413426

414427
void Response::ParseFromBytes(Response& response, const uint8_t* input) {
@@ -437,6 +450,7 @@ void Response_SerializeToWire(const Response& response,
437450
response_builder.add_tensor_sizes(tensor_sizes_wire);
438451
response_builder.add_prescale_factor(response.prescale_factor());
439452
response_builder.add_postscale_factor(response.postscale_factor());
453+
response_builder.add_last_joined_rank(response.last_joined_rank());
440454
obj = response_builder.Finish();
441455
}
442456

horovod/common/message.h

+9-4
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,10 @@ class RequestList {
146146
};
147147

148148
// A Response is a message sent from the coordinator (rank zero) to a rank
149-
// greater than zero, informing the rank of an operation should be performed
150-
// now. If the operation requested would result in an error (for example, due
151-
// to a type or shape mismatch), then the Response can contain an error and
152-
// an error message instead.
149+
// greater than zero, informing the rank of an operation that should be
150+
// performed now. If the requested operation would result in an error (for
151+
// example, due to a type or shape mismatch), then the Response can contain an
152+
// error and an error message instead.
153153
class Response {
154154
public:
155155
enum ResponseType {
@@ -208,6 +208,10 @@ class Response {
208208

209209
void set_postscale_factor(double postscale_factor);
210210

211+
int last_joined_rank() const;
212+
213+
void set_last_joined_rank(int value);
214+
211215
static void ParseFromBytes(Response& response, const uint8_t* input);
212216

213217
static void SerializeToString(const Response& response,
@@ -222,6 +226,7 @@ class Response {
222226
std::vector<int64_t> tensor_sizes_;
223227
double prescale_factor_ = 1.0;
224228
double postscale_factor_ = 1.0;
229+
int last_joined_rank_ = -1;
225230
};
226231

227232
class ResponseList {

horovod/common/operations.cc

+4-3
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,9 @@ OperationManager* CreateOperationManager(HorovodGlobalState& state) {
257257
void PerformOperation(Response response, ProcessSet& process_set) {
258258
std::vector<TensorTableEntry> entries;
259259
auto& timeline = horovod_global.timeline;
260+
process_set.tensor_queue.GetTensorEntriesFromResponse(response, entries,
261+
process_set.joined);
260262
if (response.response_type() != Response::JOIN) {
261-
process_set.tensor_queue.GetTensorEntriesFromResponse(response, entries,
262-
process_set.joined);
263-
264263
for (auto& e : entries) {
265264
timeline.Start(e.tensor_name, response.response_type(), e.tensor->size());
266265
}
@@ -1725,6 +1724,7 @@ Status EnqueueTensorAlltoall(std::shared_ptr<OpContext> context,
17251724
// Contexts and controller must be initialized and the background thread
17261725
// must be running before this function is called.
17271726
Status EnqueueJoin(std::shared_ptr<OpContext> context,
1727+
std::shared_ptr<Tensor> output_last_joined_rank,
17281728
ReadyEventList ready_event_list,
17291729
const std::string& name, const int device,
17301730
StatusCallback callback,
@@ -1739,6 +1739,7 @@ Status EnqueueJoin(std::shared_ptr<OpContext> context,
17391739
TensorTableEntry e;
17401740
e.tensor_name = name;
17411741
e.context = context;
1742+
e.output = output_last_joined_rank;
17421743
e.process_set_id = process_set_id;
17431744
e.ready_event_list = ready_event_list;
17441745
e.device = device;

horovod/common/operations.h

+1
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ Status EnqueueTensorAlltoall(std::shared_ptr<OpContext> context,
227227
int32_t process_set_id = 0);
228228

229229
Status EnqueueJoin(std::shared_ptr<OpContext> context,
230+
std::shared_ptr<Tensor> output_last_joined_rank,
230231
ReadyEventList ready_event_list,
231232
const std::string& name, int device,
232233
StatusCallback callback,

horovod/common/ops/collective_operations.cc

+5-1
Original file line numberDiff line numberDiff line change
@@ -300,10 +300,14 @@ Status JoinOp::Execute(std::vector<TensorTableEntry>& entries,
300300
const Response& response, ProcessSet& process_set) {
301301
WaitForData(entries);
302302

303-
assert(entries.empty());
303+
assert(entries.size() == 1);
304+
auto e = entries[0];
305+
auto output_ptr = (int*) e.output->data();
306+
*output_ptr = response.last_joined_rank();
304307
if (process_set.joined) {
305308
process_set.tensor_queue.RemoveJoinTensor();
306309
process_set.joined = false;
310+
process_set.last_joined_rank = -1;
307311
}
308312
return Status::OK();
309313
}

horovod/common/process_set.h

+3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ struct ProcessSet {
4444
// Number of ranks that did Join()
4545
int joined_size = 0;
4646

47+
// Last global rank that did Join()
48+
int32_t last_joined_rank = -1;
49+
4750
// If a rank is Joined, AllReduce uses temporary 0 tensors for it.
4851
bool joined = false;
4952

horovod/common/tensor_queue.cc

+12
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,18 @@ void TensorQueue::GetTensorEntriesFromResponse(
9090
{
9191
// Lock on the tensor table.
9292
std::lock_guard<std::mutex> guard(mutex_);
93+
if (response.response_type() == Response::JOIN) {
94+
assert(response.tensor_names().size() == 1);
95+
assert(response.tensor_names()[0] == JOIN_TENSOR_NAME);
96+
auto iter = tensor_table_.find(JOIN_TENSOR_NAME);
97+
assert(iter != tensor_table_.end());
98+
99+
entries.push_back(std::move(iter->second));
100+
101+
// The tensor table will be cleared of the join tensor later in
102+
// RemoveJoinTensor().
103+
return;
104+
}
93105
int64_t i = 0;
94106
for (auto& name : response.tensor_names()) {
95107
assert(response.response_type() == Response::ALLREDUCE ||

horovod/common/wire/message.fbs

+2-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ table Request {
6262
// Prescale and postscale factors
6363
prescale_factor:double;
6464
postscale_factor:double;
65-
6665
}
6766
table RequestList {
6867
requests:[Request];
@@ -110,6 +109,8 @@ table Response {
110109
// Prescale and postscale factors
111110
prescale_factor:double;
112111
postscale_factor:double;
112+
113+
last_joined_rank:int;
113114
}
114115
table ResponseList {
115116
responses:[Response];

0 commit comments

Comments
 (0)