Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NO-SNOW: Support SEED argument in Dataframe.sample #3004

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,7 @@ def do_resolve_with_resolved_children(
logical_plan,
logical_plan.probability_fraction,
logical_plan.row_count,
logical_plan.seed,
)

if isinstance(logical_plan, Join):
Expand Down
5 changes: 5 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
INFER_SCHEMA = " INFER_SCHEMA "
SAMPLE = " SAMPLE "
ROWS = " ROWS "
SEED = " SEED "
CASE = " CASE "
WHEN = " WHEN "
THEN = " THEN "
Expand Down Expand Up @@ -465,15 +466,18 @@ def sample_statement(
child: str,
probability_fraction: Optional[float] = None,
row_count: Optional[int] = None,
seed: Optional[int] = None,
):
"""Generates the sql text for the sample part of the plan being executed"""
seed_clause = f"{SEED}({seed})" if seed is not None else EMPTY_STRING
if probability_fraction is not None:
return (
project_statement([], child)
+ SAMPLE
+ LEFT_PARENTHESIS
+ str(probability_fraction * 100)
+ RIGHT_PARENTHESIS
+ seed_clause
)
elif row_count is not None:
return (
Expand All @@ -483,6 +487,7 @@ def sample_statement(
+ str(row_count)
+ ROWS
+ RIGHT_PARENTHESIS
+ seed_clause
)
# this shouldn't happen because upstream code will validate either probability_fraction or row_count will have a value.
else: # pragma: no cover
Expand Down
54 changes: 51 additions & 3 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,26 +473,38 @@ def __init__(
assert isinstance(entity, SnowflakeTable)
super().__init__(analyzer)
self.entity = entity
self._sample_clause: Optional[str] = None

def __deepcopy__(self, memodict={}) -> "SelectableEntity": # noqa: B006
copied = SelectableEntity(
deepcopy(self.entity, memodict), analyzer=self.analyzer
)
_deepcopy_selectable_fields(from_selectable=self, to_selectable=copied)
copied._sample_clause = self._sample_clause

return copied

@property
def sql_query(self) -> str:
def select_clause(self) -> str:
return f"{analyzer_utils.SELECT}{analyzer_utils.STAR}{analyzer_utils.FROM}{self.entity.name}"

@property
def sample_clause(self) -> str:
return (
self._sample_clause if self._sample_clause else analyzer_utils.EMPTY_STRING
)

@property
def sql_query(self) -> str:
return f"{self.select_clause}{self.sample_clause}"

@property
def sql_in_subquery(self) -> str:
return self.entity.name
return f"{self.entity.name}{self.sample_clause}"

@property
def schema_query(self) -> str:
return self.sql_query
return self.select_clause

@property
def plan_node_category(self) -> PlanNodeCategory:
Expand All @@ -509,6 +521,42 @@ def referenced_ctes(self) -> Dict[WithQueryBlock, int]:
# CTE table will be referred.
return dict()

def sample(
self,
probability_fraction: Optional[float],
row_count: Optional[int],
seed: Optional[int],
) -> "Selectable":
if self._sample_clause is not None:
# to prevent .sample().sample() being applied
raise ValueError(
"The sample method has already been applied to this Selectable."
)
if probability_fraction is not None and row_count is not None:
raise ValueError(
"Only one of probability_fraction and row_count can be specified."
)
if probability_fraction is None and row_count is None:
raise ValueError(
"Either 'probability_fraction' or 'row_count' must not be None."
)

new = SelectableEntity(self.entity, analyzer=self.analyzer)
seed_clause = (
f"{analyzer_utils.SEED}({seed})"
if seed is not None
else analyzer_utils.EMPTY_STRING
)
probability_or_rowcount_clause = (
f"{row_count} ROWS"
if row_count is not None
else f"{probability_fraction*100.0}"
)
new._sample_clause = (
f"{analyzer_utils.SAMPLE}({probability_or_rowcount_clause}){seed_clause}"
)
return new


class SelectSQL(Selectable):
"""Query from a SQL. Mainly used by session.sql()"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -779,11 +779,15 @@ def sample(
source_plan: Optional[LogicalPlan],
probability_fraction: Optional[float] = None,
row_count: Optional[int] = None,
seed: Optional[int] = None,
) -> SnowflakePlan:
"""Builds the sample part of the resultant sql statement"""
return self.build(
lambda x: sample_statement(
x, probability_fraction=probability_fraction, row_count=row_count
x,
probability_fraction=probability_fraction,
row_count=row_count,
seed=seed,
),
child,
source_plan,
Expand Down
27 changes: 14 additions & 13 deletions src/snowflake/snowpark/_internal/proto/ast.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1268,17 +1268,18 @@ message DataframeSample {
DataframeExpr df = 1;
google.protobuf.Int64Value num = 2;
google.protobuf.DoubleValue probability_fraction = 3;
SrcPosition src = 4;
google.protobuf.Int64Value seed = 4;
SrcPosition src = 5;
}

// dataframe.ir:311
// dataframe.ir:312
message DataframeSelect_Columns {
ExprArgList cols = 1;
DataframeExpr df = 2;
SrcPosition src = 3;
}

// dataframe.ir:316
// dataframe.ir:317
message DataframeSelect_Exprs {
DataframeExpr df = 1;
ExprArgList exprs = 2;
Expand All @@ -1292,7 +1293,7 @@ message DataframeShow {
SrcPosition src = 3;
}

// dataframe.ir:321
// dataframe.ir:322
message DataframeSort {
Expr ascending = 1;
ExprArgList cols = 2;
Expand Down Expand Up @@ -1376,28 +1377,28 @@ message DataframeToPandasBatches {
repeated Tuple_String_String statement_params = 4;
}

// dataframe.ir:328
// dataframe.ir:329
message DataframeUnion {
DataframeExpr df = 1;
DataframeExpr other = 2;
SrcPosition src = 3;
}

// dataframe.ir:333
// dataframe.ir:334
message DataframeUnionAll {
DataframeExpr df = 1;
DataframeExpr other = 2;
SrcPosition src = 3;
}

// dataframe.ir:338
// dataframe.ir:339
message DataframeUnionAllByName {
DataframeExpr df = 1;
DataframeExpr other = 2;
SrcPosition src = 3;
}

// dataframe.ir:343
// dataframe.ir:344
message DataframeUnionByName {
DataframeExpr df = 1;
DataframeExpr other = 2;
Expand All @@ -1414,23 +1415,23 @@ message DataframeUnpivot {
string value_column = 6;
}

// dataframe.ir:348
// dataframe.ir:349
message DataframeWithColumn {
Expr col = 1;
string col_name = 2;
DataframeExpr df = 3;
SrcPosition src = 4;
}

// dataframe.ir:354
// dataframe.ir:355
message DataframeWithColumnRenamed {
Expr col = 1;
DataframeExpr df = 2;
string new_name = 3;
SrcPosition src = 4;
}

// dataframe.ir:360
// dataframe.ir:361
message DataframeWithColumns {
repeated string col_names = 1;
DataframeExpr df = 2;
Expand Down Expand Up @@ -1835,7 +1836,7 @@ message Geq {
SrcPosition src = 3;
}

// dataframe.ir:366
// dataframe.ir:367
message GroupingSets {
ExprArgList sets = 1;
SrcPosition src = 2;
Expand Down Expand Up @@ -2526,7 +2527,7 @@ message TableUpdate {
repeated Tuple_String_String statement_params = 7;
}

// dataframe.ir:370
// dataframe.ir:371
message ToSnowparkPandas {
List_String columns = 1;
DataframeExpr df = 2;
Expand Down
20 changes: 16 additions & 4 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
SelectSnowflakePlan,
SelectStatement,
SelectTableFunction,
SelectableEntity,
)
from snowflake.snowpark._internal.analyzer.snowflake_plan import PlanQueryType
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import (
Expand Down Expand Up @@ -5349,6 +5350,7 @@ def sample(
self,
frac: Optional[float] = None,
n: Optional[int] = None,
seed: Optional[int] = None,
_emit_ast: bool = True,
) -> "DataFrame":
"""Samples rows based on either the number of rows to be returned or a
Expand All @@ -5371,15 +5373,25 @@ def sample(
ast.probability_fraction.value = frac
if n:
ast.num.value = n
if seed:
ast.seed.value = seed
self._set_ast_ref(ast.df)

sample_plan = Sample(self._plan, probability_fraction=frac, row_count=n)
sample_plan = Sample(
self._plan, probability_fraction=frac, row_count=n, seed=seed
)
if self._select_statement:
from_ = self._session._analyzer.create_select_snowflake_plan(
sample_plan, analyzer=self._session._analyzer
)
if isinstance(self._select_statement.from_, SelectableEntity):
from_ = self._select_statement.from_.sample(
probability_fraction=frac, row_count=n, seed=seed
)

return self._with_plan(
self._session._analyzer.create_select_statement(
from_=self._session._analyzer.create_select_snowflake_plan(
sample_plan, analyzer=self._session._analyzer
),
from_=from_,
analyzer=self._session._analyzer,
),
_ast_stmt=stmt,
Expand Down
43 changes: 41 additions & 2 deletions tests/ast/data/df_sample.test
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@ df = df.sample(n=3)

df = df.sample(frac=0.5)

df = df.sample(frac=0.7, seed=42)

## EXPECTED UNPARSER OUTPUT

df = session.table("table1")

df = df.sample(None, 3)
df = df.sample(None, 3, None)

df = df.sample(0.5, None, None)

df = df.sample(0.5, None)
df = df.sample(0.7, None, 42)

## EXPECTED ENCODED AST

Expand Down Expand Up @@ -121,6 +125,41 @@ body {
}
}
}
body {
assign {
expr {
dataframe_sample {
df {
dataframe_ref {
id {
bitfield1: 3
}
}
}
probability_fraction {
value: 0.7
}
seed {
value: 42
}
src {
end_column: 41
end_line: 31
file: 2
start_column: 13
start_line: 31
}
}
}
symbol {
value: "df"
}
uid: 4
var_id {
bitfield1: 4
}
}
}
client_ast_version: 1
client_language {
python_language {
Expand Down
Loading