Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion comms/torchcomms/ncclx/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Extension: torchcomms._comms_ncclx
file(GLOB TORCHCOMMS_NCCLX_SOURCES "comms/torchcomms/ncclx/*.cpp")
file(GLOB TORCHCOMMS_NCCLX_SOURCES
"comms/torchcomms/ncclx/*.cpp"
"comms/torchcomms/transport/*.cc"
)
file(GLOB TORCHCOMMS_CUDA_API_SOURCE "comms/torchcomms/device/CudaApi.cpp")

find_package(CUDA)
Expand Down Expand Up @@ -46,6 +49,7 @@ add_library(torchcomms_comms_ncclx MODULE
${TORCHCOMMS_NCCLX_SOURCES}
${TORCHCOMMS_CUDA_API_SOURCE}
)
target_compile_definitions(torchcomms_comms_ncclx PRIVATE MOCK_SCUBA_DATA CTRAN_DISABLE_TCPDM)
set_target_properties(torchcomms_comms_ncclx PROPERTIES
PREFIX ""
OUTPUT_NAME "_comms_ncclx"
Expand Down
21 changes: 21 additions & 0 deletions comms/torchcomms/ncclx/TorchCommNCCLXPy.cpp
Original file line number Diff line number Diff line change
@@ -1,19 +1,40 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.

#include <folly/io/async/EventBase.h>
#include <folly/io/async/ScopedEventBaseThread.h>
#include <pybind11/chrono.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/csrc/utils/pybind.h>

#include "comms/torchcomms/ncclx/TorchCommNCCLX.hpp"
#include "comms/torchcomms/transport/RdmaTransport.h"

namespace py = pybind11;
using namespace torch::comms;

namespace {
folly::ScopedEventBaseThread& getScopedEventBaseThread() {
// This intentionally creates and leaks a global event base thread to be used
// for all Transports on first use.
static folly::ScopedEventBaseThread scopedEventBaseThread{"torchcomms_evb"};
return scopedEventBaseThread;
}
} // namespace

PYBIND11_MODULE(_comms_ncclx, m) {
m.doc() = "NCCLX specific python bindings for TorchComm";

py::class_<TorchCommNCCLX, std::shared_ptr<TorchCommNCCLX>>(
m, "TorchCommNCCLX");

py::class_<RdmaTransport, std::shared_ptr<RdmaTransport>>(m, "RdmaTransport")
// initialize a new RDMATransport using a custom init fn
.def(py::init([](at::Device device) {
TORCH_INTERNAL_ASSERT(device.is_cuda());
int cuda_device = device.index();
return std::make_shared<RdmaTransport>(
cuda_device, getScopedEventBaseThread().getEventBase());
}));
}
31 changes: 31 additions & 0 deletions comms/torchcomms/tests/integration/py/TransportTest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/usr/bin/env python3
# pyre-unsafe
# Copyright (c) Meta Platforms, Inc. and affiliates.

import itertools
import os
import unittest

import torch
from torchcomms._comms_ncclx import RdmaTransport
from torchcomms.tests.integration.py.TorchCommTestHelpers import (
get_dtype_name,
TorchCommTestWrapper,
)


class TransportTest(unittest.TestCase):
def setUp(self):
"""Set up test environment before each test."""

def tearDown(self):
"""Clean up after each test."""

def test_basic(self) -> None:
device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))

transport = RdmaTransport(device)


if __name__ == "__main__" and os.environ["TEST_BACKEND"] == "ncclx":
unittest.main()
Loading