Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vLLM Engine #83

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ cat docker/example.env
```bash
cd docker && docker-compose --env-file example.env -f docker-compose.yml up
```
If running on a GPU instance, use gpu docker-compose file
```bash
cd docker && docker-compose --env-file example.env -f docker-compose-gpu.yml up
```
If you need Triton support (keras/pytorch/onnx etc.), use the triton docker-compose file
```bash
cd docker && docker-compose --env-file example.env -f docker-compose-triton.yml up
Expand Down
2 changes: 1 addition & 1 deletion clearml_serving/engines/triton/triton_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def main():
setattr(args, args_var, type(t)(v) if t is not None else v)

# noinspection PyProtectedMember
serving_task = ModelRequestProcessor._get_control_plane_task(task_id=args.inference_task_id)
serving_task = ModelRequestProcessor._get_control_plane_task(task_id=args.serving_id)

task = Task.init(
project_name=args.project or serving_task.get_project_name() or "serving",
Expand Down
4 changes: 3 additions & 1 deletion clearml_serving/serving/entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ UVICORN_SERVE_LOOP="${UVICORN_SERVE_LOOP:-uvloop}"
UVICORN_LOG_LEVEL="${UVICORN_LOG_LEVEL:-warning}"

# set default internal serve endpoint (for request pipelining)
CLEARML_DEFAULT_BASE_SERVE_URL="${CLEARML_DEFAULT_BASE_SERVE_URL:-http://127.0.0.1:$SERVING_PORT/serve}"
CLEARML_DEFAULT_SERVE_SUFFIX="${CLEARML_DEFAULT_SERVE_SUFFIX:-serve}"
CLEARML_DEFAULT_BASE_SERVE_URL="${CLEARML_DEFAULT_BASE_SERVE_URL:-http://127.0.0.1:$SERVING_PORT/$CLEARML_DEFAULT_SERVE_SUFFIX}"
CLEARML_DEFAULT_TRITON_GRPC_ADDR="${CLEARML_DEFAULT_TRITON_GRPC_ADDR:-127.0.0.1:8001}"

# print configuration
Expand All @@ -31,6 +32,7 @@ echo GUNICORN_EXTRA_ARGS="$GUNICORN_EXTRA_ARGS"
echo UVICORN_SERVE_LOOP="$UVICORN_SERVE_LOOP"
echo UVICORN_EXTRA_ARGS="$UVICORN_EXTRA_ARGS"
echo UVICORN_LOG_LEVEL="$UVICORN_LOG_LEVEL"
echo CLEARML_DEFAULT_SERVE_SUFFIX="$CLEARML_DEFAULT_SERVE_SUFFIX"
echo CLEARML_DEFAULT_BASE_SERVE_URL="$CLEARML_DEFAULT_BASE_SERVE_URL"
echo CLEARML_DEFAULT_TRITON_GRPC_ADDR="$CLEARML_DEFAULT_TRITON_GRPC_ADDR"

Expand Down
82 changes: 67 additions & 15 deletions clearml_serving/serving/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
import gzip
import asyncio

from fastapi import FastAPI, Request, Response, APIRouter, HTTPException
from fastapi import FastAPI, Request, Response, APIRouter, HTTPException, Depends
from fastapi.routing import APIRoute
from fastapi.responses import PlainTextResponse
from grpc.aio import AioRpcError

from http import HTTPStatus

from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest

from starlette.background import BackgroundTask

from typing import Optional, Dict, Any, Callable, Union
Expand Down Expand Up @@ -110,21 +114,19 @@ async def cuda_exception_handler(request, exc):
return PlainTextResponse("CUDA out of memory. Restarting service", status_code=500, background=task)


router = APIRouter(
prefix="/serve",
tags=["models"],
responses={404: {"description": "Model Serving Endpoint Not found"}},
route_class=GzipRoute, # mark-out to remove support for GZip content encoding
)


# cover all routing options for model version `/{model_id}`, `/{model_id}/123`, `/{model_id}?version=123`
@router.post("/{model_id}/{version}")
@router.post("/{model_id}/")
@router.post("/{model_id}")
async def serve_model(model_id: str, version: Optional[str] = None, request: Union[bytes, Dict[Any, Any]] = None):
async def process_with_exceptions(
base_url: str,
version: Optional[str],
request: Union[bytes, Dict[Any, Any]],
serve_type: str
):
try:
return_value = await processor.process_request(base_url=model_id, version=version, request_body=request)
return_value = await processor.process_request(
base_url=base_url,
version=version,
request_body=request,
serve_type=serve_type
)
except EndpointNotFoundException as ex:
raise HTTPException(status_code=404, detail="Error processing request, endpoint was not found: {}".format(ex))
except (EndpointModelLoadException, EndpointBackendEngineException) as ex:
Expand Down Expand Up @@ -171,4 +173,54 @@ async def serve_model(model_id: str, version: Optional[str] = None, request: Uni
return return_value


router = APIRouter(
prefix=f"/{os.environ.get('CLEARML_DEFAULT_SERVE_SUFFIX', 'serve')}",
tags=["models"],
responses={404: {"description": "Model Serving Endpoint Not found"}},
route_class=GzipRoute, # mark-out to remove support for GZip content encoding
)


@router.post("/{model_id}/{version}")
@router.post("/{model_id}/")
@router.post("/{model_id}")
async def base_serve_model(
model_id: str,
version: Optional[str] = None,
request: Union[bytes, Dict[Any, Any]] = None
):
return_value = await process_with_exceptions(
base_url=model_id,
version=version,
request_body=request,
serve_type="process"
)
return return_value


async def validate_json_request(raw_request: Request):
content_type = raw_request.headers.get("content-type", "").lower()
media_type = content_type.split(";", maxsplit=1)[0]
if media_type != "application/json":
raise HTTPException(
status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
detail="Unsupported Media Type: Only 'application/json' is allowed"
)

@router.post("/openai/v1/{endpoint_type:path}", dependencies=[Depends(validate_json_request)])
@router.get("/openai/v1/{endpoint_type:path}", dependencies=[Depends(validate_json_request)])
async def openai_serve_model(
endpoint_type: str,
request: Union[CompletionRequest, ChatCompletionRequest],
raw_request: Request
):
combined_request = {"request": request, "raw_request": raw_request}
return_value = await process_with_exceptions(
base_url=request.model,
version=None,
request_body=combined_request,
serve_type=endpoint_type
)
return return_value

app.include_router(router)
23 changes: 17 additions & 6 deletions clearml_serving/serving/model_request_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __init__(
self._serving_base_url = None
self._metric_log_freq = None

async def process_request(self, base_url: str, version: str, request_body: dict) -> dict:
async def process_request(self, base_url: str, version: str, request_body: dict, serve_type: str) -> dict:
"""
Process request coming in,
Raise Value error if url does not match existing endpoints
Expand All @@ -171,7 +171,12 @@ async def process_request(self, base_url: str, version: str, request_body: dict)
while self._update_lock_flag:
await asyncio.sleep(0.5+random())
# retry to process
return await self.process_request(base_url=base_url, version=version, request_body=request_body)
return await self.process_request(
base_url=base_url,
version=version,
request_body=request_body,
serve_type=serve_type
)

try:
# normalize url and version
Expand All @@ -192,7 +197,12 @@ async def process_request(self, base_url: str, version: str, request_body: dict)
processor = processor_cls(model_endpoint=ep, task=self._task)
self._engine_processor_lookup[url] = processor

return_value = await self._process_request(processor=processor, url=url, body=request_body)
return_value = await self._process_request(
processor=processor,
url=url,
body=request_body,
serve_type=serve_type
)
finally:
self._request_processing_state.dec()

Expand Down Expand Up @@ -1193,7 +1203,7 @@ def _deserialize_conf_dict(self, configuration: dict) -> None:
# update preprocessing classes
BasePreprocessRequest.set_server_config(self._configuration)

async def _process_request(self, processor: BasePreprocessRequest, url: str, body: dict) -> dict:
async def _process_request(self, processor: BasePreprocessRequest, url: str, body: dict, serve_type: str) -> dict:
# collect statistics for this request
stats_collect_fn = None
collect_stats = False
Expand All @@ -1215,10 +1225,11 @@ async def _process_request(self, processor: BasePreprocessRequest, url: str, bod
preprocessed = await processor.preprocess(body, state, stats_collect_fn) \
if processor.is_preprocess_async \
else processor.preprocess(body, state, stats_collect_fn)
processed_func = getattr(processor, serve_type.replace("/", "_"))
# noinspection PyUnresolvedReferences
processed = await processor.process(preprocessed, state, stats_collect_fn) \
processed = await processed_func(preprocessed, state, stats_collect_fn) \
if processor.is_process_async \
else processor.process(preprocessed, state, stats_collect_fn)
else processed_func(preprocessed, state, stats_collect_fn)
# noinspection PyUnresolvedReferences
return_value = await processor.postprocess(processed, state, stats_collect_fn) \
if processor.is_postprocess_async \
Expand Down
Loading