Skip to content

Commit 1ad00d4

Browse files
committed
better error control && fix doc && fix free buffer
1 parent f9e2901 commit 1ad00d4

15 files changed

+112
-29
lines changed

doc/source/conf.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
author = 'Jittor'
2727

2828
# The full version, including alpha/beta/rc tags
29-
release = '1.1.3.1'
29+
release = jittor.__version__
30+
# fix AttributeError for "typing.get_type_hints(jt.Var)"
31+
jittor.Var.__module__ = "jittor_core"
3032

3133
# The language for content autogenerated by Sphinx. Refer to documentation
3234
# for a list of supported languages.

python/jittor/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# This file is subject to the terms and conditions defined in
88
# file 'LICENSE.txt', which is part of this source code package.
99
# ***************************************************************
10-
__version__ = '1.1.7.5'
10+
__version__ = '1.1.7.6'
1111
from . import lock
1212
with lock.lock_scope():
1313
from . import compiler

python/jittor/compiler.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ def check_debug_flags():
798798
global cc_flags
799799
cc_flags += " -g -DNODE_MEMCHECK "
800800

801-
cc_flags = " " + os.environ.get("cc_flags", "")
801+
cc_flags = " "
802802
# os.RTLD_NOW | os.RTLD_GLOBAL cause segfault when import torch first
803803
import_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
804804
# if cc_type=="icc":
@@ -841,6 +841,8 @@ def check_debug_flags():
841841

842842
cc_flags += " -Wall -Werror -Wno-unknown-pragmas -std=c++14 -fPIC -march=native "
843843
cc_flags += " -fdiagnostics-color=always "
844+
if "cc_flags" in os.environ:
845+
cc_flags += os.environ["cc_flags"] + ' '
844846
link_flags = " -lstdc++ -ldl -shared "
845847
core_link_flags = ""
846848
opt_flags = ""

python/jittor/misc.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,11 @@ def unbind(x, dim=0):
246246
247247
Example:
248248
249-
jt.random((3,3))
249+
a = jt.random((3,3))
250+
b = jt.unbind(a, 0)
250251
251252
'''
252-
if dim < 0: dim += len(input.shape)
253+
if dim < 0: dim += len(x.shape)
253254
return [x[(slice(None),)*dim+(i,)] for i in range(x.shape[dim])]
254255

255256
def make_grid(x, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0):
@@ -261,4 +262,4 @@ def make_grid(x, nrow=8, padding=2, normalize=False, range=None, scale_each=Fals
261262
ncol = math.ceil(b / nrow)
262263
return x.reindex([c, h*ncol+(ncol+1)*padding, w*nrow+(nrow+1)*padding],
263264
[f"i1/{padding+h}*{nrow}+i2/{padding+w}", "i0",
264-
f"i1-i1/{padding+h}*{padding+h}-{padding}", f"i2-i2/{padding+w}*{padding+w}-{padding}"], overflow_value=pad_value)
265+
f"i1-i1/{padding+h}*{padding+h}-{padding}", f"i2-i2/{padding+w}*{padding+w}-{padding}"], overflow_value=pad_value)

python/jittor/nn.py

-2
Original file line numberDiff line numberDiff line change
@@ -734,13 +734,11 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
734734
Example:
735735
736736
>>> x = jt.array([[[[1,2],[3,4]]]])
737-
738737
>>> print(x)
739738
[[[[1 2]
740739
[3 4]]]]
741740
742741
>>> grid = jt.array([[[[0.5, 0.5]]]])
743-
744742
>>> print(x.shape, grid.shape)
745743
[1,1,2,2,], [1,1,2,2,]
746744
+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# ***************************************************************
2+
# Copyright (c) 2020 Jittor. Authors:
3+
# Meng-Hao Guo <[email protected]>
4+
# Dun Liang <[email protected]>.
5+
#
6+
# All Rights Reserved.
7+
# This file is subject to the terms and conditions defined in
8+
# file 'LICENSE.txt', which is part of this source code package.
9+
# ***************************************************************
10+
import jittor as jt
11+
import unittest
12+
import sys, os
13+
from subprocess import getoutput
14+
15+
class TestLazyExecution(unittest.TestCase):
16+
@unittest.skipIf(not jt.has_cuda, "No cuda found")
17+
def test_lazy_execution(self):
18+
code = """
19+
import jittor as jt
20+
jt.flags.use_cuda = 1
21+
22+
a = jt.zeros(1)
23+
b = jt.code([1], a.dtype, [a],
24+
cuda_header='''
25+
#include <assert.h>
26+
''',
27+
cuda_src='''
28+
__global__ void kernel(float32* a, float32* b) {
29+
b[0] = a[0];
30+
assert(a[0] == 1);
31+
}
32+
kernel<<<1,1>>>(in0_p, out0_p);
33+
''')
34+
c = a+b
35+
print(c)
36+
"""
37+
fpath = os.path.join(jt.flags.cache_path, "lazy_error.py")
38+
with open(fpath, 'w') as f:
39+
f.write(code)
40+
res = getoutput(f"{sys.executable} {fpath}")
41+
assert 'print(c)' in res
42+
res = getoutput(f"lazy_execution=0 {sys.executable} {fpath}")
43+
assert "''')" in res
44+
45+
46+
47+
if __name__ == "__main__":
48+
unittest.main()

python/jittor/test/test_numpy_code_op.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,14 @@
1010
from jittor import Function
1111
import jittor as jt
1212
import numpy
13-
import cupy
1413
import ctypes
1514
import sys
1615

16+
try:
17+
import cupy
18+
except:
19+
pass
20+
1721
class TestCodeOp(unittest.TestCase):
1822
def test_func(self):
1923
class Func(Function):

python/jittor/utils/polish.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from jittor import LOG
2121
from jittor.compiler import run_cmd
2222
from jittor_utils import translator
23+
import sys
2324

2425
jittor_path = os.path.realpath(os.path.join(jt.flags.jittor_path, "..", ".."))
2526

@@ -68,12 +69,14 @@
6869
env = f"cache_name=build/{cc_type}/{device} cc_path="
6970
cname = "g++" if cc_type=="g++" else "clang-8"
7071
env += cname
71-
env += " "
72+
# use core2 arch, avoid using avx instructions
73+
# TODO: support more archs, such as arm, or use ir(GIMPLE or LLVM)
74+
env += " cc_flags='-march=core2' "
7275
if device == "cpu":
7376
env += "nvcc_path='' "
7477
elif jt.flags.nvcc_path == "":
7578
env = "unset nvcc_path && " + env
76-
cmd = f"{env} python3.7 -c 'import jittor'"
79+
cmd = f"{env} {sys.executable} -c 'import jittor'"
7780
LOG.i("run cmd:", cmd)
7881
os.system(cmd)
7982
LOG.i("run cmd:", cmd)

python/jittor/version

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
a62b45d6caf9c1c18a9118630ec8a591c576e635
1+
f9e290160bead0d5892754da56b9ad63bc316320

src/event_queue.cc

+8-2
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,16 @@ void EventQueue::Worker::start() {
2525
}
2626

2727
void EventQueue::worker_caller() {
28-
event_queue.func();
28+
int status = OK;
29+
try {
30+
event_queue.func();
31+
} catch (const std::exception& e) {
32+
LOGe << "Catch error:\n" >> e.what();
33+
status = ERROR;
34+
}
2935
{
3036
std::lock_guard<std::mutex> l(event_queue.mtx);
31-
event_queue.run_sync_done = true;
37+
event_queue.run_sync_done = status;
3238
}
3339
}
3440

src/event_queue.h

+8-5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
namespace jittor {
1313

1414
struct EventQueue {
15+
static constexpr int RUNNING = 0;
16+
static constexpr int OK = 1;
17+
static constexpr int ERROR = 2;
1518
typedef void(*Func)();
1619
struct Worker {
1720
Func todo;
@@ -39,7 +42,7 @@ struct EventQueue {
3942
std::condition_variable cv;
4043
std::mutex mtx;
4144
Func func;
42-
volatile bool run_sync_done;
45+
volatile int run_sync_done;
4346

4447
inline void flush() {
4548
list<Func> ts;
@@ -53,11 +56,11 @@ struct EventQueue {
5356

5457
static void worker_caller();
5558

56-
void run_sync(Func func) {
59+
int run_sync(Func func) {
5760
// send work to worker and do something by self
5861
std::unique_lock<std::mutex> l(mtx);
5962
this->func = func;
60-
run_sync_done = false;
63+
run_sync_done = RUNNING;
6164
// send func to worker
6265
worker.run(worker_caller);
6366
while (1) {
@@ -70,8 +73,8 @@ struct EventQueue {
7073
func();
7174
l.lock();
7275
// worker is finished
73-
if (run_sync_done)
74-
return;
76+
if (int ret = run_sync_done)
77+
return ret;
7578
}
7679
}
7780

src/executor.cc

+3-2
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
318318
}
319319

320320
// running
321+
SetupFreeBuffer setup_free_buffer;
321322
FusedOp fused_op;
322323
vector<Var*> outputs_bk;
323324
#ifdef HAS_CUDA
@@ -446,9 +447,9 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
446447
if (device_sync && use_cuda) {
447448
last_is_cuda = false;
448449
sync_times++;
449-
event_queue.run_sync([]() {
450+
CHECK(EventQueue::OK == event_queue.run_sync([]() {
450451
checkCudaErrors(cudaDeviceSynchronize());
451-
});
452+
}));
452453
}
453454
LOGvv << "cudaDeviceSynchronize times:" << sync_times << "/" <<queue.size();
454455
#endif

src/node.h

+23
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ namespace jittor {
1313

1414
extern unordered_map<void*, int64> lived_nodes;
1515
extern int64 total_node;
16+
extern int64 nt;
17+
extern vector<Node*> free_buffer;
1618

1719
struct NodeFlags {
1820
typedef uint16 nf_t;
@@ -186,6 +188,27 @@ struct Node {
186188
void set_stop_grad();
187189
};
188190

191+
struct SetupFreeBuffer {
192+
193+
bool outside;
194+
inline SetupFreeBuffer() {
195+
outside = !nt;
196+
if (outside) {
197+
nt = ++Node::tflag_count;
198+
}
199+
}
200+
201+
inline ~SetupFreeBuffer() {
202+
if (outside) {
203+
for (int i=0; i<free_buffer.size(); i++)
204+
delete free_buffer[i];
205+
free_buffer.clear();
206+
nt = 0;
207+
}
208+
}
209+
210+
};
211+
189212
std::ostream& operator<<(std::ostream& os, const Node* node);
190213

191214
} // jittor

src/var.cc

-7
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,6 @@ Var::Var(NanoVector shape, NanoString dtype)
3030
number_of_lived_vars++;
3131
numel();
3232
}
33-
Var::~Var() {
34-
if (mem_ptr != nullptr)
35-
allocator->free(mem_ptr, size, allocation);
36-
number_of_lived_vars--;
37-
if (flags.get(NodeFlags::_in_update_queue))
38-
update_queue.pop(this);
39-
}
4033

4134
string Var::to_string() {
4235
string s = dtype().to_cstring();

src/var.h

-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ struct Var : Node {
3737
inline Op* output(uint i) { return Node::output(i)->op(); }
3838

3939
Var(NanoVector shape, NanoString dtype);
40-
~Var();
4140

4241
string to_string();
4342
int64_t numel();

0 commit comments

Comments
 (0)