@@ -60,13 +60,14 @@ def test_google_stream_generate_content_endpoint():
6060 # Create a test client
6161 client = TestClient (google_router )
6262
63- # Mock the router's agenerate_content method to return a stream
64- mock_stream = AsyncMock ()
65- mock_stream .__aiter__ = lambda self : mock_stream
66- mock_stream .__anext__ .side_effect = StopAsyncIteration
63+ # Mock the router's agenerate_content_stream method to return a stream
64+ async def mock_stream_generator ():
65+ yield 'data: {"test": "stream_chunk_1"}\n \n '
66+ yield 'data: {"test": "stream_chunk_2"}\n \n '
67+ yield "data: [DONE]\n \n "
6768
6869 with patch ("litellm.proxy.proxy_server.llm_router" ) as mock_router :
69- mock_router .agenerate_content = AsyncMock (return_value = mock_stream )
70+ mock_router .agenerate_content_stream = AsyncMock (return_value = mock_stream_generator () )
7071
7172 # Send a request to the endpoint
7273 response = client .post (
@@ -79,9 +80,133 @@ def test_google_stream_generate_content_endpoint():
7980 # Verify the response
8081 assert response .status_code == 200
8182
82- # Verify that agenerate_content was called with correct parameters
83- mock_router .agenerate_content .assert_called_once ()
84- call_args = mock_router .agenerate_content .call_args
83+ # Verify that agenerate_content_stream was called with correct parameters
84+ mock_router .agenerate_content_stream .assert_called_once ()
85+ call_args = mock_router .agenerate_content_stream .call_args
8586 assert call_args [1 ]["stream" ] is True
8687 assert call_args [1 ]["model" ] == "test-model"
87- assert call_args [1 ]["contents" ] == [{"role" : "user" , "parts" : [{"text" : "Hello" }]}]
88+ assert call_args [1 ]["contents" ] == [{"role" : "user" , "parts" : [{"text" : "Hello" }]}]
89+
90+
91+ def test_google_generate_content_with_cost_tracking_metadata ():
92+ """Test that the google_generate_content endpoint includes user metadata for cost tracking"""
93+ try :
94+ from fastapi .testclient import TestClient
95+ from litellm .proxy .google_endpoints .endpoints import router as google_router
96+ from litellm .proxy ._types import UserAPIKeyAuth
97+ except ImportError as e :
98+ pytest .skip (f"Skipping test due to missing dependency: { e } " )
99+
100+ # Create a test client
101+ client = TestClient (google_router )
102+
103+ # Mock all required proxy server dependencies
104+ with patch ("litellm.proxy.proxy_server.llm_router" ) as mock_router , \
105+ patch ("litellm.proxy.proxy_server.general_settings" , {}), \
106+ patch ("litellm.proxy.proxy_server.proxy_config" ) as mock_proxy_config , \
107+ patch ("litellm.proxy.proxy_server.version" , "1.0.0" ), \
108+ patch ("litellm.proxy.litellm_pre_call_utils.add_litellm_data_to_request" ) as mock_add_data :
109+
110+ mock_router .agenerate_content = AsyncMock (return_value = {"test" : "response" })
111+
112+ # Mock add_litellm_data_to_request to return data with metadata
113+ async def mock_add_litellm_data (data , request , user_api_key_dict , proxy_config , general_settings , version ):
114+ # Simulate adding user metadata
115+ data ["litellm_metadata" ] = {
116+ "user_api_key_user_id" : "test-user-id" ,
117+ "user_api_key_team_id" : "test-team-id" ,
118+ "user_api_key" : "hashed-key" ,
119+ }
120+ return data
121+
122+ mock_add_data .side_effect = mock_add_litellm_data
123+
124+ # Send a request to the endpoint
125+ response = client .post (
126+ "/v1beta/models/test-model:generateContent" ,
127+ json = {
128+ "contents" : [{"role" : "user" , "parts" : [{"text" : "Hello" }]}]
129+ },
130+ headers = {"Authorization" : "Bearer sk-test-key" }
131+ )
132+
133+ # Verify the response
134+ assert response .status_code == 200
135+
136+ # Verify that add_litellm_data_to_request was called
137+ mock_add_data .assert_called_once ()
138+
139+ # Verify that agenerate_content was called with metadata
140+ mock_router .agenerate_content .assert_called_once ()
141+ call_args = mock_router .agenerate_content .call_args
142+ called_data = call_args [1 ]
143+
144+ # Verify that litellm_metadata exists and contains user information
145+ assert "litellm_metadata" in called_data
146+ assert called_data ["litellm_metadata" ]["user_api_key_user_id" ] == "test-user-id"
147+ assert called_data ["litellm_metadata" ]["user_api_key_team_id" ] == "test-team-id"
148+
149+
150+ def test_google_stream_generate_content_with_cost_tracking_metadata ():
151+ """Test that the google_stream_generate_content endpoint includes user metadata for cost tracking"""
152+ try :
153+ from fastapi .testclient import TestClient
154+ from litellm .proxy .google_endpoints .endpoints import router as google_router
155+ except ImportError as e :
156+ pytest .skip (f"Skipping test due to missing dependency: { e } " )
157+
158+ # Create a test client
159+ client = TestClient (google_router )
160+
161+ # Mock the router's agenerate_content_stream method to return a stream
162+ mock_stream = AsyncMock ()
163+ mock_stream .__aiter__ = lambda self : mock_stream
164+ mock_stream .__anext__ .side_effect = StopAsyncIteration
165+
166+ # Mock all required proxy server dependencies
167+ with patch ("litellm.proxy.proxy_server.llm_router" ) as mock_router , \
168+ patch ("litellm.proxy.proxy_server.general_settings" , {}), \
169+ patch ("litellm.proxy.proxy_server.proxy_config" ) as mock_proxy_config , \
170+ patch ("litellm.proxy.proxy_server.version" , "1.0.0" ), \
171+ patch ("litellm.proxy.litellm_pre_call_utils.add_litellm_data_to_request" ) as mock_add_data :
172+
173+ mock_router .agenerate_content_stream = AsyncMock (return_value = mock_stream )
174+
175+ # Mock add_litellm_data_to_request to return data with metadata
176+ async def mock_add_litellm_data (data , request , user_api_key_dict , proxy_config , general_settings , version ):
177+ # Simulate adding user metadata
178+ data ["litellm_metadata" ] = {
179+ "user_api_key_user_id" : "test-user-id" ,
180+ "user_api_key_team_id" : "test-team-id" ,
181+ "user_api_key" : "hashed-key" ,
182+ }
183+ return data
184+
185+ mock_add_data .side_effect = mock_add_litellm_data
186+
187+ # Send a request to the endpoint
188+ response = client .post (
189+ "/v1beta/models/test-model:streamGenerateContent" ,
190+ json = {
191+ "contents" : [{"role" : "user" , "parts" : [{"text" : "Hello" }]}]
192+ },
193+ headers = {"Authorization" : "Bearer sk-test-key" }
194+ )
195+
196+ # Verify the response
197+ assert response .status_code == 200
198+
199+ # Verify that add_litellm_data_to_request was called
200+ mock_add_data .assert_called_once ()
201+
202+ # Verify that agenerate_content_stream was called with metadata
203+ mock_router .agenerate_content_stream .assert_called_once ()
204+ call_args = mock_router .agenerate_content_stream .call_args
205+ called_data = call_args [1 ]
206+
207+ # Verify that litellm_metadata exists and contains user information
208+ assert "litellm_metadata" in called_data
209+ assert called_data ["litellm_metadata" ]["user_api_key_user_id" ] == "test-user-id"
210+ assert called_data ["litellm_metadata" ]["user_api_key_team_id" ] == "test-team-id"
211+ # Verify stream is set to True
212+ assert called_data ["stream" ] is True
0 commit comments