Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions openviking/models/embedder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
- Volcengine: Dense, Sparse, Hybrid
- Jina AI: Dense only
- Voyage AI: Dense only
- Cohere: Dense only
- MiniMax: Dense only
"""

from openviking.models.embedder.base import (
Expand All @@ -23,6 +25,7 @@
HybridEmbedderBase,
SparseEmbedderBase,
)
from openviking.models.embedder.cohere_embedders import CohereDenseEmbedder
from openviking.models.embedder.jina_embedders import JinaDenseEmbedder
from openviking.models.embedder.minimax_embedders import MinimaxDenseEmbedder
from openviking.models.embedder.openai_embedders import OpenAIDenseEmbedder
Expand All @@ -39,6 +42,8 @@
from openviking.models.embedder.voyage_embedders import VoyageDenseEmbedder

__all__ = [
# Cohere implementations
"CohereDenseEmbedder",
# Base classes
"EmbedResult",
"EmbedderBase",
Expand Down
145 changes: 145 additions & 0 deletions openviking/models/embedder/cohere_embedders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (c) 2026 Antigravity / Dico Angelo
# SPDX-License-Identifier: Apache-2.0
"""Cohere dense embedder implementation.

Uses Cohere's Embed API v2 (https://docs.cohere.com/reference/embed).
Supports embed-v4.0 and embed-english-v3.0 models with input_type
for asymmetric retrieval.
"""

from typing import Any, Dict, List, Optional

import httpx

from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult, truncate_and_normalize

COHERE_MODEL_DIMENSIONS = {
"embed-v4.0": 1536,
"embed-multilingual-v3.0": 1024,
"embed-english-v3.0": 1024,
"embed-multilingual-light-v3.0": 384,
"embed-english-light-v3.0": 384,
}

# embed-v4.0 supports server-side dimension reduction via output_dimension
COHERE_ALLOWED_DIMENSIONS = {
"embed-v4.0": {256, 512, 1024, 1536},
}


def get_cohere_model_default_dimension(model_name: Optional[str]) -> int:
if not model_name:
return 1024
return COHERE_MODEL_DIMENSIONS.get(model_name.lower(), 1024)


class CohereDenseEmbedder(DenseEmbedderBase):
"""Cohere dense embedder.

Cohere uses its own REST API (not OpenAI-compatible), so we call it
directly via httpx. Supports asymmetric search via input_type.
"""

def __init__(
self,
model_name: str = "embed-v4.0",
api_key: Optional[str] = None,
api_base: Optional[str] = None,
dimension: Optional[int] = None,
config: Optional[Dict[str, Any]] = None,
):
super().__init__(model_name, config)

self.api_key = api_key
self.api_base = (api_base or "https://api.cohere.com").rstrip("/")

if not self.api_key:
raise ValueError("api_key is required for Cohere provider")

self._native_dimension = get_cohere_model_default_dimension(model_name)
self._dimension = dimension or self._native_dimension

# Check if server-side dimension reduction is supported
normalized = model_name.lower()
allowed = COHERE_ALLOWED_DIMENSIONS.get(normalized)
if allowed and dimension is not None and dimension not in allowed:
raise ValueError(
f"Dimension {dimension} not supported for '{model_name}'. "
f"Allowed: {sorted(allowed)}"
)

# Prefer server-side output_dimension when the model supports it
self._use_server_dim = (
allowed is not None
and dimension is not None
and dimension != self._native_dimension
)
# Fallback to client-side truncation for v3 models
self._needs_truncation = (
not self._use_server_dim
and dimension is not None
and dimension < self._native_dimension
)
self._client = httpx.Client(
base_url=self.api_base,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
timeout=60.0,
)

def _call_api(self, texts: List[str], input_type: str) -> List[List[float]]:
payload: Dict[str, Any] = {
"model": self.model_name,
"texts": texts,
"input_type": input_type,
"embedding_types": ["float"],
}
if self._use_server_dim:
payload["output_dimension"] = self._dimension
resp = self._client.post("/v2/embed", json=payload)
resp.raise_for_status()
data = resp.json()
return data["embeddings"]["float"]

def _normalize_vector(self, vector: List[float]) -> List[float]:
"""Truncate and renormalize if dimension reduction was requested."""
if self._needs_truncation:
return truncate_and_normalize(vector, self._dimension)
return vector

def embed(self, text: str, is_query: bool = False) -> EmbedResult:
input_type = "search_query" if is_query else "search_document"
try:
vectors = self._call_api([text], input_type)
return EmbedResult(dense_vector=self._normalize_vector(vectors[0]))
except httpx.HTTPStatusError as e:
raise RuntimeError(f"Cohere API error: {e.response.status_code} {e.response.text}") from e
except Exception as e:
raise RuntimeError(f"Cohere embedding failed: {e}") from e

def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]:
if not texts:
return []
input_type = "search_query" if is_query else "search_document"
try:
results: List[EmbedResult] = []
for i in range(0, len(texts), 96):
batch = texts[i : i + 96]
vectors = self._call_api(batch, input_type)
results.extend(
EmbedResult(dense_vector=self._normalize_vector(v)) for v in vectors
)
return results
except httpx.HTTPStatusError as e:
raise RuntimeError(f"Cohere API error: {e.response.status_code} {e.response.text}") from e
except Exception as e:
raise RuntimeError(f"Cohere batch embedding failed: {e}") from e

def close(self):
"""Close the httpx client connection pool."""
self._client.close()

def get_dimension(self) -> int:
return self._dimension
26 changes: 20 additions & 6 deletions openviking/retrieve/hierarchical_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class HierarchicalRetriever:
MAX_RELATIONS = 5 # Maximum relations per resource
SCORE_PROPAGATION_ALPHA = 0.5 # Score propagation coefficient
DIRECTORY_DOMINANCE_RATIO = 1.2 # Directory score must exceed max child score
GLOBAL_SEARCH_TOPK = 5 # Global retrieval count
GLOBAL_SEARCH_TOPK = 10 # Global retrieval count (more candidates = better rerank precision)
HOTNESS_ALPHA = 0.2 # Weight for hotness score in final ranking (0 = disabled)
LEVEL_URI_SUFFIX = {0: ".abstract.md", 1: ".overview.md"}

Expand All @@ -71,12 +71,26 @@ def __init__(
# Use rerank threshold if available, otherwise use a default
self.threshold = rerank_config.threshold if rerank_config else 0

# Initialize rerank client only if config is available
# Initialize rerank client based on provider
if rerank_config and rerank_config.is_available():
self._rerank_client = RerankClient.from_config(rerank_config)
logger.info(
f"[HierarchicalRetriever] Rerank config available, threshold={self.threshold}"
)
provider = rerank_config._effective_provider()
if provider == "cohere":
from openviking_cli.utils.cohere_rerank import CohereRerankClient

self._rerank_client = CohereRerankClient(
api_key=rerank_config.api_key,
model=rerank_config.model_name
if rerank_config.model_name != "doubao-seed-rerank"
else "rerank-v3.5",
)
logger.info(
f"[HierarchicalRetriever] Cohere rerank enabled, threshold={self.threshold}"
)
else:
self._rerank_client = RerankClient.from_config(rerank_config)
logger.info(
f"[HierarchicalRetriever] VikingDB rerank enabled, threshold={self.threshold}"
)
else:
self._rerank_client = None
logger.info(
Expand Down
87 changes: 87 additions & 0 deletions openviking_cli/utils/cohere_rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2026 Antigravity / Dico Angelo
# SPDX-License-Identifier: Apache-2.0
"""
Cohere Rerank API Client.

Drop-in replacement for VikingDB RerankClient, using Cohere's Rerank v3.5 API.
Same interface: rerank_batch(query, documents) -> List[float]
"""

from typing import List, Optional

import httpx

from openviking_cli.utils.logger import get_logger

logger = get_logger(__name__)


class CohereRerankClient:
"""Cohere Rerank API client — same interface as VikingDB RerankClient."""

def __init__(
self,
api_key: str,
model: str = "rerank-v3.5",
api_base: str = "https://api.cohere.com",
):
self.api_key = api_key
self.model = model
self.api_base = api_base.rstrip("/")
self._client = httpx.Client(
base_url=self.api_base,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
timeout=30.0,
)

def rerank_batch(self, query: str, documents: List[str]) -> Optional[List[float]]:
"""
Rerank documents against a query using Cohere Rerank API.

Args:
query: Query text
documents: List of document texts to rank

Returns:
List of relevance scores (0-1) in same order as input documents,
or None on failure (caller should fall back to vector scores).
"""
if not documents:
return []

try:
resp = self._client.post(
"/v2/rerank",
json={
"model": self.model,
"query": query,
"documents": documents,
"top_n": len(documents),
"return_documents": False,
},
)
resp.raise_for_status()
data = resp.json()

# Cohere returns results sorted by score desc with index field
# We need to map back to original order
scores = [0.0] * len(documents)
for result in data.get("results", []):
idx = result["index"]
scores[idx] = result["relevance_score"]

logger.debug(f"[CohereRerank] Reranked {len(documents)} documents")
return scores

except httpx.HTTPStatusError as e:
logger.error(f"[CohereRerank] API error: {e.response.status_code} {e.response.text}")
return None
except Exception as e:
logger.error(f"[CohereRerank] Rerank failed: {e}")
return None

def close(self):
self._client.close()
24 changes: 23 additions & 1 deletion openviking_cli/utils/config/embedding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,11 @@ def validate_config(self):
"ollama",
"voyage",
"minimax",
"cohere",
]:
raise ValueError(
f"Invalid embedding provider: '{self.provider}'. Must be one of: "
"'openai', 'volcengine', 'vikingdb', 'jina', 'ollama', 'voyage', 'minimax'"
"'openai', 'volcengine', 'vikingdb', 'jina', 'ollama', 'voyage', 'minimax', 'cohere'"
)

# Provider-specific validation
Expand Down Expand Up @@ -143,6 +144,10 @@ def validate_config(self):
if not self.api_key:
raise ValueError("MiniMax provider requires 'api_key' to be set")

elif self.provider == "cohere":
if not self.api_key:
raise ValueError("Cohere provider requires 'api_key' to be set")

return self

def get_effective_dimension(self) -> int:
Expand All @@ -158,6 +163,13 @@ def get_effective_dimension(self) -> int:

return get_voyage_model_default_dimension(self.model)

if provider == "cohere":
from openviking.models.embedder.cohere_embedders import (
get_cohere_model_default_dimension,
)

return get_cohere_model_default_dimension(self.model)

return 2048


Expand Down Expand Up @@ -212,6 +224,7 @@ def _create_embedder(
ValueError: If provider/type combination is not supported
"""
from openviking.models.embedder import (
CohereDenseEmbedder,
JinaDenseEmbedder,
MinimaxDenseEmbedder,
OpenAIDenseEmbedder,
Expand Down Expand Up @@ -347,6 +360,15 @@ def _create_embedder(
**({"extra_headers": cfg.extra_headers} if cfg.extra_headers else {}),
},
),
("cohere", "dense"): (
CohereDenseEmbedder,
lambda cfg: {
"model_name": cfg.model,
"api_key": cfg.api_key,
"api_base": cfg.api_base,
"dimension": cfg.dimension,
},
),
}

key = (provider, embedder_type)
Expand Down
Loading
Loading