Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ psycopg2-binary = "^2.9.10"
boto3 = "*"
tenacity = "^9.1.2"
allure-pytest = "^2.15.0"
jubilant = "^1.3.0"
jubilant = "^1.4.0"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
70 changes: 69 additions & 1 deletion src/relations/logical_replication.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
class PostgreSQLLogicalReplication(Object):
"""Defines the logical-replication logic."""

def _identity(self) -> str:
"""Return unique identity of this application in the model."""
return f"{self.model.uuid}:{self.model.app.name}"

def __init__(self, charm: "PostgresqlOperatorCharm"):
super().__init__(charm, "postgresql_logical_replication")
self.charm = charm
Expand Down Expand Up @@ -193,6 +197,8 @@ def _on_relation_joined(self, event: RelationJoinedEvent) -> None:
event.relation.data[self.model.app]["subscription-request"] = (
self.charm.config.logical_replication_subscription_request or ""
)
# Share our identity with the publisher to prevent cyclic replication
event.relation.data[self.model.app]["requester-id"] = self._identity()

def _on_relation_changed(self, event: RelationChangedEvent) -> None:
if not self._relation_changed_checks(event):
Expand Down Expand Up @@ -247,6 +253,8 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None:
self.charm.app_peer_data["logical-replication-subscriptions"] = json.dumps({
str(event.relation.id): subscriptions
})
# Rebuild subscribed upstream provenance after any changes
self._rebuild_subscribed_upstream()

def _on_relation_departed(self, event: RelationDepartedEvent) -> None:
if event.departing_unit == self.charm.unit and self.charm._peers is not None:
Expand Down Expand Up @@ -274,9 +282,38 @@ def _on_relation_broken(self, event: RelationBrokenEvent) -> None:
f"Dropped subscription {subscription} from database {database} due to relation break"
)
self.charm.app_peer_data["logical-replication-subscriptions"] = ""
# Clear provenance as subscriptions are gone
self._rebuild_subscribed_upstream()

# endregion

def _rebuild_subscribed_upstream(self) -> None:
"""Aggregate upstream provenance for all subscribed tables.

Stores mapping in app peer data as logical-replication-subscribed-upstream
with keys formatted as "<database>:<schema>.<table>" and values equal to
the upstream identity "<model_uuid>:<app_name>".
"""
mapping: dict[str, str] = {}
for relation in self.model.relations.get(LOGICAL_REPLICATION_RELATION, ()):
pubs = json.loads(relation.data[relation.app].get("publications", "{}"))
for database, pub in pubs.items():
upstream_by_table = pub.get("upstream", {}) if isinstance(pub, dict) else {}
publisher_id = pub.get("publisher-id", "") if isinstance(pub, dict) else ""
for schematable in pub.get("tables", []):
upstream = upstream_by_table.get(schematable) or publisher_id
if upstream:
mapping[f"{database}:{schematable}"] = upstream
self.charm.app_peer_data["logical-replication-subscribed-upstream"] = json.dumps(mapping)

def _get_subscribed_upstream(self) -> dict[str, str]:
try:
return json.loads(
self.charm.app_peer_data.get("logical-replication-subscribed-upstream", "{}")
)
except json.JSONDecodeError:
return {}

# region Events

def _on_secret_changed(self, event: SecretChangedEvent) -> None:
Expand Down Expand Up @@ -465,18 +502,20 @@ def _relation_changed_checks(self, event: RelationChangedEvent) -> bool:
return False
return True

def _process_offer(self, relation: Relation) -> None:
def _process_offer(self, relation: Relation) -> None: # noqa: C901
logger.debug(
f"Started processing offer for {LOGICAL_REPLICATION_OFFER_RELATION} #{relation.id}"
)

subscriptions_request = json.loads(
relation.data[relation.app].get("subscription-request", "{}")
)
requester_id = relation.data[relation.app].get("requester-id", "")
publications = json.loads(relation.data[self.model.app].get("publications", "{}"))
secret = self._get_secret(relation.id)
user = secret.peek_content()["username"]
errors = []
subscribed_upstream = self._get_subscribed_upstream()

for database, publication in publications.copy().items():
if database in subscriptions_request:
Expand All @@ -494,6 +533,27 @@ def _process_offer(self, relation: Relation) -> None:
)

for database, tables in subscriptions_request.items():
# Cycle detection: if our local upstream for any requested table equals requester, reject
cycle_detected = False
upstream_map: dict[str, str] = {}
for schematable in tables:
key = f"{database}:{schematable}"
local_upstream = subscribed_upstream.get(key) or self._identity()
upstream_map[schematable] = local_upstream
if requester_id and requester_id == local_upstream:
cycle_detected = True
if cycle_detected:
error = (
f"cyclic logical replication detected for database {database}: "
f"requested tables would replicate back to their upstream ({requester_id})"
)
errors.append(error)
logger.error(
f"Cannot create/alter publication for {LOGICAL_REPLICATION_OFFER_RELATION} #{relation.id}: {error}"
)
# Skip creating/altering this publication to avoid loop
continue

if database not in publications:
if validation_error := self._validate_new_publication(database, tables):
errors.append(validation_error)
Expand Down Expand Up @@ -521,6 +581,8 @@ def _process_offer(self, relation: Relation) -> None:
"publication-name": publication_name,
"replication-slot-name": self._replication_slot_name(relation.id, database),
"tables": tables,
"publisher-id": self._identity(),
"upstream": upstream_map,
}
elif sorted(publication_tables := publications[database]["tables"]) != sorted(tables):
publication_name = publications[database]["publication-name"]
Expand Down Expand Up @@ -551,6 +613,12 @@ def _process_offer(self, relation: Relation) -> None:
)
self.charm.postgresql.alter_publication(database, publication_name, tables)
publications[database]["tables"] = tables
publications[database]["publisher-id"] = self._identity()
publications[database]["upstream"] = upstream_map
else:
# Tables unchanged; still update provenance and publisher id to propagate upstream
publications[database]["publisher-id"] = self._identity()
publications[database]["upstream"] = upstream_map
self._save_published_resources_info(str(relation.id), secret.id, publications) # type: ignore
relation.data[self.model.app]["publications"] = json.dumps(publications)

Expand Down
36 changes: 33 additions & 3 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import uuid

import boto3
import jubilant
import pytest
from pytest_operator.plugin import OpsTest

from . import architecture
from .helpers import construct_endpoint
Expand All @@ -17,6 +17,36 @@
logger = logging.getLogger(__name__)


@pytest.fixture(scope="module")
def juju(request: pytest.FixtureRequest):
"""Pytest fixture that wraps :meth:`jubilant.with_model`.

This adds command line parameter ``--keep-models`` (see help for details).
"""
controller = request.config.getoption("--controller")
model = request.config.getoption("--model")
controller_and_model = None
if controller and model:
controller_and_model = f"{controller}:{model}"
elif controller:
controller_and_model = controller
elif model:
controller_and_model = model
keep_models = bool(request.config.getoption("--keep-models"))

if controller_and_model:
juju = jubilant.Juju(model=controller_and_model) # type: ignore
yield juju
log = juju.debug_log(limit=1000)
else:
with jubilant.temp_model(keep=keep_models) as juju:
yield juju
log = juju.debug_log(limit=1000)

if request.session.testsfailed:
print(log, end="")


@pytest.fixture(scope="session")
def charm():
# Return str instead of pathlib.Path since python-libjuju's model.deploy(), juju deploy, and
Expand Down Expand Up @@ -67,7 +97,7 @@ def cleanup_cloud(config: dict[str, str], credentials: dict[str, str]) -> None:


@pytest.fixture(scope="module")
async def aws_cloud_configs(ops_test: OpsTest) -> None:
async def aws_cloud_configs():
if (
not os.environ.get("AWS_ACCESS_KEY", "").strip()
or not os.environ.get("AWS_SECRET_KEY", "").strip()
Expand All @@ -82,7 +112,7 @@ async def aws_cloud_configs(ops_test: OpsTest) -> None:


@pytest.fixture(scope="module")
async def gcp_cloud_configs(ops_test: OpsTest) -> None:
async def gcp_cloud_configs():
if (
not os.environ.get("GCP_ACCESS_KEY", "").strip()
or not os.environ.get("GCP_SECRET_KEY", "").strip()
Expand Down
Loading
Loading