Skip to content

Commit 8a12dc9

Browse files
Add Reducescatter op (NCCL, MPI, Gloo) (horovod#3299)
Signed-off-by: Max H. Gerlach <[email protected]> Co-authored-by: Jesse Benson (AI) <[email protected]> Co-authored-by: Jesse Benson <[email protected]>
1 parent e02bdca commit 8a12dc9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2845
-111
lines changed

CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
88

99
### Added
1010

11+
- Added `hvd.reducescatter()` operation with implementations in NCCL, MPI, and Gloo. ([#3299](https://github.com/horovod/horovod/pull/3299))
12+
1113
### Changed
1214

15+
- MXNet: Updated allreduce functions to newer `op` API. ([#3299](https://github.com/horovod/horovod/pull/3299))
16+
1317
### Deprecated
1418

19+
- MXNet: Deprecated `average` argument of allreduce functions. ([#3299](https://github.com/horovod/horovod/pull/3299))
20+
1521
### Removed
1622

1723
### Fixed

CMakeLists.txt

+5-4
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,9 @@ set_gpu_op(HOROVOD_GPU_ALLREDUCE "MPI;NCCL;DDL")
110110
set_gpu_op(HOROVOD_GPU_ALLGATHER "MPI;NCCL")
111111
set_gpu_op(HOROVOD_GPU_BROADCAST "MPI;NCCL")
112112
set_gpu_op(HOROVOD_GPU_ALLTOALL "MPI;NCCL")
113+
set_gpu_op(HOROVOD_GPU_REDUCESCATTER "MPI;NCCL")
113114

114-
foreach(VAR in ITEMS HOROVOD_GPU_ALLREDUCE HOROVOD_GPU_ALLGATHER HOROVOD_GPU_BROADCAST HOROVOD_GPU_ALLTOALL)
115+
foreach(VAR in ITEMS HOROVOD_GPU_ALLREDUCE HOROVOD_GPU_ALLGATHER HOROVOD_GPU_BROADCAST HOROVOD_GPU_ALLTOALL HOROVOD_GPU_REDUCESCATTER)
115116
if(DEFINED ${VAR})
116117
string(SUBSTRING ${${VAR}} 0 1 ${VAR})
117118
convert_to_ascii_dec(ASCII_DEC ${${VAR}})
@@ -197,7 +198,7 @@ macro(ADD_CUDA)
197198
endif()
198199
endmacro()
199200

200-
if(DEFINED HOROVOD_GPU_ALLREDUCE OR DEFINED HOROVOD_GPU_ALLGATHER OR DEFINED HOROVOD_GPU_BROADCAST OR DEFINED HOROVOD_GPU_ALLTOALL)
201+
if(DEFINED HOROVOD_GPU_ALLREDUCE OR DEFINED HOROVOD_GPU_ALLGATHER OR DEFINED HOROVOD_GPU_BROADCAST OR DEFINED HOROVOD_GPU_ALLTOALL OR DEFINED HOROVOD_GPU_REDUCESCATTER)
201202
if(NOT DEFINED HOROVOD_GPU OR HOROVOD_GPU STREQUAL "CUDA")
202203
add_cuda()
203204
elseif(HOROVOD_GPU STREQUAL "ROCM")
@@ -215,7 +216,7 @@ if(DEFINED HOROVOD_GPU_ALLREDUCE OR DEFINED HOROVOD_GPU_ALLGATHER OR DEFINED HOR
215216
endif()
216217

217218
# NCCL
218-
if(HOROVOD_GPU_ALLREDUCE STREQUAL "N" OR HOROVOD_GPU_ALLGATHER STREQUAL "N" OR HOROVOD_GPU_BROADCAST STREQUAL "N" OR HOROVOD_GPU_ALLTOALL STREQUAL "N")
219+
if(HOROVOD_GPU_ALLREDUCE STREQUAL "N" OR HOROVOD_GPU_ALLGATHER STREQUAL "N" OR HOROVOD_GPU_BROADCAST STREQUAL "N" OR HOROVOD_GPU_ALLTOALL STREQUAL "N" OR HOROVOD_GPU_REDUCESCATTER STREQUAL "N")
219220
if(HAVE_ROCM)
220221
find_package(rccl REQUIRED)
221222
include_directories(SYSTEM ${RCCL_INCLUDE_DIRS})
@@ -256,7 +257,7 @@ if(DEFINED CCL_ROOT)
256257
endif()
257258

258259
set(HOROVOD_ALLOW_MIXED_GPU_IMPL $ENV{HOROVOD_ALLOW_MIXED_GPU_IMPL})
259-
if(HOROVOD_GPU_ALLREDUCE STREQUAL "N" AND (HOROVOD_GPU_ALLGATHER STREQUAL "M" OR HOROVOD_GPU_BROADCAST STREQUAL "M" OR HOROVOD_GPU_ALLTOALL STREQUAL "M") AND
260+
if(HOROVOD_GPU_ALLREDUCE STREQUAL "N" AND (HOROVOD_GPU_ALLGATHER STREQUAL "M" OR HOROVOD_GPU_BROADCAST STREQUAL "M" OR HOROVOD_GPU_ALLTOALL STREQUAL "M" OR HOROVOD_GPU_REDUCESCATTER STREQUAL "M") AND
260261
NOT HOROVOD_ALLOW_MIXED_GPU_IMPL STREQUAL "1")
261262
message(FATAL_ERROR "You should not mix NCCL and MPI GPU due to a possible deadlock.\n"
262263
"If you are sure you want to mix them, set the "

docs/concepts.rst

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ a training script on 4 servers, each having 4 GPUs. If we launched one copy of t
3131
.. image:: http://mpitutorial.com/tutorials/mpi-broadcast-and-collective-communication/broadcast_pattern.png
3232
:alt: Broadcast Illustration
3333

34+
* *Reducescatter* is an operation that aggregates data among multiple processes and scatters the data across them. *Reducescatter* is used to average dense tensors then split them across processes. Here's an illustration from the `Nvidia developer guide <https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/usage/operations.html#reducescatter>`__:
35+
36+
.. image:: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/_images/reducescatter.png
37+
:alt: Reducescatter Illustration
3438

3539
* *Alltoall* is an operation to exchange data between all processes. *Alltoall* may be useful to implement neural networks with advanced architectures that span multiple devices.
3640

docs/gpus.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ use it instead:
8282
$ HOROVOD_GPU_ALLREDUCE=MPI pip install --no-cache-dir horovod
8383
8484
85-
Additionally, if your MPI vendor's implementation supports *allgather* and *broadcast* operations on GPU, you can
85+
Additionally, if your MPI vendor's implementation supports *allgather*, *broadcast*, and *reducescatter* operations on GPU, you can
8686
configure Horovod to use them as well:
8787

8888
.. code-block:: bash

docs/install.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,9 @@ Possible values are given in curly brackets: {}.
245245
* ``HOROVOD_GPU_ALLREDUCE`` - {NCCL, MPI}. Framework to use for GPU tensor allreduce.
246246
* ``HOROVOD_GPU_ALLGATHER`` - {NCCL, MPI}. Framework to use for GPU tensor allgather.
247247
* ``HOROVOD_GPU_BROADCAST`` - {NCCL, MPI}. Framework to use for GPU tensor broadcast.
248-
* ``HOROVOD_ALLOW_MIXED_GPU_IMPL`` - {1}. Allow Horovod to install with NCCL allreduce and MPI GPU allgather / broadcast. Not recommended due to a possible deadlock.
248+
* ``HOROVOD_GPU_ALLTOALL`` - {NCCL, MPI}. Framework to use for GPU tensor alltoall.
249+
* ``HOROVOD_GPU_REDUCESCATTER`` - {NCCL, MPI}. Framework to use for GPU tensor reducescatter.
250+
* ``HOROVOD_ALLOW_MIXED_GPU_IMPL`` - {1}. Allow Horovod to install with NCCL allreduce and MPI GPU allgather / broadcast / alltoall / reducescatter. Not recommended due to a possible deadlock.
249251
* ``HOROVOD_CPU_OPERATIONS`` - {MPI, GLOO, CCL}. Framework to use for CPU tensor allreduce, allgather, and broadcast.
250252
* ``HOROVOD_CMAKE`` - path to the CMake binary used to build Horovod.
251253
* ``HOROVOD_WITH_TENSORFLOW`` - {1}. Require Horovod to install with TensorFlow support enabled.

horovod/_keras/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,10 @@ def broadcast(backend, value, root_rank, name):
188188
return _eval(backend, hvd.broadcast(tf.constant(value, name=name), root_rank))
189189

190190

191+
def reducescatter(backend, value, name, op):
192+
return _eval(backend, hvd.reducescatter(tf.constant(value, name=name), op=op))
193+
194+
191195
def load_model(keras, wrap_optimizer, optimizer_modules, filepath, custom_optimizers, custom_objects):
192196
horovod_objects = {
193197
subclass.__name__.lower(): wrap_optimizer(subclass)

horovod/common/common.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ int TensorShape::dims() const {
101101

102102
int64_t TensorShape::dim_size(int idx) const {
103103
assert(idx >= 0);
104-
assert(idx < shape_.size());
104+
assert(idx < (int)shape_.size());
105105
return shape_[idx];
106106
}
107107

@@ -165,7 +165,7 @@ void parse_and_set_affinity(const char* affinity, int local_size, int local_rank
165165
auto core_id_str = strsep(&tmp, ",");
166166
errno = 0;
167167
auto core_id = std::strtol(core_id_str, &endptr, 10);
168-
if (errno == ERANGE && (core_id == LONG_MAX || core_id == LONG_MIN)
168+
if ((errno == ERANGE && (core_id == LONG_MAX || core_id == LONG_MIN))
169169
|| (errno != 0 && core_id == 0)){
170170
LOG(ERROR) << "Core ID value is invalid in " << HOROVOD_THREAD_AFFINITY
171171
<< "=" << affinity;

horovod/common/common.h

+2
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ namespace common {
8383
#define MEMCPY_IN_SHARED_BUFFER "MEMCPY_IN_SHARED_BUFFER"
8484
#define MPI_ALLREDUCE "MPI_ALLREDUCE"
8585
#define MPI_ADASUM_ALLREDUCE "MPI_ADASUM_ALLREDUCE"
86+
#define MPI_REDUCESCATTER "MPI_REDUCESCATTER"
8687
#define MEMCPY_OUT_HOST_BUFFER "MEMCPY_OUT_HOST_BUFFER"
8788
#define NCCL_ALLREDUCE "NCCL_ALLREDUCE"
8889
#define MEMCPY_OUT_FUSION_BUFFER "MEMCPY_OUT_FUSION_BUFFER"
@@ -102,6 +103,7 @@ namespace common {
102103
#define GLOO_ALLREDUCE "GLOO_ALLREDUCE"
103104
#define GLOO_ALLGATHER "GLOO_ALLGATHER"
104105
#define GLOO_BCAST "GLOO_BCAST"
106+
#define GLOO_REDUCESCATTER "GLOO_REDUCESCATTER"
105107
#define HOROVOD_ELASTIC "HOROVOD_ELASTIC"
106108

107109
// Horovod knobs.

horovod/common/controller.cc

+25-4
Original file line numberDiff line numberDiff line change
@@ -536,11 +536,12 @@ Response Controller::ConstructResponse(const std::string& name, int joined_size)
536536
}
537537
}
538538

539-
// If we are doing an allreduce or broadcast, check that all tensor shapes are
540-
// identical.
539+
// If we are doing an allreduce, broadcast, or reducescatter check that all
540+
// tensor shapes are identical.
541541
if (message_type == Request::ALLREDUCE ||
542542
message_type == Request::ADASUM ||
543-
message_type == Request::BROADCAST) {
543+
message_type == Request::BROADCAST ||
544+
message_type == Request::REDUCESCATTER) {
544545
TensorShape tensor_shape;
545546
for (auto dim : requests[0].tensor_shape()) {
546547
tensor_shape.AddDim(dim);
@@ -673,6 +674,19 @@ Response Controller::ConstructResponse(const std::string& name, int joined_size)
673674
}
674675
}
675676

677+
if (message_type == Request::REDUCESCATTER) {
678+
if (joined_size > 0) {
679+
error = true;
680+
error_message_stream << "Reducescatter is not supported with Join at this time.";
681+
}
682+
683+
TensorShape tensor_shape;
684+
for (auto dim : requests[0].tensor_shape()) {
685+
tensor_shape.AddDim(dim);
686+
}
687+
tensor_sizes.push_back(tensor_shape.num_elements());
688+
}
689+
676690
if (message_type == Request::ALLREDUCE || message_type == Request::ADASUM) {
677691
TensorShape tensor_shape;
678692
for (auto dim : requests[0].tensor_shape()) {
@@ -756,6 +770,12 @@ Response Controller::ConstructResponse(const std::string& name, int joined_size)
756770
response.set_response_type(Response::BROADCAST);
757771
} else if (message_type == Request::ALLTOALL) {
758772
response.set_response_type(Response::ALLTOALL);
773+
} else if (message_type == Request::REDUCESCATTER) {
774+
response.set_response_type(Response::REDUCESCATTER);
775+
for (auto dim : tensor_sizes) {
776+
response.add_tensor_size(dim);
777+
}
778+
response.set_tensor_type(data_type);
759779
} else if (message_type == Request::ADASUM) {
760780
response.set_response_type(Response::ADASUM);
761781
for (auto dim : tensor_sizes) {
@@ -815,7 +835,8 @@ void Controller::FuseResponses(std::deque<Response>& responses,
815835
responses.pop_front();
816836
int64_t tensor_size = 0;
817837
if (response.response_type() == Response::ResponseType::ALLREDUCE ||
818-
response.response_type() == Response::ResponseType::ADASUM) {
838+
response.response_type() == Response::ResponseType::ADASUM ||
839+
response.response_type() == Response::ResponseType::REDUCESCATTER) {
819840
// Attempt to add more responses to this fused response.
820841

821842
tensor_size = response.tensor_sizes()[0] * GetTypeSize(response.tensor_type());

horovod/common/message.cc

+6
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ const std::string& Request::RequestType_Name(RequestType value) {
102102
case RequestType::BROADCAST:
103103
static const std::string broadcast("BROADCAST");
104104
return broadcast;
105+
case RequestType::REDUCESCATTER:
106+
static const std::string reducescatter("REDUCESCATTER");
107+
return reducescatter;
105108
case RequestType::JOIN:
106109
static const std::string join("JOIN");
107110
return join;
@@ -294,6 +297,9 @@ const std::string& Response::ResponseType_Name(ResponseType value) {
294297
case ResponseType::BROADCAST:
295298
static const std::string broadcast("BROADCAST");
296299
return broadcast;
300+
case ResponseType::REDUCESCATTER:
301+
static const std::string reducescatter("REDUCESCATTER");
302+
return reducescatter;
297303
case ResponseType::JOIN:
298304
static const std::string join("JOIN");
299305
return join;

horovod/common/message.h

+17-3
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,16 @@ std::size_t DataType_Size(DataType value);
5050
class Request {
5151
public:
5252
enum RequestType {
53-
ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2, JOIN = 3, ADASUM = 4, ALLTOALL = 5, BARRIER = 6
53+
ALLREDUCE = 0,
54+
ALLGATHER = 1,
55+
BROADCAST = 2,
56+
JOIN = 3,
57+
ADASUM = 4,
58+
ALLTOALL = 5,
59+
BARRIER = 6,
60+
REDUCESCATTER = 7
5461
};
5562

56-
5763
static const std::string& RequestType_Name(RequestType value);
5864

5965
// The request rank is necessary to create a consistent ordering of results,
@@ -153,7 +159,15 @@ class RequestList {
153159
class Response {
154160
public:
155161
enum ResponseType {
156-
ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2, JOIN = 3, ADASUM = 4, ALLTOALL= 5, BARRIER=6, ERROR = 7
162+
ALLREDUCE = 0,
163+
ALLGATHER = 1,
164+
BROADCAST = 2,
165+
JOIN = 3,
166+
ADASUM = 4,
167+
ALLTOALL = 5,
168+
BARRIER = 6,
169+
REDUCESCATTER = 7,
170+
ERROR = 8
157171
};
158172

159173
static const std::string& ResponseType_Name(ResponseType value);

horovod/common/nvtx_op_range.h

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ enum class RegisteredNvtxOp {
1515
HorovodAllgather,
1616
HorovodBroadcast,
1717
HorovodAlltoall,
18+
HorovodReducescatter,
1819
// Insert new enum values above this line
1920
END,
2021
};

horovod/common/operations.cc

+82-1
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ OperationManager* CreateOperationManager(HorovodGlobalState& state) {
148148
std::vector<std::shared_ptr<AllreduceOp>> allreduce_ops;
149149
std::vector<std::shared_ptr<AllgatherOp>> allgather_ops;
150150
std::vector<std::shared_ptr<BroadcastOp>> broadcast_ops;
151+
std::vector<std::shared_ptr<ReducescatterOp>> reducescatter_ops;
151152
std::vector<std::shared_ptr<AllreduceOp>> adasum_ops;
152153
std::vector<std::shared_ptr<AlltoallOp>> alltoall_ops;
153154

@@ -180,6 +181,11 @@ OperationManager* CreateOperationManager(HorovodGlobalState& state) {
180181
alltoall_ops.push_back(
181182
std::shared_ptr<AlltoallOp>(new MPI_GPUAlltoall(&gpu_context, &state)));
182183
#endif
184+
185+
#if HOROVOD_GPU_REDUCESCATTER == 'M'
186+
reducescatter_ops.push_back(std::shared_ptr<ReducescatterOp>(
187+
new MPI_GPUReduceScatter(&gpu_context, &state)));
188+
#endif
183189
}
184190
#endif
185191

@@ -198,6 +204,11 @@ OperationManager* CreateOperationManager(HorovodGlobalState& state) {
198204
new NCCLAllgather(&nccl_context, &gpu_context, &state)));
199205
#endif
200206

207+
#if HAVE_NCCL && HOROVOD_GPU_REDUCESCATTER == 'N'
208+
reducescatter_ops.push_back(std::shared_ptr<ReducescatterOp>(
209+
new NCCLReducescatter(&nccl_context, &gpu_context, &state)));
210+
#endif
211+
201212
#if HAVE_NCCL && HOROVOD_GPU_ALLTOALL == 'N'
202213
alltoall_ops.push_back(std::shared_ptr<AlltoallOp>(
203214
new NCCLAlltoall(&nccl_context, &gpu_context, &state)));
@@ -213,6 +224,8 @@ OperationManager* CreateOperationManager(HorovodGlobalState& state) {
213224
std::shared_ptr<BroadcastOp>(new GlooBroadcast(&state)));
214225
alltoall_ops.push_back(
215226
std::shared_ptr<AlltoallOp>(new GlooAlltoall(&state)));
227+
reducescatter_ops.push_back(
228+
std::shared_ptr<ReducescatterOp>(new GlooReducescatter(&state)));
216229
}
217230
#endif
218231

@@ -240,6 +253,8 @@ OperationManager* CreateOperationManager(HorovodGlobalState& state) {
240253
std::shared_ptr<BroadcastOp>(new MPIBroadcast(&state)));
241254
alltoall_ops.push_back(
242255
std::shared_ptr<AlltoallOp>(new MPIAlltoall(&state)));
256+
reducescatter_ops.push_back(
257+
std::shared_ptr<ReducescatterOp>(new MPIReducescatter(&state)));
243258
}
244259
#endif
245260

@@ -249,7 +264,8 @@ OperationManager* CreateOperationManager(HorovodGlobalState& state) {
249264

250265
return new OperationManager(&state.parameter_manager, allreduce_ops,
251266
allgather_ops, broadcast_ops, alltoall_ops,
252-
join_op, adasum_ops, barrier_op, error_op);
267+
reducescatter_ops, join_op, adasum_ops,
268+
barrier_op, error_op);
253269
}
254270

255271
// Process a Response by doing a reduction, a gather, a broadcast, or
@@ -1637,6 +1653,71 @@ Status EnqueueTensorBroadcast(std::shared_ptr<OpContext> context,
16371653
return status;
16381654
}
16391655

1656+
// Contexts and controller must be initialized and the background thread
1657+
// must be running before this function is called.
1658+
Status EnqueueTensorReducescatter(std::shared_ptr<OpContext> context,
1659+
std::shared_ptr<Tensor> tensor,
1660+
ReadyEventList ready_event_list,
1661+
const std::string& name, const int device,
1662+
StatusCallback callback, ReduceOp reduce_op,
1663+
int32_t process_set_id) {
1664+
if (horovod_global.cpu_operation == LibType::CCL && device == CPU_DEVICE_ID) {
1665+
return Status::InvalidArgument(
1666+
"Reducescatter is not supported yet with oneCCL operations.");
1667+
}
1668+
if (!horovod_global.process_set_table.Contains(process_set_id)) {
1669+
return Status::InvalidArgument(
1670+
"Reducescatter: Process set provided does not "
1671+
"exist, or has not been registered.");
1672+
}
1673+
if (reduce_op != ReduceOp::SUM) {
1674+
// Note: AVERAGE is supported by enqueuing SUM and performing divide at the
1675+
// framework level.
1676+
LOG(ERROR, horovod_global.global_controller->GetRank())
1677+
<< "Reducescatter currently only supports SUM.";
1678+
return Status::Aborted("Reducescatter currently only supports SUM.");
1679+
}
1680+
if (horovod_global.shut_down) {
1681+
return SHUT_DOWN_ERROR;
1682+
}
1683+
auto& process_set = horovod_global.process_set_table.Get(process_set_id);
1684+
1685+
if (!process_set.IsCurrentProcessIncluded()) {
1686+
return Status::InvalidArgument(
1687+
"Reducescatter: Rank " +
1688+
std::to_string(horovod_global.global_controller->GetRank()) +
1689+
" is not a member of the provided process set.");
1690+
}
1691+
1692+
Request message;
1693+
message.set_request_rank(process_set.controller->GetRank());
1694+
message.set_tensor_name(name);
1695+
message.set_tensor_type(tensor->dtype());
1696+
message.set_device(device);
1697+
message.set_request_type(Request::REDUCESCATTER);
1698+
for (int i = 0; i < tensor->shape().dims(); ++i) {
1699+
message.add_tensor_shape((int64_t)tensor->shape().dim_size(i));
1700+
}
1701+
1702+
TensorTableEntry e;
1703+
e.tensor_name = name;
1704+
e.context = context;
1705+
e.tensor = tensor;
1706+
e.process_set_id = process_set_id;
1707+
e.ready_event_list = ready_event_list;
1708+
e.device = device;
1709+
e.callback = callback;
1710+
e.nvtx_op_range.Start(RegisteredNvtxOp::HorovodReducescatter,
1711+
e.tensor->size());
1712+
1713+
Status status = process_set.tensor_queue.AddToTensorQueue(e, message);
1714+
if (status.ok()) {
1715+
LOG(TRACE, horovod_global.global_controller->GetRank())
1716+
<< "Enqueued " << name;
1717+
}
1718+
return status;
1719+
}
1720+
16401721
// Contexts and controller must be initialized and the background thread
16411722
// must be running before this function is called.
16421723
Status EnqueueTensorAlltoall(std::shared_ptr<OpContext> context,

0 commit comments

Comments
 (0)