Skip to content

Commit 0f97f26

Browse files
committed
improve import speed
1 parent 8d64d98 commit 0f97f26

File tree

3 files changed

+21
-10
lines changed

3 files changed

+21
-10
lines changed

python/jittor/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# This file is subject to the terms and conditions defined in
88
# file 'LICENSE.txt', which is part of this source code package.
99
# ***************************************************************
10-
__version__ = '1.2.0.1'
10+
__version__ = '1.2.0.2'
1111
from . import lock
1212
with lock.lock_scope():
1313
from . import compiler

python/jittor/compiler.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from jittor_utils import LOG, run_cmd, cache_path, find_exe, cc_path, cc_type, cache_path
1919
from . import pyjt_compiler
2020
from . import lock
21+
from jittor import __version__
2122

2223
def find_jittor_path():
2324
return os.path.dirname(__file__)
@@ -615,7 +616,7 @@ def compile_custom_ops(
615616
if len(gen_name) > 100:
616617
gen_name = gen_name[:80] + "___hash" + str(hash(gen_name))
617618

618-
includes = set(includes)
619+
includes = sorted(list(set(includes)))
619620
includes = "".join(map(lambda x: f" -I'{x}' ", includes))
620621
LOG.vvvv(f"Include flags:{includes}")
621622

@@ -828,6 +829,8 @@ def check_debug_flags():
828829
check_debug_flags()
829830

830831
sys.path.append(cache_path)
832+
LOG.i(f"Jittor({__version__}) src: {jittor_path}")
833+
LOG.i(f"cache_path: {cache_path}")
831834

832835
with jit_utils.import_scope(import_flags):
833836
jit_utils.try_import_jit_utils_core()

python/jittor_utils/__init__.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -149,22 +149,30 @@ def do_compile(args):
149149

150150
pool_size = 0
151151

152+
def pool_cleanup():
153+
global p
154+
p.__exit__(None, None, None)
155+
del p
156+
152157
def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
153-
global pool_size
158+
global pool_size, p
159+
bk = mp.current_process()._config.get('daemon')
160+
mp.current_process()._config['daemon'] = False
154161
if pool_size == 0:
155162
mem_bytes = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
156163
mem_gib = mem_bytes/(1024.**3)
157164
pool_size = min(16,max(int(mem_gib // 3), 1))
158165
LOG.i(f"Total mem: {mem_gib:.2f}GB, using {pool_size} procs for compiling.")
166+
p = Pool(pool_size)
167+
p.__enter__()
168+
import atexit
169+
atexit.register(pool_cleanup)
159170
cmds = [ [cmd, cache_path, jittor_path] for cmd in cmds ]
160-
bk = mp.current_process()._config.get('daemon')
161-
mp.current_process()._config['daemon'] = False
162171
try:
163-
with Pool(pool_size) as p:
164-
n = len(cmds)
165-
dp = DelayProgress(msg, n)
166-
for i,_ in enumerate(p.imap_unordered(do_compile, cmds)):
167-
dp.update(i)
172+
n = len(cmds)
173+
dp = DelayProgress(msg, n)
174+
for i,_ in enumerate(p.imap_unordered(do_compile, cmds)):
175+
dp.update(i)
168176
finally:
169177
mp.current_process()._config['daemon'] = bk
170178

0 commit comments

Comments
 (0)