-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathsetup.py
More file actions
executable file
·140 lines (119 loc) · 4.45 KB
/
setup.py
File metadata and controls
executable file
·140 lines (119 loc) · 4.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import glob
import os
import re
import sys
import warnings
import pkg_resources
from setuptools import find_packages, setup
# FORCE_CUDA = os.getenv("FORCE_CUDA", "0") == "1"
BUILD_CPP = BUILD_CUDA = False
TORCH_VERSION = 0
try:
import torch
print(f"setup.py with torch {torch.__version__}")
from torch.utils.cpp_extension import BuildExtension, CppExtension
BUILD_CPP = True
from torch.utils.cpp_extension import CUDA_HOME, CUDAExtension
# CUDA not currently supported
BUILD_CUDA = False #(CUDA_HOME is not None) if torch.cuda.is_available() else FORCE_CUDA
_pt_version = pkg_resources.parse_version(torch.__version__)._version.release
if _pt_version is None or len(_pt_version) < 3:
raise AssertionError("unknown torch version")
TORCH_VERSION = int(_pt_version[0]) * 10000 + int(_pt_version[1]) * 100 + int(_pt_version[2])
except (ImportError, TypeError, AssertionError, AttributeError) as e:
warnings.warn(f"extension build skipped: {e}")
finally:
print(f"BUILD_CPP={BUILD_CPP}, BUILD_CUDA={BUILD_CUDA}, TORCH_VERSION={TORCH_VERSION}.")
def torch_parallel_backend():
try:
match = re.search(
"^ATen parallel backend: (?P<backend>.*)$",
torch._C._parallel_info(),
re.MULTILINE,
)
if match is None:
return None
backend = match.group("backend")
if backend == "OpenMP":
return "AT_PARALLEL_OPENMP"
if backend == "native thread pool":
return "AT_PARALLEL_NATIVE"
if backend == "native thread pool and TBB":
return "AT_PARALLEL_NATIVE_TBB"
except (NameError, AttributeError): # no torch or no binaries
warnings.warn("Could not determine torch parallel_info.")
return None
def omp_flags():
if sys.platform == "win32":
return ["/openmp"]
if sys.platform == "darwin":
# https://stackoverflow.com/questions/37362414/
# return ["-fopenmp=libiomp5"]
return []
return ["-fopenmp"]
def get_extensions():
# this_dir = os.path.dirname(os.path.abspath(__file__))
# ext_dir = os.path.join(this_dir, "torchmaxflow")
ext_dir = "torchmaxflow"
include_dirs = [ext_dir]
source_cpu = glob.glob(os.path.join(ext_dir, "**", "*.cpp"), recursive=True)
source_cuda = glob.glob(os.path.join(ext_dir, "**", "*.cu"), recursive=True)
print(source_cpu)
extension = None
define_macros = [(f"{torch_parallel_backend()}", 1)]
extra_compile_args = {}
extra_link_args = []
sources = source_cpu
if BUILD_CPP:
extension = CppExtension
extra_compile_args.setdefault("cxx", [])
if torch_parallel_backend() == "AT_PARALLEL_OPENMP":
extra_compile_args["cxx"] += omp_flags()
extra_link_args = omp_flags()
if BUILD_CUDA:
extension = CUDAExtension
sources += source_cuda
define_macros += [("WITH_CUDA", None)]
extra_compile_args = {"cxx": [], "nvcc": []}
if torch_parallel_backend() == "AT_PARALLEL_OPENMP":
extra_compile_args["cxx"] += omp_flags()
if extension is None or not sources:
return [] # compile nothing
# compile release
extra_compile_args["cxx"] += ["-g0"]
ext_modules = [
extension(
name="torchmaxflowcpp",
sources=sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
]
return ext_modules
with open("README.md", "r") as fh:
long_description = fh.read()
print(get_extensions())
with open("requirements.txt", "r") as fp:
install_requires = fp.read().splitlines()
setup(
name="torchmaxflow",
version="0.0.7",
description="torchmaxflow: Max-flow/Min-cut in PyTorch for 2D images and 3D volumes",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/masadcv/torchmaxflow",
author="Muhammad Asad",
author_email="[email protected]",
license="BSD-3-Clause License",
classifiers=[
"License :: OSI Approved :: BSD License",
"Programming Language :: Python :: 3",
],
install_requires=install_requires,
cmdclass={"build_ext": BuildExtension}, #.with_options(no_python_abi_suffix=True)},
packages=find_packages(exclude=("data", "docs", "examples", "scripts", "tests")),
# zip_safe=False,
ext_modules=get_extensions(),
)