Skip to content

Add PTv3 base code (as well as what we'll need to change for DiPTv3 #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: trailing-whitespace
exclude: "tests/testdata/"
- id: end-of-file-fixer
- id: end-of-file-fixer
exclude: "tests/testdata/"
- id: check-yaml
- id: check-added-large-files
args: ['--maxkb=5000']
- repo: https://github.com/kynan/nbstripout
- id: check-yaml
- id: check-added-large-files
args: ["--maxkb=5000"]
- repo: https://github.com/kynan/nbstripout
rev: 0.6.0
hooks:
- id: nbstripout
- repo: https://github.com/myint/autoflake
- id: nbstripout
- repo: https://github.com/myint/autoflake
rev: v2.1.1
hooks:
- id: autoflake
- id: autoflake
args:
- --in-place
- --remove-all-unused-imports
- repo: https://github.com/timothycrosley/isort
rev: 5.12.0
- --in-place
- --remove-all-unused-imports
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/psf/black
- id: isort
name: isort (python)
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
- id: black
49 changes: 46 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,71 @@ List of networks implemented:

*All of the following should be done inside your venv or conda environment.*

Strongly recommend using a conda env, because it makes installing flash attention MUCH easier.

## Prerequisites

### DiPTv3

1. Choose a CUDA version you'll want to use! Should be compatible with everything... CUDA 12.4 is good, but so is 11.8.

2. Install CUDA toolkit inside your conda env, so that nvcc is available. For instance:

```bash
conda install -c "nvidia/label/cuda-12.4.0" cuda-toolkit
```

2. Choose a torch version (including GPU) and install it. See [here](https://pytorch.org/get-started/locally/) for instructions.
For instance:

```bash
conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
```
3. Install an appropriate torch_scatter version. See [here](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html)
For instance:

````bash
pip install torch_scatter -f https://data.pyg.org/whl/torch-2.5.0+cu124.html
````

4. Install spconv.

```bash
# spconv (SparseUNet)
# refer https://github.com/traveller59/spconv
pip install spconv-cu124 # choose version match your local cuda version
```

5. Install flash_attention.

```bash
pip install flash-attn --no-build-isolation
```

## Option 1: No changes needed (simplest).

If you want to use the architecutres as-is (or their building blocks), you can just do the following:

````bash

pip install git+hhttps://github.com/r-pad/nets.git#egg=nets@master
export NETS_MODULE=diptv3
pip install 'git+https://github.com/r-pad/nets.git#egg=nets[$NETS_MODULE]@main'

````

You can replace `master` with a specific branch or tag.
You can replace `main` with a specific branch or tag.

## Option 2: Need to be able to modify the code.

If you want to be able to modify the code, but don't want to contribute back, you can do the following:

````bash

export NETS_MODULE=diptv3
cd $CODE_DIRECTORY
git clone https://github.com/r-pad/nets.git
cd nets
pip install -e .
pip install -e '.[$NETS_MODULE]'

````

Expand Down
44 changes: 37 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ readme = "README.md"
requires-python = ">=3.6"
license = { file = "LICENSE.txt" }
authors = [{ email = "[email protected]", name = "Ben Eisner" }]
dependencies = []
dependencies = ["torch"]

[build-system]
requires = ["setuptools >= 62.3.2", "setuptools-scm", "wheel"]
Expand All @@ -22,29 +22,59 @@ develop = [
"pytest == 7.3.2",
"pre-commit == 3.3.3",
]
diptv3 = [
"addict",
"timm",
"torch_scatter",

# flash_attention is required, but you need special build config.

# spconv is also required, but instead of something sane it has https://github.com/traveller59/spconv
# where you have to install a per-cuda package. smh.
]
notebooks = ["jupyter"]
build_docs = ["mkdocs-material", "mkdocstrings[python]"]

# This is required to allow us to have notebooks/ at the top level.
[tool.setuptools.packages.find]
where = ["src"]
where = ["src", "third_party"]

[tool.black]
exclude = "third_party/*,src/rpad/nets/diptv3.py"

[tool.setuptools.package-data]
rpad = ["py.typed"]

[tool.isort]
profile = "black"
skip = ["third_party"]

[tool.autoflake]
# exclude third_party and src/rpad/nets/diptv3.py
exclude = "third_party/*"

[tool.pytest.ini_options]
addopts = "--ignore=third_party/"
testpaths = "tests"

[tool.mypy]
python_version = 3.8
python_version = "3.8"
warn_return_any = true
warn_unused_configs = true
mypy_path = "src"
namespace_packages = true
explicit_package_bases = true

# # Uncomment this when you have imports for mypy to ignore.
# [[tool.mypy.overrides]]
# module = [
# ]
# ignore_missing_imports = true
[[tool.mypy.overrides]]
module = [
"addict.*",
"PointTransformerV3.*",
"spconv.*",
"timm.*",
"torch_scatter.*",
]
ignore_missing_imports = true

[tool.pyright]
extraPaths = ['./third_party']
Loading
Loading