diff --git a/src/sentry/taskworker/worker.py b/src/sentry/taskworker/worker.py index 32d9e9ad6be4e3..0a5addb8dc3eb5 100644 --- a/src/sentry/taskworker/worker.py +++ b/src/sentry/taskworker/worker.py @@ -1,9 +1,9 @@ from __future__ import annotations -import atexit import logging import multiprocessing import queue +import signal import threading import time from concurrent.futures import ThreadPoolExecutor @@ -81,9 +81,6 @@ def __init__( self._processing_pool_name: str = processing_pool_name or "unknown" - def __del__(self) -> None: - self.shutdown() - def do_imports(self) -> None: for module in settings.TASKWORKER_IMPORTS: __import__(module) @@ -99,10 +96,20 @@ def start(self) -> int: self.start_result_thread() self.start_spawn_children_thread() - atexit.register(self.shutdown) + # Convert signals into KeyboardInterrupt. + # Running shutdown() within the signal handler can lead to deadlocks + def signal_handler(*args: Any) -> None: + raise KeyboardInterrupt() - while True: - self.run_once() + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + while True: + self.run_once() + except KeyboardInterrupt: + self.shutdown() + raise def run_once(self) -> None: """Access point for tests to run a single worker loop""" @@ -113,21 +120,25 @@ def shutdown(self) -> None: Shutdown cleanly Activate the shutdown event and drain results before terminating children. """ - if self._shutdown_event.is_set(): - return - - logger.info("taskworker.worker.shutdown") + logger.info("taskworker.worker.shutdown.start") self._shutdown_event.set() + logger.info("taskworker.worker.shutdown.spawn_children") + if self._spawn_children_thread: + self._spawn_children_thread.join() + + logger.info("taskworker.worker.shutdown.children") for child in self._children: child.terminate() + for child in self._children: child.join() + logger.info("taskworker.worker.shutdown.result") if self._result_thread: - self._result_thread.join() + # Use a timeout as sometimes this thread can deadlock on the Event. + self._result_thread.join(timeout=5) - # Drain remaining results synchronously, as the thread will have terminated - # when shutdown_event was set. + # Drain any remaining results synchronously while True: try: result = self._processed_tasks.get_nowait() @@ -135,8 +146,7 @@ def shutdown(self) -> None: except queue.Empty: break - if self._spawn_children_thread: - self._spawn_children_thread.join() + logger.info("taskworker.worker.shutdown.complete") def _add_task(self) -> bool: """ @@ -179,7 +189,7 @@ def start_result_thread(self) -> None: """ def result_thread() -> None: - logger.debug("taskworker.worker.result_thread_started") + logger.debug("taskworker.worker.result_thread.started") iopool = ThreadPoolExecutor(max_workers=self._concurrency) with iopool as executor: while not self._shutdown_event.is_set(): @@ -193,7 +203,9 @@ def result_thread() -> None: ) continue - self._result_thread = threading.Thread(target=result_thread) + self._result_thread = threading.Thread( + name="send-result", target=result_thread, daemon=True + ) self._result_thread.start() def _send_result(self, result: ProcessingResult, fetch: bool = True) -> bool: @@ -253,6 +265,7 @@ def _send_update_task( ) # Use the shutdown_event as a sleep mechanism self._shutdown_event.wait(self._setstatus_backoff_seconds) + try: next_task = self.client.update_task(result, fetch_next) self._setstatus_backoff_seconds = 0 @@ -276,7 +289,7 @@ def _send_update_task( def start_spawn_children_thread(self) -> None: def spawn_children_thread() -> None: - logger.debug("taskworker.worker.spawn_children_thread_started") + logger.debug("taskworker.worker.spawn_children_thread.started") while not self._shutdown_event.is_set(): self._children = [child for child in self._children if child.is_alive()] if len(self._children) >= self._concurrency: @@ -284,6 +297,7 @@ def spawn_children_thread() -> None: continue for i in range(self._concurrency - len(self._children)): process = self.mp_context.Process( + name=f"taskworker-child-{i}", target=child_process, args=( self._child_tasks, @@ -301,7 +315,9 @@ def spawn_children_thread() -> None: extra={"pid": process.pid, "processing_pool": self._processing_pool_name}, ) - self._spawn_children_thread = threading.Thread(target=spawn_children_thread) + self._spawn_children_thread = threading.Thread( + name="spawn-children", target=spawn_children_thread, daemon=True + ) self._spawn_children_thread.start() def fetch_task(self) -> InflightTaskActivation | None: diff --git a/src/sentry/taskworker/workerchild.py b/src/sentry/taskworker/workerchild.py index 2d53ddf5c45d47..1dc17dd9188ffc 100644 --- a/src/sentry/taskworker/workerchild.py +++ b/src/sentry/taskworker/workerchild.py @@ -160,7 +160,7 @@ def handle_alarm(signum: int, frame: FrameType | None) -> None: f"execution deadline of {deadline} seconds exceeded by {taskname}" ) - while True: + while not shutdown_event.is_set(): if max_task_count and processed_task_count >= max_task_count: metrics.incr( "taskworker.worker.max_task_count_reached", @@ -171,10 +171,6 @@ def handle_alarm(signum: int, frame: FrameType | None) -> None: ) break - if shutdown_event.is_set(): - logger.info("taskworker.worker.shutdown_event") - break - child_tasks_get_start = time.monotonic() try: # If the queue is empty, this could block for a second.