Skip to content

Commit

Permalink
Feat: implement support for bigframes (#3620)
Browse files Browse the repository at this point in the history
  • Loading branch information
z3z1ma authored Jan 15, 2025
1 parent 8ce63a7 commit 579298d
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 4 deletions.
53 changes: 52 additions & 1 deletion docs/concepts/models/python_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ The `execute` function is wrapped with the `@model` [decorator](https://wiki.pyt

Because SQLMesh creates tables before evaluating models, the schema of the output DataFrame is a required argument. The `@model` argument `columns` contains a dictionary of column names to types.

The function takes an `ExecutionContext` that is able to run queries and to retrieve the current time interval that is being processed, along with arbitrary key-value arguments passed in at runtime. The function can either return a Pandas, PySpark, or Snowpark Dataframe instance.
The function takes an `ExecutionContext` that is able to run queries and to retrieve the current time interval that is being processed, along with arbitrary key-value arguments passed in at runtime. The function can either return a Pandas, PySpark, Bigframe, or Snowpark Dataframe instance.

If the function output is too large, it can also be returned in chunks using Python generators.

Expand Down Expand Up @@ -441,6 +441,57 @@ def execute(
return df
```

### Bigframe
This example demonstrates using the [Bigframe](https://cloud.google.com/bigquery/docs/use-bigquery-dataframes#pandas-examples) DataFrame API. If you use Bigquery, the Bigframe API is preferred to Pandas as all computation is done in Bigquery.

```python linenums="1"
import typing as t
from datetime import datetime

from bigframes.pandas import DataFrame

from sqlmesh import ExecutionContext, model


def get_bucket(num: int):
if not num:
return "NA"
boundary = 10
return "at_or_above_10" if num >= boundary else "below_10"


@model(
"mart.wiki",
columns={
"title": "text",
"views": "int",
"bucket": "text",
},
)
def execute(
context: ExecutionContext,
start: datetime,
end: datetime,
execution_time: datetime,
**kwargs: t.Any,
) -> DataFrame:
# Create a remote function to be used in the Bigframe DataFrame
remote_get_bucket = context.bigframe.remote_function([int], str)(get_bucket)

# Returns the Bigframe DataFrame handle, no data is computed locally
df = context.bigframe.read_gbq("bigquery-samples.wikipedia_pageviews.200809h")

df = (
# This runs entirely on the BigQuery engine lazily
df[df.title.str.contains(r"[Gg]oogle")]
.groupby(["title"], as_index=False)["views"]
.sum(numeric_only=True)
.sort_values("views", ascending=False)
)

return df.assign(bucket=df["views"].apply(remote_get_bucket))
```

### Batching
If the output of a Python model is very large and you cannot use Spark, it may be helpful to split the output into multiple batches.

Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,6 @@ ignore_missing_imports = True

[mypy-dlt.*]
ignore_missing_imports = True

[mypy-bigframes.*]
ignore_missing_imports = True
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"google-cloud-bigquery[pandas]",
"google-cloud-bigquery-storage",
],
"bigframes": ["bigframes>=1.32.0"],
"clickhouse": ["clickhouse-connect"],
"databricks": ["databricks-sql-connector"],
"dev": [
Expand Down
6 changes: 6 additions & 0 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
from typing_extensions import Literal

from sqlmesh.core.engine_adapter._typing import (
BigframeSession,
DF,
PySparkDataFrame,
PySparkSession,
Expand Down Expand Up @@ -167,6 +168,11 @@ def snowpark(self) -> t.Optional[SnowparkSession]:
"""Returns the snowpark session if it exists."""
return self.engine_adapter.snowpark

@property
def bigframe(self) -> t.Optional[BigframeSession]:
"""Returns the bigframe session if it exists."""
return self.engine_adapter.bigframe

@property
def default_catalog(self) -> t.Optional[str]:
raise NotImplementedError
Expand Down
3 changes: 3 additions & 0 deletions sqlmesh/core/engine_adapter/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
if t.TYPE_CHECKING:
import pyspark
import pyspark.sql.connect.dataframe
from bigframes.session import Session as BigframeSession # noqa
from bigframes.dataframe import DataFrame as BigframeDataFrame

snowpark = optional_import("snowflake.snowpark")

Expand All @@ -23,6 +25,7 @@
pd.DataFrame,
pyspark.sql.DataFrame,
pyspark.sql.connect.dataframe.DataFrame,
BigframeDataFrame,
SnowparkDataFrame,
]

Expand Down
5 changes: 5 additions & 0 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
if t.TYPE_CHECKING:
from sqlmesh.core._typing import SchemaName, SessionProperties, TableName
from sqlmesh.core.engine_adapter._typing import (
BigframeSession,
DF,
PySparkDataFrame,
PySparkSession,
Expand Down Expand Up @@ -160,6 +161,10 @@ def spark(self) -> t.Optional[PySparkSession]:
def snowpark(self) -> t.Optional[SnowparkSession]:
return None

@property
def bigframe(self) -> t.Optional[BigframeSession]:
return None

@property
def comments_enabled(self) -> bool:
return self._register_comments and self.COMMENT_CREATION_TABLE.is_supported
Expand Down
25 changes: 23 additions & 2 deletions sqlmesh/core/engine_adapter/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from sqlmesh.core.node import IntervalUnit
from sqlmesh.core.schema_diff import SchemaDiffer
from sqlmesh.utils import optional_import
from sqlmesh.utils.date import to_datetime
from sqlmesh.utils.errors import SQLMeshError

Expand All @@ -35,11 +36,15 @@
from google.cloud.bigquery.table import Table as BigQueryTable

from sqlmesh.core._typing import SchemaName, SessionProperties, TableName
from sqlmesh.core.engine_adapter._typing import DF, Query
from sqlmesh.core.engine_adapter._typing import BigframeSession, DF, Query
from sqlmesh.core.engine_adapter.base import QueryOrDF


logger = logging.getLogger(__name__)

bigframes = optional_import("bigframes")
bigframes_pd = optional_import("bigframes.pandas")


NestedField = t.Tuple[str, str, t.List[str]]
NestedFieldsDict = t.Dict[str, t.List[NestedField]]
Expand Down Expand Up @@ -105,6 +110,17 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row
def client(self) -> BigQueryClient:
return self.connection._client

@property
def bigframe(self) -> t.Optional[BigframeSession]:
if bigframes:
options = bigframes.BigQueryOptions(
credentials=self.client._credentials,
project=self.client.project,
location=self.client.location,
)
return bigframes.connect(context=options)
return None

@property
def _job_params(self) -> t.Dict[str, t.Any]:
from sqlmesh.core.config.connection import BigQueryPriority
Expand Down Expand Up @@ -140,7 +156,12 @@ def _df_to_source_queries(
)

def query_factory() -> Query:
if not self.table_exists(temp_table):
if bigframes_pd and isinstance(df, bigframes_pd.DataFrame):
df.to_gbq(
f"{temp_bq_table.project}.{temp_bq_table.dataset_id}.{temp_bq_table.table_id}",
if_exists="replace",
)
elif not self.table_exists(temp_table):
# Make mypy happy
assert isinstance(df, pd.DataFrame)
self._db_call(self.client.create_table, table=temp_bq_table, exists_ok=False)
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None:
if isinstance(query_or_df, pd.DataFrame):
return query_or_df.head(limit)
if not isinstance(query_or_df, exp.Expression):
# We assume that if this branch is reached, `query_or_df` is a pyspark / snowpark dataframe,
# We assume that if this branch is reached, `query_or_df` is a pyspark / snowpark / bigframe dataframe,
# so we use `limit` instead of `head` to get back a dataframe instead of List[Row]
# https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.DataFrame.head.html#pyspark.sql.DataFrame.head
return query_or_df.limit(limit)
Expand Down

0 comments on commit 579298d

Please sign in to comment.