Skip to content

Commit

Permalink
Feat: Allow CustomKind subclasses for custom materializations (#3863)
Browse files Browse the repository at this point in the history
  • Loading branch information
erindru authored Feb 19, 2025
1 parent d379826 commit 503d13f
Show file tree
Hide file tree
Showing 13 changed files with 392 additions and 28 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ install-dev:
pip3 install -e ".[dev,web,slack,dlt]"

install-cicd-test:
pip3 install -e ".[dev,web,slack,cicdtest,dlt]"
pip3 install -e ".[dev,web,slack,cicdtest,dlt]" ./examples/custom_materializations

install-doc:
pip3 install -r ./docs/requirements.txt
Expand Down Expand Up @@ -153,7 +153,7 @@ guard-%:
fi

engine-%-install:
pip3 install -e ".[dev,web,slack,${*}]"
pip3 install -e ".[dev,web,slack,${*}]" ./examples/custom_materializations

engine-docker-%-up:
docker compose -f ./tests/core/engine_adapter/integration/docker/compose.${*}.yaml up -d
Expand Down
97 changes: 97 additions & 0 deletions docs/guides/custom_materializations.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,103 @@ class CustomFullMaterialization(CustomMaterialization):
# Example existing materialization for look and feel: https://github.com/TobikoData/sqlmesh/blob/main/sqlmesh/core/snapshot/evaluator.py
```

## Extending `CustomKind`

!!! warning
This is even lower level usage that contains a bunch of extra complexity and relies on knowledge of the SQLMesh internals.
If you dont need this level of complexity, stick with the method described above.

In many cases, the above usage of a custom materialization will suffice.

However, you may still want tighter integration with SQLMesh's internals:

- You may want more control over what is considered a metadata change vs a data change
- You may want to validate custom properties are correct before any database connections are made
- You may want to leverage existing functionality of SQLMesh that relies on specific properties being present

In this case, you can provide a subclass of `CustomKind` for SQLMesh to use instead of `CustomKind` itself.
During project load, SQLMesh will instantiate your *subclass* instead of `CustomKind`.

This allows you to run custom validators at load time rather than having to perform extra validation when `insert()` is invoked on your `CustomMaterialization`.

This approach also allows you set "top-level" properties directly in the `kind (...)` block rather than nesting them under `materialization_properties`.

To extend `CustomKind`, first you define a subclass like so:

```python linenums="1" hl_lines="7"
from sqlmesh import CustomKind
from pydantic import field_validator, ValidationInfo
from sqlmesh.utils.pydantic import list_of_fields_validator

class MyCustomKind(CustomKind):

primary_key: t.List[exp.Expression]

@field_validator("primary_key", mode="before")
@classmethod
def _validate_primary_key(cls, value: t.Any, info: ValidationInfo) -> t.Any:
return list_of_fields_validator(value, info.data)

```

In this example, we define a field called `primary_key` that takes a list of fields. Notice that the field validation is just a simple Pydantic `@field_validator` with the [exact same usage](https://github.com/TobikoData/sqlmesh/blob/ade5f7245950822f3cfe5a68a0c243f91ceca600/sqlmesh/core/model/kind.py#L470) as the standard SQLMesh model kinds.

To use it within a model, we can do something like:

```sql linenums="1" hl_lines="5"
MODEL (
name my_db.my_model,
kind CUSTOM (
materialization 'my_custom_full',
primary_key (col1, col2)
)
);
```

Notice that the `primary_key` field we declared is top-level within the `kind` block instead of being nested under `materialization_properties`.

To indicate to SQLMesh that it should use this subclass, specify it as a generic type parameter on your custom materialization class like so:

```python linenums="1" hl_lines="1 16"
class CustomFullMaterialization(CustomMaterialization[MyCustomKind]):
NAME = "my_custom_full"

def insert(
self,
table_name: str,
query_or_df: QueryOrDF,
model: Model,
is_first_insert: bool,
**kwargs: t.Any,
) -> None:
assert isinstance(model.kind, MyCustomKind)

self.adapter.merge(
...,
unique_key=model.kind.primary_key
)
```

When SQLMesh loads your custom materialization, it will inspect the Python type signature for generic parameters that are subclasses of `CustomKind`. If it finds one, it will instantiate your subclass when building `model.kind` instead of using the default `CustomKind` class.

In this example, this means that:

- Validation for `primary_key` happens at load time instead of evaluation time.
- When your custom materialization is called to load data into tables, `model.kind` will resolve to your custom kind object so you can access the extra properties you defined without first needing to validate them / coerce them to a usable type.

### Data vs Metadata changes

Subclasses of `CustomKind` that add extra properties can also decide if they are data properties (changes may trigger the creation of new snapshots) or metadata properties (changes just update metadata about the model).

They can also decide if they are relevant for text diffing when SQLMesh detects changes to a model.

You can opt in to SQLMesh's change tracking by overriding the following methods:

- If changing the property should change the data fingerprint, add it to [data_hash_values()](https://github.com/TobikoData/sqlmesh/blob/ade5f7245950822f3cfe5a68a0c243f91ceca600/sqlmesh/core/model/kind.py#L858)
- If changing the property should change the metadata fingerprint, add it to [metadata_hash_values()](https://github.com/TobikoData/sqlmesh/blob/ade5f7245950822f3cfe5a68a0c243f91ceca600/sqlmesh/core/model/kind.py#L867)
- If the property should show up in context diffs, add it to [to_expression()](https://github.com/TobikoData/sqlmesh/blob/ade5f7245950822f3cfe5a68a0c243f91ceca600/sqlmesh/core/model/kind.py#L880)


## Sharing custom materializations

### Copying files
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

import typing as t

from sqlmesh import CustomMaterialization, CustomKind, Model
from sqlmesh.utils.pydantic import validate_string
from pydantic import field_validator

if t.TYPE_CHECKING:
from sqlmesh import QueryOrDF


class ExtendedCustomKind(CustomKind):
custom_property: t.Optional[str] = None

@field_validator("custom_property", mode="before")
@classmethod
def _validate_custom_property(cls, v: t.Any) -> str:
return validate_string(v)


class CustomFullWithCustomKindMaterialization(CustomMaterialization[ExtendedCustomKind]):
NAME = "custom_full_with_custom_kind"

def insert(
self,
table_name: str,
query_or_df: QueryOrDF,
model: Model,
is_first_insert: bool,
**kwargs: t.Any,
) -> None:
assert type(model.kind).__name__ == "ExtendedCustomKind"

self._replace_query_for_model(model, table_name, query_or_df)
1 change: 1 addition & 0 deletions examples/custom_materializations/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
entry_points={
"sqlmesh.materializations": [
"custom_full_materialization = custom_materializations.full:CustomFullMaterialization",
"custom_full_with_custom_kind = custom_materializations.custom_kind:CustomFullWithCustomKindMaterialization",
],
},
install_requires=[
Expand Down
13 changes: 13 additions & 0 deletions examples/sushi/models/latest_order.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
MODEL (
name sushi.latest_order,
kind CUSTOM (
materialization 'custom_full_with_custom_kind',
custom_property 'sushi!!!'
),
cron '@daily'
);

SELECT id, customer_id, start_ts, end_ts, event_date
FROM sushi.orders
ORDER BY event_date DESC LIMIT 1

1 change: 1 addition & 0 deletions sqlmesh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sqlmesh.core.snapshot.evaluator import (
CustomMaterialization as CustomMaterialization,
)
from sqlmesh.core.model.kind import CustomKind as CustomKind
from sqlmesh.utils import (
debug_mode_enabled as debug_mode_enabled,
enable_debug_mode as enable_debug_mode,
Expand Down
15 changes: 15 additions & 0 deletions sqlmesh/core/model/kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,21 @@ def create_model_kind(v: t.Any, dialect: str, defaults: t.Dict[str, t.Any]) -> M
):
props["on_destructive_change"] = defaults.get("on_destructive_change")

if kind_type == CustomKind:
# load the custom materialization class and check if it uses a custom kind type
from sqlmesh.core.snapshot.evaluator import get_custom_materialization_type

if "materialization" not in props:
raise ConfigError(
"The 'materialization' property is required for models of the CUSTOM kind"
)

actual_kind_type, _ = get_custom_materialization_type(
validate_string(props.get("materialization"))
)

return actual_kind_type(**props)

return kind_type(**props)

name = (v.name if isinstance(v, exp.Expression) else str(v)).upper()
Expand Down
50 changes: 41 additions & 9 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
SCDType2ByColumnKind,
SCDType2ByTimeKind,
ViewKind,
CustomKind,
)
from sqlmesh.core.schema_diff import has_drop_alteration, get_dropped_column_names
from sqlmesh.core.snapshot import (
Expand Down Expand Up @@ -1130,7 +1131,7 @@ def _evaluation_strategy(snapshot: SnapshotInfoLike, adapter: EngineAdapter) ->
raise SQLMeshError(
f"Missing the name of a custom evaluation strategy in model '{snapshot.name}'."
)
klass = get_custom_materialization_type(snapshot.custom_materialization)
_, klass = get_custom_materialization_type(snapshot.custom_materialization)
return klass(adapter)
elif snapshot.is_managed:
klass = EngineManagedStrategy
Expand Down Expand Up @@ -1897,7 +1898,10 @@ def _is_materialized_view(self, model: Model) -> bool:
return isinstance(model.kind, ViewKind) and model.kind.materialized


class CustomMaterialization(MaterializableStrategy):
C = t.TypeVar("C", bound=CustomKind)


class CustomMaterialization(MaterializableStrategy, t.Generic[C]):
"""Base class for custom materializations."""

def insert(
Expand All @@ -1924,14 +1928,36 @@ def insert(
)


_custom_materialization_type_cache: t.Optional[t.Dict[str, t.Type[CustomMaterialization]]] = None
_custom_materialization_type_cache: t.Optional[
t.Dict[str, t.Tuple[t.Type[CustomKind], t.Type[CustomMaterialization]]]
] = None


def get_custom_materialization_kind_type(st: t.Type[CustomMaterialization]) -> t.Type[CustomKind]:
# try to read if there is a custom 'kind' type in use by inspecting the type signature
# eg try to read 'MyCustomKind' from:
# >>>> class MyCustomMaterialization(CustomMaterialization[MyCustomKind])
# and fall back to base CustomKind if there is no generic type declared
if hasattr(st, "__orig_bases__"):
for base in st.__orig_bases__:
if hasattr(base, "__origin__") and base.__origin__ == CustomMaterialization:
for generic_arg in t.get_args(base):
if not issubclass(generic_arg, CustomKind):
raise SQLMeshError(
f"Custom materialization kind '{generic_arg.__name__}' must be a subclass of CustomKind"
)

return generic_arg

return CustomKind

def get_custom_materialization_type(name: str) -> t.Type[CustomMaterialization]:

def get_custom_materialization_type(
name: str,
) -> t.Tuple[t.Type[CustomKind], t.Type[CustomMaterialization]]:
global _custom_materialization_type_cache

strategy_key = name.lower()

if (
_custom_materialization_type_cache is None
or strategy_key not in _custom_materialization_type_cache
Expand All @@ -1948,16 +1974,22 @@ def get_custom_materialization_type(name: str) -> t.Type[CustomMaterialization]:
strategy_types.append(strategy_type)

_custom_materialization_type_cache = {
getattr(strategy_type, "NAME", strategy_type.__name__).lower(): strategy_type
getattr(strategy_type, "NAME", strategy_type.__name__).lower(): (
get_custom_materialization_kind_type(strategy_type),
strategy_type,
)
for strategy_type in strategy_types
}

if strategy_key not in _custom_materialization_type_cache:
raise ConfigError(f"Materialization strategy with name '{name}' was not found.")

strategy_type = _custom_materialization_type_cache[strategy_key]
logger.debug("Resolved custom materialization '%s' to '%s'", name, strategy_type)
return strategy_type
strategy_kind_type, strategy_type = _custom_materialization_type_cache[strategy_key]
logger.debug(
"Resolved custom materialization '%s' to '%s' (%s)", name, strategy_type, strategy_kind_type
)

return strategy_kind_type, strategy_type


class EngineManagedStrategy(MaterializableStrategy):
Expand Down
2 changes: 1 addition & 1 deletion tests/core/analytics/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def test_on_plan_apply(
{
"seq_num": 0,
"event_type": "PLAN_APPLY_START",
"event": f'{{"plan_id": "{plan_id}", "engine_type": "bigquery", "state_sync_type": "mysql", "scheduler_type": "builtin", "is_dev": false, "skip_backfill": false, "no_gaps": false, "forward_only": false, "ensure_finalized_snapshots": false, "has_restatements": false, "directly_modified_count": 18, "indirectly_modified_count": 0, "environment_name_hash": "d6e4a9b6646c62fc48baa6dd6150d1f7"}}',
"event": f'{{"plan_id": "{plan_id}", "engine_type": "bigquery", "state_sync_type": "mysql", "scheduler_type": "builtin", "is_dev": false, "skip_backfill": false, "no_gaps": false, "forward_only": false, "ensure_finalized_snapshots": false, "has_restatements": false, "directly_modified_count": 19, "indirectly_modified_count": 0, "environment_name_hash": "d6e4a9b6646c62fc48baa6dd6150d1f7"}}',
**common_fields,
}
),
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ def test_janitor(sushi_context, mocker: MockerFixture) -> None:
)
# Assert that the views are dropped for each snapshot just once and make sure that the name used is the
# view name with the environment as a suffix
assert adapter_mock.drop_view.call_count == 13
assert adapter_mock.drop_view.call_count == 14
adapter_mock.drop_view.assert_has_calls(
[
call(
Expand Down
Loading

0 comments on commit 503d13f

Please sign in to comment.