Skip to content

Commit 6d89bb9

Browse files
authored
Merge pull request #247 from lincc-frameworks/reduce_overhaul
disallow non-column arguments for reduce
2 parents c771a8b + 0e933fa commit 6d89bb9

File tree

2 files changed

+38
-6
lines changed

2 files changed

+38
-6
lines changed

src/nested_pandas/nestedframe/core.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -1007,8 +1007,9 @@ def reduce(self, func, *args, infer_nesting=True, **kwargs) -> NestedFrame: # t
10071007
columns to apply the function to. See the Notes for recommendations
10081008
on writing func outputs.
10091009
args : positional arguments
1010-
Positional arguments to pass to the function, the first *args should be the names of the
1011-
columns to apply the function to.
1010+
A list of string column names to pull from the NestedFrame to pass along
1011+
to the function. If the function has additional arguments, pass them as
1012+
keyword arguments (e.g. `arg_name=value`).
10121013
infer_nesting : bool, default True
10131014
If True, the function will pack output columns into nested
10141015
structures based on column names adhering to a nested naming
@@ -1083,10 +1084,20 @@ def reduce(self, func, *args, infer_nesting=True, **kwargs) -> NestedFrame: # t
10831084
# Stop when we reach an argument that is not a valid column, as we assume
10841085
# that the remaining args are extra arguments to the function
10851086
if not isinstance(arg, str):
1086-
break
1087+
raise TypeError(
1088+
f"Received an argument '{arg}' that is not a string. "
1089+
"All arguments to `reduce` must be strings corresponding to"
1090+
" column names to pass along to the function. If your function"
1091+
" has additional arguments, pass them as kwargs (arg_name=value)."
1092+
)
10871093
components = self._parse_hierarchical_components(arg)
10881094
if not self._is_known_column(components):
1089-
break
1095+
raise ValueError(
1096+
f"Received a string argument '{arg}' that was not found in the columns list. "
1097+
"All arguments to `reduce` must be strings corresponding to"
1098+
" column names to pass along to the function. If your function"
1099+
" has additional arguments, pass them as kwargs (arg_name=value)."
1100+
)
10901101
layer = "base" if len(components) < 2 else components[0]
10911102
col = components[-1]
10921103
requested_columns.append((layer, col))

tests/nested_pandas/nestedframe/test_nestedframe.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -966,7 +966,7 @@ def offset_avg(offset, col_to_avg, column_names):
966966
sum([7, 10, 7]) / 3.0,
967967
]
968968

969-
result = nf.reduce(offset_avg, "b", "packed.c", ["offset_avg"])
969+
result = nf.reduce(offset_avg, "b", "packed.c", column_names=["offset_avg"])
970970
assert len(result) == len(nf)
971971
assert isinstance(result, NestedFrame)
972972
assert result.index.name == "idx"
@@ -978,7 +978,7 @@ def offset_avg(offset, col_to_avg, column_names):
978978
def make_id(col1, prefix_str):
979979
return f"{prefix_str}{col1}"
980980

981-
result = nf.reduce(make_id, "b", "some_id_")
981+
result = nf.reduce(make_id, "b", prefix_str="some_id_")
982982
assert result[0][1] == "some_id_4"
983983

984984

@@ -1095,6 +1095,27 @@ def complex_output(flux):
10951095
assert list(result.lc.nest.fields) == ["flux_quantiles", "labels"]
10961096

10971097

1098+
def test_reduce_arg_errors():
1099+
"""Test that reduce errors based on non-column args trigger as expected"""
1100+
1101+
ndf = generate_data(10, 10, seed=1)
1102+
1103+
def func(a, flux, add):
1104+
"""a function that takes a scalar, a column, and a boolean"""
1105+
if add:
1106+
return {"nested2.flux": flux + a}
1107+
return {"nested2.flux": flux + a}
1108+
1109+
with pytest.raises(TypeError):
1110+
ndf.reduce(func, "a", "nested.flux", True)
1111+
1112+
with pytest.raises(ValueError):
1113+
ndf.reduce(func, "ab", "nested.flux", add=True)
1114+
1115+
# this should work
1116+
ndf.reduce(func, "a", "nested.flux", add=True)
1117+
1118+
10981119
def test_scientific_notation():
10991120
"""
11001121
Test that NestedFrame.query handles constants that are written in scientific notation.

0 commit comments

Comments
 (0)