Skip to content

Commit e25a8ac

Browse files
committed
fix numpy code op test
1 parent ed65ba7 commit e25a8ac

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

python/jittor/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# This file is subject to the terms and conditions defined in
88
# file 'LICENSE.txt', which is part of this source code package.
99
# ***************************************************************
10-
__version__ = '1.1.7.4'
10+
__version__ = '1.1.7.5'
1111
from . import lock
1212
with lock.lock_scope():
1313
from . import compiler

python/jittor/test/test_numpy_code_op.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ def check():
5858
one=numpy.ones(a.shape)
5959
assert numpy.allclose(da.data,one*2.0)
6060

61-
jt.flags.use_cuda = 0
62-
check()
63-
jt.flags.use_cuda = 1
61+
if jt.has_cuda:
62+
with jt.flag_scope(use_cuda=1):
63+
check()
6464
check()
6565

6666
def test(self):
@@ -92,9 +92,9 @@ def check():
9292
one=numpy.ones(a.shape)
9393
assert numpy.allclose(da.data,one*2.0)
9494

95-
jt.flags.use_cuda = 0
96-
check()
97-
jt.flags.use_cuda = 1
95+
if jt.has_cuda:
96+
with jt.flag_scope(use_cuda=1):
97+
check()
9898
check()
9999

100100
def test_multi_input(self):
@@ -139,9 +139,9 @@ def check():
139139
assert numpy.allclose(dda.data,one)
140140
assert numpy.allclose(ddb.data,mone)
141141

142-
jt.flags.use_cuda = 0
143-
check()
144-
jt.flags.use_cuda = 1
142+
if jt.has_cuda:
143+
with jt.flag_scope(use_cuda=1):
144+
check()
145145
check()
146146

147147
@unittest.skipIf(True, "Memory leak testing is not in progress, Skip")

python/jittor/test/test_parallel_pass.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def check(self, use_int32):
3636
b = jt.random((n, n))
3737
a.data, b.data
3838
with jt.profile_scope(compile_options = {
39-
"compile_shapes":1, "parallel":1
39+
"compile_shapes":1, "parallel":2, "try_use_32bit_index":use_int32
4040
}, try_use_32bit_index = use_int32) as rep:
4141
c = a + b
4242
nc = c.data

0 commit comments

Comments
 (0)