diff --git a/src/proxy_app/main.py b/src/proxy_app/main.py index 2d7bf147..ed04289c 100644 --- a/src/proxy_app/main.py +++ b/src/proxy_app/main.py @@ -1,4 +1,5 @@ import os +from contextlib import asynccontextmanager from fastapi import FastAPI, Request, HTTPException, Depends from fastapi.responses import StreamingResponse from fastapi.security import APIKeyHeader @@ -6,19 +7,23 @@ import logging from pathlib import Path import sys +import json +from typing import AsyncGenerator, Any # Add the 'src' directory to the Python path to allow importing 'rotating_api_key_client' sys.path.append(str(Path(__file__).resolve().parent.parent)) from rotator_library import RotatingClient, PROVIDER_PLUGINS +from .request_logger import log_request_response # Configure logging -logging.basicConfig(level=logging.INFO) #-> moved to the rotator_library +logging.basicConfig(level=logging.INFO) # Load environment variables from .env file load_dotenv() # --- Configuration --- +ENABLE_REQUEST_LOGGING = False # Set to False to disable request/response logging PROXY_API_KEY = os.getenv("PROXY_API_KEY") if not PROXY_API_KEY: raise ValueError("PROXY_API_KEY environment variable not set.") @@ -37,38 +42,135 @@ if not api_keys: raise ValueError("No provider API keys found in environment variables.") -# Initialize the rotating client -rotating_client = RotatingClient(api_keys=api_keys) +# --- Lifespan Management --- +@asynccontextmanager +async def lifespan(app: FastAPI): + """Manage the RotatingClient's lifecycle with the app's lifespan.""" + app.state.rotating_client = RotatingClient(api_keys=api_keys) + print("RotatingClient initialized.") + yield + await app.state.rotating_client.close() + print("RotatingClient closed.") # --- FastAPI App Setup --- -app = FastAPI() +app = FastAPI(lifespan=lifespan) api_key_header = APIKeyHeader(name="Authorization", auto_error=False) +def get_rotating_client(request: Request) -> RotatingClient: + """Dependency to get the rotating client instance from the app state.""" + return request.app.state.rotating_client + async def verify_api_key(auth: str = Depends(api_key_header)): """Dependency to verify the proxy API key.""" if not auth or auth != f"Bearer {PROXY_API_KEY}": raise HTTPException(status_code=401, detail="Invalid or missing API Key") return auth +async def streaming_response_wrapper( + request_data: dict, + response_stream: AsyncGenerator[str, None] +) -> AsyncGenerator[str, None]: + """ + Wraps a streaming response to log the full response after completion. + """ + response_chunks = [] + full_response = {} + try: + async for chunk_str in response_stream: + yield chunk_str + # Process chunk for logging + if chunk_str.strip() and chunk_str.startswith("data:"): + content = chunk_str[len("data:"):].strip() + if content != "[DONE]": + try: + chunk_data = json.loads(content) + response_chunks.append(chunk_data) + except json.JSONDecodeError: + # Ignore non-json chunks if any + pass + finally: + # Reconstruct the full response object from chunks + if response_chunks: + full_content = "".join( + choice["delta"]["content"] + for chunk in response_chunks + if "choices" in chunk and chunk["choices"] + for choice in chunk["choices"] + if "delta" in choice and "content" in choice["delta"] and choice["delta"]["content"] + ) + + # Take metadata from the first chunk and construct a single choice object + first_chunk = response_chunks[0] + final_choice = { + "index": 0, + "message": { + "role": "assistant", + "content": full_content, + }, + "finish_reason": "stop", # Assuming 'stop' as stream ended + } + + full_response = { + "id": first_chunk.get("id"), + "object": "chat.completion", # Final object is a completion, not a chunk + "created": first_chunk.get("created"), + "model": first_chunk.get("model"), + "choices": [final_choice], + "usage": None # Usage is not typically available in the stream itself + } + + if ENABLE_REQUEST_LOGGING: + log_request_response( + request_data=request_data, + response_data=full_response, + is_streaming=True + ) + @app.post("/v1/chat/completions") -async def chat_completions(request: Request, _=Depends(verify_api_key)): +async def chat_completions( + request: Request, + client: RotatingClient = Depends(get_rotating_client), + _ = Depends(verify_api_key) +): """ OpenAI-compatible endpoint powered by the RotatingClient. - Handles both streaming and non-streaming responses. + Handles both streaming and non-streaming responses and logs them. """ try: - data = await request.json() - is_streaming = data.get("stream", False) - - response = await rotating_client.acompletion(**data) + request_data = await request.json() + is_streaming = request_data.get("stream", False) + + response = await client.acompletion(**request_data) if is_streaming: - return StreamingResponse(response, media_type="text/event-stream") + # Wrap the streaming response to enable logging after it's complete + return StreamingResponse( + streaming_response_wrapper(request_data, response), + media_type="text/event-stream" + ) else: + # For non-streaming, log immediately + if ENABLE_REQUEST_LOGGING: + log_request_response( + request_data=request_data, + response_data=response.dict(), + is_streaming=False + ) return response except Exception as e: logging.error(f"Request failed after all retries: {e}") + # Optionally log the failed request + if ENABLE_REQUEST_LOGGING: + try: + request_data = await request.json() + except json.JSONDecodeError: + request_data = {"error": "Could not parse request body"} + log_request_response( + request_data=request_data, + response_data={"error": str(e)}, + is_streaming=request_data.get("stream", False) + ) raise HTTPException(status_code=500, detail=str(e)) @app.get("/") @@ -76,12 +178,16 @@ def read_root(): return {"Status": "API Key Proxy is running"} @app.get("/v1/models") -async def list_models(grouped: bool = False, _=Depends(verify_api_key)): +async def list_models( + grouped: bool = False, + client: RotatingClient = Depends(get_rotating_client), + _=Depends(verify_api_key) +): """ Returns a list of available models from all configured providers. Optionally returns them as a flat list if grouped=False. """ - models = await rotating_client.get_all_available_models(grouped=grouped) + models = await client.get_all_available_models(grouped=grouped) return models @app.get("/v1/providers") @@ -92,7 +198,11 @@ async def list_providers(_=Depends(verify_api_key)): return list(PROVIDER_PLUGINS.keys()) @app.post("/v1/token-count") -async def token_count(request: Request, _=Depends(verify_api_key)): +async def token_count( + request: Request, + client: RotatingClient = Depends(get_rotating_client), + _=Depends(verify_api_key) +): """ Calculates the token count for a given list of messages and a model. """ @@ -104,7 +214,7 @@ async def token_count(request: Request, _=Depends(verify_api_key)): if not model or not messages: raise HTTPException(status_code=400, detail="'model' and 'messages' are required.") - count = rotating_client.token_count(model=model, messages=messages) + count = client.token_count(model=model, messages=messages) return {"token_count": count} except Exception as e: diff --git a/src/proxy_app/request_logger.py b/src/proxy_app/request_logger.py new file mode 100644 index 00000000..5e35441d --- /dev/null +++ b/src/proxy_app/request_logger.py @@ -0,0 +1,30 @@ +import json +import os +from datetime import datetime +from pathlib import Path +import uuid + +LOGS_DIR = Path(__file__).resolve().parent.parent.parent / "logs" +LOGS_DIR.mkdir(exist_ok=True) + +def log_request_response(request_data: dict, response_data: dict, is_streaming: bool): + """ + Logs the request and response data to a single file in the logs directory. + """ + try: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + unique_id = uuid.uuid4() + filename = LOGS_DIR / f"{timestamp}_{unique_id}.json" + + log_content = { + "request": request_data, + "response": response_data, + "is_streaming": is_streaming + } + + with open(filename, "w", encoding="utf-8") as f: + json.dump(log_content, f, indent=4, ensure_ascii=False) + + except Exception as e: + # In case of logging failure, we don't want to crash the main application + print(f"Error logging request/response: {e}") diff --git a/src/rotator_library/client.py b/src/rotator_library/client.py index e7d03b0e..dcc01963 100644 --- a/src/rotator_library/client.py +++ b/src/rotator_library/client.py @@ -6,22 +6,19 @@ import litellm from litellm.litellm_core_utils.token_counter import token_counter import logging -from typing import List, Dict, Any, AsyncGenerator +from typing import List, Dict, Any, AsyncGenerator, Optional, Union -# Set up a dedicated logger for the library lib_logger = logging.getLogger('rotator_library') lib_logger.propagate = False -# You might want to add a handler if you want to see these logs specifically -# For example, a NullHandler to avoid "No handler found" warnings if the -# main app doesn't configure this logger. if not lib_logger.handlers: lib_logger.addHandler(logging.NullHandler()) from .usage_manager import UsageManager from .failure_logger import log_failure -from .error_handler import is_rate_limit_error, is_server_error, is_unrecoverable_error +from .error_handler import classify_error, AllProviders from .providers import PROVIDER_PLUGINS +from .request_sanitizer import sanitize_request_payload class RotatingClient: """ @@ -31,6 +28,7 @@ class RotatingClient: def __init__(self, api_keys: Dict[str, List[str]], max_retries: int = 2, usage_file_path: str = "key_usage.json"): os.environ["LITELLM_LOG"] = "ERROR" litellm.set_verbose = False + litellm.drop_params = True if not api_keys: raise ValueError("API keys dictionary cannot be empty.") self.api_keys = api_keys @@ -41,36 +39,56 @@ def __init__(self, api_keys: Dict[str, List[str]], max_retries: int = 2, usage_f name: plugin() for name, plugin in PROVIDER_PLUGINS.items() } self.http_client = httpx.AsyncClient() + self.all_providers = AllProviders() - async def _streaming_wrapper(self, stream: Any, key: str, model: str) -> AsyncGenerator[Any, None]: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + async def close(self): + """Close the HTTP client to prevent resource leaks.""" + if hasattr(self, 'http_client') and self.http_client: + await self.http_client.aclose() + + async def _safe_streaming_wrapper(self, stream: Any, key: str, model: str) -> AsyncGenerator[Any, None]: """ - A wrapper for streaming responses that formats the output as OpenAI-compatible - Server-Sent Events (SSE) and records usage. + A definitive hybrid wrapper for streaming responses that ensures usage is recorded + and the key lock is released only after the stream is fully consumed. + It exhaustively checks for usage data in all possible locations. """ + usage_recorded = False + stream_completed = False try: async for chunk in stream: - #lib_logger.info(f"STREAM CHUNK: {chunk}") - # Convert the litellm chunk object to a dictionary - chunk_dict = chunk.dict() - - # Format as a Server-Sent Event - yield f"data: {json.dumps(chunk_dict)}\n\n" - - # Safely check for usage data in the chunk - if hasattr(chunk, 'usage') and chunk.usage: - lib_logger.info(f"Usage found in chunk for key ...{key[-4:]}: {chunk.usage}") + yield f"data: {json.dumps(chunk.dict())}\n\n" + # 1. First, try to find usage in a chunk (for providers that send it mid-stream) + if not usage_recorded and hasattr(chunk, 'usage') and chunk.usage: await self.usage_manager.record_success(key, model, chunk) - + usage_recorded = True + lib_logger.info(f"Recorded usage from stream chunk for key ...{key[-4:]}") + stream_completed = True finally: - # Signal the end of the stream - yield "data: [DONE]\n\n" - lib_logger.info("STREAM FINISHED and [DONE] signal sent.") + # 2. If not found in a chunk, try the final stream object itself (for other providers) + if not usage_recorded: + # This call is now safe because record_success is robust + await self.usage_manager.record_success(key, model, stream) + lib_logger.info(f"Recorded usage from final stream object for key ...{key[-4:]}") + + # 3. Release the key only after all attempts to record usage are complete + await self.usage_manager.release_key(key, model) + lib_logger.info(f"STREAM FINISHED and lock released for key ...{key[-4:]}.") + + # Only yield [DONE] if the stream completed successfully + if stream_completed: + yield "data: [DONE]\n\n" - async def acompletion(self, pre_request_callback: callable = None, **kwargs) -> Any: + async def acompletion(self, pre_request_callback: Optional[callable] = None, **kwargs) -> Union[Any, AsyncGenerator[str, None]]: """ Performs a completion call with smart key rotation and retry logic. - Handles both streaming and non-streaming requests with thread-safe key acquisition. + It will try each available key in sequence if the previous one fails. """ model = kwargs.get("model") is_streaming = kwargs.get("stream", False) @@ -82,81 +100,123 @@ async def acompletion(self, pre_request_callback: callable = None, **kwargs) -> if provider not in self.api_keys: raise ValueError(f"No API keys configured for provider: {provider}") - current_key = None - try: - while True: # Loop to acquire a key and make the call + keys_for_provider = self.api_keys[provider] + tried_keys = set() + last_exception = None + + while len(tried_keys) < len(keys_for_provider): + current_key = None + key_acquired = False + try: + keys_to_try = [k for k in keys_for_provider if k not in tried_keys] + if not keys_to_try: + break + current_key = await self.usage_manager.acquire_key( - available_keys=self.api_keys[provider], + available_keys=keys_to_try, model=model ) + key_acquired = True + tried_keys.add(current_key) + + # Prepare litellm_kwargs once per key, not on every retry + litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy()) + + if provider in self._provider_instances: + provider_instance = self._provider_instances[provider] + + # Ensure safety_settings are present, defaulting to lowest if not provided + if "safety_settings" not in litellm_kwargs: + litellm_kwargs["safety_settings"] = { + "harassment": "BLOCK_NONE", + "hate_speech": "BLOCK_NONE", + "sexually_explicit": "BLOCK_NONE", + "dangerous_content": "BLOCK_NONE", + } + + converted_settings = provider_instance.convert_safety_settings(litellm_kwargs["safety_settings"]) + + if converted_settings is not None: + litellm_kwargs["safety_settings"] = converted_settings + else: + # If conversion returns None, remove it to avoid sending empty settings + del litellm_kwargs["safety_settings"] + + if provider == "gemini": + provider_instance = self._provider_instances[provider] + provider_instance.handle_thinking_parameter(litellm_kwargs, model) + + if "gemma-3" in model and "messages" in litellm_kwargs: + new_messages = [ + {"role": "user", "content": m["content"]} if m.get("role") == "system" else m + for m in litellm_kwargs["messages"] + ] + litellm_kwargs["messages"] = new_messages + + if provider == "chutes": + litellm_kwargs["model"] = f"openai/{model.split('/', 1)[1]}" + litellm_kwargs["api_base"] = "https://llm.chutes.ai/v1" + + litellm_kwargs = sanitize_request_payload(litellm_kwargs, model) for attempt in range(self.max_retries): try: lib_logger.info(f"Attempting call with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})") - litellm_kwargs = kwargs.copy() - if "gemma-3" in model and "messages" in litellm_kwargs: - new_messages = [ - {"role": "user", "content": m["content"]} if m.get("role") == "system" else m - for m in litellm_kwargs["messages"] - ] - litellm_kwargs["messages"] = new_messages - - if provider == "chutes": - litellm_kwargs["model"] = f"openai/{model.split('/', 1)[1]}" - litellm_kwargs["api_base"] = "https://llm.chutes.ai/v1" - if pre_request_callback: await pre_request_callback() response = await litellm.acompletion(api_key=current_key, **litellm_kwargs) if is_streaming: - return self._streaming_wrapper(response, current_key, model) + # The wrapper is now responsible for releasing the key. + key_acquired = False # Transfer responsibility to wrapper + return self._safe_streaming_wrapper(response, current_key, model) else: + # For non-streaming, record and release here. await self.usage_manager.record_success(current_key, model, response) + await self.usage_manager.release_key(current_key, model) + key_acquired = False # Key has been released return response except Exception as e: + last_exception = e log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs) + + classified_error = classify_error(e) + + if classified_error.error_type in ['invalid_request', 'authentication']: + await self.usage_manager.record_failure(current_key, model, classified_error) + await self.usage_manager.release_key(current_key, model) + key_acquired = False # Key has been released + break - if is_unrecoverable_error(e): - lib_logger.error(f"Key ...{current_key[-4:]} failed with unrecoverable error: {e}. Raising exception.") - raise e - - if is_rate_limit_error(e): - lib_logger.warning(f"Key ...{current_key[-4:]} hit a rate limit for model {model}. Rotating key and setting cooldown.") - await self.usage_manager.record_rotation_error(current_key, model, e) - break # Break from retries to get a new key - - if is_server_error(e): + if classified_error.error_type == 'server_error': if attempt < self.max_retries - 1: wait_time = (2 ** attempt) + random.uniform(0, 1) lib_logger.warning(f"Key ...{current_key[-4:]} encountered a server error. Retrying in {wait_time:.2f} seconds...") await asyncio.sleep(wait_time) continue - else: - lib_logger.error(f"Key ...{current_key[-4:]} failed after max retries on a server error. Rotating key.") - await self.usage_manager.record_rotation_error(current_key, model, e) - break - # Fallback for any other unexpected errors - lib_logger.error(f"Key ...{current_key[-4:]} failed with an unexpected error: {e}. Rotating key.") - await self.usage_manager.record_rotation_error(current_key, model, e) + await self.usage_manager.record_failure(current_key, model, classified_error) + await self.usage_manager.release_key(current_key, model) + key_acquired = False # Key has been released break - - # If we exit the retry loop due to failure, release the key and try to get a new one. - await self.usage_manager.release_key(current_key) - current_key = None # Ensure key is not released again in finally + finally: + # This block is now only for handling failures where the key needs to be released + # without a successful response. The wrapper handles the success case for streams. + if key_acquired and current_key: + await self.usage_manager.release_key(current_key, model) - finally: - if current_key: - await self.usage_manager.release_key(current_key) + if last_exception: + raise last_exception + + raise Exception("Failed to complete the request: No available API keys for the provider or all keys failed.") def token_count(self, model: str, text: str = None, messages: List[Dict[str, str]] = None) -> int: - """ - Calculates the number of tokens for a given text or list of messages. - """ + """Calculates the number of tokens for a given text or list of messages.""" + if not model: + raise ValueError("'model' is a required parameter.") if messages: return token_counter(model=model, messages=messages) elif text: @@ -165,9 +225,7 @@ def token_count(self, model: str, text: str = None, messages: List[Dict[str, str raise ValueError("Either 'text' or 'messages' must be provided.") async def get_available_models(self, provider: str) -> List[str]: - """ - Returns a list of available models for a specific provider, with caching. - """ + """Returns a list of available models for a specific provider, with caching.""" lib_logger.info(f"Getting available models for provider: {provider}") if provider in self._model_list_cache: lib_logger.info(f"Returning cached models for provider: {provider}") @@ -180,23 +238,29 @@ async def get_available_models(self, provider: str) -> List[str]: if provider in self._provider_instances: lib_logger.info(f"Calling get_models for provider: {provider}") - models = await self._provider_instances[provider].get_models(api_key, self.http_client) - lib_logger.info(f"Got {len(models)} models for provider: {provider}") - self._model_list_cache[provider] = models - return models + try: + models = await self._provider_instances[provider].get_models(api_key, self.http_client) + lib_logger.info(f"Got {len(models)} models for provider: {provider}") + self._model_list_cache[provider] = models + return models + except Exception as e: + lib_logger.error(f"Failed to get models for provider {provider}: {e}") + return [] else: lib_logger.warning(f"Model list fetching not implemented for provider: {provider}") return [] - async def get_all_available_models(self, grouped: bool = True) -> Any: - """ - Returns a list of all available models, either grouped by provider or as a flat list. - """ + async def get_all_available_models(self, grouped: bool = True) -> Union[Dict[str, List[str]], List[str]]: + """Returns a list of all available models, either grouped by provider or as a flat list.""" lib_logger.info("Getting all available models...") all_provider_models = {} for provider in self.api_keys.keys(): lib_logger.info(f"Getting models for provider: {provider}") - all_provider_models[provider] = await self.get_available_models(provider) + try: + all_provider_models[provider] = await self.get_available_models(provider) + except Exception as e: + lib_logger.error(f"Failed to get models for provider {provider}: {e}") + all_provider_models[provider] = [] lib_logger.info("Finished getting all available models.") if grouped: @@ -205,5 +269,5 @@ async def get_all_available_models(self, grouped: bool = True) -> Any: flat_models = [] for provider, models in all_provider_models.items(): for model in models: - flat_models.append(f"{model}") + flat_models.append(f"{provider}/{model}") return flat_models diff --git a/src/rotator_library/error_handler.py b/src/rotator_library/error_handler.py index 5a35a385..3c235c7d 100644 --- a/src/rotator_library/error_handler.py +++ b/src/rotator_library/error_handler.py @@ -1,4 +1,95 @@ -from litellm.exceptions import APIConnectionError, RateLimitError, ServiceUnavailableError, AuthenticationError, InvalidRequestError +import re +from typing import Optional, Dict, Any + +from litellm.exceptions import APIConnectionError, RateLimitError, ServiceUnavailableError, AuthenticationError, InvalidRequestError, BadRequestError, OpenAIError, InternalServerError + +class ClassifiedError: + """A structured representation of a classified error.""" + def __init__(self, error_type: str, original_exception: Exception, status_code: Optional[int] = None, retry_after: Optional[int] = None): + self.error_type = error_type + self.original_exception = original_exception + self.status_code = status_code + self.retry_after = retry_after + + def __str__(self): + return f"ClassifiedError(type={self.error_type}, status={self.status_code}, retry_after={self.retry_after}, original_exc={self.original_exception})" + +def get_retry_after(error: Exception) -> Optional[int]: + """ + Extracts the 'retry-after' duration in seconds from an exception message. + Handles both integer and string representations of the duration. + """ + error_str = str(error).lower() + + # Common patterns for 'retry-after' + patterns = [ + r'retry after:?\s*(\d+)', + r'retry_after:?\s*(\d+)', + r'retry in\s*(\d+)\s*seconds', + r'wait for\s*(\d+)\s*seconds', + ] + + for pattern in patterns: + match = re.search(pattern, error_str) + if match: + try: + return int(match.group(1)) + except (ValueError, IndexError): + continue + + # Handle cases where the error object itself has the attribute + if hasattr(error, 'retry_after'): + value = getattr(error, 'retry_after') + if isinstance(value, int): + return value + if isinstance(value, str) and value.isdigit(): + return int(value) + + return None + +def classify_error(e: Exception) -> ClassifiedError: + """ + Classifies an exception into a structured ClassifiedError object. + """ + status_code = getattr(e, 'status_code', None) + + if isinstance(e, RateLimitError): + retry_after = get_retry_after(e) + return ClassifiedError( + error_type='rate_limit', + original_exception=e, + status_code=status_code or 429, + retry_after=retry_after + ) + + if isinstance(e, (AuthenticationError,)): + return ClassifiedError( + error_type='authentication', + original_exception=e, + status_code=status_code or 401 + ) + + if isinstance(e, (InvalidRequestError, BadRequestError)): + return ClassifiedError( + error_type='invalid_request', + original_exception=e, + status_code=status_code or 400 + ) + + if isinstance(e, (ServiceUnavailableError, APIConnectionError, OpenAIError, InternalServerError)): + # These are often temporary server-side issues + return ClassifiedError( + error_type='server_error', + original_exception=e, + status_code=status_code or 503 + ) + + # Fallback for any other unclassified errors + return ClassifiedError( + error_type='unknown', + original_exception=e, + status_code=status_code + ) def is_rate_limit_error(e: Exception) -> bool: """Checks if the exception is a rate limit error.""" @@ -6,11 +97,48 @@ def is_rate_limit_error(e: Exception) -> bool: def is_server_error(e: Exception) -> bool: """Checks if the exception is a temporary server-side error.""" - return isinstance(e, (ServiceUnavailableError, APIConnectionError)) + return isinstance(e, (ServiceUnavailableError, APIConnectionError, InternalServerError, OpenAIError)) def is_unrecoverable_error(e: Exception) -> bool: """ Checks if the exception is a non-retriable client-side error. These are errors that will not resolve on their own. """ - return isinstance(e, (InvalidRequestError, AuthenticationError)) + return isinstance(e, (InvalidRequestError, AuthenticationError, BadRequestError)) + +class AllProviders: + """ + A class to handle provider-specific settings, such as custom API bases. + """ + def __init__(self): + self.providers = { + "chutes": { + "api_base": "https://llm.chutes.ai/v1", + "model_prefix": "openai/" + } + } + + def get_provider_kwargs(self, **kwargs) -> Dict[str, Any]: + """ + Returns provider-specific kwargs for a given model. + """ + model = kwargs.get("model") + if not model: + return kwargs + + provider = self._get_provider_from_model(model) + provider_settings = self.providers.get(provider, {}) + + if "api_base" in provider_settings: + kwargs["api_base"] = provider_settings["api_base"] + + if "model_prefix" in provider_settings: + kwargs["model"] = f"{provider_settings['model_prefix']}{model.split('/', 1)[1]}" + + return kwargs + + def _get_provider_from_model(self, model: str) -> str: + """ + Determines the provider from the model name. + """ + return model.split('/')[0] diff --git a/src/rotator_library/providers/gemini_provider.py b/src/rotator_library/providers/gemini_provider.py index 26b8297f..f7c27a35 100644 --- a/src/rotator_library/providers/gemini_provider.py +++ b/src/rotator_library/providers/gemini_provider.py @@ -1,6 +1,6 @@ import httpx import logging -from typing import List +from typing import List, Dict, Any from .provider_interface import ProviderInterface lib_logger = logging.getLogger('rotator_library') @@ -26,3 +26,34 @@ async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str] except httpx.RequestError as e: lib_logger.error(f"Failed to fetch Gemini models: {e}") return [] + + def convert_safety_settings(self, settings: Dict[str, str]) -> List[Dict[str, Any]]: + """ + Converts generic safety settings to the Gemini-specific format. + """ + if not settings: + return [] + + gemini_settings = [] + category_map = { + "harassment": "HARM_CATEGORY_HARASSMENT", + "hate_speech": "HARM_CATEGORY_HATE_SPEECH", + "sexually_explicit": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "dangerous_content": "HARM_CATEGORY_DANGEROUS_CONTENT", + } + + for generic_category, threshold in settings.items(): + if generic_category in category_map: + gemini_settings.append({ + "category": category_map[generic_category], + "threshold": threshold.upper() + }) + + return gemini_settings + + def handle_thinking_parameter(self, payload: Dict[str, Any], model: str): + """ + Adds a default thinking parameter for specific Gemini models if not already present. + """ + if model in ["gemini/gemini-2.5-pro", "gemini/gemini-2.5-flash"] and "thinking" not in payload and "reasoning_effort" not in payload: + payload["thinking"] = {"type": "enabled", "budget_tokens": -1} diff --git a/src/rotator_library/providers/nvidia_provider.py b/src/rotator_library/providers/nvidia_provider.py new file mode 100644 index 00000000..e9718291 --- /dev/null +++ b/src/rotator_library/providers/nvidia_provider.py @@ -0,0 +1,28 @@ +import httpx +import logging +from typing import List +from .provider_interface import ProviderInterface + +lib_logger = logging.getLogger('rotator_library') +lib_logger.propagate = False # Ensure this logger doesn't propagate to root +if not lib_logger.handlers: + lib_logger.addHandler(logging.NullHandler()) + +class NvidiaProvider(ProviderInterface): + """ + Provider implementation for the NVIDIA API. + """ + async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]: + """ + Fetches the list of available models from the NVIDIA API. + """ + try: + response = await client.get( + "https://integrate.api.nvidia.com/v1/models", + headers={"Authorization": f"Bearer {api_key}"} + ) + response.raise_for_status() + return [f"nvidia_nim/{model['id']}" for model in response.json().get("data", [])] + except httpx.RequestError as e: + lib_logger.error(f"Failed to fetch NVIDIA models: {e}") + return [] diff --git a/src/rotator_library/providers/provider_interface.py b/src/rotator_library/providers/provider_interface.py index 8fd4342d..d4ce2396 100644 --- a/src/rotator_library/providers/provider_interface.py +++ b/src/rotator_library/providers/provider_interface.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List +from typing import List, Dict, Any import httpx class ProviderInterface(ABC): @@ -21,3 +21,15 @@ async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str] A list of model name strings. """ pass + + def convert_safety_settings(self, settings: Dict[str, str]) -> List[Dict[str, Any]]: + """ + Converts a generic safety settings dictionary to the provider-specific format. + + Args: + settings: A dictionary with generic harm categories and thresholds. + + Returns: + A list of provider-specific safety setting objects or None. + """ + return None diff --git a/src/rotator_library/pyproject.toml b/src/rotator_library/pyproject.toml index 81ff112d..34e1d558 100644 --- a/src/rotator_library/pyproject.toml +++ b/src/rotator_library/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "rotating-api-key-client" -version = "0.5.5" +version = "0.6.7" authors = [ { name="Mirrowel", email="nuh@uh.com" }, ] diff --git a/src/rotator_library/request_sanitizer.py b/src/rotator_library/request_sanitizer.py new file mode 100644 index 00000000..9f86ebf6 --- /dev/null +++ b/src/rotator_library/request_sanitizer.py @@ -0,0 +1,11 @@ +from typing import Dict, Any + +def sanitize_request_payload(payload: Dict[str, Any], model: str) -> Dict[str, Any]: + """ + Removes unsupported parameters from the request payload based on the model. + """ + if payload.get("thinking") == {"type": "enabled", "budget_tokens": -1}: + if model not in ["gemini/gemini-2.5-pro", "gemini/gemini-2.5-flash"]: + del payload["thinking"] + + return payload diff --git a/src/rotator_library/usage_manager.py b/src/rotator_library/usage_manager.py index c928837f..a6644b28 100644 --- a/src/rotator_library/usage_manager.py +++ b/src/rotator_library/usage_manager.py @@ -3,12 +3,13 @@ import time import logging import asyncio -from datetime import date +from datetime import date, datetime, timezone, time as dt_time from typing import Dict, List, Optional, Set from filelock import FileLock import aiofiles import litellm -import re + +from .error_handler import ClassifiedError lib_logger = logging.getLogger('rotator_library') lib_logger.propagate = False @@ -20,29 +21,28 @@ class UsageManager: Manages usage statistics and cooldowns for API keys with asyncio-safe locking, asynchronous file I/O, and a lazy-loading mechanism for usage data. """ - def __init__(self, file_path: str = "key_usage.json", wait_timeout: int = 5): + def __init__(self, file_path: str = "key_usage.json", wait_timeout: int = 13, daily_reset_time_utc: Optional[str] = "03:00"): self.file_path = file_path self.file_lock = FileLock(f"{self.file_path}.lock") - self.key_locks: Dict[str, asyncio.Lock] = {} - self.condition = asyncio.Condition() + self.key_states: Dict[str, Dict[str, Any]] = {} self.wait_timeout = wait_timeout - # Data-related locks and state self._data_lock = asyncio.Lock() self._usage_data: Optional[Dict] = None self._initialized = asyncio.Event() self._init_lock = asyncio.Lock() - # For "fair timeout" logic self._timeout_lock = asyncio.Lock() self._claimed_on_timeout: Set[str] = set() + if daily_reset_time_utc: + hour, minute = map(int, daily_reset_time_utc.split(':')) + self.daily_reset_time_utc = dt_time(hour=hour, minute=minute, tzinfo=timezone.utc) + else: + self.daily_reset_time_utc = None + async def _lazy_init(self): - """ - Initializes the usage data by loading it from the file asynchronously. - This method is called on the first access to ensure data is loaded - before any operations are performed. - """ + """Initializes the usage data by loading it from the file asynchronously.""" async with self._init_lock: if not self._initialized.is_set(): await self._load_usage() @@ -67,163 +67,251 @@ async def _save_usage(self): if self._usage_data is None: return async with self._data_lock: - with self.file_lock: # Use filelock to prevent multi-process race conditions + with self.file_lock: async with aiofiles.open(self.file_path, 'w') as f: await f.write(json.dumps(self._usage_data, indent=2)) async def _reset_daily_stats_if_needed(self): - """Checks if daily stats need to be reset for any key (async version).""" - if self._usage_data is None: + """Checks if daily stats need to be reset for any key.""" + if self._usage_data is None or not self.daily_reset_time_utc: return - today_str = date.today().isoformat() + now_utc = datetime.now(timezone.utc) + today_str = now_utc.date().isoformat() needs_saving = False + for key, data in self._usage_data.items(): - daily_data = data.get("daily", {}) - if daily_data.get("date") != today_str: - needs_saving = True - global_data = data.setdefault("global", {"models": {}}) - for model, stats in daily_data.get("models", {}).items(): - global_model_stats = global_data["models"].setdefault(model, {"success_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0}) - global_model_stats["success_count"] += stats.get("success_count", 0) - global_model_stats["prompt_tokens"] += stats.get("prompt_tokens", 0) - global_model_stats["completion_tokens"] += stats.get("completion_tokens", 0) - global_model_stats["approx_cost"] += stats.get("approx_cost", 0.0) - data["daily"] = {"date": today_str, "models": {}} - + last_reset_str = data.get("last_daily_reset", "") + + if last_reset_str != today_str: + last_reset_dt = None + if last_reset_str: + # Ensure the parsed datetime is timezone-aware (UTC) + last_reset_dt = datetime.fromisoformat(last_reset_str).replace(tzinfo=timezone.utc) + + # Determine the reset threshold for today + reset_threshold_today = datetime.combine(now_utc.date(), self.daily_reset_time_utc) + + if last_reset_dt is None or last_reset_dt < reset_threshold_today <= now_utc: + lib_logger.info(f"Performing daily reset for key ...{key[-4:]}") + needs_saving = True + + # Reset cooldowns + data["model_cooldowns"] = {} + data["key_cooldown_until"] = None + + # Reset consecutive failures + if "failures" in data: + data["failures"] = {} + + # Archive global stats from the previous day's 'daily' + daily_data = data.get("daily", {}) + if daily_data: + global_data = data.setdefault("global", {"models": {}}) + for model, stats in daily_data.get("models", {}).items(): + global_model_stats = global_data["models"].setdefault(model, {"success_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0}) + global_model_stats["success_count"] += stats.get("success_count", 0) + global_model_stats["prompt_tokens"] += stats.get("prompt_tokens", 0) + global_model_stats["completion_tokens"] += stats.get("completion_tokens", 0) + global_model_stats["approx_cost"] += stats.get("approx_cost", 0.0) + + # Reset daily stats + data["daily"] = {"date": today_str, "models": {}} + data["last_daily_reset"] = today_str + if needs_saving: await self._save_usage() - def _initialize_locks(self, keys: List[str]): - """Initializes asyncio locks for all provided keys if not already present.""" + def _initialize_key_states(self, keys: List[str]): + """Initializes state tracking for all provided keys if not already present.""" for key in keys: - if key not in self.key_locks: - self.key_locks[key] = asyncio.Lock() + if key not in self.key_states: + self.key_states[key] = { + "lock": asyncio.Lock(), + "condition": asyncio.Condition(), + "models_in_use": set() + } async def acquire_key(self, available_keys: List[str], model: str) -> str: """ - Acquires the best available key with robust locking and a fair timeout mechanism. + Acquires the best available key using a tiered, model-aware locking strategy. """ await self._lazy_init() - self._initialize_locks(available_keys) - - async with self.condition: - while True: - eligible_keys = [] - async with self._data_lock: - for key in available_keys: - key_data = self._usage_data.get(key, {}) - cooldown_until = key_data.get("model_cooldowns", {}).get(model) - if not cooldown_until or time.time() > cooldown_until: - usage_count = key_data.get("daily", {}).get("models", {}).get(model, {}).get("success_count", 0) - eligible_keys.append((key, usage_count)) - - if not eligible_keys: - lib_logger.warning("All keys are on cooldown. Waiting...") - await asyncio.sleep(5) - continue + self._initialize_key_states(available_keys) - eligible_keys.sort(key=lambda x: x[1]) - - for key, _ in eligible_keys: - lock = self.key_locks[key] - if not lock.locked(): - await lock.acquire() - lib_logger.info(f"Acquired lock for available key: ...{key[-4:]}") + while True: + tier1_keys, tier2_keys = [], [] + async with self._data_lock: + now = time.time() + for key in available_keys: + key_data = self._usage_data.get(key, {}) + + # Skip keys on global or model-specific cooldown + if (key_data.get("key_cooldown_until") or 0) > now or \ + (key_data.get("model_cooldowns", {}).get(model) or 0) > now: + continue + + usage_count = key_data.get("daily", {}).get("models", {}).get(model, {}).get("success_count", 0) + key_state = self.key_states[key] + + if not key_state["models_in_use"]: + tier1_keys.append((key, usage_count)) + elif model not in key_state["models_in_use"]: + tier2_keys.append((key, usage_count)) + + # Sort keys by usage count (ascending) + tier1_keys.sort(key=lambda x: x[1]) + tier2_keys.sort(key=lambda x: x[1]) + + # Attempt to acquire from Tier 1 (completely free) + for key, _ in tier1_keys: + state = self.key_states[key] + async with state["lock"]: + if not state["models_in_use"]: + state["models_in_use"].add(model) + lib_logger.info(f"Acquired Tier 1 key ...{key[-4:]} for model {model}") return key - lib_logger.info("All eligible keys are locked. Waiting for a key to be released.") - - try: - await asyncio.wait_for(self.condition.wait(), timeout=self.wait_timeout) - lib_logger.info("Notified that a key was released. Re-evaluating...") - continue - except asyncio.TimeoutError: - lib_logger.warning("Wait timed out. Attempting to acquire a key via fair timeout logic.") - async with self._timeout_lock: - for key, _ in eligible_keys: - if key not in self._claimed_on_timeout: - self._claimed_on_timeout.add(key) - lib_logger.info(f"Acquired key ...{key[-4:]} via timeout claim.") - return key - lib_logger.error("Timeout occurred, but all eligible keys were already claimed by other timed-out tasks.") - # Fallback to waiting again if all keys were claimed - await asyncio.sleep(1) - - - async def release_key(self, key: str): - """Releases the lock for a given key and notifies waiting tasks.""" - async with self.condition: - # Also release from timeout claim set if it's there - async with self._timeout_lock: - if key in self._claimed_on_timeout: - self._claimed_on_timeout.remove(key) - - if key in self.key_locks and self.key_locks[key].locked(): - self.key_locks[key].release() - lib_logger.info(f"Released lock for key ...{key[-4:]}") - self.condition.notify() - - async def record_success(self, key: str, model: str, completion_response: litellm.ModelResponse): - """Records a successful API call asynchronously.""" + # Attempt to acquire from Tier 2 (in use by other models) + for key, _ in tier2_keys: + state = self.key_states[key] + async with state["lock"]: + if model not in state["models_in_use"]: + state["models_in_use"].add(model) + lib_logger.info(f"Acquired Tier 2 key ...{key[-4:]} for model {model}") + return key + + # If no key is available, wait for one to be released + lib_logger.info("All eligible keys are currently locked for this model. Waiting...") + + # Create a combined list of all potentially usable keys to wait on + all_potential_keys = tier1_keys + tier2_keys + if not all_potential_keys: + lib_logger.warning("No keys are eligible at all (all on cooldown). Waiting before re-evaluating.") + await asyncio.sleep(5) + continue + + # Wait on the condition of the best available key + best_wait_key = min(all_potential_keys, key=lambda x: x[1])[0] + wait_condition = self.key_states[best_wait_key]["condition"] + + try: + async with wait_condition: + await asyncio.wait_for(wait_condition.wait(), timeout=self.wait_timeout) + lib_logger.info("Notified that a key was released. Re-evaluating...") + except asyncio.TimeoutError: + lib_logger.warning("Wait timed out. Re-evaluating for any available key.") + + + async def release_key(self, key: str, model: str): + """Releases a key's lock for a specific model and notifies waiting tasks.""" + if key not in self.key_states: + return + + state = self.key_states[key] + async with state["lock"]: + if model in state["models_in_use"]: + state["models_in_use"].remove(model) + lib_logger.info(f"Released key ...{key[-4:]} from model {model}") + else: + lib_logger.warning(f"Attempted to release key ...{key[-4:]} for model {model}, but it was not in use.") + + # Notify all tasks waiting on this key's condition + async with state["condition"]: + state["condition"].notify_all() + + async def record_success(self, key: str, model: str, completion_response: Optional[litellm.ModelResponse] = None): + """ + Records a successful API call, resetting failure counters. + It safely handles cases where token usage data is not available. + """ await self._lazy_init() async with self._data_lock: - key_data = self._usage_data.setdefault(key, {"daily": {"date": date.today().isoformat(), "models": {}}, "global": {"models": {}}, "model_cooldowns": {}}) + today_utc_str = datetime.now(timezone.utc).date().isoformat() + key_data = self._usage_data.setdefault(key, {"daily": {"date": today_utc_str, "models": {}}, "global": {"models": {}}, "model_cooldowns": {}, "failures": {}}) + # Perform a just-in-time daily reset if the date has changed. + if key_data["daily"].get("date") != today_utc_str: + key_data["daily"] = {"date": today_utc_str, "models": {}} + + # Always record a success and reset failures + model_failures = key_data.setdefault("failures", {}).setdefault(model, {}) + model_failures["consecutive_failures"] = 0 if model in key_data.get("model_cooldowns", {}): del key_data["model_cooldowns"][model] - if key_data["daily"].get("date") != date.today().isoformat(): - # This is a simplified reset for the current key. A full reset is done in _lazy_init. - key_data["daily"] = {"date": date.today().isoformat(), "models": {}} - daily_model_data = key_data["daily"]["models"].setdefault(model, {"success_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0}) - - usage = completion_response.usage daily_model_data["success_count"] += 1 - daily_model_data["prompt_tokens"] += usage.prompt_tokens - daily_model_data["completion_tokens"] += usage.completion_tokens - - try: - cost = litellm.completion_cost(completion_response=completion_response) - daily_model_data["approx_cost"] += cost - except Exception as e: - lib_logger.warning(f"Could not calculate cost for model {model}: {e}") + + # Safely attempt to record token and cost usage + if completion_response and hasattr(completion_response, 'usage') and completion_response.usage: + usage = completion_response.usage + daily_model_data["prompt_tokens"] += usage.prompt_tokens + daily_model_data["completion_tokens"] += usage.completion_tokens + + try: + cost = litellm.completion_cost(completion_response=completion_response) + daily_model_data["approx_cost"] += cost + except Exception as e: + lib_logger.warning(f"Could not calculate cost for model {model}: {e}") + else: + lib_logger.warning(f"No usage data found in completion response for model {model}. Recording success without token count.") key_data["last_used_ts"] = time.time() await self._save_usage() - async def record_rotation_error(self, key: str, model: str, error: Exception): - """Records a rotation error and sets a cooldown asynchronously.""" + async def record_failure(self, key: str, model: str, classified_error: ClassifiedError): + """Records a failure and applies cooldowns based on an escalating backoff strategy.""" await self._lazy_init() async with self._data_lock: - key_data = self._usage_data.setdefault(key, {"daily": {"date": date.today().isoformat(), "models": {}}, "global": {"models": {}}, "model_cooldowns": {}}) - - cooldown_seconds = 86400 - error_str = str(error).lower() - - patterns = [ - r'retry_delay.*?(\d+)', - r'retrydelay.*?(\d+)s', - r'wait.*?(\d+)\s*seconds?' - ] - for pattern in patterns: - match = re.search(pattern, error_str, re.IGNORECASE) - if match: - try: - cooldown_seconds = int(match.group(1)) - break - except (ValueError, IndexError): - continue + today_utc_str = datetime.now(timezone.utc).date().isoformat() + key_data = self._usage_data.setdefault(key, {"daily": {"date": today_utc_str, "models": {}}, "global": {"models": {}}, "model_cooldowns": {}, "failures": {}}) + # Handle specific error types first + if classified_error.error_type == 'rate_limit' and classified_error.retry_after: + cooldown_seconds = classified_error.retry_after + elif classified_error.error_type == 'authentication': + # Apply a 5-minute key-level lockout for auth errors + key_data["key_cooldown_until"] = time.time() + 300 + lib_logger.warning(f"Authentication error on key ...{key[-4:]}. Applying 5-minute key-level lockout.") + await self._save_usage() + return # No further backoff logic needed + else: + # General backoff logic for other errors + failures_data = key_data.setdefault("failures", {}) + model_failures = failures_data.setdefault(model, {"consecutive_failures": 0}) + model_failures["consecutive_failures"] += 1 + count = model_failures["consecutive_failures"] + + backoff_tiers = {1: 10, 2: 30, 3: 60, 4: 120} + cooldown_seconds = backoff_tiers.get(count, 7200) # Default to 2 hours + + # Apply the cooldown model_cooldowns = key_data.setdefault("model_cooldowns", {}) model_cooldowns[model] = time.time() + cooldown_seconds + lib_logger.warning(f"Failure recorded for key ...{key[-4:]} with model {model}. Applying {cooldown_seconds}s cooldown.") + + # Check for key-level lockout condition + await self._check_key_lockout(key, key_data) - key_data["last_rotation_error"] = { + key_data["last_failure"] = { "timestamp": time.time(), "model": model, - "error": str(error) + "error": str(classified_error.original_exception) } await self._save_usage() + + async def _check_key_lockout(self, key: str, key_data: Dict): + """Checks if a key should be locked out due to multiple model failures.""" + long_term_lockout_models = 0 + now = time.time() + + for model, cooldown_end in key_data.get("model_cooldowns", {}).items(): + if cooldown_end - now >= 7200: # Check for 2-hour lockouts + long_term_lockout_models += 1 + + if long_term_lockout_models >= 3: + key_data["key_cooldown_until"] = now + 300 # 5-minute key lockout + lib_logger.error(f"Key ...{key[-4:]} has {long_term_lockout_models} models in long-term lockout. Applying 5-minute key-level lockout.")