@@ -92,6 +92,30 @@ def filter_noexpand_columns(columns):
92
92
return other , noexpand
93
93
94
94
95
+ def do_flatten (arr , flatten ):
96
+ if flatten is True :
97
+ warnings .warn (" The option flatten=True is deprecated. Please specify the branches you would like "
98
+ "to flatten in a list: flatten=['foo', 'bar']" , FutureWarning )
99
+ arr_ , idx = stretch (arr , return_indices = True )
100
+ else :
101
+ nonscalar = get_nonscalar_columns (arr )
102
+ fields = [x for x in arr .dtype .names if (x not in nonscalar or x in flatten )]
103
+
104
+ for col in flatten :
105
+ if col in nonscalar :
106
+ pass
107
+ elif col in fields :
108
+ raise ValueError ("Requested to flatten {col} but it has a scalar type"
109
+ .format (col = col ))
110
+ else :
111
+ raise ValueError ("Requested to flatten {col} but it wasn't loaded from the input file"
112
+ .format (col = col ))
113
+
114
+ arr_ , idx = stretch (arr , fields = fields , return_indices = True )
115
+ arr = append_fields (arr_ , '__array_index' , idx , usemask = False , asrecarray = True )
116
+ return arr
117
+
118
+
95
119
def read_root (paths , key = None , columns = None , ignore = None , chunksize = None , where = None , flatten = False , * args , ** kwargs ):
96
120
"""
97
121
Read a ROOT file, or list of ROOT files, into a pandas DataFrame.
@@ -174,22 +198,6 @@ def read_root(paths, key=None, columns=None, ignore=None, chunksize=None, where=
174
198
for var in ignored :
175
199
all_vars .remove (var )
176
200
177
- def do_flatten (arr , flatten ):
178
- if flatten is True :
179
- warnings .warn (" The option flatten=True is deprecated. Please specify the branches you would like "
180
- "to flatten in a list: flatten=['foo', 'bar']" , FutureWarning )
181
- arr_ , idx = stretch (arr , return_indices = True )
182
- else :
183
- nonscalar = get_nonscalar_columns (arr )
184
- fields = [x for x in arr .dtype .names if (x not in nonscalar or x in flatten )]
185
- will_drop = [x for x in arr .dtype .names if x not in fields ]
186
- if will_drop :
187
- warnings .warn ("Ignored the following non-scalar branches: {bad_names}"
188
- .format (bad_names = ", " .join (will_drop )), UserWarning )
189
- arr_ , idx = stretch (arr , fields = fields , return_indices = True )
190
- arr = append_fields (arr_ , '__array_index' , idx , usemask = False , asrecarray = True )
191
- return arr
192
-
193
201
if chunksize :
194
202
tchain = ROOT .TChain (key )
195
203
for path in paths :
@@ -215,26 +223,45 @@ def genchunks():
215
223
216
224
def convert_to_dataframe (array , start_index = None ):
217
225
nonscalar_columns = get_nonscalar_columns (array )
218
- if nonscalar_columns :
219
- warnings .warn ("Ignored the following non-scalar branches: {bad_names}"
220
- .format (bad_names = ", " .join (nonscalar_columns )), UserWarning )
221
- indices = list (filter (lambda x : x .startswith ('__index__' ) and x not in nonscalar_columns , array .dtype .names ))
226
+
227
+ # Columns containing 2D arrays can't be loaded so convert them 1D arrays of arrays
228
+ reshaped_columns = {}
229
+ for col in nonscalar_columns :
230
+ if array [col ].ndim >= 2 :
231
+ reshaped = np .zeros (len (array [col ]), dtype = 'O' )
232
+ for i , row in enumerate (array [col ]):
233
+ reshaped [i ] = row
234
+ reshaped_columns [col ] = reshaped
235
+
236
+ indices = list (filter (lambda x : x .startswith ('__index__' ), array .dtype .names ))
222
237
if len (indices ) == 0 :
223
238
index = None
224
239
if start_index is not None :
225
240
index = RangeIndex (start = start_index , stop = start_index + len (array ))
226
- df = DataFrame .from_records (array , exclude = nonscalar_columns , index = index )
241
+ df = DataFrame .from_records (array , exclude = reshaped_columns , index = index )
227
242
elif len (indices ) == 1 :
228
243
# We store the index under the __index__* branch, where
229
244
# * is the name of the index
230
- df = DataFrame .from_records (array , index = indices [0 ], exclude = nonscalar_columns )
245
+ df = DataFrame .from_records (array , exclude = reshaped_columns , index = indices [0 ])
231
246
index_name = indices [0 ][len ('__index__' ):]
232
247
if not index_name :
233
248
# None means the index has no name
234
249
index_name = None
235
250
df .index .name = index_name
236
251
else :
237
252
raise ValueError ("More than one index found in file" )
253
+
254
+ # Manually the columns which were reshaped
255
+ for key , reshaped in reshaped_columns .items ():
256
+ df [key ] = reshaped
257
+
258
+ # Reshaping can cause the order of columns to change so we have to change it back
259
+ if reshaped_columns :
260
+ # Filter to remove __index__ columns
261
+ columns = [c for c in array .dtype .names if c in df .columns ]
262
+ assert len (columns ) == len (df .columns ), (columns , df .columns )
263
+ df = df .reindex_axis (columns , axis = 1 , copy = False )
264
+
238
265
return df
239
266
240
267
0 commit comments