Skip to content

Commit c792b23

Browse files
committed
version 1.2
1 parent d533f39 commit c792b23

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+3083
-132
lines changed

extern/cuda/cub/ops/cub_arg_reduce_op.cc

+3-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ void CubArgReduceOp::infer_shape() {
4949
if (keepdims) {
5050
shape.push_back(1);
5151
}
52+
if (shape.size() == 0)
53+
shape.push_back(1);
5254
y->set_shape(shape);
5355
y_key->set_shape(shape);
5456
}
@@ -104,4 +106,4 @@ void CubArgReduceOp::jit_run() {
104106
#endif // JIT_cuda
105107
#endif // JIT
106108

107-
} // jittor
109+
} // jittor

extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
4242
shape[0], shape[1], shape[2], shape[3]));
4343
}
4444

45-
CudnnConvBackwardWOp::CudnnConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
46-
: x(x), dy(dy), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation), groups(groups),
45+
CudnnConvBackwardWOp::CudnnConvBackwardWOp(Var* x, Var* dy, int kh, int kw, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
46+
: x(x), dy(dy), kh(kh), kw(kw), stride(stride), padding(padding), dilation(dilation), groups(groups),
4747
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
4848
flags.set(NodeFlags::_cuda, 1);
4949
flags.set(NodeFlags::_cpu, 0);
@@ -57,8 +57,8 @@ void CudnnConvBackwardWOp::infer_shape() {
5757
get_shape(x, "abcd", xformat, xn, xc, xh, xw);
5858
get_shape(dy, "abcd", yformat, yn, yc, yh, yw);
5959
wco = yc, wci = xc / groups;
60-
wh = kernel_size;
61-
ww = kernel_size;
60+
wh = kh;
61+
ww = kw;
6262
set_shape(dw, "oihw", wformat, wco, wci, wh, ww);
6363
}
6464

extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ namespace jittor {
1313

1414
struct CudnnConvBackwardWOp : Op {
1515
Var* x, * dy, * dw;
16-
int kernel_size, stride, padding, dilation, groups;
16+
int kh, kw, stride, padding, dilation, groups;
1717
string xformat, wformat, yformat;
1818

19-
CudnnConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
19+
CudnnConvBackwardWOp(Var* x, Var* y, int kh, int kw, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
2020

2121
const char* name() const override { return "cudnn_conv_backward_w"; }
2222
void infer_shape() override;
2323
DECLARE_jit_run;
2424
};
2525

26-
} // jittor
26+
} // jittor

extern/mkl/ops/mkl_conv_backward_w_op.cc

+9-8
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
4545
shape[0], shape[1], shape[2], shape[3]));
4646
}
4747

48-
MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
49-
: x(x), dy(dy), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation), groups(groups),
48+
MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kh, int kw, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
49+
: x(x), dy(dy), kh(kh), kw(kw), stride(stride), padding(padding), dilation(dilation), groups(groups),
5050
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
5151
dw = create_output(nullptr, dtype_infer(dy->ns, x->ns));
5252
}
@@ -58,8 +58,8 @@ void MklConvBackwardWOp::infer_shape() {
5858
get_shape(x, "abcd", xformat, xn, xc, xh, xw);
5959
get_shape(dy, "abcd", yformat, yn, yc, yh, yw);
6060
wco = yc, wci = xc / groups;
61-
wh = kernel_size;
62-
ww = kernel_size;
61+
wh = kh;
62+
ww = kw;
6363
set_shape(dw, "oihw", wformat, wco, wci, wh, ww);
6464
}
6565

@@ -97,7 +97,8 @@ void MklConvBackwardWOp::jit_run() {
9797
int height = x->shape[findc("@XFORMAT",'c')];
9898
int width = x->shape[findc("@XFORMAT",'d')];
9999
int ch_out = dw->shape[findc("@WFORMAT",'o')];
100-
int kernel_size = dw->shape[findc("@WFORMAT",'h')];
100+
int kh = dw->shape[findc("@WFORMAT",'h')];
101+
int kw = dw->shape[findc("@WFORMAT",'w')];
101102

102103
auto* __restrict__ net_src = x->ptr<Txd>();
103104
auto* __restrict__ net_diff_dst = dy->ptr<Tyd>();
@@ -114,9 +115,9 @@ void MklConvBackwardWOp::jit_run() {
114115

115116
memory::dims conv_src_tz = {batch, ch_in, height, width};
116117
memory::dims conv_weights_tz = groups>1
117-
? memory::dims{groups, ch_out/groups, ch_in/groups, kernel_size, kernel_size}
118-
: memory::dims{ch_out, ch_in, kernel_size, kernel_size};
119-
memory::dims conv_dst_tz = {batch, ch_out, (height+padding*2-kernel_size*dilation+dilation-1)/stride+1, (width+padding*2-kernel_size*dilation+dilation-1)/stride+1};
118+
? memory::dims{groups, ch_out/groups, ch_in/groups, kh, kw}
119+
: memory::dims{ch_out, ch_in, kh, kw};
120+
memory::dims conv_dst_tz = {batch, ch_out, (height+padding*2-kh*dilation+dilation-1)/stride+1, (width+padding*2-kw*dilation+dilation-1)/stride+1};
120121
memory::dims conv_strides = {stride, stride};
121122
memory::dims conv_padding = {padding, padding};
122123
memory::dims conv_dilation = {dilation-1, dilation-1};

extern/mkl/ops/mkl_conv_backward_w_op.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ namespace jittor {
1313

1414
struct MklConvBackwardWOp : Op {
1515
Var* x, * dy, * dw;
16-
int kernel_size, stride, padding, dilation, groups;
16+
int kh, kw, stride, padding, dilation, groups;
1717
string xformat, wformat, yformat;
1818

19-
MklConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
19+
MklConvBackwardWOp(Var* x, Var* y, int kh, int kw, int stride, int padding, int dilation, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd");
2020

2121
const char* name() const override { return "mkl_conv_backward_w"; }
2222
void infer_shape() override;
2323
DECLARE_jit_run;
2424
};
2525

26-
} // jittor
26+
} // jittor

python/jittor/__init__.py

+28-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# This file is subject to the terms and conditions defined in
88
# file 'LICENSE.txt', which is part of this source code package.
99
# ***************************************************************
10-
__version__ = '1.1.7.20'
10+
__version__ = '1.2.0.0'
1111
from . import lock
1212
with lock.lock_scope():
1313
from . import compiler
@@ -233,11 +233,22 @@ def ones(shape, dtype="float32"):
233233
shape = (shape,)
234234
return unary(1, dtype).broadcast(shape)
235235

236+
def ones_like(x):
237+
return ones(x.shape,x.dtype)
238+
236239
def zeros(shape, dtype="float32"):
237240
if not isinstance(shape, (NanoVector, Sequence)):
238241
shape = (shape,)
239242
return unary(0, dtype).broadcast(shape)
240243

244+
def full(shape,val,dtype="float32"):
245+
if not isinstance(shape, (NanoVector, Sequence)):
246+
shape = (shape,)
247+
return unary(val, dtype).broadcast(shape)
248+
249+
def zeros_like(x):
250+
return zeros(x.shape,x.dtype)
251+
241252
flags = core.flags()
242253

243254
def std(x):
@@ -311,9 +322,17 @@ def squeeze(x, dim):
311322
return x.reshape(shape[:dim] + shape[dim+1:])
312323
Var.squeeze = squeeze
313324

314-
def clamp(x, min_v, max_v):
315-
assert min_v <= max_v
316-
return x.maximum(min_v).minimum(max_v)
325+
def clamp(x, min_v=None, max_v=None):
326+
if x.shape[0]==0:
327+
return x
328+
if min_v is not None and max_v is not None:
329+
assert min_v <= max_v
330+
if min_v is not None:
331+
x = x.maximum(min_v)
332+
if max_v is not None:
333+
x = x.minimum(max_v)
334+
return x
335+
317336
Var.clamp = clamp
318337

319338
def type_as(a, b):
@@ -574,6 +593,8 @@ def load_parameters(self, params):
574593
else:
575594
if hasattr(v, k):
576595
v = getattr(v, k)
596+
assert isinstance(v, (Module, Var)), \
597+
f"expect a jittor Module or Var, but got <{v.__class__.__name__}>, key: {key}"
577598
else:
578599
end = 1
579600
break
@@ -582,6 +603,8 @@ def load_parameters(self, params):
582603
n_failed += 1
583604
LOG.w(f'load parameter {key} failed ...')
584605
else:
606+
assert isinstance(v, Var), \
607+
f"expect a jittor Var, but got <{v.__class__.__name__}>, key: {key}"
585608
LOG.v(f'load parameter {key} success ...')
586609
if isinstance(params[key], np.ndarray) or isinstance(params[key], list):
587610
v.update(array(params[key]))
@@ -872,4 +895,4 @@ def to_bool(v):
872895
from . import contrib
873896
from . import numpy2cupy
874897
from .contrib import concat
875-
from .misc import *
898+
from .misc import *

python/jittor/compiler.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def add_src(
241241
if "multiple_outputs" not in attrs:
242242
jit_cc_src.append(f"""
243243
VarPtr make_{cc_func_name}({", ".join(cc_make_args)}) {{
244-
Op* _op = new {op_name}({", ".join(op_make_args)});
244+
auto _op = new {op_name}({", ".join(op_make_args)});
245245
if (_op->outputs_holder.size() != 1) {{
246246
delete _op;
247247
LOGf << "Wrong output size of" << \"{op_name}\";
@@ -261,7 +261,7 @@ def add_src(
261261
else:
262262
jit_cc_src.append(f"""
263263
vector<VarPtr> make_{cc_func_name}({", ".join(cc_make_args)}) {{
264-
Op* _op = new {op_name}({", ".join(op_make_args)});
264+
auto _op = new {op_name}({", ".join(op_make_args)});
265265
if (_op->flags.get(NodeFlags::_forwarded)) {{
266266
vector<VarPtr> outputs = move(_op->outputs_holder);
267267
delete _op;
@@ -408,6 +408,15 @@ def add_src(
408408
arg_type.replace("Var", "VarHolder")+' '+arg)
409409
new_args.append(arg)
410410
more_src.append(f"_op->add_inputs({arg});")
411+
elif arg_type.startswith("VarSlices"):
412+
new_args_def.append(arg_def)
413+
new_args.append(arg)
414+
more_src.append(f"""
415+
vector<Var*> svars;
416+
for (int i=0; i<_op->vs.n; i++)
417+
if (_op->vs.slices[i].is_var())
418+
svars.push_back(_op->vs.slices[i].var);
419+
_op->add_inputs(svars);""")
411420
else:
412421
new_args_def.append(arg_def)
413422
new_args.append(arg)

python/jittor/contrib.py

+70-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def concat(arr, dim):
4242
indexes[dim] = f"i{dim}-{cdim}"
4343
b = a.reindex(shape, indexes)
4444
# ugly fix for preventing large fused op
45-
if len(arr)>=10:
45+
if len(arr)>=100:
4646
b.stop_fuse()
4747
if s is None:
4848
s = b
@@ -99,6 +99,20 @@ def slice_var_index(x, slices):
9999
cnt_list = 0
100100
extras_idx = []
101101
extras = []
102+
has_ellipse = 0
103+
ellipse_index = 0
104+
for s,i in zip(slices,range(len(slices))):
105+
if isinstance(s,type(...)):
106+
has_ellipse+=1
107+
ellipse_index = i
108+
if has_ellipse>1:
109+
raise Exception(f"There are more than one ...")
110+
elif has_ellipse==1:
111+
slices = list(slices)
112+
del slices[ellipse_index]
113+
while len(slices)<len(shape):
114+
slices.insert(ellipse_index,slice(None))
115+
102116
for i in range(len(shape)):
103117
if i>=len(slices):
104118
s = slice(None)
@@ -119,6 +133,7 @@ def slice_var_index(x, slices):
119133
step = 1 if s.step is None else s.step
120134
if start<0: start += sp
121135
if stop<0: stop += sp
136+
if stop>sp+1: stop = sp
122137
out_shape.append(1+int(max(0, (stop-start-1)//step)))
123138
out_index.append(f"{start}+i{j}*{step}")
124139
elif isinstance(s, jt.Var):
@@ -160,3 +175,57 @@ def setitem(x, slices, value):
160175

161176
jt.Var.__getitem__ = jt.Var.slice_var = slice_var
162177
jt.Var.__setitem__ = setitem
178+
179+
# PATCH
180+
def getitem(x, slices):
181+
if isinstance(slices, jt.Var) and slices.dtype == "bool":
182+
return getitem(x, slices.where())
183+
if isinstance(slices, list):
184+
slices = tuple(slices)
185+
return x.getitem(slices)
186+
187+
def setitem(x, slices, value):
188+
if isinstance(slices, jt.Var) and slices.dtype == "bool":
189+
mask = jt.broadcast(slices, x)
190+
value = jt.broadcast(value, x)
191+
return mask.ternary(value, mask)
192+
if isinstance(slices, list):
193+
slices = tuple(slices)
194+
return x.assign(x.setitem(slices, value))
195+
196+
jt.Var.__getitem__ = jt.Var.slice_var = getitem
197+
jt.Var.__setitem__ = setitem
198+
199+
def concat(arr, dim):
200+
'''Concat Operator can concat a list of jt Var at a specfic dimension.
201+
202+
* [in] x: input var list for concat
203+
204+
* [in] dim: concat which dim
205+
206+
* [out] out: concat result
207+
208+
Example::
209+
210+
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
211+
# return [[1],[2],[2],[2]]
212+
'''
213+
# TODO: low performance when concat lots of vars
214+
total_dim = 0
215+
if dim < 0: dim += len(arr[0].shape)
216+
for a in arr:
217+
total_dim += a.shape[dim]
218+
cdim = 0
219+
shape = list(a.shape)
220+
shape[dim] = total_dim
221+
s = jt.empty(shape, a.dtype)
222+
slices = [slice(None)]*len(a.shape)
223+
for a in arr:
224+
if a.shape[dim] == 0:
225+
continue
226+
slices[dim] = slice(cdim, cdim+a.shape[dim])
227+
# print(slices, type(a))
228+
s = s.setitem(tuple(slices), a)
229+
# s = jt.setitem(s, tuple(slices), a)
230+
cdim += a.shape[dim]
231+
return s

python/jittor/init.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,45 @@ def relu_invariant_gauss(shape, dtype, mode="fan_in"):
5656
def relu_invariant_gauss_(var, mode="fan_in"):
5757
var.assign(relu_invariant_gauss(tuple(var.shape), var.dtype, mode))
5858

59+
def calculate_std(var,mode,nonlinearity,param=0.01):
60+
mode = mode.lower()
61+
assert isinstance(param,(int,float))
62+
assert var.ndim>=2
63+
assert mode in ['fan_in', 'fan_out']
64+
65+
fan = var.shape[1] if mode == 'fan_in' else var.shape[0]
66+
fan *= var[0][0].numel()
67+
68+
gains = {
69+
'linear':1,
70+
'conv1d':1,
71+
'conv2d':1,
72+
'conv3d':1,
73+
'conv_transpose1d':1,
74+
'conv_transpose2d':1,
75+
'conv_transpose3d':1,
76+
'sigmoid':1,
77+
'tanh':5.0/3,
78+
'relu':math.sqrt(2.0),
79+
'leaky_relu':math.sqrt(2.0 / (1 + param ** 2)),
80+
}
81+
gain = gains[nonlinearity]
82+
std = gain/math.sqrt(fan)
83+
return std
84+
85+
86+
def kaiming_uniform_(var, a=0, mode='fan_in', nonlinearity='leaky_relu'):
87+
std = calculate_std(var,mode,nonlinearity,a)
88+
bound = math.sqrt(3.0) * std
89+
with jt.no_grad():
90+
return uniform_(var,-bound, bound)
91+
92+
def kaiming_normal_(var, a=0, mode='fan_in', nonlinearity='leaky_relu'):
93+
std = calculate_std(var,mode,nonlinearity,a)
94+
with jt.no_grad():
95+
return gauss_(var,0, std)
96+
97+
5998
#TODO: bound = gain * math.sqrt(6.0/fan) ??
6099
def xavier_uniform(shape, dtype, gain=1.0):
61100
assert len(shape)>1
@@ -81,4 +120,4 @@ def xavier_gauss(shape, dtype, gain=1.0):
81120
return gauss(shape, dtype, 0, std)
82121

83122
def xavier_gauss_(var, gain=1.0):
84-
var.assign(xavier_gauss(tuple(var.shape), var.dtype, gain))
123+
var.assign(xavier_gauss(tuple(var.shape), var.dtype, gain))

0 commit comments

Comments
 (0)