Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions areal/core/remote_inf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from areal.utils.launcher import wait_llm_server_addrs
from areal.utils.network import find_free_ports, gethostip
from areal.utils.perf_tracer import trace_perf
from areal.utils.proc import kill_process_tree

from .workflow_executor import WorkflowExecutor

Expand Down Expand Up @@ -925,15 +926,7 @@ def _shutdown_one_server(self, server_info: LocalInfServerInfo):
self.addresses.remove(addr)
if server_info.process.poll() is not None:
return
server_info.process.terminate()
try:
server_info.process.wait(timeout=10)
except subprocess.TimeoutExpired:
self.logger.warning(
f"Server process {server_info.process.pid} did not terminate gracefully. Killing it."
)
server_info.process.kill()
server_info.process.wait()
kill_process_tree(server_info.process.pid, graceful=True)

def teardown_server(self):
"""Teardown all locally launched servers."""
Expand Down
3 changes: 2 additions & 1 deletion areal/experimental/tests/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from areal.experimental.openai import ArealOpenAI
from areal.utils import network, seeding
from areal.utils.hf_utils import load_hf_tokenizer
from areal.utils.proc import kill_process_tree

EXPR_NAME = "test_openai"
TRIAL_NAME = "trial_0"
Expand Down Expand Up @@ -62,7 +63,7 @@ def sglang_server():
if time.time() - tik > RUN_SERVER_TIMEOUT:
raise RuntimeError("server launch failed")
yield
process.terminate()
kill_process_tree(process.pid, graceful=True)


@pytest.fixture(scope="module")
Expand Down
49 changes: 3 additions & 46 deletions areal/launcher/sglang_server.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import os
import signal
import subprocess
import sys
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy

import psutil
import requests

from areal.api.alloc_mode import AllocationMode
Expand All @@ -23,50 +20,11 @@
from areal.utils import logging, name_resolve, names
from areal.utils.launcher import TRITON_CACHE_PATH, apply_sglang_patch
from areal.utils.network import find_free_ports, gethostip
from areal.utils.proc import kill_process_tree

logger = logging.getLogger("SGLangServer Wrapper")


# Copied from SGLang
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
"""Kill the process and all its child processes."""
# Remove sigchld handler to avoid spammy logs.
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGCHLD, signal.SIG_DFL)

if parent_pid is None:
parent_pid = os.getpid()
include_parent = False

try:
itself = psutil.Process(parent_pid)
except psutil.NoSuchProcess:
return

children = itself.children(recursive=True)
for child in children:
if child.pid == skip_pid:
continue
try:
child.kill()
except psutil.NoSuchProcess:
pass

if include_parent:
try:
if parent_pid == os.getpid():
itself.kill()
sys.exit(0)

itself.kill()

# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
# so we send an additional signal to kill them.
itself.send_signal(signal.SIGQUIT)
except psutil.NoSuchProcess:
pass


def launch_server_cmd(command: list[str]) -> subprocess.Popen:
"""
Launch inference server in a new process and return its process handle.
Expand Down Expand Up @@ -217,8 +175,7 @@ def run(self):
if not all_alive:
for i, process in enumerate(server_processes):
if process.poll() is None:
process.terminate()
process.wait()
kill_process_tree(process.pid, graceful=True)
logger.info(
f"SGLang server process{server_addresses[i]} terminated."
)
Expand Down Expand Up @@ -262,7 +219,7 @@ def main(argv):
try:
launch_sglang_server(argv)
finally:
kill_process_tree(os.getpid())
kill_process_tree(os.getpid(), graceful=True)


if __name__ == "__main__":
Expand Down
52 changes: 2 additions & 50 deletions areal/launcher/vllm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy

import psutil
import requests

from areal.api.cli_args import (
Expand All @@ -22,58 +21,11 @@
from areal.utils import logging, name_resolve, names
from areal.utils.launcher import TRITON_CACHE_PATH
from areal.utils.network import find_free_ports, gethostip
from areal.utils.proc import kill_process_tree

logger = logging.getLogger("vLLMServer Wrapper")


def terminate_process_tree(pid: int, timeout: int = 5) -> None:
"""Terminate a process and all its children recursively.

Args:
pid: Process ID to terminate
timeout: Seconds to wait for graceful termination before forcing kill
"""
try:
parent = psutil.Process(pid)
children = parent.children(recursive=True)

# First, try graceful termination
logger.info(f"Sending SIGTERM to process {pid} and {len(children)} children")
for child in children:
try:
child.terminate()
except psutil.NoSuchProcess:
pass

try:
parent.terminate()
except psutil.NoSuchProcess:
return

# Wait for graceful shutdown
gone, alive = psutil.wait_procs(children + [parent], timeout=timeout)

# Force kill any remaining processes
if alive:
logger.warning(
f"Force killing {len(alive)} processes that didn't terminate gracefully"
)
for proc in alive:
try:
proc.kill()
except psutil.NoSuchProcess:
pass

# Final wait
psutil.wait_procs(alive, timeout=1)

logger.info(f"Successfully cleaned up process tree for PID {pid}")
except psutil.NoSuchProcess:
logger.info(f"Process {pid} already terminated")
except Exception as e:
logger.error(f"Error terminating process tree for PID {pid}: {e}")


def launch_server_cmd(
command: list[str], custom_env: dict | None = None
) -> subprocess.Popen:
Expand Down Expand Up @@ -175,7 +127,7 @@ def _cleanup_all_servers(self):
for i, process in enumerate(processes_to_clean):
if process.poll() is None: # Process is still running
logger.info(f"Terminating vLLM server process {i} (PID: {process.pid})")
terminate_process_tree(process.pid, timeout=10)
kill_process_tree(process.pid, timeout=10, graceful=True)
else:
logger.info(f"vLLM server process {i} already terminated")

Expand Down
43 changes: 2 additions & 41 deletions areal/scheduler/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import aiohttp
import orjson
import psutil
import requests

from areal.api.cli_args import BaseExperimentConfig
Expand All @@ -38,6 +37,7 @@
get_env_vars,
)
from areal.utils.network import find_free_ports, gethostip
from areal.utils.proc import kill_process_tree

logger = logging.getLogger("LocalScheduler")

Expand Down Expand Up @@ -584,7 +584,7 @@ def _cleanup_workers(self, workers: list[WorkerInfo]):
for port_str in worker_info.worker.worker_ports:
self._allocated_ports.discard(int(port_str))

self._terminate_process_tree(worker_info.process.pid)
kill_process_tree(worker_info.process.pid, timeout=3, graceful=True)

logger.debug(f"Cleaned up worker {worker_info.worker.id}")
except Exception as e:
Expand All @@ -593,45 +593,6 @@ def _cleanup_workers(self, workers: list[WorkerInfo]):
exc_info=True,
)

def _terminate_process_tree(self, pid: int):
try:
parent = psutil.Process(pid)
children = parent.children(recursive=True)

# Try graceful termination first
for child in children:
try:
child.terminate()
except psutil.NoSuchProcess:
pass

try:
parent.terminate()
except psutil.NoSuchProcess:
return

# Wait for graceful termination
_, alive = psutil.wait_procs([parent] + children, timeout=3)

# Force kill remaining processes
for proc in alive:
try:
proc.kill()
except psutil.NoSuchProcess:
pass

except psutil.NoSuchProcess:
# Process already gone
pass
except psutil.Error as e:
logger.warning(f"Error terminating process tree {pid}: {e}", exc_info=True)
except Exception:
import traceback

logger.warning(
f"Error terminating process tree {pid}: {traceback.format_exc()}"
)

def _read_log_tail(self, log_file: str, lines: int = 50) -> str:
try:
with open(log_file) as f:
Expand Down
9 changes: 2 additions & 7 deletions areal/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from areal.platforms import current_platform
from areal.utils import logging
from areal.utils.proc import kill_process_tree

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -758,13 +759,7 @@ def test_search_agent_deepresearch(tmp_path_factory):
f"Search Agent DeepResearch example failed, return_code={return_code}"
)
finally:
# Ensure cleanup happens even if test fails
llm_judge_proc.terminate()
try:
llm_judge_proc.wait(timeout=5)
except subprocess.TimeoutExpired:
llm_judge_proc.kill()
llm_judge_proc.wait()
kill_process_tree(llm_judge_proc.pid)


@pytest.mark.multi_gpu
Expand Down
Loading