@@ -681,13 +681,13 @@ void FlatVariables::traverse(nb::handle h, TraverseContext &ctx) {
681
681
682
682
auto full_shape = nb::borrow<nb::tuple>(shape (h));
683
683
684
- nb::list outer_shape ;
684
+ nb::list inner_shape ;
685
685
if (full_shape.size () > 0 )
686
- for (uint32_t i = 0 ; i < full_shape.size () - 1 ; i++) {
687
- outer_shape .append (full_shape[i]);
686
+ for (uint32_t i = 1 ; i < full_shape.size (); i++) {
687
+ inner_shape .append (full_shape[i]);
688
688
}
689
689
690
- layout.py_object = nb::tuple (outer_shape );
690
+ layout.py_object = nb::tuple (inner_shape );
691
691
692
692
traverse (nb::steal (array), ctx);
693
693
} else if (s.ndim != 1 ) {
@@ -825,15 +825,15 @@ nb::object FlatVariables::construct() {
825
825
if (s.is_tensor ) {
826
826
nb::object array = construct ();
827
827
828
- auto outer_shape = nb::borrow<nb::tuple>(layout.py_object );
828
+ auto inner_shape = nb::borrow<nb::tuple>(layout.py_object );
829
829
auto last_dim = prod (shape (array), nb::none ())
830
- .floor_div (prod (outer_shape , nb::none ()));
830
+ .floor_div (prod (inner_shape , nb::none ()));
831
831
832
832
nb::list full_shape;
833
- for (uint32_t i = 0 ; i < outer_shape.size (); i++) {
834
- full_shape.append (outer_shape[i]);
835
- }
836
833
full_shape.append (last_dim);
834
+ for (uint32_t i = 0 ; i < inner_shape.size (); i++) {
835
+ full_shape.append (inner_shape[i]);
836
+ }
837
837
838
838
nb::object tensor = layout.type (array, nb::tuple (full_shape));
839
839
return tensor;
0 commit comments