Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1899126: Fix malformed sql for functions that create internal alias #3136

Merged
merged 6 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
- Improved query generation for `Dataframe.stat.sample_by` to generate a single flat query that scales well with large `fractions` dictionary compared to older method of creating a UNION ALL subquery for each key in `fractions`. To enable this feature, set `session.conf.set("use_simplified_query_generation", True)`.
- `DataFrame.fillna` and `DataFrame.replace` now both support fitting `int` and `float` into `Decimal` columns if `include_decimal` is set to True.

#### Bug Fixes

- Fixed a bug for the following functions that raised errors `.cast()` is applied to their output
- `from_json`
- `size`

### Snowpark Local Testing Updates

#### New Features
Expand Down
83 changes: 61 additions & 22 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
UpdateMergeExpression,
)
from snowflake.snowpark._internal.analyzer.unary_expression import (
_InternalAlias,
Alias,
Cast,
UnaryExpression,
Expand Down Expand Up @@ -213,7 +214,9 @@ def analyze(
if isinstance(expr, Like):
return like_expression(
self.analyze(
expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr.expr),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
self.analyze(
expr.pattern, df_aliased_col_name_to_real_col_name, parse_local_name
Expand All @@ -223,7 +226,9 @@ def analyze(
if isinstance(expr, RegExp):
return regexp_expression(
self.analyze(
expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr.expr),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
self.analyze(
expr.pattern, df_aliased_col_name_to_real_col_name, parse_local_name
Expand All @@ -243,7 +248,9 @@ def analyze(
)
return collate_expression(
self.analyze(
expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr.expr),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
collation_spec,
)
Expand All @@ -254,7 +261,9 @@ def analyze(
field = field.upper()
return subfield_expression(
self.analyze(
expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr.expr),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
field,
)
Expand All @@ -264,7 +273,7 @@ def analyze(
[
(
self.analyze(
condition,
self.internal_alias_extractor(condition),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
Expand All @@ -277,7 +286,7 @@ def analyze(
for condition, value in expr.branches
],
self.analyze(
expr.else_value,
self.internal_alias_extractor(expr.else_value),
df_aliased_col_name_to_real_col_name,
parse_local_name,
)
Expand Down Expand Up @@ -309,21 +318,23 @@ def analyze(
for expression in expr.values:
if self.session.eliminate_numeric_sql_value_cast_enabled:
in_value = self.to_sql_try_avoid_cast(
expression,
self.internal_alias_extractor(expression),
df_aliased_col_name_to_real_col_name,
parse_local_name,
)
else:
in_value = self.analyze(
expression,
self.internal_alias_extractor(expression),
df_aliased_col_name_to_real_col_name,
parse_local_name,
)

in_values.append(in_value)
return in_expression(
self.analyze(
expr.columns, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr.columns),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
in_values,
)
Expand Down Expand Up @@ -449,7 +460,9 @@ def analyze(
func_name,
[
self.analyze(
x, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(x),
df_aliased_col_name_to_real_col_name,
parse_local_name,
)
for x in expr.children
],
Expand Down Expand Up @@ -494,7 +507,9 @@ def analyze(
if isinstance(expr, SortOrder):
return order_expression(
self.analyze(
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr.child),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
expr.direction.sql,
expr.null_ordering.sql,
Expand All @@ -507,7 +522,9 @@ def analyze(
if isinstance(expr, WithinGroup):
return within_group_expression(
self.analyze(
expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr.expr),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
[
self.analyze(e, df_aliased_col_name_to_real_col_name)
Expand Down Expand Up @@ -558,7 +575,9 @@ def analyze(
if isinstance(expr, ListAgg):
return list_agg(
self.analyze(
expr.col, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr.col),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
str_to_sql(expr.delimiter),
expr.is_distinct,
Expand All @@ -578,7 +597,9 @@ def analyze(
return rank_related_function_expression(
expr.sql,
self.analyze(
expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr.expr),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
expr.offset,
self.analyze(
Expand All @@ -593,6 +614,18 @@ def analyze(
str(expr)
) # pragma: no cover

def internal_alias_extractor(self, expr: Expression) -> Expression:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one last comment: can you add comment on when this function needs to called when analyzing expression to give some context to future us?

"""
This function is used to extract the internal alias of an expression. This function
needs to be called whenever an expr is coming from a Column object. This is done because
_InternalAlias is generated for implementing a few functions on the client-side with
the final output column being aliased internally. Such internal aliases need to be
dropped when they are not applied at the top level in a sql nesting level.
"""
if isinstance(expr, _InternalAlias):
return expr.child
return expr

def table_function_expression_extractor(
self,
expr: TableFunctionExpression,
Expand Down Expand Up @@ -682,17 +715,19 @@ def unary_expression_extractor(
),
quoted_name,
)

child = self.internal_alias_extractor(expr.child)
if isinstance(expr, UnresolvedAlias):
expr_str = self.analyze(
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
child, df_aliased_col_name_to_real_col_name, parse_local_name
)
if parse_local_name:
expr_str = expr_str.upper()
return expr_str
elif isinstance(expr, Cast):
return cast_expression(
self.analyze(
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
child, df_aliased_col_name_to_real_col_name, parse_local_name
),
expr.to,
expr.try_,
Expand All @@ -702,7 +737,7 @@ def unary_expression_extractor(
else:
return unary_expression(
self.analyze(
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
child, df_aliased_col_name_to_real_col_name, parse_local_name
),
expr.sql_operator,
expr.operator_first,
Expand All @@ -714,21 +749,23 @@ def binary_operator_extractor(
df_aliased_col_name_to_real_col_name,
parse_local_name=False,
) -> str:
left = self.internal_alias_extractor(expr.left)
right = self.internal_alias_extractor(expr.right)
if self.session.eliminate_numeric_sql_value_cast_enabled:
left_sql_expr = self.to_sql_try_avoid_cast(
expr.left, df_aliased_col_name_to_real_col_name, parse_local_name
left, df_aliased_col_name_to_real_col_name, parse_local_name
)
right_sql_expr = self.to_sql_try_avoid_cast(
expr.right,
right,
df_aliased_col_name_to_real_col_name,
parse_local_name,
)
else:
left_sql_expr = self.analyze(
expr.left, df_aliased_col_name_to_real_col_name, parse_local_name
left, df_aliased_col_name_to_real_col_name, parse_local_name
)
right_sql_expr = self.analyze(
expr.right, df_aliased_col_name_to_real_col_name, parse_local_name
right, df_aliased_col_name_to_real_col_name, parse_local_name
)
if isinstance(expr, BinaryArithmeticExpression):
return binary_arithmetic_expression(
Expand Down Expand Up @@ -809,7 +846,9 @@ def to_sql_try_avoid_cast(
return str(expr.value).upper()
else:
return self.analyze(
expr, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr),
df_aliased_col_name_to_real_col_name,
parse_local_name,
)

def resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
return {}


class _InternalAlias(Alias):
pass


class UnresolvedAlias(UnaryExpression, NamedExpression):
sql_operator = "AS"
operator_first = False
Expand Down
13 changes: 12 additions & 1 deletion src/snowflake/snowpark/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
Not,
UnaryMinus,
UnresolvedAlias,
_InternalAlias,
)
from snowflake.snowpark._internal.ast.utils import (
build_expr_from_python_val,
Expand Down Expand Up @@ -667,7 +668,7 @@ def in_(
[Row(IS_A_IN_B=True), Row(IS_A_IN_B=False), Row(IS_A_IN_B=False)]

Args:
vals: The lteral values, the columns in the same DataFrame, or a :class:`DataFrame` instance to use
vals: The literal values, the columns in the same DataFrame, or a :class:`DataFrame` instance to use
to check for membership against this column.
"""

Expand Down Expand Up @@ -1313,6 +1314,10 @@ def alias(self, alias: str, _emit_ast: bool = True) -> "Column":
"""Returns a new renamed Column. Alias of :func:`name`."""
return self.name(alias, variant="alias", _emit_ast=_emit_ast)

def _alias(self, alias: str) -> "Column":
"""Returns a new renamed Column called by functions that internally alias the result."""
return self.name(alias, variant="_alias", _emit_ast=False)

@publicapi
def name(
self,
Expand All @@ -1337,6 +1342,12 @@ def name(
elif variant == "name":
ast.fn.column_alias_fn_name = True

if variant == "_alias":
return Column(
_InternalAlias(expr, quote_name(alias)),
_ast=ast_expr,
_emit_ast=_emit_ast,
)
return Column(
Alias(expr, quote_name(alias)), _ast=ast_expr, _emit_ast=_emit_ast
)
Expand Down
8 changes: 4 additions & 4 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3842,7 +3842,7 @@ def array_remove_nulls(col: Column) -> Column:
),
separator=lit(sep, _emit_ast=False),
_emit_ast=False,
).alias(f"CONCAT_WS_IGNORE_NULLS('{sep}', {names})", _emit_ast=False)
)._alias(f"CONCAT_WS_IGNORE_NULLS('{sep}', {names})")


@publicapi
Expand Down Expand Up @@ -5987,7 +5987,7 @@ def window(
lit("end", _emit_ast=False),
dateadd(window_unit, window_duration, window_start, _emit_ast=False),
_emit_ast=False,
).alias("window", _emit_ast=False)
)._alias("window")
ans._ast = ast
return ans

Expand Down Expand Up @@ -6765,7 +6765,7 @@ def from_json(
ans = (
parse_json(e, _emit_ast=False)
.cast(schema, _emit_ast=False)
.alias(f"from_json({c.get_name()})", _emit_ast=False)
._alias(f"from_json({c.get_name()})")
)
ans._ast = ast
return ans
Expand Down Expand Up @@ -7467,7 +7467,7 @@ def size(col: ColumnOrName, _emit_ast: bool = True) -> Column:
_emit_ast=False,
)
.otherwise(lit(None), _emit_ast=False)
.alias(f"SIZE({c.get_name()})", _emit_ast=False)
._alias(f"SIZE({c.get_name()})")
)
result._ast = ast
return result
Expand Down
40 changes: 39 additions & 1 deletion tests/integ/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,20 @@

from snowflake.snowpark import Row
from snowflake.snowpark.exceptions import SnowparkColumnException, SnowparkSQLException
from snowflake.snowpark.functions import col, lit, parse_json, when, hour, minute
from snowflake.snowpark.functions import (
col,
lit,
parse_json,
second,
to_timestamp,
when,
hour,
minute,
window,
)
from snowflake.snowpark.types import (
IntegerType,
StringType,
StructField,
StructType,
TimestampTimeZone,
Expand Down Expand Up @@ -144,6 +155,33 @@ def test_contains(session):
)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="window function is not supported in Local Testing",
)
def test_internal_alias(session):
df = session.create_dataframe(
[[datetime.datetime(1970, 1, 1, 0, 0, 0)]], schema=["ts"]
)
Utils.check_answer(
df.select(window(df.ts, "10 seconds").alias("my_alias")),
[
Row(
MY_ALIAS='{\n "end": "1970-01-01 00:00:10.000",\n "start": "1970-01-01 00:00:00.000"\n}'
)
],
)

Utils.check_answer(
df.select(window(df.ts, "10 seconds").cast(StringType())),
[Row('{"end":"1970-01-01 00:00:10.000","start":"1970-01-01 00:00:00.000"}')],
)
Utils.check_answer(
df.select(second(to_timestamp(window(df.ts, "10 seconds")["end"]))),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this second example what does the final column name end up being?

[Row(10)],
)


def test_when_accept_literal_value(session):
assert TestData.null_data1(session).select(
when(col("a").is_null(), 5).when(col("a") == 1, 6).otherwise(7).as_("a")
Expand Down
Loading