Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1887104: Update OBJECT semantics #2971

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def convert_sf_to_sp_type(
return ArrayType(semi_structured_fill)
if column_type_name == "VARIANT":
return VariantType()
if context._should_use_structured_type_semantics() and column_type_name == "OBJECT":
return StructType()
if column_type_name in {"OBJECT", "MAP"}:
return MapType(semi_structured_fill, semi_structured_fill)
if column_type_name == "GEOGRAPHY":
Expand Down Expand Up @@ -678,6 +680,10 @@ def python_type_to_snow_type(
if tp_args
else None
)
if (
key_type is None or value_type is None
) and context._should_use_structured_type_semantics():
return StructType(), False
return MapType(key_type, value_type), False

if installed_pandas:
Expand Down
6 changes: 5 additions & 1 deletion src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,11 @@ def create_python_udf_or_sp(

if replace and if_not_exists:
raise ValueError("options replace and if_not_exists are incompatible")
if isinstance(return_type, StructType) and not return_type.structured:
if (
isinstance(return_type, StructType)
and not return_type.structured
and object_type != TempObjectType.AGGREGATE_FUNCTION
):
return_sql = f'RETURNS TABLE ({",".join(f"{field.name} {convert_sp_to_sf_type(field.datatype)}" for field in return_type.fields)})'
elif installed_pandas and isinstance(return_type, PandasDataFrameType):
return_sql = f'RETURNS TABLE ({",".join(f"{name} {convert_sp_to_sf_type(datatype)}" for name, datatype in zip(return_type.col_names, return_type.col_types))})'
Expand Down
16 changes: 3 additions & 13 deletions src/snowflake/snowpark/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,15 +403,9 @@ def __init__(
value_contains_null: bool = True,
) -> None:
if context._should_use_structured_type_semantics():
if (key_type is None and value_type is not None) or (
key_type is not None and value_type is None
):
raise ValueError(
"Must either set both key_type and value_type or leave both unset."
)
self.structured = (
structured if structured is not None else key_type is not None
)
if key_type is None or value_type is None:
raise ValueError("MapType requires key and value type be set.")
self.structured = True
self.key_type = key_type
self.value_type = value_type
else:
Expand Down Expand Up @@ -476,10 +470,6 @@ def valueType(self):

def _fill_ast(self, ast: proto.SpDataType) -> None:
ast.sp_map_type.structured = self.structured
if self.key_type is None or self.value_type is None:
raise NotImplementedError(
"SNOW-1862700: AST does not support empty key or value type."
)
self.key_type._fill_ast(ast.sp_map_type.key_ty)
self.value_type._fill_ast(ast.sp_map_type.value_ty)

Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/snowpark/udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
warning,
)
from snowflake.snowpark.column import Column
from snowflake.snowpark.types import DataType, MapType
from snowflake.snowpark.types import DataType, MapType, StructType

# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable
# Python 3.9 can use both
Expand Down Expand Up @@ -717,7 +717,7 @@ def _do_register_udaf(
"_do_register_udaf",
"Snowflake does not support structured maps as return type for UDAFs. Downcasting to semi-structured object.",
)
return_type = MapType()
return_type = StructType()

# Capture original parameters.
if _emit_ast:
Expand Down
25 changes: 8 additions & 17 deletions tests/integ/scala/test_datatype_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

# Many of the tests have been moved to unit/scala/test_datattype_suite.py
from decimal import Decimal
from unittest import mock

import logging
import pytest
Expand Down Expand Up @@ -533,18 +532,10 @@ def test_structured_dtypes_negative(structured_type_session, structured_type_sup
if not structured_type_support:
pytest.skip("Test requires structured type support.")

# SNOW-1862700: Map Type missing element or value fails to generate AST.
with pytest.raises(
NotImplementedError, match="AST does not support empty key or value type."
):
x = MapType()
x._fill_ast(mock.Mock())
with pytest.raises(ValueError, match="MapType requires key and value type be set."):
MapType()

# Maptype requires both key and value type be set if either is set
with pytest.raises(
ValueError,
match="Must either set both key_type and value_type or leave both unset.",
):
with pytest.raises(ValueError, match="MapType requires key and value type be set."):
MapType(StringType())


Expand Down Expand Up @@ -586,7 +577,7 @@ def finish(self) -> dict:
"Snowflake does not support structured maps as return type for UDAFs. Downcasting to semi-structured object."
in caplog.text
)
assert MapCollector._return_type == MapType()
assert MapCollector._return_type == StructType()


@pytest.mark.skipif(
Expand Down Expand Up @@ -978,8 +969,8 @@ def test_structured_dtypes_cast(structured_type_session, structured_type_support
expected_semi_schema = StructType(
[
StructField("ARR", ArrayType(), nullable=True),
StructField("MAP", MapType(), nullable=True),
StructField("OBJ", MapType(), nullable=True),
StructField("MAP", StructType(), nullable=True),
StructField("OBJ", StructType(), nullable=True),
]
)
expected_structured_schema = StructType(
Expand Down Expand Up @@ -1008,8 +999,8 @@ def test_structured_dtypes_cast(structured_type_session, structured_type_support
schema=StructType(
[
StructField("arr", ArrayType()),
StructField("map", MapType()),
StructField("obj", MapType()),
StructField("map", StructType()),
StructField("obj", StructType()),
]
),
)
Expand Down
Loading