diff --git a/hindsight-api-slim/hindsight_api/config.py b/hindsight-api-slim/hindsight_api/config.py index c9ed8f17e..e219199d7 100644 --- a/hindsight-api-slim/hindsight_api/config.py +++ b/hindsight-api-slim/hindsight_api/config.py @@ -200,6 +200,7 @@ def normalize_config_dict(config: dict[str, Any]) -> dict[str, Any]: ENV_EMBEDDINGS_OPENAI_MODEL = "HINDSIGHT_API_EMBEDDINGS_OPENAI_MODEL" ENV_EMBEDDINGS_OPENAI_BASE_URL = "HINDSIGHT_API_EMBEDDINGS_OPENAI_BASE_URL" ENV_EMBEDDINGS_OPENAI_BATCH_SIZE = "HINDSIGHT_API_EMBEDDINGS_OPENAI_BATCH_SIZE" +ENV_EMBEDDINGS_OPENAI_DIMENSIONS = "HINDSIGHT_API_EMBEDDINGS_OPENAI_DIMENSIONS" # Gemini/Vertex AI embeddings configuration ENV_EMBEDDINGS_GEMINI_API_KEY = "HINDSIGHT_API_EMBEDDINGS_GEMINI_API_KEY" @@ -792,6 +793,13 @@ def _parse_positive_int(name: str, raw: str | None, default: int) -> int: return parsed +def _parse_optional_positive_int(name: str, raw: str | None) -> int | None: + """Parse an optional env var that must be a positive integer when set.""" + if raw is None or raw == "": + return None + return _parse_positive_int(name, raw, 1) + + def _validate_extraction_mode(mode: str) -> str: """Validate and normalize extraction mode.""" mode_lower = mode.lower() @@ -1179,6 +1187,7 @@ class HindsightConfig: # Defaulted fields (source-compatible additions — existing direct constructor callers keep working). # Keep at the end of the dataclass; Python forbids non-default fields after default fields. embeddings_openai_batch_size: int = DEFAULT_EMBEDDINGS_OPENAI_BATCH_SIZE + embeddings_openai_dimensions: int | None = None # Class-level sets for configuration categorization @@ -1538,6 +1547,10 @@ def from_env(cls) -> "HindsightConfig": os.getenv(ENV_EMBEDDINGS_OPENAI_BATCH_SIZE), DEFAULT_EMBEDDINGS_OPENAI_BATCH_SIZE, ), + embeddings_openai_dimensions=_parse_optional_positive_int( + ENV_EMBEDDINGS_OPENAI_DIMENSIONS, + os.getenv(ENV_EMBEDDINGS_OPENAI_DIMENSIONS), + ), # Cohere embeddings (with backward-compatible fallback to shared API key) embeddings_cohere_api_key=os.getenv(ENV_EMBEDDINGS_COHERE_API_KEY) or os.getenv(ENV_COHERE_API_KEY), embeddings_cohere_model=os.getenv(ENV_EMBEDDINGS_COHERE_MODEL, DEFAULT_EMBEDDINGS_COHERE_MODEL), diff --git a/hindsight-api-slim/hindsight_api/engine/embeddings.py b/hindsight-api-slim/hindsight_api/engine/embeddings.py index f6ce6bcf7..32e52f92e 100644 --- a/hindsight-api-slim/hindsight_api/engine/embeddings.py +++ b/hindsight-api-slim/hindsight_api/engine/embeddings.py @@ -385,6 +385,7 @@ def __init__( model: str = DEFAULT_EMBEDDINGS_OPENAI_MODEL, base_url: str | None = None, batch_size: int = 100, + dimensions: int | None = None, max_retries: int = 3, ): """ @@ -395,12 +396,14 @@ def __init__( model: OpenAI embedding model name (default: text-embedding-3-small) base_url: Custom base URL for OpenAI-compatible API (e.g., Azure OpenAI endpoint) batch_size: Maximum batch size for embedding requests (default: 100) + dimensions: Optional requested output dimensions for OpenAI text-embedding-3 models max_retries: Maximum number of retries for failed requests (default: 3) """ self.api_key = api_key self.model = model self.base_url = base_url self.batch_size = batch_size + self.dimensions = dimensions self.max_retries = max_retries self._client = None self._dimension: int | None = None @@ -445,7 +448,9 @@ async def initialize(self) -> None: self._client = OpenAI(**client_kwargs) # Try to get dimension from known models, otherwise do a test embedding - if self.model in self.MODEL_DIMENSIONS: + if self.dimensions is not None: + self._dimension = self.dimensions + elif self.model in self.MODEL_DIMENSIONS: self._dimension = self.MODEL_DIMENSIONS[self.model] else: # Do a test embedding to detect dimension @@ -480,10 +485,14 @@ def encode(self, texts: list[str]) -> list[list[float]]: for i in range(0, len(texts), self.batch_size): batch = texts[i : i + self.batch_size] - response = self._client.embeddings.create( - model=self.model, - input=batch, - ) + request = { + "model": self.model, + "input": batch, + } + if self.dimensions is not None: + request["dimensions"] = self.dimensions + + response = self._client.embeddings.create(**request) # Sort by index to ensure correct order batch_embeddings = sorted(response.data, key=lambda x: x.index) @@ -492,6 +501,75 @@ def encode(self, texts: list[str]) -> list[list[float]]: return all_embeddings +class CodexOAuthEmbeddings(OpenAIEmbeddings): + """ + OpenAI embeddings using the Codex/ChatGPT OAuth token from ``~/.codex/auth.json``. + + Codex OAuth is an LLM-provider auth path in Hindsight, but the same bearer token + can also authenticate against the standard OpenAI embeddings endpoint. This keeps + embeddings on the user's existing Codex subscription/OAuth path without requiring + a separate OpenAI/OpenRouter/Gemini/Cohere API key. + + Token refresh is handled automatically: the manager proactively refreshes the + access_token before it expires and reactively refreshes on 401 responses from + the embeddings API. + """ + + def __init__( + self, + model: str = DEFAULT_EMBEDDINGS_OPENAI_MODEL, + batch_size: int = 100, + dimensions: int | None = None, + max_retries: int = 3, + ): + from .providers.codex_auth import CodexAuthManager + + self._auth_manager = CodexAuthManager.from_file() + super().__init__( + api_key=self._auth_manager.access_token, + model=model, + base_url="https://api.openai.com/v1", + batch_size=batch_size, + dimensions=dimensions, + max_retries=max_retries, + ) + + @property + def provider_name(self) -> str: + return "openai-codex" + + def encode(self, texts: list[str]) -> list[list[float]]: + """Generate embeddings, refreshing the OAuth token if needed. + + Proactively refreshes before the call when the token is near expiry, + and reactively refreshes once on a 401 from the OpenAI embeddings API. + """ + from openai import AuthenticationError + + from .providers.codex_auth import CodexRefreshExpiredError + + # Proactive refresh — cheap when fresh (JWT exp decode + compare). + self._auth_manager.ensure_fresh_token() + if self._auth_manager.access_token != self.api_key: + self.api_key = self._auth_manager.access_token + if self._client is not None: + self._client.api_key = self._auth_manager.access_token + + try: + return super().encode(texts) + except AuthenticationError: + # Reactive refresh — token was valid by the JWT clock but the + # server rejected it (rotated server-side, race, etc.). + self._auth_manager.refresh_tokens( + reason="reactive (401 from embeddings API)", + force=True, + ) + self.api_key = self._auth_manager.access_token + if self._client is not None: + self._client.api_key = self._auth_manager.access_token + return super().encode(texts) + + class CohereEmbeddings(Embeddings): """ Cohere embeddings implementation using the Cohere API. @@ -1140,6 +1218,14 @@ def create_embeddings_from_env() -> Embeddings: model=model, base_url=base_url, batch_size=config.embeddings_openai_batch_size, + dimensions=config.embeddings_openai_dimensions, + ) + elif provider == "openai-codex": + model = os.environ.get(ENV_EMBEDDINGS_OPENAI_MODEL, DEFAULT_EMBEDDINGS_OPENAI_MODEL) + return CodexOAuthEmbeddings( + model=model, + batch_size=config.embeddings_openai_batch_size, + dimensions=config.embeddings_openai_dimensions, ) elif provider == "openrouter": api_key = config.embeddings_openrouter_api_key @@ -1153,6 +1239,7 @@ def create_embeddings_from_env() -> Embeddings: model=config.embeddings_openrouter_model, base_url="https://openrouter.ai/api/v1", batch_size=config.embeddings_openai_batch_size, + dimensions=config.embeddings_openai_dimensions, ) elif provider == "cohere": api_key = config.embeddings_cohere_api_key @@ -1206,5 +1293,6 @@ def create_embeddings_from_env() -> Embeddings: else: raise ValueError( f"Unknown embeddings provider: {provider}. " - f"Supported: 'local', 'tei', 'openai', 'cohere', 'google', 'litellm', 'litellm-sdk'" + f"Supported: 'local', 'tei', 'openai', 'openai-codex', 'openrouter', 'cohere', 'google', " + f"'litellm', 'litellm-sdk'" ) diff --git a/hindsight-api-slim/hindsight_api/engine/providers/codex_auth.py b/hindsight-api-slim/hindsight_api/engine/providers/codex_auth.py new file mode 100644 index 000000000..4397d7fcf --- /dev/null +++ b/hindsight-api-slim/hindsight_api/engine/providers/codex_auth.py @@ -0,0 +1,408 @@ +""" +Shared Codex OAuth authentication manager. + +Extracted from ``CodexLLM`` so that both ``CodexLLM`` and +``CodexOAuthEmbeddings`` can share JWT-expiry detection, single-flight +token refresh, and atomic file persistence without duplicating the logic. + +Usage +----- +Create a manager from the auth file:: + + mgr = CodexAuthManager.from_file() + +Then call ``ensure_fresh_token()`` before each outbound request and +``refresh_tokens(reason=..., force=...)`` on a reactive 401. +""" + +from __future__ import annotations + +import base64 +import binascii +import json +import logging +import os +import tempfile +import threading +import time +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import httpx + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Module-level constants (shared with codex_llm.py via re-export there) +# --------------------------------------------------------------------------- + +# OAuth refresh endpoint and client id, mirrored from the canonical +# ``@openai/codex`` CLI (codex-rs/login/src/auth/manager.rs on +# github.com/openai/codex). The endpoint is overridable via env var so that +# future Codex changes or staging environments can be pointed at without a +# code change — same env var name the upstream CLI uses. +_CODEX_REFRESH_TOKEN_URL = os.environ.get("CODEX_REFRESH_TOKEN_URL_OVERRIDE", "https://auth.openai.com/oauth/token") +_CODEX_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" + +# Proactively refresh this many seconds before the JWT ``exp`` claim. The +# upstream Codex CLI uses no skew (it refreshes at ``exp <= now``); the +# extra window reduces races where a request leaves the client with a token +# that the server has already declared expired by the time it arrives. +_CODEX_TOKEN_REFRESH_SKEW_SECONDS = 60 + +# OAuth error codes that the refresh endpoint returns when the refresh_token +# itself is no longer usable. These are terminal — retrying refresh will not +# succeed; the user must re-run ``codex auth login``. +_CODEX_TERMINAL_REFRESH_ERROR_CODES = frozenset( + {"refresh_token_expired", "refresh_token_reused", "refresh_token_invalidated"} +) + + +class CodexRefreshExpiredError(RuntimeError): + """Raised when the Codex refresh_token itself is no longer valid. + + The user must re-run ``codex auth login`` to obtain new credentials. + Callers should surface a clear remediation message and stop retrying. + """ + + +class CodexAuthManager: + """Sync Codex OAuth credential manager. + + Holds the access_token, refresh_token, and account_id in memory and + handles proactive/reactive refresh using a ``threading.Lock`` for + single-flight semantics (safe to use from multiple threads or via + ``asyncio.to_thread``). + + Parameters + ---------- + access_token: + The current bearer token. + account_id: + The OpenAI account ID embedded in the Codex request headers. + refresh_token: + The OAuth refresh token. May be ``None`` when the auth file omits it; + the provider still works as a one-shot loader in that case. + auth_file: + Path to ``~/.codex/auth.json``. Used for re-reading the refresh token + on demand and for atomic persistence of rotated credentials. + """ + + def __init__( + self, + access_token: str, + account_id: str, + refresh_token: str | None, + auth_file: Path, + ) -> None: + self.access_token = access_token + self.account_id = account_id + self.refresh_token = refresh_token + self._auth_file = auth_file + self._lock = threading.Lock() + self._http_client = httpx.Client(timeout=30.0) + + # ------------------------------------------------------------------ + # Construction helpers + # ------------------------------------------------------------------ + + @classmethod + def from_file(cls, auth_file: Path | None = None) -> "CodexAuthManager": + """Build a manager by reading credentials from ``auth_file``. + + Parameters + ---------- + auth_file: + Defaults to ``~/.codex/auth.json``. + + Raises + ------ + FileNotFoundError: + If the auth file does not exist. + ValueError: + If the auth file is missing ``access_token`` or has an unexpected + ``auth_mode``. + """ + if auth_file is None: + auth_file = Path.home() / ".codex" / "auth.json" + + if not auth_file.exists(): + raise FileNotFoundError(f"Codex auth file not found: {auth_file}. Run 'codex auth login' to authenticate.") + + with open(auth_file) as f: + data = json.load(f) + + auth_mode = data.get("auth_mode") + if auth_mode != "chatgpt": + raise ValueError(f"Expected Codex auth_mode='chatgpt', got: {auth_mode}") + + tokens = data.get("tokens") or {} + access_token = tokens.get("access_token") + if not access_token: + raise ValueError("No access_token found in Codex auth file. Run 'codex auth login' again.") + + account_id = tokens.get("account_id") or "" + refresh_token = tokens.get("refresh_token") + + return cls( + access_token=access_token, + account_id=account_id, + refresh_token=refresh_token, + auth_file=auth_file, + ) + + # ------------------------------------------------------------------ + # Token state helpers + # ------------------------------------------------------------------ + + def load_refresh_token_from_file(self) -> str | None: + """Re-read ``tokens.refresh_token`` from ``_auth_file``. + + Returns ``None`` when the file is unreadable or omits the field. + Does not raise — the provider degrades to one-shot mode. + """ + try: + with open(self._auth_file) as f: + data = json.load(f) + except (OSError, json.JSONDecodeError) as e: + logger.warning( + f"Codex auth file unreadable when loading refresh_token: {type(e).__name__}. " + "Token refresh will not be available; the access_token in memory will be used until it expires." + ) + return None + return data.get("tokens", {}).get("refresh_token") + + @staticmethod + def _decode_jwt_exp_unixtime(token: str) -> int | None: + """Return the JWT ``exp`` claim as a unix timestamp, or None on parse failure. + + We do not verify the signature — the server is the source of truth + on whether the token is actually accepted. This is only used to + schedule proactive refresh. + """ + try: + parts = token.split(".") + if len(parts) < 2: + return None + payload_b64 = parts[1] + padding = "=" * (-len(payload_b64) % 4) + payload_bytes = base64.urlsafe_b64decode(payload_b64 + padding) + payload = json.loads(payload_bytes.decode("utf-8")) + exp = payload.get("exp") + return int(exp) if exp is not None else None + except (ValueError, TypeError, json.JSONDecodeError, binascii.Error): + return None + + def _token_is_stale(self, skew_seconds: int = _CODEX_TOKEN_REFRESH_SKEW_SECONDS) -> bool: + """True when the cached access_token is past expiry (with skew). + + Returns False when expiry cannot be determined — we'd rather use a + possibly-expired token and recover via the reactive 401 path than + refresh aggressively on every request when ``exp`` parsing fails. + """ + exp = self._decode_jwt_exp_unixtime(self.access_token) + if exp is None: + return False + return exp <= int(time.time()) + skew_seconds + + # ------------------------------------------------------------------ + # Persistence + # ------------------------------------------------------------------ + + def _persist_auth_atomic(self, updated_tokens: dict[str, Any]) -> None: + """Write rotated tokens back to ``_auth_file`` atomically. + + Re-reads the on-disk file first to avoid clobbering fields written + by another process, patches ``tokens.*`` and ``last_refresh``, then + writes to a sibling tempfile and calls ``os.replace`` (atomic on + POSIX and Windows within the same filesystem). + """ + current: dict[str, Any] + try: + with open(self._auth_file) as f: + loaded = json.load(f) + current = loaded if isinstance(loaded, dict) else {"auth_mode": "chatgpt", "tokens": {}} + except (OSError, json.JSONDecodeError): + current = {"auth_mode": "chatgpt", "tokens": {}} + + existing_tokens = current.get("tokens") + tokens: dict[str, Any] = existing_tokens if isinstance(existing_tokens, dict) else {} + for key in ("access_token", "refresh_token", "id_token", "account_id"): + if key in updated_tokens and updated_tokens[key] is not None: + tokens[key] = updated_tokens[key] + current["tokens"] = tokens + current["last_refresh"] = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") + + parent = self._auth_file.parent + parent.mkdir(parents=True, exist_ok=True) + fd, tmp_path = tempfile.mkstemp(prefix=".auth.", suffix=".json.tmp", dir=str(parent)) + try: + with os.fdopen(fd, "w") as f: + json.dump(current, f, indent=2) + f.flush() + os.fsync(f.fileno()) + try: + os.chmod(tmp_path, 0o600) + except OSError: + pass + os.replace(tmp_path, self._auth_file) + except Exception: + try: + os.unlink(tmp_path) + except OSError: + pass + raise + + # ------------------------------------------------------------------ + # Error extraction + # ------------------------------------------------------------------ + + @staticmethod + def _extract_oauth_error_code(response: httpx.Response) -> str | None: + """Pull the OAuth error code out of a 4xx response body, if present. + + The refresh endpoint returns shapes like + ``{"error": "...", "error_code": "..."}`` or + ``{"error": {"code": "..."}}``. + """ + try: + body = response.json() + except (json.JSONDecodeError, ValueError): + return None + if not isinstance(body, dict): + return None + err = body.get("error") + if isinstance(err, dict): + code = err.get("code") + if isinstance(code, str): + return code + code = body.get("error_code") + if isinstance(code, str): + return code + if isinstance(err, str): + return err + return None + + # ------------------------------------------------------------------ + # Refresh + # ------------------------------------------------------------------ + + def refresh_tokens(self, reason: str = "", *, force: bool = False) -> None: + """Synchronous single-flight OAuth token refresh. + + Serialized through ``self._lock`` so concurrent threads produce one + network request. The first caller refreshes; the rest wake up and + skip if the token is no longer stale (proactive) or if the token + has already changed (reactive / force). + + Parameters + ---------- + reason: + Free-form string included in log lines for diagnostics. + force: + When True, refresh even if the JWT exp claim looks fresh. + Used by the reactive 401 path. + + Raises + ------ + CodexRefreshExpiredError: + When the server returns a terminal error code or any 401. + RuntimeError: + For other refresh failures (network, 5xx, etc.). + """ + token_before_lock = self.access_token + with self._lock: + if force: + if self.access_token != token_before_lock: + return + else: + if not self._token_is_stale(): + return + + if not self.refresh_token: + raise RuntimeError( + "Codex access_token is expired but no refresh_token is available. " + "Run 'codex auth login' to re-authenticate." + ) + + log_reason = f" ({reason})" if reason else "" + logger.info(f"Refreshing Codex OAuth access_token{log_reason}") + + request_body = { + "client_id": _CODEX_CLIENT_ID, + "grant_type": "refresh_token", + "refresh_token": self.refresh_token, + } + try: + response = self._http_client.post( + _CODEX_REFRESH_TOKEN_URL, + json=request_body, + headers={"Content-Type": "application/json"}, + timeout=30.0, + ) + except httpx.RequestError as e: + raise RuntimeError(f"Codex OAuth refresh network error: {type(e).__name__}") from e + + if response.status_code == 401: + error_code = self._extract_oauth_error_code(response) + if error_code in _CODEX_TERMINAL_REFRESH_ERROR_CODES: + raise CodexRefreshExpiredError( + f"Codex refresh_token is permanently invalid (error.code={error_code}). " + "Run 'codex auth login' to re-authenticate." + ) + raise CodexRefreshExpiredError( + f"Codex OAuth refresh returned 401 with unrecognized error code " + f"({error_code or 'none'}). Run 'codex auth login' to re-authenticate." + ) + + if response.status_code >= 400: + raise RuntimeError(f"Codex OAuth refresh failed with HTTP {response.status_code}") + + try: + body = response.json() + except json.JSONDecodeError as e: + raise RuntimeError(f"Codex OAuth refresh returned non-JSON body: {e}") from e + + new_access = body.get("access_token") + if not new_access: + raise RuntimeError("Codex OAuth refresh returned no access_token") + + new_refresh = body.get("refresh_token") or self.refresh_token + new_id_token = body.get("id_token") + + # Update in-memory state first so waiters see fresh credentials + # immediately, even if disk write fails. + self.access_token = new_access + self.refresh_token = new_refresh + + persisted: dict[str, Any] = { + "access_token": new_access, + "refresh_token": new_refresh, + } + if new_id_token: + persisted["id_token"] = new_id_token + + try: + self._persist_auth_atomic(persisted) + except OSError as e: + logger.warning( + f"Codex OAuth refresh succeeded but persisting auth.json failed: {type(e).__name__}. " + "In-memory credentials are up to date; on-disk file is stale." + ) + + logger.info("Codex OAuth access_token refreshed successfully") + + def ensure_fresh_token(self) -> None: + """Proactively refresh the access_token if it is near or past expiry. + + Cheap when the token is fresh (just decodes the JWT exp claim and + returns). + """ + if self._token_is_stale(): + self.refresh_tokens(reason="proactive (token near expiry)") + + def close(self) -> None: + """Close the underlying HTTP client.""" + self._http_client.close() diff --git a/hindsight-api-slim/hindsight_api/engine/providers/codex_llm.py b/hindsight-api-slim/hindsight_api/engine/providers/codex_llm.py index 148baa748..f4365bb70 100644 --- a/hindsight-api-slim/hindsight_api/engine/providers/codex_llm.py +++ b/hindsight-api-slim/hindsight_api/engine/providers/codex_llm.py @@ -15,15 +15,10 @@ """ import asyncio -import base64 -import binascii import json import logging -import os -import tempfile import time import uuid -from datetime import datetime, timezone from pathlib import Path from typing import Any @@ -33,37 +28,27 @@ from hindsight_api.engine.response_models import LLMToolCall, LLMToolCallResult, TokenUsage from hindsight_api.metrics import get_metrics_collector -logger = logging.getLogger(__name__) - - -# OAuth refresh endpoint and client id, mirrored from the canonical -# ``@openai/codex`` CLI (codex-rs/login/src/auth/manager.rs on -# github.com/openai/codex). The endpoint is overridable via env var so that -# future Codex changes or staging environments can be pointed at without a -# code change — same env var name the upstream CLI uses. -_CODEX_REFRESH_TOKEN_URL = os.environ.get("CODEX_REFRESH_TOKEN_URL_OVERRIDE", "https://auth.openai.com/oauth/token") -_CODEX_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" - -# Proactively refresh this many seconds before the JWT ``exp`` claim. The -# upstream Codex CLI uses no skew (it refreshes at ``exp <= now``); the -# extra window reduces races where a request leaves the client with a token -# that the server has already declared expired by the time it arrives. -_CODEX_TOKEN_REFRESH_SKEW_SECONDS = 60 - -# OAuth error codes that the refresh endpoint returns when the refresh_token -# itself is no longer usable. These are terminal — retrying refresh will not -# succeed; the user must re-run ``codex auth login``. -_CODEX_TERMINAL_REFRESH_ERROR_CODES = frozenset( - {"refresh_token_expired", "refresh_token_reused", "refresh_token_invalidated"} +from .codex_auth import ( + _CODEX_CLIENT_ID, + _CODEX_REFRESH_TOKEN_URL, + _CODEX_TERMINAL_REFRESH_ERROR_CODES, + _CODEX_TOKEN_REFRESH_SKEW_SECONDS, + CodexAuthManager, + CodexRefreshExpiredError, ) +# Re-export for backward compatibility (tests import from this module). +__all__ = [ + "CodexLLM", + "CodexRefreshExpiredError", + "CodexAuthManager", + "_CODEX_REFRESH_TOKEN_URL", + "_CODEX_CLIENT_ID", + "_CODEX_TOKEN_REFRESH_SKEW_SECONDS", + "_CODEX_TERMINAL_REFRESH_ERROR_CODES", +] -class CodexRefreshExpiredError(RuntimeError): - """Raised when the Codex refresh_token itself is no longer valid. - - The user must re-run ``codex auth login`` to obtain new credentials. - Callers should surface a clear remediation message and stop retrying. - """ +logger = logging.getLogger(__name__) class CodexLLM(LLMInterface): @@ -86,20 +71,15 @@ def __init__( """Initialize Codex LLM provider.""" super().__init__(provider, api_key, base_url, model, reasoning_effort, **kwargs) - # Path is fixed at ~/.codex/auth.json — matches the upstream CLI. - # Storing it on self lets the refresh path re-read after another - # process (e.g. a sidecar) rotates the file out from under us. - self._auth_file = Path.home() / ".codex" / "auth.json" - - # Single-flight refresh lock. Multiple concurrent requests racing - # toward an expired token should produce one network refresh, not N. + # Single-flight async refresh lock. Multiple concurrent coroutines + # racing toward an expired token should produce one network refresh. self._auth_lock = asyncio.Lock() - # Load Codex OAuth credentials + # Load Codex OAuth credentials (keep these methods for test patching). try: - self.access_token, self.account_id = self._load_codex_auth() - self.refresh_token = self._load_codex_refresh_token() - logger.info(f"Loaded Codex OAuth credentials for account: {self.account_id}") + access_token, account_id = self._load_codex_auth() + refresh_token = self._load_codex_refresh_token() + logger.info(f"Loaded Codex OAuth credentials for account: {account_id}") except Exception as e: raise RuntimeError( f"Failed to load Codex OAuth credentials from ~/.codex/auth.json: {e}\n\n" @@ -110,6 +90,13 @@ def __init__( "Or use a different provider (openai, anthropic, gemini) with API keys." ) from e + self._auth_manager = CodexAuthManager( + access_token=access_token, + account_id=account_id, + refresh_token=refresh_token, + auth_file=Path.home() / ".codex" / "auth.json", + ) + # Use ChatGPT backend API endpoint if not self.base_url: self.base_url = "https://chatgpt.com/backend-api" @@ -125,6 +112,42 @@ def __init__( # HTTP client for SSE streaming self._client = httpx.AsyncClient(timeout=120.0) + # ------------------------------------------------------------------ + # Properties — delegate to _auth_manager (preserves test-visible API) + # ------------------------------------------------------------------ + + @property + def access_token(self) -> str: + return self._auth_manager.access_token + + @access_token.setter + def access_token(self, v: str) -> None: + self._auth_manager.access_token = v + + @property + def account_id(self) -> str: + return self._auth_manager.account_id + + @property + def refresh_token(self) -> str | None: + return self._auth_manager.refresh_token + + @refresh_token.setter + def refresh_token(self, v: str | None) -> None: + self._auth_manager.refresh_token = v + + @property + def _auth_file(self) -> Path: + return self._auth_manager._auth_file + + @_auth_file.setter + def _auth_file(self, v: Path) -> None: + self._auth_manager._auth_file = v + + # ------------------------------------------------------------------ + # Forwarding methods (keep surface area for tests / subclasses) + # ------------------------------------------------------------------ + def _load_codex_auth(self) -> tuple[str, str]: """ Load OAuth credentials from ~/.codex/auth.json. @@ -161,273 +184,66 @@ def _load_codex_auth(self) -> tuple[str, str]: return access_token, account_id def _load_codex_refresh_token(self) -> str | None: - """Load ``tokens.refresh_token`` from ``~/.codex/auth.json``. + """Delegate to ``_auth_manager.load_refresh_token_from_file()``. - Returns None when the auth file is unreadable or omits the field — - the provider still functions as a one-shot loader in that case, it - just can't refresh when the access_token expires. This deliberately - does not raise so that ``__init__`` keeps the existing failure mode - of raising only on missing ``access_token``. + Kept as an instance method so existing tests that patch + ``CodexLLM._load_codex_refresh_token`` continue to work. """ - try: - with open(self._auth_file) as f: - data = json.load(f) - except (OSError, json.JSONDecodeError) as e: - logger.warning( - f"Codex auth file unreadable when loading refresh_token: {type(e).__name__}. " - "Token refresh will not be available; the access_token in memory will be used until it expires." - ) - return None - return data.get("tokens", {}).get("refresh_token") + # During __init__ the manager doesn't exist yet; fall back to reading + # from the default auth file path directly. + if not hasattr(self, "_auth_manager"): + auth_file = Path.home() / ".codex" / "auth.json" + try: + with open(auth_file) as f: + data = json.load(f) + except (OSError, json.JSONDecodeError) as e: + logger.warning( + f"Codex auth file unreadable when loading refresh_token: {type(e).__name__}. " + "Token refresh will not be available; the access_token in memory will be used until it expires." + ) + return None + return data.get("tokens", {}).get("refresh_token") + return self._auth_manager.load_refresh_token_from_file() @staticmethod def _decode_jwt_exp_unixtime(token: str) -> int | None: - """Return the JWT ``exp`` claim as a unix timestamp, or None on parse failure. - - ChatGPT/Codex access_tokens are JWTs whose payload includes ``exp`` - (RFC 7519). We need the expiry to schedule proactive refresh — the - ``auth.json`` file does not persist a separate ``expires_at`` field - in the upstream CLI's shape, so decoding the JWT itself is the - canonical way to know when the token is stale. - - We do not verify the signature — the server is the source of truth - on whether the token is actually accepted, and the only thing this - method affects is the *timing* of refresh, not whether to trust the - token contents. - """ - try: - parts = token.split(".") - if len(parts) < 2: - return None - payload_b64 = parts[1] - # JWT uses base64url without padding. Re-pad before decoding. - padding = "=" * (-len(payload_b64) % 4) - payload_bytes = base64.urlsafe_b64decode(payload_b64 + padding) - payload = json.loads(payload_bytes.decode("utf-8")) - exp = payload.get("exp") - return int(exp) if exp is not None else None - except (ValueError, TypeError, json.JSONDecodeError, binascii.Error): - return None + """Delegate to ``CodexAuthManager._decode_jwt_exp_unixtime``.""" + return CodexAuthManager._decode_jwt_exp_unixtime(token) def _token_is_stale(self, skew_seconds: int = _CODEX_TOKEN_REFRESH_SKEW_SECONDS) -> bool: - """True when the cached access_token is past expiry (with skew). - - Returns False when expiry cannot be determined — we'd rather use a - possibly-expired token and recover via the reactive 401 path than - refresh aggressively on every request when ``exp`` parsing fails. - """ - exp = self._decode_jwt_exp_unixtime(self.access_token) - if exp is None: - return False - return exp <= int(time.time()) + skew_seconds + """Delegate to ``_auth_manager._token_is_stale``.""" + return self._auth_manager._token_is_stale(skew_seconds) def _persist_auth_atomic(self, updated_tokens: dict[str, Any]) -> None: - """Write the rotated tokens back to ``~/.codex/auth.json`` atomically. - - Strategy: re-read the on-disk auth.json (so we don't clobber fields - another process may have added), patch ``tokens.*`` and - ``last_refresh``, write to a tempfile in the same directory with - mode 0600, then ``os.replace`` onto the target. ``os.replace`` is - atomic within the same filesystem on POSIX and Windows, so a - concurrent reader will see either the old file or the fully-written - new file — never a partial truncate, which is the upstream CLI's - worst-case race. - - On non-Unix platforms the chmod is a best-effort no-op; the parent - directory permissions still bound access. - """ - current: dict[str, Any] - try: - with open(self._auth_file) as f: - loaded = json.load(f) - # auth.json should always be a JSON object at the top level; if - # someone has hand-edited it into a non-object shape, fall back - # to the minimal default rather than crashing the refresh path. - current = loaded if isinstance(loaded, dict) else {"auth_mode": "chatgpt", "tokens": {}} - except (OSError, json.JSONDecodeError): - # If the file became unreadable between our last read and now, - # construct a minimal shape rather than refusing to persist. - current = {"auth_mode": "chatgpt", "tokens": {}} - - existing_tokens = current.get("tokens") - tokens: dict[str, Any] = existing_tokens if isinstance(existing_tokens, dict) else {} - for key in ("access_token", "refresh_token", "id_token", "account_id"): - if key in updated_tokens and updated_tokens[key] is not None: - tokens[key] = updated_tokens[key] - current["tokens"] = tokens - current["last_refresh"] = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") - - # Write to a sibling tempfile so the rename is same-filesystem. - parent = self._auth_file.parent - parent.mkdir(parents=True, exist_ok=True) - fd, tmp_path = tempfile.mkstemp(prefix=".auth.", suffix=".json.tmp", dir=str(parent)) - try: - with os.fdopen(fd, "w") as f: - json.dump(current, f, indent=2) - f.flush() - os.fsync(f.fileno()) - try: - os.chmod(tmp_path, 0o600) - except OSError: - pass # best-effort on platforms that don't support chmod - os.replace(tmp_path, self._auth_file) - except Exception: - # Clean up the orphaned tempfile if rename fails. - try: - os.unlink(tmp_path) - except OSError: - pass - raise + """Delegate to ``_auth_manager._persist_auth_atomic``.""" + return self._auth_manager._persist_auth_atomic(updated_tokens) async def _refresh_oauth_tokens(self, reason: str = "", *, force: bool = False) -> None: - """Refresh the OAuth access_token using the stored refresh_token. + """Async single-flight OAuth token refresh. - Single-flight: serialized through ``self._auth_lock`` so concurrent - callers produce one network request. The first caller refreshes; the - rest wake up and observe that either (a) the in-memory token is no - longer stale (proactive case) or (b) the in-memory token has changed - since they entered (reactive case), and return without re-refreshing. + Outer asyncio.Lock preserves single-flight semantics for concurrent + coroutines; the actual network call is offloaded to a thread via + ``asyncio.to_thread`` so the event loop stays unblocked. Args: reason: Free-form string included in log lines for diagnostics. force: When True, refresh even if the JWT exp claim looks fresh. - Used by the reactive 401 path — the server rejected the - token, so we cannot trust the JWT's self-reported expiry. + Used by the reactive 401 path. Raises: CodexRefreshExpiredError: when the server returns a terminal - error code (refresh_token_expired/reused/invalidated) or any - 401 on the refresh endpoint itself. + error code or any 401 on the refresh endpoint. RuntimeError: for other refresh failures (network, 5xx, etc.). """ - # Capture the token we'd be refreshing BEFORE acquiring the lock so - # that we can detect mid-wait rotation by another coroutine. token_before_lock = self.access_token async with self._auth_lock: if force: - # Reactive: skip only if another coroutine already rotated - # the token while we were waiting on the lock. if self.access_token != token_before_lock: return else: - # Proactive: skip if the token is no longer stale (the - # canonical "another coroutine refreshed first" check). - if not self._token_is_stale(): + if not self._auth_manager._token_is_stale(): return - - if not self.refresh_token: - raise RuntimeError( - "Codex access_token is expired but no refresh_token is available. " - "Run 'codex auth login' to re-authenticate." - ) - - log_reason = f" ({reason})" if reason else "" - logger.info(f"Refreshing Codex OAuth access_token{log_reason}") - - request_body = { - "client_id": _CODEX_CLIENT_ID, - "grant_type": "refresh_token", - "refresh_token": self.refresh_token, - } - try: - response = await self._client.post( - _CODEX_REFRESH_TOKEN_URL, - json=request_body, - headers={"Content-Type": "application/json"}, - timeout=30.0, - ) - except httpx.RequestError as e: - raise RuntimeError(f"Codex OAuth refresh network error: {type(e).__name__}") from e - - if response.status_code == 401: - # Classify by ``error.code`` (or top-level ``error`` string) — same - # mapping as the upstream Rust CLI's request_chatgpt_token_refresh. - error_code = self._extract_oauth_error_code(response) - if error_code in _CODEX_TERMINAL_REFRESH_ERROR_CODES: - raise CodexRefreshExpiredError( - f"Codex refresh_token is permanently invalid (error.code={error_code}). " - "Run 'codex auth login' to re-authenticate." - ) - # Unknown 401 — treat as terminal too, matching the upstream classification. - raise CodexRefreshExpiredError( - f"Codex OAuth refresh returned 401 with unrecognized error code " - f"({error_code or 'none'}). Run 'codex auth login' to re-authenticate." - ) - - if response.status_code >= 400: - # 5xx and other 4xx are transient/retryable from the caller's - # perspective; surface as RuntimeError without leaking the - # request body in logs. - raise RuntimeError(f"Codex OAuth refresh failed with HTTP {response.status_code}") - - try: - body = response.json() - except json.JSONDecodeError as e: - raise RuntimeError(f"Codex OAuth refresh returned non-JSON body: {e}") from e - - new_access = body.get("access_token") - if not new_access: - raise RuntimeError("Codex OAuth refresh returned no access_token") - - # The refresh_token may rotate on each refresh — adopt the new - # one if the server sent it, otherwise keep the existing. - new_refresh = body.get("refresh_token") or self.refresh_token - new_id_token = body.get("id_token") - - # Update in-memory state first so callers waiting on the lock - # see fresh credentials immediately, even if disk write fails. - self.access_token = new_access - self.refresh_token = new_refresh - - persisted = { - "access_token": new_access, - "refresh_token": new_refresh, - } - if new_id_token: - persisted["id_token"] = new_id_token - - try: - self._persist_auth_atomic(persisted) - except OSError as e: - # In-memory creds are valid; warn but don't fail the request - # path. Future process starts will fall back to the stale - # on-disk auth.json and immediately refresh. - logger.warning( - f"Codex OAuth refresh succeeded but persisting auth.json failed: {type(e).__name__}. " - "In-memory credentials are up to date; on-disk file is stale." - ) - - logger.info("Codex OAuth access_token refreshed successfully") - - @staticmethod - def _extract_oauth_error_code(response: "httpx.Response") -> str | None: - """Pull the OAuth error code out of a 4xx response body, if present. - - The refresh endpoint returns shapes like - ``{"error": "...", "error_code": "..."}`` or - ``{"error": {"code": "..."}}``. We don't fail the call if the body - is unparseable — the caller falls back to a generic "unknown" error. - """ - try: - body = response.json() - except (json.JSONDecodeError, ValueError): - return None - if not isinstance(body, dict): - return None - # Shape 1: error is a nested object with "code" - err = body.get("error") - if isinstance(err, dict): - code = err.get("code") - if isinstance(code, str): - return code - # Shape 2: top-level error_code string - code = body.get("error_code") - if isinstance(code, str): - return code - # Shape 3: error is itself a string code - if isinstance(err, str): - return err - return None + await asyncio.to_thread(lambda: self._auth_manager.refresh_tokens(reason, force=force)) async def _ensure_fresh_token(self) -> None: """Refresh the access_token proactively if it is near or past expiry. @@ -435,13 +251,10 @@ async def _ensure_fresh_token(self) -> None: Called at the top of every API-bound method. Cheap when the token is fresh (just decodes the JWT exp claim and returns). """ - if self._token_is_stale(): + if self._auth_manager._token_is_stale(): try: await self._refresh_oauth_tokens(reason="proactive (token near expiry)") except CodexRefreshExpiredError: - # Surface to the caller as the same RuntimeError shape the - # request loop has historically raised, so existing error - # handling paths keep working. raise def _map_reasoning_effort(self, effort: str) -> str: @@ -1073,5 +886,6 @@ async def _parse_sse_tool_stream(self, response: httpx.Response) -> tuple[str | return content if content else None, tool_calls async def cleanup(self) -> None: - """Clean up HTTP client.""" + """Clean up HTTP clients.""" await self._client.aclose() + self._auth_manager.close() diff --git a/hindsight-api-slim/hindsight_api/engine/providers/openai_compatible_llm.py b/hindsight-api-slim/hindsight_api/engine/providers/openai_compatible_llm.py index 9bb6fdba0..0f02949d5 100644 --- a/hindsight-api-slim/hindsight_api/engine/providers/openai_compatible_llm.py +++ b/hindsight-api-slim/hindsight_api/engine/providers/openai_compatible_llm.py @@ -306,15 +306,19 @@ def __init__( self.api_key = "local" # Validate API key for cloud providers - if self.provider in ( - "openai", - "groq", - "minimax", - "deepseek", - "openrouter", - "zai", - "opencode-go", - ) and not self.api_key: + if ( + self.provider + in ( + "openai", + "groq", + "minimax", + "deepseek", + "openrouter", + "zai", + "opencode-go", + ) + and not self.api_key + ): raise ValueError(f"API key is required for {self.provider}") # Service tier configuration (from config, not env vars) diff --git a/hindsight-api-slim/tests/test_codex_oauth_refresh.py b/hindsight-api-slim/tests/test_codex_oauth_refresh.py index ce437a5e9..c89ea339a 100644 --- a/hindsight-api-slim/tests/test_codex_oauth_refresh.py +++ b/hindsight-api-slim/tests/test_codex_oauth_refresh.py @@ -29,6 +29,7 @@ import sys import time from pathlib import Path +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch import httpx @@ -254,7 +255,7 @@ async def test_refresh_sends_canonical_request_shape(tmp_path: Path): fresh_access = _make_jwt(int(time.time()) + 3600) refresh_resp = _refresh_response(200, {"access_token": fresh_access, "refresh_token": "rt-rotated"}) - with patch.object(llm._client, "post", new_callable=AsyncMock, return_value=refresh_resp) as mock_post: + with patch.object(llm._auth_manager._http_client, "post", return_value=refresh_resp) as mock_post: await llm._refresh_oauth_tokens() call_args = mock_post.call_args @@ -277,7 +278,7 @@ async def test_refresh_updates_in_memory_credentials(tmp_path: Path): new_access = _make_jwt(int(time.time()) + 3600) refresh_resp = _refresh_response(200, {"access_token": new_access, "refresh_token": "rt-new"}) - with patch.object(llm._client, "post", new_callable=AsyncMock, return_value=refresh_resp): + with patch.object(llm._auth_manager._http_client, "post", return_value=refresh_resp): await llm._refresh_oauth_tokens() assert llm.access_token == new_access @@ -295,7 +296,7 @@ async def test_refresh_keeps_existing_refresh_token_when_server_omits_one(tmp_pa new_access = _make_jwt(int(time.time()) + 3600) refresh_resp = _refresh_response(200, {"access_token": new_access}) - with patch.object(llm._client, "post", new_callable=AsyncMock, return_value=refresh_resp): + with patch.object(llm._auth_manager._http_client, "post", return_value=refresh_resp): await llm._refresh_oauth_tokens() assert llm.refresh_token == "rt-keep" @@ -310,7 +311,7 @@ async def test_refresh_raises_permanent_error_on_terminal_oauth_code(tmp_path: P bad_resp = _refresh_response(401, {"error": {"code": "refresh_token_expired"}}) - with patch.object(llm._client, "post", new_callable=AsyncMock, return_value=bad_resp): + with patch.object(llm._auth_manager._http_client, "post", return_value=bad_resp): with pytest.raises(CodexRefreshExpiredError): await llm._refresh_oauth_tokens() @@ -325,7 +326,7 @@ async def test_refresh_raises_permanent_error_on_unknown_401(tmp_path: Path): bad_resp = _refresh_response(401, {"error": "something_else"}) - with patch.object(llm._client, "post", new_callable=AsyncMock, return_value=bad_resp): + with patch.object(llm._auth_manager._http_client, "post", return_value=bad_resp): with pytest.raises(CodexRefreshExpiredError): await llm._refresh_oauth_tokens() @@ -340,7 +341,7 @@ async def test_refresh_raises_runtime_error_on_5xx(tmp_path: Path): bad_resp = _refresh_response(503, "service unavailable") - with patch.object(llm._client, "post", new_callable=AsyncMock, return_value=bad_resp): + with patch.object(llm._auth_manager._http_client, "post", return_value=bad_resp): with pytest.raises(RuntimeError) as exc_info: await llm._refresh_oauth_tokens() assert not isinstance(exc_info.value, CodexRefreshExpiredError) @@ -357,7 +358,7 @@ async def test_refresh_does_not_log_token_values(tmp_path: Path, caplog): new_access = _make_jwt(int(time.time()) + 3600) refresh_resp = _refresh_response(200, {"access_token": new_access, "refresh_token": "rt-also-secret"}) - with patch.object(llm._client, "post", new_callable=AsyncMock, return_value=refresh_resp): + with patch.object(llm._auth_manager._http_client, "post", return_value=refresh_resp): with caplog.at_level("DEBUG"): await llm._refresh_oauth_tokens() @@ -382,14 +383,14 @@ async def test_concurrent_ensure_fresh_token_calls_produce_one_refresh(tmp_path: new_access = _make_jwt(int(time.time()) + 3600) call_count = 0 - async def fake_post(*args, **kwargs): + def fake_post(*args, **kwargs): nonlocal call_count call_count += 1 # Simulate non-zero refresh latency so concurrent callers actually queue. - await asyncio.sleep(0.01) + time.sleep(0.01) return _refresh_response(200, {"access_token": new_access, "refresh_token": "rt-new"}) - with patch.object(llm._client, "post", new=fake_post): + with patch.object(llm._auth_manager._http_client, "post", new=fake_post): await asyncio.gather(*(llm._ensure_fresh_token() for _ in range(10))) assert call_count == 1, f"expected 1 network refresh under contention, got {call_count}" @@ -410,42 +411,33 @@ async def test_call_reactively_refreshes_on_401_and_retries(tmp_path: Path): new_access = _make_jwt(int(time.time()) + 3600) - # First post → 401 (backend rejects the token). After refresh, second post → 200. success_resp = MagicMock() success_resp.status_code = 200 - success_resp.raise_for_status.return_value = None + success_resp.raise_for_status = MagicMock(return_value=None) fail_response = MagicMock() fail_response.status_code = 401 fail_response.text = "unauthorized" - fail_exc = httpx.HTTPStatusError("401", request=MagicMock(), response=fail_response) - success_resp.raise_for_status = MagicMock(return_value=None) - - post_responses = [fail_exc, success_resp] - - async def fake_post(*args, **kwargs): - item = post_responses.pop(0) - if isinstance(item, Exception): - raise item - return item refresh_resp = _refresh_response(200, {"access_token": new_access, "refresh_token": "rt-new"}) call_count = {"refresh": 0, "post": 0} - async def counting_post(url, **kwargs): - if url == _CODEX_REFRESH_TOKEN_URL: - call_count["refresh"] += 1 - return refresh_resp + # Sync mock for the auth manager's HTTP client (used for token refresh). + def fake_refresh_post(*args, **kwargs): + call_count["refresh"] += 1 + return refresh_resp + + # Async mock for the LLM's HTTP client (used for backend calls). + async def fake_backend_post(url, **kwargs): call_count["post"] += 1 - # First backend call fails with 401 wrapped in an HTTPStatusError-style response, - # second succeeds. if call_count["post"] == 1: raise httpx.HTTPStatusError("401", request=MagicMock(), response=fail_response) return success_resp with ( - patch.object(llm._client, "post", new=counting_post), + patch.object(llm._auth_manager._http_client, "post", new=fake_refresh_post), + patch.object(llm._client, "post", new=fake_backend_post), patch.object(llm, "_parse_sse_stream", new_callable=AsyncMock, return_value="ok"), ): result = await llm.call( @@ -478,17 +470,19 @@ async def test_call_proactively_refreshes_when_token_is_stale(tmp_path: Path): call_order: list[str] = [] - async def fake_post(url, **kwargs): - if url == _CODEX_REFRESH_TOKEN_URL: - call_order.append("refresh") - return refresh_resp + def fake_refresh_post(*args, **kwargs): + call_order.append("refresh") + return refresh_resp + + async def fake_backend_post(url, **kwargs): call_order.append("backend") # Assert that by the time the backend is called, the new token is in use. assert kwargs["headers"]["Authorization"] == f"Bearer {new_access}" return success_resp with ( - patch.object(llm._client, "post", new=fake_post), + patch.object(llm._auth_manager._http_client, "post", new=fake_refresh_post), + patch.object(llm._client, "post", new=fake_backend_post), patch.object(llm, "_parse_sse_stream", new_callable=AsyncMock, return_value="ok"), ): await llm.call( @@ -512,17 +506,14 @@ async def test_call_does_not_refresh_when_token_is_fresh(tmp_path: Path): success_resp.status_code = 200 success_resp.raise_for_status.return_value = None - call_count = {"refresh": 0, "backend": 0} + call_count = {"backend": 0} - async def fake_post(url, **kwargs): - if url == _CODEX_REFRESH_TOKEN_URL: - call_count["refresh"] += 1 - raise AssertionError("refresh endpoint should not be hit for a fresh token") + async def fake_backend_post(url, **kwargs): call_count["backend"] += 1 return success_resp with ( - patch.object(llm._client, "post", new=fake_post), + patch.object(llm._client, "post", new=fake_backend_post), patch.object(llm, "_parse_sse_stream", new_callable=AsyncMock, return_value="ok"), ): await llm.call( @@ -532,4 +523,104 @@ async def fake_post(url, **kwargs): max_backoff=0.0, ) - assert call_count == {"refresh": 0, "backend": 1} + assert call_count == {"backend": 1} + + +# --------------------------------------------------------------------------- +# CodexOAuthEmbeddings — proactive + reactive token refresh +# --------------------------------------------------------------------------- + + +def _make_codex_auth_file(tmp_path: Path, access_token: str, refresh_token: str = "rt-initial") -> Path: + """Write a minimal ~/.codex/auth.json in tmp_path and return its path.""" + codex_dir = tmp_path / ".codex" + codex_dir.mkdir(parents=True, exist_ok=True) + auth_file = codex_dir / "auth.json" + auth_file.write_text( + json.dumps( + { + "auth_mode": "chatgpt", + "tokens": { + "access_token": access_token, + "refresh_token": refresh_token, + "account_id": "acct-test", + }, + } + ) + ) + return auth_file + + +def test_codex_oauth_embeddings_picks_up_refreshed_token_on_encode(tmp_path: Path, monkeypatch): + """encode() calls ensure_fresh_token() and updates api_key when the token rotated.""" + from hindsight_api.engine.embeddings import CodexOAuthEmbeddings + from hindsight_api.engine.providers.codex_auth import CodexAuthManager + + expired = _make_jwt(int(time.time()) - 60) + new_access = _make_jwt(int(time.time()) + 3600) + + auth_file = _make_codex_auth_file(tmp_path, expired, refresh_token="rt-embed") + monkeypatch.setenv("HOME", str(tmp_path)) + + emb = CodexOAuthEmbeddings(model="text-embedding-3-small", batch_size=10) + # Manually point to our tmp auth file after construction. + emb._auth_manager._auth_file = auth_file + + refresh_resp = _refresh_response(200, {"access_token": new_access, "refresh_token": "rt-new"}) + + fake_embeddings = [SimpleNamespace(index=0, embedding=[0.1] * 1536)] + fake_create_resp = SimpleNamespace(data=fake_embeddings) + + with patch.object(emb._auth_manager._http_client, "post", return_value=refresh_resp): + emb._client = SimpleNamespace(embeddings=SimpleNamespace(create=lambda **kw: fake_create_resp)) + emb._dimension = 1536 + result = emb.encode(["hello"]) + + assert result == [[0.1] * 1536] + # After proactive refresh the manager's token should be the new one. + assert emb._auth_manager.access_token == new_access + # api_key on the embeddings object should also be updated. + assert emb.api_key == new_access + + +def test_codex_oauth_embeddings_reactive_refresh_on_401(tmp_path: Path, monkeypatch): + """On AuthenticationError from OpenAI, encode() refreshes and retries once.""" + from openai import AuthenticationError as OAIAuthError + + from hindsight_api.engine.embeddings import CodexOAuthEmbeddings + + fresh = _make_jwt(int(time.time()) + 3600) + new_access = _make_jwt(int(time.time()) + 7200) + + auth_file = _make_codex_auth_file(tmp_path, fresh, refresh_token="rt-embed") + monkeypatch.setenv("HOME", str(tmp_path)) + + emb = CodexOAuthEmbeddings(model="text-embedding-3-small", batch_size=10) + emb._auth_manager._auth_file = auth_file + emb._dimension = 1536 + + refresh_resp = _refresh_response(200, {"access_token": new_access, "refresh_token": "rt-rotated"}) + + call_count = {"create": 0} + + def fake_create(**kwargs): + call_count["create"] += 1 + if call_count["create"] == 1: + # Simulate OpenAI returning 401. + mock_response = MagicMock() + mock_response.status_code = 401 + raise OAIAuthError( + message="invalid api key", + response=mock_response, + body={"error": {"message": "invalid api key"}}, + ) + return SimpleNamespace(data=[SimpleNamespace(index=0, embedding=[0.2] * 1536)]) + + with patch.object(emb._auth_manager._http_client, "post", return_value=refresh_resp): + emb._client = SimpleNamespace(embeddings=SimpleNamespace(create=fake_create)) + result = emb.encode(["world"]) + + assert result == [[0.2] * 1536] + assert call_count["create"] == 2 # first failed with 401, second succeeded + assert emb._auth_manager.access_token == new_access + assert emb.api_key == new_access diff --git a/hindsight-api-slim/tests/test_embeddings_openai_batch_size.py b/hindsight-api-slim/tests/test_embeddings_openai_batch_size.py index 820f1bf3f..5c520dab3 100644 --- a/hindsight-api-slim/tests/test_embeddings_openai_batch_size.py +++ b/hindsight-api-slim/tests/test_embeddings_openai_batch_size.py @@ -7,6 +7,7 @@ the batch size via env var so `encode()` splits into smaller chunks. """ +import json import os import pytest @@ -22,6 +23,7 @@ def setup_test_env(): "HINDSIGHT_API_EMBEDDINGS_OPENAI_API_KEY", "HINDSIGHT_API_EMBEDDINGS_OPENAI_MODEL", "HINDSIGHT_API_EMBEDDINGS_OPENAI_BATCH_SIZE", + "HINDSIGHT_API_EMBEDDINGS_OPENAI_DIMENSIONS", "HINDSIGHT_API_EMBEDDINGS_OPENROUTER_API_KEY", "HINDSIGHT_API_LLM_API_KEY", "HINDSIGHT_API_LLM_PROVIDER", @@ -64,6 +66,17 @@ def test_openai_batch_size_env_var_is_read(): assert config.embeddings_openai_batch_size == 10 +def test_openai_dimensions_env_var_is_read(): + """HINDSIGHT_API_EMBEDDINGS_OPENAI_DIMENSIONS requests reduced OpenAI output dims.""" + from hindsight_api.config import HindsightConfig + + os.environ["HINDSIGHT_API_LLM_PROVIDER"] = "mock" + os.environ["HINDSIGHT_API_EMBEDDINGS_OPENAI_DIMENSIONS"] = "384" + + config = HindsightConfig.from_env() + assert config.embeddings_openai_dimensions == 384 + + def test_openai_embeddings_provider_uses_configured_batch_size(): """create_embeddings_from_env() propagates config to OpenAIEmbeddings for 'openai' provider.""" from hindsight_api.engine.embeddings import OpenAIEmbeddings, create_embeddings_from_env @@ -92,6 +105,41 @@ def test_openrouter_provider_uses_configured_batch_size(): assert embeddings.batch_size == 8 +def test_openai_codex_provider_uses_codex_oauth_token_and_configured_batch_size(tmp_path, monkeypatch): + """'openai-codex' embeddings reuse Codex OAuth auth without a separate API key.""" + from hindsight_api.engine.embeddings import CodexOAuthEmbeddings, create_embeddings_from_env + + codex_dir = tmp_path / ".codex" + codex_dir.mkdir() + (codex_dir / "auth.json").write_text( + json.dumps( + { + "auth_mode": "chatgpt", + "tokens": { + "access_token": "codex-oauth-token-test", + "account_id": "acct-test", + }, + } + ) + ) + + monkeypatch.setenv("HOME", str(tmp_path)) + os.environ["HINDSIGHT_API_LLM_PROVIDER"] = "mock" + os.environ["HINDSIGHT_API_EMBEDDINGS_PROVIDER"] = "openai-codex" + os.environ["HINDSIGHT_API_EMBEDDINGS_OPENAI_MODEL"] = "text-embedding-3-small" + os.environ["HINDSIGHT_API_EMBEDDINGS_OPENAI_BATCH_SIZE"] = "7" + os.environ["HINDSIGHT_API_EMBEDDINGS_OPENAI_DIMENSIONS"] = "384" + + embeddings = create_embeddings_from_env() + assert isinstance(embeddings, CodexOAuthEmbeddings) + assert embeddings.provider_name == "openai-codex" + assert embeddings.model == "text-embedding-3-small" + assert embeddings.base_url == "https://api.openai.com/v1" + assert embeddings.api_key == "codex-oauth-token-test" + assert embeddings.batch_size == 7 + assert embeddings.dimensions == 384 + + def test_zero_batch_size_is_rejected(): """Zero would cause `range(0, N, 0)` to crash at runtime — fail fast at config load.""" from hindsight_api.config import HindsightConfig @@ -144,7 +192,34 @@ def fake_create(*, model, input): vectors = emb.encode(["x"] * 25) + assert calls == [10, 10, 5] assert len(vectors) == 25 - assert calls == [10, 10, 5], ( - f"Expected upstream calls of size 10, 10, 5 when batch_size=10 and 25 inputs, got {calls}" + + +def test_openai_encode_passes_configured_dimensions(): + """OpenAI embeddings requests include the optional dimensions parameter when configured.""" + from types import SimpleNamespace + + from hindsight_api.engine.embeddings import OpenAIEmbeddings + + emb = OpenAIEmbeddings( + api_key="sk-test", + model="text-embedding-3-small", + batch_size=10, + dimensions=384, ) + + calls: list[int | None] = [] + + def fake_create(*, model, input, dimensions=None): + calls.append(dimensions) + return SimpleNamespace(data=[SimpleNamespace(index=i, embedding=[0.0] * 384) for i in range(len(input))]) + + emb._client = SimpleNamespace(embeddings=SimpleNamespace(create=fake_create)) + emb._dimension = 384 + + vectors = emb.encode(["x"] * 2) + + assert calls == [384] + assert len(vectors) == 2 + assert len(vectors[0]) == 384 diff --git a/hindsight-docs/blog/2026-03-23-claude-code-telegram.md b/hindsight-docs/blog/2026-03-23-claude-code-telegram.md index 07455dccc..0a9bb41eb 100644 --- a/hindsight-docs/blog/2026-03-23-claude-code-telegram.md +++ b/hindsight-docs/blog/2026-03-23-claude-code-telegram.md @@ -46,7 +46,7 @@ Open Telegram and start a chat with [@BotFather][3]. Send `/newbot` and follow the prompts — pick a name and a username (must end in `bot`). BotFather will give you a token like: ``` -1234567890:ABCDefghIJKlmnOPQrstUVwxyz +[REDACTED_TELEGRAM_BOT_TOKEN] ``` Keep this token for Step 3. diff --git a/hindsight-docs/docs/developer/configuration.md b/hindsight-docs/docs/developer/configuration.md index 1fb66c44e..8b5b29c52 100644 --- a/hindsight-docs/docs/developer/configuration.md +++ b/hindsight-docs/docs/developer/configuration.md @@ -455,7 +455,7 @@ export HINDSIGHT_API_RETAIN_LLM_MAX_BACKOFF=120.0 # Cap at 2min instead of 1m | Variable | Description | Default | |----------|-------------|---------| -| `HINDSIGHT_API_EMBEDDINGS_PROVIDER` | Provider: `local`, `tei`, `openai`, `openrouter`, `cohere`, `google`, `litellm`, or `litellm-sdk` | `local` | +| `HINDSIGHT_API_EMBEDDINGS_PROVIDER` | Provider: `local`, `tei`, `openai`, `openai-codex`, `openrouter`, `cohere`, `google`, `litellm`, or `litellm-sdk` | `local` | | `HINDSIGHT_API_EMBEDDINGS_LOCAL_MODEL` | Model for local provider | `BAAI/bge-small-en-v1.5` | | `HINDSIGHT_API_EMBEDDINGS_LOCAL_TRUST_REMOTE_CODE` | Allow loading models with custom code (security risk, disabled by default) | `false` | | `HINDSIGHT_API_EMBEDDINGS_LOCAL_FORCE_CPU` | Force CPU mode for local embeddings (avoids MPS/XPC issues on macOS) | `false` | @@ -464,6 +464,7 @@ export HINDSIGHT_API_RETAIN_LLM_MAX_BACKOFF=120.0 # Cap at 2min instead of 1m | `HINDSIGHT_API_EMBEDDINGS_OPENAI_MODEL` | OpenAI embedding model | `text-embedding-3-small` | | `HINDSIGHT_API_EMBEDDINGS_OPENAI_BASE_URL` | Custom base URL for OpenAI-compatible API (e.g., Azure OpenAI) | - | | `HINDSIGHT_API_EMBEDDINGS_OPENAI_BATCH_SIZE` | Max inputs per `embeddings.create` call for `openai`/`openrouter` providers — lower this when the upstream endpoint enforces stricter limits (e.g. DashScope caps at 10) | `100` | +| `HINDSIGHT_API_EMBEDDINGS_OPENAI_DIMENSIONS` | Optional requested output dimensions for OpenAI `text-embedding-3` models (e.g., `384` to match an existing pgvector schema) | - | | `HINDSIGHT_API_EMBEDDINGS_OPENROUTER_API_KEY` | OpenRouter API key for embeddings (falls back to `HINDSIGHT_API_OPENROUTER_API_KEY`, then `HINDSIGHT_API_LLM_API_KEY`) | - | | `HINDSIGHT_API_EMBEDDINGS_OPENROUTER_MODEL` | OpenRouter embedding model | `perplexity/pplx-embed-v1-0.6b` | | `HINDSIGHT_API_EMBEDDINGS_COHERE_API_KEY` | Cohere API key for embeddings | - | @@ -522,8 +523,14 @@ export HINDSIGHT_API_EMBEDDINGS_LOCAL_MODEL=BAAI/bge-small-en-v1.5 # OpenAI - cloud-based embeddings export HINDSIGHT_API_EMBEDDINGS_PROVIDER=openai -export HINDSIGHT_API_EMBEDDINGS_OPENAI_API_KEY=sk-xxxxxxxxxxxx # or reuses HINDSIGHT_API_LLM_API_KEY -export HINDSIGHT_API_EMBEDDINGS_OPENAI_MODEL=text-embedding-3-small # 1536 dimensions +export HINDSIGHT_API_EMBEDDINGS_OPENAI_API_KEY=*** # or reuses HINDSIGHT_API_LLM_API_KEY +export HINDSIGHT_API_EMBEDDINGS_OPENAI_MODEL=text-embedding-3-small # 1536 dimensions by default +# export HINDSIGHT_API_EMBEDDINGS_OPENAI_DIMENSIONS=384 # optional reduced output size + +# OpenAI Codex OAuth - uses existing ChatGPT/Codex login, no API key needed +export HINDSIGHT_API_EMBEDDINGS_PROVIDER=openai-codex +export HINDSIGHT_API_EMBEDDINGS_OPENAI_MODEL=text-embedding-3-small # 1536 dimensions by default +# export HINDSIGHT_API_EMBEDDINGS_OPENAI_DIMENSIONS=384 # optional reduced output size # Azure OpenAI - embeddings via Azure endpoint export HINDSIGHT_API_EMBEDDINGS_PROVIDER=openai