Skip to content

Commit 3b65b52

Browse files
tianfengfrankmeta-codesync[bot]
authored andcommitted
add reduce_scatter_v support
Summary: tp_overlapping requires to work with uneven_split introduced by D84788079. To support that, we need reduce_scatter_v in torchcomm - enable reduce_scatter_v to support various tensor size of input_tensor list - add both cpp/py integration UTs Reviewed By: d4l3k Differential Revision: D85297838 fbshipit-source-id: 210969573cbec89341825939016a3826ac850331
1 parent 73a225d commit 3b65b52

20 files changed

+547
-0
lines changed

comms/torchcomms/TorchComm.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,15 @@ std::shared_ptr<TorchWork> TorchComm::reduce_scatter(
114114
return impl_->reduce_scatter(output, input_list, op, async_op, options);
115115
}
116116

117+
std::shared_ptr<TorchWork> TorchComm::reduce_scatter_v(
118+
at::Tensor& output,
119+
const std::vector<at::Tensor>& input_list,
120+
ReduceOp op,
121+
bool async_op,
122+
const ReduceScatterOptions& options) {
123+
return impl_->reduce_scatter_v(output, input_list, op, async_op, options);
124+
}
125+
117126
std::shared_ptr<TorchWork> TorchComm::reduce_scatter_single(
118127
at::Tensor& output,
119128
const at::Tensor& input,

comms/torchcomms/TorchComm.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ class TorchComm {
8181
ReduceOp op,
8282
bool async_op,
8383
const ReduceScatterOptions& options = {});
84+
std::shared_ptr<TorchWork> reduce_scatter_v(
85+
at::Tensor& output,
86+
const std::vector<at::Tensor>& input_list,
87+
ReduceOp op,
88+
bool async_op,
89+
const ReduceScatterOptions& options = {});
8490
std::shared_ptr<TorchWork> reduce_scatter_single(
8591
at::Tensor& output,
8692
const at::Tensor& input,

comms/torchcomms/TorchCommBackend.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ class TorchCommBackend {
9494
ReduceOp op,
9595
bool async_op,
9696
const ReduceScatterOptions& options = {}) = 0;
97+
virtual std::shared_ptr<TorchWork> reduce_scatter_v(
98+
at::Tensor& output,
99+
const std::vector<at::Tensor>& input_list,
100+
ReduceOp op,
101+
bool async_op,
102+
const ReduceScatterOptions& options = {}) = 0;
97103
virtual std::shared_ptr<TorchWork> reduce_scatter_single(
98104
at::Tensor& output,
99105
const at::Tensor& input,

comms/torchcomms/TorchCommPy.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,43 @@ Reduce, then scatter a list of tensors to all ranks.
745745
op: Reduction operation.
746746
async_op: Whether to perform the operation asynchronously.
747747
hints: Dictionary of string hints for backend-specific options.
748+
timeout: Timeout for the operation.
749+
)",
750+
py::arg("output"),
751+
py::arg("input_list"),
752+
py::arg("op"),
753+
py::arg("async_op"),
754+
py::arg("hints") = std::nullopt,
755+
py::arg("timeout") = std::nullopt,
756+
py::call_guard<py::gil_scoped_release>())
757+
.def(
758+
"reduce_scatter_v",
759+
[](TorchComm& self,
760+
at::Tensor& output,
761+
const std::vector<at::Tensor>& input_list,
762+
ReduceOp op,
763+
bool async_op,
764+
std::optional<std::unordered_map<std::string, std::string>> hints,
765+
std::optional<std::chrono::milliseconds> timeout) {
766+
ReduceScatterOptions opts;
767+
if (hints) {
768+
opts.hints = *hints;
769+
}
770+
if (timeout) {
771+
opts.timeout = *timeout;
772+
}
773+
return self.reduce_scatter_v(
774+
output, input_list, op, async_op, opts);
775+
},
776+
R"(
777+
Reduce, then scatter a list of tensors to all ranks, supporting variable tensor sizes per rank.
778+
779+
Args:
780+
output: Output tensor on each rank; size may differ per rank.
781+
input_list: List of tensors to reduce and scatter; the list is the same on all ranks, but tensor sizes may differ between indices.
782+
op: Reduction operation.
783+
async_op: Whether to perform the operation asynchronously
784+
hints: Dictionary of string hints for backend-specific options.
748785
timeout: Timeout for the operation.
749786
)",
750787
py::arg("output"),

comms/torchcomms/_comms.pyi

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,15 @@ class TorchComm:
261261
hints: Dict[str, str] | None = None,
262262
timeout: timedelta | None = None,
263263
) -> TorchWork: ...
264+
def reduce_scatter_v(
265+
self,
266+
output: Any,
267+
input_list: List[Any],
268+
op: ReduceOp,
269+
async_op: bool,
270+
hints: Dict[str, str] | None = None,
271+
timeout: timedelta | None = None,
272+
) -> TorchWork: ...
264273
def reduce_scatter_single(
265274
self,
266275
output: Any,

comms/torchcomms/gloo/TorchCommGloo.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,16 @@ std::shared_ptr<TorchWork> TorchCommGloo::reduce_scatter(
815815
return reduce_scatter_single(output, input, op, async_op, singleOptions);
816816
}
817817

818+
std::shared_ptr<TorchWork> TorchCommGloo::reduce_scatter_v(
819+
at::Tensor& output,
820+
const std::vector<at::Tensor>& input_list,
821+
ReduceOp op,
822+
bool async_op,
823+
const ReduceScatterOptions& options) {
824+
throw std::runtime_error(
825+
"reduce_scatter_v is not supported in GLOO backend yet");
826+
}
827+
818828
std::shared_ptr<TorchWork> TorchCommGloo::reduce_scatter_single(
819829
at::Tensor& output,
820830
const at::Tensor& input,

comms/torchcomms/gloo/TorchCommGloo.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ class TorchCommGloo : public TorchCommBackend,
106106
ReduceOp op,
107107
bool async_op,
108108
const ReduceScatterOptions& options = {}) override;
109+
std::shared_ptr<TorchWork> reduce_scatter_v(
110+
at::Tensor& output,
111+
const std::vector<at::Tensor>& input_list,
112+
ReduceOp op,
113+
bool async_op,
114+
const ReduceScatterOptions& options = {}) override;
109115
std::shared_ptr<TorchWork> reduce_scatter_single(
110116
at::Tensor& output,
111117
const at::Tensor& input,

comms/torchcomms/nccl/TorchCommNCCL.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,15 @@ std::shared_ptr<TorchWork> TorchCommNCCL::reduce_scatter(
818818
return work;
819819
}
820820

821+
std::shared_ptr<TorchWork> TorchCommNCCL::reduce_scatter_v(
822+
at::Tensor& output,
823+
const std::vector<at::Tensor>& input_list,
824+
ReduceOp op,
825+
bool async_op,
826+
const ReduceScatterOptions& options) {
827+
throw std::runtime_error("reduce_scatter_v is not supported in NCCL backend");
828+
}
829+
821830
std::shared_ptr<TorchWork> TorchCommNCCL::reduce_scatter_single(
822831
at::Tensor& output,
823832
const at::Tensor& input,

comms/torchcomms/nccl/TorchCommNCCL.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ class TorchCommNCCL : public TorchCommBackend,
123123
ReduceOp op,
124124
bool async_op,
125125
const ReduceScatterOptions& options = {}) override;
126+
std::shared_ptr<TorchWork> reduce_scatter_v(
127+
at::Tensor& output,
128+
const std::vector<at::Tensor>& input_list,
129+
ReduceOp op,
130+
bool async_op,
131+
const ReduceScatterOptions& options = {}) override;
126132
std::shared_ptr<TorchWork> reduce_scatter_single(
127133
at::Tensor& output,
128134
const at::Tensor& input,

comms/torchcomms/ncclx/TorchCommNCCLX.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,86 @@ std::shared_ptr<TorchWork> TorchCommNCCLX::reduce_scatter(
888888
return work;
889889
}
890890

891+
std::shared_ptr<TorchWork> TorchCommNCCLX::reduce_scatter_v(
892+
at::Tensor& output,
893+
const std::vector<at::Tensor>& input_list,
894+
ReduceOp op,
895+
bool async_op,
896+
const ReduceScatterOptions& options) {
897+
checkInitialized();
898+
checkAndAbortIfTimedOutOrError();
899+
ensureTensorContiguous(output);
900+
901+
if (input_list.size() != static_cast<size_t>(comm_size_)) {
902+
throw std::runtime_error(
903+
"input_list size must equal comm_size for reduce_scatter");
904+
}
905+
906+
// Check that all input tensors are contiguous and have correct size
907+
for (const auto& t : input_list) {
908+
ensureTensorContiguous(t);
909+
}
910+
911+
TorchCommTracingGuard tracingGuard(
912+
name_, comm_size_, "reduce_scatter", rank_, input_list, {output});
913+
914+
cudaStream_t stream = getOperationStream(async_op);
915+
auto work = createWork(
916+
stream,
917+
getOperationTimeout(options.timeout, options_.timeout),
918+
input_list);
919+
920+
work->recordStart();
921+
922+
// Use multiple reduce operations for reduce_scatter
923+
nccl_api_->groupStart();
924+
925+
for (int i = 0; i < comm_size_; ++i) {
926+
const auto dataType = getNcclDataType(input_list[i]);
927+
if (i == rank_) {
928+
// This rank receives the reduced result
929+
// assign input/output tensor to support vector reduce_scatter
930+
// (reduce_scatter_v) where inputs are reduced and scattered unevenly
931+
// among participating ranks
932+
auto& input_tensor = input_list[i];
933+
auto& output_tensor = output;
934+
if (input_tensor.numel() != output_tensor.numel()) {
935+
throw std::runtime_error(
936+
"Output tensor size must equal input tensor size for all_gather");
937+
}
938+
nccl_api_->reduce(
939+
input_tensor.data_ptr(),
940+
output_tensor.data_ptr(),
941+
output_tensor.numel(),
942+
dataType,
943+
getNcclReduceOp(op, nccl_comm_, dataType),
944+
i,
945+
nccl_comm_,
946+
stream);
947+
} else {
948+
// Other ranks contribute to the reduction
949+
nccl_api_->reduce(
950+
input_list[i].data_ptr(),
951+
nullptr, // Non-root ranks don't receive
952+
input_list[i].numel(),
953+
dataType,
954+
getNcclReduceOp(op, nccl_comm_, dataType),
955+
i,
956+
nccl_comm_,
957+
stream);
958+
}
959+
}
960+
961+
nccl_api_->groupEnd();
962+
963+
work->recordEnd();
964+
965+
// Enqueue the work after events have been recorded
966+
enqueueWork(work, stream);
967+
968+
return work;
969+
}
970+
891971
std::shared_ptr<TorchWork> TorchCommNCCLX::reduce_scatter_single(
892972
at::Tensor& output,
893973
const at::Tensor& input,

0 commit comments

Comments
 (0)