Skip to content

Commit 8d64d98

Browse files
committed
jit_prepare for array op
1 parent 6bc4a85 commit 8d64d98

14 files changed

+65
-64
lines changed

python/jittor/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# This file is subject to the terms and conditions defined in
88
# file 'LICENSE.txt', which is part of this source code package.
99
# ***************************************************************
10-
__version__ = '1.2.0.0'
10+
__version__ = '1.2.0.1'
1111
from . import lock
1212
with lock.lock_scope():
1313
from . import compiler

python/jittor/nn.py

+27-44
Original file line numberDiff line numberDiff line change
@@ -342,62 +342,45 @@ def execute(self, x):
342342

343343
class BatchNorm1d(Module):
344344
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
345-
assert affine == None
346345
self.sync = sync
347346
self.num_features = num_features
348347
self.is_train = is_train
349348
self.eps = eps
350349
self.momentum = momentum
351-
self.weight = init.constant((num_features,), "float32", 1.0)
352-
self.bias = init.constant((num_features,), "float32", 0.0)
350+
self.affine = affine
351+
if affine:
352+
self.weight = init.constant((num_features,), "float32", 1.0)
353+
self.bias = init.constant((num_features,), "float32", 0.0)
353354
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
354355
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
355356

356357
def execute(self, x):
357358
if len(x.shape) == 3:
358-
if self.is_train:
359-
xmean = jt.mean(x, dims=[0, 2], keepdims=1)
360-
x2mean = jt.mean(x*x, dims=[0, 2], keepdims=1)
361-
362-
if self.sync and jt.in_mpi:
363-
xmean = xmean.mpi_all_reduce("mean")
364-
x2mean = x2mean.mpi_all_reduce("mean")
365-
366-
xvar = x2mean-xmean*xmean
367-
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
368-
self.running_mean.update(self.running_mean +
369-
(xmean.sum([0, 2])-self.running_mean)*self.momentum)
370-
self.running_var.update(self.running_var +
371-
(xvar.sum([0, 2])-self.running_var)*self.momentum)
372-
else:
373-
running_mean = self.running_mean.broadcast(x, [0, 2])
374-
running_var = self.running_var.broadcast(x, [0, 2])
375-
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
376-
w = self.weight.broadcast(x, [0, 2])
377-
b = self.bias.broadcast(x, [0, 2])
378-
else:
379-
if self.is_train:
380-
xmean = jt.mean(x, dims=[0], keepdims=1)
381-
x2mean = jt.mean(x*x, dims=[0], keepdims=1)
382-
383-
if self.sync and jt.in_mpi:
384-
xmean = xmean.mpi_all_reduce("mean")
385-
x2mean = x2mean.mpi_all_reduce("mean")
386-
387-
xvar = x2mean-xmean*xmean
388-
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
389-
self.running_mean.update(self.running_mean +
390-
(xmean.sum([0])-self.running_mean)*self.momentum)
391-
self.running_var.update(self.running_var +
392-
(xvar.sum([0])-self.running_var)*self.momentum)
393-
else:
394-
running_mean = self.running_mean.broadcast(x, [0])
395-
running_var = self.running_var.broadcast(x, [0])
396-
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
397-
w = self.weight.broadcast(x, [0])
398-
b = self.bias.broadcast(x, [0])
359+
dims = [0, 2]
360+
else:
361+
dims = [0]
362+
if self.is_train:
363+
xmean = jt.mean(x, dims=dims, keepdims=1)
364+
x2mean = jt.mean(x*x, dims=dims, keepdims=1)
365+
366+
if self.sync and jt.in_mpi:
367+
xmean = xmean.mpi_all_reduce("mean")
368+
x2mean = x2mean.mpi_all_reduce("mean")
369+
370+
xvar = x2mean-xmean*xmean
371+
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
372+
self.running_mean.update(self.running_mean +
373+
(xmean.sum(dims)-self.running_mean)*self.momentum)
374+
self.running_var.update(self.running_var +
375+
(xvar.sum(dims)-self.running_var)*self.momentum)
376+
else:
377+
running_mean = self.running_mean.broadcast(x, dims)
378+
running_var = self.running_var.broadcast(x, dims)
379+
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
399380
if not self.affine:
400381
return norm_x
382+
w = self.weight.broadcast(x, dims)
383+
b = self.bias.broadcast(x, dims)
401384
return norm_x * w + b
402385

403386
class InstanceNorm2d(Module):

python/jittor/test/__main__.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
if __name__ == "__main__":
88
import unittest, os
9+
unittest.TestLoader.sortTestMethodsUsing = None
910

1011
suffix = "__main__.py"
1112
assert __file__.endswith(suffix)
@@ -22,17 +23,19 @@
2223
suite = unittest.TestSuite()
2324

2425
for _, test_file in enumerate(test_files):
26+
test_name = test_file.split(".")[0]
27+
tests = unittest.defaultTestLoader.loadTestsFromName(
28+
"jittor.test."+test_name)
29+
2530
if not test_file.startswith("test_"):
2631
continue
2732
if _ < skip_l or _ > skip_r:
2833
continue
29-
test_name = test_file.split(".")[0]
3034
if test_only and test_name not in test_only:
3135
continue
3236

3337
print("Add Test", _, test_name)
34-
suite.addTest(unittest.defaultTestLoader.loadTestsFromName(
35-
"jittor.test."+test_name))
38+
suite.addTest(tests)
3639

3740
result = unittest.TextTestRunner(verbosity=3).run(suite)
3841
if len(result.errors) or len(result.failures):

python/jittor/test/test_affine_grid.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,18 @@
99
from jittor.nn import affine_grid,grid_sample
1010

1111

12+
skip_this_test = False
13+
14+
try:
15+
jt.dirty_fix_pytorch_runtime_error()
16+
import torch.nn.functional as F
17+
import torch
18+
except:
19+
skip_this_test = True
20+
21+
@unittest.skipIf(skip_this_test, "No Torch found")
1222
class TestAffineGrid(unittest.TestCase):
1323
def test_affine_grid_2d(self):
14-
import torch.nn.functional as F
15-
import torch
1624
N = 8
1725
C = 3
1826
H = 256
@@ -37,8 +45,6 @@ def test_affine_grid_2d(self):
3745

3846

3947
def test_affine_grid_3d(self):
40-
import torch.nn.functional as F
41-
import torch
4248
N = 8
4349
C = 3
4450
D = 64

python/jittor/test/test_concat_op.py

+3
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def check(tmp, dim=0):
5151
check([jt.array(np.array(range(5))).reshape((5,1)), jt.array(np.array(range(1))).reshape((1,1))])
5252
print('concat success...')
5353

54+
55+
@unittest.skipIf(not jt.has_cuda, "No CUDA found")
5456
@jt.flag_scope(use_cuda = 1)
5557
def test_concat_perf(self):
5658
def check(dim, size, backward=False):
@@ -106,6 +108,7 @@ def check(dim, size, backward=False):
106108
107109
'''
108110

111+
@unittest.skipIf(not jt.has_cuda, "No CUDA found")
109112
@jt.flag_scope(use_cuda = 1)
110113
def test_concat2_perf(self):
111114
def check(dim, size, backward=False):

python/jittor/test/test_contrib.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def check(shape, dim, n):
2020
arr2.append(jt.array(a))
2121
x = np.concatenate(tuple(arr1), dim)
2222
y = jt.contrib.concat(arr2, dim)
23-
assert (x==y.data).all()
23+
assert (x==y.data).all(), (x, y.data, arr1, arr2)
2424
check([2,3,4], 0, 2)
2525
check([2,3,4], 1, 3)
2626
check([2,3,4], 2, 4)

python/jittor/test/test_merge_single_array_op.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def test(shape, op1, op2):
4545
with jt.log_capture_scope(log_v=0, log_vprefix="fused_op.cc=100") as logs:
4646
d__ = d.data
4747
logs = find_log_with_re(logs,
48-
"Jit (fused )?op key (not )?found: \[opkey0:array\]\[opkey1")
49-
assert(len(logs)==1)
48+
"Jit (fused )?op key (not )?found: \[opkey0:array\[T:float32")
49+
assert(len(logs)==1), logs
5050

5151
a_ = a.data
5252
b_ = b.data

python/jittor/test/test_mkl_conv_op.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_backward(self):
114114
b = np.random.rand(o,i,h,w).astype(np.float32)
115115
da = np.random.rand(n,o,H,W).astype(np.float32)
116116
dx = jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1).data
117-
dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,1,1,1).data
117+
dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,w,1,1,1).data
118118
a_jt = jt.array(a)
119119
b_jt = jt.array(b)
120120

@@ -160,7 +160,7 @@ def test_backward_nhwc_hwio(self):
160160
b = np.random.rand(h,w,i,o).astype(np.float32)
161161
da = np.random.rand(n,H,W,o).astype(np.float32)
162162
dx = jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data
163-
dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data
163+
dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,w,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data
164164
a_jt = jt.array(a)
165165
b_jt = jt.array(b)
166166

python/jittor/test/test_mpi_batchnorm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=T
2626
self.is_train = is_train
2727
self.eps = eps
2828
self.momentum = momentum
29-
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
30-
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
3129
self.weight = init.constant((num_features,), "float32", 1.0)
3230
self.bias = init.constant((num_features,), "float32", 0.0)
31+
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
32+
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
3333

3434
def execute(self, x, global_x):
3535
if self.is_train:

python/jittor/test/test_slice.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def check(shape, slices, i_to_vs="", i_to_o="", o_shape=""):
4949
# print(slices)
5050
x = jt.random(shape)
5151

52-
with jt.log_capture_scope(log_vprefix="getitem=1000") as logs:
52+
with jt.log_capture_scope(log_vprefix="getitem=999") as logs:
5353
a = x.getitem(slices)
5454
a.sync()
5555
b = x.data[slices]

src/ops/array_op.cc

+5
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ ArrayOp::ArrayOp(ArrayArgs&& args) {
7474
std::memcpy(allocation.ptr, args.ptr, output->size);
7575
}
7676

77+
void ArrayOp::jit_prepare() {
78+
if (output->flags.get(NodeFlags::_force_fuse))
79+
add_jit_define("T", output->dtype());
80+
}
81+
7782
void ArrayOp::run() {
7883
#ifdef HAS_CUDA
7984
if (allocation.allocator == &cuda_dual_allocator) {

src/ops/array_op.h

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ struct ArrayOp : Op {
2828

2929
const char* name() const override { return "array"; }
3030
void run() override;
31+
void jit_prepare() override;
3132
};
3233

3334
} // jittor

src/ops/getitem_op.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ void GetitemOp::infer_shape() {
358358
this->i_to_o = i_to_o.to_nano_vector();
359359
this->o_shape = o_shape.to_nano_vector();
360360

361-
LOGvvvv << "\ni_to_vs:" << i_to_vs
361+
LOGV(999) << "\ni_to_vs:" << i_to_vs
362362
<< "\ni_to_o:" << i_to_o
363363
<< "\no_shape:" << o_shape;
364364
}

src/test/test_kernel_ir.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ JIT_TEST(kernel_ir) {
4949
})", true
5050
);
5151
string code = R"(//
52-
// scope: main(1),
52+
// scope: <cmath>(1), aaa(1), main(1),
5353
54-
// C macro code:"#include <cmath>"
54+
// C macro code:"#include <cmath>" lvalue:"<cmath>"
5555
#include <cmath>
56-
// C macro code:"#define aaa bbb"
56+
// C macro code:"#define aaa bbb" lvalue:"aaa" rvalue:" bbb"
5757
#define aaa bbb
5858
// C code:"using namespace std;" raw:"1"
5959
using namespace std;

0 commit comments

Comments
 (0)