12
12
from litellm .litellm_core_utils .asyncify import asyncify
13
13
from litellm .llms .base import BaseLLM
14
14
from litellm .llms .custom_httpx .http_handler import AsyncHTTPHandler
15
+ from litellm .types .llms .vertex_ai import VERTEX_CREDENTIALS_TYPES
15
16
16
17
from .common_utils import _get_gemini_url , _get_vertex_url , all_gemini_url_modes
17
18
@@ -34,37 +35,44 @@ def get_vertex_region(self, vertex_region: Optional[str]) -> str:
34
35
return vertex_region or "us-central1"
35
36
36
37
def load_auth (
37
- self , credentials : Optional [str ], project_id : Optional [str ]
38
+ self , credentials : Optional [VERTEX_CREDENTIALS_TYPES ], project_id : Optional [str ]
38
39
) -> Tuple [Any , str ]:
39
40
import google .auth as google_auth
40
41
from google .auth import identity_pool
41
42
from google .auth .transport .requests import (
42
43
Request , # type: ignore[import-untyped]
43
44
)
44
45
45
- if credentials is not None and isinstance ( credentials , str ) :
46
+ if credentials is not None :
46
47
import google .oauth2 .service_account
47
48
48
- verbose_logger .debug (
49
- "Vertex: Loading vertex credentials from %s" , credentials
50
- )
51
- verbose_logger .debug (
52
- "Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s" ,
53
- credentials ,
54
- os .path .exists (credentials ),
55
- os .getcwd (),
56
- )
49
+ if isinstance (credentials , str ):
50
+ verbose_logger .debug (
51
+ "Vertex: Loading vertex credentials from %s" , credentials
52
+ )
53
+ verbose_logger .debug (
54
+ "Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s" ,
55
+ credentials ,
56
+ os .path .exists (credentials ),
57
+ os .getcwd (),
58
+ )
57
59
58
- try :
59
- if os .path .exists (credentials ):
60
- json_obj = json .load (open (credentials ))
61
- else :
62
- json_obj = json .loads (credentials )
63
- except Exception :
64
- raise Exception (
65
- "Unable to load vertex credentials from environment. Got={}" .format (
66
- credentials
60
+ try :
61
+ if os .path .exists (credentials ):
62
+ json_obj = json .load (open (credentials ))
63
+ else :
64
+ json_obj = json .loads (credentials )
65
+ except Exception :
66
+ raise Exception (
67
+ "Unable to load vertex credentials from environment. Got={}" .format (
68
+ credentials
69
+ )
67
70
)
71
+ elif isinstance (credentials , dict ):
72
+ json_obj = credentials
73
+ else :
74
+ raise ValueError (
75
+ "Invalid credentials type: {}" .format (type (credentials ))
68
76
)
69
77
70
78
# Check if the JSON object contains Workload Identity Federation configuration
@@ -109,7 +117,7 @@ def refresh_auth(self, credentials: Any) -> None:
109
117
110
118
def _ensure_access_token (
111
119
self ,
112
- credentials : Optional [str ],
120
+ credentials : Optional [VERTEX_CREDENTIALS_TYPES ],
113
121
project_id : Optional [str ],
114
122
custom_llm_provider : Literal [
115
123
"vertex_ai" , "vertex_ai_beta" , "gemini"
@@ -202,7 +210,7 @@ def _get_token_and_url(
202
210
gemini_api_key : Optional [str ],
203
211
vertex_project : Optional [str ],
204
212
vertex_location : Optional [str ],
205
- vertex_credentials : Optional [str ],
213
+ vertex_credentials : Optional [VERTEX_CREDENTIALS_TYPES ],
206
214
stream : Optional [bool ],
207
215
custom_llm_provider : Literal ["vertex_ai" , "vertex_ai_beta" , "gemini" ],
208
216
api_base : Optional [str ],
@@ -253,7 +261,7 @@ def _get_token_and_url(
253
261
254
262
async def _ensure_access_token_async (
255
263
self ,
256
- credentials : Optional [str ],
264
+ credentials : Optional [VERTEX_CREDENTIALS_TYPES ],
257
265
project_id : Optional [str ],
258
266
custom_llm_provider : Literal [
259
267
"vertex_ai" , "vertex_ai_beta" , "gemini"
0 commit comments