diff --git a/database/base.py b/database/base.py
index 75f35b5..cf4c080 100644
--- a/database/base.py
+++ b/database/base.py
@@ -3,7 +3,7 @@
import os
import uuid
from datetime import datetime
-from typing import Annotated, Any, ClassVar
+from typing import TYPE_CHECKING, Annotated, Any, ClassVar, cast
from sqlalchemy import FetchedValue
from sqlalchemy.dialects import postgresql
@@ -13,6 +13,9 @@
from database.enums import AttributeValueDataType
+if TYPE_CHECKING:
+ from sqlalchemy import Table
+
PROJECT_SRID = int(os.environ.get("PROJECT_SRID", "3067"))
@@ -46,6 +49,21 @@ class VersionedBase(Base):
__abstract__ = True
__table_args__: Any = {"schema": "hame"} # noqa: RUF012 # No can do, sqlalchemy has Any annotation for this
+ subclasses: ClassVar[list[type[VersionedBase]]] = []
+
+ def __init_subclass__(cls, **kwargs: object) -> None:
+ super().__init_subclass__(**kwargs)
+ # SQLAlchemy sets __table__ only on mapped (non-abstract) classes.
+ if not cls.__dict__.get("__abstract__", False) and hasattr(cls, "__table__"):
+ VersionedBase.subclasses.append(cls)
+
+ @classmethod
+ def subclass_names(cls) -> list[tuple[str, str]]:
+ return [
+ (cast("Table", cls.__table__).schema or "", cls.__tablename__)
+ for cls in VersionedBase.subclasses
+ ]
+
# Go figure. We have to *explicitly state* id is a mapped column, because id will
# have to be defined inside all the subclasses for relationship remote_side
# definition to work. So even if there is an id field in all the classes,
diff --git a/database/triggers.py b/database/triggers.py
index 429e9e1..23c9f44 100644
--- a/database/triggers.py
+++ b/database/triggers.py
@@ -7,22 +7,14 @@
from database import models
from database.base import VersionedBase
-
-def all_subclasses(cls: type) -> set[type]:
- """Recursively find all subclasses of a class."""
- return set(cls.__subclasses__()).union(
- [s for c in cls.__subclasses__() for s in all_subclasses(c)]
- )
-
-
-all_versioned_tables = [
- (cls.__table__.schema, cls.__table__.name)
- for cls in all_subclasses(VersionedBase)
- if hasattr(cls, "__table__")
- # If new table are added a new revision must be created in two steps.
- # First to create the table, second to add triggers to it.
- # To skip triggers for a new table, uncomment the next line and fill the table name.
- # and cls.__table__.name != "
"
+# If new tables are added a new migration must be created in two steps.
+# First to create the table, second to add triggers to it.
+# To skip triggers for a new table, add a guard on the table name below.
+all_versioned_tables: list[tuple[str, str]] = [
+ (schema, table)
+ for (schema, table) in VersionedBase.subclass_names()
+ if table
+ != "new_table_to_skip_triggers" # Replace with actual table name to skip triggers for
]
diff --git a/lambdas/ryhti_client/ryhti_client/database_client.py b/lambdas/ryhti_client/ryhti_client/database_client.py
index 21b6ebc..de50721 100644
--- a/lambdas/ryhti_client/ryhti_client/database_client.py
+++ b/lambdas/ryhti_client/ryhti_client/database_client.py
@@ -2,6 +2,8 @@
import datetime
import logging
+from contextlib import contextmanager
+from string import Template
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeVar, cast
from uuid import UUID, uuid4
from zoneinfo import ZoneInfo
@@ -34,6 +36,7 @@
from collections.abc import Generator
from geoalchemy2 import WKBElement
+ from sqlalchemy.orm import Session
from sqlalchemy.sql import FromClause
from database.base import DbId
@@ -997,6 +1000,32 @@ def import_plan(
return plan.id
+ @contextmanager
+ def _disable_edit_triggers(self, session: Session) -> Generator[None]:
+ """Temporarily disable all triggers on all tables in hame schema."""
+ triggers_to_disable = [
+ Template("trg_${table}_created_at"),
+ Template("trg_${table}_001_no_created_at_update"),
+ Template("trg_${table}_modified_at"),
+ ]
+
+ def _alter_triggers(state: str) -> None:
+ for schema, table in base.VersionedBase.subclass_names():
+ if schema != "hame":
+ continue
+ for trigger_template in triggers_to_disable:
+ trigger_name = trigger_template.substitute(table=table)
+ disable_sql = text(
+ f"ALTER TABLE {schema}.{table} {state} TRIGGER {trigger_name}"
+ )
+ session.execute(disable_sql)
+
+ _alter_triggers("DISABLE")
+ try:
+ yield
+ finally:
+ _alter_triggers("ENABLE")
+
def copy_plan(
self,
plan_id: str,
@@ -1051,9 +1080,12 @@ def copy_plan(
approval_date=approval_date,
period_of_validity_start=period_of_validity_start,
)
-
- copied_plan = plan_copier.copy_plan()
- session.add(copied_plan)
+ with self._disable_edit_triggers(session):
+ copied_plan = plan_copier.copy_plan()
+ session.add(copied_plan)
+ # do the actual insert while the triggers are still disabled to avoid
+ # created_at and modified_at updates.
+ session.flush()
session.commit()
return copied_plan.id
diff --git a/test/ryhti_client/test_copy_plan.py b/test/ryhti_client/test_copy_plan.py
index 159afd0..28abbbe 100644
--- a/test/ryhti_client/test_copy_plan.py
+++ b/test/ryhti_client/test_copy_plan.py
@@ -104,6 +104,11 @@ def test_copy_plan(
assert copied_plan is not None
assert copied_plan.id != original_plan_id
+ assert copied_plan.creator == original_plan.creator
+ assert copied_plan.created_at == original_plan.created_at
+ assert copied_plan.modifier == original_plan.modifier
+ assert copied_plan.modified_at == original_plan.modified_at
+
assert copied_plan.plan_matter == original_plan.plan_matter
assert len(copied_plan.documents) == len(original_plan.documents)
@@ -130,6 +135,11 @@ def test_copy_plan(
assert copied_land_use_area_1.lifecycle_status.value == valid_status_instance.value
assert copied_land_use_area_1.period_of_validity_start == period_of_validity_start
+ assert copied_land_use_area_1.creator == land_use_area_1.creator
+ assert copied_land_use_area_1.created_at == land_use_area_1.created_at
+ assert copied_land_use_area_1.modifier == land_use_area_1.modifier
+ assert copied_land_use_area_1.modified_at == land_use_area_1.modified_at
+
assert (
land_use_area_1.type_of_underground
== copied_land_use_area_1.type_of_underground
@@ -149,6 +159,11 @@ def test_copy_plan(
assert copied_other_area_1 is not None
assert copied_other_area_1.lifecycle_status.value == valid_status_instance.value
+ assert copied_other_area_1.creator == other_area_1.creator
+ assert copied_other_area_1.created_at == other_area_1.created_at
+ assert copied_other_area_1.modifier == other_area_1.modifier
+ assert copied_other_area_1.modified_at == other_area_1.modified_at
+
assert other_area_1.type_of_underground == copied_other_area_1.type_of_underground
# complete_test_plan fixture has no lines
@@ -162,7 +177,10 @@ def test_copy_plan(
)
assert copied_point_1 is not None
assert copied_point_1.lifecycle_status.value == valid_status_instance.value
-
+ assert copied_point_1.creator == point_1.creator
+ assert copied_point_1.created_at == point_1.created_at
+ assert copied_point_1.modifier == point_1.modifier
+ assert copied_point_1.modified_at == point_1.modified_at
assert point_1.type_of_underground == copied_point_1.type_of_underground
# General regulation groups
@@ -180,6 +198,16 @@ def test_copy_plan(
)
assert copied_general_regulation_group is not None
assert copied_general_regulation_group.plan_id == copied_plan.id
+ assert copied_general_regulation_group.creator == general_regulation_group.creator
+ assert (
+ copied_general_regulation_group.created_at
+ == general_regulation_group.created_at
+ )
+ assert copied_general_regulation_group.modifier == general_regulation_group.modifier
+ assert (
+ copied_general_regulation_group.modified_at
+ == general_regulation_group.modified_at
+ )
assert (
general_regulation_group.type_of_plan_regulation_group
== copied_general_regulation_group.type_of_plan_regulation_group
@@ -200,6 +228,10 @@ def test_copy_plan(
None,
)
assert copied_general_regulation is not None
+ assert copied_general_regulation.creator == general_regulation.creator
+ assert copied_general_regulation.created_at == general_regulation.created_at
+ assert copied_general_regulation.modifier == general_regulation.modifier
+ assert copied_general_regulation.modified_at == general_regulation.modified_at
assert (
copied_general_regulation.period_of_validity_start == period_of_validity_start
)