Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1952124 decimal coercion #3159

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#### Improvements

- Improved query generation for `Dataframe.stat.sample_by` to generate a single flat query that scales well with large `fractions` dictionary compared to older method of creating a UNION ALL subquery for each key in `fractions`. To enable this feature, set `session.conf.set("use_simplified_query_generation", True)`.
- `DataFrame.fillna` and `DataFrame.replace` now both support fitting `int` and `float` into `Decimal` columns if `include_decimal` is set to True.

### Snowpark Local Testing Updates

Expand Down
11 changes: 8 additions & 3 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@
)

if TYPE_CHECKING:
import snowflake.snowpark.column

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have this import?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type checking. My editor was telling me that the type-hint at some part of this file was broken


try:
from snowflake.connector.cursor import ResultMetadataV2
except ImportError:
Expand Down Expand Up @@ -164,9 +166,11 @@ def convert_metadata_to_sp_type(
return StructType(
[
StructField(
field.name
if context._should_use_structured_type_semantics()
else quote_name(field.name, keep_case=True),
(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding for better readability?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My editor has black formatting built-in on save, it was done by it automatically. I can revert this if people would like me to!

field.name
if context._should_use_structured_type_semantics()
else quote_name(field.name, keep_case=True)
),
convert_metadata_to_sp_type(field, max_string_size),
nullable=field.is_nullable,
_is_column=False,
Expand Down Expand Up @@ -358,6 +362,7 @@ def convert_sp_to_sf_type(datatype: DataType, nullable_override=None) -> str:
)


# TODO: these tuples of types can be used with isinstance, but not as a type-hints
VALID_PYTHON_TYPES_FOR_LITERAL_VALUE = (
*PYTHON_TO_SNOW_TYPE_MAPPINGS.keys(),
list,
Expand Down
41 changes: 34 additions & 7 deletions src/snowflake/snowpark/dataframe_na_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from snowflake.snowpark.functions import iff, lit, when
from snowflake.snowpark.types import (
DataType,
DecimalType,
DoubleType,
FloatType,
IntegerType,
Expand All @@ -44,20 +45,29 @@


def _is_value_type_matching_for_na_function(
value: LiteralType, datatype: DataType
value: LiteralType,
datatype: DataType,
include_decimal: bool = False,
) -> bool:
# Python `int` can match into FloatType/DoubleType,
# but Python `float` can't match IntegerType/LongType.
# None should be compatible with any Snowpark type.
int_types = (IntegerType, LongType, FloatType, DoubleType)
float_types = (FloatType, DoubleType)
# Python `int` and `float` can also match for DecimalType,
# for now this is protected by this argument
if include_decimal:
int_types = (int_types, DecimalType)
float_types = (float_types, DecimalType)
Comment on lines +60 to +61
Copy link
Preview

Copilot AI Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider flattening the tuple for float_types by concatenating DecimalType instead of nesting tuples (e.g., float_types = float_types + (DecimalType,)). This adjustment will allow isinstance checks to correctly include DecimalType when include_decimal is True.

Suggested change
int_types = (int_types, DecimalType)
float_types = (float_types, DecimalType)
int_types = int_types + (DecimalType,)
float_types = float_types + (DecimalType,)

Copilot is powered by AI, so mistakes are possible. Review output carefully before use.

Positive Feedback
Negative Feedback

Provide additional feedback

Please help us improve GitHub Copilot by sharing more details about this comment.

Please select one or more of the options
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From https://docs.python.org/3/library/functions.html#isinstance

If classinfo is a tuple of type objects (or recursively, other such tuples)

So the code should work as expected. Unless a maintainer would like me to change this I won't

return (
value is None
or (
isinstance(value, int)
# bool is a subclass of int, but we don't want to consider it numeric
and not isinstance(value, bool)
and isinstance(datatype, (IntegerType, LongType, FloatType, DoubleType))
and isinstance(datatype, int_types)
)
or (isinstance(value, float) and isinstance(datatype, (FloatType, DoubleType)))
or (isinstance(value, float) and isinstance(datatype, float_types))
or isinstance(datatype, type(python_type_to_snow_type(type(value))[0]))
)

Expand Down Expand Up @@ -273,6 +283,9 @@ def fill(
value: Union[LiteralType, Dict[str, LiteralType]],
subset: Optional[Union[str, Iterable[str]]] = None,
_emit_ast: bool = True,
*,
# keyword only arguments
include_decimal: bool = False,
) -> "snowflake.snowpark.DataFrame":
"""
Returns a new DataFrame that replaces all null and NaN values in the specified
Expand All @@ -289,6 +302,8 @@ def fill(
* If ``subset`` is not provided or ``None``, all columns will be included.

* If ``subset`` is empty, the method returns the original DataFrame.
include_decimal: Whether to allow ``Decimal`` values to fill in ``IntegerType``
and ``FloatType`` columns.

Examples::

Expand Down Expand Up @@ -445,7 +460,9 @@ def fill(
col = self._dataframe.col(col_name)
if col_name in normalized_value_dict:
value = normalized_value_dict[col_name]
if _is_value_type_matching_for_na_function(value, datatype):
if _is_value_type_matching_for_na_function(
value, datatype, include_decimal=include_decimal
):
if isinstance(datatype, (FloatType, DoubleType)):
# iff(col = 'NaN' or col is null, value, col)
res_columns.append(
Expand Down Expand Up @@ -489,6 +506,9 @@ def replace(
value: Optional[Union[LiteralType, Iterable[LiteralType]]] = None,
subset: Optional[Union[str, Iterable[str]]] = None,
_emit_ast: bool = True,
*,
# keyword only arguments
include_decimal: bool = False,
) -> "snowflake.snowpark.DataFrame":
"""
Returns a new DataFrame that replaces values in the specified columns.
Expand All @@ -508,7 +528,8 @@ def replace(
replaced. If ``cols`` is not provided or ``None``, the replacement
will be applied to all columns. If ``cols`` is empty, the method
returns the original DataFrame.

include_decimal: Whether to allow ``Decimal`` values to replace ``IntegerType``
and ``FloatType`` values.
Examples::

>>> df = session.create_dataframe([[1, 1.0, "1.0"], [2, 2.0, "2.0"]], schema=["a", "b", "c"])
Expand Down Expand Up @@ -678,8 +699,14 @@ def replace(
case_when = None
for key, value in replacement.items():
if _is_value_type_matching_for_na_function(
key, datatype
) and _is_value_type_matching_for_na_function(value, datatype):
key,
datatype,
include_decimal=include_decimal,
) and _is_value_type_matching_for_na_function(
value,
datatype,
include_decimal=include_decimal,
):
cond = col.is_null() if key is None else (col == lit(key))
replace_value = lit(None) if value is None else lit(value)
case_when = (
Expand Down
81 changes: 80 additions & 1 deletion tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#
import copy
import datetime
import decimal
import json
import logging
import math
Expand Down Expand Up @@ -2562,8 +2563,44 @@ def test_fillna(session, local_testing_mode):
df.fillna(1, subset={1: "a"})
assert _SUBSET_CHECK_ERROR_MESSAGE in str(ex_info)

# Fill Decimal columns with int
Utils.check_answer(
TestData.null_data4(session).fillna(123, include_decimal=True),
[
Row(decimal.Decimal(1), decimal.Decimal(123)),
Row(decimal.Decimal(123), 2),
],
sort=False,
)
# Fill Decimal columns with float
Utils.check_answer(
TestData.null_data4(session).fillna(123.0, include_decimal=True),
[
Row(decimal.Decimal(1), decimal.Decimal(123)),
Row(decimal.Decimal(123), 2),
],
sort=False,
)
Comment on lines +2566 to +2583
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just occurred to me that, for symmetry, we should support filling int/long/float/double columns with decimals... But for some reason Spark does not allow that, so I don't know what the right answer is.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's because int/float can be safely fit into decimal, but not the other way around

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with the type conversion. And spark explicitly mentioned that decimal can't be the input for replace api.

# Making sure default still reflects old behavior
Utils.check_answer(
TestData.null_data4(session).fillna(123),
[
Row(decimal.Decimal(1), None),
Row(None, 2),
],
sort=False,
)
Utils.check_answer(
TestData.null_data4(session).fillna(123.0),
[
Row(decimal.Decimal(1), None),
Row(None, 2),
],
sort=False,
)

def test_replace_with_coercion(session):

def test_replace_with_coercion(session, local_testing_mode):
df = session.create_dataframe(
[[1, 1.0, "1.0"], [2, 2.0, "2.0"]], schema=["a", "b", "c"]
)
Expand Down Expand Up @@ -2634,6 +2671,48 @@ def test_replace_with_coercion(session):
with pytest.raises(ValueError) as ex_info:
df.replace([1], [2, 3])
assert "to_replace and value lists should be of the same length" in str(ex_info)
if local_testing_mode:
# SNOW-1989698: local test gap
return
# Replace Decimal value with int
Utils.check_answer(
Comment on lines +2677 to +2678
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems related to coercion in localtesting, you can temporarily skip the local testing by:

def test_replace_with_coercion(session, local_testing_mode):
    # existing code
    # ...
    if local_testing_mode:
        # SNOW-xxxx local test gap
        return
    # your new test code

TestData.null_data4(session).replace(
decimal.Decimal(1), 123, include_decimal=True
),
[
Row(decimal.Decimal(123), None),
Row(None, 2),
],
sort=False,
)
# Replace Decimal value with float
Utils.check_answer(
TestData.null_data4(session).replace(
decimal.Decimal(1), 123.0, include_decimal=True
),
[
Row(decimal.Decimal(123.0), None),
Row(None, 2),
],
sort=False,
)
# Make sure old behavior is untouched
Utils.check_answer(
TestData.null_data4(session).replace(decimal.Decimal(1), 123),
[
Row(decimal.Decimal(1), None),
Row(None, 2),
],
sort=False,
)
Utils.check_answer(
TestData.null_data4(session).replace(decimal.Decimal(1), 123.0),
[
Row(decimal.Decimal(1), None),
Row(None, 2),
],
sort=False,
)


@pytest.mark.skipif(
Expand Down
9 changes: 9 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,15 @@ def null_data3(cls, session: "Session", local_testing_mode=False) -> DataFrame:
)
)

@classmethod
def null_data4(cls, session: "Session") -> DataFrame:
return session.create_dataframe(
[
[Decimal(1), None],
[None, Decimal(2)],
]
).to_df(["a", "b"])

@classmethod
def integer1(cls, session: "Session") -> DataFrame:
return session.create_dataframe([[1], [2], [3]]).to_df(["a"])
Expand Down
Loading