Skip to content

Commit 1f6d8fc

Browse files
authored
Merge pull request #17079 from naaa760/fix/vertex-batch-support
fix: support Vertex AI batch listing in LiteLLM proxy
2 parents e093429 + 2cf86e8 commit 1f6d8fc

File tree

4 files changed

+198
-4
lines changed

4 files changed

+198
-4
lines changed

litellm/batches/main.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ def retrieve_batch(
644644
async def alist_batches(
645645
after: Optional[str] = None,
646646
limit: Optional[int] = None,
647-
custom_llm_provider: Literal["openai", "azure"] = "openai",
647+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
648648
metadata: Optional[Dict[str, str]] = None,
649649
extra_headers: Optional[Dict[str, str]] = None,
650650
extra_body: Optional[Dict[str, str]] = None,
@@ -687,7 +687,7 @@ async def alist_batches(
687687
def list_batches(
688688
after: Optional[str] = None,
689689
limit: Optional[int] = None,
690-
custom_llm_provider: Literal["openai", "azure"] = "openai",
690+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
691691
extra_headers: Optional[Dict[str, str]] = None,
692692
extra_body: Optional[Dict[str, str]] = None,
693693
**kwargs,
@@ -784,9 +784,36 @@ def list_batches(
784784
max_retries=optional_params.max_retries,
785785
litellm_params=litellm_params,
786786
)
787+
elif custom_llm_provider == "vertex_ai":
788+
api_base = optional_params.api_base or ""
789+
vertex_ai_project = (
790+
optional_params.vertex_project
791+
or litellm.vertex_project
792+
or get_secret_str("VERTEXAI_PROJECT")
793+
)
794+
vertex_ai_location = (
795+
optional_params.vertex_location
796+
or litellm.vertex_location
797+
or get_secret_str("VERTEXAI_LOCATION")
798+
)
799+
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
800+
"VERTEXAI_CREDENTIALS"
801+
)
802+
803+
response = vertex_ai_batches_instance.list_batches(
804+
_is_async=_is_async,
805+
after=after,
806+
limit=limit,
807+
api_base=api_base,
808+
vertex_project=vertex_ai_project,
809+
vertex_location=vertex_ai_location,
810+
vertex_credentials=vertex_credentials,
811+
timeout=timeout,
812+
max_retries=optional_params.max_retries,
813+
)
787814
else:
788815
raise litellm.exceptions.BadRequestError(
789-
message="LiteLLM doesn't support {} for 'list_batch'. Only 'openai' is supported.".format(
816+
message="LiteLLM doesn't support {} for 'list_batch'. Supported providers: openai, azure, vertex_ai.".format(
790817
custom_llm_provider
791818
),
792819
model="n/a",

litellm/llms/vertex_ai/batches/handler.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,99 @@ async def _async_retrieve_batch(
213213
response=_json_response
214214
)
215215
return vertex_batch_response
216+
217+
def list_batches(
218+
self,
219+
_is_async: bool,
220+
after: Optional[str],
221+
limit: Optional[int],
222+
api_base: Optional[str],
223+
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
224+
vertex_project: Optional[str],
225+
vertex_location: Optional[str],
226+
timeout: Union[float, httpx.Timeout],
227+
max_retries: Optional[int],
228+
):
229+
sync_handler = _get_httpx_client()
230+
231+
access_token, project_id = self._ensure_access_token(
232+
credentials=vertex_credentials,
233+
project_id=vertex_project,
234+
custom_llm_provider="vertex_ai",
235+
)
236+
237+
default_api_base = self.create_vertex_batch_url(
238+
vertex_location=vertex_location or "us-central1",
239+
vertex_project=vertex_project or project_id,
240+
)
241+
242+
if len(default_api_base.split(":")) > 1:
243+
endpoint = default_api_base.split(":")[-1]
244+
else:
245+
endpoint = ""
246+
247+
_, api_base = self._check_custom_proxy(
248+
api_base=api_base,
249+
custom_llm_provider="vertex_ai",
250+
gemini_api_key=None,
251+
endpoint=endpoint,
252+
stream=None,
253+
auth_header=None,
254+
url=default_api_base,
255+
)
256+
257+
headers = {
258+
"Content-Type": "application/json; charset=utf-8",
259+
"Authorization": f"Bearer {access_token}",
260+
}
261+
262+
params: Dict[str, Any] = {}
263+
if limit is not None:
264+
params["pageSize"] = str(limit)
265+
if after is not None:
266+
params["pageToken"] = after
267+
268+
if _is_async is True:
269+
return self._async_list_batches(
270+
api_base=api_base,
271+
headers=headers,
272+
params=params,
273+
)
274+
275+
response = sync_handler.get(
276+
url=api_base,
277+
headers=headers,
278+
params=params,
279+
)
280+
281+
if response.status_code != 200:
282+
raise Exception(f"Error: {response.status_code} {response.text}")
283+
284+
_json_response = response.json()
285+
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_list_response_to_openai_list_response(
286+
response=_json_response
287+
)
288+
return vertex_batch_response
289+
290+
async def _async_list_batches(
291+
self,
292+
api_base: str,
293+
headers: Dict[str, str],
294+
params: Dict[str, Any],
295+
):
296+
client = get_async_httpx_client(
297+
llm_provider=litellm.LlmProviders.VERTEX_AI,
298+
)
299+
response = await client.get(
300+
url=api_base,
301+
headers=headers,
302+
params=params,
303+
)
304+
if response.status_code != 200:
305+
raise Exception(f"Error: {response.status_code} {response.text}")
306+
307+
_json_response = response.json()
308+
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_list_response_to_openai_list_response(
309+
response=_json_response
310+
)
311+
return vertex_batch_response

litellm/llms/vertex_ai/batches/transformation.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from litellm._uuid import uuid
2-
from typing import Dict
2+
from typing import Any, Dict
33

44
from litellm.llms.vertex_ai.common_utils import (
55
_convert_vertex_datetime_to_openai_datetime,
@@ -67,6 +67,33 @@ def transform_vertex_ai_batch_response_to_openai_batch_response(
6767
),
6868
)
6969

70+
@classmethod
71+
def transform_vertex_ai_batch_list_response_to_openai_list_response(
72+
cls, response: Dict[str, Any]
73+
) -> Dict[str, Any]:
74+
"""
75+
Transforms Vertex AI batch list response into OpenAI-compatible list response.
76+
"""
77+
78+
batch_jobs = response.get("batchPredictionJobs", []) or []
79+
data = [
80+
cls.transform_vertex_ai_batch_response_to_openai_batch_response(job)
81+
for job in batch_jobs
82+
]
83+
84+
first_id = data[0].id if len(data) > 0 else None
85+
last_id = data[-1].id if len(data) > 0 else None
86+
next_page_token = response.get("nextPageToken")
87+
88+
return {
89+
"object": "list",
90+
"data": data,
91+
"first_id": first_id,
92+
"last_id": last_id,
93+
"has_more": bool(next_page_token),
94+
"next_page_token": next_page_token,
95+
}
96+
7097
@classmethod
7198
def _get_batch_id_from_vertex_ai_batch_response(
7299
cls, response: VertexBatchPredictionResponse

tests/batches_tests/test_openai_batches_and_files.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,18 @@ async def test_async_create_batch(provider):
447447
"completionStats": {"successfulCount": 0, "failedCount": 0, "remainingCount": 100},
448448
}
449449

450+
mock_vertex_list_response = {
451+
"batchPredictionJobs": [
452+
mock_vertex_batch_response,
453+
{
454+
**mock_vertex_batch_response,
455+
"name": "projects/123456789/locations/us-central1/batchPredictionJobs/test-batch-id-789",
456+
"state": "JOB_STATE_SUCCEEDED",
457+
},
458+
],
459+
"nextPageToken": "",
460+
}
461+
450462

451463
@pytest.mark.asyncio
452464
async def test_avertex_batch_prediction(monkeypatch):
@@ -533,3 +545,35 @@ async def mock_side_effect(*args, **kwargs):
533545
print("retrieved_batch=", retrieved_batch)
534546

535547
assert retrieved_batch.id == "test-batch-id-456"
548+
549+
550+
@pytest.mark.asyncio
551+
async def test_vertex_list_batches(monkeypatch):
552+
monkeypatch.setenv("GCS_BUCKET_NAME", "litellm-local")
553+
monkeypatch.setenv("VERTEXAI_PROJECT", "litellm-test-project")
554+
monkeypatch.setenv("VERTEXAI_LOCATION", "us-central1")
555+
556+
monkeypatch.setattr(
557+
"litellm.llms.vertex_ai.batches.handler.VertexAIBatchPrediction._ensure_access_token",
558+
lambda self, credentials, project_id, custom_llm_provider: ("mock-token", "litellm-test-project"),
559+
)
560+
561+
with patch(
562+
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.get"
563+
) as mock_get:
564+
mock_get_response = MagicMock()
565+
mock_get_response.json.return_value = mock_vertex_list_response
566+
mock_get_response.status_code = 200
567+
mock_get_response.raise_for_status.return_value = None
568+
mock_get.return_value = mock_get_response
569+
570+
list_response = await litellm.alist_batches(
571+
custom_llm_provider="vertex_ai",
572+
limit=2,
573+
)
574+
575+
assert list_response["object"] == "list"
576+
assert list_response["has_more"] is False
577+
assert len(list_response["data"]) == 2
578+
assert list_response["data"][0].id == "test-batch-id-456"
579+
assert list_response["data"][1].id == "test-batch-id-789"

0 commit comments

Comments
 (0)