Skip to content

[Setup] Making tensordict pytorch-agnostic #1256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 53 commits into
base: gh/vmoens/49/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .github/scripts/linux-post-script.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
#!/bin/bash

yum update gcc
yum update libstdc++
if [ "$(uname)" != "Darwin" ]; then
yum update gcc
yum update libstdc++
else
brew update
brew upgrade gcc
fi
18 changes: 18 additions & 0 deletions .github/scripts/linux-pre-script.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/bin/bash

if [ "$(uname)" != "Darwin" ]; then
yum update gcc
yum update libstdc++
else
echo $(gcc --version)
echo $(clang --version)
brew update
brew upgrade gcc
brew upgrade clang

# For OSX
# export CXXFLAGS="-march=armv8-a+fp16+sha3"
export CMAKE_OSX_ARCHITECTURES=arm64
fi

${CONDA_RUN} conda install -c conda-forge pybind11 -y
8 changes: 8 additions & 0 deletions .github/scripts/version_script.bat
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
set TENSORDICT_BUILD_VERSION=0.8.0
echo TENSORDICT_BUILD_VERSION is set to %TENSORDICT_BUILD_VERSION%

if "%CONDA_RUN%"=="" (
echo CONDA_RUN is not set. Please activate your conda environment or set CONDA_RUN.
exit /b 1
)

:: Run the pip install command
%CONDA_RUN% conda install -c conda-forge pybind11 -y

@echo on

set VC_VERSION_LOWER=17
Expand Down
17 changes: 17 additions & 0 deletions .github/scripts/version_script.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
#!/bin/bash

export TENSORDICT_BUILD_VERSION=0.8.0

if [ "$(uname)" == "Darwin" ]; then
# For OSX
echo $(gcc --version)
echo $(clang --version)
brew update
brew install gcc
brew install clang-build-analyzer
brew install --cask clay
brew install llvm
# brew upgrade gcc
# brew upgrade clang
# export CXXFLAGS="-march=armv8-a+fp16+sha3"
export CMAKE_OSX_ARCHITECTURES=arm64
fi

${CONDA_RUN} conda install -c conda-forge pybind11 -y
17 changes: 17 additions & 0 deletions .github/scripts/win-pre-script.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
@echo off
:: Check if CONDA_RUN is set, if not, set it to a default value
if "%CONDA_RUN%"=="" (
echo CONDA_RUN is not set. Please activate your conda environment or set CONDA_RUN.
exit /b 1
)

:: Run the pip install command
%CONDA_RUN% conda install -c conda-forge pybind11 -y

:: Check if the installation was successful
if errorlevel 1 (
echo Failed to install cmake and pybind11.
exit /b 1
) else (
echo Successfully installed cmake and pybind11.
)
1 change: 1 addition & 0 deletions .github/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ dependencies:
- coverage
- h5py
- orjson
- ninja
- numpy<2.0.0
3 changes: 3 additions & 0 deletions .github/unittest/linux/scripts/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ pip install pip --upgrade

conda env update --file "${this_dir}/environment.yml" --prune

conda install anaconda::cmake -y
conda install -c conda-forge pybind11 -y

#if [[ $OSTYPE == 'darwin'* ]]; then
# printf "* Installing C++ for OSX\n"
# conda install -c conda-forge cxx-compiler -y
Expand Down
1 change: 1 addition & 0 deletions .github/unittest/rl_linux_optdeps/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ dependencies:
- pyyaml
- scipy
- orjson
- ninja
- numpy<2.0.0
3 changes: 3 additions & 0 deletions .github/unittest/rl_linux_optdeps/scripts/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ pip install pip --upgrade

conda env update --file "${this_dir}/environment.yml" --prune

conda install anaconda::cmake -y
conda install -c conda-forge pybind11 -y

#yum makecache
#yum -y install glfw-devel
#yum -y install libGLEW
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-wheels-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
include:
- repository: pytorch/tensordict
smoke-test-script: test/smoke_test.py
pre-script: .github/scripts/linux-pre-script.sh
post-script: .github/scripts/linux-post-script.sh
package-name: tensordict
name: pytorch/tensordict
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/build-wheels-m1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ jobs:
include:
- repository: pytorch/tensordict
smoke-test-script: test/smoke_test.py
env-script: .github/scripts/install-deps-smoke-test.sh
pre-script: .github/scripts/linux-pre-script.sh
post-script: .github/scripts/linux-post-script.sh
package-name: tensordict
name: pytorch/tensordict
uses: pytorch/test-infra/.github/workflows/build_wheels_macos.yml@main
Expand All @@ -44,7 +45,7 @@ jobs:
test-infra-ref: main
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
package-name: ${{ matrix.package-name }}
runner-type: macos-m1-stable
runner-type: macos-m2-15
smoke-test-script: ${{ matrix.smoke-test-script }}
trigger-event: ${{ github.event_name }}
env-var-script: .github/scripts/version_script.sh
2 changes: 1 addition & 1 deletion .github/workflows/build-wheels-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
matrix:
include:
- repository: pytorch/tensordict
pre-script: ""
pre-script: .github/scripts/win-pre-script.bat
env-script: .github/scripts/version_script.bat
post-script: "python packaging/wheel/relocate.py"
smoke-test-script: test/smoke_test.py
Expand Down
5 changes: 5 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ possible.
Install the library as suggested in the README. For advanced features,
it is preferable to install the nightly built of pytorch.

You will need the following packages to be installed:
```bash
pip install ninja cmake pybind11 -U
```

Make sure you install tensordict in develop mode by running
```
python setup.py develop
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ include LICENSE

recursive-exclude * __pycache__
recursive-exclude * *.py[co]
recursive-include tensordict *.so
103 changes: 56 additions & 47 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import argparse
import distutils.command.clean
import glob
import logging
import os
import shutil
Expand All @@ -15,11 +14,23 @@
from pathlib import Path
from typing import List

from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext

ROOT_DIR = Path(__file__).parent.resolve()


def get_python_executable():
# Check if we're running in a virtual environment
if "VIRTUAL_ENV" in os.environ:
# Get the virtual environment's Python executable
python_executable = os.path.join(os.environ["VIRTUAL_ENV"], "bin", "python")
else:
# Fall back to sys.executable
python_executable = sys.executable
return python_executable


try:
sha = (
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=ROOT_DIR)
Expand Down Expand Up @@ -69,7 +80,7 @@ def _get_pytorch_version(is_nightly, is_local):
return "torch>=2.7.0.dev"
if is_local:
return "torch"
return "torch>=2.6.0"
return "torch>=2.5.0"


def _get_packages():
Expand Down Expand Up @@ -99,51 +110,45 @@ def run(self):
shutil.rmtree(str(path), ignore_errors=True)


def get_extensions():
extension = CppExtension

extra_link_args = []
extra_compile_args = {
"cxx": [
"-O3",
"-std=c++17",
"-fdiagnostics-color=always",
class CMakeExtension(Extension):
def __init__(self, name, sourcedir=""):
super().__init__(name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)


class CMakeBuild(build_ext):
def run(self):
for ext in self.extensions:
self.build_extension(ext)

def build_extension(self, ext):
# extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
is_develop = self.distribution.get_command_obj("develop").finalized
# Set the output directory based on the mode
if is_develop:
extdir = os.path.abspath(os.path.join(ROOT_DIR, "tensordict"))
else:
extdir = os.path.abspath(os.path.join(self.build_lib, "tensordict"))
cmake_args = [
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}",
f"-DPYTHON_EXECUTABLE={get_python_executable()}",
f"-DPython3_EXECUTABLE={get_python_executable()}",
]
}
debug_mode = os.getenv("DEBUG", "0") == "1"
if debug_mode:
logging.info("Compiling in debug mode")
extra_compile_args = {
"cxx": [
"-O0",
"-fno-inline",
"-g",
"-std=c++17",
"-fdiagnostics-color=always",
]
}
extra_link_args = ["-O0", "-g"]

this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "tensordict", "csrc")

extension_sources = {
os.path.join(extensions_dir, p)
for p in glob.glob(os.path.join(extensions_dir, "*.cpp"))
}
sources = list(extension_sources)

ext_modules = [
extension(
"tensordict._C",
sources,
include_dirs=[this_dir],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,

build_args = []
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
subprocess.check_call(
["cmake", ext.sourcedir] + cmake_args, cwd=self.build_temp
)
subprocess.check_call(
["cmake", "--build", "."] + build_args, cwd=self.build_temp
)
]

return ext_modules

def get_extensions():
extensions_dir = os.path.join(ROOT_DIR, "tensordict", "csrc")
return [CMakeExtension("tensordict._C", sourcedir=extensions_dir)]


def _main(argv):
Expand Down Expand Up @@ -181,7 +186,7 @@ def _main(argv):
),
ext_modules=get_extensions(),
cmdclass={
"build_ext": BuildExtension.with_options(),
"build_ext": CMakeBuild,
"clean": clean,
},
install_requires=[
Expand Down Expand Up @@ -212,6 +217,10 @@ def _main(argv):
"Programming Language :: Python :: 3.13",
"Development Status :: 4 - Beta",
],
# include_package_data=True,
package_data={
"tensordict": ["*.so", "*.pyd"],
},
)


Expand Down
34 changes: 34 additions & 0 deletions tensordict/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
cmake_minimum_required(VERSION 3.12)
project(tensordict)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# Set the Python executable to the one from your virtual environment

find_package(Python3 REQUIRED COMPONENTS Interpreter Development)
find_package(pybind11 2.13 REQUIRED)

file(GLOB SOURCES "*.cpp")

add_library(_C MODULE ${SOURCES})

set_target_properties(_C PROPERTIES
OUTPUT_NAME "_C"
PREFIX "" # Remove 'lib' prefix
SUFFIX ".so" # Ensure correct suffix for macOS/Linux
)
set_target_properties(_C PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}"
)

target_include_directories(_C PRIVATE ${PROJECT_SOURCE_DIR})
target_link_libraries(_C PRIVATE Python3::Python pybind11::module)

if(CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0 -fsanitize=address")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
endif()
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum OS X deployment version")
set(CMAKE_VERBOSE_MAKEFILE ON)
Loading