Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yzou poc eager async describe #3015

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/metadata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
SnowflakeValues,
)
from snowflake.snowpark._internal.analyzer.unary_expression import UnresolvedAlias
from snowflake.snowpark._internal.analyzer.analyzer_utils import schema_value_statement

if TYPE_CHECKING:
from snowflake.snowpark._internal.analyzer.analyzer import Analyzer
Expand Down Expand Up @@ -183,6 +184,8 @@ def infer_metadata(
# as it can be retrieved later from attributes
if attributes is not None:
quoted_identifiers = None
if isinstance(source_plan, SelectStatement):
source_plan._attributes = attributes

return PlanMetadata(attributes=attributes, quoted_identifiers=quoted_identifiers)

Expand All @@ -200,6 +203,7 @@ def cache_metadata_if_select_statement(
and source_plan.analyzer.session.reduce_describe_query_enabled
):
source_plan._attributes = metadata.attributes
source_plan._schema_query = schema_value_statement(metadata.attributes)
# When source_plan doesn't have a projection, it's a simple `SELECT * from ...`,
# which means source_plan has the same metadata as its child plan,
# we should cache it on the child plan too.
Expand Down
20 changes: 19 additions & 1 deletion src/snowflake/snowpark/_internal/analyzer/schema_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
import time
import traceback
from typing import TYPE_CHECKING, List, Union
from typing import TYPE_CHECKING, List, Union, Any

import snowflake.snowpark
from snowflake.connector.cursor import ResultMetadata, SnowflakeCursor
Expand Down Expand Up @@ -104,6 +104,24 @@ def analyze_attributes(
return attributes


def get_attributes_from_sync_job(async_job: "AsyncJob", session: "snowflake.snowpark.session.Session") -> List[Attribute]:
new_cursor = async_job._cursor
new_cursor.get_results_from_sfqid(async_job.query_id)

return convert_result_meta_to_attribute(new_cursor.description, session._conn.max_string_size)


def describe_attributes_async(sql: str, session: "snowflake.snowpark.session.Session") -> "AsyncJob":
# lowercase = sql.strip().lower()
results_cursor = session.connection.cursor()._describe_internal_async(sql)
from snowflake.snowpark.async_job import AsyncJob
return AsyncJob(
results_cursor["queryId"],
sql,
session,
)


def convert_result_meta_to_attribute(
meta: Union[List[ResultMetadata], List["ResultMetadataV2"]], # pyright: ignore
max_string_size: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,9 @@ def schema_query(self) -> str:
if not self.projection:
self._schema_query = self.from_.schema_query
return self._schema_query
if self._attributes is not None:
self._schema_query = schema_value_statement(self._attributes)
return self._schema_query
self._schema_query = f"{analyzer_utils.SELECT}{self.projection_in_str}{analyzer_utils.FROM}({self.from_.schema_query})"
return self._schema_query

Expand Down
33 changes: 31 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import re
import sys
import uuid
import logging
import time
from collections import defaultdict
from enum import Enum
from functools import cached_property
Expand Down Expand Up @@ -88,7 +90,7 @@
cache_metadata_if_select_statement,
infer_metadata,
)
from snowflake.snowpark._internal.analyzer.schema_utils import analyze_attributes
from snowflake.snowpark._internal.analyzer.schema_utils import analyze_attributes, describe_attributes_async, get_attributes_from_sync_job
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import (
DynamicTableCreateMode,
LogicalPlan,
Expand All @@ -111,6 +113,7 @@
)
from snowflake.snowpark.row import Row
from snowflake.snowpark.types import StructType
from snowflake.connector.cursor import SnowflakeCursor

# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable
# Python 3.9 can use both
Expand All @@ -121,6 +124,9 @@
from collections.abc import Iterable


logger = logging.getLogger(__name__)


class SnowflakePlan(LogicalPlan):
class Decorator:
__wrap_exception_regex_match = re.compile(
Expand Down Expand Up @@ -266,6 +272,16 @@ def __init__(
)
self._plan_state: Optional[Dict[PlanState, Any]] = None

if (self._metadata.attributes is None) and (self.session._analyzer.session.reduce_describe_query_enabled):
start = time.time()
logger.info(f"getting plan attributes with query {self.schema_query}")
attributes = analyze_attributes(self.schema_query, self.session)
self._metadata = PlanMetadata(attributes=attributes, quoted_identifiers=None)
cache_metadata_if_select_statement(self.source_plan, self._metadata)
end = time.time()
logger.info(f"time spent on getting attributes with query {self.schema_query} is {end - start}")


@property
def uuid(self) -> str:
return self._uuid
Expand Down Expand Up @@ -347,11 +363,18 @@ def quoted_identifiers(self) -> List[str]:

@property
def attributes(self) -> List[Attribute]:
import time
start = time.time()
if self._metadata.attributes is not None:
return self._metadata.attributes
assert (
self.schema_query is not None
), "No schema query is available for the SnowflakePlan"
# if self._async_attribute_cursor is not None:
# logger.info("getting attributes from async cursor")
# attributes = get_attributes_from_sync_job(self._async_attribute_cursor, self.session)
# else:
logger.info("get attributes from regular schema")
attributes = analyze_attributes(self.schema_query, self.session)
self._metadata = PlanMetadata(attributes=attributes, quoted_identifiers=None)
# We need to cache attributes on SelectStatement too because df._plan is not
Expand All @@ -360,6 +383,8 @@ def attributes(self) -> List[Attribute]:
# No simplifier case relies on this schema_query change to update SHOW TABLES to a nested sql friendly query.
if not self.schema_query or not self.session.sql_simplifier_enabled:
self.schema_query = schema_value_statement(attributes)
end = time.time()
logger.info(f"time spent on getting attribute {end - start}")
return attributes

@cached_property
Expand Down Expand Up @@ -548,7 +573,9 @@ def build(
assert (
child.schema_query is not None
), "No schema query is available in child SnowflakePlan"
new_schema_query = schema_query or sql_generator(child.schema_query)
child_schema_query = schema_value_statement(child.attributes)
new_schema_query = schema_query or sql_generator(child_schema_query)
# new_schema_query = schema_query or sql_generator(child.schema_query)

return SnowflakePlan(
queries,
Expand Down Expand Up @@ -911,6 +938,8 @@ def save_as_table(
column_definition_with_hidden_columns,
)

current_time = time.time()
print(f"TIMESTAMP FOR START THE CREATION: {current_time}\n")
def get_create_table_as_select_plan(child: SnowflakePlan, replace, error):
return self.build(
lambda x: create_table_as_select_statement(
Expand Down
11 changes: 6 additions & 5 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,11 +606,12 @@ def __init__(
_PYTHON_SNOWPARK_AUTO_CLEAN_UP_TEMP_TABLE_ENABLED_VERSION
)
)
self._reduce_describe_query_enabled: bool = (
self._conn._get_client_side_session_parameter(
_PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED, False
)
)
# self._reduce_describe_query_enabled: bool = (
# self._conn._get_client_side_session_parameter(
# _PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED, False
# )
# )
self._reduce_describe_query_enabled = False
self._query_compilation_stage_enabled: bool = (
self._conn._get_client_side_session_parameter(
_PYTHON_SNOWPARK_ENABLE_QUERY_COMPILATION_STAGE, False
Expand Down
16 changes: 16 additions & 0 deletions tests/integ/test_deepcopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,19 @@ def traverse_plan(plan, plan_id_map):
traverse_plan(child, plan_id_map)

traverse_plan(copied_plan, {})


def test_deep_nested_select_create_table(session):
temp_table_name = Utils.random_table_name()
new_table_name = Utils.random_table_name()
try:
import time
start = time.time()
print(f"TIMESTAMP WHEN DF GETS CREATED: {start}\n")
df = create_df_with_deep_nested_with_column_dependencies(
session, temp_table_name, 4
)
df.write.save_as_table(new_table_name, table_type="temporary", mode="append")
finally:
Utils.drop_table(session, temp_table_name)
Utils.drop_table(session, new_table_name)
Empty file.