Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1918055: Update agg error for unsupported aggregation functions #3133

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@
- Fixed a bug where creating a Dataframe with large number of values raised `Unsupported feature 'SCOPED_TEMPORARY'.` error if thread-safe session was disabled.
- Fixed a bug where `df.describe` raised internal SQL execution error when the dataframe is created from reading a stage file and CTE optimization is enabled.
- Fixed a bug where `df.order_by(A).select(B).distinct()` would generate invalid SQL when simplified query generation was enabled using `session.conf.set("use_simplified_query_generation", True)`.
- Disabled simplified query generation by default.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

- Disabled simplified query generation by default.

#### Improvements

- Improved version validation warnings for `snowflake-snowpark-python` package compatibility when registering stored procedures. Now, warnings are only triggered if the major or minor version does not match, while bugfix version differences no longer generate warnings.
- Bumped cloudpickle dependency to also support `cloudpickle==3.0.0` in addition to previous versions.
- Improved error message for `groupby.agg` when the function name is not supported.

### Snowpark Local Testing Updates

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ def _is_supported_snowflake_agg_func(
agg_kwargs: dict[str, Any],
axis: Literal[0, 1],
_is_df_agg: bool = False,
) -> bool:
) -> tuple[bool, list]:
"""
check if the aggregation function is supported with snowflake. Current supported
aggregation functions are the functions that can be mapped to snowflake builtin function.
Expand All @@ -851,40 +851,50 @@ def _is_supported_snowflake_agg_func(
The value can be different for different aggregation functions.
Returns:
is_valid: bool. Whether it is valid to implement with snowflake or not.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspected this comment is the result of copy-pasting , Whether it is valid to implement with snowflake or not. is not consistent of the semantic of this function (to check "check if the aggregation function is supported with snowflake")

unsupported_arguments: list. The list of unsupported functions used for aggregation.
"""
if isinstance(agg_func, tuple) and len(agg_func) == 2:
# For named aggregations, like `df.agg(new_col=("old_col", "sum"))`,
# take the second part of the named aggregation.
agg_func = agg_func[0]
return get_snowflake_agg_func(agg_func, agg_kwargs, axis, _is_df_agg) is not None
if get_snowflake_agg_func(agg_func, agg_kwargs, axis, _is_df_agg) is None:
return False, agg_func
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is agg_func guaranteed to be a list?

return True, []


def _are_all_agg_funcs_supported_by_snowflake(
agg_funcs: list[AggFuncTypeBase],
agg_kwargs: dict[str, Any],
axis: Literal[0, 1],
_is_df_agg: bool = False,
) -> bool:
) -> tuple[bool, list]:
"""
Check if all aggregation functions in the given list are snowflake supported
aggregation functions.

Returns:
True if all functions in the list are snowflake supported aggregation functions, otherwise,
return False.
"""
return all(
_is_supported_snowflake_agg_func(func, agg_kwargs, axis, _is_df_agg)
for func in agg_funcs
)
bool
True if all functions in the list are snowflake supported aggregation functions, otherwise,
return False
list
The list of unsupported functions used for aggregation.
Comment on lines +876 to +880
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: usually I see docstrings for functions like this explicitly mention tuple[bool, list] as the return type, and describe what each member of the tuple means rather than separating out the values.

"""
is_supported_bools: list[bool] = []
unsupported_list: list[str] = []
for func in agg_funcs:
is_supported, unsupported_list = _is_supported_snowflake_agg_func(
func, agg_kwargs, axis, _is_df_agg
)
Comment on lines +883 to +887
Copy link
Contributor

@sfc-gh-jjiao sfc-gh-jjiao Mar 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the unsupported_list need to be appended to in the for loop ? it seems like it has been replaced/overwritten every time here

is_supported_bools.append(is_supported)
return all(is_supported_bools), unsupported_list


def check_is_aggregation_supported_in_snowflake(
agg_func: AggFuncType,
agg_kwargs: dict[str, Any],
axis: Literal[0, 1],
_is_df_agg: bool = False,
) -> bool:
) -> tuple[bool, list]:
"""
check if distributed implementation with snowflake is available for the aggregation
based on the input arguments.
Expand All @@ -898,27 +908,43 @@ def check_is_aggregation_supported_in_snowflake(
Returns:
bool
Whether the aggregation operation can be executed with snowflake sql engine.
list
The list of unsupported functions used for aggregation.
"""
# validate agg_func, only snowflake builtin agg function or dict of snowflake builtin agg
# function can be implemented in distributed way.
unsupported_arguments: list[str] = []
supported_flag = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the name "is_supported" so there's 0 ambiguity about the meaning of the flag's T/F value.

if is_dict_like(agg_func):
return all(
(
_are_all_agg_funcs_supported_by_snowflake(
for value in agg_func.values():
if is_list_like(value) and not is_named_tuple(value):
(
is_supported_func,
unsupported_arguments,
) = _are_all_agg_funcs_supported_by_snowflake(
value, agg_kwargs, axis, _is_df_agg
)
if is_list_like(value) and not is_named_tuple(value)
else _is_supported_snowflake_agg_func(
else:
(
is_supported_func,
unsupported_arguments,
) = _is_supported_snowflake_agg_func(
value, agg_kwargs, axis, _is_df_agg
)
)
for value in agg_func.values()
)
if not is_supported_func:
supported_flag = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just return here? Is your intent to combine the unsupported_arguments lists?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first function that's unsupported will be returned. The case with multiple unsupported functions needs to be handled which will require returning in this loop, so I'll make those changes as well as using repr_aggregate_function(agg_func, agg_kwargs)!

elif is_list_like(agg_func):
return _are_all_agg_funcs_supported_by_snowflake(
(
supported_flag,
unsupported_arguments,
) = _are_all_agg_funcs_supported_by_snowflake(
agg_func, agg_kwargs, axis, _is_df_agg
)
else:
supported_flag, unsupported_arguments = _is_supported_snowflake_agg_func(
agg_func, agg_kwargs, axis, _is_df_agg
)
return _is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis, _is_df_agg)
return supported_flag, unsupported_arguments


def _is_snowflake_numeric_type_required(snowflake_agg_func: Callable) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3732,10 +3732,13 @@ def groupby_agg(
level=level,
dropna=agg_kwargs.get("dropna", True),
)

if not check_is_aggregation_supported_in_snowflake(agg_func, agg_kwargs, axis):
ErrorMessage.not_implemented(
f"Snowpark pandas GroupBy.aggregate does not yet support the aggregation {repr_aggregate_function(agg_func, agg_kwargs)} with the given arguments."
(
is_supported,
unsupported_arguments,
) = check_is_aggregation_supported_in_snowflake(agg_func, agg_kwargs, axis)
if not is_supported:
raise AttributeError(
f"'SeriesGroupBy' object has no attribute '{unsupported_arguments}'"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if this is an aggregation that native pandas supports but we do not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be returning False since we check if there is a corresponding Snowflake aggregation function via get_snowflake_agg_func(). The overall checking logic for if a function is supported should not be changing here

)

sort = groupby_kwargs.get("sort", True)
Expand Down Expand Up @@ -6103,11 +6106,15 @@ def agg(
# by snowflake engine.
# If we are using Named Aggregations, we need to do our supported check slightly differently.
uses_named_aggs = using_named_aggregations_for_func(func)
if not check_is_aggregation_supported_in_snowflake(
(
is_supported,
unsupported_arguments,
) = check_is_aggregation_supported_in_snowflake(
func, kwargs, axis, _is_df_agg=True
):
ErrorMessage.not_implemented(
f"Snowpark pandas aggregate does not yet support the aggregation {repr_aggregate_function(func, kwargs)} with the given arguments."
)
if not is_supported:
raise AttributeError(
f"'SeriesGroupBy' object has no attribute '{unsupported_arguments}'"
)

query_compiler = self
Expand Down
20 changes: 20 additions & 0 deletions tests/integ/modin/groupby/test_groupby_basic_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,26 @@ def test_groupby_agg_on_groupby_columns(
)


@pytest.mark.parametrize("as_index", [True, False])
@pytest.mark.parametrize("sort", [True, False])
@sql_count_checker(query_count=1)
def test_groupby_agg_with_incorrect_func(as_index, sort) -> None:
basic_snowpark_pandas_df = pd.DataFrame(
data=8 * [range(3)], columns=["a", "b", "c"]
)
# basic_snowpark_pandas_df = basic_snowpark_pandas_df.groupby(['a', 'b']).sum()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you mean to delete this?

native_pandas = basic_snowpark_pandas_df.to_pandas()
eval_snowpark_pandas_result(
basic_snowpark_pandas_df,
native_pandas,
lambda df: df.groupby(by="a", sort=sort, as_index=as_index).agg(
{"b": sum, "c": "COUNT"}
),
expect_exception=True,
expect_exception_match="'SeriesGroupBy' object has no attribute 'COUNT'",
)


@pytest.mark.parametrize(
"by", ["col1", ["col1", "col2", "col3"], ["col1", "col1", "col2"]]
)
Expand Down
37 changes: 29 additions & 8 deletions tests/integ/modin/groupby/test_groupby_named_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,17 @@ def test_named_agg_passed_in_via_star_kwargs(basic_df_data):
def test_named_agg_with_invalid_function_raises_not_implemented(
basic_df_data,
):
with pytest.raises(
NotImplementedError,
match=re.escape(
"Snowpark pandas GroupBy.aggregate does not yet support the aggregation new_label=(label, 'min'), new_label=(label, 'random_function')"
),
):
pd.DataFrame(basic_df_data).groupby("col1").agg(
basic_snowpark_pandas_df = pd.DataFrame(basic_df_data)
native_pandas = native_pd.DataFrame(basic_df_data)
eval_snowpark_pandas_result(
basic_snowpark_pandas_df,
native_pandas,
lambda df: df.groupby("col1").agg(
c1=("col2", "min"), c2=("col2", "random_function")
)
),
expect_exception=True,
expect_exception_match="'SeriesGroupBy' object has no attribute",
)


@sql_count_checker(query_count=1)
Expand Down Expand Up @@ -163,3 +165,22 @@ def test_named_agg_size_on_series(size_func):
native_series,
lambda series: series.groupby(level=0).agg(new_col=size_func),
)


@pytest.mark.parametrize("as_index", [True, False])
@pytest.mark.parametrize("sort", [True, False])
@sql_count_checker(query_count=1)
def test_named_groupby_agg_with_incorrect_func(as_index, sort) -> None:
basic_snowpark_pandas_df = pd.DataFrame(
data=8 * [range(3)], columns=["a", "b", "c"]
)
basic_snowpark_pandas_df = basic_snowpark_pandas_df.groupby(["a", "b"]).sum()
native_pandas = basic_snowpark_pandas_df.to_pandas()
eval_snowpark_pandas_result(
basic_snowpark_pandas_df,
native_pandas,
lambda df: df.groupby(by="a", sort=sort, as_index=as_index).agg(
NEW_B=("b", "sum"), ACTIVE_DAYS=("c", "COUNT")
),
expect_exception=True,
)
Loading