From 8ff535db2fc8b697df456640f0d542bf9b923716 Mon Sep 17 00:00:00 2001 From: Wufan Shangguan Date: Thu, 13 Mar 2025 16:46:53 -0700 Subject: [PATCH 1/7] add Decimaltype --- src/snowflake/snowpark/dataframe_na_functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/dataframe_na_functions.py b/src/snowflake/snowpark/dataframe_na_functions.py index 3113d7ecf..94dc08f6e 100644 --- a/src/snowflake/snowpark/dataframe_na_functions.py +++ b/src/snowflake/snowpark/dataframe_na_functions.py @@ -26,6 +26,7 @@ from snowflake.snowpark.functions import iff, lit, when from snowflake.snowpark.types import ( DataType, + DecimalType, DoubleType, FloatType, IntegerType, @@ -55,7 +56,9 @@ def _is_value_type_matching_for_na_function( isinstance(value, int) # bool is a subclass of int, but we don't want to consider it numeric and not isinstance(value, bool) - and isinstance(datatype, (IntegerType, LongType, FloatType, DoubleType)) + and isinstance( + datatype, (IntegerType, LongType, FloatType, DoubleType, DecimalType) + ) ) or (isinstance(value, float) and isinstance(datatype, (FloatType, DoubleType))) or isinstance(datatype, type(python_type_to_snow_type(type(value))[0])) From 29ca3cd31ee91d1a1c2d94506f6a149d115b449f Mon Sep 17 00:00:00 2001 From: Mark Keller Date: Fri, 14 Mar 2025 12:02:03 -0700 Subject: [PATCH 2/7] formatting and type-hint fixes --- src/snowflake/snowpark/_internal/type_utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py index b59ae1d82..85f9de545 100644 --- a/src/snowflake/snowpark/_internal/type_utils.py +++ b/src/snowflake/snowpark/_internal/type_utils.py @@ -94,6 +94,8 @@ ) if TYPE_CHECKING: + import snowflake.snowpark.column + try: from snowflake.connector.cursor import ResultMetadataV2 except ImportError: @@ -164,9 +166,11 @@ def convert_metadata_to_sp_type( return StructType( [ StructField( - field.name - if context._should_use_structured_type_semantics() - else quote_name(field.name, keep_case=True), + ( + field.name + if context._should_use_structured_type_semantics() + else quote_name(field.name, keep_case=True) + ), convert_metadata_to_sp_type(field, max_string_size), nullable=field.is_nullable, _is_column=False, @@ -358,6 +362,7 @@ def convert_sp_to_sf_type(datatype: DataType, nullable_override=None) -> str: ) +# TODO: these tuples of types can be used with isinstance, but not as a type-hints VALID_PYTHON_TYPES_FOR_LITERAL_VALUE = ( *PYTHON_TO_SNOW_TYPE_MAPPINGS.keys(), list, From 02f5ad30bd6ce499bcf9be22592beb000226280d Mon Sep 17 00:00:00 2001 From: Mark Keller Date: Fri, 14 Mar 2025 12:30:28 -0700 Subject: [PATCH 3/7] adding arguments to protect against beraking change for Decimal types --- .../snowpark/dataframe_na_functions.py | 37 +++++++++++++++---- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/src/snowflake/snowpark/dataframe_na_functions.py b/src/snowflake/snowpark/dataframe_na_functions.py index 94dc08f6e..7d6b7f9b3 100644 --- a/src/snowflake/snowpark/dataframe_na_functions.py +++ b/src/snowflake/snowpark/dataframe_na_functions.py @@ -45,22 +45,29 @@ def _is_value_type_matching_for_na_function( - value: LiteralType, datatype: DataType + value: LiteralType, + datatype: DataType, + include_decimal: bool = False, ) -> bool: # Python `int` can match into FloatType/DoubleType, # but Python `float` can't match IntegerType/LongType. # None should be compatible with any Snowpark type. + int_types = (IntegerType, LongType, FloatType, DoubleType) + float_types = (FloatType, DoubleType) + # Python `int` and `float` can also match for DecimalType, + # for now this is protected by this argument + if include_decimal: + int_types = (int_types, DecimalType) + float_types = (float_types, DecimalType) return ( value is None or ( isinstance(value, int) # bool is a subclass of int, but we don't want to consider it numeric and not isinstance(value, bool) - and isinstance( - datatype, (IntegerType, LongType, FloatType, DoubleType, DecimalType) - ) + and isinstance(datatype, int_types) ) - or (isinstance(value, float) and isinstance(datatype, (FloatType, DoubleType))) + or (isinstance(value, float) and isinstance(datatype, float_types)) or isinstance(datatype, type(python_type_to_snow_type(type(value))[0])) ) @@ -276,6 +283,9 @@ def fill( value: Union[LiteralType, Dict[str, LiteralType]], subset: Optional[Union[str, Iterable[str]]] = None, _emit_ast: bool = True, + *, + # keyword only arguments + include_decimal: bool = False, ) -> "snowflake.snowpark.DataFrame": """ Returns a new DataFrame that replaces all null and NaN values in the specified @@ -448,7 +458,9 @@ def fill( col = self._dataframe.col(col_name) if col_name in normalized_value_dict: value = normalized_value_dict[col_name] - if _is_value_type_matching_for_na_function(value, datatype): + if _is_value_type_matching_for_na_function( + value, datatype, include_decimal=include_decimal + ): if isinstance(datatype, (FloatType, DoubleType)): # iff(col = 'NaN' or col is null, value, col) res_columns.append( @@ -492,6 +504,9 @@ def replace( value: Optional[Union[LiteralType, Iterable[LiteralType]]] = None, subset: Optional[Union[str, Iterable[str]]] = None, _emit_ast: bool = True, + *, + # keyword only arguments + include_decimal: bool = False, ) -> "snowflake.snowpark.DataFrame": """ Returns a new DataFrame that replaces values in the specified columns. @@ -681,8 +696,14 @@ def replace( case_when = None for key, value in replacement.items(): if _is_value_type_matching_for_na_function( - key, datatype - ) and _is_value_type_matching_for_na_function(value, datatype): + key, + datatype, + include_decimal=include_decimal, + ) and _is_value_type_matching_for_na_function( + value, + datatype, + include_decimal=include_decimal, + ): cond = col.is_null() if key is None else (col == lit(key)) replace_value = lit(None) if value is None else lit(value) case_when = ( From 99c9e3b377750a5b110f94601d35cbc51c45efe1 Mon Sep 17 00:00:00 2001 From: Mark Keller Date: Fri, 14 Mar 2025 13:29:27 -0700 Subject: [PATCH 4/7] adding test for new behaviour --- tests/integ/test_dataframe.py | 76 +++++++++++++++++++++++++++++++++++ tests/utils.py | 9 +++++ 2 files changed, 85 insertions(+) diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index 19724d06d..81e2eae18 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -4,6 +4,7 @@ # import copy import datetime +import decimal import json import logging import math @@ -2562,6 +2563,42 @@ def test_fillna(session, local_testing_mode): df.fillna(1, subset={1: "a"}) assert _SUBSET_CHECK_ERROR_MESSAGE in str(ex_info) + # Fill Decimal columns with int + Utils.check_answer( + TestData.null_data4(session).fillna(123, include_decimal=True), + [ + Row(decimal.Decimal(1), decimal.Decimal(123)), + Row(decimal.Decimal(123), 2), + ], + sort=False, + ) + # Fill Decimal columns with float + Utils.check_answer( + TestData.null_data4(session).fillna(123.0, include_decimal=True), + [ + Row(decimal.Decimal(1), decimal.Decimal(123)), + Row(decimal.Decimal(123), 2), + ], + sort=False, + ) + # Making sure default still reflects old behavior + Utils.check_answer( + TestData.null_data4(session).fillna(123), + [ + Row(decimal.Decimal(1), None), + Row(None, 2), + ], + sort=False, + ) + Utils.check_answer( + TestData.null_data4(session).fillna(123.0), + [ + Row(decimal.Decimal(1), None), + Row(None, 2), + ], + sort=False, + ) + def test_replace_with_coercion(session): df = session.create_dataframe( @@ -2634,6 +2671,45 @@ def test_replace_with_coercion(session): with pytest.raises(ValueError) as ex_info: df.replace([1], [2, 3]) assert "to_replace and value lists should be of the same length" in str(ex_info) + # Replace Decimal value with int + Utils.check_answer( + TestData.null_data4(session).replace( + decimal.Decimal(1), 123, include_decimal=True + ), + [ + Row(decimal.Decimal(123), None), + Row(None, 2), + ], + sort=False, + ) + # Replace Decimal value with float + Utils.check_answer( + TestData.null_data4(session).replace( + decimal.Decimal(1), 123.0, include_decimal=True + ), + [ + Row(decimal.Decimal(123.0), None), + Row(None, 2), + ], + sort=False, + ) + # Make sure old behavior is untouched + Utils.check_answer( + TestData.null_data4(session).replace(decimal.Decimal(1), 123), + [ + Row(decimal.Decimal(1), None), + Row(None, 2), + ], + sort=False, + ) + Utils.check_answer( + TestData.null_data4(session).replace(decimal.Decimal(1), 123.0), + [ + Row(decimal.Decimal(1), None), + Row(None, 2), + ], + sort=False, + ) @pytest.mark.skipif( diff --git a/tests/utils.py b/tests/utils.py index 047334440..00538094d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -636,6 +636,15 @@ def null_data3(cls, session: "Session", local_testing_mode=False) -> DataFrame: ) ) + @classmethod + def null_data4(cls, session: "Session") -> DataFrame: + return session.create_dataframe( + [ + [Decimal(1), None], + [None, Decimal(2)], + ] + ).to_df(["a", "b"]) + @classmethod def integer1(cls, session: "Session") -> DataFrame: return session.create_dataframe([[1], [2], [3]]).to_df(["a"]) From cf255d926581a15788233b8aac8232aacfde09c2 Mon Sep 17 00:00:00 2001 From: Mark Keller Date: Fri, 14 Mar 2025 13:48:38 -0700 Subject: [PATCH 5/7] adding changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ee552cbf..44159fb9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ #### Improvements - Improved query generation for `Dataframe.stat.sample_by` to generate a single flat query that scales well with large `fractions` dictionary compared to older method of creating a UNION ALL subquery for each key in `fractions`. To enable this feature, set `session.conf.set("use_simplified_query_generation", True)`. +- `DataFrame.fillna` and `DataFrame.replace` now both support fitting `int` and `float` into `Decimal` columns if `include_decimal` is set to True. ### Snowpark Local Testing Updates From 2cb04434c00f5136713cb834f0060d6d9273d83d Mon Sep 17 00:00:00 2001 From: Mark Keller Date: Mon, 17 Mar 2025 11:37:30 -0700 Subject: [PATCH 6/7] adding doc-strings --- src/snowflake/snowpark/dataframe_na_functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/dataframe_na_functions.py b/src/snowflake/snowpark/dataframe_na_functions.py index 7d6b7f9b3..a4e2e415a 100644 --- a/src/snowflake/snowpark/dataframe_na_functions.py +++ b/src/snowflake/snowpark/dataframe_na_functions.py @@ -302,6 +302,8 @@ def fill( * If ``subset`` is not provided or ``None``, all columns will be included. * If ``subset`` is empty, the method returns the original DataFrame. + include_decimal: Whether to allow ``Decimal`` values to fill in ``IntegerType`` + and ``FloatType`` columns. Examples:: @@ -526,7 +528,8 @@ def replace( replaced. If ``cols`` is not provided or ``None``, the replacement will be applied to all columns. If ``cols`` is empty, the method returns the original DataFrame. - + include_decimal: Whether to allow ``Decimal`` values to replace ``IntegerType`` + and ``FloatType`` values. Examples:: >>> df = session.create_dataframe([[1, 1.0, "1.0"], [2, 2.0, "2.0"]], schema=["a", "b", "c"]) From 862c43125e68260d13d330473261576c967bce1f Mon Sep 17 00:00:00 2001 From: Mark Keller Date: Mon, 17 Mar 2025 11:46:46 -0700 Subject: [PATCH 7/7] fix local testing tests --- tests/integ/test_dataframe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index 81e2eae18..2de12c038 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -2600,7 +2600,7 @@ def test_fillna(session, local_testing_mode): ) -def test_replace_with_coercion(session): +def test_replace_with_coercion(session, local_testing_mode): df = session.create_dataframe( [[1, 1.0, "1.0"], [2, 2.0, "2.0"]], schema=["a", "b", "c"] ) @@ -2671,6 +2671,9 @@ def test_replace_with_coercion(session): with pytest.raises(ValueError) as ex_info: df.replace([1], [2, 3]) assert "to_replace and value lists should be of the same length" in str(ex_info) + if local_testing_mode: + # SNOW-1989698: local test gap + return # Replace Decimal value with int Utils.check_answer( TestData.null_data4(session).replace(