From 06dd9602aaae921db031686224be92b3d6a4ac90 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Fri, 24 Oct 2025 17:03:01 -0700 Subject: [PATCH] torchcomms/ncclx: expose Ctran transport to Python --- comms/torchcomms/ncclx/CMakeLists.txt | 6 +++- comms/torchcomms/ncclx/TorchCommNCCLXPy.cpp | 21 +++++++++++++ .../tests/integration/py/TransportTest.py | 31 +++++++++++++++++++ 3 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 comms/torchcomms/tests/integration/py/TransportTest.py diff --git a/comms/torchcomms/ncclx/CMakeLists.txt b/comms/torchcomms/ncclx/CMakeLists.txt index eace5d0..6e4aba2 100644 --- a/comms/torchcomms/ncclx/CMakeLists.txt +++ b/comms/torchcomms/ncclx/CMakeLists.txt @@ -1,6 +1,9 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # Extension: torchcomms._comms_ncclx -file(GLOB TORCHCOMMS_NCCLX_SOURCES "comms/torchcomms/ncclx/*.cpp") +file(GLOB TORCHCOMMS_NCCLX_SOURCES + "comms/torchcomms/ncclx/*.cpp" + "comms/torchcomms/transport/*.cc" +) file(GLOB TORCHCOMMS_CUDA_API_SOURCE "comms/torchcomms/device/CudaApi.cpp") find_package(CUDA) @@ -46,6 +49,7 @@ add_library(torchcomms_comms_ncclx MODULE ${TORCHCOMMS_NCCLX_SOURCES} ${TORCHCOMMS_CUDA_API_SOURCE} ) +target_compile_definitions(torchcomms_comms_ncclx PRIVATE MOCK_SCUBA_DATA CTRAN_DISABLE_TCPDM) set_target_properties(torchcomms_comms_ncclx PROPERTIES PREFIX "" OUTPUT_NAME "_comms_ncclx" diff --git a/comms/torchcomms/ncclx/TorchCommNCCLXPy.cpp b/comms/torchcomms/ncclx/TorchCommNCCLXPy.cpp index a06c04e..61667a1 100644 --- a/comms/torchcomms/ncclx/TorchCommNCCLXPy.cpp +++ b/comms/torchcomms/ncclx/TorchCommNCCLXPy.cpp @@ -1,5 +1,7 @@ // Copyright (c) Meta Platforms, Inc. and affiliates. +#include +#include #include #include #include @@ -7,13 +9,32 @@ #include #include "comms/torchcomms/ncclx/TorchCommNCCLX.hpp" +#include "comms/torchcomms/transport/RdmaTransport.h" namespace py = pybind11; using namespace torch::comms; +namespace { +folly::ScopedEventBaseThread& getScopedEventBaseThread() { + // This intentionally creates and leaks a global event base thread to be used + // for all Transports on first use. + static folly::ScopedEventBaseThread scopedEventBaseThread{"torchcomms_evb"}; + return scopedEventBaseThread; +} +} // namespace + PYBIND11_MODULE(_comms_ncclx, m) { m.doc() = "NCCLX specific python bindings for TorchComm"; py::class_>( m, "TorchCommNCCLX"); + + py::class_>(m, "RdmaTransport") + // initialize a new RDMATransport using a custom init fn + .def(py::init([](at::Device device) { + TORCH_INTERNAL_ASSERT(device.is_cuda()); + int cuda_device = device.index(); + return std::make_shared( + cuda_device, getScopedEventBaseThread().getEventBase()); + })); } diff --git a/comms/torchcomms/tests/integration/py/TransportTest.py b/comms/torchcomms/tests/integration/py/TransportTest.py new file mode 100644 index 0000000..8516546 --- /dev/null +++ b/comms/torchcomms/tests/integration/py/TransportTest.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# pyre-unsafe +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import itertools +import os +import unittest + +import torch +from torchcomms._comms_ncclx import RdmaTransport +from torchcomms.tests.integration.py.TorchCommTestHelpers import ( + get_dtype_name, + TorchCommTestWrapper, +) + + +class TransportTest(unittest.TestCase): + def setUp(self): + """Set up test environment before each test.""" + + def tearDown(self): + """Clean up after each test.""" + + def test_basic(self) -> None: + device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) + + transport = RdmaTransport(device) + + +if __name__ == "__main__" and os.environ["TEST_BACKEND"] == "ncclx": + unittest.main()