4
4
# https://www.aim.security/
5
5
#
6
6
# +-------------------------------------------------------------+
7
-
7
+ import asyncio
8
+ import json
8
9
import os
9
- from typing import Literal , Optional , Union
10
+ from typing import Any , AsyncGenerator , Literal , Optional , Union
10
11
11
12
from fastapi import HTTPException
13
+ from pydantic import BaseModel
14
+ from websockets .asyncio .client import ClientConnection , connect
12
15
13
16
from litellm import DualCache
14
17
from litellm ._logging import verbose_proxy_logger
18
21
httpxSpecialProvider ,
19
22
)
20
23
from litellm .proxy ._types import UserAPIKeyAuth
24
+ from litellm .proxy .proxy_server import StreamingCallbackError
25
+ from litellm .types .utils import (
26
+ Choices ,
27
+ EmbeddingResponse ,
28
+ ImageResponse ,
29
+ ModelResponse ,
30
+ ModelResponseStream ,
31
+ )
21
32
22
33
23
34
class AimGuardrailMissingSecrets (Exception ):
@@ -41,6 +52,9 @@ def __init__(
41
52
self .api_base = (
42
53
api_base or os .environ .get ("AIM_API_BASE" ) or "https://api.aim.security"
43
54
)
55
+ self .ws_api_base = self .api_base .replace ("http://" , "ws://" ).replace (
56
+ "https://" , "wss://"
57
+ )
44
58
super ().__init__ (** kwargs )
45
59
46
60
async def async_pre_call_hook (
@@ -98,8 +112,101 @@ async def call_aim_guardrail(self, data: dict, hook: str) -> None:
98
112
detected = res ["detected" ]
99
113
verbose_proxy_logger .info (
100
114
"Aim: detected: {detected}, enabled policies: {policies}" .format (
101
- detected = detected , policies = list (res ["details" ].keys ())
102
- )
115
+ detected = detected ,
116
+ policies = list (res ["details" ].keys ()),
117
+ ),
103
118
)
104
119
if detected :
105
120
raise HTTPException (status_code = 400 , detail = res ["detection_message" ])
121
+
122
+ async def call_aim_guardrail_on_output (
123
+ self , request_data : dict , output : str , hook : str
124
+ ) -> Optional [str ]:
125
+ user_email = (
126
+ request_data .get ("metadata" , {}).get ("headers" , {}).get ("x-aim-user-email" )
127
+ )
128
+ headers = {
129
+ "Authorization" : f"Bearer { self .api_key } " ,
130
+ "x-aim-litellm-hook" : hook ,
131
+ } | ({"x-aim-user-email" : user_email } if user_email else {})
132
+ response = await self .async_handler .post (
133
+ f"{ self .api_base } /detect/output" ,
134
+ headers = headers ,
135
+ json = {"output" : output , "messages" : request_data .get ("messages" , [])},
136
+ )
137
+ response .raise_for_status ()
138
+ res = response .json ()
139
+ detected = res ["detected" ]
140
+ verbose_proxy_logger .info (
141
+ "Aim: detected: {detected}, enabled policies: {policies}" .format (
142
+ detected = detected ,
143
+ policies = list (res ["details" ].keys ()),
144
+ ),
145
+ )
146
+ if detected :
147
+ return res ["detection_message" ]
148
+ return None
149
+
150
+ async def async_post_call_success_hook (
151
+ self ,
152
+ data : dict ,
153
+ user_api_key_dict : UserAPIKeyAuth ,
154
+ response : Union [Any , ModelResponse , EmbeddingResponse , ImageResponse ],
155
+ ) -> Any :
156
+ if (
157
+ isinstance (response , ModelResponse )
158
+ and response .choices
159
+ and isinstance (response .choices [0 ], Choices )
160
+ ):
161
+ content = response .choices [0 ].message .content or ""
162
+ detection = await self .call_aim_guardrail_on_output (
163
+ data , content , hook = "output"
164
+ )
165
+ if detection :
166
+ raise HTTPException (status_code = 400 , detail = detection )
167
+
168
+ async def async_post_call_streaming_iterator_hook (
169
+ self ,
170
+ user_api_key_dict : UserAPIKeyAuth ,
171
+ response ,
172
+ request_data : dict ,
173
+ ) -> AsyncGenerator [ModelResponseStream , None ]:
174
+ user_email = (
175
+ request_data .get ("metadata" , {}).get ("headers" , {}).get ("x-aim-user-email" )
176
+ )
177
+ headers = {
178
+ "Authorization" : f"Bearer { self .api_key } " ,
179
+ } | ({"x-aim-user-email" : user_email } if user_email else {})
180
+ async with connect (
181
+ f"{ self .ws_api_base } /detect/output/ws" , additional_headers = headers
182
+ ) as websocket :
183
+ sender = asyncio .create_task (
184
+ self .forward_the_stream_to_aim (websocket , response )
185
+ )
186
+ while True :
187
+ result = json .loads (await websocket .recv ())
188
+ if verified_chunk := result .get ("verified_chunk" ):
189
+ yield ModelResponseStream .model_validate (verified_chunk )
190
+ else :
191
+ sender .cancel ()
192
+ if result .get ("done" ):
193
+ return
194
+ if blocking_message := result .get ("blocking_message" ):
195
+ raise StreamingCallbackError (blocking_message )
196
+ verbose_proxy_logger .error (
197
+ f"Unknown message received from AIM: { result } "
198
+ )
199
+ return
200
+
201
+ async def forward_the_stream_to_aim (
202
+ self ,
203
+ websocket : ClientConnection ,
204
+ response_iter ,
205
+ ) -> None :
206
+ async for chunk in response_iter :
207
+ if isinstance (chunk , BaseModel ):
208
+ chunk = chunk .model_dump_json ()
209
+ if isinstance (chunk , dict ):
210
+ chunk = json .dumps (chunk )
211
+ await websocket .send (chunk )
212
+ await websocket .send (json .dumps ({"done" : True }))
0 commit comments