Skip to content

Commit d4caaae

Browse files
Merge pull request BerriAI#9274 from BerriAI/litellm_contributor_rebase_branch
Litellm contributor rebase branch
2 parents cd95634 + 4bc5f27 commit d4caaae

File tree

15 files changed

+467
-44
lines changed

15 files changed

+467
-44
lines changed

.circleci/config.yml

+3-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ jobs:
7171
pip install "Pillow==10.3.0"
7272
pip install "jsonschema==4.22.0"
7373
pip install "pytest-xdist==3.6.1"
74-
pip install "websockets==10.4"
74+
pip install "websockets==13.1.0"
7575
pip uninstall posthog -y
7676
- save_cache:
7777
paths:
@@ -189,6 +189,7 @@ jobs:
189189
pip install "diskcache==5.6.1"
190190
pip install "Pillow==10.3.0"
191191
pip install "jsonschema==4.22.0"
192+
pip install "websockets==13.1.0"
192193
- save_cache:
193194
paths:
194195
- ./venv
@@ -288,6 +289,7 @@ jobs:
288289
pip install "diskcache==5.6.1"
289290
pip install "Pillow==10.3.0"
290291
pip install "jsonschema==4.22.0"
292+
pip install "websockets==13.1.0"
291293
- save_cache:
292294
paths:
293295
- ./venv

docs/my-website/docs/proxy/guardrails/aim_security.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ guardrails:
3737
- guardrail_name: aim-protected-app
3838
litellm_params:
3939
guardrail: aim
40-
mode: pre_call # 'during_call' is also available
40+
mode: [pre_call, post_call] # "During_call" is also available
4141
api_key: os.environ/AIM_API_KEY
4242
api_base: os.environ/AIM_API_BASE # Optional, use only when using a self-hosted Aim Outpost
4343
```

litellm/integrations/custom_logger.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
#### What this does ####
22
# On success, logs events to Promptlayer
33
import traceback
4-
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
4+
from typing import (
5+
TYPE_CHECKING,
6+
Any,
7+
AsyncGenerator,
8+
List,
9+
Literal,
10+
Optional,
11+
Tuple,
12+
Union,
13+
)
514

615
from pydantic import BaseModel
716

@@ -14,6 +23,7 @@
1423
EmbeddingResponse,
1524
ImageResponse,
1625
ModelResponse,
26+
ModelResponseStream,
1727
StandardCallbackDynamicParams,
1828
StandardLoggingPayload,
1929
)
@@ -251,6 +261,15 @@ async def async_post_call_streaming_hook(
251261
) -> Any:
252262
pass
253263

264+
async def async_post_call_streaming_iterator_hook(
265+
self,
266+
user_api_key_dict: UserAPIKeyAuth,
267+
response: Any,
268+
request_data: dict,
269+
) -> AsyncGenerator[ModelResponseStream, None]:
270+
async for item in response:
271+
yield item
272+
254273
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
255274

256275
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):

litellm/llms/bedrock/chat/converse_handler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def completion( # noqa: PLR0915
274274
if modelId is not None:
275275
modelId = self.encode_model_id(model_id=modelId)
276276
else:
277-
modelId = model
277+
modelId = self.encode_model_id(model_id=model)
278278

279279
if stream is True and "ai21" in modelId:
280280
fake_stream = True

litellm/proxy/guardrails/guardrail_hooks/aim.py

+111-4
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44
# https://www.aim.security/
55
#
66
# +-------------------------------------------------------------+
7-
7+
import asyncio
8+
import json
89
import os
9-
from typing import Literal, Optional, Union
10+
from typing import Any, AsyncGenerator, Literal, Optional, Union
1011

1112
from fastapi import HTTPException
13+
from pydantic import BaseModel
14+
from websockets.asyncio.client import ClientConnection, connect
1215

1316
from litellm import DualCache
1417
from litellm._logging import verbose_proxy_logger
@@ -18,6 +21,14 @@
1821
httpxSpecialProvider,
1922
)
2023
from litellm.proxy._types import UserAPIKeyAuth
24+
from litellm.proxy.proxy_server import StreamingCallbackError
25+
from litellm.types.utils import (
26+
Choices,
27+
EmbeddingResponse,
28+
ImageResponse,
29+
ModelResponse,
30+
ModelResponseStream,
31+
)
2132

2233

2334
class AimGuardrailMissingSecrets(Exception):
@@ -41,6 +52,9 @@ def __init__(
4152
self.api_base = (
4253
api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security"
4354
)
55+
self.ws_api_base = self.api_base.replace("http://", "ws://").replace(
56+
"https://", "wss://"
57+
)
4458
super().__init__(**kwargs)
4559

4660
async def async_pre_call_hook(
@@ -98,8 +112,101 @@ async def call_aim_guardrail(self, data: dict, hook: str) -> None:
98112
detected = res["detected"]
99113
verbose_proxy_logger.info(
100114
"Aim: detected: {detected}, enabled policies: {policies}".format(
101-
detected=detected, policies=list(res["details"].keys())
102-
)
115+
detected=detected,
116+
policies=list(res["details"].keys()),
117+
),
103118
)
104119
if detected:
105120
raise HTTPException(status_code=400, detail=res["detection_message"])
121+
122+
async def call_aim_guardrail_on_output(
123+
self, request_data: dict, output: str, hook: str
124+
) -> Optional[str]:
125+
user_email = (
126+
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
127+
)
128+
headers = {
129+
"Authorization": f"Bearer {self.api_key}",
130+
"x-aim-litellm-hook": hook,
131+
} | ({"x-aim-user-email": user_email} if user_email else {})
132+
response = await self.async_handler.post(
133+
f"{self.api_base}/detect/output",
134+
headers=headers,
135+
json={"output": output, "messages": request_data.get("messages", [])},
136+
)
137+
response.raise_for_status()
138+
res = response.json()
139+
detected = res["detected"]
140+
verbose_proxy_logger.info(
141+
"Aim: detected: {detected}, enabled policies: {policies}".format(
142+
detected=detected,
143+
policies=list(res["details"].keys()),
144+
),
145+
)
146+
if detected:
147+
return res["detection_message"]
148+
return None
149+
150+
async def async_post_call_success_hook(
151+
self,
152+
data: dict,
153+
user_api_key_dict: UserAPIKeyAuth,
154+
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
155+
) -> Any:
156+
if (
157+
isinstance(response, ModelResponse)
158+
and response.choices
159+
and isinstance(response.choices[0], Choices)
160+
):
161+
content = response.choices[0].message.content or ""
162+
detection = await self.call_aim_guardrail_on_output(
163+
data, content, hook="output"
164+
)
165+
if detection:
166+
raise HTTPException(status_code=400, detail=detection)
167+
168+
async def async_post_call_streaming_iterator_hook(
169+
self,
170+
user_api_key_dict: UserAPIKeyAuth,
171+
response,
172+
request_data: dict,
173+
) -> AsyncGenerator[ModelResponseStream, None]:
174+
user_email = (
175+
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
176+
)
177+
headers = {
178+
"Authorization": f"Bearer {self.api_key}",
179+
} | ({"x-aim-user-email": user_email} if user_email else {})
180+
async with connect(
181+
f"{self.ws_api_base}/detect/output/ws", additional_headers=headers
182+
) as websocket:
183+
sender = asyncio.create_task(
184+
self.forward_the_stream_to_aim(websocket, response)
185+
)
186+
while True:
187+
result = json.loads(await websocket.recv())
188+
if verified_chunk := result.get("verified_chunk"):
189+
yield ModelResponseStream.model_validate(verified_chunk)
190+
else:
191+
sender.cancel()
192+
if result.get("done"):
193+
return
194+
if blocking_message := result.get("blocking_message"):
195+
raise StreamingCallbackError(blocking_message)
196+
verbose_proxy_logger.error(
197+
f"Unknown message received from AIM: {result}"
198+
)
199+
return
200+
201+
async def forward_the_stream_to_aim(
202+
self,
203+
websocket: ClientConnection,
204+
response_iter,
205+
) -> None:
206+
async for chunk in response_iter:
207+
if isinstance(chunk, BaseModel):
208+
chunk = chunk.model_dump_json()
209+
if isinstance(chunk, dict):
210+
chunk = json.dumps(chunk)
211+
await websocket.send(chunk)
212+
await websocket.send(json.dumps({"done": True}))

litellm/proxy/proxy_server.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
get_origin,
2424
get_type_hints,
2525
)
26+
from litellm.types.utils import (
27+
ModelResponse,
28+
ModelResponseStream,
29+
TextCompletionResponse,
30+
)
2631

2732
if TYPE_CHECKING:
2833
from opentelemetry.trace import Span as _Span
@@ -1377,6 +1382,10 @@ async def _run_background_health_check():
13771382
await asyncio.sleep(health_check_interval)
13781383

13791384

1385+
class StreamingCallbackError(Exception):
1386+
pass
1387+
1388+
13801389
class ProxyConfig:
13811390
"""
13821391
Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
@@ -3038,8 +3047,7 @@ async def async_data_generator(
30383047
):
30393048
verbose_proxy_logger.debug("inside generator")
30403049
try:
3041-
time.time()
3042-
async for chunk in response:
3050+
async for chunk in proxy_logging_obj.async_post_call_streaming_iterator_hook(user_api_key_dict=user_api_key_dict, response=response, request_data=request_data):
30433051
verbose_proxy_logger.debug(
30443052
"async_data_generator: received streaming chunk - {}".format(chunk)
30453053
)
@@ -3076,6 +3084,8 @@ async def async_data_generator(
30763084

30773085
if isinstance(e, HTTPException):
30783086
raise e
3087+
elif isinstance(e, StreamingCallbackError):
3088+
error_msg = str(e)
30793089
else:
30803090
error_traceback = traceback.format_exc()
30813091
error_msg = f"{str(e)}\n\n{error_traceback}"
@@ -5421,11 +5431,11 @@ async def token_counter(request: TokenCountRequest):
54215431
)
54225432
async def supported_openai_params(model: str):
54235433
"""
5424-
Returns supported openai params for a given litellm model name
5434+
Returns supported openai params for a given litellm model name
54255435
5426-
e.g. `gpt-4` vs `gpt-3.5-turbo`
5436+
e.g. `gpt-4` vs `gpt-3.5-turbo`
54275437
5428-
Example curl:
5438+
Example curl:
54295439
```
54305440
curl -X GET --location 'http://localhost:4000/utils/supported_openai_params?model=gpt-3.5-turbo-16k' \
54315441
--header 'Authorization: Bearer sk-1234'
@@ -6194,7 +6204,7 @@ async def model_group_info(
61946204
- /model_group/info returns all model groups. End users of proxy should use /model_group/info since those models will be used for /chat/completions, /embeddings, etc.
61956205
- /model_group/info?model_group=rerank-english-v3.0 returns all model groups for a specific model group (`model_name` in config.yaml)
61966206
6197-
6207+
61986208
61996209
Example Request (All Models):
62006210
```shell
@@ -6212,10 +6222,10 @@ async def model_group_info(
62126222
-H 'Authorization: Bearer sk-1234'
62136223
```
62146224
6215-
Example Request (Specific Wildcard Model Group): (e.g. `model_name: openai/*` on config.yaml)
6225+
Example Request (Specific Wildcard Model Group): (e.g. `model_name: openai/*` on config.yaml)
62166226
```shell
62176227
curl -X 'GET' \
6218-
'http://localhost:4000/model_group/info?model_group=openai/tts-1'
6228+
'http://localhost:4000/model_group/info?model_group=openai/tts-1'
62196229
-H 'accept: application/json' \
62206230
-H 'Authorization: Bearersk-1234'
62216231
```
@@ -7242,7 +7252,7 @@ async def invitation_update(
72427252
):
72437253
"""
72447254
Update when invitation is accepted
7245-
7255+
72467256
```
72477257
curl -X POST 'http://localhost:4000/invitation/update' \
72487258
-H 'Content-Type: application/json' \
@@ -7303,7 +7313,7 @@ async def invitation_delete(
73037313
):
73047314
"""
73057315
Delete invitation link
7306-
7316+
73077317
```
73087318
curl -X POST 'http://localhost:4000/invitation/delete' \
73097319
-H 'Content-Type: application/json' \

litellm/proxy/utils.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ProxyErrorTypes,
1919
ProxyException,
2020
)
21+
from litellm.types.guardrails import GuardrailEventHooks
2122

2223
try:
2324
import backoff
@@ -31,7 +32,7 @@
3132
import litellm
3233
import litellm.litellm_core_utils
3334
import litellm.litellm_core_utils.litellm_logging
34-
from litellm import EmbeddingResponse, ImageResponse, ModelResponse, Router
35+
from litellm import EmbeddingResponse, ImageResponse, ModelResponse, Router, ModelResponseStream
3536
from litellm._logging import verbose_proxy_logger
3637
from litellm._service_logger import ServiceLogging, ServiceTypes
3738
from litellm.caching.caching import DualCache, RedisCache
@@ -972,7 +973,7 @@ async def async_post_call_streaming_hook(
972973
1. /chat/completions
973974
"""
974975
response_str: Optional[str] = None
975-
if isinstance(response, ModelResponse):
976+
if isinstance(response, (ModelResponse, ModelResponseStream)):
976977
response_str = litellm.get_response_string(response_obj=response)
977978
if response_str is not None:
978979
for callback in litellm.callbacks:
@@ -992,6 +993,35 @@ async def async_post_call_streaming_hook(
992993
raise e
993994
return response
994995

996+
def async_post_call_streaming_iterator_hook(
997+
self,
998+
response,
999+
user_api_key_dict: UserAPIKeyAuth,
1000+
request_data: dict,
1001+
):
1002+
"""
1003+
Allow user to modify outgoing streaming data -> Given a whole response iterator.
1004+
This hook is best used when you need to modify multiple chunks of the response at once.
1005+
1006+
Covers:
1007+
1. /chat/completions
1008+
"""
1009+
for callback in litellm.callbacks:
1010+
_callback: Optional[CustomLogger] = None
1011+
if isinstance(callback, str):
1012+
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(callback)
1013+
else:
1014+
_callback = callback # type: ignore
1015+
if _callback is not None and isinstance(_callback, CustomLogger):
1016+
if not isinstance(_callback, CustomGuardrail) or _callback.should_run_guardrail(
1017+
data=request_data, event_type=GuardrailEventHooks.post_call
1018+
):
1019+
response = _callback.async_post_call_streaming_iterator_hook(
1020+
user_api_key_dict=user_api_key_dict, response=response, request_data=request_data
1021+
)
1022+
return response
1023+
1024+
9951025
async def post_call_streaming_hook(
9961026
self,
9971027
response: str,

litellm/realtime_api/main.py

+2
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ async def _realtime_health_check(
151151
url = openai_realtime._construct_url(
152152
api_base=api_base or "https://api.openai.com/", model=model
153153
)
154+
else:
155+
raise ValueError(f"Unsupported model: {model}")
154156
async with websockets.connect( # type: ignore
155157
url,
156158
extra_headers={

0 commit comments

Comments
 (0)