diff --git a/app/services/cf_refresh/scheduler.py b/app/services/cf_refresh/scheduler.py index 11c398b3a..37d867ddc 100644 --- a/app/services/cf_refresh/scheduler.py +++ b/app/services/cf_refresh/scheduler.py @@ -39,6 +39,27 @@ async def _update_app_config( return False +async def _refresh_cooling_tokens_after_cf_update() -> None: + """在 cf 配置更新成功后,顺手恢复一次 cooling token。""" + try: + from app.core.config import get_config + from app.services.token.manager import get_token_manager + + max_tokens = int(get_config("token.on_demand_refresh_max_tokens", 100) or 100) + manager = await get_token_manager() + result = await manager.refresh_cooling_tokens( + trigger="cf_refresh", + max_tokens=max_tokens, + ) + logger.info( + "cf_refresh token check completed: " + f"checked={result['checked']}, refreshed={result['refreshed']}, " + f"recovered={result['recovered']}, expired={result['expired']}" + ) + except Exception as e: + logger.warning(f"cf_refresh token recovery skipped: {e}") + + async def refresh_once() -> bool: """执行一次刷新流程""" logger.info("=" * 50) @@ -58,6 +79,7 @@ async def refresh_once() -> bool: if success: logger.info("刷新完成") + await _refresh_cooling_tokens_after_cf_update() else: logger.error("刷新失败: 更新配置失败") @@ -72,10 +94,16 @@ async def _scheduler_loop(): # 周期性刷新(每次循环重新读取配置,支持面板修改实时生效) while True: - if is_enabled(): - await refresh_once() - else: - logger.debug("cf_refresh disabled, skip refresh") + try: + if is_enabled(): + await refresh_once() + else: + logger.debug("cf_refresh disabled, skip refresh") + except asyncio.CancelledError: + raise + except Exception as e: + logger.exception(f"cf_refresh loop error: {e}") + interval = get_refresh_interval() await asyncio.sleep(interval) diff --git a/app/services/grok/utils/retry.py b/app/services/grok/utils/retry.py index 459881e6e..8ace5c741 100644 --- a/app/services/grok/utils/retry.py +++ b/app/services/grok/utils/retry.py @@ -24,17 +24,6 @@ async def pick_token( if token: break - if not token and not tried: - await token_mgr.refresh_cooling_tokens_on_demand() - for pool_name in ModelService.pool_candidates_for_model(model_id): - token = token_mgr.get_token( - pool_name, - exclude=tried, - prefer_tags=prefer_tags, - ) - if token: - break - return token diff --git a/app/services/token/manager.py b/app/services/token/manager.py index 80549c9fb..28d8efb55 100644 --- a/app/services/token/manager.py +++ b/app/services/token/manager.py @@ -26,6 +26,7 @@ DEFAULT_REFRESH_CONCURRENCY = 5 DEFAULT_SUPER_REFRESH_INTERVAL_HOURS = 2 DEFAULT_REFRESH_INTERVAL_HOURS = 8 +DEFAULT_RATE_LIMIT_BACKOFF_SECONDS = 300 DEFAULT_RELOAD_INTERVAL_SEC = 30 DEFAULT_SAVE_DELAY_MS = 500 DEFAULT_USAGE_FLUSH_INTERVAL_SEC = 5 @@ -215,6 +216,81 @@ def _extract_window_size_seconds(self, result: dict) -> Optional[int]: return None return None + def _extract_remaining_quota(self, result: dict) -> tuple[Optional[int], bool]: + if not isinstance(result, dict): + return None, False + + value = result.get("remainingTokens") + authoritative = value is not None + if value is None: + value = result.get("remainingQueries") + + if value is None: + return None, authoritative + + try: + return max(0, int(value)), authoritative + except (TypeError, ValueError): + return None, authoritative + + def _apply_usage_result( + self, + token: TokenInfo, + pool_name: Optional[str], + result: dict, + *, + allow_from_expired: bool = False, + ) -> dict: + new_quota, authoritative = self._extract_remaining_quota(result) + if new_quota is None: + return { + "applied": False, + "pool_name": pool_name, + } + + old_quota = token.quota + old_status = token.status + token.quota = new_quota + + if new_quota > 0: + token.recover_active(allow_from_expired=allow_from_expired) + elif authoritative: + token.enter_cooling(reset_consumed=False) + + token.mark_synced() + + window_size = self._extract_window_size_seconds(result) + if window_size is not None and pool_name is not None: + if ( + pool_name == SUPER_POOL_NAME + and window_size >= SUPER_WINDOW_THRESHOLD_SECONDS + ): + pool_name = self._move_token_pool( + token, + SUPER_POOL_NAME, + BASIC_POOL_NAME, + reason=f"windowSizeSeconds={window_size}", + ) + elif ( + pool_name == BASIC_POOL_NAME + and window_size < SUPER_WINDOW_THRESHOLD_SECONDS + ): + pool_name = self._move_token_pool( + token, + BASIC_POOL_NAME, + SUPER_POOL_NAME, + reason=f"windowSizeSeconds={window_size}", + ) + + return { + "applied": True, + "pool_name": pool_name, + "old_quota": old_quota, + "old_status": old_status, + "new_quota": new_quota, + "authoritative": authoritative, + } + def _move_token_pool( self, token: TokenInfo, @@ -539,44 +615,19 @@ async def sync_usage( usage_service = UsageService() result = await usage_service.get(token_str) - if result and "remainingTokens" in result: - new_quota = result.get("remainingTokens") - if new_quota is None: - new_quota = result.get("remainingQueries") - if new_quota is None: - return False - old_quota = target_token.quota - old_status = target_token.status + usage_update = self._apply_usage_result( + target_token, + target_pool_name, + result, + allow_from_expired=True, + ) + if usage_update.get("applied"): + target_pool_name = usage_update.get("pool_name") + old_quota = usage_update["old_quota"] + old_status = usage_update["old_status"] + new_quota = usage_update["new_quota"] - if self._is_consumed_mode(): - target_token.update_quota_with_consumed(new_quota) - else: - target_token.update_quota(new_quota) target_token.record_success(is_usage=is_usage) - target_token.mark_synced() - - window_size = self._extract_window_size_seconds(result) - if window_size is not None: - if ( - target_pool_name == SUPER_POOL_NAME - and window_size >= SUPER_WINDOW_THRESHOLD_SECONDS - ): - target_pool_name = self._move_token_pool( - target_token, - SUPER_POOL_NAME, - BASIC_POOL_NAME, - reason=f"windowSizeSeconds={window_size}", - ) - elif ( - target_pool_name == BASIC_POOL_NAME - and window_size < SUPER_WINDOW_THRESHOLD_SECONDS - ): - target_pool_name = self._move_token_pool( - target_token, - BASIC_POOL_NAME, - SUPER_POOL_NAME, - reason=f"windowSizeSeconds={window_size}", - ) consumed = max(0, old_quota - new_quota) logger.debug( @@ -695,12 +746,22 @@ async def mark_rate_limited(self, token_str: str) -> bool: for pool in self.pools.values(): token = pool.get(raw_token) if token: - old_quota = token.quota - token.quota = 0 - token.enter_cooling() + backoff_seconds = get_config( + "token.rate_limit_backoff_seconds", + DEFAULT_RATE_LIMIT_BACKOFF_SECONDS, + ) + try: + backoff_seconds = int(backoff_seconds) + except (TypeError, ValueError): + backoff_seconds = DEFAULT_RATE_LIMIT_BACKOFF_SECONDS + + token.enter_cooling( + reset_consumed=False, + cooldown_seconds=max(0, backoff_seconds), + ) logger.warning( f"Token {raw_token[:10]}...: marked as rate limited " - f"(quota {old_quota} -> 0, status -> cooling)" + f"(status -> cooling, backoff={backoff_seconds}s, quota={token.quota})" ) self._track_token_change(token, pool.name, "state") self._schedule_save() @@ -986,44 +1047,15 @@ async def _refresh_one(item: tuple[str, TokenInfo]) -> dict: result, status, error = await _get_usage_with_retry(token_str) - if result and "remainingTokens" in result: - new_quota = result.get("remainingTokens") - if new_quota is None: - new_quota = result.get("remainingQueries") - if new_quota is None: - return {"recovered": False, "expired": False} - old_quota = token_info.quota - old_status = token_info.status - - if self._is_consumed_mode(): - token_info.update_quota_with_consumed(new_quota) - else: - token_info.update_quota(new_quota) - token_info.mark_synced() - - window_size = self._extract_window_size_seconds(result) - if window_size is not None: - current_pool = self.get_pool_name_for_token(token_info.token) - if ( - current_pool == SUPER_POOL_NAME - and window_size >= SUPER_WINDOW_THRESHOLD_SECONDS - ): - self._move_token_pool( - token_info, - SUPER_POOL_NAME, - BASIC_POOL_NAME, - reason=f"windowSizeSeconds={window_size}", - ) - elif ( - current_pool == BASIC_POOL_NAME - and window_size < SUPER_WINDOW_THRESHOLD_SECONDS - ): - self._move_token_pool( - token_info, - BASIC_POOL_NAME, - SUPER_POOL_NAME, - reason=f"windowSizeSeconds={window_size}", - ) + usage_update = self._apply_usage_result( + token_info, + self.get_pool_name_for_token(token_info.token), + result, + ) + if usage_update.get("applied"): + old_quota = usage_update["old_quota"] + old_status = usage_update["old_status"] + new_quota = usage_update["new_quota"] logger.debug( f"Token {token_info.token[:10]}...: refreshed " @@ -1031,7 +1063,7 @@ async def _refresh_one(item: tuple[str, TokenInfo]) -> dict: ) return { - "recovered": new_quota > 0 and old_quota == 0, + "recovered": old_status == TokenStatus.COOLING and token_info.status == TokenStatus.ACTIVE, "expired": False, } diff --git a/app/services/token/models.py b/app/services/token/models.py index c5c5413d3..703e61e36 100644 --- a/app/services/token/models.py +++ b/app/services/token/models.py @@ -69,6 +69,7 @@ class TokenInfo(BaseModel): # 冷却管理 last_sync_at: Optional[int] = None # 上次同步时间 + cooldown_until: Optional[int] = None # 短退避结束时间 # 扩展 tags: List[str] = Field(default_factory=list) @@ -111,52 +112,52 @@ def _normalize_token(cls, value): def is_available(self, consumed_mode: bool = False) -> bool: """检查当前模式下 token 是否可用。""" - if self.status != TokenStatus.ACTIVE: - return False - if consumed_mode: - return True - return self.quota > 0 + return self.status == TokenStatus.ACTIVE - def enter_cooling(self, reset_consumed: bool = True): - """进入冷却状态,并在新窗口开始时清空 consumed。""" + def enter_cooling( + self, + reset_consumed: bool = True, + cooldown_seconds: Optional[int] = None, + ): + """进入冷却状态,并按需设置短退避时间。""" self.status = TokenStatus.COOLING if reset_consumed: self.consumed = 0 + if cooldown_seconds is not None and cooldown_seconds > 0: + self.cooldown_until = ( + int(datetime.now().timestamp() * 1000) + int(cooldown_seconds * 1000) + ) + elif cooldown_seconds == 0: + self.cooldown_until = int(datetime.now().timestamp() * 1000) + else: + self.cooldown_until = None def recover_active(self, allow_from_expired: bool = False): """仅在允许的前提下恢复为 active。""" if self.status == TokenStatus.COOLING: self.status = TokenStatus.ACTIVE + self.cooldown_until = None elif allow_from_expired and self.status == TokenStatus.EXPIRED: self.status = TokenStatus.ACTIVE + self.cooldown_until = None def consume(self, effort: EffortType = EffortType.LOW) -> int: """ - 消耗配额(默认:扣减 quota) + 记录一次本地消耗估算,不再依据本地 quota 推导冷却。 Args: effort: LOW 计 1 次,HIGH 计 4 次 Returns: - 实际扣除的配额 + 本次计入的消耗值 """ cost = EFFORT_COST[effort] - # 默认行为:扣减 quota - actual_cost = min(cost, self.quota) - self.last_used_at = int(datetime.now().timestamp() * 1000) - self.consumed += cost # 无论是否开启消耗模式,都记录消耗 - self.use_count += actual_cost - self.quota = max(0, self.quota - actual_cost) - - # 默认行为:quota 耗尽时标记冷却,并重置消耗记录 - if self.quota == 0: - self.enter_cooling() - else: - self.recover_active() + self.consumed += cost + self.use_count += 1 - return actual_cost + return cost def consume_with_consumed(self, effort: EffortType = EffortType.LOW) -> int: """ @@ -251,14 +252,21 @@ def record_success(self, is_usage: bool = True): self.last_used_at = int(datetime.now().timestamp() * 1000) def need_refresh(self, interval_hours: int = 8) -> bool: - """检查是否需要刷新配额""" + """检查是否需要刷新配额。""" if self.status != TokenStatus.COOLING: return False + now = int(datetime.now().timestamp() * 1000) + if self.cooldown_until is not None: + if now < self.cooldown_until: + return False + if self.last_sync_at is None: + return True + return self.last_sync_at < self.cooldown_until + if self.last_sync_at is None: return True - now = int(datetime.now().timestamp() * 1000) interval_ms = interval_hours * 3600 * 1000 return (now - self.last_sync_at) >= interval_ms diff --git a/app/services/token/pool.py b/app/services/token/pool.py index b23be921d..bdfb889cc 100644 --- a/app/services/token/pool.py +++ b/app/services/token/pool.py @@ -43,9 +43,9 @@ def select( 选择一个可用 Token 默认模式(consumed_mode_enabled=false): - 1. 选择 active 状态且 quota > 0 的 token - 2. 优先选择剩余额度最多的 - 3. 如果额度相同,随机选择 + 1. 选择 active 状态的 token + 2. 优先选择本地消耗次数(consumed)最少的 + 3. 如果 consumed 相同,随机选择 Consumed 模式(consumed_mode_enabled=true): 1. 选择 active 状态的 token @@ -85,7 +85,7 @@ def select( else: - # ===== 默认模式(旧逻辑)===== + # ===== 默认模式 ===== available = [ t for t in self._tokens.values() @@ -96,7 +96,6 @@ def select( if not available: return None - # 优先选带指定标签的 token(若存在) if prefer_tags: preferred = [ t for t in available if prefer_tags.issubset(set(t.tags or [])) @@ -104,13 +103,8 @@ def select( if preferred: available = preferred - # 找到最大额度 - max_quota = max(t.quota for t in available) - - # 筛选最大额度 - candidates = [t for t in available if t.quota == max_quota] - - # 随机选择 + min_consumed = min(t.consumed for t in available) + candidates = [t for t in available if t.consumed == min_consumed] return random.choice(candidates) def count(self) -> int: diff --git a/tests/test_cf_refresh_scheduler.py b/tests/test_cf_refresh_scheduler.py new file mode 100644 index 000000000..14a30741b --- /dev/null +++ b/tests/test_cf_refresh_scheduler.py @@ -0,0 +1,82 @@ +import asyncio +import unittest +from unittest.mock import patch + +from app.services.cf_refresh import scheduler + + +class CfRefreshSchedulerTests(unittest.IsolatedAsyncioTestCase): + async def test_scheduler_loop_continues_after_refresh_exception(self): + calls = {"refresh": 0, "sleep": 0} + + async def fake_refresh_once(): + calls["refresh"] += 1 + if calls["refresh"] == 1: + raise RuntimeError("boom") + raise asyncio.CancelledError() + + async def fake_sleep(_seconds): + calls["sleep"] += 1 + + with ( + patch.object(scheduler, "is_enabled", return_value=True), + patch.object(scheduler, "get_refresh_interval", return_value=0), + patch.object(scheduler, "refresh_once", new=fake_refresh_once), + patch.object(scheduler.asyncio, "sleep", new=fake_sleep), + ): + with self.assertRaises(asyncio.CancelledError): + await scheduler._scheduler_loop() + + self.assertEqual(calls["refresh"], 2) + self.assertEqual(calls["sleep"], 1) + + async def test_refresh_once_triggers_token_refresh_after_cf_update(self): + calls = [] + + async def fake_solve(): + return { + "cookies": "cf_clearance=test-cookie", + "cf_clearance": "test-cookie", + "user_agent": "UA", + "browser": "chrome142", + } + + async def fake_update(**kwargs): + calls.append(("update", kwargs)) + return True + + async def fake_refresh_tokens(): + calls.append(("refresh_tokens", None)) + + with ( + patch.object(scheduler, "solve_cf_challenge", new=fake_solve), + patch.object(scheduler, "_update_app_config", new=fake_update), + patch.object( + scheduler, + "_refresh_cooling_tokens_after_cf_update", + new=fake_refresh_tokens, + create=True, + ), + ): + ok = await scheduler.refresh_once() + + self.assertTrue(ok) + self.assertEqual( + calls, + [ + ( + "update", + { + "cf_cookies": "cf_clearance=test-cookie", + "cf_clearance": "test-cookie", + "user_agent": "UA", + "browser": "chrome142", + }, + ), + ("refresh_tokens", None), + ], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_token_rate_limit_recovery.py b/tests/test_token_rate_limit_recovery.py new file mode 100644 index 000000000..c26a34c9a --- /dev/null +++ b/tests/test_token_rate_limit_recovery.py @@ -0,0 +1,113 @@ +import asyncio +import time +from unittest.mock import AsyncMock + +from app.services.grok.utils.retry import pick_token +from app.services.token.manager import TokenManager, BASIC_POOL_NAME +from app.services.token.models import TokenInfo, TokenStatus +from app.services.token.pool import TokenPool + + +async def _noop_save(*_args, **_kwargs): + return None + + +def _build_manager(token: TokenInfo) -> TokenManager: + manager = TokenManager() + pool = TokenPool(BASIC_POOL_NAME) + pool.add(token) + manager.pools = {BASIC_POOL_NAME: pool} + manager._schedule_save = lambda: None + manager._save = _noop_save + return manager + + +def test_mark_rate_limited_uses_short_backoff(monkeypatch): + monkeypatch.setattr( + "app.services.token.manager.get_config", + lambda key, default=None: 300 if key == "token.rate_limit_backoff_seconds" else default, + ) + + token = TokenInfo(token="tok-1", quota=40) + manager = _build_manager(token) + + asyncio.run(manager.mark_rate_limited("tok-1")) + + assert token.status == TokenStatus.COOLING + assert token.quota == 40 + assert token.cooldown_until is not None + remaining_ms = token.cooldown_until - int(time.time() * 1000) + assert 0 < remaining_ms <= 300_000 + + +def test_refresh_cooling_tokens_recovers_after_backoff(monkeypatch): + monkeypatch.setattr( + "app.services.token.manager.get_config", + lambda key, default=None: 60 if key == "token.refresh_interval_hours" else default, + ) + class DummyRetryContext: + def __init__(self): + self.attempt = 0 + self.max_retry = 0 + self.total_delay = 0.0 + self.retry_budget = 0.0 + + def record_error(self, status, error): + self.attempt += 1 + + def should_retry(self, status, error): + return False + + def calculate_delay(self, status, retry_after): + return 0.0 + + def record_delay(self, delay): + self.total_delay += delay + + monkeypatch.setattr( + "app.services.token.manager.RetryContext", + DummyRetryContext, + ) + + token = TokenInfo( + token="tok-2", + status=TokenStatus.COOLING, + quota=0, + cooldown_until=int(time.time() * 1000) - 1, + ) + manager = _build_manager(token) + + async def fake_get(_token: str): + return {"remainingQueries": 40, "windowSizeSeconds": 7200} + + monkeypatch.setattr( + "app.services.token.manager.UsageService.get", + AsyncMock(side_effect=fake_get), + ) + + result = asyncio.run(manager.refresh_cooling_tokens(trigger="test")) + + assert result == {"checked": 1, "refreshed": 1, "recovered": 1, "expired": 0} + assert token.status == TokenStatus.ACTIVE + assert token.quota == 40 + assert token.cooldown_until is None + + +def test_pick_token_does_not_trigger_on_demand_refresh_when_empty(): + class DummyManager: + def __init__(self): + self.calls = 0 + + def get_token(self, pool_name, exclude=None, prefer_tags=None): + self.calls += 1 + return None + + async def refresh_cooling_tokens_on_demand(self): + raise AssertionError("should not be called") + + manager = DummyManager() + + token = asyncio.run(pick_token(manager, "grok-4", set())) + + assert token is None + assert manager.calls >= 1