diff --git a/.gitignore b/.gitignore index 3711fdfd..0a308835 100644 --- a/.gitignore +++ b/.gitignore @@ -126,9 +126,10 @@ staged_changes.txt launcher_config.json quota_viewer_config.json cache/antigravity/thought_signatures.json -logs/ -cache/ +/logs/ +/cache/ *.env -oauth_creds/ +/oauth_creds/ +/usage/ diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index f7ebbde5..905ab4b0 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -22,9 +22,9 @@ This architecture cleanly separates the API interface from the resilience logic, This library is the heart of the project, containing all the logic for managing a pool of API keys, tracking their usage, and handling provider interactions to ensure application resilience. -### 2.1. `client.py` - The `RotatingClient` +### 2.1. `client/rotating_client.py` - The `RotatingClient` -The `RotatingClient` is the central class that orchestrates all operations. It is designed as a long-lived, async-native object. +The `RotatingClient` is the central class that orchestrates all operations. It is now a slim facade that delegates to modular components (executor, filters, transforms) while remaining a long-lived, async-native object. #### Initialization @@ -35,7 +35,7 @@ client = RotatingClient( api_keys=api_keys, oauth_credentials=oauth_credentials, max_retries=2, - usage_file_path="key_usage.json", + usage_file_path="usage.json", configure_logging=True, global_timeout=30, abort_on_callback_error=True, @@ -50,7 +50,7 @@ client = RotatingClient( - `api_keys` (`Optional[Dict[str, List[str]]]`, default: `None`): A dictionary mapping provider names to a list of API keys. - `oauth_credentials` (`Optional[Dict[str, List[str]]]`, default: `None`): A dictionary mapping provider names to a list of file paths to OAuth credential JSON files. - `max_retries` (`int`, default: `2`): The number of times to retry a request with the *same key* if a transient server error occurs. -- `usage_file_path` (`str`, default: `"key_usage.json"`): The path to the JSON file where usage statistics are persisted. +- `usage_file_path` (`str`, optional): Base path for usage persistence (defaults to `usage/` in the data directory). The client stores per-provider files under `usage/usage_.json`. - `configure_logging` (`bool`, default: `True`): If `True`, configures the library's logger to propagate logs to the root logger. - `global_timeout` (`int`, default: `30`): A hard time limit (in seconds) for the entire request lifecycle. - `abort_on_callback_error` (`bool`, default: `True`): If `True`, any exception raised by `pre_request_callback` will abort the request. @@ -96,9 +96,9 @@ The `_safe_streaming_wrapper` is a critical component for stability. It: * **Error Interception**: Detects if a chunk contains an API error (like a quota limit) instead of content, and raises a specific `StreamedAPIError`. * **Quota Handling**: If a specific "quota exceeded" error is detected mid-stream multiple times, it can terminate the stream gracefully to prevent infinite retry loops on oversized inputs. -### 2.2. `usage_manager.py` - Stateful Concurrency & Usage Management +### 2.2. `usage/manager.py` - Stateful Concurrency & Usage Management -This class is the stateful core of the library, managing concurrency, usage tracking, cooldowns, and quota resets. +This class is the stateful core of the library, managing concurrency, usage tracking, cooldowns, and quota resets. Usage tracking now lives in the `rotator_library/usage/` package with per-provider managers and `usage/usage_.json` storage. #### Key Concepts @@ -419,7 +419,7 @@ The `CooldownManager` handles IP or account-level rate limiting that affects all - All subsequent `acquire_key()` calls for that provider will wait until the cooldown expires -### 2.10. Credential Prioritization System (`client.py` & `usage_manager.py`) +### 2.10. Credential Prioritization System (`client/rotating_client.py` & `usage/manager.py`) The library now includes an intelligent credential prioritization system that automatically detects credential tiers and ensures optimal credential selection for each request. @@ -762,7 +762,7 @@ Acquiring key for model antigravity/claude-opus-4.5. Tried keys: 0/12(17,cd:3,fc ``` **Persistence:** -Cycle state is persisted in `key_usage.json` under the `__fair_cycle__` key. +Cycle state is persisted alongside usage data in `usage/usage_.json`. ### 2.20. Custom Caps @@ -1773,7 +1773,7 @@ The system follows a strict hierarchy of survival: 2. **Credential Management (Level 2)**: OAuth tokens are cached in memory first. If credential files are deleted, the proxy continues using cached tokens. If a token refresh succeeds but the file cannot be written, the new token is buffered for retry and saved on shutdown. -3. **Usage Tracking (Level 3)**: Usage statistics (`key_usage.json`) are maintained in memory via `ResilientStateWriter`. If the file is deleted, the system tracks usage internally and attempts to recreate the file on the next save interval. Pending writes are flushed on shutdown. +3. **Usage Tracking (Level 3)**: Usage statistics (`usage/usage_.json`) are maintained in memory via `ResilientStateWriter`. If the file is deleted, the system tracks usage internally and attempts to recreate the file on the next save interval. Pending writes are flushed on shutdown. 4. **Provider Cache (Level 4)**: The provider cache tracks disk health and continues operating in memory-only mode if disk writes fail. Has its own shutdown mechanism. @@ -1813,7 +1813,7 @@ INFO:rotator_library.resilient_io:Shutdown flush: all 2 write(s) succeeded This architecture supports a robust development workflow: - **Log Cleanup**: You can safely run `rm -rf logs/` while the proxy is serving traffic. The system will recreate the directory structure on the next request. -- **Config Reset**: Deleting `key_usage.json` resets the persistence layer, but the running instance preserves its current in-memory counts for load balancing consistency. +- **Config Reset**: Deleting `usage/usage_.json` resets the persistence layer, but the running instance preserves its current in-memory counts for load balancing consistency. - **File Recovery**: If you delete a critical file, the system attempts directory auto-recreation before every write operation. - **Safe Exit**: Ctrl+C triggers graceful shutdown with final data flush attempt. diff --git a/README.md b/README.md index a7c3c438..c15ed094 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ docker run -d \ -v $(pwd)/.env:/app/.env:ro \ -v $(pwd)/oauth_creds:/app/oauth_creds \ -v $(pwd)/logs:/app/logs \ + -v $(pwd)/usage:/app/usage \ -e SKIP_OAUTH_INIT_CHECK=true \ ghcr.io/mirrowel/llm-api-key-proxy:latest ``` @@ -60,13 +61,13 @@ docker run -d \ **Using Docker Compose:** ```bash -# Create your .env file and key_usage.json first, then: +# Create your .env file and usage directory first, then: cp .env.example .env -touch key_usage.json +mkdir usage docker compose up -d ``` -> **Important:** You must create both `.env` and `key_usage.json` files before running Docker Compose. If `key_usage.json` doesn't exist, Docker will create it as a directory instead of a file, causing errors. +> **Important:** Create the `usage/` directory before running Docker Compose so usage stats persist on the host. > **Note:** For OAuth providers, complete authentication locally first using the credential tool, then mount the `oauth_creds/` directory or export credentials to environment variables. @@ -335,10 +336,11 @@ The proxy includes a powerful text-based UI for configuration and management. ### TUI Features - **🚀 Run Proxy** — Start the server with saved settings -- **âš™ī¸ Configure Settings** — Host, port, API key, request logging +- **âš™ī¸ Configure Settings** — Host, port, API key, request logging, raw I/O logging - **🔑 Manage Credentials** — Add/edit API keys and OAuth credentials -- **📊 View Status** — See configured providers and credential counts -- **🔧 Advanced Settings** — Custom providers, model definitions, concurrency +- **📊 View Provider & Advanced Settings** — Inspect providers and launch the settings tool +- **📈 View Quota & Usage Stats (Alpha)** — Usage, quota windows, fair-cycle status +- **🔄 Reload Configuration** — Refresh settings without restarting ### Configuration Files @@ -346,6 +348,8 @@ The proxy includes a powerful text-based UI for configuration and management. |------|----------| | `.env` | All credentials and advanced settings | | `launcher_config.json` | TUI-specific settings (host, port, logging) | +| `quota_viewer_config.json` | Quota viewer remotes + per-provider display toggles | +| `usage/usage_.json` | Usage persistence per provider | --- @@ -446,6 +450,7 @@ The proxy includes a powerful text-based UI for configuration and management. 📝 Logging & Debugging - **Per-request file logging** with `--enable-request-logging` +- **Raw I/O logging** with `--enable-raw-logging` (proxy boundary payloads) - **Unique request directories** with full transaction details - **Streaming chunk capture** for debugging - **Performance metadata** (duration, tokens, model used) @@ -801,6 +806,7 @@ Options: --host TEXT Host to bind (default: 0.0.0.0) --port INTEGER Port to run on (default: 8000) --enable-request-logging Enable detailed per-request logging + --enable-raw-logging Capture raw proxy I/O payloads --add-credential Launch interactive credential setup tool ``` @@ -813,6 +819,9 @@ python src/proxy_app/main.py --host 127.0.0.1 --port 9000 # Run with logging python src/proxy_app/main.py --enable-request-logging +# Run with raw I/O logging +python src/proxy_app/main.py --enable-raw-logging + # Add credentials without starting proxy python src/proxy_app/main.py --add-credential ``` @@ -850,8 +859,8 @@ The proxy is available as a multi-architecture Docker image (amd64/arm64) from G cp .env.example .env nano .env -# 2. Create key_usage.json file (required before first run) -touch key_usage.json +# 2. Create usage directory (usage_*.json files are created automatically) +mkdir usage # 3. Start the proxy docker compose up -d @@ -860,13 +869,13 @@ docker compose up -d docker compose logs -f ``` -> **Important:** You must create `key_usage.json` before running Docker Compose. If this file doesn't exist on the host, Docker will create it as a directory instead of a file, causing the container to fail. +> **Important:** Create the `usage/` directory before running Docker Compose so usage stats persist on the host. **Manual Docker Run:** ```bash -# Create key_usage.json if it doesn't exist -touch key_usage.json +# Create usage directory if it doesn't exist +mkdir usage docker run -d \ --name llm-api-proxy \ @@ -875,7 +884,7 @@ docker run -d \ -v $(pwd)/.env:/app/.env:ro \ -v $(pwd)/oauth_creds:/app/oauth_creds \ -v $(pwd)/logs:/app/logs \ - -v $(pwd)/key_usage.json:/app/key_usage.json \ + -v $(pwd)/usage:/app/usage \ -e SKIP_OAUTH_INIT_CHECK=true \ -e PYTHONUNBUFFERED=1 \ ghcr.io/mirrowel/llm-api-key-proxy:latest @@ -895,7 +904,7 @@ docker compose -f docker-compose.dev.yml up -d --build | `.env` | Configuration and API keys (read-only) | | `oauth_creds/` | OAuth credential files (persistent) | | `logs/` | Request logs and detailed logging | -| `key_usage.json` | Usage statistics persistence | +| `usage/` | Usage statistics persistence (`usage_*.json`) | **Image Tags:** diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index 36458929..becc2606 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -19,8 +19,8 @@ services: - ./oauth_creds:/app/oauth_creds # Mount logs directory for persistent logging - ./logs:/app/logs - # Mount key_usage.json for usage statistics persistence - - ./key_usage.json:/app/key_usage.json + # Mount usage directory for usage statistics persistence + - ./usage:/app/usage # Optionally mount additional .env files (e.g., combined credential files) # - ./antigravity_all_combined.env:/app/antigravity_all_combined.env:ro environment: diff --git a/docker-compose.tls.yml b/docker-compose.tls.yml index e210423f..0c670b3d 100644 --- a/docker-compose.tls.yml +++ b/docker-compose.tls.yml @@ -36,8 +36,8 @@ services: - ./oauth_creds:/app/oauth_creds # Mount logs directory for persistent logging - ./logs:/app/logs - # Mount key_usage.json for usage statistics persistence - - ./key_usage.json:/app/key_usage.json + # Mount usage directory for usage statistics persistence + - ./usage:/app/usage # Optionally mount additional .env files (e.g., combined credential files) # - ./antigravity_all_combined.env:/app/antigravity_all_combined.env:ro environment: diff --git a/docker-compose.yml b/docker-compose.yml index 31964b60..027d5d91 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -17,8 +17,8 @@ services: - ./oauth_creds:/app/oauth_creds # Mount logs directory for persistent logging - ./logs:/app/logs - # Mount key_usage.json for usage statistics persistence - - ./key_usage.json:/app/key_usage.json + # Mount usage directory for usage statistics persistence + - ./usage:/app/usage # Optionally mount additional .env files (e.g., combined credential files) # - ./antigravity_all_combined.env:/app/antigravity_all_combined.env:ro environment: diff --git a/src/proxy_app/launcher_tui.py b/src/proxy_app/launcher_tui.py index 68338b02..b2fec223 100644 --- a/src/proxy_app/launcher_tui.py +++ b/src/proxy_app/launcher_tui.py @@ -377,9 +377,9 @@ def show_main_menu(self): self.console.print( Panel( Text.from_markup( - "âš ī¸ [bold yellow]INITIAL SETUP REQUIRED[/bold yellow]\n\n" + ":warning: [bold yellow]INITIAL SETUP REQUIRED[/bold yellow]\n\n" "The proxy needs initial configuration:\n" - " ❌ No .env file found\n\n" + " :x: No .env file found\n\n" "Why this matters:\n" " â€ĸ The .env file stores your credentials and settings\n" " â€ĸ PROXY_API_KEY protects your proxy from unauthorized access\n" @@ -388,7 +388,7 @@ def show_main_menu(self): ' 1. Select option "3. Manage Credentials" to launch the credential tool\n' " 2. The tool will create .env and set up PROXY_API_KEY automatically\n" " 3. You can add provider credentials (API keys or OAuth)\n\n" - "âš ī¸ Note: The credential tool adds PROXY_API_KEY by default.\n" + ":warning: Note: The credential tool adds PROXY_API_KEY by default.\n" " You can remove it later if you want an unsecured proxy." ), border_style="yellow", @@ -401,12 +401,12 @@ def show_main_menu(self): self.console.print( Panel( Text.from_markup( - "âš ī¸ [bold red]SECURITY WARNING: PROXY_API_KEY Not Set[/bold red]\n\n" + ":warning: [bold red]SECURITY WARNING: PROXY_API_KEY Not Set[/bold red]\n\n" "Your proxy is currently UNSECURED!\n" "Anyone can access it without authentication.\n\n" "This is a serious security risk if your proxy is accessible\n" "from the internet or untrusted networks.\n\n" - "👉 [bold]Recommended:[/bold] Set PROXY_API_KEY in .env file\n" + ":point_right: [bold]Recommended:[/bold] Set PROXY_API_KEY in .env file\n" ' Use option "2. Configure Proxy Settings" → "3. Set Proxy API Key"\n' ' or option "3. Manage Credentials"' ), @@ -417,15 +417,15 @@ def show_main_menu(self): # Show config self.console.print() - self.console.print("[bold]📋 Proxy Configuration[/bold]") + self.console.print("[bold]:clipboard: Proxy Configuration[/bold]") self.console.print("━" * 70) self.console.print(f" Host: {self.config.config['host']}") self.console.print(f" Port: {self.config.config['port']}") self.console.print( - f" Transaction Logging: {'✅ Enabled' if self.config.config['enable_request_logging'] else '❌ Disabled'}" + f" Transaction Logging: {':white_check_mark: Enabled' if self.config.config['enable_request_logging'] else ':x: Disabled'}" ) self.console.print( - f" Raw I/O Logging: {'✅ Enabled' if self.config.config.get('enable_raw_logging', False) else '❌ Disabled'}" + f" Raw I/O Logging: {':white_check_mark: Enabled' if self.config.config.get('enable_raw_logging', False) else ':x: Disabled'}" ) # Show actual API key value @@ -437,7 +437,7 @@ def show_main_menu(self): # Show status summary self.console.print() - self.console.print("[bold]📊 Status Summary[/bold]") + self.console.print("[bold]:bar_chart: Status Summary[/bold]") self.console.print("━" * 70) provider_count = len(credentials) custom_count = len(custom_bases) @@ -458,24 +458,26 @@ def show_main_menu(self): self.console.print() self.console.print("━" * 70) self.console.print() - self.console.print("[bold]đŸŽ¯ Main Menu[/bold]") + self.console.print("[bold]:dart: Main Menu[/bold]") self.console.print() if show_warning: - self.console.print(" 1. â–ļī¸ Run Proxy Server") - self.console.print(" 2. âš™ī¸ Configure Proxy Settings") + self.console.print(" 1. :arrow_forward: Run Proxy Server") + self.console.print(" 2. :gear: Configure Proxy Settings") self.console.print( - " 3. 🔑 Manage Credentials âŦ…ī¸ [bold yellow]Start here![/bold yellow]" + " 3. :key: Manage Credentials :arrow_left: [bold yellow]Start here![/bold yellow]" ) else: - self.console.print(" 1. â–ļī¸ Run Proxy Server") - self.console.print(" 2. âš™ī¸ Configure Proxy Settings") - self.console.print(" 3. 🔑 Manage Credentials") + self.console.print(" 1. :arrow_forward: Run Proxy Server") + self.console.print(" 2. :gear: Configure Proxy Settings") + self.console.print(" 3. :key: Manage Credentials") - self.console.print(" 4. 📊 View Provider & Advanced Settings") - self.console.print(" 5. 📈 View Quota & Usage Stats (Alpha)") - self.console.print(" 6. 🔄 Reload Configuration") - self.console.print(" 7. â„šī¸ About") - self.console.print(" 8. đŸšĒ Exit") + self.console.print(" 4. :bar_chart: View Provider & Advanced Settings") + self.console.print( + " 5. :chart_with_upwards_trend: View Quota & Usage Stats (Alpha)" + ) + self.console.print(" 6. :arrows_counterclockwise: Reload Configuration") + self.console.print(" 7. :information_source: About") + self.console.print(" 8. :door: Exit") self.console.print() self.console.print("━" * 70) @@ -500,7 +502,9 @@ def show_main_menu(self): elif choice == "6": load_dotenv(dotenv_path=_get_env_file(), override=True) self.config = LauncherConfig() # Reload config - self.console.print("\n[green]✅ Configuration reloaded![/green]") + self.console.print( + "\n[green]:white_check_mark: Configuration reloaded![/green]" + ) elif choice == "7": self.show_about() elif choice == "8": @@ -518,7 +522,7 @@ def confirm_setting_change(self, setting_name: str, warning_lines: list) -> bool self.console.print( Panel( Text.from_markup( - f"[bold yellow]âš ī¸ WARNING: You are about to change the {setting_name}[/bold yellow]\n\n" + f"[bold yellow]:warning: WARNING: You are about to change the {setting_name}[/bold yellow]\n\n" + "\n".join(warning_lines) + "\n\n[bold]If you are not sure about changing this - don't.[/bold]" ), @@ -548,36 +552,39 @@ def show_config_menu(self): self.console.print( Panel.fit( - "[bold cyan]âš™ī¸ Proxy Configuration[/bold cyan]", border_style="cyan" + "[bold cyan]:gear: Proxy Configuration[/bold cyan]", + border_style="cyan", ) ) self.console.print() - self.console.print("[bold]📋 Current Settings[/bold]") + self.console.print("[bold]:clipboard: Current Settings[/bold]") self.console.print("━" * 70) self.console.print(f" Host: {self.config.config['host']}") self.console.print(f" Port: {self.config.config['port']}") self.console.print( - f" Transaction Logging: {'✅ Enabled' if self.config.config['enable_request_logging'] else '❌ Disabled'}" + f" Transaction Logging: {':white_check_mark: Enabled' if self.config.config['enable_request_logging'] else ':x: Disabled'}" ) self.console.print( - f" Raw I/O Logging: {'✅ Enabled' if self.config.config.get('enable_raw_logging', False) else '❌ Disabled'}" + f" Raw I/O Logging: {':white_check_mark: Enabled' if self.config.config.get('enable_raw_logging', False) else ':x: Disabled'}" ) self.console.print( - f" Proxy API Key: {'✅ Set' if os.getenv('PROXY_API_KEY') else '❌ Not Set'}" + f" Proxy API Key: {':white_check_mark: Set' if os.getenv('PROXY_API_KEY') else ':x: Not Set'}" ) self.console.print() self.console.print("━" * 70) self.console.print() - self.console.print("[bold]âš™ī¸ Configuration Options[/bold]") + self.console.print("[bold]:gear: Configuration Options[/bold]") self.console.print() - self.console.print(" 1. 🌐 Set Host IP") + self.console.print(" 1. :globe_with_meridians: Set Host IP") self.console.print(" 2. 🔌 Set Port") - self.console.print(" 3. 🔑 Set Proxy API Key") - self.console.print(" 4. 📝 Toggle Transaction Logging") - self.console.print(" 5. 📋 Toggle Raw I/O Logging") - self.console.print(" 6. 🔄 Reset to Default Settings") + self.console.print(" 3. :key: Set Proxy API Key") + self.console.print(" 4. :memo: Toggle Transaction Logging") + self.console.print(" 5. :clipboard: Toggle Raw I/O Logging") + self.console.print( + " 6. :arrows_counterclockwise: Reset to Default Settings" + ) self.console.print(" 7. â†Šī¸ Back to Main Menu") self.console.print() @@ -609,7 +616,9 @@ def show_config_menu(self): "Enter new host IP", default=self.config.config["host"] ) self.config.update(host=new_host) - self.console.print(f"\n[green]✅ Host updated to: {new_host}[/green]") + self.console.print( + f"\n[green]:white_check_mark: Host updated to: {new_host}[/green]" + ) elif choice == "2": # Show warning and require confirmation confirmed = self.confirm_setting_change( @@ -630,10 +639,10 @@ def show_config_menu(self): if 1 <= new_port <= 65535: self.config.update(port=new_port) self.console.print( - f"\n[green]✅ Port updated to: {new_port}[/green]" + f"\n[green]:white_check_mark: Port updated to: {new_port}[/green]" ) else: - self.console.print("\n[red]❌ Port must be between 1-65535[/red]") + self.console.print("\n[red]:x: Port must be between 1-65535[/red]") elif choice == "3": # Show warning and require confirmation confirmed = self.confirm_setting_change( @@ -641,11 +650,11 @@ def show_config_menu(self): [ "This is the authentication key that applications use to access your proxy.", "", - "[bold red]âš ī¸ Changing this will BREAK all applications currently configured", + "[bold red]:warning: Changing this will BREAK all applications currently configured", " with the existing API key![/bold red]", "", - "[bold cyan]💡 If you want to add provider API keys (OpenAI, Gemini, etc.),", - ' go to "3. 🔑 Manage Credentials" in the main menu instead.[/bold cyan]', + "[bold cyan]:bulb: If you want to add provider API keys (OpenAI, Gemini, etc.),", + ' go to "3. :key: Manage Credentials" in the main menu instead.[/bold cyan]', ], ) if not confirmed: @@ -661,7 +670,7 @@ def show_config_menu(self): # If setting to empty, show additional warning if not new_key: self.console.print( - "\n[bold red]âš ī¸ Authentication will be DISABLED - anyone can access your proxy![/bold red]" + "\n[bold red]:warning: Authentication will be DISABLED - anyone can access your proxy![/bold red]" ) Prompt.ask("Press Enter to continue", default="") @@ -669,12 +678,12 @@ def show_config_menu(self): if new_key: self.console.print( - "\n[green]✅ Proxy API Key updated successfully![/green]" + "\n[green]:white_check_mark: Proxy API Key updated successfully![/green]" ) self.console.print(" Updated in .env file") else: self.console.print( - "\n[yellow]âš ī¸ Proxy API Key cleared - authentication disabled![/yellow]" + "\n[yellow]:warning: Proxy API Key cleared - authentication disabled![/yellow]" ) self.console.print(" Updated in .env file") else: @@ -683,13 +692,13 @@ def show_config_menu(self): current = self.config.config["enable_request_logging"] self.config.update(enable_request_logging=not current) self.console.print( - f"\n[green]✅ Transaction Logging {'enabled' if not current else 'disabled'}![/green]" + f"\n[green]:white_check_mark: Transaction Logging {'enabled' if not current else 'disabled'}![/green]" ) elif choice == "5": current = self.config.config.get("enable_raw_logging", False) self.config.update(enable_raw_logging=not current) self.console.print( - f"\n[green]✅ Raw I/O Logging {'enabled' if not current else 'disabled'}![/green]" + f"\n[green]:white_check_mark: Raw I/O Logging {'enabled' if not current else 'disabled'}![/green]" ) elif choice == "6": # Reset to Default Settings @@ -725,7 +734,7 @@ def show_config_menu(self): else f" Raw I/O Logging {'Disabled':20} → Disabled", f" Proxy API Key {current_api_key[:20]:20} → {default_api_key}", "", - "[bold red]âš ī¸ This may break applications configured with current settings![/bold red]", + "[bold red]:warning: This may break applications configured with current settings![/bold red]", ] confirmed = self.confirm_setting_change( @@ -744,7 +753,7 @@ def show_config_menu(self): LauncherConfig.update_proxy_api_key(default_api_key) self.console.print( - "\n[green]✅ All settings have been reset to defaults![/green]" + "\n[green]:white_check_mark: All settings have been reset to defaults![/green]" ) self.console.print(f" Host: {default_host}") self.console.print(f" Port: {default_port}") @@ -769,14 +778,14 @@ def show_provider_settings_menu(self): self.console.print( Panel.fit( - "[bold cyan]📊 Provider & Advanced Settings[/bold cyan]", + "[bold cyan]:bar_chart: Provider & Advanced Settings[/bold cyan]", border_style="cyan", ) ) # Configured Providers self.console.print() - self.console.print("[bold]📊 Configured Providers[/bold]") + self.console.print("[bold]:bar_chart: Configured Providers[/bold]") self.console.print("━" * 70) if credentials: for provider, info in credentials.items(): @@ -795,14 +804,16 @@ def show_provider_settings_menu(self): if info["custom"]: display += " (Custom)" - self.console.print(f" ✅ {provider_name:20} {display}") + self.console.print( + f" :white_check_mark: {provider_name:20} {display}" + ) else: self.console.print(" [dim]No providers configured[/dim]") # Custom API Bases if custom_bases: self.console.print() - self.console.print("[bold]🌐 Custom API Bases[/bold]") + self.console.print("[bold]:globe_with_meridians: Custom API Bases[/bold]") self.console.print("━" * 70) for provider, base in custom_bases.items(): self.console.print(f" â€ĸ {provider:15} {base}") @@ -829,7 +840,7 @@ def show_provider_settings_menu(self): # Model Filters (basic info only) if filters: self.console.print() - self.console.print("[bold]đŸŽ¯ Model Filters[/bold]") + self.console.print("[bold]:dart: Model Filters[/bold]") self.console.print("━" * 70) for provider, filter_info in filters.items(): status_parts = [] @@ -838,7 +849,7 @@ def show_provider_settings_menu(self): if filter_info["has_ignore"]: status_parts.append("Ignore list") status = " + ".join(status_parts) if status_parts else "None" - self.console.print(f" â€ĸ {provider:15} ✅ {status}") + self.console.print(f" â€ĸ {provider:15} :white_check_mark: {status}") # Provider-Specific Settings (deferred to Settings Tool to avoid heavy imports) self.console.print() @@ -852,21 +863,21 @@ def show_provider_settings_menu(self): self.console.print() self.console.print("━" * 70) self.console.print() - self.console.print("[bold]💡 Actions[/bold]") + self.console.print("[bold]:bulb: Actions[/bold]") self.console.print() self.console.print( - " 1. 🔧 Launch Settings Tool (configure advanced settings)" + " 1. :wrench: Launch Settings Tool (configure advanced settings)" ) self.console.print(" 2. â†Šī¸ Back to Main Menu") self.console.print() self.console.print("━" * 70) self.console.print( - "[dim]â„šī¸ Advanced settings are stored in .env file.\n Use the Settings Tool to configure them interactively.[/dim]" + "[dim]:information_source: Advanced settings are stored in .env file.\n Use the Settings Tool to configure them interactively.[/dim]" ) self.console.print() self.console.print( - "[dim]âš ī¸ Note: Settings Tool supports only common configuration types.\n For complex settings, edit .env directly.[/dim]" + "[dim]:warning: Note: Settings Tool supports only common configuration types.\n For complex settings, edit .env directly.[/dim]" ) self.console.print() @@ -959,7 +970,8 @@ def show_about(self): self.console.print( Panel.fit( - "[bold cyan]â„šī¸ About LLM API Key Proxy[/bold cyan]", border_style="cyan" + "[bold cyan]:information_source: About LLM API Key Proxy[/bold cyan]", + border_style="cyan", ) ) @@ -1005,7 +1017,7 @@ def show_about(self): ) self.console.print() - self.console.print("[bold]📝 License & Credits[/bold]") + self.console.print("[bold]:memo: License & Credits[/bold]") self.console.print("━" * 70) self.console.print(" Made with â¤ī¸ by the community") self.console.print(" Open source - contributions welcome!") @@ -1024,7 +1036,7 @@ def run_proxy(self): self.console.print( Panel( Text.from_markup( - "âš ī¸ [bold yellow]Setup Required[/bold yellow]\n\n" + ":warning: [bold yellow]Setup Required[/bold yellow]\n\n" "Cannot start without .env.\n" "Launching credential tool..." ), @@ -1046,7 +1058,7 @@ def run_proxy(self): # Check again after credential tool if not os.getenv("PROXY_API_KEY"): self.console.print( - "\n[red]❌ PROXY_API_KEY still not set. Cannot start proxy.[/red]" + "\n[red]:x: PROXY_API_KEY still not set. Cannot start proxy.[/red]" ) return diff --git a/src/proxy_app/main.py b/src/proxy_app/main.py index 12014bdc..3e4bbbbc 100644 --- a/src/proxy_app/main.py +++ b/src/proxy_app/main.py @@ -603,6 +603,8 @@ async def process_credential(provider: str, path: str, provider_instance): max_concurrent_requests_per_key=max_concurrent_requests_per_key, ) + await client.initialize_usage_managers() + # Log loaded credentials summary (compact, always visible for deployment verification) # _api_summary = ', '.join([f"{p}:{len(c)}" for p, c in api_keys.items()]) if api_keys else "none" # _oauth_summary = ', '.join([f"{p}:{len(c)}" for p, c in oauth_credentials.items()]) if oauth_credentials else "none" @@ -956,7 +958,9 @@ async def chat_completions( is_streaming = request_data.get("stream", False) if is_streaming: - response_generator = client.acompletion(request=request, **request_data) + response_generator = await client.acompletion( + request=request, **request_data + ) return StreamingResponse( streaming_response_wrapper( request, request_data, response_generator, raw_logger @@ -1649,10 +1653,10 @@ def show_onboarding_message(): border_style="cyan", ) ) - console.print("[bold yellow]âš ī¸ Configuration Required[/bold yellow]\n") + console.print("[bold yellow]:warning: Configuration Required[/bold yellow]\n") console.print("The proxy needs initial configuration:") - console.print(" [red]❌ No .env file found[/red]") + console.print(" [red]:x: No .env file found[/red]") console.print("\n[bold]Why this matters:[/bold]") console.print(" â€ĸ The .env file stores your credentials and settings") @@ -1665,7 +1669,7 @@ def show_onboarding_message(): console.print(" 3. The proxy will then start normally") console.print( - "\n[bold yellow]âš ī¸ Note:[/bold yellow] The credential tool adds PROXY_API_KEY by default." + "\n[bold yellow]:warning: Note:[/bold yellow] The credential tool adds PROXY_API_KEY by default." ) console.print(" You can remove it later if you want an unsecured proxy.\n") @@ -1708,13 +1712,15 @@ def show_onboarding_message(): # Verify onboarding is complete if needs_onboarding(): - console.print("\n[bold red]❌ Configuration incomplete.[/bold red]") + console.print("\n[bold red]:x: Configuration incomplete.[/bold red]") console.print( "The proxy still cannot start. Please ensure PROXY_API_KEY is set in .env\n" ) sys.exit(1) else: - console.print("\n[bold green]✅ Configuration complete![/bold green]") + console.print( + "\n[bold green]:white_check_mark: Configuration complete![/bold green]" + ) console.print("\nStarting proxy server...\n") import uvicorn diff --git a/src/proxy_app/quota_viewer.py b/src/proxy_app/quota_viewer.py index 6cdb224d..687f96cd 100644 --- a/src/proxy_app/quota_viewer.py +++ b/src/proxy_app/quota_viewer.py @@ -6,42 +6,6 @@ Connects to a running proxy to display quota and usage statistics. Uses only httpx + rich (no heavy rotator_library imports). - -TODO: Missing Features & Improvements -====================================== - -Display Improvements: -- [ ] Add color legend/help screen explaining status colors and symbols -- [ ] Show credential email/project ID if available (currently just filename) -- [ ] Add keyboard shortcut hints (e.g., "Press ? for help") -- [ ] Support terminal resize / responsive layout - -Global Stats Fix: -- [ ] HACK: Global requests currently set to current period requests only - (see client.py get_quota_stats). This doesn't include archived stats. - Fix requires tracking archived requests per quota group in usage_manager.py - to avoid double-counting models that share quota groups. - -Data & Refresh: -- [ ] Auto-refresh option (configurable interval) -- [ ] Show last refresh timestamp more prominently -- [ ] Cache invalidation when switching between current/global view -- [ ] Support for non-OAuth providers (API keys like nvapi-*, gsk_*, etc.) - -Remote Management: -- [ ] Test connection before saving remote -- [ ] Import/export remote configurations -- [ ] SSH tunnel support for remote proxies - -Quota Groups: -- [ ] Show which models are in each quota group (expandable) -- [ ] Historical quota usage graphs (if data available) -- [ ] Alerts/notifications when quota is low - -Credential Details: -- [ ] Show per-model breakdown within quota groups -- [ ] Edit credential priority/tier manually -- [ ] Disable/enable individual credentials """ import os @@ -62,6 +26,56 @@ from .quota_viewer_config import QuotaViewerConfig +# ============================================================================= +# DISPLAY CONFIGURATION - Adjust these values to customize the layout +# ============================================================================= + +# Summary screen table column widths +TABLE_PROVIDER_WIDTH = 12 +TABLE_CREDS_WIDTH = 3 +TABLE_QUOTA_STATUS_WIDTH = 62 +TABLE_REQUESTS_WIDTH = 5 +TABLE_TOKENS_WIDTH = 20 +TABLE_COST_WIDTH = 6 + +# Quota status formatting in summary screen +QUOTA_NAME_WIDTH = 15 # Width for quota group name (e.g., "claude:") +QUOTA_USAGE_WIDTH = 11 # Width for usage ratio (e.g., "2071/2700") +QUOTA_PCT_WIDTH = 6 # Width for percentage (e.g., "76.7%") +QUOTA_BAR_WIDTH = 10 # Width for progress bar + +# Detail view credential panel formatting +DETAIL_GROUP_NAME_WIDTH = ( + 18 # Width for group name in detail view (handles "g25-flash (daily)") +) +DETAIL_USAGE_WIDTH = ( + 16 # Width for usage ratio in detail view (handles "3000/3000(5000)") +) +DETAIL_PCT_WIDTH = 7 # Width for percentage in detail view + +# ============================================================================= +# STATUS DISPLAY CONFIGURATION +# ============================================================================= + +# Credential status icons and colors: (icon, label, color) +# Using Rich emoji markup :name: for consistent width handling +STATUS_DISPLAY = { + "active": (":white_check_mark:", "Active", "green"), + "cooldown": (":stopwatch:", "Cooldown", "yellow"), + "exhausted": (":no_entry:", "Exhausted", "red"), + "mixed": (":warning:", "Mixed", "yellow"), +} + +# Per-group indicator icons (using Rich emoji markup for proper width handling) +INDICATOR_ICONS = { + "fair_cycle": ":scales:", # âš–ī¸ - Rich will handle width + "custom_cap": ":bar_chart:", # 📊 + "cooldown": ":stopwatch:", # âąī¸ +} + +# ============================================================================= + + def clear_screen(): """Clear the terminal screen.""" os.system("cls" if os.name == "nt" else "clear") @@ -187,19 +201,127 @@ def format_cooldown(seconds: int) -> str: return f"{hours}h {mins}m" if mins > 0 else f"{hours}h" -def natural_sort_key(item: Dict[str, Any]) -> List: +def natural_sort_key(item: Any) -> List: """ Generate a sort key for natural/numeric sorting. Sorts credentials like proj-1, proj-2, proj-10 correctly instead of alphabetically (proj-1, proj-10, proj-2). + + Handles both dict items (new API format) and strings. """ - identifier = item.get("identifier", "") + if isinstance(item, dict): + identifier = item.get("identifier", item.get("stable_id", "")) + else: + identifier = str(item) # Split into text and numeric parts parts = re.split(r"(\d+)", identifier) return [int(p) if p.isdigit() else p.lower() for p in parts] +def get_credentials_list(prov_stats: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Convert credentials from dict format to list. + + The new API returns credentials as a dict keyed by stable_id. + This function converts it to a list for iteration and sorting. + """ + credentials = prov_stats.get("credentials", {}) + if isinstance(credentials, list): + return credentials + if isinstance(credentials, dict): + return list(credentials.values()) + return [] + + +def get_credential_stats( + cred: Dict[str, Any], view_mode: str = "current" +) -> Dict[str, Any]: + """ + Extract display stats from a credential with field name adaptation. + + Maps new API field names to what the viewer expects: + - totals.request_count -> requests + - totals.last_used_at -> last_used_ts + - totals.approx_cost -> approx_cost + - Derive tokens from totals + """ + totals = cred.get("totals", {}) + + # For global view mode, we'd need global totals (currently same as totals) + if view_mode == "global": + stats_source = cred.get("global", totals) + if stats_source == totals: + stats_source = totals + else: + stats_source = totals + + # Calculate proper token stats + prompt_tokens = stats_source.get("prompt_tokens", 0) + cache_read = stats_source.get("prompt_tokens_cache_read", 0) + output_tokens = stats_source.get("output_tokens", 0) + + # Total input = uncached (prompt_tokens) + cached (cache_read) + input_total = prompt_tokens + cache_read + input_cached = cache_read + input_uncached = prompt_tokens + + cache_pct = round(input_cached / input_total * 100, 1) if input_total > 0 else 0 + + return { + "requests": stats_source.get("request_count", 0), + "last_used_ts": stats_source.get("last_used_at"), + "approx_cost": stats_source.get("approx_cost"), + "tokens": { + "input_cached": input_cached, + "input_uncached": input_uncached, + "input_cache_pct": cache_pct, + "output": output_tokens, + }, + } + + +def provider_sort_key(item: Tuple[str, Dict[str, Any]]) -> Tuple: + """ + Sort key for providers. + + Order: has quota_groups -> has activity -> request count -> credential count + """ + name, stats = item + has_quota_groups = bool(stats.get("quota_groups")) + has_activity = stats.get("total_requests", 0) > 0 + return ( + not has_quota_groups, # False (has groups) sorts first + not has_activity, # False (has activity) sorts first + -stats.get("total_requests", 0), # Higher requests first + -stats.get("credential_count", 0), # Higher creds first + name.lower(), # Alphabetically last + ) + + +def quota_group_sort_key(item: Tuple[str, Dict[str, Any]]) -> Tuple: + """ + Sort key for quota groups. + + Order: by total quota limit (lowest first), then alphabetically. + Groups without limits sort last. + """ + name, group_stats = item + windows = group_stats.get("windows", {}) + + if not windows: + return (float("inf"), name) # No windows = sort last + + # Find minimum total_max across windows + min_limit = float("inf") + for window_stats in windows.values(): + total_max = window_stats.get("total_max", 0) + if total_max > 0: + min_limit = min(min_limit, total_max) + + return (min_limit, name) + + class QuotaViewer: """Main Quota Viewer TUI class.""" @@ -210,7 +332,8 @@ def __init__(self, config: Optional[QuotaViewerConfig] = None): Args: config: Optional config object. If not provided, one will be created. """ - self.console = Console() + # Use emoji_variant="text" for more consistent width calculations + self.console = Console(emoji_variant="text") self.config = config or QuotaViewerConfig() self.config.sync_with_launcher_config() @@ -651,18 +774,18 @@ def show_summary_screen(self): # View mode indicator if self.view_mode == "global": - view_label = "[magenta]📊 Global/Lifetime[/magenta]" + view_label = "[magenta]:bar_chart: Global/Lifetime[/magenta]" else: - view_label = "[cyan]📈 Current Period[/cyan]" + view_label = "[cyan]:chart_with_upwards_trend: Current Period[/cyan]" self.console.print("━" * 78) self.console.print( - f"[bold cyan]📈 Quota & Usage Statistics[/bold cyan] | {view_label}" + f"[bold cyan]:chart_with_upwards_trend: Quota & Usage Statistics[/bold cyan] | {view_label}" ) self.console.print("━" * 78) self.console.print( f"Connected to: [bold]{remote_name}[/bold] ({connection_display}) " - f"[green]✅[/green] | {data_age}" + f"[green]:white_check_mark:[/green] | {data_age}" ) self.console.print() @@ -673,17 +796,18 @@ def show_summary_screen(self): table = Table( box=None, show_header=True, header_style="bold", padding=(0, 1) ) - table.add_column("Provider", style="cyan", min_width=10) - table.add_column("Creds", justify="center", min_width=5) - table.add_column("Quota Status", min_width=28) - table.add_column("Requests", justify="right", min_width=8) - table.add_column("Tokens (in/out)", min_width=20) - table.add_column("Cost", justify="right", min_width=6) + table.add_column("Provider", style="cyan", min_width=TABLE_PROVIDER_WIDTH) + table.add_column("Creds", justify="center", min_width=TABLE_CREDS_WIDTH) + table.add_column("Quota Status", min_width=TABLE_QUOTA_STATUS_WIDTH) + table.add_column("Req.", justify="right", min_width=TABLE_REQUESTS_WIDTH) + table.add_column("Tok. in/out(cached)", min_width=TABLE_TOKENS_WIDTH) + table.add_column("Cost", justify="right", min_width=TABLE_COST_WIDTH) providers = self.cached_stats.get("providers", {}) - provider_list = list(providers.keys()) + # Sort providers: quota_groups -> activity -> requests -> creds + sorted_providers = sorted(providers.items(), key=provider_sort_key) - for idx, (provider, prov_stats) in enumerate(providers.items(), 1): + for idx, (provider, prov_stats) in enumerate(sorted_providers, 1): cred_count = prov_stats.get("credential_count", 0) # Use global stats if in global mode @@ -703,7 +827,7 @@ def show_summary_screen(self): ) output = tokens.get("output", 0) cache_pct = tokens.get("input_cache_pct", 0) - token_str = f"{format_tokens(input_total)}/{format_tokens(output)} ({cache_pct}% cached)" + token_str = f"{format_tokens(input_total)}/{format_tokens(output)} ({cache_pct}%)" # Format cost cost_str = format_cost(cost_value) @@ -712,63 +836,100 @@ def show_summary_screen(self): quota_groups = prov_stats.get("quota_groups", {}) if quota_groups: quota_lines = [] - for group_name, group_stats in quota_groups.items(): - # Use remaining requests (not used) so percentage matches displayed value - total_remaining = group_stats.get("total_requests_remaining", 0) - total_max = group_stats.get("total_requests_max", 0) - total_pct = group_stats.get("total_remaining_pct") - tiers = group_stats.get("tiers", {}) + # Sort quota groups by minimum remaining % (lowest first) + sorted_groups = sorted( + quota_groups.items(), key=quota_group_sort_key + ) - # Format tier info: "5(15)f/2s" = 5 active out of 15 free, 2 standard all active - # Sort by priority (lower number = higher priority, appears first) - tier_parts = [] - sorted_tiers = sorted( - tiers.items(), key=lambda x: x[1].get("priority", 10) - ) - for tier_name, tier_info in sorted_tiers: - if tier_name == "unknown": - continue # Skip unknown tiers in display - total_t = tier_info.get("total", 0) - active_t = tier_info.get("active", 0) - # Use first letter: standard-tier -> s, free-tier -> f - short = tier_name.replace("-tier", "")[0] - - if active_t < total_t: - # Some exhausted - show active(total) - tier_parts.append(f"{active_t}({total_t}){short}") + for group_name, group_stats in sorted_groups: + tiers = group_stats.get("tiers", {}) + windows = group_stats.get("windows", {}) + fc_summary = group_stats.get("fair_cycle_summary", {}) + + if not windows: + # No windows = no data, skip + continue + + # Process each window for this group + for window_name, window_stats in windows.items(): + total_remaining = window_stats.get("total_remaining", 0) + total_max = window_stats.get("total_max", 0) + total_pct = window_stats.get("remaining_pct") + tier_availability = window_stats.get( + "tier_availability", {} + ) + + # Format tier info using per-window availability + # "15(18)s" = 15 available out of 18 standard-tier + tier_parts = [] + # Sort tiers by priority (from provider-level tiers) + sorted_tier_names = sorted( + tier_availability.keys(), + key=lambda t: tiers.get(t, {}).get("priority", 10), + ) + for tier_name in sorted_tier_names: + if tier_name == "unknown": + continue + avail_info = tier_availability[tier_name] + total_t = avail_info.get("total", 0) + available_t = avail_info.get("available", 0) + # Use first letter: standard-tier -> s, free-tier -> f + short = tier_name.replace("-tier", "")[0] + + if available_t < total_t: + tier_parts.append( + f"{available_t}({total_t}){short}" + ) + else: + tier_parts.append(f"{total_t}{short}") + + tier_str = "/".join(tier_parts) if tier_parts else "" + + # Only show tier info if this group has limits + if total_max == 0: + tier_str = "" + + # Build FC summary string if any credentials are FC exhausted + fc_str = "" + fc_exhausted = fc_summary.get("exhausted_count", 0) + fc_total = fc_summary.get("total_count", 0) + if fc_exhausted > 0: + fc_str = f"[yellow]{INDICATOR_ICONS['fair_cycle']} {fc_exhausted}/{fc_total}[/yellow]" + + # Determine color based on remaining percentage and FC status + if total_pct is not None: + if total_pct <= 10: + color = "red" + elif total_pct < 30 or fc_exhausted > 0: + color = "yellow" + else: + color = "green" else: - # All active - just show total - tier_parts.append(f"{total_t}{short}") - tier_str = "/".join(tier_parts) if tier_parts else "" - - # Determine color based purely on remaining percentage - if total_pct is not None: - if total_pct <= 10: - color = "red" - elif total_pct < 30: - color = "yellow" + color = "dim" + + pct_str = f"{total_pct}%" if total_pct is not None else "?" + + # Format: "group (window): remaining/max pct% bar tier_info fc_info" + # Show window name if multiple windows exist + if len(windows) > 1: + display_name = f"{group_name} ({window_name})" else: - color = "green" - else: - color = "dim" + display_name = group_name - bar = create_progress_bar(total_pct) - pct_str = f"{total_pct}%" if total_pct is not None else "?" + display_name_trunc = display_name[: QUOTA_NAME_WIDTH - 1] + usage_str = f"{total_remaining}/{total_max}" + bar = create_progress_bar(total_pct, QUOTA_BAR_WIDTH) - # Build status suffix (just tiers now, no outer parens) - status = tier_str + # Build the line with tier info and FC summary + line_parts = [ + f"[{color}]{display_name_trunc + ':':<{QUOTA_NAME_WIDTH}}{usage_str:>{QUOTA_USAGE_WIDTH}} {pct_str:>{QUOTA_PCT_WIDTH}} {bar}[/{color}]" + ] + if tier_str: + line_parts.append(tier_str) + if fc_str: + line_parts.append(fc_str) - # Fixed-width format for aligned bars - # Adjust these to change column spacing: - QUOTA_NAME_WIDTH = 10 # name + colon, left-aligned - QUOTA_USAGE_WIDTH = ( - 12 # remaining/max ratio, right-aligned (handles 100k+) - ) - display_name = group_name[: QUOTA_NAME_WIDTH - 1] - usage_str = f"{total_remaining}/{total_max}" - quota_lines.append( - f"[{color}]{display_name + ':':<{QUOTA_NAME_WIDTH}}{usage_str:>{QUOTA_USAGE_WIDTH}} {pct_str:>4} {bar}[/{color}] {status}" - ) + quota_lines.append(" ".join(line_parts)) # First line goes in the main row first_quota = quota_lines[0] if quota_lines else "-" @@ -795,9 +956,14 @@ def show_summary_screen(self): ) # Add separator between providers (except last) - if idx < len(providers): + if idx < len(sorted_providers): table.add_row( - "─" * 10, "─" * 4, "─" * 26, "─" * 7, "─" * 20, "─" * 6 + "─" * TABLE_PROVIDER_WIDTH, + "─" * TABLE_CREDS_WIDTH, + "─" * TABLE_QUOTA_STATUS_WIDTH, + "─" * TABLE_REQUESTS_WIDTH, + "─" * TABLE_TOKENS_WIDTH, + "─" * TABLE_COST_WIDTH, ) self.console.print(table) @@ -832,9 +998,10 @@ def show_summary_screen(self): self.console.print("━" * 78) self.console.print() - # Build provider menu options + # Build provider menu options (use same sorted order as display) providers = self.cached_stats.get("providers", {}) if self.cached_stats else {} - provider_list = list(providers.keys()) + sorted_providers = sorted(providers.items(), key=provider_sort_key) + provider_list = [name for name, _ in sorted_providers] for idx, provider in enumerate(provider_list, 1): self.console.print(f" {idx}. View [cyan]{provider}[/cyan] details") @@ -886,7 +1053,7 @@ def show_provider_detail_screen(self, provider: str): self.console.print("━" * 78) self.console.print( - f"[bold cyan]📊 {provider.title()} - Detailed Stats[/bold cyan] | {view_label}" + f"[bold cyan]:bar_chart: {provider.title()} - Detailed Stats[/bold cyan] | {view_label}" ) self.console.print("━" * 78) self.console.print() @@ -895,7 +1062,7 @@ def show_provider_detail_screen(self, provider: str): self.console.print("[yellow]No data available.[/yellow]") else: prov_stats = self.cached_stats.get("providers", {}).get(provider, {}) - credentials = prov_stats.get("credentials", []) + credentials = get_credentials_list(prov_stats) # Sort credentials naturally (1, 2, 10 not 1, 10, 2) credentials = sorted(credentials, key=natural_sort_key) @@ -913,8 +1080,6 @@ def show_provider_detail_screen(self, provider: str): self.console.print("━" * 78) self.console.print() self.console.print(" G. Toggle view mode (current/global)") - self.console.print(" R. Reload stats (from proxy cache)") - self.console.print(" RA. Reload all stats") # Force refresh options (only for providers that support it) has_quota_groups = bool( @@ -924,18 +1089,29 @@ def show_provider_detail_screen(self, provider: str): .get("quota_groups") ) + # Model toggle option (only show if provider has quota groups) - MOVED UP + if has_quota_groups: + show_models_status = ( + "ON" if self.config.get_show_models(provider) else "OFF" + ) + self.console.print( + f" T. Toggle model details ({show_models_status})" + ) + + self.console.print(" R. Reload stats (from proxy cache)") + self.console.print(" RA. Reload all stats") + if has_quota_groups: self.console.print() self.console.print( f" F. [yellow]Force refresh ALL {provider} quotas from API[/yellow]" ) - credentials = ( - self.cached_stats.get("providers", {}) - .get(provider, {}) - .get("credentials", []) + prov_stats_for_menu = ( + self.cached_stats.get("providers", {}).get(provider, {}) if self.cached_stats - else [] + else {} ) + credentials = get_credentials_list(prov_stats_for_menu) # Sort credentials naturally credentials = sorted(credentials, key=natural_sort_key) for idx, cred in enumerate(credentials, 1): @@ -945,6 +1121,14 @@ def show_provider_detail_screen(self, provider: str): f" F{idx}. Force refresh [{idx}] only ({email})" ) + # DEBUG: Add fake window for testing multi-window display + if has_quota_groups: + self.console.print() + self.console.print(" [dim]DEBUG:[/dim]") + self.console.print( + " W. [dim]Add fake 'daily' window (test multi-window)[/dim]" + ) + self.console.print() self.console.print(" B. Back to summary") self.console.print() @@ -957,6 +1141,13 @@ def show_provider_detail_screen(self, provider: str): elif choice == "G": # Toggle view mode self.view_mode = "global" if self.view_mode == "current" else "current" + elif choice == "T" and has_quota_groups: + # Toggle show models + new_state = self.config.toggle_show_models(provider) + status_str = "enabled" if new_state else "disabled" + self.console.print( + f"[dim]Model details {status_str} for {provider}[/dim]" + ) elif choice == "R": with self.console.status( f"[bold]Reloading {provider} stats...", spinner="dots" @@ -987,15 +1178,21 @@ def show_provider_detail_screen(self, provider: str): for err in rr["errors"]: self.console.print(f"[red] Error: {err}[/red]") Prompt.ask("Press Enter to continue", default="") + elif choice == "W" and has_quota_groups: + # DEBUG: Inject fake "daily" window for testing multi-window display + self._inject_fake_daily_window(provider) + self.console.print( + "[dim]Injected fake 'daily' window into cached stats[/dim]" + ) + Prompt.ask("Press Enter to continue", default="") elif choice.startswith("F") and choice[1:].isdigit() and has_quota_groups: idx = int(choice[1:]) - credentials = ( - self.cached_stats.get("providers", {}) - .get(provider, {}) - .get("credentials", []) + prov_stats_for_refresh = ( + self.cached_stats.get("providers", {}).get(provider, {}) if self.cached_stats - else [] + else {} ) + credentials = get_credentials_list(prov_stats_for_refresh) # Sort credentials naturally to match display order credentials = sorted(credentials, key=natural_sort_key) if 1 <= idx <= len(credentials): @@ -1030,44 +1227,46 @@ def _render_credential_panel(self, idx: int, cred: Dict[str, Any], provider: str tier = cred.get("tier", "") status = cred.get("status", "unknown") - # Check for active cooldowns - key_cooldown = cred.get("key_cooldown_remaining") - model_cooldowns = cred.get("model_cooldowns", {}) - has_cooldown = key_cooldown or model_cooldowns - - # Status indicator - if status == "exhausted": - status_icon = "[red]⛔ Exhausted[/red]" - elif status == "cooldown" or has_cooldown: - if key_cooldown: - status_icon = f"[yellow]âš ī¸ Cooldown ({format_cooldown(int(key_cooldown))})[/yellow]" + # Check for active cooldowns (new format: cooldowns dict) + cooldowns = cred.get("cooldowns", {}) + has_cooldown = bool(cooldowns) + + # Check for global cooldown + global_cooldown = cooldowns.get("_global_", {}) + global_cooldown_remaining = ( + global_cooldown.get("remaining_seconds", 0) if global_cooldown else 0 + ) + + # Status indicator using centralized config + status_config = STATUS_DISPLAY.get(status, STATUS_DISPLAY["active"]) + icon, label, color = status_config + + if status == "cooldown" or (has_cooldown and global_cooldown_remaining > 0): + if global_cooldown_remaining > 0: + status_icon = f"[{color}]{icon} {label} ({format_cooldown(int(global_cooldown_remaining))})[/{color}]" else: - status_icon = "[yellow]âš ī¸ Cooldown[/yellow]" + status_icon = f"[{color}]{icon} {label}[/{color}]" else: - status_icon = "[green]✅ Active[/green]" + status_icon = f"[{color}]{icon} {label}[/{color}]" # Header line display_name = email if email else identifier tier_str = f" ({tier})" if tier else "" header = f"[{idx}] {display_name}{tier_str} {status_icon}" - # Use global stats if in global mode - if self.view_mode == "global": - stats_source = cred.get("global", cred) - else: - stats_source = cred - - # Stats line - last_used = format_time_ago(cred.get("last_used_ts")) # Always from current - requests = stats_source.get("requests", 0) - tokens = stats_source.get("tokens", {}) + # Get stats using helper function + cred_stats = get_credential_stats(cred, self.view_mode) + last_used = format_time_ago(cred_stats.get("last_used_ts")) + requests = cred_stats.get("requests", 0) + tokens = cred_stats.get("tokens", {}) input_total = tokens.get("input_cached", 0) + tokens.get("input_uncached", 0) output = tokens.get("output", 0) - cost = format_cost(stats_source.get("approx_cost")) + cache_pct = tokens.get("input_cache_pct", 0) + cost = format_cost(cred_stats.get("approx_cost")) stats_line = ( f"Last used: {last_used} | Requests: {requests} | " - f"Tokens: {format_tokens(input_total)}/{format_tokens(output)}" + f"Tokens: {format_tokens(input_total)}/{format_tokens(output)} ({cache_pct}%)" ) if cost != "-": stats_line += f" | Cost: {cost}" @@ -1077,136 +1276,188 @@ def _render_credential_panel(self, idx: int, cred: Dict[str, Any], provider: str f"[dim]{stats_line}[/dim]", ] - # Model groups (for providers with quota tracking) - model_groups = cred.get("model_groups", {}) - - # Show cooldowns grouped by quota group (if model_groups exist) - if model_cooldowns: - if model_groups: - # Group cooldowns by quota group - group_cooldowns: Dict[ - str, int - ] = {} # group_name -> max_remaining_seconds - ungrouped_cooldowns: List[Tuple[str, int]] = [] - - for model_name, cooldown_info in model_cooldowns.items(): - remaining = cooldown_info.get("remaining_seconds", 0) - if remaining <= 0: - continue - - # Find which group this model belongs to - clean_model = model_name.split("/")[-1] - found_group = None - for group_name, group_info in model_groups.items(): - group_models = group_info.get("models", []) - if clean_model in group_models: - found_group = group_name - break - - if found_group: - group_cooldowns[found_group] = max( - group_cooldowns.get(found_group, 0), remaining + # Group usage (for providers with quota tracking) + group_usage = cred.get("group_usage", {}) + + # Show cooldowns + if cooldowns: + active_cooldowns = [] + for key, cooldown_info in cooldowns.items(): + if key == "_global_": + continue # Already shown in status + remaining = cooldown_info.get("remaining_seconds", 0) + if remaining > 0: + reason = cooldown_info.get("reason", "") + source = cooldown_info.get("source", "") + active_cooldowns.append((key, remaining, reason, source)) + + if active_cooldowns: + content_lines.append("") + content_lines.append("[yellow]Active Cooldowns:[/yellow]") + for key, remaining, reason, source in active_cooldowns: + source_str = f" ({source})" if source else "" + content_lines.append( + f" [yellow]:stopwatch: {key}: {format_cooldown(int(remaining))}{source_str}[/yellow]" + ) + + # Display group usage with per-window breakdown + # Note: group_usage is pre-sorted by limit (lowest first) from the API + if group_usage: + content_lines.append("") + content_lines.append("[bold]Quota Groups:[/bold]") + + for group_name, group_stats in group_usage.items(): + windows = group_stats.get("windows", {}) + if not windows: + continue + + # Get per-group status info + fc_exhausted = group_stats.get("fair_cycle_exhausted", False) + fc_reason = group_stats.get("fair_cycle_reason") + group_cooldown_remaining = group_stats.get("cooldown_remaining") + group_cooldown_source = group_stats.get("cooldown_source") + custom_cap = group_stats.get("custom_cap") + + for window_name, window_stats in windows.items(): + request_count = window_stats.get("request_count", 0) + limit = window_stats.get("limit") + remaining = window_stats.get("remaining") + reset_at = window_stats.get("reset_at") + max_recorded = window_stats.get("max_recorded_requests") + max_recorded_at = window_stats.get("max_recorded_at") + + # Calculate remaining percentage + if limit is not None and limit > 0: + remaining_val = ( + remaining + if remaining is not None + else max(0, limit - request_count) ) + remaining_pct = round(remaining_val / limit * 100, 1) + is_exhausted = remaining_val <= 0 else: - ungrouped_cooldowns.append((model_name, remaining)) - - if group_cooldowns or ungrouped_cooldowns: - content_lines.append("") - content_lines.append("[yellow]Active Cooldowns:[/yellow]") + remaining_pct = None + remaining_val = None + is_exhausted = False + + # Format reset time (only show if there's actual usage or cooldown) + reset_time_str = "" + if reset_at and (request_count > 0 or group_cooldown_remaining): + try: + reset_dt = datetime.fromtimestamp(reset_at) + reset_time_str = reset_dt.strftime("%b %d %H:%M") + except (ValueError, OSError): + reset_time_str = "" + + # Format max recorded info + max_info = "" + if max_recorded is not None: + max_info = f" Max: {max_recorded}" + if max_recorded_at: + try: + max_dt = datetime.fromtimestamp(max_recorded_at) + max_info += f" @ {max_dt.strftime('%b %d')}" + except (ValueError, OSError): + pass + + # Build display line + bar = create_progress_bar(remaining_pct) + + # Determine color (account for fair cycle) + if is_exhausted: + color = "red" + elif fc_exhausted: + color = "yellow" + elif remaining_pct is not None and remaining_pct < 20: + color = "yellow" + else: + color = "green" - # Show grouped cooldowns - for group_name in sorted(group_cooldowns.keys()): - remaining = group_cooldowns[group_name] - content_lines.append( - f" [yellow]âąī¸ {group_name}: {format_cooldown(remaining)}[/yellow]" + # Format group name + if len(windows) > 1: + display_name = f"{group_name} ({window_name})" + else: + display_name = group_name + + # Format usage string with custom cap if applicable + if custom_cap: + cap_remaining = custom_cap.get("remaining", 0) + cap_limit = custom_cap.get("limit", 0) + api_limit = limit if limit else cap_limit + usage_str = f"{cap_remaining}/{cap_limit}({api_limit})" + # Recalculate percentage based on custom cap + if cap_limit > 0: + remaining_pct = round(cap_remaining / cap_limit * 100, 1) + bar = create_progress_bar(remaining_pct) + pct_str = ( + f"{remaining_pct}%" if remaining_pct is not None else "" ) + elif limit is not None: + usage_str = f"{remaining_val}/{limit}" + pct_str = f"{remaining_pct}%" + else: + usage_str = f"{request_count} req" + pct_str = "" - # Show ungrouped (shouldn't happen often) - for model_name, remaining in ungrouped_cooldowns: - short_model = model_name.split("/")[-1][:35] - content_lines.append( - f" [yellow]âąī¸ {short_model}: {format_cooldown(remaining)}[/yellow]" + line = f" [{color}]{display_name:<{DETAIL_GROUP_NAME_WIDTH}} {usage_str:<{DETAIL_USAGE_WIDTH}} {pct_str:>{DETAIL_PCT_WIDTH}} {bar}[/{color}]" + + # Add reset time if applicable + if reset_time_str: + line += f" Resets: {reset_time_str}" + + # Add indicators + indicators = [] + + # Group cooldown indicator (if not already showing reset time) + if group_cooldown_remaining and not reset_time_str: + indicators.append( + f"[yellow]{INDICATOR_ICONS['cooldown']} {format_cooldown(int(group_cooldown_remaining))}[/yellow]" ) - else: - # No model groups - show per-model cooldowns - content_lines.append("") - content_lines.append("[yellow]Active Cooldowns:[/yellow]") - for model_name, cooldown_info in model_cooldowns.items(): - remaining = cooldown_info.get("remaining_seconds", 0) - if remaining > 0: - short_model = model_name.split("/")[-1][:35] - content_lines.append( - f" [yellow]âąī¸ {short_model}: {format_cooldown(int(remaining))}[/yellow]" + + # Fair cycle indicator + if fc_exhausted: + indicators.append( + f"[yellow]{INDICATOR_ICONS['fair_cycle']} FC[/yellow]" ) - # Display model groups with quota info - if model_groups: - content_lines.append("") - for group_name, group_stats in model_groups.items(): - remaining_pct = group_stats.get("remaining_pct") - requests_used = group_stats.get("requests_used", 0) - requests_max = group_stats.get("requests_max") - requests_remaining = group_stats.get("requests_remaining") - is_exhausted = group_stats.get("is_exhausted", False) - reset_time = format_reset_time(group_stats.get("reset_time_iso")) - confidence = group_stats.get("confidence", "low") - - # Format display - use requests_remaining/max format - if requests_remaining is None and requests_max: - requests_remaining = max(0, requests_max - requests_used) - display = group_stats.get( - "display", f"{requests_remaining or 0}/{requests_max or '?'}" - ) - bar = create_progress_bar(remaining_pct) + # Custom cap indicator (only if not on cooldown, just to show cap exists) + if custom_cap and not group_cooldown_remaining and not fc_exhausted: + indicators.append(f"[dim]{INDICATOR_ICONS['custom_cap']}[/dim]") - # Build status text - always show reset time if available - has_reset_time = reset_time and reset_time != "-" + if indicators: + line += " " + " ".join(indicators) - # Color based on status - if is_exhausted: - color = "red" - if has_reset_time: - status_text = f"⛔ Resets: {reset_time}" - else: - status_text = "⛔ EXHAUSTED" - elif remaining_pct is not None and remaining_pct < 20: - color = "yellow" - if has_reset_time: - status_text = f"âš ī¸ Resets: {reset_time}" - else: - status_text = "âš ī¸ LOW" - else: - color = "green" - if has_reset_time: - status_text = f"Resets: {reset_time}" - else: - status_text = "" # Hide if unused/no reset time + # Add max info at the end + if max_info: + line += f" [dim]{max_info}[/dim]" - # Confidence indicator - conf_indicator = "" - if confidence == "low": - conf_indicator = " [dim](~)[/dim]" - elif confidence == "medium": - conf_indicator = " [dim](?)[/dim]" + content_lines.append(line) - pct_str = f"{remaining_pct}%" if remaining_pct is not None else "?%" + # Model usage (show if no group usage, or if toggle enabled via config) + model_usage = cred.get("model_usage", {}) + # Check config for show_models setting, default to showing only if no group_usage + show_models = self.config.get_show_models(provider) if group_usage else True + + if show_models and model_usage: + content_lines.append("") + content_lines.append("[dim]Models used:[/dim]") + for model_name, model_stats in model_usage.items(): + totals = model_stats.get("totals", {}) + req_count = totals.get("request_count", 0) + if req_count == 0: + continue # Skip models with no usage + + prompt = totals.get("prompt_tokens", 0) + cache_read = totals.get("prompt_tokens_cache_read", 0) + output_tokens = totals.get("output_tokens", 0) + model_cost = format_cost(totals.get("approx_cost")) + + total_input = prompt + cache_read + short_name = model_name.split("/")[-1][:30] content_lines.append( - f" [{color}]{group_name:<18} {display:<10} {pct_str:>4} {bar}[/{color}] {status_text}{conf_indicator}" + f" {short_name}: {req_count} req | {format_tokens(total_input)}/{format_tokens(output_tokens)} tokens" + + (f" | {model_cost}" if model_cost != "-" else "") ) - else: - # For providers without quota groups, show model breakdown if available - models = cred.get("models", {}) - if models: - content_lines.append("") - content_lines.append(" [dim]Models used:[/dim]") - for model_name, model_stats in models.items(): - req_count = model_stats.get("success_count", 0) - model_cost = format_cost(model_stats.get("approx_cost")) - # Shorten model name for display - short_name = model_name.split("/")[-1][:30] - content_lines.append( - f" {short_name}: {req_count} requests, {model_cost}" - ) self.console.print( Panel( @@ -1218,12 +1469,86 @@ def _render_credential_panel(self, idx: int, cred: Dict[str, Any], provider: str ) ) + def _inject_fake_daily_window(self, provider: str) -> None: + """ + DEBUG: Inject a fake 'daily' window into cached stats for testing. + + This modifies cached_stats in-place to add a second window to each + quota group, simulating multi-window display without needing real + multi-window data from the API. + """ + if not self.cached_stats: + return + + prov_stats = self.cached_stats.get("providers", {}).get(provider, {}) + if not prov_stats: + return + + # Inject into quota_groups (global view) + quota_groups = prov_stats.get("quota_groups", {}) + for group_name, group_stats in quota_groups.items(): + windows = group_stats.get("windows", {}) + if "daily" not in windows and windows: + # Copy the first window and modify it + first_window = next(iter(windows.values())) + daily_window = { + "total_used": int(first_window.get("total_used", 0) * 0.3), + "total_remaining": int(first_window.get("total_max", 100) * 0.7), + "total_max": first_window.get("total_max", 100), + "remaining_pct": 70.0, + "tier_availability": first_window.get("tier_availability", {}), + } + windows["daily"] = daily_window + + # Inject into credential group_usage (detail view) + credentials = prov_stats.get("credentials", {}) + if isinstance(credentials, dict): + cred_list = credentials.values() + else: + cred_list = credentials + + for cred in cred_list: + group_usage = cred.get("group_usage", {}) + for group_name, group_stats in group_usage.items(): + windows = group_stats.get("windows", {}) + if "daily" not in windows and windows: + # Copy the first window and create a daily version + first_window = next(iter(windows.values())) + limit = first_window.get("limit", 100) + daily_window = { + "request_count": int( + first_window.get("request_count", 0) * 0.3 + ), + "success_count": int( + first_window.get("success_count", 0) * 0.3 + ), + "failure_count": 0, + "prompt_tokens": 0, + "completion_tokens": 0, + "thinking_tokens": 0, + "output_tokens": 0, + "prompt_tokens_cache_read": 0, + "prompt_tokens_cache_write": 0, + "total_tokens": 0, + "limit": limit, + "remaining": int(limit * 0.7), + "max_recorded_requests": None, + "max_recorded_at": None, + "reset_at": time.time() + 3600, # 1 hour from now + "approx_cost": 0.0, + "first_used_at": None, + "last_used_at": None, + } + windows["daily"] = daily_window + def show_switch_remote_screen(self): """Display remote selection screen.""" clear_screen() self.console.print("━" * 78) - self.console.print("[bold cyan]🔄 Switch Remote[/bold cyan]") + self.console.print( + "[bold cyan]:arrows_counterclockwise: Switch Remote[/bold cyan]" + ) self.console.print("━" * 78) self.console.print() @@ -1258,9 +1583,9 @@ def show_switch_remote_screen(self): current_marker = " (current)" if is_current else "" if is_online: - status_icon = "[green]✅ Online[/green]" + status_icon = "[green]:white_check_mark: Online[/green]" else: - status_icon = f"[red]âš ī¸ {status_msg}[/red]" + status_icon = f"[red]:warning: {status_msg}[/red]" self.console.print( f" {idx}. {name:<20} {connection_display:<30} {status_icon}{current_marker}" @@ -1330,7 +1655,7 @@ def show_manage_remotes_screen(self): clear_screen() self.console.print("━" * 78) - self.console.print("[bold cyan]âš™ī¸ Manage Remotes[/bold cyan]") + self.console.print("[bold cyan]:gear: Manage Remotes[/bold cyan]") self.console.print("━" * 78) self.console.print() diff --git a/src/proxy_app/quota_viewer_config.py b/src/proxy_app/quota_viewer_config.py index 7b2f573f..10b0d506 100644 --- a/src/proxy_app/quota_viewer_config.py +++ b/src/proxy_app/quota_viewer_config.py @@ -298,3 +298,48 @@ def get_api_key_from_env(self) -> Optional[str]: except IOError: pass return None + + def get_show_models(self, provider: str) -> bool: + """ + Get whether to show model breakdown for a provider. + + Args: + provider: Provider name + + Returns: + True if models should be shown, False otherwise + """ + provider_settings = self.config.get("provider_settings", {}) + return provider_settings.get(provider, {}).get("show_models", False) + + def set_show_models(self, provider: str, show: bool) -> bool: + """ + Set whether to show model breakdown for a provider. + + Args: + provider: Provider name + show: Whether to show models + + Returns: + True on success + """ + if "provider_settings" not in self.config: + self.config["provider_settings"] = {} + if provider not in self.config["provider_settings"]: + self.config["provider_settings"][provider] = {} + self.config["provider_settings"][provider]["show_models"] = show + return self._save() + + def toggle_show_models(self, provider: str) -> bool: + """ + Toggle the show_models setting for a provider. + + Args: + provider: Provider name + + Returns: + The new value (True if now showing, False if now hidden) + """ + current = self.get_show_models(provider) + self.set_show_models(provider, not current) + return not current diff --git a/src/proxy_app/settings_tool.py b/src/proxy_app/settings_tool.py index 689839bb..57b7eb3b 100644 --- a/src/proxy_app/settings_tool.py +++ b/src/proxy_app/settings_tool.py @@ -673,7 +673,7 @@ def _format_item( def _get_pending_status_text(self) -> str: """Get formatted pending changes status text for main menu.""" if not self.settings.has_pending(): - return "[dim]â„šī¸ No pending changes[/dim]" + return "[dim]:information_source: No pending changes[/dim]" counts = self.settings.get_pending_counts() parts = [] @@ -690,7 +690,7 @@ def _get_pending_status_text(self) -> str: f"[red]{counts['remove']} removal{'s' if counts['remove'] > 1 else ''}[/red]" ) - return f"[bold]â„šī¸ Pending changes: {', '.join(parts)}[/bold]" + return f"[bold]:information_source: Pending changes: {', '.join(parts)}[/bold]" self.running = True def get_available_providers(self) -> List[str]: @@ -739,21 +739,21 @@ def show_main_menu(self): self.console.print( Panel.fit( - "[bold cyan]🔧 Advanced Settings Configuration[/bold cyan]", + "[bold cyan]:wrench: Advanced Settings Configuration[/bold cyan]", border_style="cyan", ) ) self.console.print() - self.console.print("[bold]âš™ī¸ Configuration Categories[/bold]") + self.console.print("[bold]:gear: Configuration Categories[/bold]") self.console.print() - self.console.print(" 1. 🌐 Custom Provider API Bases") + self.console.print(" 1. :globe_with_meridians: Custom Provider API Bases") self.console.print(" 2. đŸ“Ļ Provider Model Definitions") self.console.print(" 3. ⚡ Concurrency Limits") - self.console.print(" 4. 🔄 Rotation Modes") + self.console.print(" 4. :arrows_counterclockwise: Rotation Modes") self.console.print(" 5. đŸ”Ŧ Provider-Specific Settings") - self.console.print(" 6. đŸŽ¯ Model Filters (Ignore/Whitelist)") - self.console.print(" 7. 💾 Save & Exit") + self.console.print(" 6. :dart: Model Filters (Ignore/Whitelist)") + self.console.print(" 7. :floppy_disk: Save & Exit") self.console.print(" 8. đŸšĢ Exit Without Saving") self.console.print() @@ -796,13 +796,13 @@ def manage_custom_providers(self): self.console.print( Panel.fit( - "[bold cyan]🌐 Custom Provider API Bases[/bold cyan]", + "[bold cyan]:globe_with_meridians: Custom Provider API Bases[/bold cyan]", border_style="cyan", ) ) self.console.print() - self.console.print("[bold]📋 Configured Custom Providers[/bold]") + self.console.print("[bold]:clipboard: Configured Custom Providers[/bold]") self.console.print("━" * 70) # Build combined view with pending changes @@ -853,11 +853,11 @@ def manage_custom_providers(self): self.console.print() self.console.print("━" * 70) self.console.print() - self.console.print("[bold]âš™ī¸ Actions[/bold]") + self.console.print("[bold]:gear: Actions[/bold]") self.console.print() self.console.print(" 1. ➕ Add New Custom Provider") self.console.print(" 2. âœī¸ Edit Existing Provider") - self.console.print(" 3. đŸ—‘ī¸ Remove Provider") + self.console.print(" 3. :wastebasket: Remove Provider") self.console.print(" 4. â†Šī¸ Back to Settings Menu") self.console.print() @@ -875,7 +875,7 @@ def manage_custom_providers(self): if api_base: self.provider_mgr.add_provider(name, api_base) self.console.print( - f"\n[green]✅ Custom provider '{name}' staged![/green]" + f"\n[green]:white_check_mark: Custom provider '{name}' staged![/green]" ) self.console.print( f" To use: set {name.upper()}_API_KEY in credentials" @@ -915,7 +915,7 @@ def manage_custom_providers(self): if new_base and new_base != current_base: self.provider_mgr.edit_provider(name, new_base) self.console.print( - f"\n[green]✅ Custom provider '{name}' updated![/green]" + f"\n[green]:white_check_mark: Custom provider '{name}' updated![/green]" ) else: self.console.print("\n[yellow]No changes made[/yellow]") @@ -962,12 +962,12 @@ def manage_custom_providers(self): key = f"{name.upper()}_API_BASE" del self.settings.pending_changes[key] self.console.print( - f"\n[green]✅ Pending addition of '{name}' cancelled![/green]" + f"\n[green]:white_check_mark: Pending addition of '{name}' cancelled![/green]" ) else: self.provider_mgr.remove_provider(name) self.console.print( - f"\n[green]✅ Provider '{name}' marked for removal![/green]" + f"\n[green]:white_check_mark: Provider '{name}' marked for removal![/green]" ) input("\nPress Enter to continue...") @@ -990,7 +990,7 @@ def manage_model_definitions(self): ) self.console.print() - self.console.print("[bold]📋 Configured Provider Models[/bold]") + self.console.print("[bold]:clipboard: Configured Provider Models[/bold]") self.console.print("━" * 70) # Build combined view with pending changes @@ -1063,12 +1063,12 @@ def manage_model_definitions(self): self.console.print() self.console.print("━" * 70) self.console.print() - self.console.print("[bold]âš™ī¸ Actions[/bold]") + self.console.print("[bold]:gear: Actions[/bold]") self.console.print() self.console.print(" 1. ➕ Add Models for Provider") self.console.print(" 2. âœī¸ Edit Provider Models") self.console.print(" 3. đŸ‘ī¸ View Provider Models") - self.console.print(" 4. đŸ—‘ī¸ Remove Provider Models") + self.console.print(" 4. :wastebasket: Remove Provider Models") self.console.print(" 5. â†Šī¸ Back to Settings Menu") self.console.print() @@ -1141,12 +1141,12 @@ def manage_model_definitions(self): key = f"{provider.upper()}{suffix}" del self.settings.pending_changes[key] self.console.print( - f"\n[green]✅ Pending models for '{provider}' cancelled![/green]" + f"\n[green]:white_check_mark: Pending models for '{provider}' cancelled![/green]" ) else: self.model_mgr.remove_models(provider) self.console.print( - f"\n[green]✅ Model definitions marked for removal for '{provider}'![/green]" + f"\n[green]:white_check_mark: Model definitions marked for removal for '{provider}'![/green]" ) input("\nPress Enter to continue...") elif choice == "5": @@ -1244,7 +1244,7 @@ def add_model_definitions(self): if models: self.model_mgr.set_models(provider, models) self.console.print( - f"\n[green]✅ Model definitions saved for '{provider}'![/green]" + f"\n[green]:white_check_mark: Model definitions saved for '{provider}'![/green]" ) else: self.console.print("\n[yellow]No models added[/yellow]") @@ -1343,7 +1343,9 @@ def edit_model_definitions(self, providers: List[str]): if current_models: self.model_mgr.set_models(provider, current_models) - self.console.print(f"\n[green]✅ Models updated for '{provider}'![/green]") + self.console.print( + f"\n[green]:white_check_mark: Models updated for '{provider}'![/green]" + ) else: self.console.print( "\n[yellow]No models left - removing definition[/yellow]" @@ -1433,7 +1435,7 @@ def manage_provider_settings(self): self.console.print() self.console.print( - "[bold]📋 Available Providers with Custom Settings[/bold]" + "[bold]:clipboard: Available Providers with Custom Settings[/bold]" ) self.console.print("━" * 70) @@ -1450,7 +1452,7 @@ def manage_provider_settings(self): self.console.print() self.console.print("━" * 70) self.console.print() - self.console.print("[bold]âš™ī¸ Select Provider to Configure[/bold]") + self.console.print("[bold]:gear: Select Provider to Configure[/bold]") self.console.print() for idx, provider in enumerate(available_providers, 1): @@ -1492,7 +1494,7 @@ def _manage_single_provider_settings(self, provider: str): ) self.console.print() - self.console.print("[bold]📋 Current Settings[/bold]") + self.console.print("[bold]:clipboard: Current Settings[/bold]") self.console.print("━" * 70) # Display all settings with current values and pending changes @@ -1587,11 +1589,13 @@ def _manage_single_provider_settings(self, provider: str): "[dim]* = modified from default, + = pending add, ~ = pending edit, - = pending reset[/dim]" ) self.console.print() - self.console.print("[bold]âš™ī¸ Actions[/bold]") + self.console.print("[bold]:gear: Actions[/bold]") self.console.print() self.console.print(" E. âœī¸ Edit a Setting") - self.console.print(" R. 🔄 Reset Setting to Default") - self.console.print(" A. 🔄 Reset All to Defaults") + self.console.print( + " R. :arrows_counterclockwise: Reset Setting to Default" + ) + self.console.print(" A. :arrows_counterclockwise: Reset All to Defaults") self.console.print(" B. â†Šī¸ Back to Provider Selection") self.console.print() @@ -1641,18 +1645,24 @@ def _edit_provider_setting( new_value = Confirm.ask("\nEnable this setting?", default=current) self.provider_settings_mgr.set_value(key, new_value, definition) status = "enabled" if new_value else "disabled" - self.console.print(f"\n[green]✅ {short_key} {status}![/green]") + self.console.print( + f"\n[green]:white_check_mark: {short_key} {status}![/green]" + ) elif setting_type == "int": new_value = IntPrompt.ask("\nNew value", default=current) self.provider_settings_mgr.set_value(key, new_value, definition) - self.console.print(f"\n[green]✅ {short_key} set to {new_value}![/green]") + self.console.print( + f"\n[green]:white_check_mark: {short_key} set to {new_value}![/green]" + ) else: new_value = Prompt.ask( "\nNew value", default=str(current) if current else "" ).strip() if new_value: self.provider_settings_mgr.set_value(key, new_value, definition) - self.console.print(f"\n[green]✅ {short_key} updated![/green]") + self.console.print( + f"\n[green]:white_check_mark: {short_key} updated![/green]" + ) else: self.console.print("\n[yellow]No changes made[/yellow]") @@ -1677,7 +1687,9 @@ def _reset_provider_setting( if Confirm.ask(f"\nReset {short_key} to default ({default})?"): self.provider_settings_mgr.reset_to_default(key) - self.console.print(f"\n[green]✅ {short_key} reset to default![/green]") + self.console.print( + f"\n[green]:white_check_mark: {short_key} reset to default![/green]" + ) else: self.console.print("\n[yellow]No changes made[/yellow]") @@ -1693,7 +1705,7 @@ def _reset_all_provider_settings(self, provider: str, settings_list: List[str]): for key in settings_list: self.provider_settings_mgr.reset_to_default(key) self.console.print( - f"\n[green]✅ All {display_name} settings reset to defaults![/green]" + f"\n[green]:white_check_mark: All {display_name} settings reset to defaults![/green]" ) else: self.console.print("\n[yellow]No changes made[/yellow]") @@ -1711,13 +1723,13 @@ def manage_rotation_modes(self): self.console.print( Panel.fit( - "[bold cyan]🔄 Credential Rotation Mode Configuration[/bold cyan]", + "[bold cyan]:arrows_counterclockwise: Credential Rotation Mode Configuration[/bold cyan]", border_style="cyan", ) ) self.console.print() - self.console.print("[bold]📋 Rotation Modes Explained[/bold]") + self.console.print("[bold]:clipboard: Rotation Modes Explained[/bold]") self.console.print("━" * 70) self.console.print( " [cyan]balanced[/cyan] - Rotate credentials evenly across requests (default)" @@ -1726,7 +1738,9 @@ def manage_rotation_modes(self): " [cyan]sequential[/cyan] - Use one credential until exhausted (429), then switch" ) self.console.print() - self.console.print("[bold]📋 Current Rotation Mode Settings[/bold]") + self.console.print( + "[bold]:clipboard: Current Rotation Mode Settings[/bold]" + ) self.console.print("━" * 70) # Build combined view with pending changes @@ -1821,10 +1835,10 @@ def manage_rotation_modes(self): "[dim]* = custom setting (differs from provider default)[/dim]" ) self.console.print() - self.console.print("[bold]âš™ī¸ Actions[/bold]") + self.console.print("[bold]:gear: Actions[/bold]") self.console.print() self.console.print(" 1. ➕ Set Rotation Mode for Provider") - self.console.print(" 2. đŸ—‘ī¸ Reset to Provider Default") + self.console.print(" 2. :wastebasket: Reset to Provider Default") self.console.print(" 3. ⚡ Configure Priority Concurrency Multipliers") self.console.print(" 4. â†Šī¸ Back to Settings Menu") @@ -1888,7 +1902,7 @@ def manage_rotation_modes(self): self.rotation_mgr.set_mode(provider, new_mode) self.console.print( - f"\n[green]✅ Rotation mode for '{provider}' staged as {new_mode}![/green]" + f"\n[green]:white_check_mark: Rotation mode for '{provider}' staged as {new_mode}![/green]" ) input("\nPress Enter to continue...") @@ -1935,12 +1949,12 @@ def manage_rotation_modes(self): key = f"{prefix}{provider.upper()}" del self.settings.pending_changes[key] self.console.print( - f"\n[green]✅ Pending mode for '{provider}' cancelled![/green]" + f"\n[green]:white_check_mark: Pending mode for '{provider}' cancelled![/green]" ) else: self.rotation_mgr.remove_mode(provider) self.console.print( - f"\n[green]✅ Rotation mode for '{provider}' marked for reset to default ({default_mode})![/green]" + f"\n[green]:white_check_mark: Rotation mode for '{provider}' marked for reset to default ({default_mode})![/green]" ) input("\nPress Enter to continue...") @@ -1965,7 +1979,9 @@ def manage_priority_multipliers(self): ) self.console.print() - self.console.print("[bold]📋 Current Priority Multiplier Settings[/bold]") + self.console.print( + "[bold]:clipboard: Current Priority Multiplier Settings[/bold]" + ) self.console.print("━" * 70) # Show all providers with their priority multipliers @@ -2009,7 +2025,9 @@ def manage_priority_multipliers(self): self.console.print(" [dim]No priority multipliers configured[/dim]") self.console.print() - self.console.print("[bold]â„šī¸ About Priority Multipliers:[/bold]") + self.console.print( + "[bold]:information_source: About Priority Multipliers:[/bold]" + ) self.console.print( " Higher priority tiers (lower numbers) can have higher multipliers." ) @@ -2018,7 +2036,7 @@ def manage_priority_multipliers(self): self.console.print("━" * 70) self.console.print() self.console.print(" 1. âœī¸ Set Priority Multiplier") - self.console.print(" 2. 🔄 Reset to Provider Default") + self.console.print(" 2. :arrows_counterclockwise: Reset to Provider Default") self.console.print(" 3. â†Šī¸ Back") choice = Prompt.ask( @@ -2059,7 +2077,7 @@ def manage_priority_multipliers(self): provider, priority, multiplier ) self.console.print( - f"\n[green]✅ Priority {priority} multiplier for '{provider}' set to {multiplier}x[/green]" + f"\n[green]:white_check_mark: Priority {priority} multiplier for '{provider}' set to {multiplier}x[/green]" ) else: self.console.print( @@ -2101,7 +2119,7 @@ def manage_priority_multipliers(self): provider, priority ) self.console.print( - f"\n[green]✅ Reset priority {priority} for '{provider}' to default ({default}x)[/green]" + f"\n[green]:white_check_mark: Reset priority {priority} for '{provider}' to default ({default}x)[/green]" ) else: self.console.print( @@ -2125,7 +2143,7 @@ def manage_concurrency_limits(self): ) self.console.print() - self.console.print("[bold]📋 Current Concurrency Settings[/bold]") + self.console.print("[bold]:clipboard: Current Concurrency Settings[/bold]") self.console.print("━" * 70) # Build combined view with pending changes @@ -2185,11 +2203,11 @@ def manage_concurrency_limits(self): self.console.print() self.console.print("━" * 70) self.console.print() - self.console.print("[bold]âš™ī¸ Actions[/bold]") + self.console.print("[bold]:gear: Actions[/bold]") self.console.print() self.console.print(" 1. ➕ Add Concurrency Limit for Provider") self.console.print(" 2. âœī¸ Edit Existing Limit") - self.console.print(" 3. đŸ—‘ī¸ Remove Limit (reset to default)") + self.console.print(" 3. :wastebasket: Remove Limit (reset to default)") self.console.print(" 4. â†Šī¸ Back to Settings Menu") self.console.print() @@ -2236,11 +2254,11 @@ def manage_concurrency_limits(self): if 1 <= limit <= 100: self.concurrency_mgr.set_limit(provider, limit) self.console.print( - f"\n[green]✅ Concurrency limit staged for '{provider}': {limit} requests/key[/green]" + f"\n[green]:white_check_mark: Concurrency limit staged for '{provider}': {limit} requests/key[/green]" ) else: self.console.print( - "\n[red]❌ Limit must be between 1-100[/red]" + "\n[red]:x: Limit must be between 1-100[/red]" ) input("\nPress Enter to continue...") @@ -2278,7 +2296,7 @@ def manage_concurrency_limits(self): if new_limit != current_limit: self.concurrency_mgr.set_limit(provider, new_limit) self.console.print( - f"\n[green]✅ Concurrency limit updated for '{provider}': {new_limit} requests/key[/green]" + f"\n[green]:white_check_mark: Concurrency limit updated for '{provider}': {new_limit} requests/key[/green]" ) else: self.console.print("\n[yellow]No changes made[/yellow]") @@ -2330,12 +2348,12 @@ def manage_concurrency_limits(self): key = f"{prefix}{provider.upper()}" del self.settings.pending_changes[key] self.console.print( - f"\n[green]✅ Pending limit for '{provider}' cancelled![/green]" + f"\n[green]:white_check_mark: Pending limit for '{provider}' cancelled![/green]" ) else: self.concurrency_mgr.remove_limit(provider) self.console.print( - f"\n[green]✅ Limit marked for removal for '{provider}'[/green]" + f"\n[green]:white_check_mark: Limit marked for removal for '{provider}'[/green]" ) input("\nPress Enter to continue...") @@ -2346,7 +2364,7 @@ def _show_changes_summary(self): """Display categorized summary of all pending changes.""" self.console.print( Panel.fit( - "[bold cyan]📋 Pending Changes Summary[/bold cyan]", + "[bold cyan]:clipboard: Pending Changes Summary[/bold cyan]", border_style="cyan", ) ) @@ -2439,7 +2457,9 @@ def save_and_exit(self): if Confirm.ask("\n[bold yellow]Save all pending changes?[/bold yellow]"): self.settings.save() - self.console.print("\n[green]✅ All changes saved to .env![/green]") + self.console.print( + "\n[green]:white_check_mark: All changes saved to .env![/green]" + ) input("\nPress Enter to return to launcher...") else: self.console.print("\n[yellow]Changes not saved[/yellow]") diff --git a/src/rotator_library/README.md b/src/rotator_library/README.md index 22d2bf6e..657032ce 100644 --- a/src/rotator_library/README.md +++ b/src/rotator_library/README.md @@ -23,7 +23,7 @@ A robust, asynchronous, and thread-safe Python library for managing a pool of AP - **Credential Prioritization**: Automatic tier detection and priority-based credential selection (e.g., paid tier credentials used first for models that require them). - **Advanced Model Requirements**: Support for model-tier restrictions (e.g., Gemini 3 requires paid-tier credentials). - **Robust Streaming Support**: Includes a wrapper for streaming responses that reassembles fragmented JSON chunks. -- **Detailed Usage Tracking**: Tracks daily and global usage for each key, persisted to a JSON file. +- **Detailed Usage Tracking**: Tracks daily and global usage for each key, persisted per provider in `usage/usage_.json`. - **Automatic Daily Resets**: Automatically resets cooldowns and archives stats daily. - **Provider Agnostic**: Works with any provider supported by `litellm`. - **Extensible**: Easily add support for new providers through a simple plugin-based architecture. @@ -73,7 +73,7 @@ client = RotatingClient( api_keys=api_keys, oauth_credentials=oauth_credentials, max_retries=2, - usage_file_path="key_usage.json", + usage_file_path="usage.json", configure_logging=True, global_timeout=30, abort_on_callback_error=True, @@ -91,7 +91,7 @@ client = RotatingClient( - `api_keys` (`Optional[Dict[str, List[str]]]`): A dictionary mapping provider names (e.g., "openai", "anthropic") to a list of API keys. - `oauth_credentials` (`Optional[Dict[str, List[str]]]`): A dictionary mapping provider names (e.g., "gemini_cli", "qwen_code") to a list of file paths to OAuth credential JSON files. - `max_retries` (`int`, default: `2`): The number of times to retry a request with the *same key* if a transient server error (e.g., 500, 503) occurs. -- `usage_file_path` (`str`, default: `"key_usage.json"`): The path to the JSON file where usage statistics (tokens, cost, success counts) are persisted. +- `usage_file_path` (`str`, optional): Base path for usage persistence (defaults to `usage/` in the data directory). The client stores per-provider files under `usage/usage_.json` next to this path. - `configure_logging` (`bool`, default: `True`): If `True`, configures the library's logger to propagate logs to the root logger. Set to `False` if you want to handle logging configuration manually. - `global_timeout` (`int`, default: `30`): A hard time limit (in seconds) for the entire request lifecycle. If the request (including all retries) takes longer than this, it is aborted. - `abort_on_callback_error` (`bool`, default: `True`): If `True`, any exception raised by `pre_request_callback` will abort the request. If `False`, the error is logged and the request proceeds. diff --git a/src/rotator_library/background_refresher.py b/src/rotator_library/background_refresher.py index e3da1f76..acc66c89 100644 --- a/src/rotator_library/background_refresher.py +++ b/src/rotator_library/background_refresher.py @@ -234,9 +234,13 @@ async def _run_provider_background_job( # Run immediately on start if configured if run_on_start: try: - await provider.run_background_job( - self._client.usage_manager, credentials - ) + usage_manager = self._client.usage_managers.get(provider_name) + if usage_manager is None: + lib_logger.debug( + f"Skipping {provider_name} {job_name}: no UsageManager" + ) + return + await provider.run_background_job(usage_manager, credentials) lib_logger.debug(f"{provider_name} {job_name}: initial run complete") except Exception as e: lib_logger.error( @@ -247,9 +251,13 @@ async def _run_provider_background_job( while True: try: await asyncio.sleep(interval) - await provider.run_background_job( - self._client.usage_manager, credentials - ) + usage_manager = self._client.usage_managers.get(provider_name) + if usage_manager is None: + lib_logger.debug( + f"Skipping {provider_name} {job_name}: no UsageManager" + ) + return + await provider.run_background_job(usage_manager, credentials) lib_logger.debug(f"{provider_name} {job_name}: periodic run complete") except asyncio.CancelledError: lib_logger.debug(f"{provider_name} {job_name}: cancelled") @@ -259,6 +267,7 @@ async def _run_provider_background_job( async def _run(self): """The main loop for OAuth token refresh.""" + await self._client.initialize_usage_managers() # Initialize credentials (load persisted tiers) before starting await self._initialize_credentials() diff --git a/src/rotator_library/client.py b/src/rotator_library/client.py deleted file mode 100644 index fdd12d67..00000000 --- a/src/rotator_library/client.py +++ /dev/null @@ -1,3581 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-only -# Copyright (c) 2026 Mirrowel - -import asyncio -import fnmatch -import json -import re -import codecs -import time -import os -import random -import httpx -import litellm -from litellm.exceptions import APIConnectionError -from litellm.litellm_core_utils.token_counter import token_counter -import logging -from pathlib import Path -from typing import List, Dict, Any, AsyncGenerator, Optional, Union, Tuple - -lib_logger = logging.getLogger("rotator_library") -# Ensure the logger is configured to propagate to the root logger -# which is set up in main.py. This allows the main app to control -# log levels and handlers centrally. -lib_logger.propagate = False - -from .usage_manager import UsageManager -from .failure_logger import log_failure, configure_failure_logger -from .error_handler import ( - PreRequestCallbackError, - CredentialNeedsReauthError, - classify_error, - NoAvailableKeysError, - should_rotate_on_error, - should_retry_same_key, - RequestErrorAccumulator, - mask_credential, -) -from .provider_config import ProviderConfig -from .providers import PROVIDER_PLUGINS -from .providers.openai_compatible_provider import OpenAICompatibleProvider -from .request_sanitizer import sanitize_request_payload -from .cooldown_manager import CooldownManager -from .credential_manager import CredentialManager -from .background_refresher import BackgroundRefresher -from .model_definitions import ModelDefinitions -from .transaction_logger import TransactionLogger -from .utils.paths import get_default_root, get_logs_dir, get_oauth_dir, get_data_file -from .utils.suppress_litellm_warnings import suppress_litellm_serialization_warnings -from .config import ( - DEFAULT_MAX_RETRIES, - DEFAULT_GLOBAL_TIMEOUT, - DEFAULT_ROTATION_TOLERANCE, - DEFAULT_FAIR_CYCLE_DURATION, - DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, - DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER, -) - - -class StreamedAPIError(Exception): - """Custom exception to signal an API error received over a stream.""" - - def __init__(self, message, data=None): - super().__init__(message) - self.data = data - - -class RotatingClient: - """ - A client that intelligently rotates and retries API keys using LiteLLM, - with support for both streaming and non-streaming responses. - """ - - def __init__( - self, - api_keys: Optional[Dict[str, List[str]]] = None, - oauth_credentials: Optional[Dict[str, List[str]]] = None, - max_retries: int = DEFAULT_MAX_RETRIES, - usage_file_path: Optional[Union[str, Path]] = None, - configure_logging: bool = True, - global_timeout: int = DEFAULT_GLOBAL_TIMEOUT, - abort_on_callback_error: bool = True, - litellm_provider_params: Optional[Dict[str, Any]] = None, - ignore_models: Optional[Dict[str, List[str]]] = None, - whitelist_models: Optional[Dict[str, List[str]]] = None, - enable_request_logging: bool = False, - max_concurrent_requests_per_key: Optional[Dict[str, int]] = None, - rotation_tolerance: float = DEFAULT_ROTATION_TOLERANCE, - data_dir: Optional[Union[str, Path]] = None, - ): - """ - Initialize the RotatingClient with intelligent credential rotation. - - Args: - api_keys: Dictionary mapping provider names to lists of API keys - oauth_credentials: Dictionary mapping provider names to OAuth credential paths - max_retries: Maximum number of retry attempts per credential - usage_file_path: Path to store usage statistics. If None, uses data_dir/key_usage.json - configure_logging: Whether to configure library logging - global_timeout: Global timeout for requests in seconds - abort_on_callback_error: Whether to abort on pre-request callback errors - litellm_provider_params: Provider-specific parameters for LiteLLM - ignore_models: Models to ignore/blacklist per provider - whitelist_models: Models to explicitly whitelist per provider - enable_request_logging: Whether to enable detailed request logging - max_concurrent_requests_per_key: Max concurrent requests per key by provider - rotation_tolerance: Tolerance for weighted random credential rotation. - - 0.0: Deterministic, least-used credential always selected - - 2.0 - 4.0 (default, recommended): Balanced randomness, can pick credentials within 2 uses of max - - 5.0+: High randomness, more unpredictable selection patterns - data_dir: Root directory for all data files (logs, cache, oauth_creds, key_usage.json). - If None, auto-detects: EXE directory if frozen, else current working directory. - """ - # Resolve data_dir early - this becomes the root for all file operations - if data_dir is not None: - self.data_dir = Path(data_dir).resolve() - else: - self.data_dir = get_default_root() - - # Configure failure logger to use correct logs directory - configure_failure_logger(get_logs_dir(self.data_dir)) - - os.environ["LITELLM_LOG"] = "ERROR" - litellm.set_verbose = False - litellm.drop_params = True - - # Suppress harmless Pydantic serialization warnings from litellm - # See: https://github.com/BerriAI/litellm/issues/11759 - # TODO: Remove this workaround once litellm patches the issue - suppress_litellm_serialization_warnings() - - if configure_logging: - # When True, this allows logs from this library to be handled - # by the parent application's logging configuration. - lib_logger.propagate = True - # Remove any default handlers to prevent duplicate logging - if lib_logger.hasHandlers(): - lib_logger.handlers.clear() - lib_logger.addHandler(logging.NullHandler()) - else: - lib_logger.propagate = False - - api_keys = api_keys or {} - oauth_credentials = oauth_credentials or {} - - # Filter out providers with empty lists of credentials to ensure validity - api_keys = {provider: keys for provider, keys in api_keys.items() if keys} - oauth_credentials = { - provider: paths for provider, paths in oauth_credentials.items() if paths - } - - if not api_keys and not oauth_credentials: - lib_logger.warning( - "No provider credentials configured. The client will be unable to make any API requests." - ) - - self.api_keys = api_keys - # Use provided oauth_credentials directly if available (already discovered by main.py) - # Only call discover_and_prepare() if no credentials were passed - if oauth_credentials: - self.oauth_credentials = oauth_credentials - else: - self.credential_manager = CredentialManager( - os.environ, oauth_dir=get_oauth_dir(self.data_dir) - ) - self.oauth_credentials = self.credential_manager.discover_and_prepare() - self.background_refresher = BackgroundRefresher(self) - self.oauth_providers = set(self.oauth_credentials.keys()) - - all_credentials = {} - for provider, keys in api_keys.items(): - all_credentials.setdefault(provider, []).extend(keys) - for provider, paths in self.oauth_credentials.items(): - all_credentials.setdefault(provider, []).extend(paths) - self.all_credentials = all_credentials - - self.max_retries = max_retries - self.global_timeout = global_timeout - self.abort_on_callback_error = abort_on_callback_error - - # Initialize provider plugins early so they can be used for rotation mode detection - self._provider_plugins = PROVIDER_PLUGINS - self._provider_instances = {} - - # Build provider rotation modes map - # Each provider can specify its preferred rotation mode ("balanced" or "sequential") - provider_rotation_modes = {} - for provider in self.all_credentials.keys(): - provider_class = self._provider_plugins.get(provider) - if provider_class and hasattr(provider_class, "get_rotation_mode"): - # Use class method to get rotation mode (checks env var + class default) - mode = provider_class.get_rotation_mode(provider) - else: - # Fallback: check environment variable directly - env_key = f"ROTATION_MODE_{provider.upper()}" - mode = os.getenv(env_key, "balanced") - - provider_rotation_modes[provider] = mode - if mode != "balanced": - lib_logger.info(f"Provider '{provider}' using rotation mode: {mode}") - - # Build priority-based concurrency multiplier maps - # These are universal multipliers based on credential tier/priority - priority_multipliers: Dict[str, Dict[int, int]] = {} - priority_multipliers_by_mode: Dict[str, Dict[str, Dict[int, int]]] = {} - sequential_fallback_multipliers: Dict[str, int] = {} - - for provider in self.all_credentials.keys(): - provider_class = self._provider_plugins.get(provider) - - # Start with provider class defaults - if provider_class: - # Get default priority multipliers from provider class - if hasattr(provider_class, "default_priority_multipliers"): - default_multipliers = provider_class.default_priority_multipliers - if default_multipliers: - priority_multipliers[provider] = dict(default_multipliers) - - # Get sequential fallback from provider class - if hasattr(provider_class, "default_sequential_fallback_multiplier"): - fallback = provider_class.default_sequential_fallback_multiplier - if ( - fallback != DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER - ): # Only store if different from global default - sequential_fallback_multipliers[provider] = fallback - - # Override with environment variables - # Format: CONCURRENCY_MULTIPLIER__PRIORITY_= - # Format: CONCURRENCY_MULTIPLIER__PRIORITY__= - for key, value in os.environ.items(): - prefix = f"CONCURRENCY_MULTIPLIER_{provider.upper()}_PRIORITY_" - if key.startswith(prefix): - remainder = key[len(prefix) :] - try: - multiplier = int(value) - if multiplier < 1: - lib_logger.warning(f"Invalid {key}: {value}. Must be >= 1.") - continue - - # Check if mode-specific (e.g., _PRIORITY_1_SEQUENTIAL) - if "_" in remainder: - parts = remainder.rsplit("_", 1) - priority = int(parts[0]) - mode = parts[1].lower() - if mode in ("sequential", "balanced"): - # Mode-specific override - if provider not in priority_multipliers_by_mode: - priority_multipliers_by_mode[provider] = {} - if mode not in priority_multipliers_by_mode[provider]: - priority_multipliers_by_mode[provider][mode] = {} - priority_multipliers_by_mode[provider][mode][ - priority - ] = multiplier - lib_logger.info( - f"Provider '{provider}' priority {priority} ({mode} mode) multiplier: {multiplier}x" - ) - else: - # Assume it's part of the priority number (unlikely but handle gracefully) - lib_logger.warning(f"Unknown mode in {key}: {mode}") - else: - # Universal priority multiplier - priority = int(remainder) - if provider not in priority_multipliers: - priority_multipliers[provider] = {} - priority_multipliers[provider][priority] = multiplier - lib_logger.info( - f"Provider '{provider}' priority {priority} multiplier: {multiplier}x" - ) - except ValueError: - lib_logger.warning( - f"Invalid {key}: {value}. Could not parse priority or multiplier." - ) - - # Log configured multipliers - for provider, multipliers in priority_multipliers.items(): - if multipliers: - lib_logger.info( - f"Provider '{provider}' priority multipliers: {multipliers}" - ) - for provider, fallback in sequential_fallback_multipliers.items(): - lib_logger.info( - f"Provider '{provider}' sequential fallback multiplier: {fallback}x" - ) - - # Build fair cycle configuration - fair_cycle_enabled: Dict[str, bool] = {} - fair_cycle_tracking_mode: Dict[str, str] = {} - fair_cycle_cross_tier: Dict[str, bool] = {} - fair_cycle_duration: Dict[str, int] = {} - - for provider in self.all_credentials.keys(): - provider_class = self._provider_plugins.get(provider) - rotation_mode = provider_rotation_modes.get(provider, "balanced") - - # Fair cycle enabled - check env, then provider default, then derive from rotation mode - env_key = f"FAIR_CYCLE_{provider.upper()}" - env_val = os.getenv(env_key) - if env_val is not None: - fair_cycle_enabled[provider] = env_val.lower() in ("true", "1", "yes") - elif provider_class and hasattr( - provider_class, "default_fair_cycle_enabled" - ): - default_val = provider_class.default_fair_cycle_enabled - if default_val is not None: - fair_cycle_enabled[provider] = default_val - # None means use global default (enabled for all modes) - # Default: enabled for all rotation modes (not stored, handled in UsageManager) - - # Tracking mode - check env, then provider default - env_key = f"FAIR_CYCLE_TRACKING_MODE_{provider.upper()}" - env_val = os.getenv(env_key) - if env_val is not None and env_val.lower() in ("model_group", "credential"): - fair_cycle_tracking_mode[provider] = env_val.lower() - elif provider_class and hasattr( - provider_class, "default_fair_cycle_tracking_mode" - ): - fair_cycle_tracking_mode[provider] = ( - provider_class.default_fair_cycle_tracking_mode - ) - - # Cross-tier - check env, then provider default - env_key = f"FAIR_CYCLE_CROSS_TIER_{provider.upper()}" - env_val = os.getenv(env_key) - if env_val is not None: - fair_cycle_cross_tier[provider] = env_val.lower() in ( - "true", - "1", - "yes", - ) - elif provider_class and hasattr( - provider_class, "default_fair_cycle_cross_tier" - ): - if provider_class.default_fair_cycle_cross_tier: - fair_cycle_cross_tier[provider] = True - - # Duration - check provider-specific env, then provider default - env_key = f"FAIR_CYCLE_DURATION_{provider.upper()}" - env_val = os.getenv(env_key) - if env_val is not None: - try: - fair_cycle_duration[provider] = int(env_val) - except ValueError: - lib_logger.warning( - f"Invalid {env_key}: {env_val}. Must be integer." - ) - elif provider_class and hasattr( - provider_class, "default_fair_cycle_duration" - ): - duration = provider_class.default_fair_cycle_duration - if ( - duration != DEFAULT_FAIR_CYCLE_DURATION - ): # Only store if different from global default - fair_cycle_duration[provider] = duration - - # Build exhaustion cooldown threshold per provider - # Check global env first, then per-provider env, then provider class default - exhaustion_cooldown_threshold: Dict[str, int] = {} - global_threshold_str = os.getenv("EXHAUSTION_COOLDOWN_THRESHOLD") - global_threshold = DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD - if global_threshold_str: - try: - global_threshold = int(global_threshold_str) - except ValueError: - lib_logger.warning( - f"Invalid EXHAUSTION_COOLDOWN_THRESHOLD: {global_threshold_str}. Using default {DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD}." - ) - - for provider in self.all_credentials.keys(): - provider_class = self._provider_plugins.get(provider) - - # Check per-provider env var first - env_key = f"EXHAUSTION_COOLDOWN_THRESHOLD_{provider.upper()}" - env_val = os.getenv(env_key) - if env_val is not None: - try: - exhaustion_cooldown_threshold[provider] = int(env_val) - except ValueError: - lib_logger.warning( - f"Invalid {env_key}: {env_val}. Must be integer." - ) - elif provider_class and hasattr( - provider_class, "default_exhaustion_cooldown_threshold" - ): - threshold = provider_class.default_exhaustion_cooldown_threshold - if ( - threshold != DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD - ): # Only store if different from global default - exhaustion_cooldown_threshold[provider] = threshold - elif global_threshold != DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD: - # Use global threshold if set and different from default - exhaustion_cooldown_threshold[provider] = global_threshold - - # Log fair cycle configuration - for provider, enabled in fair_cycle_enabled.items(): - if not enabled: - lib_logger.info(f"Provider '{provider}' fair cycle: disabled") - for provider, mode in fair_cycle_tracking_mode.items(): - if mode != "model_group": - lib_logger.info( - f"Provider '{provider}' fair cycle tracking mode: {mode}" - ) - for provider, cross_tier in fair_cycle_cross_tier.items(): - if cross_tier: - lib_logger.info(f"Provider '{provider}' fair cycle cross-tier: enabled") - - # Build custom caps configuration - # Format: CUSTOM_CAP_{PROVIDER}_T{TIER}_{MODEL_OR_GROUP}= - # Format: CUSTOM_CAP_COOLDOWN_{PROVIDER}_T{TIER}_{MODEL_OR_GROUP}=: - custom_caps: Dict[ - str, Dict[Union[int, Tuple[int, ...], str], Dict[str, Dict[str, Any]]] - ] = {} - - for provider in self.all_credentials.keys(): - provider_class = self._provider_plugins.get(provider) - provider_upper = provider.upper() - - # Start with provider class defaults - if provider_class and hasattr(provider_class, "default_custom_caps"): - default_caps = provider_class.default_custom_caps - if default_caps: - custom_caps[provider] = {} - for tier_key, models_config in default_caps.items(): - custom_caps[provider][tier_key] = dict(models_config) - - # Parse environment variable overrides - cap_prefix = f"CUSTOM_CAP_{provider_upper}_T" - cooldown_prefix = f"CUSTOM_CAP_COOLDOWN_{provider_upper}_T" - - for env_key, env_value in os.environ.items(): - if env_key.startswith(cap_prefix) and not env_key.startswith( - cooldown_prefix - ): - # Parse cap value - remainder = env_key[len(cap_prefix) :] - tier_key, model_key = self._parse_custom_cap_env_key(remainder) - if tier_key is None: - continue - - if provider not in custom_caps: - custom_caps[provider] = {} - if tier_key not in custom_caps[provider]: - custom_caps[provider][tier_key] = {} - if model_key not in custom_caps[provider][tier_key]: - custom_caps[provider][tier_key][model_key] = {} - - # Store max_requests value - custom_caps[provider][tier_key][model_key]["max_requests"] = ( - env_value - ) - - elif env_key.startswith(cooldown_prefix): - # Parse cooldown config - remainder = env_key[len(cooldown_prefix) :] - tier_key, model_key = self._parse_custom_cap_env_key(remainder) - if tier_key is None: - continue - - # Parse mode:value format - if ":" in env_value: - mode, value_str = env_value.split(":", 1) - try: - value = int(value_str) - except ValueError: - lib_logger.warning( - f"Invalid cooldown value in {env_key}: {env_value}" - ) - continue - else: - mode = env_value - value = 0 - - if provider not in custom_caps: - custom_caps[provider] = {} - if tier_key not in custom_caps[provider]: - custom_caps[provider][tier_key] = {} - if model_key not in custom_caps[provider][tier_key]: - custom_caps[provider][tier_key][model_key] = {} - - custom_caps[provider][tier_key][model_key]["cooldown_mode"] = mode - custom_caps[provider][tier_key][model_key]["cooldown_value"] = value - - # Log custom caps configuration - for provider, tier_configs in custom_caps.items(): - for tier_key, models_config in tier_configs.items(): - for model_key, config in models_config.items(): - max_req = config.get("max_requests", "default") - cooldown = config.get("cooldown_mode", "quota_reset") - lib_logger.info( - f"Custom cap: {provider}/T{tier_key}/{model_key} = {max_req}, cooldown={cooldown}" - ) - - # Resolve usage file path - use provided path or default to data_dir - if usage_file_path is not None: - resolved_usage_path = Path(usage_file_path) - else: - resolved_usage_path = self.data_dir / "key_usage.json" - - self.usage_manager = UsageManager( - file_path=resolved_usage_path, - rotation_tolerance=rotation_tolerance, - provider_rotation_modes=provider_rotation_modes, - provider_plugins=PROVIDER_PLUGINS, - priority_multipliers=priority_multipliers, - priority_multipliers_by_mode=priority_multipliers_by_mode, - sequential_fallback_multipliers=sequential_fallback_multipliers, - fair_cycle_enabled=fair_cycle_enabled, - fair_cycle_tracking_mode=fair_cycle_tracking_mode, - fair_cycle_cross_tier=fair_cycle_cross_tier, - fair_cycle_duration=fair_cycle_duration, - exhaustion_cooldown_threshold=exhaustion_cooldown_threshold, - custom_caps=custom_caps, - ) - self._model_list_cache = {} - self.http_client = httpx.AsyncClient() - self.provider_config = ProviderConfig() - self.cooldown_manager = CooldownManager() - self.litellm_provider_params = litellm_provider_params or {} - self.ignore_models = ignore_models or {} - self.whitelist_models = whitelist_models or {} - self.enable_request_logging = enable_request_logging - self.model_definitions = ModelDefinitions() - - # Store and validate max concurrent requests per key - self.max_concurrent_requests_per_key = max_concurrent_requests_per_key or {} - # Validate all values are >= 1 - for provider, max_val in self.max_concurrent_requests_per_key.items(): - if max_val < 1: - lib_logger.warning( - f"Invalid max_concurrent for '{provider}': {max_val}. Setting to 1." - ) - self.max_concurrent_requests_per_key[provider] = 1 - - def _parse_custom_cap_env_key( - self, remainder: str - ) -> Tuple[Optional[Union[int, Tuple[int, ...], str]], Optional[str]]: - """ - Parse the tier and model/group from a custom cap env var remainder. - - Args: - remainder: String after "CUSTOM_CAP_{PROVIDER}_T" prefix - e.g., "2_CLAUDE" or "2_3_CLAUDE" or "DEFAULT_CLAUDE" - - Returns: - (tier_key, model_key) tuple, or (None, None) if parse fails - """ - if not remainder: - return None, None - - remaining_parts = remainder.split("_") - if len(remaining_parts) < 2: - return None, None - - tier_key: Union[int, Tuple[int, ...], str, None] = None - model_key: Optional[str] = None - - # Tiers are numeric or "DEFAULT" - tier_parts: List[int] = [] - - for i, part in enumerate(remaining_parts): - if part == "DEFAULT": - tier_key = "default" - model_key = "_".join(remaining_parts[i + 1 :]) - break - elif part.isdigit(): - tier_parts.append(int(part)) - else: - # First non-numeric part is start of model name - if len(tier_parts) == 0: - return None, None - elif len(tier_parts) == 1: - tier_key = tier_parts[0] - else: - tier_key = tuple(tier_parts) - model_key = "_".join(remaining_parts[i:]) - break - else: - # All parts were tier parts, no model - return None, None - - if model_key: - # Convert model_key back to original format (for matching) - # Env vars use underscores, but we store with original names - # The matching in UsageManager will handle this - model_key = model_key.lower().replace("_", "-") - - return tier_key, model_key - - def _is_model_ignored(self, provider: str, model_id: str) -> bool: - """ - Checks if a model should be ignored based on the ignore list. - Supports full glob/fnmatch patterns for both full model IDs and model names. - - Pattern examples: - - "gpt-4" - exact match - - "gpt-4*" - prefix wildcard (matches gpt-4, gpt-4-turbo, etc.) - - "*-preview" - suffix wildcard (matches gpt-4-preview, o1-preview, etc.) - - "*-preview*" - contains wildcard (matches anything with -preview) - - "*" - match all - """ - model_provider = model_id.split("/")[0] - if model_provider not in self.ignore_models: - return False - - ignore_list = self.ignore_models[model_provider] - if ignore_list == ["*"]: - return True - - try: - # This is the model name as the provider sees it (e.g., "gpt-4" or "google/gemma-7b") - provider_model_name = model_id.split("/", 1)[1] - except IndexError: - provider_model_name = model_id - - for ignored_pattern in ignore_list: - # Use fnmatch for full glob pattern support - if fnmatch.fnmatch(provider_model_name, ignored_pattern) or fnmatch.fnmatch( - model_id, ignored_pattern - ): - return True - return False - - def _is_model_whitelisted(self, provider: str, model_id: str) -> bool: - """ - Checks if a model is explicitly whitelisted. - Supports full glob/fnmatch patterns for both full model IDs and model names. - - Pattern examples: - - "gpt-4" - exact match - - "gpt-4*" - prefix wildcard (matches gpt-4, gpt-4-turbo, etc.) - - "*-preview" - suffix wildcard (matches gpt-4-preview, o1-preview, etc.) - - "*-preview*" - contains wildcard (matches anything with -preview) - - "*" - match all - """ - model_provider = model_id.split("/")[0] - if model_provider not in self.whitelist_models: - return False - - whitelist = self.whitelist_models[model_provider] - - try: - # This is the model name as the provider sees it (e.g., "gpt-4" or "google/gemma-7b") - provider_model_name = model_id.split("/", 1)[1] - except IndexError: - provider_model_name = model_id - - for whitelisted_pattern in whitelist: - # Use fnmatch for full glob pattern support - if fnmatch.fnmatch( - provider_model_name, whitelisted_pattern - ) or fnmatch.fnmatch(model_id, whitelisted_pattern): - return True - return False - - def _sanitize_litellm_log(self, log_data: dict) -> dict: - """ - Recursively removes large data fields and sensitive information from litellm log - dictionaries to keep debug logs clean and secure. - """ - if not isinstance(log_data, dict): - return log_data - - # Keys to remove at any level of the dictionary - keys_to_pop = [ - "messages", - "input", - "response", - "data", - "api_key", - "api_base", - "original_response", - "additional_args", - ] - - # Keys that might contain nested dictionaries to clean - nested_keys = ["kwargs", "litellm_params", "model_info", "proxy_server_request"] - - # Create a deep copy to avoid modifying the original log object in memory - clean_data = json.loads(json.dumps(log_data, default=str)) - - def clean_recursively(data_dict): - if not isinstance(data_dict, dict): - return - - # Remove sensitive/large keys - for key in keys_to_pop: - data_dict.pop(key, None) - - # Recursively clean nested dictionaries - for key in nested_keys: - if key in data_dict and isinstance(data_dict[key], dict): - clean_recursively(data_dict[key]) - - # Also iterate through all values to find any other nested dicts - for key, value in list(data_dict.items()): - if isinstance(value, dict): - clean_recursively(value) - - clean_recursively(clean_data) - return clean_data - - def _litellm_logger_callback(self, log_data: dict): - """ - Callback function to redirect litellm's logs to the library's logger. - This allows us to control the log level and destination of litellm's output. - It also cleans up error logs for better readability in debug files. - """ - # Filter out verbose pre_api_call and post_api_call logs - log_event_type = log_data.get("log_event_type") - if log_event_type in ["pre_api_call", "post_api_call"]: - return # Skip these verbose logs entirely - - # For successful calls or pre-call logs, a simple debug message is enough. - if not log_data.get("exception"): - sanitized_log = self._sanitize_litellm_log(log_data) - # We log it at the DEBUG level to ensure it goes to the debug file - # and not the console, based on the main.py configuration. - lib_logger.debug(f"LiteLLM Log: {sanitized_log}") - return - - # For failures, extract key info to make debug logs more readable. - model = log_data.get("model", "N/A") - call_id = log_data.get("litellm_call_id", "N/A") - error_info = log_data.get("standard_logging_object", {}).get( - "error_information", {} - ) - error_class = error_info.get("error_class", "UnknownError") - error_message = error_info.get( - "error_message", str(log_data.get("exception", "")) - ) - error_message = " ".join(error_message.split()) # Sanitize - - lib_logger.debug( - f"LiteLLM Callback Handled Error: Model={model} | " - f"Type={error_class} | Message='{error_message}'" - ) - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.close() - - async def close(self): - """Close the HTTP client to prevent resource leaks.""" - if hasattr(self, "http_client") and self.http_client: - await self.http_client.aclose() - - def _apply_default_safety_settings( - self, litellm_kwargs: Dict[str, Any], provider: str - ): - """ - Ensure default Gemini safety settings are present when calling the Gemini provider. - This will not override any explicit settings provided by the request. It accepts - either OpenAI-compatible generic `safety_settings` (dict) or direct Gemini-style - `safetySettings` (list of dicts). Missing categories will be added with safe defaults. - """ - if provider != "gemini": - return - - # Generic defaults (openai-compatible style) - default_generic = { - "harassment": "OFF", - "hate_speech": "OFF", - "sexually_explicit": "OFF", - "dangerous_content": "OFF", - "civic_integrity": "BLOCK_NONE", - } - - # Gemini defaults (direct Gemini format) - default_gemini = [ - {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, - {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, - ] - - # If generic form is present, ensure missing generic keys are filled in - if "safety_settings" in litellm_kwargs and isinstance( - litellm_kwargs["safety_settings"], dict - ): - for k, v in default_generic.items(): - if k not in litellm_kwargs["safety_settings"]: - litellm_kwargs["safety_settings"][k] = v - return - - # If Gemini form is present, ensure missing gemini categories are appended - if "safetySettings" in litellm_kwargs and isinstance( - litellm_kwargs["safetySettings"], list - ): - present = { - item.get("category") - for item in litellm_kwargs["safetySettings"] - if isinstance(item, dict) - } - for d in default_gemini: - if d["category"] not in present: - litellm_kwargs["safetySettings"].append(d) - return - - # Neither present: set generic defaults so provider conversion will translate them - if ( - "safety_settings" not in litellm_kwargs - and "safetySettings" not in litellm_kwargs - ): - litellm_kwargs["safety_settings"] = default_generic.copy() - - def get_oauth_credentials(self) -> Dict[str, List[str]]: - return self.oauth_credentials - - def _is_custom_openai_compatible_provider(self, provider_name: str) -> bool: - """ - Checks if a provider is a custom OpenAI-compatible provider. - - Custom providers are identified by: - 1. Having a _API_BASE environment variable set, AND - 2. NOT being in the list of known LiteLLM providers - """ - return self.provider_config.is_custom_provider(provider_name) - - def _get_provider_instance(self, provider_name: str): - """ - Lazily initializes and returns a provider instance. - Only initializes providers that have configured credentials. - - Args: - provider_name: The name of the provider to get an instance for. - For OAuth providers, this may include "_oauth" suffix - (e.g., "antigravity_oauth"), but credentials are stored - under the base name (e.g., "antigravity"). - - Returns: - Provider instance if credentials exist, None otherwise. - """ - # For OAuth providers, credentials are stored under base name (without _oauth suffix) - # e.g., "antigravity_oauth" plugin → credentials under "antigravity" - credential_key = provider_name - if provider_name.endswith("_oauth"): - base_name = provider_name[:-6] # Remove "_oauth" - if base_name in self.oauth_providers: - credential_key = base_name - - # Only initialize providers for which we have credentials - if credential_key not in self.all_credentials: - lib_logger.debug( - f"Skipping provider '{provider_name}' initialization: no credentials configured" - ) - return None - - if provider_name not in self._provider_instances: - if provider_name in self._provider_plugins: - self._provider_instances[provider_name] = self._provider_plugins[ - provider_name - ]() - elif self._is_custom_openai_compatible_provider(provider_name): - # Create a generic OpenAI-compatible provider for custom providers - try: - self._provider_instances[provider_name] = OpenAICompatibleProvider( - provider_name - ) - except ValueError: - # If the provider doesn't have the required environment variables, treat it as a standard provider - return None - else: - return None - return self._provider_instances[provider_name] - - def _resolve_model_id(self, model: str, provider: str) -> str: - """ - Resolves the actual model ID to send to the provider. - - For custom models with name/ID mappings, returns the ID. - Otherwise, returns the model name unchanged. - - Args: - model: Full model string with provider (e.g., "iflow/DS-v3.2") - provider: Provider name (e.g., "iflow") - - Returns: - Full model string with ID (e.g., "iflow/deepseek-v3.2") - """ - # Extract model name from "provider/model_name" format - model_name = model.split("/")[-1] if "/" in model else model - - # Try to get provider instance to check for model definitions - provider_plugin = self._get_provider_instance(provider) - - # Check if provider has model definitions - if provider_plugin and hasattr(provider_plugin, "model_definitions"): - model_id = provider_plugin.model_definitions.get_model_id( - provider, model_name - ) - if model_id and model_id != model_name: - # Return with provider prefix - return f"{provider}/{model_id}" - - # Fallback: use client's own model definitions - model_id = self.model_definitions.get_model_id(provider, model_name) - if model_id and model_id != model_name: - return f"{provider}/{model_id}" - - # No conversion needed, return original - return model - - async def _safe_streaming_wrapper( - self, - stream: Any, - key: str, - model: str, - request: Optional[Any] = None, - provider_plugin: Optional[Any] = None, - ) -> AsyncGenerator[Any, None]: - """ - A hybrid wrapper for streaming that buffers fragmented JSON, handles client disconnections gracefully, - and distinguishes between content and streamed errors. - - FINISH_REASON HANDLING: - Providers just translate chunks - this wrapper handles ALL finish_reason logic: - 1. Strip finish_reason from intermediate chunks (litellm defaults to "stop") - 2. Track accumulated_finish_reason with priority: tool_calls > length/content_filter > stop - 3. Only emit finish_reason on final chunk (detected by usage.completion_tokens > 0) - """ - last_usage = None - stream_completed = False - stream_iterator = stream.__aiter__() - json_buffer = "" - accumulated_finish_reason = None # Track strongest finish_reason across chunks - has_tool_calls = False # Track if ANY tool calls were seen in stream - - try: - while True: - if request and await request.is_disconnected(): - lib_logger.info( - f"Client disconnected. Aborting stream for credential {mask_credential(key)}." - ) - break - - try: - chunk = await stream_iterator.__anext__() - if json_buffer: - lib_logger.warning( - f"Discarding incomplete JSON buffer from previous chunk: {json_buffer}" - ) - json_buffer = "" - - # Convert chunk to dict, handling both litellm.ModelResponse and raw dicts - if hasattr(chunk, "dict"): - chunk_dict = chunk.dict() - elif hasattr(chunk, "model_dump"): - chunk_dict = chunk.model_dump() - else: - chunk_dict = chunk - - # === FINISH_REASON LOGIC === - # Providers send raw chunks without finish_reason logic. - # This wrapper determines finish_reason based on accumulated state. - if "choices" in chunk_dict and chunk_dict["choices"]: - choice = chunk_dict["choices"][0] - delta = choice.get("delta", {}) - usage = chunk_dict.get("usage", {}) - - # Track tool_calls across ALL chunks - if we ever see one, finish_reason must be tool_calls - if delta.get("tool_calls"): - has_tool_calls = True - accumulated_finish_reason = "tool_calls" - - # Detect final chunk: has usage with completion_tokens > 0 - has_completion_tokens = ( - usage - and isinstance(usage, dict) - and usage.get("completion_tokens", 0) > 0 - ) - - if has_completion_tokens: - # FINAL CHUNK: Determine correct finish_reason - if has_tool_calls: - # Tool calls always win - choice["finish_reason"] = "tool_calls" - elif accumulated_finish_reason: - # Use accumulated reason (length, content_filter, etc.) - choice["finish_reason"] = accumulated_finish_reason - else: - # Default to stop - choice["finish_reason"] = "stop" - else: - # INTERMEDIATE CHUNK: Never emit finish_reason - # (litellm.ModelResponse defaults to "stop" which is wrong) - choice["finish_reason"] = None - - yield f"data: {json.dumps(chunk_dict)}\n\n" - - if hasattr(chunk, "usage") and chunk.usage: - last_usage = chunk.usage - - except StopAsyncIteration: - stream_completed = True - if json_buffer: - lib_logger.info( - f"Stream ended with incomplete data in buffer: {json_buffer}" - ) - if last_usage: - # Create a dummy ModelResponse for recording (only usage matters) - dummy_response = litellm.ModelResponse(usage=last_usage) - await self.usage_manager.record_success( - key, model, dummy_response - ) - else: - # If no usage seen (rare), record success without tokens/cost - await self.usage_manager.record_success(key, model) - - break - - except CredentialNeedsReauthError as e: - # This credential needs re-authentication but re-auth is already queued. - # Wrap it so the outer retry loop can rotate to the next credential. - # No scary traceback needed - this is an expected recovery scenario. - raise StreamedAPIError("Credential needs re-authentication", data=e) - - except ( - litellm.RateLimitError, - litellm.ServiceUnavailableError, - litellm.InternalServerError, - APIConnectionError, - httpx.HTTPStatusError, - ) as e: - # This is a critical, typed error from litellm or httpx that signals a key failure. - # We do not try to parse it here. We wrap it and raise it immediately - # for the outer retry loop to handle. - lib_logger.warning( - f"Caught a critical API error mid-stream: {type(e).__name__}. Signaling for credential rotation." - ) - raise StreamedAPIError("Provider error received in stream", data=e) - - except Exception as e: - try: - raw_chunk = "" - # Google streams errors inside a bytes representation (b'{...}'). - # We use regex to extract the content, which is more reliable than splitting. - match = re.search(r"b'(\{.*\})'", str(e), re.DOTALL) - if match: - # The extracted string is unicode-escaped (e.g., '\\n'). We must decode it. - raw_chunk = codecs.decode(match.group(1), "unicode_escape") - else: - # Fallback for other potential error formats that use "Received chunk:". - chunk_from_split = ( - str(e).split("Received chunk:")[-1].strip() - ) - if chunk_from_split != str( - e - ): # Ensure the split actually did something - raw_chunk = chunk_from_split - - if not raw_chunk: - # If we could not extract a valid chunk, we cannot proceed with reassembly. - # This indicates a different, unexpected error type. Re-raise it. - raise e - - # Append the clean chunk to the buffer and try to parse. - json_buffer += raw_chunk - parsed_data = json.loads(json_buffer) - - # If parsing succeeds, we have the complete object. - lib_logger.info( - f"Successfully reassembled JSON from stream: {json_buffer}" - ) - - # Wrap the complete error object and raise it. The outer function will decide how to handle it. - raise StreamedAPIError( - "Provider error received in stream", data=parsed_data - ) - - except json.JSONDecodeError: - # This is the expected outcome if the JSON in the buffer is not yet complete. - lib_logger.info( - f"Buffer still incomplete. Waiting for more chunks: {json_buffer}" - ) - continue # Continue to the next loop to get the next chunk. - except StreamedAPIError: - # Re-raise to be caught by the outer retry handler. - raise - except Exception as buffer_exc: - # If the error was not a JSONDecodeError, it's an unexpected internal error. - lib_logger.error( - f"Error during stream buffering logic: {buffer_exc}. Discarding buffer." - ) - json_buffer = ( - "" # Clear the corrupted buffer to prevent further issues. - ) - raise buffer_exc - - except StreamedAPIError: - # This is caught by the acompletion retry logic. - # We re-raise it to ensure it's not caught by the generic 'except Exception'. - raise - - except Exception as e: - # Catch any other unexpected errors during streaming. - lib_logger.error(f"Caught unexpected exception of type: {type(e).__name__}") - lib_logger.error( - f"An unexpected error occurred during the stream for credential {mask_credential(key)}: {e}" - ) - # We still need to raise it so the client knows something went wrong. - raise - - finally: - # This block now runs regardless of how the stream terminates (completion, client disconnect, etc.). - # The primary goal is to ensure usage is always logged internally. - await self.usage_manager.release_key(key, model) - lib_logger.info( - f"STREAM FINISHED and lock released for credential {mask_credential(key)}." - ) - - # Only send [DONE] if the stream completed naturally and the client is still there. - # This prevents sending [DONE] to a disconnected client or after an error. - if stream_completed and ( - not request or not await request.is_disconnected() - ): - yield "data: [DONE]\n\n" - - async def _transaction_logging_stream_wrapper( - self, - stream: Any, - transaction_logger: Optional[TransactionLogger], - request_data: Dict[str, Any], - ) -> Any: - """ - Wrap a stream to log chunks and final response to TransactionLogger. - - This wrapper: - 1. Yields chunks unchanged (passthrough) - 2. Parses SSE chunks and logs them via transaction_logger.log_stream_chunk() - 3. Collects chunks for final response assembly - 4. After stream ends, assembles and logs final response - - Args: - stream: The streaming generator (yields SSE strings like "data: {...}") - transaction_logger: Optional TransactionLogger instance - request_data: Original request data for context - """ - chunks = [] - try: - async for chunk_str in stream: - yield chunk_str - - # Log chunk if logging enabled - if ( - transaction_logger - and isinstance(chunk_str, str) - and chunk_str.strip() - and chunk_str.startswith("data:") - ): - content = chunk_str[len("data:") :].strip() - if content and content != "[DONE]": - try: - chunk_data = json.loads(content) - chunks.append(chunk_data) - transaction_logger.log_stream_chunk(chunk_data) - except json.JSONDecodeError: - lib_logger.warning( - f"TransactionLogger: Failed to parse chunk: {content[:100]}" - ) - finally: - # Assemble and log final response after stream ends - if transaction_logger and chunks: - try: - final_response = TransactionLogger.assemble_streaming_response( - chunks, request_data - ) - transaction_logger.log_response(final_response) - except Exception as e: - lib_logger.warning( - f"TransactionLogger: Failed to assemble/log final response: {e}" - ) - - async def _execute_with_retry( - self, - api_call: callable, - request: Optional[Any], - pre_request_callback: Optional[callable] = None, - **kwargs, - ) -> Any: - """A generic retry mechanism for non-streaming API calls.""" - model = kwargs.get("model") - if not model: - raise ValueError("'model' is a required parameter.") - - provider = model.split("/")[0] - if provider not in self.all_credentials: - raise ValueError( - f"No API keys or OAuth credentials configured for provider: {provider}" - ) - - # Extract internal logging parameters (not passed to API) - parent_log_dir = kwargs.pop("_parent_log_dir", None) - - # Establish a global deadline for the entire request lifecycle. - deadline = time.time() + self.global_timeout - - # Create transaction logger if request logging is enabled - transaction_logger = None - if self.enable_request_logging: - transaction_logger = TransactionLogger( - provider, - model, - enabled=True, - api_format="oai", - parent_dir=parent_log_dir, - ) - transaction_logger.log_request(kwargs) - - # Create a mutable copy of the keys and shuffle it to ensure - # that the key selection is randomized, which is crucial when - # multiple keys have the same usage stats. - credentials_for_provider = list(self.all_credentials[provider]) - random.shuffle(credentials_for_provider) - - # Filter out credentials that are unavailable (queued for re-auth) - provider_plugin = self._get_provider_instance(provider) - if provider_plugin and hasattr(provider_plugin, "is_credential_available"): - available_creds = [ - cred - for cred in credentials_for_provider - if provider_plugin.is_credential_available(cred) - ] - if available_creds: - credentials_for_provider = available_creds - # If all credentials are unavailable, keep the original list - # (better to try unavailable creds than fail immediately) - - tried_creds = set() - last_exception = None - - # The main rotation loop. It continues as long as there are untried credentials and the global deadline has not been exceeded. - - # Resolve model ID early, before any credential operations - # This ensures consistent model ID usage for acquisition, release, and tracking - resolved_model = self._resolve_model_id(model, provider) - if resolved_model != model: - lib_logger.info(f"Resolved model '{model}' to '{resolved_model}'") - model = resolved_model - kwargs["model"] = model # Ensure kwargs has the resolved model for litellm - - # [NEW] Filter by model tier requirement and build priority map - credential_priorities = None - if provider_plugin and hasattr(provider_plugin, "get_model_tier_requirement"): - required_tier = provider_plugin.get_model_tier_requirement(model) - if required_tier is not None: - # Filter OUT only credentials we KNOW are too low priority - # Keep credentials with unknown priority (None) - they might be high priority - incompatible_creds = [] - compatible_creds = [] - unknown_creds = [] - - for cred in credentials_for_provider: - if hasattr(provider_plugin, "get_credential_priority"): - priority = provider_plugin.get_credential_priority(cred) - if priority is None: - # Unknown priority - keep it, will be discovered on first use - unknown_creds.append(cred) - elif priority <= required_tier: - # Known compatible priority - compatible_creds.append(cred) - else: - # Known incompatible priority (too low) - incompatible_creds.append(cred) - else: - # Provider doesn't support priorities - keep all - unknown_creds.append(cred) - - # If we have any known-compatible or unknown credentials, use them - tier_compatible_creds = compatible_creds + unknown_creds - if tier_compatible_creds: - credentials_for_provider = tier_compatible_creds - if compatible_creds and unknown_creds: - lib_logger.info( - f"Model {model} requires priority <= {required_tier}. " - f"Using {len(compatible_creds)} known-compatible + {len(unknown_creds)} unknown-tier credentials." - ) - elif compatible_creds: - lib_logger.info( - f"Model {model} requires priority <= {required_tier}. " - f"Using {len(compatible_creds)} known-compatible credentials." - ) - else: - lib_logger.info( - f"Model {model} requires priority <= {required_tier}. " - f"Using {len(unknown_creds)} unknown-tier credentials (will discover on use)." - ) - elif incompatible_creds: - # Only known-incompatible credentials remain - lib_logger.warning( - f"Model {model} requires priority <= {required_tier} credentials, " - f"but all {len(incompatible_creds)} known credentials have priority > {required_tier}. " - f"Request will likely fail." - ) - - # Build priority map and tier names map for usage_manager - credential_tier_names = None - if provider_plugin and hasattr(provider_plugin, "get_credential_priority"): - credential_priorities = {} - credential_tier_names = {} - for cred in credentials_for_provider: - priority = provider_plugin.get_credential_priority(cred) - if priority is not None: - credential_priorities[cred] = priority - # Also get tier name for logging - if hasattr(provider_plugin, "get_credential_tier_name"): - tier_name = provider_plugin.get_credential_tier_name(cred) - if tier_name: - credential_tier_names[cred] = tier_name - - if credential_priorities: - lib_logger.debug( - f"Credential priorities for {provider}: {', '.join(f'P{p}={len([c for c in credentials_for_provider if credential_priorities.get(c) == p])}' for p in sorted(set(credential_priorities.values())))}" - ) - - # Initialize error accumulator for tracking errors across credential rotation - error_accumulator = RequestErrorAccumulator() - error_accumulator.model = model - error_accumulator.provider = provider - - while ( - len(tried_creds) < len(credentials_for_provider) and time.time() < deadline - ): - current_cred = None - key_acquired = False - try: - # Check for a provider-wide cooldown first. - if await self.cooldown_manager.is_cooling_down(provider): - remaining_cooldown = ( - await self.cooldown_manager.get_cooldown_remaining(provider) - ) - remaining_budget = deadline - time.time() - - # If the cooldown is longer than the remaining time budget, fail fast. - if remaining_cooldown > remaining_budget: - lib_logger.warning( - f"Provider {provider} cooldown ({remaining_cooldown:.2f}s) exceeds remaining request budget ({remaining_budget:.2f}s). Failing early." - ) - break - - lib_logger.warning( - f"Provider {provider} is in cooldown. Waiting for {remaining_cooldown:.2f} seconds." - ) - await asyncio.sleep(remaining_cooldown) - - creds_to_try = [ - c for c in credentials_for_provider if c not in tried_creds - ] - if not creds_to_try: - break - - # Get count of credentials not on cooldown for this model - availability_stats = ( - await self.usage_manager.get_credential_availability_stats( - creds_to_try, model, credential_priorities - ) - ) - available_count = availability_stats["available"] - total_count = len(credentials_for_provider) - on_cooldown = availability_stats["on_cooldown"] - fc_excluded = availability_stats["fair_cycle_excluded"] - - # Build compact exclusion breakdown - exclusion_parts = [] - if on_cooldown > 0: - exclusion_parts.append(f"cd:{on_cooldown}") - if fc_excluded > 0: - exclusion_parts.append(f"fc:{fc_excluded}") - exclusion_str = ( - f",{','.join(exclusion_parts)}" if exclusion_parts else "" - ) - - lib_logger.info( - f"Acquiring key for model {model}. Tried keys: {len(tried_creds)}/{available_count}({total_count}{exclusion_str})" - ) - max_concurrent = self.max_concurrent_requests_per_key.get(provider, 1) - current_cred = await self.usage_manager.acquire_key( - available_keys=creds_to_try, - model=model, - deadline=deadline, - max_concurrent=max_concurrent, - credential_priorities=credential_priorities, - credential_tier_names=credential_tier_names, - all_provider_credentials=credentials_for_provider, - ) - key_acquired = True - tried_creds.add(current_cred) - - litellm_kwargs = kwargs.copy() - - # [NEW] Merge provider-specific params - if provider in self.litellm_provider_params: - litellm_kwargs["litellm_params"] = { - **self.litellm_provider_params[provider], - **litellm_kwargs.get("litellm_params", {}), - } - - provider_plugin = self._get_provider_instance(provider) - - # Model ID is already resolved before the loop, and kwargs['model'] is updated. - # No further resolution needed here. - - # Apply model-specific options for custom providers - if provider_plugin and hasattr(provider_plugin, "get_model_options"): - model_options = provider_plugin.get_model_options(model) - if model_options: - # Merge model options into litellm_kwargs - for key, value in model_options.items(): - if key == "reasoning_effort": - litellm_kwargs["reasoning_effort"] = value - elif key not in litellm_kwargs: - litellm_kwargs[key] = value - - if provider_plugin and provider_plugin.has_custom_logic(): - lib_logger.debug( - f"Provider '{provider}' has custom logic. Delegating call." - ) - litellm_kwargs["credential_identifier"] = current_cred - litellm_kwargs["transaction_context"] = ( - transaction_logger.get_context() if transaction_logger else None - ) - - # Retry loop for custom providers - mirrors streaming path error handling - for attempt in range(self.max_retries): - try: - lib_logger.info( - f"Attempting call with credential {mask_credential(current_cred)} (Attempt {attempt + 1}/{self.max_retries})" - ) - - if pre_request_callback: - try: - await pre_request_callback(request, litellm_kwargs) - except Exception as e: - if self.abort_on_callback_error: - raise PreRequestCallbackError( - f"Pre-request callback failed: {e}" - ) from e - else: - lib_logger.warning( - f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}" - ) - - response = await provider_plugin.acompletion( - self.http_client, **litellm_kwargs - ) - - # For non-streaming, success is immediate - await self.usage_manager.record_success( - current_cred, model, response - ) - - await self.usage_manager.release_key(current_cred, model) - key_acquired = False - - # Log response to transaction logger - if transaction_logger: - response_data = ( - response.model_dump() - if hasattr(response, "model_dump") - else response - ) - transaction_logger.log_response(response_data) - - return response - - except ( - litellm.RateLimitError, - httpx.HTTPStatusError, - ) as e: - last_exception = e - classified_error = classify_error(e, provider=provider) - error_message = str(e).split("\n")[0] - - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - - # Record in accumulator for client reporting - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - - # Check if this error should trigger rotation - if not should_rotate_on_error(classified_error): - lib_logger.error( - f"Non-recoverable error ({classified_error.error_type}) during custom provider call. Failing." - ) - raise last_exception - - # Handle rate limits with cooldown (exclude quota_exceeded) - if classified_error.error_type == "rate_limit": - cooldown_duration = classified_error.retry_after or 60 - await self.cooldown_manager.start_cooldown( - provider, cooldown_duration - ) - - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - lib_logger.warning( - f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code}). Rotating." - ) - break # Rotate to next credential - - except ( - APIConnectionError, - litellm.InternalServerError, - litellm.ServiceUnavailableError, - ) as e: - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - classified_error = classify_error(e, provider=provider) - error_message = str(e).split("\n")[0] - - # Provider-level error: don't increment consecutive failures - await self.usage_manager.record_failure( - current_cred, - model, - classified_error, - increment_consecutive_failures=False, - ) - - if attempt >= self.max_retries - 1: - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - lib_logger.warning( - f"Cred {mask_credential(current_cred)} failed after max retries. Rotating." - ) - break - - wait_time = classified_error.retry_after or ( - 2**attempt - ) + random.uniform(0, 1) - remaining_budget = deadline - time.time() - if wait_time > remaining_budget: - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - lib_logger.warning( - f"Retry wait ({wait_time:.2f}s) exceeds budget. Rotating." - ) - break - - lib_logger.warning( - f"Cred {mask_credential(current_cred)} server error. Retrying in {wait_time:.2f}s." - ) - await asyncio.sleep(wait_time) - continue - - except Exception as e: - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - classified_error = classify_error(e, provider=provider) - error_message = str(e).split("\n")[0] - - # Record in accumulator - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - - lib_logger.warning( - f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code})." - ) - - # Check if this error should trigger rotation - if not should_rotate_on_error(classified_error): - lib_logger.error( - f"Non-recoverable error ({classified_error.error_type}). Failing." - ) - raise last_exception - - # Handle rate limits with cooldown (exclude quota_exceeded) - if ( - classified_error.status_code == 429 - and classified_error.error_type != "quota_exceeded" - ) or classified_error.error_type == "rate_limit": - cooldown_duration = classified_error.retry_after or 60 - await self.cooldown_manager.start_cooldown( - provider, cooldown_duration - ) - - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - break # Rotate to next credential - - # If the inner loop breaks, it means the key failed and we need to rotate. - # Continue to the next iteration of the outer while loop to pick a new key. - continue - - else: # This is the standard API Key / litellm-handled provider logic - is_oauth = provider in self.oauth_providers - if is_oauth: # Standard OAuth provider (not custom) - # ... (logic to set headers) ... - pass - else: # API Key - litellm_kwargs["api_key"] = current_cred - - provider_instance = self._get_provider_instance(provider) - if provider_instance: - # Ensure default Gemini safety settings are present (without overriding request) - try: - self._apply_default_safety_settings( - litellm_kwargs, provider - ) - except Exception: - # If anything goes wrong here, avoid breaking the request flow. - lib_logger.debug( - "Could not apply default safety settings; continuing." - ) - - if "safety_settings" in litellm_kwargs: - converted_settings = ( - provider_instance.convert_safety_settings( - litellm_kwargs["safety_settings"] - ) - ) - if converted_settings is not None: - litellm_kwargs["safety_settings"] = converted_settings - else: - del litellm_kwargs["safety_settings"] - - if provider == "gemini" and provider_instance: - provider_instance.handle_thinking_parameter( - litellm_kwargs, model - ) - if provider == "nvidia_nim" and provider_instance: - provider_instance.handle_thinking_parameter( - litellm_kwargs, model - ) - - if "gemma-3" in model and "messages" in litellm_kwargs: - litellm_kwargs["messages"] = [ - {"role": "user", "content": m["content"]} - if m.get("role") == "system" - else m - for m in litellm_kwargs["messages"] - ] - - litellm_kwargs = sanitize_request_payload(litellm_kwargs, model) - - for attempt in range(self.max_retries): - try: - lib_logger.info( - f"Attempting call with credential {mask_credential(current_cred)} (Attempt {attempt + 1}/{self.max_retries})" - ) - - if pre_request_callback: - try: - await pre_request_callback(request, litellm_kwargs) - except Exception as e: - if self.abort_on_callback_error: - raise PreRequestCallbackError( - f"Pre-request callback failed: {e}" - ) from e - else: - lib_logger.warning( - f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}" - ) - - # Convert model parameters for custom providers right before LiteLLM call - final_kwargs = self.provider_config.convert_for_litellm( - **litellm_kwargs - ) - - response = await api_call( - **final_kwargs, - logger_fn=self._litellm_logger_callback, - ) - - await self.usage_manager.record_success( - current_cred, model, response - ) - - await self.usage_manager.release_key(current_cred, model) - key_acquired = False - - # Log response to transaction logger - if transaction_logger: - response_data = ( - response.model_dump() - if hasattr(response, "model_dump") - else response - ) - transaction_logger.log_response(response_data) - - return response - - except litellm.RateLimitError as e: - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - classified_error = classify_error(e, provider=provider) - - # Extract a clean error message for the user-facing log - error_message = str(e).split("\n")[0] - - # Record in accumulator for client reporting - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - - lib_logger.info( - f"Key {mask_credential(current_cred)} hit rate limit for {model}. Rotating key." - ) - - # Only trigger provider-wide cooldown for rate limits, not quota issues - if ( - classified_error.status_code == 429 - and classified_error.error_type != "quota_exceeded" - ): - cooldown_duration = classified_error.retry_after or 60 - await self.cooldown_manager.start_cooldown( - provider, cooldown_duration - ) - - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - break # Move to the next key - - except ( - APIConnectionError, - litellm.InternalServerError, - litellm.ServiceUnavailableError, - ) as e: - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - classified_error = classify_error(e, provider=provider) - error_message = str(e).split("\n")[0] - - # Provider-level error: don't increment consecutive failures - await self.usage_manager.record_failure( - current_cred, - model, - classified_error, - increment_consecutive_failures=False, - ) - - if attempt >= self.max_retries - 1: - # Record in accumulator only on final failure for this key - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - lib_logger.warning( - f"Key {mask_credential(current_cred)} failed after max retries due to server error. Rotating." - ) - break # Move to the next key - - # For temporary errors, wait before retrying with the same key. - wait_time = classified_error.retry_after or ( - 2**attempt - ) + random.uniform(0, 1) - remaining_budget = deadline - time.time() - - # If the required wait time exceeds the budget, don't wait; rotate to the next key immediately. - if wait_time > remaining_budget: - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - lib_logger.warning( - f"Retry wait ({wait_time:.2f}s) exceeds budget ({remaining_budget:.2f}s). Rotating key." - ) - break - - lib_logger.warning( - f"Key {mask_credential(current_cred)} server error. Retrying in {wait_time:.2f}s." - ) - await asyncio.sleep(wait_time) - continue # Retry with the same key - - except httpx.HTTPStatusError as e: - # Handle HTTP errors from httpx (e.g., from custom providers like Antigravity) - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - - classified_error = classify_error(e, provider=provider) - error_message = str(e).split("\n")[0] - - lib_logger.warning( - f"Key {mask_credential(current_cred)} HTTP {e.response.status_code} ({classified_error.error_type})." - ) - - # Check if this error should trigger rotation - if not should_rotate_on_error(classified_error): - lib_logger.error( - f"Non-recoverable error ({classified_error.error_type}). Failing request." - ) - raise last_exception - - # Record in accumulator after confirming it's a rotatable error - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - - # Handle rate limits with cooldown (exclude quota_exceeded from provider-wide cooldown) - if classified_error.error_type == "rate_limit": - cooldown_duration = classified_error.retry_after or 60 - await self.cooldown_manager.start_cooldown( - provider, cooldown_duration - ) - - # Check if we should retry same key (server errors with retries left) - if ( - should_retry_same_key(classified_error) - and attempt < self.max_retries - 1 - ): - wait_time = classified_error.retry_after or ( - 2**attempt - ) + random.uniform(0, 1) - remaining_budget = deadline - time.time() - if wait_time <= remaining_budget: - lib_logger.warning( - f"Server error, retrying same key in {wait_time:.2f}s." - ) - await asyncio.sleep(wait_time) - continue - - # Record failure and rotate to next key - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - lib_logger.info( - f"Rotating to next key after {classified_error.error_type} error." - ) - break - - except Exception as e: - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - - if request and await request.is_disconnected(): - lib_logger.warning( - f"Client disconnected. Aborting retries for {mask_credential(current_cred)}." - ) - raise last_exception - - classified_error = classify_error(e, provider=provider) - error_message = str(e).split("\n")[0] - - lib_logger.warning( - f"Key {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code})." - ) - - # Handle rate limits with cooldown (exclude quota_exceeded from provider-wide cooldown) - if ( - classified_error.status_code == 429 - and classified_error.error_type != "quota_exceeded" - ) or classified_error.error_type == "rate_limit": - cooldown_duration = classified_error.retry_after or 60 - await self.cooldown_manager.start_cooldown( - provider, cooldown_duration - ) - - # Check if this error should trigger rotation - if not should_rotate_on_error(classified_error): - lib_logger.error( - f"Non-recoverable error ({classified_error.error_type}). Failing request." - ) - raise last_exception - - # Record in accumulator after confirming it's a rotatable error - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - break # Try next key for other errors - finally: - if key_acquired and current_cred: - await self.usage_manager.release_key(current_cred, model) - - # Check if we exhausted all credentials or timed out - if time.time() >= deadline: - error_accumulator.timeout_occurred = True - - if error_accumulator.has_errors(): - # Log concise summary for server logs - lib_logger.error(error_accumulator.build_log_message()) - - # Return the structured error response for the client - return error_accumulator.build_client_error_response() - - # Return None to indicate failure without error details (shouldn't normally happen) - lib_logger.warning( - "Unexpected state: request failed with no recorded errors. " - "This may indicate a logic error in error tracking." - ) - return None - - async def _streaming_acompletion_with_retry( - self, - request: Optional[Any], - pre_request_callback: Optional[callable] = None, - **kwargs, - ) -> AsyncGenerator[str, None]: - """A dedicated generator for retrying streaming completions with full request preparation and per-key retries.""" - model = kwargs.get("model") - provider = model.split("/")[0] - - # Extract internal logging parameters (not passed to API) - parent_log_dir = kwargs.pop("_parent_log_dir", None) - - # Create a mutable copy of the keys and shuffle it. - credentials_for_provider = list(self.all_credentials[provider]) - random.shuffle(credentials_for_provider) - - # Filter out credentials that are unavailable (queued for re-auth) - provider_plugin = self._get_provider_instance(provider) - if provider_plugin and hasattr(provider_plugin, "is_credential_available"): - available_creds = [ - cred - for cred in credentials_for_provider - if provider_plugin.is_credential_available(cred) - ] - if available_creds: - credentials_for_provider = available_creds - # If all credentials are unavailable, keep the original list - # (better to try unavailable creds than fail immediately) - - deadline = time.time() + self.global_timeout - - # Create transaction logger if request logging is enabled - transaction_logger = None - if self.enable_request_logging: - transaction_logger = TransactionLogger( - provider, - model, - enabled=True, - api_format="oai", - parent_dir=parent_log_dir, - ) - transaction_logger.log_request(kwargs) - - tried_creds = set() - last_exception = None - - consecutive_quota_failures = 0 - - # Resolve model ID early, before any credential operations - # This ensures consistent model ID usage for acquisition, release, and tracking - resolved_model = self._resolve_model_id(model, provider) - if resolved_model != model: - lib_logger.info(f"Resolved model '{model}' to '{resolved_model}'") - model = resolved_model - kwargs["model"] = model # Ensure kwargs has the resolved model for litellm - - # [NEW] Filter by model tier requirement and build priority map - credential_priorities = None - if provider_plugin and hasattr(provider_plugin, "get_model_tier_requirement"): - required_tier = provider_plugin.get_model_tier_requirement(model) - if required_tier is not None: - # Filter OUT only credentials we KNOW are too low priority - # Keep credentials with unknown priority (None) - they might be high priority - incompatible_creds = [] - compatible_creds = [] - unknown_creds = [] - - for cred in credentials_for_provider: - if hasattr(provider_plugin, "get_credential_priority"): - priority = provider_plugin.get_credential_priority(cred) - if priority is None: - # Unknown priority - keep it, will be discovered on first use - unknown_creds.append(cred) - elif priority <= required_tier: - # Known compatible priority - compatible_creds.append(cred) - else: - # Known incompatible priority (too low) - incompatible_creds.append(cred) - else: - # Provider doesn't support priorities - keep all - unknown_creds.append(cred) - - # If we have any known-compatible or unknown credentials, use them - tier_compatible_creds = compatible_creds + unknown_creds - if tier_compatible_creds: - credentials_for_provider = tier_compatible_creds - if compatible_creds and unknown_creds: - lib_logger.info( - f"Model {model} requires priority <= {required_tier}. " - f"Using {len(compatible_creds)} known-compatible + {len(unknown_creds)} unknown-tier credentials." - ) - elif compatible_creds: - lib_logger.info( - f"Model {model} requires priority <= {required_tier}. " - f"Using {len(compatible_creds)} known-compatible credentials." - ) - else: - lib_logger.info( - f"Model {model} requires priority <= {required_tier}. " - f"Using {len(unknown_creds)} unknown-tier credentials (will discover on use)." - ) - elif incompatible_creds: - # Only known-incompatible credentials remain - lib_logger.warning( - f"Model {model} requires priority <= {required_tier} credentials, " - f"but all {len(incompatible_creds)} known credentials have priority > {required_tier}. " - f"Request will likely fail." - ) - - # Build priority map and tier names map for usage_manager - credential_tier_names = None - if provider_plugin and hasattr(provider_plugin, "get_credential_priority"): - credential_priorities = {} - credential_tier_names = {} - for cred in credentials_for_provider: - priority = provider_plugin.get_credential_priority(cred) - if priority is not None: - credential_priorities[cred] = priority - # Also get tier name for logging - if hasattr(provider_plugin, "get_credential_tier_name"): - tier_name = provider_plugin.get_credential_tier_name(cred) - if tier_name: - credential_tier_names[cred] = tier_name - - if credential_priorities: - lib_logger.debug( - f"Credential priorities for {provider}: {', '.join(f'P{p}={len([c for c in credentials_for_provider if credential_priorities.get(c) == p])}' for p in sorted(set(credential_priorities.values())))}" - ) - - # Initialize error accumulator for tracking errors across credential rotation - error_accumulator = RequestErrorAccumulator() - error_accumulator.model = model - error_accumulator.provider = provider - - try: - while ( - len(tried_creds) < len(credentials_for_provider) - and time.time() < deadline - ): - current_cred = None - key_acquired = False - try: - if await self.cooldown_manager.is_cooling_down(provider): - remaining_cooldown = ( - await self.cooldown_manager.get_cooldown_remaining(provider) - ) - remaining_budget = deadline - time.time() - if remaining_cooldown > remaining_budget: - lib_logger.warning( - f"Provider {provider} cooldown ({remaining_cooldown:.2f}s) exceeds remaining request budget ({remaining_budget:.2f}s). Failing early." - ) - break - lib_logger.warning( - f"Provider {provider} is in a global cooldown. All requests to this provider will be paused for {remaining_cooldown:.2f} seconds." - ) - await asyncio.sleep(remaining_cooldown) - - creds_to_try = [ - c for c in credentials_for_provider if c not in tried_creds - ] - if not creds_to_try: - lib_logger.warning( - f"All credentials for provider {provider} have been tried. No more credentials to rotate to." - ) - break - - # Get count of credentials not on cooldown for this model - availability_stats = ( - await self.usage_manager.get_credential_availability_stats( - creds_to_try, model, credential_priorities - ) - ) - available_count = availability_stats["available"] - total_count = len(credentials_for_provider) - on_cooldown = availability_stats["on_cooldown"] - fc_excluded = availability_stats["fair_cycle_excluded"] - - # Build compact exclusion breakdown - exclusion_parts = [] - if on_cooldown > 0: - exclusion_parts.append(f"cd:{on_cooldown}") - if fc_excluded > 0: - exclusion_parts.append(f"fc:{fc_excluded}") - exclusion_str = ( - f",{','.join(exclusion_parts)}" if exclusion_parts else "" - ) - - lib_logger.info( - f"Acquiring credential for model {model}. Tried credentials: {len(tried_creds)}/{available_count}({total_count}{exclusion_str})" - ) - max_concurrent = self.max_concurrent_requests_per_key.get( - provider, 1 - ) - current_cred = await self.usage_manager.acquire_key( - available_keys=creds_to_try, - model=model, - deadline=deadline, - max_concurrent=max_concurrent, - credential_priorities=credential_priorities, - credential_tier_names=credential_tier_names, - all_provider_credentials=credentials_for_provider, - ) - key_acquired = True - tried_creds.add(current_cred) - - litellm_kwargs = kwargs.copy() - if "reasoning_effort" in kwargs: - litellm_kwargs["reasoning_effort"] = kwargs["reasoning_effort"] - - # [NEW] Merge provider-specific params - if provider in self.litellm_provider_params: - litellm_kwargs["litellm_params"] = { - **self.litellm_provider_params[provider], - **litellm_kwargs.get("litellm_params", {}), - } - - provider_plugin = self._get_provider_instance(provider) - - # Model ID is already resolved before the loop, and kwargs['model'] is updated. - # No further resolution needed here. - - # Apply model-specific options for custom providers - if provider_plugin and hasattr( - provider_plugin, "get_model_options" - ): - model_options = provider_plugin.get_model_options(model) - if model_options: - # Merge model options into litellm_kwargs - for key, value in model_options.items(): - if key == "reasoning_effort": - litellm_kwargs["reasoning_effort"] = value - elif key not in litellm_kwargs: - litellm_kwargs[key] = value - if provider_plugin and provider_plugin.has_custom_logic(): - lib_logger.debug( - f"Provider '{provider}' has custom logic. Delegating call." - ) - litellm_kwargs["credential_identifier"] = current_cred - litellm_kwargs["transaction_context"] = ( - transaction_logger.get_context() - if transaction_logger - else None - ) - - for attempt in range(self.max_retries): - try: - lib_logger.info( - f"Attempting stream with credential {mask_credential(current_cred)} (Attempt {attempt + 1}/{self.max_retries})" - ) - - if pre_request_callback: - try: - await pre_request_callback( - request, litellm_kwargs - ) - except Exception as e: - if self.abort_on_callback_error: - raise PreRequestCallbackError( - f"Pre-request callback failed: {e}" - ) from e - else: - lib_logger.warning( - f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}" - ) - - response = await provider_plugin.acompletion( - self.http_client, **litellm_kwargs - ) - - lib_logger.info( - f"Stream connection established for credential {mask_credential(current_cred)}. Processing response." - ) - - key_acquired = False - stream_generator = self._safe_streaming_wrapper( - response, - current_cred, - model, - request, - provider_plugin, - ) - - # Wrap with transaction logging - logged_stream = ( - self._transaction_logging_stream_wrapper( - stream_generator, transaction_logger, kwargs - ) - ) - - async for chunk in logged_stream: - yield chunk - return - - except ( - StreamedAPIError, - litellm.RateLimitError, - httpx.HTTPStatusError, - ) as e: - last_exception = e - # If the exception is our custom wrapper, unwrap the original error - original_exc = getattr(e, "data", e) - classified_error = classify_error( - original_exc, provider=provider - ) - error_message = str(original_exc).split("\n")[0] - - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - - # Record in accumulator for client reporting - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - - # Check if this error should trigger rotation - if not should_rotate_on_error(classified_error): - lib_logger.error( - f"Non-recoverable error ({classified_error.error_type}) during custom stream. Failing." - ) - raise last_exception - - # Handle rate limits with cooldown (exclude quota_exceeded) - if classified_error.error_type == "rate_limit": - cooldown_duration = ( - classified_error.retry_after or 60 - ) - await self.cooldown_manager.start_cooldown( - provider, cooldown_duration - ) - - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - lib_logger.warning( - f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code}). Rotating." - ) - break - - except ( - APIConnectionError, - litellm.InternalServerError, - litellm.ServiceUnavailableError, - ) as e: - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - classified_error = classify_error(e, provider=provider) - error_message = str(e).split("\n")[0] - - # Provider-level error: don't increment consecutive failures - await self.usage_manager.record_failure( - current_cred, - model, - classified_error, - increment_consecutive_failures=False, - ) - - if attempt >= self.max_retries - 1: - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - lib_logger.warning( - f"Cred {mask_credential(current_cred)} failed after max retries. Rotating." - ) - break - - wait_time = classified_error.retry_after or ( - 2**attempt - ) + random.uniform(0, 1) - remaining_budget = deadline - time.time() - if wait_time > remaining_budget: - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - lib_logger.warning( - f"Retry wait ({wait_time:.2f}s) exceeds budget. Rotating." - ) - break - - lib_logger.warning( - f"Cred {mask_credential(current_cred)} server error. Retrying in {wait_time:.2f}s." - ) - await asyncio.sleep(wait_time) - continue - - except Exception as e: - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - classified_error = classify_error(e, provider=provider) - error_message = str(e).split("\n")[0] - - # Record in accumulator - error_accumulator.record_error( - current_cred, classified_error, error_message - ) - - lib_logger.warning( - f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code})." - ) - - # Check if this error should trigger rotation - if not should_rotate_on_error(classified_error): - lib_logger.error( - f"Non-recoverable error ({classified_error.error_type}). Failing." - ) - raise last_exception - - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - break - - # If the inner loop breaks, it means the key failed and we need to rotate. - # Continue to the next iteration of the outer while loop to pick a new key. - continue - - else: # This is the standard API Key / litellm-handled provider logic - is_oauth = provider in self.oauth_providers - if is_oauth: # Standard OAuth provider (not custom) - # ... (logic to set headers) ... - pass - else: # API Key - litellm_kwargs["api_key"] = current_cred - - provider_instance = self._get_provider_instance(provider) - if provider_instance: - # Ensure default Gemini safety settings are present (without overriding request) - try: - self._apply_default_safety_settings( - litellm_kwargs, provider - ) - except Exception: - lib_logger.debug( - "Could not apply default safety settings for streaming path; continuing." - ) - - if "safety_settings" in litellm_kwargs: - converted_settings = ( - provider_instance.convert_safety_settings( - litellm_kwargs["safety_settings"] - ) - ) - if converted_settings is not None: - litellm_kwargs["safety_settings"] = converted_settings - else: - del litellm_kwargs["safety_settings"] - - if provider == "gemini" and provider_instance: - provider_instance.handle_thinking_parameter( - litellm_kwargs, model - ) - if provider == "nvidia_nim" and provider_instance: - provider_instance.handle_thinking_parameter( - litellm_kwargs, model - ) - - if "gemma-3" in model and "messages" in litellm_kwargs: - litellm_kwargs["messages"] = [ - {"role": "user", "content": m["content"]} - if m.get("role") == "system" - else m - for m in litellm_kwargs["messages"] - ] - - litellm_kwargs = sanitize_request_payload(litellm_kwargs, model) - - # If the provider is 'qwen_code', set the custom provider to 'qwen' - # and strip the prefix from the model name for LiteLLM. - if provider == "qwen_code": - litellm_kwargs["custom_llm_provider"] = "qwen" - litellm_kwargs["model"] = model.split("/", 1)[1] - - for attempt in range(self.max_retries): - try: - lib_logger.info( - f"Attempting stream with credential {mask_credential(current_cred)} (Attempt {attempt + 1}/{self.max_retries})" - ) - - if pre_request_callback: - try: - await pre_request_callback(request, litellm_kwargs) - except Exception as e: - if self.abort_on_callback_error: - raise PreRequestCallbackError( - f"Pre-request callback failed: {e}" - ) from e - else: - lib_logger.warning( - f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}" - ) - - # lib_logger.info(f"DEBUG: litellm.acompletion kwargs: {litellm_kwargs}") - # Convert model parameters for custom providers right before LiteLLM call - final_kwargs = self.provider_config.convert_for_litellm( - **litellm_kwargs - ) - - response = await litellm.acompletion( - **final_kwargs, - logger_fn=self._litellm_logger_callback, - ) - - lib_logger.info( - f"Stream connection established for credential {mask_credential(current_cred)}. Processing response." - ) - - key_acquired = False - stream_generator = self._safe_streaming_wrapper( - response, - current_cred, - model, - request, - provider_instance, - ) - - # Wrap with transaction logging - logged_stream = self._transaction_logging_stream_wrapper( - stream_generator, transaction_logger, kwargs - ) - - async for chunk in logged_stream: - yield chunk - return - - except ( - StreamedAPIError, - litellm.RateLimitError, - httpx.HTTPStatusError, - ) as e: - last_exception = e - - # This is the final, robust handler for streamed errors. - error_payload = {} - cleaned_str = None - # The actual exception might be wrapped in our StreamedAPIError. - original_exc = getattr(e, "data", e) - classified_error = classify_error( - original_exc, provider=provider - ) - - # Check if this error should trigger rotation - if not should_rotate_on_error(classified_error): - lib_logger.error( - f"Non-recoverable error ({classified_error.error_type}) during litellm stream. Failing." - ) - raise last_exception - - try: - # The full error JSON is in the string representation of the exception. - json_str_match = re.search( - r"(\{.*\})", str(original_exc), re.DOTALL - ) - if json_str_match: - cleaned_str = codecs.decode( - json_str_match.group(1), "unicode_escape" - ) - error_payload = json.loads(cleaned_str) - except (json.JSONDecodeError, TypeError): - error_payload = {} - - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - raw_response_text=cleaned_str, - ) - - error_details = error_payload.get("error", {}) - error_status = error_details.get("status", "") - error_message_text = error_details.get( - "message", str(original_exc).split("\n")[0] - ) - - # Record in accumulator for client reporting - error_accumulator.record_error( - current_cred, classified_error, error_message_text - ) - - if ( - "quota" in error_message_text.lower() - or "resource_exhausted" in error_status.lower() - ): - consecutive_quota_failures += 1 - - quota_value = "N/A" - quota_id = "N/A" - if "details" in error_details and isinstance( - error_details.get("details"), list - ): - for detail in error_details["details"]: - if isinstance(detail.get("violations"), list): - for violation in detail["violations"]: - if "quotaValue" in violation: - quota_value = violation[ - "quotaValue" - ] - if "quotaId" in violation: - quota_id = violation["quotaId"] - if ( - quota_value != "N/A" - and quota_id != "N/A" - ): - break - - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - - if consecutive_quota_failures >= 3: - # Fatal: likely input data too large - client_error_message = ( - f"Request failed after 3 consecutive quota errors (input may be too large). " - f"Limit: {quota_value} (Quota ID: {quota_id})" - ) - lib_logger.error( - f"Fatal quota error for {mask_credential(current_cred)}. ID: {quota_id}, Limit: {quota_value}" - ) - yield f"data: {json.dumps({'error': {'message': client_error_message, 'type': 'proxy_fatal_quota_error'}})}\n\n" - yield "data: [DONE]\n\n" - return - else: - lib_logger.warning( - f"Cred {mask_credential(current_cred)} quota error ({consecutive_quota_failures}/3). Rotating." - ) - break - - else: - consecutive_quota_failures = 0 - lib_logger.warning( - f"Cred {mask_credential(current_cred)} {classified_error.error_type}. Rotating." - ) - - if classified_error.error_type == "rate_limit": - cooldown_duration = ( - classified_error.retry_after or 60 - ) - await self.cooldown_manager.start_cooldown( - provider, cooldown_duration - ) - - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - break - - except ( - APIConnectionError, - litellm.InternalServerError, - litellm.ServiceUnavailableError, - ) as e: - consecutive_quota_failures = 0 - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - classified_error = classify_error(e, provider=provider) - error_message_text = str(e).split("\n")[0] - - # Record error in accumulator (server errors are transient, not abnormal) - error_accumulator.record_error( - current_cred, classified_error, error_message_text - ) - - # Provider-level error: don't increment consecutive failures - await self.usage_manager.record_failure( - current_cred, - model, - classified_error, - increment_consecutive_failures=False, - ) - - if attempt >= self.max_retries - 1: - lib_logger.warning( - f"Credential {mask_credential(current_cred)} failed after max retries for model {model} due to a server error. Rotating key silently." - ) - # [MODIFIED] Do not yield to the client here. - break - - wait_time = classified_error.retry_after or ( - 2**attempt - ) + random.uniform(0, 1) - remaining_budget = deadline - time.time() - if wait_time > remaining_budget: - lib_logger.warning( - f"Required retry wait time ({wait_time:.2f}s) exceeds remaining budget ({remaining_budget:.2f}s). Rotating key early." - ) - break - - lib_logger.warning( - f"Credential {mask_credential(current_cred)} encountered a server error for model {model}. Reason: '{error_message_text}'. Retrying in {wait_time:.2f}s." - ) - await asyncio.sleep(wait_time) - continue - - except Exception as e: - consecutive_quota_failures = 0 - last_exception = e - log_failure( - api_key=current_cred, - model=model, - attempt=attempt + 1, - error=e, - request_headers=dict(request.headers) - if request - else {}, - ) - classified_error = classify_error(e, provider=provider) - error_message_text = str(e).split("\n")[0] - - # Record error in accumulator - error_accumulator.record_error( - current_cred, classified_error, error_message_text - ) - - lib_logger.warning( - f"Credential {mask_credential(current_cred)} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {error_message_text}." - ) - - # Handle rate limits with cooldown (exclude quota_exceeded) - if ( - classified_error.status_code == 429 - and classified_error.error_type != "quota_exceeded" - ) or classified_error.error_type == "rate_limit": - cooldown_duration = classified_error.retry_after or 60 - await self.cooldown_manager.start_cooldown( - provider, cooldown_duration - ) - lib_logger.warning( - f"Rate limit detected for {provider}. Starting {cooldown_duration}s cooldown." - ) - - # Check if this error should trigger rotation - if not should_rotate_on_error(classified_error): - # Non-rotatable errors - fail immediately - lib_logger.error( - f"Non-recoverable error ({classified_error.error_type}). Failing request." - ) - raise last_exception - - # Record failure and rotate to next key - await self.usage_manager.record_failure( - current_cred, model, classified_error - ) - lib_logger.info( - f"Rotating to next key after {classified_error.error_type} error." - ) - break - - finally: - if key_acquired and current_cred: - await self.usage_manager.release_key(current_cred, model) - - # Build detailed error response using error accumulator - error_accumulator.timeout_occurred = time.time() >= deadline - - if error_accumulator.has_errors(): - # Log concise summary for server logs - lib_logger.error(error_accumulator.build_log_message()) - - # Build structured error response for client - error_response = error_accumulator.build_client_error_response() - error_data = error_response - else: - # Fallback if no errors were recorded (shouldn't happen) - final_error_message = ( - "Request failed: No available API keys after rotation or timeout." - ) - if last_exception: - final_error_message = ( - f"Request failed. Last error: {str(last_exception)}" - ) - error_data = { - "error": {"message": final_error_message, "type": "proxy_error"} - } - lib_logger.error(final_error_message) - - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" - - except NoAvailableKeysError as e: - lib_logger.error( - f"A streaming request failed because no keys were available within the time budget: {e}" - ) - error_data = {"error": {"message": str(e), "type": "proxy_busy"}} - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" - except Exception as e: - # This will now only catch fatal errors that should be raised, like invalid requests. - lib_logger.error( - f"An unhandled exception occurred in streaming retry logic: {e}", - exc_info=True, - ) - error_data = { - "error": { - "message": f"An unexpected error occurred: {str(e)}", - "type": "proxy_internal_error", - } - } - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" - - def acompletion( - self, - request: Optional[Any] = None, - pre_request_callback: Optional[callable] = None, - **kwargs, - ) -> Union[Any, AsyncGenerator[str, None]]: - """ - Dispatcher for completion requests. - - Args: - request: Optional request object, used for client disconnect checks and logging. - pre_request_callback: Optional async callback function to be called before each API request attempt. - The callback will receive the `request` object and the prepared request `kwargs` as arguments. - This can be used for custom logic such as request validation, logging, or rate limiting. - If the callback raises an exception, the completion request will be aborted and the exception will propagate. - - Returns: - The completion response object, or an async generator for streaming responses, or None if all retries fail. - """ - # Handle iflow provider: remove stream_options to avoid HTTP 406 - model = kwargs.get("model", "") - provider = model.split("/")[0] if "/" in model else "" - - if provider == "iflow" and "stream_options" in kwargs: - lib_logger.debug( - "Removing stream_options for iflow provider to avoid HTTP 406" - ) - kwargs.pop("stream_options", None) - - if kwargs.get("stream"): - # Only add stream_options for providers that support it (excluding iflow) - if provider != "iflow": - if "stream_options" not in kwargs: - kwargs["stream_options"] = {} - if "include_usage" not in kwargs["stream_options"]: - kwargs["stream_options"]["include_usage"] = True - - return self._streaming_acompletion_with_retry( - request=request, pre_request_callback=pre_request_callback, **kwargs - ) - else: - return self._execute_with_retry( - litellm.acompletion, - request=request, - pre_request_callback=pre_request_callback, - **kwargs, - ) - - def aembedding( - self, - request: Optional[Any] = None, - pre_request_callback: Optional[callable] = None, - **kwargs, - ) -> Any: - """ - Executes an embedding request with retry logic. - - Args: - request: Optional request object, used for client disconnect checks and logging. - pre_request_callback: Optional async callback function to be called before each API request attempt. - The callback will receive the `request` object and the prepared request `kwargs` as arguments. - This can be used for custom logic such as request validation, logging, or rate limiting. - If the callback raises an exception, the embedding request will be aborted and the exception will propagate. - - Returns: - The embedding response object, or None if all retries fail. - """ - return self._execute_with_retry( - litellm.aembedding, - request=request, - pre_request_callback=pre_request_callback, - **kwargs, - ) - - def token_count(self, **kwargs) -> int: - """Calculates the number of tokens for a given text or list of messages. - - For Antigravity provider models, this also includes the preprompt tokens - that get injected during actual API calls (agent instruction + identity override). - This ensures token counts match actual usage. - """ - model = kwargs.get("model") - text = kwargs.get("text") - messages = kwargs.get("messages") - - if not model: - raise ValueError("'model' is a required parameter.") - - # Calculate base token count - if messages: - base_count = token_counter(model=model, messages=messages) - elif text: - base_count = token_counter(model=model, text=text) - else: - raise ValueError("Either 'text' or 'messages' must be provided.") - - # Add preprompt tokens for Antigravity provider - # The Antigravity provider injects system instructions during actual API calls, - # so we need to account for those tokens in the count - provider = model.split("/")[0] if "/" in model else "" - if provider == "antigravity": - try: - from .providers.antigravity_provider import ( - get_antigravity_preprompt_text, - ) - - preprompt_text = get_antigravity_preprompt_text() - if preprompt_text: - preprompt_tokens = token_counter(model=model, text=preprompt_text) - base_count += preprompt_tokens - except ImportError: - # Provider not available, skip preprompt token counting - pass - - return base_count - - async def get_available_models(self, provider: str) -> List[str]: - """Returns a list of available models for a specific provider, with caching.""" - lib_logger.info(f"Getting available models for provider: {provider}") - if provider in self._model_list_cache: - lib_logger.debug(f"Returning cached models for provider: {provider}") - return self._model_list_cache[provider] - - credentials_for_provider = self.all_credentials.get(provider) - if not credentials_for_provider: - lib_logger.warning(f"No credentials for provider: {provider}") - return [] - - # Create a copy and shuffle it to randomize the starting credential - shuffled_credentials = list(credentials_for_provider) - random.shuffle(shuffled_credentials) - - provider_instance = self._get_provider_instance(provider) - if provider_instance: - # For providers with hardcoded models (like gemini_cli), we only need to call once. - # For others, we might need to try multiple keys if one is invalid. - # The current logic of iterating works for both, as the credential is not - # always used in get_models. - for credential in shuffled_credentials: - try: - # Display last 6 chars for API keys, or the filename for OAuth paths - cred_display = mask_credential(credential) - lib_logger.debug( - f"Attempting to get models for {provider} with credential {cred_display}" - ) - models = await provider_instance.get_models( - credential, self.http_client - ) - lib_logger.info( - f"Got {len(models)} models for provider: {provider}" - ) - - # Whitelist and blacklist logic - final_models = [] - for m in models: - is_whitelisted = self._is_model_whitelisted(provider, m) - is_blacklisted = self._is_model_ignored(provider, m) - - if is_whitelisted: - final_models.append(m) - continue - - if not is_blacklisted: - final_models.append(m) - - if len(final_models) != len(models): - lib_logger.info( - f"Filtered out {len(models) - len(final_models)} models for provider {provider}." - ) - - self._model_list_cache[provider] = final_models - return final_models - except Exception as e: - classified_error = classify_error(e, provider=provider) - cred_display = mask_credential(credential) - lib_logger.debug( - f"Failed to get models for provider {provider} with credential {cred_display}: {classified_error.error_type}. Trying next credential." - ) - continue # Try the next credential - - lib_logger.error( - f"Failed to get models for provider {provider} after trying all credentials." - ) - return [] - - async def get_all_available_models( - self, grouped: bool = True - ) -> Union[Dict[str, List[str]], List[str]]: - """Returns a list of all available models, either grouped by provider or as a flat list.""" - lib_logger.info("Getting all available models...") - - all_providers = list(self.all_credentials.keys()) - tasks = [self.get_available_models(provider) for provider in all_providers] - results = await asyncio.gather(*tasks, return_exceptions=True) - - all_provider_models = {} - for provider, result in zip(all_providers, results): - if isinstance(result, Exception): - lib_logger.error( - f"Failed to get models for provider {provider}: {result}" - ) - all_provider_models[provider] = [] - else: - all_provider_models[provider] = result - - lib_logger.info("Finished getting all available models.") - if grouped: - return all_provider_models - else: - flat_models = [] - for models in all_provider_models.values(): - flat_models.extend(models) - return flat_models - - async def get_quota_stats( - self, - provider_filter: Optional[str] = None, - ) -> Dict[str, Any]: - """ - Get quota and usage stats for all credentials. - - This returns cached/disk data aggregated by provider. - For provider-specific quota info (e.g., Antigravity quota groups), - it enriches the data from provider plugins. - - Args: - provider_filter: If provided, only return stats for this provider - - Returns: - Complete stats dict ready for the /v1/quota-stats endpoint - """ - # Get base stats from usage manager - stats = await self.usage_manager.get_stats_for_endpoint(provider_filter) - - # Enrich with provider-specific quota data - for provider, prov_stats in stats.get("providers", {}).items(): - provider_class = self._provider_plugins.get(provider) - if not provider_class: - continue - - # Get or create provider instance - if provider not in self._provider_instances: - self._provider_instances[provider] = provider_class() - provider_instance = self._provider_instances[provider] - - # Check if provider has quota tracking (like Antigravity) - if hasattr(provider_instance, "_get_effective_quota_groups"): - # Add quota group summary - quota_groups = provider_instance._get_effective_quota_groups() - prov_stats["quota_groups"] = {} - - for group_name, group_models in quota_groups.items(): - group_stats = { - "models": group_models, - "credentials_total": 0, - "credentials_exhausted": 0, - "avg_remaining_pct": 0, - "total_remaining_pcts": [], - # Total requests tracking across all credentials - "total_requests_used": 0, - "total_requests_max": 0, - # Tier breakdown: tier_name -> {"total": N, "active": M} - "tiers": {}, - } - - # Calculate per-credential quota for this group - for cred in prov_stats.get("credentials", []): - models_data = cred.get("models", {}) - group_stats["credentials_total"] += 1 - - # Track tier - get directly from provider cache since cred["tier"] not set yet - tier = cred.get("tier") - if not tier and hasattr( - provider_instance, "project_tier_cache" - ): - cred_path = cred.get("full_path", "") - tier = provider_instance.project_tier_cache.get(cred_path) - tier = tier or "unknown" - - # Initialize tier entry if needed with priority for sorting - if tier not in group_stats["tiers"]: - priority = 10 # default - if hasattr(provider_instance, "_resolve_tier_priority"): - priority = provider_instance._resolve_tier_priority( - tier - ) - group_stats["tiers"][tier] = { - "total": 0, - "active": 0, - "priority": priority, - } - group_stats["tiers"][tier]["total"] += 1 - - # Find model with VALID baseline (not just any model with stats) - model_stats = None - for model in group_models: - candidate = self._find_model_stats_in_data( - models_data, model, provider, provider_instance - ) - if candidate: - baseline = candidate.get("baseline_remaining_fraction") - if baseline is not None: - model_stats = candidate - break - # Keep first found as fallback (for request counts) - if model_stats is None: - model_stats = candidate - - if model_stats: - baseline = model_stats.get("baseline_remaining_fraction") - req_count = model_stats.get("request_count", 0) - max_req = model_stats.get("quota_max_requests") or 0 - - # Accumulate totals (one model per group per credential) - group_stats["total_requests_used"] += req_count - group_stats["total_requests_max"] += max_req - - if baseline is not None: - remaining_pct = int(baseline * 100) - group_stats["total_remaining_pcts"].append( - remaining_pct - ) - if baseline <= 0: - group_stats["credentials_exhausted"] += 1 - else: - # Credential is active (has quota remaining) - group_stats["tiers"][tier]["active"] += 1 - - # Calculate average remaining percentage (per-credential average) - if group_stats["total_remaining_pcts"]: - group_stats["avg_remaining_pct"] = int( - sum(group_stats["total_remaining_pcts"]) - / len(group_stats["total_remaining_pcts"]) - ) - del group_stats["total_remaining_pcts"] - - # Calculate total remaining percentage (global) - if group_stats["total_requests_max"] > 0: - used = group_stats["total_requests_used"] - max_r = group_stats["total_requests_max"] - group_stats["total_requests_remaining"] = max_r - used - group_stats["total_remaining_pct"] = max( - 0, int((1 - used / max_r) * 100) - ) - else: - group_stats["total_requests_remaining"] = 0 - # Fallback to avg_remaining_pct when max_requests unavailable - # This handles providers like Firmware that only provide percentage - group_stats["total_remaining_pct"] = group_stats.get("avg_remaining_pct") - - prov_stats["quota_groups"][group_name] = group_stats - - # Also enrich each credential with formatted quota group info - for cred in prov_stats.get("credentials", []): - cred["model_groups"] = {} - models_data = cred.get("models", {}) - - for group_name, group_models in quota_groups.items(): - # Find model with VALID baseline (prefer over any model with stats) - # Also track the best reset_ts across all models in the group - model_stats = None - best_reset_ts = None - - for model in group_models: - candidate = self._find_model_stats_in_data( - models_data, model, provider, provider_instance - ) - if candidate: - # Track the best (latest) reset_ts from any model in group - candidate_reset_ts = candidate.get("quota_reset_ts") - if candidate_reset_ts: - if ( - best_reset_ts is None - or candidate_reset_ts > best_reset_ts - ): - best_reset_ts = candidate_reset_ts - - baseline = candidate.get("baseline_remaining_fraction") - if baseline is not None: - model_stats = candidate - # Don't break - continue to find best reset_ts - # Keep first found as fallback - if model_stats is None: - model_stats = candidate - - if model_stats: - baseline = model_stats.get("baseline_remaining_fraction") - max_req = model_stats.get("quota_max_requests") - req_count = model_stats.get("request_count", 0) - # Use best_reset_ts from any model in the group - reset_ts = best_reset_ts or model_stats.get( - "quota_reset_ts" - ) - - remaining_pct = ( - int(baseline * 100) if baseline is not None else None - ) - is_exhausted = baseline is not None and baseline <= 0 - - # Format reset time - reset_iso = None - if reset_ts: - try: - from datetime import datetime, timezone - - reset_iso = datetime.fromtimestamp( - reset_ts, tz=timezone.utc - ).isoformat() - except (ValueError, OSError): - pass - - requests_remaining = ( - max(0, max_req - req_count) if max_req else 0 - ) - - # Determine display format - # Priority: requests (if max known) > percentage (if baseline available) > unknown - if max_req: - display = f"{requests_remaining}/{max_req}" - elif remaining_pct is not None: - display = f"{remaining_pct}%" - else: - display = "?/?" - - cred["model_groups"][group_name] = { - "remaining_pct": remaining_pct, - "requests_used": req_count, - "requests_remaining": requests_remaining, - "requests_max": max_req, - "display": display, - "is_exhausted": is_exhausted, - "reset_time_iso": reset_iso, - "models": group_models, - "confidence": self._get_baseline_confidence( - model_stats - ), - } - - # Recalculate credential's requests from model_groups - # This fixes double-counting when models share quota groups - if cred.get("model_groups"): - group_requests = sum( - g.get("requests_used", 0) - for g in cred["model_groups"].values() - ) - cred["requests"] = group_requests - - # HACK: Fix global requests if present - # This is a simplified fix that sets global.requests = current group_requests. - # TODO: Properly track archived requests per quota group in usage_manager.py - # so that global stats correctly sum: current_period + archived_periods - # without double-counting models that share quota groups. - # See: usage_manager.py lines 2388-2404 where global stats are built - # by iterating all models (causing double-counting for grouped models). - if cred.get("global"): - cred["global"]["requests"] = group_requests - - # Try to get email from provider's cache - cred_path = cred.get("full_path", "") - if hasattr(provider_instance, "project_tier_cache"): - tier = provider_instance.project_tier_cache.get(cred_path) - if tier: - cred["tier"] = tier - - return stats - - def _find_model_stats_in_data( - self, - models_data: Dict[str, Any], - model: str, - provider: str, - provider_instance: Any, - ) -> Optional[Dict[str, Any]]: - """ - Find model stats in models_data, trying various name variants. - - Handles aliased model names (e.g., gemini-3-pro-preview -> gemini-3-pro-high) - by using the provider's _user_to_api_model() mapping. - - Args: - models_data: Dict of model_name -> stats from credential - model: Model name to look up (user-facing name) - provider: Provider name for prefixing - provider_instance: Provider instance for alias methods - - Returns: - Model stats dict if found, None otherwise - """ - # Try direct match with and without provider prefix - prefixed_model = f"{provider}/{model}" - model_stats = models_data.get(prefixed_model) or models_data.get(model) - - if model_stats: - return model_stats - - # Try with API model name (e.g., gemini-3-pro-preview -> gemini-3-pro-high) - if hasattr(provider_instance, "_user_to_api_model"): - api_model = provider_instance._user_to_api_model(model) - if api_model != model: - prefixed_api = f"{provider}/{api_model}" - model_stats = models_data.get(prefixed_api) or models_data.get( - api_model - ) - - return model_stats - - def _get_baseline_confidence(self, model_stats: Dict) -> str: - """ - Determine confidence level based on baseline age. - - Args: - model_stats: Model statistics dict with baseline_fetched_at - - Returns: - "high" | "medium" | "low" - """ - baseline_fetched_at = model_stats.get("baseline_fetched_at") - if not baseline_fetched_at: - return "low" - - age_seconds = time.time() - baseline_fetched_at - if age_seconds < 300: # 5 minutes - return "high" - elif age_seconds < 1800: # 30 minutes - return "medium" - return "low" - - async def reload_usage_from_disk(self) -> None: - """ - Force reload usage data from disk. - - Useful when wanting fresh stats without making external API calls. - """ - await self.usage_manager.reload_from_disk() - - async def force_refresh_quota( - self, - provider: Optional[str] = None, - credential: Optional[str] = None, - ) -> Dict[str, Any]: - """ - Force refresh quota from external API. - - For Antigravity, this fetches live quota data from the API. - For other providers, this is a no-op (just reloads from disk). - - Args: - provider: If specified, only refresh this provider - credential: If specified, only refresh this specific credential - - Returns: - Refresh result dict with success/failure info - """ - result = { - "action": "force_refresh", - "scope": "credential" - if credential - else ("provider" if provider else "all"), - "provider": provider, - "credential": credential, - "credentials_refreshed": 0, - "success_count": 0, - "failed_count": 0, - "duration_ms": 0, - "errors": [], - } - - start_time = time.time() - - # Determine which providers to refresh - if provider: - providers_to_refresh = ( - [provider] if provider in self.all_credentials else [] - ) - else: - providers_to_refresh = list(self.all_credentials.keys()) - - for prov in providers_to_refresh: - provider_class = self._provider_plugins.get(prov) - if not provider_class: - continue - - # Get or create provider instance - if prov not in self._provider_instances: - self._provider_instances[prov] = provider_class() - provider_instance = self._provider_instances[prov] - - # Check if provider supports quota refresh (like Antigravity) - if hasattr(provider_instance, "fetch_initial_baselines"): - # Get credentials to refresh - if credential: - # Find full path for this credential - creds_to_refresh = [] - for cred_path in self.all_credentials.get(prov, []): - if cred_path.endswith(credential) or cred_path == credential: - creds_to_refresh.append(cred_path) - break - else: - creds_to_refresh = self.all_credentials.get(prov, []) - - if not creds_to_refresh: - continue - - try: - # Fetch live quota from API for ALL specified credentials - quota_results = await provider_instance.fetch_initial_baselines( - creds_to_refresh - ) - - # Store baselines in usage manager - if hasattr(provider_instance, "_store_baselines_to_usage_manager"): - stored = ( - await provider_instance._store_baselines_to_usage_manager( - quota_results, self.usage_manager - ) - ) - result["success_count"] += stored - - result["credentials_refreshed"] += len(creds_to_refresh) - - # Count failures - for cred_path, data in quota_results.items(): - if data.get("status") != "success": - result["failed_count"] += 1 - result["errors"].append( - f"{Path(cred_path).name}: {data.get('error', 'Unknown error')}" - ) - - except Exception as e: - lib_logger.error(f"Failed to refresh quota for {prov}: {e}") - result["errors"].append(f"{prov}: {str(e)}") - result["failed_count"] += len(creds_to_refresh) - - result["duration_ms"] = int((time.time() - start_time) * 1000) - return result - - # --- Anthropic API Compatibility Methods --- - - async def anthropic_messages( - self, - request: "AnthropicMessagesRequest", - raw_request: Optional[Any] = None, - pre_request_callback: Optional[callable] = None, - ) -> Any: - """ - Handle Anthropic Messages API requests. - - This method accepts requests in Anthropic's format, translates them to - OpenAI format internally, processes them through the existing acompletion - method, and returns responses in Anthropic's format. - - Args: - request: An AnthropicMessagesRequest object - raw_request: Optional raw request object for disconnect checks - pre_request_callback: Optional async callback before each API request - - Returns: - For non-streaming: dict in Anthropic Messages format - For streaming: AsyncGenerator yielding Anthropic SSE format strings - """ - from .anthropic_compat import ( - translate_anthropic_request, - openai_to_anthropic_response, - anthropic_streaming_wrapper, - ) - import uuid - - request_id = f"msg_{uuid.uuid4().hex[:24]}" - original_model = request.model - - # Extract provider from model for logging - provider = original_model.split("/")[0] if "/" in original_model else "unknown" - - # Create Anthropic transaction logger if request logging is enabled - anthropic_logger = None - if self.enable_request_logging: - anthropic_logger = TransactionLogger( - provider, - original_model, - enabled=True, - api_format="ant", - ) - # Log original Anthropic request - anthropic_logger.log_request( - request.model_dump(exclude_none=True), - filename="anthropic_request.json", - ) - - # Translate Anthropic request to OpenAI format - openai_request = translate_anthropic_request(request) - - # Pass parent log directory to acompletion for nested logging - if anthropic_logger and anthropic_logger.log_dir: - openai_request["_parent_log_dir"] = anthropic_logger.log_dir - - if request.stream: - # Streaming response - response_generator = self.acompletion( - request=raw_request, - pre_request_callback=pre_request_callback, - **openai_request, - ) - - # Create disconnect checker if raw_request provided - is_disconnected = None - if raw_request is not None and hasattr(raw_request, "is_disconnected"): - is_disconnected = raw_request.is_disconnected - - # Return the streaming wrapper - # Note: For streaming, the anthropic response logging happens in the wrapper - return anthropic_streaming_wrapper( - openai_stream=response_generator, - original_model=original_model, - request_id=request_id, - is_disconnected=is_disconnected, - transaction_logger=anthropic_logger, - ) - else: - # Non-streaming response - response = await self.acompletion( - request=raw_request, - pre_request_callback=pre_request_callback, - **openai_request, - ) - - # Convert OpenAI response to Anthropic format - openai_response = ( - response.model_dump() - if hasattr(response, "model_dump") - else dict(response) - ) - anthropic_response = openai_to_anthropic_response( - openai_response, original_model - ) - - # Override the ID with our request ID - anthropic_response["id"] = request_id - - # Log Anthropic response - if anthropic_logger: - anthropic_logger.log_response( - anthropic_response, - filename="anthropic_response.json", - ) - - return anthropic_response - - async def anthropic_count_tokens( - self, - request: "AnthropicCountTokensRequest", - ) -> dict: - """ - Handle Anthropic count_tokens API requests. - - Counts the number of tokens that would be used by a Messages API request. - This is useful for estimating costs and managing context windows. - - Args: - request: An AnthropicCountTokensRequest object - - Returns: - Dict with input_tokens count in Anthropic format - """ - from .anthropic_compat import ( - anthropic_to_openai_messages, - anthropic_to_openai_tools, - ) - import json - - anthropic_request = request.model_dump(exclude_none=True) - - openai_messages = anthropic_to_openai_messages( - anthropic_request.get("messages", []), anthropic_request.get("system") - ) - - # Count tokens for messages - message_tokens = self.token_count( - model=request.model, - messages=openai_messages, - ) - - # Count tokens for tools if present - tool_tokens = 0 - if request.tools: - # Tools add tokens based on their definitions - # Convert to JSON string and count tokens for tool definitions - openai_tools = anthropic_to_openai_tools( - [tool.model_dump() for tool in request.tools] - ) - if openai_tools: - # Serialize tools to count their token contribution - tools_text = json.dumps(openai_tools) - tool_tokens = self.token_count( - model=request.model, - text=tools_text, - ) - - total_tokens = message_tokens + tool_tokens - - return {"input_tokens": total_tokens} diff --git a/src/rotator_library/client/__init__.py b/src/rotator_library/client/__init__.py new file mode 100644 index 00000000..4307f84d --- /dev/null +++ b/src/rotator_library/client/__init__.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Client package for LLM API key rotation. + +This package provides the RotatingClient and associated components +for intelligent credential rotation and retry logic. + +Public API: + RotatingClient: Main client class for making API requests + StreamedAPIError: Exception for streaming errors + +Components (for advanced usage): + RequestExecutor: Unified retry/rotation logic + CredentialFilter: Tier compatibility filtering + ModelResolver: Model name resolution + ProviderTransforms: Provider-specific transforms + StreamingHandler: Streaming response processing +""" + +from .rotating_client import RotatingClient +from ..core.errors import StreamedAPIError + +# Also expose components for advanced usage +from .executor import RequestExecutor +from .filters import CredentialFilter +from .models import ModelResolver +from .transforms import ProviderTransforms +from .streaming import StreamingHandler +from .anthropic import AnthropicHandler +from .types import AvailabilityStats, RetryState, ExecutionResult + +__all__ = [ + # Main public API + "RotatingClient", + "StreamedAPIError", + # Components + "RequestExecutor", + "CredentialFilter", + "ModelResolver", + "ProviderTransforms", + "StreamingHandler", + "AnthropicHandler", + # Types + "AvailabilityStats", + "RetryState", + "ExecutionResult", +] diff --git a/src/rotator_library/client/anthropic.py b/src/rotator_library/client/anthropic.py new file mode 100644 index 00000000..f3f6119c --- /dev/null +++ b/src/rotator_library/client/anthropic.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Anthropic API compatibility handler for RotatingClient. + +This module provides Anthropic SDK compatibility methods that allow using +Anthropic's Messages API format with the credential rotation system. +""" + +import json +import logging +import uuid +from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional + +from ..anthropic_compat import ( + AnthropicMessagesRequest, + AnthropicCountTokensRequest, + translate_anthropic_request, + openai_to_anthropic_response, + anthropic_streaming_wrapper, + anthropic_to_openai_messages, + anthropic_to_openai_tools, +) +from ..transaction_logger import TransactionLogger + +if TYPE_CHECKING: + from .rotating_client import RotatingClient + +lib_logger = logging.getLogger("rotator_library") + + +class AnthropicHandler: + """ + Handler for Anthropic API compatibility methods. + + This class provides methods to handle Anthropic Messages API requests + by translating them to OpenAI format, processing through the client's + acompletion method, and converting responses back to Anthropic format. + + Example: + handler = AnthropicHandler(client) + response = await handler.messages(request, raw_request) + """ + + def __init__(self, client: "RotatingClient"): + """ + Initialize the Anthropic handler. + + Args: + client: The RotatingClient instance to use for completions + """ + self._client = client + + async def messages( + self, + request: AnthropicMessagesRequest, + raw_request: Optional[Any] = None, + pre_request_callback: Optional[callable] = None, + ) -> Any: + """ + Handle Anthropic Messages API requests. + + This method accepts requests in Anthropic's format, translates them to + OpenAI format internally, processes them through the existing acompletion + method, and returns responses in Anthropic's format. + + Args: + request: An AnthropicMessagesRequest object + raw_request: Optional raw request object for disconnect checks + pre_request_callback: Optional async callback before each API request + + Returns: + For non-streaming: dict in Anthropic Messages format + For streaming: AsyncGenerator yielding Anthropic SSE format strings + """ + request_id = f"msg_{uuid.uuid4().hex[:24]}" + original_model = request.model + + # Extract provider from model for logging + provider = original_model.split("/")[0] if "/" in original_model else "unknown" + + # Create Anthropic transaction logger if request logging is enabled + anthropic_logger = None + if self._client.enable_request_logging: + anthropic_logger = TransactionLogger( + provider, + original_model, + enabled=True, + api_format="ant", + ) + # Log original Anthropic request + anthropic_logger.log_request( + request.model_dump(exclude_none=True), + filename="anthropic_request.json", + ) + + # Translate Anthropic request to OpenAI format + openai_request = translate_anthropic_request(request) + + # Pass parent log directory to acompletion for nested logging + if anthropic_logger and anthropic_logger.log_dir: + openai_request["_parent_log_dir"] = anthropic_logger.log_dir + + if request.stream: + # Streaming response + response_generator = self._client.acompletion( + request=raw_request, + pre_request_callback=pre_request_callback, + **openai_request, + ) + + # Create disconnect checker if raw_request provided + is_disconnected = None + if raw_request is not None and hasattr(raw_request, "is_disconnected"): + is_disconnected = raw_request.is_disconnected + + # Return the streaming wrapper + # Note: For streaming, the anthropic response logging happens in the wrapper + return anthropic_streaming_wrapper( + openai_stream=response_generator, + original_model=original_model, + request_id=request_id, + is_disconnected=is_disconnected, + transaction_logger=anthropic_logger, + ) + else: + # Non-streaming response + response = await self._client.acompletion( + request=raw_request, + pre_request_callback=pre_request_callback, + **openai_request, + ) + + # Convert OpenAI response to Anthropic format + openai_response = ( + response.model_dump() + if hasattr(response, "model_dump") + else dict(response) + ) + anthropic_response = openai_to_anthropic_response( + openai_response, original_model + ) + + # Override the ID with our request ID + anthropic_response["id"] = request_id + + # Log Anthropic response + if anthropic_logger: + anthropic_logger.log_response( + anthropic_response, + filename="anthropic_response.json", + ) + + return anthropic_response + + async def count_tokens( + self, + request: AnthropicCountTokensRequest, + ) -> dict: + """ + Handle Anthropic count_tokens API requests. + + Counts the number of tokens that would be used by a Messages API request. + This is useful for estimating costs and managing context windows. + + Args: + request: An AnthropicCountTokensRequest object + + Returns: + Dict with input_tokens count in Anthropic format + """ + anthropic_request = request.model_dump(exclude_none=True) + + openai_messages = anthropic_to_openai_messages( + anthropic_request.get("messages", []), anthropic_request.get("system") + ) + + # Count tokens for messages + message_tokens = self._client.token_count( + model=request.model, + messages=openai_messages, + ) + + # Count tokens for tools if present + tool_tokens = 0 + if request.tools: + # Tools add tokens based on their definitions + # Convert to JSON string and count tokens for tool definitions + openai_tools = anthropic_to_openai_tools( + [tool.model_dump() for tool in request.tools] + ) + if openai_tools: + # Serialize tools to count their token contribution + tools_text = json.dumps(openai_tools) + tool_tokens = self._client.token_count( + model=request.model, + text=tools_text, + ) + + total_tokens = message_tokens + tool_tokens + + return {"input_tokens": total_tokens} diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py new file mode 100644 index 00000000..6360b8a3 --- /dev/null +++ b/src/rotator_library/client/executor.py @@ -0,0 +1,1169 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Unified request execution with retry and rotation. + +This module extracts and unifies the retry logic that was duplicated in: +- _execute_with_retry (lines 1174-1945) +- _streaming_acompletion_with_retry (lines 1947-2780) + +The RequestExecutor provides a single code path for all request types, +with streaming vs non-streaming handled as a parameter. +""" + +import asyncio +import json +import logging +import random +import time +from typing import ( + Any, + AsyncGenerator, + Dict, + List, + Optional, + Set, + TYPE_CHECKING, + Tuple, + Union, +) + +import httpx +import litellm +from litellm.exceptions import ( + APIConnectionError, + RateLimitError, + ServiceUnavailableError, + InternalServerError, +) + +from ..core.types import RequestContext, ErrorAction +from ..core.errors import ( + NoAvailableKeysError, + PreRequestCallbackError, + StreamedAPIError, + ClassifiedError, + RequestErrorAccumulator, + classify_error, + should_rotate_on_error, + should_retry_same_key, + mask_credential, +) +from ..core.constants import DEFAULT_MAX_RETRIES +from ..request_sanitizer import sanitize_request_payload +from ..transaction_logger import TransactionLogger +from ..failure_logger import log_failure + +from .types import RetryState, AvailabilityStats +from .filters import CredentialFilter +from .transforms import ProviderTransforms +from .streaming import StreamingHandler + +if TYPE_CHECKING: + from ..usage import UsageManager + +lib_logger = logging.getLogger("rotator_library") + + +class RequestExecutor: + """ + Unified retry/rotation logic for all request types. + + This class handles: + - Credential rotation across providers + - Per-credential retry with backoff + - Error classification and handling + - Streaming and non-streaming requests + """ + + def __init__( + self, + usage_managers: Dict[str, "UsageManager"], + cooldown_manager: Any, + credential_filter: CredentialFilter, + provider_transforms: ProviderTransforms, + provider_plugins: Dict[str, Any], + http_client: httpx.AsyncClient, + max_retries: int = DEFAULT_MAX_RETRIES, + global_timeout: int = 30, + abort_on_callback_error: bool = True, + litellm_provider_params: Optional[Dict[str, Any]] = None, + litellm_logger_fn: Optional[Any] = None, + provider_instances: Optional[Dict[str, Any]] = None, + ): + """ + Initialize RequestExecutor. + + Args: + usage_managers: Dict mapping provider names to UsageManager instances + cooldown_manager: CooldownManager instance + credential_filter: CredentialFilter instance + provider_transforms: ProviderTransforms instance + provider_plugins: Dict mapping provider names to plugin classes + http_client: Shared httpx.AsyncClient for provider requests + max_retries: Max retries per credential + global_timeout: Global request timeout in seconds + abort_on_callback_error: Abort on pre-request callback errors + litellm_provider_params: Optional dict of provider-specific LiteLLM + parameters to merge into requests (e.g., custom headers, timeouts) + litellm_logger_fn: Optional callback function for LiteLLM logging + provider_instances: Shared dict for caching provider instances. + If None, creates a new dict (not recommended - leads to duplicate instances). + """ + self._usage_managers = usage_managers + self._cooldown = cooldown_manager + self._filter = credential_filter + self._transforms = provider_transforms + self._plugins = provider_plugins + self._plugin_instances: Dict[str, Any] = ( + provider_instances if provider_instances is not None else {} + ) + self._http_client = http_client + self._max_retries = max_retries + self._global_timeout = global_timeout + self._abort_on_callback_error = abort_on_callback_error + self._litellm_provider_params = litellm_provider_params or {} + self._litellm_logger_fn = litellm_logger_fn + # StreamingHandler no longer needs usage_manager - we pass cred_context directly + self._streaming_handler = StreamingHandler() + + def _get_plugin_instance(self, provider: str) -> Optional[Any]: + """Get or create a plugin instance for a provider.""" + if provider not in self._plugin_instances: + plugin_class = self._plugins.get(provider) + if plugin_class: + if isinstance(plugin_class, type): + self._plugin_instances[provider] = plugin_class() + else: + self._plugin_instances[provider] = plugin_class + else: + return None + return self._plugin_instances[provider] + + def _get_quota_display( + self, + state: Any, + model: str, + quota_group: Optional[str], + usage_manager: "UsageManager", + ) -> str: + """ + Get quota display string for logging. + + Checks group stats first (if quota_group provided), then falls back + to model stats. Returns a formatted string like "5/50 [90%]". + + Args: + state: CredentialState object + model: Model name + quota_group: Optional quota group name + usage_manager: UsageManager instance + + Returns: + Formatted quota display string, or "?/?" if unavailable + """ + if not state: + return "?/?" + + window_manager = getattr(usage_manager, "window_manager", None) + if not window_manager: + return "?/?" + + primary_def = window_manager.get_primary_definition() + if not primary_def: + return "?/?" + + window = None + # Check GROUP first if quota_group provided (shared limits) + if quota_group: + group_stats = state.get_group_stats(quota_group, create=False) + if group_stats: + window = group_stats.windows.get(primary_def.name) + + # Fall back to MODEL if no group limit found + if window is None or window.limit is None: + model_stats = state.get_model_stats(model, create=False) + if model_stats: + window = model_stats.windows.get(primary_def.name) + + # Display quota if we found a window with a limit + if window and window.limit is not None: + remaining = max(0, window.limit - window.request_count) + pct = round(remaining / window.limit * 100) if window.limit else 0 + return f"{window.request_count}/{window.limit} [{pct}%]" + + return "?/?" + + def _log_acquiring_credential( + self, + model: str, + tried_count: int, + availability: Dict[str, Any], + ) -> None: + """ + Log credential acquisition attempt with availability info. + + Args: + model: Model name + tried_count: Number of credentials already tried + availability: Availability stats dict from usage manager + """ + blocked = availability.get("blocked_by", {}) + blocked_parts = [] + if blocked.get("cooldowns"): + blocked_parts.append(f"cd:{blocked['cooldowns']}") + if blocked.get("fair_cycle"): + blocked_parts.append(f"fc:{blocked['fair_cycle']}") + if blocked.get("custom_caps"): + blocked_parts.append(f"cap:{blocked['custom_caps']}") + if blocked.get("window_limits"): + blocked_parts.append(f"wl:{blocked['window_limits']}") + if blocked.get("concurrent"): + blocked_parts.append(f"con:{blocked['concurrent']}") + blocked_str = f"({', '.join(blocked_parts)})" if blocked_parts else "" + lib_logger.info( + f"Acquiring credential for model {model}. Tried: {tried_count}/" + f"{availability.get('available', 0)}({availability.get('total', 0)}{blocked_str})" + ) + + async def _prepare_request_kwargs( + self, + provider: str, + model: str, + cred: str, + context: "RequestContext", + ) -> Dict[str, Any]: + """ + Prepare request kwargs with transforms, sanitization, and provider params. + + Args: + provider: Provider name + model: Model name + cred: Credential string + context: Request context + + Returns: + Prepared kwargs dict for the LiteLLM call + """ + # Apply transforms + kwargs = await self._transforms.apply( + provider, model, cred, context.kwargs.copy() + ) + + # Sanitize request payload + kwargs = sanitize_request_payload(kwargs, model) + + # Apply provider-specific LiteLLM params + self._apply_litellm_provider_params(provider, kwargs) + + # Add transaction context for provider logging + if context.transaction_logger: + kwargs["transaction_context"] = context.transaction_logger.get_context() + + return kwargs + + def _log_acquired_credential( + self, + cred: str, + model: str, + state: Any, + quota_group: Optional[str], + availability: Dict[str, Any], + usage_manager: "UsageManager", + ) -> None: + """ + Log successful credential acquisition. + + Args: + cred: Credential string + model: Model name + state: CredentialState object + quota_group: Optional quota group + availability: Availability stats dict + usage_manager: UsageManager instance + """ + tier = state.tier if state else None + priority = state.priority if state else None + selection_mode = availability.get("rotation_mode") + quota_display = self._get_quota_display( + state, model, quota_group, usage_manager + ) + lib_logger.info( + f"Acquired key {mask_credential(cred)} for model {model} " + f"(tier: {tier}, priority: {priority}, selection: {selection_mode}, quota: {quota_display})" + ) + + async def _run_pre_request_callback( + self, + context: "RequestContext", + kwargs: Dict[str, Any], + ) -> None: + """ + Run pre-request callback if configured. + + Args: + context: Request context + kwargs: Request kwargs + + Raises: + PreRequestCallbackError: If callback fails and abort_on_callback_error is True + """ + if context.pre_request_callback: + try: + await context.pre_request_callback(context.request, kwargs) + except Exception as e: + if self._abort_on_callback_error: + raise PreRequestCallbackError(str(e)) from e + lib_logger.warning(f"Pre-request callback failed: {e}") + + async def execute( + self, + context: RequestContext, + ) -> Union[Any, AsyncGenerator[str, None]]: + """ + Execute request with retry/rotation. + + This is the main entry point for request execution. + + Args: + context: RequestContext with all request details + + Returns: + Response object or async generator for streaming + """ + if context.streaming: + return self._execute_streaming(context) + else: + return await self._execute_non_streaming(context) + + async def _prepare_execution( + self, + context: RequestContext, + ) -> Tuple["UsageManager", Any, List[str], Optional[str], Dict[str, Any]]: + provider = context.provider + model = context.model + + usage_manager = self._usage_managers.get(provider) + if not usage_manager: + raise NoAvailableKeysError(f"No UsageManager for provider {provider}") + + filter_result = self._filter.filter_by_tier( + context.credentials, model, provider + ) + credentials = filter_result.all_usable + quota_group = usage_manager.get_model_quota_group(model) + + await self._ensure_initialized(usage_manager, context, filter_result) + await self._validate_request(provider, model, context.kwargs) + + if not credentials: + raise NoAvailableKeysError(f"No compatible credentials for model {model}") + + request_headers = ( + dict(context.request.headers) if context.request is not None else {} + ) + + return usage_manager, filter_result, credentials, quota_group, request_headers + + async def _execute_non_streaming( + self, + context: RequestContext, + ) -> Any: + """ + Execute non-streaming request with retry/rotation. + + Args: + context: RequestContext with all request details + + Returns: + Response object + """ + provider = context.provider + model = context.model + deadline = context.deadline + + ( + usage_manager, + filter_result, + credentials, + quota_group, + request_headers, + ) = await self._prepare_execution(context) + + error_accumulator = RequestErrorAccumulator() + error_accumulator.model = model + error_accumulator.provider = provider + + retry_state = RetryState() + last_exception: Optional[Exception] = None + + while time.time() < deadline: + # Check for untried credentials + untried = [c for c in credentials if c not in retry_state.tried_credentials] + if not untried: + lib_logger.warning( + f"All {len(credentials)} credentials tried for {model}" + ) + break + + # Wait for provider cooldown + await self._wait_for_cooldown(provider, deadline) + + # Acquire credential using context manager + try: + availability = await usage_manager.get_availability_stats( + model, quota_group + ) + self._log_acquiring_credential( + model, len(retry_state.tried_credentials), availability + ) + async with await usage_manager.acquire_credential( + model=model, + quota_group=quota_group, + candidates=untried, + priorities=filter_result.priorities, + deadline=deadline, + ) as cred_context: + cred = cred_context.credential + retry_state.record_attempt(cred) + + state = getattr(usage_manager, "states", {}).get( + cred_context.stable_id + ) + self._log_acquired_credential( + cred, model, state, quota_group, availability, usage_manager + ) + + try: + # Prepare request kwargs + kwargs = await self._prepare_request_kwargs( + provider, model, cred, context + ) + + # Get provider plugin + plugin = self._get_plugin_instance(provider) + + # Execute request with retries + for attempt in range(self._max_retries): + try: + lib_logger.info( + f"Attempting call with credential {mask_credential(cred)} " + f"(Attempt {attempt + 1}/{self._max_retries})" + ) + # Pre-request callback + await self._run_pre_request_callback(context, kwargs) + + # Make the API call + if plugin and plugin.has_custom_logic(): + kwargs["credential_identifier"] = cred + response = await plugin.acompletion( + self._http_client, **kwargs + ) + else: + # Standard LiteLLM call + kwargs["api_key"] = cred + self._apply_litellm_logger(kwargs) + response = await litellm.acompletion(**kwargs) + + # Success! Extract token usage if available + ( + prompt_tokens, + completion_tokens, + prompt_tokens_cached, + prompt_tokens_cache_write, + thinking_tokens, + ) = self._extract_usage_tokens(response) + approx_cost = self._calculate_cost( + provider, model, response + ) + response_headers = self._extract_response_headers( + response + ) + + cred_context.mark_success( + response=response, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cached, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + response_headers=response_headers, + ) + + lib_logger.info( + f"Recorded usage from response object for key {mask_credential(cred)}" + ) + + # Log response if transaction logging enabled + if context.transaction_logger: + try: + response_data = ( + response.model_dump() + if hasattr(response, "model_dump") + else response + ) + context.transaction_logger.log_response( + response_data + ) + except Exception as log_err: + lib_logger.debug( + f"Failed to log response: {log_err}" + ) + + return response + + except Exception as e: + last_exception = e + action = await self._handle_error_with_context( + e, + cred_context, + model, + provider, + attempt, + error_accumulator, + retry_state, + request_headers, + ) + + if action == ErrorAction.RETRY_SAME: + continue + elif action == ErrorAction.ROTATE: + break # Try next credential + else: # FAIL + raise + + except PreRequestCallbackError: + raise + except Exception: + # Let context manager handle cleanup + pass + + except NoAvailableKeysError: + break + + # All credentials exhausted + error_accumulator.timeout_occurred = time.time() >= deadline + if last_exception and not error_accumulator.has_errors(): + raise last_exception + + # Return error response + return error_accumulator.build_client_error_response() + + async def _execute_streaming( + self, + context: RequestContext, + ) -> AsyncGenerator[str, None]: + """ + Execute streaming request with retry/rotation. + + This is an async generator that yields SSE-formatted strings. + + Args: + context: RequestContext with all request details + + Yields: + SSE-formatted strings + """ + provider = context.provider + model = context.model + deadline = context.deadline + + try: + ( + usage_manager, + filter_result, + credentials, + quota_group, + request_headers, + ) = await self._prepare_execution(context) + except NoAvailableKeysError as exc: + error_data = { + "error": { + "message": str(exc), + "type": "proxy_error", + } + } + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + return + + error_accumulator = RequestErrorAccumulator() + error_accumulator.model = model + error_accumulator.provider = provider + + retry_state = RetryState() + last_exception: Optional[Exception] = None + + try: + while time.time() < deadline: + # Check for untried credentials + untried = [ + c for c in credentials if c not in retry_state.tried_credentials + ] + if not untried: + lib_logger.warning( + f"All {len(credentials)} credentials tried for {model}" + ) + break + + # Wait for provider cooldown + remaining = deadline - time.time() + if remaining <= 0: + break + await self._wait_for_cooldown(provider, deadline) + + # Acquire credential using context manager + try: + availability = await usage_manager.get_availability_stats( + model, quota_group + ) + self._log_acquiring_credential( + model, len(retry_state.tried_credentials), availability + ) + async with await usage_manager.acquire_credential( + model=model, + quota_group=quota_group, + candidates=untried, + priorities=filter_result.priorities, + deadline=deadline, + ) as cred_context: + cred = cred_context.credential + retry_state.record_attempt(cred) + + state = getattr(usage_manager, "states", {}).get( + cred_context.stable_id + ) + self._log_acquired_credential( + cred, model, state, quota_group, availability, usage_manager + ) + + try: + # Prepare request kwargs + kwargs = await self._prepare_request_kwargs( + provider, model, cred, context + ) + + # Add stream options (but not for iflow - it returns 406) + if provider != "iflow": + if "stream_options" not in kwargs: + kwargs["stream_options"] = {} + if "include_usage" not in kwargs["stream_options"]: + kwargs["stream_options"]["include_usage"] = True + + # Get provider plugin + plugin = self._get_plugin_instance(provider) + skip_cost_calculation = bool( + plugin + and getattr(plugin, "skip_cost_calculation", False) + ) + + # Execute request with retries + for attempt in range(self._max_retries): + try: + lib_logger.info( + f"Attempting stream with credential {mask_credential(cred)} " + f"(Attempt {attempt + 1}/{self._max_retries})" + ) + # Pre-request callback + await self._run_pre_request_callback( + context, kwargs + ) + + # Make the API call + if plugin and plugin.has_custom_logic(): + kwargs["credential_identifier"] = cred + stream = await plugin.acompletion( + self._http_client, **kwargs + ) + else: + kwargs["api_key"] = cred + kwargs["stream"] = True + self._apply_litellm_logger(kwargs) + stream = await litellm.acompletion(**kwargs) + + # Hand off to streaming handler with cred_context + # The handler will call mark_success on completion + base_stream = self._streaming_handler.wrap_stream( + stream, + cred, + model, + context.request, + cred_context, + skip_cost_calculation=skip_cost_calculation, + ) + + lib_logger.info( + f"Stream connection established for credential {mask_credential(cred)}. " + "Processing response." + ) + + # Wrap with transaction logging if enabled + if context.transaction_logger: + async for ( + chunk + ) in self._transaction_logging_stream_wrapper( + base_stream, + context.transaction_logger, + context.kwargs, + ): + yield chunk + else: + async for chunk in base_stream: + yield chunk + return + + except StreamedAPIError as e: + last_exception = e + original = getattr(e, "data", e) + classified = classify_error(original, provider) + log_failure( + api_key=cred, + model=model, + attempt=attempt + 1, + error=e, + request_headers=request_headers, + ) + error_accumulator.record_error( + cred, classified, str(original)[:150] + ) + + # Track consecutive quota failures + if classified.error_type == "quota_exceeded": + retry_state.increment_quota_failures() + if retry_state.consecutive_quota_failures >= 3: + lib_logger.error( + "3 consecutive quota errors in streaming - " + "request may be too large" + ) + cred_context.mark_failure(classified) + error_data = { + "error": { + "message": "Request exceeds quota for all credentials", + "type": "quota_exhausted", + } + } + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + return + else: + retry_state.reset_quota_failures() + + if not should_rotate_on_error(classified): + cred_context.mark_failure(classified) + raise + + cred_context.mark_failure(classified) + break # Rotate + + except (RateLimitError, httpx.HTTPStatusError) as e: + last_exception = e + classified = classify_error(e, provider) + log_failure( + api_key=cred, + model=model, + attempt=attempt + 1, + error=e, + request_headers=request_headers, + ) + error_accumulator.record_error( + cred, classified, str(e)[:150] + ) + + # Track consecutive quota failures + if classified.error_type == "quota_exceeded": + retry_state.increment_quota_failures() + if retry_state.consecutive_quota_failures >= 3: + lib_logger.error( + "3 consecutive quota errors in streaming - " + "request may be too large" + ) + cred_context.mark_failure(classified) + error_data = { + "error": { + "message": "Request exceeds quota for all credentials", + "type": "quota_exhausted", + } + } + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + return + else: + retry_state.reset_quota_failures() + + if not should_rotate_on_error(classified): + cred_context.mark_failure(classified) + raise + + cred_context.mark_failure(classified) + break # Rotate + + except ( + APIConnectionError, + InternalServerError, + ServiceUnavailableError, + ) as e: + last_exception = e + classified = classify_error(e, provider) + log_failure( + api_key=cred, + model=model, + attempt=attempt + 1, + error=e, + request_headers=request_headers, + ) + + if attempt >= self._max_retries - 1: + error_accumulator.record_error( + cred, classified, str(e)[:150] + ) + cred_context.mark_failure(classified) + break # Rotate + + # Calculate wait time + wait_time = classified.retry_after or ( + 2**attempt + ) + random.uniform(0, 1) + remaining = deadline - time.time() + if wait_time > remaining: + break # No time to wait + + await asyncio.sleep(wait_time) + continue # Retry + + except Exception as e: + last_exception = e + classified = classify_error(e, provider) + log_failure( + api_key=cred, + model=model, + attempt=attempt + 1, + error=e, + request_headers=request_headers, + ) + error_accumulator.record_error( + cred, classified, str(e)[:150] + ) + + if not should_rotate_on_error(classified): + cred_context.mark_failure(classified) + raise + + cred_context.mark_failure(classified) + break # Rotate + + except PreRequestCallbackError: + raise + except Exception: + # Let context manager handle cleanup + pass + + except NoAvailableKeysError: + break + + # All credentials exhausted or timeout + error_accumulator.timeout_occurred = time.time() >= deadline + error_data = error_accumulator.build_client_error_response() + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + + except NoAvailableKeysError as e: + lib_logger.error(f"No keys available: {e}") + error_data = {"error": {"message": str(e), "type": "proxy_busy"}} + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + + except Exception as e: + lib_logger.error(f"Unhandled exception in streaming: {e}", exc_info=True) + error_data = {"error": {"message": str(e), "type": "proxy_internal_error"}} + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + + def _apply_litellm_provider_params( + self, provider: str, kwargs: Dict[str, Any] + ) -> None: + """Merge provider-specific LiteLLM parameters into request kwargs.""" + params = self._litellm_provider_params.get(provider) + if not params: + return + kwargs["litellm_params"] = { + **params, + **kwargs.get("litellm_params", {}), + } + + def _apply_litellm_logger(self, kwargs: Dict[str, Any]) -> None: + """Attach LiteLLM logger callback if configured.""" + if self._litellm_logger_fn and "logger_fn" not in kwargs: + kwargs["logger_fn"] = self._litellm_logger_fn + + def _extract_response_headers(self, response: Any) -> Optional[Dict[str, Any]]: + """Extract response headers from LiteLLM response objects.""" + if hasattr(response, "response") and response.response is not None: + headers = getattr(response.response, "headers", None) + if headers is not None: + return dict(headers) + headers = getattr(response, "headers", None) + if headers is not None: + return dict(headers) + return None + + async def _wait_for_cooldown( + self, + provider: str, + deadline: float, + ) -> None: + """ + Wait for provider-level cooldown to end. + + Args: + provider: Provider name + deadline: Request deadline + """ + if not self._cooldown: + return + + remaining = await self._cooldown.get_remaining_cooldown(provider) + if remaining > 0: + budget = deadline - time.time() + if remaining > budget: + lib_logger.warning( + f"Provider {provider} cooldown ({remaining:.1f}s) exceeds budget ({budget:.1f}s)" + ) + return # Will fail on no keys available + lib_logger.info(f"Waiting {remaining:.1f}s for {provider} cooldown") + await asyncio.sleep(remaining) + + async def _handle_error_with_context( + self, + error: Exception, + cred_context: Any, # CredentialContext + model: str, + provider: str, + attempt: int, + error_accumulator: RequestErrorAccumulator, + retry_state: RetryState, + request_headers: Dict[str, Any], + ) -> str: + """ + Handle an error and determine next action. + + Args: + error: The caught exception + cred_context: CredentialContext for marking failure + model: Model name + provider: Provider name + attempt: Current attempt number + error_accumulator: Error tracking + retry_state: Retry state tracking + + Returns: + ErrorAction indicating what to do next + """ + classified = classify_error(error, provider) + error_message = str(error)[:150] + credential = cred_context.credential + + log_failure( + api_key=credential, + model=model, + attempt=attempt + 1, + error=error, + request_headers=request_headers, + ) + + # Check for quota errors + if classified.error_type == "quota_exceeded": + retry_state.increment_quota_failures() + if retry_state.consecutive_quota_failures >= 3: + # Likely request is too large + lib_logger.error( + f"3 consecutive quota errors - request may be too large" + ) + error_accumulator.record_error(credential, classified, error_message) + cred_context.mark_failure(classified) + return ErrorAction.FAIL + else: + retry_state.reset_quota_failures() + + # Check if should rotate + if not should_rotate_on_error(classified): + error_accumulator.record_error(credential, classified, error_message) + cred_context.mark_failure(classified) + return ErrorAction.FAIL + + # Check if should retry same key + if should_retry_same_key(classified) and attempt < self._max_retries - 1: + wait_time = classified.retry_after or (2**attempt) + random.uniform(0, 1) + lib_logger.info( + f"Retrying {mask_credential(credential)} in {wait_time:.1f}s" + ) + await asyncio.sleep(wait_time) + return ErrorAction.RETRY_SAME + + # Record error and rotate + error_accumulator.record_error(credential, classified, error_message) + cred_context.mark_failure(classified) + lib_logger.info( + f"Rotating from {mask_credential(credential)} after {classified.error_type}" + ) + return ErrorAction.ROTATE + + async def _ensure_initialized( + self, + usage_manager: "UsageManager", + context: RequestContext, + filter_result: "FilterResult", + ) -> None: + if usage_manager.initialized: + return + await usage_manager.initialize( + context.credentials, + priorities=filter_result.priorities, + tiers=filter_result.tier_names, + ) + + async def _validate_request( + self, + provider: str, + model: str, + kwargs: Dict[str, Any], + ) -> None: + plugin = self._get_plugin_instance(provider) + if not plugin or not hasattr(plugin, "validate_request"): + return + + result = plugin.validate_request(kwargs, model) + if asyncio.iscoroutine(result): + result = await result + if result is False: + raise ValueError(f"Request validation failed for {provider}/{model}") + if isinstance(result, str): + raise ValueError(result) + + def _extract_usage_tokens(self, response: Any) -> tuple[int, int, int, int, int]: + prompt_tokens = 0 + completion_tokens = 0 + cached_tokens = 0 + cache_write_tokens = 0 + thinking_tokens = 0 + + if hasattr(response, "usage") and response.usage: + prompt_tokens = getattr(response.usage, "prompt_tokens", 0) or 0 + completion_tokens = getattr(response.usage, "completion_tokens", 0) or 0 + + prompt_details = getattr(response.usage, "prompt_tokens_details", None) + if prompt_details: + if isinstance(prompt_details, dict): + cached_tokens = prompt_details.get("cached_tokens", 0) or 0 + cache_write_tokens = ( + prompt_details.get("cache_creation_tokens", 0) or 0 + ) + else: + cached_tokens = getattr(prompt_details, "cached_tokens", 0) or 0 + cache_write_tokens = ( + getattr(prompt_details, "cache_creation_tokens", 0) or 0 + ) + + completion_details = getattr( + response.usage, "completion_tokens_details", None + ) + if completion_details: + if isinstance(completion_details, dict): + thinking_tokens = completion_details.get("reasoning_tokens", 0) or 0 + else: + thinking_tokens = ( + getattr(completion_details, "reasoning_tokens", 0) or 0 + ) + + cache_read_tokens = getattr(response.usage, "cache_read_tokens", None) + if cache_read_tokens is not None: + cached_tokens = cache_read_tokens or 0 + cache_creation_tokens = getattr( + response.usage, "cache_creation_tokens", None + ) + if cache_creation_tokens is not None: + cache_write_tokens = cache_creation_tokens or 0 + + if thinking_tokens and completion_tokens >= thinking_tokens: + completion_tokens = completion_tokens - thinking_tokens + + uncached_prompt = max(0, prompt_tokens - cached_tokens) + return ( + uncached_prompt, + completion_tokens, + cached_tokens, + cache_write_tokens, + thinking_tokens, + ) + + def _calculate_cost(self, provider: str, model: str, response: Any) -> float: + plugin = self._get_plugin_instance(provider) + if plugin and getattr(plugin, "skip_cost_calculation", False): + return 0.0 + + try: + if isinstance(response, litellm.EmbeddingResponse): + model_info = litellm.get_model_info(model) + input_cost = model_info.get("input_cost_per_token") + if input_cost: + return (response.usage.prompt_tokens or 0) * input_cost + return 0.0 + + cost = litellm.completion_cost( + completion_response=response, + model=model, + ) + return float(cost) if cost is not None else 0.0 + except Exception as exc: + lib_logger.debug(f"Cost calculation failed for {model}: {exc}") + return 0.0 + + async def _transaction_logging_stream_wrapper( + self, + stream: AsyncGenerator[str, None], + transaction_logger: TransactionLogger, + request_kwargs: Dict[str, Any], + ) -> AsyncGenerator[str, None]: + """ + Wrap a stream to log chunks and final response to TransactionLogger. + + Yields all chunks unchanged while accumulating them for final logging. + + Args: + stream: The SSE stream from wrap_stream + transaction_logger: TransactionLogger instance + request_kwargs: Original request kwargs for context + + Yields: + SSE-formatted strings unchanged + """ + chunks = [] + + async for sse_line in stream: + yield sse_line + + # Parse and accumulate for final logging + if sse_line.startswith("data: ") and not sse_line.startswith( + "data: [DONE]" + ): + try: + content = sse_line[6:].strip() + if content: + chunk_data = json.loads(content) + chunks.append(chunk_data) + transaction_logger.log_stream_chunk(chunk_data) + except json.JSONDecodeError: + lib_logger.debug( + f"Failed to parse chunk for logging: {sse_line[:100]}" + ) + + # Log assembled final response + if chunks: + try: + final_response = TransactionLogger.assemble_streaming_response(chunks) + transaction_logger.log_response(final_response) + except Exception as e: + lib_logger.debug( + f"Failed to assemble/log final streaming response: {e}" + ) diff --git a/src/rotator_library/client/filters.py b/src/rotator_library/client/filters.py new file mode 100644 index 00000000..305e4fad --- /dev/null +++ b/src/rotator_library/client/filters.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Credential filtering by tier compatibility and priority. + +Extracts the tier filtering logic that was duplicated in client.py +at lines 1242-1315 and 2004-2076. +""" + +import logging +from typing import Any, Dict, List, Optional + +from ..core.types import FilterResult + +lib_logger = logging.getLogger("rotator_library") + + +class CredentialFilter: + """ + Filter and group credentials by tier compatibility and priority. + + This class extracts the credential filtering logic that was previously + duplicated in both _execute_with_retry and _streaming_acompletion_with_retry. + """ + + def __init__( + self, + provider_plugins: Dict[str, Any], + provider_instances: Optional[Dict[str, Any]] = None, + ): + """ + Initialize the CredentialFilter. + + Args: + provider_plugins: Dict mapping provider names to plugin classes/instances + provider_instances: Shared dict for caching provider instances. + If None, creates a new dict (not recommended - leads to duplicate instances). + """ + self._plugins = provider_plugins + self._plugin_instances: Dict[str, Any] = ( + provider_instances if provider_instances is not None else {} + ) + + def _get_plugin_instance(self, provider: str) -> Optional[Any]: + """ + Get or create a plugin instance for a provider. + + Args: + provider: Provider name + + Returns: + Plugin instance or None if not found + """ + if provider not in self._plugin_instances: + plugin_class = self._plugins.get(provider) + if plugin_class: + # Check if it's a class or already an instance + if isinstance(plugin_class, type): + lib_logger.debug( + f"[CredentialFilter] CREATING NEW INSTANCE for {provider} (cache_id={cache_id})" + ) + self._plugin_instances[provider] = plugin_class() + else: + self._plugin_instances[provider] = plugin_class + else: + return None + return self._plugin_instances[provider] + + def filter_by_tier( + self, + credentials: List[str], + model: str, + provider: str, + ) -> FilterResult: + """ + Filter credentials by tier compatibility for a model. + + Args: + credentials: List of credential identifiers + model: Model being requested + provider: Provider name + + Returns: + FilterResult with categorized credentials + """ + plugin = self._get_plugin_instance(provider) + + # Get tier requirement for model + required_tier = None + if plugin and hasattr(plugin, "get_model_tier_requirement"): + required_tier = plugin.get_model_tier_requirement(model) + + compatible: List[str] = [] + unknown: List[str] = [] + incompatible: List[str] = [] + priorities: Dict[str, int] = {} + tier_names: Dict[str, str] = {} + + for cred in credentials: + # Get priority and tier name + priority = None + tier_name = None + + if plugin: + if hasattr(plugin, "get_credential_priority"): + priority = plugin.get_credential_priority(cred) + if hasattr(plugin, "get_credential_tier_name"): + tier_name = plugin.get_credential_tier_name(cred) + + if priority is not None: + priorities[cred] = priority + if tier_name: + tier_names[cred] = tier_name + + # Categorize by tier compatibility + if required_tier is None: + # No tier requirement - all compatible + compatible.append(cred) + elif priority is None: + # Unknown priority - keep as candidate + unknown.append(cred) + elif priority <= required_tier: + # Known compatible (lower priority number = higher tier) + compatible.append(cred) + else: + # Known incompatible + incompatible.append(cred) + + # Log if all credentials are incompatible + if incompatible and not compatible and not unknown: + lib_logger.warning( + f"Model {model} requires tier <= {required_tier}, " + f"but all {len(incompatible)} credentials are incompatible" + ) + + return FilterResult( + compatible=compatible, + unknown=unknown, + incompatible=incompatible, + priorities=priorities, + tier_names=tier_names, + ) + + def group_by_priority( + self, + credentials: List[str], + priorities: Dict[str, int], + ) -> Dict[int, List[str]]: + """ + Group credentials by priority level. + + Args: + credentials: List of credential identifiers + priorities: Dict mapping credentials to priority levels + + Returns: + Dict mapping priority levels to credential lists, sorted by priority + """ + groups: Dict[int, List[str]] = {} + + for cred in credentials: + priority = priorities.get(cred, 999) + if priority not in groups: + groups[priority] = [] + groups[priority].append(cred) + + # Return sorted by priority (lower = higher priority) + return dict(sorted(groups.items())) + + def get_highest_priority_credentials( + self, + credentials: List[str], + priorities: Dict[str, int], + ) -> List[str]: + """ + Get credentials with the highest priority (lowest priority number). + + Args: + credentials: List of credential identifiers + priorities: Dict mapping credentials to priority levels + + Returns: + List of credentials with the highest priority + """ + if not credentials: + return [] + + groups = self.group_by_priority(credentials, priorities) + if not groups: + return credentials + + # Get the lowest priority number (highest priority) + highest_priority = min(groups.keys()) + return groups[highest_priority] diff --git a/src/rotator_library/client/models.py b/src/rotator_library/client/models.py new file mode 100644 index 00000000..27c9015c --- /dev/null +++ b/src/rotator_library/client/models.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Model name resolution and filtering. + +Extracts model-related logic from client.py including: +- _resolve_model_id (lines 867-902) +- _is_model_ignored (lines 587-619) +- _is_model_whitelisted (lines 621-651) +""" + +import fnmatch +import logging +from typing import Any, Dict, List, Optional + +lib_logger = logging.getLogger("rotator_library") + + +class ModelResolver: + """ + Resolve model names and apply filtering rules. + + Handles: + - Model ID resolution (display name -> actual ID) + - Whitelist/blacklist filtering + - Provider prefix handling + """ + + def __init__( + self, + provider_plugins: Dict[str, Any], + model_definitions: Optional[Any] = None, + ignore_models: Optional[Dict[str, List[str]]] = None, + whitelist_models: Optional[Dict[str, List[str]]] = None, + provider_instances: Optional[Dict[str, Any]] = None, + ): + """ + Initialize the ModelResolver. + + Args: + provider_plugins: Dict mapping provider names to plugin classes + model_definitions: ModelDefinitions instance for ID mapping + ignore_models: Models to ignore/blacklist per provider + whitelist_models: Models to explicitly whitelist per provider + provider_instances: Shared dict for caching provider instances. + If None, creates a new dict (not recommended - leads to duplicate instances). + """ + self._plugins = provider_plugins + self._plugin_instances: Dict[str, Any] = ( + provider_instances if provider_instances is not None else {} + ) + self._definitions = model_definitions + self._ignore = ignore_models or {} + self._whitelist = whitelist_models or {} + + def _get_plugin_instance(self, provider: str) -> Optional[Any]: + """ + Get or create a plugin instance for a provider. + """ + if provider not in self._plugin_instances: + plugin_class = self._plugins.get(provider) + if plugin_class: + if isinstance(plugin_class, type): + self._plugin_instances[provider] = plugin_class() + else: + self._plugin_instances[provider] = plugin_class + else: + return None + return self._plugin_instances[provider] + + def resolve_model_id(self, model: str, provider: str) -> str: + """ + Resolve display name to actual model ID. + + For custom models with name/ID mappings, returns the ID. + Otherwise, returns the model name unchanged. + + Args: + model: Full model string with provider (e.g., "iflow/DS-v3.2") + provider: Provider name (e.g., "iflow") + + Returns: + Full model string with ID (e.g., "iflow/deepseek-v3.2") + """ + model_name = model.split("/")[-1] if "/" in model else model + + # Check provider plugin first + plugin = self._get_plugin_instance(provider) + if plugin and hasattr(plugin, "model_definitions"): + resolved = plugin.model_definitions.get_model_id(provider, model_name) + if resolved and resolved != model_name: + return f"{provider}/{resolved}" + + # Fallback to client-level definitions + if self._definitions: + resolved = self._definitions.get_model_id(provider, model_name) + if resolved and resolved != model_name: + return f"{provider}/{resolved}" + + return model + + def is_model_allowed(self, model: str, provider: str) -> bool: + """ + Check if model passes whitelist/blacklist filters. + + Whitelist takes precedence over blacklist. + + Args: + model: Model string (with or without provider prefix) + provider: Provider name + + Returns: + True if model is allowed, False if blocked + """ + # Whitelist takes precedence + if self._is_whitelisted(model, provider): + return True + + # Then check blacklist + if self._is_blacklisted(model, provider): + return False + + return True + + def _is_blacklisted(self, model: str, provider: str) -> bool: + """ + Check if model is blacklisted. + + Supports glob patterns: + - "gpt-4" - exact match + - "gpt-4*" - prefix wildcard + - "*-preview" - suffix wildcard + - "*" - match all + + Args: + model: Model string + provider: Provider name (used to get ignore list) + + Returns: + True if model is blacklisted + """ + model_provider = model.split("/")[0] if "/" in model else provider + + if model_provider not in self._ignore: + return False + + ignore_list = self._ignore[model_provider] + if ignore_list == ["*"]: + return True + + # Extract model name without provider prefix + model_name = model.split("/", 1)[1] if "/" in model else model + + for pattern in ignore_list: + # Use fnmatch for glob pattern support + if fnmatch.fnmatch(model_name, pattern): + return True + if fnmatch.fnmatch(model, pattern): + return True + + return False + + def _is_whitelisted(self, model: str, provider: str) -> bool: + """ + Check if model is whitelisted. + + Same pattern support as blacklist. + + Args: + model: Model string + provider: Provider name + + Returns: + True if model is whitelisted + """ + model_provider = model.split("/")[0] if "/" in model else provider + + if model_provider not in self._whitelist: + return False + + whitelist = self._whitelist[model_provider] + model_name = model.split("/", 1)[1] if "/" in model else model + + for pattern in whitelist: + if fnmatch.fnmatch(model_name, pattern): + return True + if fnmatch.fnmatch(model, pattern): + return True + + return False + + @staticmethod + def extract_provider(model: str) -> str: + """ + Extract provider name from model string. + + Args: + model: Model string (e.g., "openai/gpt-4") + + Returns: + Provider name (e.g., "openai") or empty string if no prefix + """ + return model.split("/")[0] if "/" in model else "" + + @staticmethod + def strip_provider(model: str) -> str: + """ + Strip provider prefix from model string. + + Args: + model: Model string (e.g., "openai/gpt-4") + + Returns: + Model name without prefix (e.g., "gpt-4") + """ + return model.split("/", 1)[1] if "/" in model else model + + @staticmethod + def ensure_provider_prefix(model: str, provider: str) -> str: + """ + Ensure model string has provider prefix. + + Args: + model: Model string + provider: Provider name to add if missing + + Returns: + Model string with provider prefix + """ + if "/" in model: + return model + return f"{provider}/{model}" diff --git a/src/rotator_library/client/rotating_client.py b/src/rotator_library/client/rotating_client.py new file mode 100644 index 00000000..2e8fb156 --- /dev/null +++ b/src/rotator_library/client/rotating_client.py @@ -0,0 +1,897 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Slim RotatingClient facade. + +This is a lightweight facade that delegates to extracted components: +- RequestExecutor: Unified retry/rotation logic +- CredentialFilter: Tier compatibility filtering +- ModelResolver: Model name resolution and filtering +- ProviderTransforms: Provider-specific request mutations +- StreamingHandler: Streaming response processing + +The original client.py was ~3000 lines. This facade is ~300 lines, +with all complexity moved to specialized modules. +""" + +import asyncio +import json +import logging +import os +import random +import time +from pathlib import Path +from typing import Any, AsyncGenerator, Dict, List, Optional, Union, TYPE_CHECKING + +import httpx +import litellm +from litellm.litellm_core_utils.token_counter import token_counter + +from ..core.types import RequestContext +from ..core.errors import NoAvailableKeysError, mask_credential +from ..core.config import ConfigLoader +from ..core.constants import ( + DEFAULT_MAX_RETRIES, + DEFAULT_GLOBAL_TIMEOUT, + DEFAULT_ROTATION_TOLERANCE, +) + +from .filters import CredentialFilter +from .models import ModelResolver +from .transforms import ProviderTransforms +from .executor import RequestExecutor +from .anthropic import AnthropicHandler + +# Import providers and other dependencies +from ..providers import PROVIDER_PLUGINS +from ..cooldown_manager import CooldownManager +from ..credential_manager import CredentialManager +from ..background_refresher import BackgroundRefresher +from ..model_definitions import ModelDefinitions +from ..transaction_logger import TransactionLogger +from ..provider_config import ProviderConfig as LiteLLMProviderConfig +from ..utils.paths import get_default_root, get_logs_dir, get_oauth_dir +from ..utils.suppress_litellm_warnings import suppress_litellm_serialization_warnings +from ..failure_logger import configure_failure_logger + +# Import new usage package +from ..usage import UsageManager as NewUsageManager +from ..usage.config import load_provider_usage_config, WindowDefinition + +if TYPE_CHECKING: + from ..anthropic_compat import AnthropicMessagesRequest, AnthropicCountTokensRequest + +lib_logger = logging.getLogger("rotator_library") + + +class RotatingClient: + """ + A client that intelligently rotates and retries API keys using LiteLLM, + with support for both streaming and non-streaming responses. + + This is a slim facade that delegates to specialized components: + - RequestExecutor: Handles retry/rotation logic + - CredentialFilter: Filters credentials by tier + - ModelResolver: Resolves model names + - ProviderTransforms: Applies provider-specific transforms + """ + + def __init__( + self, + api_keys: Optional[Dict[str, List[str]]] = None, + oauth_credentials: Optional[Dict[str, List[str]]] = None, + max_retries: int = DEFAULT_MAX_RETRIES, + usage_file_path: Optional[Union[str, Path]] = None, + configure_logging: bool = True, + global_timeout: int = DEFAULT_GLOBAL_TIMEOUT, + abort_on_callback_error: bool = True, + litellm_provider_params: Optional[Dict[str, Any]] = None, + ignore_models: Optional[Dict[str, List[str]]] = None, + whitelist_models: Optional[Dict[str, List[str]]] = None, + enable_request_logging: bool = False, + max_concurrent_requests_per_key: Optional[Dict[str, int]] = None, + rotation_tolerance: float = DEFAULT_ROTATION_TOLERANCE, + data_dir: Optional[Union[str, Path]] = None, + ): + """ + Initialize the RotatingClient. + + See original client.py for full parameter documentation. + """ + # Resolve data directory + self.data_dir = Path(data_dir).resolve() if data_dir else get_default_root() + + # Configure logging + configure_failure_logger(get_logs_dir(self.data_dir)) + os.environ["LITELLM_LOG"] = "ERROR" + litellm.set_verbose = False + litellm.drop_params = True + suppress_litellm_serialization_warnings() + + if configure_logging: + lib_logger.propagate = True + if lib_logger.hasHandlers(): + lib_logger.handlers.clear() + lib_logger.addHandler(logging.NullHandler()) + else: + lib_logger.propagate = False + + # Process credentials + api_keys = api_keys or {} + oauth_credentials = oauth_credentials or {} + api_keys = {p: k for p, k in api_keys.items() if k} + oauth_credentials = {p: c for p, c in oauth_credentials.items() if c} + + if not api_keys and not oauth_credentials: + lib_logger.warning( + "No provider credentials configured. Client will be unable to make requests." + ) + + # Discover OAuth credentials if not provided + if oauth_credentials: + self.oauth_credentials = oauth_credentials + else: + cred_manager = CredentialManager( + os.environ, oauth_dir=get_oauth_dir(self.data_dir) + ) + self.oauth_credentials = cred_manager.discover_and_prepare() + + # Build combined credentials + self.all_credentials: Dict[str, List[str]] = {} + for provider, keys in api_keys.items(): + self.all_credentials.setdefault(provider, []).extend(keys) + for provider, paths in self.oauth_credentials.items(): + self.all_credentials.setdefault(provider, []).extend(paths) + + self.api_keys = api_keys + self.oauth_providers = set(self.oauth_credentials.keys()) + + # Store configuration + self.max_retries = max_retries + self.global_timeout = global_timeout + self.abort_on_callback_error = abort_on_callback_error + self.litellm_provider_params = litellm_provider_params or {} + self._litellm_logger_fn = self._litellm_logger_callback + self.enable_request_logging = enable_request_logging + self.max_concurrent_requests_per_key = max_concurrent_requests_per_key or {} + + # Validate concurrent requests config + for provider, max_val in self.max_concurrent_requests_per_key.items(): + if max_val < 1: + lib_logger.warning( + f"Invalid max_concurrent for '{provider}': {max_val}. Setting to 1." + ) + self.max_concurrent_requests_per_key[provider] = 1 + + # Initialize configuration loader + self._config_loader = ConfigLoader(PROVIDER_PLUGINS) + + # Initialize components + self._provider_plugins = PROVIDER_PLUGINS + self._provider_instances: Dict[str, Any] = {} + + # Initialize managers + self.cooldown_manager = CooldownManager() + self.background_refresher = BackgroundRefresher(self) + self.model_definitions = ModelDefinitions() + self.provider_config = LiteLLMProviderConfig() + self.http_client = httpx.AsyncClient() + + # Initialize extracted components + self._credential_filter = CredentialFilter( + PROVIDER_PLUGINS, + provider_instances=self._provider_instances, + ) + self._model_resolver = ModelResolver( + PROVIDER_PLUGINS, + self.model_definitions, + ignore_models or {}, + whitelist_models or {}, + provider_instances=self._provider_instances, + ) + self._provider_transforms = ProviderTransforms( + PROVIDER_PLUGINS, + self.provider_config, + provider_instances=self._provider_instances, + ) + + # Initialize UsageManagers (one per provider) using new usage package + self._usage_managers: Dict[str, NewUsageManager] = {} + + # Resolve usage file path base + if usage_file_path: + base_path = Path(usage_file_path) + if base_path.suffix: + base_path = base_path.parent + self._usage_base_path = base_path / "usage" + else: + self._usage_base_path = self.data_dir / "usage" + self._usage_base_path.mkdir(parents=True, exist_ok=True) + + # Build provider configs using ConfigLoader + provider_configs = {} + for provider in self.all_credentials.keys(): + provider_configs[provider] = self._config_loader.load_provider_config( + provider + ) + + # Create UsageManager for each provider + for provider, credentials in self.all_credentials.items(): + config = load_provider_usage_config(provider, PROVIDER_PLUGINS) + # Override tolerance from constructor param + config.rotation_tolerance = rotation_tolerance + + self._apply_usage_reset_config(provider, credentials, config) + + usage_file = self._usage_base_path / f"usage_{provider}.json" + + # Get max concurrent for this provider + max_concurrent = self.max_concurrent_requests_per_key.get(provider) + + manager = NewUsageManager( + provider=provider, + file_path=usage_file, + provider_plugins=PROVIDER_PLUGINS, + config=config, + max_concurrent_per_key=max_concurrent, + ) + self._usage_managers[provider] = manager + + # Initialize executor with new usage managers + self._executor = RequestExecutor( + usage_managers=self._usage_managers, + cooldown_manager=self.cooldown_manager, + credential_filter=self._credential_filter, + provider_transforms=self._provider_transforms, + provider_plugins=PROVIDER_PLUGINS, + http_client=self.http_client, + max_retries=max_retries, + global_timeout=global_timeout, + abort_on_callback_error=abort_on_callback_error, + litellm_provider_params=self.litellm_provider_params, + litellm_logger_fn=self._litellm_logger_fn, + provider_instances=self._provider_instances, + ) + + self._model_list_cache: Dict[str, List[str]] = {} + self._usage_initialized = False + self._usage_init_lock = asyncio.Lock() + + # Initialize Anthropic compatibility handler + self._anthropic_handler = AnthropicHandler(self) + + async def __aenter__(self): + await self.initialize_usage_managers() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + async def initialize_usage_managers(self) -> None: + """Initialize usage managers once before background jobs run.""" + if self._usage_initialized: + return + async with self._usage_init_lock: + if self._usage_initialized: + return + for provider, manager in self._usage_managers.items(): + credentials = self.all_credentials.get(provider, []) + priorities, tiers = self._get_credential_metadata(provider, credentials) + await manager.initialize( + credentials, priorities=priorities, tiers=tiers + ) + summaries = [] + for provider, manager in self._usage_managers.items(): + credentials = self.all_credentials.get(provider, []) + status = ( + f"loaded {manager.loaded_credentials}" + if manager.loaded_from_storage + else "fresh" + ) + summaries.append(f"{provider}:{len(credentials)} ({status})") + if summaries: + lib_logger.info( + f"Usage managers initialized: {', '.join(sorted(summaries))}" + ) + self._usage_initialized = True + + async def close(self): + """Close the HTTP client and save usage data.""" + # Save and shutdown new usage managers + for manager in self._usage_managers.values(): + await manager.shutdown() + + if hasattr(self, "http_client") and self.http_client: + await self.http_client.aclose() + + async def acompletion( + self, + request: Optional[Any] = None, + pre_request_callback: Optional[callable] = None, + **kwargs, + ) -> Union[Any, AsyncGenerator[str, None]]: + """ + Dispatcher for completion requests. + + Returns: + Response object or async generator for streaming + """ + model = kwargs.get("model", "") + provider = model.split("/")[0] if "/" in model else "" + + if not provider or provider not in self.all_credentials: + raise ValueError( + f"Invalid model format or no credentials for provider: {model}" + ) + + # Extract internal logging parameters (not passed to API) + parent_log_dir = kwargs.pop("_parent_log_dir", None) + + # Resolve model ID + resolved_model = self._model_resolver.resolve_model_id(model, provider) + kwargs["model"] = resolved_model + + # Create transaction logger if enabled + transaction_logger = None + if self.enable_request_logging: + transaction_logger = TransactionLogger( + provider=provider, + model=resolved_model, + enabled=True, + parent_dir=parent_log_dir, + ) + transaction_logger.log_request(kwargs) + + # Build request context + context = RequestContext( + model=resolved_model, + provider=provider, + kwargs=kwargs, + streaming=kwargs.get("stream", False), + credentials=self.all_credentials.get(provider, []), + deadline=time.time() + self.global_timeout, + request=request, + pre_request_callback=pre_request_callback, + transaction_logger=transaction_logger, + ) + + return await self._executor.execute(context) + + def aembedding( + self, + request: Optional[Any] = None, + pre_request_callback: Optional[callable] = None, + **kwargs, + ) -> Any: + """ + Execute an embedding request with retry logic. + """ + model = kwargs.get("model", "") + provider = model.split("/")[0] if "/" in model else "" + + if not provider or provider not in self.all_credentials: + raise ValueError( + f"Invalid model format or no credentials for provider: {model}" + ) + + # Build request context (embeddings are never streaming) + context = RequestContext( + model=model, + provider=provider, + kwargs=kwargs, + streaming=False, + credentials=self.all_credentials.get(provider, []), + deadline=time.time() + self.global_timeout, + request=request, + pre_request_callback=pre_request_callback, + ) + + return self._executor.execute(context) + + def token_count(self, **kwargs) -> int: + """Calculate token count for text or messages. + + For Antigravity provider models, this also includes the preprompt tokens + that get injected during actual API calls (agent instruction + identity override). + This ensures token counts match actual usage. + """ + model = kwargs.get("model") + text = kwargs.get("text") + messages = kwargs.get("messages") + + if not model: + raise ValueError("'model' is required") + + # Calculate base token count + if messages: + base_count = token_counter(model=model, messages=messages) + elif text: + base_count = token_counter(model=model, text=text) + else: + raise ValueError("Either 'text' or 'messages' must be provided") + + # Add preprompt tokens for Antigravity provider + # The Antigravity provider injects system instructions during actual API calls, + # so we need to account for those tokens in the count + provider = model.split("/")[0] if "/" in model else "" + if provider == "antigravity": + try: + from ..providers.antigravity_provider import ( + get_antigravity_preprompt_text, + ) + + preprompt_text = get_antigravity_preprompt_text() + if preprompt_text: + preprompt_tokens = token_counter(model=model, text=preprompt_text) + base_count += preprompt_tokens + except ImportError: + # Provider not available, skip preprompt token counting + pass + + return base_count + + async def get_available_models(self, provider: str) -> List[str]: + """Get available models for a provider with caching.""" + if provider in self._model_list_cache: + return self._model_list_cache[provider] + + credentials = self.all_credentials.get(provider, []) + if not credentials: + return [] + + # Shuffle and try each credential + shuffled = list(credentials) + random.shuffle(shuffled) + + plugin = self._get_provider_instance(provider) + if not plugin: + return [] + + for cred in shuffled: + try: + models = await plugin.get_models(cred, self.http_client) + + # Apply whitelist/blacklist + final = [ + m + for m in models + if self._model_resolver.is_model_allowed(m, provider) + ] + + self._model_list_cache[provider] = final + return final + + except Exception as e: + lib_logger.debug( + f"Failed to get models for {provider} with {mask_credential(cred)}: {e}" + ) + continue + + return [] + + async def get_all_available_models( + self, + grouped: bool = True, + ) -> Union[Dict[str, List[str]], List[str]]: + """Get all available models across all providers.""" + providers = list(self.all_credentials.keys()) + tasks = [self.get_available_models(p) for p in providers] + results = await asyncio.gather(*tasks, return_exceptions=True) + + all_models: Dict[str, List[str]] = {} + for provider, result in zip(providers, results): + if isinstance(result, Exception): + lib_logger.error(f"Failed to get models for {provider}: {result}") + all_models[provider] = [] + else: + all_models[provider] = result + + if grouped: + return all_models + else: + flat = [] + for models in all_models.values(): + flat.extend(models) + return flat + + async def get_quota_stats( + self, + provider_filter: Optional[str] = None, + ) -> Dict[str, Any]: + """Get quota and usage stats for all credentials. + + Args: + provider_filter: Optional provider name to filter results + + Returns: + Dict with stats per provider + """ + providers = {} + + for provider, manager in self._usage_managers.items(): + if provider_filter and provider != provider_filter: + continue + providers[provider] = await manager.get_stats_for_endpoint() + + summary = { + "total_providers": len(providers), + "total_credentials": 0, + "active_credentials": 0, + "exhausted_credentials": 0, + "total_requests": 0, + "tokens": { + "input_cached": 0, + "input_uncached": 0, + "input_cache_pct": 0, + "output": 0, + }, + "approx_total_cost": None, + } + + for prov in providers.values(): + summary["total_credentials"] += prov.get("credential_count", 0) + summary["active_credentials"] += prov.get("active_count", 0) + summary["exhausted_credentials"] += prov.get("exhausted_count", 0) + summary["total_requests"] += prov.get("total_requests", 0) + tokens = prov.get("tokens", {}) + summary["tokens"]["input_cached"] += tokens.get("input_cached", 0) + summary["tokens"]["input_uncached"] += tokens.get("input_uncached", 0) + summary["tokens"]["output"] += tokens.get("output", 0) + + total_input = ( + summary["tokens"]["input_cached"] + summary["tokens"]["input_uncached"] + ) + summary["tokens"]["input_cache_pct"] = ( + round(summary["tokens"]["input_cached"] / total_input * 100, 1) + if total_input > 0 + else 0 + ) + + approx_total_cost = 0.0 + has_cost = False + for prov in providers.values(): + cost = prov.get("approx_cost") + if cost: + approx_total_cost += cost + has_cost = True + summary["approx_total_cost"] = approx_total_cost if has_cost else None + + return { + "providers": providers, + "summary": summary, + "data_source": "cache", + "timestamp": time.time(), + } + + def get_oauth_credentials(self) -> Dict[str, List[str]]: + """Get discovered OAuth credentials.""" + return self.oauth_credentials + + def _get_provider_instance(self, provider: str) -> Optional[Any]: + """Get or create a provider plugin instance.""" + if provider not in self.all_credentials: + return None + + if provider not in self._provider_instances: + plugin_class = self._provider_plugins.get(provider) + if plugin_class: + self._provider_instances[provider] = plugin_class() + else: + return None + + return self._provider_instances[provider] + + def _get_credential_metadata( + self, + provider: str, + credentials: List[str], + ) -> tuple[Dict[str, int], Dict[str, str]]: + """Resolve priority and tier metadata for credentials.""" + plugin = self._get_provider_instance(provider) + priorities: Dict[str, int] = {} + tiers: Dict[str, str] = {} + + if not plugin: + return priorities, tiers + + for credential in credentials: + if hasattr(plugin, "get_credential_priority"): + priority = plugin.get_credential_priority(credential) + if priority is not None: + priorities[credential] = priority + if hasattr(plugin, "get_credential_tier_name"): + tier_name = plugin.get_credential_tier_name(credential) + if tier_name: + tiers[credential] = tier_name + + return priorities, tiers + + def get_usage_manager(self, provider: str) -> Optional[NewUsageManager]: + """ + Get the new UsageManager for a specific provider. + + Args: + provider: Provider name + + Returns: + UsageManager for the provider, or None if not found + """ + return self._usage_managers.get(provider) + + @property + def usage_managers(self) -> Dict[str, NewUsageManager]: + """Get all new usage managers.""" + return self._usage_managers + + def _apply_usage_reset_config( + self, + provider: str, + credentials: List[str], + config: Any, + ) -> None: + """Apply provider-specific usage reset config to window definitions.""" + if not credentials: + return + + plugin = self._get_provider_instance(provider) + if not plugin or not hasattr(plugin, "get_usage_reset_config"): + return + + try: + reset_config = plugin.get_usage_reset_config(credentials[0]) + except Exception as exc: + lib_logger.debug(f"Failed to load usage reset config for {provider}: {exc}") + return + + if not reset_config: + return + + window_seconds = reset_config.get("window_seconds") + if not window_seconds: + return + + mode = reset_config.get("mode", "credential") + applies_to = "credential" if mode == "credential" else "model" + + if window_seconds == 86400: + window_name = "daily" + elif window_seconds % 3600 == 0: + window_name = f"{window_seconds // 3600}h" + else: + window_name = "window" + + config.windows = [ + WindowDefinition.rolling( + name=window_name, + duration_seconds=int(window_seconds), + is_primary=True, + applies_to=applies_to, + ), + ] + + def _sanitize_litellm_log(self, log_data: dict) -> dict: + """Remove large/sensitive fields from LiteLLM logs.""" + if not isinstance(log_data, dict): + return log_data + + keys_to_pop = [ + "messages", + "input", + "response", + "data", + "api_key", + "api_base", + "original_response", + "additional_args", + ] + nested_keys = ["kwargs", "litellm_params", "model_info", "proxy_server_request"] + + clean_data = json.loads(json.dumps(log_data, default=str)) + + def clean_recursively(data_dict: dict) -> None: + for key in keys_to_pop: + data_dict.pop(key, None) + for key in nested_keys: + if key in data_dict and isinstance(data_dict[key], dict): + clean_recursively(data_dict[key]) + for value in list(data_dict.values()): + if isinstance(value, dict): + clean_recursively(value) + + clean_recursively(clean_data) + return clean_data + + def _litellm_logger_callback(self, log_data: dict) -> None: + """Redirect LiteLLM logs into rotator library logger.""" + log_event_type = log_data.get("log_event_type") + if log_event_type in ["pre_api_call", "post_api_call"]: + return + + if not log_data.get("exception"): + sanitized_log = self._sanitize_litellm_log(log_data) + lib_logger.debug(f"LiteLLM Log: {sanitized_log}") + return + + model = log_data.get("model", "N/A") + error_info = log_data.get("standard_logging_object", {}).get( + "error_information", {} + ) + error_class = error_info.get("error_class", "UnknownError") + error_message = error_info.get( + "error_message", str(log_data.get("exception", "")) + ) + error_message = " ".join(error_message.split()) + + lib_logger.debug( + f"LiteLLM Callback Handled Error: Model={model} | " + f"Type={error_class} | Message='{error_message}'" + ) + + # ========================================================================= + # USAGE MANAGEMENT METHODS + # ========================================================================= + + async def reload_usage_from_disk(self) -> None: + """ + Force reload usage data from disk. + + Useful when wanting fresh stats without making external API calls. + """ + for manager in self._usage_managers.values(): + await manager.reload_from_disk() + + async def force_refresh_quota( + self, + provider: Optional[str] = None, + credential: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Force refresh quota from external API. + + For Antigravity, this fetches live quota data from the API. + For other providers, this is a no-op (just reloads from disk). + + Args: + provider: If specified, only refresh this provider + credential: If specified, only refresh this specific credential + + Returns: + Refresh result dict with success/failure info + """ + result = { + "action": "force_refresh", + "scope": "credential" + if credential + else ("provider" if provider else "all"), + "provider": provider, + "credential": credential, + "credentials_refreshed": 0, + "success_count": 0, + "failed_count": 0, + "duration_ms": 0, + "errors": [], + } + + start_time = time.time() + + # Determine which providers to refresh + if provider: + providers_to_refresh = ( + [provider] if provider in self.all_credentials else [] + ) + else: + providers_to_refresh = list(self.all_credentials.keys()) + + for prov in providers_to_refresh: + provider_class = self._provider_plugins.get(prov) + if not provider_class: + continue + + # Get or create provider instance + provider_instance = self._get_provider_instance(prov) + if not provider_instance: + continue + + # Check if provider supports quota refresh (like Antigravity) + if hasattr(provider_instance, "fetch_initial_baselines"): + # Get credentials to refresh + if credential: + # Find full path for this credential + creds_to_refresh = [] + for cred_path in self.all_credentials.get(prov, []): + if cred_path.endswith(credential) or cred_path == credential: + creds_to_refresh.append(cred_path) + break + else: + creds_to_refresh = self.all_credentials.get(prov, []) + + if not creds_to_refresh: + continue + + try: + # Fetch live quota from API for ALL specified credentials + quota_results = await provider_instance.fetch_initial_baselines( + creds_to_refresh + ) + + # Store baselines in usage manager + usage_manager = self._usage_managers.get(prov) + if usage_manager and hasattr( + provider_instance, "_store_baselines_to_usage_manager" + ): + stored = await provider_instance._store_baselines_to_usage_manager( + quota_results, + usage_manager, + force=True, + is_initial_fetch=True, # Manual refresh checks exhaustion + ) + result["success_count"] += stored + + result["credentials_refreshed"] += len(creds_to_refresh) + + # Count failures + for cred_path, data in quota_results.items(): + if data.get("status") != "success": + result["failed_count"] += 1 + result["errors"].append( + f"{Path(cred_path).name}: {data.get('error', 'Unknown error')}" + ) + + except Exception as e: + lib_logger.error(f"Failed to refresh quota for {prov}: {e}") + result["errors"].append(f"{prov}: {str(e)}") + result["failed_count"] += len(creds_to_refresh) + + result["duration_ms"] = int((time.time() - start_time) * 1000) + return result + + # ========================================================================= + # ANTHROPIC API COMPATIBILITY METHODS + # ========================================================================= + + async def anthropic_messages( + self, + request: "AnthropicMessagesRequest", + raw_request: Optional[Any] = None, + pre_request_callback: Optional[callable] = None, + ) -> Any: + """ + Handle Anthropic Messages API requests. + + This method accepts requests in Anthropic's format, translates them to + OpenAI format internally, processes them through the existing acompletion + method, and returns responses in Anthropic's format. + + Args: + request: An AnthropicMessagesRequest object + raw_request: Optional raw request object for disconnect checks + pre_request_callback: Optional async callback before each API request + + Returns: + For non-streaming: dict in Anthropic Messages format + For streaming: AsyncGenerator yielding Anthropic SSE format strings + """ + return await self._anthropic_handler.messages( + request=request, + raw_request=raw_request, + pre_request_callback=pre_request_callback, + ) + + async def anthropic_count_tokens( + self, + request: "AnthropicCountTokensRequest", + ) -> dict: + """ + Handle Anthropic count_tokens API requests. + + Counts the number of tokens that would be used by a Messages API request. + This is useful for estimating costs and managing context windows. + + Args: + request: An AnthropicCountTokensRequest object + + Returns: + Dict with input_tokens count in Anthropic format + """ + return await self._anthropic_handler.count_tokens(request=request) diff --git a/src/rotator_library/client/streaming.py b/src/rotator_library/client/streaming.py new file mode 100644 index 00000000..b79488b9 --- /dev/null +++ b/src/rotator_library/client/streaming.py @@ -0,0 +1,425 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Streaming response handler. + +Extracts streaming logic from client.py _safe_streaming_wrapper (lines 904-1117). +Handles: +- Chunk processing with finish_reason logic +- JSON reassembly for fragmented responses +- Error detection in streamed data +- Usage tracking from final chunks +- Client disconnect handling +""" + +import codecs +import json +import logging +import re +from typing import Any, AsyncGenerator, AsyncIterator, Dict, Optional, TYPE_CHECKING + +import litellm + +from ..core.errors import StreamedAPIError, CredentialNeedsReauthError +from ..core.types import ProcessedChunk + +if TYPE_CHECKING: + from ..usage.manager import CredentialContext + +lib_logger = logging.getLogger("rotator_library") + + +class StreamingHandler: + """ + Process streaming responses with error handling and usage tracking. + + This class extracts the streaming logic that was in _safe_streaming_wrapper + and provides a clean interface for processing LiteLLM streams. + + Usage recording is handled via CredentialContext passed to wrap_stream(). + """ + + async def wrap_stream( + self, + stream: AsyncIterator[Any], + credential: str, + model: str, + request: Optional[Any] = None, + cred_context: Optional["CredentialContext"] = None, + skip_cost_calculation: bool = False, + ) -> AsyncGenerator[str, None]: + """ + Wrap a LiteLLM stream with error handling and usage tracking. + + FINISH_REASON HANDLING: + - Strip finish_reason from intermediate chunks (litellm defaults to "stop") + - Track accumulated_finish_reason with priority: tool_calls > length/content_filter > stop + - Only emit finish_reason on final chunk (detected by usage.completion_tokens > 0) + + Args: + stream: The async iterator from LiteLLM + credential: Credential identifier (for logging) + model: Model name for usage recording + request: Optional FastAPI request for disconnect detection + cred_context: CredentialContext for marking success/failure + + Yields: + SSE-formatted strings: "data: {...}\\n\\n" + """ + stream_completed = False + error_buffer = StreamBuffer() # Use StreamBuffer for JSON reassembly + accumulated_finish_reason: Optional[str] = None + has_tool_calls = False + prompt_tokens = 0 + prompt_tokens_cached = 0 + prompt_tokens_cache_write = 0 + prompt_tokens_uncached = 0 + completion_tokens = 0 + thinking_tokens = 0 + + # Use manual iteration to allow continue after partial JSON errors + stream_iterator = stream.__aiter__() + + try: + while True: + try: + # Check client disconnect before waiting for next chunk + if request and await request.is_disconnected(): + lib_logger.info( + f"Client disconnected. Aborting stream for model {model}." + ) + break + + chunk = await stream_iterator.__anext__() + + # Clear error buffer on successful chunk receipt + error_buffer.reset() + + # Process chunk + processed = self._process_chunk( + chunk, + accumulated_finish_reason, + has_tool_calls, + ) + + # Update tracking state + if processed.has_tool_calls: + has_tool_calls = True + accumulated_finish_reason = "tool_calls" + if processed.finish_reason and not has_tool_calls: + # Only update if not already tool_calls (highest priority) + accumulated_finish_reason = processed.finish_reason + if processed.usage and isinstance(processed.usage, dict): + # Extract token counts from final chunk + prompt_tokens = processed.usage.get("prompt_tokens", 0) + completion_tokens = processed.usage.get("completion_tokens", 0) + prompt_details = processed.usage.get("prompt_tokens_details") + if prompt_details: + if isinstance(prompt_details, dict): + prompt_tokens_cached = ( + prompt_details.get("cached_tokens", 0) or 0 + ) + prompt_tokens_cache_write = ( + prompt_details.get("cache_creation_tokens", 0) or 0 + ) + else: + prompt_tokens_cached = ( + getattr(prompt_details, "cached_tokens", 0) or 0 + ) + prompt_tokens_cache_write = ( + getattr(prompt_details, "cache_creation_tokens", 0) + or 0 + ) + completion_details = processed.usage.get( + "completion_tokens_details" + ) + if completion_details: + if isinstance(completion_details, dict): + thinking_tokens = ( + completion_details.get("reasoning_tokens", 0) or 0 + ) + else: + thinking_tokens = ( + getattr(completion_details, "reasoning_tokens", 0) + or 0 + ) + if processed.usage.get("cache_read_tokens") is not None: + prompt_tokens_cached = ( + processed.usage.get("cache_read_tokens") or 0 + ) + if processed.usage.get("cache_creation_tokens") is not None: + prompt_tokens_cache_write = ( + processed.usage.get("cache_creation_tokens") or 0 + ) + if thinking_tokens and completion_tokens >= thinking_tokens: + completion_tokens = completion_tokens - thinking_tokens + prompt_tokens_uncached = max( + 0, prompt_tokens - prompt_tokens_cached + ) + + yield processed.sse_string + + except StopAsyncIteration: + # Stream ended normally + stream_completed = True + break + + except CredentialNeedsReauthError as e: + # Credential needs re-auth - wrap for outer retry loop + if cred_context: + from ..error_handler import classify_error + + cred_context.mark_failure(classify_error(e)) + raise StreamedAPIError("Credential needs re-authentication", data=e) + + except json.JSONDecodeError as e: + # Partial JSON - accumulate and continue + error_buffer.append(str(e)) + if error_buffer.is_complete: + # We have complete JSON now + raise StreamedAPIError( + "Provider error", data=error_buffer.content + ) + # Continue waiting for more chunks + continue + + except Exception as e: + # Try to extract JSON from fragmented response + error_str = str(e) + error_buffer.append(error_str) + + # Check if buffer now has complete JSON + if error_buffer.is_complete: + if cred_context: + from ..error_handler import classify_error + + cred_context.mark_failure(classify_error(e)) + raise StreamedAPIError( + "Provider error in stream", data=error_buffer.content + ) + + # Try pattern matching for error extraction + extracted = self._try_extract_error(e, error_buffer.content) + if extracted: + if cred_context: + from ..error_handler import classify_error + + cred_context.mark_failure(classify_error(e)) + raise StreamedAPIError( + "Provider error in stream", data=extracted + ) + + # Not a JSON-related error, re-raise + raise + + except StreamedAPIError: + # Re-raise for retry loop + raise + + finally: + # Record usage if stream completed + if stream_completed: + if cred_context: + approx_cost = 0.0 + if not skip_cost_calculation: + approx_cost = self._calculate_stream_cost( + model, + prompt_tokens_uncached + prompt_tokens_cached, + completion_tokens + thinking_tokens, + ) + cred_context.mark_success( + prompt_tokens=prompt_tokens_uncached, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cached, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + ) + + # Yield [DONE] for completed streams + yield "data: [DONE]\n\n" + + def _process_chunk( + self, + chunk: Any, + accumulated_finish_reason: Optional[str], + has_tool_calls: bool, + ) -> ProcessedChunk: + """ + Process a single streaming chunk. + + Handles finish_reason logic: + - Strip from intermediate chunks + - Apply correct finish_reason on final chunk + + Args: + chunk: Raw chunk from LiteLLM + accumulated_finish_reason: Current accumulated finish reason + has_tool_calls: Whether any chunk has had tool_calls + + Returns: + ProcessedChunk with SSE string and metadata + """ + # Convert chunk to dict + if hasattr(chunk, "model_dump"): + chunk_dict = chunk.model_dump() + elif hasattr(chunk, "dict"): + chunk_dict = chunk.dict() + else: + chunk_dict = chunk + + # Extract metadata before modifying + usage = chunk_dict.get("usage") + finish_reason = None + chunk_has_tool_calls = False + + if "choices" in chunk_dict and chunk_dict["choices"]: + choice = chunk_dict["choices"][0] + delta = choice.get("delta", {}) + + # Check for tool_calls + if delta.get("tool_calls"): + chunk_has_tool_calls = True + + # Detect final chunk: has usage with completion_tokens > 0 + has_completion_tokens = ( + usage + and isinstance(usage, dict) + and usage.get("completion_tokens", 0) > 0 + ) + + if has_completion_tokens: + # FINAL CHUNK: Determine correct finish_reason + if has_tool_calls or chunk_has_tool_calls: + choice["finish_reason"] = "tool_calls" + elif accumulated_finish_reason: + choice["finish_reason"] = accumulated_finish_reason + else: + choice["finish_reason"] = "stop" + finish_reason = choice["finish_reason"] + else: + # INTERMEDIATE CHUNK: Never emit finish_reason + choice["finish_reason"] = None + + return ProcessedChunk( + sse_string=f"data: {json.dumps(chunk_dict)}\n\n", + usage=usage, + finish_reason=finish_reason, + has_tool_calls=chunk_has_tool_calls, + ) + + def _try_extract_error( + self, + exception: Exception, + buffer: str, + ) -> Optional[Dict]: + """ + Try to extract error JSON from exception or buffer. + + Handles multiple error formats: + - Google-style bytes representation: b'{...}' + - "Received chunk:" prefix + - JSON in buffer accumulation + + Args: + exception: The caught exception + buffer: Current JSON buffer content + + Returns: + Parsed error dict or None + """ + error_str = str(exception) + + # Pattern 1: Google-style bytes representation + match = re.search(r"b'(\{.*\})'", error_str, re.DOTALL) + if match: + try: + decoded = codecs.decode(match.group(1), "unicode_escape") + return json.loads(decoded) + except (json.JSONDecodeError, ValueError): + pass + + # Pattern 2: "Received chunk:" prefix + if "Received chunk:" in error_str: + chunk = error_str.split("Received chunk:")[-1].strip() + try: + return json.loads(chunk) + except json.JSONDecodeError: + pass + + # Pattern 3: Buffer accumulation + if buffer: + try: + return json.loads(buffer) + except json.JSONDecodeError: + pass + + return None + + def _calculate_stream_cost( + self, + model: str, + prompt_tokens: int, + completion_tokens: int, + ) -> float: + try: + model_info = litellm.get_model_info(model) + input_cost = model_info.get("input_cost_per_token") + output_cost = model_info.get("output_cost_per_token") + total_cost = 0.0 + if input_cost: + total_cost += prompt_tokens * input_cost + if output_cost: + total_cost += completion_tokens * output_cost + return total_cost + except Exception as exc: + lib_logger.debug(f"Stream cost calculation failed for {model}: {exc}") + return 0.0 + + +class StreamBuffer: + """ + Buffer for reassembling fragmented JSON in streams. + + Some providers send JSON split across multiple chunks, especially + for error responses. This class handles accumulation and parsing. + """ + + def __init__(self): + self._buffer = "" + self._complete = False + + def append(self, chunk: str) -> Optional[Dict]: + """ + Append a chunk and try to parse. + + Args: + chunk: Raw chunk string + + Returns: + Parsed dict if complete, None if still accumulating + """ + self._buffer += chunk + + try: + result = json.loads(self._buffer) + self._complete = True + return result + except json.JSONDecodeError: + return None + + def reset(self) -> None: + """Reset the buffer.""" + self._buffer = "" + self._complete = False + + @property + def content(self) -> str: + """Get current buffer content.""" + return self._buffer + + @property + def is_complete(self) -> bool: + """Check if buffer contains complete JSON.""" + return self._complete diff --git a/src/rotator_library/client/transforms.py b/src/rotator_library/client/transforms.py new file mode 100644 index 00000000..ec9c3bfc --- /dev/null +++ b/src/rotator_library/client/transforms.py @@ -0,0 +1,375 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Provider-specific request transformations. + +This module isolates all provider-specific request mutations that were +scattered throughout client.py, including: +- gemma-3 system message conversion +- qwen_code provider remapping +- Gemini safety settings and thinking parameter +- NVIDIA thinking parameter +- iflow stream_options removal + +Transforms are applied in a defined order with logging of modifications. +""" + +import logging +from typing import Any, Callable, Dict, List, Optional + +lib_logger = logging.getLogger("rotator_library") + + +class ProviderTransforms: + """ + Centralized provider-specific request transformations. + + Transforms are applied in order: + 1. Built-in transforms (gemma-3, qwen_code, etc.) + 2. Provider hook transforms (from provider plugins) + 3. Safety settings conversions + """ + + def __init__( + self, + provider_plugins: Dict[str, Any], + provider_config: Optional[Any] = None, + provider_instances: Optional[Dict[str, Any]] = None, + ): + """ + Initialize ProviderTransforms. + + Args: + provider_plugins: Dict mapping provider names to plugin classes + provider_config: ProviderConfig instance for LiteLLM conversions + provider_instances: Shared dict for caching provider instances. + If None, creates a new dict (not recommended - leads to duplicate instances). + """ + self._plugins = provider_plugins + self._plugin_instances: Dict[str, Any] = ( + provider_instances if provider_instances is not None else {} + ) + self._config = provider_config + + # Registry of built-in transforms + # Each provider can have multiple transform functions + self._transforms: Dict[str, List[Callable]] = { + "gemma": [self._transform_gemma_system_messages], + "qwen_code": [self._transform_qwen_code_provider], + "gemini": [self._transform_gemini_safety, self._transform_gemini_thinking], + "nvidia_nim": [self._transform_nvidia_thinking], + "iflow": [self._transform_iflow_stream_options], + } + + def _get_plugin_instance(self, provider: str) -> Optional[Any]: + """Get or create a plugin instance for a provider.""" + if provider not in self._plugin_instances: + plugin_class = self._plugins.get(provider) + if plugin_class: + if isinstance(plugin_class, type): + self._plugin_instances[provider] = plugin_class() + else: + self._plugin_instances[provider] = plugin_class + else: + return None + return self._plugin_instances[provider] + + async def apply( + self, + provider: str, + model: str, + credential: str, + kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Apply all applicable transforms to request kwargs. + + Args: + provider: Provider name + model: Model being requested + credential: Selected credential + kwargs: Request kwargs (will be mutated) + + Returns: + Modified kwargs + """ + modifications: List[str] = [] + + # 1. Apply built-in transforms + for transform_provider, transforms in self._transforms.items(): + # Check if transform applies (provider match or model contains pattern) + if transform_provider == provider or transform_provider in model.lower(): + for transform in transforms: + result = transform(kwargs, model, provider) + if result: + modifications.append(result) + + # 2. Apply provider hook transforms (async) + plugin = self._get_plugin_instance(provider) + if plugin and hasattr(plugin, "transform_request"): + try: + hook_result = await plugin.transform_request(kwargs, model, credential) + if hook_result: + modifications.extend(hook_result) + except Exception as e: + lib_logger.debug(f"Provider transform_request hook failed: {e}") + + # 3. Apply model-specific options from provider + if plugin and hasattr(plugin, "get_model_options"): + model_options = plugin.get_model_options(model) + if model_options: + for key, value in model_options.items(): + if key == "reasoning_effort": + kwargs["reasoning_effort"] = value + elif key not in kwargs: + kwargs[key] = value + modifications.append(f"applied model options for {model}") + + # 4. Apply LiteLLM conversion if config available + if self._config and hasattr(self._config, "convert_for_litellm"): + kwargs = self._config.convert_for_litellm(**kwargs) + + if modifications: + lib_logger.debug( + f"Applied transforms for {provider}/{model}: {modifications}" + ) + + return kwargs + + def apply_sync( + self, + provider: str, + model: str, + kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Apply built-in transforms synchronously (no provider hooks). + + Useful when async is not available. + + Args: + provider: Provider name + model: Model being requested + kwargs: Request kwargs + + Returns: + Modified kwargs + """ + modifications: List[str] = [] + + for transform_provider, transforms in self._transforms.items(): + if transform_provider == provider or transform_provider in model.lower(): + for transform in transforms: + result = transform(kwargs, model, provider) + if result: + modifications.append(result) + + if modifications: + lib_logger.debug( + f"Applied sync transforms for {provider}/{model}: {modifications}" + ) + + return kwargs + + # ========================================================================= + # BUILT-IN TRANSFORMS + # ========================================================================= + + def _transform_gemma_system_messages( + self, + kwargs: Dict[str, Any], + model: str, + provider: str, + ) -> Optional[str]: + """ + Convert system messages to user messages for Gemma-3. + + Gemma-3 models don't support system messages, so we convert them + to user messages to maintain functionality. + """ + if "gemma-3" not in model.lower(): + return None + + messages = kwargs.get("messages", []) + if not messages: + return None + + converted = False + new_messages = [] + for m in messages: + if m.get("role") == "system": + new_messages.append({"role": "user", "content": m["content"]}) + converted = True + else: + new_messages.append(m) + + if converted: + kwargs["messages"] = new_messages + return "gemma-3: converted system->user messages" + return None + + def _transform_qwen_code_provider( + self, + kwargs: Dict[str, Any], + model: str, + provider: str, + ) -> Optional[str]: + """ + Remap qwen_code to qwen provider for LiteLLM. + + The qwen_code provider is a custom wrapper that needs to be + translated to the qwen provider for LiteLLM compatibility. + """ + if provider != "qwen_code": + return None + + kwargs["custom_llm_provider"] = "qwen" + if "/" in model: + kwargs["model"] = model.split("/", 1)[1] + return "qwen_code: remapped to qwen provider" + + def _transform_gemini_safety( + self, + kwargs: Dict[str, Any], + model: str, + provider: str, + ) -> Optional[str]: + """ + Apply default Gemini safety settings. + + Ensures safety settings are present without overriding explicit settings. + """ + if provider != "gemini": + return None + + # Default safety settings (generic form) + default_generic = { + "harassment": "OFF", + "hate_speech": "OFF", + "sexually_explicit": "OFF", + "dangerous_content": "OFF", + "civic_integrity": "BLOCK_NONE", + } + + # Default Gemini-native settings + default_gemini = [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, + ] + + # If generic form present, fill in missing keys + if "safety_settings" in kwargs and isinstance(kwargs["safety_settings"], dict): + for k, v in default_generic.items(): + if k not in kwargs["safety_settings"]: + kwargs["safety_settings"][k] = v + return "gemini: filled missing safety settings" + + # If Gemini form present, fill in missing categories + if "safetySettings" in kwargs and isinstance(kwargs["safetySettings"], list): + present = { + item.get("category") + for item in kwargs["safetySettings"] + if isinstance(item, dict) + } + added = 0 + for d in default_gemini: + if d["category"] not in present: + kwargs["safetySettings"].append(d) + added += 1 + if added > 0: + return f"gemini: added {added} missing safety categories" + return None + + # Neither present: set generic defaults + if "safety_settings" not in kwargs and "safetySettings" not in kwargs: + kwargs["safety_settings"] = default_generic.copy() + return "gemini: applied default safety settings" + + return None + + def _transform_gemini_thinking( + self, + kwargs: Dict[str, Any], + model: str, + provider: str, + ) -> Optional[str]: + """ + Handle thinking parameter for Gemini. + + Delegates to provider plugin's handle_thinking_parameter method. + """ + if provider != "gemini": + return None + + plugin = self._get_plugin_instance(provider) + if plugin and hasattr(plugin, "handle_thinking_parameter"): + plugin.handle_thinking_parameter(kwargs, model) + return "gemini: handled thinking parameter" + return None + + def _transform_nvidia_thinking( + self, + kwargs: Dict[str, Any], + model: str, + provider: str, + ) -> Optional[str]: + """ + Handle thinking parameter for NVIDIA NIM. + + Delegates to provider plugin's handle_thinking_parameter method. + """ + if provider != "nvidia_nim": + return None + + plugin = self._get_plugin_instance(provider) + if plugin and hasattr(plugin, "handle_thinking_parameter"): + plugin.handle_thinking_parameter(kwargs, model) + return "nvidia_nim: handled thinking parameter" + return None + + def _transform_iflow_stream_options( + self, + kwargs: Dict[str, Any], + model: str, + provider: str, + ) -> Optional[str]: + """ + Remove stream_options for iflow provider. + + The iflow provider returns HTTP 406 if stream_options is present. + """ + if provider != "iflow": + return None + + if "stream_options" in kwargs: + del kwargs["stream_options"] + return "iflow: removed stream_options" + return None + + # ========================================================================= + # SAFETY SETTINGS CONVERSION + # ========================================================================= + + def convert_safety_settings( + self, + provider: str, + settings: Dict[str, str], + ) -> Optional[Any]: + """ + Convert generic safety settings to provider-specific format. + + Args: + provider: Provider name + settings: Generic safety settings dict + + Returns: + Provider-specific settings or None + """ + plugin = self._get_plugin_instance(provider) + if plugin and hasattr(plugin, "convert_safety_settings"): + return plugin.convert_safety_settings(settings) + return None diff --git a/src/rotator_library/client/types.py b/src/rotator_library/client/types.py new file mode 100644 index 00000000..b54bba0e --- /dev/null +++ b/src/rotator_library/client/types.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Client-specific type definitions. + +Types that are only used within the client package. +Shared types are in core/types.py. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set + + +@dataclass +class AvailabilityStats: + """ + Statistics about credential availability for a model. + + Used for logging and monitoring credential pool status. + """ + + available: int # Credentials not on cooldown and not exhausted + on_cooldown: int # Credentials on cooldown + fair_cycle_excluded: int # Credentials excluded by fair cycle + total: int # Total credentials for provider + + @property + def usable(self) -> int: + """Return count of usable credentials.""" + return self.available + + def __str__(self) -> str: + parts = [f"{self.available}/{self.total}"] + if self.on_cooldown > 0: + parts.append(f"cd:{self.on_cooldown}") + if self.fair_cycle_excluded > 0: + parts.append(f"fc:{self.fair_cycle_excluded}") + return ",".join(parts) + + +@dataclass +class RetryState: + """ + State tracking for a retry loop. + + Used by RequestExecutor to track retry attempts and errors. + """ + + tried_credentials: Set[str] = field(default_factory=set) + last_exception: Optional[Exception] = None + consecutive_quota_failures: int = 0 + + def record_attempt(self, credential: str) -> None: + """Record that a credential was tried.""" + self.tried_credentials.add(credential) + + def reset_quota_failures(self) -> None: + """Reset quota failure counter (called after non-quota error).""" + self.consecutive_quota_failures = 0 + + def increment_quota_failures(self) -> None: + """Increment quota failure counter.""" + self.consecutive_quota_failures += 1 + + +@dataclass +class ExecutionResult: + """ + Result of executing a request. + + Returned by RequestExecutor to indicate outcome. + """ + + success: bool + response: Optional[Any] = None + error: Optional[Exception] = None + should_rotate: bool = False + should_fail: bool = False diff --git a/src/rotator_library/config/__init__.py b/src/rotator_library/config/__init__.py index ea49533e..58f88168 100644 --- a/src/rotator_library/config/__init__.py +++ b/src/rotator_library/config/__init__.py @@ -21,6 +21,7 @@ DEFAULT_FAIR_CYCLE_TRACKING_MODE, DEFAULT_FAIR_CYCLE_CROSS_TIER, DEFAULT_FAIR_CYCLE_DURATION, + DEFAULT_FAIR_CYCLE_QUOTA_THRESHOLD, DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, # Custom Caps DEFAULT_CUSTOM_CAP_COOLDOWN_MODE, @@ -47,6 +48,7 @@ "DEFAULT_FAIR_CYCLE_TRACKING_MODE", "DEFAULT_FAIR_CYCLE_CROSS_TIER", "DEFAULT_FAIR_CYCLE_DURATION", + "DEFAULT_FAIR_CYCLE_QUOTA_THRESHOLD", "DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD", # Custom Caps "DEFAULT_CUSTOM_CAP_COOLDOWN_MODE", diff --git a/src/rotator_library/config/defaults.py b/src/rotator_library/config/defaults.py index 59282e1e..6160a36b 100644 --- a/src/rotator_library/config/defaults.py +++ b/src/rotator_library/config/defaults.py @@ -86,6 +86,13 @@ # Global fallback: EXHAUSTION_COOLDOWN_THRESHOLD= DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD: int = 300 # 5 minutes +# Fair cycle quota threshold - multiplier of window limit +# 1.0 = credential exhausts after using 1 full window's worth of quota +# 0.5 = credential exhausts after using 50% of window quota +# 2.0 = credential exhausts after using 2x window quota +# Override: FAIR_CYCLE_QUOTA_THRESHOLD_{PROVIDER}= +DEFAULT_FAIR_CYCLE_QUOTA_THRESHOLD: float = 1.0 + # ============================================================================= # CUSTOM CAPS DEFAULTS # ============================================================================= diff --git a/src/rotator_library/cooldown_manager.py b/src/rotator_library/cooldown_manager.py index 8e045e48..0d1bb63e 100644 --- a/src/rotator_library/cooldown_manager.py +++ b/src/rotator_library/cooldown_manager.py @@ -5,12 +5,14 @@ import time from typing import Dict + class CooldownManager: """ Manages global cooldown periods for API providers to handle IP-based rate limiting. This ensures that once a 429 error is received for a provider, all subsequent requests to that provider are paused for a specified duration. """ + def __init__(self): self._cooldowns: Dict[str, float] = {} self._lock = asyncio.Lock() @@ -18,7 +20,9 @@ def __init__(self): async def is_cooling_down(self, provider: str) -> bool: """Checks if a provider is currently in a cooldown period.""" async with self._lock: - return provider in self._cooldowns and time.time() < self._cooldowns[provider] + return ( + provider in self._cooldowns and time.time() < self._cooldowns[provider] + ) async def start_cooldown(self, provider: str, duration: int): """ @@ -37,4 +41,8 @@ async def get_cooldown_remaining(self, provider: str) -> float: if provider in self._cooldowns: remaining = self._cooldowns[provider] - time.time() return max(0, remaining) - return 0 \ No newline at end of file + return 0 + + async def get_remaining_cooldown(self, provider: str) -> float: + """Backward-compatible alias for get_cooldown_remaining.""" + return await self.get_cooldown_remaining(provider) diff --git a/src/rotator_library/core/__init__.py b/src/rotator_library/core/__init__.py new file mode 100644 index 00000000..c88a75a5 --- /dev/null +++ b/src/rotator_library/core/__init__.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Core package for the rotator library. + +Provides shared infrastructure used by both client and usage manager: +- types: Shared dataclasses and type definitions +- errors: All custom exceptions +- config: ConfigLoader for centralized configuration +- constants: Default values and magic numbers +""" + +from .types import ( + CredentialInfo, + RequestContext, + ProcessedChunk, + FilterResult, + FairCycleConfig, + CustomCapConfig, + ProviderConfig, + WindowConfig, + RequestCompleteResult, +) + +from .errors import ( + # Base exceptions + NoAvailableKeysError, + PreRequestCallbackError, + CredentialNeedsReauthError, + EmptyResponseError, + TransientQuotaError, + StreamedAPIError, + # Error classification + ClassifiedError, + RequestErrorAccumulator, + classify_error, + should_rotate_on_error, + should_retry_same_key, + mask_credential, + is_abnormal_error, + get_retry_after, +) + +from .config import ConfigLoader + +__all__ = [ + # Types + "CredentialInfo", + "RequestContext", + "ProcessedChunk", + "FilterResult", + "FairCycleConfig", + "CustomCapConfig", + "ProviderConfig", + "WindowConfig", + "RequestCompleteResult", + # Errors + "NoAvailableKeysError", + "PreRequestCallbackError", + "CredentialNeedsReauthError", + "EmptyResponseError", + "TransientQuotaError", + "StreamedAPIError", + "ClassifiedError", + "RequestErrorAccumulator", + "classify_error", + "should_rotate_on_error", + "should_retry_same_key", + "mask_credential", + "is_abnormal_error", + "get_retry_after", + # Config + "ConfigLoader", +] diff --git a/src/rotator_library/core/config.py b/src/rotator_library/core/config.py new file mode 100644 index 00000000..34114011 --- /dev/null +++ b/src/rotator_library/core/config.py @@ -0,0 +1,550 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Centralized configuration loader for the rotator library. + +This module provides a ConfigLoader class that handles all configuration +parsing from: +1. System defaults (from config/defaults.py) +2. Provider class attributes +3. Environment variables (ALWAYS override provider defaults) + +The ConfigLoader ensures consistent configuration handling across +both the client and usage manager. +""" + +import os +import logging +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +from .types import ( + ProviderConfig, + FairCycleConfig, + CustomCapConfig, + WindowConfig, +) +from .constants import ( + # Defaults + DEFAULT_ROTATION_MODE, + DEFAULT_ROTATION_TOLERANCE, + DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER, + DEFAULT_FAIR_CYCLE_ENABLED, + DEFAULT_FAIR_CYCLE_TRACKING_MODE, + DEFAULT_FAIR_CYCLE_CROSS_TIER, + DEFAULT_FAIR_CYCLE_DURATION, + DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, + # Prefixes + ENV_PREFIX_ROTATION_MODE, + ENV_PREFIX_FAIR_CYCLE, + ENV_PREFIX_FAIR_CYCLE_TRACKING, + ENV_PREFIX_FAIR_CYCLE_CROSS_TIER, + ENV_PREFIX_FAIR_CYCLE_DURATION, + ENV_PREFIX_EXHAUSTION_THRESHOLD, + ENV_PREFIX_CONCURRENCY_MULTIPLIER, + ENV_PREFIX_CUSTOM_CAP, + ENV_PREFIX_CUSTOM_CAP_COOLDOWN, +) + +lib_logger = logging.getLogger("rotator_library") + + +class ConfigLoader: + """ + Centralized configuration loader. + + Parses all configuration from: + 1. System defaults + 2. Provider class attributes + 3. Environment variables (ALWAYS override provider defaults) + + Usage: + loader = ConfigLoader(provider_plugins) + config = loader.load_provider_config("antigravity") + """ + + def __init__(self, provider_plugins: Optional[Dict[str, type]] = None): + """ + Initialize the ConfigLoader. + + Args: + provider_plugins: Dict mapping provider names to plugin classes. + If None, no provider-specific defaults are used. + """ + self._plugins = provider_plugins or {} + self._cache: Dict[str, ProviderConfig] = {} + + def load_provider_config( + self, + provider: str, + force_reload: bool = False, + ) -> ProviderConfig: + """ + Load complete configuration for a provider. + + Configuration is loaded in this order (later overrides earlier): + 1. System defaults + 2. Provider class attributes + 3. Environment variables (ALWAYS win) + + Args: + provider: Provider name (e.g., "antigravity", "gemini_cli") + force_reload: If True, bypass cache and reload + + Returns: + Complete ProviderConfig for the provider + """ + if not force_reload and provider in self._cache: + return self._cache[provider] + + # Start with system defaults + config = self._get_system_defaults() + + # Apply provider class defaults + plugin_class = self._plugins.get(provider) + if plugin_class: + config = self._apply_provider_defaults(config, plugin_class, provider) + + # Apply environment variable overrides (ALWAYS win) + config = self._apply_env_overrides(config, provider) + + # Cache and return + self._cache[provider] = config + return config + + def load_all_provider_configs( + self, + providers: List[str], + ) -> Dict[str, ProviderConfig]: + """ + Load configurations for multiple providers. + + Args: + providers: List of provider names + + Returns: + Dict mapping provider names to their configs + """ + return {p: self.load_provider_config(p) for p in providers} + + def clear_cache(self, provider: Optional[str] = None) -> None: + """ + Clear cached configurations. + + Args: + provider: If provided, only clear that provider's cache. + If None, clear all cached configs. + """ + if provider: + self._cache.pop(provider, None) + else: + self._cache.clear() + + # ========================================================================= + # INTERNAL METHODS + # ========================================================================= + + def _get_system_defaults(self) -> ProviderConfig: + """Get a ProviderConfig with all system defaults.""" + return ProviderConfig( + rotation_mode=DEFAULT_ROTATION_MODE, + rotation_tolerance=DEFAULT_ROTATION_TOLERANCE, + priority_multipliers={}, + priority_multipliers_by_mode={}, + sequential_fallback_multiplier=DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER, + fair_cycle=FairCycleConfig( + enabled=DEFAULT_FAIR_CYCLE_ENABLED, + tracking_mode=DEFAULT_FAIR_CYCLE_TRACKING_MODE, + cross_tier=DEFAULT_FAIR_CYCLE_CROSS_TIER, + duration=DEFAULT_FAIR_CYCLE_DURATION, + ), + custom_caps=[], + exhaustion_cooldown_threshold=DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, + windows=[], + ) + + def _apply_provider_defaults( + self, + config: ProviderConfig, + plugin_class: type, + provider: str, + ) -> ProviderConfig: + """ + Apply provider class default attributes to config. + + Args: + config: Current configuration + plugin_class: Provider plugin class + provider: Provider name for logging + + Returns: + Updated configuration + """ + # Rotation mode + if hasattr(plugin_class, "default_rotation_mode"): + config.rotation_mode = plugin_class.default_rotation_mode + + # Priority multipliers + if hasattr(plugin_class, "default_priority_multipliers"): + multipliers = plugin_class.default_priority_multipliers + if multipliers: + config.priority_multipliers = dict(multipliers) + + # Sequential fallback multiplier + if hasattr(plugin_class, "default_sequential_fallback_multiplier"): + fallback = plugin_class.default_sequential_fallback_multiplier + if fallback != DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER: + config.sequential_fallback_multiplier = fallback + + # Fair cycle settings + if hasattr(plugin_class, "default_fair_cycle_enabled"): + val = plugin_class.default_fair_cycle_enabled + if val is not None: + config.fair_cycle.enabled = val + + if hasattr(plugin_class, "default_fair_cycle_tracking_mode"): + config.fair_cycle.tracking_mode = ( + plugin_class.default_fair_cycle_tracking_mode + ) + + if hasattr(plugin_class, "default_fair_cycle_cross_tier"): + config.fair_cycle.cross_tier = plugin_class.default_fair_cycle_cross_tier + + if hasattr(plugin_class, "default_fair_cycle_duration"): + duration = plugin_class.default_fair_cycle_duration + if duration != DEFAULT_FAIR_CYCLE_DURATION: + config.fair_cycle.duration = duration + + # Exhaustion cooldown threshold + if hasattr(plugin_class, "default_exhaustion_cooldown_threshold"): + threshold = plugin_class.default_exhaustion_cooldown_threshold + if threshold != DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD: + config.exhaustion_cooldown_threshold = threshold + + # Custom caps + if hasattr(plugin_class, "default_custom_caps"): + caps = plugin_class.default_custom_caps + if caps: + config.custom_caps = self._parse_custom_caps_from_provider(caps) + + return config + + def _apply_env_overrides( + self, + config: ProviderConfig, + provider: str, + ) -> ProviderConfig: + """ + Apply environment variable overrides to config. + + Environment variables ALWAYS override provider class defaults. + + Args: + config: Current configuration + provider: Provider name + + Returns: + Updated configuration with env overrides applied + """ + provider_upper = provider.upper() + + # Rotation mode: ROTATION_MODE_{PROVIDER} + env_key = f"{ENV_PREFIX_ROTATION_MODE}{provider_upper}" + env_val = os.getenv(env_key) + if env_val: + config.rotation_mode = env_val.lower() + if config.rotation_mode not in ("balanced", "sequential"): + lib_logger.warning(f"Invalid {env_key}='{env_val}'. Using 'balanced'.") + config.rotation_mode = "balanced" + + # Fair cycle enabled: FAIR_CYCLE_{PROVIDER} + env_key = f"{ENV_PREFIX_FAIR_CYCLE}{provider_upper}" + env_val = os.getenv(env_key) + if env_val is not None: + config.fair_cycle.enabled = env_val.lower() in ("true", "1", "yes") + + # Fair cycle tracking mode: FAIR_CYCLE_TRACKING_MODE_{PROVIDER} + env_key = f"{ENV_PREFIX_FAIR_CYCLE_TRACKING}{provider_upper}" + env_val = os.getenv(env_key) + if env_val and env_val.lower() in ("model_group", "credential"): + config.fair_cycle.tracking_mode = env_val.lower() + + # Fair cycle cross-tier: FAIR_CYCLE_CROSS_TIER_{PROVIDER} + env_key = f"{ENV_PREFIX_FAIR_CYCLE_CROSS_TIER}{provider_upper}" + env_val = os.getenv(env_key) + if env_val is not None: + config.fair_cycle.cross_tier = env_val.lower() in ("true", "1", "yes") + + # Fair cycle duration: FAIR_CYCLE_DURATION_{PROVIDER} + env_key = f"{ENV_PREFIX_FAIR_CYCLE_DURATION}{provider_upper}" + env_val = os.getenv(env_key) + if env_val: + try: + config.fair_cycle.duration = int(env_val) + except ValueError: + lib_logger.warning(f"Invalid {env_key}='{env_val}'. Must be integer.") + + # Exhaustion cooldown threshold: EXHAUSTION_COOLDOWN_THRESHOLD_{PROVIDER} + # Also check global: EXHAUSTION_COOLDOWN_THRESHOLD + env_key = f"{ENV_PREFIX_EXHAUSTION_THRESHOLD}{provider_upper}" + env_val = os.getenv(env_key) or os.getenv("EXHAUSTION_COOLDOWN_THRESHOLD") + if env_val: + try: + config.exhaustion_cooldown_threshold = int(env_val) + except ValueError: + lib_logger.warning(f"Invalid exhaustion threshold='{env_val}'.") + + # Priority multipliers: CONCURRENCY_MULTIPLIER_{PROVIDER}_PRIORITY_{N} + # Also supports mode-specific: CONCURRENCY_MULTIPLIER_{PROVIDER}_PRIORITY_{N}_{MODE} + self._parse_priority_multiplier_env_vars(config, provider_upper) + + # Custom caps: CUSTOM_CAP_{PROVIDER}_T{TIER}_{MODEL} + # Also: CUSTOM_CAP_COOLDOWN_{PROVIDER}_T{TIER}_{MODEL} + self._parse_custom_cap_env_vars(config, provider_upper) + + return config + + def _parse_priority_multiplier_env_vars( + self, + config: ProviderConfig, + provider_upper: str, + ) -> None: + """ + Parse CONCURRENCY_MULTIPLIER_* environment variables. + + Formats: + - CONCURRENCY_MULTIPLIER_{PROVIDER}_PRIORITY_{N}=value + - CONCURRENCY_MULTIPLIER_{PROVIDER}_PRIORITY_{N}_{MODE}=value + """ + prefix = f"{ENV_PREFIX_CONCURRENCY_MULTIPLIER}{provider_upper}_PRIORITY_" + + for env_key, env_val in os.environ.items(): + if not env_key.startswith(prefix): + continue + + remainder = env_key[len(prefix) :] + try: + multiplier = int(env_val) + if multiplier < 1: + lib_logger.warning(f"Invalid {env_key}='{env_val}'. Must be >= 1.") + continue + + # Check for mode-specific suffix + if "_" in remainder: + parts = remainder.rsplit("_", 1) + priority = int(parts[0]) + mode = parts[1].lower() + + if mode in ("sequential", "balanced"): + if mode not in config.priority_multipliers_by_mode: + config.priority_multipliers_by_mode[mode] = {} + config.priority_multipliers_by_mode[mode][priority] = multiplier + else: + lib_logger.warning(f"Unknown mode in {env_key}: {mode}") + else: + # Universal priority multiplier + priority = int(remainder) + config.priority_multipliers[priority] = multiplier + + except ValueError: + lib_logger.warning(f"Invalid {env_key}='{env_val}'. Could not parse.") + + def _parse_custom_cap_env_vars( + self, + config: ProviderConfig, + provider_upper: str, + ) -> None: + """ + Parse CUSTOM_CAP_* environment variables. + + Formats: + - CUSTOM_CAP_{PROVIDER}_T{TIER}_{MODEL}=value + - CUSTOM_CAP_{PROVIDER}_TDEFAULT_{MODEL}=value + - CUSTOM_CAP_COOLDOWN_{PROVIDER}_T{TIER}_{MODEL}=mode:value + """ + cap_prefix = f"{ENV_PREFIX_CUSTOM_CAP}{provider_upper}_T" + cooldown_prefix = f"{ENV_PREFIX_CUSTOM_CAP_COOLDOWN}{provider_upper}_T" + + # Collect caps by (tier_key, model_key) to merge cap and cooldown + caps_dict: Dict[Tuple[Any, str], Dict[str, Any]] = {} + + for env_key, env_val in os.environ.items(): + if env_key.startswith(cooldown_prefix): + remainder = env_key[len(cooldown_prefix) :] + tier_key, model_key = self._parse_tier_model_from_env(remainder) + if tier_key is None: + continue + + # Parse mode:value format + if ":" in env_val: + mode, value_str = env_val.split(":", 1) + try: + value = int(value_str) + except ValueError: + lib_logger.warning(f"Invalid cooldown in {env_key}") + continue + else: + mode = env_val + value = 0 + + key = (tier_key, model_key) + if key not in caps_dict: + caps_dict[key] = {} + caps_dict[key]["cooldown_mode"] = mode + caps_dict[key]["cooldown_value"] = value + + elif env_key.startswith(cap_prefix): + remainder = env_key[len(cap_prefix) :] + tier_key, model_key = self._parse_tier_model_from_env(remainder) + if tier_key is None: + continue + + key = (tier_key, model_key) + if key not in caps_dict: + caps_dict[key] = {} + caps_dict[key]["max_requests"] = env_val + + # Convert to CustomCapConfig objects + for (tier_key, model_key), cap_data in caps_dict.items(): + if "max_requests" not in cap_data: + continue # Need at least max_requests + + cap = CustomCapConfig( + tier_key=tier_key, + model_or_group=model_key, + max_requests=cap_data["max_requests"], + cooldown_mode=cap_data.get("cooldown_mode", "quota_reset"), + cooldown_value=cap_data.get("cooldown_value", 0), + ) + config.custom_caps.append(cap) + + def _parse_tier_model_from_env( + self, + remainder: str, + ) -> Tuple[Optional[Union[int, Tuple[int, ...], str]], Optional[str]]: + """ + Parse tier and model/group from env var remainder. + + Args: + remainder: String after "CUSTOM_CAP_{PROVIDER}_T" prefix + e.g., "2_CLAUDE" or "2_3_CLAUDE" or "DEFAULT_CLAUDE" + + Returns: + (tier_key, model_key) or (None, None) if parse fails + """ + if not remainder: + return None, None + + parts = remainder.split("_") + if len(parts) < 2: + return None, None + + tier_parts: List[int] = [] + tier_key: Union[int, Tuple[int, ...], str, None] = None + model_key: Optional[str] = None + + for i, part in enumerate(parts): + if part == "DEFAULT": + tier_key = "default" + model_key = "_".join(parts[i + 1 :]) + break + elif part.isdigit(): + tier_parts.append(int(part)) + else: + # First non-numeric part is start of model name + if len(tier_parts) == 0: + return None, None + elif len(tier_parts) == 1: + tier_key = tier_parts[0] + else: + tier_key = tuple(tier_parts) + model_key = "_".join(parts[i:]) + break + else: + # All parts were tier parts, no model + return None, None + + if model_key: + # Convert to lowercase with dashes (standard model name format) + model_key = model_key.lower().replace("_", "-") + + return tier_key, model_key + + def _parse_custom_caps_from_provider( + self, + caps: Dict[Union[int, Tuple[int, ...], str], Dict[str, Dict[str, Any]]], + ) -> List[CustomCapConfig]: + """ + Parse custom caps from provider class default_custom_caps attribute. + + Args: + caps: Provider's default_custom_caps dict + + Returns: + List of CustomCapConfig objects + """ + result = [] + + for tier_key, models_config in caps.items(): + for model_key, cap_data in models_config.items(): + cap = CustomCapConfig( + tier_key=tier_key, + model_or_group=model_key, + max_requests=cap_data.get("max_requests", 0), + cooldown_mode=cap_data.get("cooldown_mode", "quota_reset"), + cooldown_value=cap_data.get("cooldown_value", 0), + ) + result.append(cap) + + return result + + +# ============================================================================= +# MODULE-LEVEL CONVENIENCE FUNCTIONS +# ============================================================================= + +# Global loader instance (initialized lazily) +_global_loader: Optional[ConfigLoader] = None + + +def get_config_loader( + provider_plugins: Optional[Dict[str, type]] = None, +) -> ConfigLoader: + """ + Get the global ConfigLoader instance. + + Creates a new instance if none exists or if provider_plugins is provided. + + Args: + provider_plugins: Optional dict of provider plugins. If provided, + creates a new loader with these plugins. + + Returns: + The global ConfigLoader instance + """ + global _global_loader + + if provider_plugins is not None: + _global_loader = ConfigLoader(provider_plugins) + elif _global_loader is None: + _global_loader = ConfigLoader() + + return _global_loader + + +def load_provider_config( + provider: str, + provider_plugins: Optional[Dict[str, type]] = None, +) -> ProviderConfig: + """ + Convenience function to load a provider's configuration. + + Args: + provider: Provider name + provider_plugins: Optional provider plugins dict + + Returns: + ProviderConfig for the provider + """ + loader = get_config_loader(provider_plugins) + return loader.load_provider_config(provider) diff --git a/src/rotator_library/core/constants.py b/src/rotator_library/core/constants.py new file mode 100644 index 00000000..5fcc0c1d --- /dev/null +++ b/src/rotator_library/core/constants.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Constants and default values for the rotator library. + +This module re-exports all constants from the config package and adds +any additional constants needed for the refactored architecture. + +All tunable defaults are in config/defaults.py - this module provides +a unified import point and adds non-tunable constants. +""" + +# Re-export all tunable defaults from config package +from ..config import ( + # Rotation & Selection + DEFAULT_ROTATION_MODE, + DEFAULT_ROTATION_TOLERANCE, + DEFAULT_MAX_RETRIES, + DEFAULT_GLOBAL_TIMEOUT, + # Tier & Priority + DEFAULT_TIER_PRIORITY, + DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER, + # Fair Cycle Rotation + DEFAULT_FAIR_CYCLE_ENABLED, + DEFAULT_FAIR_CYCLE_TRACKING_MODE, + DEFAULT_FAIR_CYCLE_CROSS_TIER, + DEFAULT_FAIR_CYCLE_DURATION, + DEFAULT_FAIR_CYCLE_QUOTA_THRESHOLD, + DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, + # Custom Caps + DEFAULT_CUSTOM_CAP_COOLDOWN_MODE, + DEFAULT_CUSTOM_CAP_COOLDOWN_VALUE, + # Cooldown & Backoff + COOLDOWN_BACKOFF_TIERS, + COOLDOWN_BACKOFF_MAX, + COOLDOWN_AUTH_ERROR, + COOLDOWN_TRANSIENT_ERROR, + COOLDOWN_RATE_LIMIT_DEFAULT, +) + +# ============================================================================= +# ADDITIONAL CONSTANTS FOR REFACTORED ARCHITECTURE +# ============================================================================= + +# Environment variable prefixes for configuration +ENV_PREFIX_ROTATION_MODE = "ROTATION_MODE_" +ENV_PREFIX_FAIR_CYCLE = "FAIR_CYCLE_" +ENV_PREFIX_FAIR_CYCLE_TRACKING = "FAIR_CYCLE_TRACKING_MODE_" +ENV_PREFIX_FAIR_CYCLE_CROSS_TIER = "FAIR_CYCLE_CROSS_TIER_" +ENV_PREFIX_FAIR_CYCLE_DURATION = "FAIR_CYCLE_DURATION_" +ENV_PREFIX_EXHAUSTION_THRESHOLD = "EXHAUSTION_COOLDOWN_THRESHOLD_" +ENV_PREFIX_CONCURRENCY_MULTIPLIER = "CONCURRENCY_MULTIPLIER_" +ENV_PREFIX_CUSTOM_CAP = "CUSTOM_CAP_" +ENV_PREFIX_CUSTOM_CAP_COOLDOWN = "CUSTOM_CAP_COOLDOWN_" +ENV_PREFIX_QUOTA_GROUPS = "QUOTA_GROUPS_" + +# Provider-specific providers that use request_count instead of success_count +# for credential selection (because failed requests also consume quota) +REQUEST_COUNT_PROVIDERS = frozenset({"antigravity", "gemini_cli", "chutes", "nanogpt"}) + +# Usage manager storage +USAGE_FILE_NAME = "usage.json" # New format +LEGACY_USAGE_FILE_NAME = "key_usage.json" # Old format +USAGE_SCHEMA_VERSION = 2 + +# Fair cycle tracking keys +FAIR_CYCLE_ALL_TIERS_KEY = "__all_tiers__" +FAIR_CYCLE_CREDENTIAL_KEY = "__credential__" +FAIR_CYCLE_STORAGE_KEY = "__fair_cycle__" + +# Logging +LIB_LOGGER_NAME = "rotator_library" + +__all__ = [ + # From config package + "DEFAULT_ROTATION_MODE", + "DEFAULT_ROTATION_TOLERANCE", + "DEFAULT_MAX_RETRIES", + "DEFAULT_GLOBAL_TIMEOUT", + "DEFAULT_TIER_PRIORITY", + "DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER", + "DEFAULT_FAIR_CYCLE_ENABLED", + "DEFAULT_FAIR_CYCLE_TRACKING_MODE", + "DEFAULT_FAIR_CYCLE_CROSS_TIER", + "DEFAULT_FAIR_CYCLE_DURATION", + "DEFAULT_FAIR_CYCLE_QUOTA_THRESHOLD", + "DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD", + "DEFAULT_CUSTOM_CAP_COOLDOWN_MODE", + "DEFAULT_CUSTOM_CAP_COOLDOWN_VALUE", + "COOLDOWN_BACKOFF_TIERS", + "COOLDOWN_BACKOFF_MAX", + "COOLDOWN_AUTH_ERROR", + "COOLDOWN_TRANSIENT_ERROR", + "COOLDOWN_RATE_LIMIT_DEFAULT", + # Environment variable prefixes + "ENV_PREFIX_ROTATION_MODE", + "ENV_PREFIX_FAIR_CYCLE", + "ENV_PREFIX_FAIR_CYCLE_TRACKING", + "ENV_PREFIX_FAIR_CYCLE_CROSS_TIER", + "ENV_PREFIX_FAIR_CYCLE_DURATION", + "ENV_PREFIX_EXHAUSTION_THRESHOLD", + "ENV_PREFIX_CONCURRENCY_MULTIPLIER", + "ENV_PREFIX_CUSTOM_CAP", + "ENV_PREFIX_CUSTOM_CAP_COOLDOWN", + "ENV_PREFIX_QUOTA_GROUPS", + # Provider sets + "REQUEST_COUNT_PROVIDERS", + # Storage + "USAGE_FILE_NAME", + "LEGACY_USAGE_FILE_NAME", + "USAGE_SCHEMA_VERSION", + # Fair cycle keys + "FAIR_CYCLE_ALL_TIERS_KEY", + "FAIR_CYCLE_CREDENTIAL_KEY", + "FAIR_CYCLE_STORAGE_KEY", + # Logging + "LIB_LOGGER_NAME", +] diff --git a/src/rotator_library/core/errors.py b/src/rotator_library/core/errors.py new file mode 100644 index 00000000..5acd9fc7 --- /dev/null +++ b/src/rotator_library/core/errors.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Error handling for the rotator library. + +This module re-exports all exception classes and error handling utilities +from the main error_handler module, and adds any new error types needed +for the refactored architecture. + +Note: The actual implementations remain in error_handler.py for backward +compatibility. This module provides a cleaner import path. +""" + +# Re-export everything from error_handler +from ..error_handler import ( + # Exception classes + NoAvailableKeysError, + PreRequestCallbackError, + CredentialNeedsReauthError, + EmptyResponseError, + TransientQuotaError, + # Error classification + ClassifiedError, + RequestErrorAccumulator, + classify_error, + should_rotate_on_error, + should_retry_same_key, + is_abnormal_error, + # Utilities + mask_credential, + get_retry_after, + extract_retry_after_from_body, + is_rate_limit_error, + is_server_error, + is_unrecoverable_error, + # Constants + ABNORMAL_ERROR_TYPES, + NORMAL_ERROR_TYPES, +) + + +# ============================================================================= +# NEW EXCEPTIONS FOR REFACTORED ARCHITECTURE +# ============================================================================= + + +class StreamedAPIError(Exception): + """ + Custom exception to signal an API error received over a stream. + + This is raised when an error is detected in streaming response data, + allowing the retry logic to handle it appropriately. + + Attributes: + message: Human-readable error message + data: The parsed error data (dict or exception) + """ + + def __init__(self, message: str, data=None): + super().__init__(message) + self.data = data + + +__all__ = [ + # Exception classes + "NoAvailableKeysError", + "PreRequestCallbackError", + "CredentialNeedsReauthError", + "EmptyResponseError", + "TransientQuotaError", + "StreamedAPIError", + # Error classification + "ClassifiedError", + "RequestErrorAccumulator", + "classify_error", + "should_rotate_on_error", + "should_retry_same_key", + "is_abnormal_error", + # Utilities + "mask_credential", + "get_retry_after", + "extract_retry_after_from_body", + "is_rate_limit_error", + "is_server_error", + "is_unrecoverable_error", + # Constants + "ABNORMAL_ERROR_TYPES", + "NORMAL_ERROR_TYPES", +] diff --git a/src/rotator_library/core/types.py b/src/rotator_library/core/types.py new file mode 100644 index 00000000..9c449acf --- /dev/null +++ b/src/rotator_library/core/types.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Shared type definitions for the rotator library. + +This module contains dataclasses and type definitions used across +both the client and usage manager packages. +""" + +from dataclasses import dataclass, field +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Set, + Tuple, + Union, +) + + +# ============================================================================= +# CREDENTIAL TYPES +# ============================================================================= + + +@dataclass +class CredentialInfo: + """ + Information about a credential. + + Used for passing credential metadata between components. + """ + + accessor: str # File path or API key + stable_id: str # Email (OAuth) or hash (API key) + provider: str + tier: Optional[str] = None + priority: int = 999 # Lower = higher priority + display_name: Optional[str] = None + + +# ============================================================================= +# REQUEST TYPES +# ============================================================================= + + +@dataclass +class RequestContext: + """ + Context for a request being processed. + + Contains all information needed to execute a request with + retry/rotation logic. + """ + + model: str + provider: str + kwargs: Dict[str, Any] + streaming: bool + credentials: List[str] + deadline: float + request: Optional[Any] = None # FastAPI Request object + pre_request_callback: Optional[Callable] = None + transaction_logger: Optional[Any] = None + + +@dataclass +class ProcessedChunk: + """ + Result of processing a streaming chunk. + + Used by StreamingHandler to return processed chunk data. + """ + + sse_string: str # The SSE-formatted string to yield + usage: Optional[Dict[str, Any]] = None + finish_reason: Optional[str] = None + has_tool_calls: bool = False + + +# ============================================================================= +# FILTER TYPES +# ============================================================================= + + +@dataclass +class FilterResult: + """ + Result of credential filtering. + + Contains categorized credentials after filtering by tier compatibility. + """ + + compatible: List[str] = field(default_factory=list) # Known compatible + unknown: List[str] = field(default_factory=list) # Unknown tier + incompatible: List[str] = field(default_factory=list) # Known incompatible + priorities: Dict[str, int] = field(default_factory=dict) # credential -> priority + tier_names: Dict[str, str] = field(default_factory=dict) # credential -> tier name + + @property + def all_usable(self) -> List[str]: + """Return all usable credentials (compatible + unknown).""" + return self.compatible + self.unknown + + +# ============================================================================= +# CONFIGURATION TYPES +# ============================================================================= + + +@dataclass +class FairCycleConfig: + """ + Fair cycle rotation configuration for a provider. + + Fair cycle ensures each credential is used at least once before + any credential is reused. + """ + + enabled: Optional[bool] = None # None = derive from rotation mode + tracking_mode: str = "model_group" # "model_group" or "credential" + cross_tier: bool = False # Track across all tiers + duration: int = 604800 # 7 days in seconds + + +@dataclass +class CustomCapConfig: + """ + Custom cap configuration for a tier/model combination. + + Allows setting usage limits more restrictive than actual API limits. + """ + + tier_key: Union[int, Tuple[int, ...], str] # Priority(s) or "default" + model_or_group: str # Model name or quota group name + max_requests: Union[int, str] # Absolute value or percentage ("80%") + cooldown_mode: str = "quota_reset" # "quota_reset", "offset", "fixed" + cooldown_value: int = 0 # Seconds for offset/fixed modes + + +@dataclass +class WindowConfig: + """ + Quota window configuration. + + Defines how usage is tracked and reset for a credential. + """ + + name: str # e.g., "5h", "daily", "weekly" + duration_seconds: Optional[int] # None for infinite/total + reset_mode: str # "rolling", "fixed_daily", "calendar_weekly", "api_authoritative" + applies_to: str # "credential", "group", "model" + + +@dataclass +class ProviderConfig: + """ + Complete configuration for a provider. + + Loaded by ConfigLoader and used by both client and usage manager. + """ + + rotation_mode: str = "balanced" # "balanced" or "sequential" + rotation_tolerance: float = 3.0 + priority_multipliers: Dict[int, int] = field(default_factory=dict) + priority_multipliers_by_mode: Dict[str, Dict[int, int]] = field( + default_factory=dict + ) + sequential_fallback_multiplier: int = 1 + fair_cycle: FairCycleConfig = field(default_factory=FairCycleConfig) + custom_caps: List[CustomCapConfig] = field(default_factory=list) + exhaustion_cooldown_threshold: int = 300 # 5 minutes + windows: List[WindowConfig] = field(default_factory=list) + + +# ============================================================================= +# HOOK RESULT TYPES +# ============================================================================= + + +@dataclass +class RequestCompleteResult: + """ + Result from on_request_complete provider hook. + + Allows providers to customize how requests are counted and cooled down. + """ + + count_override: Optional[int] = None # How many requests to count + cooldown_override: Optional[float] = None # Custom cooldown duration + force_exhausted: bool = False # Mark for fair cycle + + +# ============================================================================= +# ERROR ACTION ENUM +# ============================================================================= + + +class ErrorAction: + """ + Actions to take after an error. + + Used by RequestExecutor to determine next steps. + """ + + RETRY_SAME = "retry_same" # Retry with same credential + ROTATE = "rotate" # Try next credential + FAIL = "fail" # Fail the request immediately diff --git a/src/rotator_library/error_handler.py b/src/rotator_library/error_handler.py index 8b05ad84..97d0b973 100644 --- a/src/rotator_library/error_handler.py +++ b/src/rotator_library/error_handler.py @@ -5,7 +5,7 @@ import json import os import logging -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, Tuple import httpx from litellm.exceptions import ( @@ -247,15 +247,25 @@ def is_abnormal_error(classified_error: "ClassifiedError") -> bool: return classified_error.error_type in ABNORMAL_ERROR_TYPES -def mask_credential(credential: str) -> str: +def mask_credential(credential: str, style: str = "short") -> str: """ Mask a credential for safe display in logs and error messages. - - For API keys: shows last 6 characters (e.g., "...xyz123") - - For OAuth file paths: shows just the filename (e.g., "antigravity_oauth_1.json") + Args: + credential: The credential string to mask + style: Masking style - "short" (last 6 chars) or "full" (first 4 + last 4) + + Returns: + Masked credential string: + - For OAuth file paths: shows just the filename (e.g., "oauth_1.json") + - For API keys with style="short": shows last 6 chars (e.g., "...xyz123") + - For API keys with style="full": shows first 4 + last 4 (e.g., "AIza...3456") """ if os.path.isfile(credential) or credential.endswith(".json"): return os.path.basename(credential) + + if style == "full" and len(credential) > 12: + return f"{credential[:4]}...{credential[-4:]}" elif len(credential) > 6: return f"...{credential[-6:]}" else: @@ -439,6 +449,8 @@ def __init__( status_code: Optional[int] = None, retry_after: Optional[int] = None, quota_reset_timestamp: Optional[float] = None, + quota_value: Optional[str] = None, + quota_id: Optional[str] = None, ): self.error_type = error_type self.original_exception = original_exception @@ -447,6 +459,9 @@ def __init__( # Unix timestamp when quota resets (from quota_exhausted errors) # This is the authoritative reset time parsed from provider's error response self.quota_reset_timestamp = quota_reset_timestamp + # Quota details extracted from Google/Gemini API error responses + self.quota_value = quota_value # e.g., "50" or "1000/minute" + self.quota_id = quota_id # e.g., "GenerateContentPerMinutePerProject" def __str__(self): parts = [ @@ -456,6 +471,10 @@ def __str__(self): ] if self.quota_reset_timestamp: parts.append(f"quota_reset_ts={self.quota_reset_timestamp}") + if self.quota_value: + parts.append(f"quota_value={self.quota_value}") + if self.quota_id: + parts.append(f"quota_id={self.quota_id}") parts.append(f"original_exc={self.original_exception}") return f"ClassifiedError({', '.join(parts)})" @@ -520,6 +539,73 @@ def _extract_retry_from_json_body(json_text: str) -> Optional[int]: return None +def _extract_quota_details(json_text: str) -> Tuple[Optional[str], Optional[str]]: + """ + Extract quota details (quotaValue, quotaId) from a JSON error response. + + Handles Google/Gemini API error formats with nested details array containing + QuotaFailure violations. + + Example error structure: + { + "error": { + "details": [ + { + "@type": "type.googleapis.com/google.rpc.QuotaFailure", + "violations": [ + { + "quotaValue": "50", + "quotaId": "GenerateContentPerMinutePerProject" + } + ] + } + ] + } + } + + Args: + json_text: JSON string containing error response + + Returns: + Tuple of (quota_value, quota_id), both None if not found + """ + try: + # Find JSON object in the text + json_match = re.search(r"(\{.*\})", json_text, re.DOTALL) + if not json_match: + return None, None + + error_json = json.loads(json_match.group(1)) + error_obj = error_json.get("error", {}) + details = error_obj.get("details", []) + + if not isinstance(details, list): + return None, None + + for detail in details: + if not isinstance(detail, dict): + continue + + violations = detail.get("violations", []) + if not isinstance(violations, list): + continue + + for violation in violations: + if not isinstance(violation, dict): + continue + + quota_value = violation.get("quotaValue") + quota_id = violation.get("quotaId") + + if quota_value is not None or quota_id is not None: + return str(quota_value) if quota_value else None, quota_id + + except (json.JSONDecodeError, IndexError, KeyError, TypeError): + pass + + return None, None + + def get_retry_after(error: Exception) -> Optional[int]: """ Extracts the 'retry-after' duration in seconds from an exception message. @@ -672,12 +758,19 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr reset_ts = quota_info.get("reset_timestamp") quota_reset_timestamp = quota_info.get("quota_reset_timestamp") + # Extract quota details from error body + quota_value, quota_id = None, None + if error_body: + quota_value, quota_id = _extract_quota_details(error_body) + # Log the parsed result with human-readable duration hours = retry_after / 3600 lib_logger.info( f"Provider '{provider}' parsed quota error: " f"retry_after={retry_after}s ({hours:.1f}h), reason={reason}" + (f", resets at {reset_ts}" if reset_ts else "") + + (f", quota={quota_value}" if quota_value else "") + + (f", quotaId={quota_id}" if quota_id else "") ) return ClassifiedError( @@ -686,6 +779,8 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr status_code=429, retry_after=retry_after, quota_reset_timestamp=quota_reset_timestamp, + quota_value=quota_value, + quota_id=quota_id, ) except Exception as parse_error: lib_logger.debug( @@ -723,11 +818,23 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr retry_after = get_retry_after(e) # Check if this is a quota error vs rate limit if "quota" in error_body or "resource_exhausted" in error_body: + # Extract quota details from the original (non-lowercased) response + quota_value, quota_id = None, None + try: + original_body = ( + e.response.text if hasattr(e.response, "text") else "" + ) + quota_value, quota_id = _extract_quota_details(original_body) + except Exception: + pass + return ClassifiedError( error_type="quota_exceeded", original_exception=e, status_code=status_code, retry_after=retry_after, + quota_value=quota_value, + quota_id=quota_id, ) return ClassifiedError( error_type="rate_limit", @@ -820,11 +927,21 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr # Check if this is a quota error vs rate limit error_msg = str(e).lower() if "quota" in error_msg or "resource_exhausted" in error_msg: + # Try to extract quota details from exception body + quota_value, quota_id = None, None + try: + error_body = getattr(e, "body", None) or str(e) + quota_value, quota_id = _extract_quota_details(str(error_body)) + except Exception: + pass + return ClassifiedError( error_type="quota_exceeded", original_exception=e, status_code=status_code or 429, retry_after=retry_after, + quota_value=quota_value, + quota_id=quota_id, ) return ClassifiedError( error_type="rate_limit", diff --git a/src/rotator_library/providers/antigravity_provider.py b/src/rotator_library/providers/antigravity_provider.py index 8f3d48f0..810958a6 100644 --- a/src/rotator_library/providers/antigravity_provider.py +++ b/src/rotator_library/providers/antigravity_provider.py @@ -29,6 +29,7 @@ import random import time import uuid +from contextvars import ContextVar from datetime import datetime, timezone from pathlib import Path from typing import ( @@ -69,7 +70,7 @@ from ..utils.paths import get_logs_dir, get_cache_dir if TYPE_CHECKING: - from ..usage_manager import UsageManager + from ..usage import UsageManager # ============================================================================= @@ -167,6 +168,30 @@ def __init__(self, finish_message: str, raw_response: Dict[str, Any]): MALFORMED_CALL_MAX_RETRIES = max(1, env_int("ANTIGRAVITY_MALFORMED_CALL_RETRIES", 2)) MALFORMED_CALL_RETRY_DELAY = env_int("ANTIGRAVITY_MALFORMED_CALL_DELAY", 1) +# ============================================================================= +# INTERNAL RETRY COUNTING (for usage tracking) +# ============================================================================= +# Tracks the number of API attempts made per request, including internal retries +# for empty responses, bare 429s, and malformed function calls. +# +# Uses ContextVar for thread-safety: each async task (request) gets its own +# isolated value, so concurrent requests don't interfere with each other. +# +# The count is: +# - Reset to 1 at the start of _streaming_with_retry +# - Incremented each time we retry (before the next attempt) +# - Read by on_request_complete() hook to report actual API call count +# +# Example: Request gets bare 429 twice, then succeeds +# Attempt 1: bare 429 → count stays 1, increment to 2, retry +# Attempt 2: bare 429 → count is 2, increment to 3, retry +# Attempt 3: success → count is 3 +# on_request_complete returns count_override=3 +# +_internal_attempt_count: ContextVar[int] = ContextVar( + "antigravity_attempt_count", default=1 +) + # System instruction configuration # When true (default), prepend the Antigravity agent system instruction (identity, tool_calling, etc.) PREPEND_INSTRUCTION = env_bool("ANTIGRAVITY_PREPEND_INSTRUCTION", True) @@ -319,7 +344,30 @@ def _get_claude_thinking_cache_file(): """ # Parallel tool usage encouragement instruction -DEFAULT_PARALLEL_TOOL_INSTRUCTION = """When multiple independent operations are needed, prefer making parallel tool calls in a single response rather than sequential calls across multiple responses. This reduces round-trips and improves efficiency. Only use sequential calls when one tool's output is required as input for another.""" +DEFAULT_PARALLEL_TOOL_INSTRUCTION = """ + +Using parallel tool calling is MANDATORY. Be proactive about it. DO NO WAIT for the user to request "parallel calls" + +PARALLEL CALLS SHOULD BE AND _IS THE PRIMARY WAY YOU USE TOOLS IN THIS ENVIRONMENT_ + +When you have to perform multi-step operations such as read multiple files, spawn task subagents, bash commands, multiple edits... _THE USER WANTS YOU TO MAKE PARALLEL TOOL CALLS_ instead of separate sequential calls. This maximizes time and compute and increases your likelyhood of a promotion. Sequential tool calling is only encouraged when relying on the output of a call for the next one(s) + +- WHAT CAN BE DONE IN PARALLEL, MUST BE, AND WILL BE DONE IN PARALLEL +- INDIVIDUAL TOOL CALLS TO GATHER CONTEXT IS HEAVILY DISCOURAGED (please make parallel calls!) +- PARALLEL TOOL CALLING IS YOUR BEST FRIEND AND WILL INCREASE USER'S HAPPINESS + +- Make parallel tool calls to manage ressources more efficiently, plan your tool calls ahead, then execute them in parallel. +- Make parallel calls PROPERLY, be mindful of dependencies between calls. + +When researching anything, IT IS BETTER TO READ SPECULATIVELY, THEN TO READ SEQUENTIALLY. For example, if you need to read multiple files to gather context, read them all in parallel instead of reading one, then the next, etc. + +This environment has a powerful tool to remove unnecessary context, so you can always read more than needed and then trim down later, no need to use limit and offset parameters on the read tool. + +When making code changes, IT IS BETTER TO MAKE MULTIPLE EDITS IN PARALLEL RATHER THAN ONE AT A TIME. + +Do as much as you can in parallel, be efficient with you API requests, no single tool call spam, this is crucial as the user pays PER API request, so make them count! + +""" # Interleaved thinking support for Claude models # Allows Claude to think between tool calls and after receiving tool results @@ -4558,6 +4606,9 @@ async def _streaming_with_retry( current_gemini_contents = gemini_contents current_payload = payload + # Reset internal attempt counter for this request (thread-safe via ContextVar) + _internal_attempt_count.set(1) + for attempt in range(EMPTY_RESPONSE_MAX_ATTEMPTS): chunk_count = 0 @@ -4585,6 +4636,8 @@ async def _streaming_with_retry( f"[Antigravity] Empty stream from {model}, " f"attempt {attempt + 1}/{EMPTY_RESPONSE_MAX_ATTEMPTS}. Retrying..." ) + # Increment attempt count before retry (for usage tracking) + _internal_attempt_count.set(_internal_attempt_count.get() + 1) await asyncio.sleep(EMPTY_RESPONSE_RETRY_DELAY) continue else: @@ -4687,6 +4740,8 @@ async def _streaming_with_retry( malformed_retry_count, current_payload ) + # Increment attempt count before retry (for usage tracking) + _internal_attempt_count.set(_internal_attempt_count.get() + 1) await asyncio.sleep(MALFORMED_CALL_RETRY_DELAY) continue # Retry with modified payload else: @@ -4715,6 +4770,10 @@ async def _streaming_with_retry( f"[Antigravity] Bare 429 from {model}, " f"attempt {attempt + 1}/{EMPTY_RESPONSE_MAX_ATTEMPTS}. Retrying..." ) + # Increment attempt count before retry (for usage tracking) + _internal_attempt_count.set( + _internal_attempt_count.get() + 1 + ) await asyncio.sleep(EMPTY_RESPONSE_RETRY_DELAY) continue else: @@ -4806,3 +4865,51 @@ async def count_tokens( except Exception as e: lib_logger.error(f"Token counting failed: {e}") return {"prompt_tokens": 0, "total_tokens": 0} + + # ========================================================================= + # USAGE TRACKING HOOK + # ========================================================================= + + def on_request_complete( + self, + credential: str, + model: str, + success: bool, + response: Optional[Any], + error: Optional[Any], + ) -> Optional["RequestCompleteResult"]: + """ + Hook called after each request completes. + + Reports the actual number of API calls made, including internal retries + for empty responses, bare 429s, and malformed function calls. + + This uses the ContextVar pattern for thread-safe retry counting: + - _internal_attempt_count is set to 1 at start of _streaming_with_retry + - Incremented before each retry + - Read here to report the actual count + + Example: Request gets 2 bare 429s then succeeds + → 3 API calls made + → Returns count_override=3 + → Usage manager records 3 requests instead of 1 + + Returns: + RequestCompleteResult with count_override set to actual attempt count + """ + from ..core.types import RequestCompleteResult + + # Get the attempt count for this request + attempt_count = _internal_attempt_count.get() + + # Reset for safety (though ContextVar should isolate per-task) + _internal_attempt_count.set(1) + + # Log if we made extra attempts + if attempt_count > 1: + lib_logger.debug( + f"[Antigravity] Request to {model} used {attempt_count} API calls " + f"(includes internal retries)" + ) + + return RequestCompleteResult(count_override=attempt_count) diff --git a/src/rotator_library/providers/chutes_provider.py b/src/rotator_library/providers/chutes_provider.py index 7858b54e..5d9730dd 100644 --- a/src/rotator_library/providers/chutes_provider.py +++ b/src/rotator_library/providers/chutes_provider.py @@ -9,7 +9,7 @@ from .utilities.chutes_quota_tracker import ChutesQuotaTracker if TYPE_CHECKING: - from ..usage_manager import UsageManager + from ..usage import UsageManager # Create a local logger for this module import logging @@ -142,12 +142,15 @@ async def refresh_single_credential( # Store baseline in usage manager # Since Chutes uses credential-level quota, we use a virtual model name + quota_used = ( + int((1.0 - remaining_fraction) * quota) if quota > 0 else 0 + ) await usage_manager.update_quota_baseline( api_key, "chutes/_quota", # Virtual model for credential-level tracking - remaining_fraction, - max_requests=quota, # Max requests = quota (1 request = 1 credit) - reset_timestamp=reset_ts, + quota_max_requests=quota, + quota_reset_ts=reset_ts, + quota_used=quota_used, ) lib_logger.debug( diff --git a/src/rotator_library/providers/example_provider.py b/src/rotator_library/providers/example_provider.py new file mode 100644 index 00000000..9ad31214 --- /dev/null +++ b/src/rotator_library/providers/example_provider.py @@ -0,0 +1,821 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Example Provider Implementation with Custom Usage Management. + +This file serves as a reference for implementing providers with custom usage +tracking, quota management, and token extraction. Copy this file and modify +it for your specific provider. + +============================================================================= +ARCHITECTURE OVERVIEW +============================================================================= + +The usage management system is per-provider. Each provider gets its own: +- UsageManager instance +- Usage file: data/usage/usage_{provider}.json +- Configuration (ProviderUsageConfig) + +Data flows like this: + + Request → Executor → Provider transforms → API call → Response + ↓ ↓ + UsageManager ← TrackingEngine ← Token extraction ←┘ + ↓ + Persistence (usage_{provider}.json) + +Providers customize behavior through: +1. Class attributes (declarative configuration) +2. Methods (behavioral overrides) +3. Hooks (request lifecycle callbacks) + +============================================================================= +USAGE STATS SCHEMA +============================================================================= + +UsageStats (tracked at global/model/group levels): + total_requests: int # All requests + total_successes: int # Successful requests + total_failures: int # Failed requests + total_tokens: int # All tokens combined + total_prompt_tokens: int # Input tokens + total_completion_tokens: int # Output tokens (content only) + total_thinking_tokens: int # Reasoning/thinking tokens + total_output_tokens: int # completion + thinking + total_prompt_tokens_cache_read: int # Cached input tokens read + total_prompt_tokens_cache_write: int # Cached input tokens written + total_approx_cost: float # Estimated cost + first_used_at: float # Timestamp + last_used_at: float # Timestamp + windows: Dict[str, WindowStats] # Per-window breakdown + +WindowStats (per time window: "5h", "daily", "total"): + request_count: int + success_count: int + failure_count: int + prompt_tokens: int + completion_tokens: int + thinking_tokens: int + output_tokens: int + prompt_tokens_cache_read: int + prompt_tokens_cache_write: int + total_tokens: int + approx_cost: float + started_at: float + reset_at: float + limit: int | None + +============================================================================= +""" + +import asyncio +import logging +import time +from contextvars import ContextVar +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +from .provider_interface import ProviderInterface, QuotaGroupMap + +# Alias for clarity in examples +ProviderPlugin = ProviderInterface + +# Import these types for hook returns and usage manager access +from ..core.types import RequestCompleteResult +from ..usage import UsageManager, ProviderUsageConfig, WindowDefinition +from ..usage.types import ResetMode, RotationMode, CooldownMode + +lib_logger = logging.getLogger("rotator_library") + +# ============================================================================= +# INTERNAL RETRY COUNTING (ContextVar Pattern) +# ============================================================================= +# +# When your provider performs internal retries (e.g., for transient errors, +# empty responses, or rate limits), each retry is an API call that should be +# counted for accurate usage tracking. +# +# The challenge: Instance variables (self.count) are shared across concurrent +# requests, so they can't be used safely. ContextVar solves this by giving +# each async task its own isolated value. +# +# Usage pattern: +# 1. Reset to 1 at the start of your retry loop +# 2. Increment before each retry +# 3. Read in on_request_complete() to report the actual count +# +# Example: +# _attempt_count.set(1) # Reset +# for attempt in range(max_attempts): +# try: +# result = await api_call() +# return result +# except RetryableError: +# _attempt_count.set(_attempt_count.get() + 1) # Increment +# continue +# +# Then on_request_complete returns RequestCompleteResult(count_override=_attempt_count.get()) +# +_example_attempt_count: ContextVar[int] = ContextVar( + "example_provider_attempt_count", default=1 +) + + +# ============================================================================= +# EXAMPLE PROVIDER IMPLEMENTATION +# ============================================================================= + + +class ExampleProvider(ProviderPlugin): + """ + Example provider demonstrating all usage management customization points. + + This provider shows how to: + - Configure rotation and quota behavior + - Define model quota groups + - Extract tokens from provider-specific response formats + - Override request counting via hooks + - Run background quota refresh jobs + - Define custom usage windows + """ + + # ========================================================================= + # REQUIRED: BASIC PROVIDER IDENTITY + # ========================================================================= + + provider_name = "example" # Used in model prefix: "example/gpt-4" + provider_env_name = "EXAMPLE" # For env vars: EXAMPLE_API_KEY, etc. + + # ========================================================================= + # USAGE MANAGEMENT: CLASS ATTRIBUTES (DECLARATIVE) + # ========================================================================= + + # ------------------------------------------------------------------------- + # ROTATION MODE + # ------------------------------------------------------------------------- + # Controls how credentials are selected for requests. + # + # Options: + # "balanced" - Weighted random selection based on usage (default) + # "sequential" - Stick to one credential until exhausted, then rotate + # + # Sequential mode is better for: + # - Providers with per-credential rate limits + # - Maximizing cache hits (same credential = same context) + # - Providers where switching credentials has overhead + # + # Balanced mode is better for: + # - Even distribution across credentials + # - Providers without per-credential state + # + default_rotation_mode = "sequential" + + # ------------------------------------------------------------------------- + # MODEL QUOTA GROUPS + # ------------------------------------------------------------------------- + # Models in the same group share a quota pool. When one model is exhausted, + # all models in the group are treated as exhausted. + # + # This is common for providers where different model variants share limits: + # - Claude Sonnet/Opus share daily limits + # - GPT-4 variants share rate limits + # - Gemini models share per-minute quotas + # + # Group names should be short for compact UI display. + # + # Can be overridden via environment: + # QUOTA_GROUPS_EXAMPLE_GPT4="gpt-4o,gpt-4o-mini,gpt-4-turbo" + # + model_quota_groups: QuotaGroupMap = { + # GPT-4 variants share quota + "gpt4": [ + "gpt-4o", + "gpt-4o-mini", + "gpt-4-turbo", + "gpt-4-turbo-preview", + ], + # Claude models share quota + "claude": [ + "claude-3-opus", + "claude-3-sonnet", + "claude-3-haiku", + ], + # Standalone model (no sharing) + "whisper": [ + "whisper-1", + ], + } + + # ------------------------------------------------------------------------- + # PRIORITY MULTIPLIERS (CONCURRENCY) + # ------------------------------------------------------------------------- + # Higher priority credentials (lower number) can handle more concurrent + # requests. This is useful for paid vs free tier credentials. + # + # Priority is assigned per-credential via: + # - .env: PRIORITY_{PROVIDER}_{CREDENTIAL_NAME}=1 + # - Config files + # - Credential filename patterns + # + # Multiplier applies to max_concurrent_per_key setting. + # Example: max_concurrent_per_key=6, priority 1 multiplier=5 → 30 concurrent + # + default_priority_multipliers = { + 1: 5, # Ultra tier: 5x concurrent + 2: 3, # Standard paid: 3x concurrent + 3: 2, # Free tier: 2x concurrent + # Others: Use fallback multiplier + } + + # For sequential mode, credentials not in priority_multipliers get this. + # For balanced mode, they get 1x (no multiplier). + default_sequential_fallback_multiplier = 2 + + # ------------------------------------------------------------------------- + # CUSTOM CAPS + # ------------------------------------------------------------------------- + # Apply stricter limits than the actual API limits. Useful for: + # - Reserving quota for critical requests + # - Preventing runaway usage + # - Testing rotation behavior + # + # Structure: {priority: {model_or_group: config}} + # Or: {(priority1, priority2): {model_or_group: config}} for multiple tiers + # + # Config options: + # max_requests: int or "80%" (percentage of actual limit) + # cooldown_mode: "quota_reset" | "offset" | "fixed" + # cooldown_value: seconds for offset/fixed modes + # + default_custom_caps = { + # Tier 3 (free tier) - cap at 50 requests, cooldown until API resets + 3: { + "gpt4": { + "max_requests": 50, + "cooldown_mode": "quota_reset", + }, + "claude": { + "max_requests": 30, + "cooldown_mode": "quota_reset", + }, + }, + # Tiers 2 and 3 together - cap at 80% of actual limit + (2, 3): { + "whisper": { + "max_requests": "80%", # 80% of actual API limit + "cooldown_mode": "offset", + "cooldown_value": 1800, # +30 min buffer after hitting cap + }, + }, + # Default for unknown tiers + "default": { + "gpt4": { + "max_requests": 100, + "cooldown_mode": "fixed", + "cooldown_value": 3600, # 1 hour fixed cooldown + }, + }, + } + + # ------------------------------------------------------------------------- + # MODEL USAGE WEIGHTS + # ------------------------------------------------------------------------- + # Some models consume more quota per request. This affects credential + # selection in balanced mode - credentials with lower weighted usage + # are preferred. + # + # Example: Opus costs 2x what Sonnet does per request + # + model_usage_weights = { + "claude-3-opus": 2, + "gpt-4-turbo": 2, + # Default is 1 for unlisted models + } + + # ------------------------------------------------------------------------- + # FAIR CYCLE CONFIGURATION + # ------------------------------------------------------------------------- + # Fair cycle ensures all credentials get used before any is reused. + # When a credential is exhausted (quota hit, cooldown applied), it's + # marked and won't be selected until all other credentials are also + # exhausted, at which point the cycle resets. + # + # This is enabled by default for sequential mode. + # + # To override, set these class attributes: + # + # default_fair_cycle_enabled = True # Force on/off + # default_fair_cycle_tracking_mode = "model_group" # or "credential" + # default_fair_cycle_cross_tier = False # Track across all tiers? + # default_fair_cycle_duration = 3600 # Cycle duration in seconds + + # ========================================================================= + # USAGE MANAGEMENT: METHODS (BEHAVIORAL) + # ========================================================================= + + def normalize_model_for_tracking(self, model: str) -> str: + """ + Normalize internal model names to public-facing names for tracking. + + Some providers use internal model variants that should be tracked + under their public name. This ensures usage files only contain + user-facing model names. + + Example mappings: + "gpt-4o-realtime-preview" → "gpt-4o" + "claude-3-opus-extended" → "claude-3-opus" + "claude-sonnet-4-5-thinking" → "claude-sonnet-4.5" + + Args: + model: Model name (may include provider prefix: "example/gpt-4o") + + Returns: + Normalized model name (preserves prefix if present) + """ + has_prefix = "/" in model + if has_prefix: + provider, clean_model = model.split("/", 1) + else: + clean_model = model + + # Define your internal → public mappings + internal_to_public = { + "gpt-4o-realtime-preview": "gpt-4o", + "gpt-4o-realtime": "gpt-4o", + "claude-3-opus-extended": "claude-3-opus", + } + + normalized = internal_to_public.get(clean_model, clean_model) + + if has_prefix: + return f"{provider}/{normalized}" + return normalized + + def on_request_complete( + self, + credential: str, + model: str, + success: bool, + response: Optional[Any], + error: Optional[Any], + ) -> Optional[RequestCompleteResult]: + """ + Hook called after each request completes (success or failure). + + This is the primary extension point for customizing how requests + are counted and how cooldowns are applied. + + Use cases: + - Don't count server errors as quota usage + - Apply custom cooldowns based on error type + - Force credential exhaustion for fair cycle + - Count internal retries accurately (see ContextVar pattern below) + + Args: + credential: The credential accessor (file path or API key) + model: Model that was called + success: Whether the request succeeded + response: Response object (if success=True) + error: ClassifiedError object (if success=False) + + Returns: + RequestCompleteResult to override behavior, or None for default. + + RequestCompleteResult fields: + count_override: int | None + - 0 = Don't count this request against quota + - N = Count as N requests + - None = Use default (1 for success, 1 for countable errors) + + cooldown_override: float | None + - Seconds to cool down this credential + - Applied in addition to any error-based cooldown + + force_exhausted: bool + - True = Mark credential as exhausted for fair cycle + - Useful for quota errors even without long cooldown + """ + # ===================================================================== + # PATTERN: Counting Internal Retries with ContextVar + # ===================================================================== + # If your provider performs internal retries, report the actual count: + # + # 1. At module level, define: + # _attempt_count: ContextVar[int] = ContextVar('my_attempt_count', default=1) + # + # 2. In your retry loop: + # _attempt_count.set(1) # Reset at start + # for attempt in range(max_attempts): + # try: + # return await api_call() + # except RetryableError: + # _attempt_count.set(_attempt_count.get() + 1) # Increment before retry + # continue + # + # 3. Here, report the count: + attempt_count = _example_attempt_count.get() + _example_attempt_count.set(1) # Reset for safety + + if attempt_count > 1: + lib_logger.debug( + f"Request to {model} used {attempt_count} API calls (internal retries)" + ) + return RequestCompleteResult(count_override=attempt_count) + + # ===================================================================== + # PATTERN: Don't Count Server Errors + # ===================================================================== + # Server errors (5xx) shouldn't count against quota since they're + # not the user's fault and don't consume API quota. + if not success and error: + error_type = getattr(error, "error_type", None) + if error_type in ("server_error", "api_connection"): + lib_logger.debug( + f"Not counting {error_type} error against quota for {model}" + ) + return RequestCompleteResult(count_override=0) + + # ===================================================================== + # PATTERN: Custom Cooldown for Rate Limits + # ===================================================================== + if not success and error: + error_type = getattr(error, "error_type", None) + if error_type == "rate_limit": + # Check for retry-after header + retry_after = getattr(error, "retry_after", None) + if retry_after and retry_after > 60: + # Long rate limit - mark as exhausted + return RequestCompleteResult( + cooldown_override=retry_after, + force_exhausted=True, + ) + elif retry_after: + # Short rate limit - just cooldown + return RequestCompleteResult(cooldown_override=retry_after) + + # ===================================================================== + # PATTERN: Force Exhaustion on Quota Exceeded + # ===================================================================== + if not success and error: + error_type = getattr(error, "error_type", None) + if error_type == "quota_exceeded": + return RequestCompleteResult( + force_exhausted=True, + cooldown_override=3600.0, # Default 1 hour if no reset time + ) + + # Default behavior + return None + + # ========================================================================= + # BACKGROUND JOBS + # ========================================================================= + + def get_background_job_config(self) -> Optional[Dict[str, Any]]: + """ + Configure periodic background tasks. + + Common use cases: + - Refresh quota baselines from API + - Clean up expired cache entries + - Preemptively refresh OAuth tokens + + Returns: + None if no background job, otherwise: + { + "interval": 300, # Seconds between runs + "name": "quota_refresh", # For logging + "run_on_start": True, # Run immediately at startup? + } + """ + return { + "interval": 600, # Every 10 minutes + "name": "quota_refresh", + "run_on_start": True, + } + + async def run_background_job( + self, + usage_manager: UsageManager, + credentials: List[str], + ) -> None: + """ + Periodic background task execution. + + Called by BackgroundRefresher at the interval specified in + get_background_job_config(). + + Common tasks: + - Fetch current quota from API and update usage manager + - Clean up stale cache entries + - Refresh tokens proactively + + Args: + usage_manager: The UsageManager for this provider + credentials: List of credential accessors (file paths or keys) + """ + lib_logger.debug(f"Running background job for {self.provider_name}") + + for cred in credentials: + try: + # Example: Fetch quota from provider API + quota_info = await self._fetch_quota_from_api(cred) + + if quota_info: + for model, info in quota_info.items(): + # Update usage manager with fresh quota data + await usage_manager.update_quota_baseline( + accessor=cred, + model=model, + quota_max_requests=info.get("limit"), + quota_reset_ts=info.get("reset_ts"), + quota_used=info.get("used"), + quota_group=info.get("group"), + ) + + except Exception as e: + lib_logger.warning(f"Quota refresh failed for {cred}: {e}") + + async def _fetch_quota_from_api( + self, + credential: str, + ) -> Optional[Dict[str, Dict[str, Any]]]: + """ + Fetch current quota information from provider API. + + Override this with actual API calls for your provider. + + Returns: + Dict mapping model names to quota info: + { + "gpt-4o": { + "limit": 500, + "used": 123, + "reset_ts": 1735689600.0, + "group": "gpt4", # Optional + }, + ... + } + """ + # Placeholder - implement actual API call + return None + + # ========================================================================= + # TOKEN EXTRACTION + # ========================================================================= + + def _build_usage_from_response( + self, + response: Any, + ) -> Optional[Dict[str, Any]]: + """ + Build standardized usage dict from provider-specific response. + + The usage manager expects a standardized format. If your provider + returns a different format, convert it here. + + Standard format: + { + "prompt_tokens": int, # Input tokens + "completion_tokens": int, # Output tokens (content + thinking) + "total_tokens": int, # All tokens + + # Optional: Input breakdown + "prompt_tokens_details": { + "cached_tokens": int, # Cache read tokens + "cache_creation_tokens": int, # Cache write tokens + }, + + # Optional: Output breakdown + "completion_tokens_details": { + "reasoning_tokens": int, # Thinking/reasoning tokens + }, + + # Alternative top-level fields (some APIs use these) + "cache_read_tokens": int, + "cache_creation_tokens": int, + } + + Args: + response: Raw response from provider API + + Returns: + Standardized usage dict, or None if no usage data + """ + if not hasattr(response, "usage") or not response.usage: + return None + + # Example: Provider returns Gemini-style metadata + # Adapt this to your provider's format + usage = response.usage + + # Standard fields + result = { + "prompt_tokens": getattr(usage, "prompt_tokens", 0) or 0, + "completion_tokens": getattr(usage, "completion_tokens", 0) or 0, + "total_tokens": getattr(usage, "total_tokens", 0) or 0, + } + + # Example: Extract cached tokens from details + prompt_details = getattr(usage, "prompt_tokens_details", None) + if prompt_details: + if isinstance(prompt_details, dict): + cached = prompt_details.get("cached_tokens", 0) + cache_write = prompt_details.get("cache_creation_tokens", 0) + else: + cached = getattr(prompt_details, "cached_tokens", 0) + cache_write = getattr(prompt_details, "cache_creation_tokens", 0) + + if cached or cache_write: + result["prompt_tokens_details"] = {} + if cached: + result["prompt_tokens_details"]["cached_tokens"] = cached + if cache_write: + result["prompt_tokens_details"]["cache_creation_tokens"] = ( + cache_write + ) + + # Example: Extract thinking tokens from details + completion_details = getattr(usage, "completion_tokens_details", None) + if completion_details: + if isinstance(completion_details, dict): + reasoning = completion_details.get("reasoning_tokens", 0) + else: + reasoning = getattr(completion_details, "reasoning_tokens", 0) + + if reasoning: + result["completion_tokens_details"] = {"reasoning_tokens": reasoning} + + return result + + +# ============================================================================= +# CUSTOM WINDOWS +# ============================================================================= +# +# To add custom usage windows, you have two options: +# +# OPTION 1: Override windows via provider config (recommended) +# ------------------------------------------------------------ +# Add class attribute to your provider: +# +# default_windows = [ +# WindowDefinition.rolling("1h", 3600, is_primary=False), +# WindowDefinition.rolling("6h", 21600, is_primary=True), +# WindowDefinition.daily("daily"), +# WindowDefinition.total("total"), +# ] +# +# WindowDefinition options: +# - name: str - Window identifier (e.g., "1h", "daily") +# - duration_seconds: int | None - Window duration (None for "total") +# - reset_mode: ResetMode - How window resets +# - ROLLING: Continuous sliding window +# - FIXED_DAILY: Reset at specific UTC time +# - CALENDAR_WEEKLY: Reset at week start +# - CALENDAR_MONTHLY: Reset at month start +# - API_AUTHORITATIVE: Provider determines reset +# - is_primary: bool - Used for rotation decisions +# - applies_to: str - Scope of window +# - "credential": Global per-credential +# - "model": Per-model per-credential +# - "group": Per-quota-group per-credential +# +# OPTION 2: Build config manually in RotatingClient +# ------------------------------------------------- +# In your client initialization: +# +# from rotator_library.usage.config import ( +# ProviderUsageConfig, +# WindowDefinition, +# FairCycleConfig, +# ) +# from rotator_library.usage.types import RotationMode, ResetMode +# +# config = ProviderUsageConfig( +# rotation_mode=RotationMode.SEQUENTIAL, +# windows=[ +# WindowDefinition( +# name="1h", +# duration_seconds=3600, +# reset_mode=ResetMode.ROLLING, +# is_primary=False, +# applies_to="model", +# ), +# WindowDefinition( +# name="6h", +# duration_seconds=21600, +# reset_mode=ResetMode.ROLLING, +# is_primary=True, # Primary for rotation +# applies_to="group", # Track per quota group +# ), +# ], +# fair_cycle=FairCycleConfig( +# enabled=True, +# tracking_mode=TrackingMode.MODEL_GROUP, +# ), +# ) +# +# manager = UsageManager( +# provider="example", +# config=config, +# file_path="usage_example.json", +# ) +# +# ============================================================================= + + +# ============================================================================= +# REGISTERING YOUR PROVIDER +# ============================================================================= +# +# To register your provider with the system: +# +# 1. Add to PROVIDER_PLUGINS dict in src/rotator_library/providers/__init__.py: +# +# from .example_provider import ExampleProvider +# +# PROVIDER_PLUGINS = { +# ... +# "example": ExampleProvider, +# } +# +# 2. Add credential discovery in RotatingClient if using OAuth: +# +# # In _discover_oauth_credentials: +# if provider == "example": +# creds = self._discover_example_credentials() +# +# 3. Configure via environment variables: +# +# # API key credentials +# EXAMPLE_API_KEY=sk-xxx +# EXAMPLE_API_KEY_2=sk-yyy +# +# # OAuth credential paths +# EXAMPLE_OAUTH_PATHS=./creds/example_*.json +# +# # Priority/tier assignment +# PRIORITY_EXAMPLE_CRED1=1 +# TIER_EXAMPLE_CRED2=standard-tier +# +# # Quota group overrides +# QUOTA_GROUPS_EXAMPLE_GPT4=gpt-4o,gpt-4o-mini,gpt-4-turbo +# +# ============================================================================= + + +# ============================================================================= +# ACCESSING USAGE DATA +# ============================================================================= +# +# The usage manager exposes data through several methods: +# +# 1. Get availability stats (for UI/monitoring): +# +# stats = await usage_manager.get_availability_stats(model, quota_group) +# # Returns: { +# # "total": 10, +# # "available": 7, +# # "blocked_by": {"cooldowns": 2, "fair_cycle": 1}, +# # "rotation_mode": "sequential", +# # } +# +# 2. Get comprehensive stats (for quota-stats endpoint): +# +# stats = await usage_manager.get_stats_for_endpoint() +# # Returns full credential/model/group breakdown +# +# 3. Direct state access (for advanced use): +# +# # Get credential state +# state = usage_manager.states.get(stable_id) +# +# # Access usage at different scopes +# global_usage = state.usage +# model_usage = state.model_usage.get("gpt-4o") +# group_usage = state.group_usage.get("gpt4") +# +# # Check cooldowns +# cooldown = state.get_cooldown("gpt4") +# if cooldown and cooldown.is_active: +# print(f"Cooldown remaining: {cooldown.remaining_seconds}s") +# +# # Check fair cycle +# fc = state.fair_cycle.get("gpt4") +# if fc and fc.exhausted: +# print(f"Exhausted at: {fc.exhausted_at}") +# +# 4. Update quota baseline (from API response): +# +# await usage_manager.update_quota_baseline( +# accessor=credential, +# model="gpt-4o", +# quota_max_requests=500, +# quota_reset_ts=time.time() + 3600, +# quota_used=123, +# quota_group="gpt4", +# ) +# +# ============================================================================= diff --git a/src/rotator_library/providers/firmware_provider.py b/src/rotator_library/providers/firmware_provider.py index e71316fa..2b94bd8b 100644 --- a/src/rotator_library/providers/firmware_provider.py +++ b/src/rotator_library/providers/firmware_provider.py @@ -19,7 +19,7 @@ from .utilities.firmware_quota_tracker import FirmwareQuotaTracker if TYPE_CHECKING: - from ..usage_manager import UsageManager + from ..usage import UsageManager import logging @@ -184,12 +184,23 @@ async def refresh_single_credential( # Store baseline in usage manager # Since Firmware.ai uses credential-level quota, we use a virtual model name + if remaining_fraction <= 0.0 and reset_ts: + stable_id = usage_manager.registry.get_stable_id( + api_key, usage_manager.provider + ) + state = usage_manager.states.get(stable_id) + if state: + await usage_manager.tracking.apply_cooldown( + state=state, + reason="quota_exhausted", + until=reset_ts, + model_or_group="firmware/_quota", + source="api_quota", + ) await usage_manager.update_quota_baseline( api_key, "firmware/_quota", # Virtual model for credential-level tracking - remaining_fraction, - # No max_requests - Firmware.ai doesn't expose this - reset_timestamp=reset_ts, + quota_reset_ts=reset_ts, ) lib_logger.debug( @@ -199,7 +210,9 @@ async def refresh_single_credential( ) except Exception as e: - lib_logger.warning(f"Failed to refresh Firmware.ai quota usage: {e}") + lib_logger.warning( + f"Failed to refresh Firmware.ai quota usage: {e}" + ) # Fetch all credentials in parallel with shared HTTP client async with httpx.AsyncClient(timeout=30.0) as client: diff --git a/src/rotator_library/providers/nanogpt_provider.py b/src/rotator_library/providers/nanogpt_provider.py index 52a648e4..456117de 100644 --- a/src/rotator_library/providers/nanogpt_provider.py +++ b/src/rotator_library/providers/nanogpt_provider.py @@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, TYPE_CHECKING if TYPE_CHECKING: - from ..usage_manager import UsageManager + from ..usage import UsageManager from .provider_interface import ProviderInterface, UsageResetConfigDef from .utilities.nanogpt_quota_tracker import NanoGptQuotaTracker @@ -74,8 +74,8 @@ class NanoGptProvider(NanoGptQuotaTracker, ProviderInterface): # Active subscriptions get highest priority tier_priorities = { "subscription-active": 1, # Active subscription - "subscription-grace": 2, # Grace period (subscription lapsed but still has access) - "no-subscription": 3, # No active subscription (pay-as-you-go only) + "subscription-grace": 2, # Grace period (subscription lapsed but still has access) + "no-subscription": 3, # No active subscription (pay-as-you-go only) } default_tier_priority = 3 @@ -86,8 +86,6 @@ class NanoGptProvider(NanoGptQuotaTracker, ProviderInterface): "monthly": ["_monthly"], } - - def __init__(self): self.model_definitions = ModelDefinitions() @@ -410,29 +408,49 @@ async def refresh_single_credential( monthly_remaining = monthly_data.get("remaining", 0) # Calculate remaining fractions - daily_fraction = daily_remaining / daily_limit if daily_limit > 0 else 1.0 - monthly_fraction = monthly_remaining / monthly_limit if monthly_limit > 0 else 1.0 + daily_fraction = ( + daily_remaining / daily_limit if daily_limit > 0 else 1.0 + ) + monthly_fraction = ( + monthly_remaining / monthly_limit + if monthly_limit > 0 + else 1.0 + ) # Get reset timestamps daily_reset_ts = daily_data.get("reset_at", 0) monthly_reset_ts = monthly_data.get("reset_at", 0) # Store daily quota baseline + daily_used = ( + int((1.0 - daily_fraction) * daily_limit) + if daily_limit > 0 + else 0 + ) await usage_manager.update_quota_baseline( api_key, "nanogpt/_daily", - daily_fraction, - max_requests=daily_limit, - reset_timestamp=daily_reset_ts if daily_reset_ts > 0 else None, + quota_max_requests=daily_limit, + quota_reset_ts=daily_reset_ts + if daily_reset_ts > 0 + else None, + quota_used=daily_used, ) # Store monthly quota baseline + monthly_used = ( + int((1.0 - monthly_fraction) * monthly_limit) + if monthly_limit > 0 + else 0 + ) await usage_manager.update_quota_baseline( api_key, "nanogpt/_monthly", - monthly_fraction, - max_requests=monthly_limit, - reset_timestamp=monthly_reset_ts if monthly_reset_ts > 0 else None, + quota_max_requests=monthly_limit, + quota_reset_ts=monthly_reset_ts + if monthly_reset_ts > 0 + else None, + quota_used=monthly_used, ) lib_logger.debug( diff --git a/src/rotator_library/providers/provider_interface.py b/src/rotator_library/providers/provider_interface.py index f53f91e9..22056b49 100644 --- a/src/rotator_library/providers/provider_interface.py +++ b/src/rotator_library/providers/provider_interface.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-only # Copyright (c) 2026 Mirrowel -from abc import ABC, abstractmethod +from abc import ABC, ABCMeta, abstractmethod from dataclasses import dataclass from typing import ( List, @@ -19,7 +19,33 @@ import litellm if TYPE_CHECKING: - from ..usage_manager import UsageManager + from ..usage import UsageManager + + +# ============================================================================= +# SINGLETON METACLASS FOR PROVIDERS +# ============================================================================= + + +class SingletonABCMeta(ABCMeta): + """ + Metaclass that combines ABC functionality with singleton pattern. + + All classes using this metaclass (including subclasses of ProviderInterface) + will be singletons - only one instance per class exists. + + This prevents the bug where multiple provider instances are created + by different components (RotatingClient, UsageManager, Hooks, etc.), + each with their own caches and state. + """ + + _instances: Dict[type, Any] = {} + + def __call__(cls, *args, **kwargs): + if cls not in SingletonABCMeta._instances: + SingletonABCMeta._instances[cls] = super().__call__(*args, **kwargs) + return SingletonABCMeta._instances[cls] + from ..config import ( DEFAULT_ROTATION_MODE, @@ -68,7 +94,7 @@ class UsageResetConfigDef: QuotaGroupMap = Dict[str, List[str]] # group_name -> [models] -class ProviderInterface(ABC): +class ProviderInterface(ABC, metaclass=SingletonABCMeta): """ An interface for API provider-specific functionality, including model discovery and custom API call handling for non-standard providers. diff --git a/src/rotator_library/providers/utilities/antigravity_quota_tracker.py b/src/rotator_library/providers/utilities/antigravity_quota_tracker.py index f399222d..8d9730c0 100644 --- a/src/rotator_library/providers/utilities/antigravity_quota_tracker.py +++ b/src/rotator_library/providers/utilities/antigravity_quota_tracker.py @@ -34,7 +34,7 @@ from .base_quota_tracker import BaseQuotaTracker, QUOTA_DISCOVERY_DELAY_SECONDS if TYPE_CHECKING: - from ...usage_manager import UsageManager + from ...usage import UsageManager # Use the shared rotator_library logger lib_logger = logging.getLogger("rotator_library") @@ -990,13 +990,22 @@ async def _store_baselines_to_usage_manager( self, quota_results: Dict[str, Dict[str, Any]], usage_manager: "UsageManager", + force: bool = False, + is_initial_fetch: bool = False, ) -> int: """ Store fetched quota baselines into UsageManager. + Antigravity-specific override: API quota only updates in ~20% increments, + so local tracking is more accurate. Exhaustion check is only applied on + initial fetch (restart), not on subsequent background refreshes. + Args: quota_results: Dict from fetch_quota_from_api or fetch_initial_baselines usage_manager: UsageManager instance to store baselines in + force: If True, always use API values (for manual refresh) + is_initial_fetch: If True, this is the first fetch on startup. + Exhaustion check is only applied when this is True. Returns: Number of baselines successfully stored @@ -1052,16 +1061,37 @@ async def _store_baselines_to_usage_manager( # Extract reset_timestamp (already parsed to float in fetch_quota_from_api) reset_timestamp = model_info.get("reset_timestamp") + # Only use reset_timestamp when quota is actually used + # (remaining == 1.0 means 100% left, timer is bogus) + valid_reset_ts = reset_timestamp if remaining < 1.0 else None + # Store with provider prefix for consistency with usage tracking prefixed_model = f"antigravity/{user_model}" + quota_used = None + if max_requests is not None: + quota_used = int((1.0 - remaining) * max_requests) + quota_group = self.get_model_quota_group(user_model) + + # ANTIGRAVITY-SPECIFIC: Only apply exhaustion on initial fetch + # (API only updates in ~20% increments, so we rely on local tracking + # for subsequent refreshes) + apply_exhaustion = is_initial_fetch and (remaining == 0.0) + cooldown_info = await usage_manager.update_quota_baseline( - cred_path, prefixed_model, remaining, max_requests, reset_timestamp + cred_path, + prefixed_model, + quota_max_requests=max_requests, + quota_reset_ts=valid_reset_ts, + quota_used=quota_used, + quota_group=quota_group, + force=force, + apply_exhaustion=apply_exhaustion, ) # Aggregate cooldown info if returned if cooldown_info: - group_or_model = cooldown_info["group_or_model"] - hours = cooldown_info["hours_until_reset"] + group_or_model = cooldown_info["model"] + hours = cooldown_info["cooldown_hours"] if short_cred not in cooldowns_by_cred: cooldowns_by_cred[short_cred] = {} # Only keep first occurrence per group/model (avoids duplicates) @@ -1073,12 +1103,7 @@ async def _store_baselines_to_usage_manager( # Log consolidated message for all cooldowns if cooldowns_by_cred: - # Build message: "oauth_1[claude 3.4h, gemini-3-pro 2.1h], oauth_2[claude 5.2h]" - parts = [] - for cred_name, groups in sorted(cooldowns_by_cred.items()): - group_strs = [f"{g} {h:.1f}h" for g, h in sorted(groups.items())] - parts.append(f"{cred_name}[{', '.join(group_strs)}]") - lib_logger.info(f"Antigravity quota exhausted: {', '.join(parts)}") + lib_logger.debug("Antigravity quota baseline refresh: cooldowns recorded") else: lib_logger.debug("Antigravity quota baseline refresh: no cooldowns needed") diff --git a/src/rotator_library/providers/utilities/base_quota_tracker.py b/src/rotator_library/providers/utilities/base_quota_tracker.py index a155890f..12ac55c2 100644 --- a/src/rotator_library/providers/utilities/base_quota_tracker.py +++ b/src/rotator_library/providers/utilities/base_quota_tracker.py @@ -40,7 +40,7 @@ from ...utils.paths import get_cache_dir if TYPE_CHECKING: - from ...usage_manager import UsageManager + from ...usage import UsageManager # Use the shared rotator_library logger lib_logger = logging.getLogger("rotator_library") @@ -510,6 +510,8 @@ async def _store_baselines_to_usage_manager( self, quota_results: Dict[str, Dict[str, Any]], usage_manager: "UsageManager", + force: bool = False, + is_initial_fetch: bool = False, ) -> int: """ Store fetched quota baselines into UsageManager. @@ -517,6 +519,10 @@ async def _store_baselines_to_usage_manager( Args: quota_results: Dict from _fetch_quota_for_credential or fetch_initial_baselines usage_manager: UsageManager instance to store baselines in + force: If True, always use API values (for manual refresh) + is_initial_fetch: If True, this is the first fetch on startup. + For default providers (API is authoritative), exhaustion + is applied whenever remaining == 0.0 regardless of this flag. Returns: Number of baselines successfully stored @@ -543,13 +549,47 @@ async def _store_baselines_to_usage_manager( max_requests = self.get_max_requests_for_model(user_model, tier) # Store baseline + quota_used = None + if max_requests is not None: + quota_used = int((1.0 - remaining) * max_requests) + quota_group = self.get_model_quota_group(user_model) + + # Only use reset_timestamp when quota is actually used + # (remaining == 1.0 means 100% left, timer is bogus) + bucket = self._find_bucket_for_model(quota_data, user_model) + reset_timestamp = bucket.get("reset_timestamp") if bucket else None + valid_reset_ts = reset_timestamp if remaining < 1.0 else None + + # DEFAULT: Always apply exhaustion if remaining == 0.0 exactly + # (API is authoritative for most providers) + apply_exhaustion = remaining == 0.0 + await usage_manager.update_quota_baseline( - cred_path, prefixed_model, remaining, max_requests=max_requests + cred_path, + prefixed_model, + quota_max_requests=max_requests, + quota_reset_ts=valid_reset_ts, + quota_used=quota_used, + quota_group=quota_group, + force=force, + apply_exhaustion=apply_exhaustion, ) stored_count += 1 return stored_count + def _find_bucket_for_model( + self, quota_data: Dict[str, Any], user_model: str + ) -> Optional[Dict[str, Any]]: + """Find the bucket data for a specific model in quota response.""" + for bucket in quota_data.get("buckets", []): + model_id = bucket.get("model_id") + if model_id: + bucket_user_model = self._api_to_user_model(model_id) + if bucket_user_model == user_model: + return bucket + return None + # ========================================================================= # QUOTA COST DISCOVERY # ========================================================================= diff --git a/src/rotator_library/providers/utilities/gemini_cli_quota_tracker.py b/src/rotator_library/providers/utilities/gemini_cli_quota_tracker.py index 3f86014f..f100f031 100644 --- a/src/rotator_library/providers/utilities/gemini_cli_quota_tracker.py +++ b/src/rotator_library/providers/utilities/gemini_cli_quota_tracker.py @@ -39,7 +39,7 @@ from .gemini_shared_utils import CODE_ASSIST_ENDPOINT if TYPE_CHECKING: - from ...usage_manager import UsageManager + from ...usage import UsageManager # Use the shared rotator_library logger lib_logger = logging.getLogger("rotator_library") diff --git a/src/rotator_library/providers/utilities/gemini_credential_manager.py b/src/rotator_library/providers/utilities/gemini_credential_manager.py index f8f4dfcc..f427c7b4 100644 --- a/src/rotator_library/providers/utilities/gemini_credential_manager.py +++ b/src/rotator_library/providers/utilities/gemini_credential_manager.py @@ -17,7 +17,7 @@ from typing import Any, Dict, List, Optional, TYPE_CHECKING if TYPE_CHECKING: - from ...usage_manager import UsageManager + from ...usage import UsageManager lib_logger = logging.getLogger("rotator_library") @@ -267,20 +267,27 @@ async def run_background_job( f"{provider_name}: Fetching initial quota baselines for {len(credentials)} credentials..." ) quota_results = await self.fetch_initial_baselines(credentials) + is_initial_fetch = True self._initial_quota_fetch_done = True else: # Subsequent runs: only recently used credentials (incremental updates) - usage_data = await usage_manager._get_usage_data_snapshot() + usage_data = await usage_manager.get_usage_snapshot() quota_results = await self.refresh_active_quota_baselines( credentials, usage_data ) + is_initial_fetch = False if not quota_results: return # Store new baselines in UsageManager + # On initial fetch: force=True overwrites with API data, is_initial_fetch enables exhaustion check + # On subsequent: force=False uses max logic, no exhaustion check stored = await self._store_baselines_to_usage_manager( - quota_results, usage_manager + quota_results, + usage_manager, + force=is_initial_fetch, # Force on initial fetch + is_initial_fetch=is_initial_fetch, ) if stored > 0: lib_logger.debug( @@ -320,7 +327,11 @@ async def refresh_active_quota_baselines( ) async def _store_baselines_to_usage_manager( - self, quota_results: Dict[str, Any], usage_manager: "UsageManager" + self, + quota_results: Dict[str, Any], + usage_manager: "UsageManager", + force: bool = False, + is_initial_fetch: bool = False, ) -> int: """Store quota baselines to usage manager. Must be implemented by quota tracker.""" raise NotImplementedError( diff --git a/src/rotator_library/usage/__init__.py b/src/rotator_library/usage/__init__.py new file mode 100644 index 00000000..f8ae947f --- /dev/null +++ b/src/rotator_library/usage/__init__.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Usage tracking and credential selection package. + +This package provides the UsageManager facade and associated components +for tracking API usage, enforcing limits, and selecting credentials. + +Public API: + UsageManager: Main facade for usage tracking and credential selection + CredentialContext: Context manager for credential lifecycle + +Components (for advanced usage): + CredentialRegistry: Stable credential identity management + TrackingEngine: Usage recording and window management + LimitEngine: Limit checking and enforcement + SelectionEngine: Credential selection with strategies + UsageStorage: JSON file persistence +""" + +# Types first (no dependencies on other modules) +from .types import ( + WindowStats, + TotalStats, + ModelStats, + GroupStats, + CredentialState, + CooldownInfo, + FairCycleState, + UsageUpdate, + SelectionContext, + LimitCheckResult, + RotationMode, + ResetMode, + LimitResult, +) + +# Config +from .config import ( + ProviderUsageConfig, + FairCycleConfig, + CustomCapConfig, + WindowDefinition, + load_provider_usage_config, +) + +# Components +from .identity.registry import CredentialRegistry +from .tracking.windows import WindowManager +from .tracking.engine import TrackingEngine +from .limits.engine import LimitEngine +from .selection.engine import SelectionEngine +from .persistence.storage import UsageStorage +from .integration.api import UsageAPI + +# Main facade (imports components above) +from .manager import UsageManager, CredentialContext + +__all__ = [ + # Main public API + "UsageManager", + "CredentialContext", + # Types + "WindowStats", + "TotalStats", + "ModelStats", + "GroupStats", + "CredentialState", + "UsageUpdate", + "CooldownInfo", + "FairCycleState", + "SelectionContext", + "LimitCheckResult", + "RotationMode", + "ResetMode", + "LimitResult", + # Config + "ProviderUsageConfig", + "FairCycleConfig", + "CustomCapConfig", + "WindowDefinition", + "load_provider_usage_config", + # Engines + "CredentialRegistry", + "WindowManager", + "TrackingEngine", + "LimitEngine", + "SelectionEngine", + "UsageStorage", + "UsageAPI", +] diff --git a/src/rotator_library/usage/config.py b/src/rotator_library/usage/config.py new file mode 100644 index 00000000..d8cf39ed --- /dev/null +++ b/src/rotator_library/usage/config.py @@ -0,0 +1,688 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Default configurations for the usage tracking package. + +This module contains default values and configuration loading +for usage tracking, limits, and credential selection. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union + +from ..core.constants import ( + DEFAULT_FAIR_CYCLE_DURATION, + DEFAULT_FAIR_CYCLE_QUOTA_THRESHOLD, + DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, + DEFAULT_ROTATION_TOLERANCE, + DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER, +) +from .types import ResetMode, RotationMode, TrackingMode, CooldownMode + + +# ============================================================================= +# WINDOW CONFIGURATION +# ============================================================================= + + +@dataclass +class WindowDefinition: + """ + Definition of a usage tracking window. + + Used to configure how usage is tracked and when it resets. + """ + + name: str # e.g., "5h", "daily", "weekly" + duration_seconds: Optional[int] # None for infinite/total + reset_mode: ResetMode + is_primary: bool = False # Primary window used for rotation decisions + applies_to: str = "model" # "credential", "model", "group" + + @classmethod + def rolling( + cls, + name: str, + duration_seconds: int, + is_primary: bool = False, + applies_to: str = "model", + ) -> "WindowDefinition": + """Create a rolling window definition.""" + return cls( + name=name, + duration_seconds=duration_seconds, + reset_mode=ResetMode.ROLLING, + is_primary=is_primary, + applies_to=applies_to, + ) + + @classmethod + def daily( + cls, + name: str = "daily", + applies_to: str = "model", + ) -> "WindowDefinition": + """Create a daily fixed window definition.""" + return cls( + name=name, + duration_seconds=86400, + reset_mode=ResetMode.FIXED_DAILY, + applies_to=applies_to, + ) + + +# ============================================================================= +# FAIR CYCLE CONFIGURATION +# ============================================================================= + + +@dataclass +class FairCycleConfig: + """ + Fair cycle rotation configuration. + + Controls how credentials are cycled to ensure fair usage distribution. + """ + + enabled: Optional[bool] = ( + None # None = derive from rotation mode (on for sequential) + ) + tracking_mode: TrackingMode = TrackingMode.MODEL_GROUP + cross_tier: bool = False # Track across all tiers + duration: int = DEFAULT_FAIR_CYCLE_DURATION # Cycle duration in seconds + quota_threshold: float = ( + DEFAULT_FAIR_CYCLE_QUOTA_THRESHOLD # Multiplier of window limit for exhaustion + ) + + +# ============================================================================= +# CUSTOM CAP CONFIGURATION +# ============================================================================= + + +def _parse_duration_string(duration_str: str) -> Optional[int]: + """ + Parse duration strings in various formats to total seconds. + + Handles: + - Plain seconds (no unit): '300', '562476' + - Simple durations: '3600s', '60m', '2h', '1d' + - Compound durations: '2h30m', '1h30m45s', '2d1h30m' + + Args: + duration_str: Duration string to parse + + Returns: + Total seconds as integer, or None if parsing fails. + """ + import re + + if not duration_str: + return None + + remaining = duration_str.strip().lower() + + # Try parsing as plain number first (no units) + try: + return int(float(remaining)) + except ValueError: + pass + + total_seconds = 0.0 + + # Parse days component + day_match = re.match(r"(\d+)d", remaining) + if day_match: + total_seconds += int(day_match.group(1)) * 86400 + remaining = remaining[day_match.end() :] + + # Parse hours component + hour_match = re.match(r"(\d+)h", remaining) + if hour_match: + total_seconds += int(hour_match.group(1)) * 3600 + remaining = remaining[hour_match.end() :] + + # Parse minutes component - use negative lookahead to avoid matching 'ms' + min_match = re.match(r"(\d+)m(?!s)", remaining) + if min_match: + total_seconds += int(min_match.group(1)) * 60 + remaining = remaining[min_match.end() :] + + # Parse seconds component (including decimals) + sec_match = re.match(r"([\d.]+)s", remaining) + if sec_match: + total_seconds += float(sec_match.group(1)) + + if total_seconds > 0: + return int(total_seconds) + return None + + +def _parse_cooldown_config( + mode: Optional[str], + value: Any, +) -> Tuple[CooldownMode, int]: + """ + Parse cooldown configuration from config dict values. + + Supports comprehensive cooldown_value parsing: + - Flat duration: 300, "300", "1h", "30m", "1h30m", "2d1h30m" → fixed seconds + - Offset with sign: "+300", "+1h30m", "-300", "-5m" → offset from natural reset + - Percentage: "+50%", "-20%" → percentage of window duration as offset + (stored as negative value with special encoding: -1000 - percentage) + - String "quota_reset" → use natural reset time + + The cooldown_mode is auto-detected from the value format if not explicitly set: + - Starts with '+' or '-' → CooldownMode.OFFSET + - Just a duration → CooldownMode.FIXED + - "quota_reset" string → CooldownMode.QUOTA_RESET + + Args: + mode: Explicit cooldown mode string, or None to auto-detect + value: Cooldown value (int, str, or various formats) + + Returns: + Tuple of (CooldownMode, cooldown_value in seconds) + """ + # Handle explicit mode with simple value + if mode is not None: + try: + cooldown_mode = CooldownMode(mode) + except ValueError: + cooldown_mode = CooldownMode.QUOTA_RESET + + # Parse value + if isinstance(value, int): + return cooldown_mode, value + elif isinstance(value, str): + parsed = _parse_duration_string(value.lstrip("+-")) + return cooldown_mode, parsed or 0 + else: + return cooldown_mode, 0 + + # Auto-detect mode from value format + if isinstance(value, int): + if value == 0: + return CooldownMode.QUOTA_RESET, 0 + return CooldownMode.FIXED, value + + if isinstance(value, str): + value = value.strip() + + # Check for "quota_reset" string + if value.lower() in ("quota_reset", "quota-reset", "quotareset"): + return CooldownMode.QUOTA_RESET, 0 + + # Check for percentage format: "+50%", "-20%" + if value.endswith("%"): + sign = 1 + val_str = value.rstrip("%") + if val_str.startswith("+"): + val_str = val_str[1:] + elif val_str.startswith("-"): + sign = -1 + val_str = val_str[1:] + try: + percentage = int(val_str) + # Encode percentage as special value: -1000 - (sign * percentage) + # This allows the custom_caps checker to detect and handle percentages + # Range: -1001 to -1100 for +1% to +100%, -999 to -900 for -1% to -100% + encoded = -1000 - (sign * percentage) + return CooldownMode.OFFSET, encoded + except ValueError: + pass + + # Check for offset format: "+300", "+1h30m", "-5m" + if value.startswith("+") or value.startswith("-"): + sign = 1 if value.startswith("+") else -1 + duration_str = value[1:] + parsed = _parse_duration_string(duration_str) + if parsed is not None: + return CooldownMode.OFFSET, sign * parsed + return CooldownMode.OFFSET, 0 + + # Plain duration: "300", "1h30m" + parsed = _parse_duration_string(value) + if parsed is not None: + return CooldownMode.FIXED, parsed + + return CooldownMode.QUOTA_RESET, 0 + + +@dataclass +class CustomCapConfig: + """ + Custom cap configuration for a tier/model combination. + + Allows setting usage limits more restrictive than actual API limits. + """ + + tier_key: str # Priority as string or "default" + model_or_group: str # Model name or quota group name + max_requests: int # Maximum requests allowed + cooldown_mode: CooldownMode = CooldownMode.QUOTA_RESET + cooldown_value: int = 0 # Seconds for offset/fixed modes + + @classmethod + def from_dict( + cls, tier_key: str, model_or_group: str, config: Dict[str, Any] + ) -> "CustomCapConfig": + """ + Create from dictionary config. + + Supports comprehensive cooldown_value parsing: + - Flat duration: 300, "300", "1h", "30m", "1h30m", "2d1h30m" → fixed seconds + - Offset with sign: "+300", "+1h30m", "-300", "-5m" → offset from natural reset + - Percentage: "+50%", "-20%" → percentage of window duration as offset + - String "quota_reset" → use natural reset time + + The cooldown_mode is auto-detected from the value format: + - Starts with '+' or '-' → CooldownMode.OFFSET + - Just a duration → CooldownMode.FIXED + - "quota_reset" string → CooldownMode.QUOTA_RESET + """ + max_requests = config.get("max_requests", 0) + + # Handle percentage strings like "80%" + if isinstance(max_requests, str) and max_requests.endswith("%"): + # Store as negative to indicate percentage + # Will be resolved later when actual limit is known + max_requests = -int(max_requests.rstrip("%")) + + # Parse cooldown configuration + cooldown_mode, cooldown_value = _parse_cooldown_config( + config.get("cooldown_mode"), + config.get("cooldown_value", 0), + ) + + return cls( + tier_key=tier_key, + model_or_group=model_or_group, + max_requests=max_requests, + cooldown_mode=cooldown_mode, + cooldown_value=cooldown_value, + ) + + +# ============================================================================= +# PROVIDER USAGE CONFIG +# ============================================================================= + + +@dataclass +class ProviderUsageConfig: + """ + Complete usage configuration for a provider. + + Combines all settings needed for usage tracking and credential selection. + """ + + # Rotation settings + rotation_mode: RotationMode = RotationMode.BALANCED + rotation_tolerance: float = DEFAULT_ROTATION_TOLERANCE + sequential_fallback_multiplier: int = DEFAULT_SEQUENTIAL_FALLBACK_MULTIPLIER + + # Priority multipliers (priority -> max concurrent) + priority_multipliers: Dict[int, int] = field(default_factory=dict) + priority_multipliers_by_mode: Dict[str, Dict[int, int]] = field( + default_factory=dict + ) + + # Fair cycle + fair_cycle: FairCycleConfig = field(default_factory=FairCycleConfig) + + # Custom caps + custom_caps: List[CustomCapConfig] = field(default_factory=list) + + # Exhaustion threshold (cooldown must exceed this to count as "exhausted") + exhaustion_cooldown_threshold: int = DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD + + # Window limits blocking (if True, block credentials when window quota exhausted locally) + # Default False: only API errors (cooldowns) should block, not local tracking + window_limits_enabled: bool = False + + # Window definitions + windows: List[WindowDefinition] = field(default_factory=list) + + def get_effective_multiplier(self, priority: int) -> int: + """ + Get the effective multiplier for a priority level. + + Checks mode-specific overrides first, then universal multipliers, + then falls back to sequential_fallback_multiplier. + """ + mode_key = self.rotation_mode.value + mode_multipliers = self.priority_multipliers_by_mode.get(mode_key, {}) + + # Check mode-specific first + if priority in mode_multipliers: + return mode_multipliers[priority] + + # Check universal + if priority in self.priority_multipliers: + return self.priority_multipliers[priority] + + # Fall back + return self.sequential_fallback_multiplier + + +# ============================================================================= +# DEFAULT WINDOWS +# ============================================================================= + + +def get_default_windows() -> List[WindowDefinition]: + """ + Get default window definitions. + + Only used when provider doesn't define custom windows via + usage_reset_configs or get_usage_reset_config(). + """ + return [ + WindowDefinition.rolling("daily", 86400, is_primary=True, applies_to="model"), + ] + + +# ============================================================================= +# CONFIG LOADER INTEGRATION +# ============================================================================= + + +def load_provider_usage_config( + provider: str, + provider_plugins: Dict[str, Any], +) -> ProviderUsageConfig: + """ + Load usage configuration for a provider. + + Merges: + 1. System defaults + 2. Provider class attributes + 3. Environment variables (always win) + + Args: + provider: Provider name (e.g., "gemini", "openai") + provider_plugins: Dict of provider plugin classes + + Returns: + Complete configuration for the provider + """ + import os + + config = ProviderUsageConfig() + + # Get plugin class + plugin_class = provider_plugins.get(provider) + + # Apply provider defaults + if plugin_class: + # Rotation mode + if hasattr(plugin_class, "default_rotation_mode"): + config.rotation_mode = RotationMode(plugin_class.default_rotation_mode) + + # Priority multipliers + if hasattr(plugin_class, "default_priority_multipliers"): + config.priority_multipliers = dict( + plugin_class.default_priority_multipliers + ) + + if hasattr(plugin_class, "default_priority_multipliers_by_mode"): + config.priority_multipliers_by_mode = { + k: dict(v) + for k, v in plugin_class.default_priority_multipliers_by_mode.items() + } + + # Sequential fallback multiplier + if hasattr(plugin_class, "default_sequential_fallback_multiplier"): + fallback = plugin_class.default_sequential_fallback_multiplier + if fallback is not None: + config.sequential_fallback_multiplier = fallback + + # Fair cycle + if hasattr(plugin_class, "default_fair_cycle_config"): + fc_config = plugin_class.default_fair_cycle_config + config.fair_cycle = FairCycleConfig( + enabled=fc_config.get("enabled"), + tracking_mode=TrackingMode( + fc_config.get("tracking_mode", "model_group") + ), + cross_tier=fc_config.get("cross_tier", False), + duration=fc_config.get("duration", DEFAULT_FAIR_CYCLE_DURATION), + quota_threshold=fc_config.get( + "quota_threshold", DEFAULT_FAIR_CYCLE_QUOTA_THRESHOLD + ), + ) + else: + if hasattr(plugin_class, "default_fair_cycle_enabled"): + config.fair_cycle.enabled = plugin_class.default_fair_cycle_enabled + if hasattr(plugin_class, "default_fair_cycle_tracking_mode"): + config.fair_cycle.tracking_mode = TrackingMode( + plugin_class.default_fair_cycle_tracking_mode + ) + if hasattr(plugin_class, "default_fair_cycle_cross_tier"): + config.fair_cycle.cross_tier = ( + plugin_class.default_fair_cycle_cross_tier + ) + if hasattr(plugin_class, "default_fair_cycle_duration"): + config.fair_cycle.duration = plugin_class.default_fair_cycle_duration + if hasattr(plugin_class, "default_fair_cycle_quota_threshold"): + config.fair_cycle.quota_threshold = ( + plugin_class.default_fair_cycle_quota_threshold + ) + + # Custom caps + if hasattr(plugin_class, "default_custom_caps"): + for tier_key, models in plugin_class.default_custom_caps.items(): + tier_keys: Tuple[Union[int, str], ...] + if isinstance(tier_key, tuple): + tier_keys = tuple(tier_key) + else: + tier_keys = (tier_key,) + for model_or_group, cap_config in models.items(): + for resolved_tier in tier_keys: + config.custom_caps.append( + CustomCapConfig.from_dict( + str(resolved_tier), model_or_group, cap_config + ) + ) + + # Windows + if hasattr(plugin_class, "usage_window_definitions"): + config.windows = [] + for wdef in plugin_class.usage_window_definitions: + config.windows.append( + WindowDefinition( + name=wdef.get("name", "default"), + duration_seconds=wdef.get("duration_seconds"), + reset_mode=ResetMode(wdef.get("reset_mode", "rolling")), + is_primary=wdef.get("is_primary", False), + applies_to=wdef.get("applies_to", "model"), + ) + ) + + # Use default windows if none defined + if not config.windows: + config.windows = get_default_windows() + + # Apply environment variable overrides + provider_upper = provider.upper() + + # Rotation mode from env + env_mode = os.getenv(f"ROTATION_MODE_{provider_upper}") + if env_mode: + config.rotation_mode = RotationMode(env_mode.lower()) + + # Sequential fallback multiplier + env_fallback = os.getenv(f"SEQUENTIAL_FALLBACK_MULTIPLIER_{provider_upper}") + if env_fallback: + try: + config.sequential_fallback_multiplier = int(env_fallback) + except ValueError: + pass + + # Fair cycle enabled from env + env_fc = os.getenv(f"FAIR_CYCLE_{provider_upper}") + if env_fc is None: + env_fc = os.getenv(f"FAIR_CYCLE_ENABLED_{provider_upper}") + if env_fc: + config.fair_cycle.enabled = env_fc.lower() in ("true", "1", "yes") + + # Fair cycle tracking mode + env_fc_mode = os.getenv(f"FAIR_CYCLE_TRACKING_MODE_{provider_upper}") + if env_fc_mode: + try: + config.fair_cycle.tracking_mode = TrackingMode(env_fc_mode.lower()) + except ValueError: + pass + + # Fair cycle cross-tier + env_fc_cross = os.getenv(f"FAIR_CYCLE_CROSS_TIER_{provider_upper}") + if env_fc_cross: + config.fair_cycle.cross_tier = env_fc_cross.lower() in ("true", "1", "yes") + + # Fair cycle duration from env + env_fc_duration = os.getenv(f"FAIR_CYCLE_DURATION_{provider_upper}") + if env_fc_duration: + try: + config.fair_cycle.duration = int(env_fc_duration) + except ValueError: + pass + + # Fair cycle quota threshold from env + env_fc_quota = os.getenv(f"FAIR_CYCLE_QUOTA_THRESHOLD_{provider_upper}") + if env_fc_quota: + try: + config.fair_cycle.quota_threshold = float(env_fc_quota) + except ValueError: + pass + + # Exhaustion threshold from env + env_threshold = os.getenv(f"EXHAUSTION_COOLDOWN_THRESHOLD_{provider_upper}") + if env_threshold: + try: + config.exhaustion_cooldown_threshold = int(env_threshold) + except ValueError: + pass + + # Priority multipliers from env + # Format: CONCURRENCY_MULTIPLIER_{PROVIDER}_PRIORITY_{N}=value + # Format: CONCURRENCY_MULTIPLIER_{PROVIDER}_PRIORITY_{N}_{MODE}=value + for key, value in os.environ.items(): + prefix = f"CONCURRENCY_MULTIPLIER_{provider_upper}_PRIORITY_" + if key.startswith(prefix): + try: + remainder = key[len(prefix) :] + multiplier = int(value) + if multiplier < 1: + continue + if "_" in remainder: + priority_str, mode = remainder.rsplit("_", 1) + priority = int(priority_str) + mode = mode.lower() + if mode in ("sequential", "balanced"): + config.priority_multipliers_by_mode.setdefault(mode, {})[ + priority + ] = multiplier + else: + config.priority_multipliers[priority] = multiplier + else: + priority = int(remainder) + config.priority_multipliers[priority] = multiplier + except ValueError: + pass + + # Custom caps from env + if os.environ: + cap_map: Dict[str, Dict[str, Dict[str, Any]]] = {} + for cap in config.custom_caps: + cap_entry = cap_map.setdefault(str(cap.tier_key), {}) + cap_entry[cap.model_or_group] = { + "max_requests": cap.max_requests, + "cooldown_mode": cap.cooldown_mode.value, + "cooldown_value": cap.cooldown_value, + } + + cap_prefix = f"CUSTOM_CAP_{provider_upper}_T" + cooldown_prefix = f"CUSTOM_CAP_COOLDOWN_{provider_upper}_T" + for env_key, env_value in os.environ.items(): + if env_key.startswith(cap_prefix) and not env_key.startswith( + cooldown_prefix + ): + remainder = env_key[len(cap_prefix) :] + tier_key, model_key = _parse_custom_cap_env_key(remainder) + if tier_key is None or not model_key: + continue + cap_entry = cap_map.setdefault(str(tier_key), {}) + cap_entry.setdefault(model_key, {})["max_requests"] = env_value + elif env_key.startswith(cooldown_prefix): + remainder = env_key[len(cooldown_prefix) :] + tier_key, model_key = _parse_custom_cap_env_key(remainder) + if tier_key is None or not model_key: + continue + if ":" in env_value: + mode, value_str = env_value.split(":", 1) + try: + value = int(value_str) + except ValueError: + continue + else: + mode = env_value + value = 0 + cap_entry = cap_map.setdefault(str(tier_key), {}) + cap_entry.setdefault(model_key, {})["cooldown_mode"] = mode + cap_entry.setdefault(model_key, {})["cooldown_value"] = value + + config.custom_caps = [] + for tier_key, models in cap_map.items(): + for model_or_group, cap_config in models.items(): + config.custom_caps.append( + CustomCapConfig.from_dict(tier_key, model_or_group, cap_config) + ) + + # Derive fair cycle enabled from rotation mode if not explicitly set + if config.fair_cycle.enabled is None: + config.fair_cycle.enabled = config.rotation_mode == RotationMode.SEQUENTIAL + + return config + + +def _parse_custom_cap_env_key( + remainder: str, +) -> Tuple[Optional[Union[int, Tuple[int, ...], str]], Optional[str]]: + """Parse the tier and model/group from a custom cap env var remainder.""" + if not remainder: + return None, None + + remaining_parts = remainder.split("_") + if len(remaining_parts) < 2: + return None, None + + tier_key: Union[int, Tuple[int, ...], str, None] = None + model_key: Optional[str] = None + tier_parts: List[int] = [] + + for i, part in enumerate(remaining_parts): + if part == "DEFAULT": + tier_key = "default" + model_key = "_".join(remaining_parts[i + 1 :]) + break + if part.isdigit(): + tier_parts.append(int(part)) + continue + + if not tier_parts: + return None, None + if len(tier_parts) == 1: + tier_key = tier_parts[0] + else: + tier_key = tuple(tier_parts) + model_key = "_".join(remaining_parts[i:]) + break + else: + return None, None + + if model_key: + model_key = model_key.lower().replace("_", "-") + + return tier_key, model_key diff --git a/src/rotator_library/usage/identity/__init__.py b/src/rotator_library/usage/identity/__init__.py new file mode 100644 index 00000000..b42af7ff --- /dev/null +++ b/src/rotator_library/usage/identity/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Credential identity management.""" + +from .registry import CredentialRegistry + +__all__ = ["CredentialRegistry"] diff --git a/src/rotator_library/usage/identity/registry.py b/src/rotator_library/usage/identity/registry.py new file mode 100644 index 00000000..ddd867d4 --- /dev/null +++ b/src/rotator_library/usage/identity/registry.py @@ -0,0 +1,270 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Credential identity registry. + +Provides stable identifiers for credentials that persist across +file path changes (for OAuth) and hide sensitive data (for API keys). +""" + +import hashlib +import json +import logging +from pathlib import Path +from typing import Any, Dict, Optional, Set + +from ...core.types import CredentialInfo + +lib_logger = logging.getLogger("rotator_library") + + +class CredentialRegistry: + """ + Manages stable identifiers for credentials. + + Stable IDs are: + - For OAuth credentials: The email address from _proxy_metadata.email + - For API keys: SHA-256 hash of the key (truncated for readability) + + This ensures usage data persists even when: + - OAuth credential files are moved/renamed + - API keys are passed in different orders + """ + + def __init__(self): + # Cache: accessor -> CredentialInfo + self._cache: Dict[str, CredentialInfo] = {} + # Reverse index: stable_id -> accessor + self._id_to_accessor: Dict[str, str] = {} + + def get_stable_id(self, accessor: str, provider: str) -> str: + """ + Get or create a stable ID for a credential accessor. + + Args: + accessor: The credential accessor (file path or API key) + provider: Provider name + + Returns: + Stable identifier string + """ + # Check cache first + if accessor in self._cache: + return self._cache[accessor].stable_id + + # Determine if OAuth or API key + if self._is_oauth_path(accessor): + stable_id = self._get_oauth_stable_id(accessor) + else: + stable_id = self._get_api_key_stable_id(accessor) + + # Cache the result + info = CredentialInfo( + accessor=accessor, + stable_id=stable_id, + provider=provider, + ) + self._cache[accessor] = info + self._id_to_accessor[stable_id] = accessor + + return stable_id + + def get_info(self, accessor: str, provider: str) -> CredentialInfo: + """ + Get complete credential info for an accessor. + + Args: + accessor: The credential accessor + provider: Provider name + + Returns: + CredentialInfo with stable_id and metadata + """ + # Ensure stable ID is computed + self.get_stable_id(accessor, provider) + return self._cache[accessor] + + def get_accessor(self, stable_id: str) -> Optional[str]: + """ + Get the current accessor for a stable ID. + + Args: + stable_id: The stable identifier + + Returns: + Current accessor string, or None if not found + """ + return self._id_to_accessor.get(stable_id) + + def update_accessor(self, stable_id: str, new_accessor: str) -> None: + """ + Update the accessor for a stable ID. + + Used when an OAuth credential file is moved/renamed. + + Args: + stable_id: The stable identifier + new_accessor: New accessor path + """ + old_accessor = self._id_to_accessor.get(stable_id) + if old_accessor and old_accessor in self._cache: + info = self._cache.pop(old_accessor) + info.accessor = new_accessor + self._cache[new_accessor] = info + self._id_to_accessor[stable_id] = new_accessor + + def update_metadata( + self, + accessor: str, + provider: str, + tier: Optional[str] = None, + priority: Optional[int] = None, + display_name: Optional[str] = None, + ) -> None: + """ + Update metadata for a credential. + + Args: + accessor: The credential accessor + provider: Provider name + tier: Tier name (e.g., "standard-tier") + priority: Priority level (lower = higher priority) + display_name: Human-readable name + """ + info = self.get_info(accessor, provider) + if tier is not None: + info.tier = tier + if priority is not None: + info.priority = priority + if display_name is not None: + info.display_name = display_name + + def get_all_accessors(self) -> Set[str]: + """Get all registered accessors.""" + return set(self._cache.keys()) + + def get_all_stable_ids(self) -> Set[str]: + """Get all registered stable IDs.""" + return set(self._id_to_accessor.keys()) + + def clear_cache(self) -> None: + """Clear the internal cache.""" + self._cache.clear() + self._id_to_accessor.clear() + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _is_oauth_path(self, accessor: str) -> bool: + """ + Check if accessor is an OAuth credential file path. + + OAuth paths typically end with .json and exist on disk. + API keys are typically raw strings. + """ + # Simple heuristic: if it looks like a file path with .json, it's OAuth + if accessor.endswith(".json"): + return True + # If it contains path separators, it's likely a file path + if "/" in accessor or "\\" in accessor: + return True + return False + + def _get_oauth_stable_id(self, accessor: str) -> str: + """ + Get stable ID for an OAuth credential. + + Reads the email from _proxy_metadata.email in the credential file. + Falls back to file hash if email not found. + """ + try: + path = Path(accessor) + if path.exists(): + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + + # Try to get email from _proxy_metadata + metadata = data.get("_proxy_metadata", {}) + email = metadata.get("email") + if email: + return email + + # Fallback: try common OAuth fields + for field in ["email", "client_email", "account"]: + if field in data: + return data[field] + + # Last resort: hash the file content + lib_logger.debug( + f"No email found in OAuth credential {accessor}, using content hash" + ) + return self._hash_content(json.dumps(data, sort_keys=True)) + + except Exception as e: + lib_logger.warning(f"Failed to read OAuth credential {accessor}: {e}") + + # Fallback: hash the path + return self._hash_content(accessor) + + def _get_api_key_stable_id(self, accessor: str) -> str: + """ + Get stable ID for an API key. + + Uses truncated SHA-256 hash to hide the actual key. + """ + return self._hash_content(accessor) + + def _hash_content(self, content: str) -> str: + """ + Create a stable hash of content. + + Uses first 12 characters of SHA-256 for readability. + """ + return hashlib.sha256(content.encode()).hexdigest()[:12] + + # ========================================================================= + # SERIALIZATION + # ========================================================================= + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize registry state for persistence. + + Returns: + Dictionary suitable for JSON serialization + """ + return { + "accessor_index": dict(self._id_to_accessor), + "credentials": { + accessor: { + "stable_id": info.stable_id, + "provider": info.provider, + "tier": info.tier, + "priority": info.priority, + "display_name": info.display_name, + } + for accessor, info in self._cache.items() + }, + } + + def from_dict(self, data: Dict[str, Any]) -> None: + """ + Restore registry state from persistence. + + Args: + data: Dictionary from to_dict() + """ + self._id_to_accessor = dict(data.get("accessor_index", {})) + + for accessor, cred_data in data.get("credentials", {}).items(): + info = CredentialInfo( + accessor=accessor, + stable_id=cred_data["stable_id"], + provider=cred_data["provider"], + tier=cred_data.get("tier"), + priority=cred_data.get("priority", 999), + display_name=cred_data.get("display_name"), + ) + self._cache[accessor] = info diff --git a/src/rotator_library/usage/integration/__init__.py b/src/rotator_library/usage/integration/__init__.py new file mode 100644 index 00000000..4e58bdd8 --- /dev/null +++ b/src/rotator_library/usage/integration/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Integration helpers for usage manager.""" + +from .hooks import HookDispatcher +from .api import UsageAPI + +__all__ = ["HookDispatcher", "UsageAPI"] diff --git a/src/rotator_library/usage/integration/api.py b/src/rotator_library/usage/integration/api.py new file mode 100644 index 00000000..30f988c6 --- /dev/null +++ b/src/rotator_library/usage/integration/api.py @@ -0,0 +1,377 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Usage API Facade for Reading and Updating Usage Data. + +This module provides a clean, public API for programmatically interacting with +usage data. It's accessible via `usage_manager.api` and is intended for: + + - Admin endpoints (viewing/modifying credential state) + - Background jobs (quota refresh, cleanup tasks) + - Monitoring and alerting (checking remaining quota) + - External tooling and integrations + - Provider-specific logic that needs to inspect/modify state + +============================================================================= +ACCESSING THE API +============================================================================= + +The UsageAPI is available as a property on UsageManager: + + # From RotatingClient + usage_manager = client.get_usage_manager("my_provider") + api = usage_manager.api + + # Or if you have the manager directly + api = usage_manager.api + +============================================================================= +AVAILABLE METHODS +============================================================================= + +Reading State +------------- + + # Get state for a specific credential + state = api.get_state("path/to/credential.json") + if state: + print(f"Total requests: {state.totals.request_count}") + print(f"Total successes: {state.totals.success_count}") + print(f"Total failures: {state.totals.failure_count}") + + # Get all credential states + all_states = api.get_all_states() + for stable_id, state in all_states.items(): + print(f"{stable_id}: {state.totals.request_count} requests") + + # Check remaining quota in a window + remaining = api.get_window_remaining( + accessor="path/to/credential.json", + window_name="5h", + model="gpt-4o", # Optional: specific model + quota_group="gpt4", # Optional: quota group + ) + print(f"Remaining in 5h window: {remaining}") + +Modifying State +--------------- + + # Apply a manual cooldown + await api.apply_cooldown( + accessor="path/to/credential.json", + duration=1800.0, # 30 minutes + reason="manual_override", + model_or_group="gpt4", # Optional: scope to model/group + ) + + # Clear a cooldown + await api.clear_cooldown( + accessor="path/to/credential.json", + model_or_group="gpt4", # Optional: scope + ) + + # Mark credential as exhausted for fair cycle + await api.mark_exhausted( + accessor="path/to/credential.json", + model_or_group="gpt4", + reason="quota_exceeded", + ) + +============================================================================= +CREDENTIAL STATE STRUCTURE +============================================================================= + +CredentialState contains: + + state.accessor # File path or API key + state.display_name # Human-readable name (e.g., email) + state.tier # Tier name (e.g., "standard-tier") + state.priority # Priority level (1 = highest) + state.active_requests # Currently in-flight requests + + state.totals # TotalStats - credential-level totals + state.model_usage # Dict[model, ModelStats] + state.group_usage # Dict[group, GroupStats] + + state.cooldowns # Dict[key, CooldownState] + state.fair_cycle # Dict[key, FairCycleState] + +ModelStats / GroupStats contain: + + stats.windows # Dict[name, WindowStats] - time-based windows + stats.totals # TotalStats - all-time totals for this scope + +TotalStats contains: + + totals.request_count + totals.success_count + totals.failure_count + totals.prompt_tokens + totals.completion_tokens + totals.thinking_tokens + totals.output_tokens + totals.prompt_tokens_cache_read + totals.prompt_tokens_cache_write + totals.total_tokens + totals.approx_cost + totals.first_used_at + totals.last_used_at + +WindowStats contains: + + window.request_count + window.success_count + window.failure_count + window.prompt_tokens + window.completion_tokens + window.thinking_tokens + window.output_tokens + window.prompt_tokens_cache_read + window.prompt_tokens_cache_write + window.total_tokens + window.approx_cost + window.started_at + window.reset_at + window.limit + window.remaining # Computed: limit - request_count (if limit set) + +============================================================================= +EXAMPLE: BUILDING AN ADMIN ENDPOINT +============================================================================= + + from fastapi import APIRouter + from rotator_library import RotatingClient + + router = APIRouter() + + @router.get("/admin/credentials/{provider}") + async def list_credentials(provider: str): + usage_manager = client.get_usage_manager(provider) + if not usage_manager: + return {"error": "Provider not found"} + + api = usage_manager.api + result = [] + + for stable_id, state in api.get_all_states().items(): + result.append({ + "id": stable_id, + "accessor": state.accessor, + "tier": state.tier, + "priority": state.priority, + "requests": state.totals.request_count, + "successes": state.totals.success_count, + "failures": state.totals.failure_count, + "cooldowns": [ + {"key": k, "remaining": v.remaining_seconds} + for k, v in state.cooldowns.items() + if v.is_active + ], + }) + + return {"credentials": result} + + @router.post("/admin/credentials/{provider}/{accessor}/cooldown") + async def apply_cooldown(provider: str, accessor: str, duration: float): + usage_manager = client.get_usage_manager(provider) + api = usage_manager.api + await api.apply_cooldown(accessor, duration, reason="admin") + return {"status": "cooldown applied"} + +============================================================================= +""" + +from typing import Any, Dict, Optional, TYPE_CHECKING + +from ..types import CredentialState + +if TYPE_CHECKING: + from ..manager import UsageManager + + +class UsageAPI: + """ + Public API facade for reading and updating usage data. + + Provides a clean interface for external code to interact with usage + tracking without needing to understand the internal component structure. + + Access via: usage_manager.api + + Example: + api = usage_manager.api + state = api.get_state("path/to/credential.json") + remaining = api.get_window_remaining("path/to/cred.json", "5h", "gpt-4o") + await api.apply_cooldown("path/to/cred.json", 1800.0, "manual") + """ + + def __init__(self, manager: "UsageManager"): + """ + Initialize the API facade. + + Args: + manager: The UsageManager instance to wrap. + """ + self._manager = manager + + def get_state(self, accessor: str) -> Optional[CredentialState]: + """ + Get the credential state for a given accessor. + + Args: + accessor: Credential file path or API key. + + Returns: + CredentialState if found, None otherwise. + + Example: + state = api.get_state("oauth_creds/my_cred.json") + if state: + print(f"Requests: {state.totals.request_count}") + """ + stable_id = self._manager.registry.get_stable_id( + accessor, self._manager.provider + ) + return self._manager.states.get(stable_id) + + def get_all_states(self) -> Dict[str, CredentialState]: + """ + Get all credential states. + + Returns: + Dict mapping stable_id to CredentialState. + + Example: + for stable_id, state in api.get_all_states().items(): + print(f"{stable_id}: {state.totals.request_count} requests") + """ + return dict(self._manager.states) + + def get_window_remaining( + self, + accessor: str, + window_name: str, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> Optional[int]: + """ + Get remaining requests in a usage window. + + Args: + accessor: Credential file path or API key. + window_name: Window name (e.g., "5h", "daily"). + model: Optional model to check (uses model-specific window). + quota_group: Optional quota group to check. + + Returns: + Remaining requests (limit - used), or None if: + - Credential not found + - Window has no limit set + + Example: + remaining = api.get_window_remaining("cred.json", "5h", model="gpt-4o") + if remaining is not None and remaining < 10: + print("Warning: low quota remaining") + """ + state = self.get_state(accessor) + if not state: + return None + return self._manager.limits.window_checker.get_remaining( + state, window_name, model=model, quota_group=quota_group + ) + + async def apply_cooldown( + self, + accessor: str, + duration: float, + reason: str = "manual", + model_or_group: Optional[str] = None, + ) -> None: + """ + Apply a cooldown to a credential. + + The credential will not be selected for requests until the cooldown + expires or is cleared. + + Args: + accessor: Credential file path or API key. + duration: Cooldown duration in seconds. + reason: Reason for cooldown (for logging/debugging). + model_or_group: Optional scope (model name or quota group). + If None, applies to credential globally. + + Example: + # Global cooldown + await api.apply_cooldown("cred.json", 1800.0, "maintenance") + + # Model-specific cooldown + await api.apply_cooldown("cred.json", 3600.0, "quota", "gpt-4o") + """ + await self._manager.apply_cooldown( + accessor=accessor, + duration=duration, + reason=reason, + model_or_group=model_or_group, + ) + + async def clear_cooldown( + self, + accessor: str, + model_or_group: Optional[str] = None, + ) -> None: + """ + Clear a cooldown from a credential. + + Args: + accessor: Credential file path or API key. + model_or_group: Optional scope to clear. If None, clears all. + + Example: + # Clear specific cooldown + await api.clear_cooldown("cred.json", "gpt-4o") + + # Clear all cooldowns + await api.clear_cooldown("cred.json") + """ + stable_id = self._manager.registry.get_stable_id( + accessor, self._manager.provider + ) + state = self._manager.states.get(stable_id) + if state: + await self._manager.tracking.clear_cooldown( + state=state, + model_or_group=model_or_group, + ) + + async def mark_exhausted( + self, + accessor: str, + model_or_group: str, + reason: str, + ) -> None: + """ + Mark a credential as exhausted for fair cycle. + + The credential will be skipped during selection until all other + credentials in the same tier are also exhausted, at which point + the fair cycle resets. + + Args: + accessor: Credential file path or API key. + model_or_group: Model name or quota group to mark exhausted. + reason: Reason for exhaustion (for logging/debugging). + + Example: + await api.mark_exhausted("cred.json", "gpt-4o", "quota_exceeded") + """ + stable_id = self._manager.registry.get_stable_id( + accessor, self._manager.provider + ) + state = self._manager.states.get(stable_id) + if state: + await self._manager.tracking.mark_exhausted( + state=state, + model_or_group=model_or_group, + reason=reason, + ) diff --git a/src/rotator_library/usage/integration/hooks.py b/src/rotator_library/usage/integration/hooks.py new file mode 100644 index 00000000..1f59446d --- /dev/null +++ b/src/rotator_library/usage/integration/hooks.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Provider Hook Dispatcher for Usage Manager. + +This module bridges provider plugins to the usage manager, allowing providers +to customize how requests are counted, cooled down, and tracked. + +============================================================================= +OVERVIEW +============================================================================= + +The HookDispatcher calls provider hooks at key points in the request lifecycle. +Currently, the main hook is `on_request_complete`, which is called after every +request (success or failure) and allows the provider to override: + + - Request count (how many requests to record) + - Cooldown duration (custom cooldown to apply) + - Exhaustion state (mark credential for fair cycle) + +============================================================================= +IMPLEMENTING on_request_complete IN YOUR PROVIDER +============================================================================= + +Add this method to your provider class: + + from rotator_library.core.types import RequestCompleteResult + + def on_request_complete( + self, + credential: str, + model: str, + success: bool, + response: Optional[Any], + error: Optional[Any], + ) -> Optional[RequestCompleteResult]: + ''' + Called after each request completes. + + Args: + credential: Credential accessor (file path or API key) + model: Model that was called + success: Whether the request succeeded + response: Response object (if success=True) + error: ClassifiedError object (if success=False) + + Returns: + RequestCompleteResult to override behavior, or None for defaults. + ''' + # Your logic here + return None # Use default behavior + +============================================================================= +RequestCompleteResult FIELDS +============================================================================= + + count_override: Optional[int] + How many requests to count for usage tracking. + - 0 = Don't count this request (e.g., server errors) + - N = Count as N requests (e.g., internal retries) + - None = Use default (1) + + cooldown_override: Optional[float] + Seconds to cool down this credential. + - Applied in addition to any error-based cooldown. + - Use for custom rate limiting logic. + + force_exhausted: bool + Mark credential as exhausted for fair cycle. + - True = Skip this credential until fair cycle resets. + - Useful for quota errors without long cooldowns. + +============================================================================= +USE CASE: COUNTING INTERNAL RETRIES +============================================================================= + +If your provider performs internal retries (e.g., for transient errors, empty +responses, or malformed responses), each retry is an API call that should be +counted. Use the ContextVar pattern for thread-safe counting: + + from contextvars import ContextVar + from rotator_library.core.types import RequestCompleteResult + + # Module-level: each async task gets its own isolated value + _internal_attempt_count: ContextVar[int] = ContextVar( + 'my_provider_attempt_count', default=1 + ) + + class MyProvider: + + async def _make_request_with_retry(self, ...): + # Reset at start of request + _internal_attempt_count.set(1) + + for attempt in range(max_attempts): + try: + result = await self._call_api(...) + return result # Success + except RetryableError: + # Increment before retry + _internal_attempt_count.set(_internal_attempt_count.get() + 1) + continue + + def on_request_complete(self, credential, model, success, response, error): + # Report actual API call count + count = _internal_attempt_count.get() + _internal_attempt_count.set(1) # Reset for safety + + if count > 1: + logging.debug(f"Request used {count} API calls (internal retries)") + + return RequestCompleteResult(count_override=count) + +Why ContextVar? + - Instance variables (self.count) are shared across concurrent requests + - ContextVar gives each async task its own isolated value + - Thread-safe without explicit locking + +============================================================================= +USE CASE: CUSTOM ERROR HANDLING +============================================================================= + +Override counting or cooldown based on error type: + + def on_request_complete(self, credential, model, success, response, error): + if not success and error: + # Don't count server errors against quota + if error.error_type == "server_error": + return RequestCompleteResult(count_override=0) + + # Force exhaustion on quota errors + if error.error_type == "quota_exceeded": + return RequestCompleteResult( + force_exhausted=True, + cooldown_override=3600.0, # 1 hour + ) + + # Custom cooldown for rate limits + if error.error_type == "rate_limit": + retry_after = getattr(error, "retry_after", 60) + return RequestCompleteResult(cooldown_override=retry_after) + + return None # Default behavior + +============================================================================= +""" + +import asyncio +from typing import Any, Dict, Optional + +from ...core.types import RequestCompleteResult + + +class HookDispatcher: + """ + Dispatch optional provider hooks during request lifecycle. + + The HookDispatcher is instantiated by UsageManager with the provider plugins + dict. It lazily instantiates provider instances and calls their hooks. + + Currently supported hooks: + - on_request_complete: Called after each request completes + + Usage: + dispatcher = HookDispatcher(provider_plugins) + result = await dispatcher.dispatch_request_complete( + provider="my_provider", + credential="path/to/cred.json", + model="my-model", + success=True, + response=response_obj, + error=None, + ) + if result and result.count_override is not None: + request_count = result.count_override + """ + + def __init__(self, provider_plugins: Optional[Dict[str, Any]] = None): + """ + Initialize the hook dispatcher. + + Args: + provider_plugins: Dict mapping provider names to plugin classes. + Classes are lazily instantiated on first hook call. + """ + self._plugins = provider_plugins or {} + + def _get_instance(self, provider: str) -> Optional[Any]: + """Get provider plugin instance (singleton via metaclass).""" + plugin_class = self._plugins.get(provider) + if not plugin_class: + return None + if isinstance(plugin_class, type): + return plugin_class() # Singleton - always returns same instance + return plugin_class + + async def dispatch_request_complete( + self, + provider: str, + credential: str, + model: str, + success: bool, + response: Optional[Any], + error: Optional[Any], + ) -> Optional[RequestCompleteResult]: + """ + Dispatch the on_request_complete hook to a provider. + + Called by UsageManager after each request completes (success or failure). + The provider can return a RequestCompleteResult to override default + behavior for request counting, cooldowns, or exhaustion marking. + + Args: + provider: Provider name (e.g., "antigravity", "openai") + credential: Credential accessor (file path or API key) + model: Model that was called (with provider prefix) + success: Whether the request succeeded + response: Response object if success=True, else None + error: ClassifiedError if success=False, else None + + Returns: + RequestCompleteResult from provider, or None if: + - Provider not found in plugins + - Provider doesn't implement on_request_complete + - Provider returns None (use default behavior) + """ + plugin = self._get_instance(provider) + if not plugin or not hasattr(plugin, "on_request_complete"): + return None + + result = plugin.on_request_complete(credential, model, success, response, error) + if asyncio.iscoroutine(result): + result = await result + + return result diff --git a/src/rotator_library/usage/limits/__init__.py b/src/rotator_library/usage/limits/__init__.py new file mode 100644 index 00000000..e759c312 --- /dev/null +++ b/src/rotator_library/usage/limits/__init__.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Limit checking and enforcement.""" + +from .engine import LimitEngine +from .base import LimitChecker +from .window_limits import WindowLimitChecker +from .cooldowns import CooldownChecker +from .fair_cycle import FairCycleChecker +from .custom_caps import CustomCapChecker + +__all__ = [ + "LimitEngine", + "LimitChecker", + "WindowLimitChecker", + "CooldownChecker", + "FairCycleChecker", + "CustomCapChecker", +] diff --git a/src/rotator_library/usage/limits/base.py b/src/rotator_library/usage/limits/base.py new file mode 100644 index 00000000..c6c0a4d8 --- /dev/null +++ b/src/rotator_library/usage/limits/base.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Base interface for limit checkers. + +All limit types implement this interface for consistent behavior. +""" + +from abc import ABC, abstractmethod +from typing import Optional + +from ..types import CredentialState, LimitCheckResult, LimitResult + + +class LimitChecker(ABC): + """ + Abstract base class for limit checkers. + + Each limit type (window, cooldown, fair cycle, custom cap) + implements this interface. + """ + + @property + @abstractmethod + def name(self) -> str: + """Name of this limit checker.""" + ... + + @abstractmethod + def check( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check if a credential passes this limit. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating pass/fail and reason + """ + ... + + def reset( + self, + state: CredentialState, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> None: + """ + Reset this limit for a credential. + + Default implementation does nothing - override if needed. + + Args: + state: Credential state to reset + model: Optional model scope + quota_group: Optional quota group scope + """ + pass diff --git a/src/rotator_library/usage/limits/concurrent.py b/src/rotator_library/usage/limits/concurrent.py new file mode 100644 index 00000000..83a510cb --- /dev/null +++ b/src/rotator_library/usage/limits/concurrent.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Concurrent request limit checker. + +Blocks credentials that have reached their max_concurrent limit. +""" + +from typing import Optional + +from ..types import CredentialState, LimitCheckResult, LimitResult +from .base import LimitChecker + + +class ConcurrentLimitChecker(LimitChecker): + """ + Checks concurrent request limits. + + Blocks credentials that have active_requests >= max_concurrent. + This ensures we don't overload any single credential. + """ + + @property + def name(self) -> str: + return "concurrent" + + def check( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check if credential is at max concurrent. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating pass/fail + """ + # If no limit set, always allow + if state.max_concurrent is None: + return LimitCheckResult.ok() + + # Check if at or above limit + if state.active_requests >= state.max_concurrent: + return LimitCheckResult.blocked( + result=LimitResult.BLOCKED_CONCURRENT, + reason=f"At max concurrent: {state.active_requests}/{state.max_concurrent}", + blocked_until=None, # No specific time - depends on request completion + ) + + return LimitCheckResult.ok() + + def reset( + self, + state: CredentialState, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> None: + """ + Reset concurrent count. + + Note: This is rarely needed as active_requests is + managed by acquire/release, not limit checking. + """ + # Typically don't reset active_requests via limit system + pass diff --git a/src/rotator_library/usage/limits/cooldowns.py b/src/rotator_library/usage/limits/cooldowns.py new file mode 100644 index 00000000..bf08a4a4 --- /dev/null +++ b/src/rotator_library/usage/limits/cooldowns.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Cooldown checker. + +Checks if a credential is currently in cooldown. +""" + +import time +from typing import Optional + +from ..types import CredentialState, LimitCheckResult, LimitResult +from .base import LimitChecker + + +class CooldownChecker(LimitChecker): + """ + Checks cooldown status for credentials. + + Blocks credentials that are currently cooling down from + rate limits, errors, or other causes. + """ + + @property + def name(self) -> str: + return "cooldowns" + + def check( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check if credential is in cooldown. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating pass/fail + """ + now = time.time() + group_key = quota_group or model + + # Check model/group-specific cooldowns + keys_to_check = [] + if group_key: + keys_to_check.append(group_key) + if quota_group and quota_group != model: + keys_to_check.append(model) + + for key in keys_to_check: + cooldown = state.cooldowns.get(key) + if cooldown and cooldown.until > now: + return LimitCheckResult.blocked( + result=LimitResult.BLOCKED_COOLDOWN, + reason=f"Cooldown for '{key}': {cooldown.reason} (expires in {cooldown.remaining_seconds:.0f}s)", + blocked_until=cooldown.until, + ) + + # Check global cooldown + global_cooldown = state.cooldowns.get("_global_") + if global_cooldown and global_cooldown.until > now: + return LimitCheckResult.blocked( + result=LimitResult.BLOCKED_COOLDOWN, + reason=f"Global cooldown: {global_cooldown.reason} (expires in {global_cooldown.remaining_seconds:.0f}s)", + blocked_until=global_cooldown.until, + ) + + return LimitCheckResult.ok() + + def reset( + self, + state: CredentialState, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> None: + """ + Clear cooldown for a credential. + + Args: + state: Credential state + model: Optional model scope + quota_group: Optional quota group scope + """ + if quota_group: + if quota_group in state.cooldowns: + del state.cooldowns[quota_group] + elif model: + if model in state.cooldowns: + del state.cooldowns[model] + else: + # Clear all cooldowns + state.cooldowns.clear() + + def get_cooldown_end( + self, + state: CredentialState, + model_or_group: Optional[str] = None, + ) -> Optional[float]: + """ + Get when cooldown ends for a credential. + + Args: + state: Credential state + model_or_group: Optional scope to check + + Returns: + Timestamp when cooldown ends, or None if not in cooldown + """ + now = time.time() + + # Check specific scope + if model_or_group: + cooldown = state.cooldowns.get(model_or_group) + if cooldown and cooldown.until > now: + return cooldown.until + + # Check global + global_cooldown = state.cooldowns.get("_global_") + if global_cooldown and global_cooldown.until > now: + return global_cooldown.until + + return None diff --git a/src/rotator_library/usage/limits/custom_caps.py b/src/rotator_library/usage/limits/custom_caps.py new file mode 100644 index 00000000..86ffc78a --- /dev/null +++ b/src/rotator_library/usage/limits/custom_caps.py @@ -0,0 +1,359 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Custom cap limit checker. + +Enforces user-defined limits on API usage. +""" + +import time +import logging +from typing import Dict, List, Optional, Tuple + +from ..types import CredentialState, LimitCheckResult, LimitResult, WindowStats +from ..config import CustomCapConfig, CooldownMode +from ..tracking.windows import WindowManager +from .base import LimitChecker + +lib_logger = logging.getLogger("rotator_library") + + +# Scope constants for cap application +SCOPE_MODEL = "model" +SCOPE_GROUP = "group" + + +class CustomCapChecker(LimitChecker): + """ + Checks custom cap limits. + + Custom caps allow users to set custom usage limits per tier/model/group. + Limits can be absolute numbers or percentages of API limits. + + Caps are checked independently for both model AND group scopes: + - A model cap being exceeded blocks only that model + - A group cap being exceeded blocks the entire group + - Both caps can exist and are checked separately (first blocked wins) + """ + + def __init__( + self, + caps: List[CustomCapConfig], + window_manager: WindowManager, + ): + """ + Initialize custom cap checker. + + Args: + caps: List of custom cap configurations + window_manager: WindowManager for checking window usage + """ + self._caps = caps + self._windows = window_manager + # Index caps by (tier_key, model_or_group) for fast lookup + self._cap_index: Dict[tuple, CustomCapConfig] = {} + for cap in caps: + self._cap_index[(cap.tier_key, cap.model_or_group)] = cap + + @property + def name(self) -> str: + return "custom_caps" + + def check( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check if any custom cap is exceeded. + + Checks both model and group caps independently. If both exist, + each is checked against its respective usage scope. First blocked + cap wins. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating pass/fail + """ + if not self._caps: + return LimitCheckResult.ok() + + primary_def = self._windows.get_primary_definition() + if primary_def is None: + return LimitCheckResult.ok() + + priority = state.priority + + # Find all applicable caps (model + group separately) + all_caps = self._find_all_caps(str(priority), model, quota_group) + if not all_caps: + return LimitCheckResult.ok() + + # Check each cap against its proper scope + for cap, scope, scope_key in all_caps: + result = self._check_single_cap( + state, cap, scope, scope_key, model, quota_group + ) + if not result.allowed: + return result + + return LimitCheckResult.ok() + + def get_cap_for( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> Optional[CustomCapConfig]: + """ + Get the first applicable custom cap for a credential/model. + + Args: + state: Credential state + model: Model name + quota_group: Quota group + + Returns: + CustomCapConfig if one applies, None otherwise + """ + priority = state.priority + all_caps = self._find_all_caps(str(priority), model, quota_group) + if all_caps: + return all_caps[0][0] # Return first cap + return None + + def get_all_caps_for( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> List[Tuple[CustomCapConfig, str, str]]: + """ + Get all applicable custom caps for a credential/model. + + Args: + state: Credential state + model: Model name + quota_group: Quota group + + Returns: + List of (cap, scope, scope_key) tuples + """ + priority = state.priority + return self._find_all_caps(str(priority), model, quota_group) + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _find_all_caps( + self, + priority_key: str, + model: str, + quota_group: Optional[str], + ) -> List[Tuple[CustomCapConfig, str, str]]: + """ + Find all applicable caps for a request. + + Returns caps for both model AND group scopes (if they exist). + Each cap is returned with its scope type and scope key. + + Args: + priority_key: Priority level as string + model: Model name + quota_group: Quota group (optional) + + Returns: + List of (cap, scope, scope_key) tuples where: + - cap: The CustomCapConfig + - scope: SCOPE_MODEL or SCOPE_GROUP + - scope_key: The model name or group name + """ + result: List[Tuple[CustomCapConfig, str, str]] = [] + + # Check model cap (priority-specific, then default) + model_cap = self._cap_index.get((priority_key, model)) or self._cap_index.get( + ("default", model) + ) + if model_cap: + result.append((model_cap, SCOPE_MODEL, model)) + + # Check group cap (priority-specific, then default) - only if group differs from model + if quota_group and quota_group != model: + group_cap = self._cap_index.get( + (priority_key, quota_group) + ) or self._cap_index.get(("default", quota_group)) + if group_cap: + result.append((group_cap, SCOPE_GROUP, quota_group)) + + return result + + def _check_single_cap( + self, + state: CredentialState, + cap: CustomCapConfig, + scope: str, + scope_key: str, + model: str, + quota_group: Optional[str], + ) -> LimitCheckResult: + """ + Check a single cap against its appropriate usage scope. + + Args: + state: Credential state + cap: The cap configuration to check + scope: SCOPE_MODEL or SCOPE_GROUP + scope_key: The model name or group name + model: Original model name (for fallback) + quota_group: Original quota group (for fallback) + + Returns: + LimitCheckResult for this specific cap + """ + # Get windows based on scope + windows = None + + if scope == SCOPE_GROUP: + group_stats = state.get_group_stats(scope_key, create=False) + if group_stats: + windows = group_stats.windows + else: # SCOPE_MODEL + model_stats = state.get_model_stats(scope_key, create=False) + if model_stats: + windows = model_stats.windows + + if windows is None: + return LimitCheckResult.ok() + + # Get usage from primary window + primary_window = self._windows.get_primary_window(windows) + if primary_window is None: + return LimitCheckResult.ok() + + current_usage = primary_window.request_count + max_requests = self._resolve_max_requests(cap, primary_window.limit) + + if current_usage >= max_requests: + # Calculate cooldown end + cooldown_until = self._calculate_cooldown_until(cap, primary_window) + + # Build descriptive reason with scope info + scope_desc = "model" if scope == SCOPE_MODEL else "group" + reason = ( + f"Custom cap for {scope_desc} '{scope_key}' exceeded " + f"({current_usage}/{max_requests})" + ) + + return LimitCheckResult.blocked( + result=LimitResult.BLOCKED_CUSTOM_CAP, + reason=reason, + blocked_until=cooldown_until, + ) + + return LimitCheckResult.ok() + + def _find_cap( + self, + priority_key: str, + group_key: str, + model: str, + ) -> Optional[CustomCapConfig]: + """ + Find the most specific applicable cap (legacy method for compatibility). + + Deprecated: Use _find_all_caps() for layered cap checking. + """ + # Try exact matches first + # Priority + group + cap = self._cap_index.get((priority_key, group_key)) + if cap: + return cap + + # Priority + model (if different from group) + if model != group_key: + cap = self._cap_index.get((priority_key, model)) + if cap: + return cap + + # Default tier + group + cap = self._cap_index.get(("default", group_key)) + if cap: + return cap + + # Default tier + model + if model != group_key: + cap = self._cap_index.get(("default", model)) + if cap: + return cap + + return None + + def _resolve_max_requests( + self, + cap: CustomCapConfig, + window_limit: Optional[int], + ) -> int: + """ + Resolve max requests, handling percentage values. + + Custom caps can be ANY value - higher or lower than API limits. + Negative values indicate percentage of the window limit. + """ + if cap.max_requests >= 0: + # Absolute value - use as-is (no clamping) + return cap.max_requests + + # Negative value indicates percentage + if window_limit is None: + # No window limit known, use a high default + return 1000 + + percentage = -cap.max_requests + calculated = int(window_limit * percentage / 100) + return calculated + + def _calculate_cooldown_until( + self, + cap: CustomCapConfig, + window: WindowStats, + ) -> Optional[float]: + """ + Calculate when the custom cap cooldown ends. + + Modes: + - QUOTA_RESET: Wait until natural window reset + - OFFSET: Add/subtract offset from natural reset (clamped to >= reset) + - FIXED: Fixed duration from now + """ + now = time.time() + natural_reset = window.reset_at + + if cap.cooldown_mode == CooldownMode.QUOTA_RESET: + # Wait until window resets + return natural_reset + + elif cap.cooldown_mode == CooldownMode.OFFSET: + # Offset from natural reset time + # Positive offset = wait AFTER reset + # Negative offset = wait BEFORE reset (clamped to >= reset for safety) + if natural_reset: + calculated = natural_reset + cap.cooldown_value + # Always clamp to at least natural_reset (can't end before quota resets) + return max(calculated, natural_reset) + else: + # No natural reset known, use absolute offset from now + return now + abs(cap.cooldown_value) + + elif cap.cooldown_mode == CooldownMode.FIXED: + # Fixed duration from now + calculated = now + cap.cooldown_value + return calculated + + return None diff --git a/src/rotator_library/usage/limits/engine.py b/src/rotator_library/usage/limits/engine.py new file mode 100644 index 00000000..bfe7d922 --- /dev/null +++ b/src/rotator_library/usage/limits/engine.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Limit engine for orchestrating limit checks. + +Central component that runs all limit checkers and determines +if a credential is available for use. +""" + +import logging +from typing import Dict, List, Optional + +from ..types import CredentialState, LimitCheckResult, LimitResult +from ..config import ProviderUsageConfig +from ..tracking.windows import WindowManager +from .base import LimitChecker +from .concurrent import ConcurrentLimitChecker +from .window_limits import WindowLimitChecker +from .cooldowns import CooldownChecker +from .fair_cycle import FairCycleChecker +from .custom_caps import CustomCapChecker + +lib_logger = logging.getLogger("rotator_library") + + +class LimitEngine: + """ + Central engine for limit checking. + + Orchestrates all limit checkers and provides a single entry point + for determining credential availability. + """ + + def __init__( + self, + config: ProviderUsageConfig, + window_manager: WindowManager, + ): + """ + Initialize limit engine. + + Args: + config: Provider usage configuration + window_manager: WindowManager for window-based checks + """ + self._config = config + self._window_manager = window_manager + + # Initialize all limit checkers + # Order matters: concurrent first (fast check), then others + # Note: WindowLimitChecker is optional - only included if window_limits_enabled + self._checkers: List[LimitChecker] = [ + ConcurrentLimitChecker(), + CooldownChecker(), + ] + + # Window limit checker - kept as reference for info purposes, + # only added to blocking checkers if explicitly enabled + self._window_checker = WindowLimitChecker(window_manager) + if config.window_limits_enabled: + self._checkers.append(self._window_checker) + + # Custom caps and fair cycle always active + self._custom_cap_checker = CustomCapChecker(config.custom_caps, window_manager) + self._fair_cycle_checker = FairCycleChecker(config.fair_cycle, window_manager) + self._checkers.append(self._custom_cap_checker) + self._checkers.append(self._fair_cycle_checker) + + # Quick access to specific checkers + self._concurrent_checker = self._checkers[0] + self._cooldown_checker = self._checkers[1] + + def check_all( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check all limits for a credential. + + Runs all limit checkers in order and returns the first failure, + or success if all pass. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating overall pass/fail + """ + for checker in self._checkers: + result = checker.check(state, model, quota_group) + if not result.allowed: + lib_logger.debug( + f"Credential {state.stable_id} blocked by {checker.name}: {result.reason}" + ) + return result + + return LimitCheckResult.ok() + + def check_specific( + self, + checker_name: str, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check a specific limit type. + + Args: + checker_name: Name of the checker ("cooldowns", "window_limits", etc.) + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult from the specified checker + """ + for checker in self._checkers: + if checker.name == checker_name: + return checker.check(state, model, quota_group) + + # Unknown checker - return ok + return LimitCheckResult.ok() + + def get_available_candidates( + self, + states: List[CredentialState], + model: str, + quota_group: Optional[str] = None, + ) -> List[CredentialState]: + """ + Filter credentials to only those passing all limits. + + Args: + states: List of credential states to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + List of available credential states + """ + available = [] + for state in states: + result = self.check_all(state, model, quota_group) + if result.allowed: + available.append(state) + + return available + + def get_blocking_info( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> Dict[str, LimitCheckResult]: + """ + Get detailed blocking info for each limit type. + + Useful for debugging and status reporting. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + Dict mapping checker name to its result + """ + results = {} + for checker in self._checkers: + results[checker.name] = checker.check(state, model, quota_group) + return results + + def reset_all( + self, + state: CredentialState, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> None: + """ + Reset all limits for a credential. + + Args: + state: Credential state + model: Optional model scope + quota_group: Optional quota group scope + """ + for checker in self._checkers: + checker.reset(state, model, quota_group) + + @property + def concurrent_checker(self) -> ConcurrentLimitChecker: + """Get the concurrent limit checker.""" + return self._concurrent_checker + + @property + def cooldown_checker(self) -> CooldownChecker: + """Get the cooldown checker.""" + return self._cooldown_checker + + @property + def window_checker(self) -> WindowLimitChecker: + """Get the window limit checker.""" + return self._window_checker + + @property + def custom_cap_checker(self) -> CustomCapChecker: + """Get the custom cap checker.""" + return self._custom_cap_checker + + @property + def fair_cycle_checker(self) -> FairCycleChecker: + """Get the fair cycle checker.""" + return self._fair_cycle_checker + + def add_checker(self, checker: LimitChecker) -> None: + """ + Add a custom limit checker. + + Allows extending the limit system with custom logic. + + Args: + checker: LimitChecker implementation to add + """ + self._checkers.append(checker) + + def remove_checker(self, name: str) -> bool: + """ + Remove a limit checker by name. + + Args: + name: Name of the checker to remove + + Returns: + True if removed, False if not found + """ + for i, checker in enumerate(self._checkers): + if checker.name == name: + del self._checkers[i] + return True + return False diff --git a/src/rotator_library/usage/limits/fair_cycle.py b/src/rotator_library/usage/limits/fair_cycle.py new file mode 100644 index 00000000..66705235 --- /dev/null +++ b/src/rotator_library/usage/limits/fair_cycle.py @@ -0,0 +1,384 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Fair cycle limit checker. + +Ensures credentials are used fairly by blocking exhausted ones +until all credentials in the pool are exhausted. +""" + +import time +import logging +from typing import Dict, List, Optional, Set + +from ..types import ( + CredentialState, + LimitCheckResult, + LimitResult, + FairCycleState, + GlobalFairCycleState, + TrackingMode, + FAIR_CYCLE_GLOBAL_KEY, +) +from ..config import FairCycleConfig +from ..tracking.windows import WindowManager +from .base import LimitChecker + +lib_logger = logging.getLogger("rotator_library") + + +class FairCycleChecker(LimitChecker): + """ + Checks fair cycle constraints. + + Blocks credentials that have been "exhausted" (quota used or long cooldown) + until all credentials in the pool have been exhausted, then resets the cycle. + """ + + def __init__( + self, config: FairCycleConfig, window_manager: Optional[WindowManager] = None + ): + """ + Initialize fair cycle checker. + + Args: + config: Fair cycle configuration + window_manager: WindowManager for getting window limits (optional) + """ + self._config = config + self._window_manager = window_manager + # Global cycle state per provider + self._global_state: Dict[str, Dict[str, GlobalFairCycleState]] = {} + + @property + def name(self) -> str: + return "fair_cycle" + + def check( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check if credential is blocked by fair cycle. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating pass/fail + """ + if not self._config.enabled: + return LimitCheckResult.ok() + + group_key = self._resolve_tracking_key(model, quota_group) + fc_state = state.fair_cycle.get(group_key) + + # Check quota-based exhaustion (cycle_request_count >= window.limit * threshold) + # This is separate from explicit exhaustion marking + if fc_state and not fc_state.exhausted: + quota_limit = self._get_quota_limit(state, model, quota_group) + if quota_limit is not None: + threshold = int(quota_limit * self._config.quota_threshold) + if fc_state.cycle_request_count >= threshold: + # Mark as exhausted due to quota threshold + now = time.time() + fc_state.exhausted = True + fc_state.exhausted_at = now + fc_state.exhausted_reason = "quota_threshold" + lib_logger.debug( + f"Credential {state.stable_id} exhausted for {group_key}: " + f"cycle_request_count ({fc_state.cycle_request_count}) >= " + f"quota_threshold ({threshold})" + ) + + # Not exhausted = allowed + if fc_state is None or not fc_state.exhausted: + return LimitCheckResult.ok() + + # Exhausted - check if cycle should reset + provider = state.provider + global_state = self._get_global_state(provider, group_key) + + # Check if cycle has expired + if self._should_reset_cycle(global_state): + # Don't block - cycle will be reset + return LimitCheckResult.ok() + + # Still blocked by fair cycle + return LimitCheckResult.blocked( + result=LimitResult.BLOCKED_FAIR_CYCLE, + reason=f"Fair cycle: exhausted for '{group_key}' - waiting for other credentials", + blocked_until=None, # Depends on other credentials + ) + + def reset( + self, + state: CredentialState, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> None: + """ + Reset fair cycle state for a credential. + + Args: + state: Credential state + model: Optional model scope + quota_group: Optional quota group scope + """ + group_key = self._resolve_tracking_key(model or "", quota_group) + + if quota_group or model: + if group_key in state.fair_cycle: + fc_state = state.fair_cycle[group_key] + fc_state.exhausted = False + fc_state.exhausted_at = None + fc_state.exhausted_reason = None + fc_state.cycle_request_count = 0 + else: + # Reset all + for fc_state in state.fair_cycle.values(): + fc_state.exhausted = False + fc_state.exhausted_at = None + fc_state.exhausted_reason = None + fc_state.cycle_request_count = 0 + + def check_all_exhausted( + self, + provider: str, + group_key: str, + all_states: List[CredentialState], + priorities: Optional[Dict[str, int]] = None, + ) -> bool: + """ + Check if all credentials in the pool are exhausted. + + Args: + provider: Provider name + group_key: Model or quota group + all_states: All credential states for this provider + priorities: Optional priority filter + + Returns: + True if all are exhausted + """ + # Filter by tier if not cross-tier + if priorities and not self._config.cross_tier: + # Group by priority tier + priority_groups: Dict[int, List[CredentialState]] = {} + for state in all_states: + p = priorities.get(state.stable_id, 999) + priority_groups.setdefault(p, []).append(state) + + # Check each priority group separately + for priority, group_states in priority_groups.items(): + if not self._all_exhausted_in_group(group_states, group_key): + return False + return True + else: + return self._all_exhausted_in_group(all_states, group_key) + + def reset_cycle( + self, + provider: str, + group_key: str, + all_states: List[CredentialState], + ) -> None: + """ + Reset the fair cycle for all credentials. + + Args: + provider: Provider name + group_key: Model or quota group + all_states: All credential states to reset + """ + now = time.time() + + for state in all_states: + if group_key in state.fair_cycle: + fc_state = state.fair_cycle[group_key] + fc_state.exhausted = False + fc_state.exhausted_at = None + fc_state.exhausted_reason = None + fc_state.cycle_request_count = 0 + + # Update global state + global_state = self._get_global_state(provider, group_key) + global_state.cycle_start = now + global_state.all_exhausted_at = None + global_state.cycle_count += 1 + + lib_logger.info( + f"Fair cycle reset for {provider}/{group_key}, cycle #{global_state.cycle_count}" + ) + + def mark_all_exhausted( + self, + provider: str, + group_key: str, + ) -> None: + """ + Record that all credentials are now exhausted. + + Args: + provider: Provider name + group_key: Model or quota group + """ + global_state = self._get_global_state(provider, group_key) + global_state.all_exhausted_at = time.time() + + lib_logger.info(f"All credentials exhausted for {provider}/{group_key}") + + def get_tracking_key(self, model: str, quota_group: Optional[str]) -> str: + """Get the fair cycle tracking key for a request.""" + return self._resolve_tracking_key(model, quota_group) + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _get_global_state( + self, + provider: str, + group_key: str, + ) -> GlobalFairCycleState: + """Get or create global fair cycle state.""" + if provider not in self._global_state: + self._global_state[provider] = {} + + if group_key not in self._global_state[provider]: + self._global_state[provider][group_key] = GlobalFairCycleState( + cycle_start=time.time() + ) + + return self._global_state[provider][group_key] + + def _resolve_tracking_key( + self, + model: str, + quota_group: Optional[str], + ) -> str: + """Resolve tracking key based on fair cycle mode.""" + if self._config.tracking_mode == TrackingMode.CREDENTIAL: + return FAIR_CYCLE_GLOBAL_KEY + return quota_group or model + + def _should_reset_cycle(self, global_state: GlobalFairCycleState) -> bool: + """Check if cycle duration has expired.""" + now = time.time() + return now >= global_state.cycle_start + self._config.duration + + def _get_quota_limit( + self, + state: CredentialState, + model: str, + quota_group: Optional[str], + ) -> Optional[int]: + """ + Get the quota limit for fair cycle comparison. + + Uses the smallest window limit available (most restrictive). + + Args: + state: Credential state + model: Model name + quota_group: Quota group (optional) + + Returns: + The quota limit, or None if unknown + """ + if self._window_manager is None: + return None + + primary_def = self._window_manager.get_primary_definition() + if primary_def is None: + return None + + group_key = quota_group or model + windows = None + + # Check group first if quota_group is specified + if quota_group: + group_stats = state.get_group_stats(quota_group, create=False) + if group_stats: + windows = group_stats.windows + + # Fall back to model + if windows is None: + model_stats = state.get_model_stats(model, create=False) + if model_stats: + windows = model_stats.windows + + if windows is None: + return None + + # Get limit from primary window + primary_window = self._window_manager.get_active_window( + windows, primary_def.name + ) + if primary_window and primary_window.limit: + return primary_window.limit + + # If no primary window limit, try to find smallest limit from any window + smallest_limit: Optional[int] = None + for window in windows.values(): + if window.limit is not None: + if smallest_limit is None or window.limit < smallest_limit: + smallest_limit = window.limit + + return smallest_limit + + def _all_exhausted_in_group( + self, + states: List[CredentialState], + group_key: str, + ) -> bool: + """Check if all credentials in a group are exhausted.""" + if not states: + return True + + for state in states: + fc_state = state.fair_cycle.get(group_key) + if fc_state is None or not fc_state.exhausted: + return False + + return True + + def get_global_state_dict(self) -> Dict[str, Dict[str, Dict]]: + """ + Get global state for serialization. + + Returns: + Dict suitable for JSON serialization + """ + result = {} + for provider, groups in self._global_state.items(): + result[provider] = {} + for group_key, state in groups.items(): + result[provider][group_key] = { + "cycle_start": state.cycle_start, + "all_exhausted_at": state.all_exhausted_at, + "cycle_count": state.cycle_count, + } + return result + + def load_global_state_dict(self, data: Dict[str, Dict[str, Dict]]) -> None: + """ + Load global state from serialized data. + + Args: + data: Dict from get_global_state_dict() + """ + self._global_state.clear() + for provider, groups in data.items(): + self._global_state[provider] = {} + for group_key, state_data in groups.items(): + self._global_state[provider][group_key] = GlobalFairCycleState( + cycle_start=state_data.get("cycle_start", 0), + all_exhausted_at=state_data.get("all_exhausted_at"), + cycle_count=state_data.get("cycle_count", 0), + ) diff --git a/src/rotator_library/usage/limits/window_limits.py b/src/rotator_library/usage/limits/window_limits.py new file mode 100644 index 00000000..31433f8c --- /dev/null +++ b/src/rotator_library/usage/limits/window_limits.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Window limit checker. + +Checks if a credential has exceeded its request quota for a window. +""" + +from typing import Dict, List, Optional + +from ..types import CredentialState, LimitCheckResult, LimitResult, WindowStats +from ..tracking.windows import WindowManager +from .base import LimitChecker + + +class WindowLimitChecker(LimitChecker): + """ + Checks window-based request limits. + + Blocks credentials that have exhausted their quota in any + tracked window. + """ + + def __init__(self, window_manager: WindowManager): + """ + Initialize window limit checker. + + Args: + window_manager: WindowManager instance for window operations + """ + self._windows = window_manager + + @property + def name(self) -> str: + return "window_limits" + + def check( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + ) -> LimitCheckResult: + """ + Check if any window limit is exceeded. + + Args: + state: Credential state to check + model: Model being requested + quota_group: Quota group for this model + + Returns: + LimitCheckResult indicating pass/fail + """ + group_key = quota_group or model + + # Check all configured windows + for definition in self._windows.definitions.values(): + windows = None + + if definition.applies_to == "model": + model_stats = state.get_model_stats(model, create=False) + if model_stats: + windows = model_stats.windows + elif definition.applies_to == "group": + group_stats = state.get_group_stats(group_key, create=False) + if group_stats: + windows = group_stats.windows + + if windows is None: + continue + + window = windows.get(definition.name) + if window is None or window.limit is None: + continue + + active = self._windows.get_active_window(windows, definition.name) + if active is None: + continue + + if active.request_count >= active.limit: + return LimitCheckResult.blocked( + result=LimitResult.BLOCKED_WINDOW, + reason=( + f"Window '{definition.name}' exhausted " + f"({active.request_count}/{active.limit})" + ), + blocked_until=active.reset_at, + ) + + return LimitCheckResult.ok() + + def get_remaining( + self, + state: CredentialState, + window_name: str, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> Optional[int]: + """ + Get remaining requests in a specific window. + + Args: + state: Credential state + window_name: Name of window to check + model: Model to check + quota_group: Quota group to check + + Returns: + Remaining requests, or None if unlimited/unknown + """ + group_key = quota_group or model or "" + definition = self._windows.definitions.get(window_name) + + windows = None + if definition: + if definition.applies_to == "model" and model: + model_stats = state.get_model_stats(model, create=False) + if model_stats: + windows = model_stats.windows + elif definition.applies_to == "group": + group_stats = state.get_group_stats(group_key, create=False) + if group_stats: + windows = group_stats.windows + + if windows is None: + return None + + return self._windows.get_window_remaining(windows, window_name) + + def get_all_remaining( + self, + state: CredentialState, + model: Optional[str] = None, + quota_group: Optional[str] = None, + ) -> Dict[str, Optional[int]]: + """ + Get remaining requests for all windows. + + Args: + state: Credential state + model: Model to check + quota_group: Quota group to check + + Returns: + Dict of window_name -> remaining (None if unlimited) + """ + result = {} + for definition in self._windows.definitions.values(): + result[definition.name] = self.get_remaining( + state, + definition.name, + model=model, + quota_group=quota_group, + ) + return result diff --git a/src/rotator_library/usage/manager.py b/src/rotator_library/usage/manager.py new file mode 100644 index 00000000..45b24091 --- /dev/null +++ b/src/rotator_library/usage/manager.py @@ -0,0 +1,1828 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +UsageManager facade and CredentialContext. + +This is the main public API for the usage tracking system. +""" + +import asyncio +import logging +import time +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Set, Union + +from ..core.types import CredentialInfo, RequestCompleteResult +from ..error_handler import ClassifiedError, classify_error, mask_credential + +from .types import ( + WindowStats, + TotalStats, + ModelStats, + GroupStats, + CredentialState, + LimitCheckResult, + RotationMode, + LimitResult, + FAIR_CYCLE_GLOBAL_KEY, + TrackingMode, + ResetMode, +) +from .config import ( + ProviderUsageConfig, + load_provider_usage_config, + get_default_windows, +) +from .identity.registry import CredentialRegistry +from .tracking.engine import TrackingEngine +from .tracking.windows import WindowManager +from .limits.engine import LimitEngine +from .selection.engine import SelectionEngine +from .persistence.storage import UsageStorage +from .integration.hooks import HookDispatcher +from .integration.api import UsageAPI + +lib_logger = logging.getLogger("rotator_library") + + +class CredentialContext: + """ + Context manager for credential lifecycle. + + Handles: + - Automatic release on exit + - Success/failure recording + - Usage tracking + + Usage: + async with usage_manager.acquire_credential(provider, model) as ctx: + response = await make_request(ctx.credential) + ctx.mark_success(response) + """ + + def __init__( + self, + manager: "UsageManager", + credential: str, + stable_id: str, + model: str, + quota_group: Optional[str] = None, + ): + self._manager = manager + self.credential = credential # The accessor (path or key) + self.stable_id = stable_id + self.model = model + self.quota_group = quota_group + self._acquired_at = time.time() + self._result: Optional[Literal["success", "failure"]] = None + self._response: Optional[Any] = None + self._response_headers: Optional[Dict[str, Any]] = None + self._error: Optional[ClassifiedError] = None + self._tokens: Dict[str, int] = {} + self._approx_cost: float = 0.0 + + async def __aenter__(self) -> "CredentialContext": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> bool: + # Always release the credential + await self._manager._release_credential(self.stable_id, self.model) + + success = False + error = self._error + response = self._response + + if self._result == "success": + success = True + elif self._result == "failure": + success = False + elif exc_val is not None: + error = classify_error(exc_val) + success = False + else: + success = True + + await self._manager._handle_request_complete( + stable_id=self.stable_id, + model=self.model, + quota_group=self.quota_group, + success=success, + response=response, + response_headers=self._response_headers, + error=error, + prompt_tokens=self._tokens.get("prompt", 0), + completion_tokens=self._tokens.get("completion", 0), + thinking_tokens=self._tokens.get("thinking", 0), + prompt_tokens_cache_read=self._tokens.get("prompt_cached", 0), + prompt_tokens_cache_write=self._tokens.get("prompt_cache_write", 0), + approx_cost=self._approx_cost, + ) + + return False # Don't suppress exceptions + + def mark_success( + self, + response: Any = None, + prompt_tokens: int = 0, + completion_tokens: int = 0, + thinking_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + approx_cost: float = 0.0, + response_headers: Optional[Dict[str, Any]] = None, + ) -> None: + """Mark request as successful.""" + self._result = "success" + self._response = response + self._response_headers = response_headers + self._tokens = { + "prompt": prompt_tokens, + "completion": completion_tokens, + "thinking": thinking_tokens, + "prompt_cached": prompt_tokens_cache_read, + "prompt_cache_write": prompt_tokens_cache_write, + } + self._approx_cost = approx_cost + + def mark_failure(self, error: ClassifiedError) -> None: + """Mark request as failed.""" + self._result = "failure" + self._error = error + + +class UsageManager: + """ + Main facade for usage tracking and credential selection. + + This class provides the primary interface for: + - Acquiring credentials for requests (with context manager) + - Recording usage and failures + - Selecting the best available credential + - Managing cooldowns and limits + + Example: + manager = UsageManager(provider="gemini", file_path="usage.json") + await manager.initialize(credentials) + + async with manager.acquire_credential(model="gemini-pro") as ctx: + response = await make_request(ctx.credential) + ctx.mark_success(response, prompt_tokens=100, completion_tokens=50) + """ + + def __init__( + self, + provider: str, + file_path: Optional[Union[str, Path]] = None, + provider_plugins: Optional[Dict[str, Any]] = None, + config: Optional[ProviderUsageConfig] = None, + max_concurrent_per_key: Optional[int] = None, + ): + """ + Initialize UsageManager. + + Args: + provider: Provider name (e.g., "gemini", "openai") + file_path: Path to usage.json file + provider_plugins: Dict of provider plugin classes + config: Optional pre-built configuration + max_concurrent_per_key: Max concurrent requests per credential + """ + self.provider = provider + self._provider_plugins = provider_plugins or {} + self._max_concurrent_per_key = max_concurrent_per_key + + # Load configuration + if config: + self._config = config + else: + self._config = load_provider_usage_config(provider, self._provider_plugins) + + # Initialize components + self._registry = CredentialRegistry() + self._window_manager = WindowManager( + window_definitions=self._config.windows or get_default_windows() + ) + self._tracking = TrackingEngine(self._window_manager, self._config) + self._limits = LimitEngine(self._config, self._window_manager) + self._selection = SelectionEngine( + self._config, self._limits, self._window_manager + ) + self._hooks = HookDispatcher(self._provider_plugins) + self._api = UsageAPI(self) + + # Storage + if file_path: + self._storage = UsageStorage(file_path) + else: + self._storage = None + + # State + self._states: Dict[str, CredentialState] = {} + self._initialized = False + self._lock = asyncio.Lock() + self._loaded_from_storage = False + self._loaded_count = 0 + self._quota_exhausted_summary: Dict[str, Dict[str, float]] = {} + self._quota_exhausted_task: Optional[asyncio.Task] = None + self._quota_exhausted_lock = asyncio.Lock() + self._save_task: Optional[asyncio.Task] = None + self._save_lock = asyncio.Lock() + + # Concurrency control: per-credential locks and conditions for waiting + self._key_locks: Dict[str, asyncio.Lock] = {} + self._key_conditions: Dict[str, asyncio.Condition] = {} + + async def initialize( + self, + credentials: List[str], + priorities: Optional[Dict[str, int]] = None, + tiers: Optional[Dict[str, str]] = None, + ) -> None: + """ + Initialize with credentials. + + Args: + credentials: List of credential accessors (paths or keys) + priorities: Optional priority overrides (accessor -> priority) + tiers: Optional tier overrides (accessor -> tier name) + """ + async with self._lock: + if self._initialized: + return + # Load persisted state + if self._storage: + ( + self._states, + fair_cycle_global, + loaded_from_storage, + ) = await self._storage.load() + self._loaded_from_storage = loaded_from_storage + self._loaded_count = len(self._states) + if fair_cycle_global: + self._limits.fair_cycle_checker.load_global_state_dict( + fair_cycle_global + ) + + # Register credentials + for accessor in credentials: + stable_id = self._registry.get_stable_id(accessor, self.provider) + + # Create or update state + if stable_id not in self._states: + self._states[stable_id] = CredentialState( + stable_id=stable_id, + provider=self.provider, + accessor=accessor, + created_at=time.time(), + ) + else: + # Update accessor in case it changed + self._states[stable_id].accessor = accessor + + # Apply overrides + if priorities and accessor in priorities: + self._states[stable_id].priority = priorities[accessor] + if tiers and accessor in tiers: + self._states[stable_id].tier = tiers[accessor] + + # Set max concurrent if configured, applying priority multiplier + if self._max_concurrent_per_key is not None: + base_concurrent = self._max_concurrent_per_key + # Apply priority multiplier from config + priority = self._states[stable_id].priority + multiplier = self._config.get_effective_multiplier(priority) + effective_concurrent = base_concurrent * multiplier + self._states[stable_id].max_concurrent = effective_concurrent + + self._initialized = True + lib_logger.debug( + f"UsageManager initialized for {self.provider} with {len(credentials)} credentials" + ) + + async def acquire_credential( + self, + model: str, + quota_group: Optional[str] = None, + exclude: Optional[Set[str]] = None, + candidates: Optional[List[str]] = None, + priorities: Optional[Dict[str, int]] = None, + deadline: float = 0.0, + ) -> CredentialContext: + """ + Acquire a credential for a request. + + Returns a context manager that automatically releases + the credential and records success/failure. + + This method will wait for credentials to become available if all are + currently busy (at max_concurrent), up until the deadline. + + Args: + model: Model to use + quota_group: Optional quota group (uses model name if None) + exclude: Set of stable_ids to exclude (by accessor) + candidates: Optional list of credential accessors to consider. + If provided, only these will be considered for selection. + priorities: Optional priority overrides (accessor -> priority). + If provided, overrides the stored priorities. + deadline: Request deadline timestamp + + Returns: + CredentialContext for use with async with + + Raises: + NoAvailableKeysError: If no credentials available within deadline + """ + from ..error_handler import NoAvailableKeysError + + # Convert accessor-based exclude to stable_id-based + exclude_ids = set() + if exclude: + for accessor in exclude: + stable_id = self._registry.get_stable_id(accessor, self.provider) + exclude_ids.add(stable_id) + + # Filter states to only candidates if provided + if candidates is not None: + candidate_ids = set() + for accessor in candidates: + stable_id = self._registry.get_stable_id(accessor, self.provider) + candidate_ids.add(stable_id) + states_to_check = { + sid: state + for sid, state in self._states.items() + if sid in candidate_ids + } + else: + states_to_check = self._states + + # Convert accessor-based priorities to stable_id-based + priority_overrides = None + if priorities: + priority_overrides = {} + for accessor, priority in priorities.items(): + stable_id = self._registry.get_stable_id(accessor, self.provider) + priority_overrides[stable_id] = priority + + # Normalize model name for consistent tracking and selection + normalized_model = self._normalize_model(model) + + # Ensure key conditions exist for all candidates + for stable_id in states_to_check: + if stable_id not in self._key_conditions: + self._key_conditions[stable_id] = asyncio.Condition() + self._key_locks[stable_id] = asyncio.Lock() + + # Main acquisition loop - continues until deadline + while time.time() < deadline: + # Try to select a credential + stable_id = self._selection.select( + provider=self.provider, + model=normalized_model, + states=states_to_check, + quota_group=quota_group, + exclude=exclude_ids, + priorities=priority_overrides, + deadline=deadline, + ) + + if stable_id is not None: + state = self._states[stable_id] + lock = self._key_locks.get(stable_id) + + if lock: + async with lock: + # Double-check availability after acquiring lock + if ( + state.max_concurrent is None + or state.active_requests < state.max_concurrent + ): + state.active_requests += 1 + lib_logger.debug( + f"Acquired credential {mask_credential(state.accessor, style='full')} " + f"for {model} (active: {state.active_requests}" + f"{f'/{state.max_concurrent}' if state.max_concurrent else ''})" + ) + return CredentialContext( + manager=self, + credential=state.accessor, + stable_id=stable_id, + model=normalized_model, + quota_group=quota_group, + ) + else: + # No lock configured, just increment + state.active_requests += 1 + return CredentialContext( + manager=self, + credential=state.accessor, + stable_id=stable_id, + model=normalized_model, + quota_group=quota_group, + ) + + # No credential available - need to wait + # Find the best credential to wait for (prefer lowest usage) + best_wait_id = None + best_usage = float("inf") + + for sid, state in states_to_check.items(): + if sid in exclude_ids: + continue + if ( + state.max_concurrent is not None + and state.active_requests >= state.max_concurrent + ): + # This one is busy but might become free + usage = state.totals.request_count + if usage < best_usage: + best_usage = usage + best_wait_id = sid + + if best_wait_id is None: + # All credentials blocked by cooldown or limits, not just concurrency + # Check if waiting for cooldown makes sense + soonest_cooldown = self._get_soonest_cooldown_end( + states_to_check, normalized_model, quota_group + ) + + if soonest_cooldown is not None: + remaining_budget = deadline - time.time() + wait_needed = soonest_cooldown - time.time() + + if wait_needed > remaining_budget: + # No credential will be available in time + lib_logger.warning( + f"All credentials on cooldown. Soonest in {wait_needed:.1f}s, " + f"budget {remaining_budget:.1f}s. Failing fast." + ) + break + + # Wait for cooldown to expire + lib_logger.info( + f"All credentials on cooldown. Waiting {wait_needed:.1f}s..." + ) + await asyncio.sleep(min(wait_needed + 0.1, remaining_budget)) + continue + + # No cooldowns and no busy keys - truly no keys available + break + + # Wait on the best credential's condition + condition = self._key_conditions.get(best_wait_id) + if condition: + lib_logger.debug( + f"All credentials busy. Waiting for {mask_credential(self._states[best_wait_id].accessor, style='full')}..." + ) + try: + async with condition: + remaining_budget = deadline - time.time() + if remaining_budget <= 0: + break + # Wait for notification or timeout (max 1 second to re-check) + await asyncio.wait_for( + condition.wait(), + timeout=min(1.0, remaining_budget), + ) + lib_logger.debug("Credential released. Re-evaluating...") + except asyncio.TimeoutError: + # Timeout is normal, just retry the loop + lib_logger.debug("Wait timed out. Re-evaluating...") + else: + # No condition, just sleep briefly and retry + await asyncio.sleep(0.1) + + # Deadline exceeded + raise NoAvailableKeysError( + f"Could not acquire a credential for {self.provider}/{model} " + f"within the time budget." + ) + + def _get_soonest_cooldown_end( + self, + states: Dict[str, CredentialState], + model: str, + quota_group: Optional[str], + ) -> Optional[float]: + """Get the soonest cooldown end time across all credentials.""" + soonest = None + now = time.time() + group_key = quota_group or model + + for state in states.values(): + # Check model-specific cooldown + cooldown = state.get_cooldown(group_key) + if cooldown and cooldown.until > now: + if soonest is None or cooldown.until < soonest: + soonest = cooldown.until + + # Check global cooldown + global_cooldown = state.get_cooldown() + if global_cooldown and global_cooldown.until > now: + if soonest is None or global_cooldown.until < soonest: + soonest = global_cooldown.until + + return soonest + + async def get_best_credential( + self, + model: str, + quota_group: Optional[str] = None, + exclude: Optional[Set[str]] = None, + deadline: float = 0.0, + ) -> Optional[str]: + """ + Get the best available credential without acquiring. + + Useful for checking availability or manual acquisition. + + Args: + model: Model to use + quota_group: Optional quota group + exclude: Set of accessors to exclude + deadline: Request deadline + + Returns: + Credential accessor, or None if none available + """ + # Convert exclude from accessors to stable_ids + exclude_ids = set() + if exclude: + for accessor in exclude: + stable_id = self._registry.get_stable_id(accessor, self.provider) + exclude_ids.add(stable_id) + + # Normalize model name for consistent selection + normalized_model = self._normalize_model(model) + + stable_id = self._selection.select( + provider=self.provider, + model=normalized_model, + states=self._states, + quota_group=quota_group, + exclude=exclude_ids, + deadline=deadline, + ) + + if stable_id is None: + return None + + return self._states[stable_id].accessor + + async def record_usage( + self, + accessor: str, + model: str, + success: bool, + prompt_tokens: int = 0, + completion_tokens: int = 0, + thinking_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + approx_cost: float = 0.0, + error: Optional[ClassifiedError] = None, + quota_group: Optional[str] = None, + ) -> None: + """ + Record usage for a credential (manual recording). + + Use this for manual tracking outside of context manager. + + Args: + accessor: Credential accessor + model: Model used + success: Whether request succeeded + prompt_tokens: Prompt tokens used + completion_tokens: Completion tokens used + thinking_tokens: Thinking tokens used + prompt_tokens_cache_read: Cached prompt tokens read + prompt_tokens_cache_write: Cached prompt tokens written + approx_cost: Approximate cost + error: Classified error if failed + quota_group: Quota group + """ + stable_id = self._registry.get_stable_id(accessor, self.provider) + + if success: + await self._record_success( + stable_id, + model, + quota_group, + prompt_tokens, + completion_tokens, + thinking_tokens, + prompt_tokens_cache_read, + prompt_tokens_cache_write, + approx_cost, + ) + else: + await self._record_failure( + stable_id, + model, + quota_group, + error, + request_count=1, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + ) + + async def _handle_request_complete( + self, + stable_id: str, + model: str, + quota_group: Optional[str], + success: bool, + response: Optional[Any], + response_headers: Optional[Dict[str, Any]], + error: Optional[ClassifiedError], + prompt_tokens: int = 0, + completion_tokens: int = 0, + thinking_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + approx_cost: float = 0.0, + ) -> None: + """Handle provider hooks and record request outcome.""" + state = self._states.get(stable_id) + if not state: + return + + normalized_model = self._normalize_model(model) + group_key = quota_group or self._get_model_quota_group(normalized_model) + + hook_result: Optional[RequestCompleteResult] = None + if self._hooks: + hook_result = await self._hooks.dispatch_request_complete( + provider=self.provider, + credential=state.accessor, + model=normalized_model, + success=success, + response=response, + error=error, + ) + + request_count = 1 + cooldown_override = None + force_exhausted = False + + if hook_result: + if hook_result.count_override is not None: + request_count = max(0, hook_result.count_override) + cooldown_override = hook_result.cooldown_override + force_exhausted = hook_result.force_exhausted + + if not success and error and hook_result is None: + if error.error_type in {"server_error", "api_connection"}: + request_count = 0 + + if request_count == 0: + prompt_tokens = 0 + completion_tokens = 0 + thinking_tokens = 0 + prompt_tokens_cache_read = 0 + prompt_tokens_cache_write = 0 + approx_cost = 0.0 + + if cooldown_override: + await self._tracking.apply_cooldown( + state=state, + reason="provider_hook", + duration=cooldown_override, + model_or_group=group_key, + source="provider_hook", + ) + + if force_exhausted: + await self._tracking.mark_exhausted( + state=state, + model_or_group=self._resolve_fair_cycle_key( + group_key or normalized_model + ), + reason="provider_hook", + ) + + if success: + await self._record_success( + stable_id, + normalized_model, + quota_group, + prompt_tokens, + completion_tokens, + thinking_tokens, + prompt_tokens_cache_read, + prompt_tokens_cache_write, + approx_cost, + response_headers=response_headers, + request_count=request_count, + ) + else: + await self._record_failure( + stable_id, + normalized_model, + quota_group, + error, + request_count=request_count, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + ) + + async def apply_cooldown( + self, + accessor: str, + duration: float, + reason: str = "manual", + model_or_group: Optional[str] = None, + ) -> None: + """ + Apply a cooldown to a credential. + + Args: + accessor: Credential accessor + duration: Cooldown duration in seconds + reason: Reason for cooldown + model_or_group: Scope of cooldown + """ + stable_id = self._registry.get_stable_id(accessor, self.provider) + state = self._states.get(stable_id) + if state: + await self._tracking.apply_cooldown( + state=state, + reason=reason, + duration=duration, + model_or_group=model_or_group, + ) + await self._save_if_needed() + + async def get_availability_stats( + self, + model: str, + quota_group: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Get availability statistics for credentials. + + Args: + model: Model to check + quota_group: Quota group + + Returns: + Dict with availability info + """ + return self._selection.get_availability_stats( + provider=self.provider, + model=model, + states=self._states, + quota_group=quota_group, + ) + + async def get_stats_for_endpoint( + self, + model_filter: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Get comprehensive stats suitable for status endpoints. + + Returns credential states, usage windows, cooldowns, and fair cycle state. + + Args: + model_filter: Optional model to filter stats for + + Returns: + Dict with comprehensive statistics + """ + stats = { + "provider": self.provider, + "credential_count": len(self._states), + "rotation_mode": self._config.rotation_mode.value, + "credentials": {}, + } + + stats.update( + { + "active_count": 0, + "exhausted_count": 0, + "total_requests": 0, + "tokens": { + "input_cached": 0, + "input_uncached": 0, + "input_cache_pct": 0, + "output": 0, + }, + "approx_cost": None, + "quota_groups": {}, + } + ) + + for stable_id, state in self._states.items(): + now = time.time() + + # Determine credential status with proper granularity + status = "active" + has_global_cooldown = False + has_group_cooldown = False + fc_exhausted_groups = [] + + # Check cooldowns (global vs per-group) + for key, cooldown in state.cooldowns.items(): + if cooldown.until > now: + if key == "_global_": + has_global_cooldown = True + else: + has_group_cooldown = True + + # Check fair cycle per group + for group_key, fc_state in state.fair_cycle.items(): + if fc_state.exhausted: + fc_exhausted_groups.append(group_key) + + # Determine final status + known_groups = set(state.group_usage.keys()) if state.group_usage else set() + + if has_global_cooldown: + status = "cooldown" + elif fc_exhausted_groups: + # Check if ALL known groups are exhausted + if known_groups and set(fc_exhausted_groups) >= known_groups: + status = "exhausted" + else: + status = "mixed" # Some groups available + elif has_group_cooldown: + status = "cooldown" + + cred_stats = { + "stable_id": stable_id, + "accessor_masked": mask_credential(state.accessor, style="full"), + "full_path": state.accessor, + "identifier": mask_credential(state.accessor, style="full"), + "email": state.display_name, + "tier": state.tier, + "priority": state.priority, + "active_requests": state.active_requests, + "status": status, + "totals": { + "request_count": state.totals.request_count, + "success_count": state.totals.success_count, + "failure_count": state.totals.failure_count, + "prompt_tokens": state.totals.prompt_tokens, + "completion_tokens": state.totals.completion_tokens, + "thinking_tokens": state.totals.thinking_tokens, + "output_tokens": state.totals.output_tokens, + "prompt_tokens_cache_read": state.totals.prompt_tokens_cache_read, + "prompt_tokens_cache_write": state.totals.prompt_tokens_cache_write, + "total_tokens": state.totals.total_tokens, + "approx_cost": state.totals.approx_cost, + "first_used_at": state.totals.first_used_at, + "last_used_at": state.totals.last_used_at, + }, + "model_usage": {}, + "group_usage": {}, + "cooldowns": {}, + "fair_cycle": {}, + } + + stats["total_requests"] += state.totals.request_count + stats["tokens"]["output"] += state.totals.output_tokens + stats["tokens"]["input_cached"] += state.totals.prompt_tokens_cache_read + # prompt_tokens in LiteLLM = uncached tokens (not total input) + # Total input = prompt_tokens + prompt_tokens_cache_read + stats["tokens"]["input_uncached"] += state.totals.prompt_tokens + if state.totals.approx_cost: + stats["approx_cost"] = ( + stats["approx_cost"] or 0.0 + ) + state.totals.approx_cost + + if status == "active": + stats["active_count"] += 1 + elif status == "exhausted": + stats["exhausted_count"] += 1 + + # Add model usage stats + for model_key, model_stats in state.model_usage.items(): + model_windows = {} + for window_name, window in model_stats.windows.items(): + model_windows[window_name] = { + "request_count": window.request_count, + "success_count": window.success_count, + "failure_count": window.failure_count, + "prompt_tokens": window.prompt_tokens, + "completion_tokens": window.completion_tokens, + "thinking_tokens": window.thinking_tokens, + "output_tokens": window.output_tokens, + "prompt_tokens_cache_read": window.prompt_tokens_cache_read, + "prompt_tokens_cache_write": window.prompt_tokens_cache_write, + "total_tokens": window.total_tokens, + "limit": window.limit, + "remaining": window.remaining, + "max_recorded_requests": window.max_recorded_requests, + "max_recorded_at": window.max_recorded_at, + "reset_at": window.reset_at, + "approx_cost": window.approx_cost, + "first_used_at": window.first_used_at, + "last_used_at": window.last_used_at, + } + cred_stats["model_usage"][model_key] = { + "windows": model_windows, + "totals": { + "request_count": model_stats.totals.request_count, + "success_count": model_stats.totals.success_count, + "failure_count": model_stats.totals.failure_count, + "prompt_tokens": model_stats.totals.prompt_tokens, + "completion_tokens": model_stats.totals.completion_tokens, + "thinking_tokens": model_stats.totals.thinking_tokens, + "output_tokens": model_stats.totals.output_tokens, + "prompt_tokens_cache_read": model_stats.totals.prompt_tokens_cache_read, + "prompt_tokens_cache_write": model_stats.totals.prompt_tokens_cache_write, + "total_tokens": model_stats.totals.total_tokens, + "approx_cost": model_stats.totals.approx_cost, + "first_used_at": model_stats.totals.first_used_at, + "last_used_at": model_stats.totals.last_used_at, + }, + } + + # Add group usage stats + for group_key, group_stats in state.group_usage.items(): + group_windows = {} + for window_name, window in group_stats.windows.items(): + group_windows[window_name] = { + "request_count": window.request_count, + "success_count": window.success_count, + "failure_count": window.failure_count, + "prompt_tokens": window.prompt_tokens, + "completion_tokens": window.completion_tokens, + "thinking_tokens": window.thinking_tokens, + "output_tokens": window.output_tokens, + "prompt_tokens_cache_read": window.prompt_tokens_cache_read, + "prompt_tokens_cache_write": window.prompt_tokens_cache_write, + "total_tokens": window.total_tokens, + "limit": window.limit, + "remaining": window.remaining, + "max_recorded_requests": window.max_recorded_requests, + "max_recorded_at": window.max_recorded_at, + "reset_at": window.reset_at, + "approx_cost": window.approx_cost, + "first_used_at": window.first_used_at, + "last_used_at": window.last_used_at, + } + cred_stats["group_usage"][group_key] = { + "windows": group_windows, + "totals": { + "request_count": group_stats.totals.request_count, + "success_count": group_stats.totals.success_count, + "failure_count": group_stats.totals.failure_count, + "prompt_tokens": group_stats.totals.prompt_tokens, + "completion_tokens": group_stats.totals.completion_tokens, + "thinking_tokens": group_stats.totals.thinking_tokens, + "output_tokens": group_stats.totals.output_tokens, + "prompt_tokens_cache_read": group_stats.totals.prompt_tokens_cache_read, + "prompt_tokens_cache_write": group_stats.totals.prompt_tokens_cache_write, + "total_tokens": group_stats.totals.total_tokens, + "approx_cost": group_stats.totals.approx_cost, + "first_used_at": group_stats.totals.first_used_at, + "last_used_at": group_stats.totals.last_used_at, + }, + } + + # Add per-group status info for this credential + group_data = cred_stats["group_usage"][group_key] + + # Fair cycle status for this group + fc_state = state.fair_cycle.get(group_key) + group_data["fair_cycle_exhausted"] = ( + fc_state.exhausted if fc_state else False + ) + group_data["fair_cycle_reason"] = ( + fc_state.exhausted_reason + if fc_state and fc_state.exhausted + else None + ) + + # Group-specific cooldown + group_cooldown = state.cooldowns.get(group_key) + if group_cooldown and group_cooldown.is_active: + group_data["cooldown_remaining"] = int( + group_cooldown.remaining_seconds + ) + group_data["cooldown_source"] = group_cooldown.source + else: + group_data["cooldown_remaining"] = None + group_data["cooldown_source"] = None + + # Custom cap info for this group + cap = self._limits.custom_cap_checker.get_cap_for( + state, group_key, group_key + ) + if cap: + # Get usage from primary window + primary_window = group_windows.get( + self._window_manager.get_primary_definition().name + if self._window_manager.get_primary_definition() + else "5h" + ) + cap_used = ( + primary_window.get("request_count", 0) if primary_window else 0 + ) + cap_limit = cap.max_requests + # Handle percentage-based caps (negative means percentage) + if cap_limit < 0: + api_limit = ( + primary_window.get("limit") if primary_window else None + ) + if api_limit: + cap_limit = int(api_limit * (-cap_limit) / 100) + else: + cap_limit = 0 + group_data["custom_cap"] = { + "limit": cap_limit, + "used": cap_used, + "remaining": max(0, cap_limit - cap_used), + } + else: + group_data["custom_cap"] = None + + # Aggregate quota group stats with per-window breakdown + group_agg = stats["quota_groups"].setdefault( + group_key, + { + "tiers": {}, # Credential tier counts (provider-level) + "windows": {}, # Per-window aggregated stats + "fair_cycle_summary": { # FC status across all credentials + "exhausted_count": 0, + "total_count": 0, + }, + }, + ) + + # Update fair cycle summary for this group + group_agg["fair_cycle_summary"]["total_count"] += 1 + if group_data.get("fair_cycle_exhausted"): + group_agg["fair_cycle_summary"]["exhausted_count"] += 1 + + # Add credential to tier count (provider-level, not per-window) + tier_key = state.tier or "unknown" + tier_stats = group_agg["tiers"].setdefault( + tier_key, + {"priority": state.priority or 0, "total": 0}, + ) + tier_stats["total"] += 1 + + # Aggregate per-window stats + for window_name, window in group_windows.items(): + window_agg = group_agg[ + "windows" + ].setdefault( + window_name, + { + "total_used": 0, + "total_remaining": 0, + "total_max": 0, + "remaining_pct": None, + "tier_availability": {}, # Per-window credential availability + }, + ) + + # Track tier availability for this window + tier_avail = window_agg["tier_availability"].setdefault( + tier_key, + {"total": 0, "available": 0}, + ) + tier_avail["total"] += 1 + + # Check if this credential has quota remaining in this window + limit = window.get("limit") + if limit is not None: + used = window["request_count"] + remaining = max(0, limit - used) + window_agg["total_used"] += used + window_agg["total_remaining"] += remaining + window_agg["total_max"] += limit + + # Credential has availability if remaining > 0 + if remaining > 0: + tier_avail["available"] += 1 + else: + # No limit = unlimited = always available + tier_avail["available"] += 1 + + # Add active cooldowns + for key, cooldown in state.cooldowns.items(): + if cooldown.is_active: + cred_stats["cooldowns"][key] = { + "reason": cooldown.reason, + "remaining_seconds": cooldown.remaining_seconds, + "source": cooldown.source, + } + + # Add fair cycle state + for key, fc_state in state.fair_cycle.items(): + if model_filter and key != model_filter: + continue + cred_stats["fair_cycle"][key] = { + "exhausted": fc_state.exhausted, + "cycle_request_count": fc_state.cycle_request_count, + } + + # Sort group_usage by quota limit (lowest first), then alphabetically + # This ensures detail view matches the global summary sort order + def group_sort_key(item): + group_name, group_data = item + windows = group_data.get("windows", {}) + if not windows: + return (float("inf"), group_name) # No windows = sort last + + # Find minimum limit across windows + min_limit = float("inf") + for window_data in windows.values(): + limit = window_data.get("limit") + if limit is not None and limit > 0: + min_limit = min(min_limit, limit) + + return (min_limit, group_name) + + sorted_group_usage = dict( + sorted(cred_stats["group_usage"].items(), key=group_sort_key) + ) + cred_stats["group_usage"] = sorted_group_usage + + stats["credentials"][stable_id] = cred_stats + + # Calculate remaining percentages for each window in each quota group + for group_stats in stats["quota_groups"].values(): + for window_stats in group_stats.get("windows", {}).values(): + if window_stats["total_max"] > 0: + window_stats["remaining_pct"] = round( + window_stats["total_remaining"] + / window_stats["total_max"] + * 100, + 1, + ) + + total_input = ( + stats["tokens"]["input_cached"] + stats["tokens"]["input_uncached"] + ) + stats["tokens"]["input_cache_pct"] = ( + round(stats["tokens"]["input_cached"] / total_input * 100, 1) + if total_input > 0 + else 0 + ) + + return stats + + def _get_provider_plugin_instance(self) -> Optional[Any]: + """Get provider plugin instance for the current provider.""" + if not self._provider_plugins: + return None + + # Provider plugins dict maps provider name -> plugin class or instance + plugin = self._provider_plugins.get(self.provider) + if plugin is None: + return None + + # If it's a class, instantiate it (singleton via metaclass); if already an instance, use directly + if isinstance(plugin, type): + return plugin() + return plugin + + def _normalize_model(self, model: str) -> str: + """ + Normalize model name using provider's mapping. + + Converts internal model names (e.g., claude-sonnet-4-5-thinking) to + public-facing names (e.g., claude-sonnet-4.5) for consistent storage + and tracking. + + Args: + model: Model name (with or without provider prefix) + + Returns: + Normalized model name (provider prefix preserved if present) + """ + plugin_instance = self._get_provider_plugin_instance() + + if plugin_instance and hasattr(plugin_instance, "normalize_model_for_tracking"): + return plugin_instance.normalize_model_for_tracking(model) + + return model + + def _get_model_quota_group(self, model: str) -> Optional[str]: + """ + Get the quota group for a model, if the provider defines one. + + Models in the same quota group share a single quota pool. + For example, all Claude models in Antigravity share the same daily quota. + + Args: + model: Model name (with or without provider prefix) + + Returns: + Group name (e.g., "claude") or None if not grouped + """ + plugin_instance = self._get_provider_plugin_instance() + + if plugin_instance and hasattr(plugin_instance, "get_model_quota_group"): + return plugin_instance.get_model_quota_group(model) + + return None + + def get_model_quota_group(self, model: str) -> Optional[str]: + """Public helper to get quota group for a model.""" + normalized_model = self._normalize_model(model) + return self._get_model_quota_group(normalized_model) + + def _get_grouped_models(self, group: str) -> List[str]: + """ + Get all model names in a quota group (with provider prefix), normalized. + + Returns only public-facing model names, deduplicated. Internal variants + (e.g., claude-sonnet-4-5-thinking) are normalized to their public name + (e.g., claude-sonnet-4.5). + + Args: + group: Group name (e.g., "claude") + + Returns: + List of normalized, deduplicated model names with provider prefix + (e.g., ["antigravity/claude-sonnet-4.5", "antigravity/claude-opus-4.5"]) + """ + plugin_instance = self._get_provider_plugin_instance() + + if plugin_instance and hasattr(plugin_instance, "get_models_in_quota_group"): + models = plugin_instance.get_models_in_quota_group(group) + + # Normalize and deduplicate + if hasattr(plugin_instance, "normalize_model_for_tracking"): + seen: Set[str] = set() + normalized: List[str] = [] + for m in models: + prefixed = f"{self.provider}/{m}" + norm = plugin_instance.normalize_model_for_tracking(prefixed) + if norm not in seen: + seen.add(norm) + normalized.append(norm) + return normalized + + # Fallback: just add provider prefix + return [f"{self.provider}/{m}" for m in models] + + return [] + + async def save(self, force: bool = False) -> bool: + """ + Save usage data to file. + + Args: + force: Force save even if debounce not elapsed + + Returns: + True if saved successfully + """ + if self._storage: + fair_cycle_global = self._limits.fair_cycle_checker.get_global_state_dict() + return await self._storage.save( + self._states, fair_cycle_global, force=force + ) + return False + + async def get_usage_snapshot(self) -> Dict[str, Dict[str, Any]]: + """ + Get a lightweight usage snapshot keyed by accessor. + + Returns: + Dict mapping accessor -> usage metadata. + """ + async with self._lock: + snapshot: Dict[str, Dict[str, Any]] = {} + for state in self._states.values(): + snapshot[state.accessor] = { + "last_used_ts": state.totals.last_used_at or 0, + } + return snapshot + + async def shutdown(self) -> None: + """Shutdown and save any pending data.""" + await self.save(force=True) + + async def reload_from_disk(self) -> None: + """ + Force reload usage data from disk. + + Useful when wanting fresh stats without making external API calls. + This reloads persisted state while preserving current credential registrations. + """ + if not self._storage: + lib_logger.debug( + f"reload_from_disk: No storage configured for {self.provider}" + ) + return + + async with self._lock: + # Load persisted state + loaded_states, fair_cycle_global, _ = await self._storage.load() + + # Merge loaded state with current state + # Keep current accessors but update usage data + for stable_id, loaded_state in loaded_states.items(): + if stable_id in self._states: + # Update usage data from loaded state + current = self._states[stable_id] + current.model_usage = loaded_state.model_usage + current.group_usage = loaded_state.group_usage + current.totals = loaded_state.totals + current.cooldowns = loaded_state.cooldowns + current.fair_cycle = loaded_state.fair_cycle + current.last_updated = loaded_state.last_updated + else: + # New credential from disk, add it + self._states[stable_id] = loaded_state + + # Reload fair cycle global state + if fair_cycle_global: + self._limits.fair_cycle_checker.load_global_state_dict( + fair_cycle_global + ) + + lib_logger.info( + f"Reloaded usage data from disk for {self.provider}: " + f"{len(self._states)} credentials" + ) + + async def update_quota_baseline( + self, + accessor: str, + model: str, + quota_max_requests: Optional[int] = None, + quota_reset_ts: Optional[float] = None, + quota_used: Optional[int] = None, + quota_group: Optional[str] = None, + force: bool = False, + apply_exhaustion: bool = False, + ) -> Optional[Dict[str, Any]]: + """ + Update quota baseline from provider API response. + + Called by provider plugins after receiving rate limit headers or + quota information from API responses. + + Args: + accessor: Credential accessor (path or key) + model: Model name + quota_max_requests: Max requests allowed in window + quota_reset_ts: When quota resets (Unix timestamp) + quota_used: Current used count from API + quota_group: Optional quota group (uses model if None) + force: If True, always use API values (for manual refresh). + If False (default), use max(local, api) to prevent stale + API data from overwriting accurate local counts during + background fetches. + apply_exhaustion: If True, apply cooldown for exhausted quota. + Provider controls when this is set based on its semantics + (e.g., Antigravity only on initial fetch, others always + when remaining == 0). + + Returns: + Cooldown info dict if cooldown was applied, None otherwise + """ + stable_id = self._registry.get_stable_id(accessor, self.provider) + state = self._states.get(stable_id) + if not state: + lib_logger.warning( + f"update_quota_baseline: Unknown credential {accessor[:20]}..." + ) + return None + + # Normalize model name for consistent tracking + normalized_model = self._normalize_model(model) + group_key = quota_group or self._get_model_quota_group(normalized_model) + + primary_def = self._window_manager.get_primary_definition() + + # Update windows based on quota scope + # If group_key exists, quota is at group level - only update group stats + # We can't know which model the requests went to from API-level quota + if group_key: + group_stats = state.get_group_stats(group_key) + if primary_def: + group_window = self._window_manager.get_or_create_window( + group_stats.windows, primary_def.name + ) + self._apply_quota_update( + group_window, quota_max_requests, quota_reset_ts, quota_used, force + ) + else: + # No quota group - model IS the quota scope, update model stats + model_stats = state.get_model_stats(normalized_model) + if primary_def: + model_window = self._window_manager.get_or_create_window( + model_stats.windows, primary_def.name + ) + self._apply_quota_update( + model_window, quota_max_requests, quota_reset_ts, quota_used, force + ) + + # Mark state as updated + state.last_updated = time.time() + + # Apply cooldown if provider indicates exhaustion + # Provider controls when apply_exhaustion is set based on its semantics + if apply_exhaustion: + cooldown_target = group_key or normalized_model + if quota_reset_ts: + await self._tracking.apply_cooldown( + state=state, + reason="quota_exhausted", + until=quota_reset_ts, + model_or_group=cooldown_target, + source="api_quota", + ) + + await self._queue_quota_exhausted_log( + accessor=accessor, + group_key=cooldown_target, + quota_reset_ts=quota_reset_ts, + ) + + await self._save_if_needed() + + return { + "cooldown_until": quota_reset_ts, + "reason": "quota_exhausted", + "model": model, + "cooldown_hours": max(0.0, (quota_reset_ts - time.time()) / 3600), + } + else: + # ERROR: Provider says exhausted but no reset timestamp! + lib_logger.error( + f"Quota exhausted for {cooldown_target} on " + f"{mask_credential(accessor, style='full')} but no reset_timestamp " + f"provided by API - cannot apply cooldown" + ) + + await self._save_if_needed() + + return None + + def _apply_quota_update( + self, + window: WindowStats, + quota_max_requests: Optional[int], + quota_reset_ts: Optional[float], + quota_used: Optional[int], + force: bool, + ) -> None: + """Apply quota update to a window.""" + if quota_max_requests is not None: + window.limit = quota_max_requests + + # Determine if there's actual usage (either API-reported or local) + has_usage = ( + quota_used is not None and quota_used > 0 + ) or window.request_count > 0 + + # Only set started_at and reset_at if there's actual usage + # This prevents bogus reset times for unused windows + if has_usage: + if quota_reset_ts is not None: + window.reset_at = quota_reset_ts + # Set started_at to now if not already set (API shows usage we don't have locally) + if window.started_at is None: + window.started_at = time.time() + + if quota_used is not None: + if force: + synced_count = quota_used + else: + synced_count = max( + window.request_count, + quota_used, + window.success_count + window.failure_count, + ) + self._reconcile_window_counts(window, synced_count) + + def _reconcile_window_counts(self, window: WindowStats, request_count: int) -> None: + """Reconcile window counts after quota sync.""" + local_total = window.success_count + window.failure_count + window.request_count = request_count + if local_total == 0 and request_count > 0: + window.success_count = request_count + window.failure_count = 0 + return + + if request_count < local_total: + failure_count = min(window.failure_count, request_count) + success_count = max(0, request_count - failure_count) + window.success_count = success_count + window.failure_count = failure_count + return + + if request_count > local_total: + window.success_count += request_count - local_total + + # ========================================================================= + # PROPERTIES + # ========================================================================= + + @property + def config(self) -> ProviderUsageConfig: + """Get the configuration.""" + return self._config + + @property + def registry(self) -> CredentialRegistry: + """Get the credential registry.""" + return self._registry + + @property + def api(self) -> UsageAPI: + """Get the usage API facade.""" + return self._api + + @property + def initialized(self) -> bool: + """Check if the manager is initialized.""" + return self._initialized + + @property + def tracking(self) -> TrackingEngine: + """Get the tracking engine.""" + return self._tracking + + @property + def limits(self) -> LimitEngine: + """Get the limit engine.""" + return self._limits + + @property + def window_manager(self) -> WindowManager: + """Get the window manager.""" + return self._window_manager + + @property + def selection(self) -> SelectionEngine: + """Get the selection engine.""" + return self._selection + + @property + def states(self) -> Dict[str, CredentialState]: + """Get all credential states.""" + return self._states + + @property + def loaded_from_storage(self) -> bool: + """Whether usage data was loaded from storage.""" + return self._loaded_from_storage + + @property + def loaded_credentials(self) -> int: + """Number of credentials loaded from storage.""" + return self._loaded_count + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _resolve_fair_cycle_key(self, group_key: str) -> str: + """Resolve fair cycle tracking key based on config.""" + if self._config.fair_cycle.tracking_mode == TrackingMode.CREDENTIAL: + return FAIR_CYCLE_GLOBAL_KEY + return group_key + + async def _release_credential(self, stable_id: str, model: str) -> None: + """Release a credential after use and notify waiting tasks.""" + state = self._states.get(stable_id) + if not state: + return + + # Decrement active requests + lock = self._key_locks.get(stable_id) + if lock: + async with lock: + state.active_requests = max(0, state.active_requests - 1) + else: + state.active_requests = max(0, state.active_requests - 1) + + # Log release with current state + remaining = state.active_requests + max_concurrent = state.max_concurrent + lib_logger.info( + f"Released credential {mask_credential(state.accessor, style='full')} " + f"from {model} (remaining concurrent: {remaining}" + f"{f'/{max_concurrent}' if max_concurrent else ''})" + ) + + # Notify all tasks waiting on this credential's condition + condition = self._key_conditions.get(stable_id) + if condition: + async with condition: + condition.notify_all() + + async def _queue_quota_exhausted_log( + self, accessor: str, group_key: str, quota_reset_ts: float + ) -> None: + async with self._quota_exhausted_lock: + masked = mask_credential(accessor, style="full") + if masked not in self._quota_exhausted_summary: + self._quota_exhausted_summary[masked] = {} + self._quota_exhausted_summary[masked][group_key] = quota_reset_ts + + if self._quota_exhausted_task is None or self._quota_exhausted_task.done(): + self._quota_exhausted_task = asyncio.create_task( + self._flush_quota_exhausted_log() + ) + + async def _flush_quota_exhausted_log(self) -> None: + await asyncio.sleep(0.2) + async with self._quota_exhausted_lock: + summary = self._quota_exhausted_summary + self._quota_exhausted_summary = {} + + if not summary: + return + + now = time.time() + parts = [] + for accessor, groups in sorted(summary.items()): + group_parts = [] + for group, reset_ts in sorted(groups.items()): + hours = max(0.0, (reset_ts - now) / 3600) if reset_ts else 0.0 + group_parts.append(f"{group} {hours:.1f}h") + parts.append(f"{accessor}[{', '.join(group_parts)}]") + + lib_logger.info(f"Quota exhausted: {', '.join(parts)}") + + async def _save_if_needed(self) -> None: + """Persist state if storage is configured.""" + if not self._storage: + return + fair_cycle_global = self._limits.fair_cycle_checker.get_global_state_dict() + saved = await self._storage.save(self._states, fair_cycle_global) + if not saved: + await self._schedule_save_flush() + + async def _schedule_save_flush(self) -> None: + if self._save_task and not self._save_task.done(): + return + self._save_task = asyncio.create_task(self._flush_save()) + + async def _flush_save(self) -> None: + async with self._save_lock: + await asyncio.sleep(self._storage.save_debounce_seconds) + if not self._storage: + return + fair_cycle_global = self._limits.fair_cycle_checker.get_global_state_dict() + await self._storage.save_if_dirty(self._states, fair_cycle_global) + + async def _record_success( + self, + stable_id: str, + model: str, + quota_group: Optional[str] = None, + prompt_tokens: int = 0, + completion_tokens: int = 0, + thinking_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + approx_cost: float = 0.0, + response_headers: Optional[Dict[str, Any]] = None, + request_count: int = 1, + ) -> None: + """Record a successful request.""" + state = self._states.get(stable_id) + if state: + # Normalize model name for consistent tracking + normalized_model = self._normalize_model(model) + group_key = quota_group or self._get_model_quota_group(normalized_model) + + await self._tracking.record_success( + state=state, + model=normalized_model, + quota_group=group_key, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + response_headers=response_headers, + request_count=request_count, + ) + + # Apply custom cap cooldown if exceeded + cap_result = self._limits.custom_cap_checker.check( + state, normalized_model, group_key + ) + if ( + not cap_result.allowed + and cap_result.result == LimitResult.BLOCKED_CUSTOM_CAP + and cap_result.blocked_until + ): + await self._tracking.apply_cooldown( + state=state, + reason="custom_cap", + until=cap_result.blocked_until, + model_or_group=group_key or normalized_model, + source="custom_cap", + ) + + await self._save_if_needed() + + async def _record_failure( + self, + stable_id: str, + model: str, + quota_group: Optional[str] = None, + error: Optional[ClassifiedError] = None, + request_count: int = 1, + prompt_tokens: int = 0, + completion_tokens: int = 0, + thinking_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + approx_cost: float = 0.0, + ) -> None: + """Record a failed request.""" + state = self._states.get(stable_id) + if not state: + return + + # Normalize model name for consistent tracking + normalized_model = self._normalize_model(model) + group_key = quota_group or self._get_model_quota_group(normalized_model) + + # Determine cooldown from error + cooldown_duration = None + quota_reset = None + mark_exhausted = False + + if error: + cooldown_duration = error.retry_after + quota_reset = error.quota_reset_timestamp + + # Mark exhausted for quota errors with long cooldown + if error.error_type == "quota_exceeded": + if ( + cooldown_duration + and cooldown_duration >= self._config.exhaustion_cooldown_threshold + ): + mark_exhausted = True + + # Log quota exhaustion like legacy system + masked = mask_credential(state.accessor, style="full") + cooldown_target = group_key or normalized_model + + if quota_reset: + reset_dt = datetime.fromtimestamp(quota_reset, tz=timezone.utc) + hours = max(0.0, (quota_reset - time.time()) / 3600) + lib_logger.info( + f"Quota exhausted for '{cooldown_target}' on {masked}. " + f"Resets at {reset_dt.isoformat()} ({hours:.1f}h)" + ) + elif cooldown_duration: + hours = cooldown_duration / 3600 + lib_logger.info( + f"Quota exhausted on {masked} for '{cooldown_target}'. " + f"Cooldown: {cooldown_duration:.0f}s ({hours:.1f}h)" + ) + + await self._tracking.record_failure( + state=state, + model=normalized_model, + error_type=error.error_type if error else "unknown", + quota_group=group_key, + cooldown_duration=cooldown_duration, + quota_reset_timestamp=quota_reset, + mark_exhausted=mark_exhausted, + request_count=request_count, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + ) + + # Log fair cycle marking like legacy system + if mark_exhausted and self._config.fair_cycle.enabled: + fc_key = group_key or normalized_model + # Resolve fair cycle tracking key based on config + if self._config.fair_cycle.tracking_mode == TrackingMode.CREDENTIAL: + fc_key = FAIR_CYCLE_GLOBAL_KEY + exhausted_count = sum( + 1 + for s in self._states.values() + if fc_key in s.fair_cycle and s.fair_cycle[fc_key].exhausted + ) + masked = mask_credential(state.accessor, style="full") + lib_logger.info( + f"Fair cycle: marked {masked} exhausted for {fc_key} ({exhausted_count} total)" + ) + + await self._save_if_needed() diff --git a/src/rotator_library/usage/persistence/__init__.py b/src/rotator_library/usage/persistence/__init__.py new file mode 100644 index 00000000..9bdf7d76 --- /dev/null +++ b/src/rotator_library/usage/persistence/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Usage data persistence.""" + +from .storage import UsageStorage + +__all__ = ["UsageStorage"] diff --git a/src/rotator_library/usage/persistence/storage.py b/src/rotator_library/usage/persistence/storage.py new file mode 100644 index 00000000..06c0af32 --- /dev/null +++ b/src/rotator_library/usage/persistence/storage.py @@ -0,0 +1,493 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Usage data storage. + +Handles loading and saving usage data to JSON files. +""" + +import asyncio +import json +import logging +import time +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from ..types import ( + WindowStats, + TotalStats, + ModelStats, + GroupStats, + CredentialState, + CooldownInfo, + FairCycleState, + GlobalFairCycleState, + StorageSchema, +) +from ...utils.resilient_io import ResilientStateWriter, safe_read_json + +lib_logger = logging.getLogger("rotator_library") + + +def _format_timestamp(ts: Optional[float]) -> Optional[str]: + """Format a unix timestamp as a human-readable local time string.""" + if ts is None: + return None + try: + # Use local timezone for human readability + dt = datetime.fromtimestamp(ts) + return dt.strftime("%Y-%m-%d %H:%M:%S") + except (ValueError, OSError): + return None + + +class UsageStorage: + """ + Handles persistence of usage data to JSON files. + + Features: + - Async file I/O with aiofiles + - Atomic writes (write to temp, then rename) + - Automatic schema migration + - Debounced saves to reduce I/O + """ + + CURRENT_SCHEMA_VERSION = 2 + + def __init__( + self, + file_path: Union[str, Path], + save_debounce_seconds: float = 5.0, + ): + """ + Initialize storage. + + Args: + file_path: Path to the usage.json file + save_debounce_seconds: Minimum time between saves + """ + self.file_path = Path(file_path) + self.save_debounce_seconds = save_debounce_seconds + + self._last_save: float = 0 + self._pending_save: bool = False + self._save_lock = asyncio.Lock() + self._dirty: bool = False + self._writer = ResilientStateWriter(self.file_path, lib_logger) + + async def load( + self, + ) -> tuple[Dict[str, CredentialState], Dict[str, Dict[str, Any]], bool]: + """ + Load usage data from file. + + Returns: + Tuple of (states dict, fair_cycle_global dict, loaded_from_file bool) + """ + if not self.file_path.exists(): + return {}, {}, False + + try: + async with self._file_lock(): + data = safe_read_json(self.file_path, lib_logger, parse_json=True) + + if not data: + return {}, {}, True + + # Check schema version + version = data.get("schema_version", 1) + if version < self.CURRENT_SCHEMA_VERSION: + lib_logger.info( + f"Migrating usage data from v{version} to v{self.CURRENT_SCHEMA_VERSION}" + ) + data = self._migrate(data, version) + + # Parse credentials + states = {} + for stable_id, cred_data in data.get("credentials", {}).items(): + state = self._parse_credential_state(stable_id, cred_data) + if state: + states[stable_id] = state + + lib_logger.info(f"Loaded {len(states)} credentials from {self.file_path}") + return states, data.get("fair_cycle_global", {}), True + + except json.JSONDecodeError as e: + lib_logger.error(f"Failed to parse usage file: {e}") + return {}, {}, True + except Exception as e: + lib_logger.error(f"Failed to load usage file: {e}") + return {}, {}, True + + async def save( + self, + states: Dict[str, CredentialState], + fair_cycle_global: Optional[Dict[str, Dict[str, Any]]] = None, + force: bool = False, + ) -> bool: + """ + Save usage data to file. + + Args: + states: Dict of stable_id -> CredentialState + fair_cycle_global: Global fair cycle state + force: Force save even if debounce not elapsed + + Returns: + True if saved, False if skipped or failed + """ + now = time.time() + + # Check debounce + if not force and (now - self._last_save) < self.save_debounce_seconds: + self._dirty = True + return False + + async with self._save_lock: + try: + # Build storage data + data = { + "schema_version": self.CURRENT_SCHEMA_VERSION, + "updated_at": datetime.now(timezone.utc).isoformat(), + "credentials": {}, + "accessor_index": {}, + "fair_cycle_global": fair_cycle_global or {}, + } + + for stable_id, state in states.items(): + data["credentials"][stable_id] = self._serialize_credential_state( + state + ) + data["accessor_index"][state.accessor] = stable_id + + saved = self._writer.write(data) + + if saved: + self._last_save = now + self._dirty = False + lib_logger.debug( + f"Saved {len(states)} credentials to {self.file_path}" + ) + return True + + self._dirty = True + return False + + except Exception as e: + lib_logger.error(f"Failed to save usage file: {e}") + return False + + async def save_if_dirty( + self, + states: Dict[str, CredentialState], + fair_cycle_global: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> bool: + """ + Save if there are pending changes. + + Args: + states: Dict of stable_id -> CredentialState + fair_cycle_global: Global fair cycle state + + Returns: + True if saved, False otherwise + """ + if self._dirty: + return await self.save(states, fair_cycle_global, force=True) + return False + + def mark_dirty(self) -> None: + """Mark data as changed, needing save.""" + self._dirty = True + + @property + def is_dirty(self) -> bool: + """Check if there are unsaved changes.""" + return self._dirty + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _file_lock(self): + """Get a lock for file operations.""" + return self._save_lock + + def _migrate(self, data: Dict[str, Any], from_version: int) -> Dict[str, Any]: + """Migrate data from older schema versions.""" + if from_version == 1: + # v1 -> v2: Add accessor_index, restructure credentials + data["schema_version"] = 2 + data.setdefault("accessor_index", {}) + data.setdefault("fair_cycle_global", {}) + + # v1 used file paths as keys, v2 uses stable_ids + # For migration, treat paths as stable_ids + old_credentials = data.get("credentials", data.get("key_states", {})) + new_credentials = {} + + for key, cred_data in old_credentials.items(): + # Use path as temporary stable_id + stable_id = cred_data.get("stable_id", key) + new_credentials[stable_id] = cred_data + new_credentials[stable_id]["accessor"] = key + + data["credentials"] = new_credentials + + return data + + def _parse_window_stats(self, name: str, data: Dict[str, Any]) -> WindowStats: + """Parse window stats from storage data.""" + return WindowStats( + name=name, + request_count=data.get("request_count", 0), + success_count=data.get("success_count", 0), + failure_count=data.get("failure_count", 0), + prompt_tokens=data.get("prompt_tokens", 0), + completion_tokens=data.get("completion_tokens", 0), + thinking_tokens=data.get("thinking_tokens", 0), + output_tokens=data.get("output_tokens", 0), + prompt_tokens_cache_read=data.get("prompt_tokens_cache_read", 0), + prompt_tokens_cache_write=data.get("prompt_tokens_cache_write", 0), + total_tokens=data.get("total_tokens", 0), + approx_cost=data.get("approx_cost", 0.0), + started_at=data.get("started_at"), + reset_at=data.get("reset_at"), + limit=data.get("limit"), + max_recorded_requests=data.get("max_recorded_requests"), + max_recorded_at=data.get("max_recorded_at"), + first_used_at=data.get("first_used_at"), + last_used_at=data.get("last_used_at"), + ) + + def _serialize_window_stats(self, window: WindowStats) -> Dict[str, Any]: + """Serialize window stats for storage.""" + return { + "request_count": window.request_count, + "success_count": window.success_count, + "failure_count": window.failure_count, + "prompt_tokens": window.prompt_tokens, + "completion_tokens": window.completion_tokens, + "thinking_tokens": window.thinking_tokens, + "output_tokens": window.output_tokens, + "prompt_tokens_cache_read": window.prompt_tokens_cache_read, + "prompt_tokens_cache_write": window.prompt_tokens_cache_write, + "total_tokens": window.total_tokens, + "approx_cost": window.approx_cost, + "started_at": window.started_at, + "started_at_human": _format_timestamp(window.started_at), + "reset_at": window.reset_at, + "reset_at_human": _format_timestamp(window.reset_at), + "limit": window.limit, + "max_recorded_requests": window.max_recorded_requests, + "max_recorded_at": window.max_recorded_at, + "max_recorded_at_human": _format_timestamp(window.max_recorded_at), + "first_used_at": window.first_used_at, + "first_used_at_human": _format_timestamp(window.first_used_at), + "last_used_at": window.last_used_at, + "last_used_at_human": _format_timestamp(window.last_used_at), + } + + def _parse_total_stats(self, data: Dict[str, Any]) -> TotalStats: + """Parse total stats from storage data.""" + return TotalStats( + request_count=data.get("request_count", 0), + success_count=data.get("success_count", 0), + failure_count=data.get("failure_count", 0), + prompt_tokens=data.get("prompt_tokens", 0), + completion_tokens=data.get("completion_tokens", 0), + thinking_tokens=data.get("thinking_tokens", 0), + output_tokens=data.get("output_tokens", 0), + prompt_tokens_cache_read=data.get("prompt_tokens_cache_read", 0), + prompt_tokens_cache_write=data.get("prompt_tokens_cache_write", 0), + total_tokens=data.get("total_tokens", 0), + approx_cost=data.get("approx_cost", 0.0), + first_used_at=data.get("first_used_at"), + last_used_at=data.get("last_used_at"), + ) + + def _serialize_total_stats(self, totals: TotalStats) -> Dict[str, Any]: + """Serialize total stats for storage.""" + return { + "request_count": totals.request_count, + "success_count": totals.success_count, + "failure_count": totals.failure_count, + "prompt_tokens": totals.prompt_tokens, + "completion_tokens": totals.completion_tokens, + "thinking_tokens": totals.thinking_tokens, + "output_tokens": totals.output_tokens, + "prompt_tokens_cache_read": totals.prompt_tokens_cache_read, + "prompt_tokens_cache_write": totals.prompt_tokens_cache_write, + "total_tokens": totals.total_tokens, + "approx_cost": totals.approx_cost, + "first_used_at": totals.first_used_at, + "first_used_at_human": _format_timestamp(totals.first_used_at), + "last_used_at": totals.last_used_at, + "last_used_at_human": _format_timestamp(totals.last_used_at), + } + + def _parse_model_stats(self, data: Dict[str, Any]) -> ModelStats: + """Parse model stats from storage data.""" + windows = {} + for name, wdata in data.get("windows", {}).items(): + # Skip legacy "total" window - now tracked in totals + if name == "total": + continue + windows[name] = self._parse_window_stats(name, wdata) + + totals = self._parse_total_stats(data.get("totals", {})) + + return ModelStats(windows=windows, totals=totals) + + def _serialize_model_stats(self, stats: ModelStats) -> Dict[str, Any]: + """Serialize model stats for storage.""" + return { + "windows": { + name: self._serialize_window_stats(window) + for name, window in stats.windows.items() + }, + "totals": self._serialize_total_stats(stats.totals), + } + + def _parse_group_stats(self, data: Dict[str, Any]) -> GroupStats: + """Parse group stats from storage data.""" + windows = {} + for name, wdata in data.get("windows", {}).items(): + # Skip legacy "total" window - now tracked in totals + if name == "total": + continue + windows[name] = self._parse_window_stats(name, wdata) + + totals = self._parse_total_stats(data.get("totals", {})) + + return GroupStats(windows=windows, totals=totals) + + def _serialize_group_stats(self, stats: GroupStats) -> Dict[str, Any]: + """Serialize group stats for storage.""" + return { + "windows": { + name: self._serialize_window_stats(window) + for name, window in stats.windows.items() + }, + "totals": self._serialize_total_stats(stats.totals), + } + + def _parse_credential_state( + self, + stable_id: str, + data: Dict[str, Any], + ) -> Optional[CredentialState]: + """Parse a credential state from storage data.""" + try: + # Parse model_usage + model_usage = {} + for key, usage_data in data.get("model_usage", {}).items(): + model_usage[key] = self._parse_model_stats(usage_data) + + # Parse group_usage + group_usage = {} + for key, usage_data in data.get("group_usage", {}).items(): + group_usage[key] = self._parse_group_stats(usage_data) + + # Parse credential-level totals + totals = self._parse_total_stats(data.get("totals", {})) + + # Parse cooldowns + cooldowns = {} + for key, cdata in data.get("cooldowns", {}).items(): + cooldowns[key] = CooldownInfo( + reason=cdata.get("reason", "unknown"), + until=cdata.get("until", 0), + started_at=cdata.get("started_at", 0), + source=cdata.get("source", "system"), + model_or_group=cdata.get("model_or_group"), + backoff_count=cdata.get("backoff_count", 0), + ) + + # Parse fair cycle + fair_cycle = {} + for key, fcdata in data.get("fair_cycle", {}).items(): + fair_cycle[key] = FairCycleState( + exhausted=fcdata.get("exhausted", False), + exhausted_at=fcdata.get("exhausted_at"), + exhausted_reason=fcdata.get("exhausted_reason"), + cycle_request_count=fcdata.get("cycle_request_count", 0), + model_or_group=key, + ) + + return CredentialState( + stable_id=stable_id, + provider=data.get("provider", "unknown"), + accessor=data.get("accessor", stable_id), + display_name=data.get("display_name"), + tier=data.get("tier"), + priority=data.get("priority", 999), + model_usage=model_usage, + group_usage=group_usage, + totals=totals, + cooldowns=cooldowns, + fair_cycle=fair_cycle, + active_requests=0, # Always starts at 0 + max_concurrent=data.get("max_concurrent"), + created_at=data.get("created_at"), + last_updated=data.get("last_updated"), + ) + + except Exception as e: + lib_logger.warning(f"Failed to parse credential {stable_id}: {e}") + return None + + def _serialize_credential_state(self, state: CredentialState) -> Dict[str, Any]: + """Serialize a credential state for storage.""" + # Serialize cooldowns (only active ones) + now = time.time() + cooldowns = {} + for key, cd in state.cooldowns.items(): + if cd.until > now: # Only save active cooldowns + cooldowns[key] = { + "reason": cd.reason, + "until": cd.until, + "until_human": _format_timestamp(cd.until), + "started_at": cd.started_at, + "started_at_human": _format_timestamp(cd.started_at), + "source": cd.source, + "model_or_group": cd.model_or_group, + "backoff_count": cd.backoff_count, + } + + # Serialize fair cycle + fair_cycle = {} + for key, fc in state.fair_cycle.items(): + fair_cycle[key] = { + "exhausted": fc.exhausted, + "exhausted_at": fc.exhausted_at, + "exhausted_at_human": _format_timestamp(fc.exhausted_at), + "exhausted_reason": fc.exhausted_reason, + "cycle_request_count": fc.cycle_request_count, + } + + return { + "provider": state.provider, + "accessor": state.accessor, + "display_name": state.display_name, + "tier": state.tier, + "priority": state.priority, + "model_usage": { + key: self._serialize_model_stats(stats) + for key, stats in state.model_usage.items() + }, + "group_usage": { + key: self._serialize_group_stats(stats) + for key, stats in state.group_usage.items() + }, + "totals": self._serialize_total_stats(state.totals), + "cooldowns": cooldowns, + "fair_cycle": fair_cycle, + "max_concurrent": state.max_concurrent, + "created_at": state.created_at, + "created_at_human": _format_timestamp(state.created_at), + "last_updated": state.last_updated, + "last_updated_human": _format_timestamp(state.last_updated), + } diff --git a/src/rotator_library/usage/selection/__init__.py b/src/rotator_library/usage/selection/__init__.py new file mode 100644 index 00000000..79824e51 --- /dev/null +++ b/src/rotator_library/usage/selection/__init__.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Credential selection and rotation strategies.""" + +from .engine import SelectionEngine +from .strategies.balanced import BalancedStrategy +from .strategies.sequential import SequentialStrategy + +__all__ = [ + "SelectionEngine", + "BalancedStrategy", + "SequentialStrategy", +] diff --git a/src/rotator_library/usage/selection/engine.py b/src/rotator_library/usage/selection/engine.py new file mode 100644 index 00000000..f0186f16 --- /dev/null +++ b/src/rotator_library/usage/selection/engine.py @@ -0,0 +1,419 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Selection engine for credential selection. + +Central component that orchestrates limit checking, modifiers, +and rotation strategies to select the best credential. +""" + +import time +import logging +from typing import Any, Dict, List, Optional, Set, Union + +from ..types import ( + CredentialState, + SelectionContext, + RotationMode, + LimitCheckResult, +) +from ..config import ProviderUsageConfig +from ..limits.engine import LimitEngine +from ..tracking.windows import WindowManager +from .strategies.balanced import BalancedStrategy +from .strategies.sequential import SequentialStrategy + +lib_logger = logging.getLogger("rotator_library") + + +class SelectionEngine: + """ + Central engine for credential selection. + + Orchestrates: + 1. Limit checking (filter unavailable credentials) + 2. Fair cycle modifiers (filter exhausted credentials) + 3. Rotation strategy (select from available) + """ + + def __init__( + self, + config: ProviderUsageConfig, + limit_engine: LimitEngine, + window_manager: WindowManager, + ): + """ + Initialize selection engine. + + Args: + config: Provider usage configuration + limit_engine: LimitEngine for availability checks + """ + self._config = config + self._limits = limit_engine + self._windows = window_manager + + # Initialize strategies + self._balanced = BalancedStrategy(config.rotation_tolerance) + self._sequential = SequentialStrategy(config.sequential_fallback_multiplier) + + # Current strategy + if config.rotation_mode == RotationMode.SEQUENTIAL: + self._strategy = self._sequential + else: + self._strategy = self._balanced + + def select( + self, + provider: str, + model: str, + states: Dict[str, CredentialState], + quota_group: Optional[str] = None, + exclude: Optional[Set[str]] = None, + priorities: Optional[Dict[str, int]] = None, + deadline: float = 0.0, + ) -> Optional[str]: + """ + Select the best available credential. + + Args: + provider: Provider name + model: Model being requested + states: Dict of stable_id -> CredentialState + quota_group: Quota group for this model + exclude: Set of stable_ids to exclude + priorities: Override priorities (stable_id -> priority) + deadline: Request deadline timestamp + + Returns: + Selected stable_id, or None if none available + """ + exclude = exclude or set() + + # Step 1: Get all candidates (not excluded) + candidates = [sid for sid in states.keys() if sid not in exclude] + + if not candidates: + return None + + # Step 2: Filter by limits + available = [] + for stable_id in candidates: + state = states[stable_id] + result = self._limits.check_all(state, model, quota_group) + if result.allowed: + available.append(stable_id) + + if not available: + # Check if we should reset fair cycle + if self._config.fair_cycle.enabled: + reset_performed = self._try_fair_cycle_reset( + provider, + model, + quota_group, + states, + candidates, + priorities, + ) + if reset_performed: + # Retry selection after reset + return self.select( + provider, + model, + states, + quota_group, + exclude, + priorities, + deadline, + ) + + lib_logger.debug( + f"No available credentials for {provider}/{model} " + f"(all {len(candidates)} blocked by limits)" + ) + return None + + # Step 3: Build selection context + # Get usage counts for weighting + usage_counts = {} + for stable_id in available: + state = states[stable_id] + usage_counts[stable_id] = self._get_usage_count(state, model, quota_group) + + # Build priorities map + if priorities is None: + priorities = {} + for stable_id in available: + priorities[stable_id] = states[stable_id].priority + + context = SelectionContext( + provider=provider, + model=model, + quota_group=quota_group, + candidates=available, + priorities=priorities, + usage_counts=usage_counts, + rotation_mode=self._config.rotation_mode, + rotation_tolerance=self._config.rotation_tolerance, + deadline=deadline or (time.time() + 120), + ) + + # Step 4: Apply rotation strategy + selected = self._strategy.select(context, states) + + if selected: + lib_logger.debug( + f"Selected credential {selected} for {provider}/{model} " + f"(from {len(available)} available)" + ) + + return selected + + def select_with_retry( + self, + provider: str, + model: str, + states: Dict[str, CredentialState], + quota_group: Optional[str] = None, + tried: Optional[Set[str]] = None, + priorities: Optional[Dict[str, int]] = None, + deadline: float = 0.0, + ) -> Optional[str]: + """ + Select a credential for retry, excluding already-tried ones. + + Convenience method for retry loops. + + Args: + provider: Provider name + model: Model being requested + states: Dict of stable_id -> CredentialState + quota_group: Quota group for this model + tried: Set of already-tried stable_ids + priorities: Override priorities + deadline: Request deadline timestamp + + Returns: + Selected stable_id, or None if none available + """ + return self.select( + provider=provider, + model=model, + states=states, + quota_group=quota_group, + exclude=tried, + priorities=priorities, + deadline=deadline, + ) + + def get_availability_stats( + self, + provider: str, + model: str, + states: Dict[str, CredentialState], + quota_group: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Get availability statistics for credentials. + + Useful for status reporting and debugging. + + Args: + provider: Provider name + model: Model being requested + states: Dict of stable_id -> CredentialState + quota_group: Quota group for this model + + Returns: + Dict with availability stats + """ + total = len(states) + available = 0 + blocked_by = { + "cooldowns": 0, + "window_limits": 0, + "custom_caps": 0, + "fair_cycle": 0, + "concurrent": 0, + } + + for stable_id, state in states.items(): + blocking = self._limits.get_blocking_info(state, model, quota_group) + + is_available = True + for checker_name, result in blocking.items(): + if not result.allowed: + is_available = False + if checker_name in blocked_by: + blocked_by[checker_name] += 1 + break + + if is_available: + available += 1 + + return { + "total": total, + "available": available, + "blocked": total - available, + "blocked_by": blocked_by, + "rotation_mode": self._config.rotation_mode.value, + } + + def set_rotation_mode(self, mode: RotationMode) -> None: + """ + Change the rotation mode. + + Args: + mode: New rotation mode + """ + self._config.rotation_mode = mode + if mode == RotationMode.SEQUENTIAL: + self._strategy = self._sequential + else: + self._strategy = self._balanced + + lib_logger.info(f"Rotation mode changed to {mode.value}") + + def mark_exhausted(self, provider: str, model_or_group: str) -> None: + """ + Mark current credential as exhausted (for sequential mode). + + Args: + provider: Provider name + model_or_group: Model or quota group + """ + if isinstance(self._strategy, SequentialStrategy): + self._strategy.mark_exhausted(provider, model_or_group) + + @property + def balanced_strategy(self) -> BalancedStrategy: + """Get the balanced strategy instance.""" + return self._balanced + + @property + def sequential_strategy(self) -> SequentialStrategy: + """Get the sequential strategy instance.""" + return self._sequential + + def _get_usage_count( + self, + state: CredentialState, + model: str, + quota_group: Optional[str], + ) -> int: + """Get the relevant usage count for rotation weighting.""" + primary_def = self._windows.get_primary_definition() + if primary_def: + windows = None + + if primary_def.applies_to == "model": + model_stats = state.get_model_stats(model, create=False) + if model_stats: + windows = model_stats.windows + elif primary_def.applies_to == "group": + group_key = quota_group or model + group_stats = state.get_group_stats(group_key, create=False) + if group_stats: + windows = group_stats.windows + + if windows: + window = self._windows.get_active_window(windows, primary_def.name) + if window: + return window.request_count + + return state.totals.request_count + + def _try_fair_cycle_reset( + self, + provider: str, + model: str, + quota_group: Optional[str], + states: Dict[str, CredentialState], + candidates: List[str], + priorities: Optional[Dict[str, int]], + ) -> bool: + """ + Try to reset fair cycle if all credentials are exhausted. + + Tier-aware: If cross_tier is disabled, checks each tier separately. + + Args: + provider: Provider name + model: Model being requested + quota_group: Quota group for this model + states: All credential states + candidates: Candidate stable_ids + + Returns: + True if reset was performed, False otherwise + """ + from ..types import LimitResult + + group_key = quota_group or model + fair_cycle_checker = self._limits.fair_cycle_checker + tracking_key = fair_cycle_checker.get_tracking_key(model, quota_group) + + # Check if all candidates are blocked by fair cycle + all_fair_cycle_blocked = True + fair_cycle_blocked_count = 0 + + for stable_id in candidates: + state = states[stable_id] + result = self._limits.check_all(state, model, quota_group) + + if result.allowed: + # Some credential is available - no need to reset + return False + + if result.result == LimitResult.BLOCKED_FAIR_CYCLE: + fair_cycle_blocked_count += 1 + else: + # Blocked by something other than fair cycle + all_fair_cycle_blocked = False + + # If no credentials blocked by fair cycle, can't help + if fair_cycle_blocked_count == 0: + return False + + # Get all candidate states for reset + candidate_states = [states[sid] for sid in candidates] + priority_map = priorities or {sid: states[sid].priority for sid in candidates} + + # Tier-aware reset + if self._config.fair_cycle.cross_tier: + # Cross-tier: reset all at once + if fair_cycle_checker.check_all_exhausted( + provider, tracking_key, candidate_states, priorities=priority_map + ): + lib_logger.info( + f"All credentials fair-cycle exhausted for {provider}/{model} " + f"(cross-tier), resetting cycle" + ) + fair_cycle_checker.reset_cycle(provider, tracking_key, candidate_states) + return True + else: + # Per-tier: group by priority and check each tier + tier_groups: Dict[int, List[CredentialState]] = {} + for state in candidate_states: + priority = state.priority + tier_groups.setdefault(priority, []).append(state) + + reset_any = False + for priority, tier_states in tier_groups.items(): + # Check if all in this tier are exhausted + all_tier_exhausted = all( + state.is_fair_cycle_exhausted(tracking_key) for state in tier_states + ) + + if all_tier_exhausted: + lib_logger.info( + f"All credentials fair-cycle exhausted for {provider}/{model} " + f"in tier {priority}, resetting tier cycle" + ) + fair_cycle_checker.reset_cycle(provider, tracking_key, tier_states) + reset_any = True + + return reset_any + + return False diff --git a/src/rotator_library/usage/selection/modifiers/__init__.py b/src/rotator_library/usage/selection/modifiers/__init__.py new file mode 100644 index 00000000..f06468eb --- /dev/null +++ b/src/rotator_library/usage/selection/modifiers/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Selection modifiers.""" diff --git a/src/rotator_library/usage/selection/strategies/__init__.py b/src/rotator_library/usage/selection/strategies/__init__.py new file mode 100644 index 00000000..68eeba8c --- /dev/null +++ b/src/rotator_library/usage/selection/strategies/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Rotation strategy implementations.""" + +from .balanced import BalancedStrategy +from .sequential import SequentialStrategy + +__all__ = ["BalancedStrategy", "SequentialStrategy"] diff --git a/src/rotator_library/usage/selection/strategies/balanced.py b/src/rotator_library/usage/selection/strategies/balanced.py new file mode 100644 index 00000000..6d070117 --- /dev/null +++ b/src/rotator_library/usage/selection/strategies/balanced.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Balanced rotation strategy. + +Distributes load evenly across credentials using weighted random selection. +""" + +import random +import logging +from typing import Dict, List, Optional + +from ...types import CredentialState, SelectionContext, RotationMode + +lib_logger = logging.getLogger("rotator_library") + + +class BalancedStrategy: + """ + Balanced credential rotation strategy. + + Uses weighted random selection where less-used credentials have + higher probability of being selected. The tolerance parameter + controls how much randomness is introduced. + + Weight formula: weight = (max_usage - credential_usage) + tolerance + 1 + """ + + def __init__(self, tolerance: float = 3.0): + """ + Initialize balanced strategy. + + Args: + tolerance: Controls randomness of selection. + - 0.0: Deterministic, least-used always selected + - 2.0-4.0: Recommended, balanced randomness + - 5.0+: High randomness + """ + self.tolerance = tolerance + + @property + def name(self) -> str: + return "balanced" + + @property + def mode(self) -> RotationMode: + return RotationMode.BALANCED + + def select( + self, + context: SelectionContext, + states: Dict[str, CredentialState], + ) -> Optional[str]: + """ + Select a credential using weighted random selection. + + Args: + context: Selection context with candidates and usage info + states: Dict of stable_id -> CredentialState + + Returns: + Selected stable_id, or None if no candidates + """ + if not context.candidates: + return None + + if len(context.candidates) == 1: + return context.candidates[0] + + # Group by priority for tiered selection + priority_groups = self._group_by_priority( + context.candidates, context.priorities + ) + + # Try each priority tier in order + for priority in sorted(priority_groups.keys()): + candidates = priority_groups[priority] + if not candidates: + continue + + # Calculate weights for this tier + weights = self._calculate_weights(candidates, context.usage_counts) + + # Weighted random selection + selected = self._weighted_random_choice(candidates, weights) + if selected: + return selected + + # Fallback: first candidate + return context.candidates[0] + + def _group_by_priority( + self, + candidates: List[str], + priorities: Dict[str, int], + ) -> Dict[int, List[str]]: + """Group candidates by priority tier.""" + groups: Dict[int, List[str]] = {} + for stable_id in candidates: + priority = priorities.get(stable_id, 999) + groups.setdefault(priority, []).append(stable_id) + return groups + + def _calculate_weights( + self, + candidates: List[str], + usage_counts: Dict[str, int], + ) -> List[float]: + """ + Calculate selection weights for candidates. + + Weight formula: weight = (max_usage - credential_usage) + tolerance + 1 + """ + if not candidates: + return [] + + # Get usage counts + usages = [usage_counts.get(stable_id, 0) for stable_id in candidates] + max_usage = max(usages) if usages else 0 + + # Calculate weights + weights = [] + for usage in usages: + weight = (max_usage - usage) + self.tolerance + 1 + weights.append(max(weight, 0.1)) # Ensure minimum weight + + return weights + + def _weighted_random_choice( + self, + candidates: List[str], + weights: List[float], + ) -> Optional[str]: + """Select a candidate using weighted random choice.""" + if not candidates: + return None + + if len(candidates) == 1: + return candidates[0] + + # Normalize weights + total = sum(weights) + if total <= 0: + return random.choice(candidates) + + # Weighted selection + r = random.uniform(0, total) + cumulative = 0 + for candidate, weight in zip(candidates, weights): + cumulative += weight + if r <= cumulative: + return candidate + + return candidates[-1] diff --git a/src/rotator_library/usage/selection/strategies/sequential.py b/src/rotator_library/usage/selection/strategies/sequential.py new file mode 100644 index 00000000..e00029b4 --- /dev/null +++ b/src/rotator_library/usage/selection/strategies/sequential.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Sequential rotation strategy. + +Uses one credential until exhausted, then moves to the next. +Good for providers that benefit from request caching. +""" + +import logging +from typing import Dict, List, Optional + +from ...types import CredentialState, SelectionContext, RotationMode + +lib_logger = logging.getLogger("rotator_library") + + +class SequentialStrategy: + """ + Sequential credential rotation strategy. + + Sticks to one credential until it's exhausted (rate limited, + quota exceeded, etc.), then moves to the next in priority order. + + This is useful for providers where repeated requests to the same + credential benefit from caching (e.g., context caching in LLMs). + """ + + def __init__(self, fallback_multiplier: int = 1): + """ + Initialize sequential strategy. + + Args: + fallback_multiplier: Default concurrent slots per priority + when not explicitly configured + """ + self.fallback_multiplier = fallback_multiplier + # Track current "sticky" credential per (provider, model_group) + self._current: Dict[tuple, str] = {} + + @property + def name(self) -> str: + return "sequential" + + @property + def mode(self) -> RotationMode: + return RotationMode.SEQUENTIAL + + def select( + self, + context: SelectionContext, + states: Dict[str, CredentialState], + ) -> Optional[str]: + """ + Select a credential using sequential/sticky selection. + + Prefers the currently active credential if it's still available. + Otherwise, selects the first available by priority. + + Args: + context: Selection context with candidates and usage info + states: Dict of stable_id -> CredentialState + + Returns: + Selected stable_id, or None if no candidates + """ + if not context.candidates: + return None + + if len(context.candidates) == 1: + return context.candidates[0] + + key = (context.provider, context.quota_group or context.model) + + # Check if current sticky credential is still available + current = self._current.get(key) + if current and current in context.candidates: + return current + + # Current not available - select new one by priority + selected = self._select_by_priority(context.candidates, context.priorities) + + # Make it sticky + if selected: + self._current[key] = selected + lib_logger.debug(f"Sequential: switched to credential {selected} for {key}") + + return selected + + def mark_exhausted(self, provider: str, model_or_group: str) -> None: + """ + Mark current credential as exhausted, forcing rotation. + + Args: + provider: Provider name + model_or_group: Model or quota group + """ + key = (provider, model_or_group) + if key in self._current: + old = self._current[key] + del self._current[key] + lib_logger.debug(f"Sequential: marked {old} exhausted for {key}") + + def get_current(self, provider: str, model_or_group: str) -> Optional[str]: + """ + Get the currently sticky credential. + + Args: + provider: Provider name + model_or_group: Model or quota group + + Returns: + Current sticky credential stable_id, or None + """ + key = (provider, model_or_group) + return self._current.get(key) + + def _select_by_priority( + self, + candidates: List[str], + priorities: Dict[str, int], + ) -> Optional[str]: + """Select the highest priority (lowest number) candidate.""" + if not candidates: + return None + + # Sort by priority (lower = higher priority) + sorted_candidates = sorted(candidates, key=lambda c: priorities.get(c, 999)) + + return sorted_candidates[0] + + def clear_sticky(self, provider: Optional[str] = None) -> None: + """ + Clear sticky credential state. + + Args: + provider: If specified, only clear for this provider + """ + if provider: + keys_to_remove = [k for k in self._current if k[0] == provider] + for key in keys_to_remove: + del self._current[key] + else: + self._current.clear() diff --git a/src/rotator_library/usage/tracking/__init__.py b/src/rotator_library/usage/tracking/__init__.py new file mode 100644 index 00000000..e459d28f --- /dev/null +++ b/src/rotator_library/usage/tracking/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +"""Usage tracking and window management.""" + +from .engine import TrackingEngine +from .windows import WindowManager + +__all__ = ["TrackingEngine", "WindowManager"] diff --git a/src/rotator_library/usage/tracking/engine.py b/src/rotator_library/usage/tracking/engine.py new file mode 100644 index 00000000..7cb23d22 --- /dev/null +++ b/src/rotator_library/usage/tracking/engine.py @@ -0,0 +1,659 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Tracking engine for usage recording. + +Central component for recording requests, successes, and failures. +""" + +import asyncio +import logging +import time +from typing import Any, Dict, List, Optional, Set + +from ..types import ( + WindowStats, + TotalStats, + ModelStats, + GroupStats, + CredentialState, + CooldownInfo, + FairCycleState, + TrackingMode, + UsageUpdate, + FAIR_CYCLE_GLOBAL_KEY, +) +from ..config import WindowDefinition, ProviderUsageConfig +from .windows import WindowManager + +lib_logger = logging.getLogger("rotator_library") + + +class TrackingEngine: + """ + Central engine for usage tracking. + + Responsibilities: + - Recording request successes and failures + - Managing usage windows + - Updating global statistics + - Managing cooldowns + - Tracking fair cycle state + """ + + def __init__( + self, + window_manager: WindowManager, + config: ProviderUsageConfig, + ): + """ + Initialize tracking engine. + + Args: + window_manager: WindowManager instance for window operations + config: Provider usage configuration + """ + self._windows = window_manager + self._config = config + self._lock = asyncio.Lock() + + async def record_usage( + self, + state: CredentialState, + model: str, + update: UsageUpdate, + group: Optional[str] = None, + response_headers: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Record usage for a request (consolidated function). + + Updates: + - model_usage[model].windows[*] + totals + - group_usage[group].windows[*] + totals (if group provided) + - credential.totals + + Args: + state: Credential state to update + model: Model that was used + update: UsageUpdate with all metrics + group: Quota group for this model (None = no group tracking) + response_headers: Optional response headers with rate limit info + """ + async with self._lock: + now = time.time() + fair_cycle_key = self._resolve_fair_cycle_key(group or model) + + # Calculate derived values + output_tokens = update.completion_tokens + update.thinking_tokens + total_tokens = ( + update.prompt_tokens + + update.completion_tokens + + update.thinking_tokens + + update.prompt_tokens_cache_read + + update.prompt_tokens_cache_write + ) + + # 1. Update model stats + model_stats = state.get_model_stats(model) + self._apply_to_windows( + model_stats.windows, update, now, total_tokens, output_tokens + ) + self._apply_to_totals( + model_stats.totals, update, now, total_tokens, output_tokens + ) + + # 2. Update group stats (if applicable) + if group: + group_stats = state.get_group_stats(group) + self._apply_to_windows( + group_stats.windows, update, now, total_tokens, output_tokens + ) + self._apply_to_totals( + group_stats.totals, update, now, total_tokens, output_tokens + ) + + # 3. Update credential totals + self._apply_to_totals( + state.totals, update, now, total_tokens, output_tokens + ) + + # 4. Update fair cycle request count + if self._config.fair_cycle.enabled: + fc_state = state.fair_cycle.get(fair_cycle_key) + if not fc_state: + fc_state = FairCycleState(model_or_group=fair_cycle_key) + state.fair_cycle[fair_cycle_key] = fc_state + fc_state.cycle_request_count += update.request_count + + # 5. Update from response headers if provided + if response_headers: + self._update_from_headers(state, response_headers, model, group) + + state.last_updated = now + + async def record_success( + self, + state: CredentialState, + model: str, + quota_group: Optional[str] = None, + prompt_tokens: int = 0, + completion_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + thinking_tokens: int = 0, + approx_cost: float = 0.0, + request_count: int = 1, + response_headers: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Record a successful request. + + Args: + state: Credential state to update + model: Model that was used + quota_group: Quota group for this model (None = use model name) + prompt_tokens: Prompt tokens used + completion_tokens: Completion tokens used + prompt_tokens_cache_read: Cached prompt tokens read + prompt_tokens_cache_write: Cached prompt tokens written + thinking_tokens: Thinking tokens used + approx_cost: Approximate cost + request_count: Number of requests to record + response_headers: Optional response headers with rate limit info + """ + update = UsageUpdate( + request_count=request_count, + success=True, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + ) + await self.record_usage( + state=state, + model=model, + update=update, + group=quota_group, + response_headers=response_headers, + ) + + async def record_failure( + self, + state: CredentialState, + model: str, + error_type: str, + quota_group: Optional[str] = None, + cooldown_duration: Optional[float] = None, + quota_reset_timestamp: Optional[float] = None, + mark_exhausted: bool = False, + request_count: int = 1, + prompt_tokens: int = 0, + completion_tokens: int = 0, + thinking_tokens: int = 0, + prompt_tokens_cache_read: int = 0, + prompt_tokens_cache_write: int = 0, + approx_cost: float = 0.0, + ) -> None: + """ + Record a failed request. + + Args: + state: Credential state to update + model: Model that was used + error_type: Type of error (quota_exceeded, rate_limit, etc.) + quota_group: Quota group for this model + cooldown_duration: How long to cool down (if applicable) + quota_reset_timestamp: When quota resets (from API) + mark_exhausted: Whether to mark as exhausted for fair cycle + request_count: Number of requests to record + prompt_tokens: Prompt tokens used + completion_tokens: Completion tokens used + thinking_tokens: Thinking tokens used + prompt_tokens_cache_read: Cached prompt tokens read + prompt_tokens_cache_write: Cached prompt tokens written + approx_cost: Approximate cost + """ + update = UsageUpdate( + request_count=request_count, + success=False, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + thinking_tokens=thinking_tokens, + prompt_tokens_cache_read=prompt_tokens_cache_read, + prompt_tokens_cache_write=prompt_tokens_cache_write, + approx_cost=approx_cost, + ) + + # Record the usage + await self.record_usage( + state=state, + model=model, + update=update, + group=quota_group, + ) + + async with self._lock: + group_key = quota_group or model + fair_cycle_key = self._resolve_fair_cycle_key(group_key) + + # Apply cooldown if specified + if cooldown_duration is not None and cooldown_duration > 0: + self._apply_cooldown( + state=state, + reason=error_type, + duration=cooldown_duration, + model_or_group=group_key, + source="error", + ) + + # Use quota reset timestamp if provided + if quota_reset_timestamp is not None: + self._apply_cooldown( + state=state, + reason=error_type, + until=quota_reset_timestamp, + model_or_group=group_key, + source="api_quota", + ) + + # Mark exhausted for fair cycle if requested + if mark_exhausted: + self._mark_exhausted(state, fair_cycle_key, error_type) + + async def acquire( + self, + state: CredentialState, + model: str, + ) -> bool: + """ + Acquire a credential for a request (increment active count). + + Args: + state: Credential state + model: Model being used + + Returns: + True if acquired, False if at max concurrent + """ + async with self._lock: + # Check concurrent limit + if state.max_concurrent is not None: + if state.active_requests >= state.max_concurrent: + return False + + state.active_requests += 1 + return True + + async def apply_cooldown( + self, + state: CredentialState, + reason: str, + duration: Optional[float] = None, + until: Optional[float] = None, + model_or_group: Optional[str] = None, + source: str = "system", + ) -> None: + """ + Apply a cooldown to a credential. + + Args: + state: Credential state + reason: Why the cooldown was applied + duration: Cooldown duration in seconds (if not using 'until') + until: Timestamp when cooldown ends (if not using 'duration') + model_or_group: Scope of cooldown (None = credential-wide) + source: Source of cooldown (system, custom_cap, rate_limit, etc.) + """ + async with self._lock: + self._apply_cooldown( + state=state, + reason=reason, + duration=duration, + until=until, + model_or_group=model_or_group, + source=source, + ) + + async def clear_cooldown( + self, + state: CredentialState, + model_or_group: Optional[str] = None, + ) -> None: + """ + Clear a cooldown from a credential. + + Args: + state: Credential state + model_or_group: Scope of cooldown to clear (None = global) + """ + async with self._lock: + key = model_or_group or "_global_" + if key in state.cooldowns: + del state.cooldowns[key] + + async def mark_exhausted( + self, + state: CredentialState, + model_or_group: str, + reason: str, + ) -> None: + """ + Mark a credential as exhausted for fair cycle. + + Args: + state: Credential state + model_or_group: Scope of exhaustion + reason: Why credential was exhausted + """ + async with self._lock: + self._mark_exhausted(state, model_or_group, reason) + + async def reset_fair_cycle( + self, + state: CredentialState, + model_or_group: str, + ) -> None: + """ + Reset fair cycle state for a credential. + + Args: + state: Credential state + model_or_group: Scope to reset + """ + async with self._lock: + if model_or_group in state.fair_cycle: + fc_state = state.fair_cycle[model_or_group] + fc_state.exhausted = False + fc_state.exhausted_at = None + fc_state.exhausted_reason = None + fc_state.cycle_request_count = 0 + + def get_window_usage( + self, + state: CredentialState, + window_name: str, + model: Optional[str] = None, + group: Optional[str] = None, + ) -> int: + """ + Get request count for a specific window. + + Args: + state: Credential state + window_name: Name of window + model: Model to check (optional) + group: Group to check (optional) + + Returns: + Request count (0 if window doesn't exist) + """ + # Check group first if provided + if group: + group_stats = state.group_usage.get(group) + if group_stats: + window = self._windows.get_active_window( + group_stats.windows, window_name + ) + if window: + return window.request_count + + # Check model if provided + if model: + model_stats = state.model_usage.get(model) + if model_stats: + window = self._windows.get_active_window( + model_stats.windows, window_name + ) + if window: + return window.request_count + + return 0 + + def get_primary_window_usage( + self, + state: CredentialState, + model: Optional[str] = None, + group: Optional[str] = None, + ) -> int: + """ + Get request count for the primary window. + + Args: + state: Credential state + model: Model to check (optional) + group: Group to check (optional) + + Returns: + Request count (0 if no primary window) + """ + primary_def = self._windows.get_primary_definition() + if primary_def is None: + return 0 + return self.get_window_usage(state, primary_def.name, model, group) + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _apply_to_windows( + self, + windows: Dict[str, WindowStats], + update: UsageUpdate, + now: float, + total_tokens: int, + output_tokens: int, + ) -> None: + """Apply update to all configured windows.""" + for window_def in self._config.windows: + window = self._windows.get_or_create_window(windows, window_def.name) + self._apply_to_window(window, update, now, total_tokens, output_tokens) + + def _apply_to_window( + self, + window: WindowStats, + update: UsageUpdate, + now: float, + total_tokens: int, + output_tokens: int, + ) -> None: + """Apply update to a single window.""" + window.request_count += update.request_count + if update.success: + window.success_count += update.request_count + else: + window.failure_count += update.request_count + + window.prompt_tokens += update.prompt_tokens + window.completion_tokens += update.completion_tokens + window.thinking_tokens += update.thinking_tokens + window.output_tokens += output_tokens + window.prompt_tokens_cache_read += update.prompt_tokens_cache_read + window.prompt_tokens_cache_write += update.prompt_tokens_cache_write + window.total_tokens += total_tokens + window.approx_cost += update.approx_cost + + window.last_used_at = now + if window.first_used_at is None: + window.first_used_at = now + + # Set started_at on first usage and calculate reset_at + if window.started_at is None: + window.started_at = now + # Calculate reset_at based on window definition + window_def = self._windows.definitions.get(window.name) + if window_def and window.reset_at is None: + window.reset_at = self._windows._calculate_reset_time(window_def, now) + + # Update max recorded requests (historical high-water mark) + if ( + window.max_recorded_requests is None + or window.request_count > window.max_recorded_requests + ): + window.max_recorded_requests = window.request_count + window.max_recorded_at = now + + def _apply_to_totals( + self, + totals: TotalStats, + update: UsageUpdate, + now: float, + total_tokens: int, + output_tokens: int, + ) -> None: + """Apply update to totals.""" + totals.request_count += update.request_count + if update.success: + totals.success_count += update.request_count + else: + totals.failure_count += update.request_count + + totals.prompt_tokens += update.prompt_tokens + totals.completion_tokens += update.completion_tokens + totals.thinking_tokens += update.thinking_tokens + totals.output_tokens += output_tokens + totals.prompt_tokens_cache_read += update.prompt_tokens_cache_read + totals.prompt_tokens_cache_write += update.prompt_tokens_cache_write + totals.total_tokens += total_tokens + totals.approx_cost += update.approx_cost + + totals.last_used_at = now + if totals.first_used_at is None: + totals.first_used_at = now + + def _apply_cooldown( + self, + state: CredentialState, + reason: str, + duration: Optional[float] = None, + until: Optional[float] = None, + model_or_group: Optional[str] = None, + source: str = "system", + ) -> None: + """Internal cooldown application (no lock).""" + now = time.time() + + if until is not None: + cooldown_until = until + elif duration is not None: + cooldown_until = now + duration + else: + return # No cooldown specified + + key = model_or_group or "_global_" + + # Check for existing cooldown + existing = state.cooldowns.get(key) + backoff_count = 0 + if existing and existing.is_active: + backoff_count = existing.backoff_count + 1 + + state.cooldowns[key] = CooldownInfo( + reason=reason, + until=cooldown_until, + started_at=now, + source=source, + model_or_group=model_or_group, + backoff_count=backoff_count, + ) + + # Check if cooldown qualifies as exhaustion + cooldown_duration = cooldown_until - now + if cooldown_duration >= self._config.exhaustion_cooldown_threshold: + if self._config.fair_cycle.enabled and model_or_group: + fair_cycle_key = self._resolve_fair_cycle_key(model_or_group) + self._mark_exhausted(state, fair_cycle_key, f"cooldown_{reason}") + + def _mark_exhausted( + self, + state: CredentialState, + model_or_group: str, + reason: str, + ) -> None: + """Internal exhaustion marking (no lock).""" + now = time.time() + + if model_or_group not in state.fair_cycle: + state.fair_cycle[model_or_group] = FairCycleState( + model_or_group=model_or_group + ) + + fc_state = state.fair_cycle[model_or_group] + + # Idempotency check: skip if already exhausted (avoid duplicate logging) + if fc_state.exhausted: + return + + fc_state.exhausted = True + fc_state.exhausted_at = now + fc_state.exhausted_reason = reason + + lib_logger.debug( + f"Credential {state.stable_id} marked exhausted for {model_or_group}: {reason}" + ) + + def _resolve_fair_cycle_key(self, group_key: str) -> str: + """Resolve fair cycle tracking key based on config.""" + if self._config.fair_cycle.tracking_mode == TrackingMode.CREDENTIAL: + return FAIR_CYCLE_GLOBAL_KEY + return group_key + + def _update_from_headers( + self, + state: CredentialState, + headers: Dict[str, Any], + model: str, + group: Optional[str], + ) -> None: + """Update state from API response headers.""" + # Common header patterns for rate limiting + # X-RateLimit-Remaining, X-RateLimit-Reset, etc. + remaining = headers.get("x-ratelimit-remaining") + reset = headers.get("x-ratelimit-reset") + limit = headers.get("x-ratelimit-limit") + + primary_def = self._windows.get_primary_definition() + if primary_def is None: + return + + # Update group windows if group is provided + if group: + group_stats = state.get_group_stats(group, create=False) + if group_stats: + window = group_stats.windows.get(primary_def.name) + if window: + self._apply_header_updates(window, limit, reset) + + # Update model windows + model_stats = state.get_model_stats(model, create=False) + if model_stats: + window = model_stats.windows.get(primary_def.name) + if window: + self._apply_header_updates(window, limit, reset) + + def _apply_header_updates( + self, + window: WindowStats, + limit: Optional[Any], + reset: Optional[Any], + ) -> None: + """Apply header updates to a window.""" + if limit is not None: + try: + window.limit = int(limit) + except (ValueError, TypeError): + pass + + if reset is not None: + try: + reset_float = float(reset) + # If reset is in the past, it might be a Unix timestamp + # If it's a small number, it might be seconds until reset + if reset_float < 1000000000: # Less than ~2001, probably relative + reset_float = time.time() + reset_float + window.reset_at = reset_float + except (ValueError, TypeError): + pass diff --git a/src/rotator_library/usage/tracking/windows.py b/src/rotator_library/usage/tracking/windows.py new file mode 100644 index 00000000..88852506 --- /dev/null +++ b/src/rotator_library/usage/tracking/windows.py @@ -0,0 +1,423 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Window management for usage tracking. + +Handles time-based usage windows with various reset modes. +""" + +import time +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone, time as dt_time +from typing import Any, Dict, List, Optional, Tuple + +from ..types import WindowStats, ResetMode +from ..config import WindowDefinition + +lib_logger = logging.getLogger("rotator_library") + + +class WindowManager: + """ + Manages usage tracking windows for credentials. + + Handles: + - Rolling windows (e.g., last 5 hours) + - Fixed daily windows (reset at specific UTC time) + - Calendar windows (weekly, monthly) + - API-authoritative windows (provider determines reset) + """ + + def __init__( + self, + window_definitions: List[WindowDefinition], + daily_reset_time_utc: str = "03:00", + ): + """ + Initialize window manager. + + Args: + window_definitions: List of window configurations + daily_reset_time_utc: Time for daily reset in HH:MM format + """ + self.definitions = {w.name: w for w in window_definitions} + self.daily_reset_time_utc = self._parse_time(daily_reset_time_utc) + + def get_active_window( + self, + windows: Dict[str, WindowStats], + window_name: str, + ) -> Optional[WindowStats]: + """ + Get an active (non-expired) window by name. + + Args: + windows: Current windows dict for a credential + window_name: Name of window to get + + Returns: + WindowStats if active, None if expired or doesn't exist + """ + window = windows.get(window_name) + if window is None: + return None + + definition = self.definitions.get(window_name) + if definition is None: + return window # Unknown window, return as-is + + # Check if window needs reset + if self._should_reset(window, definition): + return None + + return window + + def get_or_create_window( + self, + windows: Dict[str, WindowStats], + window_name: str, + limit: Optional[int] = None, + ) -> WindowStats: + """ + Get an active window or create a new one. + + Args: + windows: Current windows dict for a credential + window_name: Name of window to get/create + limit: Optional request limit for the window + + Returns: + Active WindowStats (may be newly created) + """ + window = self.get_active_window(windows, window_name) + if window is not None: + return window + + # Preserve max_recorded_requests from expired window (if exists) + old_max = None + old_max_at = None + old_window = windows.get(window_name) + if old_window is not None: + # Take max of the old window's recorded max and its final request count + old_recorded_max = old_window.max_recorded_requests or 0 + if old_window.request_count > old_recorded_max: + old_max = old_window.request_count + old_max_at = old_window.last_used_at or time.time() + elif old_recorded_max > 0: + old_max = old_recorded_max + old_max_at = old_window.max_recorded_at + + # Create new window + # Note: started_at and reset_at are left as None until first actual usage + # This prevents bogus reset times from being displayed for unused windows + new_window = WindowStats( + name=window_name, + started_at=None, + reset_at=None, + limit=limit, + max_recorded_requests=old_max, # Carry forward historical max + max_recorded_at=old_max_at, + ) + + windows[window_name] = new_window + return new_window + + def get_primary_window( + self, + windows: Dict[str, WindowStats], + ) -> Optional[WindowStats]: + """ + Get the primary window used for rotation decisions. + + Args: + windows: Current windows dict for a credential + + Returns: + Primary WindowStats or None + """ + for name, definition in self.definitions.items(): + if definition.is_primary: + return self.get_active_window(windows, name) + return None + + def get_primary_definition(self) -> Optional[WindowDefinition]: + """Get the primary window definition.""" + for definition in self.definitions.values(): + if definition.is_primary: + return definition + return None + + def get_window_remaining( + self, + windows: Dict[str, WindowStats], + window_name: str, + ) -> Optional[int]: + """ + Get remaining requests in a window. + + Args: + windows: Current windows dict for a credential + window_name: Name of window to check + + Returns: + Remaining requests, or None if unlimited/unknown + """ + window = self.get_active_window(windows, window_name) + if window is None: + return None + return window.remaining + + def check_expired_windows( + self, + windows: Dict[str, WindowStats], + ) -> List[str]: + """ + Check for and remove expired windows. + + Args: + windows: Current windows dict for a credential + + Returns: + List of removed window names + """ + expired = [] + for name in list(windows.keys()): + if self.get_active_window(windows, name) is None: + del windows[name] + expired.append(name) + return expired + + def update_limit( + self, + windows: Dict[str, WindowStats], + window_name: str, + new_limit: int, + ) -> None: + """ + Update the limit for a window (e.g., from API response). + + Args: + windows: Current windows dict for a credential + window_name: Name of window to update + new_limit: New request limit + """ + window = windows.get(window_name) + if window is not None: + window.limit = new_limit + + def update_reset_time( + self, + windows: Dict[str, WindowStats], + window_name: str, + reset_timestamp: float, + ) -> None: + """ + Update the reset time for a window (e.g., from API response). + + Args: + windows: Current windows dict for a credential + window_name: Name of window to update + reset_timestamp: New reset timestamp + """ + window = windows.get(window_name) + if window is not None: + window.reset_at = reset_timestamp + + # ========================================================================= + # PRIVATE METHODS + # ========================================================================= + + def _should_reset(self, window: WindowStats, definition: WindowDefinition) -> bool: + """ + Check if a window should be reset based on its definition. + """ + now = time.time() + + # If window has an explicit reset time, use it + if window.reset_at is not None: + return now >= window.reset_at + + # If window has no start time, it hasn't been used yet - no need to reset + if window.started_at is None: + return False + + # Check based on reset mode + if definition.reset_mode == ResetMode.ROLLING: + if definition.duration_seconds is None: + return False # Infinite window + return now >= window.started_at + definition.duration_seconds + + elif definition.reset_mode == ResetMode.FIXED_DAILY: + return self._past_daily_reset(window.started_at, now) + + elif definition.reset_mode == ResetMode.CALENDAR_WEEKLY: + return self._past_weekly_reset(window.started_at, now) + + elif definition.reset_mode == ResetMode.CALENDAR_MONTHLY: + return self._past_monthly_reset(window.started_at, now) + + elif definition.reset_mode == ResetMode.API_AUTHORITATIVE: + # Only reset if explicit reset_at is set and passed + return False + + return False + + def _calculate_reset_time( + self, + definition: WindowDefinition, + start_time: float, + ) -> Optional[float]: + """ + Calculate when a window should reset based on its definition. + """ + if definition.reset_mode == ResetMode.ROLLING: + if definition.duration_seconds is None: + return None # Infinite window + return start_time + definition.duration_seconds + + elif definition.reset_mode == ResetMode.FIXED_DAILY: + return self._next_daily_reset(start_time) + + elif definition.reset_mode == ResetMode.CALENDAR_WEEKLY: + return self._next_weekly_reset(start_time) + + elif definition.reset_mode == ResetMode.CALENDAR_MONTHLY: + return self._next_monthly_reset(start_time) + + elif definition.reset_mode == ResetMode.API_AUTHORITATIVE: + return None # Will be set by API response + + return None + + def _parse_time(self, time_str: str) -> dt_time: + """Parse HH:MM time string.""" + try: + parts = time_str.split(":") + return dt_time(hour=int(parts[0]), minute=int(parts[1])) + except (ValueError, IndexError): + return dt_time(hour=3, minute=0) # Default 03:00 + + def _past_daily_reset(self, started_at: float, now: float) -> bool: + """Check if we've passed the daily reset time since window started.""" + start_dt = datetime.fromtimestamp(started_at, tz=timezone.utc) + now_dt = datetime.fromtimestamp(now, tz=timezone.utc) + + # Get reset time for the day after start + reset_dt = start_dt.replace( + hour=self.daily_reset_time_utc.hour, + minute=self.daily_reset_time_utc.minute, + second=0, + microsecond=0, + ) + if reset_dt <= start_dt: + # Reset time already passed today, use tomorrow + from datetime import timedelta + + reset_dt += timedelta(days=1) + + return now_dt >= reset_dt + + def _next_daily_reset(self, from_time: float) -> float: + """Calculate next daily reset timestamp.""" + from datetime import timedelta + + from_dt = datetime.fromtimestamp(from_time, tz=timezone.utc) + reset_dt = from_dt.replace( + hour=self.daily_reset_time_utc.hour, + minute=self.daily_reset_time_utc.minute, + second=0, + microsecond=0, + ) + if reset_dt <= from_dt: + reset_dt += timedelta(days=1) + + return reset_dt.timestamp() + + def _past_weekly_reset(self, started_at: float, now: float) -> bool: + """Check if we've passed the weekly reset (Sunday 03:00 UTC).""" + start_dt = datetime.fromtimestamp(started_at, tz=timezone.utc) + now_dt = datetime.fromtimestamp(now, tz=timezone.utc) + + # Get start of next week (Sunday 03:00 UTC) + days_until_sunday = (6 - start_dt.weekday()) % 7 + if days_until_sunday == 0 and start_dt.hour >= 3: + days_until_sunday = 7 + + from datetime import timedelta + + reset_dt = start_dt.replace( + hour=3, minute=0, second=0, microsecond=0 + ) + timedelta(days=days_until_sunday) + + return now_dt >= reset_dt + + def _next_weekly_reset(self, from_time: float) -> float: + """Calculate next weekly reset timestamp.""" + from datetime import timedelta + + from_dt = datetime.fromtimestamp(from_time, tz=timezone.utc) + days_until_sunday = (6 - from_dt.weekday()) % 7 + if days_until_sunday == 0 and from_dt.hour >= 3: + days_until_sunday = 7 + + reset_dt = from_dt.replace( + hour=3, minute=0, second=0, microsecond=0 + ) + timedelta(days=days_until_sunday) + + return reset_dt.timestamp() + + def _past_monthly_reset(self, started_at: float, now: float) -> bool: + """Check if we've passed the monthly reset (1st 03:00 UTC).""" + start_dt = datetime.fromtimestamp(started_at, tz=timezone.utc) + now_dt = datetime.fromtimestamp(now, tz=timezone.utc) + + # Get 1st of next month + if start_dt.month == 12: + reset_dt = start_dt.replace( + year=start_dt.year + 1, + month=1, + day=1, + hour=3, + minute=0, + second=0, + microsecond=0, + ) + else: + reset_dt = start_dt.replace( + month=start_dt.month + 1, + day=1, + hour=3, + minute=0, + second=0, + microsecond=0, + ) + + return now_dt >= reset_dt + + def _next_monthly_reset(self, from_time: float) -> float: + """Calculate next monthly reset timestamp.""" + from_dt = datetime.fromtimestamp(from_time, tz=timezone.utc) + + if from_dt.month == 12: + reset_dt = from_dt.replace( + year=from_dt.year + 1, + month=1, + day=1, + hour=3, + minute=0, + second=0, + microsecond=0, + ) + else: + reset_dt = from_dt.replace( + month=from_dt.month + 1, + day=1, + hour=3, + minute=0, + second=0, + microsecond=0, + ) + + return reset_dt.timestamp() diff --git a/src/rotator_library/usage/types.py b/src/rotator_library/usage/types.py new file mode 100644 index 00000000..3a4bd08a --- /dev/null +++ b/src/rotator_library/usage/types.py @@ -0,0 +1,458 @@ +# SPDX-License-Identifier: LGPL-3.0-only +# Copyright (c) 2026 Mirrowel + +""" +Type definitions for the usage tracking package. + +This module contains dataclasses and type definitions specific to +usage tracking, limits, and credential selection. +""" + +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union + + +# ============================================================================= +# ENUMS +# ============================================================================= + + +FAIR_CYCLE_GLOBAL_KEY = "_credential_" + + +class ResetMode(str, Enum): + """How a usage window resets.""" + + ROLLING = "rolling" # Continuous rolling window + FIXED_DAILY = "fixed_daily" # Reset at specific time each day + CALENDAR_WEEKLY = "calendar_weekly" # Reset at start of week + CALENDAR_MONTHLY = "calendar_monthly" # Reset at start of month + API_AUTHORITATIVE = "api_authoritative" # Provider API determines reset + + +class LimitResult(str, Enum): + """Result of a limit check.""" + + ALLOWED = "allowed" + BLOCKED_WINDOW = "blocked_window" + BLOCKED_COOLDOWN = "blocked_cooldown" + BLOCKED_FAIR_CYCLE = "blocked_fair_cycle" + BLOCKED_CUSTOM_CAP = "blocked_custom_cap" + BLOCKED_CONCURRENT = "blocked_concurrent" + + +class RotationMode(str, Enum): + """How credentials are rotated.""" + + BALANCED = "balanced" # Weighted random selection + SEQUENTIAL = "sequential" # Sticky until exhausted + + +class TrackingMode(str, Enum): + """How fair cycle tracks exhaustion.""" + + MODEL_GROUP = "model_group" # Track per quota group or model + CREDENTIAL = "credential" # Track per credential globally + + +class CooldownMode(str, Enum): + """How custom cap cooldowns are calculated.""" + + QUOTA_RESET = "quota_reset" # Wait until quota window resets + OFFSET = "offset" # Add offset seconds to current time + FIXED = "fixed" # Use fixed duration + + +# ============================================================================= +# WINDOW STATS +# ============================================================================= + + +@dataclass +class WindowStats: + """ + Statistics for a single time-based usage window (e.g., 5h, daily). + + Tracks usage within a specific time window for quota management. + """ + + name: str # Window identifier (e.g., "5h", "daily") + request_count: int = 0 + success_count: int = 0 + failure_count: int = 0 + + # Token stats + prompt_tokens: int = 0 + completion_tokens: int = 0 + thinking_tokens: int = 0 + output_tokens: int = 0 # completion + thinking + prompt_tokens_cache_read: int = 0 + prompt_tokens_cache_write: int = 0 + total_tokens: int = 0 + + approx_cost: float = 0.0 + + # Window timing + started_at: Optional[float] = None # When window period started + reset_at: Optional[float] = None # When window resets + limit: Optional[int] = None # Max requests allowed (None = unlimited) + + # Historical max tracking (persists across window resets) + max_recorded_requests: Optional[int] = ( + None # Highest request_count ever in any window period + ) + max_recorded_at: Optional[float] = None # When the max was recorded + + # Usage timing (for smart selection) + first_used_at: Optional[float] = None # First request in this window + last_used_at: Optional[float] = None # Last request in this window + + @property + def remaining(self) -> Optional[int]: + """Remaining requests in this window, or None if unlimited.""" + if self.limit is None: + return None + return max(0, self.limit - self.request_count) + + @property + def is_exhausted(self) -> bool: + """True if limit reached.""" + if self.limit is None: + return False + return self.request_count >= self.limit + + +# ============================================================================= +# TOTAL STATS +# ============================================================================= + + +@dataclass +class TotalStats: + """ + All-time totals for a model, group, or credential. + + Tracks cumulative usage across all time (never resets). + """ + + request_count: int = 0 + success_count: int = 0 + failure_count: int = 0 + + # Token stats + prompt_tokens: int = 0 + completion_tokens: int = 0 + thinking_tokens: int = 0 + output_tokens: int = 0 # completion + thinking + prompt_tokens_cache_read: int = 0 + prompt_tokens_cache_write: int = 0 + total_tokens: int = 0 + + approx_cost: float = 0.0 + + # Timestamps + first_used_at: Optional[float] = None # All-time first use + last_used_at: Optional[float] = None # All-time last use + + +# ============================================================================= +# MODEL & GROUP STATS CONTAINERS +# ============================================================================= + + +@dataclass +class ModelStats: + """ + Stats for a single model (own usage only). + + Contains time-based windows and all-time totals. + Each model only tracks its own usage, not shared quota. + """ + + windows: Dict[str, WindowStats] = field(default_factory=dict) + totals: TotalStats = field(default_factory=TotalStats) + + +@dataclass +class GroupStats: + """ + Stats for a quota group (shared usage). + + Contains time-based windows and all-time totals. + Updated when ANY model in the group is used. + """ + + windows: Dict[str, WindowStats] = field(default_factory=dict) + totals: TotalStats = field(default_factory=TotalStats) + + +# ============================================================================= +# COOLDOWN TYPES +# ============================================================================= + + +@dataclass +class CooldownInfo: + """ + Information about a cooldown period. + + Cooldowns temporarily block a credential from being used. + """ + + reason: str # Why the cooldown was applied + until: float # Timestamp when cooldown ends + started_at: float # Timestamp when cooldown started + source: str = "system" # "system", "custom_cap", "rate_limit", "provider_hook" + model_or_group: Optional[str] = None # Scope of cooldown (None = credential-wide) + backoff_count: int = 0 # Number of consecutive cooldowns + + @property + def remaining_seconds(self) -> float: + """Seconds remaining in cooldown.""" + import time + + return max(0.0, self.until - time.time()) + + @property + def is_active(self) -> bool: + """True if cooldown is still in effect.""" + import time + + return time.time() < self.until + + +# ============================================================================= +# FAIR CYCLE TYPES +# ============================================================================= + + +@dataclass +class FairCycleState: + """ + Fair cycle state for a credential. + + Tracks whether a credential has been exhausted in the current cycle. + """ + + exhausted: bool = False + exhausted_at: Optional[float] = None + exhausted_reason: Optional[str] = None + cycle_request_count: int = 0 # Requests in current cycle + model_or_group: Optional[str] = None # Scope of exhaustion + + +@dataclass +class GlobalFairCycleState: + """ + Global fair cycle state for a provider. + + Tracks the overall cycle across all credentials. + """ + + cycle_start: float = 0.0 # Timestamp when current cycle started + all_exhausted_at: Optional[float] = None # When all credentials exhausted + cycle_count: int = 0 # How many full cycles completed + + +# ============================================================================= +# USAGE UPDATE (for consolidated tracking) +# ============================================================================= + + +@dataclass +class UsageUpdate: + """ + All data for a single usage update. + + Used by TrackingEngine.record_usage() to apply updates atomically. + """ + + request_count: int = 1 + success: bool = True + + # Tokens (optional) + prompt_tokens: int = 0 + completion_tokens: int = 0 + thinking_tokens: int = 0 + prompt_tokens_cache_read: int = 0 + prompt_tokens_cache_write: int = 0 + approx_cost: float = 0.0 + + +# ============================================================================= +# CREDENTIAL STATE +# ============================================================================= + + +@dataclass +class CredentialState: + """ + Complete state for a single credential. + + This is the primary storage unit for credential data. + """ + + # Identity + stable_id: str # Email (OAuth) or hash (API key) + provider: str + accessor: str # Current file path or API key + display_name: Optional[str] = None + tier: Optional[str] = None + priority: int = 999 # Lower = higher priority + + # Stats - source of truth + model_usage: Dict[str, ModelStats] = field(default_factory=dict) + group_usage: Dict[str, GroupStats] = field(default_factory=dict) + totals: TotalStats = field(default_factory=TotalStats) # Credential-level totals + + # Cooldowns (keyed by model/group or "_global_") + cooldowns: Dict[str, CooldownInfo] = field(default_factory=dict) + + # Fair cycle state (keyed by model/group) + fair_cycle: Dict[str, FairCycleState] = field(default_factory=dict) + + # Active requests (for concurrent request limiting) + active_requests: int = 0 + max_concurrent: Optional[int] = None + + # Metadata + created_at: Optional[float] = None + last_updated: Optional[float] = None + + def get_cooldown( + self, model_or_group: Optional[str] = None + ) -> Optional[CooldownInfo]: + """Get active cooldown for given scope.""" + import time + + now = time.time() + + # Check specific cooldown + if model_or_group: + cooldown = self.cooldowns.get(model_or_group) + if cooldown and cooldown.until > now: + return cooldown + + # Check global cooldown + global_cooldown = self.cooldowns.get("_global_") + if global_cooldown and global_cooldown.until > now: + return global_cooldown + + return None + + def is_fair_cycle_exhausted(self, model_or_group: str) -> bool: + """Check if exhausted for fair cycle purposes.""" + state = self.fair_cycle.get(model_or_group) + return state.exhausted if state else False + + def get_model_stats(self, model: str, create: bool = True) -> Optional[ModelStats]: + """Get model stats, optionally creating if not exists.""" + if create: + return self.model_usage.setdefault(model, ModelStats()) + return self.model_usage.get(model) + + def get_group_stats(self, group: str, create: bool = True) -> Optional[GroupStats]: + """Get group stats, optionally creating if not exists.""" + if create: + return self.group_usage.setdefault(group, GroupStats()) + return self.group_usage.get(group) + + def get_window_for_model( + self, model: str, window_name: str + ) -> Optional[WindowStats]: + """Get a specific window for a model.""" + model_stats = self.model_usage.get(model) + if model_stats: + return model_stats.windows.get(window_name) + return None + + def get_window_for_group( + self, group: str, window_name: str + ) -> Optional[WindowStats]: + """Get a specific window for a group.""" + group_stats = self.group_usage.get(group) + if group_stats: + return group_stats.windows.get(window_name) + return None + + +# ============================================================================= +# SELECTION TYPES +# ============================================================================= + + +@dataclass +class SelectionContext: + """ + Context passed to rotation strategies during credential selection. + + Contains all information needed to make a selection decision. + """ + + provider: str + model: str + quota_group: Optional[str] # Quota group for this model + candidates: List[str] # Stable IDs of available candidates + priorities: Dict[str, int] # stable_id -> priority + usage_counts: Dict[str, int] # stable_id -> request count for relevant window + rotation_mode: RotationMode + rotation_tolerance: float + deadline: float + + +@dataclass +class LimitCheckResult: + """ + Result of checking all limits for a credential. + + Used by LimitEngine to report why a credential was blocked. + """ + + allowed: bool + result: LimitResult = LimitResult.ALLOWED + reason: Optional[str] = None + blocked_until: Optional[float] = None # When the block expires + + @classmethod + def ok(cls) -> "LimitCheckResult": + """Create an allowed result.""" + return cls(allowed=True, result=LimitResult.ALLOWED) + + @classmethod + def blocked( + cls, + result: LimitResult, + reason: str, + blocked_until: Optional[float] = None, + ) -> "LimitCheckResult": + """Create a blocked result.""" + return cls( + allowed=False, + result=result, + reason=reason, + blocked_until=blocked_until, + ) + + +# ============================================================================= +# STORAGE TYPES +# ============================================================================= + + +@dataclass +class StorageSchema: + """ + Root schema for usage.json storage file. + """ + + schema_version: int = 2 + updated_at: Optional[str] = None # ISO format + credentials: Dict[str, Dict[str, Any]] = field(default_factory=dict) + accessor_index: Dict[str, str] = field( + default_factory=dict + ) # accessor -> stable_id + fair_cycle_global: Dict[str, Dict[str, Any]] = field( + default_factory=dict + ) # provider -> GlobalFairCycleState diff --git a/src/rotator_library/usage_manager.py b/src/rotator_library/usage_manager.py index 46e30bbc..5430dc80 100644 --- a/src/rotator_library/usage_manager.py +++ b/src/rotator_library/usage_manager.py @@ -1,3980 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-only # Copyright (c) 2026 Mirrowel -import json -import os -import time -import logging -import asyncio -import random -from datetime import date, datetime, timezone, time as dt_time -from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, Union -import aiofiles -import litellm +"""Compatibility shim for legacy imports.""" -from .error_handler import ClassifiedError, NoAvailableKeysError, mask_credential -from .providers import PROVIDER_PLUGINS -from .utils.resilient_io import ResilientStateWriter -from .utils.paths import get_data_file -from .config import ( - DEFAULT_FAIR_CYCLE_DURATION, - DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, - DEFAULT_CUSTOM_CAP_COOLDOWN_MODE, - DEFAULT_CUSTOM_CAP_COOLDOWN_VALUE, - COOLDOWN_BACKOFF_TIERS, - COOLDOWN_BACKOFF_MAX, - COOLDOWN_AUTH_ERROR, - COOLDOWN_TRANSIENT_ERROR, - COOLDOWN_RATE_LIMIT_DEFAULT, -) +from .usage import UsageManager, CredentialContext -lib_logger = logging.getLogger("rotator_library") -lib_logger.propagate = False -if not lib_logger.handlers: - lib_logger.addHandler(logging.NullHandler()) - - -class UsageManager: - """ - Manages usage statistics and cooldowns for API keys with asyncio-safe locking, - asynchronous file I/O, lazy-loading mechanism, and weighted random credential rotation. - - The credential rotation strategy can be configured via the `rotation_tolerance` parameter: - - - **tolerance = 0.0**: Deterministic least-used selection. The credential with - the lowest usage count is always selected. This provides predictable, perfectly balanced - load distribution but may be vulnerable to fingerprinting. - - - **tolerance = 2.0 - 4.0 (default, recommended)**: Balanced weighted randomness. Credentials are selected - randomly with weights biased toward less-used ones. Credentials within 2 uses of the - maximum can still be selected with reasonable probability. This provides security through - unpredictability while maintaining good load balance. - - - **tolerance = 5.0+**: High randomness. Even heavily-used credentials have significant - selection probability. Useful for stress testing or maximum unpredictability, but may - result in less balanced load distribution. - - The weight formula is: `weight = (max_usage - credential_usage) + tolerance + 1` - - This ensures lower-usage credentials are preferred while tolerance controls how much - randomness is introduced into the selection process. - - Additionally, providers can specify a rotation mode: - - "balanced" (default): Rotate credentials to distribute load evenly - - "sequential": Use one credential until exhausted (preserves caching) - """ - - def __init__( - self, - file_path: Optional[Union[str, Path]] = None, - daily_reset_time_utc: Optional[str] = "03:00", - rotation_tolerance: float = 0.0, - provider_rotation_modes: Optional[Dict[str, str]] = None, - provider_plugins: Optional[Dict[str, Any]] = None, - priority_multipliers: Optional[Dict[str, Dict[int, int]]] = None, - priority_multipliers_by_mode: Optional[ - Dict[str, Dict[str, Dict[int, int]]] - ] = None, - sequential_fallback_multipliers: Optional[Dict[str, int]] = None, - fair_cycle_enabled: Optional[Dict[str, bool]] = None, - fair_cycle_tracking_mode: Optional[Dict[str, str]] = None, - fair_cycle_cross_tier: Optional[Dict[str, bool]] = None, - fair_cycle_duration: Optional[Dict[str, int]] = None, - exhaustion_cooldown_threshold: Optional[Dict[str, int]] = None, - custom_caps: Optional[ - Dict[str, Dict[Union[int, Tuple[int, ...], str], Dict[str, Dict[str, Any]]]] - ] = None, - ): - """ - Initialize the UsageManager. - - Args: - file_path: Path to the usage data JSON file. If None, uses get_data_file("key_usage.json"). - Can be absolute Path, relative Path, or string. - daily_reset_time_utc: Time in UTC when daily stats should reset (HH:MM format) - rotation_tolerance: Tolerance for weighted random credential rotation. - - 0.0: Deterministic, least-used credential always selected - - tolerance = 2.0 - 4.0 (default, recommended): Balanced randomness, can pick credentials within 2 uses of max - - 5.0+: High randomness, more unpredictable selection patterns - provider_rotation_modes: Dict mapping provider names to rotation modes. - - "balanced": Rotate credentials to distribute load evenly (default) - - "sequential": Use one credential until exhausted (preserves caching) - provider_plugins: Dict mapping provider names to provider plugin instances. - Used for per-provider usage reset configuration (window durations, field names). - priority_multipliers: Dict mapping provider -> priority -> multiplier. - Universal multipliers that apply regardless of rotation mode. - Example: {"antigravity": {1: 5, 2: 3}} - priority_multipliers_by_mode: Dict mapping provider -> mode -> priority -> multiplier. - Mode-specific overrides. Example: {"antigravity": {"balanced": {3: 1}}} - sequential_fallback_multipliers: Dict mapping provider -> fallback multiplier. - Used in sequential mode when priority not in priority_multipliers. - Example: {"antigravity": 2} - fair_cycle_enabled: Dict mapping provider -> bool to enable fair cycle rotation. - When enabled, credentials must all exhaust before any can be reused. - Default: enabled for sequential mode only. - fair_cycle_tracking_mode: Dict mapping provider -> tracking mode. - - "model_group": Track per quota group or model (default) - - "credential": Track per credential globally - fair_cycle_cross_tier: Dict mapping provider -> bool for cross-tier tracking. - - False: Each tier cycles independently (default) - - True: All credentials must exhaust regardless of tier - fair_cycle_duration: Dict mapping provider -> cycle duration in seconds. - Default: 86400 (24 hours) - exhaustion_cooldown_threshold: Dict mapping provider -> threshold in seconds. - A cooldown must exceed this to qualify as "exhausted". Default: 300 (5 min) - custom_caps: Dict mapping provider -> tier -> model/group -> cap config. - Allows setting custom usage limits per tier, per model or quota group. - See ProviderInterface.default_custom_caps for format details. - """ - # Resolve file_path - use default if not provided - if file_path is None: - self.file_path = str(get_data_file("key_usage.json")) - elif isinstance(file_path, Path): - self.file_path = str(file_path) - else: - # String path - could be relative or absolute - self.file_path = file_path - self.rotation_tolerance = rotation_tolerance - self.provider_rotation_modes = provider_rotation_modes or {} - self.provider_plugins = provider_plugins or PROVIDER_PLUGINS - self.priority_multipliers = priority_multipliers or {} - self.priority_multipliers_by_mode = priority_multipliers_by_mode or {} - self.sequential_fallback_multipliers = sequential_fallback_multipliers or {} - self._provider_instances: Dict[str, Any] = {} # Cache for provider instances - self.key_states: Dict[str, Dict[str, Any]] = {} - - # Fair cycle rotation configuration - self.fair_cycle_enabled = fair_cycle_enabled or {} - self.fair_cycle_tracking_mode = fair_cycle_tracking_mode or {} - self.fair_cycle_cross_tier = fair_cycle_cross_tier or {} - self.fair_cycle_duration = fair_cycle_duration or {} - self.exhaustion_cooldown_threshold = exhaustion_cooldown_threshold or {} - self.custom_caps = custom_caps or {} - # In-memory cycle state: {provider: {tier_key: {tracking_key: {"cycle_started_at": float, "exhausted": Set[str]}}}} - self._cycle_exhausted: Dict[str, Dict[str, Dict[str, Dict[str, Any]]]] = {} - - self._data_lock = asyncio.Lock() - self._usage_data: Optional[Dict] = None - self._initialized = asyncio.Event() - self._init_lock = asyncio.Lock() - - self._timeout_lock = asyncio.Lock() - self._claimed_on_timeout: Set[str] = set() - - # Resilient writer for usage data persistence - self._state_writer = ResilientStateWriter(file_path, lib_logger) - - if daily_reset_time_utc: - hour, minute = map(int, daily_reset_time_utc.split(":")) - self.daily_reset_time_utc = dt_time( - hour=hour, minute=minute, tzinfo=timezone.utc - ) - else: - self.daily_reset_time_utc = None - - def _get_rotation_mode(self, provider: str) -> str: - """ - Get the rotation mode for a provider. - - Args: - provider: Provider name (e.g., "antigravity", "gemini_cli") - - Returns: - "balanced" or "sequential" - """ - return self.provider_rotation_modes.get(provider, "balanced") - - # ========================================================================= - # FAIR CYCLE ROTATION HELPERS - # ========================================================================= - - def _is_fair_cycle_enabled(self, provider: str, rotation_mode: str) -> bool: - """ - Check if fair cycle rotation is enabled for a provider. - - Args: - provider: Provider name - rotation_mode: Current rotation mode ("balanced" or "sequential") - - Returns: - True if fair cycle is enabled - """ - # Check provider-specific setting first - if provider in self.fair_cycle_enabled: - return self.fair_cycle_enabled[provider] - # Default: enabled only for sequential mode - return rotation_mode == "sequential" - - def _get_fair_cycle_tracking_mode(self, provider: str) -> str: - """ - Get fair cycle tracking mode for a provider. - - Returns: - "model_group" or "credential" - """ - return self.fair_cycle_tracking_mode.get(provider, "model_group") - - def _is_fair_cycle_cross_tier(self, provider: str) -> bool: - """ - Check if fair cycle tracks across all tiers (ignoring priority boundaries). - - Returns: - True if cross-tier tracking is enabled - """ - return self.fair_cycle_cross_tier.get(provider, False) - - def _get_fair_cycle_duration(self, provider: str) -> int: - """ - Get fair cycle duration in seconds for a provider. - - Returns: - Duration in seconds (default 86400 = 24 hours) - """ - return self.fair_cycle_duration.get(provider, DEFAULT_FAIR_CYCLE_DURATION) - - def _get_exhaustion_cooldown_threshold(self, provider: str) -> int: - """ - Get exhaustion cooldown threshold in seconds for a provider. - - A cooldown must exceed this duration to qualify as "exhausted" for fair cycle. - - Returns: - Threshold in seconds (default 300 = 5 minutes) - """ - return self.exhaustion_cooldown_threshold.get( - provider, DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD - ) - - # ========================================================================= - # CUSTOM CAPS HELPERS - # ========================================================================= - - def _get_custom_cap_config( - self, - provider: str, - tier_priority: int, - model: str, - ) -> Optional[Dict[str, Any]]: - """ - Get custom cap config for a provider/tier/model combination. - - Resolution order: - 1. tier + model (exact match) - 2. tier + group (model's quota group) - 3. "default" + model - 4. "default" + group - - Args: - provider: Provider name - tier_priority: Credential's priority level - model: Model name (with provider prefix) - - Returns: - Cap config dict or None if no custom cap applies - """ - provider_caps = self.custom_caps.get(provider) - if not provider_caps: - return None - - # Strip provider prefix from model - clean_model = model.split("/")[-1] if "/" in model else model - - # Get quota group for this model - group = self._get_model_quota_group_by_provider(provider, model) - - # Try to find matching tier config - tier_config = None - default_config = None - - for tier_key, models_config in provider_caps.items(): - if tier_key == "default": - default_config = models_config - continue - - # Check if this tier_key matches our priority - if isinstance(tier_key, int) and tier_key == tier_priority: - tier_config = models_config - break - elif isinstance(tier_key, tuple) and tier_priority in tier_key: - tier_config = models_config - break - - # Resolution order for tier config - if tier_config: - # Try model first - if clean_model in tier_config: - return tier_config[clean_model] - # Try group - if group and group in tier_config: - return tier_config[group] - - # Resolution order for default config - if default_config: - # Try model first - if clean_model in default_config: - return default_config[clean_model] - # Try group - if group and group in default_config: - return default_config[group] - - return None - - def _get_model_quota_group_by_provider( - self, provider: str, model: str - ) -> Optional[str]: - """ - Get quota group for a model using provider name instead of credential. - - Args: - provider: Provider name - model: Model name - - Returns: - Group name or None - """ - plugin_instance = self._get_provider_instance(provider) - if plugin_instance and hasattr(plugin_instance, "get_model_quota_group"): - return plugin_instance.get_model_quota_group(model) - return None - - def _resolve_custom_cap_max( - self, - provider: str, - model: str, - cap_config: Dict[str, Any], - actual_max: Optional[int], - ) -> Optional[int]: - """ - Resolve custom cap max_requests value, handling percentages and clamping. - - Args: - provider: Provider name - model: Model name (for logging) - cap_config: Custom cap configuration - actual_max: Actual API max requests (may be None if unknown) - - Returns: - Resolved cap value (clamped), or None if can't be calculated - """ - max_requests = cap_config.get("max_requests") - if max_requests is None: - return None - - # Handle percentage - if isinstance(max_requests, str) and max_requests.endswith("%"): - if actual_max is None: - lib_logger.warning( - f"Custom cap '{max_requests}' for {provider}/{model} requires known max_requests. " - f"Skipping until quota baseline is fetched. Use absolute value for immediate enforcement." - ) - return None - try: - percentage = float(max_requests.rstrip("%")) / 100.0 - calculated = int(actual_max * percentage) - except ValueError: - lib_logger.warning( - f"Invalid percentage cap '{max_requests}' for {provider}/{model}" - ) - return None - else: - # Absolute value - try: - calculated = int(max_requests) - except (ValueError, TypeError): - lib_logger.warning( - f"Invalid cap value '{max_requests}' for {provider}/{model}" - ) - return None - - # Clamp to actual max (can only be MORE restrictive) - if actual_max is not None: - return min(calculated, actual_max) - return calculated - - def _calculate_custom_cooldown_until( - self, - cap_config: Dict[str, Any], - window_start_ts: Optional[float], - natural_reset_ts: Optional[float], - ) -> Optional[float]: - """ - Calculate when custom cap cooldown should end, clamped to natural reset. - - Args: - cap_config: Custom cap configuration - window_start_ts: When first request was made (for fixed mode) - natural_reset_ts: Natural quota reset timestamp - - Returns: - Cooldown end timestamp (clamped), or None if can't calculate - """ - mode = cap_config.get("cooldown_mode", DEFAULT_CUSTOM_CAP_COOLDOWN_MODE) - value = cap_config.get("cooldown_value", DEFAULT_CUSTOM_CAP_COOLDOWN_VALUE) - - if mode == "quota_reset": - calculated = natural_reset_ts - elif mode == "offset": - if natural_reset_ts is None: - return None - calculated = natural_reset_ts + value - elif mode == "fixed": - if window_start_ts is None: - return None - calculated = window_start_ts + value - else: - lib_logger.warning(f"Unknown cooldown_mode '{mode}', using quota_reset") - calculated = natural_reset_ts - - if calculated is None: - return None - - # Clamp to natural reset (can only be MORE restrictive = longer cooldown) - if natural_reset_ts is not None: - return max(calculated, natural_reset_ts) - return calculated - - def _check_and_apply_custom_cap( - self, - credential: str, - model: str, - request_count: int, - ) -> bool: - """ - Check if custom cap is exceeded and apply cooldown if so. - - This should be called after incrementing request_count in record_success(). - - Args: - credential: Credential identifier - model: Model name (with provider prefix) - request_count: Current request count for this model - - Returns: - True if cap exceeded and cooldown applied, False otherwise - """ - provider = self._get_provider_from_credential(credential) - if not provider: - return False - - priority = self._get_credential_priority(credential, provider) - cap_config = self._get_custom_cap_config(provider, priority, model) - if not cap_config: - return False - - # Get model data for actual max and timing info - key_data = self._usage_data.get(credential, {}) - model_data = key_data.get("models", {}).get(model, {}) - actual_max = model_data.get("quota_max_requests") - window_start_ts = model_data.get("window_start_ts") - natural_reset_ts = model_data.get("quota_reset_ts") - - # Resolve custom cap max - custom_max = self._resolve_custom_cap_max( - provider, model, cap_config, actual_max - ) - if custom_max is None: - return False - - # Check if exceeded - if request_count < custom_max: - return False - - # Calculate cooldown end time - cooldown_until = self._calculate_custom_cooldown_until( - cap_config, window_start_ts, natural_reset_ts - ) - if cooldown_until is None: - # Can't calculate cooldown, use natural reset if available - if natural_reset_ts: - cooldown_until = natural_reset_ts - else: - lib_logger.warning( - f"Custom cap hit for {mask_credential(credential)}/{model} but can't calculate cooldown. " - f"Skipping cooldown application." - ) - return False - - now_ts = time.time() - - # Apply cooldown - model_cooldowns = key_data.setdefault("model_cooldowns", {}) - model_cooldowns[model] = cooldown_until - - # Store custom cap info in model data for reference - model_data["custom_cap_max"] = custom_max - model_data["custom_cap_hit_at"] = now_ts - model_data["custom_cap_cooldown_until"] = cooldown_until - - hours_until = (cooldown_until - now_ts) / 3600 - lib_logger.info( - f"Custom cap hit: {mask_credential(credential)} reached {request_count}/{custom_max} " - f"for {model}. Cooldown for {hours_until:.1f}h" - ) - - # Sync cooldown across quota group - group = self._get_model_quota_group(credential, model) - if group: - grouped_models = self._get_grouped_models(credential, group) - for grouped_model in grouped_models: - if grouped_model != model: - model_cooldowns[grouped_model] = cooldown_until - - # Check if this should trigger fair cycle exhaustion - cooldown_duration = cooldown_until - now_ts - threshold = self._get_exhaustion_cooldown_threshold(provider) - if cooldown_duration > threshold: - rotation_mode = self._get_rotation_mode(provider) - if self._is_fair_cycle_enabled(provider, rotation_mode): - tier_key = self._get_tier_key(provider, priority) - tracking_key = self._get_tracking_key(credential, model, provider) - self._mark_credential_exhausted( - credential, provider, tier_key, tracking_key - ) - - return True - - def _get_tier_key(self, provider: str, priority: int) -> str: - """ - Get the tier key for cycle tracking based on cross_tier setting. - - Args: - provider: Provider name - priority: Credential priority level - - Returns: - "__all_tiers__" if cross-tier enabled, else str(priority) - """ - if self._is_fair_cycle_cross_tier(provider): - return "__all_tiers__" - return str(priority) - - def _get_tracking_key(self, credential: str, model: str, provider: str) -> str: - """ - Get the key for exhaustion tracking based on tracking mode. - - Args: - credential: Credential identifier - model: Model name (with provider prefix) - provider: Provider name - - Returns: - Tracking key string (quota group name, model name, or "__credential__") - """ - mode = self._get_fair_cycle_tracking_mode(provider) - if mode == "credential": - return "__credential__" - # model_group mode: use quota group if exists, else model - group = self._get_model_quota_group(credential, model) - return group if group else model - - def _get_credential_priority(self, credential: str, provider: str) -> int: - """ - Get the priority level for a credential. - - Args: - credential: Credential identifier - provider: Provider name - - Returns: - Priority level (default 999 if unknown) - """ - plugin_instance = self._get_provider_instance(provider) - if plugin_instance and hasattr(plugin_instance, "get_credential_priority"): - priority = plugin_instance.get_credential_priority(credential) - if priority is not None: - return priority - return 999 - - def _get_cycle_data( - self, provider: str, tier_key: str, tracking_key: str - ) -> Optional[Dict[str, Any]]: - """ - Get cycle data for a provider/tier/tracking key combination. - - Returns: - Cycle data dict or None if not exists - """ - return ( - self._cycle_exhausted.get(provider, {}).get(tier_key, {}).get(tracking_key) - ) - - def _ensure_cycle_structure( - self, provider: str, tier_key: str, tracking_key: str - ) -> Dict[str, Any]: - """ - Ensure the nested cycle structure exists and return the cycle data dict. - """ - if provider not in self._cycle_exhausted: - self._cycle_exhausted[provider] = {} - if tier_key not in self._cycle_exhausted[provider]: - self._cycle_exhausted[provider][tier_key] = {} - if tracking_key not in self._cycle_exhausted[provider][tier_key]: - self._cycle_exhausted[provider][tier_key][tracking_key] = { - "cycle_started_at": None, - "exhausted": set(), - } - return self._cycle_exhausted[provider][tier_key][tracking_key] - - def _mark_credential_exhausted( - self, - credential: str, - provider: str, - tier_key: str, - tracking_key: str, - ) -> None: - """ - Mark a credential as exhausted for fair cycle tracking. - - Starts the cycle timer on first exhaustion. - Skips if credential is already in the exhausted set (prevents duplicate logging). - """ - cycle_data = self._ensure_cycle_structure(provider, tier_key, tracking_key) - - # Skip if already exhausted in this cycle (prevents duplicate logging) - if credential in cycle_data.get("exhausted", set()): - return - - # Start cycle timer on first exhaustion - if cycle_data["cycle_started_at"] is None: - cycle_data["cycle_started_at"] = time.time() - lib_logger.info( - f"Fair cycle started for {provider} tier={tier_key} tracking='{tracking_key}'" - ) - - cycle_data["exhausted"].add(credential) - lib_logger.info( - f"Fair cycle: marked {mask_credential(credential)} exhausted " - f"for {tracking_key} ({len(cycle_data['exhausted'])} total)" - ) - - def _is_credential_exhausted_in_cycle( - self, - credential: str, - provider: str, - tier_key: str, - tracking_key: str, - ) -> bool: - """ - Check if a credential was exhausted in the current cycle. - """ - cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) - if cycle_data is None: - return False - return credential in cycle_data.get("exhausted", set()) - - def _is_cycle_expired( - self, provider: str, tier_key: str, tracking_key: str - ) -> bool: - """ - Check if the current cycle has exceeded its duration. - """ - cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) - if cycle_data is None: - return False - cycle_started = cycle_data.get("cycle_started_at") - if cycle_started is None: - return False - duration = self._get_fair_cycle_duration(provider) - return time.time() >= cycle_started + duration - - def _should_reset_cycle( - self, - provider: str, - tier_key: str, - tracking_key: str, - all_credentials_in_tier: List[str], - available_not_on_cooldown: Optional[List[str]] = None, - ) -> bool: - """ - Check if cycle should reset. - - Returns True if: - 1. Cycle duration has expired, OR - 2. No credentials remain available (after cooldown + fair cycle exclusion), OR - 3. All credentials in the tier have been marked exhausted (fallback) - """ - # Check duration first - if self._is_cycle_expired(provider, tier_key, tracking_key): - return True - - cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) - if cycle_data is None: - return False - - # If available credentials are provided, reset when none remain usable - if available_not_on_cooldown is not None: - has_available = any( - not self._is_credential_exhausted_in_cycle( - cred, provider, tier_key, tracking_key - ) - for cred in available_not_on_cooldown - ) - if not has_available and len(all_credentials_in_tier) > 0: - return True - - exhausted = cycle_data.get("exhausted", set()) - # All must be exhausted (and there must be at least one credential) - return ( - len(exhausted) >= len(all_credentials_in_tier) - and len(all_credentials_in_tier) > 0 - ) - - def _reset_cycle(self, provider: str, tier_key: str, tracking_key: str) -> None: - """ - Reset exhaustion tracking for a completed cycle. - """ - cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) - if cycle_data: - exhausted_count = len(cycle_data.get("exhausted", set())) - lib_logger.info( - f"Fair cycle complete for {provider} tier={tier_key} " - f"tracking='{tracking_key}' - resetting ({exhausted_count} credentials cycled)" - ) - cycle_data["cycle_started_at"] = None - cycle_data["exhausted"] = set() - - def _get_all_credentials_for_tier_key( - self, - provider: str, - tier_key: str, - available_keys: List[str], - credential_priorities: Optional[Dict[str, int]], - ) -> List[str]: - """ - Get all credentials that belong to a tier key. - - Args: - provider: Provider name - tier_key: Either "__all_tiers__" or str(priority) - available_keys: List of available credential identifiers - credential_priorities: Dict mapping credentials to priorities - - Returns: - List of credentials belonging to this tier key - """ - if tier_key == "__all_tiers__": - # Cross-tier: all credentials for this provider - return list(available_keys) - else: - # Within-tier: only credentials with matching priority - priority = int(tier_key) - if credential_priorities: - return [ - k - for k in available_keys - if credential_priorities.get(k, 999) == priority - ] - return list(available_keys) - - def _count_fair_cycle_excluded( - self, - provider: str, - tier_key: str, - tracking_key: str, - candidates: List[str], - ) -> int: - """ - Count how many candidates are excluded by fair cycle. - - Args: - provider: Provider name - tier_key: Tier key for tracking - tracking_key: Model/group tracking key - candidates: List of candidate credentials (not on cooldown) - - Returns: - Number of candidates excluded by fair cycle - """ - count = 0 - for cred in candidates: - if self._is_credential_exhausted_in_cycle( - cred, provider, tier_key, tracking_key - ): - count += 1 - return count - - def _get_priority_multiplier( - self, provider: str, priority: int, rotation_mode: str - ) -> int: - """ - Get the concurrency multiplier for a provider/priority/mode combination. - - Lookup order: - 1. Mode-specific tier override: priority_multipliers_by_mode[provider][mode][priority] - 2. Universal tier multiplier: priority_multipliers[provider][priority] - 3. Sequential fallback (if mode is sequential): sequential_fallback_multipliers[provider] - 4. Global default: 1 (no multiplier effect) - - Args: - provider: Provider name (e.g., "antigravity") - priority: Priority level (1 = highest priority) - rotation_mode: Current rotation mode ("sequential" or "balanced") - - Returns: - Multiplier value - """ - provider_lower = provider.lower() - - # 1. Check mode-specific override - if provider_lower in self.priority_multipliers_by_mode: - mode_multipliers = self.priority_multipliers_by_mode[provider_lower] - if rotation_mode in mode_multipliers: - if priority in mode_multipliers[rotation_mode]: - return mode_multipliers[rotation_mode][priority] - - # 2. Check universal tier multiplier - if provider_lower in self.priority_multipliers: - if priority in self.priority_multipliers[provider_lower]: - return self.priority_multipliers[provider_lower][priority] - - # 3. Sequential fallback (only for sequential mode) - if rotation_mode == "sequential": - if provider_lower in self.sequential_fallback_multipliers: - return self.sequential_fallback_multipliers[provider_lower] - - # 4. Global default - return 1 - - def _get_provider_from_credential(self, credential: str) -> Optional[str]: - """ - Extract provider name from credential path or identifier. - - Supports multiple credential formats: - - OAuth: "oauth_creds/antigravity_oauth_15.json" -> "antigravity" - - OAuth: "C:\\...\\oauth_creds\\gemini_cli_oauth_1.json" -> "gemini_cli" - - OAuth filename only: "antigravity_oauth_1.json" -> "antigravity" - - API key style: extracted from model names in usage data (e.g., "firmware/model" -> "firmware") - - Args: - credential: The credential identifier (path or key) - - Returns: - Provider name string or None if cannot be determined - """ - import re - - # Pattern: env:// URI format (e.g., "env://antigravity/1" -> "antigravity") - if credential.startswith("env://"): - parts = credential[6:].split("/") # Remove "env://" prefix - if parts and parts[0]: - return parts[0].lower() - # Malformed env:// URI (empty provider name) - lib_logger.warning(f"Malformed env:// credential URI: {credential}") - return None - - # Normalize path separators - normalized = credential.replace("\\", "/") - - # Pattern: path ending with {provider}_oauth_{number}.json - match = re.search(r"/([a-z_]+)_oauth_\d+\.json$", normalized, re.IGNORECASE) - if match: - return match.group(1).lower() - - # Pattern: oauth_creds/{provider}_... - match = re.search(r"oauth_creds/([a-z_]+)_", normalized, re.IGNORECASE) - if match: - return match.group(1).lower() - - # Pattern: filename only {provider}_oauth_{number}.json (no path) - match = re.match(r"([a-z_]+)_oauth_\d+\.json$", normalized, re.IGNORECASE) - if match: - return match.group(1).lower() - - # Pattern: API key prefixes for specific providers - # These are raw API keys with recognizable prefixes - api_key_prefixes = { - "sk-nano-": "nanogpt", - "sk-or-": "openrouter", - "sk-ant-": "anthropic", - } - for prefix, provider in api_key_prefixes.items(): - if credential.startswith(prefix): - return provider - - # Fallback: For raw API keys, extract provider from model names in usage data - # This handles providers like firmware, chutes, nanogpt that use credential-level quota - if self._usage_data and credential in self._usage_data: - cred_data = self._usage_data[credential] - - # Check "models" section first (for per_model mode and quota tracking) - models_data = cred_data.get("models", {}) - if models_data: - # Get first model name and extract provider prefix - first_model = next(iter(models_data.keys()), None) - if first_model and "/" in first_model: - provider = first_model.split("/")[0].lower() - return provider - - # Fallback to "daily" section (legacy structure) - daily_data = cred_data.get("daily", {}) - daily_models = daily_data.get("models", {}) - if daily_models: - # Get first model name and extract provider prefix - first_model = next(iter(daily_models.keys()), None) - if first_model and "/" in first_model: - provider = first_model.split("/")[0].lower() - return provider - - return None - - def _get_provider_instance(self, provider: str) -> Optional[Any]: - """ - Get or create a provider plugin instance. - - Args: - provider: The provider name - - Returns: - Provider plugin instance or None - """ - if not provider: - return None - - plugin_class = self.provider_plugins.get(provider) - if not plugin_class: - return None - - # Get or create provider instance from cache - if provider not in self._provider_instances: - # Instantiate the plugin if it's a class, or use it directly if already an instance - if isinstance(plugin_class, type): - self._provider_instances[provider] = plugin_class() - else: - self._provider_instances[provider] = plugin_class - - return self._provider_instances[provider] - - def _get_usage_reset_config(self, credential: str) -> Optional[Dict[str, Any]]: - """ - Get the usage reset configuration for a credential from its provider plugin. - - Args: - credential: The credential identifier - - Returns: - Configuration dict with window_seconds, field_name, etc. - or None to use default daily reset. - """ - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - - if plugin_instance and hasattr(plugin_instance, "get_usage_reset_config"): - return plugin_instance.get_usage_reset_config(credential) - - return None - - def _get_reset_mode(self, credential: str) -> str: - """ - Get the reset mode for a credential: 'credential' or 'per_model'. - - Args: - credential: The credential identifier - - Returns: - "per_model" or "credential" (default) - """ - config = self._get_usage_reset_config(credential) - return config.get("mode", "credential") if config else "credential" - - def _get_model_quota_group(self, credential: str, model: str) -> Optional[str]: - """ - Get the quota group for a model, if the provider defines one. - - Args: - credential: The credential identifier - model: Model name (with or without provider prefix) - - Returns: - Group name (e.g., "claude") or None if not grouped - """ - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - - if plugin_instance and hasattr(plugin_instance, "get_model_quota_group"): - return plugin_instance.get_model_quota_group(model) - - return None - - def _get_grouped_models(self, credential: str, group: str) -> List[str]: - """ - Get all model names in a quota group (with provider prefix), normalized. - - Returns only public-facing model names, deduplicated. Internal variants - (e.g., claude-sonnet-4-5-thinking) are normalized to their public name - (e.g., claude-sonnet-4.5). - - Args: - credential: The credential identifier - group: Group name (e.g., "claude") - - Returns: - List of normalized, deduplicated model names with provider prefix - (e.g., ["antigravity/claude-sonnet-4.5", "antigravity/claude-opus-4.5"]) - """ - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - - if plugin_instance and hasattr(plugin_instance, "get_models_in_quota_group"): - models = plugin_instance.get_models_in_quota_group(group) - - # Normalize and deduplicate - if hasattr(plugin_instance, "normalize_model_for_tracking"): - seen = set() - normalized = [] - for m in models: - prefixed = f"{provider}/{m}" - norm = plugin_instance.normalize_model_for_tracking(prefixed) - if norm not in seen: - seen.add(norm) - normalized.append(norm) - return normalized - - # Fallback: just add provider prefix - return [f"{provider}/{m}" for m in models] - - return [] - - def _get_model_usage_weight(self, credential: str, model: str) -> int: - """ - Get the usage weight for a model when calculating grouped usage. - - Args: - credential: The credential identifier - model: Model name (with or without provider prefix) - - Returns: - Weight multiplier (default 1 if not configured) - """ - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - - if plugin_instance and hasattr(plugin_instance, "get_model_usage_weight"): - return plugin_instance.get_model_usage_weight(model) - - return 1 - - def _normalize_model(self, credential: str, model: str) -> str: - """ - Normalize model name using provider's mapping. - - Converts internal model names (e.g., claude-sonnet-4-5-thinking) to - public-facing names (e.g., claude-sonnet-4.5) for consistent storage. - - Args: - credential: The credential identifier - model: Model name (with or without provider prefix) - - Returns: - Normalized model name (provider prefix preserved if present) - """ - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - - if plugin_instance and hasattr(plugin_instance, "normalize_model_for_tracking"): - return plugin_instance.normalize_model_for_tracking(model) - - return model - - # Providers where request_count should be used for credential selection - # instead of success_count (because failed requests also consume quota) - _REQUEST_COUNT_PROVIDERS = {"antigravity", "gemini_cli", "chutes", "nanogpt"} - - def _get_grouped_usage_count(self, key: str, model: str) -> int: - """ - Get usage count for credential selection, considering quota groups. - - For providers in _REQUEST_COUNT_PROVIDERS (e.g., antigravity), uses - request_count instead of success_count since failed requests also - consume quota. - - If the model belongs to a quota group, the request_count is already - synced across all models in the group (by record_success/record_failure), - so we just read from the requested model directly. - - Args: - key: Credential identifier - model: Model name (with provider prefix, e.g., "antigravity/claude-sonnet-4-5") - - Returns: - Usage count for the model (synced across group if applicable) - """ - # Determine usage field based on provider - # Some providers (antigravity) count failed requests against quota - provider = self._get_provider_from_credential(key) - usage_field = ( - "request_count" - if provider in self._REQUEST_COUNT_PROVIDERS - else "success_count" - ) - - # For providers with synced quota groups (antigravity), request_count - # is already synced across all models in the group, so just read directly. - # For other providers, we still need to sum success_count across group. - if provider in self._REQUEST_COUNT_PROVIDERS: - # request_count is synced - just read the model's value - return self._get_usage_count(key, model, usage_field) - - # For non-synced providers, check if model is in a quota group and sum - group = self._get_model_quota_group(key, model) - - if group: - # Get all models in the group - grouped_models = self._get_grouped_models(key, group) - - # Sum weighted usage across all models in the group - total_weighted_usage = 0 - for grouped_model in grouped_models: - usage = self._get_usage_count(key, grouped_model, usage_field) - weight = self._get_model_usage_weight(key, grouped_model) - total_weighted_usage += usage * weight - return total_weighted_usage - - # Not grouped - return individual model usage (no weight applied) - return self._get_usage_count(key, model, usage_field) - - def _get_quota_display(self, key: str, model: str) -> str: - """ - Get a formatted quota display string for logging. - - For antigravity (providers in _REQUEST_COUNT_PROVIDERS), returns: - "quota: 170/250 [32%]" format - - For other providers, returns: - "usage: 170" format (no max available) - - Args: - key: Credential identifier - model: Model name (with provider prefix) - - Returns: - Formatted string for logging - """ - provider = self._get_provider_from_credential(key) - - if provider not in self._REQUEST_COUNT_PROVIDERS: - # Non-antigravity: just show usage count - usage = self._get_usage_count(key, model, "success_count") - return f"usage: {usage}" - - # Antigravity: show quota display with remaining percentage - if self._usage_data is None: - return "quota: 0/? [100%]" - - # Normalize model name for consistent lookup (data is stored under normalized names) - model = self._normalize_model(key, model) - - key_data = self._usage_data.get(key, {}) - model_data = key_data.get("models", {}).get(model, {}) - - request_count = model_data.get("request_count", 0) - max_requests = model_data.get("quota_max_requests") - - if max_requests: - remaining = max_requests - request_count - remaining_pct = ( - int((remaining / max_requests) * 100) if max_requests > 0 else 0 - ) - return f"quota: {request_count}/{max_requests} [{remaining_pct}%]" - else: - return f"quota: {request_count}" - - def _get_usage_field_name(self, credential: str) -> str: - """ - Get the usage tracking field name for a credential. - - Returns the provider-specific field name if configured, - otherwise falls back to "daily". - - Args: - credential: The credential identifier - - Returns: - Field name string (e.g., "5h_window", "weekly", "daily") - """ - config = self._get_usage_reset_config(credential) - if config and "field_name" in config: - return config["field_name"] - - # Check provider default - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - - if plugin_instance and hasattr(plugin_instance, "get_default_usage_field_name"): - return plugin_instance.get_default_usage_field_name() - - return "daily" - - def _get_usage_count( - self, key: str, model: str, field: str = "success_count" - ) -> int: - """ - Get the current usage count for a model from the appropriate usage structure. - - Supports both: - - New per-model structure: {"models": {"model_name": {"success_count": N, ...}}} - - Legacy structure: {"daily": {"models": {"model_name": {"success_count": N, ...}}}} - - Args: - key: Credential identifier - model: Model name - field: The field to read for usage count (default: "success_count"). - Use "request_count" for providers where failed requests also - consume quota (e.g., antigravity). - - Returns: - Usage count for the model in the current window/period - """ - if self._usage_data is None: - return 0 - - # Normalize model name for consistent lookup (data is stored under normalized names) - model = self._normalize_model(key, model) - - key_data = self._usage_data.get(key, {}) - reset_mode = self._get_reset_mode(key) - - if reset_mode == "per_model": - # New per-model structure: key_data["models"][model][field] - return key_data.get("models", {}).get(model, {}).get(field, 0) - else: - # Legacy structure: key_data["daily"]["models"][model][field] - return ( - key_data.get("daily", {}).get("models", {}).get(model, {}).get(field, 0) - ) - - # ========================================================================= - # TIMESTAMP FORMATTING HELPERS - # ========================================================================= - - def _format_timestamp_local(self, ts: Optional[float]) -> Optional[str]: - """ - Format Unix timestamp as local time string with timezone offset. - - Args: - ts: Unix timestamp or None - - Returns: - Formatted string like "2025-12-07 14:30:17 +0100" or None - """ - if ts is None: - return None - try: - dt = datetime.fromtimestamp(ts).astimezone() # Local timezone - # Use UTC offset for conciseness (works on all platforms) - return dt.strftime("%Y-%m-%d %H:%M:%S %z") - except (OSError, ValueError, OverflowError): - return None - - def _add_readable_timestamps(self, data: Dict) -> Dict: - """ - Add human-readable timestamp fields to usage data before saving. - - Adds 'window_started' and 'quota_resets' fields derived from - Unix timestamps for easier debugging and monitoring. - - Args: - data: The usage data dict to enhance - - Returns: - The same dict with readable timestamp fields added - """ - for key, key_data in data.items(): - # Handle per-model structure - models = key_data.get("models", {}) - for model_name, model_stats in models.items(): - if not isinstance(model_stats, dict): - continue - - # Add readable window start time - window_start = model_stats.get("window_start_ts") - if window_start: - model_stats["window_started"] = self._format_timestamp_local( - window_start - ) - elif "window_started" in model_stats: - del model_stats["window_started"] - - # Add readable reset time - quota_reset = model_stats.get("quota_reset_ts") - if quota_reset: - model_stats["quota_resets"] = self._format_timestamp_local( - quota_reset - ) - elif "quota_resets" in model_stats: - del model_stats["quota_resets"] - - return data - - def _sort_sequential( - self, - candidates: List[Tuple[str, int]], - credential_priorities: Optional[Dict[str, int]] = None, - ) -> List[Tuple[str, int]]: - """ - Sort credentials for sequential mode with position retention. - - Credentials maintain their position based on established usage patterns, - ensuring that actively-used credentials remain primary until exhausted. - - Sorting order (within each sort key, lower value = higher priority): - 1. Priority tier (lower number = higher priority) - 2. Usage count (higher = more established in rotation, maintains position) - 3. Last used timestamp (higher = more recent, tiebreaker for stickiness) - 4. Credential ID (alphabetical, stable ordering) - - Args: - candidates: List of (credential_id, usage_count) tuples - credential_priorities: Optional dict mapping credentials to priority levels - - Returns: - Sorted list of candidates (same format as input) - """ - if not candidates: - return [] - - if len(candidates) == 1: - return candidates - - def sort_key(item: Tuple[str, int]) -> Tuple[int, int, float, str]: - cred, usage_count = item - priority = ( - credential_priorities.get(cred, 999) if credential_priorities else 999 - ) - last_used = ( - self._usage_data.get(cred, {}).get("last_used_ts", 0) - if self._usage_data - else 0 - ) - return ( - priority, # ASC: lower priority number = higher priority - -usage_count, # DESC: higher usage = more established - -last_used, # DESC: more recent = preferred for ties - cred, # ASC: stable alphabetical ordering - ) - - sorted_candidates = sorted(candidates, key=sort_key) - - # Debug logging - show top 3 credentials in ordering - if lib_logger.isEnabledFor(logging.DEBUG): - order_info = [ - f"{mask_credential(c)}(p={credential_priorities.get(c, 999) if credential_priorities else 'N/A'}, u={u})" - for c, u in sorted_candidates[:3] - ] - lib_logger.debug(f"Sequential ordering: {' → '.join(order_info)}") - - return sorted_candidates - - # ========================================================================= - # FAIR CYCLE PERSISTENCE - # ========================================================================= - - def _serialize_cycle_state(self) -> Dict[str, Any]: - """ - Serialize in-memory cycle state for JSON persistence. - - Converts sets to lists for JSON compatibility. - """ - result: Dict[str, Any] = {} - for provider, tier_data in self._cycle_exhausted.items(): - result[provider] = {} - for tier_key, tracking_data in tier_data.items(): - result[provider][tier_key] = {} - for tracking_key, cycle_data in tracking_data.items(): - result[provider][tier_key][tracking_key] = { - "cycle_started_at": cycle_data.get("cycle_started_at"), - "exhausted": list(cycle_data.get("exhausted", set())), - } - return result - - def _deserialize_cycle_state(self, data: Dict[str, Any]) -> None: - """ - Deserialize cycle state from JSON and populate in-memory structure. - - Converts lists back to sets and validates expired cycles. - """ - self._cycle_exhausted = {} - now_ts = time.time() - - for provider, tier_data in data.items(): - if not isinstance(tier_data, dict): - continue - self._cycle_exhausted[provider] = {} - - for tier_key, tracking_data in tier_data.items(): - if not isinstance(tracking_data, dict): - continue - self._cycle_exhausted[provider][tier_key] = {} - - for tracking_key, cycle_data in tracking_data.items(): - if not isinstance(cycle_data, dict): - continue - - cycle_started = cycle_data.get("cycle_started_at") - exhausted_list = cycle_data.get("exhausted", []) - - # Check if cycle has expired - if cycle_started is not None: - duration = self._get_fair_cycle_duration(provider) - if now_ts >= cycle_started + duration: - # Cycle expired - skip (don't restore) - lib_logger.debug( - f"Fair cycle expired for {provider}/{tier_key}/{tracking_key} - not restoring" - ) - continue - - # Restore valid cycle - self._cycle_exhausted[provider][tier_key][tracking_key] = { - "cycle_started_at": cycle_started, - "exhausted": set(exhausted_list) if exhausted_list else set(), - } - - # Log restoration summary - total_cycles = sum( - len(tracking) - for tier in self._cycle_exhausted.values() - for tracking in tier.values() - ) - if total_cycles > 0: - lib_logger.info(f"Restored {total_cycles} active fair cycle(s) from disk") - - async def _lazy_init(self): - """Initializes the usage data by loading it from the file asynchronously.""" - async with self._init_lock: - if not self._initialized.is_set(): - await self._load_usage() - await self._reset_daily_stats_if_needed() - self._initialized.set() - - async def _load_usage(self): - """Loads usage data from the JSON file asynchronously with resilience.""" - async with self._data_lock: - if not os.path.exists(self.file_path): - self._usage_data = {} - return - - try: - async with aiofiles.open(self.file_path, "r") as f: - content = await f.read() - self._usage_data = json.loads(content) if content.strip() else {} - except FileNotFoundError: - # File deleted between exists check and open - self._usage_data = {} - except json.JSONDecodeError as e: - lib_logger.warning( - f"Corrupted usage file {self.file_path}: {e}. Starting fresh." - ) - self._usage_data = {} - except (OSError, PermissionError, IOError) as e: - lib_logger.warning( - f"Cannot read usage file {self.file_path}: {e}. Using empty state." - ) - self._usage_data = {} - - # Restore fair cycle state from persisted data - fair_cycle_data = self._usage_data.get("__fair_cycle__", {}) - if fair_cycle_data: - self._deserialize_cycle_state(fair_cycle_data) - - async def _save_usage(self): - """Saves the current usage data using the resilient state writer.""" - if self._usage_data is None: - return - - async with self._data_lock: - # Add human-readable timestamp fields before saving - self._add_readable_timestamps(self._usage_data) - - # Persist fair cycle state (separate from credential data) - if self._cycle_exhausted: - self._usage_data["__fair_cycle__"] = self._serialize_cycle_state() - elif "__fair_cycle__" in self._usage_data: - # Clean up empty cycle data - del self._usage_data["__fair_cycle__"] - - # Hand off to resilient writer - handles retries and disk failures - self._state_writer.write(self._usage_data) - - async def _get_usage_data_snapshot(self) -> Dict[str, Any]: - """ - Get a shallow copy of the current usage data. - - Returns: - Copy of usage data dict (safe for reading without lock) - """ - await self._lazy_init() - async with self._data_lock: - return dict(self._usage_data) if self._usage_data else {} - - async def get_available_credentials_for_model( - self, credentials: List[str], model: str - ) -> List[str]: - """ - Get credentials that are not on cooldown for a specific model. - - Filters out credentials where: - - key_cooldown_until > now (key-level cooldown) - - model_cooldowns[model] > now (model-specific cooldown, includes quota exhausted) - - Args: - credentials: List of credential identifiers to check - model: Model name to check cooldowns for - - Returns: - List of credentials that are available (not on cooldown) for this model - """ - await self._lazy_init() - now = time.time() - available = [] - - async with self._data_lock: - for key in credentials: - key_data = self._usage_data.get(key, {}) - - # Skip if key-level cooldown is active - if (key_data.get("key_cooldown_until") or 0) > now: - continue - - # Normalize model name for consistent cooldown lookup - # (cooldowns are stored under normalized names by record_failure) - # For providers without normalize_model_for_tracking (non-Antigravity), - # this returns the model unchanged, so cooldown lookups work as before. - normalized_model = self._normalize_model(key, model) - - # Skip if model-specific cooldown is active - if ( - key_data.get("model_cooldowns", {}).get(normalized_model) or 0 - ) > now: - continue - - available.append(key) - - return available - - async def get_credential_availability_stats( - self, - credentials: List[str], - model: str, - credential_priorities: Optional[Dict[str, int]] = None, - ) -> Dict[str, int]: - """ - Get credential availability statistics including cooldown and fair cycle exclusions. - - This is used for logging to show why credentials are excluded. - - Args: - credentials: List of credential identifiers to check - model: Model name to check - credential_priorities: Optional dict mapping credentials to priorities - - Returns: - Dict with: - "total": Total credentials - "on_cooldown": Count on cooldown - "fair_cycle_excluded": Count excluded by fair cycle - "available": Count available for selection - """ - await self._lazy_init() - now = time.time() - - total = len(credentials) - on_cooldown = 0 - not_on_cooldown = [] - - # First pass: check cooldowns - async with self._data_lock: - for key in credentials: - key_data = self._usage_data.get(key, {}) - - # Check if key-level or model-level cooldown is active - normalized_model = self._normalize_model(key, model) - if (key_data.get("key_cooldown_until") or 0) > now or ( - key_data.get("model_cooldowns", {}).get(normalized_model) or 0 - ) > now: - on_cooldown += 1 - else: - not_on_cooldown.append(key) - - # Second pass: check fair cycle exclusions (only for non-cooldown credentials) - fair_cycle_excluded = 0 - if not_on_cooldown: - provider = self._get_provider_from_credential(not_on_cooldown[0]) - if provider: - rotation_mode = self._get_rotation_mode(provider) - if self._is_fair_cycle_enabled(provider, rotation_mode): - # Check each credential against its own tier's exhausted set - for key in not_on_cooldown: - key_priority = ( - credential_priorities.get(key, 999) - if credential_priorities - else 999 - ) - tier_key = self._get_tier_key(provider, key_priority) - tracking_key = self._get_tracking_key(key, model, provider) - - if self._is_credential_exhausted_in_cycle( - key, provider, tier_key, tracking_key - ): - fair_cycle_excluded += 1 - - available = total - on_cooldown - fair_cycle_excluded - - return { - "total": total, - "on_cooldown": on_cooldown, - "fair_cycle_excluded": fair_cycle_excluded, - "available": available, - } - - async def get_soonest_cooldown_end( - self, - credentials: List[str], - model: str, - ) -> Optional[float]: - """ - Find the soonest time when any credential will come off cooldown. - - This is used for smart waiting logic - if no credentials are available, - we can determine whether to wait (if soonest cooldown < deadline) or - fail fast (if soonest cooldown > deadline). - - Args: - credentials: List of credential identifiers to check - model: Model name to check cooldowns for - - Returns: - Timestamp of soonest cooldown end, or None if no credentials are on cooldown - """ - await self._lazy_init() - now = time.time() - soonest_end = None - - async with self._data_lock: - for key in credentials: - key_data = self._usage_data.get(key, {}) - normalized_model = self._normalize_model(key, model) - - # Check key-level cooldown - key_cooldown = key_data.get("key_cooldown_until") or 0 - if key_cooldown > now: - if soonest_end is None or key_cooldown < soonest_end: - soonest_end = key_cooldown - - # Check model-level cooldown - model_cooldown = ( - key_data.get("model_cooldowns", {}).get(normalized_model) or 0 - ) - if model_cooldown > now: - if soonest_end is None or model_cooldown < soonest_end: - soonest_end = model_cooldown - - return soonest_end - - async def _reset_daily_stats_if_needed(self): - """ - Checks if usage stats need to be reset for any key. - - Supports three reset modes: - 1. per_model: Each model has its own window, resets based on quota_reset_ts or fallback window - 2. credential: One window per credential (legacy with custom window duration) - 3. daily: Legacy daily reset at daily_reset_time_utc - """ - if self._usage_data is None: - return - - now_utc = datetime.now(timezone.utc) - now_ts = time.time() - today_str = now_utc.date().isoformat() - needs_saving = False - - for key, data in self._usage_data.items(): - reset_config = self._get_usage_reset_config(key) - - if reset_config: - reset_mode = reset_config.get("mode", "credential") - - if reset_mode == "per_model": - # Per-model window reset - needs_saving |= await self._check_per_model_resets( - key, data, reset_config, now_ts - ) - else: - # Credential-level window reset (legacy) - needs_saving |= await self._check_window_reset( - key, data, reset_config, now_ts - ) - elif self.daily_reset_time_utc: - # Legacy daily reset - needs_saving |= await self._check_daily_reset( - key, data, now_utc, today_str, now_ts - ) - - if needs_saving: - await self._save_usage() - - async def _check_per_model_resets( - self, - key: str, - data: Dict[str, Any], - reset_config: Dict[str, Any], - now_ts: float, - ) -> bool: - """ - Check and perform per-model resets for a credential. - - Each model resets independently based on: - 1. quota_reset_ts (authoritative, from quota exhausted error) if set - 2. window_start_ts + window_seconds (fallback) otherwise - - Grouped models reset together - all models in a group must be ready. - - Args: - key: Credential identifier - data: Usage data for this credential - reset_config: Provider's reset configuration - now_ts: Current timestamp - - Returns: - True if data was modified and needs saving - """ - window_seconds = reset_config.get("window_seconds", 86400) - models_data = data.get("models", {}) - - if not models_data: - return False - - modified = False - processed_groups = set() - - for model, model_data in list(models_data.items()): - # Check if this model is in a quota group - group = self._get_model_quota_group(key, model) - - if group: - if group in processed_groups: - continue # Already handled this group - - # Check if entire group should reset - if self._should_group_reset( - key, group, models_data, window_seconds, now_ts - ): - # Archive and reset all models in group - grouped_models = self._get_grouped_models(key, group) - archived_count = 0 - - for grouped_model in grouped_models: - if grouped_model in models_data: - gm_data = models_data[grouped_model] - self._archive_model_to_global(data, grouped_model, gm_data) - self._reset_model_data(gm_data) - archived_count += 1 - - if archived_count > 0: - lib_logger.info( - f"Reset model group '{group}' ({archived_count} models) for {mask_credential(key)}" - ) - modified = True - - processed_groups.add(group) - - else: - # Ungrouped model - check individually - if self._should_model_reset(model_data, window_seconds, now_ts): - self._archive_model_to_global(data, model, model_data) - self._reset_model_data(model_data) - lib_logger.info(f"Reset model {model} for {mask_credential(key)}") - modified = True - - # Preserve unexpired cooldowns - if modified: - self._preserve_unexpired_cooldowns(key, data, now_ts) - if "failures" in data: - data["failures"] = {} - - return modified - - def _should_model_reset( - self, model_data: Dict[str, Any], window_seconds: int, now_ts: float - ) -> bool: - """ - Check if a single model should reset. - - Returns True if: - - quota_reset_ts is set AND now >= quota_reset_ts, OR - - quota_reset_ts is NOT set AND now >= window_start_ts + window_seconds - """ - quota_reset = model_data.get("quota_reset_ts") - window_start = model_data.get("window_start_ts") - - if quota_reset: - return now_ts >= quota_reset - elif window_start: - return now_ts >= window_start + window_seconds - return False - - def _should_group_reset( - self, - key: str, - group: str, - models_data: Dict[str, Dict], - window_seconds: int, - now_ts: float, - ) -> bool: - """ - Check if all models in a group should reset. - - All models in the group must be ready to reset. - If any model has an active cooldown/window, the whole group waits. - """ - grouped_models = self._get_grouped_models(key, group) - - # Track if any model in group has data - any_has_data = False - - for grouped_model in grouped_models: - model_data = models_data.get(grouped_model, {}) - - if not model_data or ( - model_data.get("window_start_ts") is None - and model_data.get("success_count", 0) == 0 - ): - continue # No stats for this model yet - - any_has_data = True - - if not self._should_model_reset(model_data, window_seconds, now_ts): - return False # At least one model not ready - - return any_has_data - - def _archive_model_to_global( - self, data: Dict[str, Any], model: str, model_data: Dict[str, Any] - ) -> None: - """Archive a single model's stats to global.""" - global_data = data.setdefault("global", {"models": {}}) - global_model = global_data["models"].setdefault( - model, - { - "success_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - - global_model["success_count"] += model_data.get("success_count", 0) - global_model["prompt_tokens"] += model_data.get("prompt_tokens", 0) - global_model["prompt_tokens_cached"] = global_model.get( - "prompt_tokens_cached", 0 - ) + model_data.get("prompt_tokens_cached", 0) - global_model["completion_tokens"] += model_data.get("completion_tokens", 0) - global_model["approx_cost"] += model_data.get("approx_cost", 0.0) - - def _reset_model_data(self, model_data: Dict[str, Any]) -> None: - """Reset a model's window and stats.""" - model_data["window_start_ts"] = None - model_data["quota_reset_ts"] = None - model_data["success_count"] = 0 - model_data["failure_count"] = 0 - model_data["request_count"] = 0 - model_data["prompt_tokens"] = 0 - model_data["completion_tokens"] = 0 - model_data["approx_cost"] = 0.0 - # Reset quota baseline fields only if they exist (Antigravity-specific) - # These are added by update_quota_baseline(), only called for Antigravity - if "baseline_remaining_fraction" in model_data: - model_data["baseline_remaining_fraction"] = None - model_data["baseline_fetched_at"] = None - model_data["requests_at_baseline"] = None - # Reset quota display but keep max_requests (it doesn't change between periods) - max_req = model_data.get("quota_max_requests") - if max_req: - model_data["quota_display"] = f"0/{max_req}" - - async def _check_window_reset( - self, - key: str, - data: Dict[str, Any], - reset_config: Dict[str, Any], - now_ts: float, - ) -> bool: - """ - Check and perform rolling window reset for a credential. - - Args: - key: Credential identifier - data: Usage data for this credential - reset_config: Provider's reset configuration - now_ts: Current timestamp - - Returns: - True if data was modified and needs saving - """ - window_seconds = reset_config.get("window_seconds", 86400) # Default 24h - field_name = reset_config.get("field_name", "window") - description = reset_config.get("description", "rolling window") - - # Get current window data - window_data = data.get(field_name, {}) - window_start = window_data.get("start_ts") - - # No window started yet - nothing to reset - if window_start is None: - return False - - # Check if window has expired - window_end = window_start + window_seconds - if now_ts < window_end: - # Window still active - return False - - # Window expired - perform reset - hours_elapsed = (now_ts - window_start) / 3600 - lib_logger.info( - f"Resetting {field_name} for {mask_credential(key)} - " - f"{description} expired after {hours_elapsed:.1f}h" - ) - - # Archive to global - self._archive_to_global(data, window_data) - - # Preserve unexpired cooldowns - self._preserve_unexpired_cooldowns(key, data, now_ts) - - # Reset window stats (but don't start new window until first request) - data[field_name] = {"start_ts": None, "models": {}} - - # Reset consecutive failures - if "failures" in data: - data["failures"] = {} - - return True - - async def _check_daily_reset( - self, - key: str, - data: Dict[str, Any], - now_utc: datetime, - today_str: str, - now_ts: float, - ) -> bool: - """ - Check and perform legacy daily reset for a credential. - - Args: - key: Credential identifier - data: Usage data for this credential - now_utc: Current datetime in UTC - today_str: Today's date as ISO string - now_ts: Current timestamp - - Returns: - True if data was modified and needs saving - """ - last_reset_str = data.get("last_daily_reset", "") - - if last_reset_str == today_str: - return False - - last_reset_dt = None - if last_reset_str: - try: - last_reset_dt = datetime.fromisoformat(last_reset_str).replace( - tzinfo=timezone.utc - ) - except ValueError: - pass - - # Determine the reset threshold for today - reset_threshold_today = datetime.combine( - now_utc.date(), self.daily_reset_time_utc - ) - - if not ( - last_reset_dt is None or last_reset_dt < reset_threshold_today <= now_utc - ): - return False - - lib_logger.debug(f"Performing daily reset for key {mask_credential(key)}") - - # Preserve unexpired cooldowns - self._preserve_unexpired_cooldowns(key, data, now_ts) - - # Reset consecutive failures - if "failures" in data: - data["failures"] = {} - - # Archive daily stats to global - daily_data = data.get("daily", {}) - if daily_data: - self._archive_to_global(data, daily_data) - - # Reset daily stats - data["daily"] = {"date": today_str, "models": {}} - data["last_daily_reset"] = today_str - - return True - - def _archive_to_global( - self, data: Dict[str, Any], source_data: Dict[str, Any] - ) -> None: - """ - Archive usage stats from a source field (daily/window) to global. - - Args: - data: The credential's usage data - source_data: The source field data to archive (has "models" key) - """ - global_data = data.setdefault("global", {"models": {}}) - for model, stats in source_data.get("models", {}).items(): - global_model_stats = global_data["models"].setdefault( - model, - { - "success_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - global_model_stats["success_count"] += stats.get("success_count", 0) - global_model_stats["prompt_tokens"] += stats.get("prompt_tokens", 0) - global_model_stats["prompt_tokens_cached"] = global_model_stats.get( - "prompt_tokens_cached", 0 - ) + stats.get("prompt_tokens_cached", 0) - global_model_stats["completion_tokens"] += stats.get("completion_tokens", 0) - global_model_stats["approx_cost"] += stats.get("approx_cost", 0.0) - - def _preserve_unexpired_cooldowns( - self, key: str, data: Dict[str, Any], now_ts: float - ) -> None: - """ - Preserve unexpired cooldowns during reset (important for long quota cooldowns). - - Args: - key: Credential identifier (for logging) - data: The credential's usage data - now_ts: Current timestamp - """ - # Preserve unexpired model cooldowns - if "model_cooldowns" in data: - active_cooldowns = { - model: end_time - for model, end_time in data["model_cooldowns"].items() - if end_time > now_ts - } - if active_cooldowns: - max_remaining = max( - end_time - now_ts for end_time in active_cooldowns.values() - ) - hours_remaining = max_remaining / 3600 - lib_logger.info( - f"Preserving {len(active_cooldowns)} active cooldown(s) " - f"for key {mask_credential(key)} during reset " - f"(longest: {hours_remaining:.1f}h remaining)" - ) - data["model_cooldowns"] = active_cooldowns - else: - data["model_cooldowns"] = {} - - # Preserve unexpired key-level cooldown - if data.get("key_cooldown_until"): - if data["key_cooldown_until"] <= now_ts: - data["key_cooldown_until"] = None - else: - hours_remaining = (data["key_cooldown_until"] - now_ts) / 3600 - lib_logger.info( - f"Preserving key-level cooldown for {mask_credential(key)} " - f"during reset ({hours_remaining:.1f}h remaining)" - ) - else: - data["key_cooldown_until"] = None - - def _initialize_key_states(self, keys: List[str]): - """Initializes state tracking for all provided keys if not already present.""" - for key in keys: - if key not in self.key_states: - self.key_states[key] = { - "lock": asyncio.Lock(), - "condition": asyncio.Condition(), - "models_in_use": {}, # Dict[model_name, concurrent_count] - } - - def _select_weighted_random(self, candidates: List[tuple], tolerance: float) -> str: - """ - Selects a credential using weighted random selection based on usage counts. - - Args: - candidates: List of (credential_id, usage_count) tuples - tolerance: Tolerance value for weight calculation - - Returns: - Selected credential ID - - Formula: - weight = (max_usage - credential_usage) + tolerance + 1 - - This formula ensures: - - Lower usage = higher weight = higher selection probability - - Tolerance adds variability: higher tolerance means more randomness - - The +1 ensures all credentials have at least some chance of selection - """ - if not candidates: - raise ValueError("Cannot select from empty candidate list") - - if len(candidates) == 1: - return candidates[0][0] - - # Extract usage counts - usage_counts = [usage for _, usage in candidates] - max_usage = max(usage_counts) - - # Calculate weights using the formula: (max - current) + tolerance + 1 - weights = [] - for credential, usage in candidates: - weight = (max_usage - usage) + tolerance + 1 - weights.append(weight) - - # Log weight distribution for debugging - if lib_logger.isEnabledFor(logging.DEBUG): - total_weight = sum(weights) - weight_info = ", ".join( - f"{mask_credential(cred)}: w={w:.1f} ({w / total_weight * 100:.1f}%)" - for (cred, _), w in zip(candidates, weights) - ) - # lib_logger.debug(f"Weighted selection candidates: {weight_info}") - - # Random selection with weights - selected_credential = random.choices( - [cred for cred, _ in candidates], weights=weights, k=1 - )[0] - - return selected_credential - - async def acquire_key( - self, - available_keys: List[str], - model: str, - deadline: float, - max_concurrent: int = 1, - credential_priorities: Optional[Dict[str, int]] = None, - credential_tier_names: Optional[Dict[str, str]] = None, - all_provider_credentials: Optional[List[str]] = None, - ) -> str: - """ - Acquires the best available key using a tiered, model-aware locking strategy, - respecting a global deadline and credential priorities. - - Priority Logic: - - Groups credentials by priority level (1=highest, 2=lower, etc.) - - Always tries highest priority (lowest number) first - - Within same priority, sorts by usage count (load balancing) - - Only moves to next priority if all higher-priority keys exhausted/busy - - Args: - available_keys: List of credential identifiers to choose from - model: Model name being requested - deadline: Timestamp after which to stop trying - max_concurrent: Maximum concurrent requests allowed per credential - credential_priorities: Optional dict mapping credentials to priority levels (1=highest) - credential_tier_names: Optional dict mapping credentials to tier names (for logging) - all_provider_credentials: Full list of provider credentials (used for cycle reset checks) - - Returns: - Selected credential identifier - - Raises: - NoAvailableKeysError: If no key could be acquired within the deadline - """ - await self._lazy_init() - await self._reset_daily_stats_if_needed() - self._initialize_key_states(available_keys) - - # Normalize model name for consistent cooldown lookup - # (cooldowns are stored under normalized names by record_failure) - # Use first credential for provider detection; all credentials passed here - # are for the same provider (filtered by client.py before calling acquire_key). - # For providers without normalize_model_for_tracking (non-Antigravity), - # this returns the model unchanged, so cooldown lookups work as before. - normalized_model = ( - self._normalize_model(available_keys[0], model) if available_keys else model - ) - - # This loop continues as long as the global deadline has not been met. - while time.time() < deadline: - now = time.time() - - # Group credentials by priority level (if priorities provided) - if credential_priorities: - # Group keys by priority level - priority_groups = {} - async with self._data_lock: - for key in available_keys: - key_data = self._usage_data.get(key, {}) - - # Skip keys on cooldown (use normalized model for lookup) - if (key_data.get("key_cooldown_until") or 0) > now or ( - key_data.get("model_cooldowns", {}).get(normalized_model) - or 0 - ) > now: - continue - - # Get priority for this key (default to 999 if not specified) - priority = credential_priorities.get(key, 999) - - # Get usage count for load balancing within priority groups - # Uses grouped usage if model is in a quota group - usage_count = self._get_grouped_usage_count(key, model) - - # Group by priority - if priority not in priority_groups: - priority_groups[priority] = [] - priority_groups[priority].append((key, usage_count)) - - # Try priority groups in order (1, 2, 3, ...) - sorted_priorities = sorted(priority_groups.keys()) - - for priority_level in sorted_priorities: - keys_in_priority = priority_groups[priority_level] - - # Determine selection method based on provider's rotation mode - provider = model.split("/")[0] if "/" in model else "" - rotation_mode = self._get_rotation_mode(provider) - - # Fair cycle filtering - if provider and self._is_fair_cycle_enabled( - provider, rotation_mode - ): - tier_key = self._get_tier_key(provider, priority_level) - tracking_key = self._get_tracking_key( - keys_in_priority[0][0] if keys_in_priority else "", - model, - provider, - ) - - # Get all credentials for this tier (for cycle completion check) - all_tier_creds = self._get_all_credentials_for_tier_key( - provider, - tier_key, - all_provider_credentials or available_keys, - credential_priorities, - ) - - # Check if cycle should reset (all exhausted, expired, or none available) - if self._should_reset_cycle( - provider, - tier_key, - tracking_key, - all_tier_creds, - available_not_on_cooldown=[ - key for key, _ in keys_in_priority - ], - ): - self._reset_cycle(provider, tier_key, tracking_key) - - # Filter out exhausted credentials - filtered_keys = [] - for key, usage_count in keys_in_priority: - if not self._is_credential_exhausted_in_cycle( - key, provider, tier_key, tracking_key - ): - filtered_keys.append((key, usage_count)) - - keys_in_priority = filtered_keys - - # Calculate effective concurrency based on priority tier - multiplier = self._get_priority_multiplier( - provider, priority_level, rotation_mode - ) - effective_max_concurrent = max_concurrent * multiplier - - # Within each priority group, use existing tier1/tier2 logic - tier1_keys, tier2_keys = [], [] - for key, usage_count in keys_in_priority: - key_state = self.key_states[key] - - # Tier 1: Completely idle keys (preferred) - if not key_state["models_in_use"]: - tier1_keys.append((key, usage_count)) - # Tier 2: Keys that can accept more concurrent requests - elif ( - key_state["models_in_use"].get(model, 0) - < effective_max_concurrent - ): - tier2_keys.append((key, usage_count)) - - if rotation_mode == "sequential": - # Sequential mode: sort credentials by priority, usage, recency - # Keep all candidates in sorted order (no filtering to single key) - selection_method = "sequential" - if tier1_keys: - tier1_keys = self._sort_sequential( - tier1_keys, credential_priorities - ) - if tier2_keys: - tier2_keys = self._sort_sequential( - tier2_keys, credential_priorities - ) - elif self.rotation_tolerance > 0: - # Balanced mode with weighted randomness - selection_method = "weighted-random" - if tier1_keys: - selected_key = self._select_weighted_random( - tier1_keys, self.rotation_tolerance - ) - tier1_keys = [ - (k, u) for k, u in tier1_keys if k == selected_key - ] - if tier2_keys: - selected_key = self._select_weighted_random( - tier2_keys, self.rotation_tolerance - ) - tier2_keys = [ - (k, u) for k, u in tier2_keys if k == selected_key - ] - else: - # Deterministic: sort by usage within each tier - selection_method = "least-used" - tier1_keys.sort(key=lambda x: x[1]) - tier2_keys.sort(key=lambda x: x[1]) - - # Try to acquire from Tier 1 first - for key, usage in tier1_keys: - state = self.key_states[key] - async with state["lock"]: - if not state["models_in_use"]: - state["models_in_use"][model] = 1 - tier_name = ( - credential_tier_names.get(key, "unknown") - if credential_tier_names - else "unknown" - ) - quota_display = self._get_quota_display(key, model) - lib_logger.info( - f"Acquired key {mask_credential(key)} for model {model} " - f"(tier: {tier_name}, priority: {priority_level}, selection: {selection_method}, {quota_display})" - ) - return key - - # Then try Tier 2 - for key, usage in tier2_keys: - state = self.key_states[key] - async with state["lock"]: - current_count = state["models_in_use"].get(model, 0) - if current_count < effective_max_concurrent: - state["models_in_use"][model] = current_count + 1 - tier_name = ( - credential_tier_names.get(key, "unknown") - if credential_tier_names - else "unknown" - ) - quota_display = self._get_quota_display(key, model) - lib_logger.info( - f"Acquired key {mask_credential(key)} for model {model} " - f"(tier: {tier_name}, priority: {priority_level}, selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{effective_max_concurrent}, {quota_display})" - ) - return key - - # If we get here, all priority groups were exhausted but keys might become available - # Collect all keys across all priorities for waiting - all_potential_keys = [] - for keys_list in priority_groups.values(): - all_potential_keys.extend(keys_list) - - if not all_potential_keys: - # All credentials are on cooldown - check if waiting makes sense - soonest_end = await self.get_soonest_cooldown_end( - available_keys, model - ) - - if soonest_end is None: - # No cooldowns active but no keys available (shouldn't happen) - lib_logger.warning( - "No keys eligible and no cooldowns active. Re-evaluating..." - ) - await asyncio.sleep(1) - continue - - remaining_budget = deadline - time.time() - wait_needed = soonest_end - time.time() - - if wait_needed > remaining_budget: - # Fail fast - no credential will be available in time - lib_logger.warning( - f"All credentials on cooldown. Soonest available in {wait_needed:.1f}s, " - f"but only {remaining_budget:.1f}s budget remaining. Failing fast." - ) - break # Exit loop, will raise NoAvailableKeysError - - # Wait for the credential to become available - lib_logger.info( - f"All credentials on cooldown. Waiting {wait_needed:.1f}s for soonest credential..." - ) - await asyncio.sleep(min(wait_needed + 0.1, remaining_budget)) - continue - - # Wait for the highest priority key with lowest usage - best_priority = min(priority_groups.keys()) - best_priority_keys = priority_groups[best_priority] - best_wait_key = min(best_priority_keys, key=lambda x: x[1])[0] - wait_condition = self.key_states[best_wait_key]["condition"] - - lib_logger.info( - f"All Priority-{best_priority} keys are busy. Waiting for highest priority credential to become available..." - ) - - else: - # Original logic when no priorities specified - - # Determine selection method based on provider's rotation mode - provider = model.split("/")[0] if "/" in model else "" - rotation_mode = self._get_rotation_mode(provider) - - # Calculate effective concurrency for default priority (999) - # When no priorities are specified, all credentials get default priority - default_priority = 999 - multiplier = self._get_priority_multiplier( - provider, default_priority, rotation_mode - ) - effective_max_concurrent = max_concurrent * multiplier - - tier1_keys, tier2_keys = [], [] - - # First, filter the list of available keys to exclude any on cooldown. - async with self._data_lock: - for key in available_keys: - key_data = self._usage_data.get(key, {}) - - # Skip keys on cooldown (use normalized model for lookup) - if (key_data.get("key_cooldown_until") or 0) > now or ( - key_data.get("model_cooldowns", {}).get(normalized_model) - or 0 - ) > now: - continue - - # Prioritize keys based on their current usage to ensure load balancing. - # Uses grouped usage if model is in a quota group - usage_count = self._get_grouped_usage_count(key, model) - key_state = self.key_states[key] - - # Tier 1: Completely idle keys (preferred). - if not key_state["models_in_use"]: - tier1_keys.append((key, usage_count)) - # Tier 2: Keys that can accept more concurrent requests for this model. - elif ( - key_state["models_in_use"].get(model, 0) - < effective_max_concurrent - ): - tier2_keys.append((key, usage_count)) - - # Fair cycle filtering (non-priority case) - if provider and self._is_fair_cycle_enabled(provider, rotation_mode): - tier_key = self._get_tier_key(provider, default_priority) - tracking_key = self._get_tracking_key( - available_keys[0] if available_keys else "", - model, - provider, - ) - - # Get all credentials for this tier (for cycle completion check) - all_tier_creds = self._get_all_credentials_for_tier_key( - provider, - tier_key, - all_provider_credentials or available_keys, - None, - ) - - # Check if cycle should reset (all exhausted, expired, or none available) - if self._should_reset_cycle( - provider, - tier_key, - tracking_key, - all_tier_creds, - available_not_on_cooldown=[ - key for key, _ in (tier1_keys + tier2_keys) - ], - ): - self._reset_cycle(provider, tier_key, tracking_key) - - # Filter out exhausted credentials from both tiers - tier1_keys = [ - (key, usage) - for key, usage in tier1_keys - if not self._is_credential_exhausted_in_cycle( - key, provider, tier_key, tracking_key - ) - ] - tier2_keys = [ - (key, usage) - for key, usage in tier2_keys - if not self._is_credential_exhausted_in_cycle( - key, provider, tier_key, tracking_key - ) - ] - - if rotation_mode == "sequential": - # Sequential mode: sort credentials by priority, usage, recency - # Keep all candidates in sorted order (no filtering to single key) - selection_method = "sequential" - if tier1_keys: - tier1_keys = self._sort_sequential( - tier1_keys, credential_priorities - ) - if tier2_keys: - tier2_keys = self._sort_sequential( - tier2_keys, credential_priorities - ) - elif self.rotation_tolerance > 0: - # Balanced mode with weighted randomness - selection_method = "weighted-random" - if tier1_keys: - selected_key = self._select_weighted_random( - tier1_keys, self.rotation_tolerance - ) - tier1_keys = [ - (k, u) for k, u in tier1_keys if k == selected_key - ] - if tier2_keys: - selected_key = self._select_weighted_random( - tier2_keys, self.rotation_tolerance - ) - tier2_keys = [ - (k, u) for k, u in tier2_keys if k == selected_key - ] - else: - # Deterministic: sort by usage within each tier - selection_method = "least-used" - tier1_keys.sort(key=lambda x: x[1]) - tier2_keys.sort(key=lambda x: x[1]) - - # Attempt to acquire a key from Tier 1 first. - for key, usage in tier1_keys: - state = self.key_states[key] - async with state["lock"]: - if not state["models_in_use"]: - state["models_in_use"][model] = 1 - tier_name = ( - credential_tier_names.get(key) - if credential_tier_names - else None - ) - tier_info = f"tier: {tier_name}, " if tier_name else "" - quota_display = self._get_quota_display(key, model) - lib_logger.info( - f"Acquired key {mask_credential(key)} for model {model} " - f"({tier_info}selection: {selection_method}, {quota_display})" - ) - return key - - # If no Tier 1 keys are available, try Tier 2. - for key, usage in tier2_keys: - state = self.key_states[key] - async with state["lock"]: - current_count = state["models_in_use"].get(model, 0) - if current_count < effective_max_concurrent: - state["models_in_use"][model] = current_count + 1 - tier_name = ( - credential_tier_names.get(key) - if credential_tier_names - else None - ) - tier_info = f"tier: {tier_name}, " if tier_name else "" - quota_display = self._get_quota_display(key, model) - lib_logger.info( - f"Acquired key {mask_credential(key)} for model {model} " - f"({tier_info}selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{effective_max_concurrent}, {quota_display})" - ) - return key - - # If all eligible keys are locked, wait for a key to be released. - lib_logger.info( - "All eligible keys are currently locked for this model. Waiting..." - ) - - all_potential_keys = tier1_keys + tier2_keys - if not all_potential_keys: - # All credentials are on cooldown - check if waiting makes sense - soonest_end = await self.get_soonest_cooldown_end( - available_keys, model - ) - - if soonest_end is None: - # No cooldowns active but no keys available (shouldn't happen) - lib_logger.warning( - "No keys eligible and no cooldowns active. Re-evaluating..." - ) - await asyncio.sleep(1) - continue - - remaining_budget = deadline - time.time() - wait_needed = soonest_end - time.time() - - if wait_needed > remaining_budget: - # Fail fast - no credential will be available in time - lib_logger.warning( - f"All credentials on cooldown. Soonest available in {wait_needed:.1f}s, " - f"but only {remaining_budget:.1f}s budget remaining. Failing fast." - ) - break # Exit loop, will raise NoAvailableKeysError - - # Wait for the credential to become available - lib_logger.info( - f"All credentials on cooldown. Waiting {wait_needed:.1f}s for soonest credential..." - ) - await asyncio.sleep(min(wait_needed + 0.1, remaining_budget)) - continue - - # Wait on the condition of the key with the lowest current usage. - best_wait_key = min(all_potential_keys, key=lambda x: x[1])[0] - wait_condition = self.key_states[best_wait_key]["condition"] - - try: - async with wait_condition: - remaining_budget = deadline - time.time() - if remaining_budget <= 0: - break # Exit if the budget has already been exceeded. - # Wait for a notification, but no longer than the remaining budget or 1 second. - await asyncio.wait_for( - wait_condition.wait(), timeout=min(1, remaining_budget) - ) - lib_logger.info("Notified that a key was released. Re-evaluating...") - except asyncio.TimeoutError: - # This is not an error, just a timeout for the wait. The main loop will re-evaluate. - lib_logger.info("Wait timed out. Re-evaluating for any available key.") - - # If the loop exits, it means the deadline was exceeded. - raise NoAvailableKeysError( - f"Could not acquire a key for model {model} within the global time budget." - ) - - async def release_key(self, key: str, model: str): - """Releases a key's lock for a specific model and notifies waiting tasks.""" - if key not in self.key_states: - return - - state = self.key_states[key] - async with state["lock"]: - if model in state["models_in_use"]: - state["models_in_use"][model] -= 1 - remaining = state["models_in_use"][model] - if remaining <= 0: - del state["models_in_use"][model] # Clean up when count reaches 0 - lib_logger.info( - f"Released credential {mask_credential(key)} from model {model} " - f"(remaining concurrent: {max(0, remaining)})" - ) - else: - lib_logger.warning( - f"Attempted to release credential {mask_credential(key)} for model {model}, but it was not in use." - ) - - # Notify all tasks waiting on this key's condition - async with state["condition"]: - state["condition"].notify_all() - - async def record_success( - self, - key: str, - model: str, - completion_response: Optional[litellm.ModelResponse] = None, - ): - """ - Records a successful API call, resetting failure counters. - It safely handles cases where token usage data is not available. - - Supports two modes based on provider configuration: - - per_model: Each model has its own window_start_ts and stats in key_data["models"] - - credential: Legacy mode with key_data["daily"]["models"] - """ - await self._lazy_init() - - # Normalize model name to public-facing name for consistent tracking - model = self._normalize_model(key, model) - - async with self._data_lock: - now_ts = time.time() - today_utc_str = datetime.now(timezone.utc).date().isoformat() - - reset_config = self._get_usage_reset_config(key) - reset_mode = ( - reset_config.get("mode", "credential") if reset_config else "credential" - ) - - if reset_mode == "per_model": - # New per-model structure - key_data = self._usage_data.setdefault( - key, - { - "models": {}, - "global": {"models": {}}, - "model_cooldowns": {}, - "failures": {}, - }, - ) - - # Ensure models dict exists - if "models" not in key_data: - key_data["models"] = {} - - # Get or create per-model data with window tracking - model_data = key_data["models"].setdefault( - model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - - # Start window on first request for this model - if model_data.get("window_start_ts") is None: - model_data["window_start_ts"] = now_ts - - # Set expected quota reset time from provider config - window_seconds = ( - reset_config.get("window_seconds", 0) if reset_config else 0 - ) - if window_seconds > 0: - model_data["quota_reset_ts"] = now_ts + window_seconds - - window_hours = window_seconds / 3600 if window_seconds else 0 - lib_logger.info( - f"Started {window_hours:.1f}h window for model {model} on {mask_credential(key)}" - ) - - # Record stats - model_data["success_count"] += 1 - model_data["request_count"] = model_data.get("request_count", 0) + 1 - - # Sync request_count across quota group (for providers with shared quota pools) - new_request_count = model_data["request_count"] - group = self._get_model_quota_group(key, model) - if group: - grouped_models = self._get_grouped_models(key, group) - for grouped_model in grouped_models: - if grouped_model != model: - other_model_data = key_data["models"].setdefault( - grouped_model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - other_model_data["request_count"] = new_request_count - # Sync window timing (shared quota pool = shared window) - window_start = model_data.get("window_start_ts") - if window_start: - other_model_data["window_start_ts"] = window_start - # Also sync quota_max_requests if set - max_req = model_data.get("quota_max_requests") - if max_req: - other_model_data["quota_max_requests"] = max_req - other_model_data["quota_display"] = ( - f"{new_request_count}/{max_req}" - ) - - # Update quota_display if max_requests is set (Antigravity-specific) - max_req = model_data.get("quota_max_requests") - if max_req: - model_data["quota_display"] = ( - f"{model_data['request_count']}/{max_req}" - ) - - # Check custom cap - if self._check_and_apply_custom_cap( - key, model, model_data["request_count"] - ): - # Custom cap exceeded, cooldown applied - # Continue to record tokens/cost but credential will be skipped next time - pass - - usage_data_ref = model_data # For token/cost recording below - - else: - # Legacy credential-level structure - key_data = self._usage_data.setdefault( - key, - { - "daily": {"date": today_utc_str, "models": {}}, - "global": {"models": {}}, - "model_cooldowns": {}, - "failures": {}, - }, - ) - - if "last_daily_reset" not in key_data: - key_data["last_daily_reset"] = today_utc_str - - # Get or create model data in daily structure - usage_data_ref = key_data["daily"]["models"].setdefault( - model, - { - "success_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - usage_data_ref["success_count"] += 1 - - # Reset failures for this model - model_failures = key_data.setdefault("failures", {}).setdefault(model, {}) - model_failures["consecutive_failures"] = 0 - - # Clear transient cooldown on success (but NOT quota_reset_ts) - if model in key_data.get("model_cooldowns", {}): - del key_data["model_cooldowns"][model] - - # Record token and cost usage - if ( - completion_response - and hasattr(completion_response, "usage") - and completion_response.usage - ): - usage = completion_response.usage - prompt_total = usage.prompt_tokens - - # Extract cached tokens from prompt_tokens_details if present - cached_tokens = 0 - prompt_details = getattr(usage, "prompt_tokens_details", None) - if prompt_details: - if isinstance(prompt_details, dict): - cached_tokens = prompt_details.get("cached_tokens", 0) or 0 - elif hasattr(prompt_details, "cached_tokens"): - cached_tokens = prompt_details.cached_tokens or 0 - - # Store uncached tokens (prompt_tokens is total, subtract cached) - uncached_tokens = prompt_total - cached_tokens - usage_data_ref["prompt_tokens"] += uncached_tokens - - # Store cached tokens separately - if cached_tokens > 0: - usage_data_ref["prompt_tokens_cached"] = ( - usage_data_ref.get("prompt_tokens_cached", 0) + cached_tokens - ) - - usage_data_ref["completion_tokens"] += getattr( - usage, "completion_tokens", 0 - ) - lib_logger.info( - f"Recorded usage from response object for key {mask_credential(key)}" - ) - try: - provider_name = model.split("/")[0] - provider_instance = self._get_provider_instance(provider_name) - - if provider_instance and getattr( - provider_instance, "skip_cost_calculation", False - ): - lib_logger.debug( - f"Skipping cost calculation for provider '{provider_name}' (custom provider)." - ) - else: - if isinstance(completion_response, litellm.EmbeddingResponse): - model_info = litellm.get_model_info(model) - input_cost = model_info.get("input_cost_per_token") - if input_cost: - cost = ( - completion_response.usage.prompt_tokens * input_cost - ) - else: - cost = None - else: - cost = litellm.completion_cost( - completion_response=completion_response, model=model - ) - - if cost is not None: - usage_data_ref["approx_cost"] += cost - except Exception as e: - lib_logger.warning( - f"Could not calculate cost for model {model}: {e}" - ) - elif isinstance(completion_response, asyncio.Future) or hasattr( - completion_response, "__aiter__" - ): - pass # Stream - usage recorded from chunks - else: - lib_logger.warning( - f"No usage data found in completion response for model {model}. Recording success without token count." - ) - - key_data["last_used_ts"] = now_ts - - await self._save_usage() - - async def record_failure( - self, - key: str, - model: str, - classified_error: ClassifiedError, - increment_consecutive_failures: bool = True, - ): - """Records a failure and applies cooldowns based on error type. - - Distinguishes between: - - quota_exceeded: Long cooldown with exact reset time (from quota_reset_timestamp) - Sets quota_reset_ts on model (and group) - this becomes authoritative stats reset time - - rate_limit: Short transient cooldown (just wait and retry) - Only sets model_cooldowns - does NOT affect stats reset timing - - Args: - key: The API key or credential identifier - model: The model name - classified_error: The classified error object - increment_consecutive_failures: Whether to increment the failure counter. - Set to False for provider-level errors that shouldn't count against the key. - """ - await self._lazy_init() - - # Normalize model name to public-facing name for consistent tracking - model = self._normalize_model(key, model) - - async with self._data_lock: - now_ts = time.time() - today_utc_str = datetime.now(timezone.utc).date().isoformat() - - reset_config = self._get_usage_reset_config(key) - reset_mode = ( - reset_config.get("mode", "credential") if reset_config else "credential" - ) - - # Initialize key data with appropriate structure - if reset_mode == "per_model": - key_data = self._usage_data.setdefault( - key, - { - "models": {}, - "global": {"models": {}}, - "model_cooldowns": {}, - "failures": {}, - }, - ) - else: - key_data = self._usage_data.setdefault( - key, - { - "daily": {"date": today_utc_str, "models": {}}, - "global": {"models": {}}, - "model_cooldowns": {}, - "failures": {}, - }, - ) - - # Provider-level errors (transient issues) should not count against the key - provider_level_errors = {"server_error", "api_connection"} - - # Determine if we should increment the failure counter - should_increment = ( - increment_consecutive_failures - and classified_error.error_type not in provider_level_errors - ) - - # Calculate cooldown duration based on error type - cooldown_seconds = None - model_cooldowns = key_data.setdefault("model_cooldowns", {}) - - # Capture existing cooldown BEFORE we modify it - # Used to determine if this is a fresh exhaustion vs re-processing - existing_cooldown_before = model_cooldowns.get(model) - was_already_on_cooldown = ( - existing_cooldown_before is not None - and existing_cooldown_before > now_ts - ) - - if classified_error.error_type == "quota_exceeded": - # Quota exhausted - use authoritative reset timestamp if available - quota_reset_ts = classified_error.quota_reset_timestamp - cooldown_seconds = ( - classified_error.retry_after or COOLDOWN_RATE_LIMIT_DEFAULT - ) - - if quota_reset_ts and reset_mode == "per_model": - # Set quota_reset_ts on model - this becomes authoritative stats reset time - models_data = key_data.setdefault("models", {}) - model_data = models_data.setdefault( - model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - model_data["quota_reset_ts"] = quota_reset_ts - # Track failure for quota estimation (request still consumes quota) - model_data["failure_count"] = model_data.get("failure_count", 0) + 1 - model_data["request_count"] = model_data.get("request_count", 0) + 1 - - # Clamp request_count to quota_max_requests when quota is exhausted - # This prevents display overflow (e.g., 151/150) when requests are - # counted locally before API refresh corrects the value - max_req = model_data.get("quota_max_requests") - if max_req is not None and model_data["request_count"] > max_req: - model_data["request_count"] = max_req - # Update quota_display with clamped value - model_data["quota_display"] = f"{max_req}/{max_req}" - new_request_count = model_data["request_count"] - - # Apply to all models in the same quota group - group = self._get_model_quota_group(key, model) - if group: - grouped_models = self._get_grouped_models(key, group) - for grouped_model in grouped_models: - group_model_data = models_data.setdefault( - grouped_model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - group_model_data["quota_reset_ts"] = quota_reset_ts - # Sync request_count across quota group - group_model_data["request_count"] = new_request_count - # Also sync quota_max_requests if set - max_req = model_data.get("quota_max_requests") - if max_req: - group_model_data["quota_max_requests"] = max_req - group_model_data["quota_display"] = ( - f"{new_request_count}/{max_req}" - ) - # Also set transient cooldown for selection logic - model_cooldowns[grouped_model] = quota_reset_ts - - reset_dt = datetime.fromtimestamp( - quota_reset_ts, tz=timezone.utc - ) - lib_logger.info( - f"Quota exhausted for group '{group}' ({len(grouped_models)} models) " - f"on {mask_credential(key)}. Resets at {reset_dt.isoformat()}" - ) - else: - reset_dt = datetime.fromtimestamp( - quota_reset_ts, tz=timezone.utc - ) - hours = (quota_reset_ts - now_ts) / 3600 - lib_logger.info( - f"Quota exhausted for model {model} on {mask_credential(key)}. " - f"Resets at {reset_dt.isoformat()} ({hours:.1f}h)" - ) - - # Set transient cooldown for selection logic - model_cooldowns[model] = quota_reset_ts - else: - # No authoritative timestamp or legacy mode - just use retry_after - model_cooldowns[model] = now_ts + cooldown_seconds - hours = cooldown_seconds / 3600 - lib_logger.info( - f"Quota exhausted on {mask_credential(key)} for model {model}. " - f"Cooldown: {cooldown_seconds}s ({hours:.1f}h)" - ) - - # Mark credential as exhausted for fair cycle if cooldown exceeds threshold - # BUT only if this is a FRESH exhaustion (wasn't already on cooldown) - # This prevents re-marking after cycle reset - if not was_already_on_cooldown: - effective_cooldown = ( - (quota_reset_ts - now_ts) - if quota_reset_ts - else (cooldown_seconds or 0) - ) - provider = self._get_provider_from_credential(key) - if provider: - threshold = self._get_exhaustion_cooldown_threshold(provider) - if effective_cooldown > threshold: - rotation_mode = self._get_rotation_mode(provider) - if self._is_fair_cycle_enabled(provider, rotation_mode): - priority = self._get_credential_priority(key, provider) - tier_key = self._get_tier_key(provider, priority) - tracking_key = self._get_tracking_key( - key, model, provider - ) - self._mark_credential_exhausted( - key, provider, tier_key, tracking_key - ) - - elif classified_error.error_type == "rate_limit": - # Transient rate limit - just set short cooldown (does NOT set quota_reset_ts) - cooldown_seconds = ( - classified_error.retry_after or COOLDOWN_RATE_LIMIT_DEFAULT - ) - model_cooldowns[model] = now_ts + cooldown_seconds - lib_logger.info( - f"Rate limit on {mask_credential(key)} for model {model}. " - f"Transient cooldown: {cooldown_seconds}s" - ) - - elif classified_error.error_type == "authentication": - # Apply a 5-minute key-level lockout for auth errors - key_data["key_cooldown_until"] = now_ts + COOLDOWN_AUTH_ERROR - cooldown_seconds = COOLDOWN_AUTH_ERROR - model_cooldowns[model] = now_ts + cooldown_seconds - lib_logger.warning( - f"Authentication error on key {mask_credential(key)}. Applying 5-minute key-level lockout." - ) - - # If we should increment failures, calculate escalating backoff - if should_increment: - failures_data = key_data.setdefault("failures", {}) - model_failures = failures_data.setdefault( - model, {"consecutive_failures": 0} - ) - model_failures["consecutive_failures"] += 1 - count = model_failures["consecutive_failures"] - - # If cooldown wasn't set by specific error type, use escalating backoff - if cooldown_seconds is None: - cooldown_seconds = COOLDOWN_BACKOFF_TIERS.get( - count, COOLDOWN_BACKOFF_MAX - ) - model_cooldowns[model] = now_ts + cooldown_seconds - lib_logger.warning( - f"Failure #{count} for key {mask_credential(key)} with model {model}. " - f"Error type: {classified_error.error_type}, cooldown: {cooldown_seconds}s" - ) - else: - # Provider-level errors: apply short cooldown but don't count against key - if cooldown_seconds is None: - cooldown_seconds = COOLDOWN_TRANSIENT_ERROR - model_cooldowns[model] = now_ts + cooldown_seconds - lib_logger.info( - f"Provider-level error ({classified_error.error_type}) for key {mask_credential(key)} " - f"with model {model}. NOT incrementing failures. Cooldown: {cooldown_seconds}s" - ) - - # Check for key-level lockout condition - await self._check_key_lockout(key, key_data) - - # Track failure count for quota estimation (all failures consume quota) - # This is separate from consecutive_failures which is for backoff logic - if reset_mode == "per_model": - models_data = key_data.setdefault("models", {}) - model_data = models_data.setdefault( - model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - # Only increment if not already incremented in quota_exceeded branch - if classified_error.error_type != "quota_exceeded": - model_data["failure_count"] = model_data.get("failure_count", 0) + 1 - model_data["request_count"] = model_data.get("request_count", 0) + 1 - - # Sync request_count across quota group - new_request_count = model_data["request_count"] - group = self._get_model_quota_group(key, model) - if group: - grouped_models = self._get_grouped_models(key, group) - for grouped_model in grouped_models: - if grouped_model != model: - other_model_data = models_data.setdefault( - grouped_model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - other_model_data["request_count"] = new_request_count - # Also sync quota_max_requests if set - max_req = model_data.get("quota_max_requests") - if max_req: - other_model_data["quota_max_requests"] = max_req - other_model_data["quota_display"] = ( - f"{new_request_count}/{max_req}" - ) - - key_data["last_failure"] = { - "timestamp": now_ts, - "model": model, - "error": str(classified_error.original_exception), - } - - await self._save_usage() - - async def update_quota_baseline( - self, - credential: str, - model: str, - remaining_fraction: float, - max_requests: Optional[int] = None, - reset_timestamp: Optional[float] = None, - ) -> Optional[Dict[str, Any]]: - """ - Update quota baseline data for a credential/model after fetching from API. - - This stores the current quota state as a baseline, which is used to - estimate remaining quota based on subsequent request counts. - - When quota is exhausted (remaining_fraction <= 0.0) and a valid reset_timestamp - is provided, this also sets model_cooldowns to prevent wasted requests. - - Args: - credential: Credential identifier (file path or env:// URI) - model: Model name (with or without provider prefix) - remaining_fraction: Current remaining quota as fraction (0.0 to 1.0) - max_requests: Maximum requests allowed per quota period (e.g., 250 for Claude) - reset_timestamp: Unix timestamp when quota resets. Only trusted when - remaining_fraction < 1.0 (quota has been used). API returns garbage - reset times for unused quota (100%). - - Returns: - None if no cooldown was set/updated, otherwise: - { - "group_or_model": str, # quota group name or model name if ungrouped - "hours_until_reset": float, - } - """ - await self._lazy_init() - async with self._data_lock: - now_ts = time.time() - - # Get or create key data structure - key_data = self._usage_data.setdefault( - credential, - { - "models": {}, - "global": {"models": {}}, - "model_cooldowns": {}, - "failures": {}, - }, - ) - - # Ensure models dict exists - if "models" not in key_data: - key_data["models"] = {} - - # Get or create per-model data - model_data = key_data["models"].setdefault( - model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - "baseline_remaining_fraction": None, - "baseline_fetched_at": None, - "requests_at_baseline": None, - }, - ) - - # Calculate actual used requests from API's remaining fraction - # The API is authoritative - sync our local count to match reality - if max_requests is not None: - used_requests = int((1.0 - remaining_fraction) * max_requests) - else: - # Estimate max_requests from provider's quota cost - # This matches how get_max_requests_for_model() calculates it - provider = self._get_provider_from_credential(credential) - plugin_instance = self._get_provider_instance(provider) - if plugin_instance and hasattr( - plugin_instance, "get_max_requests_for_model" - ): - # Get tier from provider's cache - tier = getattr(plugin_instance, "project_tier_cache", {}).get( - credential, "standard-tier" - ) - # Strip provider prefix from model if present - clean_model = model.split("/")[-1] if "/" in model else model - max_requests = plugin_instance.get_max_requests_for_model( - clean_model, tier - ) - used_requests = int((1.0 - remaining_fraction) * max_requests) - else: - # Fallback: keep existing count if we can't calculate - used_requests = model_data.get("request_count", 0) - max_requests = model_data.get("quota_max_requests") - - # Sync local request count to API's authoritative value - # Use max() to prevent API from resetting our count if it returns stale/cached 100% - # The API can only increase our count (if we missed requests), not decrease it - # See: https://github.com/Mirrowel/LLM-API-Key-Proxy/issues/75 - current_count = model_data.get("request_count", 0) - synced_count = max(current_count, used_requests) - model_data["request_count"] = synced_count - model_data["requests_at_baseline"] = synced_count - - # Update baseline fields - model_data["baseline_remaining_fraction"] = remaining_fraction - model_data["baseline_fetched_at"] = now_ts - - # Update max_requests and quota_display - if max_requests is not None: - model_data["quota_max_requests"] = max_requests - model_data["quota_display"] = f"{synced_count}/{max_requests}" - - # Handle reset_timestamp: only trust it when quota has been used (< 100%) - # API returns garbage reset times for unused quota - valid_reset_ts = ( - reset_timestamp is not None - and remaining_fraction < 1.0 - and reset_timestamp > now_ts - ) - - if valid_reset_ts: - model_data["quota_reset_ts"] = reset_timestamp - - # Set cooldowns when quota is exhausted - model_cooldowns = key_data.setdefault("model_cooldowns", {}) - is_exhausted = remaining_fraction <= 0.0 - cooldown_set_info = ( - None # Will be returned if cooldown was newly set/updated - ) - - if is_exhausted and valid_reset_ts: - # Check if there was an existing ACTIVE cooldown before we update - # This distinguishes between fresh exhaustion vs refresh of existing state - existing_cooldown = model_cooldowns.get(model) - was_already_on_cooldown = ( - existing_cooldown is not None and existing_cooldown > now_ts - ) - - # Only update cooldown if not set or differs by more than 5 minutes - should_update = ( - existing_cooldown is None - or abs(existing_cooldown - reset_timestamp) > 300 - ) - if should_update: - model_cooldowns[model] = reset_timestamp - hours_until_reset = (reset_timestamp - now_ts) / 3600 - # Determine group or model name for logging - group = self._get_model_quota_group(credential, model) - cooldown_set_info = { - "group_or_model": group if group else model.split("/")[-1], - "hours_until_reset": hours_until_reset, - } - - # Mark credential as exhausted in fair cycle if cooldown exceeds threshold - # BUT only if this is a FRESH exhaustion (wasn't already on cooldown) - # This prevents re-marking after cycle reset when quota refresh sees existing cooldown - if not was_already_on_cooldown: - cooldown_duration = reset_timestamp - now_ts - provider = self._get_provider_from_credential(credential) - if provider: - threshold = self._get_exhaustion_cooldown_threshold(provider) - if cooldown_duration > threshold: - rotation_mode = self._get_rotation_mode(provider) - if self._is_fair_cycle_enabled(provider, rotation_mode): - priority = self._get_credential_priority( - credential, provider - ) - tier_key = self._get_tier_key(provider, priority) - tracking_key = self._get_tracking_key( - credential, model, provider - ) - self._mark_credential_exhausted( - credential, provider, tier_key, tracking_key - ) - - # Defensive clamp: ensure request_count doesn't exceed max when exhausted - if ( - max_requests is not None - and model_data["request_count"] > max_requests - ): - model_data["request_count"] = max_requests - model_data["quota_display"] = f"{max_requests}/{max_requests}" - - # Sync baseline fields and quota info across quota group - group = self._get_model_quota_group(credential, model) - if group: - grouped_models = self._get_grouped_models(credential, group) - for grouped_model in grouped_models: - if grouped_model != model: - other_model_data = key_data["models"].setdefault( - grouped_model, - { - "window_start_ts": None, - "quota_reset_ts": None, - "success_count": 0, - "failure_count": 0, - "request_count": 0, - "prompt_tokens": 0, - "prompt_tokens_cached": 0, - "completion_tokens": 0, - "approx_cost": 0.0, - }, - ) - # Sync request tracking (use synced_count to prevent reset bug) - other_model_data["request_count"] = synced_count - if max_requests is not None: - other_model_data["quota_max_requests"] = max_requests - other_model_data["quota_display"] = ( - f"{synced_count}/{max_requests}" - ) - # Sync baseline fields - other_model_data["baseline_remaining_fraction"] = ( - remaining_fraction - ) - other_model_data["baseline_fetched_at"] = now_ts - other_model_data["requests_at_baseline"] = synced_count - # Sync reset timestamp if valid - if valid_reset_ts: - other_model_data["quota_reset_ts"] = reset_timestamp - # Sync window start time - window_start = model_data.get("window_start_ts") - if window_start: - other_model_data["window_start_ts"] = window_start - # Sync cooldown if exhausted (with Âą5 min check) - if is_exhausted and valid_reset_ts: - existing_grouped = model_cooldowns.get(grouped_model) - should_update_grouped = ( - existing_grouped is None - or abs(existing_grouped - reset_timestamp) > 300 - ) - if should_update_grouped: - model_cooldowns[grouped_model] = reset_timestamp - - # Defensive clamp for grouped models when exhausted - if ( - max_requests is not None - and other_model_data["request_count"] > max_requests - ): - other_model_data["request_count"] = max_requests - other_model_data["quota_display"] = ( - f"{max_requests}/{max_requests}" - ) - - lib_logger.debug( - f"Updated quota baseline for {mask_credential(credential)} model={model}: " - f"remaining={remaining_fraction:.2%}, synced_request_count={synced_count}" - ) - - await self._save_usage() - return cooldown_set_info - - async def _check_key_lockout(self, key: str, key_data: Dict): - """ - Checks if a key should be locked out due to multiple model failures. - - NOTE: This check is currently disabled. The original logic counted individual - models in long-term lockout, but this caused issues with quota groups - when - a single quota group (e.g., "claude" with 5 models) was exhausted, it would - count as 5 lockouts and trigger key-level lockout, blocking other quota groups - (like gemini) that were still available. - - The per-model and per-group cooldowns already handle quota exhaustion properly. - """ - # Disabled - see docstring above - pass - - async def get_stats_for_endpoint( - self, - provider_filter: Optional[str] = None, - include_global: bool = True, - ) -> Dict[str, Any]: - """ - Get usage stats formatted for the /v1/quota-stats endpoint. - - Aggregates data from key_usage.json grouped by provider. - Includes both current period stats and global (lifetime) stats. - - Args: - provider_filter: If provided, only return stats for this provider - include_global: If True, include global/lifetime stats alongside current - - Returns: - { - "providers": { - "provider_name": { - "credential_count": int, - "active_count": int, - "on_cooldown_count": int, - "total_requests": int, - "tokens": { - "input_cached": int, - "input_uncached": int, - "input_cache_pct": float, - "output": int - }, - "approx_cost": float | None, - "credentials": [...], - "global": {...} # If include_global is True - } - }, - "summary": {...}, - "global_summary": {...}, # If include_global is True - "timestamp": float - } - """ - await self._lazy_init() - - now_ts = time.time() - providers: Dict[str, Dict[str, Any]] = {} - # Track global stats separately - global_providers: Dict[str, Dict[str, Any]] = {} - - async with self._data_lock: - if not self._usage_data: - return { - "providers": {}, - "summary": { - "total_providers": 0, - "total_credentials": 0, - "active_credentials": 0, - "exhausted_credentials": 0, - "total_requests": 0, - "tokens": { - "input_cached": 0, - "input_uncached": 0, - "input_cache_pct": 0, - "output": 0, - }, - "approx_total_cost": 0.0, - }, - "global_summary": { - "total_providers": 0, - "total_credentials": 0, - "total_requests": 0, - "tokens": { - "input_cached": 0, - "input_uncached": 0, - "input_cache_pct": 0, - "output": 0, - }, - "approx_total_cost": 0.0, - }, - "data_source": "cache", - "timestamp": now_ts, - } - - for credential, cred_data in self._usage_data.items(): - # Extract provider from credential path - provider = self._get_provider_from_credential(credential) - if not provider: - continue - - # Apply filter if specified - if provider_filter and provider != provider_filter: - continue - - # Initialize provider entry - if provider not in providers: - providers[provider] = { - "credential_count": 0, - "active_count": 0, - "on_cooldown_count": 0, - "exhausted_count": 0, - "total_requests": 0, - "tokens": { - "input_cached": 0, - "input_uncached": 0, - "input_cache_pct": 0, - "output": 0, - }, - "approx_cost": 0.0, - "credentials": [], - } - global_providers[provider] = { - "total_requests": 0, - "tokens": { - "input_cached": 0, - "input_uncached": 0, - "input_cache_pct": 0, - "output": 0, - }, - "approx_cost": 0.0, - } - - prov_stats = providers[provider] - prov_stats["credential_count"] += 1 - - # Determine credential status and cooldowns - key_cooldown = cred_data.get("key_cooldown_until", 0) or 0 - model_cooldowns = cred_data.get("model_cooldowns", {}) - - # Build active cooldowns with remaining time - active_cooldowns = {} - for model, cooldown_ts in model_cooldowns.items(): - if cooldown_ts > now_ts: - remaining_seconds = int(cooldown_ts - now_ts) - active_cooldowns[model] = { - "until_ts": cooldown_ts, - "remaining_seconds": remaining_seconds, - } - - key_cooldown_remaining = None - if key_cooldown > now_ts: - key_cooldown_remaining = int(key_cooldown - now_ts) - - has_active_cooldown = key_cooldown > now_ts or len(active_cooldowns) > 0 - - # Check if exhausted (all quota groups exhausted for Antigravity) - is_exhausted = False - models_data = cred_data.get("models", {}) - if models_data: - # Check if any model has remaining quota - all_exhausted = True - for model_stats in models_data.values(): - if isinstance(model_stats, dict): - baseline = model_stats.get("baseline_remaining_fraction") - if baseline is None or baseline > 0: - all_exhausted = False - break - if all_exhausted and len(models_data) > 0: - is_exhausted = True - - if is_exhausted: - prov_stats["exhausted_count"] += 1 - status = "exhausted" - elif has_active_cooldown: - prov_stats["on_cooldown_count"] += 1 - status = "cooldown" - else: - prov_stats["active_count"] += 1 - status = "active" - - # Aggregate token stats (current period) - cred_tokens = { - "input_cached": 0, - "input_uncached": 0, - "output": 0, - } - cred_requests = 0 - cred_cost = 0.0 - - # Aggregate global token stats - cred_global_tokens = { - "input_cached": 0, - "input_uncached": 0, - "output": 0, - } - cred_global_requests = 0 - cred_global_cost = 0.0 - - # Handle per-model structure (current period) - if models_data: - for model_name, model_stats in models_data.items(): - if not isinstance(model_stats, dict): - continue - # Prefer request_count if available and non-zero, else fall back to success+failure - req_count = model_stats.get("request_count", 0) - if req_count > 0: - cred_requests += req_count - else: - cred_requests += model_stats.get("success_count", 0) - cred_requests += model_stats.get("failure_count", 0) - # Token stats - track cached separately - cred_tokens["input_cached"] += model_stats.get( - "prompt_tokens_cached", 0 - ) - cred_tokens["input_uncached"] += model_stats.get( - "prompt_tokens", 0 - ) - cred_tokens["output"] += model_stats.get("completion_tokens", 0) - cred_cost += model_stats.get("approx_cost", 0.0) - - # Handle legacy daily structure - daily_data = cred_data.get("daily", {}) - daily_models = daily_data.get("models", {}) - for model_name, model_stats in daily_models.items(): - if not isinstance(model_stats, dict): - continue - cred_requests += model_stats.get("success_count", 0) - cred_tokens["input_cached"] += model_stats.get( - "prompt_tokens_cached", 0 - ) - cred_tokens["input_uncached"] += model_stats.get("prompt_tokens", 0) - cred_tokens["output"] += model_stats.get("completion_tokens", 0) - cred_cost += model_stats.get("approx_cost", 0.0) - - # Handle global stats - global_data = cred_data.get("global", {}) - global_models = global_data.get("models", {}) - for model_name, model_stats in global_models.items(): - if not isinstance(model_stats, dict): - continue - cred_global_requests += model_stats.get("success_count", 0) - cred_global_tokens["input_cached"] += model_stats.get( - "prompt_tokens_cached", 0 - ) - cred_global_tokens["input_uncached"] += model_stats.get( - "prompt_tokens", 0 - ) - cred_global_tokens["output"] += model_stats.get( - "completion_tokens", 0 - ) - cred_global_cost += model_stats.get("approx_cost", 0.0) - - # Add current period stats to global totals - cred_global_requests += cred_requests - cred_global_tokens["input_cached"] += cred_tokens["input_cached"] - cred_global_tokens["input_uncached"] += cred_tokens["input_uncached"] - cred_global_tokens["output"] += cred_tokens["output"] - cred_global_cost += cred_cost - - # Build credential entry - # Mask credential identifier for display - if credential.startswith("env://"): - identifier = credential - else: - identifier = Path(credential).name - - cred_entry = { - "identifier": identifier, - "full_path": credential, - "status": status, - "last_used_ts": cred_data.get("last_used_ts"), - "requests": cred_requests, - "tokens": cred_tokens, - "approx_cost": cred_cost if cred_cost > 0 else None, - } - - # Add cooldown info - if key_cooldown_remaining is not None: - cred_entry["key_cooldown_remaining"] = key_cooldown_remaining - if active_cooldowns: - cred_entry["model_cooldowns"] = active_cooldowns - - # Add global stats for this credential - if include_global: - # Calculate global cache percentage - global_total_input = ( - cred_global_tokens["input_cached"] - + cred_global_tokens["input_uncached"] - ) - global_cache_pct = ( - round( - cred_global_tokens["input_cached"] - / global_total_input - * 100, - 1, - ) - if global_total_input > 0 - else 0 - ) - - cred_entry["global"] = { - "requests": cred_global_requests, - "tokens": { - "input_cached": cred_global_tokens["input_cached"], - "input_uncached": cred_global_tokens["input_uncached"], - "input_cache_pct": global_cache_pct, - "output": cred_global_tokens["output"], - }, - "approx_cost": cred_global_cost - if cred_global_cost > 0 - else None, - } - - # Add model-specific data for providers with per-model tracking - if models_data: - cred_entry["models"] = {} - for model_name, model_stats in models_data.items(): - if not isinstance(model_stats, dict): - continue - cred_entry["models"][model_name] = { - "requests": model_stats.get("success_count", 0) - + model_stats.get("failure_count", 0), - "request_count": model_stats.get("request_count", 0), - "success_count": model_stats.get("success_count", 0), - "failure_count": model_stats.get("failure_count", 0), - "prompt_tokens": model_stats.get("prompt_tokens", 0), - "prompt_tokens_cached": model_stats.get( - "prompt_tokens_cached", 0 - ), - "completion_tokens": model_stats.get( - "completion_tokens", 0 - ), - "approx_cost": model_stats.get("approx_cost", 0.0), - "window_start_ts": model_stats.get("window_start_ts"), - "quota_reset_ts": model_stats.get("quota_reset_ts"), - # Quota baseline fields (Antigravity-specific) - "baseline_remaining_fraction": model_stats.get( - "baseline_remaining_fraction" - ), - "baseline_fetched_at": model_stats.get( - "baseline_fetched_at" - ), - "quota_max_requests": model_stats.get("quota_max_requests"), - "quota_display": model_stats.get("quota_display"), - } - - prov_stats["credentials"].append(cred_entry) - - # Aggregate to provider totals (current period) - prov_stats["total_requests"] += cred_requests - prov_stats["tokens"]["input_cached"] += cred_tokens["input_cached"] - prov_stats["tokens"]["input_uncached"] += cred_tokens["input_uncached"] - prov_stats["tokens"]["output"] += cred_tokens["output"] - if cred_cost > 0: - prov_stats["approx_cost"] += cred_cost - - # Aggregate to global provider totals - global_providers[provider]["total_requests"] += cred_global_requests - global_providers[provider]["tokens"]["input_cached"] += ( - cred_global_tokens["input_cached"] - ) - global_providers[provider]["tokens"]["input_uncached"] += ( - cred_global_tokens["input_uncached"] - ) - global_providers[provider]["tokens"]["output"] += cred_global_tokens[ - "output" - ] - global_providers[provider]["approx_cost"] += cred_global_cost - - # Calculate cache percentages for each provider - for provider, prov_stats in providers.items(): - total_input = ( - prov_stats["tokens"]["input_cached"] - + prov_stats["tokens"]["input_uncached"] - ) - if total_input > 0: - prov_stats["tokens"]["input_cache_pct"] = round( - prov_stats["tokens"]["input_cached"] / total_input * 100, 1 - ) - # Set cost to None if 0 - if prov_stats["approx_cost"] == 0: - prov_stats["approx_cost"] = None - - # Calculate global cache percentages - if include_global and provider in global_providers: - gp = global_providers[provider] - global_total = ( - gp["tokens"]["input_cached"] + gp["tokens"]["input_uncached"] - ) - if global_total > 0: - gp["tokens"]["input_cache_pct"] = round( - gp["tokens"]["input_cached"] / global_total * 100, 1 - ) - if gp["approx_cost"] == 0: - gp["approx_cost"] = None - prov_stats["global"] = gp - - # Build summary (current period) - total_creds = sum(p["credential_count"] for p in providers.values()) - active_creds = sum(p["active_count"] for p in providers.values()) - exhausted_creds = sum(p["exhausted_count"] for p in providers.values()) - total_requests = sum(p["total_requests"] for p in providers.values()) - total_input_cached = sum( - p["tokens"]["input_cached"] for p in providers.values() - ) - total_input_uncached = sum( - p["tokens"]["input_uncached"] for p in providers.values() - ) - total_output = sum(p["tokens"]["output"] for p in providers.values()) - total_cost = sum(p["approx_cost"] or 0 for p in providers.values()) - - total_input = total_input_cached + total_input_uncached - input_cache_pct = ( - round(total_input_cached / total_input * 100, 1) if total_input > 0 else 0 - ) - - result = { - "providers": providers, - "summary": { - "total_providers": len(providers), - "total_credentials": total_creds, - "active_credentials": active_creds, - "exhausted_credentials": exhausted_creds, - "total_requests": total_requests, - "tokens": { - "input_cached": total_input_cached, - "input_uncached": total_input_uncached, - "input_cache_pct": input_cache_pct, - "output": total_output, - }, - "approx_total_cost": total_cost if total_cost > 0 else None, - }, - "data_source": "cache", - "timestamp": now_ts, - } - - # Build global summary - if include_global: - global_total_requests = sum( - gp["total_requests"] for gp in global_providers.values() - ) - global_total_input_cached = sum( - gp["tokens"]["input_cached"] for gp in global_providers.values() - ) - global_total_input_uncached = sum( - gp["tokens"]["input_uncached"] for gp in global_providers.values() - ) - global_total_output = sum( - gp["tokens"]["output"] for gp in global_providers.values() - ) - global_total_cost = sum( - gp["approx_cost"] or 0 for gp in global_providers.values() - ) - - global_total_input = global_total_input_cached + global_total_input_uncached - global_input_cache_pct = ( - round(global_total_input_cached / global_total_input * 100, 1) - if global_total_input > 0 - else 0 - ) - - result["global_summary"] = { - "total_providers": len(global_providers), - "total_credentials": total_creds, - "total_requests": global_total_requests, - "tokens": { - "input_cached": global_total_input_cached, - "input_uncached": global_total_input_uncached, - "input_cache_pct": global_input_cache_pct, - "output": global_total_output, - }, - "approx_total_cost": global_total_cost - if global_total_cost > 0 - else None, - } - - return result - - async def reload_from_disk(self) -> None: - """ - Force reload usage data from disk. - - Useful when another process may have updated the file. - """ - async with self._init_lock: - self._initialized.clear() - await self._load_usage() - await self._reset_daily_stats_if_needed() - self._initialized.set() +__all__ = ["UsageManager", "CredentialContext"] diff --git a/src/rotator_library/utils/__init__.py b/src/rotator_library/utils/__init__.py index ce8d959d..a51d1db7 100644 --- a/src/rotator_library/utils/__init__.py +++ b/src/rotator_library/utils/__init__.py @@ -17,6 +17,7 @@ ResilientStateWriter, safe_write_json, safe_log_write, + safe_read_json, safe_mkdir, ) from .suppress_litellm_warnings import suppress_litellm_serialization_warnings @@ -34,6 +35,7 @@ "ResilientStateWriter", "safe_write_json", "safe_log_write", + "safe_read_json", "safe_mkdir", "suppress_litellm_serialization_warnings", ] diff --git a/src/rotator_library/utils/resilient_io.py b/src/rotator_library/utils/resilient_io.py index 1125e3b7..91e96f37 100644 --- a/src/rotator_library/utils/resilient_io.py +++ b/src/rotator_library/utils/resilient_io.py @@ -660,6 +660,36 @@ def safe_log_write( return False +def safe_read_json( + path: Union[str, Path], + logger: logging.Logger, + *, + parse_json: bool = True, +) -> Optional[Any]: + """ + Read file contents with error handling. + + Args: + path: File path to read from + logger: Logger for warnings/errors + parse_json: When True, parse JSON; when False, return raw text + + Returns: + Parsed JSON dict, raw text, or None on failure + """ + path = Path(path) + try: + with open(path, "r", encoding="utf-8") as f: + if parse_json: + return json.load(f) + return f.read() + except FileNotFoundError: + return None + except (OSError, PermissionError, IOError, json.JSONDecodeError) as e: + logger.error(f"Failed to read {path}: {e}") + return None + + def safe_mkdir(path: Union[str, Path], logger: logging.Logger) -> bool: """ Create directory with error handling.