Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/common/locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ def units_requesting_lock(self) -> list[str]:
@property
def next_unit_to_give_lock(self) -> str | None:
"""Get the next unit to give the start lock to."""
units_requesting_lock = self.units_requesting_lock
if self.state.unit_server.model[self.unit_request_lock_atr_name]:
return self.state.unit_server.unit_name
return self.units_requesting_lock[0] if self.units_requesting_lock else None
return units_requesting_lock[0] if units_requesting_lock else None

@property
def unit_with_lock(self) -> "ValkeyServer | None":
Expand Down Expand Up @@ -165,10 +166,11 @@ class StartLock(DataBagLock):
@property
def is_lock_free_to_give(self) -> bool:
"""Check if the unit with the start lock has completed its operation."""
if not self.state.cluster.model.start_member:
return True
starting_unit = self.unit_with_lock
return (
not self.state.cluster.model.start_member
or not starting_unit
not starting_unit
or starting_unit.is_started
or not starting_unit.model.request_start_lock
)
Expand All @@ -183,12 +185,10 @@ class RestartLock(DataBagLock):
@property
def is_lock_free_to_give(self) -> bool:
"""Check if the unit with the restart lock has completed its operation."""
if not self.state.cluster.model.restart_member:
return True
restarting_unit = self.unit_with_lock
return (
not self.state.cluster.model.restart_member
or not restarting_unit
or not restarting_unit.model.request_restart_lock
)
return not restarting_unit or not restarting_unit.model.request_restart_lock


class ScaleDownLock(Lockable):
Expand Down
16 changes: 8 additions & 8 deletions src/events/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,6 @@ def _on_start(self, event: ops.StartEvent) -> None:
event.defer()
return

self.charm.state.unit_server.update({"start_state": StartState.WAITING_TO_START.value})
start_lock.request_lock()

if not start_lock.is_held_by_this_unit:
logger.info("Waiting for lock to start")
event.defer()
return
try:
primary_endpoint = self.charm.sentinel_manager.get_primary_ip()
except ValkeyCannotGetPrimaryIPError:
Expand All @@ -173,10 +166,17 @@ def _on_start(self, event: ops.StartEvent) -> None:
self.charm.state.unit_server.update(
{"start_state": StartState.WAITING_FOR_PRIMARY_START.value}
)
start_lock.release_lock()
event.defer()
return

self.charm.state.unit_server.update({"start_state": StartState.WAITING_TO_START.value})
start_lock.request_lock()

if not start_lock.is_held_by_this_unit:
logger.info("Waiting for lock to start")
event.defer()
return

try:
self.charm.config_manager.configure_services(primary_endpoint)
self.charm.workload.start()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/continuous_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,10 @@ async def with_client(conf: SimpleNamespace):
):
raise WriteFailedError("LPUSH returned 0/None")
proc_logger.info("Length after write: %s", res)
await asyncio.sleep(in_between_sleep)
except Exception as e:
proc_logger.warning("Write failed at %s: %s", current_val, e)
finally:
await asyncio.sleep(in_between_sleep)
if event.is_set():
break

Expand Down
15 changes: 12 additions & 3 deletions tests/integration/cw_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def assert_continuous_writes_consistent(
hostnames: list[str],
username: str,
password: str,
ignore_count: bool = False,
) -> None:
"""Assert that the continuous writes are consistent."""
last_written_value = None
Expand All @@ -81,10 +82,18 @@ def assert_continuous_writes_consistent(
for endpoint in hostnames:
last_value = int(exec_valkey_cli(endpoint, username, password, f"LRANGE {KEY} 0 0").stdout)
count = int(exec_valkey_cli(endpoint, username, password, f"LLEN {KEY}").stdout)
logger.info(
"Endpoint: %s, last written value: %s, last value in DB: %s, count in DB: %s",
endpoint,
last_written_value,
last_value,
count,
)
assert last_written_value == last_value, (
f"endpoint: {endpoint}, expected value: {last_written_value}, current value: {last_value}"
)
assert count == last_written_value + 1, (
f"endpoint: {endpoint}, expected count: {last_written_value + 1}, current count: {count}"
)
if not ignore_count:
assert count == last_written_value + 1, (
f"endpoint: {endpoint}, expected count: {last_written_value + 1}, current count: {count}"
)
logger.info("Continuous writes are consistent on %s.", endpoint)
252 changes: 251 additions & 1 deletion tests/integration/ha/helpers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,29 @@
import os
import string
import subprocess
import tarfile
import tempfile
import time
from datetime import datetime
from logging import getLogger

import jubilant
import kubernetes as kubernetes
import urllib3
from kubernetes import client, config
from kubernetes.client.rest import ApiException
from tenacity import RetryError, Retrying, stop_after_attempt, wait_fixed
from tenacity import RetryError, Retrying, stop_after_attempt, stop_after_delay, wait_fixed

from literals import Substrate
from tests.integration.helpers import APP_NAME

logger = getLogger(__name__)

VALKEY_SNAP_SERVICE_NAME = "snap.charmed-valkey.server.service"
VM_RESTART_DELAY_DEFAULT = 20
K8S_RESTART_DELAY_DEFAULT = 5
RESTART_DELAY_PATCHED = 120


def lxd_cut_network_from_unit_with_ip_change(machine_name: str) -> None:
"""Cut network from a lxc container in a way the changes the IP."""
Expand Down Expand Up @@ -361,3 +368,246 @@ def get_sans_from_certificate(certificate_path: str) -> dict[str, set[str]]:
sans_ip.add(san_value)

return {"sans_ip": sans_ip, "sans_dns": sans_dns}


def send_process_control_signal(
unit_name: str,
model_full_name: str,
signal: str,
db_process: str,
substrate: Substrate,
) -> None:
"""Send control signal to a database process running on a Juju unit.

Args:
unit_name: the Juju unit running the process
model_full_name: the Juju model for the unit
signal: the signal to issue, e.g `SIGKILL`
db_process: the path to the database process binary
substrate: the substrate the test is running on
"""
if substrate == Substrate.K8S:
# For k8s, we exec into the pod and send the signal to the process
command = f"JUJU_MODEL={model_full_name} juju ssh --container valkey {unit_name} pkill --signal {signal} {db_process}"
else:
command = f"JUJU_MODEL={model_full_name} juju ssh {unit_name} sudo -i 'pkill --signal {signal} -f {db_process}'"

try:
subprocess.check_output(
command, stderr=subprocess.PIPE, shell=True, universal_newlines=True, timeout=3
)
except (subprocess.CalledProcessError, subprocess.TimeoutExpired):
pass
logger.info(f"Signal {signal} sent to database process on unit {unit_name}.")


def lxd_patch_restart_delay(juju: jubilant.Juju, unit_name: str, delay: int | None = None) -> None:
"""Update the restart delay in the snap's systemd service file."""
delay = delay or VM_RESTART_DELAY_DEFAULT
juju.exec(
command=f"sed -i 's/^RestartSec=.*/RestartSec={delay}s/' /etc/systemd/system/{VALKEY_SNAP_SERVICE_NAME}",
unit=unit_name,
)

# reload the daemon for systemd to reflect changes
juju.exec(command="sudo systemctl daemon-reload", unit=unit_name)


EXTEND_PEBBLE_RESTART_DELAY_YAML = """services:
valkey:
override: merge
backoff-delay: {delay}s
backoff-limit: {delay}s
"""

RESTORE_PEBBLE_RESTART_DELAY_YAML = """services:
valkey:
override: merge
backoff-delay: 500ms
backoff-limit: 30s
"""


def pebble_patch_restart_delay(
juju: jubilant.Juju,
unit_name: str,
delay: int | None = None,
ensure_replan: bool = False,
) -> None:
"""Modify the pebble restart delay of the underlying process.

Args:
juju: An instance of Jubilant's Juju class on which to run Juju commands
unit_name: The name of unit to extend the pebble restart delay for
delay: The new restart delay to apply
ensure_replan: Whether to check that the replan command succeeded
"""
pebble_file_content = (
EXTEND_PEBBLE_RESTART_DELAY_YAML.format(delay=delay)
if delay
else RESTORE_PEBBLE_RESTART_DELAY_YAML
)
kubernetes.config.load_kube_config()
client = kubernetes.client.api.core_v1_api.CoreV1Api()

pod_name = unit_name.replace("/", "-")
container_name = "valkey"
service_name = "valkey"
now = datetime.now().isoformat()

with tempfile.NamedTemporaryFile() as pebble_plan_file:
pebble_plan_file.write(str.encode(pebble_file_content))
pebble_plan_file.flush()

copy_file_into_pod(
client,
juju.model,
pod_name,
container_name,
pebble_plan_file.name,
f"/tmp/pebble_plan_{now}.yml",
)

add_to_pebble_layer_commands = (
f"/charm/bin/pebble add --combine {service_name} /tmp/pebble_plan_{now}.yml"
)
response = kubernetes.stream.stream(
client.connect_get_namespaced_pod_exec,
pod_name,
juju.model,
container=container_name,
command=add_to_pebble_layer_commands.split(),
stdin=False,
stdout=True,
stderr=True,
tty=False,
_preload_content=False,
)
response.run_forever(timeout=5)
assert response.returncode == 0, (
f"Failed to add to pebble layer, unit={unit_name}, container={container_name}, service={service_name}"
)

for attempt in Retrying(stop=stop_after_delay(60), wait=wait_fixed(3)):
with attempt:
replan_pebble_layer_commands = "/charm/bin/pebble replan"
response = kubernetes.stream.stream(
client.connect_get_namespaced_pod_exec,
pod_name,
juju.model,
container=container_name,
command=replan_pebble_layer_commands.split(),
stdin=False,
stdout=True,
stderr=True,
tty=False,
_preload_content=False,
)
response.run_forever(timeout=60)
if ensure_replan:
assert response.returncode == 0, (
f"Failed to replan pebble layer, unit={unit_name}, container={container_name}, service={service_name}"
)


def copy_file_into_pod(
client: kubernetes.client.api.core_v1_api.CoreV1Api,
namespace: str,
pod_name: str,
container_name: str,
source_path: str,
destination_path: str,
) -> None:
"""Copy file contents into pod.

Args:
client: The kubernetes CoreV1Api client
namespace: The namespace of the pod to copy files to
pod_name: The name of the pod to copy files to
container_name: The name of the pod container to copy files to
source_path: The path of the file to copy from the local machine
destination_path: The path to copy the file to in the pod
"""
try:
exec_command = ["tar", "xvf", "-", "-C", "/"]

api_response = kubernetes.stream.stream(
client.connect_get_namespaced_pod_exec,
pod_name,
namespace,
container=container_name,
command=exec_command,
stdin=True,
stdout=True,
stderr=True,
tty=False,
_preload_content=False,
)

with tempfile.TemporaryFile() as tar_buffer:
with tarfile.open(fileobj=tar_buffer, mode="w") as tar:
tar.add(source_path, destination_path)

tar_buffer.seek(0)
commands = []
commands.append(tar_buffer.read())

while api_response.is_open():
api_response.update(timeout=1)

if commands:
command = commands.pop(0)
api_response.write_stdin(command.decode())
else:
break

api_response.close()
except kubernetes.client.rest.ApiException:
assert False


def patch_restart_delay(
juju: jubilant.Juju, unit_name: str, delay: int | None, substrate: Substrate
) -> None:
"""Update the restart delay for the database process based on the substrate."""
match substrate:
case Substrate.VM:
lxd_patch_restart_delay(juju, unit_name, delay)
case Substrate.K8S:
pebble_patch_restart_delay(juju, unit_name, delay=delay, ensure_replan=True)


def reboot_unit(juju: jubilant.Juju, unit_name: str, substrate: Substrate) -> None:
"""Reboot a unit."""
if substrate == Substrate.VM:
juju.exec(command="sudo reboot", unit=unit_name)
else:
delete_pod(unit_name.replace("/", "-"), juju.model)


def delete_pod(pod_name: str, namespace="testing"):
# Load the kubeconfig file from your local machine (~/.kube/config)
# Note: If running this script INSIDE a pod, use config.load_incluster_config() instead.
config.load_kube_config()

configuration = client.Configuration.get_default_copy()
configuration.verify_ssl = False
client.Configuration.set_default(configuration)
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

# CoreV1Api contains the methods for core resources like Pods, Services, etc.
v1 = client.CoreV1Api()

try:
# Call the API to delete the pod
logger.info("Attempting to delete pod %s in namespace '%s'...", pod_name, namespace)
v1.delete_namespaced_pod(name=pod_name, namespace=namespace)

logger.info("Success! Pod deleted.")

except ApiException as e:
# Handle API errors (e.g., pod not found, unauthorized, etc.)
if e.status == 404:
logger.warning("Error: Pod '%s' not found in namespace '%s'.", pod_name, namespace)
else:
logger.error("Exception when calling CoreV1Api->delete_namespaced_pod: %s", e)
Loading
Loading