Skip to content

fix: Improve signal handling for worker processes #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 70 additions & 11 deletions neps/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import logging
import os
import shutil
import signal
import time
from collections.abc import Callable, Iterator, Mapping
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, ClassVar, Literal
from typing import TYPE_CHECKING, Any, ClassVar, Literal

from pandas.core.common import contextlib
from portalocker import portalocker

from neps.env import (
Expand Down Expand Up @@ -55,6 +57,11 @@ def _default_worker_name() -> str:
return f"{os.getpid()}-{isoformat}"


SIGNALS_TO_HANDLE_IF_AVAILABLE = [
"SIGINT",
"SIGTERM",
]

_DDP_ENV_VAR_NAME = "NEPS_DDP_TRIAL_ID"


Expand Down Expand Up @@ -182,6 +189,8 @@ class DefaultWorker:
worker_cumulative_evaluation_time_seconds: float = 0.0
"""The time spent evaluating configurations by this worker."""

_PREVIOUS_SIGNAL_HANDLERS: dict[int, signal._HANDLER] = field(default_factory=dict)

_GRACE: ClassVar = FS_SYNC_GRACE_BASE

@classmethod
Expand Down Expand Up @@ -369,6 +378,16 @@ def _check_global_stopping_criterion(

return False

def _set_signal_handlers(self) -> None:
for name in SIGNALS_TO_HANDLE_IF_AVAILABLE:
if hasattr(signal.Signals, name):
sig = getattr(signal.Signals, name)
# HACK: Despite what python documentation says, the existance of a signal
# is not enough to guarantee that it can be caught.
with contextlib.suppress(ValueError):
previous_signal_handler = signal.signal(sig, self._emergency_cleanup)
self._PREVIOUS_SIGNAL_HANDLERS[sig] = previous_signal_handler

@property
def _requires_global_stopping_criterion(self) -> bool:
return (
Expand Down Expand Up @@ -491,6 +510,7 @@ def run(self) -> None: # noqa: C901, PLR0912, PLR0915
Will keep running until one of the criterion defined by the `WorkerSettings`
is met.
"""
self._set_signal_handlers()
_set_workers_neps_state(self.state)

logger.info("Launching NePS")
Expand Down Expand Up @@ -580,15 +600,21 @@ def run(self) -> None: # noqa: C901, PLR0912, PLR0915
continue

# We (this worker) has managed to set it to evaluating, now we can evaluate it
with _set_global_trial(trial_to_eval):
evaluated_trial, report = evaluate_trial(
trial=trial_to_eval,
evaluation_fn=self.evaluation_fn,
default_report_values=self.settings.default_report_values,
)
evaluation_duration = evaluated_trial.metadata.evaluation_duration
assert evaluation_duration is not None
self.worker_cumulative_evaluation_time_seconds += evaluation_duration
try:
with _set_global_trial(trial_to_eval):
evaluated_trial, report = evaluate_trial(
trial=trial_to_eval,
evaluation_fn=self.evaluation_fn,
default_report_values=self.settings.default_report_values,
)
except KeyboardInterrupt as e:
# This throws and we have stopped the worker at this point
self._emergency_cleanup(signum=signal.SIGINT, frame=None, rethrow=e)
return

evaluation_duration = evaluated_trial.metadata.evaluation_duration
assert evaluation_duration is not None
self.worker_cumulative_evaluation_time_seconds += evaluation_duration

self.worker_cumulative_eval_count += 1

Expand Down Expand Up @@ -630,6 +656,39 @@ def run(self) -> None: # noqa: C901, PLR0912, PLR0915
"Learning Curve %s: %s", evaluated_trial.id, report.learning_curve
)

def _emergency_cleanup(
self,
signum: int,
frame: Any,
rethrow: KeyboardInterrupt | None = None,
) -> None:
"""Handle signals."""
global _CURRENTLY_RUNNING_TRIAL_IN_PROCESS # noqa: PLW0603
logger.error(
f"Worker '{self.worker_id}' received signal {signum}. Stopping worker now!"
)
if _CURRENTLY_RUNNING_TRIAL_IN_PROCESS is not None:
logger.error(
"Worker '%s' was interrupted while evaluating trial: %s. Setting"
" trial to pending!",
self.worker_id,
_CURRENTLY_RUNNING_TRIAL_IN_PROCESS.id,
)
_CURRENTLY_RUNNING_TRIAL_IN_PROCESS.reset()
try:
self.state.put_updated_trial(_CURRENTLY_RUNNING_TRIAL_IN_PROCESS)
except NePSError as e:
logger.exception(e)
finally:
_CURRENTLY_RUNNING_TRIAL_IN_PROCESS = None

previous_handler = self._PREVIOUS_SIGNAL_HANDLERS.get(signum)
if previous_handler is not None and callable(previous_handler):
previous_handler(signum, frame)
if rethrow is not None:
raise rethrow
raise KeyboardInterrupt(f"Worker was interrupted by signal {signum}.")


def _launch_ddp_runtime(
*,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ dependencies = [
"ifbo>=0.3.10",
"botorch>=0.12",
"gpytorch==1.13.0",
"psutil>=7.0.0",
]

[project.urls]
Expand Down
92 changes: 91 additions & 1 deletion tests/test_runtime/test_error_handling_strategies.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from __future__ import annotations

import contextlib
import multiprocessing
import signal
import time
from dataclasses import dataclass
from pathlib import Path

import psutil
import pytest
from pytest_cases import fixture, parametrize

from neps.exceptions import WorkerRaiseError
from neps.optimizers import OptimizerInfo
from neps.optimizers.algorithms import random_search
from neps.runtime import DefaultWorker
from neps.runtime import SIGNALS_TO_HANDLE_IF_AVAILABLE, DefaultWorker
from neps.space import Float, SearchSpace
from neps.state import (
DefaultReportValues,
Expand Down Expand Up @@ -209,3 +213,89 @@ def __call__(self, *args, **kwargs) -> float: # noqa: ARG002

assert neps_state.lock_and_get_next_pending_trial() is None
assert len(neps_state.lock_and_get_errors()) == 1


def sleep_function(*args, **kwargs) -> float:
time.sleep(20)
return 10


SIGNALS: list[signal.Signals] = []
for name in SIGNALS_TO_HANDLE_IF_AVAILABLE:
if hasattr(signal.Signals, name):
sig: signal.Signals = getattr(signal.Signals, name)
SIGNALS.append(sig)


# @pytest.mark.ci_examples
@pytest.mark.parametrize("signum", SIGNALS)
def test_worker_reset_evaluating_to_pending_on_ctrl_c(
signum: signal.Signals,
neps_state: NePSState,
) -> None:
optimizer = random_search(SearchSpace({"a": Float(0, 1)}))
settings = WorkerSettings(
on_error=OnErrorPossibilities.IGNORE, # <- Highlight
default_report_values=DefaultReportValues(),
max_evaluations_total=None,
include_in_progress_evaluations_towards_maximum=False,
max_cost_total=None,
max_evaluations_for_worker=1,
max_evaluation_time_total_seconds=None,
max_wallclock_time_for_worker_seconds=None,
max_evaluation_time_for_worker_seconds=None,
max_cost_for_worker=None,
batch_size=None,
)

worker1 = DefaultWorker.new(
state=neps_state,
optimizer=optimizer,
evaluation_fn=sleep_function,
settings=settings,
)

# Use multiprocessing.Process
p = multiprocessing.Process(target=worker1.run)
p.start()

time.sleep(5)
assert p.pid is not None
assert p.is_alive()

# Should be evaluating at this stage
trials = neps_state.lock_and_read_trials()
assert len(trials) == 1
assert next(iter(trials.values())).metadata.state == Trial.State.EVALUATING

# Kill the process while it's evaluating using signals
process = psutil.Process(p.pid)

# If sending the signal fails, skip the test,
# as most likely the signal is not supported on this platform
try:
process.send_signal(signum)
except ValueError as e:
pytest.skip(f"Signal error: {e}")
else:
# If the signal is sent successfully, we can proceed with the test
pass

# If the system is windows and the signal is SIGTERM, skip the test
if (
signum == signal.SIGTERM
and multiprocessing.get_start_method() == "spawn"
and multiprocessing.current_process().name == "MainProcess"
):
pytest.skip("SIGTERM is not supported on Windows with spawn start method")

p.join(timeout=5) # Wait for the process to terminate

if p.is_alive():
p.terminate() # Force terminate if it's still alive
p.join()
pytest.fail("Worker did not terminate after receiving signal!")
else:
trials2 = neps_state.lock_and_read_trials()
assert len(trials2) == 1
assert next(iter(trials2.values())).metadata.state == Trial.State.PENDING
Loading