Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implementation of python async engine, which decouples receiving requ… #2758

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions engines/python/setup/djl_python/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def python_engine_args():
type=str,
default="info",
help="log level to use for djl_python logging")
parser.add_argument(
'--async-mode',
required=False,
action=argparse.BooleanOptionalAction,
help="whether to use async python engine for comms")
return parser

@staticmethod
Expand Down
8 changes: 7 additions & 1 deletion engines/python/setup/djl_python/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ def add_property(self, key, val):
self.properties[key] = val
return self

def get_properties(self):
return self.properties

def get_property(self, key):
return self.properties.get(key)

def add(self, value, key=None, batch_index=None):
if key is not None and type(key) is not str:
logging.warning(f"Output key should be str type, got {type(key)}")
Expand Down Expand Up @@ -182,7 +188,7 @@ def send(self, cl_socket):
msg += struct.pack('>h', len(self.properties))
for k, v in self.properties.items():
self.write_utf8(msg, k)
self.write_utf8(msg, v)
self.write_utf8(msg, str(v))

if self.stream_content is None:
size = self.content.size()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# the specific language governing permissions and limitations under the License.
import ast
import logging
from typing import Optional, Any, Dict, Tuple, Literal
from typing import Optional, Any, Dict, Tuple, Literal, Union
from pydantic import field_validator, model_validator, ConfigDict, Field
from vllm import EngineArgs
from vllm import EngineArgs, AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser
from vllm.engine.arg_utils import StoreBoolean

Expand Down Expand Up @@ -219,18 +219,21 @@ def generate_vllm_engine_arg_dict(self,
vllm_engine_args.update(passthrough_vllm_engine_args)
return vllm_engine_args

def get_engine_args(self) -> EngineArgs:
def get_engine_args(self,
async_engine=False
) -> Union[EngineArgs, AsyncEngineArgs]:
additional_vllm_engine_args = self.get_additional_vllm_engine_args()
self.handle_lmi_vllm_config_conflicts(additional_vllm_engine_args)
vllm_engine_arg_dict = self.generate_vllm_engine_arg_dict(
additional_vllm_engine_args)
logging.debug(
f"Construction vLLM engine args from the following DJL configs: {vllm_engine_arg_dict}"
)
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
arg_cls = AsyncEngineArgs if async_engine else EngineArgs
parser = arg_cls.add_cli_args(FlexibleArgumentParser())
args_list = construct_vllm_args_list(vllm_engine_arg_dict, parser)
args = parser.parse_args(args=args_list)
engine_args = EngineArgs.from_cli_args(args)
engine_args = arg_cls.from_cli_args(args)
# we have to do this separately because vllm converts it into a string
engine_args.long_lora_scaling_factors = self.long_lora_scaling_factors
# These neuron configs are not implemented in the vllm arg parser
Expand Down
143 changes: 143 additions & 0 deletions engines/python/setup/djl_python/python_async_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#!/usr/bin/env python
#
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import asyncio
import logging
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from threading import Thread
from queue import Queue
from asyncio.queues import Queue as AsyncQueue

from djl_python.inputs import Input
from djl_python.outputs import Output
from djl_python.python_sync_engine import PythonSyncEngine

REQUEST_TRACKING_ID_KEY = "request_tracking_id"


class PythonAsyncEngine(PythonSyncEngine):
"""
Backend engine to run python code in decoupled/async mode.
Requests are forwarded from the model server and submitted to the handler.
The handler returns responses as they become available, and sends them to the frontend.
Requests are tracked/coordinated via the request_tracking_id property.
This is an internal property set/managed by the model server.
"""

def __init__(self, args, service):
super().__init__(args, service)
self.output_queue = AsyncQueue()
self.exception_queue = Queue()
self.loop = None
# Todo: for async mode we should maybe consider

def receive_requests(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should add a graceful shutdown mechanism to the receive_requests and send_responses threads, right now it is unclear how these threads would behave if the main proc is abruptly shut down

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, definitely. We need graceful shutdown in a few places:

  1. Python Engine - I think if the python process crashes or is terminated, the on the front-end we should receive that signal and send a response for all incomplete outstanding requests. I'd probably put this logic in AsyncRequestManager, similar to how we have something for RollingBatch called shutdown
  2. The vllm_handler - in local testing, i've triggered a few unrecoverable engine errors that at least trigger a python restart, but don't clean up the memory/resources used by vllm. Would be great too if we have some sort of internal health check.

I think I can tackle 1 in this PR, but 2 may need some more thought.

logging.info("starting receive requests thread")
while True:
inputs, function_name = self._prepare_inputs()
logging.debug(
f"received new request with tracking_id {inputs.get_property(REQUEST_TRACKING_ID_KEY)}, submitting to handler"
)
asyncio.run_coroutine_threadsafe(
self.invoke_handler(function_name, inputs), self.loop)

async def invoke_handler(self, function_name: str, inputs: Input):
request_tracking_id = inputs.get_property(REQUEST_TRACKING_ID_KEY)
try:
outputs = await self.service.invoke_handler_async(
function_name, inputs, self.cl_socket)
except Exception as e:
logging.exception("Failed invoke service.invoke_handler_async()")
if (type(e).__name__ == "OutOfMemoryError"
or type(e).__name__ == "MemoryError"
or "No available memory for the cache blocks" in str(e)
or "CUDA error: out of memory" in str(e)):
logging.exception(
f"Memory Error encountered when invoking module {self.service.module}, function {function_name}"
)
outputs = Output(code=507, message=str(e))
outputs.add_property(REQUEST_TRACKING_ID_KEY,
request_tracking_id)
else:
logging.exception(
f"service.invoke_handler_async() failure. There was an error invoking module {self.service.module}, function {function_name}"
)
outputs = Output().error(
str(e), message="service.invoke_handler_async() failure")
outputs.add_property(REQUEST_TRACKING_ID_KEY,
request_tracking_id)
if outputs is None:
outputs = Output(code=204, message="No content")
logging.debug(
"empty response received from service.invoke_handler_async()")
outputs.add_property(REQUEST_TRACKING_ID_KEY, request_tracking_id)
elif not isinstance(outputs, Output):
message = (
f"Invalid type returned from {self.service.module}.{function_name}. "
f"Received type {type(outputs)}, does not match expected type djl_python.outputs.Output"
)
logging.error(message)
outputs = Output().error(message)
outputs.add_property(REQUEST_TRACKING_ID_KEY, request_tracking_id)
logging.info(f"putting result of inference to output queue")
await self.output_queue.put(outputs)

def send_responses(self):
logging.info("starting send responses thread")
while True:
future = asyncio.run_coroutine_threadsafe(self.output_queue.get(),
self.loop)
logging.debug("waiting for new inference response")
output = future.result()
output.send(self.cl_socket)

def run_server(self):

async def main():
self.loop = asyncio.get_running_loop()
self._create_cl_socket()

def catch_all(func):
try:
func()
except Exception as e:
logging.error(f"{func} failed. Details {e}")
self.exception_queue.put(str(traceback.format_exc()))

threads = [
Thread(target=partial(catch_all, self.receive_requests)),
Thread(target=partial(catch_all, self.send_responses)),
]

for thread in threads:
thread.start()

def check_threads():
while True:
if not all(t.is_alive() for t in threads):
return
time.sleep(1)

with ThreadPoolExecutor(1) as executor:
await asyncio.get_event_loop().run_in_executor(
executor, check_threads)

asyncio.get_event_loop().run_until_complete(main())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just do asyncio.run() since there is no existing event loop at this point

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call i'll update

if not self.exception_queue.empty():
logging.error(
f"djl async engine terminated with error {self.exception_queue.get()}"
)
logging.info("djl async engine terminated")
183 changes: 183 additions & 0 deletions engines/python/setup/djl_python/python_sync_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
#!/usr/bin/env python
#
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import os
import socket
import signal
import logging
from typing import Tuple

from djl_python.service_loader import get_annotated_function, load_model_service, has_function_in_module
from djl_python.inputs import Input
from djl_python.outputs import Output

SOCKET_ACCEPT_TIMEOUT = 30.0


class PythonSyncEngine(object):
"""
Backend engine to run python code
"""

def __init__(self, args, service):
# Support MPI environment args
if os.getenv('OMPI_COMM_WORLD_SIZE'):
os.environ["WORLD_SIZE"] = os.getenv('OMPI_COMM_WORLD_SIZE')
if os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'):
os.environ["LOCAL_RANK"] = os.getenv('OMPI_COMM_WORLD_LOCAL_RANK')
rank = os.environ.get("OMPI_COMM_WORLD_RANK")
if rank:
os.environ["RANK"] = rank

self.model_dir = args.model_dir
self.sock_type = args.sock_type
self.sock_name = args.sock_name
self.port = args.port
self.service = service
self.device_id = args.device_id
self.tensor_parallel_degree = args.tensor_parallel_degree
self.pipeline_parallel_degree = args.pipeline_parallel_degree
self.cluster_size = args.cluster_size
self.entry_point = args.entry_point
self.recommended_entry_point = args.recommended_entry_point
self.output_formatter = get_annotated_function(args.model_dir,
"is_output_formatter")
self.input_formatter = get_annotated_function(args.model_dir,
"is_input_formatter")
self.is_entry_point_verified = False

if self.sock_type == "unix":
if self.sock_name is None:
raise ValueError("Missing sock-name argument.")
self.sock_name = f"{args.sock_name}.{rank}" if rank else args.sock_name

self.clean_up()
elif self.sock_type == "tcp":
if self.sock_name is None:
self.sock_name = "0.0.0.0"
if self.port is None:
raise ValueError("Missing port argument.")
self.port = int(self.port) + int(rank) if rank else self.port
else:
raise ValueError(f"Invalid socket-type: {self.sock_type}.")

socket_family = socket.AF_INET if self.sock_type == "tcp" else socket.AF_UNIX
self.sock = socket.socket(socket_family, socket.SOCK_STREAM)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.settimeout(SOCKET_ACCEPT_TIMEOUT)
self.cl_socket = None

def clean_up(self):
pid_file = f"{self.sock_name}.pid"
if os.path.exists(pid_file):
with open(pid_file, "r") as f:
pid = f.readline()
if pid:
try:
os.kill(int(pid), signal.SIGKILL)
logging.warning(
f"{self.sock_name} - kill dangling process: {pid}")
except ProcessLookupError:
pass

with open(pid_file, "w") as f:
f.write(str(os.getpid()))

if os.path.exists(self.sock_name):
os.remove(self.sock_name)

def _prepare_inputs(self) -> Tuple[Input, str]:
inputs = Input()
inputs.read(self.cl_socket)
prop = inputs.get_properties()
if self.tensor_parallel_degree:
prop["tensor_parallel_degree"] = self.tensor_parallel_degree
if self.pipeline_parallel_degree:
prop["pipeline_parallel_degree"] = self.pipeline_parallel_degree
if self.cluster_size:
prop["cluster_size"] = self.cluster_size
prop["device_id"] = self.device_id

if "output_formatter" in prop:
if hasattr(self.service, prop["output_formatter"]):
# TODO: custom output_formatter in serving.properties is deprecated. Remove users are migrated.
prop["output_formatter"] = getattr(self.service,
prop["output_formatter"])
elif self.output_formatter:
prop["output_formatter"] = self.output_formatter

if self.input_formatter:
prop["input_formatter"] = self.input_formatter
function_name = inputs.get_function_name()
if not self.is_entry_point_verified:
if self.recommended_entry_point:
if not has_function_in_module(self.service.module,
function_name):
self.service = load_model_service(
self.model_dir, self.recommended_entry_point,
self.device_id)
logging.info(
f"{self.entry_point} file has no handler function {function_name}."
f"Hence choosing the LMI recommended entry point {self.recommended_entry_point}"
)
self.is_entry_point_verified = True
return inputs, function_name

def _create_cl_socket(self):
if self.sock_type == "unix":
self.sock.bind(self.sock_name)
else:
logging.info(
f"Socket bind on address: {self.sock_name}:{self.port}")
self.sock.bind((self.sock_name, int(self.port)))

self.sock.listen(128)
logging.info("Python engine started.")

(cl_socket, _) = self.sock.accept()
# workaround error(35, 'Resource temporarily unavailable') on OSX
cl_socket.setblocking(True)
self.cl_socket = cl_socket

def run_server(self):
"""
Run the backend worker process and listen on a socket
:return:
"""
self._create_cl_socket()

while True:
inputs, function_name = self._prepare_inputs()
try:
outputs = self.service.invoke_handler(function_name, inputs)
if outputs is None:
outputs = Output(code=204, message="No content")
elif not isinstance(outputs, Output):
outputs = Output().error(
f"Invalid output type: {type(outputs)}")
except Exception as e:
logging.exception("Failed invoke service.invoke_handler()")
if (type(e).__name__ == "OutOfMemoryError"
or type(e).__name__ == "MemoryError"
or "No available memory for the cache blocks" in str(e)
or "CUDA error: out of memory" in str(e)):
outputs = Output(code=507, message=str(e))
else:
outputs = Output().error(str(e))

outputs.send(self.cl_socket)
logging.debug("Outputs is sent to DJL engine.")
try:
outputs.execute_finalize()
except Exception as e:
logging.exception(f"Failed on finalize function: {e}")
Loading
Loading