Skip to content

Commit e489273

Browse files
authored
fix gemini cli by actually streaming the response (#15264)
* fix gemini cli by actually streaming the response * fix cost tracking * fix test
1 parent 73f9671 commit e489273

File tree

2 files changed

+160
-12
lines changed

2 files changed

+160
-12
lines changed

litellm/proxy/google_endpoints/endpoints.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,23 @@ async def google_generate_content(
2525
fastapi_response: Response,
2626
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
2727
):
28-
from litellm.proxy.proxy_server import llm_router
28+
from litellm.proxy.proxy_server import llm_router, general_settings, proxy_config, version
29+
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
2930

3031
data = await _read_request_body(request=request)
3132
if "model" not in data:
3233
data["model"] = model_name
34+
35+
# Add user authentication metadata for cost tracking
36+
data = await add_litellm_data_to_request(
37+
data=data,
38+
request=request,
39+
user_api_key_dict=user_api_key_dict,
40+
proxy_config=proxy_config,
41+
general_settings=general_settings,
42+
version=version,
43+
)
44+
3345
# call router
3446
if llm_router is None:
3547
raise HTTPException(status_code=500, detail="Router not initialized")
@@ -51,7 +63,8 @@ async def google_stream_generate_content(
5163
fastapi_response: Response,
5264
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
5365
):
54-
from litellm.proxy.proxy_server import llm_router
66+
from litellm.proxy.proxy_server import llm_router, general_settings, proxy_config, version
67+
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
5568

5669
data = await _read_request_body(request=request)
5770

@@ -60,10 +73,20 @@ async def google_stream_generate_content(
6073

6174
data["stream"] = True # enforce streaming for this endpoint
6275

76+
# Add user authentication metadata for cost tracking
77+
data = await add_litellm_data_to_request(
78+
data=data,
79+
request=request,
80+
user_api_key_dict=user_api_key_dict,
81+
proxy_config=proxy_config,
82+
general_settings=general_settings,
83+
version=version,
84+
)
85+
6386
# call router
6487
if llm_router is None:
6588
raise HTTPException(status_code=500, detail="Router not initialized")
66-
response = await llm_router.agenerate_content(**data)
89+
response = await llm_router.agenerate_content_stream(**data)
6790

6891
# Check if response is an async iterator (streaming response)
6992
if hasattr(response, "__aiter__"):

tests/test_litellm/proxy/google_endpoints/test_google_api_endpoints.py

Lines changed: 134 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,14 @@ def test_google_stream_generate_content_endpoint():
6060
# Create a test client
6161
client = TestClient(google_router)
6262

63-
# Mock the router's agenerate_content method to return a stream
64-
mock_stream = AsyncMock()
65-
mock_stream.__aiter__ = lambda self: mock_stream
66-
mock_stream.__anext__.side_effect = StopAsyncIteration
63+
# Mock the router's agenerate_content_stream method to return a stream
64+
async def mock_stream_generator():
65+
yield 'data: {"test": "stream_chunk_1"}\n\n'
66+
yield 'data: {"test": "stream_chunk_2"}\n\n'
67+
yield "data: [DONE]\n\n"
6768

6869
with patch("litellm.proxy.proxy_server.llm_router") as mock_router:
69-
mock_router.agenerate_content = AsyncMock(return_value=mock_stream)
70+
mock_router.agenerate_content_stream = AsyncMock(return_value=mock_stream_generator())
7071

7172
# Send a request to the endpoint
7273
response = client.post(
@@ -79,9 +80,133 @@ def test_google_stream_generate_content_endpoint():
7980
# Verify the response
8081
assert response.status_code == 200
8182

82-
# Verify that agenerate_content was called with correct parameters
83-
mock_router.agenerate_content.assert_called_once()
84-
call_args = mock_router.agenerate_content.call_args
83+
# Verify that agenerate_content_stream was called with correct parameters
84+
mock_router.agenerate_content_stream.assert_called_once()
85+
call_args = mock_router.agenerate_content_stream.call_args
8586
assert call_args[1]["stream"] is True
8687
assert call_args[1]["model"] == "test-model"
87-
assert call_args[1]["contents"] == [{"role": "user", "parts": [{"text": "Hello"}]}]
88+
assert call_args[1]["contents"] == [{"role": "user", "parts": [{"text": "Hello"}]}]
89+
90+
91+
def test_google_generate_content_with_cost_tracking_metadata():
92+
"""Test that the google_generate_content endpoint includes user metadata for cost tracking"""
93+
try:
94+
from fastapi.testclient import TestClient
95+
from litellm.proxy.google_endpoints.endpoints import router as google_router
96+
from litellm.proxy._types import UserAPIKeyAuth
97+
except ImportError as e:
98+
pytest.skip(f"Skipping test due to missing dependency: {e}")
99+
100+
# Create a test client
101+
client = TestClient(google_router)
102+
103+
# Mock all required proxy server dependencies
104+
with patch("litellm.proxy.proxy_server.llm_router") as mock_router, \
105+
patch("litellm.proxy.proxy_server.general_settings", {}), \
106+
patch("litellm.proxy.proxy_server.proxy_config") as mock_proxy_config, \
107+
patch("litellm.proxy.proxy_server.version", "1.0.0"), \
108+
patch("litellm.proxy.litellm_pre_call_utils.add_litellm_data_to_request") as mock_add_data:
109+
110+
mock_router.agenerate_content = AsyncMock(return_value={"test": "response"})
111+
112+
# Mock add_litellm_data_to_request to return data with metadata
113+
async def mock_add_litellm_data(data, request, user_api_key_dict, proxy_config, general_settings, version):
114+
# Simulate adding user metadata
115+
data["litellm_metadata"] = {
116+
"user_api_key_user_id": "test-user-id",
117+
"user_api_key_team_id": "test-team-id",
118+
"user_api_key": "hashed-key",
119+
}
120+
return data
121+
122+
mock_add_data.side_effect = mock_add_litellm_data
123+
124+
# Send a request to the endpoint
125+
response = client.post(
126+
"/v1beta/models/test-model:generateContent",
127+
json={
128+
"contents": [{"role": "user", "parts": [{"text": "Hello"}]}]
129+
},
130+
headers={"Authorization": "Bearer sk-test-key"}
131+
)
132+
133+
# Verify the response
134+
assert response.status_code == 200
135+
136+
# Verify that add_litellm_data_to_request was called
137+
mock_add_data.assert_called_once()
138+
139+
# Verify that agenerate_content was called with metadata
140+
mock_router.agenerate_content.assert_called_once()
141+
call_args = mock_router.agenerate_content.call_args
142+
called_data = call_args[1]
143+
144+
# Verify that litellm_metadata exists and contains user information
145+
assert "litellm_metadata" in called_data
146+
assert called_data["litellm_metadata"]["user_api_key_user_id"] == "test-user-id"
147+
assert called_data["litellm_metadata"]["user_api_key_team_id"] == "test-team-id"
148+
149+
150+
def test_google_stream_generate_content_with_cost_tracking_metadata():
151+
"""Test that the google_stream_generate_content endpoint includes user metadata for cost tracking"""
152+
try:
153+
from fastapi.testclient import TestClient
154+
from litellm.proxy.google_endpoints.endpoints import router as google_router
155+
except ImportError as e:
156+
pytest.skip(f"Skipping test due to missing dependency: {e}")
157+
158+
# Create a test client
159+
client = TestClient(google_router)
160+
161+
# Mock the router's agenerate_content_stream method to return a stream
162+
mock_stream = AsyncMock()
163+
mock_stream.__aiter__ = lambda self: mock_stream
164+
mock_stream.__anext__.side_effect = StopAsyncIteration
165+
166+
# Mock all required proxy server dependencies
167+
with patch("litellm.proxy.proxy_server.llm_router") as mock_router, \
168+
patch("litellm.proxy.proxy_server.general_settings", {}), \
169+
patch("litellm.proxy.proxy_server.proxy_config") as mock_proxy_config, \
170+
patch("litellm.proxy.proxy_server.version", "1.0.0"), \
171+
patch("litellm.proxy.litellm_pre_call_utils.add_litellm_data_to_request") as mock_add_data:
172+
173+
mock_router.agenerate_content_stream = AsyncMock(return_value=mock_stream)
174+
175+
# Mock add_litellm_data_to_request to return data with metadata
176+
async def mock_add_litellm_data(data, request, user_api_key_dict, proxy_config, general_settings, version):
177+
# Simulate adding user metadata
178+
data["litellm_metadata"] = {
179+
"user_api_key_user_id": "test-user-id",
180+
"user_api_key_team_id": "test-team-id",
181+
"user_api_key": "hashed-key",
182+
}
183+
return data
184+
185+
mock_add_data.side_effect = mock_add_litellm_data
186+
187+
# Send a request to the endpoint
188+
response = client.post(
189+
"/v1beta/models/test-model:streamGenerateContent",
190+
json={
191+
"contents": [{"role": "user", "parts": [{"text": "Hello"}]}]
192+
},
193+
headers={"Authorization": "Bearer sk-test-key"}
194+
)
195+
196+
# Verify the response
197+
assert response.status_code == 200
198+
199+
# Verify that add_litellm_data_to_request was called
200+
mock_add_data.assert_called_once()
201+
202+
# Verify that agenerate_content_stream was called with metadata
203+
mock_router.agenerate_content_stream.assert_called_once()
204+
call_args = mock_router.agenerate_content_stream.call_args
205+
called_data = call_args[1]
206+
207+
# Verify that litellm_metadata exists and contains user information
208+
assert "litellm_metadata" in called_data
209+
assert called_data["litellm_metadata"]["user_api_key_user_id"] == "test-user-id"
210+
assert called_data["litellm_metadata"]["user_api_key_team_id"] == "test-team-id"
211+
# Verify stream is set to True
212+
assert called_data["stream"] is True

0 commit comments

Comments
 (0)