diff --git a/python/jittor/src/ops/getitem_op.cc b/python/jittor/src/ops/getitem_op.cc index d279f4ed..7f2802ca 100644 --- a/python/jittor/src/ops/getitem_op.cc +++ b/python/jittor/src/ops/getitem_op.cc @@ -572,6 +572,8 @@ void GetitemOp::jit_run() { index_t(vp@d[0 @for(j,0,VD,@if((VS@d>>j)&1, + i@{j+FOV} * vs@d@@s@j,))]) , ??? )))))); ) + @for(d, 0, IDIM, if (iid@d < 0) iid@d += ishape@d; + ) auto iid = 0 @for(d, 0, IDIM, + iid@d * istride@d); op[oid] = ip[iid]; } diff --git a/python/jittor/test/test_getitem_simple.py b/python/jittor/test/test_getitem_simple.py new file mode 100644 index 00000000..095dc65b --- /dev/null +++ b/python/jittor/test/test_getitem_simple.py @@ -0,0 +1,48 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np + +class TestGetItemSimple(unittest.TestCase): + def test_get_by_pos_int(self): + a = jt.array([-2,3,4,-5,-6]) + b = a[3] + b.sync() + assert b.item() == -5 + def test_get_by_neg_int(self): + a = jt.array([-2,3,4,-5,-6]) + b = a[-3] + b.sync() + assert b.item() == 4 + def test_get_slice(self): + a = jt.array([-2,3,4,-5,-6]) + b = a[-1:-3:-1].numpy().tolist() + assert len(b) == 2 + assert b[0] == -6 + assert b[1] == -5 + def test_get_by_list(self): + a = jt.array([-2,3,4,-5,-6]) + b = a[[-1, -3, 1]].numpy().tolist() + assert len(b) == 3 + assert b[0] == -6 + assert b[1] == 4 + assert b[2] == 3 + def test_multidim_by_points(self): + a = jt.arange(24).reshape(2, 3, 4) + b = jt.array([0, 1, 0]) + c = jt.array([0, -1, 1]) + d = jt.array([-2, 0, 3]) + e = a[(b, c, d)].numpy().tolist() + assert len(e) == 3 + assert e[0] == 2 + assert e[1] == 20 + assert e[2] == 7 + +if __name__ == "__main__": + jt.flags.use_cuda = True + unittest.main() \ No newline at end of file