Skip to content

Commit 0b13930

Browse files
committed
add slice broadcast
1 parent ab30a15 commit 0b13930

File tree

6 files changed

+39
-12
lines changed

6 files changed

+39
-12
lines changed

python/jittor/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# file 'LICENSE.txt', which is part of this source code package.
1010
# ***************************************************************
1111

12-
__version__ = '1.3.3.7'
12+
__version__ = '1.3.3.8'
1313
from jittor_utils import lock
1414
with lock.lock_scope():
1515
ori_int = int

python/jittor/src/ops/broadcast_to_op.cc

+1-3
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, NanoVector dims) : x(x),
9898
bool BroadcastToOp::need_broadcast(const Var* x, const NanoVector& shape) {
9999
if (x->shape.size() < shape.size()) return true;
100100
for (uint i=shape.size()-1, j=x->shape.size()-1; i<shape.size(); i--,j--)
101-
if (x->shape[j]< 0 || (x->shape[j] != shape[i] && shape[i] != 1)) return true;
101+
if ((x->shape[j] != shape[i] && shape[i] != 1)) return true;
102102
return false;
103103
}
104104

@@ -154,8 +154,6 @@ void BroadcastToOp::infer_shape() {
154154
int64 zs;
155155
if ((xshape == 1 || yshape == 1) && (xshape != yshape)) {
156156
zs = xshape * yshape;
157-
} else if (xshape < 0 || yshape < 0) {
158-
zs = std::min(xshape, yshape);
159157
} else {
160158
CHECKop(xshape,==,yshape) << "Shape not match" << x->shape << yshapes << bcast_mask;
161159
zs = xshape;

python/jittor/src/ops/getitem_op.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,11 @@ void GetitemOp::infer_slices(
108108
for (int j=0; j<niv; j++) {
109109
auto iv_shape_j = iv_shape[niv-j-1];
110110
auto& out_shape_j = out_shape[first_oid_of_var+var_dim-j-1];
111+
CHECK(out_shape_j == iv_shape_j || out_shape_j == 1 || iv_shape_j == 1) << "Shape not match " >> out_shape_j >> "!="
112+
>> iv_shape_j << "data shape:" << in_shape <<
113+
"slice shape:" << iv_shape;
111114
if (out_shape_j == 1)
112115
out_shape_j = iv_shape_j;
113-
else
114-
ASSERT(out_shape_j == iv_shape_j || out_shape_j < 0 || iv_shape_j < 0)
115-
<< out_shape_j << iv_shape_j << out_shape;
116116
}
117117
} else
118118
if (s.is_ellipsis()) {

python/jittor/src/ops/reshape_op.cc

+1-3
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@ void ReshapeOp::infer_shape() {
4141
CHECK(uncertain_dim <= 1) << "max number of -1 is 1, but get" << uncertain_dim << ".";
4242
int64_t x_items = x->num;
4343
auto yshape = shape;
44-
if (x_items < 0) {
45-
// pass if input is uncertain
46-
} else if (uncertain_dim == 0) {
44+
if (uncertain_dim == 0) {
4745
CHECKop(x_items,==,y_items) << "reshape shape is invalid for input of size";
4846
} else {
4947
CHECK(y_items != 0 && x_items % y_items == 0) << "reshape shape is invalid for input of size " << x_items;

python/jittor/src/ops/ternary_op.cc

-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ void TernaryOp::infer_shape() {
5858
auto shape = std::min(xshape, std::min(yshape, cshape));
5959
auto shape2 = std::max(xshape, std::max(yshape, cshape));
6060
zshape.push_back(shape2);
61-
if (shape < 0) continue;
6261
CHECK(shape==shape2) << "Shape not match" << x->shape << y->shape << cond->shape;
6362
}
6463
z->set_shape(zshape);

python/jittor/test/test_setitem.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,39 @@ def test_dfs_memopt(self):
394394
jt.get_max_memory_treemap()
395395

396396

397-
397+
def test_setitem_bc(self):
398+
a = jt.random([10,11,12])
399+
b = a[jt.arange(3)[:,None],
400+
jt.arange(4)[None,:]]
401+
b.sync()
402+
assert (a[:3, :4] == b).all()
403+
404+
a = jt.random([10,11,12])
405+
b = a[jt.arange(3)[:,None],
406+
jt.arange(4)[None,:],
407+
jt.arange(4)[None,:]]
408+
nb = a.data[np.arange(3)[:,None],
409+
np.arange(4)[None,:],
410+
np.arange(4)[None,:]]
411+
np.testing.assert_allclose(nb, b.data)
412+
413+
a = jt.random([10,11,12])
414+
b = a[jt.arange(3)[::-1,None],
415+
jt.arange(4)[None,:],
416+
jt.arange(4)[None,:]]
417+
nb = a.data[np.arange(3)[::-1,None],
418+
np.arange(4)[None,:],
419+
np.arange(4)[None,:]]
420+
np.testing.assert_allclose(nb, b.data)
421+
422+
a = jt.random([10,11,12])
423+
b = a[jt.arange(3)[::-1,None],
424+
jt.arange(4)[None,:],
425+
jt.arange(4)[None,::-1]]
426+
nb = a.data[np.arange(3)[::-1,None],
427+
np.arange(4)[None,:],
428+
np.arange(4)[None,::-1]]
429+
np.testing.assert_allclose(nb, b.data)
398430

399431

400432

0 commit comments

Comments
 (0)