2828
2929
3030def expand_braces (orig ):
31- r = r'.*(\{.+? [^\\]\})'
31+ r = r'.*? (\{.+[^\\]\})'
3232 p = re .compile (r )
3333
3434 s = orig [:]
@@ -40,12 +40,10 @@ def expand_braces(orig):
4040 open_brace = s .find (sub )
4141 close_brace = open_brace + len (sub ) - 1
4242 if sub .find (',' ) != - 1 :
43- for pat in sub . strip ( '{}' ) .split (',' ):
43+ for pat in sub [ 1 : - 1 ] .split (',' ):
4444 res .extend (expand_braces (s [:open_brace ] + pat + s [close_brace + 1 :]))
45-
4645 else :
4746 res .extend (expand_braces (s [:open_brace ] + sub .replace ('}' , '\\ }' ) + s [close_brace + 1 :]))
48-
4947 else :
5048 res .append (s .replace ('\\ }' , '}' ))
5149
@@ -59,6 +57,7 @@ def get_nonscalar_columns(array):
5957 bad_names = col_names [bad_cols ]
6058 return list (bad_names )
6159
60+
6261def get_matching_variables (branches , patterns , fail = True ):
6362 selected = []
6463
@@ -93,6 +92,30 @@ def filter_noexpand_columns(columns):
9392 return other , noexpand
9493
9594
95+ def do_flatten (arr , flatten ):
96+ if flatten is True :
97+ warnings .warn (" The option flatten=True is deprecated. Please specify the branches you would like "
98+ "to flatten in a list: flatten=['foo', 'bar']" , FutureWarning )
99+ arr_ , idx = stretch (arr , return_indices = True )
100+ else :
101+ nonscalar = get_nonscalar_columns (arr )
102+ fields = [x for x in arr .dtype .names if (x not in nonscalar or x in flatten )]
103+
104+ for col in flatten :
105+ if col in nonscalar :
106+ pass
107+ elif col in fields :
108+ raise ValueError ("Requested to flatten {col} but it has a scalar type"
109+ .format (col = col ))
110+ else :
111+ raise ValueError ("Requested to flatten {col} but it wasn't loaded from the input file"
112+ .format (col = col ))
113+
114+ arr_ , idx = stretch (arr , fields = fields , return_indices = True )
115+ arr = append_fields (arr_ , '__array_index' , idx , usemask = False , asrecarray = True )
116+ return arr
117+
118+
96119def read_root (paths , key = None , columns = None , ignore = None , chunksize = None , where = None , flatten = False , * args , ** kwargs ):
97120 """
98121 Read a ROOT file, or list of ROOT files, into a pandas DataFrame.
@@ -175,22 +198,6 @@ def read_root(paths, key=None, columns=None, ignore=None, chunksize=None, where=
175198 for var in ignored :
176199 all_vars .remove (var )
177200
178- def do_flatten (arr , flatten ):
179- if flatten is True :
180- warnings .warn (" The option flatten=True is deprecated. Please specify the branches you would like "
181- "to flatten in a list: flatten=['foo', 'bar']" , FutureWarning )
182- arr_ , idx = stretch (arr , return_indices = True )
183- else :
184- nonscalar = get_nonscalar_columns (arr )
185- fields = [x for x in arr .dtype .names if (x not in nonscalar or x in flatten )]
186- will_drop = [x for x in arr .dtype .names if x not in fields ]
187- if will_drop :
188- warnings .warn ("Ignored the following non-scalar branches: {bad_names}"
189- .format (bad_names = ", " .join (will_drop )), UserWarning )
190- arr_ , idx = stretch (arr , fields = fields , return_indices = True )
191- arr = append_fields (arr_ , '__array_index' , idx , usemask = False , asrecarray = True )
192- return arr
193-
194201 if chunksize :
195202 tchain = ROOT .TChain (key )
196203 for path in paths :
@@ -216,26 +223,45 @@ def genchunks():
216223
217224def convert_to_dataframe (array , start_index = None ):
218225 nonscalar_columns = get_nonscalar_columns (array )
219- if nonscalar_columns :
220- warnings .warn ("Ignored the following non-scalar branches: {bad_names}"
221- .format (bad_names = ", " .join (nonscalar_columns )), UserWarning )
222- indices = list (filter (lambda x : x .startswith ('__index__' ) and x not in nonscalar_columns , array .dtype .names ))
226+
227+ # Columns containing 2D arrays can't be loaded so convert them 1D arrays of arrays
228+ reshaped_columns = {}
229+ for col in nonscalar_columns :
230+ if array [col ].ndim >= 2 :
231+ reshaped = np .zeros (len (array [col ]), dtype = 'O' )
232+ for i , row in enumerate (array [col ]):
233+ reshaped [i ] = row
234+ reshaped_columns [col ] = reshaped
235+
236+ indices = list (filter (lambda x : x .startswith ('__index__' ), array .dtype .names ))
223237 if len (indices ) == 0 :
224238 index = None
225239 if start_index is not None :
226240 index = RangeIndex (start = start_index , stop = start_index + len (array ))
227- df = DataFrame .from_records (array , exclude = nonscalar_columns , index = index )
241+ df = DataFrame .from_records (array , exclude = reshaped_columns , index = index )
228242 elif len (indices ) == 1 :
229243 # We store the index under the __index__* branch, where
230244 # * is the name of the index
231- df = DataFrame .from_records (array , index = indices [0 ], exclude = nonscalar_columns )
245+ df = DataFrame .from_records (array , exclude = reshaped_columns , index = indices [0 ])
232246 index_name = indices [0 ][len ('__index__' ):]
233247 if not index_name :
234248 # None means the index has no name
235249 index_name = None
236250 df .index .name = index_name
237251 else :
238252 raise ValueError ("More than one index found in file" )
253+
254+ # Manually the columns which were reshaped
255+ for key , reshaped in reshaped_columns .items ():
256+ df [key ] = reshaped
257+
258+ # Reshaping can cause the order of columns to change so we have to change it back
259+ if reshaped_columns :
260+ # Filter to remove __index__ columns
261+ columns = [c for c in array .dtype .names if c in df .columns ]
262+ assert len (columns ) == len (df .columns ), (columns , df .columns )
263+ df = df .reindex_axis (columns , axis = 1 , copy = False )
264+
239265 return df
240266
241267
0 commit comments