Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ads/aqua/common/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class PredictEndpoints(ExtendedEnum):
CHAT_COMPLETIONS_ENDPOINT = "/v1/chat/completions"
TEXT_COMPLETIONS_ENDPOINT = "/v1/completions"
EMBEDDING_ENDPOINT = "/v1/embedding"
RESPONSES = "/v1/responses"


class Tags(ExtendedEnum):
Expand Down
301 changes: 253 additions & 48 deletions ads/aqua/extension/deployment_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from typing import List, Optional, Union
from typing import List, Union
from urllib.parse import urlparse

from tornado.web import HTTPError

from ads.aqua.app import logger
from ads.aqua import logger
from ads.aqua.client.client import Client, ExtendedRequestError
from ads.aqua.client.openai_client import OpenAI
from ads.aqua.common.decorator import handle_exceptions
from ads.aqua.common.enums import PredictEndpoints
from ads.aqua.extension.base_handler import AquaAPIhandler
Expand Down Expand Up @@ -221,12 +222,98 @@ def list_shapes(self):


class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler):
def _get_model_deployment_response(
self,
model_deployment_id: str,
payload: dict,
route_override_header: Optional[str],
):
def _extract_text_from_choice(self, choice: dict) -> str:
"""
Extract text content from a single choice structure.

Handles both dictionary-based API responses and object-based SDK responses.
For dict choices, it checks delta-based streaming fields, message-based
non-streaming fields, and finally top-level text/content keys.
For object choices, it inspects `.delta`, `.message`, and top-level
`.text` or `.content` attributes.

Parameters
----------
choice : dict
A choice entry from a model response. It may be:
- A dict originating from a JSON API response (streaming or non-streaming).
- An SDK-style object with attributes such as `delta`, `message`,
`text`, or `content`.

For dicts, the method checks:
• delta → content/text
• message → content/text
• top-level → text/content

For objects, the method checks the same fields via attributes.

Returns
-------
str | None:
The extracted text if present; otherwise None.
"""
# choice may be a dict or an object
if isinstance(choice, dict):
# streaming chunk: {"delta": {"content": "..."}}
delta = choice.get("delta")
if isinstance(delta, dict):
return delta.get("content") or delta.get("text") or None
# non-streaming: {"message": {"content": "..."}}
msg = choice.get("message")
if isinstance(msg, dict):
return msg.get("content") or msg.get("text")
# fallback top-level fields
return choice.get("text") or choice.get("content")
# object-like choice
delta = getattr(choice, "delta", None)
if delta is not None:
return getattr(delta, "content", None) or getattr(delta, "text", None)
msg = getattr(choice, "message", None)
if msg is not None:
if isinstance(msg, str):
return msg
return getattr(msg, "content", None) or getattr(msg, "text", None)
return getattr(choice, "text", None) or getattr(choice, "content", None)

def _extract_text_from_chunk(self, chunk: dict) -> str:
"""
Extract text content from a model response chunk.

Supports both dict-form chunks (streaming or non-streaming) and SDK-style
object chunks. When choices are present, extraction is delegated to
`_extract_text_from_choice`. If no choices exist, top-level text/content
fields or attributes are used.

Parameters
----------
chunk : dict
A chunk returned from a model stream or full response. It may be:
- A dict containing a `choices` list or top-level text/content fields.
- An SDK-style object with a `choices` attribute or top-level
`text`/`content` attributes.

If `choices` is present, the method extracts text from the first
choice using `_extract_text_from_choice`. Otherwise, it falls back
to top-level text/content.
Returns
-------
str
The extracted text if present; otherwise None.
"""
if chunk:
if isinstance(chunk, dict):
choices = chunk.get("choices") or []
if choices:
return self._extract_text_from_choice(choices[0])
# fallback top-level
return chunk.get("text") or chunk.get("content")
# object-like chunk
choices = getattr(chunk, "choices", None)
if choices:
return self._extract_text_from_choice(choices[0])
return getattr(chunk, "text", None) or getattr(chunk, "content", None)

def _get_model_deployment_response(self, model_deployment_id: str, payload: dict):
"""
Returns the model deployment inference response in a streaming fashion.

Expand Down Expand Up @@ -272,53 +359,172 @@ def _get_model_deployment_response(
"""

model_deployment = AquaDeploymentApp().get(model_deployment_id)
endpoint = model_deployment.endpoint + "/predictWithResponseStream"
endpoint_type = model_deployment.environment_variables.get(
"MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT
)
aqua_client = Client(endpoint=endpoint)

if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in (
endpoint_type,
route_override_header,
endpoint = model_deployment.endpoint + "/predictWithResponseStream/v1"

required_keys = ["endpoint_type", "prompt", "model"]
missing = [k for k in required_keys if k not in payload]

if missing:
raise HTTPError(400, f"Missing required payload keys: {', '.join(missing)}")

endpoint_type = payload["endpoint_type"]
aqua_client = OpenAI(base_url=endpoint)

allowed = {
"max_tokens",
"temperature",
"top_p",
"stop",
"n",
"presence_penalty",
"frequency_penalty",
"logprobs",
"user",
"echo",
}
responses_allowed = {"temperature", "top_p"}

# normalize and filter
if payload.get("stop") == []:
payload["stop"] = None

encoded_image = "NA"
if "encoded_image" in payload:
encoded_image = payload["encoded_image"]

model = payload.pop("model")
filtered = {k: v for k, v in payload.items() if k in allowed}
responses_filtered = {
k: v for k, v in payload.items() if k in responses_allowed
}

if (
endpoint_type == PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT
and encoded_image == "NA"
):
try:
for chunk in aqua_client.chat(
messages=payload.pop("messages"),
payload=payload,
stream=True,
):
try:
if "text" in chunk["choices"][0]:
yield chunk["choices"][0]["text"]
elif "content" in chunk["choices"][0]["delta"]:
yield chunk["choices"][0]["delta"]["content"]
except Exception as e:
logger.debug(
f"Exception occurred while parsing streaming response: {e}"
)
api_kwargs = {
"model": model,
"messages": [{"role": "user", "content": payload["prompt"]}],
"stream": True,
**filtered,
}
if "chat_template" in payload:
chat_template = payload.pop("chat_template")
api_kwargs["extra_body"] = {"chat_template": chat_template}

stream = aqua_client.chat.completions.create(**api_kwargs)

for chunk in stream:
if chunk:
piece = self._extract_text_from_chunk(chunk)
if piece:
yield piece
except ExtendedRequestError as ex:
raise HTTPError(400, str(ex))
raise HTTPError(400, str(ex)) from ex
except Exception as ex:
raise HTTPError(500, str(ex))
raise HTTPError(500, str(ex)) from ex

elif (
endpoint_type == PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT
and encoded_image != "NA"
):
file_type = payload.pop("file_type")
if file_type.startswith("image"):
api_kwargs = {
"model": model,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": payload["prompt"]},
{
"type": "image_url",
"image_url": {"url": f"{encoded_image}"},
},
],
}
],
"stream": True,
**filtered,
}

# Add chat_template for image-based chat completions
if "chat_template" in payload:
chat_template = payload.pop("chat_template")
api_kwargs["extra_body"] = {"chat_template": chat_template}

response = aqua_client.chat.completions.create(**api_kwargs)

elif file_type.startswith("audio"):
api_kwargs = {
"model": model,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": payload["prompt"]},
{
"type": "audio_url",
"audio_url": {"url": f"{encoded_image}"},
},
],
}
],
"stream": True,
**filtered,
}

# Add chat_template for audio-based chat completions
if "chat_template" in payload:
chat_template = payload.pop("chat_template")
api_kwargs["extra_body"] = {"chat_template": chat_template}

response = aqua_client.chat.completions.create(**api_kwargs)
try:
for chunk in response:
piece = self._extract_text_from_chunk(chunk)
if piece:
yield piece
except ExtendedRequestError as ex:
raise HTTPError(400, str(ex)) from ex
except Exception as ex:
raise HTTPError(500, str(ex)) from ex
elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT:
try:
for chunk in aqua_client.generate(
prompt=payload.pop("prompt"),
payload=payload,
stream=True,
for chunk in aqua_client.completions.create(
prompt=payload["prompt"], stream=True, model=model, **filtered
):
try:
yield chunk["choices"][0]["text"]
except Exception as e:
logger.debug(
f"Exception occurred while parsing streaming response: {e}"
)
if chunk:
piece = self._extract_text_from_chunk(chunk)
if piece:
yield piece
except ExtendedRequestError as ex:
raise HTTPError(400, str(ex)) from ex
except Exception as ex:
raise HTTPError(500, str(ex)) from ex

elif endpoint_type == PredictEndpoints.RESPONSES:
kwargs = {"model": model, "input": payload["prompt"], "stream": True}

if "temperature" in responses_filtered:
kwargs["temperature"] = responses_filtered["temperature"]
if "top_p" in responses_filtered:
kwargs["top_p"] = responses_filtered["top_p"]

response = aqua_client.responses.create(**kwargs)
try:
for chunk in response:
if chunk:
piece = self._extract_text_from_chunk(chunk)
if piece:
yield piece
except ExtendedRequestError as ex:
raise HTTPError(400, str(ex))
raise HTTPError(400, str(ex)) from ex
except Exception as ex:
raise HTTPError(500, str(ex))
raise HTTPError(500, str(ex)) from ex
else:
raise HTTPError(400, f"Unsupported endpoint_type: {endpoint_type}")

@handle_exceptions
def post(self, model_deployment_id):
Expand Down Expand Up @@ -346,18 +552,17 @@ def post(self, model_deployment_id):
)
if not input_data.get("model"):
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model"))
route_override_header = self.request.headers.get("route", None)
self.set_header("Content-Type", "text/event-stream")
response_gen = self._get_model_deployment_response(
model_deployment_id, input_data, route_override_header
model_deployment_id, input_data
)
try:
for chunk in response_gen:
self.write(chunk)
self.flush()
self.finish()
except Exception as ex:
self.set_status(ex.status_code)
self.set_status(getattr(ex, "status_code", 500))
self.write({"message": "Error occurred", "reason": str(ex)})
self.finish()

Expand Down
Loading
Loading