28
28
29
29
30
30
def expand_braces (orig ):
31
- r = r'.*(\{.+? [^\\]\})'
31
+ r = r'.*? (\{.+[^\\]\})'
32
32
p = re .compile (r )
33
33
34
34
s = orig [:]
@@ -40,12 +40,10 @@ def expand_braces(orig):
40
40
open_brace = s .find (sub )
41
41
close_brace = open_brace + len (sub ) - 1
42
42
if sub .find (',' ) != - 1 :
43
- for pat in sub . strip ( '{}' ) .split (',' ):
43
+ for pat in sub [ 1 : - 1 ] .split (',' ):
44
44
res .extend (expand_braces (s [:open_brace ] + pat + s [close_brace + 1 :]))
45
-
46
45
else :
47
46
res .extend (expand_braces (s [:open_brace ] + sub .replace ('}' , '\\ }' ) + s [close_brace + 1 :]))
48
-
49
47
else :
50
48
res .append (s .replace ('\\ }' , '}' ))
51
49
@@ -59,6 +57,7 @@ def get_nonscalar_columns(array):
59
57
bad_names = col_names [bad_cols ]
60
58
return list (bad_names )
61
59
60
+
62
61
def get_matching_variables (branches , patterns , fail = True ):
63
62
selected = []
64
63
@@ -93,6 +92,30 @@ def filter_noexpand_columns(columns):
93
92
return other , noexpand
94
93
95
94
95
+ def do_flatten (arr , flatten ):
96
+ if flatten is True :
97
+ warnings .warn (" The option flatten=True is deprecated. Please specify the branches you would like "
98
+ "to flatten in a list: flatten=['foo', 'bar']" , FutureWarning )
99
+ arr_ , idx = stretch (arr , return_indices = True )
100
+ else :
101
+ nonscalar = get_nonscalar_columns (arr )
102
+ fields = [x for x in arr .dtype .names if (x not in nonscalar or x in flatten )]
103
+
104
+ for col in flatten :
105
+ if col in nonscalar :
106
+ pass
107
+ elif col in fields :
108
+ raise ValueError ("Requested to flatten {col} but it has a scalar type"
109
+ .format (col = col ))
110
+ else :
111
+ raise ValueError ("Requested to flatten {col} but it wasn't loaded from the input file"
112
+ .format (col = col ))
113
+
114
+ arr_ , idx = stretch (arr , fields = fields , return_indices = True )
115
+ arr = append_fields (arr_ , '__array_index' , idx , usemask = False , asrecarray = True )
116
+ return arr
117
+
118
+
96
119
def read_root (paths , key = None , columns = None , ignore = None , chunksize = None , where = None , flatten = False , * args , ** kwargs ):
97
120
"""
98
121
Read a ROOT file, or list of ROOT files, into a pandas DataFrame.
@@ -175,22 +198,6 @@ def read_root(paths, key=None, columns=None, ignore=None, chunksize=None, where=
175
198
for var in ignored :
176
199
all_vars .remove (var )
177
200
178
- def do_flatten (arr , flatten ):
179
- if flatten is True :
180
- warnings .warn (" The option flatten=True is deprecated. Please specify the branches you would like "
181
- "to flatten in a list: flatten=['foo', 'bar']" , FutureWarning )
182
- arr_ , idx = stretch (arr , return_indices = True )
183
- else :
184
- nonscalar = get_nonscalar_columns (arr )
185
- fields = [x for x in arr .dtype .names if (x not in nonscalar or x in flatten )]
186
- will_drop = [x for x in arr .dtype .names if x not in fields ]
187
- if will_drop :
188
- warnings .warn ("Ignored the following non-scalar branches: {bad_names}"
189
- .format (bad_names = ", " .join (will_drop )), UserWarning )
190
- arr_ , idx = stretch (arr , fields = fields , return_indices = True )
191
- arr = append_fields (arr_ , '__array_index' , idx , usemask = False , asrecarray = True )
192
- return arr
193
-
194
201
if chunksize :
195
202
tchain = ROOT .TChain (key )
196
203
for path in paths :
@@ -216,26 +223,45 @@ def genchunks():
216
223
217
224
def convert_to_dataframe (array , start_index = None ):
218
225
nonscalar_columns = get_nonscalar_columns (array )
219
- if nonscalar_columns :
220
- warnings .warn ("Ignored the following non-scalar branches: {bad_names}"
221
- .format (bad_names = ", " .join (nonscalar_columns )), UserWarning )
222
- indices = list (filter (lambda x : x .startswith ('__index__' ) and x not in nonscalar_columns , array .dtype .names ))
226
+
227
+ # Columns containing 2D arrays can't be loaded so convert them 1D arrays of arrays
228
+ reshaped_columns = {}
229
+ for col in nonscalar_columns :
230
+ if array [col ].ndim >= 2 :
231
+ reshaped = np .zeros (len (array [col ]), dtype = 'O' )
232
+ for i , row in enumerate (array [col ]):
233
+ reshaped [i ] = row
234
+ reshaped_columns [col ] = reshaped
235
+
236
+ indices = list (filter (lambda x : x .startswith ('__index__' ), array .dtype .names ))
223
237
if len (indices ) == 0 :
224
238
index = None
225
239
if start_index is not None :
226
240
index = RangeIndex (start = start_index , stop = start_index + len (array ))
227
- df = DataFrame .from_records (array , exclude = nonscalar_columns , index = index )
241
+ df = DataFrame .from_records (array , exclude = reshaped_columns , index = index )
228
242
elif len (indices ) == 1 :
229
243
# We store the index under the __index__* branch, where
230
244
# * is the name of the index
231
- df = DataFrame .from_records (array , index = indices [0 ], exclude = nonscalar_columns )
245
+ df = DataFrame .from_records (array , exclude = reshaped_columns , index = indices [0 ])
232
246
index_name = indices [0 ][len ('__index__' ):]
233
247
if not index_name :
234
248
# None means the index has no name
235
249
index_name = None
236
250
df .index .name = index_name
237
251
else :
238
252
raise ValueError ("More than one index found in file" )
253
+
254
+ # Manually the columns which were reshaped
255
+ for key , reshaped in reshaped_columns .items ():
256
+ df [key ] = reshaped
257
+
258
+ # Reshaping can cause the order of columns to change so we have to change it back
259
+ if reshaped_columns :
260
+ # Filter to remove __index__ columns
261
+ columns = [c for c in array .dtype .names if c in df .columns ]
262
+ assert len (columns ) == len (df .columns ), (columns , df .columns )
263
+ df = df .reindex_axis (columns , axis = 1 , copy = False )
264
+
239
265
return df
240
266
241
267
0 commit comments