Skip to content

Commit 88eedb2

Browse files
vertex ai anthropic thinking param support (BerriAI#8853)
* fix(vertex_llm_base.py): handle credentials passed in as dictionary * fix(router.py): support vertex credentials as json dict * test(test_vertex.py): allows easier testing mock anthropic thinking response for vertex ai * test(vertex_ai_partner_models/): don't remove "@" from model breaks anthropic cost calculation * test: move testing * fix: fix linting error * fix: fix linting error * fix(vertex_ai_partner_models/main.py): split @ for codestral model * test: fix test * fix: fix stripping "@" on mistral models * fix: fix test * test: fix test
1 parent 992e78d commit 88eedb2

File tree

15 files changed

+135
-45
lines changed

15 files changed

+135
-45
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,5 @@ litellm/proxy/_experimental/out/404.html
7777
litellm/proxy/_experimental/out/model_hub.html
7878
.mypy_cache/*
7979
litellm/proxy/application.log
80+
tests/llm_translation/vertex_test_account.json
81+
tests/llm_translation/test_vertex_key.json

litellm/llms/vertex_ai/batches/handler.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
)
1111
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
1212
from litellm.types.llms.openai import Batch, CreateBatchRequest
13-
from litellm.types.llms.vertex_ai import VertexAIBatchPredictionJob
13+
from litellm.types.llms.vertex_ai import (
14+
VERTEX_CREDENTIALS_TYPES,
15+
VertexAIBatchPredictionJob,
16+
)
1417

1518
from .transformation import VertexAIBatchTransformation
1619

@@ -25,7 +28,7 @@ def create_batch(
2528
_is_async: bool,
2629
create_batch_data: CreateBatchRequest,
2730
api_base: Optional[str],
28-
vertex_credentials: Optional[str],
31+
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
2932
vertex_project: Optional[str],
3033
vertex_location: Optional[str],
3134
timeout: Union[float, httpx.Timeout],
@@ -130,7 +133,7 @@ def retrieve_batch(
130133
_is_async: bool,
131134
batch_id: str,
132135
api_base: Optional[str],
133-
vertex_credentials: Optional[str],
136+
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
134137
vertex_project: Optional[str],
135138
vertex_location: Optional[str],
136139
timeout: Union[float, httpx.Timeout],

litellm/llms/vertex_ai/files/handler.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
)
1010
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
1111
from litellm.types.llms.openai import CreateFileRequest, FileObject
12+
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
1213

1314
from .transformation import VertexAIFilesTransformation
1415

@@ -34,7 +35,7 @@ async def async_create_file(
3435
self,
3536
create_file_data: CreateFileRequest,
3637
api_base: Optional[str],
37-
vertex_credentials: Optional[str],
38+
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
3839
vertex_project: Optional[str],
3940
vertex_location: Optional[str],
4041
timeout: Union[float, httpx.Timeout],
@@ -70,7 +71,7 @@ def create_file(
7071
_is_async: bool,
7172
create_file_data: CreateFileRequest,
7273
api_base: Optional[str],
73-
vertex_credentials: Optional[str],
74+
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
7475
vertex_project: Optional[str],
7576
vertex_location: Optional[str],
7677
timeout: Union[float, httpx.Timeout],

litellm/llms/vertex_ai/fine_tuning/handler.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from litellm.types.fine_tuning import OpenAIFineTuningHyperparameters
1414
from litellm.types.llms.openai import FineTuningJobCreate
1515
from litellm.types.llms.vertex_ai import (
16+
VERTEX_CREDENTIALS_TYPES,
1617
FineTuneHyperparameters,
1718
FineTuneJobCreate,
1819
FineTunesupervisedTuningSpec,
@@ -222,7 +223,7 @@ def create_fine_tuning_job(
222223
create_fine_tuning_job_data: FineTuningJobCreate,
223224
vertex_project: Optional[str],
224225
vertex_location: Optional[str],
225-
vertex_credentials: Optional[str],
226+
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
226227
api_base: Optional[str],
227228
timeout: Union[float, httpx.Timeout],
228229
kwargs: Optional[dict] = None,

litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
ChatCompletionUsageBlock,
4141
)
4242
from litellm.types.llms.vertex_ai import (
43+
VERTEX_CREDENTIALS_TYPES,
4344
Candidates,
4445
ContentType,
4546
FunctionCallingConfig,
@@ -930,7 +931,7 @@ async def async_streaming(
930931
client: Optional[AsyncHTTPHandler] = None,
931932
vertex_project: Optional[str] = None,
932933
vertex_location: Optional[str] = None,
933-
vertex_credentials: Optional[str] = None,
934+
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
934935
gemini_api_key: Optional[str] = None,
935936
extra_headers: Optional[dict] = None,
936937
) -> CustomStreamWrapper:
@@ -1018,7 +1019,7 @@ async def async_completion(
10181019
client: Optional[AsyncHTTPHandler] = None,
10191020
vertex_project: Optional[str] = None,
10201021
vertex_location: Optional[str] = None,
1021-
vertex_credentials: Optional[str] = None,
1022+
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
10221023
gemini_api_key: Optional[str] = None,
10231024
extra_headers: Optional[dict] = None,
10241025
) -> Union[ModelResponse, CustomStreamWrapper]:
@@ -1123,7 +1124,7 @@ def completion(
11231124
timeout: Optional[Union[float, httpx.Timeout]],
11241125
vertex_project: Optional[str],
11251126
vertex_location: Optional[str],
1126-
vertex_credentials: Optional[str],
1127+
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
11271128
gemini_api_key: Optional[str],
11281129
litellm_params: dict,
11291130
logger_fn=None,

litellm/llms/vertex_ai/image_generation/image_generation_handler.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
get_async_httpx_client,
1212
)
1313
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
14+
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
1415
from litellm.types.utils import ImageResponse
1516

1617

@@ -44,7 +45,7 @@ def image_generation(
4445
prompt: str,
4546
vertex_project: Optional[str],
4647
vertex_location: Optional[str],
47-
vertex_credentials: Optional[str],
48+
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
4849
model_response: ImageResponse,
4950
logging_obj: Any,
5051
model: Optional[
@@ -139,7 +140,7 @@ async def aimage_generation(
139140
prompt: str,
140141
vertex_project: Optional[str],
141142
vertex_location: Optional[str],
142-
vertex_credentials: Optional[str],
143+
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
143144
model_response: litellm.ImageResponse,
144145
logging_obj: Any,
145146
model: Optional[

litellm/llms/vertex_ai/text_to_speech/text_to_speech_handler.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
)
1010
from litellm.llms.openai.openai import HttpxBinaryResponseContent
1111
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
12+
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
1213

1314

1415
class VertexInput(TypedDict, total=False):
@@ -45,7 +46,7 @@ def audio_speech(
4546
logging_obj,
4647
vertex_project: Optional[str],
4748
vertex_location: Optional[str],
48-
vertex_credentials: Optional[str],
49+
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
4950
api_base: Optional[str],
5051
timeout: Union[float, httpx.Timeout],
5152
model: str,

litellm/llms/vertex_ai/vertex_ai_partner_models/main.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ def completion(
160160
url=default_api_base,
161161
)
162162

163-
model = model.split("@")[0]
163+
if "codestral" in model or "mistral" in model:
164+
model = model.split("@")[0]
164165

165166
if "codestral" in model and litellm_params.get("text_completion") is True:
166167
optional_params["model"] = model

litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def embedding(
4141
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
4242
vertex_project: Optional[str] = None,
4343
vertex_location: Optional[str] = None,
44-
vertex_credentials: Optional[str] = None,
44+
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
4545
gemini_api_key: Optional[str] = None,
4646
extra_headers: Optional[dict] = None,
4747
) -> EmbeddingResponse:
@@ -148,7 +148,7 @@ async def async_embedding(
148148
client: Optional[AsyncHTTPHandler] = None,
149149
vertex_project: Optional[str] = None,
150150
vertex_location: Optional[str] = None,
151-
vertex_credentials: Optional[str] = None,
151+
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
152152
gemini_api_key: Optional[str] = None,
153153
extra_headers: Optional[dict] = None,
154154
encoding=None,

litellm/llms/vertex_ai/vertex_llm_base.py

+31-23
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from litellm.litellm_core_utils.asyncify import asyncify
1313
from litellm.llms.base import BaseLLM
1414
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
15+
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
1516

1617
from .common_utils import _get_gemini_url, _get_vertex_url, all_gemini_url_modes
1718

@@ -34,37 +35,44 @@ def get_vertex_region(self, vertex_region: Optional[str]) -> str:
3435
return vertex_region or "us-central1"
3536

3637
def load_auth(
37-
self, credentials: Optional[str], project_id: Optional[str]
38+
self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], project_id: Optional[str]
3839
) -> Tuple[Any, str]:
3940
import google.auth as google_auth
4041
from google.auth import identity_pool
4142
from google.auth.transport.requests import (
4243
Request, # type: ignore[import-untyped]
4344
)
4445

45-
if credentials is not None and isinstance(credentials, str):
46+
if credentials is not None:
4647
import google.oauth2.service_account
4748

48-
verbose_logger.debug(
49-
"Vertex: Loading vertex credentials from %s", credentials
50-
)
51-
verbose_logger.debug(
52-
"Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
53-
credentials,
54-
os.path.exists(credentials),
55-
os.getcwd(),
56-
)
49+
if isinstance(credentials, str):
50+
verbose_logger.debug(
51+
"Vertex: Loading vertex credentials from %s", credentials
52+
)
53+
verbose_logger.debug(
54+
"Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
55+
credentials,
56+
os.path.exists(credentials),
57+
os.getcwd(),
58+
)
5759

58-
try:
59-
if os.path.exists(credentials):
60-
json_obj = json.load(open(credentials))
61-
else:
62-
json_obj = json.loads(credentials)
63-
except Exception:
64-
raise Exception(
65-
"Unable to load vertex credentials from environment. Got={}".format(
66-
credentials
60+
try:
61+
if os.path.exists(credentials):
62+
json_obj = json.load(open(credentials))
63+
else:
64+
json_obj = json.loads(credentials)
65+
except Exception:
66+
raise Exception(
67+
"Unable to load vertex credentials from environment. Got={}".format(
68+
credentials
69+
)
6770
)
71+
elif isinstance(credentials, dict):
72+
json_obj = credentials
73+
else:
74+
raise ValueError(
75+
"Invalid credentials type: {}".format(type(credentials))
6876
)
6977

7078
# Check if the JSON object contains Workload Identity Federation configuration
@@ -109,7 +117,7 @@ def refresh_auth(self, credentials: Any) -> None:
109117

110118
def _ensure_access_token(
111119
self,
112-
credentials: Optional[str],
120+
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
113121
project_id: Optional[str],
114122
custom_llm_provider: Literal[
115123
"vertex_ai", "vertex_ai_beta", "gemini"
@@ -202,7 +210,7 @@ def _get_token_and_url(
202210
gemini_api_key: Optional[str],
203211
vertex_project: Optional[str],
204212
vertex_location: Optional[str],
205-
vertex_credentials: Optional[str],
213+
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
206214
stream: Optional[bool],
207215
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
208216
api_base: Optional[str],
@@ -253,7 +261,7 @@ def _get_token_and_url(
253261

254262
async def _ensure_access_token_async(
255263
self,
256-
credentials: Optional[str],
264+
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
257265
project_id: Optional[str],
258266
custom_llm_provider: Literal[
259267
"vertex_ai", "vertex_ai_beta", "gemini"

litellm/proxy/vertex_ai_endpoints/vertex_passthrough_router.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
77
VertexPassThroughCredentials,
88
)
9+
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
910

1011

1112
class VertexPassThroughRouter:
@@ -58,7 +59,7 @@ def add_vertex_credentials(
5859
self,
5960
project_id: str,
6061
location: str,
61-
vertex_credentials: str,
62+
vertex_credentials: VERTEX_CREDENTIALS_TYPES,
6263
):
6364
"""
6465
Add the vertex credentials for the given project-id, location

litellm/types/llms/vertex_ai.py

+3
Original file line numberDiff line numberDiff line change
@@ -481,3 +481,6 @@ class VertexBatchPredictionResponse(TypedDict, total=False):
481481
createTime: str
482482
updateTime: str
483483
modelVersionId: str
484+
485+
486+
VERTEX_CREDENTIALS_TYPES = Union[str, Dict[str, str]]

litellm/types/passthrough_endpoints/vertex_ai.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from pydantic import BaseModel
88

9+
from ..llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
10+
911

1012
class VertexPassThroughCredentials(BaseModel):
1113
# Example: vertex_project = "my-project-123"
@@ -15,4 +17,4 @@ class VertexPassThroughCredentials(BaseModel):
1517
vertex_location: Optional[str] = None
1618

1719
# Example: vertex_credentials = "/path/to/credentials.json" or "os.environ/GOOGLE_CREDS"
18-
vertex_credentials: Optional[str] = None
20+
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None

litellm/types/router.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ..exceptions import RateLimitError
1919
from .completion import CompletionRequest
2020
from .embedding import EmbeddingRequest
21+
from .llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
2122
from .utils import ModelResponse, ProviderSpecificModelInfo
2223

2324

@@ -171,7 +172,7 @@ class GenericLiteLLMParams(BaseModel):
171172
## VERTEX AI ##
172173
vertex_project: Optional[str] = None
173174
vertex_location: Optional[str] = None
174-
vertex_credentials: Optional[str] = None
175+
vertex_credentials: Optional[Union[str, dict]] = None
175176
## AWS BEDROCK / SAGEMAKER ##
176177
aws_access_key_id: Optional[str] = None
177178
aws_secret_access_key: Optional[str] = None
@@ -213,7 +214,7 @@ def __init__(
213214
## VERTEX AI ##
214215
vertex_project: Optional[str] = None,
215216
vertex_location: Optional[str] = None,
216-
vertex_credentials: Optional[str] = None,
217+
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
217218
## AWS BEDROCK / SAGEMAKER ##
218219
aws_access_key_id: Optional[str] = None,
219220
aws_secret_access_key: Optional[str] = None,

0 commit comments

Comments
 (0)