From 835293683922d6b1c99ca25bf5415a053afc5bc0 Mon Sep 17 00:00:00 2001 From: Lauri Kajan Date: Tue, 2 Jun 2026 12:06:27 +0300 Subject: [PATCH] Don't change modification info when copying plans Disable triggers temporarily --- database/base.py | 20 +++++++++- database/triggers.py | 24 ++++-------- .../ryhti_client/database_client.py | 38 +++++++++++++++++-- test/ryhti_client/test_copy_plan.py | 34 ++++++++++++++++- 4 files changed, 95 insertions(+), 21 deletions(-) diff --git a/database/base.py b/database/base.py index 75f35b5..cf4c080 100644 --- a/database/base.py +++ b/database/base.py @@ -3,7 +3,7 @@ import os import uuid from datetime import datetime -from typing import Annotated, Any, ClassVar +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, cast from sqlalchemy import FetchedValue from sqlalchemy.dialects import postgresql @@ -13,6 +13,9 @@ from database.enums import AttributeValueDataType +if TYPE_CHECKING: + from sqlalchemy import Table + PROJECT_SRID = int(os.environ.get("PROJECT_SRID", "3067")) @@ -46,6 +49,21 @@ class VersionedBase(Base): __abstract__ = True __table_args__: Any = {"schema": "hame"} # noqa: RUF012 # No can do, sqlalchemy has Any annotation for this + subclasses: ClassVar[list[type[VersionedBase]]] = [] + + def __init_subclass__(cls, **kwargs: object) -> None: + super().__init_subclass__(**kwargs) + # SQLAlchemy sets __table__ only on mapped (non-abstract) classes. + if not cls.__dict__.get("__abstract__", False) and hasattr(cls, "__table__"): + VersionedBase.subclasses.append(cls) + + @classmethod + def subclass_names(cls) -> list[tuple[str, str]]: + return [ + (cast("Table", cls.__table__).schema or "", cls.__tablename__) + for cls in VersionedBase.subclasses + ] + # Go figure. We have to *explicitly state* id is a mapped column, because id will # have to be defined inside all the subclasses for relationship remote_side # definition to work. So even if there is an id field in all the classes, diff --git a/database/triggers.py b/database/triggers.py index 429e9e1..23c9f44 100644 --- a/database/triggers.py +++ b/database/triggers.py @@ -7,22 +7,14 @@ from database import models from database.base import VersionedBase - -def all_subclasses(cls: type) -> set[type]: - """Recursively find all subclasses of a class.""" - return set(cls.__subclasses__()).union( - [s for c in cls.__subclasses__() for s in all_subclasses(c)] - ) - - -all_versioned_tables = [ - (cls.__table__.schema, cls.__table__.name) - for cls in all_subclasses(VersionedBase) - if hasattr(cls, "__table__") - # If new table are added a new revision must be created in two steps. - # First to create the table, second to add triggers to it. - # To skip triggers for a new table, uncomment the next line and fill the table name. - # and cls.__table__.name != "" +# If new tables are added a new migration must be created in two steps. +# First to create the table, second to add triggers to it. +# To skip triggers for a new table, add a guard on the table name below. +all_versioned_tables: list[tuple[str, str]] = [ + (schema, table) + for (schema, table) in VersionedBase.subclass_names() + if table + != "new_table_to_skip_triggers" # Replace with actual table name to skip triggers for ] diff --git a/lambdas/ryhti_client/ryhti_client/database_client.py b/lambdas/ryhti_client/ryhti_client/database_client.py index 21b6ebc..de50721 100644 --- a/lambdas/ryhti_client/ryhti_client/database_client.py +++ b/lambdas/ryhti_client/ryhti_client/database_client.py @@ -2,6 +2,8 @@ import datetime import logging +from contextlib import contextmanager +from string import Template from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeVar, cast from uuid import UUID, uuid4 from zoneinfo import ZoneInfo @@ -34,6 +36,7 @@ from collections.abc import Generator from geoalchemy2 import WKBElement + from sqlalchemy.orm import Session from sqlalchemy.sql import FromClause from database.base import DbId @@ -997,6 +1000,32 @@ def import_plan( return plan.id + @contextmanager + def _disable_edit_triggers(self, session: Session) -> Generator[None]: + """Temporarily disable all triggers on all tables in hame schema.""" + triggers_to_disable = [ + Template("trg_${table}_created_at"), + Template("trg_${table}_001_no_created_at_update"), + Template("trg_${table}_modified_at"), + ] + + def _alter_triggers(state: str) -> None: + for schema, table in base.VersionedBase.subclass_names(): + if schema != "hame": + continue + for trigger_template in triggers_to_disable: + trigger_name = trigger_template.substitute(table=table) + disable_sql = text( + f"ALTER TABLE {schema}.{table} {state} TRIGGER {trigger_name}" + ) + session.execute(disable_sql) + + _alter_triggers("DISABLE") + try: + yield + finally: + _alter_triggers("ENABLE") + def copy_plan( self, plan_id: str, @@ -1051,9 +1080,12 @@ def copy_plan( approval_date=approval_date, period_of_validity_start=period_of_validity_start, ) - - copied_plan = plan_copier.copy_plan() - session.add(copied_plan) + with self._disable_edit_triggers(session): + copied_plan = plan_copier.copy_plan() + session.add(copied_plan) + # do the actual insert while the triggers are still disabled to avoid + # created_at and modified_at updates. + session.flush() session.commit() return copied_plan.id diff --git a/test/ryhti_client/test_copy_plan.py b/test/ryhti_client/test_copy_plan.py index 159afd0..28abbbe 100644 --- a/test/ryhti_client/test_copy_plan.py +++ b/test/ryhti_client/test_copy_plan.py @@ -104,6 +104,11 @@ def test_copy_plan( assert copied_plan is not None assert copied_plan.id != original_plan_id + assert copied_plan.creator == original_plan.creator + assert copied_plan.created_at == original_plan.created_at + assert copied_plan.modifier == original_plan.modifier + assert copied_plan.modified_at == original_plan.modified_at + assert copied_plan.plan_matter == original_plan.plan_matter assert len(copied_plan.documents) == len(original_plan.documents) @@ -130,6 +135,11 @@ def test_copy_plan( assert copied_land_use_area_1.lifecycle_status.value == valid_status_instance.value assert copied_land_use_area_1.period_of_validity_start == period_of_validity_start + assert copied_land_use_area_1.creator == land_use_area_1.creator + assert copied_land_use_area_1.created_at == land_use_area_1.created_at + assert copied_land_use_area_1.modifier == land_use_area_1.modifier + assert copied_land_use_area_1.modified_at == land_use_area_1.modified_at + assert ( land_use_area_1.type_of_underground == copied_land_use_area_1.type_of_underground @@ -149,6 +159,11 @@ def test_copy_plan( assert copied_other_area_1 is not None assert copied_other_area_1.lifecycle_status.value == valid_status_instance.value + assert copied_other_area_1.creator == other_area_1.creator + assert copied_other_area_1.created_at == other_area_1.created_at + assert copied_other_area_1.modifier == other_area_1.modifier + assert copied_other_area_1.modified_at == other_area_1.modified_at + assert other_area_1.type_of_underground == copied_other_area_1.type_of_underground # complete_test_plan fixture has no lines @@ -162,7 +177,10 @@ def test_copy_plan( ) assert copied_point_1 is not None assert copied_point_1.lifecycle_status.value == valid_status_instance.value - + assert copied_point_1.creator == point_1.creator + assert copied_point_1.created_at == point_1.created_at + assert copied_point_1.modifier == point_1.modifier + assert copied_point_1.modified_at == point_1.modified_at assert point_1.type_of_underground == copied_point_1.type_of_underground # General regulation groups @@ -180,6 +198,16 @@ def test_copy_plan( ) assert copied_general_regulation_group is not None assert copied_general_regulation_group.plan_id == copied_plan.id + assert copied_general_regulation_group.creator == general_regulation_group.creator + assert ( + copied_general_regulation_group.created_at + == general_regulation_group.created_at + ) + assert copied_general_regulation_group.modifier == general_regulation_group.modifier + assert ( + copied_general_regulation_group.modified_at + == general_regulation_group.modified_at + ) assert ( general_regulation_group.type_of_plan_regulation_group == copied_general_regulation_group.type_of_plan_regulation_group @@ -200,6 +228,10 @@ def test_copy_plan( None, ) assert copied_general_regulation is not None + assert copied_general_regulation.creator == general_regulation.creator + assert copied_general_regulation.created_at == general_regulation.created_at + assert copied_general_regulation.modifier == general_regulation.modifier + assert copied_general_regulation.modified_at == general_regulation.modified_at assert ( copied_general_regulation.period_of_validity_start == period_of_validity_start )