-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implementation of python async engine, which decouples receiving requ… #2758
Open
siddvenk
wants to merge
1
commit into
deepjavalibrary:master
Choose a base branch
from
siddvenk:py-continuous-batching
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
#!/usr/bin/env python | ||
# | ||
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file | ||
# except in compliance with the License. A copy of the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" | ||
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for | ||
# the specific language governing permissions and limitations under the License. | ||
|
||
import asyncio | ||
import logging | ||
import time | ||
import traceback | ||
from concurrent.futures import ThreadPoolExecutor | ||
from functools import partial | ||
from threading import Thread | ||
from queue import Queue | ||
from asyncio.queues import Queue as AsyncQueue | ||
|
||
from djl_python.inputs import Input | ||
from djl_python.outputs import Output | ||
from djl_python.python_sync_engine import PythonSyncEngine | ||
|
||
REQUEST_TRACKING_ID_KEY = "request_tracking_id" | ||
|
||
|
||
class PythonAsyncEngine(PythonSyncEngine): | ||
""" | ||
Backend engine to run python code in decoupled/async mode. | ||
Requests are forwarded from the model server and submitted to the handler. | ||
The handler returns responses as they become available, and sends them to the frontend. | ||
Requests are tracked/coordinated via the request_tracking_id property. | ||
This is an internal property set/managed by the model server. | ||
""" | ||
|
||
def __init__(self, args, service): | ||
super().__init__(args, service) | ||
self.output_queue = AsyncQueue() | ||
self.exception_queue = Queue() | ||
self.loop = None | ||
# Todo: for async mode we should maybe consider | ||
|
||
def receive_requests(self): | ||
logging.info("starting receive requests thread") | ||
while True: | ||
inputs, function_name = self._prepare_inputs() | ||
logging.debug( | ||
f"received new request with tracking_id {inputs.get_property(REQUEST_TRACKING_ID_KEY)}, submitting to handler" | ||
) | ||
asyncio.run_coroutine_threadsafe( | ||
self.invoke_handler(function_name, inputs), self.loop) | ||
|
||
async def invoke_handler(self, function_name: str, inputs: Input): | ||
request_tracking_id = inputs.get_property(REQUEST_TRACKING_ID_KEY) | ||
try: | ||
outputs = await self.service.invoke_handler_async( | ||
function_name, inputs, self.cl_socket) | ||
except Exception as e: | ||
logging.exception("Failed invoke service.invoke_handler_async()") | ||
if (type(e).__name__ == "OutOfMemoryError" | ||
or type(e).__name__ == "MemoryError" | ||
or "No available memory for the cache blocks" in str(e) | ||
or "CUDA error: out of memory" in str(e)): | ||
logging.exception( | ||
f"Memory Error encountered when invoking module {self.service.module}, function {function_name}" | ||
) | ||
outputs = Output(code=507, message=str(e)) | ||
outputs.add_property(REQUEST_TRACKING_ID_KEY, | ||
request_tracking_id) | ||
else: | ||
logging.exception( | ||
f"service.invoke_handler_async() failure. There was an error invoking module {self.service.module}, function {function_name}" | ||
) | ||
outputs = Output().error( | ||
str(e), message="service.invoke_handler_async() failure") | ||
outputs.add_property(REQUEST_TRACKING_ID_KEY, | ||
request_tracking_id) | ||
if outputs is None: | ||
outputs = Output(code=204, message="No content") | ||
logging.debug( | ||
"empty response received from service.invoke_handler_async()") | ||
outputs.add_property(REQUEST_TRACKING_ID_KEY, request_tracking_id) | ||
elif not isinstance(outputs, Output): | ||
message = ( | ||
f"Invalid type returned from {self.service.module}.{function_name}. " | ||
f"Received type {type(outputs)}, does not match expected type djl_python.outputs.Output" | ||
) | ||
logging.error(message) | ||
outputs = Output().error(message) | ||
outputs.add_property(REQUEST_TRACKING_ID_KEY, request_tracking_id) | ||
logging.info(f"putting result of inference to output queue") | ||
await self.output_queue.put(outputs) | ||
|
||
def send_responses(self): | ||
logging.info("starting send responses thread") | ||
while True: | ||
future = asyncio.run_coroutine_threadsafe(self.output_queue.get(), | ||
self.loop) | ||
logging.debug("waiting for new inference response") | ||
output = future.result() | ||
output.send(self.cl_socket) | ||
|
||
def run_server(self): | ||
|
||
async def main(): | ||
self.loop = asyncio.get_running_loop() | ||
self._create_cl_socket() | ||
|
||
def catch_all(func): | ||
try: | ||
func() | ||
except Exception as e: | ||
logging.error(f"{func} failed. Details {e}") | ||
self.exception_queue.put(str(traceback.format_exc())) | ||
|
||
threads = [ | ||
Thread(target=partial(catch_all, self.receive_requests)), | ||
Thread(target=partial(catch_all, self.send_responses)), | ||
] | ||
|
||
for thread in threads: | ||
thread.start() | ||
|
||
def check_threads(): | ||
while True: | ||
if not all(t.is_alive() for t in threads): | ||
return | ||
time.sleep(1) | ||
|
||
with ThreadPoolExecutor(1) as executor: | ||
await asyncio.get_event_loop().run_in_executor( | ||
executor, check_threads) | ||
|
||
asyncio.get_event_loop().run_until_complete(main()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can just do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good call i'll update |
||
if not self.exception_queue.empty(): | ||
logging.error( | ||
f"djl async engine terminated with error {self.exception_queue.get()}" | ||
) | ||
logging.info("djl async engine terminated") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
#!/usr/bin/env python | ||
# | ||
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file | ||
# except in compliance with the License. A copy of the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" | ||
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for | ||
# the specific language governing permissions and limitations under the License. | ||
|
||
import os | ||
import socket | ||
import signal | ||
import logging | ||
from typing import Tuple | ||
|
||
from djl_python.service_loader import get_annotated_function, load_model_service, has_function_in_module | ||
from djl_python.inputs import Input | ||
from djl_python.outputs import Output | ||
|
||
SOCKET_ACCEPT_TIMEOUT = 30.0 | ||
|
||
|
||
class PythonSyncEngine(object): | ||
""" | ||
Backend engine to run python code | ||
""" | ||
|
||
def __init__(self, args, service): | ||
# Support MPI environment args | ||
if os.getenv('OMPI_COMM_WORLD_SIZE'): | ||
os.environ["WORLD_SIZE"] = os.getenv('OMPI_COMM_WORLD_SIZE') | ||
if os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'): | ||
os.environ["LOCAL_RANK"] = os.getenv('OMPI_COMM_WORLD_LOCAL_RANK') | ||
rank = os.environ.get("OMPI_COMM_WORLD_RANK") | ||
if rank: | ||
os.environ["RANK"] = rank | ||
|
||
self.model_dir = args.model_dir | ||
self.sock_type = args.sock_type | ||
self.sock_name = args.sock_name | ||
self.port = args.port | ||
self.service = service | ||
self.device_id = args.device_id | ||
self.tensor_parallel_degree = args.tensor_parallel_degree | ||
self.pipeline_parallel_degree = args.pipeline_parallel_degree | ||
self.cluster_size = args.cluster_size | ||
self.entry_point = args.entry_point | ||
self.recommended_entry_point = args.recommended_entry_point | ||
self.output_formatter = get_annotated_function(args.model_dir, | ||
"is_output_formatter") | ||
self.input_formatter = get_annotated_function(args.model_dir, | ||
"is_input_formatter") | ||
self.is_entry_point_verified = False | ||
|
||
if self.sock_type == "unix": | ||
if self.sock_name is None: | ||
raise ValueError("Missing sock-name argument.") | ||
self.sock_name = f"{args.sock_name}.{rank}" if rank else args.sock_name | ||
|
||
self.clean_up() | ||
elif self.sock_type == "tcp": | ||
if self.sock_name is None: | ||
self.sock_name = "0.0.0.0" | ||
if self.port is None: | ||
raise ValueError("Missing port argument.") | ||
self.port = int(self.port) + int(rank) if rank else self.port | ||
else: | ||
raise ValueError(f"Invalid socket-type: {self.sock_type}.") | ||
|
||
socket_family = socket.AF_INET if self.sock_type == "tcp" else socket.AF_UNIX | ||
self.sock = socket.socket(socket_family, socket.SOCK_STREAM) | ||
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | ||
self.sock.settimeout(SOCKET_ACCEPT_TIMEOUT) | ||
self.cl_socket = None | ||
|
||
def clean_up(self): | ||
pid_file = f"{self.sock_name}.pid" | ||
if os.path.exists(pid_file): | ||
with open(pid_file, "r") as f: | ||
pid = f.readline() | ||
if pid: | ||
try: | ||
os.kill(int(pid), signal.SIGKILL) | ||
logging.warning( | ||
f"{self.sock_name} - kill dangling process: {pid}") | ||
except ProcessLookupError: | ||
pass | ||
|
||
with open(pid_file, "w") as f: | ||
f.write(str(os.getpid())) | ||
|
||
if os.path.exists(self.sock_name): | ||
os.remove(self.sock_name) | ||
|
||
def _prepare_inputs(self) -> Tuple[Input, str]: | ||
inputs = Input() | ||
inputs.read(self.cl_socket) | ||
prop = inputs.get_properties() | ||
if self.tensor_parallel_degree: | ||
prop["tensor_parallel_degree"] = self.tensor_parallel_degree | ||
if self.pipeline_parallel_degree: | ||
prop["pipeline_parallel_degree"] = self.pipeline_parallel_degree | ||
if self.cluster_size: | ||
prop["cluster_size"] = self.cluster_size | ||
prop["device_id"] = self.device_id | ||
|
||
if "output_formatter" in prop: | ||
if hasattr(self.service, prop["output_formatter"]): | ||
# TODO: custom output_formatter in serving.properties is deprecated. Remove users are migrated. | ||
prop["output_formatter"] = getattr(self.service, | ||
prop["output_formatter"]) | ||
elif self.output_formatter: | ||
prop["output_formatter"] = self.output_formatter | ||
|
||
if self.input_formatter: | ||
prop["input_formatter"] = self.input_formatter | ||
function_name = inputs.get_function_name() | ||
if not self.is_entry_point_verified: | ||
if self.recommended_entry_point: | ||
if not has_function_in_module(self.service.module, | ||
function_name): | ||
self.service = load_model_service( | ||
self.model_dir, self.recommended_entry_point, | ||
self.device_id) | ||
logging.info( | ||
f"{self.entry_point} file has no handler function {function_name}." | ||
f"Hence choosing the LMI recommended entry point {self.recommended_entry_point}" | ||
) | ||
self.is_entry_point_verified = True | ||
return inputs, function_name | ||
|
||
def _create_cl_socket(self): | ||
if self.sock_type == "unix": | ||
self.sock.bind(self.sock_name) | ||
else: | ||
logging.info( | ||
f"Socket bind on address: {self.sock_name}:{self.port}") | ||
self.sock.bind((self.sock_name, int(self.port))) | ||
|
||
self.sock.listen(128) | ||
logging.info("Python engine started.") | ||
|
||
(cl_socket, _) = self.sock.accept() | ||
# workaround error(35, 'Resource temporarily unavailable') on OSX | ||
cl_socket.setblocking(True) | ||
self.cl_socket = cl_socket | ||
|
||
def run_server(self): | ||
""" | ||
Run the backend worker process and listen on a socket | ||
:return: | ||
""" | ||
self._create_cl_socket() | ||
|
||
while True: | ||
inputs, function_name = self._prepare_inputs() | ||
try: | ||
outputs = self.service.invoke_handler(function_name, inputs) | ||
if outputs is None: | ||
outputs = Output(code=204, message="No content") | ||
elif not isinstance(outputs, Output): | ||
outputs = Output().error( | ||
f"Invalid output type: {type(outputs)}") | ||
except Exception as e: | ||
logging.exception("Failed invoke service.invoke_handler()") | ||
if (type(e).__name__ == "OutOfMemoryError" | ||
or type(e).__name__ == "MemoryError" | ||
or "No available memory for the cache blocks" in str(e) | ||
or "CUDA error: out of memory" in str(e)): | ||
outputs = Output(code=507, message=str(e)) | ||
else: | ||
outputs = Output().error(str(e)) | ||
|
||
outputs.send(self.cl_socket) | ||
logging.debug("Outputs is sent to DJL engine.") | ||
try: | ||
outputs.execute_finalize() | ||
except Exception as e: | ||
logging.exception(f"Failed on finalize function: {e}") |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we should add a graceful shutdown mechanism to the
receive_requests
andsend_responses
threads, right now it is unclear how these threads would behave if the main proc is abruptly shut downThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, definitely. We need graceful shutdown in a few places:
I think I can tackle 1 in this PR, but 2 may need some more thought.