Skip to content

Commit

Permalink
Fix: Streamline execution of pre- / post- statements when creating a …
Browse files Browse the repository at this point in the history
…physical table (#3837)
  • Loading branch information
izeigerman authored Feb 14, 2025
1 parent beaebe3 commit ec5e085
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 51 deletions.
1 change: 1 addition & 0 deletions sqlmesh/core/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def _render(
self._model_fqn,
snapshots={self._model_fqn: this_snapshot} if this_snapshot else None,
deployability_index=deployability_index,
table_mapping=table_mapping,
)

expressions = [self._expression]
Expand Down
101 changes: 59 additions & 42 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,25 +757,14 @@ def _create_snapshot(
deployability_index = deployability_index or DeployabilityIndex.all_deployable()

adapter = self._get_adapter(snapshot.model.gateway)
common_render_kwargs: t.Dict[str, t.Any] = dict(
create_render_kwargs: t.Dict[str, t.Any] = dict(
engine_adapter=adapter,
snapshots=parent_snapshots_by_name(snapshot, snapshots),
runtime_stage=RuntimeStage.CREATING,
deployability_index=deployability_index,
)
pre_post_render_kwargs = dict(
**common_render_kwargs,
deployability_index=deployability_index.with_deployable(snapshot),
)
create_render_kwargs = dict(**common_render_kwargs, deployability_index=deployability_index)

# It can still be useful for some strategies to know if the snapshot was actually deployable
is_snapshot_deployable = deployability_index.is_deployable(snapshot)
is_snapshot_representative = deployability_index.is_representative(snapshot)

evaluation_strategy = _evaluation_strategy(snapshot, adapter)

with adapter.transaction(), adapter.session(snapshot.model.session_properties):
adapter.execute(snapshot.model.render_pre_statements(**pre_post_render_kwargs))
rendered_physical_properties = snapshot.model.render_physical_properties(
**create_render_kwargs
)
Expand All @@ -796,18 +785,16 @@ def _create_snapshot(

logger.info(f"Cloning table '{source_table_name}' into '{target_table_name}'")

evaluation_strategy.create(
self._execute_create(
snapshot=snapshot,
table_name=tmp_table_name,
model=snapshot.model,
is_table_deployable=False,
render_kwargs=dict(
table_mapping={snapshot.name: tmp_table_name},
**create_render_kwargs,
),
is_snapshot_deployable=is_snapshot_deployable,
is_snapshot_representative=is_snapshot_representative,
physical_properties=rendered_physical_properties,
deployability_index=deployability_index,
create_render_kwargs=create_render_kwargs,
rendered_physical_properties=rendered_physical_properties,
dry_run=True,
)

try:
adapter.clone_table(target_table_name, snapshot.table_name(), replace=True)
alter_expressions = adapter.get_alter_expressions(
Expand All @@ -828,7 +815,7 @@ def _create_snapshot(
if (
is_table_deployable
and snapshot.model.forward_only
and not is_snapshot_representative
and not deployability_index.is_representative(snapshot)
):
logger.info(
"Skipping creation of the deployable table '%s' for the forward-only model %s. "
Expand All @@ -838,19 +825,16 @@ def _create_snapshot(
)
continue

evaluation_strategy.create(
self._execute_create(
snapshot=snapshot,
table_name=snapshot.table_name(is_deployable=is_table_deployable),
model=snapshot.model,
is_table_deployable=is_table_deployable,
render_kwargs=create_render_kwargs,
is_snapshot_deployable=is_snapshot_deployable,
is_snapshot_representative=is_snapshot_representative,
deployability_index=deployability_index,
create_render_kwargs=create_render_kwargs,
rendered_physical_properties=rendered_physical_properties,
dry_run=dry_run,
physical_properties=rendered_physical_properties,
)

adapter.execute(snapshot.model.render_post_statements(**pre_post_render_kwargs))

if on_complete is not None:
on_complete(snapshot)

Expand All @@ -871,10 +855,9 @@ def _migrate_snapshot(
if not needs_migration:
return

evaluation_strategy = _evaluation_strategy(snapshot, adapter)

target_table_name = snapshot.table_name()
if adapter.table_exists(target_table_name):
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
tmp_table_name = snapshot.table_name(is_deployable=False)
logger.info(
"Migrating table schema from '%s' to '%s'",
Expand All @@ -894,25 +877,25 @@ def _migrate_snapshot(
target_table_name,
snapshot.snapshot_id,
)
deployability_index = DeployabilityIndex.all_deployable()
render_kwargs: t.Dict[str, t.Any] = dict(
engine_adapter=adapter,
snapshots=parent_snapshots_by_name(snapshot, snapshots),
runtime_stage=RuntimeStage.CREATING,
deployability_index=DeployabilityIndex.all_deployable(),
deployability_index=deployability_index,
)
with adapter.transaction(), adapter.session(snapshot.model.session_properties):
adapter.execute(snapshot.model.render_pre_statements(**render_kwargs))
evaluation_strategy.create(
self._execute_create(
snapshot=snapshot,
table_name=target_table_name,
model=snapshot.model,
is_table_deployable=True,
render_kwargs=render_kwargs,
is_snapshot_deployable=True,
is_snapshot_representative=True,
deployability_index=deployability_index,
create_render_kwargs=render_kwargs,
rendered_physical_properties=snapshot.model.render_physical_properties(
**render_kwargs
),
dry_run=False,
physical_properties=snapshot.model.render_physical_properties(**render_kwargs),
)
adapter.execute(snapshot.model.render_post_statements(**render_kwargs))

def _promote_snapshot(
self,
Expand Down Expand Up @@ -1085,6 +1068,40 @@ def _get_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
raise SQLMeshError(f"Gateway '{gateway}' not found in the available engine adapters.")
return self.adapter

def _execute_create(
self,
snapshot: Snapshot,
table_name: str,
is_table_deployable: bool,
deployability_index: DeployabilityIndex,
create_render_kwargs: t.Dict[str, t.Any],
rendered_physical_properties: t.Dict[str, exp.Expression],
dry_run: bool,
) -> None:
adapter = self._get_adapter(snapshot.model.gateway)
evaluation_strategy = _evaluation_strategy(snapshot, adapter)

# It can still be useful for some strategies to know if the snapshot was actually deployable
is_snapshot_deployable = deployability_index.is_deployable(snapshot)
is_snapshot_representative = deployability_index.is_representative(snapshot)

create_render_kwargs = {
**create_render_kwargs,
"table_mapping": {snapshot.name: table_name},
}
adapter.execute(snapshot.model.render_pre_statements(**create_render_kwargs))
evaluation_strategy.create(
table_name=table_name,
model=snapshot.model,
is_table_deployable=is_table_deployable,
render_kwargs=create_render_kwargs,
is_snapshot_deployable=is_snapshot_deployable,
is_snapshot_representative=is_snapshot_representative,
dry_run=dry_run,
physical_properties=rendered_physical_properties,
)
adapter.execute(snapshot.model.render_post_statements(**create_render_kwargs))


def _evaluation_strategy(snapshot: SnapshotInfoLike, adapter: EngineAdapter) -> EvaluationStrategy:
klass: t.Type
Expand Down
34 changes: 25 additions & 9 deletions tests/core/test_snapshot_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,22 +1616,26 @@ def test_drop_clone_in_dev_when_migration_fails(mocker: MockerFixture, adapter_m
)


def test_create_clone_in_dev_self_referencing(mocker: MockerFixture, adapter_mock, make_snapshot):
@pytest.mark.parametrize("use_this_model", [True, False])
def test_create_clone_in_dev_self_referencing(
mocker: MockerFixture, adapter_mock, make_snapshot, use_this_model: bool
):
adapter_mock.SUPPORTS_CLONING = True
adapter_mock.get_alter_expressions.return_value = []
evaluator = SnapshotEvaluator(adapter_mock)

from_table = "test_schema.test_model" if not use_this_model else "@this_model"
model = load_sql_based_model(
parse( # type: ignore
"""
f"""
MODEL (
name test_schema.test_model,
kind INCREMENTAL_BY_TIME_RANGE (
time_column ds
)
);
SELECT 1::INT as a, ds::DATE FROM test_schema.test_model;
SELECT 1::INT as a, ds::DATE FROM {from_table};
"""
),
)
Expand Down Expand Up @@ -1664,10 +1668,15 @@ def test_create_clone_in_dev_self_referencing(mocker: MockerFixture, adapter_moc
)

# Make sure the dry run references the correct ("...__schema_migration_source") table.
table_alias = (
"test_model"
if not use_this_model
else f"test_schema__test_model__{snapshot.version}__dev__schema_migration_source"
)
dry_run_query = adapter_mock.fetchall.call_args[0][0].sql()
assert (
dry_run_query
== f'SELECT CAST(1 AS INT) AS "a", CAST("ds" AS DATE) AS "ds" FROM "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}__dev__schema_migration_source" AS "test_model" /* test_schema.test_model */ WHERE FALSE LIMIT 0'
== f'SELECT CAST(1 AS INT) AS "a", CAST("ds" AS DATE) AS "ds" FROM "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}__dev__schema_migration_source" AS "{table_alias}" /* test_schema.test_model */ WHERE FALSE LIMIT 0'
)


Expand Down Expand Up @@ -2800,7 +2809,7 @@ def blocking_value(evaluator):
assert results[0].blocking


def test_create_post_statements_use_deployable_table(
def test_create_post_statements_use_non_deployable_table(
mocker: MockerFixture, adapter_mock, make_snapshot
):
evaluator = SnapshotEvaluator(adapter_mock)
Expand All @@ -2826,7 +2835,7 @@ def test_create_post_statements_use_deployable_table(
snapshot = make_snapshot(model)
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)

expected_call = f'CREATE INDEX IF NOT EXISTS "test_idx" ON "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}" /* test_schema.test_model */("a" NULLS FIRST)'
expected_call = f'CREATE INDEX IF NOT EXISTS "test_idx" ON "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}__dev" /* test_schema.test_model */("a" NULLS FIRST)'

evaluator.create([snapshot], {}, DeployabilityIndex.none_deployable())

Expand Down Expand Up @@ -2890,7 +2899,7 @@ def model_with_statements(context, **kwargs):
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)

evaluator.create([snapshot], {}, DeployabilityIndex.none_deployable())
expected_call = f'CREATE INDEX IF NOT EXISTS "idx" ON "sqlmesh__db"."db__test_model__{snapshot.version}" /* db.test_model */("id")'
expected_call = f'CREATE INDEX IF NOT EXISTS "idx" ON "sqlmesh__db"."db__test_model__{snapshot.version}__dev" /* db.test_model */("id")'

call_args = adapter_mock.execute.call_args_list
pre_calls = call_args[0][0][0]
Expand Down Expand Up @@ -2952,12 +2961,19 @@ def test_on_virtual_update_statements(mocker: MockerFixture, adapter_mock, make_
call_args = adapter_mock.execute.call_args_list
post_calls = call_args[1][0][0]
assert len(post_calls) == 1
assert (
post_calls[0].sql(dialect="postgres")
== f'CREATE INDEX IF NOT EXISTS "test_idx" ON "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}__dev" /* test_schema.test_model */("a")'
)

post_calls = call_args[3][0][0]
assert len(post_calls) == 1
assert (
post_calls[0].sql(dialect="postgres")
== f'CREATE INDEX IF NOT EXISTS "test_idx" ON "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}" /* test_schema.test_model */("a")'
)

on_virtual_update_calls = call_args[2][0][0]
on_virtual_update_calls = call_args[4][0][0]
assert (
on_virtual_update_calls[0].sql(dialect="postgres")
== 'GRANT SELECT ON VIEW "test_schema__test_env"."test_model" /* test_schema.test_model */ TO ROLE "admin"'
Expand Down Expand Up @@ -3029,7 +3045,7 @@ def model_with_statements(context, **kwargs):
)

call_args = adapter_mock.execute.call_args_list
on_virtual_update_call = call_args[2][0][0][0]
on_virtual_update_call = call_args[4][0][0][0]
assert (
on_virtual_update_call.sql(dialect="postgres")
== 'CREATE INDEX IF NOT EXISTS "idx" ON "db"."test_model_3" /* db.test_model_3 */("id")'
Expand Down

0 comments on commit ec5e085

Please sign in to comment.