Skip to content

Commit 02b5351

Browse files
Fixed tensor freezing indexing issue
1 parent 1874caf commit 02b5351

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/python/freeze.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -681,13 +681,13 @@ void FlatVariables::traverse(nb::handle h, TraverseContext &ctx) {
681681

682682
auto full_shape = nb::borrow<nb::tuple>(shape(h));
683683

684-
nb::list outer_shape;
684+
nb::list inner_shape;
685685
if (full_shape.size() > 0)
686-
for (uint32_t i = 0; i < full_shape.size() - 1; i++) {
687-
outer_shape.append(full_shape[i]);
686+
for (uint32_t i = 1; i < full_shape.size(); i++) {
687+
inner_shape.append(full_shape[i]);
688688
}
689689

690-
layout.py_object = nb::tuple(outer_shape);
690+
layout.py_object = nb::tuple(inner_shape);
691691

692692
traverse(nb::steal(array), ctx);
693693
} else if (s.ndim != 1) {
@@ -825,15 +825,15 @@ nb::object FlatVariables::construct() {
825825
if (s.is_tensor) {
826826
nb::object array = construct();
827827

828-
auto outer_shape = nb::borrow<nb::tuple>(layout.py_object);
828+
auto inner_shape = nb::borrow<nb::tuple>(layout.py_object);
829829
auto last_dim = prod(shape(array), nb::none())
830-
.floor_div(prod(outer_shape, nb::none()));
830+
.floor_div(prod(inner_shape, nb::none()));
831831

832832
nb::list full_shape;
833-
for (uint32_t i = 0; i < outer_shape.size(); i++) {
834-
full_shape.append(outer_shape[i]);
835-
}
836833
full_shape.append(last_dim);
834+
for (uint32_t i = 0; i < inner_shape.size(); i++) {
835+
full_shape.append(inner_shape[i]);
836+
}
837837

838838
nb::object tensor = layout.type(array, nb::tuple(full_shape));
839839
return tensor;

0 commit comments

Comments
 (0)