diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm.h b/transformer_engine/common/include/transformer_engine/comm_gemm.h index 14cf56a002..7c36871f90 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm.h @@ -44,6 +44,14 @@ enum NVTECommGemmAlgoType { kNVTECommGemmAlgoAtomicMulticast = 4 }; +bool nvte_built_with_cublasmp() { +#ifdef NVTE_WITH_CUBLASMP + return true; +#else + return false; +#endif +} + /*! \brief Create a comm-gemm context. * * \param[in] comm NCCL communicator. diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index cffc411a0d..df8a8d2131 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -26,6 +26,8 @@ namespace transformer_engine { */ bool ubuf_built_with_mpi(); +enum class CommOverlapMethod { BULK = 0, PIPELINE = 1, RING_EXCHANGE = 2 }; + enum class CommOverlapType { RS = 0, AG = 1 }; enum class CommOverlapAlgo { diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index bce124e705..58bc2d5a99 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -8,6 +8,7 @@ #define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ #include +#include #include #include #include @@ -84,6 +85,11 @@ m, "Float8BlockScaleTensorFormat", pybind11::module_local()) \ .value("GEMM_READY", transformer_engine::Float8BlockScaleTensorFormat::GEMM_READY) \ .value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT); \ + pybind11::enum_(m, "CommOverlapMethod", \ + pybind11::module_local()) \ + .value("BULK", transformer_engine::CommOverlapMethod::BULK) \ + .value("PIPELINE", transformer_engine::CommOverlapMethod::PIPELINE) \ + .value("RING_EXCHANGE", transformer_engine::CommOverlapMethod::RING_EXCHANGE); \ pybind11::enum_(m, "CommOverlapType", \ pybind11::module_local()) \ .value("RS", transformer_engine::CommOverlapType::RS) \ @@ -135,6 +141,8 @@ }, \ py::call_guard(), py::arg("device_id") = -1); \ m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ + py::call_guard()); \ + m.def("nvte_built_with_cublasmp", &nvte_built_with_cublasmp, \ py::call_guard()); #endif diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 978bee52dc..d76def28c9 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 79fb798422..6a7a7b8e59 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -128,7 +128,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans py::handle quantizer, std::optional out_dtype, MaybeTensor bias, DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr, + bool use_split_accumulator, CommOverlapManager *comm_overlap = nullptr, std::optional comm_type = std::nullopt, MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false, float alpha = 1.0f, std::optional beta = std::nullopt); @@ -504,8 +504,9 @@ class CommOverlapHelper : torch::CustomClassHolder { CommOverlapHelper(); - CommOverlapHelper(c10d::ProcessGroup *world_group, - std::optional intra_node_group); + CommOverlapHelper(c10d::ProcessGroup *tp_group); + + CommOverlapHelper(c10d::ProcessGroup *world_group, c10d::ProcessGroup *intra_node_group); ~CommOverlapHelper(); @@ -513,39 +514,46 @@ class CommOverlapHelper : torch::CustomClassHolder { ExtComm comm); void ub_barrier(ExtComm comm); -}; - -class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase { - public: - CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, - CommOverlapHelper *helper, int tp_size, int num_splits = 3, - int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, - bool set_sm_margin = true, bool atomic_gemm = false, - bool rs_overlap_first_gemm = false); - - ~CommOverlap() {} - - void copy_into_buffer(const at::Tensor &input, bool local_chunk = false); - at::Tensor get_buffer(bool local_chunk = false, - std::optional> shape = std::nullopt); - - std::pair get_communication_stream(); + int64_t get_comm_ptr(std::string group = "world") { return pgs[group]->getCommPtr(); } +}; -}; // CommOverlap +class CommOverlapManager : torch::CustomClassHolder { + private: +#ifndef NVTE_WITH_CUBLASMP + transformer_engine::CommOverlapCore *_ctx; +#else + CommGemmCtx *_ctx; +#endif + transformer_engine::CommOverlapMethod _method; + int _num_comm_sm; + bool _use_atomic_gemm; -class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase { public: - CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, - CommOverlapHelper *helper, int tp_size, - transformer_engine::CommOverlapType comm_type, - int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 3, - bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true, - bool aggregate = false); - - ~CommOverlapP2P() {} + CommOverlapManager(transformer_engine::CommOverlapMethod method, + transformer_engine::CommOverlapType comm_type, + const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, int num_splits = 3, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, + bool set_sm_margin = false, bool atomic_gemm = false, + bool aggregate_ag = false, bool rs_overlap_first_gemm = false); + + ~CommOverlapManager() { +#ifdef NVTE_WITH_CUBLASMP + nvte_comm_gemm_ctx_destroy(_ctx); +#else + delete _ctx; +#endif; + } + + bool is_fp8_ubuf() { +#ifndef NVTE_WITH_CUBLASMP + return _ctx->is_fp8_ubuf(); +#else + return false; +#endif + } void copy_into_buffer(const at::Tensor &input, bool local_chunk = false); @@ -554,6 +562,11 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm std::pair get_communication_stream(); -}; // CommOverlapP2P + void execute(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, + transformer_engine::CommOverlapType comm_type, TensorWrapper &aux_out, + cudaStream_t stream); +}; // CommOverlapManager #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 38947c5a9d..179e52e626 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -21,57 +21,65 @@ namespace te = transformer_engine; CommOverlapHelper::CommOverlapHelper() { #ifndef NVTE_UB_WITH_MPI - NVTE_ERROR("Internal TE error: Dummy CommOverlapHelper init without NVTE_UB_WITH_MPI=1!"); + NVTE_ERROR("Internal TE error: CommOverlapHelper() requires NVTE_UB_WITH_MPI=1!"); #endif } // empty constructor for NVTE_UB_WITH_MPI=1 -CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, - std::optional intra_domain_group) { -#ifndef NVTE_UB_WITH_MPI +CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *tp_group) { +#ifndef NVTE_WITH_CUBLASMP + NVTE_ERROR("Internal TE error: CommOverlapHelper(tp_group) requires NVTE_WITH_CUBLASMP=1!"); +#endif + c10d::ProcessGroup::BackendType backend = tp_group->getBackendType(); + backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL); + NVTE_CHECK(backend_is_nccl, "Comm+GEMM overlap with cuBlasMp requires bootstrapping with NCCL."); + + myrank = tp_group->getRank(); + numranks = tp_group->getSize(); + pgs.insert({"tp", tp_group}); + initialized = true; +} +. // TP group constructor for NVTE_WITH_CUBLASMP=1 + + CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, + c10d::ProcessGroup *intra_domain_group) { +#if defined(NVTE_UB_WITH_MPI) + NVTE_ERROR("Internal TE error: CommOverlapHelper(world, intra_domain) is not supported with ", + "NVTE_UB_WITH_MPI=1!"); +#elif defined(NVTE_WITH_CUBLASMP) + NVTE_ERROR("Internal TE error: CommOverlapHelper(world, intra_domain) is not supported with ", + "NVTE_WITH_CUBLASMP=1!"); +#endif pgs.insert({"world", world_group}); myrank = pgs["world"]->getRank(); numranks = pgs["world"]->getSize(); c10d::ProcessGroup::BackendType backend = pgs["world"]->getBackendType(); backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL); - if (intra_domain_group.has_value()) { - // Get local rank on node and number of local ranks - NVTE_CHECK(intra_domain_group.value()->getBackendType() == backend, - "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", - "group!", pgs["world"]->getBackendName()); - pgs.insert({"intra", intra_domain_group.value()}); - mylocal = pgs["intra"]->getRank(); - numlocal = pgs["intra"]->getSize(); - - if (numlocal == numranks) { - // Intra-node group is same as the world group so there can only be 1 node - NVTE_CHECK( - mylocal == myrank, - "Internal TE error: Local rank must be equal to global rank when intra-node group size ", - "is equal to the world group size!"); - mynode = 0; - numnodes = 1; - } else { - // Get node ID and number of nodes - mynode = myrank / numlocal; - numnodes = numranks / numlocal; - } - } else { - // Intra-node group is not set so we assume there is only 1 node - mylocal = myrank; - numlocal = numranks; - pgs.insert({"intra", world_group}); - + // Get local rank on node and number of local ranks + NVTE_CHECK(intra_domain_group.value()->getBackendType() == backend, + "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", + "group!", pgs["world"]->getBackendName()); + pgs.insert({"intra", intra_domain_group.value()}); + mylocal = pgs["intra"]->getRank(); + numlocal = pgs["intra"]->getSize(); + + if (numlocal == numranks) { + // Intra-node group is same as the world group so there can only be 1 node + NVTE_CHECK( + mylocal == myrank, + "Internal TE error: Local rank must be equal to global rank when intra-node group size ", + "is equal to the world group size!"); mynode = 0; numnodes = 1; + } else { + // Get node ID and number of nodes + mynode = myrank / numlocal; + numnodes = numranks / numlocal; } initialized = true; -#else - NVTE_ERROR("Internal TE error: CommOverlapHelper cannot be initialized with valid PyTorch ", - "distributed process groups when TE is compiled with NVTE_UB_WITH_MPI=1!"); #endif -} +} // world + intra-node constructor for Userbuffers w/ PyTorch Distributed bootstrapping CommOverlapHelper::~CommOverlapHelper() { #ifndef NVTE_UB_WITH_MPI @@ -128,23 +136,43 @@ void CommOverlapHelper::ub_barrier(ExtComm group) { * CommOverlap **************************************************************************************************/ -CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, - CommOverlapHelper *helper, int tp_size, int num_splits, - int num_max_streams, int comm_cga_size, int gemm_priority, - int comm_priority, int num_comm_sm, bool set_sm_margin, bool atomic_gemm, - bool rs_overlap_first_gemm) - : te::CommOverlapBase(buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), - helper->myrank, helper->numranks, helper->mylocal, helper->numlocal, - helper->mynode, helper->numnodes, tp_size, - std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), - std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, - num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, - set_sm_margin, atomic_gemm, rs_overlap_first_gemm) {} +CommOverlapManager::CommOverlapManager(transformer_engine::CommOverlapMethod method, + transformer_engine::CommOverlapType comm_type, + const std::vector &buffer_shape, + at::ScalarType buffer_dtype, CommOverlapHelper *helper, + int tp_size, int num_splits, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool atomic_gemm, + bool aggregate_ag, bool rs_overlap_first_gemm) { +#ifdef NVTE_WITH_CUBLASMP + _ctx = nvte_comm_gemm_ctx_create(reinterpret_cast( + helper->get_comm_ptr("tp"), helper->numranks, helper->myrank, te::cuda::current_device())); +#else + if (method == te::CommOverlapMethod::RING_EXCHANGE) { + _ctx = reinterpret_cast(new te::CommOverlapP2PBase( + buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), helper->myrank, + helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, + tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, + comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, + atomic_gemm, aggregate)); + } else { + _ctx = reinterpret_cast(new te::CommOverlapBase( + buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), helper->myrank, + helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, + tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, num_max_streams, + comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, atomic_gemm, + rs_overlap_first_gemm)) + } +#endif +} /* ** Helper function to copy input to _ubuf */ -void CommOverlap::copy_into_buffer(const at::Tensor &input, bool local_chunk) { +void CommOverlapManager::copy_into_buffer(const at::Tensor &input, bool local_chunk) { +#ifndef NVTE_WITH_CUBLASMP const auto &input_ = input.contiguous(); // Check element size @@ -159,162 +187,198 @@ void CommOverlap::copy_into_buffer(const at::Tensor &input, bool local_chunk) { const void *src_ptr = input_.data_ptr(); // Userbuffers data - const size_t ubuf_size = _ubuf.numel(); - void *dst_ptr = _ubuf.dptr(); - if (local_chunk) { - NVTE_CHECK(input_size * _tp_size == ubuf_size, - "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", - "(input_size=", input_size, ", tensor_parallel_size=", _tp_size, - ", ubuf_size=", ubuf_size, ")"); - dst_ptr = (reinterpret_cast(dst_ptr) + (ubuf_size / _tp_size) * _tp_id * element_size); - } else { - NVTE_CHECK(input_size == ubuf_size, - "Tried to copy an invalid tensor into a Userbuffers buffer ", - "(input_size=", input_size, ", ubuf_size=", ubuf_size, ")"); - } - - // Copy data - auto stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); - NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size, - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); -} - -at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional> shape) { - // Check buffer shape - const size_t ubuf_size = _ubuf.numel(); - if (shape) { - const size_t requested_size = transformer_engine::pytorch::product(*shape); + if (_method == te::CommOverlapMethod::RING_EXCHANGE) { + void *dst_ptr; if (local_chunk) { - NVTE_CHECK(requested_size * _tp_size == ubuf_size, - "Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape, - ", tensor_parallel_size=", _tp_size, ", ubuf_size=", ubuf_size, ")"); + NVTE_CHECK(_ubufs[_tp_id].numel() == input_size, + "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", + "(input_size=", input_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); + dst_ptr = _ubufs[_tp_id].dptr(); } else { - NVTE_CHECK(requested_size == ubuf_size, - "Invalid shape for a Userbuffers buffer (requested shape=", *shape, - ", ubuf_size=", ubuf_size, ")"); + NVTE_CHECK(_ubuf.numel() == input_size, + "Tried to copy an invalid tensor into a Userbuffers buffer ", + "(input_size=", input_size, ", ubuf_size=", _ubuf.numel(), ")"); + dst_ptr = _ubuf.dptr(); } + + // Copy data + NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size, + cudaMemcpyDeviceToDevice, + (cudaStream_t)at::cuda::getCurrentCUDAStream())); } else { - int64_t dim0 = _ubuf.size(0); - int64_t dim1 = _ubuf.size(1); + const size_t ubuf_size = _ubuf.numel(); + void *dst_ptr = _ubuf.dptr(); if (local_chunk) { - dim0 /= _tp_size; + NVTE_CHECK(input_size * _tp_size == ubuf_size, + "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", + "(input_size=", input_size, ", tensor_parallel_size=", _tp_size, + ", ubuf_size=", ubuf_size, ")"); + dst_ptr = + (reinterpret_cast(dst_ptr) + (ubuf_size / _tp_size) * _tp_id * element_size); + } else { + NVTE_CHECK(input_size == ubuf_size, + "Tried to copy an invalid tensor into a Userbuffers buffer ", + "(input_size=", input_size, ", ubuf_size=", ubuf_size, ")"); } - shape = {dim0, dim1}; - } - // Data pointer - void *ubuf_ptr = _ubuf.dptr(); - if (local_chunk) { - ubuf_ptr = (reinterpret_cast(ubuf_ptr) + - (ubuf_size / _tp_size) * _tp_id * _ubuf.element_size()); + // Copy data + auto stream_main = at::cuda::getCurrentCUDAStream(); + NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size, + cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); } - - // Construct PyTorch tensor - const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype()); - return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA)); -} - -std::pair CommOverlap::get_communication_stream() { - // Return the same stream for both send and recv - return {at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device()), - at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device())}; +#endif } -/*************************************************************************************************** - * CommOverlapP2P - **************************************************************************************************/ - -CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, - CommOverlapHelper *helper, int tp_size, - te::CommOverlapType comm_type, int num_max_streams, - int comm_cga_size, int gemm_priority, int comm_priority, - int num_comm_sm, bool set_sm_margin, bool atomic_gemm, bool use_ce, - bool aggregate) - : te::CommOverlapP2PBase( - buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), helper->myrank, - helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, - tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), - std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, - comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, - atomic_gemm, aggregate) {} - -/* -** Copy input to _ubufs[0] -*/ -void CommOverlapP2P::copy_into_buffer(const at::Tensor &input, bool local_chunk) { - const auto &input_ = input.contiguous(); - - // Check element size - const size_t element_size = input.element_size(); - NVTE_CHECK(_ubuf.element_size() == element_size, - "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", - "(input dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(), - " bytes)"); +at::Tensor CommOverlapManager::get_buffer(bool local_chunk, + std::optional> shape) { +#ifndef NVTE_WITH_CUBLASMP + at::Tensor buffer_tensor; + if (_method == te::CommOverlapMethod::RING_EXCHANGE) { + // Check buffer shape + if (shape) { + const size_t requested_size = transformer_engine::pytorch::product(*shape); + if (local_chunk) { + NVTE_CHECK(requested_size == _ubufs[_tp_id].numel(), + + "Invalid shape for local chunk of a Userbuffers buffer (requested shape=", + *shape, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); + } else { + NVTE_CHECK(requested_size == _ubuf.numel(), + "Invalid shape for a Userbuffers buffer (requested shape=", *shape, + ", ubuf_size=", _ubuf.numel(), ")"); + } + } else { + int64_t dim0 = _ubuf.size(0); + int64_t dim1 = _ubuf.size(1); + if (local_chunk) { + dim0 /= _tp_size; + } + shape = {dim0, dim1}; + } - // Input data - const size_t input_size = input_.numel(); - const void *src_ptr = input_.data_ptr(); + // Data pointer + void *ubuf_ptr = local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr(); - // Userbuffers data - void *dst_ptr; - if (local_chunk) { - NVTE_CHECK(_ubufs[_tp_id].numel() == input_size, - "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", - "(input_size=", input_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); - dst_ptr = _ubufs[_tp_id].dptr(); + // Construct PyTorch tensor + const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype()); + buffer_tensor = torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA)); } else { - NVTE_CHECK(_ubuf.numel() == input_size, - "Tried to copy an invalid tensor into a Userbuffers buffer ", - "(input_size=", input_size, ", ubuf_size=", _ubuf.numel(), ")"); - dst_ptr = _ubuf.dptr(); - } - - // Copy data - NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, input_size * element_size, - cudaMemcpyDeviceToDevice, - (cudaStream_t)at::cuda::getCurrentCUDAStream())); -} - -at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional> shape) { - // Check buffer shape - if (shape) { - const size_t requested_size = transformer_engine::pytorch::product(*shape); - if (local_chunk) { - NVTE_CHECK(requested_size == _ubufs[_tp_id].numel(), - "Invalid shape for local chunk of a Userbuffers buffer (requested shape=", *shape, - ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); + // Check buffer shape + const size_t ubuf_size = _ubuf.numel(); + if (shape) { + const size_t requested_size = transformer_engine::pytorch::product(*shape); + if (local_chunk) { + NVTE_CHECK(requested_size * _tp_size == ubuf_size, + "Invalid shape for local chunk of a Userbuffers buffer (requested shape=", + *shape, ", tensor_parallel_size=", _tp_size, ", ubuf_size=", ubuf_size, ")"); + } else { + NVTE_CHECK(requested_size == ubuf_size, + "Invalid shape for a Userbuffers buffer (requested shape=", *shape, + ", ubuf_size=", ubuf_size, ")"); + } } else { - NVTE_CHECK(requested_size == _ubuf.numel(), - "Invalid shape for a Userbuffers buffer (requested shape=", *shape, - ", ubuf_size=", _ubuf.numel(), ")"); + int64_t dim0 = _ubuf.size(0); + int64_t dim1 = _ubuf.size(1); + if (local_chunk) { + dim0 /= _tp_size; + } + shape = {dim0, dim1}; } - } else { - int64_t dim0 = _ubuf.size(0); - int64_t dim1 = _ubuf.size(1); + + // Data pointer + void *ubuf_ptr = _ubuf.dptr(); if (local_chunk) { - dim0 /= _tp_size; + ubuf_ptr = (reinterpret_cast(ubuf_ptr) + + (ubuf_size / _tp_size) * _tp_id * _ubuf.element_size()); } - shape = {dim0, dim1}; - } - // Data pointer - void *ubuf_ptr = local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr(); + // Construct PyTorch tensor + const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype()); + buffer_tensor = torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA)); + } - // Construct PyTorch tensor - const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype()); - return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA)); + return buffer_tensor; +#else + // Return dummy tensor, will not be used with cuBlasMp + const auto dtype = transformer_engine::pytorch::GetATenDType(DType::kByte); + return torch::from_blob(nullptr, std::vector{0}, at::dtype(dtype).device(torch::kCUDA)); +#endif } -std::pair CommOverlapP2P::get_communication_stream() { - return {at::cuda::getStreamFromExternal(_stream_send[0], at::cuda::current_device()), - at::cuda::getStreamFromExternal(_stream_recv, at::cuda::current_device())}; +at::Stream CommOverlapManager::get_communication_stream() { + return at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device()); } -void transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm( - CommOverlap &allgather_communicator, at::Stream send_stream, at::Stream recv_stream) { - auto main_stream = at::cuda::getCurrentCUDAStream(); - allgather_communicator.bulk_overlap_external_ag(at::cuda::CUDAStream(send_stream), - at::cuda::CUDAStream(recv_stream), main_stream); +void CommOverlapManager::execute(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + te::CommOverlapType comm_type, TensorWrapper &aux_out, + cudaStream_t stream) { +#ifdef NVTE_WITH_CUBLASMP + if (_method == te::CommOverlapMethod::BULK) { + NVTE_ERROR("Bulk overlap is not supported with cuBlasMp."); + } else { + cublasMpMatmulAlgoType_t algo = CUBLASMP_MATMUL_ALGO_TYPE_DEFAULT; + if (_method == te::CommOverlapMethod::RING_EXCHANGE) { + algo = (_use_atomic_gemm) ? CUBLASMP_MATMUL_ALGO_ATOMIC_P2P : CUBLASMP_MATMUL_ALGO_SPLIT_P2P; + } else if (_method == te::CommOverlapMethod::PIPELINE) { + algo = (_use_atomic_gemm) ? CUBLASMP_MATMUL_ALGO_ATOMIC_MULTICAST + : CUBLASMP_MATMUL_ALGO_SPLIT_MULTICAST; + } + + // Tensor dimms in row-major order + auto A_shape = A.shape(); + const int A0 = product(A_shape, 0, A_shape.ndim - 1); + const int A1 = A_shape.data[A_shape.ndim - 1]; + auto B_shape = B.shape(); + const int B0 = product(B_shape, 0, B_shape.ndim - 1); + const int B1 = B_shape.data[B_shape.ndim - 1]; + + // GEMM dims in column-major order + const int m = (transa) ? A0 : A1; + const int n = (transb) ? B1 : B0; + const int k = (transa) ? A1 : A0; + + if (comm_type == te::CommOverlapType::AG) { + n *= _ctx->nranks; // convert all-gathered dimension to global size + NVTE_CHECK_CUBLASMP(nvte_all_gather_gemm(_ctx, m, n, k, A.data(), B.data(), D.data(), + bias.data(), pre_gelu_out.data(), transa, transb, + grad, accumulate, _num_comm_sm, stream, algo)); + } else { + k *= _ctx->nranks; // convert contracting dimension to global size + NVTE_CHECK_CUBLASMP(nvte_gemm_reduce_scatter(_ctx, m, n, k, A.data(), B.data(), D.data(), + bias.data(), pre_gelu_out.data(), transa, transb, + grad, accumulate, _num_comm_sm, stream, algo)); + } + } +#else + if (_method == te::CommOverlapMethod::BULK) { + _ctx->bulk_overlap(A.data(), transa, B.data(), transb, D.data(), bias.data(), + pre_gelu_out.data(), workspace.data(), grad, accumulate, + use_split_accumulator, comm_type, aux_out.data(), stream); + } else if (comm_type == te::CommOverlapType::AG) { + if (_use_atomic_gemm) { + _ctx->atomic_gemm_overlap_ag(A.data(), transa, B.data(), transb, D.data(), bias.data(), + pre_gelu_out.data(), workspace.data(), grad, accumulate, + use_split_accumulator, aux_out.data(), stream); + } else { + _ctx->split_overlap_ag(A.data(), transa, B.data(), transb, D.data(), bias.data(), + pre_gelu_out.data(), workspace.data(), grad, accumulate, + use_split_accumulator, aux_out.data(), stream); + } + } else { + if (_use_atomic_gemm) { + _ctx->atomic_gemm_overlap_rs(A.data(), transa, B.data(), transb, D.data(), bias.data(), + pre_gelu_out.data(), workspace.data(), grad, accumulate, + use_split_accumulator, aux_out.data(), stream); + } else { + _ctx->split_overlap_rs(A.data(), transa, B.data(), transb, D.data(), bias.data(), + pre_gelu_out.data(), workspace.data(), grad, accumulate, + use_split_accumulator, aux_out.data(), stream); + } + } +#endif } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 15404ad9a6..704cde70a0 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -90,7 +90,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans py::handle quantizer, std::optional out_dtype, MaybeTensor bias, DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, CommOverlapCore* comm_overlap, + bool use_split_accumulator, CommOverlapManager* comm_overlap, std::optional comm_type, MaybeTensor extra_output, bool bulk_overlap, float alpha, std::optional beta) { using namespace transformer_engine::pytorch::detail; @@ -262,47 +262,12 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans makeTransformerEngineTensor(nullptr, std::vector{0}, DType::kByte); } - // Direct GEMM call to the correct overlap - if (bulk_overlap) { - NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, - te_pre_gelu_out, te_workspace, grad, accumulate, - use_split_accumulator, comm_type.value(), extra_output_tensor, - main_stream); - }); - } else if (comm_type.value() == CommOverlapType::AG) { - if (comm_overlap->is_atomic_gemm()) { - NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, - bias_tensor, te_pre_gelu_out, te_workspace, grad, - accumulate, use_split_accumulator, - extra_output_tensor, main_stream); - }); - } else { - NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, - bias_tensor, te_pre_gelu_out, te_workspace, grad, - accumulate, use_split_accumulator, extra_output_tensor, - main_stream); - }); - } - } else { - if (comm_overlap->is_atomic_gemm()) { - NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, - bias_tensor, te_pre_gelu_out, te_workspace, grad, - accumulate, use_split_accumulator, - extra_output_tensor, main_stream); - }); - } else { - NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, - bias_tensor, te_pre_gelu_out, te_workspace, grad, - accumulate, use_split_accumulator, extra_output_tensor, - main_stream); - }); - } - } + NVTE_SCOPED_GIL_RELEASE({ + comm_overlap->execute(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + te_pre_gelu_out, te_workspace, grad, accumulate, + use_split_accumulator, comm_type.value(), extra_output_tensor, + main_stream); + }); } else { // Launch GEMM NVTE_SCOPED_GIL_RELEASE({ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 3b81393dbd..8235f1cdd0 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -463,41 +463,26 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::class_(m, "CommOverlapHelper") .def(py::init<>(), py::call_guard()) + .def(py::init(), py::call_guard(), + py::arg("tp_group")) .def(py::init>(), py::call_guard(), py::arg("world_group"), - py::arg("intra_node_group") = py::none()); + py::arg("intra_node_group")); - py::class_, transformer_engine::CommOverlapBase, - transformer_engine::CommOverlapCore>(m, "CommOverlap") - .def(py::init &, at::ScalarType, CommOverlapHelper *, int, int, int, - int, int, int, int, bool, bool, bool>(), + py::class_(m, "CommOverlapManager") + .def(py::init &, + at::ScalarType, CommOverlapHelper *, int, int, int, int, int, int, int, bool, + bool, bool>(), py::call_guard(), py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false) - .def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"), - py::arg("local_chunk") = false) - .def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false, + .def("is_fp8_ubuf", &CommOverlapManager::is_fp8_ubuf); + .def("copy_into_buffer", &CommOverlapManager::copy_into_buffer, py::arg("input"), + py::arg("local_chunk") = false) + .def("get_buffer", &CommOverlapManager::get_buffer, py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) - .def("get_communication_stream", &CommOverlap::get_communication_stream); - - py::class_, - transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>( - m, "CommOverlapP2P") - .def(py::init &, at::ScalarType, CommOverlapHelper *, int, - transformer_engine::CommOverlapType, int, int, int, int, int, bool, bool, bool, - bool>(), - py::call_guard(), py::arg("buffer_shape"), - py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("comm_type"), - py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1, - py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1, - py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, - py::arg("use_ce") = true, py::arg("aggregate") = false) - .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), - py::arg("local_chunk") = false) - .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false, - py::arg("shape") = std::nullopt) - .def("get_communication_stream", &CommOverlapP2P::get_communication_stream); + .def("get_communication_stream", &CommOverlapManager::get_communication_stream); } diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 9b6ca9d9cd..110116629a 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -212,6 +212,7 @@ def initialize_ub( assert _ub_communicators is None, "UB communicators are already initialized." _ub_communicators = {} + helper = None if tex.ubuf_built_with_mpi(): # We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force # an MPI_Init() here by creating a new MPI process group... @@ -226,10 +227,15 @@ def initialize_ub( ), "torch.distributed must be initialized before Userbuffers" if bootstrap_backend is None: bootstrap_backend = "nccl" - if torch.distributed.is_mpi_available(): - bootstrap_backend = "mpi" - elif torch.distributed.is_gloo_available(): - bootstrap_backend = "gloo" + if not tex.nvte_built_with_cublasmp(): + if torch.distributed.is_mpi_available(): + bootstrap_backend = "mpi" + elif torch.distributed.is_gloo_available(): + bootstrap_backend = "gloo" + elif tex.nvte_built_with_cublasmp(): + assert ( + bootstrap_backend == "nccl" + ), 'Comm+GEMM overlap w/ cuBlasMp needs `bootstrap_backend="nccl"`.' else: assert bootstrap_backend in [ "gloo", @@ -257,14 +263,20 @@ def initialize_ub( local_rank = torch.distributed.get_rank(tp_domain_group) tp_domain_ranks = torch.distributed.get_process_group_ranks(tp_domain_group) - helper = tex.CommOverlapHelper(world_group, tp_domain_group) + if tex.nvte_built_with_cublasmp(): + helper = tex.CommOverlapHelper(tp_domain_group) + else: + helper = tex.CommOverlapHelper(world_group, tp_domain_group) else: # TP model on single NVLink domain, no replication, no data-parallelism mydomain_idx = 0 local_rank = world_rank tp_domain_ranks = list(range(world_size)) - helper = tex.CommOverlapHelper(world_group) + if tex.nvte_built_with_cublasmp(): + helper = tex.CommOverlapHelper(world_group) + else: + helper = tex.CommOverlapHelper(world_group, world_group) if world_rank == 0: print(f"!!! [UB] Number of TP domains: {num_domains}\n", end="", flush=True) @@ -410,45 +422,28 @@ def add_ub( f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method" ) - buffer_dtype = ( - torch.uint8 - if (quantization_mode == UserBufferQuantizationMode.FP8 and fp8_buf) - else dtype + _ub_communicators[(name, quantization_mode)] = tex.CommOverlapManager( + ( + tex.CommOverlapMethod.RING_EXCHANGE + if method == "ring_exchange" + else tex.CommOverlapMethod.PIPELINE + ), + tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG, + shape, + torch.uint8 if (use_fp8 and fp8_buf) else dtype, + helper, + tp_size, + num_splits=num_splits, + num_max_streams=_NUM_MAX_UB_STREAMS, + comm_cga_size=cga_size, + gemm_priority=gemm_priority, + comm_priority=comm_priority, + num_comm_sm=num_sm, + set_sm_margin=set_sm_margin, + atomic_gemm=atomic_gemm, + aggregate_ag=aggregate, + rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm, ) - if method == "ring_exchange": - ub_obj = tex.CommOverlapP2P( - shape, # Communication buffer shape - buffer_dtype, # Communication buffer data type - helper, # Helper for torch.distributed callbacks during bootstrapping - tp_size, # Tensor-parallel group size (may be different than local_size) - tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG, - num_max_streams=_NUM_MAX_UB_STREAMS, - comm_cga_size=cga_size, - num_comm_sm=num_sm, - set_sm_margin=set_sm_margin, - atomic_gemm=atomic_gemm, - use_ce=use_ce, - aggregate=aggregate, - gemm_priority=gemm_priority, - comm_priority=comm_priority, - ) - else: - ub_obj = tex.CommOverlap( - shape, # Communication buffer shape - buffer_dtype, # Communication buffer data type - helper, # Helper for torch.distributed callbacks during bootstrapping - tp_size, # Tensor-parallel group size (may be different than local_size) - num_splits=num_splits, - num_max_streams=_NUM_MAX_UB_STREAMS, - comm_cga_size=cga_size, - num_comm_sm=num_sm, - set_sm_margin=set_sm_margin, - atomic_gemm=atomic_gemm, - gemm_priority=gemm_priority, - comm_priority=comm_priority, - rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm, - ) - _ub_communicators[(name, quantization_mode)] = ub_obj for quantization_mode, user_ub_cfg in zip(quantization_modes, ub_cfgs): if user_ub_cfg is not None: @@ -517,6 +512,16 @@ def fill_userbuffers_buffer_for_all_gather( tensor's metadata, e.g. scaling factors. """ + # cuBlasMp handles AG for the scaling factor so we just need the local tensor here + if tex.nvte_built_with_cublasmp(): + tensor_is_quantized = isinstance(local_tensor, QuantizedTensorBase) + if quantizer is None and tensor_is_quantized: + # Dequantize quantized tensor if quantizer is None + local_tensor = local_tensor.dequantize() + elif quantizer is not None and not tensor_is_quantized: + # Quantize unquantized tensor if quantizer is not None + local_tensor = quantizer(local_tensor) + return local_tensor, local_tensor # Tensor dimensions local_shape = local_tensor.size()