Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1900040: Estimate an upper bound for row counts between operations in OrderedDataFrame #3144

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 135 additions & 16 deletions src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
sum as sum_,
)
from snowflake.snowpark.modin.plugin._typing import AlignTypeLit, JoinTypeLit
from snowflake.snowpark.modin.plugin._internal.row_count_estimation import (
DataFrameOperation,
RowCountEstimator,
)
from snowflake.snowpark.row import Row
from snowflake.snowpark.session import Session
from snowflake.snowpark.table_function import TableFunctionCall
Expand Down Expand Up @@ -296,6 +300,8 @@ def __init__(
self.row_count_snowflake_quoted_identifier = (
row_count_snowflake_quoted_identifier
)
self.row_count: Optional[int] = None
self.row_count_upper_bound: Optional[int] = None

@property
def ordering_columns(self) -> list[OrderingColumn]:
Expand Down Expand Up @@ -425,6 +431,17 @@ def ensure_row_count_column(self) -> "OrderedDataFrame":
*self.projected_column_snowflake_quoted_identifiers,
count("*").over().as_(row_count_snowflake_quoted_identifier),
)

# Get the row count from the underlying Snowpark dataframe.
materialized_row_count = (
ordered_dataframe._dataframe_ref.snowpark_dataframe.select(
row_count_snowflake_quoted_identifier
).first()[row_count_snowflake_quoted_identifier.strip('"')]
)
# Set the row count and upper bound to the materialized row count.
ordered_dataframe.row_count = materialized_row_count
ordered_dataframe.row_count_upper_bound = materialized_row_count

# inplace update so dataframe_ref can be shared. Note that we keep
# the original ordering columns.
ordered_dataframe.row_count_snowflake_quoted_identifier = (
Expand Down Expand Up @@ -626,7 +643,12 @@ def select(
snowpark_dataframe = self._dataframe_ref.snowpark_dataframe.select(
*cols
)
return OrderedDataFrame(DataFrameReference(snowpark_dataframe))
new_df = OrderedDataFrame(DataFrameReference(snowpark_dataframe))
# Update the row count upper bound
new_df.row_count_upper_bound = RowCountEstimator.upper_bound(
self, DataFrameOperation.SELECT, args={}
)
return new_df
elif isinstance(e, (Column, str)):
column_names = self._extract_quoted_identifiers_from_column_or_name(
e, active_columns
Expand Down Expand Up @@ -660,7 +682,7 @@ def select(
f"The Snowpark DataFrame in DataFrameReference with id={dataframe_ref._id} is updated"
)

return OrderedDataFrame(
new_df = OrderedDataFrame(
dataframe_ref,
projected_column_snowflake_quoted_identifiers=new_projected_columns,
# keep the original ordering columns and row position column
Expand All @@ -669,6 +691,12 @@ def select(
row_count_snowflake_quoted_identifier=self.row_count_snowflake_quoted_identifier,
)

# Update the row count upper bound
new_df.row_count_upper_bound = RowCountEstimator.upper_bound(
self, DataFrameOperation.SELECT, args={}
)
return new_df

def dropna(
self,
how: str = "any",
Expand All @@ -693,11 +721,16 @@ def dropna(
result_column_quoted_identifiers = (
projected_dataframe_ref.snowflake_quoted_identifiers
)
return OrderedDataFrame(
new_df = OrderedDataFrame(
DataFrameReference(snowpark_dataframe, result_column_quoted_identifiers),
projected_column_snowflake_quoted_identifiers=result_column_quoted_identifiers,
ordering_columns=self.ordering_columns,
)
# Update the row count upper bound
new_df.row_count_upper_bound = RowCountEstimator.upper_bound(
self, DataFrameOperation.DROPNA, args={}
)
return new_df

def union_all(self, other: "OrderedDataFrame") -> "OrderedDataFrame":
"""
Expand All @@ -720,10 +753,15 @@ def union_all(self, other: "OrderedDataFrame") -> "OrderedDataFrame":
snowpark_dataframe = self_snowpark_dataframe_ref.snowpark_dataframe.union_all(
other_snowpark_dataframe_ref.snowpark_dataframe
)
return OrderedDataFrame(
new_df = OrderedDataFrame(
DataFrameReference(snowpark_dataframe, result_column_quoted_identifiers),
projected_column_snowflake_quoted_identifiers=result_column_quoted_identifiers,
)
# Update the row count upper bound
new_df.row_count_upper_bound = RowCountEstimator.upper_bound(
self, DataFrameOperation.UNION_ALL, args={"other": other}
)
return new_df

def _extract_aggregation_result_column_quoted_identifiers(
self,
Expand Down Expand Up @@ -775,14 +813,19 @@ def group_by(
self._extract_aggregation_result_column_quoted_identifiers(*agg_exprs)
)

return OrderedDataFrame(
new_df = OrderedDataFrame(
DataFrameReference(
self._dataframe_ref.snowpark_dataframe.group_by(cols).agg(*agg_exprs),
snowflake_quoted_identifiers=result_column_quoted_identifiers,
),
projected_column_snowflake_quoted_identifiers=result_column_quoted_identifiers,
ordering_columns=[OrderingColumn(identifier) for identifier in cols],
)
# Update the row count upper bound
new_df.row_count_upper_bound = RowCountEstimator.upper_bound(
self, DataFrameOperation.GROUP_BY, args={}
)
return new_df

def sort(
self,
Expand All @@ -806,7 +849,7 @@ def sort(
if ordering_columns == self.ordering_columns:
return self

return OrderedDataFrame(
new_df = OrderedDataFrame(
self._to_projected_snowpark_dataframe_reference(
include_row_count_column=True, include_ordering_columns=True
),
Expand All @@ -817,6 +860,11 @@ def sort(
# No need to reset row count, since sorting should not add/drop rows.
row_count_snowflake_quoted_identifier=self.row_count_snowflake_quoted_identifier,
)
# Update the row count upper bound
new_df.row_count_upper_bound = RowCountEstimator.upper_bound(
self, DataFrameOperation.SORT, args={}
)
return new_df

def pivot(
self,
Expand All @@ -834,7 +882,7 @@ def pivot(
See detailed docstring in Snowpark DataFrame's pivot.
"""
snowpark_dataframe = self.to_projected_snowpark_dataframe()
return OrderedDataFrame(
new_df = OrderedDataFrame(
# the pivot result columns for dynamic pivot are data dependent, a schema call is required
# to know all the quoted identifiers for the pivot result.
DataFrameReference(
Expand All @@ -845,6 +893,11 @@ def pivot(
).agg(*agg_exprs)
)
)
# Update the row count upper bound
new_df.row_count_upper_bound = RowCountEstimator.upper_bound(
self, DataFrameOperation.PIVOT, args={}
)
return new_df

def unpivot(
self,
Expand Down Expand Up @@ -903,7 +956,7 @@ def unpivot(
]
# add the name column and value colum to the result
result_column_quoted_identifiers += [name_column, value_column]
return OrderedDataFrame(
new_df = OrderedDataFrame(
DataFrameReference(
projected_dataframe_ref.snowpark_dataframe.unpivot(
value_column=value_column,
Expand All @@ -915,6 +968,11 @@ def unpivot(
),
projected_column_snowflake_quoted_identifiers=result_column_quoted_identifiers,
)
# Update the row count upper bound
new_df.row_count_upper_bound = RowCountEstimator.upper_bound(
self, DataFrameOperation.UNPIVOT, args={"column_list": column_list}
)
return new_df

def agg(
self,
Expand All @@ -930,10 +988,15 @@ def agg(
result_column_quoted_identifiers = (
self._extract_aggregation_result_column_quoted_identifiers(*exprs)
)
return OrderedDataFrame(
new_df = OrderedDataFrame(
DataFrameReference(snowpark_dataframe, result_column_quoted_identifiers),
projected_column_snowflake_quoted_identifiers=result_column_quoted_identifiers,
)
# Update the row count upper bound
new_df.row_count_upper_bound = RowCountEstimator.upper_bound(
self, DataFrameOperation.AGG, args={}
)
return new_df

def _deduplicate_active_column_snowflake_quoted_identifiers(
self,
Expand Down Expand Up @@ -1189,13 +1252,26 @@ def join(
# if no join needed, we simply return the deduplicated right frame with the projected columns
# set to the left.projected_column_snowflake_quoted_identifiers and the deduplicated the right
# projected_column_snowflake_quoted_identifiers.
return OrderedDataFrame(
new_df = OrderedDataFrame(
right._dataframe_ref,
projected_column_snowflake_quoted_identifiers=projected_column_snowflake_quoted_identifiers,
ordering_columns=self.ordering_columns,
row_position_snowflake_quoted_identifier=self.row_position_snowflake_quoted_identifier,
row_count_snowflake_quoted_identifier=self.row_count_snowflake_quoted_identifier,
)
new_df.row_count_upper_bound = RowCountEstimator.upper_bound(
self,
DataFrameOperation.JOIN,
args={
"left_on_cols": left_on_cols,
"right_on_cols": right_on_cols,
"how": how,
"left_match_col": left_match_col,
"right_match_col": right_match_col,
"match_comparator": match_comparator,
},
)
return new_df

# reproject the snowpark dataframe with only necessary columns
left_snowpark_dataframe_ref = self._to_projected_snowpark_dataframe_reference(
Expand Down Expand Up @@ -1262,7 +1338,7 @@ def join(
else:
ordering_columns = self.ordering_columns + right.ordering_columns

return OrderedDataFrame(
new_df = OrderedDataFrame(
DataFrameReference(
snowpark_dataframe,
# the result of join retains column quoted identifier of both left + right
Expand All @@ -1272,6 +1348,19 @@ def join(
projected_column_snowflake_quoted_identifiers=projected_column_snowflake_quoted_identifiers,
ordering_columns=ordering_columns,
)
new_df.row_count_upper_bound = RowCountEstimator.upper_bound(
self,
DataFrameOperation.JOIN,
args={
"left_on_cols": left_on_cols,
"right_on_cols": right_on_cols,
"how": how,
"left_match_col": left_match_col,
"right_match_col": right_match_col,
"match_comparator": match_comparator,
},
)
return new_df

def _has_same_base_ordered_dataframe(self, other: "OrderedDataFrame") -> bool:
"""
Expand Down Expand Up @@ -1421,14 +1510,20 @@ def align(
include_row_count_column=False,
)
)
return OrderedDataFrame(
new_df = OrderedDataFrame(
dataframe_ref=aligned_ordered_frame._dataframe_ref,
projected_column_snowflake_quoted_identifiers=self.projected_column_snowflake_quoted_identifiers
+ aligned_ordered_frame.projected_column_snowflake_quoted_identifiers,
ordering_columns=self.ordering_columns,
row_position_snowflake_quoted_identifier=self.row_position_snowflake_quoted_identifier,
row_count_snowflake_quoted_identifier=self.row_count_snowflake_quoted_identifier,
)
new_df.row_count_upper_bound = RowCountEstimator.upper_bound(
self,
DataFrameOperation.ALIGN,
args={"right": right},
)
return new_df

from snowflake.snowpark.modin.plugin._internal.join_utils import (
JoinOrAlignOrderedDataframeResultHelper,
Expand Down Expand Up @@ -1731,7 +1826,13 @@ def align(

# call select to make sure only the result_projected_column_snowflake_quoted_identifiers are projected
# in the join result
return joined_ordered_frame.select(select_list)
new_df = joined_ordered_frame.select(select_list)
new_df.row_count_upper_bound = RowCountEstimator.upper_bound(
self,
DataFrameOperation.ALIGN,
args={"right": right},
)
return new_df

def filter(self, expr: ColumnOrSqlExpr) -> "OrderedDataFrame":
"""
Expand All @@ -1744,7 +1845,7 @@ def filter(self, expr: ColumnOrSqlExpr) -> "OrderedDataFrame":
include_ordering_columns=True
)
snowpark_dataframe = projected_dataframe_ref.snowpark_dataframe.filter(expr)
return OrderedDataFrame(
new_df = OrderedDataFrame(
DataFrameReference(
snowpark_dataframe,
# same columns are retained after filtering
Expand All @@ -1753,6 +1854,12 @@ def filter(self, expr: ColumnOrSqlExpr) -> "OrderedDataFrame":
projected_column_snowflake_quoted_identifiers=projected_dataframe_ref.snowflake_quoted_identifiers,
ordering_columns=self.ordering_columns,
)
new_df.row_count_upper_bound = RowCountEstimator.upper_bound(
self,
DataFrameOperation.FILTER,
args={},
)
return new_df

def limit(self, n: int, offset: int = 0, sort: bool = True) -> "OrderedDataFrame":
"""
Expand All @@ -1770,7 +1877,7 @@ def limit(self, n: int, offset: int = 0, sort: bool = True) -> "OrderedDataFrame
snowpark_dataframe = projected_dataframe_ref.snowpark_dataframe.limit(
n=n, offset=offset
)
return OrderedDataFrame(
new_df = OrderedDataFrame(
DataFrameReference(
snowpark_dataframe,
# the same columns are retained for limit
Expand All @@ -1779,6 +1886,12 @@ def limit(self, n: int, offset: int = 0, sort: bool = True) -> "OrderedDataFrame
projected_column_snowflake_quoted_identifiers=projected_dataframe_ref.snowflake_quoted_identifiers,
ordering_columns=self.ordering_columns,
)
new_df.row_count_upper_bound = RowCountEstimator.upper_bound(
self,
DataFrameOperation.LIMIT,
args={"n": n},
)
return new_df

@property
def write(self) -> DataFrameWriter:
Expand Down Expand Up @@ -2009,7 +2122,7 @@ def sample(self, n: Optional[int], frac: Optional[float]) -> "OrderedDataFrame":
# df_s = df.sample(frac=0.5)
# assert df_s.index == df_s.index may fail because both the LHS and RHS will call the sample method during
# evaluation and the results won't be deterministic.
return cache_result(
new_df = cache_result(
OrderedDataFrame(
DataFrameReference(
snowpark_dataframe,
Expand All @@ -2020,3 +2133,9 @@ def sample(self, n: Optional[int], frac: Optional[float]) -> "OrderedDataFrame":
ordering_columns=self.ordering_columns,
)
)
new_df.row_count_upper_bound = RowCountEstimator.upper_bound(
self,
DataFrameOperation.SAMPLE,
args={"n": n, "frac": frac},
)
return new_df
Loading
Loading