Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1899126: Fix malformed sql for functions that create internal alias #3136

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@

- Improved query generation for `Dataframe.stat.sample_by` to generate a single flat query that scales well with large `fractions` dictionary compared to older method of creating a UNION ALL subquery for each key in `fractions`. To enable this feature, set `session.conf.set("use_simplified_query_generation", True)`.

#### Bug Fixes

- Fixed a bug for the following functions that raised errors `.cast()` is applied to their output
- `from_json`
- `size`

### Snowpark Local Testing Updates

#### New Features
Expand Down
76 changes: 54 additions & 22 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
UpdateMergeExpression,
)
from snowflake.snowpark._internal.analyzer.unary_expression import (
_InternalAlias,
Alias,
Cast,
UnaryExpression,
Expand Down Expand Up @@ -213,7 +214,9 @@ def analyze(
if isinstance(expr, Like):
return like_expression(
self.analyze(
expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr.expr),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
self.analyze(
expr.pattern, df_aliased_col_name_to_real_col_name, parse_local_name
Expand All @@ -223,7 +226,9 @@ def analyze(
if isinstance(expr, RegExp):
return regexp_expression(
self.analyze(
expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr.expr),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
self.analyze(
expr.pattern, df_aliased_col_name_to_real_col_name, parse_local_name
Expand All @@ -243,7 +248,9 @@ def analyze(
)
return collate_expression(
self.analyze(
expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr.expr),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
collation_spec,
)
Expand All @@ -254,7 +261,9 @@ def analyze(
field = field.upper()
return subfield_expression(
self.analyze(
expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr.expr),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
field,
)
Expand All @@ -264,7 +273,7 @@ def analyze(
[
(
self.analyze(
condition,
self.internal_alias_extractor(condition),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
Expand All @@ -277,7 +286,7 @@ def analyze(
for condition, value in expr.branches
],
self.analyze(
expr.else_value,
self.internal_alias_extractor(expr.else_value),
df_aliased_col_name_to_real_col_name,
parse_local_name,
)
Expand Down Expand Up @@ -309,21 +318,23 @@ def analyze(
for expression in expr.values:
if self.session.eliminate_numeric_sql_value_cast_enabled:
in_value = self.to_sql_try_avoid_cast(
expression,
self.internal_alias_extractor(expression),
df_aliased_col_name_to_real_col_name,
parse_local_name,
)
else:
in_value = self.analyze(
expression,
self.internal_alias_extractor(expression),
df_aliased_col_name_to_real_col_name,
parse_local_name,
)

in_values.append(in_value)
return in_expression(
self.analyze(
expr.columns, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr.columns),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
in_values,
)
Expand Down Expand Up @@ -449,7 +460,9 @@ def analyze(
func_name,
[
self.analyze(
x, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(x),
df_aliased_col_name_to_real_col_name,
parse_local_name,
)
for x in expr.children
],
Expand Down Expand Up @@ -494,7 +507,9 @@ def analyze(
if isinstance(expr, SortOrder):
return order_expression(
self.analyze(
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr.child),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
expr.direction.sql,
expr.null_ordering.sql,
Expand All @@ -507,7 +522,9 @@ def analyze(
if isinstance(expr, WithinGroup):
return within_group_expression(
self.analyze(
expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr.expr),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
[
self.analyze(e, df_aliased_col_name_to_real_col_name)
Expand Down Expand Up @@ -558,7 +575,9 @@ def analyze(
if isinstance(expr, ListAgg):
return list_agg(
self.analyze(
expr.col, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr.col),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
str_to_sql(expr.delimiter),
expr.is_distinct,
Expand All @@ -578,7 +597,9 @@ def analyze(
return rank_related_function_expression(
expr.sql,
self.analyze(
expr.expr, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr.expr),
df_aliased_col_name_to_real_col_name,
parse_local_name,
),
expr.offset,
self.analyze(
Expand All @@ -593,6 +614,11 @@ def analyze(
str(expr)
) # pragma: no cover

def internal_alias_extractor(self, expr: Expression) -> Expression:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one last comment: can you add comment on when this function needs to called when analyzing expression to give some context to future us?

if isinstance(expr, _InternalAlias):
return expr.child
return expr

def table_function_expression_extractor(
self,
expr: TableFunctionExpression,
Expand Down Expand Up @@ -682,17 +708,19 @@ def unary_expression_extractor(
),
quoted_name,
)

child = self.internal_alias_extractor(expr.child)
if isinstance(expr, UnresolvedAlias):
expr_str = self.analyze(
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
child, df_aliased_col_name_to_real_col_name, parse_local_name
)
if parse_local_name:
expr_str = expr_str.upper()
return expr_str
elif isinstance(expr, Cast):
return cast_expression(
self.analyze(
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
child, df_aliased_col_name_to_real_col_name, parse_local_name
),
expr.to,
expr.try_,
Expand All @@ -702,7 +730,7 @@ def unary_expression_extractor(
else:
return unary_expression(
self.analyze(
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
child, df_aliased_col_name_to_real_col_name, parse_local_name
),
expr.sql_operator,
expr.operator_first,
Expand All @@ -714,21 +742,23 @@ def binary_operator_extractor(
df_aliased_col_name_to_real_col_name,
parse_local_name=False,
) -> str:
left = self.internal_alias_extractor(expr.left)
right = self.internal_alias_extractor(expr.right)
if self.session.eliminate_numeric_sql_value_cast_enabled:
left_sql_expr = self.to_sql_try_avoid_cast(
expr.left, df_aliased_col_name_to_real_col_name, parse_local_name
left, df_aliased_col_name_to_real_col_name, parse_local_name
)
right_sql_expr = self.to_sql_try_avoid_cast(
expr.right,
right,
df_aliased_col_name_to_real_col_name,
parse_local_name,
)
else:
left_sql_expr = self.analyze(
expr.left, df_aliased_col_name_to_real_col_name, parse_local_name
left, df_aliased_col_name_to_real_col_name, parse_local_name
)
right_sql_expr = self.analyze(
expr.right, df_aliased_col_name_to_real_col_name, parse_local_name
right, df_aliased_col_name_to_real_col_name, parse_local_name
)
if isinstance(expr, BinaryArithmeticExpression):
return binary_arithmetic_expression(
Expand Down Expand Up @@ -809,7 +839,9 @@ def to_sql_try_avoid_cast(
return str(expr.value).upper()
else:
return self.analyze(
expr, df_aliased_col_name_to_real_col_name, parse_local_name
self.internal_alias_extractor(expr),
df_aliased_col_name_to_real_col_name,
parse_local_name,
)

def resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
return {}


class _InternalAlias(Alias):
pass


class UnresolvedAlias(UnaryExpression, NamedExpression):
sql_operator = "AS"
operator_first = False
Expand Down
13 changes: 12 additions & 1 deletion src/snowflake/snowpark/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
Not,
UnaryMinus,
UnresolvedAlias,
_InternalAlias,
)
from snowflake.snowpark._internal.ast.utils import (
build_expr_from_python_val,
Expand Down Expand Up @@ -667,7 +668,7 @@ def in_(
[Row(IS_A_IN_B=True), Row(IS_A_IN_B=False), Row(IS_A_IN_B=False)]

Args:
vals: The lteral values, the columns in the same DataFrame, or a :class:`DataFrame` instance to use
vals: The literal values, the columns in the same DataFrame, or a :class:`DataFrame` instance to use
to check for membership against this column.
"""

Expand Down Expand Up @@ -1313,6 +1314,10 @@ def alias(self, alias: str, _emit_ast: bool = True) -> "Column":
"""Returns a new renamed Column. Alias of :func:`name`."""
return self.name(alias, variant="alias", _emit_ast=_emit_ast)

def _alias(self, alias: str) -> "Column":
"""Returns a new renamed Column called by functions that internally alias the result."""
return self.name(alias, variant="_alias", _emit_ast=False)

@publicapi
def name(
self,
Expand All @@ -1337,6 +1342,12 @@ def name(
elif variant == "name":
ast.fn.column_alias_fn_name = True

if variant == "_alias":
return Column(
_InternalAlias(expr, quote_name(alias)),
_ast=ast_expr,
_emit_ast=_emit_ast,
)
return Column(
Alias(expr, quote_name(alias)), _ast=ast_expr, _emit_ast=_emit_ast
)
Expand Down
8 changes: 4 additions & 4 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3831,7 +3831,7 @@ def array_remove_nulls(col: Column) -> Column:
),
separator=lit(sep, _emit_ast=False),
_emit_ast=False,
).alias(f"CONCAT_WS_IGNORE_NULLS('{sep}', {names})", _emit_ast=False)
)._alias(f"CONCAT_WS_IGNORE_NULLS('{sep}', {names})")


@publicapi
Expand Down Expand Up @@ -5964,7 +5964,7 @@ def window(
window_start,
lit("end", _emit_ast=False),
dateadd(window_unit, window_duration, window_start, _emit_ast=False),
).alias("window", _emit_ast=False)
)._alias("window")


@publicapi
Expand Down Expand Up @@ -6739,7 +6739,7 @@ def from_json(
return (
parse_json(e, _emit_ast=False)
.cast(schema, _emit_ast=False)
.alias(f"from_json({c.get_name()})", _emit_ast=False)
._alias(f"from_json({c.get_name()})")
)


Expand Down Expand Up @@ -7439,7 +7439,7 @@ def size(col: ColumnOrName, _emit_ast: bool = True) -> Column:
_emit_ast=False,
)
.otherwise(lit(None), _emit_ast=False)
.alias(f"SIZE({c.get_name()})", _emit_ast=False)
._alias(f"SIZE({c.get_name()})")
)
result._ast = ast
return result
Expand Down
40 changes: 39 additions & 1 deletion tests/integ/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,20 @@

from snowflake.snowpark import Row
from snowflake.snowpark.exceptions import SnowparkColumnException, SnowparkSQLException
from snowflake.snowpark.functions import col, lit, parse_json, when, hour, minute
from snowflake.snowpark.functions import (
col,
lit,
parse_json,
second,
to_timestamp,
when,
hour,
minute,
window,
)
from snowflake.snowpark.types import (
IntegerType,
StringType,
StructField,
StructType,
TimestampTimeZone,
Expand Down Expand Up @@ -144,6 +155,33 @@ def test_contains(session):
)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="window function is not supported in Local Testing",
)
def test_internal_alias(session):
df = session.create_dataframe(
[[datetime.datetime(1970, 1, 1, 0, 0, 0)]], schema=["ts"]
)
Utils.check_answer(
df.select(window(df.ts, "10 seconds").alias("my_alias")),
[
Row(
MY_ALIAS='{\n "end": "1970-01-01 00:00:10.000",\n "start": "1970-01-01 00:00:00.000"\n}'
)
],
)

Utils.check_answer(
df.select(window(df.ts, "10 seconds").cast(StringType())),
[Row('{"end":"1970-01-01 00:00:10.000","start":"1970-01-01 00:00:00.000"}')],
)
Utils.check_answer(
df.select(second(to_timestamp(window(df.ts, "10 seconds")["end"]))),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this second example what does the final column name end up being?

[Row(10)],
)


def test_when_accept_literal_value(session):
assert TestData.null_data1(session).select(
when(col("a").is_null(), 5).when(col("a") == 1, 6).otherwise(7).as_("a")
Expand Down
Loading