diff --git a/openviking/models/embedder/__init__.py b/openviking/models/embedder/__init__.py index f69a74e34..bceae2a8f 100644 --- a/openviking/models/embedder/__init__.py +++ b/openviking/models/embedder/__init__.py @@ -13,6 +13,7 @@ - Volcengine: Dense, Sparse, Hybrid - Jina AI: Dense only - Voyage AI: Dense only +- Cohere: Dense only - Google Gemini: Dense only - LiteLLM: Dense only (bridges to OpenRouter, Ollama, vLLM, and many others) """ @@ -25,6 +26,7 @@ HybridEmbedderBase, SparseEmbedderBase, ) +from openviking.models.embedder.cohere_embedders import CohereDenseEmbedder try: from openviking.models.embedder.gemini_embedders import GeminiDenseEmbedder @@ -51,6 +53,8 @@ from openviking.models.embedder.voyage_embedders import VoyageDenseEmbedder __all__ = [ + # Cohere implementations + "CohereDenseEmbedder", # Base classes "EmbedResult", "EmbedderBase", diff --git a/openviking/models/embedder/cohere_embedders.py b/openviking/models/embedder/cohere_embedders.py new file mode 100644 index 000000000..3c5815d26 --- /dev/null +++ b/openviking/models/embedder/cohere_embedders.py @@ -0,0 +1,145 @@ +# Copyright (c) 2026 Antigravity / Dico Angelo +# SPDX-License-Identifier: Apache-2.0 +"""Cohere dense embedder implementation. + +Uses Cohere's Embed API v2 (https://docs.cohere.com/reference/embed). +Supports embed-v4.0 and embed-english-v3.0 models with input_type +for asymmetric retrieval. +""" + +from typing import Any, Dict, List, Optional + +import httpx + +from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult, truncate_and_normalize + +COHERE_MODEL_DIMENSIONS = { + "embed-v4.0": 1536, + "embed-multilingual-v3.0": 1024, + "embed-english-v3.0": 1024, + "embed-multilingual-light-v3.0": 384, + "embed-english-light-v3.0": 384, +} + +# embed-v4.0 supports server-side dimension reduction via output_dimension +COHERE_ALLOWED_DIMENSIONS = { + "embed-v4.0": {256, 512, 1024, 1536}, +} + + +def get_cohere_model_default_dimension(model_name: Optional[str]) -> int: + if not model_name: + return 1024 + return COHERE_MODEL_DIMENSIONS.get(model_name.lower(), 1024) + + +class CohereDenseEmbedder(DenseEmbedderBase): + """Cohere dense embedder. + + Cohere uses its own REST API (not OpenAI-compatible), so we call it + directly via httpx. Supports asymmetric search via input_type. + """ + + def __init__( + self, + model_name: str = "embed-v4.0", + api_key: Optional[str] = None, + api_base: Optional[str] = None, + dimension: Optional[int] = None, + config: Optional[Dict[str, Any]] = None, + ): + super().__init__(model_name, config) + + self.api_key = api_key + self.api_base = (api_base or "https://api.cohere.com").rstrip("/") + + if not self.api_key: + raise ValueError("api_key is required for Cohere provider") + + self._native_dimension = get_cohere_model_default_dimension(model_name) + self._dimension = dimension or self._native_dimension + + # Check if server-side dimension reduction is supported + normalized = model_name.lower() + allowed = COHERE_ALLOWED_DIMENSIONS.get(normalized) + if allowed and dimension is not None and dimension not in allowed: + raise ValueError( + f"Dimension {dimension} not supported for '{model_name}'. " + f"Allowed: {sorted(allowed)}" + ) + + # Prefer server-side output_dimension when the model supports it + self._use_server_dim = ( + allowed is not None + and dimension is not None + and dimension != self._native_dimension + ) + # Fallback to client-side truncation for v3 models + self._needs_truncation = ( + not self._use_server_dim + and dimension is not None + and dimension < self._native_dimension + ) + self._client = httpx.Client( + base_url=self.api_base, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + timeout=60.0, + ) + + def _call_api(self, texts: List[str], input_type: str) -> List[List[float]]: + payload: Dict[str, Any] = { + "model": self.model_name, + "texts": texts, + "input_type": input_type, + "embedding_types": ["float"], + } + if self._use_server_dim: + payload["output_dimension"] = self._dimension + resp = self._client.post("/v2/embed", json=payload) + resp.raise_for_status() + data = resp.json() + return data["embeddings"]["float"] + + def _normalize_vector(self, vector: List[float]) -> List[float]: + """Truncate and renormalize if dimension reduction was requested.""" + if self._needs_truncation: + return truncate_and_normalize(vector, self._dimension) + return vector + + def embed(self, text: str, is_query: bool = False) -> EmbedResult: + input_type = "search_query" if is_query else "search_document" + try: + vectors = self._call_api([text], input_type) + return EmbedResult(dense_vector=self._normalize_vector(vectors[0])) + except httpx.HTTPStatusError as e: + raise RuntimeError(f"Cohere API error: {e.response.status_code} {e.response.text}") from e + except Exception as e: + raise RuntimeError(f"Cohere embedding failed: {e}") from e + + def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: + if not texts: + return [] + input_type = "search_query" if is_query else "search_document" + try: + results: List[EmbedResult] = [] + for i in range(0, len(texts), 96): + batch = texts[i : i + 96] + vectors = self._call_api(batch, input_type) + results.extend( + EmbedResult(dense_vector=self._normalize_vector(v)) for v in vectors + ) + return results + except httpx.HTTPStatusError as e: + raise RuntimeError(f"Cohere API error: {e.response.status_code} {e.response.text}") from e + except Exception as e: + raise RuntimeError(f"Cohere batch embedding failed: {e}") from e + + def close(self): + """Close the httpx client connection pool.""" + self._client.close() + + def get_dimension(self) -> int: + return self._dimension diff --git a/openviking/retrieve/hierarchical_retriever.py b/openviking/retrieve/hierarchical_retriever.py index b60d612af..9a3cbf597 100644 --- a/openviking/retrieve/hierarchical_retriever.py +++ b/openviking/retrieve/hierarchical_retriever.py @@ -48,7 +48,7 @@ class HierarchicalRetriever: MAX_RELATIONS = 5 # Maximum relations per resource SCORE_PROPAGATION_ALPHA = 0.5 # Score propagation coefficient DIRECTORY_DOMINANCE_RATIO = 1.2 # Directory score must exceed max child score - GLOBAL_SEARCH_TOPK = 5 # Global retrieval count + GLOBAL_SEARCH_TOPK = 10 # Global retrieval count (more candidates = better rerank precision) HOTNESS_ALPHA = 0.2 # Weight for hotness score in final ranking (0 = disabled) LEVEL_URI_SUFFIX = {0: ".abstract.md", 1: ".overview.md"} @@ -72,12 +72,26 @@ def __init__( # Use rerank threshold if available, otherwise use a default self.threshold = rerank_config.threshold if rerank_config else 0 - # Initialize rerank client only if config is available + # Initialize rerank client based on provider if rerank_config and rerank_config.is_available(): - self._rerank_client = RerankClient.from_config(rerank_config) - logger.info( - f"[HierarchicalRetriever] Rerank config available, threshold={self.threshold}" - ) + provider = rerank_config._effective_provider() + if provider == "cohere": + from openviking_cli.utils.cohere_rerank import CohereRerankClient + + self._rerank_client = CohereRerankClient( + api_key=rerank_config.api_key, + model=rerank_config.model_name + if rerank_config.model_name != "doubao-seed-rerank" + else "rerank-v3.5", + ) + logger.info( + f"[HierarchicalRetriever] Cohere rerank enabled, threshold={self.threshold}" + ) + else: + self._rerank_client = RerankClient.from_config(rerank_config) + logger.info( + f"[HierarchicalRetriever] VikingDB rerank enabled, threshold={self.threshold}" + ) else: self._rerank_client = None logger.info( diff --git a/openviking_cli/utils/cohere_rerank.py b/openviking_cli/utils/cohere_rerank.py new file mode 100644 index 000000000..e339f7ba4 --- /dev/null +++ b/openviking_cli/utils/cohere_rerank.py @@ -0,0 +1,87 @@ +# Copyright (c) 2026 Antigravity / Dico Angelo +# SPDX-License-Identifier: Apache-2.0 +""" +Cohere Rerank API Client. + +Drop-in replacement for VikingDB RerankClient, using Cohere's Rerank v3.5 API. +Same interface: rerank_batch(query, documents) -> List[float] +""" + +from typing import List, Optional + +import httpx + +from openviking_cli.utils.logger import get_logger + +logger = get_logger(__name__) + + +class CohereRerankClient: + """Cohere Rerank API client — same interface as VikingDB RerankClient.""" + + def __init__( + self, + api_key: str, + model: str = "rerank-v3.5", + api_base: str = "https://api.cohere.com", + ): + self.api_key = api_key + self.model = model + self.api_base = api_base.rstrip("/") + self._client = httpx.Client( + base_url=self.api_base, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + timeout=30.0, + ) + + def rerank_batch(self, query: str, documents: List[str]) -> Optional[List[float]]: + """ + Rerank documents against a query using Cohere Rerank API. + + Args: + query: Query text + documents: List of document texts to rank + + Returns: + List of relevance scores (0-1) in same order as input documents, + or None on failure (caller should fall back to vector scores). + """ + if not documents: + return [] + + try: + resp = self._client.post( + "/v2/rerank", + json={ + "model": self.model, + "query": query, + "documents": documents, + "top_n": len(documents), + "return_documents": False, + }, + ) + resp.raise_for_status() + data = resp.json() + + # Cohere returns results sorted by score desc with index field + # We need to map back to original order + scores = [0.0] * len(documents) + for result in data.get("results", []): + idx = result["index"] + scores[idx] = result["relevance_score"] + + logger.debug(f"[CohereRerank] Reranked {len(documents)} documents") + return scores + + except httpx.HTTPStatusError as e: + logger.error(f"[CohereRerank] API error: {e.response.status_code} {e.response.text}") + return None + except Exception as e: + logger.error(f"[CohereRerank] Rerank failed: {e}") + return None + + def close(self): + self._client.close() diff --git a/openviking_cli/utils/config/embedding_config.py b/openviking_cli/utils/config/embedding_config.py index 52025ecd2..9082d7d3a 100644 --- a/openviking_cli/utils/config/embedding_config.py +++ b/openviking_cli/utils/config/embedding_config.py @@ -103,11 +103,12 @@ def validate_config(self): "gemini", "voyage", "minimax", + "cohere", "litellm", ]: raise ValueError( f"Invalid embedding provider: '{self.provider}'. Must be one of: " - "'openai', 'azure', 'volcengine', 'vikingdb', 'jina', 'ollama', 'gemini', 'voyage', 'minimax', 'litellm'" + "'openai', 'azure', 'volcengine', 'vikingdb', 'jina', 'ollama', 'gemini', 'voyage', 'minimax', 'cohere', 'litellm'" ) # Provider-specific validation @@ -179,6 +180,10 @@ def validate_config(self): if not self.api_key: raise ValueError("MiniMax provider requires 'api_key' to be set") + elif self.provider == "cohere": + if not self.api_key: + raise ValueError("Cohere provider requires 'api_key' to be set") + elif self.provider == "litellm": # litellm handles auth via env vars or explicit api_key; no strict requirement if not self.dimension: @@ -202,6 +207,13 @@ def get_effective_dimension(self) -> int: return get_voyage_model_default_dimension(self.model) + if provider == "cohere": + from openviking.models.embedder.cohere_embedders import ( + get_cohere_model_default_dimension, + ) + + return get_cohere_model_default_dimension(self.model) + if provider == "gemini": from openviking.models.embedder.gemini_embedders import GeminiDenseEmbedder @@ -298,6 +310,7 @@ def _create_embedder( ValueError: If provider/type combination is not supported """ from openviking.models.embedder import ( + CohereDenseEmbedder, GeminiDenseEmbedder, JinaDenseEmbedder, LiteLLMDenseEmbedder, @@ -464,6 +477,15 @@ def _create_embedder( **({"extra_headers": cfg.extra_headers} if cfg.extra_headers else {}), }, ), + ("cohere", "dense"): ( + CohereDenseEmbedder, + lambda cfg: { + "model_name": cfg.model, + "api_key": cfg.api_key, + "api_base": cfg.api_base, + "dimension": cfg.dimension, + }, + ), ("litellm", "dense"): ( LiteLLMDenseEmbedder, lambda cfg: { diff --git a/openviking_cli/utils/config/rerank_config.py b/openviking_cli/utils/config/rerank_config.py index 0a4f3928f..ce11c237b 100644 --- a/openviking_cli/utils/config/rerank_config.py +++ b/openviking_cli/utils/config/rerank_config.py @@ -6,10 +6,11 @@ class RerankConfig(BaseModel): - """Configuration for rerank API (VikingDB or OpenAI-compatible providers).""" + """Configuration for rerank API. Supports VikingDB, Cohere, OpenAI-compatible, and LiteLLM providers.""" - provider: str = Field( - default="vikingdb", description="Rerank provider: 'vikingdb', 'openai', or 'litellm'" + provider: Optional[str] = Field( + default=None, + description="Rerank provider: 'vikingdb', 'cohere', 'openai', or 'litellm'. Auto-detected from config if omitted.", ) # VikingDB fields @@ -21,13 +22,13 @@ class RerankConfig(BaseModel): model_name: str = Field(default="doubao-seed-rerank", description="Rerank model name") model_version: str = Field(default="251028", description="Rerank model version") - # OpenAI-compatible fields + # Shared / OpenAI-compatible / Cohere fields api_key: Optional[str] = Field( - default=None, description="Bearer token for OpenAI-compatible providers" + default=None, description="API key (Cohere Bearer token or OpenAI-compatible providers)" ) api_base: Optional[str] = Field(default=None, description="Custom endpoint URL") model: Optional[str] = Field( - default=None, description="Model name for OpenAI-compatible providers" + default=None, description="Model name for OpenAI-compatible or LiteLLM providers" ) threshold: float = Field( @@ -36,25 +37,44 @@ class RerankConfig(BaseModel): model_config = {"extra": "forbid"} + def _effective_provider(self) -> Optional[str]: + """Auto-detect provider from config fields when not explicitly set.""" + if self.provider: + return self.provider.lower() + if self.api_key and self.api_base: + return "openai" + if self.api_key: + return "cohere" + if self.ak and self.sk: + return "vikingdb" + return None + @model_validator(mode="after") def validate_provider_fields(self) -> "RerankConfig": - allowed = ["vikingdb", "openai", "litellm"] - if self.provider not in allowed: - raise ValueError(f"Rerank provider must be one of {allowed}, got '{self.provider}'") - if self.provider == "openai": + provider = self._effective_provider() + if provider and provider not in ["vikingdb", "cohere", "openai", "litellm"]: + raise ValueError( + f"Rerank provider must be one of ['vikingdb', 'cohere', 'openai', 'litellm'], got '{provider}'" + ) + if provider == "openai": if not self.api_key or not self.api_base: raise ValueError( "OpenAI-compatible rerank provider requires 'api_key' and 'api_base'" ) - if self.provider == "litellm": + if provider == "litellm": if not self.model: raise ValueError("LiteLLM rerank provider requires 'model'") return self def is_available(self) -> bool: """Check if rerank is configured.""" - if self.provider == "openai": + p = self._effective_provider() + if p == "cohere": + return self.api_key is not None + if p == "openai": return self.api_key is not None and self.api_base is not None - if self.provider == "litellm": + if p == "litellm": return self.model is not None - return self.ak is not None and self.sk is not None + if p == "vikingdb": + return self.ak is not None and self.sk is not None + return False diff --git a/tests/unit/test_cohere_embedder.py b/tests/unit/test_cohere_embedder.py new file mode 100644 index 000000000..05706b45c --- /dev/null +++ b/tests/unit/test_cohere_embedder.py @@ -0,0 +1,222 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for Cohere embedder support.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from openviking.models.embedder import CohereDenseEmbedder +from openviking.models.embedder.cohere_embedders import ( + COHERE_ALLOWED_DIMENSIONS, + COHERE_MODEL_DIMENSIONS, + get_cohere_model_default_dimension, +) + + +class TestCohereDenseEmbedder: + """Test cases for CohereDenseEmbedder.""" + + def test_init_requires_api_key(self): + with pytest.raises(ValueError, match="api_key is required"): + CohereDenseEmbedder(model_name="embed-v4.0") + + def test_init_with_defaults(self): + embedder = CohereDenseEmbedder( + model_name="embed-v4.0", + api_key="cohere-key", + ) + assert embedder.api_key == "cohere-key" + assert embedder.api_base == "https://api.cohere.com" + assert embedder.get_dimension() == 1536 # embed-v4.0 native + + def test_model_dimensions_constant(self): + assert COHERE_MODEL_DIMENSIONS["embed-v4.0"] == 1536 + assert COHERE_MODEL_DIMENSIONS["embed-english-v3.0"] == 1024 + assert COHERE_MODEL_DIMENSIONS["embed-multilingual-v3.0"] == 1024 + assert COHERE_MODEL_DIMENSIONS["embed-english-light-v3.0"] == 384 + + def test_allowed_dimensions_v4(self): + assert COHERE_ALLOWED_DIMENSIONS["embed-v4.0"] == {256, 512, 1024, 1536} + + def test_default_dimension_helper(self): + assert get_cohere_model_default_dimension("embed-v4.0") == 1536 + assert get_cohere_model_default_dimension("embed-english-v3.0") == 1024 + assert get_cohere_model_default_dimension(None) == 1024 + assert get_cohere_model_default_dimension("unknown-model") == 1024 + + def test_custom_dimension_v4(self): + embedder = CohereDenseEmbedder( + model_name="embed-v4.0", + api_key="cohere-key", + dimension=1024, + ) + assert embedder.get_dimension() == 1024 + + def test_invalid_dimension_for_v4(self): + with pytest.raises(ValueError, match="not supported"): + CohereDenseEmbedder( + model_name="embed-v4.0", + api_key="cohere-key", + dimension=768, + ) + + def test_v3_model_allows_truncation(self): + """v3 models don't have server-side dim reduction, so any smaller dim is OK (client truncation).""" + embedder = CohereDenseEmbedder( + model_name="embed-english-v3.0", + api_key="cohere-key", + dimension=512, + ) + assert embedder.get_dimension() == 512 + assert embedder._needs_truncation is True + + def test_server_dim_for_v4(self): + """embed-v4.0 should use server-side output_dimension when dimension differs from native.""" + embedder = CohereDenseEmbedder( + model_name="embed-v4.0", + api_key="cohere-key", + dimension=1024, + ) + assert embedder._use_server_dim is True + assert embedder._needs_truncation is False + + @patch("openviking.models.embedder.cohere_embedders.httpx.Client") + def test_embed_single_text(self, mock_client_class): + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_response = MagicMock() + mock_response.json.return_value = { + "embeddings": {"float": [[0.1] * 1024]} + } + mock_response.raise_for_status = MagicMock() + mock_client.post.return_value = mock_response + + embedder = CohereDenseEmbedder( + model_name="embed-english-v3.0", + api_key="cohere-key", + ) + result = embedder.embed("Hello world") + + assert result.dense_vector is not None + assert len(result.dense_vector) == 1024 + call_args = mock_client.post.call_args + payload = call_args[1]["json"] + assert payload["model"] == "embed-english-v3.0" + assert payload["input_type"] == "search_document" + assert payload["texts"] == ["Hello world"] + + @patch("openviking.models.embedder.cohere_embedders.httpx.Client") + def test_embed_query_uses_search_query_type(self, mock_client_class): + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_response = MagicMock() + mock_response.json.return_value = { + "embeddings": {"float": [[0.1] * 1024]} + } + mock_response.raise_for_status = MagicMock() + mock_client.post.return_value = mock_response + + embedder = CohereDenseEmbedder( + model_name="embed-english-v3.0", + api_key="cohere-key", + ) + embedder.embed("search query", is_query=True) + + payload = mock_client.post.call_args[1]["json"] + assert payload["input_type"] == "search_query" + + @patch("openviking.models.embedder.cohere_embedders.httpx.Client") + def test_embed_v4_sends_output_dimension(self, mock_client_class): + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_response = MagicMock() + mock_response.json.return_value = { + "embeddings": {"float": [[0.1] * 1024]} + } + mock_response.raise_for_status = MagicMock() + mock_client.post.return_value = mock_response + + embedder = CohereDenseEmbedder( + model_name="embed-v4.0", + api_key="cohere-key", + dimension=1024, + ) + embedder.embed("Hello world") + + payload = mock_client.post.call_args[1]["json"] + assert payload["output_dimension"] == 1024 + + @patch("openviking.models.embedder.cohere_embedders.httpx.Client") + def test_embed_batch(self, mock_client_class): + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_response = MagicMock() + mock_response.json.return_value = { + "embeddings": {"float": [[0.1] * 1024, [0.2] * 1024]} + } + mock_response.raise_for_status = MagicMock() + mock_client.post.return_value = mock_response + + embedder = CohereDenseEmbedder( + model_name="embed-english-v3.0", + api_key="cohere-key", + ) + results = embedder.embed_batch(["Hello", "World"]) + + assert len(results) == 2 + assert len(results[0].dense_vector) == 1024 + + @patch("openviking.models.embedder.cohere_embedders.httpx.Client") + def test_embed_batch_empty(self, mock_client_class): + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + embedder = CohereDenseEmbedder( + model_name="embed-english-v3.0", + api_key="cohere-key", + ) + assert embedder.embed_batch([]) == [] + mock_client.post.assert_not_called() + + @patch("openviking.models.embedder.cohere_embedders.httpx.Client") + def test_embed_api_error(self, mock_client_class): + import httpx + + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_request = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.text = "Unauthorized" + mock_client.post.return_value = mock_response + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "401 Unauthorized", + request=mock_request, + response=mock_response, + ) + + embedder = CohereDenseEmbedder( + model_name="embed-english-v3.0", + api_key="bad-key", + ) + + with pytest.raises(RuntimeError, match="Cohere API error"): + embedder.embed("Hello world") + + @patch("openviking.models.embedder.cohere_embedders.httpx.Client") + def test_close(self, mock_client_class): + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + embedder = CohereDenseEmbedder( + model_name="embed-english-v3.0", + api_key="cohere-key", + ) + embedder.close() + mock_client.close.assert_called_once() diff --git a/tests/unit/test_cohere_rerank.py b/tests/unit/test_cohere_rerank.py new file mode 100644 index 000000000..ee250a4fd --- /dev/null +++ b/tests/unit/test_cohere_rerank.py @@ -0,0 +1,127 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for Cohere rerank client.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from openviking_cli.utils.cohere_rerank import CohereRerankClient + + +class TestCohereRerankClient: + """Test cases for CohereRerankClient.""" + + @patch("openviking_cli.utils.cohere_rerank.httpx.Client") + def test_rerank_batch_basic(self, mock_client_class): + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [ + {"index": 1, "relevance_score": 0.95}, + {"index": 0, "relevance_score": 0.42}, + {"index": 2, "relevance_score": 0.10}, + ] + } + mock_response.raise_for_status = MagicMock() + mock_client.post.return_value = mock_response + + client = CohereRerankClient(api_key="test-key") + scores = client.rerank_batch("What is UCW?", ["doc A", "doc B", "doc C"]) + + assert scores == [0.42, 0.95, 0.10] + payload = mock_client.post.call_args[1]["json"] + assert payload["model"] == "rerank-v3.5" + assert payload["query"] == "What is UCW?" + assert payload["documents"] == ["doc A", "doc B", "doc C"] + assert payload["return_documents"] is False + + @patch("openviking_cli.utils.cohere_rerank.httpx.Client") + def test_rerank_batch_empty(self, mock_client_class): + client = CohereRerankClient(api_key="test-key") + assert client.rerank_batch("query", []) == [] + + @patch("openviking_cli.utils.cohere_rerank.httpx.Client") + def test_rerank_batch_api_error_returns_none(self, mock_client_class): + import httpx + + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.text = "Unauthorized" + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "401", request=MagicMock(), response=mock_response + ) + mock_client.post.return_value = mock_response + + client = CohereRerankClient(api_key="bad-key") + result = client.rerank_batch("query", ["doc"]) + assert result is None + + @patch("openviking_cli.utils.cohere_rerank.httpx.Client") + def test_rerank_preserves_original_order(self, mock_client_class): + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + # Cohere returns sorted by score desc, we must map back to original order + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [ + {"index": 2, "relevance_score": 0.99}, + {"index": 0, "relevance_score": 0.50}, + {"index": 1, "relevance_score": 0.01}, + ] + } + mock_response.raise_for_status = MagicMock() + mock_client.post.return_value = mock_response + + client = CohereRerankClient(api_key="test-key") + scores = client.rerank_batch("q", ["first", "second", "third"]) + + assert scores[0] == 0.50 # "first" was index 0 + assert scores[1] == 0.01 # "second" was index 1 + assert scores[2] == 0.99 # "third" was index 2 + + @patch("openviking_cli.utils.cohere_rerank.httpx.Client") + def test_close(self, mock_client_class): + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + client = CohereRerankClient(api_key="test-key") + client.close() + mock_client.close.assert_called_once() + + +class TestRerankConfig: + """Test RerankConfig provider detection.""" + + def test_cohere_provider_auto_detected(self): + from openviking_cli.utils.config.rerank_config import RerankConfig + + config = RerankConfig(api_key="cohere-key") + assert config._effective_provider() == "cohere" + assert config.is_available() is True + + def test_vikingdb_provider_auto_detected(self): + from openviking_cli.utils.config.rerank_config import RerankConfig + + config = RerankConfig(ak="ak", sk="sk") + assert config._effective_provider() == "vikingdb" + assert config.is_available() is True + + def test_explicit_provider_overrides(self): + from openviking_cli.utils.config.rerank_config import RerankConfig + + config = RerankConfig(provider="cohere", api_key="key", ak="ak", sk="sk") + assert config._effective_provider() == "cohere" + + def test_empty_config_not_available(self): + from openviking_cli.utils.config.rerank_config import RerankConfig + + config = RerankConfig() + assert config.is_available() is False + assert config._effective_provider() is None