Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion database/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import uuid
from datetime import datetime
from typing import Annotated, Any, ClassVar
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, cast

from sqlalchemy import FetchedValue
from sqlalchemy.dialects import postgresql
Expand All @@ -13,6 +13,9 @@

from database.enums import AttributeValueDataType

if TYPE_CHECKING:
from sqlalchemy import Table

PROJECT_SRID = int(os.environ.get("PROJECT_SRID", "3067"))


Expand Down Expand Up @@ -46,6 +49,21 @@ class VersionedBase(Base):
__abstract__ = True
__table_args__: Any = {"schema": "hame"} # noqa: RUF012 # No can do, sqlalchemy has Any annotation for this

subclasses: ClassVar[list[type[VersionedBase]]] = []

def __init_subclass__(cls, **kwargs: object) -> None:
super().__init_subclass__(**kwargs)
# SQLAlchemy sets __table__ only on mapped (non-abstract) classes.
if not cls.__dict__.get("__abstract__", False) and hasattr(cls, "__table__"):
VersionedBase.subclasses.append(cls)

@classmethod
def subclass_names(cls) -> list[tuple[str, str]]:
return [
(cast("Table", cls.__table__).schema or "", cls.__tablename__)
for cls in VersionedBase.subclasses
]

# Go figure. We have to *explicitly state* id is a mapped column, because id will
# have to be defined inside all the subclasses for relationship remote_side
# definition to work. So even if there is an id field in all the classes,
Expand Down
24 changes: 8 additions & 16 deletions database/triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,14 @@
from database import models
from database.base import VersionedBase


def all_subclasses(cls: type) -> set[type]:
"""Recursively find all subclasses of a class."""
return set(cls.__subclasses__()).union(
[s for c in cls.__subclasses__() for s in all_subclasses(c)]
)


all_versioned_tables = [
(cls.__table__.schema, cls.__table__.name)
for cls in all_subclasses(VersionedBase)
if hasattr(cls, "__table__")
# If new table are added a new revision must be created in two steps.
# First to create the table, second to add triggers to it.
# To skip triggers for a new table, uncomment the next line and fill the table name.
# and cls.__table__.name != "<table name>"
# If new tables are added a new migration must be created in two steps.
# First to create the table, second to add triggers to it.
# To skip triggers for a new table, add a guard on the table name below.
all_versioned_tables: list[tuple[str, str]] = [
(schema, table)
for (schema, table) in VersionedBase.subclass_names()
if table
!= "new_table_to_skip_triggers" # Replace with actual table name to skip triggers for
]


Expand Down
38 changes: 35 additions & 3 deletions lambdas/ryhti_client/ryhti_client/database_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import datetime
import logging
from contextlib import contextmanager
from string import Template
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeVar, cast
from uuid import UUID, uuid4
from zoneinfo import ZoneInfo
Expand Down Expand Up @@ -34,6 +36,7 @@
from collections.abc import Generator

from geoalchemy2 import WKBElement
from sqlalchemy.orm import Session
from sqlalchemy.sql import FromClause

from database.base import DbId
Expand Down Expand Up @@ -997,6 +1000,32 @@ def import_plan(

return plan.id

@contextmanager
def _disable_edit_triggers(self, session: Session) -> Generator[None]:
"""Temporarily disable all triggers on all tables in hame schema."""
triggers_to_disable = [
Template("trg_${table}_created_at"),
Template("trg_${table}_001_no_created_at_update"),
Template("trg_${table}_modified_at"),
]

def _alter_triggers(state: str) -> None:
for schema, table in base.VersionedBase.subclass_names():
if schema != "hame":
continue
for trigger_template in triggers_to_disable:
trigger_name = trigger_template.substitute(table=table)
disable_sql = text(
f"ALTER TABLE {schema}.{table} {state} TRIGGER {trigger_name}"
)
session.execute(disable_sql)

_alter_triggers("DISABLE")
try:
yield
finally:
_alter_triggers("ENABLE")

def copy_plan(
self,
plan_id: str,
Expand Down Expand Up @@ -1051,9 +1080,12 @@ def copy_plan(
approval_date=approval_date,
period_of_validity_start=period_of_validity_start,
)

copied_plan = plan_copier.copy_plan()
session.add(copied_plan)
with self._disable_edit_triggers(session):
copied_plan = plan_copier.copy_plan()
session.add(copied_plan)
# do the actual insert while the triggers are still disabled to avoid
# created_at and modified_at updates.
session.flush()
session.commit()

return copied_plan.id
34 changes: 33 additions & 1 deletion test/ryhti_client/test_copy_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ def test_copy_plan(
assert copied_plan is not None
assert copied_plan.id != original_plan_id

assert copied_plan.creator == original_plan.creator
assert copied_plan.created_at == original_plan.created_at
assert copied_plan.modifier == original_plan.modifier
assert copied_plan.modified_at == original_plan.modified_at

assert copied_plan.plan_matter == original_plan.plan_matter
assert len(copied_plan.documents) == len(original_plan.documents)

Expand All @@ -130,6 +135,11 @@ def test_copy_plan(
assert copied_land_use_area_1.lifecycle_status.value == valid_status_instance.value
assert copied_land_use_area_1.period_of_validity_start == period_of_validity_start

assert copied_land_use_area_1.creator == land_use_area_1.creator
assert copied_land_use_area_1.created_at == land_use_area_1.created_at
assert copied_land_use_area_1.modifier == land_use_area_1.modifier
assert copied_land_use_area_1.modified_at == land_use_area_1.modified_at

assert (
land_use_area_1.type_of_underground
== copied_land_use_area_1.type_of_underground
Expand All @@ -149,6 +159,11 @@ def test_copy_plan(
assert copied_other_area_1 is not None
assert copied_other_area_1.lifecycle_status.value == valid_status_instance.value

assert copied_other_area_1.creator == other_area_1.creator
assert copied_other_area_1.created_at == other_area_1.created_at
assert copied_other_area_1.modifier == other_area_1.modifier
assert copied_other_area_1.modified_at == other_area_1.modified_at

assert other_area_1.type_of_underground == copied_other_area_1.type_of_underground

# complete_test_plan fixture has no lines
Expand All @@ -162,7 +177,10 @@ def test_copy_plan(
)
assert copied_point_1 is not None
assert copied_point_1.lifecycle_status.value == valid_status_instance.value

assert copied_point_1.creator == point_1.creator
assert copied_point_1.created_at == point_1.created_at
assert copied_point_1.modifier == point_1.modifier
assert copied_point_1.modified_at == point_1.modified_at
assert point_1.type_of_underground == copied_point_1.type_of_underground

# General regulation groups
Expand All @@ -180,6 +198,16 @@ def test_copy_plan(
)
assert copied_general_regulation_group is not None
assert copied_general_regulation_group.plan_id == copied_plan.id
assert copied_general_regulation_group.creator == general_regulation_group.creator
assert (
copied_general_regulation_group.created_at
== general_regulation_group.created_at
)
assert copied_general_regulation_group.modifier == general_regulation_group.modifier
assert (
copied_general_regulation_group.modified_at
== general_regulation_group.modified_at
)
assert (
general_regulation_group.type_of_plan_regulation_group
== copied_general_regulation_group.type_of_plan_regulation_group
Expand All @@ -200,6 +228,10 @@ def test_copy_plan(
None,
)
assert copied_general_regulation is not None
assert copied_general_regulation.creator == general_regulation.creator
assert copied_general_regulation.created_at == general_regulation.created_at
assert copied_general_regulation.modifier == general_regulation.modifier
assert copied_general_regulation.modified_at == general_regulation.modified_at
assert (
copied_general_regulation.period_of_validity_start == period_of_validity_start
)
Expand Down
Loading