@@ -888,6 +888,86 @@ std::shared_ptr<TorchWork> TorchCommNCCLX::reduce_scatter(
888888 return work;
889889}
890890
891+ std::shared_ptr<TorchWork> TorchCommNCCLX::reduce_scatter_v (
892+ at::Tensor& output,
893+ const std::vector<at::Tensor>& input_list,
894+ ReduceOp op,
895+ bool async_op,
896+ const ReduceScatterOptions& options) {
897+ checkInitialized ();
898+ checkAndAbortIfTimedOutOrError ();
899+ ensureTensorContiguous (output);
900+
901+ if (input_list.size () != static_cast <size_t >(comm_size_)) {
902+ throw std::runtime_error (
903+ " input_list size must equal comm_size for reduce_scatter" );
904+ }
905+
906+ // Check that all input tensors are contiguous and have correct size
907+ for (const auto & t : input_list) {
908+ ensureTensorContiguous (t);
909+ }
910+
911+ TorchCommTracingGuard tracingGuard (
912+ name_, comm_size_, " reduce_scatter" , rank_, input_list, {output});
913+
914+ cudaStream_t stream = getOperationStream (async_op);
915+ auto work = createWork (
916+ stream,
917+ getOperationTimeout (options.timeout , options_.timeout ),
918+ input_list);
919+
920+ work->recordStart ();
921+
922+ // Use multiple reduce operations for reduce_scatter
923+ nccl_api_->groupStart ();
924+
925+ for (int i = 0 ; i < comm_size_; ++i) {
926+ const auto dataType = getNcclDataType (input_list[i]);
927+ if (i == rank_) {
928+ // This rank receives the reduced result
929+ // assign input/output tensor to support vector reduce_scatter
930+ // (reduce_scatter_v) where inputs are reduced and scattered unevenly
931+ // among participating ranks
932+ auto & input_tensor = input_list[i];
933+ auto & output_tensor = output;
934+ if (input_tensor.numel () != output_tensor.numel ()) {
935+ throw std::runtime_error (
936+ " Output tensor size must equal input tensor size for all_gather" );
937+ }
938+ nccl_api_->reduce (
939+ input_tensor.data_ptr (),
940+ output_tensor.data_ptr (),
941+ output_tensor.numel (),
942+ dataType,
943+ getNcclReduceOp (op, nccl_comm_, dataType),
944+ i,
945+ nccl_comm_,
946+ stream);
947+ } else {
948+ // Other ranks contribute to the reduction
949+ nccl_api_->reduce (
950+ input_list[i].data_ptr (),
951+ nullptr , // Non-root ranks don't receive
952+ input_list[i].numel (),
953+ dataType,
954+ getNcclReduceOp (op, nccl_comm_, dataType),
955+ i,
956+ nccl_comm_,
957+ stream);
958+ }
959+ }
960+
961+ nccl_api_->groupEnd ();
962+
963+ work->recordEnd ();
964+
965+ // Enqueue the work after events have been recorded
966+ enqueueWork (work, stream);
967+
968+ return work;
969+ }
970+
891971std::shared_ptr<TorchWork> TorchCommNCCLX::reduce_scatter_single (
892972 at::Tensor& output,
893973 const at::Tensor& input,
0 commit comments