-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SNOW-1918055: Update agg error for unsupported aggregation functions #3133
base: main
Are you sure you want to change the base?
Changes from all commits
ef9abd2
1855cd6
7af1aa0
31b3fb3
97692e9
36ccf27
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -840,7 +840,7 @@ def _is_supported_snowflake_agg_func( | |
agg_kwargs: dict[str, Any], | ||
axis: Literal[0, 1], | ||
_is_df_agg: bool = False, | ||
) -> bool: | ||
) -> tuple[bool, list]: | ||
""" | ||
check if the aggregation function is supported with snowflake. Current supported | ||
aggregation functions are the functions that can be mapped to snowflake builtin function. | ||
|
@@ -851,40 +851,50 @@ def _is_supported_snowflake_agg_func( | |
The value can be different for different aggregation functions. | ||
Returns: | ||
is_valid: bool. Whether it is valid to implement with snowflake or not. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspected this comment is the result of copy-pasting , |
||
unsupported_arguments: list. The list of unsupported functions used for aggregation. | ||
""" | ||
if isinstance(agg_func, tuple) and len(agg_func) == 2: | ||
# For named aggregations, like `df.agg(new_col=("old_col", "sum"))`, | ||
# take the second part of the named aggregation. | ||
agg_func = agg_func[0] | ||
return get_snowflake_agg_func(agg_func, agg_kwargs, axis, _is_df_agg) is not None | ||
if get_snowflake_agg_func(agg_func, agg_kwargs, axis, _is_df_agg) is None: | ||
return False, agg_func | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is |
||
return True, [] | ||
|
||
|
||
def _are_all_agg_funcs_supported_by_snowflake( | ||
agg_funcs: list[AggFuncTypeBase], | ||
agg_kwargs: dict[str, Any], | ||
axis: Literal[0, 1], | ||
_is_df_agg: bool = False, | ||
) -> bool: | ||
) -> tuple[bool, list]: | ||
""" | ||
Check if all aggregation functions in the given list are snowflake supported | ||
aggregation functions. | ||
|
||
Returns: | ||
True if all functions in the list are snowflake supported aggregation functions, otherwise, | ||
return False. | ||
""" | ||
return all( | ||
_is_supported_snowflake_agg_func(func, agg_kwargs, axis, _is_df_agg) | ||
for func in agg_funcs | ||
) | ||
bool | ||
True if all functions in the list are snowflake supported aggregation functions, otherwise, | ||
return False | ||
list | ||
The list of unsupported functions used for aggregation. | ||
Comment on lines
+876
to
+880
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: usually I see docstrings for functions like this explicitly mention |
||
""" | ||
is_supported_bools: list[bool] = [] | ||
unsupported_list: list[str] = [] | ||
for func in agg_funcs: | ||
is_supported, unsupported_list = _is_supported_snowflake_agg_func( | ||
func, agg_kwargs, axis, _is_df_agg | ||
) | ||
Comment on lines
+883
to
+887
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does the |
||
is_supported_bools.append(is_supported) | ||
return all(is_supported_bools), unsupported_list | ||
|
||
|
||
def check_is_aggregation_supported_in_snowflake( | ||
agg_func: AggFuncType, | ||
agg_kwargs: dict[str, Any], | ||
axis: Literal[0, 1], | ||
_is_df_agg: bool = False, | ||
) -> bool: | ||
) -> tuple[bool, list]: | ||
""" | ||
check if distributed implementation with snowflake is available for the aggregation | ||
based on the input arguments. | ||
|
@@ -898,27 +908,43 @@ def check_is_aggregation_supported_in_snowflake( | |
Returns: | ||
bool | ||
Whether the aggregation operation can be executed with snowflake sql engine. | ||
list | ||
The list of unsupported functions used for aggregation. | ||
""" | ||
# validate agg_func, only snowflake builtin agg function or dict of snowflake builtin agg | ||
# function can be implemented in distributed way. | ||
unsupported_arguments: list[str] = [] | ||
supported_flag = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer the name " |
||
if is_dict_like(agg_func): | ||
return all( | ||
( | ||
_are_all_agg_funcs_supported_by_snowflake( | ||
for value in agg_func.values(): | ||
if is_list_like(value) and not is_named_tuple(value): | ||
( | ||
is_supported_func, | ||
unsupported_arguments, | ||
) = _are_all_agg_funcs_supported_by_snowflake( | ||
value, agg_kwargs, axis, _is_df_agg | ||
) | ||
if is_list_like(value) and not is_named_tuple(value) | ||
else _is_supported_snowflake_agg_func( | ||
else: | ||
( | ||
is_supported_func, | ||
unsupported_arguments, | ||
) = _is_supported_snowflake_agg_func( | ||
value, agg_kwargs, axis, _is_df_agg | ||
) | ||
) | ||
for value in agg_func.values() | ||
) | ||
if not is_supported_func: | ||
supported_flag = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not just return here? Is your intent to combine the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The first function that's unsupported will be returned. The case with multiple unsupported functions needs to be handled which will require returning in this loop, so I'll make those changes as well as using repr_aggregate_function(agg_func, agg_kwargs)! |
||
elif is_list_like(agg_func): | ||
return _are_all_agg_funcs_supported_by_snowflake( | ||
( | ||
supported_flag, | ||
unsupported_arguments, | ||
) = _are_all_agg_funcs_supported_by_snowflake( | ||
agg_func, agg_kwargs, axis, _is_df_agg | ||
) | ||
else: | ||
supported_flag, unsupported_arguments = _is_supported_snowflake_agg_func( | ||
agg_func, agg_kwargs, axis, _is_df_agg | ||
) | ||
return _is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis, _is_df_agg) | ||
return supported_flag, unsupported_arguments | ||
|
||
|
||
def _is_snowflake_numeric_type_required(snowflake_agg_func: Callable) -> bool: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3732,10 +3732,13 @@ def groupby_agg( | |
level=level, | ||
dropna=agg_kwargs.get("dropna", True), | ||
) | ||
|
||
if not check_is_aggregation_supported_in_snowflake(agg_func, agg_kwargs, axis): | ||
ErrorMessage.not_implemented( | ||
f"Snowpark pandas GroupBy.aggregate does not yet support the aggregation {repr_aggregate_function(agg_func, agg_kwargs)} with the given arguments." | ||
( | ||
is_supported, | ||
unsupported_arguments, | ||
) = check_is_aggregation_supported_in_snowflake(agg_func, agg_kwargs, axis) | ||
if not is_supported: | ||
raise AttributeError( | ||
f"'SeriesGroupBy' object has no attribute '{unsupported_arguments}'" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens if this is an aggregation that native pandas supports but we do not? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be returning False since we check if there is a corresponding Snowflake aggregation function via get_snowflake_agg_func(). The overall checking logic for if a function is supported should not be changing here |
||
) | ||
|
||
sort = groupby_kwargs.get("sort", True) | ||
|
@@ -6103,11 +6106,15 @@ def agg( | |
# by snowflake engine. | ||
# If we are using Named Aggregations, we need to do our supported check slightly differently. | ||
uses_named_aggs = using_named_aggregations_for_func(func) | ||
if not check_is_aggregation_supported_in_snowflake( | ||
( | ||
is_supported, | ||
unsupported_arguments, | ||
) = check_is_aggregation_supported_in_snowflake( | ||
func, kwargs, axis, _is_df_agg=True | ||
): | ||
ErrorMessage.not_implemented( | ||
f"Snowpark pandas aggregate does not yet support the aggregation {repr_aggregate_function(func, kwargs)} with the given arguments." | ||
) | ||
if not is_supported: | ||
raise AttributeError( | ||
f"'SeriesGroupBy' object has no attribute '{unsupported_arguments}'" | ||
) | ||
|
||
query_compiler = self | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -505,6 +505,26 @@ def test_groupby_agg_on_groupby_columns( | |
) | ||
|
||
|
||
@pytest.mark.parametrize("as_index", [True, False]) | ||
@pytest.mark.parametrize("sort", [True, False]) | ||
@sql_count_checker(query_count=1) | ||
def test_groupby_agg_with_incorrect_func(as_index, sort) -> None: | ||
basic_snowpark_pandas_df = pd.DataFrame( | ||
data=8 * [range(3)], columns=["a", "b", "c"] | ||
) | ||
# basic_snowpark_pandas_df = basic_snowpark_pandas_df.groupby(['a', 'b']).sum() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did you mean to delete this? |
||
native_pandas = basic_snowpark_pandas_df.to_pandas() | ||
eval_snowpark_pandas_result( | ||
basic_snowpark_pandas_df, | ||
native_pandas, | ||
lambda df: df.groupby(by="a", sort=sort, as_index=as_index).agg( | ||
{"b": sum, "c": "COUNT"} | ||
), | ||
expect_exception=True, | ||
expect_exception_match="'SeriesGroupBy' object has no attribute 'COUNT'", | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"by", ["col1", ["col1", "col2", "col3"], ["col1", "col1", "col2"]] | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!