@@ -441,19 +441,34 @@ def grid_nodes_per_face(self, grid, out=None):
441441
442442 def grid_x (self , grid , out = None ):
443443 if out is None :
444- out = np .empty (self .grid_node_count (grid ), dtype = float )
444+ if self .grid_type (grid ) == 'rectilinear' :
445+ out = np .empty (self .grid_shape (grid )[1 ], dtype = float )
446+ else :
447+ out = np .empty (self .grid_node_count (grid ), dtype = float )
445448 self .bmi .get_grid_x (grid , out )
446449 return out
447450
448451 def grid_y (self , grid , out = None ):
449452 if out is None :
450- out = np .empty (self .grid_node_count (grid ), dtype = float )
453+ if self .grid_type (grid ) == 'rectilinear' :
454+ out = np .empty (self .grid_shape (grid )[0 ], dtype = float )
455+ else :
456+ out = np .empty (self .grid_node_count (grid ), dtype = float )
451457 self .bmi .get_grid_y (grid , out )
452458 return out
453459
454460 def grid_z (self , grid , out = None ):
455461 if out is None :
456- out = np .empty (self .node_count (grid ), dtype = float )
462+ if self .grid_type (grid ) == 'rectilinear' :
463+ shape = self .grid_shape (grid )
464+ try :
465+ zdim = shape [2 ]
466+ except IndexError :
467+ zdim = 1
468+ out = np .empty (zdim , dtype = float )
469+ else :
470+ out = np .empty (self .grid_node_count (grid ), dtype = float )
471+
457472 self .bmi .get_grid_z (grid , out )
458473 return out
459474
0 commit comments