From fa2e98749f13327ce94069b203b58a3b67860e9f Mon Sep 17 00:00:00 2001 From: spoons-and-mirrors <212802214+spoons-and-mirrors@users.noreply.github.com> Date: Wed, 21 Jan 2026 06:42:47 +0100 Subject: [PATCH] feat(rotation): manual account selection --- src/proxy_app/main.py | 61 ++++++++++++++++++ src/proxy_app/quota_viewer.py | 81 ++++++++++++++++++++++++ src/rotator_library/usage_manager.py | 93 ++++++++++++++++++++++++++++ 3 files changed, 235 insertions(+) diff --git a/src/proxy_app/main.py b/src/proxy_app/main.py index 4d8dba99..216f0c82 100644 --- a/src/proxy_app/main.py +++ b/src/proxy_app/main.py @@ -1486,6 +1486,67 @@ async def refresh_quota_stats( raise HTTPException(status_code=500, detail=str(e)) +@app.post("/v1/force-credential") +async def force_credential( + request: Request, + client: RotatingClient = Depends(get_rotating_client), + _=Depends(verify_api_key), +): + """ + Force the proxy to use a specific credential for all requests. + + This overrides the normal rotation logic and always selects the specified + credential (if available and not on cooldown). Used by the TUI for manual + credential selection. + + Request body: + { + "credential": "path/to/credential.json" | null, + "provider": "antigravity" // optional, for display/logging only + } + + Set credential to null to clear the override and resume normal rotation. + + Returns: + { + "status": "forced" | "cleared", + "credential": "path/to/credential.json" | null, + "message": "..." + } + """ + try: + data = await request.json() + credential = data.get("credential") + provider = data.get("provider", "unknown") + + # Set or clear forced credential + await client.usage_manager.set_forced_credential(credential) + + if credential: + # Extract a friendly display name from the credential path + if "/" in credential or "\\" in credential: + display_name = credential.split("/")[-1].split("\\")[-1] + else: + display_name = credential + + return { + "status": "forced", + "credential": credential, + "provider": provider, + "message": f"Now forcing credential: {display_name}" + } + else: + return { + "status": "cleared", + "credential": None, + "message": "Forced credential cleared. Resuming normal rotation." + } + + except Exception as e: + logging.error(f"Failed to set forced credential: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + @app.post("/v1/token-count") async def token_count( request: Request, diff --git a/src/proxy_app/quota_viewer.py b/src/proxy_app/quota_viewer.py index f4bf59b2..c4300bd9 100644 --- a/src/proxy_app/quota_viewer.py +++ b/src/proxy_app/quota_viewer.py @@ -912,6 +912,9 @@ def show_provider_detail_screen(self, provider: str): self.console.print(" G. Toggle view mode (current/global)") self.console.print(" R. Reload stats (from proxy cache)") self.console.print(" RA. Reload all stats") + self.console.print() + self.console.print(" [cyan]1-9. Force use of credential [N] (locks rotation)[/cyan]") + self.console.print(" C. Clear forced credential (resume normal rotation)") # Force refresh options (only for providers that support it) has_quota_groups = bool( @@ -964,6 +967,84 @@ def show_provider_detail_screen(self, provider: str): "[bold]Reloading all stats...", spinner="dots" ): self.post_action("reload", scope="all") + elif choice.isdigit() and 1 <= int(choice) <= 9: + # Handle numeric selection (force credential) + idx = int(choice) + credentials = ( + self.cached_stats.get("providers", {}) + .get(provider, {}) + .get("credentials", []) + if self.cached_stats + else [] + ) + # Sort credentials naturally to match display order + credentials = sorted(credentials, key=natural_sort_key) + + if idx <= len(credentials): + cred = credentials[idx - 1] + # Use full_path for matching, fall back to identifier + cred_identifier = cred.get("full_path", cred.get("identifier", "")) + cred_email = cred.get("email", cred.get("identifier", "")) + + # Call API to force this credential + url = self._build_endpoint_url("/v1/force-credential") + payload = { + "credential": cred_identifier, + "provider": provider + } + + try: + with httpx.Client(timeout=10.0) as http_client: + response = http_client.post( + url, + headers=self._get_headers(), + json=payload + ) + + if response.status_code == 200: + result = response.json() + self.console.print( + f"\n[green]✓ Forced credential:[/green] [{idx}] {cred_email}" + ) + self.console.print( + "[dim]All requests will now use this credential (if available)[/dim]" + ) + else: + self.console.print( + f"\n[red]Failed to force credential: HTTP {response.status_code}[/red]" + ) + except Exception as e: + self.console.print(f"\n[red]Error: {e}[/red]") + + Prompt.ask("Press Enter to continue", default="") + else: + self.console.print(f"\n[red]Invalid selection. Only {len(credentials)} credentials available.[/red]") + Prompt.ask("Press Enter to continue", default="") + elif choice == "C": + # Clear forced credential + url = self._build_endpoint_url("/v1/force-credential") + payload = {"credential": None} + + try: + with httpx.Client(timeout=10.0) as http_client: + response = http_client.post( + url, + headers=self._get_headers(), + json=payload + ) + + if response.status_code == 200: + self.console.print( + "\n[green]✓ Forced credential cleared. Resuming normal rotation.[/green]" + ) + else: + self.console.print( + f"\n[red]Failed to clear forced credential: HTTP {response.status_code}[/red]" + ) + except Exception as e: + self.console.print(f"\n[red]Error: {e}[/red]") + + Prompt.ask("Press Enter to continue", default="") elif choice == "F" and has_quota_groups: result = None with self.console.status( diff --git a/src/rotator_library/usage_manager.py b/src/rotator_library/usage_manager.py index 39d9c05e..fe4e101b 100644 --- a/src/rotator_library/usage_manager.py +++ b/src/rotator_library/usage_manager.py @@ -162,6 +162,10 @@ def __init__( # Resilient writer for usage data persistence self._state_writer = ResilientStateWriter(file_path, lib_logger) + # Forced credential for manual override (TUI control) + self._forced_credential: Optional[str] = None + self._forced_credential_lock = asyncio.Lock() + if daily_reset_time_utc: hour, minute = map(int, daily_reset_time_utc.split(":")) self.daily_reset_time_utc = dt_time( @@ -182,6 +186,37 @@ def _get_rotation_mode(self, provider: str) -> str: """ return self.provider_rotation_modes.get(provider, "balanced") + # ========================================================================= + # FORCED CREDENTIAL (TUI OVERRIDE) + # ========================================================================= + + async def set_forced_credential(self, credential: Optional[str]) -> None: + """ + Force the usage manager to use a specific credential for all requests. + + This overrides the normal rotation logic and always selects the specified + credential, if it's available and not on cooldown. + + Args: + credential: Full credential path/identifier, or None to clear the override + """ + async with self._forced_credential_lock: + self._forced_credential = credential + if credential: + lib_logger.info(f"Forced credential set to: {mask_credential(credential)}") + else: + lib_logger.info("Forced credential cleared") + + async def get_forced_credential(self) -> Optional[str]: + """ + Get the currently forced credential, if any. + + Returns: + The forced credential path/identifier, or None if no override is active + """ + async with self._forced_credential_lock: + return self._forced_credential + # ========================================================================= # FAIR CYCLE ROTATION HELPERS # ========================================================================= @@ -2163,6 +2198,64 @@ async def acquire_key( self._normalize_model(available_keys[0], model) if available_keys else model ) + # Check if a specific credential is forced (TUI override) + forced_cred = await self.get_forced_credential() + if forced_cred: + # Find matching credential - support both full path and filename matching + matched_cred = None + if forced_cred in available_keys: + matched_cred = forced_cred + else: + # Try matching by filename (basename) + for key in available_keys: + if key.endswith(forced_cred) or Path(key).name == forced_cred: + matched_cred = key + break + + if matched_cred: + now = time.time() + async with self._data_lock: + key_data = self._usage_data.get(matched_cred, {}) + # Check if forced credential is available (not on cooldown) + is_on_cooldown = ( + (key_data.get("key_cooldown_until") or 0) > now or + (key_data.get("model_cooldowns", {}).get(normalized_model) or 0) > now + ) + + if not is_on_cooldown: + # Try to acquire the forced credential + state = self.key_states[matched_cred] + async with state["lock"]: + current_count = state["models_in_use"].get(model, 0) + if current_count < max_concurrent: + state["models_in_use"][model] = current_count + 1 + tier_name = ( + credential_tier_names.get(matched_cred, "unknown") + if credential_tier_names + else "unknown" + ) + quota_display = self._get_quota_display(matched_cred, model) + lib_logger.info( + f"Acquired FORCED key {mask_credential(matched_cred)} for model {model} " + f"(tier: {tier_name}, {quota_display})" + ) + return matched_cred + else: + lib_logger.warning( + f"Forced credential {mask_credential(matched_cred)} is at max concurrency " + f"({current_count}/{max_concurrent}), falling back to normal rotation" + ) + else: + lib_logger.warning( + f"Forced credential {mask_credential(matched_cred)} is on cooldown, " + f"falling back to normal rotation" + ) + else: + lib_logger.warning( + f"Forced credential {mask_credential(forced_cred)} not found in available credentials, " + f"falling back to normal rotation" + ) + # This loop continues as long as the global deadline has not been met. while time.time() < deadline: now = time.time()