-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SNOW-1952124 decimal coercion #3159
base: main
Are you sure you want to change the base?
Changes from all commits
f649f7a
bafd78d
4356c51
a9f408d
3cb3b56
b11bd84
a6e8153
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -94,6 +94,8 @@ | |
) | ||
|
||
if TYPE_CHECKING: | ||
import snowflake.snowpark.column | ||
|
||
try: | ||
from snowflake.connector.cursor import ResultMetadataV2 | ||
except ImportError: | ||
|
@@ -164,9 +166,11 @@ def convert_metadata_to_sp_type( | |
return StructType( | ||
[ | ||
StructField( | ||
field.name | ||
if context._should_use_structured_type_semantics() | ||
else quote_name(field.name, keep_case=True), | ||
( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding for better readability? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My editor has black formatting built-in on save, it was done by it automatically. I can revert this if people would like me to! |
||
field.name | ||
if context._should_use_structured_type_semantics() | ||
else quote_name(field.name, keep_case=True) | ||
), | ||
convert_metadata_to_sp_type(field, max_string_size), | ||
nullable=field.is_nullable, | ||
_is_column=False, | ||
|
@@ -358,6 +362,7 @@ def convert_sp_to_sf_type(datatype: DataType, nullable_override=None) -> str: | |
) | ||
|
||
|
||
# TODO: these tuples of types can be used with isinstance, but not as a type-hints | ||
VALID_PYTHON_TYPES_FOR_LITERAL_VALUE = ( | ||
*PYTHON_TO_SNOW_TYPE_MAPPINGS.keys(), | ||
list, | ||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -26,6 +26,7 @@ | |||||||||
from snowflake.snowpark.functions import iff, lit, when | ||||||||||
from snowflake.snowpark.types import ( | ||||||||||
DataType, | ||||||||||
DecimalType, | ||||||||||
DoubleType, | ||||||||||
FloatType, | ||||||||||
IntegerType, | ||||||||||
|
@@ -44,20 +45,29 @@ | |||||||||
|
||||||||||
|
||||||||||
def _is_value_type_matching_for_na_function( | ||||||||||
value: LiteralType, datatype: DataType | ||||||||||
value: LiteralType, | ||||||||||
datatype: DataType, | ||||||||||
include_decimal: bool = False, | ||||||||||
) -> bool: | ||||||||||
# Python `int` can match into FloatType/DoubleType, | ||||||||||
# but Python `float` can't match IntegerType/LongType. | ||||||||||
# None should be compatible with any Snowpark type. | ||||||||||
int_types = (IntegerType, LongType, FloatType, DoubleType) | ||||||||||
float_types = (FloatType, DoubleType) | ||||||||||
# Python `int` and `float` can also match for DecimalType, | ||||||||||
# for now this is protected by this argument | ||||||||||
if include_decimal: | ||||||||||
int_types = (int_types, DecimalType) | ||||||||||
float_types = (float_types, DecimalType) | ||||||||||
Comment on lines
+60
to
+61
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider flattening the tuple for float_types by concatenating DecimalType instead of nesting tuples (e.g., float_types = float_types + (DecimalType,)). This adjustment will allow isinstance checks to correctly include DecimalType when include_decimal is True.
Suggested change
Copilot is powered by AI, so mistakes are possible. Review output carefully before use. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From https://docs.python.org/3/library/functions.html#isinstance
So the code should work as expected. Unless a maintainer would like me to change this I won't |
||||||||||
return ( | ||||||||||
value is None | ||||||||||
or ( | ||||||||||
isinstance(value, int) | ||||||||||
# bool is a subclass of int, but we don't want to consider it numeric | ||||||||||
and not isinstance(value, bool) | ||||||||||
and isinstance(datatype, (IntegerType, LongType, FloatType, DoubleType)) | ||||||||||
and isinstance(datatype, int_types) | ||||||||||
) | ||||||||||
or (isinstance(value, float) and isinstance(datatype, (FloatType, DoubleType))) | ||||||||||
or (isinstance(value, float) and isinstance(datatype, float_types)) | ||||||||||
or isinstance(datatype, type(python_type_to_snow_type(type(value))[0])) | ||||||||||
) | ||||||||||
|
||||||||||
|
@@ -273,6 +283,9 @@ def fill( | |||||||||
value: Union[LiteralType, Dict[str, LiteralType]], | ||||||||||
subset: Optional[Union[str, Iterable[str]]] = None, | ||||||||||
_emit_ast: bool = True, | ||||||||||
*, | ||||||||||
# keyword only arguments | ||||||||||
include_decimal: bool = False, | ||||||||||
) -> "snowflake.snowpark.DataFrame": | ||||||||||
""" | ||||||||||
Returns a new DataFrame that replaces all null and NaN values in the specified | ||||||||||
|
@@ -289,6 +302,8 @@ def fill( | |||||||||
* If ``subset`` is not provided or ``None``, all columns will be included. | ||||||||||
|
||||||||||
* If ``subset`` is empty, the method returns the original DataFrame. | ||||||||||
include_decimal: Whether to allow ``Decimal`` values to fill in ``IntegerType`` | ||||||||||
and ``FloatType`` columns. | ||||||||||
|
||||||||||
Examples:: | ||||||||||
|
||||||||||
|
@@ -445,7 +460,9 @@ def fill( | |||||||||
col = self._dataframe.col(col_name) | ||||||||||
if col_name in normalized_value_dict: | ||||||||||
value = normalized_value_dict[col_name] | ||||||||||
if _is_value_type_matching_for_na_function(value, datatype): | ||||||||||
if _is_value_type_matching_for_na_function( | ||||||||||
value, datatype, include_decimal=include_decimal | ||||||||||
): | ||||||||||
if isinstance(datatype, (FloatType, DoubleType)): | ||||||||||
# iff(col = 'NaN' or col is null, value, col) | ||||||||||
res_columns.append( | ||||||||||
|
@@ -489,6 +506,9 @@ def replace( | |||||||||
value: Optional[Union[LiteralType, Iterable[LiteralType]]] = None, | ||||||||||
subset: Optional[Union[str, Iterable[str]]] = None, | ||||||||||
_emit_ast: bool = True, | ||||||||||
*, | ||||||||||
# keyword only arguments | ||||||||||
include_decimal: bool = False, | ||||||||||
) -> "snowflake.snowpark.DataFrame": | ||||||||||
""" | ||||||||||
Returns a new DataFrame that replaces values in the specified columns. | ||||||||||
|
@@ -508,7 +528,8 @@ def replace( | |||||||||
replaced. If ``cols`` is not provided or ``None``, the replacement | ||||||||||
will be applied to all columns. If ``cols`` is empty, the method | ||||||||||
returns the original DataFrame. | ||||||||||
|
||||||||||
include_decimal: Whether to allow ``Decimal`` values to replace ``IntegerType`` | ||||||||||
and ``FloatType`` values. | ||||||||||
Examples:: | ||||||||||
|
||||||||||
>>> df = session.create_dataframe([[1, 1.0, "1.0"], [2, 2.0, "2.0"]], schema=["a", "b", "c"]) | ||||||||||
|
@@ -678,8 +699,14 @@ def replace( | |||||||||
case_when = None | ||||||||||
for key, value in replacement.items(): | ||||||||||
if _is_value_type_matching_for_na_function( | ||||||||||
key, datatype | ||||||||||
) and _is_value_type_matching_for_na_function(value, datatype): | ||||||||||
key, | ||||||||||
datatype, | ||||||||||
include_decimal=include_decimal, | ||||||||||
) and _is_value_type_matching_for_na_function( | ||||||||||
value, | ||||||||||
datatype, | ||||||||||
include_decimal=include_decimal, | ||||||||||
): | ||||||||||
cond = col.is_null() if key is None else (col == lit(key)) | ||||||||||
replace_value = lit(None) if value is None else lit(value) | ||||||||||
case_when = ( | ||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
# | ||
import copy | ||
import datetime | ||
import decimal | ||
import json | ||
import logging | ||
import math | ||
|
@@ -2562,8 +2563,44 @@ def test_fillna(session, local_testing_mode): | |
df.fillna(1, subset={1: "a"}) | ||
assert _SUBSET_CHECK_ERROR_MESSAGE in str(ex_info) | ||
|
||
# Fill Decimal columns with int | ||
Utils.check_answer( | ||
TestData.null_data4(session).fillna(123, include_decimal=True), | ||
[ | ||
Row(decimal.Decimal(1), decimal.Decimal(123)), | ||
Row(decimal.Decimal(123), 2), | ||
], | ||
sort=False, | ||
) | ||
# Fill Decimal columns with float | ||
Utils.check_answer( | ||
TestData.null_data4(session).fillna(123.0, include_decimal=True), | ||
[ | ||
Row(decimal.Decimal(1), decimal.Decimal(123)), | ||
Row(decimal.Decimal(123), 2), | ||
], | ||
sort=False, | ||
) | ||
Comment on lines
+2566
to
+2583
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It just occurred to me that, for symmetry, we should support filling int/long/float/double columns with decimals... But for some reason Spark does not allow that, so I don't know what the right answer is. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that's because int/float can be safely fit into decimal, but not the other way around There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with the type conversion. And spark explicitly mentioned that |
||
# Making sure default still reflects old behavior | ||
Utils.check_answer( | ||
TestData.null_data4(session).fillna(123), | ||
[ | ||
Row(decimal.Decimal(1), None), | ||
Row(None, 2), | ||
], | ||
sort=False, | ||
) | ||
Utils.check_answer( | ||
TestData.null_data4(session).fillna(123.0), | ||
[ | ||
Row(decimal.Decimal(1), None), | ||
Row(None, 2), | ||
], | ||
sort=False, | ||
) | ||
|
||
def test_replace_with_coercion(session): | ||
|
||
def test_replace_with_coercion(session, local_testing_mode): | ||
df = session.create_dataframe( | ||
[[1, 1.0, "1.0"], [2, 2.0, "2.0"]], schema=["a", "b", "c"] | ||
) | ||
|
@@ -2634,6 +2671,48 @@ def test_replace_with_coercion(session): | |
with pytest.raises(ValueError) as ex_info: | ||
df.replace([1], [2, 3]) | ||
assert "to_replace and value lists should be of the same length" in str(ex_info) | ||
if local_testing_mode: | ||
# SNOW-1989698: local test gap | ||
return | ||
# Replace Decimal value with int | ||
Utils.check_answer( | ||
Comment on lines
+2677
to
+2678
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems related to coercion in localtesting, you can temporarily skip the local testing by: def test_replace_with_coercion(session, local_testing_mode):
# existing code
# ...
if local_testing_mode:
# SNOW-xxxx local test gap
return
# your new test code |
||
TestData.null_data4(session).replace( | ||
decimal.Decimal(1), 123, include_decimal=True | ||
), | ||
[ | ||
Row(decimal.Decimal(123), None), | ||
Row(None, 2), | ||
], | ||
sort=False, | ||
) | ||
# Replace Decimal value with float | ||
Utils.check_answer( | ||
TestData.null_data4(session).replace( | ||
decimal.Decimal(1), 123.0, include_decimal=True | ||
), | ||
[ | ||
Row(decimal.Decimal(123.0), None), | ||
Row(None, 2), | ||
], | ||
sort=False, | ||
) | ||
# Make sure old behavior is untouched | ||
Utils.check_answer( | ||
TestData.null_data4(session).replace(decimal.Decimal(1), 123), | ||
[ | ||
Row(decimal.Decimal(1), None), | ||
Row(None, 2), | ||
], | ||
sort=False, | ||
) | ||
Utils.check_answer( | ||
TestData.null_data4(session).replace(decimal.Decimal(1), 123.0), | ||
[ | ||
Row(decimal.Decimal(1), None), | ||
Row(None, 2), | ||
], | ||
sort=False, | ||
) | ||
|
||
|
||
@pytest.mark.skipif( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have this import?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type checking. My editor was telling me that the type-hint at some part of this file was broken