Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
27b342a
feat: Enhance error classification and handling in UsageManager and e…
Mirrowel Jul 2, 2025
44481e5
feat: Bump version to 0.6 in pyproject.toml for release
Mirrowel Jul 2, 2025
bab8bf0
feat: Update key release logic in RotatingClient and UsageManager to …
Mirrowel Jul 3, 2025
8267102
feat: Update version to 0.61 in pyproject.toml for release
Mirrowel Jul 3, 2025
9c2c89a
feat: Implement Gemini stream wrapper for improved JSON chunk handlin…
Mirrowel Jul 3, 2025
50f9490
Revert "feat: Implement Gemini stream wrapper for improved JSON chunk…
Mirrowel Jul 3, 2025
5fb237f
fix: Handle missing cooldown values in UsageManager to prevent key lo…
Mirrowel Jul 3, 2025
576a3ec
feat: Safety settings addition(only gemini for now)
Mirrowel Jul 3, 2025
2786fc8
feat: optimize retry logic
Mirrowel Jul 3, 2025
3f1021f
feat: Add async context management to RotatingClient and improve reso…
Mirrowel Jul 3, 2025
72525d5
feat: update proxy to use the new async context management
Mirrowel Jul 3, 2025
5538104
feat: Implement NvidiaProvider for fetching models from NVIDIA API
Mirrowel Jul 3, 2025
33e95e6
feat: Enhance daily reset logic to ensure timezone-aware date handling
Mirrowel Jul 4, 2025
54f0f2c
feat: Implement request logging for API responses and enhance streami…
Mirrowel Jul 4, 2025
c3b2e49
feat: Add default thinking parameter handling for specific Gemini models
Mirrowel Jul 4, 2025
dcaf0eb
feat: Add request payload sanitization to remove unsupported paramete…
Mirrowel Jul 4, 2025
2d13d90
feat: Fix thinking parameter handling for specific Gemini models
Mirrowel Jul 4, 2025
0374254
feat: Enhance request logging to ensure proper UTF-8 encoding in log …
Mirrowel Jul 4, 2025
21ffa62
feat: Bump version to 0.6.7 in pyproject.toml
Mirrowel Jul 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 125 additions & 15 deletions src/proxy_app/main.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.responses import StreamingResponse
from fastapi.security import APIKeyHeader
from dotenv import load_dotenv
import logging
from pathlib import Path
import sys
import json
from typing import AsyncGenerator, Any

# Add the 'src' directory to the Python path to allow importing 'rotating_api_key_client'
sys.path.append(str(Path(__file__).resolve().parent.parent))

from rotator_library import RotatingClient, PROVIDER_PLUGINS
from .request_logger import log_request_response

# Configure logging
logging.basicConfig(level=logging.INFO) #-> moved to the rotator_library
logging.basicConfig(level=logging.INFO)

# Load environment variables from .env file
load_dotenv()

# --- Configuration ---
ENABLE_REQUEST_LOGGING = False # Set to False to disable request/response logging
PROXY_API_KEY = os.getenv("PROXY_API_KEY")
if not PROXY_API_KEY:
raise ValueError("PROXY_API_KEY environment variable not set.")
Expand All @@ -37,51 +42,152 @@
if not api_keys:
raise ValueError("No provider API keys found in environment variables.")

# Initialize the rotating client
rotating_client = RotatingClient(api_keys=api_keys)
# --- Lifespan Management ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage the RotatingClient's lifecycle with the app's lifespan."""
app.state.rotating_client = RotatingClient(api_keys=api_keys)
print("RotatingClient initialized.")
yield
await app.state.rotating_client.close()
print("RotatingClient closed.")

# --- FastAPI App Setup ---
app = FastAPI()
app = FastAPI(lifespan=lifespan)
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)

def get_rotating_client(request: Request) -> RotatingClient:
"""Dependency to get the rotating client instance from the app state."""
return request.app.state.rotating_client

async def verify_api_key(auth: str = Depends(api_key_header)):
"""Dependency to verify the proxy API key."""
if not auth or auth != f"Bearer {PROXY_API_KEY}":
raise HTTPException(status_code=401, detail="Invalid or missing API Key")
return auth

async def streaming_response_wrapper(
request_data: dict,
response_stream: AsyncGenerator[str, None]
) -> AsyncGenerator[str, None]:
"""
Wraps a streaming response to log the full response after completion.
"""
response_chunks = []
full_response = {}
try:
async for chunk_str in response_stream:
yield chunk_str
# Process chunk for logging
if chunk_str.strip() and chunk_str.startswith("data:"):
content = chunk_str[len("data:"):].strip()
if content != "[DONE]":
try:
chunk_data = json.loads(content)
response_chunks.append(chunk_data)
except json.JSONDecodeError:
# Ignore non-json chunks if any
pass
finally:
# Reconstruct the full response object from chunks
if response_chunks:
full_content = "".join(
choice["delta"]["content"]
for chunk in response_chunks
if "choices" in chunk and chunk["choices"]
for choice in chunk["choices"]
if "delta" in choice and "content" in choice["delta"] and choice["delta"]["content"]
)

# Take metadata from the first chunk and construct a single choice object
first_chunk = response_chunks[0]
final_choice = {
"index": 0,
"message": {
"role": "assistant",
"content": full_content,
},
"finish_reason": "stop", # Assuming 'stop' as stream ended
}

full_response = {
"id": first_chunk.get("id"),
"object": "chat.completion", # Final object is a completion, not a chunk
"created": first_chunk.get("created"),
"model": first_chunk.get("model"),
"choices": [final_choice],
"usage": None # Usage is not typically available in the stream itself
}

if ENABLE_REQUEST_LOGGING:
log_request_response(
request_data=request_data,
response_data=full_response,
is_streaming=True
)

@app.post("/v1/chat/completions")
async def chat_completions(request: Request, _=Depends(verify_api_key)):
async def chat_completions(
request: Request,
client: RotatingClient = Depends(get_rotating_client),
_ = Depends(verify_api_key)
):
"""
OpenAI-compatible endpoint powered by the RotatingClient.
Handles both streaming and non-streaming responses.
Handles both streaming and non-streaming responses and logs them.
"""
try:
data = await request.json()
is_streaming = data.get("stream", False)
response = await rotating_client.acompletion(**data)
request_data = await request.json()
is_streaming = request_data.get("stream", False)

response = await client.acompletion(**request_data)

if is_streaming:
return StreamingResponse(response, media_type="text/event-stream")
# Wrap the streaming response to enable logging after it's complete
return StreamingResponse(
streaming_response_wrapper(request_data, response),
media_type="text/event-stream"
)
else:
# For non-streaming, log immediately
if ENABLE_REQUEST_LOGGING:
log_request_response(
request_data=request_data,
response_data=response.dict(),
is_streaming=False
)
return response

except Exception as e:
logging.error(f"Request failed after all retries: {e}")
# Optionally log the failed request
if ENABLE_REQUEST_LOGGING:
try:
request_data = await request.json()
except json.JSONDecodeError:
request_data = {"error": "Could not parse request body"}
log_request_response(
request_data=request_data,
response_data={"error": str(e)},
is_streaming=request_data.get("stream", False)
)
raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
def read_root():
return {"Status": "API Key Proxy is running"}

@app.get("/v1/models")
async def list_models(grouped: bool = False, _=Depends(verify_api_key)):
async def list_models(
grouped: bool = False,
client: RotatingClient = Depends(get_rotating_client),
_=Depends(verify_api_key)
):
"""
Returns a list of available models from all configured providers.
Optionally returns them as a flat list if grouped=False.
"""
models = await rotating_client.get_all_available_models(grouped=grouped)
models = await client.get_all_available_models(grouped=grouped)
return models

@app.get("/v1/providers")
Expand All @@ -92,7 +198,11 @@ async def list_providers(_=Depends(verify_api_key)):
return list(PROVIDER_PLUGINS.keys())

@app.post("/v1/token-count")
async def token_count(request: Request, _=Depends(verify_api_key)):
async def token_count(
request: Request,
client: RotatingClient = Depends(get_rotating_client),
_=Depends(verify_api_key)
):
"""
Calculates the token count for a given list of messages and a model.
"""
Expand All @@ -104,7 +214,7 @@ async def token_count(request: Request, _=Depends(verify_api_key)):
if not model or not messages:
raise HTTPException(status_code=400, detail="'model' and 'messages' are required.")

count = rotating_client.token_count(model=model, messages=messages)
count = client.token_count(model=model, messages=messages)
return {"token_count": count}

except Exception as e:
Expand Down
30 changes: 30 additions & 0 deletions src/proxy_app/request_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import json
import os
from datetime import datetime
from pathlib import Path
import uuid

LOGS_DIR = Path(__file__).resolve().parent.parent.parent / "logs"
LOGS_DIR.mkdir(exist_ok=True)

def log_request_response(request_data: dict, response_data: dict, is_streaming: bool):
"""
Logs the request and response data to a single file in the logs directory.
"""
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
unique_id = uuid.uuid4()
filename = LOGS_DIR / f"{timestamp}_{unique_id}.json"

log_content = {
"request": request_data,
"response": response_data,
"is_streaming": is_streaming
}

with open(filename, "w", encoding="utf-8") as f:
json.dump(log_content, f, indent=4, ensure_ascii=False)

except Exception as e:
# In case of logging failure, we don't want to crash the main application
print(f"Error logging request/response: {e}")
Loading