diff --git a/litellm/proxy/google_endpoints/endpoints.py b/litellm/proxy/google_endpoints/endpoints.py index 51c6d5ab6344..b6c0046a8910 100644 --- a/litellm/proxy/google_endpoints/endpoints.py +++ b/litellm/proxy/google_endpoints/endpoints.py @@ -25,11 +25,23 @@ async def google_generate_content( fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): - from litellm.proxy.proxy_server import llm_router + from litellm.proxy.proxy_server import llm_router, general_settings, proxy_config, version + from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request data = await _read_request_body(request=request) if "model" not in data: data["model"] = model_name + + # Add user authentication metadata for cost tracking + data = await add_litellm_data_to_request( + data=data, + request=request, + user_api_key_dict=user_api_key_dict, + proxy_config=proxy_config, + general_settings=general_settings, + version=version, + ) + # call router if llm_router is None: raise HTTPException(status_code=500, detail="Router not initialized") @@ -51,7 +63,8 @@ async def google_stream_generate_content( fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): - from litellm.proxy.proxy_server import llm_router + from litellm.proxy.proxy_server import llm_router, general_settings, proxy_config, version + from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request data = await _read_request_body(request=request) @@ -60,10 +73,20 @@ async def google_stream_generate_content( data["stream"] = True # enforce streaming for this endpoint + # Add user authentication metadata for cost tracking + data = await add_litellm_data_to_request( + data=data, + request=request, + user_api_key_dict=user_api_key_dict, + proxy_config=proxy_config, + general_settings=general_settings, + version=version, + ) + # call router if llm_router is None: raise HTTPException(status_code=500, detail="Router not initialized") - response = await llm_router.agenerate_content(**data) + response = await llm_router.agenerate_content_stream(**data) # Check if response is an async iterator (streaming response) if hasattr(response, "__aiter__"): diff --git a/tests/test_litellm/proxy/google_endpoints/test_google_api_endpoints.py b/tests/test_litellm/proxy/google_endpoints/test_google_api_endpoints.py index 62e8aaf2794a..6e1c6e11be0b 100644 --- a/tests/test_litellm/proxy/google_endpoints/test_google_api_endpoints.py +++ b/tests/test_litellm/proxy/google_endpoints/test_google_api_endpoints.py @@ -60,13 +60,14 @@ def test_google_stream_generate_content_endpoint(): # Create a test client client = TestClient(google_router) - # Mock the router's agenerate_content method to return a stream - mock_stream = AsyncMock() - mock_stream.__aiter__ = lambda self: mock_stream - mock_stream.__anext__.side_effect = StopAsyncIteration + # Mock the router's agenerate_content_stream method to return a stream + async def mock_stream_generator(): + yield 'data: {"test": "stream_chunk_1"}\n\n' + yield 'data: {"test": "stream_chunk_2"}\n\n' + yield "data: [DONE]\n\n" with patch("litellm.proxy.proxy_server.llm_router") as mock_router: - mock_router.agenerate_content = AsyncMock(return_value=mock_stream) + mock_router.agenerate_content_stream = AsyncMock(return_value=mock_stream_generator()) # Send a request to the endpoint response = client.post( @@ -79,9 +80,133 @@ def test_google_stream_generate_content_endpoint(): # Verify the response assert response.status_code == 200 - # Verify that agenerate_content was called with correct parameters - mock_router.agenerate_content.assert_called_once() - call_args = mock_router.agenerate_content.call_args + # Verify that agenerate_content_stream was called with correct parameters + mock_router.agenerate_content_stream.assert_called_once() + call_args = mock_router.agenerate_content_stream.call_args assert call_args[1]["stream"] is True assert call_args[1]["model"] == "test-model" - assert call_args[1]["contents"] == [{"role": "user", "parts": [{"text": "Hello"}]}] \ No newline at end of file + assert call_args[1]["contents"] == [{"role": "user", "parts": [{"text": "Hello"}]}] + + +def test_google_generate_content_with_cost_tracking_metadata(): + """Test that the google_generate_content endpoint includes user metadata for cost tracking""" + try: + from fastapi.testclient import TestClient + from litellm.proxy.google_endpoints.endpoints import router as google_router + from litellm.proxy._types import UserAPIKeyAuth + except ImportError as e: + pytest.skip(f"Skipping test due to missing dependency: {e}") + + # Create a test client + client = TestClient(google_router) + + # Mock all required proxy server dependencies + with patch("litellm.proxy.proxy_server.llm_router") as mock_router, \ + patch("litellm.proxy.proxy_server.general_settings", {}), \ + patch("litellm.proxy.proxy_server.proxy_config") as mock_proxy_config, \ + patch("litellm.proxy.proxy_server.version", "1.0.0"), \ + patch("litellm.proxy.litellm_pre_call_utils.add_litellm_data_to_request") as mock_add_data: + + mock_router.agenerate_content = AsyncMock(return_value={"test": "response"}) + + # Mock add_litellm_data_to_request to return data with metadata + async def mock_add_litellm_data(data, request, user_api_key_dict, proxy_config, general_settings, version): + # Simulate adding user metadata + data["litellm_metadata"] = { + "user_api_key_user_id": "test-user-id", + "user_api_key_team_id": "test-team-id", + "user_api_key": "hashed-key", + } + return data + + mock_add_data.side_effect = mock_add_litellm_data + + # Send a request to the endpoint + response = client.post( + "/v1beta/models/test-model:generateContent", + json={ + "contents": [{"role": "user", "parts": [{"text": "Hello"}]}] + }, + headers={"Authorization": "Bearer sk-test-key"} + ) + + # Verify the response + assert response.status_code == 200 + + # Verify that add_litellm_data_to_request was called + mock_add_data.assert_called_once() + + # Verify that agenerate_content was called with metadata + mock_router.agenerate_content.assert_called_once() + call_args = mock_router.agenerate_content.call_args + called_data = call_args[1] + + # Verify that litellm_metadata exists and contains user information + assert "litellm_metadata" in called_data + assert called_data["litellm_metadata"]["user_api_key_user_id"] == "test-user-id" + assert called_data["litellm_metadata"]["user_api_key_team_id"] == "test-team-id" + + +def test_google_stream_generate_content_with_cost_tracking_metadata(): + """Test that the google_stream_generate_content endpoint includes user metadata for cost tracking""" + try: + from fastapi.testclient import TestClient + from litellm.proxy.google_endpoints.endpoints import router as google_router + except ImportError as e: + pytest.skip(f"Skipping test due to missing dependency: {e}") + + # Create a test client + client = TestClient(google_router) + + # Mock the router's agenerate_content_stream method to return a stream + mock_stream = AsyncMock() + mock_stream.__aiter__ = lambda self: mock_stream + mock_stream.__anext__.side_effect = StopAsyncIteration + + # Mock all required proxy server dependencies + with patch("litellm.proxy.proxy_server.llm_router") as mock_router, \ + patch("litellm.proxy.proxy_server.general_settings", {}), \ + patch("litellm.proxy.proxy_server.proxy_config") as mock_proxy_config, \ + patch("litellm.proxy.proxy_server.version", "1.0.0"), \ + patch("litellm.proxy.litellm_pre_call_utils.add_litellm_data_to_request") as mock_add_data: + + mock_router.agenerate_content_stream = AsyncMock(return_value=mock_stream) + + # Mock add_litellm_data_to_request to return data with metadata + async def mock_add_litellm_data(data, request, user_api_key_dict, proxy_config, general_settings, version): + # Simulate adding user metadata + data["litellm_metadata"] = { + "user_api_key_user_id": "test-user-id", + "user_api_key_team_id": "test-team-id", + "user_api_key": "hashed-key", + } + return data + + mock_add_data.side_effect = mock_add_litellm_data + + # Send a request to the endpoint + response = client.post( + "/v1beta/models/test-model:streamGenerateContent", + json={ + "contents": [{"role": "user", "parts": [{"text": "Hello"}]}] + }, + headers={"Authorization": "Bearer sk-test-key"} + ) + + # Verify the response + assert response.status_code == 200 + + # Verify that add_litellm_data_to_request was called + mock_add_data.assert_called_once() + + # Verify that agenerate_content_stream was called with metadata + mock_router.agenerate_content_stream.assert_called_once() + call_args = mock_router.agenerate_content_stream.call_args + called_data = call_args[1] + + # Verify that litellm_metadata exists and contains user information + assert "litellm_metadata" in called_data + assert called_data["litellm_metadata"]["user_api_key_user_id"] == "test-user-id" + assert called_data["litellm_metadata"]["user_api_key_team_id"] == "test-team-id" + # Verify stream is set to True + assert called_data["stream"] is True \ No newline at end of file