Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a functionality in apply_in_pandas to support spark api #3162

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/snowflake/snowpark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
_use_structured_type_semantics = False
_use_structured_type_semantics_lock = threading.RLock()

# This is an internal-only global flag, used to determine whether the api is called from snowflake.snowpark_connect
_is_called_from_snowpark_connect = False


def _should_use_structured_type_semantics():
global _use_structured_type_semantics
Expand Down
32 changes: 31 additions & 1 deletion src/snowflake/snowpark/relational_grouped_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
import inspect

import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto
import snowflake.snowpark.context as context
from snowflake.connector.options import pandas
from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
from snowflake.snowpark import functions
from snowflake.snowpark._internal.analyzer.expression import (
Expression,
Expand Down Expand Up @@ -403,8 +406,36 @@ def apply_in_pandas(
- :func:`~snowflake.snowpark.functions.pandas_udtf`
"""

partition_by = [Column(expr, _emit_ast=False) for expr in self._grouping_exprs]

# this is the case where this is being called from spark
# this is not handleing nested column access, it is assuming that the access in the function is not nested
original_columns: List[str] | None = None
key_columns: List[str] | None = None
if context._is_called_from_snowpark_connect:
if self._dataframe._column_map is not None:
original_columns = [
column.spark_name for column in self._dataframe._column_map.columns
]
signature = inspect.signature(func)
parameters = signature.parameters
if len(parameters) == 2:
key_columns = [
unquote_if_quoted(col.get_name()) for col in partition_by
]

class _ApplyInPandas:
def end_partition(self, pdf: pandas.DataFrame) -> pandas.DataFrame:
if key_columns is not None:
import numpy as np

key_list = [pdf[key].iloc[0] for key in key_columns]
numpy_array = np.array(key_list)
keys = tuple(numpy_array)
if original_columns is not None:
pdf.columns = original_columns
if key_columns is not None:
return func(keys, pdf)
Comment on lines +429 to +438
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if key_columns is not None:
import numpy as np
key_list = [pdf[key].iloc[0] for key in key_columns]
numpy_array = np.array(key_list)
keys = tuple(numpy_array)
if original_columns is not None:
pdf.columns = original_columns
if key_columns is not None:
return func(keys, pdf)
if original_columns is not None:
pdf.columns = original_columns
if key_columns is not None:
import numpy as np
key_list = [pdf[key].iloc[0] for key in key_columns]
numpy_array = np.array(key_list)
keys = tuple(numpy_array)
return func(keys, pdf)

nit: can we restructure it this way? grouping the if statements together

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't work because we are changing the pdf in
if original_columns is not None: pdf.columns = original_columns

return func(pdf)

# for vectorized UDTF
Expand All @@ -426,7 +457,6 @@ def end_partition(self, pdf: pandas.DataFrame) -> pandas.DataFrame:
_emit_ast=_emit_ast,
**kwargs,
)
partition_by = [Column(expr, _emit_ast=False) for expr in self._grouping_exprs]

df = self._dataframe.select(
_apply_in_pandas_udtf(*self._dataframe.columns).over(
Expand Down
Loading