From a81a70aa1604d4fd12ae62aa10575839e541c5ae Mon Sep 17 00:00:00 2001 From: Sergii Nemesh Date: Sat, 28 Mar 2026 02:31:29 +0100 Subject: [PATCH 01/10] feat(retry): add unified transient retry module (#922) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements `openviking/models/retry.py` with `is_transient_error`, `transient_retry`, and `transient_retry_async` — a single config-driven retry layer replacing scattered per-backend implementations. Adds 50 unit tests covering classification, backoff, jitter, exhaustion, and custom predicates. --- openviking/models/retry.py | 285 ++++++++++++++++++++++++ tests/unit/test_retry.py | 433 +++++++++++++++++++++++++++++++++++++ 2 files changed, 718 insertions(+) create mode 100644 openviking/models/retry.py create mode 100644 tests/unit/test_retry.py diff --git a/openviking/models/retry.py b/openviking/models/retry.py new file mode 100644 index 000000000..40f6a4088 --- /dev/null +++ b/openviking/models/retry.py @@ -0,0 +1,285 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +"""Unified retry logic for VLM backends and embedding providers. + +Provides three public helpers: + +- ``is_transient_error`` — classifies an exception as transient (retryable) + or permanent (should propagate immediately). +- ``transient_retry`` — synchronous retry loop with exponential backoff. +- ``transient_retry_async`` — asynchronous counterpart using ``asyncio.sleep``. + +Transient errors are those that may resolve on their own (rate-limits, temporary +server errors, network resets). Permanent errors indicate a caller mistake +(bad auth, invalid input) and should never be retried. + +Usage example:: + + result = transient_retry(lambda: client.chat(...), max_retries=3) + result = await transient_retry_async(lambda: client.chat_async(...), max_retries=3) +""" + +from __future__ import annotations + +import asyncio +import logging +import random +import time +from collections.abc import Callable +from typing import Optional, TypeVar + +logger = logging.getLogger("openviking.models.retry") + +T = TypeVar("T") + +# --------------------------------------------------------------------------- +# Status code helpers +# --------------------------------------------------------------------------- + +_TRANSIENT_STATUS_CODES: frozenset[int] = frozenset({429, 500, 502, 503, 504}) +_PERMANENT_STATUS_CODES: frozenset[int] = frozenset({400, 401, 403, 404, 422}) + +# String patterns — permanent check runs first (more specific) +_PERMANENT_STR_PATTERNS: tuple[str, ...] = ( + "InvalidRequestError", + "AuthenticationError", +) +_TRANSIENT_STR_PATTERNS: tuple[str, ...] = ( + "TooManyRequests", + "RateLimit", + "RequestBurstTooFast", + "timed out", + "timeout", +) + + +def _extract_status_code(exc: Exception) -> int | None: + """Return numeric HTTP status from common status-bearing attributes. + + Checks ``.status_code``, ``.code``, and ``.http_status`` in that order. + Returns ``None`` if none of the attributes exist or hold an integer. + """ + for attr in ("status_code", "code", "http_status"): + value = getattr(exc, attr, None) + if isinstance(value, int): + return value + return None + + +# --------------------------------------------------------------------------- +# is_transient_error +# --------------------------------------------------------------------------- + + +def is_transient_error(exc: Exception) -> bool: + """Classify an exception as transient (retryable) or permanent. + + Evaluation order: + 1. Extract numeric status code from the exception attributes; check + permanent codes first, then transient codes. + 2. Check the exception type directly (built-in connection / timeout types). + 3. Scan ``str(exc)`` for known permanent string patterns, then transient + ones. + 4. Attempt to import ``openai`` and check against its error hierarchy. + 5. Default to ``False`` (conservative — unknown errors are not retried). + + Args: + exc: The exception to classify. + + Returns: + ``True`` if the error is likely transient and worth retrying. + ``False`` for permanent errors or any unrecognised exception. + """ + # ── 1. Numeric status code ──────────────────────────────────────────── + status = _extract_status_code(exc) + if status is not None: + if status in _PERMANENT_STATUS_CODES: + return False + if status in _TRANSIENT_STATUS_CODES: + return True + + # ── 2. Exception type ───────────────────────────────────────────────── + # asyncio.TimeoutError is a subclass of TimeoutError on 3.11+, but treat + # both explicitly for clarity on 3.10. + if isinstance(exc, (ConnectionError, ConnectionResetError, ConnectionRefusedError)): + return True + if isinstance(exc, (TimeoutError, asyncio.TimeoutError)): + return True + + # ── 3. String patterns ──────────────────────────────────────────────── + message = str(exc) + + for pattern in _PERMANENT_STR_PATTERNS: + if pattern in message: + return False + + for pattern in _TRANSIENT_STR_PATTERNS: + if pattern in message: + return True + + # ── 4. openai error types (optional dependency) ─────────────────────── + try: + import openai # type: ignore[import-untyped] + + # Permanent openai errors — check before transient + if isinstance(exc, openai.AuthenticationError): + return False + + # Transient openai errors + if isinstance(exc, (openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError)): + return True + except ImportError: + pass + + # ── 5. Default: do not retry unknown errors ─────────────────────────── + return False + + +# --------------------------------------------------------------------------- +# transient_retry (sync) +# --------------------------------------------------------------------------- + + +def transient_retry( + func: Callable[[], T], + max_retries: int = 3, + base_delay: float = 0.5, + max_delay: float = 8.0, + jitter: bool = True, + is_retryable: Optional[Callable[[Exception], bool]] = None, +) -> T: + """Call *func* and retry on transient failures with exponential backoff. + + The delay between attempts follows the formula:: + + delay = min(base_delay * 2^attempt, max_delay) + + When ``jitter=True`` the delay is multiplied by a random factor in + ``[0.5, 1.5)`` to spread concurrent retries. + + Args: + func: Zero-argument callable to invoke. + max_retries: Maximum number of *additional* attempts after the first + failure. ``0`` disables retrying entirely. + base_delay: Initial delay in seconds before the first retry. + max_delay: Upper bound on the computed delay (seconds). + jitter: Whether to apply random jitter to the delay. + is_retryable: Optional predicate that decides whether an exception + should be retried. Defaults to ``is_transient_error``. + + Returns: + The return value of *func* on success. + + Raises: + Exception: The last exception raised by *func* after all retries are + exhausted, or immediately if the error is not retryable. + """ + _check = is_retryable if is_retryable is not None else is_transient_error + + last_exc: Exception + for attempt in range(max_retries + 1): + try: + return func() + except Exception as exc: + last_exc = exc + + if not _check(exc): + # Permanent — propagate immediately + raise + + if attempt >= max_retries: + # Retries exhausted + logger.warning( + "transient_retry: all %d retries exhausted; last error: %s", + max_retries, + exc, + ) + raise + + delay = min(base_delay * (2**attempt), max_delay) + if jitter: + delay *= 0.5 + random.random() # [0.5, 1.5) + + logger.info( + "transient_retry: attempt %d/%d failed (%s); retrying in %.2fs", + attempt + 1, + max_retries, + exc, + delay, + ) + time.sleep(delay) + + # Unreachable, but satisfies the type checker + raise last_exc # type: ignore[possibly-undefined] + + +# --------------------------------------------------------------------------- +# transient_retry_async +# --------------------------------------------------------------------------- + + +async def transient_retry_async( + coro_func: Callable[[], "asyncio.Coroutine[object, object, T]"], + max_retries: int = 3, + base_delay: float = 0.5, + max_delay: float = 8.0, + jitter: bool = True, + is_retryable: Optional[Callable[[Exception], bool]] = None, +) -> T: + """Async version of :func:`transient_retry`. + + Identical semantics to the sync variant but uses ``asyncio.sleep`` + so it does not block the event loop during backoff. + + Args: + coro_func: Zero-argument async callable (coroutine factory) to invoke. + max_retries: Maximum number of *additional* attempts after the first + failure. ``0`` disables retrying entirely. + base_delay: Initial delay in seconds before the first retry. + max_delay: Upper bound on the computed delay (seconds). + jitter: Whether to apply random jitter to the delay. + is_retryable: Optional predicate that decides whether an exception + should be retried. Defaults to ``is_transient_error``. + + Returns: + The return value of *coro_func()* on success. + + Raises: + Exception: The last exception raised by *coro_func* after all retries + are exhausted, or immediately if the error is not retryable. + """ + _check = is_retryable if is_retryable is not None else is_transient_error + + last_exc: Exception + for attempt in range(max_retries + 1): + try: + return await coro_func() + except Exception as exc: + last_exc = exc + + if not _check(exc): + raise + + if attempt >= max_retries: + logger.warning( + "transient_retry_async: all %d retries exhausted; last error: %s", + max_retries, + exc, + ) + raise + + delay = min(base_delay * (2**attempt), max_delay) + if jitter: + delay *= 0.5 + random.random() + + logger.info( + "transient_retry_async: attempt %d/%d failed (%s); retrying in %.2fs", + attempt + 1, + max_retries, + exc, + delay, + ) + await asyncio.sleep(delay) + + raise last_exc # type: ignore[possibly-undefined] diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py new file mode 100644 index 000000000..18dd26947 --- /dev/null +++ b/tests/unit/test_retry.py @@ -0,0 +1,433 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +"""Comprehensive tests for the core retry module (openviking.models.retry). + +Tests cover: +- is_transient_error: ~28 parametrized cases (14 transient, 14 permanent) +- transient_retry (sync): 8 behavioral tests +- transient_retry_async (async): 8 mirrored behavioral tests +""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, patch + +import pytest + +from openviking.models.retry import is_transient_error, transient_retry, transient_retry_async + +# --------------------------------------------------------------------------- +# Helper fake HTTP error with status_code attribute +# --------------------------------------------------------------------------- + +class _HttpError(Exception): + """Fake HTTP error carrying a numeric status code for testing.""" + + def __init__(self, status_code: int, message: str = ""): + super().__init__(message or f"HTTP {status_code}") + self.status_code = status_code + + +# --------------------------------------------------------------------------- +# is_transient_error — parametrized cases +# --------------------------------------------------------------------------- + +_TRANSIENT_CASES = [ + # HTTP status codes via _HttpError.status_code + pytest.param(_HttpError(429), True, id="http_429"), + pytest.param(_HttpError(500), True, id="http_500"), + pytest.param(_HttpError(502), True, id="http_502"), + pytest.param(_HttpError(503), True, id="http_503"), + pytest.param(_HttpError(504), True, id="http_504"), + # Built-in connection exceptions + pytest.param(ConnectionError("connection failed"), True, id="ConnectionError"), + pytest.param(ConnectionResetError("reset"), True, id="ConnectionResetError"), + pytest.param(ConnectionRefusedError("refused"), True, id="ConnectionRefusedError"), + pytest.param(TimeoutError("timed out"), True, id="TimeoutError"), + pytest.param(asyncio.TimeoutError(), True, id="asyncio_TimeoutError"), + # String-pattern transient errors + pytest.param(Exception("TooManyRequests from server"), True, id="str_TooManyRequests"), + pytest.param(Exception("RateLimit exceeded"), True, id="str_RateLimit"), + pytest.param(Exception("RequestBurstTooFast"), True, id="str_RequestBurstTooFast"), + pytest.param(Exception("request timed out after 30s"), True, id="str_timed_out"), +] + +_PERMANENT_CASES = [ + # HTTP status codes via _HttpError.status_code + pytest.param(_HttpError(400), False, id="http_400"), + pytest.param(_HttpError(401), False, id="http_401"), + pytest.param(_HttpError(403), False, id="http_403"), + pytest.param(_HttpError(404), False, id="http_404"), + pytest.param(_HttpError(422), False, id="http_422"), + # Built-in value/type errors + pytest.param(ValueError("bad value"), False, id="ValueError"), + pytest.param(TypeError("wrong type"), False, id="TypeError"), + # String-pattern permanent errors + pytest.param(Exception("InvalidRequestError: field missing"), False, id="str_InvalidRequestError"), + pytest.param(Exception("AuthenticationError: invalid key"), False, id="str_AuthenticationError"), + # Unknown errors — conservative default False + pytest.param(Exception("some unknown error"), False, id="unknown_generic"), + pytest.param(RuntimeError("unexpected state"), False, id="RuntimeError_unknown"), + pytest.param(KeyError("missing key"), False, id="KeyError"), + pytest.param(AttributeError("no attr"), False, id="AttributeError"), + pytest.param(Exception("config_value_out_of_range"), False, id="str_unknown_no_transient_keyword"), +] + + +@pytest.mark.parametrize("exc,expected", _TRANSIENT_CASES) +def test_is_transient_error_transient(exc, expected): + """Transient errors should be classified as retryable (True).""" + assert is_transient_error(exc) is expected + + +@pytest.mark.parametrize("exc,expected", _PERMANENT_CASES) +def test_is_transient_error_permanent(exc, expected): + """Permanent / unknown errors should not be retried (False).""" + assert is_transient_error(exc) is expected + + +# --------------------------------------------------------------------------- +# transient_retry (sync) +# --------------------------------------------------------------------------- + +class TestTransientRetrySync: + """Sync retry behaviour tests.""" + + def test_success_first_try(self): + """Function succeeds on first attempt — call_count == 1.""" + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + return "ok" + + result = transient_retry(fn, max_retries=3) + assert result == "ok" + assert call_count == 1 + + def test_retry_then_success(self): + """Two transient failures then success — call_count == 3.""" + errors = [_HttpError(429), _HttpError(503)] + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return "ok" + + with patch("time.sleep"): + result = transient_retry(fn, max_retries=3) + + assert result == "ok" + assert call_count == 3 + + def test_permanent_error_no_retry(self): + """Permanent error (401) should not be retried — call_count == 1 and raises.""" + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + raise _HttpError(401) + + with patch("time.sleep"): + with pytest.raises(_HttpError): + transient_retry(fn, max_retries=3) + + assert call_count == 1 + + def test_max_retries_exhausted(self): + """4 consecutive 429 errors with max_retries=3 → raises after 4 calls.""" + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + raise _HttpError(429) + + with patch("time.sleep"): + with pytest.raises(_HttpError): + transient_retry(fn, max_retries=3) + + assert call_count == 4 # 1 initial + 3 retries + + def test_max_retries_zero_raises_immediately(self): + """max_retries=0 disables retrying — call_count == 1.""" + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + raise _HttpError(429) + + with patch("time.sleep"): + with pytest.raises(_HttpError): + transient_retry(fn, max_retries=0) + + assert call_count == 1 + + def test_max_retries_one(self): + """max_retries=1: one failure then success → call_count == 2.""" + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise _HttpError(429) + return "done" + + with patch("time.sleep"): + result = transient_retry(fn, max_retries=1) + + assert result == "done" + assert call_count == 2 + + def test_backoff_delays_exponential(self): + """Verify exponential backoff: base_delay=1.0, jitter=False → 1.0, 2.0, 4.0.""" + call_count = 0 + sleep_calls = [] + + def fn(): + nonlocal call_count + call_count += 1 + raise _HttpError(429) + + with patch("time.sleep", side_effect=lambda d: sleep_calls.append(d)): + with pytest.raises(_HttpError): + transient_retry(fn, max_retries=3, base_delay=1.0, max_delay=100.0, jitter=False) + + assert len(sleep_calls) == 3 + assert sleep_calls[0] == pytest.approx(1.0) + assert sleep_calls[1] == pytest.approx(2.0) + assert sleep_calls[2] == pytest.approx(4.0) + + def test_delay_capped_at_max_delay(self): + """Delays must not exceed max_delay even with many retries.""" + sleep_calls = [] + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + raise _HttpError(503) + + with patch("time.sleep", side_effect=lambda d: sleep_calls.append(d)): + with pytest.raises(_HttpError): + transient_retry( + fn, max_retries=10, base_delay=1.0, max_delay=8.0, jitter=False + ) + + assert all(d <= 8.0 for d in sleep_calls), f"Some delays exceed max_delay: {sleep_calls}" + + +# --------------------------------------------------------------------------- +# transient_retry_async (async) +# --------------------------------------------------------------------------- + +class TestTransientRetryAsync: + """Async retry behaviour tests — mirrors sync suite.""" + + async def test_success_first_try(self): + """Async function succeeds on first attempt — call_count == 1.""" + call_count = 0 + + async def coro(): + nonlocal call_count + call_count += 1 + return "ok" + + result = await transient_retry_async(coro, max_retries=3) + assert result == "ok" + assert call_count == 1 + + async def test_retry_then_success(self): + """Two transient failures then success — call_count == 3.""" + errors = [_HttpError(429), _HttpError(503)] + call_count = 0 + + async def coro(): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return "ok" + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await transient_retry_async(coro, max_retries=3) + + assert result == "ok" + assert call_count == 3 + + async def test_permanent_error_no_retry(self): + """Permanent error (401) should not be retried — call_count == 1 and raises.""" + call_count = 0 + + async def coro(): + nonlocal call_count + call_count += 1 + raise _HttpError(401) + + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(_HttpError): + await transient_retry_async(coro, max_retries=3) + + assert call_count == 1 + + async def test_max_retries_exhausted(self): + """4 consecutive 429 errors with max_retries=3 → raises after 4 calls.""" + call_count = 0 + + async def coro(): + nonlocal call_count + call_count += 1 + raise _HttpError(429) + + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(_HttpError): + await transient_retry_async(coro, max_retries=3) + + assert call_count == 4 + + async def test_max_retries_zero_raises_immediately(self): + """max_retries=0 disables retrying — call_count == 1.""" + call_count = 0 + + async def coro(): + nonlocal call_count + call_count += 1 + raise _HttpError(429) + + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(_HttpError): + await transient_retry_async(coro, max_retries=0) + + assert call_count == 1 + + async def test_max_retries_one(self): + """max_retries=1: one failure then success → call_count == 2.""" + call_count = 0 + + async def coro(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise _HttpError(429) + return "done" + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await transient_retry_async(coro, max_retries=1) + + assert result == "done" + assert call_count == 2 + + async def test_backoff_delays_exponential(self): + """Verify exponential backoff: base_delay=1.0, jitter=False → 1.0, 2.0, 4.0.""" + call_count = 0 + sleep_calls = [] + + async def fake_sleep(d): + sleep_calls.append(d) + + async def coro(): + nonlocal call_count + call_count += 1 + raise _HttpError(429) + + with patch("asyncio.sleep", side_effect=fake_sleep): + with pytest.raises(_HttpError): + await transient_retry_async( + coro, max_retries=3, base_delay=1.0, max_delay=100.0, jitter=False + ) + + assert len(sleep_calls) == 3 + assert sleep_calls[0] == pytest.approx(1.0) + assert sleep_calls[1] == pytest.approx(2.0) + assert sleep_calls[2] == pytest.approx(4.0) + + async def test_delay_capped_at_max_delay(self): + """Async delays must not exceed max_delay even with many retries.""" + sleep_calls = [] + call_count = 0 + + async def fake_sleep(d): + sleep_calls.append(d) + + async def coro(): + nonlocal call_count + call_count += 1 + raise _HttpError(503) + + with patch("asyncio.sleep", side_effect=fake_sleep): + with pytest.raises(_HttpError): + await transient_retry_async( + coro, max_retries=10, base_delay=1.0, max_delay=8.0, jitter=False + ) + + assert all(d <= 8.0 for d in sleep_calls), f"Some delays exceed max_delay: {sleep_calls}" + + +# --------------------------------------------------------------------------- +# Additional edge-case tests +# --------------------------------------------------------------------------- + +class TestIsTransientErrorEdgeCases: + """Edge cases for is_transient_error.""" + + def test_timeout_substring_in_message(self): + """'timeout' substring in message → transient.""" + err = Exception("connection timeout after 10s") + assert is_transient_error(err) is True + + def test_status_code_attribute_takes_priority(self): + """status_code=503 → transient, even if message says 'bad request'.""" + err = _HttpError(503, "bad request") + assert is_transient_error(err) is True + + def test_status_code_401_permanent_priority(self): + """status_code=401 → permanent, even if message contains 'timeout'.""" + err = _HttpError(401, "timeout auth failure") + assert is_transient_error(err) is False + + def test_custom_is_retryable_overrides(self): + """Custom is_retryable callback overrides default classification.""" + # 429 is normally transient but we pass a custom fn that returns False + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + raise _HttpError(429) + + with patch("time.sleep"): + with pytest.raises(_HttpError): + transient_retry(fn, max_retries=3, is_retryable=lambda e: False) + + assert call_count == 1 # no retries because custom fn says not retryable + + def test_http_status_attribute_variant(self): + """Objects with .http_status should be checked for transient status.""" + + class AltHttpError(Exception): + def __init__(self, http_status: int): + super().__init__(f"HTTP {http_status}") + self.http_status = http_status + + assert is_transient_error(AltHttpError(503)) is True + assert is_transient_error(AltHttpError(401)) is False + + def test_code_attribute_variant(self): + """Objects with .code should be checked for transient status.""" + + class CodeError(Exception): + def __init__(self, code: int): + super().__init__(f"Error code {code}") + self.code = code + + assert is_transient_error(CodeError(429)) is True + assert is_transient_error(CodeError(403)) is False From 5221436ed10ceb6bf140eef4727d031335e84b68 Mon Sep 17 00:00:00 2001 From: Sergii Nemesh Date: Sat, 28 Mar 2026 02:46:55 +0100 Subject: [PATCH 02/10] feat(vlm): integrate unified retry into VLM backends (#922) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - VLMBase: change max_retries default 2→3, remove max_retries param from get_completion_async abstract signature - OpenAI backend: wrap all 4 methods with transient_retry/transient_retry_async, disable SDK retry (max_retries=0 in client constructors), remove manual for-loop retry - VolcEngine backend: same pattern — transient_retry for all methods, remove manual for-loop retry - LiteLLM backend: same pattern — transient_retry for all methods, remove manual for-loop retry --- openviking/models/vlm/backends/litellm_vlm.py | 48 ++++++------- openviking/models/vlm/backends/openai_vlm.py | 67 ++++++++++--------- .../models/vlm/backends/volcengine_vlm.py | 47 +++++++------ openviking/models/vlm/base.py | 4 +- 4 files changed, 82 insertions(+), 84 deletions(-) diff --git a/openviking/models/vlm/backends/litellm_vlm.py b/openviking/models/vlm/backends/litellm_vlm.py index 4ee8c8921..45135fac5 100644 --- a/openviking/models/vlm/backends/litellm_vlm.py +++ b/openviking/models/vlm/backends/litellm_vlm.py @@ -8,7 +8,6 @@ os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" -import asyncio import base64 import time from pathlib import Path @@ -18,6 +17,7 @@ from litellm import acompletion, completion from ..base import ToolCall, VLMBase, VLMResponse +from openviking.models.retry import transient_retry, transient_retry_async logger = logging.getLogger(__name__) @@ -294,8 +294,11 @@ def get_completion( kwargs = self._build_kwargs(model, kwargs_messages, tools, tool_choice, thinking=thinking) + def _call(): + return completion(**kwargs) + t0 = time.perf_counter() - response = completion(**kwargs) + response = transient_retry(_call, max_retries=self.max_retries) elapsed = time.perf_counter() - t0 self._update_token_usage_from_response(response, duration_seconds=elapsed) return self._build_vlm_response(response, has_tools=bool(tools)) @@ -304,7 +307,6 @@ async def get_completion_async( self, prompt: str = "", thinking: bool = False, - max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, @@ -318,25 +320,17 @@ async def get_completion_async( kwargs = self._build_kwargs(model, kwargs_messages, tools, tool_choice, thinking=thinking) - last_error = None - for attempt in range(max_retries + 1): - try: - t0 = time.perf_counter() - response = await acompletion(**kwargs) - elapsed = time.perf_counter() - t0 - self._update_token_usage_from_response( - response, - duration_seconds=elapsed, - ) - return self._build_vlm_response(response, has_tools=bool(tools)) - except Exception as e: - last_error = e - if attempt < max_retries: - await asyncio.sleep(2**attempt) - - if last_error: - raise last_error - raise RuntimeError("Unknown error in async completion") + async def _call(): + return await acompletion(**kwargs) + + t0 = time.perf_counter() + response = await transient_retry_async(_call, max_retries=self.max_retries) + elapsed = time.perf_counter() - t0 + self._update_token_usage_from_response( + response, + duration_seconds=elapsed, + ) + return self._build_vlm_response(response, has_tools=bool(tools)) def get_vision_completion( self, @@ -362,8 +356,11 @@ def get_vision_completion( kwargs = self._build_kwargs(model, kwargs_messages, tools, thinking=thinking) + def _call(): + return completion(**kwargs) + t0 = time.perf_counter() - response = completion(**kwargs) + response = transient_retry(_call, max_retries=self.max_retries) elapsed = time.perf_counter() - t0 self._update_token_usage_from_response(response, duration_seconds=elapsed) return self._build_vlm_response(response, has_tools=bool(tools)) @@ -392,8 +389,11 @@ async def get_vision_completion_async( kwargs = self._build_kwargs(model, kwargs_messages, tools, thinking=thinking) + async def _call(): + return await acompletion(**kwargs) + t0 = time.perf_counter() - response = await acompletion(**kwargs) + response = await transient_retry_async(_call, max_retries=self.max_retries) elapsed = time.perf_counter() - t0 self._update_token_usage_from_response(response, duration_seconds=elapsed) return self._build_vlm_response(response, has_tools=bool(tools)) diff --git a/openviking/models/vlm/backends/openai_vlm.py b/openviking/models/vlm/backends/openai_vlm.py index 31f871315..4e0bb231d 100644 --- a/openviking/models/vlm/backends/openai_vlm.py +++ b/openviking/models/vlm/backends/openai_vlm.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 """OpenAI VLM backend implementation""" -import asyncio import base64 import json import logging @@ -13,6 +12,7 @@ from ..base import VLMBase, VLMResponse, ToolCall from ..registry import DEFAULT_AZURE_API_VERSION +from openviking.models.retry import transient_retry, transient_retry_async logger = logging.getLogger(__name__) @@ -65,6 +65,7 @@ def get_client(self): self.provider, self.api_key, self.api_base, self.api_version, self.extra_headers, ) + kwargs["max_retries"] = 0 # Disable SDK retry; we use transient_retry if self.provider == "azure": self._sync_client = openai.AzureOpenAI(**kwargs) else: @@ -82,6 +83,7 @@ def get_async_client(self): self.provider, self.api_key, self.api_base, self.api_version, self.extra_headers, ) + kwargs["max_retries"] = 0 # Disable SDK retry; we use transient_retry_async if self.provider == "azure": self._async_client = openai.AsyncAzureOpenAI(**kwargs) else: @@ -286,8 +288,11 @@ def get_completion( kwargs["tools"] = tools kwargs["tool_choice"] = tool_choice or "auto" + def _call(): + return client.chat.completions.create(**kwargs) + t0 = time.perf_counter() - response = client.chat.completions.create(**kwargs) + response = transient_retry(_call, max_retries=self.max_retries) elapsed = time.perf_counter() - t0 if tools: @@ -306,7 +311,6 @@ async def get_completion_async( self, prompt: str = "", thinking: bool = False, - max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, @@ -332,35 +336,26 @@ async def get_completion_async( kwargs["tools"] = tools kwargs["tool_choice"] = tool_choice or "auto" - last_error = None - for attempt in range(max_retries + 1): - try: - t0 = time.perf_counter() - response = await client.chat.completions.create(**kwargs) - elapsed = time.perf_counter() - t0 - - if tools: - self._update_token_usage_from_response(response) - return self._build_vlm_response(response, has_tools=bool(tools)) - - if self.stream: - content = await self._process_streaming_response_async(response) - else: - self._update_token_usage_from_response( - response, duration_seconds=elapsed, - ) - content = self._extract_content_from_response(response) - - return self._clean_response(content) - except Exception as e: - last_error = e - if attempt < max_retries: - await asyncio.sleep(2**attempt) - - if last_error: - raise last_error + async def _call(): + return await client.chat.completions.create(**kwargs) + + t0 = time.perf_counter() + response = await transient_retry_async(_call, max_retries=self.max_retries) + elapsed = time.perf_counter() - t0 + + if tools: + self._update_token_usage_from_response(response) + return self._build_vlm_response(response, has_tools=bool(tools)) + + if self.stream: + content = await self._process_streaming_response_async(response) else: - raise RuntimeError("Unknown error in async completion") + self._update_token_usage_from_response( + response, duration_seconds=elapsed, + ) + content = self._extract_content_from_response(response) + + return self._clean_response(content) def _detect_image_format(self, data: bytes) -> str: """Detect image format from magic bytes. @@ -450,8 +445,11 @@ def get_vision_completion( kwargs["tools"] = tools kwargs["tool_choice"] = "auto" + def _call(): + return client.chat.completions.create(**kwargs) + t0 = time.perf_counter() - response = client.chat.completions.create(**kwargs) + response = transient_retry(_call, max_retries=self.max_retries) elapsed = time.perf_counter() - t0 if tools: @@ -502,8 +500,11 @@ async def get_vision_completion_async( kwargs["tools"] = tools kwargs["tool_choice"] = "auto" + async def _call(): + return await client.chat.completions.create(**kwargs) + t0 = time.perf_counter() - response = await client.chat.completions.create(**kwargs) + response = await transient_retry_async(_call, max_retries=self.max_retries) elapsed = time.perf_counter() - t0 if tools: diff --git a/openviking/models/vlm/backends/volcengine_vlm.py b/openviking/models/vlm/backends/volcengine_vlm.py index 63c8abfc5..36cd83ebb 100644 --- a/openviking/models/vlm/backends/volcengine_vlm.py +++ b/openviking/models/vlm/backends/volcengine_vlm.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 """VolcEngine VLM backend implementation""" -import asyncio import base64 import json import logging @@ -12,6 +11,7 @@ from .openai_vlm import OpenAIVLM from ..base import VLMResponse, ToolCall +from openviking.models.retry import transient_retry, transient_retry_async logger = logging.getLogger(__name__) @@ -132,8 +132,11 @@ def get_completion( kwargs["tools"] = tools kwargs["tool_choice"] = tool_choice or "auto" + def _call(): + return client.chat.completions.create(**kwargs) + t0 = time.perf_counter() - response = client.chat.completions.create(**kwargs) + response = transient_retry(_call, max_retries=self.max_retries) elapsed = time.perf_counter() - t0 self._update_token_usage_from_response(response, duration_seconds=elapsed) return self._build_vlm_response(response, has_tools=bool(tools)) @@ -142,7 +145,6 @@ async def get_completion_async( self, prompt: str = "", thinking: bool = False, - max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, @@ -167,25 +169,16 @@ async def get_completion_async( kwargs["tools"] = tools kwargs["tool_choice"] = tool_choice or "auto" - last_error = None - for attempt in range(max_retries + 1): - try: - t0 = time.perf_counter() - response = await client.chat.completions.create(**kwargs) - elapsed = time.perf_counter() - t0 - self._update_token_usage_from_response( - response, duration_seconds=elapsed, - ) - return self._build_vlm_response(response, has_tools=bool(tools)) - except Exception as e: - last_error = e - if attempt < max_retries: - await asyncio.sleep(2**attempt) - - if last_error: - raise last_error - else: - raise RuntimeError("Unknown error in async completion") + async def _call(): + return await client.chat.completions.create(**kwargs) + + t0 = time.perf_counter() + response = await transient_retry_async(_call, max_retries=self.max_retries) + elapsed = time.perf_counter() - t0 + self._update_token_usage_from_response( + response, duration_seconds=elapsed, + ) + return self._build_vlm_response(response, has_tools=bool(tools)) def _detect_image_format(self, data: bytes) -> str: """Detect image format from magic bytes. @@ -336,8 +329,11 @@ def get_vision_completion( kwargs["tools"] = tools kwargs["tool_choice"] = "auto" + def _call(): + return client.chat.completions.create(**kwargs) + t0 = time.perf_counter() - response = client.chat.completions.create(**kwargs) + response = transient_retry(_call, max_retries=self.max_retries) elapsed = time.perf_counter() - t0 self._update_token_usage_from_response(response, duration_seconds=elapsed) return self._build_vlm_response(response, has_tools=bool(tools)) @@ -377,8 +373,11 @@ async def get_vision_completion_async( kwargs["tools"] = tools kwargs["tool_choice"] = "auto" + async def _call(): + return await client.chat.completions.create(**kwargs) + t0 = time.perf_counter() - response = await client.chat.completions.create(**kwargs) + response = await transient_retry_async(_call, max_retries=self.max_retries) elapsed = time.perf_counter() - t0 self._update_token_usage_from_response(response, duration_seconds=elapsed) return self._build_vlm_response(response, has_tools=bool(tools)) diff --git a/openviking/models/vlm/base.py b/openviking/models/vlm/base.py index 7ff039a09..b9184359d 100644 --- a/openviking/models/vlm/base.py +++ b/openviking/models/vlm/base.py @@ -58,7 +58,7 @@ def __init__(self, config: Dict[str, Any]): self.api_key = config.get("api_key") self.api_base = config.get("api_base") self.temperature = config.get("temperature", 0.0) - self.max_retries = config.get("max_retries", 2) + self.max_retries = config.get("max_retries", 3) self.max_tokens = config.get("max_tokens") self.extra_headers = config.get("extra_headers") self.stream = config.get("stream", False) @@ -94,7 +94,6 @@ async def get_completion_async( self, prompt: str = "", thinking: bool = False, - max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, @@ -104,7 +103,6 @@ async def get_completion_async( Args: prompt: Text prompt (used if messages not provided) thinking: Whether to enable thinking mode - max_retries: Maximum number of retries tools: Optional list of tool definitions in OpenAI function format tool_choice: Optional tool choice mode ("auto", "none", or specific tool name) messages: Optional list of message dicts (takes precedence over prompt) From de1b97dc030cd9eee07e8eb767e6e6729acfa6d8 Mon Sep 17 00:00:00 2001 From: Sergii Nemesh Date: Sat, 28 Mar 2026 02:47:11 +0100 Subject: [PATCH 03/10] refactor(vlm): migrate call sites to kwargs, remove max_retries params (#922) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - VLMConfig: default max_retries 2→3, remove max_retries from get_completion_async signature, switch all wrappers to kwargs - StructuredVLM (llm.py): remove max_retries from complete_json_async and complete_model_async, switch all internal calls to kwargs - memory_react.py: remove max_retries=self.vlm.max_retries (now handled internally by backend) - Update test stubs to match new signatures (remove max_retries=0) --- openviking/models/vlm/llm.py | 20 +++++++++++++------- openviking/session/memory/memory_react.py | 1 - openviking_cli/utils/config/vlm_config.py | 21 ++++++++++++++------- tests/models/test_vlm_strip_think_tags.py | 2 +- tests/unit/test_extra_headers_vlm.py | 4 ++-- tests/unit/test_stream_config_vlm.py | 4 ++-- 6 files changed, 32 insertions(+), 20 deletions(-) diff --git a/openviking/models/vlm/llm.py b/openviking/models/vlm/llm.py index f2fc70dd4..314e6614c 100644 --- a/openviking/models/vlm/llm.py +++ b/openviking/models/vlm/llm.py @@ -183,7 +183,9 @@ def complete_json( if schema and not messages: prompt = f"{prompt}\n\n{get_json_schema_prompt(schema)}" - response = self._get_vlm().get_completion(prompt, thinking, tools, messages) + response = self._get_vlm().get_completion( + prompt=prompt, thinking=thinking, tools=tools, messages=messages, + ) return parse_json_from_response(response) async def complete_json_async( @@ -191,7 +193,6 @@ async def complete_json_async( prompt: str = "", schema: Optional[Dict[str, Any]] = None, thinking: bool = False, - max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, messages: Optional[List[Dict[str, Any]]] = None, ) -> Optional[Dict[str, Any]]: @@ -199,7 +200,9 @@ async def complete_json_async( if schema and not messages: prompt = f"{prompt}\n\n{get_json_schema_prompt(schema)}" - response = await self._get_vlm().get_completion_async(prompt, thinking, max_retries, tools, messages) + response = await self._get_vlm().get_completion_async( + prompt=prompt, thinking=thinking, tools=tools, messages=messages, + ) return parse_json_from_response(response) def complete_model( @@ -225,12 +228,11 @@ async def complete_model_async( prompt: str, model_class: Type[T], thinking: bool = False, - max_retries: int = 0, ) -> Optional[T]: """Async version of complete_model.""" schema = model_class.model_json_schema() response = await self.complete_json_async( - prompt, schema=schema, thinking=thinking, max_retries=max_retries + prompt=prompt, schema=schema, thinking=thinking, ) if response is None: return None @@ -250,7 +252,9 @@ def get_vision_completion( messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, Any]: """Get vision completion.""" - return self._get_vlm().get_vision_completion(prompt, images, thinking, tools, messages) + return self._get_vlm().get_vision_completion( + prompt=prompt, images=images, thinking=thinking, tools=tools, messages=messages, + ) async def get_vision_completion_async( self, @@ -261,4 +265,6 @@ async def get_vision_completion_async( messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, Any]: """Async vision completion.""" - return await self._get_vlm().get_vision_completion_async(prompt, images, thinking, tools, messages) + return await self._get_vlm().get_vision_completion_async( + prompt=prompt, images=images, thinking=thinking, tools=tools, messages=messages, + ) diff --git a/openviking/session/memory/memory_react.py b/openviking/session/memory/memory_react.py index 9f85838cc..e3b7d7257 100644 --- a/openviking/session/memory/memory_react.py +++ b/openviking/session/memory/memory_react.py @@ -512,7 +512,6 @@ async def _call_llm( messages=messages, tools=get_tool_schemas(), tool_choice=tool_choice, - max_retries=self.vlm.max_retries, ) # Log cache hit info diff --git a/openviking_cli/utils/config/vlm_config.py b/openviking_cli/utils/config/vlm_config.py index 9e8260d29..d404ed1ef 100644 --- a/openviking_cli/utils/config/vlm_config.py +++ b/openviking_cli/utils/config/vlm_config.py @@ -12,7 +12,7 @@ class VLMConfig(BaseModel): api_key: Optional[str] = Field(default=None, description="API key") api_base: Optional[str] = Field(default=None, description="API base URL") temperature: float = Field(default=0.0, description="Generation temperature") - max_retries: int = Field(default=2, description="Maximum retry attempts") + max_retries: int = Field(default=3, description="Maximum retry attempts") provider: Optional[str] = Field(default=None, description="Provider type") backend: Optional[str] = Field( @@ -181,18 +181,21 @@ def get_completion( messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, Any]: """Get LLM completion.""" - return self.get_vlm_instance().get_completion(prompt, thinking, tools, messages) + return self.get_vlm_instance().get_completion( + prompt=prompt, thinking=thinking, tools=tools, messages=messages, + ) async def get_completion_async( self, prompt: str = "", thinking: bool = False, - max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, Any]: - """Get LLM completion asynchronously, max_retries=0 means no retry.""" - return await self.get_vlm_instance().get_completion_async(prompt, thinking, max_retries, tools, messages) + """Get LLM completion asynchronously.""" + return await self.get_vlm_instance().get_completion_async( + prompt=prompt, thinking=thinking, tools=tools, messages=messages, + ) def is_available(self) -> bool: """Check if LLM is configured.""" @@ -207,7 +210,9 @@ def get_vision_completion( messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, Any]: """Get LLM completion with images.""" - return self.get_vlm_instance().get_vision_completion(prompt, images, thinking, tools, messages) + return self.get_vlm_instance().get_vision_completion( + prompt=prompt, images=images, thinking=thinking, tools=tools, messages=messages, + ) async def get_vision_completion_async( self, @@ -218,4 +223,6 @@ async def get_vision_completion_async( messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, Any]: """Get LLM completion with images asynchronously.""" - return await self.get_vlm_instance().get_vision_completion_async(prompt, images, thinking, tools, messages) + return await self.get_vlm_instance().get_vision_completion_async( + prompt=prompt, images=images, thinking=thinking, tools=tools, messages=messages, + ) diff --git a/tests/models/test_vlm_strip_think_tags.py b/tests/models/test_vlm_strip_think_tags.py index fd47fa1c4..da6114851 100644 --- a/tests/models/test_vlm_strip_think_tags.py +++ b/tests/models/test_vlm_strip_think_tags.py @@ -18,7 +18,7 @@ class _Stub(VLMBase): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False, max_retries=0): + async def get_completion_async(self, prompt, thinking=False): return "" def get_vision_completion(self, prompt, images, thinking=False): diff --git a/tests/unit/test_extra_headers_vlm.py b/tests/unit/test_extra_headers_vlm.py index 1d5bdafbd..29d580a29 100644 --- a/tests/unit/test_extra_headers_vlm.py +++ b/tests/unit/test_extra_headers_vlm.py @@ -210,7 +210,7 @@ class StubVLM(OpenAIVLM): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False, max_retries=0): + async def get_completion_async(self, prompt, thinking=False): return "" def get_vision_completion(self, prompt, images, thinking=False): @@ -236,7 +236,7 @@ class StubVLM(OpenAIVLM): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False, max_retries=0): + async def get_completion_async(self, prompt, thinking=False): return "" def get_vision_completion(self, prompt, images, thinking=False): diff --git a/tests/unit/test_stream_config_vlm.py b/tests/unit/test_stream_config_vlm.py index faf5b6e25..c28eabea2 100644 --- a/tests/unit/test_stream_config_vlm.py +++ b/tests/unit/test_stream_config_vlm.py @@ -253,7 +253,7 @@ class StubVLM(OpenAIVLM): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False, max_retries=0): + async def get_completion_async(self, prompt, thinking=False): return "" def get_vision_completion(self, prompt, images, thinking=False): @@ -277,7 +277,7 @@ class StubVLM(OpenAIVLM): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False, max_retries=0): + async def get_completion_async(self, prompt, thinking=False): return "" def get_vision_completion(self, prompt, images, thinking=False): From 57a0a7cf0879d3ab6532ca8f9341313d05dcce09 Mon Sep 17 00:00:00 2001 From: Sergii Nemesh Date: Sat, 28 Mar 2026 02:47:23 +0100 Subject: [PATCH 04/10] test(vlm): add VLM retry integration tests (#922) Tests cover OpenAI backend as representative: - Completion retries on 429, does NOT retry on 401 - Vision completion now retries (was zero before) - Config max_retries is used (default=3) - max_retries removed from get_completion_async signature (all backends) - OpenAI SDK retry disabled (max_retries=0 in client constructors) --- tests/unit/test_vlm_retry_integration.py | 299 +++++++++++++++++++++++ 1 file changed, 299 insertions(+) create mode 100644 tests/unit/test_vlm_retry_integration.py diff --git a/tests/unit/test_vlm_retry_integration.py b/tests/unit/test_vlm_retry_integration.py new file mode 100644 index 000000000..647a95820 --- /dev/null +++ b/tests/unit/test_vlm_retry_integration.py @@ -0,0 +1,299 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +"""Integration tests for VLM backends with unified retry logic. + +Tests cover (using OpenAI backend as representative): +- completion retries on 429 (transient) +- completion does NOT retry on 401 (permanent) +- vision completion now retries (was zero before) +- uses config max_retries +- max_retries parameter removed from get_completion_async signature +""" + +from __future__ import annotations + +import inspect +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class _HttpError(Exception): + """Fake HTTP error carrying a numeric status code.""" + + def __init__(self, status_code: int, message: str = ""): + super().__init__(message or f"HTTP {status_code}") + self.status_code = status_code + + +def _make_fake_response(content: str = "ok") -> SimpleNamespace: + """Build a minimal fake OpenAI ChatCompletion response.""" + message = SimpleNamespace(content=content, tool_calls=None) + choice = SimpleNamespace(message=message, finish_reason="stop") + usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15) + return SimpleNamespace(choices=[choice], usage=usage) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture() +def openai_vlm(): + """Create an OpenAIVLM instance with mocked clients.""" + from openviking.models.vlm.backends.openai_vlm import OpenAIVLM + + vlm = OpenAIVLM({ + "api_key": "sk-test", + "model": "gpt-4o-mini", + "provider": "openai", + "max_retries": 2, + }) + + # Mock sync client + mock_sync = MagicMock() + vlm._sync_client = mock_sync + + # Mock async client + mock_async = MagicMock() + vlm._async_client = mock_async + + return vlm + + +# --------------------------------------------------------------------------- +# Tests: get_completion_async retries on 429 +# --------------------------------------------------------------------------- + +class TestCompletionAsyncRetries: + + async def test_retries_on_429(self, openai_vlm): + """get_completion_async should retry on 429 (transient) and succeed.""" + errors = [_HttpError(429), _HttpError(429)] + call_count = 0 + + async def fake_create(**kwargs): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return _make_fake_response("success") + + openai_vlm._async_client.chat.completions.create = fake_create + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await openai_vlm.get_completion_async(prompt="test") + + assert result == "success" + assert call_count == 3 # 2 failures + 1 success + + async def test_no_retry_on_401(self, openai_vlm): + """get_completion_async should NOT retry on 401 (permanent).""" + call_count = 0 + + async def fake_create(**kwargs): + nonlocal call_count + call_count += 1 + raise _HttpError(401, "Unauthorized") + + openai_vlm._async_client.chat.completions.create = fake_create + + with patch("asyncio.sleep", new_callable=AsyncMock): + with pytest.raises(_HttpError): + await openai_vlm.get_completion_async(prompt="test") + + assert call_count == 1 # no retries + + async def test_uses_config_max_retries(self): + """Backend should use self.max_retries from config, not a param.""" + from openviking.models.vlm.backends.openai_vlm import OpenAIVLM + + vlm = OpenAIVLM({ + "api_key": "sk-test", + "model": "gpt-4o-mini", + "provider": "openai", + "max_retries": 5, + }) + assert vlm.max_retries == 5 + + # Config default is now 3 + vlm2 = OpenAIVLM({ + "api_key": "sk-test", + "model": "gpt-4o-mini", + "provider": "openai", + }) + assert vlm2.max_retries == 3 + + +# --------------------------------------------------------------------------- +# Tests: get_vision_completion_async now retries +# --------------------------------------------------------------------------- + +class TestVisionCompletionAsyncRetries: + + async def test_vision_retries_on_429(self, openai_vlm): + """get_vision_completion_async should retry on 429 (was zero retry before).""" + errors = [_HttpError(429)] + call_count = 0 + + async def fake_create(**kwargs): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return _make_fake_response("vision ok") + + openai_vlm._async_client.chat.completions.create = fake_create + + with patch("asyncio.sleep", new_callable=AsyncMock): + result = await openai_vlm.get_vision_completion_async( + prompt="describe", images=["http://example.com/img.png"], + ) + + assert result == "vision ok" + assert call_count == 2 # 1 failure + 1 success + + +# --------------------------------------------------------------------------- +# Tests: sync completion retries +# --------------------------------------------------------------------------- + +class TestCompletionSyncRetries: + + def test_sync_retries_on_429(self, openai_vlm): + """get_completion should retry on 429.""" + errors = [_HttpError(429)] + call_count = 0 + + def fake_create(**kwargs): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return _make_fake_response("sync ok") + + openai_vlm._sync_client.chat.completions.create = fake_create + + with patch("time.sleep"): + result = openai_vlm.get_completion(prompt="test") + + assert result == "sync ok" + assert call_count == 2 + + def test_sync_vision_retries_on_503(self, openai_vlm): + """get_vision_completion should retry on 503.""" + errors = [_HttpError(503)] + call_count = 0 + + def fake_create(**kwargs): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return _make_fake_response("vision sync ok") + + openai_vlm._sync_client.chat.completions.create = fake_create + + with patch("time.sleep"): + result = openai_vlm.get_vision_completion( + prompt="describe", images=["http://example.com/img.png"], + ) + + assert result == "vision sync ok" + assert call_count == 2 + + +# --------------------------------------------------------------------------- +# Tests: signature change verification +# --------------------------------------------------------------------------- + +class TestSignatureChange: + + def test_no_max_retries_in_get_completion_async(self): + """get_completion_async should no longer accept max_retries parameter.""" + from openviking.models.vlm.backends.openai_vlm import OpenAIVLM + + sig = inspect.signature(OpenAIVLM.get_completion_async) + param_names = list(sig.parameters.keys()) + + assert "max_retries" not in param_names, ( + f"max_retries should be removed from get_completion_async, got params: {param_names}" + ) + + def test_no_max_retries_in_base_get_completion_async(self): + """VLMBase.get_completion_async should no longer accept max_retries parameter.""" + from openviking.models.vlm.base import VLMBase + + sig = inspect.signature(VLMBase.get_completion_async) + param_names = list(sig.parameters.keys()) + + assert "max_retries" not in param_names, ( + f"max_retries should be removed from VLMBase.get_completion_async, got params: {param_names}" + ) + + def test_no_max_retries_in_litellm_get_completion_async(self): + """LiteLLMVLMProvider.get_completion_async should no longer accept max_retries.""" + from openviking.models.vlm.backends.litellm_vlm import LiteLLMVLMProvider + + sig = inspect.signature(LiteLLMVLMProvider.get_completion_async) + param_names = list(sig.parameters.keys()) + + assert "max_retries" not in param_names + + def test_no_max_retries_in_volcengine_get_completion_async(self): + """VolcEngineVLM.get_completion_async should no longer accept max_retries.""" + from openviking.models.vlm.backends.volcengine_vlm import VolcEngineVLM + + sig = inspect.signature(VolcEngineVLM.get_completion_async) + param_names = list(sig.parameters.keys()) + + assert "max_retries" not in param_names + + +# --------------------------------------------------------------------------- +# Tests: OpenAI SDK retry disabled +# --------------------------------------------------------------------------- + +class TestOpenAISDKRetryDisabled: + + def test_sync_client_max_retries_zero(self): + """OpenAI sync client should be created with max_retries=0.""" + from openviking.models.vlm.backends.openai_vlm import OpenAIVLM + + vlm = OpenAIVLM({ + "api_key": "sk-test", + "model": "gpt-4o-mini", + "provider": "openai", + }) + + with patch("openai.OpenAI") as mock_openai: + mock_openai.return_value = MagicMock() + vlm._sync_client = None # force re-creation + vlm.get_client() + call_kwargs = mock_openai.call_args + assert call_kwargs[1].get("max_retries") == 0 or \ + (len(call_kwargs[0]) == 0 and call_kwargs.kwargs.get("max_retries") == 0) + + def test_async_client_max_retries_zero(self): + """OpenAI async client should be created with max_retries=0.""" + from openviking.models.vlm.backends.openai_vlm import OpenAIVLM + + vlm = OpenAIVLM({ + "api_key": "sk-test", + "model": "gpt-4o-mini", + "provider": "openai", + }) + + with patch("openai.AsyncOpenAI") as mock_async_openai: + mock_async_openai.return_value = MagicMock() + vlm._async_client = None # force re-creation + vlm.get_async_client() + call_kwargs = mock_async_openai.call_args + assert call_kwargs[1].get("max_retries") == 0 or \ + (len(call_kwargs[0]) == 0 and call_kwargs.kwargs.get("max_retries") == 0) From 591755999a96c977b2795f4b04410fcb18d064a1 Mon Sep 17 00:00:00 2001 From: Sergii Nemesh Date: Sat, 28 Mar 2026 02:57:17 +0100 Subject: [PATCH 05/10] =?UTF-8?q?feat(embedding):=20=D0=B4=D0=BE=D0=B1?= =?UTF-8?q?=D0=B0=D0=B2=D0=B8=D1=82=D1=8C=20max=5Fretries=20=D0=B2=20Embed?= =?UTF-8?q?dingConfig=20=D0=B8=20EmbedderBase?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - EmbeddingConfig: новое поле max_retries (default=3) для конфигурации retry - EmbeddingConfig._create_embedder(): инжектирует max_retries в params["config"] - EmbedderBase.__init__(): извлекает max_retries из config dict --- openviking/models/embedder/base.py | 1 + openviking_cli/utils/config/embedding_config.py | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/openviking/models/embedder/base.py b/openviking/models/embedder/base.py index a8c23a5f7..8c2f00d01 100644 --- a/openviking/models/embedder/base.py +++ b/openviking/models/embedder/base.py @@ -74,6 +74,7 @@ def __init__(self, model_name: str, config: Optional[Dict[str, Any]] = None): """ self.model_name = model_name self.config = config or {} + self.max_retries = self.config.get("max_retries", 3) if self.config else 3 @abstractmethod def embed(self, text: str, is_query: bool = False) -> EmbedResult: diff --git a/openviking_cli/utils/config/embedding_config.py b/openviking_cli/utils/config/embedding_config.py index 52025ecd2..2d3029ce8 100644 --- a/openviking_cli/utils/config/embedding_config.py +++ b/openviking_cli/utils/config/embedding_config.py @@ -250,6 +250,9 @@ class EmbeddingConfig(BaseModel): sparse: Optional[EmbeddingModelConfig] = Field(default=None) hybrid: Optional[EmbeddingModelConfig] = Field(default=None) + max_retries: int = Field( + default=3, description="Maximum retry attempts for transient errors" + ) max_concurrent: int = Field( default=10, description="Maximum number of concurrent embedding requests" ) @@ -487,6 +490,13 @@ def _create_embedder( embedder_class, param_builder = factory_registry[key] params = param_builder(config) + + # Inject max_retries into the config dict so embedders pick it up + existing_config = params.get("config") or {} + if isinstance(existing_config, dict): + existing_config["max_retries"] = self.max_retries + params["config"] = existing_config + return embedder_class(**params) def get_embedder(self): From ced4ce85c6fef3b0f77f697b0b910ebfe773480c Mon Sep 17 00:00:00 2001 From: Sergii Nemesh Date: Sat, 28 Mar 2026 02:57:33 +0100 Subject: [PATCH 06/10] =?UTF-8?q?feat(embedding):=20=D0=BF=D0=B5=D1=80?= =?UTF-8?q?=D0=B5=D0=B2=D0=B5=D1=81=D1=82=D0=B8=20=D0=B2=D1=81=D0=B5=20emb?= =?UTF-8?q?edding=20=D0=BF=D1=80=D0=BE=D0=B2=D0=B0=D0=B9=D0=B4=D0=B5=D1=80?= =?UTF-8?q?=D1=8B=20=D0=BD=D0=B0=20transient=5Fretry?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - OpenAI: отключить SDK retry (max_retries=0), обернуть embed/embed_batch - Volcengine: заменить exponential_backoff_retry на transient_retry, убрать is_429_error - VikingDB: добавить transient_retry (ранее retry отсутствовал) - Gemini: отключить SDK HttpRetryOptions (attempts=1), обернуть embed/embed_batch - MiniMax: отключить urllib3 Retry (total=0), обернуть embed/embed_batch - Jina: отключить SDK retry (max_retries=0), обернуть embed/embed_batch - Voyage: отключить SDK retry (max_retries=0), обернуть embed/embed_batch - LiteLLM: обернуть litellm.embedding() вызовы Все провайдеры теперь используют единый transient_retry с is_transient_error для классификации ошибок. Wrapper размещён ВНУТРИ метода вокруг raw API call, ДО try/except который конвертирует в RuntimeError. --- .../models/embedder/gemini_embedders.py | 36 ++++++++----- openviking/models/embedder/jina_embedders.py | 13 ++++- .../models/embedder/litellm_embedders.py | 13 ++++- .../models/embedder/minimax_embedders.py | 21 ++++---- .../models/embedder/openai_embedders.py | 16 ++++-- .../models/embedder/vikingdb_embedders.py | 35 +++++++++--- .../models/embedder/volcengine_embedders.py | 53 +++---------------- .../models/embedder/voyage_embedders.py | 13 ++++- 8 files changed, 114 insertions(+), 86 deletions(-) diff --git a/openviking/models/embedder/gemini_embedders.py b/openviking/models/embedder/gemini_embedders.py index 878fa80f5..a3265afa3 100644 --- a/openviking/models/embedder/gemini_embedders.py +++ b/openviking/models/embedder/gemini_embedders.py @@ -29,6 +29,7 @@ EmbedResult, truncate_and_normalize, ) +from openviking.models.retry import transient_retry logger = logging.getLogger("gemini_embedders") @@ -146,15 +147,13 @@ def __init__( ) if dimension is not None and not (1 <= dimension <= 3072): raise ValueError(f"dimension must be between 1 and 3072, got {dimension}") + # Disable SDK-level retry; we use transient_retry for unified retry logic if _HTTP_RETRY_AVAILABLE: self.client = genai.Client( api_key=api_key, http_options=HttpOptions( retry_options=HttpRetryOptions( - attempts=3, - initial_delay=1.0, - max_delay=30.0, - exp_base=2.0, + attempts=1, ) ), ) @@ -209,11 +208,16 @@ def embed( task_type = self.document_param # SDK accepts plain str; converts to REST Parts format internally. try: - result = self.client.models.embed_content( - model=self.model_name, - contents=text, - config=self._build_config(task_type=task_type, title=title), - ) + embed_config = self._build_config(task_type=task_type, title=title) + + def _call(): + return self.client.models.embed_content( + model=self.model_name, + contents=text, + config=embed_config, + ) + + result = transient_retry(_call, max_retries=self.max_retries) vector = truncate_and_normalize(list(result.embeddings[0].values), self._dimension) return EmbedResult(dense_vector=vector) except (APIError, ClientError) as e: @@ -254,11 +258,15 @@ def embed_batch( non_empty_texts = [batch[j] for j in non_empty_indices] try: - response = self.client.models.embed_content( - model=self.model_name, - contents=non_empty_texts, - config=config, - ) + + def _batch_call(texts=non_empty_texts, cfg=config): + return self.client.models.embed_content( + model=self.model_name, + contents=texts, + config=cfg, + ) + + response = transient_retry(_batch_call, max_retries=self.max_retries) batch_results = [None] * len(batch) for j, emb in zip(non_empty_indices, response.embeddings): batch_results[j] = EmbedResult( diff --git a/openviking/models/embedder/jina_embedders.py b/openviking/models/embedder/jina_embedders.py index 84713ba21..9aa928e44 100644 --- a/openviking/models/embedder/jina_embedders.py +++ b/openviking/models/embedder/jina_embedders.py @@ -10,6 +10,7 @@ DenseEmbedderBase, EmbedResult, ) +from openviking.models.retry import transient_retry # Default dimensions for Jina embedding models JINA_MODEL_DIMENSIONS = { @@ -113,9 +114,11 @@ def __init__( raise ValueError("api_key is required") # Initialize OpenAI-compatible client with Jina base URL + # Disable SDK retry; we use transient_retry for unified retry logic self.client = openai.OpenAI( api_key=self.api_key, base_url=self.api_base, + max_retries=0, ) # Determine dimension @@ -174,7 +177,10 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: if extra_body: kwargs["extra_body"] = extra_body - response = self.client.embeddings.create(**kwargs) + def _call(): + return self.client.embeddings.create(**kwargs) + + response = transient_retry(_call, max_retries=self.max_retries) vector = response.data[0].embedding return EmbedResult(dense_vector=vector) @@ -209,7 +215,10 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if extra_body: kwargs["extra_body"] = extra_body - response = self.client.embeddings.create(**kwargs) + def _call(): + return self.client.embeddings.create(**kwargs) + + response = transient_retry(_call, max_retries=self.max_retries) return [EmbedResult(dense_vector=item.embedding) for item in response.data] except openai.APIError as e: diff --git a/openviking/models/embedder/litellm_embedders.py b/openviking/models/embedder/litellm_embedders.py index 4f10f99c0..a3e80c52d 100644 --- a/openviking/models/embedder/litellm_embedders.py +++ b/openviking/models/embedder/litellm_embedders.py @@ -13,6 +13,7 @@ import litellm from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult +from openviking.models.retry import transient_retry from openviking.telemetry import get_current_telemetry logger = logging.getLogger(__name__) @@ -157,7 +158,11 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: try: kwargs = self._build_kwargs(is_query=is_query) kwargs["input"] = [text] - response = litellm.embedding(**kwargs) + + def _call(): + return litellm.embedding(**kwargs) + + response = transient_retry(_call, max_retries=self.max_retries) self._update_telemetry_token_usage(response) vector = response.data[0]["embedding"] return EmbedResult(dense_vector=vector) @@ -183,7 +188,11 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes try: kwargs = self._build_kwargs(is_query=is_query) kwargs["input"] = texts - response = litellm.embedding(**kwargs) + + def _call(): + return litellm.embedding(**kwargs) + + response = transient_retry(_call, max_retries=self.max_retries) self._update_telemetry_token_usage(response) return [EmbedResult(dense_vector=item["embedding"]) for item in response.data] except Exception as e: diff --git a/openviking/models/embedder/minimax_embedders.py b/openviking/models/embedder/minimax_embedders.py index faec9b913..d986d8a63 100644 --- a/openviking/models/embedder/minimax_embedders.py +++ b/openviking/models/embedder/minimax_embedders.py @@ -9,6 +9,7 @@ from urllib3.util.retry import Retry from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult +from openviking.models.retry import transient_retry from openviking_cli.utils.logger import default_logger as logger @@ -89,12 +90,8 @@ def __init__( def _create_session(self) -> requests.Session: """Create a requests session with retry logic""" session = requests.Session() - retry_strategy = Retry( - total=6, - backoff_factor=1, # 1s, 2s, 4s, 8s, 16s, 32s - status_forcelist=[429, 500, 502, 503, 504], - allowed_methods=["POST"], - ) + # Disable transport-level retry; we use transient_retry for unified retry logic + retry_strategy = Retry(total=0) adapter = HTTPAdapter(max_retries=retry_strategy) session.mount("https://", adapter) session.mount("http://", adapter) @@ -163,7 +160,10 @@ def _call_api(self, texts: List[str], is_query: bool = False) -> List[List[float def embed(self, text: str, is_query: bool = False) -> EmbedResult: """Perform dense embedding on text""" - vectors = self._call_api([text], is_query=is_query) + vectors = transient_retry( + lambda: self._call_api([text], is_query=is_query), + max_retries=self.max_retries, + ) return EmbedResult(dense_vector=vectors[0]) def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: @@ -171,9 +171,10 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if not texts: return [] - # MiniMax might have batch size limits, but let's assume the caller handles batching or use safe defaults - # For now, we pass through. If needed, we can implement internal chunking. - vectors = self._call_api(texts, is_query=is_query) + vectors = transient_retry( + lambda: self._call_api(texts, is_query=is_query), + max_retries=self.max_retries, + ) return [EmbedResult(dense_vector=v) for v in vectors] def get_dimension(self) -> int: diff --git a/openviking/models/embedder/openai_embedders.py b/openviking/models/embedder/openai_embedders.py index c57ec9ff3..408ccfe7a 100644 --- a/openviking/models/embedder/openai_embedders.py +++ b/openviking/models/embedder/openai_embedders.py @@ -13,6 +13,7 @@ HybridEmbedderBase, SparseEmbedderBase, ) +from openviking.models.retry import transient_retry from openviking.telemetry import get_current_telemetry @@ -118,7 +119,10 @@ def __init__( if not self.api_key and not self.api_base: raise ValueError("api_key is required") - client_kwargs: Dict[str, Any] = {"api_key": self.api_key or "no-key"} + client_kwargs: Dict[str, Any] = { + "api_key": self.api_key or "no-key", + "max_retries": 0, # Disable SDK retry; we use transient_retry + } if self._provider == "azure": if not self.api_base: raise ValueError("api_base (Azure endpoint) is required for Azure provider") @@ -242,7 +246,10 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: if extra_body: kwargs["extra_body"] = extra_body - response = self.client.embeddings.create(**kwargs) + def _call(): + return self.client.embeddings.create(**kwargs) + + response = transient_retry(_call, max_retries=self.max_retries) self._update_telemetry_token_usage(response) vector = response.data[0].embedding @@ -277,7 +284,10 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if extra_body: kwargs["extra_body"] = extra_body - response = self.client.embeddings.create(**kwargs) + def _call(): + return self.client.embeddings.create(**kwargs) + + response = transient_retry(_call, max_retries=self.max_retries) self._update_telemetry_token_usage(response) return [EmbedResult(dense_vector=item.embedding) for item in response.data] diff --git a/openviking/models/embedder/vikingdb_embedders.py b/openviking/models/embedder/vikingdb_embedders.py index d28aac492..f0af1185b 100644 --- a/openviking/models/embedder/vikingdb_embedders.py +++ b/openviking/models/embedder/vikingdb_embedders.py @@ -10,6 +10,7 @@ HybridEmbedderBase, SparseEmbedderBase, ) +from openviking.models.retry import transient_retry from openviking.storage.vectordb.collection.volcengine_clients import ClientForDataApi from openviking_cli.utils.logger import default_logger as logger @@ -124,7 +125,10 @@ def __init__( self.dense_model = {"name": model_name, "version": model_version, "dim": dimension} def embed(self, text: str, is_query: bool = False) -> EmbedResult: - results = self._call_api([text], dense_model=self.dense_model) + results = transient_retry( + lambda: self._call_api([text], dense_model=self.dense_model), + max_retries=self.max_retries, + ) if not results: return EmbedResult(dense_vector=[]) @@ -138,7 +142,10 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: if not texts: return [] - raw_results = self._call_api(texts, dense_model=self.dense_model) + raw_results = transient_retry( + lambda: self._call_api(texts, dense_model=self.dense_model), + max_retries=self.max_retries, + ) return [ EmbedResult( dense_vector=self._truncate_and_normalize( @@ -174,7 +181,10 @@ def __init__( } def embed(self, text: str, is_query: bool = False) -> EmbedResult: - results = self._call_api([text], sparse_model=self.sparse_model) + results = transient_retry( + lambda: self._call_api([text], sparse_model=self.sparse_model), + max_retries=self.max_retries, + ) if not results: return EmbedResult(sparse_vector={}) @@ -188,7 +198,10 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: if not texts: return [] - raw_results = self._call_api(texts, sparse_model=self.sparse_model) + raw_results = transient_retry( + lambda: self._call_api(texts, sparse_model=self.sparse_model), + max_retries=self.max_retries, + ) return [ EmbedResult( sparse_vector=self._process_sparse_embedding(item.get("sparse_embedding", {})) @@ -224,8 +237,11 @@ def __init__( } def embed(self, text: str, is_query: bool = False) -> EmbedResult: - results = self._call_api( - [text], dense_model=self.dense_model, sparse_model=self.sparse_model + results = transient_retry( + lambda: self._call_api( + [text], dense_model=self.dense_model, sparse_model=self.sparse_model + ), + max_retries=self.max_retries, ) if not results: return EmbedResult(dense_vector=[], sparse_vector={}) @@ -244,8 +260,11 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: if not texts: return [] - raw_results = self._call_api( - texts, dense_model=self.dense_model, sparse_model=self.sparse_model + raw_results = transient_retry( + lambda: self._call_api( + texts, dense_model=self.dense_model, sparse_model=self.sparse_model + ), + max_retries=self.max_retries, ) results = [] for item in raw_results: diff --git a/openviking/models/embedder/volcengine_embedders.py b/openviking/models/embedder/volcengine_embedders.py index c2384ec88..7075f7583 100644 --- a/openviking/models/embedder/volcengine_embedders.py +++ b/openviking/models/embedder/volcengine_embedders.py @@ -11,29 +11,13 @@ EmbedResult, HybridEmbedderBase, SparseEmbedderBase, - exponential_backoff_retry, truncate_and_normalize, ) +from openviking.models.retry import transient_retry from openviking.telemetry import get_current_telemetry from openviking_cli.utils.logger import default_logger as logger -def is_429_error(exception: Exception) -> bool: - """ - 判断异常是否为 429 限流错误 - - Args: - exception: 要检查的异常 - - Returns: - 如果是 429 错误则返回 True,否则返回 False - """ - exception_str = str(exception) - return ( - "429" in exception_str or "TooManyRequests" in exception_str or "RateLimit" in exception_str - ) - - def process_sparse_embedding(sparse_data: Any) -> Dict[str, float]: """Process sparse embedding data from SDK response""" if not sparse_data: @@ -177,15 +161,7 @@ def _embed_call(): return EmbedResult(dense_vector=vector) try: - return exponential_backoff_retry( - _embed_call, - max_wait=10.0, - base_delay=0.5, - max_delay=2.0, - jitter=True, - is_retryable=is_429_error, - logger=logger, - ) + return transient_retry(_embed_call, max_retries=self.max_retries) except Exception as e: raise RuntimeError(f"Volcengine embedding failed: {str(e)}") from e @@ -205,7 +181,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if not texts: return [] - try: + def _batch_call(): if self.input_type == "multimodal": multimodal_inputs = [{"type": "text", "text": text} for text in texts] response = self.client.multimodal_embeddings.create( @@ -222,6 +198,9 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes EmbedResult(dense_vector=truncate_and_normalize(item.embedding, self.dimension)) for item in data ] + + try: + return transient_retry(_batch_call, max_retries=self.max_retries) except Exception as e: logger.error( f"Volcengine batch embedding failed, texts length: {len(texts)}, input_type: {self.input_type}, model_name: {self.model_name}" @@ -295,15 +274,7 @@ def _embed_call(): return EmbedResult(sparse_vector=process_sparse_embedding(sparse_vector)) try: - return exponential_backoff_retry( - _embed_call, - max_wait=10.0, - base_delay=0.5, - max_delay=2.0, - jitter=True, - is_retryable=is_429_error, - logger=logger, - ) + return transient_retry(_embed_call, max_retries=self.max_retries) except Exception as e: raise RuntimeError(f"Volcengine sparse embedding failed: {str(e)}") from e @@ -400,15 +371,7 @@ def _embed_call(): ) try: - return exponential_backoff_retry( - _embed_call, - max_wait=10.0, - base_delay=0.5, - max_delay=2.0, - jitter=True, - is_retryable=is_429_error, - logger=logger, - ) + return transient_retry(_embed_call, max_retries=self.max_retries) except Exception as e: raise RuntimeError(f"Volcengine hybrid embedding failed: {str(e)}") from e diff --git a/openviking/models/embedder/voyage_embedders.py b/openviking/models/embedder/voyage_embedders.py index e866f0a18..2478e1515 100644 --- a/openviking/models/embedder/voyage_embedders.py +++ b/openviking/models/embedder/voyage_embedders.py @@ -7,6 +7,7 @@ import openai from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult +from openviking.models.retry import transient_retry VOYAGE_MODEL_DIMENSIONS = { "voyage-3": 1024, @@ -74,9 +75,11 @@ def __init__( f"Supported dimensions: {supported}." ) + # Disable SDK retry; we use transient_retry for unified retry logic self.client = openai.OpenAI( api_key=self.api_key, base_url=self.api_base, + max_retries=0, ) self._dimension = dimension or get_voyage_model_default_dimension(normalized_model_name) @@ -88,7 +91,10 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: if self.dimension is not None: kwargs["extra_body"] = {"output_dimension": self.dimension} - response = self.client.embeddings.create(**kwargs) + def _call(): + return self.client.embeddings.create(**kwargs) + + response = transient_retry(_call, max_retries=self.max_retries) vector = response.data[0].embedding return EmbedResult(dense_vector=vector) except openai.APIError as e: @@ -106,7 +112,10 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if self.dimension is not None: kwargs["extra_body"] = {"output_dimension": self.dimension} - response = self.client.embeddings.create(**kwargs) + def _call(): + return self.client.embeddings.create(**kwargs) + + response = transient_retry(_call, max_retries=self.max_retries) return [EmbedResult(dense_vector=item.embedding) for item in response.data] except openai.APIError as e: raise RuntimeError(f"Voyage API error: {e.message}") from e From b59fd97409ba2a4ed4aba119662cc731fb92df90 Mon Sep 17 00:00:00 2001 From: Sergii Nemesh Date: Sat, 28 Mar 2026 02:57:45 +0100 Subject: [PATCH 07/10] =?UTF-8?q?test(embedding):=20=D1=82=D0=B5=D1=81?= =?UTF-8?q?=D1=82=D1=8B=20retry=20=D0=B4=D0=BB=D1=8F=20embedding=20=D0=BF?= =?UTF-8?q?=D1=80=D0=BE=D0=B2=D0=B0=D0=B9=D0=B4=D0=B5=D1=80=D0=BE=D0=B2=20?= =?UTF-8?q?=D0=B8=20backward=20compat?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - test_embedding_retry_integration: OpenAI и VikingDB retry на transient/permanent ошибки - test_retry_config: VLMConfig и EmbeddingConfig max_retries поля и defaults - test_backward_compat: exponential_backoff_retry importable, signature unchanged, time-based --- tests/unit/test_backward_compat.py | 172 +++++++++++++ .../unit/test_embedding_retry_integration.py | 236 ++++++++++++++++++ tests/unit/test_retry_config.py | 85 +++++++ 3 files changed, 493 insertions(+) create mode 100644 tests/unit/test_backward_compat.py create mode 100644 tests/unit/test_embedding_retry_integration.py create mode 100644 tests/unit/test_retry_config.py diff --git a/tests/unit/test_backward_compat.py b/tests/unit/test_backward_compat.py new file mode 100644 index 000000000..cb9db9a46 --- /dev/null +++ b/tests/unit/test_backward_compat.py @@ -0,0 +1,172 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +"""Backward compatibility tests for the retry migration. + +Verifies that: +- exponential_backoff_retry is still importable from the old location (base.py) +- exponential_backoff_retry signature is unchanged +- exponential_backoff_retry behaviour still works (time-based) +- transient_retry is count-based (different semantics) +""" + +from __future__ import annotations + +import inspect +import time +from unittest.mock import patch + +import pytest + + +class _HttpError(Exception): + """Fake HTTP error carrying a numeric status code.""" + + def __init__(self, status_code: int, message: str = ""): + super().__init__(message or f"HTTP {status_code}") + self.status_code = status_code + + +class TestExponentialBackoffRetryImportable: + + def test_importable_from_old_location(self): + """exponential_backoff_retry should still be importable from base.py.""" + from openviking.models.embedder.base import exponential_backoff_retry + + assert callable(exponential_backoff_retry) + + +class TestExponentialBackoffRetrySignature: + + def test_signature_unchanged(self): + """exponential_backoff_retry should retain its original signature.""" + from openviking.models.embedder.base import exponential_backoff_retry + + sig = inspect.signature(exponential_backoff_retry) + param_names = list(sig.parameters.keys()) + + expected_params = [ + "func", + "max_wait", + "base_delay", + "max_delay", + "jitter", + "is_retryable", + "logger", + ] + + assert param_names == expected_params, ( + f"exponential_backoff_retry signature changed.\n" + f"Expected: {expected_params}\n" + f"Got: {param_names}" + ) + + def test_defaults_unchanged(self): + """Default parameter values should be preserved.""" + from openviking.models.embedder.base import exponential_backoff_retry + + sig = inspect.signature(exponential_backoff_retry) + params = sig.parameters + + assert params["max_wait"].default == 10.0 + assert params["base_delay"].default == 0.5 + assert params["max_delay"].default == 2.0 + assert params["jitter"].default is True + assert params["is_retryable"].default is None + assert params["logger"].default is None + + +class TestExponentialBackoffRetryBehavior: + + def test_success_first_try(self): + """Function succeeds on first attempt.""" + from openviking.models.embedder.base import exponential_backoff_retry + + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + return "ok" + + result = exponential_backoff_retry(fn) + assert result == "ok" + assert call_count == 1 + + def test_retries_on_failure(self): + """Function retries on failure until success.""" + from openviking.models.embedder.base import exponential_backoff_retry + + errors = [Exception("fail"), Exception("fail")] + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return "ok" + + with patch("time.sleep"): + result = exponential_backoff_retry(fn, max_wait=10.0) + + assert result == "ok" + assert call_count == 3 + + def test_is_time_based(self): + """exponential_backoff_retry should be time-based (uses max_wait, not count).""" + from openviking.models.embedder.base import exponential_backoff_retry + + sig = inspect.signature(exponential_backoff_retry) + param_names = list(sig.parameters.keys()) + + # Time-based: has max_wait, no max_retries + assert "max_wait" in param_names + assert "max_retries" not in param_names + + def test_respects_is_retryable(self): + """exponential_backoff_retry should respect is_retryable callback.""" + from openviking.models.embedder.base import exponential_backoff_retry + + call_count = 0 + + def fn(): + nonlocal call_count + call_count += 1 + raise ValueError("permanent") + + # is_retryable returns False => no retry + with patch("time.sleep"): + with pytest.raises(ValueError): + exponential_backoff_retry(fn, is_retryable=lambda e: False) + + assert call_count == 1 + + +class TestTransientRetryIsCountBased: + + def test_is_count_based(self): + """transient_retry should be count-based (uses max_retries, not max_wait).""" + from openviking.models.retry import transient_retry + + sig = inspect.signature(transient_retry) + param_names = list(sig.parameters.keys()) + + # Count-based: has max_retries, no max_wait + assert "max_retries" in param_names + assert "max_wait" not in param_names + + def test_different_from_backoff_retry(self): + """transient_retry and exponential_backoff_retry should have different signatures.""" + from openviking.models.embedder.base import exponential_backoff_retry + from openviking.models.retry import transient_retry + + backoff_params = set(inspect.signature(exponential_backoff_retry).parameters.keys()) + retry_params = set(inspect.signature(transient_retry).parameters.keys()) + + # They share 'func', 'base_delay', 'max_delay', 'jitter', 'is_retryable' + # but differ on time vs count control params + assert "max_wait" in backoff_params + assert "max_wait" not in retry_params + assert "max_retries" in retry_params + assert "max_retries" not in backoff_params diff --git a/tests/unit/test_embedding_retry_integration.py b/tests/unit/test_embedding_retry_integration.py new file mode 100644 index 000000000..17b381f18 --- /dev/null +++ b/tests/unit/test_embedding_retry_integration.py @@ -0,0 +1,236 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +"""Integration tests for embedding providers with unified retry logic. + +Tests cover (using OpenAI and VikingDB as representatives): +- embed retries on transient error (mock API client) +- embed does NOT retry on permanent error +- uses config max_retries +- VikingDB now has retry +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class _HttpError(Exception): + """Fake HTTP error carrying a numeric status code.""" + + def __init__(self, status_code: int, message: str = ""): + super().__init__(message or f"HTTP {status_code}") + self.status_code = status_code + + +def _make_fake_embedding_response(vector=None): + """Build a minimal fake OpenAI embeddings response.""" + if vector is None: + vector = [0.1] * 10 + item = SimpleNamespace(embedding=vector) + usage = SimpleNamespace(prompt_tokens=5, total_tokens=5) + return SimpleNamespace(data=[item], usage=usage) + + +# --------------------------------------------------------------------------- +# OpenAI Embedder Tests +# --------------------------------------------------------------------------- + +class TestOpenAIEmbedderRetry: + + @pytest.fixture() + def openai_embedder(self): + """Create an OpenAIDenseEmbedder with mocked client.""" + from openviking.models.embedder.openai_embedders import OpenAIDenseEmbedder + + embedder = OpenAIDenseEmbedder( + model_name="text-embedding-3-small", + api_key="sk-test", + dimension=10, + config={"max_retries": 2}, + ) + embedder.client = MagicMock() + return embedder + + def test_embed_retries_on_transient_error(self, openai_embedder): + """embed() should retry on 429 (transient) and succeed.""" + errors = [_HttpError(429)] + call_count = 0 + + def fake_create(**kwargs): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return _make_fake_embedding_response() + + openai_embedder.client.embeddings.create = fake_create + + with patch("time.sleep"): + result = openai_embedder.embed("test text") + + assert result.dense_vector == [0.1] * 10 + assert call_count == 2 # 1 failure + 1 success + + def test_embed_no_retry_on_permanent_error(self, openai_embedder): + """embed() should NOT retry on 401 (permanent).""" + call_count = 0 + + def fake_create(**kwargs): + nonlocal call_count + call_count += 1 + raise _HttpError(401, "Unauthorized") + + openai_embedder.client.embeddings.create = fake_create + + with patch("time.sleep"): + # 401 is permanent, transient_retry won't retry it. + # It will propagate and be caught by the except block, re-raised as RuntimeError. + with pytest.raises((RuntimeError, _HttpError)): + openai_embedder.embed("test text") + + assert call_count == 1 # no retries + + def test_uses_config_max_retries(self): + """Embedder should use self.max_retries from config.""" + from openviking.models.embedder.openai_embedders import OpenAIDenseEmbedder + + embedder = OpenAIDenseEmbedder( + model_name="text-embedding-3-small", + api_key="sk-test", + dimension=10, + config={"max_retries": 5}, + ) + assert embedder.max_retries == 5 + + # Default + embedder2 = OpenAIDenseEmbedder( + model_name="text-embedding-3-small", + api_key="sk-test", + dimension=10, + ) + assert embedder2.max_retries == 3 + + def test_openai_sdk_retry_disabled(self): + """OpenAI client should be created with max_retries=0.""" + from openviking.models.embedder.openai_embedders import OpenAIDenseEmbedder + + with patch("openai.OpenAI") as mock_openai: + mock_openai.return_value = MagicMock() + OpenAIDenseEmbedder( + model_name="text-embedding-3-small", + api_key="sk-test", + dimension=10, + ) + call_kwargs = mock_openai.call_args + assert call_kwargs.kwargs.get("max_retries") == 0 + + +# --------------------------------------------------------------------------- +# VikingDB Embedder Tests +# --------------------------------------------------------------------------- + +class TestVikingDBEmbedderRetry: + + @pytest.fixture() + def vikingdb_embedder(self): + """Create a VikingDBDenseEmbedder with mocked client.""" + from openviking.models.embedder.vikingdb_embedders import VikingDBDenseEmbedder + + with patch( + "openviking.storage.vectordb.collection.volcengine_clients.ClientForDataApi" + ): + embedder = VikingDBDenseEmbedder( + model_name="test-model", + model_version="1.0", + ak="test-ak", + sk="test-sk", + region="cn-beijing", + dimension=10, + config={"max_retries": 2}, + ) + return embedder + + def test_embed_retries_on_transient_error(self, vikingdb_embedder): + """embed() should retry on transient error and succeed.""" + errors = [_HttpError(503)] + call_count = 0 + + original_call_api = vikingdb_embedder._call_api + + def fake_call_api(*args, **kwargs): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return [{"dense_embedding": [0.1] * 10}] + + vikingdb_embedder._call_api = fake_call_api + + with patch("time.sleep"): + result = vikingdb_embedder.embed("test text") + + assert result.dense_vector == [0.1] * 10 + assert call_count == 2 # 1 failure + 1 success + + def test_embed_no_retry_on_permanent_error(self, vikingdb_embedder): + """embed() should NOT retry on 401 (permanent).""" + call_count = 0 + + def fake_call_api(*args, **kwargs): + nonlocal call_count + call_count += 1 + raise _HttpError(401, "Unauthorized") + + vikingdb_embedder._call_api = fake_call_api + + with patch("time.sleep"): + with pytest.raises(_HttpError): + vikingdb_embedder.embed("test text") + + assert call_count == 1 # no retries + + def test_uses_config_max_retries(self): + """VikingDB embedder should use self.max_retries from config.""" + from openviking.models.embedder.vikingdb_embedders import VikingDBDenseEmbedder + + with patch( + "openviking.storage.vectordb.collection.volcengine_clients.ClientForDataApi" + ): + embedder = VikingDBDenseEmbedder( + model_name="test-model", + model_version="1.0", + ak="test-ak", + sk="test-sk", + region="cn-beijing", + dimension=10, + config={"max_retries": 7}, + ) + assert embedder.max_retries == 7 + + def test_vikingdb_now_has_retry(self, vikingdb_embedder): + """VikingDB embed() should retry on 429 (was zero retry before unified retry).""" + errors = [_HttpError(429), _HttpError(429)] + call_count = 0 + + def fake_call_api(*args, **kwargs): + nonlocal call_count + call_count += 1 + if errors: + raise errors.pop(0) + return [{"dense_embedding": [0.2] * 10}] + + vikingdb_embedder._call_api = fake_call_api + + with patch("time.sleep"): + result = vikingdb_embedder.embed("test text") + + assert result.dense_vector == [0.2] * 10 + assert call_count == 3 # 2 failures + 1 success diff --git a/tests/unit/test_retry_config.py b/tests/unit/test_retry_config.py new file mode 100644 index 000000000..14b61fb4c --- /dev/null +++ b/tests/unit/test_retry_config.py @@ -0,0 +1,85 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for retry configuration fields on VLMConfig and EmbeddingConfig. + +Verifies that: +- VLMConfig default max_retries = 3 +- EmbeddingConfig has max_retries field, default = 3 +- EmbeddingConfig accepts custom max_retries +""" + +from __future__ import annotations + +import pytest + + +class TestVLMConfigMaxRetries: + + def test_default_max_retries(self): + """VLMConfig should default max_retries to 3.""" + from openviking_cli.utils.config.vlm_config import VLMConfig + + cfg = VLMConfig( + model="gpt-4o-mini", + api_key="sk-test", + provider="openai", + ) + assert cfg.max_retries == 3 + + def test_custom_max_retries(self): + """VLMConfig should accept custom max_retries.""" + from openviking_cli.utils.config.vlm_config import VLMConfig + + cfg = VLMConfig( + model="gpt-4o-mini", + api_key="sk-test", + provider="openai", + max_retries=10, + ) + assert cfg.max_retries == 10 + + +class TestEmbeddingConfigMaxRetries: + + def test_has_max_retries_field(self): + """EmbeddingConfig should have a max_retries field.""" + from openviking_cli.utils.config.embedding_config import EmbeddingConfig + + fields = EmbeddingConfig.model_fields + assert "max_retries" in fields, ( + f"EmbeddingConfig is missing 'max_retries' field. Fields: {list(fields.keys())}" + ) + + def test_default_max_retries(self): + """EmbeddingConfig should default max_retries to 3.""" + from openviking_cli.utils.config.embedding_config import ( + EmbeddingConfig, + EmbeddingModelConfig, + ) + + cfg = EmbeddingConfig( + dense=EmbeddingModelConfig( + model="text-embedding-3-small", + api_key="sk-test", + provider="openai", + ), + ) + assert cfg.max_retries == 3 + + def test_custom_max_retries(self): + """EmbeddingConfig should accept custom max_retries.""" + from openviking_cli.utils.config.embedding_config import ( + EmbeddingConfig, + EmbeddingModelConfig, + ) + + cfg = EmbeddingConfig( + dense=EmbeddingModelConfig( + model="text-embedding-3-small", + api_key="sk-test", + provider="openai", + ), + max_retries=7, + ) + assert cfg.max_retries == 7 From f10946b8988b7059603fcc27c334f23dacc1d2ee Mon Sep 17 00:00:00 2001 From: Sergii Nemesh Date: Sat, 28 Mar 2026 11:42:56 +0100 Subject: [PATCH 08/10] =?UTF-8?q?style:=20ruff=20format=20=D0=B4=D0=BB?= =?UTF-8?q?=D1=8F=20=D0=B2=D1=81=D0=B5=D1=85=20=D0=B8=D0=B7=D0=BC=D0=B5?= =?UTF-8?q?=D0=BD=D1=91=D0=BD=D0=BD=D1=8B=D1=85=20=D1=84=D0=B0=D0=B9=D0=BB?= =?UTF-8?q?=D0=BE=D0=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- openviking/models/retry.py | 4 +- openviking/models/vlm/backends/openai_vlm.py | 32 +++-- .../models/vlm/backends/volcengine_vlm.py | 9 +- openviking/models/vlm/llm.py | 28 +++- openviking/session/memory/memory_react.py | 133 +++++++++++------- .../utils/config/embedding_config.py | 4 +- openviking_cli/utils/config/vlm_config.py | 22 ++- tests/unit/test_backward_compat.py | 4 - .../unit/test_embedding_retry_integration.py | 13 +- tests/unit/test_retry.py | 20 ++- tests/unit/test_retry_config.py | 2 - tests/unit/test_vlm_retry_integration.py | 92 +++++++----- 12 files changed, 223 insertions(+), 140 deletions(-) diff --git a/openviking/models/retry.py b/openviking/models/retry.py index 40f6a4088..bf4b4e697 100644 --- a/openviking/models/retry.py +++ b/openviking/models/retry.py @@ -127,7 +127,9 @@ def is_transient_error(exc: Exception) -> bool: return False # Transient openai errors - if isinstance(exc, (openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError)): + if isinstance( + exc, (openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError) + ): return True except ImportError: pass diff --git a/openviking/models/vlm/backends/openai_vlm.py b/openviking/models/vlm/backends/openai_vlm.py index 4e0bb231d..22f0c8689 100644 --- a/openviking/models/vlm/backends/openai_vlm.py +++ b/openviking/models/vlm/backends/openai_vlm.py @@ -22,6 +22,7 @@ "dashscope-intl.aliyuncs.com", } + def _build_openai_client_kwargs( provider: str, api_key: str, @@ -62,8 +63,11 @@ def get_client(self): except ImportError: raise ImportError("Please install openai: pip install openai") kwargs = _build_openai_client_kwargs( - self.provider, self.api_key, self.api_base, - self.api_version, self.extra_headers, + self.provider, + self.api_key, + self.api_base, + self.api_version, + self.extra_headers, ) kwargs["max_retries"] = 0 # Disable SDK retry; we use transient_retry if self.provider == "azure": @@ -80,8 +84,11 @@ def get_async_client(self): except ImportError: raise ImportError("Please install openai: pip install openai") kwargs = _build_openai_client_kwargs( - self.provider, self.api_key, self.api_base, - self.api_version, self.extra_headers, + self.provider, + self.api_key, + self.api_base, + self.api_version, + self.extra_headers, ) kwargs["max_retries"] = 0 # Disable SDK retry; we use transient_retry_async if self.provider == "azure": @@ -108,15 +115,15 @@ def _supports_enable_thinking(self) -> bool: return host.lower() in _DASHSCOPE_HOSTS - def _apply_provider_specific_extra_body( - self, kwargs: Dict[str, Any], thinking: bool - ) -> None: + def _apply_provider_specific_extra_body(self, kwargs: Dict[str, Any], thinking: bool) -> None: """Attach provider-specific raw body parameters understood by compatible APIs.""" if self._supports_enable_thinking(): kwargs["extra_body"] = {"enable_thinking": bool(thinking)} def _update_token_usage_from_response( - self, response, duration_seconds: float = 0.0, + self, + response, + duration_seconds: float = 0.0, ): if hasattr(response, "usage") and response.usage: prompt_tokens = response.usage.prompt_tokens @@ -141,11 +148,7 @@ def _parse_tool_calls(self, message) -> List[ToolCall]: args = json.loads(args) except json.JSONDecodeError: args = {"raw": args} - tool_calls.append(ToolCall( - id=tc.id, - name=tc.function.name, - arguments=args - )) + tool_calls.append(ToolCall(id=tc.id, name=tc.function.name, arguments=args)) return tool_calls def _build_vlm_response(self, response, has_tools: bool) -> Union[str, VLMResponse]: @@ -351,7 +354,8 @@ async def _call(): content = await self._process_streaming_response_async(response) else: self._update_token_usage_from_response( - response, duration_seconds=elapsed, + response, + duration_seconds=elapsed, ) content = self._extract_content_from_response(response) diff --git a/openviking/models/vlm/backends/volcengine_vlm.py b/openviking/models/vlm/backends/volcengine_vlm.py index 36cd83ebb..60a4f4219 100644 --- a/openviking/models/vlm/backends/volcengine_vlm.py +++ b/openviking/models/vlm/backends/volcengine_vlm.py @@ -73,11 +73,7 @@ def _parse_tool_calls(self, message) -> List[ToolCall]: args = json.loads(args) except json.JSONDecodeError: args = {"raw": args} - tool_calls.append(ToolCall( - id=tc.id, - name=tc.function.name, - arguments=args - )) + tool_calls.append(ToolCall(id=tc.id, name=tc.function.name, arguments=args)) return tool_calls def _build_vlm_response(self, response, has_tools: bool) -> Union[str, VLMResponse]: @@ -176,7 +172,8 @@ async def _call(): response = await transient_retry_async(_call, max_retries=self.max_retries) elapsed = time.perf_counter() - t0 self._update_token_usage_from_response( - response, duration_seconds=elapsed, + response, + duration_seconds=elapsed, ) return self._build_vlm_response(response, has_tools=bool(tools)) diff --git a/openviking/models/vlm/llm.py b/openviking/models/vlm/llm.py index 314e6614c..3ace412c7 100644 --- a/openviking/models/vlm/llm.py +++ b/openviking/models/vlm/llm.py @@ -33,7 +33,7 @@ def parse_json_from_response(response: Union[str, Any]) -> Optional[Any]: Optional[Any]: Parsed JSON object, None if parsing fails """ # Handle VLMResponse - extract content - if hasattr(response, 'content'): + if hasattr(response, "content"): response = response.content if response is None: @@ -184,7 +184,10 @@ def complete_json( prompt = f"{prompt}\n\n{get_json_schema_prompt(schema)}" response = self._get_vlm().get_completion( - prompt=prompt, thinking=thinking, tools=tools, messages=messages, + prompt=prompt, + thinking=thinking, + tools=tools, + messages=messages, ) return parse_json_from_response(response) @@ -201,7 +204,10 @@ async def complete_json_async( prompt = f"{prompt}\n\n{get_json_schema_prompt(schema)}" response = await self._get_vlm().get_completion_async( - prompt=prompt, thinking=thinking, tools=tools, messages=messages, + prompt=prompt, + thinking=thinking, + tools=tools, + messages=messages, ) return parse_json_from_response(response) @@ -232,7 +238,9 @@ async def complete_model_async( """Async version of complete_model.""" schema = model_class.model_json_schema() response = await self.complete_json_async( - prompt=prompt, schema=schema, thinking=thinking, + prompt=prompt, + schema=schema, + thinking=thinking, ) if response is None: return None @@ -253,7 +261,11 @@ def get_vision_completion( ) -> Union[str, Any]: """Get vision completion.""" return self._get_vlm().get_vision_completion( - prompt=prompt, images=images, thinking=thinking, tools=tools, messages=messages, + prompt=prompt, + images=images, + thinking=thinking, + tools=tools, + messages=messages, ) async def get_vision_completion_async( @@ -266,5 +278,9 @@ async def get_vision_completion_async( ) -> Union[str, Any]: """Async vision completion.""" return await self._get_vlm().get_vision_completion_async( - prompt=prompt, images=images, thinking=thinking, tools=tools, messages=messages, + prompt=prompt, + images=images, + thinking=thinking, + tools=tools, + messages=messages, ) diff --git a/openviking/session/memory/memory_react.py b/openviking/session/memory/memory_react.py index e3b7d7257..f8700d33f 100644 --- a/openviking/session/memory/memory_react.py +++ b/openviking/session/memory/memory_react.py @@ -42,7 +42,6 @@ logger = get_logger(__name__) - class MemoryReAct: """ Simplified ReAct orchestrator for memory updates. @@ -85,7 +84,10 @@ def __init__( self.registry = registry else: import os - schemas_dir = os.path.join(os.path.dirname(__file__), "..", "..", "prompts", "templates", "memory") + + schemas_dir = os.path.join( + os.path.dirname(__file__), "..", "..", "prompts", "templates", "memory" + ) self.registry = MemoryTypeRegistry() self.registry.load_from_directory(schemas_dir) self.schema_model_generator = SchemaModelGenerator(self.registry) @@ -115,6 +117,7 @@ async def _pre_fetch_context(self, conversation: str) -> Dict[str, Any]: Pre-fetched context with directories, summaries, and search_results """ from openviking.session.memory.tools import get_tool + messages = [] # Step 1: Separate schemas into multi-file (ls) and single-file (direct read) @@ -126,9 +129,15 @@ async def _pre_fetch_context(self, conversation: str) -> Dict[str, Any]: continue # Replace variables in directory path with actual user/agent space - user_space = self.ctx.user.user_space_name() if self.ctx and self.ctx.user else "default" - agent_space = self.ctx.user.agent_space_name() if self.ctx and self.ctx.user else "default" - dir_path = schema.directory.replace("{user_space}", user_space).replace("{agent_space}", agent_space) + user_space = ( + self.ctx.user.user_space_name() if self.ctx and self.ctx.user else "default" + ) + agent_space = ( + self.ctx.user.agent_space_name() if self.ctx and self.ctx.user else "default" + ) + dir_path = schema.directory.replace("{user_space}", user_space).replace( + "{agent_space}", agent_space + ) # Check if filename_template has variables (contains {xxx}) has_variables = False @@ -154,24 +163,22 @@ async def _pre_fetch_context(self, conversation: str) -> Dict[str, Any]: add_tool_call_pair_to_messages( messages=messages, call_id=call_id_seq, - tool_name='ls', - params={ - "uri": dir_uri - }, - result=result_str + tool_name="ls", + params={"uri": dir_uri}, + result=result_str, ) call_id_seq += 1 - result_str = await read_tool.execute(self.viking_fs, self.ctx, uri=f'{dir_uri}/.overview.md') + result_str = await read_tool.execute( + self.viking_fs, self.ctx, uri=f"{dir_uri}/.overview.md" + ) add_tool_call_pair_to_messages( messages=messages, call_id=call_id_seq, - tool_name='read', - params={ - "uri": f'{dir_uri}/.overview.md' - }, - result=result_str + tool_name="read", + params={"uri": f"{dir_uri}/.overview.md"}, + result=result_str, ) call_id_seq += 1 @@ -186,7 +193,7 @@ async def _pre_fetch_context(self, conversation: str) -> Dict[str, Any]: user_messages = [] for line in conversation.split("\n"): if line.startswith("[user]:"): - user_messages.append(line[len("[user]:"):].strip()) + user_messages.append(line[len("[user]:") :].strip()) user_query = " ".join(user_messages) if user_query: @@ -199,9 +206,9 @@ async def _pre_fetch_context(self, conversation: str) -> Dict[str, Any]: add_tool_call_pair_to_messages( messages=messages, call_id=call_id_seq, - tool_name='search', + tool_name="search", params={"query": user_query}, - result=str(search_result) + result=str(search_result), ) call_id_seq += 1 except Exception as e: @@ -209,7 +216,6 @@ async def _pre_fetch_context(self, conversation: str) -> Dict[str, Any]: return messages - async def run( self, conversation: str, @@ -241,7 +247,9 @@ async def run( # Reset read files tracking for this run self._read_files.clear() - messages = self._build_initial_messages(conversation, tool_call_messages, self._output_language) + messages = self._build_initial_messages( + conversation, tool_call_messages, self._output_language + ) while iteration < self.max_iterations: iteration += 1 @@ -252,10 +260,12 @@ async def run( # If last iteration, add a message telling the model to return result directly if is_last_iteration: - messages.append({ - "role": "user", - "content": "You have reached the maximum number of tool call iterations. Do not call any more tools - return your final result directly now." - }) + messages.append( + { + "role": "user", + "content": "You have reached the maximum number of tool call iterations. Do not call any more tools - return your final result directly now.", + } + ) # Call LLM with tools - model decides: tool calls OR final operations tool_calls, operations = await self._call_llm(messages, force_final=is_last_iteration) @@ -278,7 +288,9 @@ async def run( # If no tool calls either, continue to next iteration (don't break!) if not tool_calls: - logger.warning(f"LLM returned neither tool calls nor operations (iteration {iteration}/{self.max_iterations})") + logger.warning( + f"LLM returned neither tool calls nor operations (iteration {iteration}/{self.max_iterations})" + ) # If it's the last iteration, use empty operations if is_last_iteration: final_operations = MemoryOperations() @@ -293,18 +305,19 @@ async def execute_single_tool_call(idx: int, tool_call): return idx, tool_call, result action_tasks = [ - execute_single_tool_call(idx, tool_call) - for idx, tool_call in enumerate(tool_calls) + execute_single_tool_call(idx, tool_call) for idx, tool_call in enumerate(tool_calls) ] results = await self._execute_in_parallel(action_tasks) # Process results and add to messages for _idx, tool_call, result in results: - tools_used.append({ - "tool_name": tool_call.name, - "params": tool_call.arguments, - "result": result, - }) + tools_used.append( + { + "tool_name": tool_call.name, + "params": tool_call.arguments, + "result": result, + } + ) # Track read tool calls for refetch detection if tool_call.name == "read" and tool_call.arguments.get("uri"): @@ -325,7 +338,7 @@ async def execute_single_tool_call(idx: int, tool_call): else: raise RuntimeError("ReAct loop completed but no operations generated") - logger.info(f'final_operations={final_operations.model_dump_json(indent=4)}') + logger.info(f"final_operations={final_operations.model_dump_json(indent=4)}") return final_operations, tools_used @@ -346,19 +359,20 @@ def _build_initial_messages( # Add pre-fetched context as tool calls messages.extend(tool_call_messages) - messages.append({ + messages.append( + { "role": "user", "content": f"""## Conversation History {conversation} After exploring, analyze the conversation and output ALL memory write/edit/delete operations in a single response. Do not output operations one at a time - gather all changes first, then return them together.""", - }) + } + ) # Print messages in a readable format pretty_print_messages(messages) return messages - def _get_allowed_directories_list(self) -> str: """Get a formatted list of allowed directories for the system prompt.""" user_space = self.ctx.user.user_space_name() if self.ctx and self.ctx.user else "default" @@ -375,6 +389,7 @@ def _get_allowed_directories_list(self) -> str: def _get_system_prompt(self, output_language: str) -> str: """Get the simplified system prompt.""" import json + schema_str = json.dumps(self._json_schema, ensure_ascii=False) allowed_dirs_list = self._get_allowed_directories_list() @@ -515,15 +530,23 @@ async def _call_llm( ) # Log cache hit info - if hasattr(response, 'usage') and response.usage: + if hasattr(response, "usage") and response.usage: usage = response.usage - prompt_tokens = usage.get('prompt_tokens', 0) - cached_tokens = usage.get('prompt_tokens_details', {}).get('cached_tokens', 0) if isinstance(usage.get('prompt_tokens_details'), dict) else 0 + prompt_tokens = usage.get("prompt_tokens", 0) + cached_tokens = ( + usage.get("prompt_tokens_details", {}).get("cached_tokens", 0) + if isinstance(usage.get("prompt_tokens_details"), dict) + else 0 + ) if prompt_tokens > 0: cache_hit_rate = (cached_tokens / prompt_tokens) * 100 - logger.info(f"[KVCache] prompt_tokens={prompt_tokens}, cached_tokens={cached_tokens}, cache_hit_rate={cache_hit_rate:.1f}%") + logger.info( + f"[KVCache] prompt_tokens={prompt_tokens}, cached_tokens={cached_tokens}, cache_hit_rate={cache_hit_rate:.1f}%" + ) else: - logger.info(f"[KVCache] prompt_tokens={prompt_tokens}, cached_tokens={cached_tokens}") + logger.info( + f"[KVCache] prompt_tokens={prompt_tokens}, cached_tokens={cached_tokens}" + ) # Case 1: LLM returned tool calls if response.has_tool_calls: @@ -545,7 +568,13 @@ async def _call_llm( operations, error = parse_json_with_stability( content=content, model_class=operations_model, - expected_fields=['reasoning', 'write_uris', 'edit_uris', 'edit_overview_uris', 'delete_uris'], + expected_fields=[ + "reasoning", + "write_uris", + "edit_uris", + "edit_overview_uris", + "delete_uris", + ], ) if error is not None: @@ -555,7 +584,13 @@ async def _call_llm( operations, error_fallback = parse_json_with_stability( content=content_no_md, model_class=MemoryOperations, - expected_fields=['reasoning', 'write_uris', 'edit_uris', 'edit_overview_uris', 'delete_uris'], + expected_fields=[ + "reasoning", + "write_uris", + "edit_uris", + "edit_overview_uris", + "delete_uris", + ], ) if error_fallback is not None: logger.warning(f"Fallback parse also failed: {error_fallback}") @@ -658,7 +693,9 @@ async def _add_refetch_results_to_messages( logger.warning(f"Failed to refetch {uri}: {e}") # Add reminder message for the model - messages.append({ - "role": "user", - "content": "Note: The files above were automatically read because they exist and you didn't read them before deciding to write. Please consider the existing content when making write decisions. You can now output updated operations." - }) + messages.append( + { + "role": "user", + "content": "Note: The files above were automatically read because they exist and you didn't read them before deciding to write. Please consider the existing content when making write decisions. You can now output updated operations.", + } + ) diff --git a/openviking_cli/utils/config/embedding_config.py b/openviking_cli/utils/config/embedding_config.py index 2d3029ce8..0532b5d41 100644 --- a/openviking_cli/utils/config/embedding_config.py +++ b/openviking_cli/utils/config/embedding_config.py @@ -250,9 +250,7 @@ class EmbeddingConfig(BaseModel): sparse: Optional[EmbeddingModelConfig] = Field(default=None) hybrid: Optional[EmbeddingModelConfig] = Field(default=None) - max_retries: int = Field( - default=3, description="Maximum retry attempts for transient errors" - ) + max_retries: int = Field(default=3, description="Maximum retry attempts for transient errors") max_concurrent: int = Field( default=10, description="Maximum number of concurrent embedding requests" ) diff --git a/openviking_cli/utils/config/vlm_config.py b/openviking_cli/utils/config/vlm_config.py index d404ed1ef..6c9f3e006 100644 --- a/openviking_cli/utils/config/vlm_config.py +++ b/openviking_cli/utils/config/vlm_config.py @@ -182,7 +182,10 @@ def get_completion( ) -> Union[str, Any]: """Get LLM completion.""" return self.get_vlm_instance().get_completion( - prompt=prompt, thinking=thinking, tools=tools, messages=messages, + prompt=prompt, + thinking=thinking, + tools=tools, + messages=messages, ) async def get_completion_async( @@ -194,7 +197,10 @@ async def get_completion_async( ) -> Union[str, Any]: """Get LLM completion asynchronously.""" return await self.get_vlm_instance().get_completion_async( - prompt=prompt, thinking=thinking, tools=tools, messages=messages, + prompt=prompt, + thinking=thinking, + tools=tools, + messages=messages, ) def is_available(self) -> bool: @@ -211,7 +217,11 @@ def get_vision_completion( ) -> Union[str, Any]: """Get LLM completion with images.""" return self.get_vlm_instance().get_vision_completion( - prompt=prompt, images=images, thinking=thinking, tools=tools, messages=messages, + prompt=prompt, + images=images, + thinking=thinking, + tools=tools, + messages=messages, ) async def get_vision_completion_async( @@ -224,5 +234,9 @@ async def get_vision_completion_async( ) -> Union[str, Any]: """Get LLM completion with images asynchronously.""" return await self.get_vlm_instance().get_vision_completion_async( - prompt=prompt, images=images, thinking=thinking, tools=tools, messages=messages, + prompt=prompt, + images=images, + thinking=thinking, + tools=tools, + messages=messages, ) diff --git a/tests/unit/test_backward_compat.py b/tests/unit/test_backward_compat.py index cb9db9a46..f005ac066 100644 --- a/tests/unit/test_backward_compat.py +++ b/tests/unit/test_backward_compat.py @@ -28,7 +28,6 @@ def __init__(self, status_code: int, message: str = ""): class TestExponentialBackoffRetryImportable: - def test_importable_from_old_location(self): """exponential_backoff_retry should still be importable from base.py.""" from openviking.models.embedder.base import exponential_backoff_retry @@ -37,7 +36,6 @@ def test_importable_from_old_location(self): class TestExponentialBackoffRetrySignature: - def test_signature_unchanged(self): """exponential_backoff_retry should retain its original signature.""" from openviking.models.embedder.base import exponential_backoff_retry @@ -77,7 +75,6 @@ def test_defaults_unchanged(self): class TestExponentialBackoffRetryBehavior: - def test_success_first_try(self): """Function succeeds on first attempt.""" from openviking.models.embedder.base import exponential_backoff_retry @@ -144,7 +141,6 @@ def fn(): class TestTransientRetryIsCountBased: - def test_is_count_based(self): """transient_retry should be count-based (uses max_retries, not max_wait).""" from openviking.models.retry import transient_retry diff --git a/tests/unit/test_embedding_retry_integration.py b/tests/unit/test_embedding_retry_integration.py index 17b381f18..bb8ab2a9f 100644 --- a/tests/unit/test_embedding_retry_integration.py +++ b/tests/unit/test_embedding_retry_integration.py @@ -22,6 +22,7 @@ # Helpers # --------------------------------------------------------------------------- + class _HttpError(Exception): """Fake HTTP error carrying a numeric status code.""" @@ -43,8 +44,8 @@ def _make_fake_embedding_response(vector=None): # OpenAI Embedder Tests # --------------------------------------------------------------------------- -class TestOpenAIEmbedderRetry: +class TestOpenAIEmbedderRetry: @pytest.fixture() def openai_embedder(self): """Create an OpenAIDenseEmbedder with mocked client.""" @@ -137,16 +138,14 @@ def test_openai_sdk_retry_disabled(self): # VikingDB Embedder Tests # --------------------------------------------------------------------------- -class TestVikingDBEmbedderRetry: +class TestVikingDBEmbedderRetry: @pytest.fixture() def vikingdb_embedder(self): """Create a VikingDBDenseEmbedder with mocked client.""" from openviking.models.embedder.vikingdb_embedders import VikingDBDenseEmbedder - with patch( - "openviking.storage.vectordb.collection.volcengine_clients.ClientForDataApi" - ): + with patch("openviking.storage.vectordb.collection.volcengine_clients.ClientForDataApi"): embedder = VikingDBDenseEmbedder( model_name="test-model", model_version="1.0", @@ -201,9 +200,7 @@ def test_uses_config_max_retries(self): """VikingDB embedder should use self.max_retries from config.""" from openviking.models.embedder.vikingdb_embedders import VikingDBDenseEmbedder - with patch( - "openviking.storage.vectordb.collection.volcengine_clients.ClientForDataApi" - ): + with patch("openviking.storage.vectordb.collection.volcengine_clients.ClientForDataApi"): embedder = VikingDBDenseEmbedder( model_name="test-model", model_version="1.0", diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py index 18dd26947..176412d48 100644 --- a/tests/unit/test_retry.py +++ b/tests/unit/test_retry.py @@ -22,6 +22,7 @@ # Helper fake HTTP error with status_code attribute # --------------------------------------------------------------------------- + class _HttpError(Exception): """Fake HTTP error carrying a numeric status code for testing.""" @@ -65,14 +66,20 @@ def __init__(self, status_code: int, message: str = ""): pytest.param(ValueError("bad value"), False, id="ValueError"), pytest.param(TypeError("wrong type"), False, id="TypeError"), # String-pattern permanent errors - pytest.param(Exception("InvalidRequestError: field missing"), False, id="str_InvalidRequestError"), - pytest.param(Exception("AuthenticationError: invalid key"), False, id="str_AuthenticationError"), + pytest.param( + Exception("InvalidRequestError: field missing"), False, id="str_InvalidRequestError" + ), + pytest.param( + Exception("AuthenticationError: invalid key"), False, id="str_AuthenticationError" + ), # Unknown errors — conservative default False pytest.param(Exception("some unknown error"), False, id="unknown_generic"), pytest.param(RuntimeError("unexpected state"), False, id="RuntimeError_unknown"), pytest.param(KeyError("missing key"), False, id="KeyError"), pytest.param(AttributeError("no attr"), False, id="AttributeError"), - pytest.param(Exception("config_value_out_of_range"), False, id="str_unknown_no_transient_keyword"), + pytest.param( + Exception("config_value_out_of_range"), False, id="str_unknown_no_transient_keyword" + ), ] @@ -92,6 +99,7 @@ def test_is_transient_error_permanent(exc, expected): # transient_retry (sync) # --------------------------------------------------------------------------- + class TestTransientRetrySync: """Sync retry behaviour tests.""" @@ -219,9 +227,7 @@ def fn(): with patch("time.sleep", side_effect=lambda d: sleep_calls.append(d)): with pytest.raises(_HttpError): - transient_retry( - fn, max_retries=10, base_delay=1.0, max_delay=8.0, jitter=False - ) + transient_retry(fn, max_retries=10, base_delay=1.0, max_delay=8.0, jitter=False) assert all(d <= 8.0 for d in sleep_calls), f"Some delays exceed max_delay: {sleep_calls}" @@ -230,6 +236,7 @@ def fn(): # transient_retry_async (async) # --------------------------------------------------------------------------- + class TestTransientRetryAsync: """Async retry behaviour tests — mirrors sync suite.""" @@ -376,6 +383,7 @@ async def coro(): # Additional edge-case tests # --------------------------------------------------------------------------- + class TestIsTransientErrorEdgeCases: """Edge cases for is_transient_error.""" diff --git a/tests/unit/test_retry_config.py b/tests/unit/test_retry_config.py index 14b61fb4c..9e8adb235 100644 --- a/tests/unit/test_retry_config.py +++ b/tests/unit/test_retry_config.py @@ -15,7 +15,6 @@ class TestVLMConfigMaxRetries: - def test_default_max_retries(self): """VLMConfig should default max_retries to 3.""" from openviking_cli.utils.config.vlm_config import VLMConfig @@ -41,7 +40,6 @@ def test_custom_max_retries(self): class TestEmbeddingConfigMaxRetries: - def test_has_max_retries_field(self): """EmbeddingConfig should have a max_retries field.""" from openviking_cli.utils.config.embedding_config import EmbeddingConfig diff --git a/tests/unit/test_vlm_retry_integration.py b/tests/unit/test_vlm_retry_integration.py index 647a95820..69c4b4b1f 100644 --- a/tests/unit/test_vlm_retry_integration.py +++ b/tests/unit/test_vlm_retry_integration.py @@ -24,6 +24,7 @@ # Helpers # --------------------------------------------------------------------------- + class _HttpError(Exception): """Fake HTTP error carrying a numeric status code.""" @@ -44,17 +45,20 @@ def _make_fake_response(content: str = "ok") -> SimpleNamespace: # Fixtures # --------------------------------------------------------------------------- + @pytest.fixture() def openai_vlm(): """Create an OpenAIVLM instance with mocked clients.""" from openviking.models.vlm.backends.openai_vlm import OpenAIVLM - vlm = OpenAIVLM({ - "api_key": "sk-test", - "model": "gpt-4o-mini", - "provider": "openai", - "max_retries": 2, - }) + vlm = OpenAIVLM( + { + "api_key": "sk-test", + "model": "gpt-4o-mini", + "provider": "openai", + "max_retries": 2, + } + ) # Mock sync client mock_sync = MagicMock() @@ -71,8 +75,8 @@ def openai_vlm(): # Tests: get_completion_async retries on 429 # --------------------------------------------------------------------------- -class TestCompletionAsyncRetries: +class TestCompletionAsyncRetries: async def test_retries_on_429(self, openai_vlm): """get_completion_async should retry on 429 (transient) and succeed.""" errors = [_HttpError(429), _HttpError(429)] @@ -114,20 +118,24 @@ async def test_uses_config_max_retries(self): """Backend should use self.max_retries from config, not a param.""" from openviking.models.vlm.backends.openai_vlm import OpenAIVLM - vlm = OpenAIVLM({ - "api_key": "sk-test", - "model": "gpt-4o-mini", - "provider": "openai", - "max_retries": 5, - }) + vlm = OpenAIVLM( + { + "api_key": "sk-test", + "model": "gpt-4o-mini", + "provider": "openai", + "max_retries": 5, + } + ) assert vlm.max_retries == 5 # Config default is now 3 - vlm2 = OpenAIVLM({ - "api_key": "sk-test", - "model": "gpt-4o-mini", - "provider": "openai", - }) + vlm2 = OpenAIVLM( + { + "api_key": "sk-test", + "model": "gpt-4o-mini", + "provider": "openai", + } + ) assert vlm2.max_retries == 3 @@ -135,8 +143,8 @@ async def test_uses_config_max_retries(self): # Tests: get_vision_completion_async now retries # --------------------------------------------------------------------------- -class TestVisionCompletionAsyncRetries: +class TestVisionCompletionAsyncRetries: async def test_vision_retries_on_429(self, openai_vlm): """get_vision_completion_async should retry on 429 (was zero retry before).""" errors = [_HttpError(429)] @@ -153,7 +161,8 @@ async def fake_create(**kwargs): with patch("asyncio.sleep", new_callable=AsyncMock): result = await openai_vlm.get_vision_completion_async( - prompt="describe", images=["http://example.com/img.png"], + prompt="describe", + images=["http://example.com/img.png"], ) assert result == "vision ok" @@ -164,8 +173,8 @@ async def fake_create(**kwargs): # Tests: sync completion retries # --------------------------------------------------------------------------- -class TestCompletionSyncRetries: +class TestCompletionSyncRetries: def test_sync_retries_on_429(self, openai_vlm): """get_completion should retry on 429.""" errors = [_HttpError(429)] @@ -202,7 +211,8 @@ def fake_create(**kwargs): with patch("time.sleep"): result = openai_vlm.get_vision_completion( - prompt="describe", images=["http://example.com/img.png"], + prompt="describe", + images=["http://example.com/img.png"], ) assert result == "vision sync ok" @@ -213,8 +223,8 @@ def fake_create(**kwargs): # Tests: signature change verification # --------------------------------------------------------------------------- -class TestSignatureChange: +class TestSignatureChange: def test_no_max_retries_in_get_completion_async(self): """get_completion_async should no longer accept max_retries parameter.""" from openviking.models.vlm.backends.openai_vlm import OpenAIVLM @@ -260,40 +270,46 @@ def test_no_max_retries_in_volcengine_get_completion_async(self): # Tests: OpenAI SDK retry disabled # --------------------------------------------------------------------------- -class TestOpenAISDKRetryDisabled: +class TestOpenAISDKRetryDisabled: def test_sync_client_max_retries_zero(self): """OpenAI sync client should be created with max_retries=0.""" from openviking.models.vlm.backends.openai_vlm import OpenAIVLM - vlm = OpenAIVLM({ - "api_key": "sk-test", - "model": "gpt-4o-mini", - "provider": "openai", - }) + vlm = OpenAIVLM( + { + "api_key": "sk-test", + "model": "gpt-4o-mini", + "provider": "openai", + } + ) with patch("openai.OpenAI") as mock_openai: mock_openai.return_value = MagicMock() vlm._sync_client = None # force re-creation vlm.get_client() call_kwargs = mock_openai.call_args - assert call_kwargs[1].get("max_retries") == 0 or \ - (len(call_kwargs[0]) == 0 and call_kwargs.kwargs.get("max_retries") == 0) + assert call_kwargs[1].get("max_retries") == 0 or ( + len(call_kwargs[0]) == 0 and call_kwargs.kwargs.get("max_retries") == 0 + ) def test_async_client_max_retries_zero(self): """OpenAI async client should be created with max_retries=0.""" from openviking.models.vlm.backends.openai_vlm import OpenAIVLM - vlm = OpenAIVLM({ - "api_key": "sk-test", - "model": "gpt-4o-mini", - "provider": "openai", - }) + vlm = OpenAIVLM( + { + "api_key": "sk-test", + "model": "gpt-4o-mini", + "provider": "openai", + } + ) with patch("openai.AsyncOpenAI") as mock_async_openai: mock_async_openai.return_value = MagicMock() vlm._async_client = None # force re-creation vlm.get_async_client() call_kwargs = mock_async_openai.call_args - assert call_kwargs[1].get("max_retries") == 0 or \ - (len(call_kwargs[0]) == 0 and call_kwargs.kwargs.get("max_retries") == 0) + assert call_kwargs[1].get("max_retries") == 0 or ( + len(call_kwargs[0]) == 0 and call_kwargs.kwargs.get("max_retries") == 0 + ) From f1c45a4fd30ea347574a119b4b4fe0cf9ed11231 Mon Sep 17 00:00:00 2001 From: Sergii Nemesh Date: Sat, 28 Mar 2026 11:54:47 +0100 Subject: [PATCH 09/10] style: fix ruff lint errors (import sorting, unused imports) --- openviking/models/embedder/base.py | 2 +- .../models/embedder/gemini_embedders.py | 4 +-- .../models/embedder/openai_embedders.py | 2 +- openviking/models/vlm/backends/litellm_vlm.py | 3 ++- openviking/models/vlm/backends/openai_vlm.py | 7 ++--- .../models/vlm/backends/volcengine_vlm.py | 5 ++-- openviking/session/memory/memory_react.py | 27 +++++++++---------- tests/unit/test_backward_compat.py | 1 - .../unit/test_embedding_retry_integration.py | 2 -- tests/unit/test_retry_config.py | 2 -- tests/unit/test_vlm_retry_integration.py | 1 - 11 files changed, 25 insertions(+), 31 deletions(-) diff --git a/openviking/models/embedder/base.py b/openviking/models/embedder/base.py index 8c2f00d01..2cf9b7a0b 100644 --- a/openviking/models/embedder/base.py +++ b/openviking/models/embedder/base.py @@ -256,7 +256,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes return [ EmbedResult(dense_vector=d.dense_vector, sparse_vector=s.sparse_vector) - for d, s in zip(dense_results, sparse_results) + for d, s in zip(dense_results, sparse_results, strict=False) ] def get_dimension(self) -> int: diff --git a/openviking/models/embedder/gemini_embedders.py b/openviking/models/embedder/gemini_embedders.py index a3265afa3..c8bae8fd9 100644 --- a/openviking/models/embedder/gemini_embedders.py +++ b/openviking/models/embedder/gemini_embedders.py @@ -237,7 +237,7 @@ def embed_batch( if titles is not None: return [ self.embed(text, is_query=is_query, task_type=task_type, title=title) - for text, title in zip(texts, titles) + for text, title in zip(texts, titles, strict=False) ] # Resolve effective task_type from is_query when no explicit override if task_type is None: @@ -268,7 +268,7 @@ def _batch_call(texts=non_empty_texts, cfg=config): response = transient_retry(_batch_call, max_retries=self.max_retries) batch_results = [None] * len(batch) - for j, emb in zip(non_empty_indices, response.embeddings): + for j, emb in zip(non_empty_indices, response.embeddings, strict=False): batch_results[j] = EmbedResult( dense_vector=truncate_and_normalize(list(emb.values), self._dimension) ) diff --git a/openviking/models/embedder/openai_embedders.py b/openviking/models/embedder/openai_embedders.py index 408ccfe7a..265c6b902 100644 --- a/openviking/models/embedder/openai_embedders.py +++ b/openviking/models/embedder/openai_embedders.py @@ -6,7 +6,6 @@ import openai -from openviking.models.vlm.registry import DEFAULT_AZURE_API_VERSION from openviking.models.embedder.base import ( DenseEmbedderBase, EmbedResult, @@ -14,6 +13,7 @@ SparseEmbedderBase, ) from openviking.models.retry import transient_retry +from openviking.models.vlm.registry import DEFAULT_AZURE_API_VERSION from openviking.telemetry import get_current_telemetry diff --git a/openviking/models/vlm/backends/litellm_vlm.py b/openviking/models/vlm/backends/litellm_vlm.py index 45135fac5..077eb760a 100644 --- a/openviking/models/vlm/backends/litellm_vlm.py +++ b/openviking/models/vlm/backends/litellm_vlm.py @@ -16,9 +16,10 @@ import litellm from litellm import acompletion, completion -from ..base import ToolCall, VLMBase, VLMResponse from openviking.models.retry import transient_retry, transient_retry_async +from ..base import ToolCall, VLMBase, VLMResponse + logger = logging.getLogger(__name__) PROVIDER_CONFIGS: Dict[str, Dict[str, Any]] = { diff --git a/openviking/models/vlm/backends/openai_vlm.py b/openviking/models/vlm/backends/openai_vlm.py index 22f0c8689..3d1e080ad 100644 --- a/openviking/models/vlm/backends/openai_vlm.py +++ b/openviking/models/vlm/backends/openai_vlm.py @@ -7,13 +7,14 @@ import logging import time from pathlib import Path -from urllib.parse import urlparse from typing import Any, Dict, List, Optional, Union +from urllib.parse import urlparse -from ..base import VLMBase, VLMResponse, ToolCall -from ..registry import DEFAULT_AZURE_API_VERSION from openviking.models.retry import transient_retry, transient_retry_async +from ..base import ToolCall, VLMBase, VLMResponse +from ..registry import DEFAULT_AZURE_API_VERSION + logger = logging.getLogger(__name__) diff --git a/openviking/models/vlm/backends/volcengine_vlm.py b/openviking/models/vlm/backends/volcengine_vlm.py index 60a4f4219..dea2f116f 100644 --- a/openviking/models/vlm/backends/volcengine_vlm.py +++ b/openviking/models/vlm/backends/volcengine_vlm.py @@ -9,10 +9,11 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union -from .openai_vlm import OpenAIVLM -from ..base import VLMResponse, ToolCall from openviking.models.retry import transient_retry, transient_retry_async +from ..base import ToolCall, VLMResponse +from .openai_vlm import OpenAIVLM + logger = logging.getLogger(__name__) diff --git a/openviking/session/memory/memory_react.py b/openviking/session/memory/memory_react.py index f8700d33f..51b692c5f 100644 --- a/openviking/session/memory/memory_react.py +++ b/openviking/session/memory/memory_react.py @@ -8,22 +8,10 @@ import asyncio import json -from enum import Enum from typing import Any, Dict, List, Optional, Set, Tuple -from pydantic import BaseModel, Field - -from openviking.models.vlm.base import VLMBase, VLMResponse +from openviking.models.vlm.base import VLMBase from openviking.server.identity import RequestContext -from openviking.session.memory.utils import ( - collect_allowed_directories, - detect_language_from_conversation, - extract_json_from_markdown, - parse_json_with_stability, - parse_memory_file_with_fields, - pretty_print_messages, - validate_operations_uris, -) from openviking.session.memory.dataclass import MemoryOperations from openviking.session.memory.memory_type_registry import MemoryTypeRegistry from openviking.session.memory.schema_model_generator import ( @@ -31,9 +19,18 @@ SchemaPromptGenerator, ) from openviking.session.memory.tools import ( + add_tool_call_pair_to_messages, get_tool, get_tool_schemas, - add_tool_call_pair_to_messages, +) +from openviking.session.memory.utils import ( + collect_allowed_directories, + detect_language_from_conversation, + extract_json_from_markdown, + parse_json_with_stability, + parse_memory_file_with_fields, + pretty_print_messages, + validate_operations_uris, ) from openviking.storage.viking_fs import VikingFS, get_viking_fs from openviking_cli.utils import get_logger @@ -391,7 +388,7 @@ def _get_system_prompt(self, output_language: str) -> str: import json schema_str = json.dumps(self._json_schema, ensure_ascii=False) - allowed_dirs_list = self._get_allowed_directories_list() + self._get_allowed_directories_list() return f"""You are a memory extraction agent. Your task is to analyze conversations and update memories. diff --git a/tests/unit/test_backward_compat.py b/tests/unit/test_backward_compat.py index f005ac066..ba4d8a762 100644 --- a/tests/unit/test_backward_compat.py +++ b/tests/unit/test_backward_compat.py @@ -13,7 +13,6 @@ from __future__ import annotations import inspect -import time from unittest.mock import patch import pytest diff --git a/tests/unit/test_embedding_retry_integration.py b/tests/unit/test_embedding_retry_integration.py index bb8ab2a9f..fbfaba1a8 100644 --- a/tests/unit/test_embedding_retry_integration.py +++ b/tests/unit/test_embedding_retry_integration.py @@ -17,7 +17,6 @@ import pytest - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -162,7 +161,6 @@ def test_embed_retries_on_transient_error(self, vikingdb_embedder): errors = [_HttpError(503)] call_count = 0 - original_call_api = vikingdb_embedder._call_api def fake_call_api(*args, **kwargs): nonlocal call_count diff --git a/tests/unit/test_retry_config.py b/tests/unit/test_retry_config.py index 9e8adb235..24f28fb5e 100644 --- a/tests/unit/test_retry_config.py +++ b/tests/unit/test_retry_config.py @@ -11,8 +11,6 @@ from __future__ import annotations -import pytest - class TestVLMConfigMaxRetries: def test_default_max_retries(self): diff --git a/tests/unit/test_vlm_retry_integration.py b/tests/unit/test_vlm_retry_integration.py index 69c4b4b1f..e65f0ef37 100644 --- a/tests/unit/test_vlm_retry_integration.py +++ b/tests/unit/test_vlm_retry_integration.py @@ -19,7 +19,6 @@ import pytest - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- From 6f9439b2fef587e8edf296cbcde957c5d12750f6 Mon Sep 17 00:00:00 2001 From: Sergii Nemesh Date: Sat, 28 Mar 2026 12:05:21 +0100 Subject: [PATCH 10/10] style: format test_embedding_retry_integration.py --- tests/unit/test_embedding_retry_integration.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/test_embedding_retry_integration.py b/tests/unit/test_embedding_retry_integration.py index fbfaba1a8..02011b4cb 100644 --- a/tests/unit/test_embedding_retry_integration.py +++ b/tests/unit/test_embedding_retry_integration.py @@ -161,7 +161,6 @@ def test_embed_retries_on_transient_error(self, vikingdb_embedder): errors = [_HttpError(503)] call_count = 0 - def fake_call_api(*args, **kwargs): nonlocal call_count call_count += 1