Skip to content

Commit 9d899dc

Browse files
committed
polish reindex memory optimize
1 parent 0b13930 commit 9d899dc

File tree

5 files changed

+15
-2
lines changed

5 files changed

+15
-2
lines changed

python/jittor/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# file 'LICENSE.txt', which is part of this source code package.
1010
# ***************************************************************
1111

12-
__version__ = '1.3.3.8'
12+
__version__ = '1.3.3.9'
1313
from jittor_utils import lock
1414
with lock.lock_scope():
1515
ori_int = int

python/jittor/src/executor.cc

-1
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,6 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
547547
root = fuse_ops[rr-1];
548548
load_fused_op(fused_op, fuse_ops, ops, ll, rr, tt);
549549
}
550-
LOGvvv << "Run" << op;
551550
for (auto* var : op->outputs()) {
552551
var->alloc(allocator);
553552
}

python/jittor/src/ops/reindex_op.cc

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ ReindexOp::ReindexOp(Var* x, NanoVector shape, vector<string>&& indexes, float64
2828
flags.set(NodeFlags::_cuda);
2929
set_type(OpType::broadcast);
3030
flags.set(NodeFlags::_manual_set_vnbb);
31+
for (auto& v : extras) v->flags.set(NodeFlags::_needed_by_backward);
3132
y = create_output(nullptr, x->dtype());
3233
}
3334

@@ -64,6 +65,7 @@ ReindexOp::ReindexOp(Var* x, vector<Var*>&& indexes, float64 overflow_value, vec
6465
extras = indexes;
6566
for (uint i = 0; i < indexes.size(); ++i) {
6667
indexes[i]->flags.set(NodeFlags::_force_fuse);
68+
indexes[i]->flags.set(NodeFlags::_needed_by_backward);
6769
}
6870
}
6971

python/jittor/src/ops/reindex_reduce_op.cc

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ ReindexReduceOp::ReindexReduceOp(Var* y, NanoString op, NanoVector shape, vector
3737
if (e->shape != y->shape) {
3838
e->flags.set(NodeFlags::_stop_fuse);
3939
}
40+
if (op.get(NanoString::_no_need_back_in))
41+
e->flags.set(NodeFlags::_needed_by_backward);
4042
}
4143
}
4244

python/jittor/test/test_reindex_op.py

+10
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,16 @@ def test_reindex_wrong_op(self):
306306
b = jt.array([1])
307307
c = a.reindex([8,8], ["@e0(0) // 1", "@e0(0)"], extras=[b, b])
308308
expect_error(lambda: c.sync())
309+
310+
def test_reindex_memopt(self):
311+
a = jt.zeros([10,10])
312+
b = jt.array([1,2,3]).name("b")
313+
c = a.reindex([8,8], ["@e0(0) / 1", "@e0(0)"], extras=[b, b])
314+
del b
315+
c.sync()
316+
da = jt.grad(c, a)
317+
da.sync()
318+
309319

310320

311321
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")

0 commit comments

Comments
 (0)