Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions litellm/proxy/google_endpoints/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,23 @@ async def google_generate_content(
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
from litellm.proxy.proxy_server import llm_router
from litellm.proxy.proxy_server import llm_router, general_settings, proxy_config, version
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request

data = await _read_request_body(request=request)
if "model" not in data:
data["model"] = model_name

# Add user authentication metadata for cost tracking
data = await add_litellm_data_to_request(
data=data,
request=request,
user_api_key_dict=user_api_key_dict,
proxy_config=proxy_config,
general_settings=general_settings,
version=version,
)

# call router
if llm_router is None:
raise HTTPException(status_code=500, detail="Router not initialized")
Expand All @@ -51,7 +63,8 @@ async def google_stream_generate_content(
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
from litellm.proxy.proxy_server import llm_router
from litellm.proxy.proxy_server import llm_router, general_settings, proxy_config, version
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request

data = await _read_request_body(request=request)

Expand All @@ -60,10 +73,20 @@ async def google_stream_generate_content(

data["stream"] = True # enforce streaming for this endpoint

# Add user authentication metadata for cost tracking
data = await add_litellm_data_to_request(
data=data,
request=request,
user_api_key_dict=user_api_key_dict,
proxy_config=proxy_config,
general_settings=general_settings,
version=version,
)

# call router
if llm_router is None:
raise HTTPException(status_code=500, detail="Router not initialized")
response = await llm_router.agenerate_content(**data)
response = await llm_router.agenerate_content_stream(**data)

# Check if response is an async iterator (streaming response)
if hasattr(response, "__aiter__"):
Expand Down
143 changes: 134 additions & 9 deletions tests/test_litellm/proxy/google_endpoints/test_google_api_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,14 @@ def test_google_stream_generate_content_endpoint():
# Create a test client
client = TestClient(google_router)

# Mock the router's agenerate_content method to return a stream
mock_stream = AsyncMock()
mock_stream.__aiter__ = lambda self: mock_stream
mock_stream.__anext__.side_effect = StopAsyncIteration
# Mock the router's agenerate_content_stream method to return a stream
async def mock_stream_generator():
yield 'data: {"test": "stream_chunk_1"}\n\n'
yield 'data: {"test": "stream_chunk_2"}\n\n'
yield "data: [DONE]\n\n"

with patch("litellm.proxy.proxy_server.llm_router") as mock_router:
mock_router.agenerate_content = AsyncMock(return_value=mock_stream)
mock_router.agenerate_content_stream = AsyncMock(return_value=mock_stream_generator())

# Send a request to the endpoint
response = client.post(
Expand All @@ -79,9 +80,133 @@ def test_google_stream_generate_content_endpoint():
# Verify the response
assert response.status_code == 200

# Verify that agenerate_content was called with correct parameters
mock_router.agenerate_content.assert_called_once()
call_args = mock_router.agenerate_content.call_args
# Verify that agenerate_content_stream was called with correct parameters
mock_router.agenerate_content_stream.assert_called_once()
call_args = mock_router.agenerate_content_stream.call_args
assert call_args[1]["stream"] is True
assert call_args[1]["model"] == "test-model"
assert call_args[1]["contents"] == [{"role": "user", "parts": [{"text": "Hello"}]}]
assert call_args[1]["contents"] == [{"role": "user", "parts": [{"text": "Hello"}]}]


def test_google_generate_content_with_cost_tracking_metadata():
"""Test that the google_generate_content endpoint includes user metadata for cost tracking"""
try:
from fastapi.testclient import TestClient
from litellm.proxy.google_endpoints.endpoints import router as google_router
from litellm.proxy._types import UserAPIKeyAuth
except ImportError as e:
pytest.skip(f"Skipping test due to missing dependency: {e}")

# Create a test client
client = TestClient(google_router)

# Mock all required proxy server dependencies
with patch("litellm.proxy.proxy_server.llm_router") as mock_router, \
patch("litellm.proxy.proxy_server.general_settings", {}), \
patch("litellm.proxy.proxy_server.proxy_config") as mock_proxy_config, \
patch("litellm.proxy.proxy_server.version", "1.0.0"), \
patch("litellm.proxy.litellm_pre_call_utils.add_litellm_data_to_request") as mock_add_data:

mock_router.agenerate_content = AsyncMock(return_value={"test": "response"})

# Mock add_litellm_data_to_request to return data with metadata
async def mock_add_litellm_data(data, request, user_api_key_dict, proxy_config, general_settings, version):
# Simulate adding user metadata
data["litellm_metadata"] = {
"user_api_key_user_id": "test-user-id",
"user_api_key_team_id": "test-team-id",
"user_api_key": "hashed-key",
}
return data

mock_add_data.side_effect = mock_add_litellm_data

# Send a request to the endpoint
response = client.post(
"/v1beta/models/test-model:generateContent",
json={
"contents": [{"role": "user", "parts": [{"text": "Hello"}]}]
},
headers={"Authorization": "Bearer sk-test-key"}
)

# Verify the response
assert response.status_code == 200

# Verify that add_litellm_data_to_request was called
mock_add_data.assert_called_once()

# Verify that agenerate_content was called with metadata
mock_router.agenerate_content.assert_called_once()
call_args = mock_router.agenerate_content.call_args
called_data = call_args[1]

# Verify that litellm_metadata exists and contains user information
assert "litellm_metadata" in called_data
assert called_data["litellm_metadata"]["user_api_key_user_id"] == "test-user-id"
assert called_data["litellm_metadata"]["user_api_key_team_id"] == "test-team-id"


def test_google_stream_generate_content_with_cost_tracking_metadata():
"""Test that the google_stream_generate_content endpoint includes user metadata for cost tracking"""
try:
from fastapi.testclient import TestClient
from litellm.proxy.google_endpoints.endpoints import router as google_router
except ImportError as e:
pytest.skip(f"Skipping test due to missing dependency: {e}")

# Create a test client
client = TestClient(google_router)

# Mock the router's agenerate_content_stream method to return a stream
mock_stream = AsyncMock()
mock_stream.__aiter__ = lambda self: mock_stream
mock_stream.__anext__.side_effect = StopAsyncIteration

# Mock all required proxy server dependencies
with patch("litellm.proxy.proxy_server.llm_router") as mock_router, \
patch("litellm.proxy.proxy_server.general_settings", {}), \
patch("litellm.proxy.proxy_server.proxy_config") as mock_proxy_config, \
patch("litellm.proxy.proxy_server.version", "1.0.0"), \
patch("litellm.proxy.litellm_pre_call_utils.add_litellm_data_to_request") as mock_add_data:

mock_router.agenerate_content_stream = AsyncMock(return_value=mock_stream)

# Mock add_litellm_data_to_request to return data with metadata
async def mock_add_litellm_data(data, request, user_api_key_dict, proxy_config, general_settings, version):
# Simulate adding user metadata
data["litellm_metadata"] = {
"user_api_key_user_id": "test-user-id",
"user_api_key_team_id": "test-team-id",
"user_api_key": "hashed-key",
}
return data

mock_add_data.side_effect = mock_add_litellm_data

# Send a request to the endpoint
response = client.post(
"/v1beta/models/test-model:streamGenerateContent",
json={
"contents": [{"role": "user", "parts": [{"text": "Hello"}]}]
},
headers={"Authorization": "Bearer sk-test-key"}
)

# Verify the response
assert response.status_code == 200

# Verify that add_litellm_data_to_request was called
mock_add_data.assert_called_once()

# Verify that agenerate_content_stream was called with metadata
mock_router.agenerate_content_stream.assert_called_once()
call_args = mock_router.agenerate_content_stream.call_args
called_data = call_args[1]

# Verify that litellm_metadata exists and contains user information
assert "litellm_metadata" in called_data
assert called_data["litellm_metadata"]["user_api_key_user_id"] == "test-user-id"
assert called_data["litellm_metadata"]["user_api_key_team_id"] == "test-team-id"
# Verify stream is set to True
assert called_data["stream"] is True
Loading