Skip to content
This repository was archived by the owner on Jan 9, 2023. It is now read-only.

Commit 5a20a52

Browse files
committed
Add support for loading columns of object arrays
1 parent d571619 commit 5a20a52

File tree

2 files changed

+69
-36
lines changed

2 files changed

+69
-36
lines changed

root_pandas/readwrite.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,30 @@ def filter_noexpand_columns(columns):
9292
return other, noexpand
9393

9494

95+
def do_flatten(arr, flatten):
96+
if flatten is True:
97+
warnings.warn(" The option flatten=True is deprecated. Please specify the branches you would like "
98+
"to flatten in a list: flatten=['foo', 'bar']", FutureWarning)
99+
arr_, idx = stretch(arr, return_indices=True)
100+
else:
101+
nonscalar = get_nonscalar_columns(arr)
102+
fields = [x for x in arr.dtype.names if (x not in nonscalar or x in flatten)]
103+
104+
for col in flatten:
105+
if col in nonscalar:
106+
pass
107+
elif col in fields:
108+
raise ValueError("Requested to flatten {col} but it has a scalar type"
109+
.format(col=col))
110+
else:
111+
raise ValueError("Requested to flatten {col} but it wasn't loaded from the input file"
112+
.format(col=col))
113+
114+
arr_, idx = stretch(arr, fields=fields, return_indices=True)
115+
arr = append_fields(arr_, '__array_index', idx, usemask=False, asrecarray=True)
116+
return arr
117+
118+
95119
def read_root(paths, key=None, columns=None, ignore=None, chunksize=None, where=None, flatten=False, *args, **kwargs):
96120
"""
97121
Read a ROOT file, or list of ROOT files, into a pandas DataFrame.
@@ -174,22 +198,6 @@ def read_root(paths, key=None, columns=None, ignore=None, chunksize=None, where=
174198
for var in ignored:
175199
all_vars.remove(var)
176200

177-
def do_flatten(arr, flatten):
178-
if flatten is True:
179-
warnings.warn(" The option flatten=True is deprecated. Please specify the branches you would like "
180-
"to flatten in a list: flatten=['foo', 'bar']", FutureWarning)
181-
arr_, idx = stretch(arr, return_indices=True)
182-
else:
183-
nonscalar = get_nonscalar_columns(arr)
184-
fields = [x for x in arr.dtype.names if (x not in nonscalar or x in flatten)]
185-
will_drop = [x for x in arr.dtype.names if x not in fields]
186-
if will_drop:
187-
warnings.warn("Ignored the following non-scalar branches: {bad_names}"
188-
.format(bad_names=", ".join(will_drop)), UserWarning)
189-
arr_, idx = stretch(arr, fields=fields, return_indices=True)
190-
arr = append_fields(arr_, '__array_index', idx, usemask=False, asrecarray=True)
191-
return arr
192-
193201
if chunksize:
194202
tchain = ROOT.TChain(key)
195203
for path in paths:
@@ -215,26 +223,45 @@ def genchunks():
215223

216224
def convert_to_dataframe(array, start_index=None):
217225
nonscalar_columns = get_nonscalar_columns(array)
218-
if nonscalar_columns:
219-
warnings.warn("Ignored the following non-scalar branches: {bad_names}"
220-
.format(bad_names=", ".join(nonscalar_columns)), UserWarning)
221-
indices = list(filter(lambda x: x.startswith('__index__') and x not in nonscalar_columns, array.dtype.names))
226+
227+
# Columns containing 2D arrays can't be loaded so convert them 1D arrays of arrays
228+
reshaped_columns = {}
229+
for col in nonscalar_columns:
230+
if array[col].ndim >= 2:
231+
reshaped = np.zeros(len(array[col]), dtype='O')
232+
for i, row in enumerate(array[col]):
233+
reshaped[i] = row
234+
reshaped_columns[col] = reshaped
235+
236+
indices = list(filter(lambda x: x.startswith('__index__'), array.dtype.names))
222237
if len(indices) == 0:
223238
index = None
224239
if start_index is not None:
225240
index = RangeIndex(start=start_index, stop=start_index + len(array))
226-
df = DataFrame.from_records(array, exclude=nonscalar_columns, index=index)
241+
df = DataFrame.from_records(array, exclude=reshaped_columns, index=index)
227242
elif len(indices) == 1:
228243
# We store the index under the __index__* branch, where
229244
# * is the name of the index
230-
df = DataFrame.from_records(array, index=indices[0], exclude=nonscalar_columns)
245+
df = DataFrame.from_records(array, exclude=reshaped_columns, index=indices[0])
231246
index_name = indices[0][len('__index__'):]
232247
if not index_name:
233248
# None means the index has no name
234249
index_name = None
235250
df.index.name = index_name
236251
else:
237252
raise ValueError("More than one index found in file")
253+
254+
# Manually the columns which were reshaped
255+
for key, reshaped in reshaped_columns.items():
256+
df[key] = reshaped
257+
258+
# Reshaping can cause the order of columns to change so we have to change it back
259+
if reshaped_columns:
260+
# Filter to remove __index__ columns
261+
columns = [c for c in array.dtype.names if c in df.columns]
262+
assert len(columns) == len(df.columns), (columns, df.columns)
263+
df = df.reindex_axis(columns, axis=1, copy=False)
264+
238265
return df
239266

240267

tests/test.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,17 @@ def test_flatten():
191191
os.remove('tmp.root')
192192

193193

194-
def test_drop_nonscalar_columns():
195-
array = np.array([1, 2, 3])
196-
matrix = np.array([[1, 2, 3], [4, 5, 6]])
197-
bool_matrix = np.array([[True, False, True], [True, True, True]])
194+
def to_object_array(array):
195+
new_array = np.zeros(len(array), dtype='O')
196+
for i, row in enumerate(array):
197+
new_array[i] = row
198+
return new_array
199+
200+
201+
def test_nonscalar_columns():
202+
array = np.array([1, 2, 3], dtype=np.int64)
203+
matrix = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
204+
bool_matrix = np.array([[True, False, True], [True, True, True]], dtype=np.bool_)
198205

199206
dt = np.dtype([
200207
('a', 'i4'),
@@ -208,18 +215,17 @@ def test_drop_nonscalar_columns():
208215
(2, array, matrix, False, bool_matrix)],
209216
dtype=dt)
210217

218+
reference_df = pd.DataFrame()
219+
reference_df['a'] = np.array([3, 2], dtype=np.int32)
220+
reference_df['b'] = to_object_array([array, array])
221+
reference_df['c'] = to_object_array([matrix, matrix])
222+
reference_df['d'] = np.array([True, False], dtype=np.bool_)
223+
reference_df['e'] = to_object_array([bool_matrix, bool_matrix])
224+
211225
path = 'tmp.root'
212226
array2root(arr, path, 'ntuple', mode='recreate')
213-
with warnings.catch_warnings():
214-
warnings.simplefilter("ignore")
215-
df = read_root(path, flatten=False)
216-
# the above line throws an error if flatten=True because nonscalar columns
217-
# are dropped only after the flattening is applied. However, the flattening
218-
# algorithm can not deal with arrays of more than one dimension.
219-
assert(len(df.columns) == 2)
220-
assert(np.all(df.index.values == np.array([0, 1])))
221-
assert(np.all(df.a.values == np.array([3, 2])))
222-
assert(np.all(df.d.values == np.array([True, False])))
227+
df = read_root(path, flatten=False)
228+
assert_frame_equal(df, reference_df)
223229

224230
os.remove(path)
225231

0 commit comments

Comments
 (0)