diff --git a/examples/python/.env.example b/examples/python/.env.example new file mode 100644 index 00000000..6841df1f --- /dev/null +++ b/examples/python/.env.example @@ -0,0 +1,3 @@ +MOSS_PROJECT_ID=your_project_id +MOSS_PROJECT_KEY=your_project_key +COHERE_API_KEY=your_cohere_api_key diff --git a/examples/python/requirements.txt b/examples/python/requirements.txt index b8470339..cd9589b3 100644 --- a/examples/python/requirements.txt +++ b/examples/python/requirements.txt @@ -1,3 +1,4 @@ -moss>=1.0.0 +-e sdks/python/sdk +cohere>=5.0.0 python-dotenv openai \ No newline at end of file diff --git a/examples/python/reranking_sample.py b/examples/python/reranking_sample.py new file mode 100644 index 00000000..ab92d809 --- /dev/null +++ b/examples/python/reranking_sample.py @@ -0,0 +1,75 @@ +import asyncio +import json +import os + +from dotenv import load_dotenv +from moss import MossClient, DocumentInfo, QueryOptions, RerankOptions + +load_dotenv() + +INDEX_NAME = "rerank-demo-full" + + +async def setup_index(client): + """Delete old index, create fresh one with all FAQ data.""" + faqs_path = os.path.join(os.path.dirname(__file__), "faqs.json") + with open(faqs_path, "r") as f: + faqs = json.load(f) + + docs = [ + DocumentInfo( + id=faq["id"], + text=faq["text"], + metadata={k: str(v) for k, v in faq.get("metadata", {}).items()}, + ) + for faq in faqs + ] + + print(f"Setting up index '{INDEX_NAME}'") + try: + await client.delete_index(INDEX_NAME) + print("Deleted old index.") + except Exception: + pass + + await client.create_index(INDEX_NAME, docs) + print("Index created.") + + await client.load_index(INDEX_NAME) + print("Index loaded.\n") + + +async def main(): + client = MossClient(os.getenv("MOSS_PROJECT_ID"), os.getenv("MOSS_PROJECT_KEY")) + + await setup_index(client) + + print("Without Reranking") + results = await client.query( + INDEX_NAME, + "How to get discount?", + QueryOptions(top_k=5, alpha=0.8), + ) + for i, doc in enumerate(results.docs): + print(f" {i + 1}. [{doc.score:.3f}] {doc.text[:100]}...") + + print("\nWith Cohere Reranking") + results = await client.query( + INDEX_NAME, + "How to get discount?", + QueryOptions( + top_k=10, + alpha=0.8, + rerank=RerankOptions( + provider="cohere", + api_key=os.getenv("COHERE_API_KEY"), + top_n=5, + ), + ), + ) + for i, doc in enumerate(results.docs): + print(f" {i + 1}. [{doc.score:.3f}] {doc.text[:100]}...") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/sdks/python/sdk/pyproject.toml b/sdks/python/sdk/pyproject.toml index a656fcc7..dcc3550b 100644 --- a/sdks/python/sdk/pyproject.toml +++ b/sdks/python/sdk/pyproject.toml @@ -32,7 +32,11 @@ dependencies = [ ] [project.optional-dependencies] +cohere = [ + "cohere>=5.0.0", +] dev = [ + "cohere>=5.0.0", "pytest>=8.4.2", "pytest-asyncio>=1.2.0", "tox>=4.0.0", diff --git a/sdks/python/sdk/src/moss/__init__.py b/sdks/python/sdk/src/moss/__init__.py index 13e3ef2d..3bc72043 100644 --- a/sdks/python/sdk/src/moss/__init__.py +++ b/sdks/python/sdk/src/moss/__init__.py @@ -1,62 +1,64 @@ -""" -Moss Semantic Search SDK - -Powerful Python SDK for semantic search using state-of-the-art embedding models. - -Example: - ```python - from moss import MossClient, DocumentInfo - - client = MossClient('your-project-id', 'your-project-key') - - docs = [DocumentInfo(id="1", text="Example document")] - - result = await client.create_index('my-index', docs, 'moss-minilm') - - await client.load_index('my-index') - results = await client.query('my-index', 'search query') - ``` -""" - -from moss_core import ( - DocumentInfo, - GetDocumentsOptions, - IndexInfo, - IndexStatus, - IndexStatusValues, - ModelRef, - MutationOptions, - MutationResult, - JobStatus, - JobPhase, - JobProgress, - JobStatusResponse, - QueryOptions, - QueryResultDocumentInfo, - SearchResult, -) - -from .client.moss_client import MossClient - -__version__ = "1.0.0b19" - -__all__ = [ - "MossClient", - # Core data types - "DocumentInfo", - "GetDocumentsOptions", - "IndexInfo", - "SearchResult", - "QueryResultDocumentInfo", - "ModelRef", - "IndexStatus", - "IndexStatusValues", - "QueryOptions", - # Mutation types - "MutationResult", - "MutationOptions", - "JobStatus", - "JobPhase", - "JobProgress", - "JobStatusResponse", -] +""" +Moss Semantic Search SDK + +Powerful Python SDK for semantic search using state-of-the-art embedding models. + +Example: + ```python + from moss import MossClient, DocumentInfo + + client = MossClient('your-project-id', 'your-project-key') + + docs = [DocumentInfo(id="1", text="Example document")] + + result = await client.create_index('my-index', docs, 'moss-minilm') + + await client.load_index('my-index') + results = await client.query('my-index', 'search query') + ``` +""" + +from moss_core import ( + DocumentInfo, + GetDocumentsOptions, + IndexInfo, + IndexStatus, + IndexStatusValues, + ModelRef, + MutationOptions, + MutationResult, + JobStatus, + JobPhase, + JobProgress, + JobStatusResponse, + QueryResultDocumentInfo, + SearchResult, +) + +from .client.models import QueryOptions, RerankOptions +from .client.moss_client import MossClient + +__version__ = "1.0.0b19" + +__all__ = [ + "MossClient", + # Core data types + "DocumentInfo", + "GetDocumentsOptions", + "IndexInfo", + "SearchResult", + "QueryResultDocumentInfo", + "ModelRef", + "IndexStatus", + "IndexStatusValues", + "QueryOptions", + # Mutation types + "MutationResult", + "MutationOptions", + "JobStatus", + "JobPhase", + "JobProgress", + "JobStatusResponse", + # Reranking + "RerankOptions", +] diff --git a/sdks/python/sdk/src/moss/__init__.pyi b/sdks/python/sdk/src/moss/__init__.pyi index 6aa41699..e86f183c 100644 --- a/sdks/python/sdk/src/moss/__init__.pyi +++ b/sdks/python/sdk/src/moss/__init__.pyi @@ -1,248 +1,266 @@ -from __future__ import annotations - -from typing import ClassVar, Dict, List, Optional, Sequence - - -class MossClient: - """Semantic search client for vector similarity operations.""" - - DEFAULT_MODEL_ID: ClassVar[str] - - def __init__(self, project_id: str, project_key: str) -> None: ... - - async def create_index( - self, - name: str, - docs: List[DocumentInfo], - model_id: Optional[str] = ..., - ) -> MutationResult: ... - - async def add_docs( - self, - name: str, - docs: List[DocumentInfo], - options: Optional[MutationOptions] = None, - ) -> MutationResult: ... - - async def delete_docs( - self, - name: str, - doc_ids: List[str], - ) -> MutationResult: ... - - async def get_job_status(self, job_id: str) -> JobStatusResponse: ... - - async def get_index(self, name: str) -> IndexInfo: ... - - async def list_indexes(self) -> List[IndexInfo]: ... - - async def delete_index(self, name: str) -> bool: ... - - async def get_docs( - self, - name: str, - options: Optional[GetDocumentsOptions] = None, - ) -> List[DocumentInfo]: ... - - async def load_index( - self, - name: str, - auto_refresh: bool = False, - polling_interval_in_seconds: int = 600, - ) -> str: ... - - async def unload_index(self, name: str) -> None: ... - - async def query( - self, - name: str, - query: str, - options: Optional[QueryOptions] = None, - ) -> SearchResult: ... - - -class MutationResult: - """Return value from create_index/add_docs/delete_docs.""" - - job_id: str - index_name: str - doc_count: int - - -class MutationOptions: - """Options for add_docs (e.g. upsert behavior).""" - - upsert: Optional[bool] - - def __init__(self, upsert: Optional[bool] = None) -> None: ... - - -class GetDocumentsOptions: - """Options for get_docs (e.g. filter by document IDs).""" - - doc_ids: Optional[List[str]] - - def __init__(self, doc_ids: Optional[List[str]] = None) -> None: ... - - -class JobStatus: - """Enum-like class for job status values.""" - - PENDING_UPLOAD: ClassVar[str] - UPLOADING: ClassVar[str] - BUILDING: ClassVar[str] - COMPLETED: ClassVar[str] - FAILED: ClassVar[str] - - value: str - - -class JobPhase: - """Enum-like class for job phase values.""" - - DOWNLOADING: ClassVar[str] - DESERIALIZING: ClassVar[str] - GENERATING_EMBEDDINGS: ClassVar[str] - BUILDING_INDEX: ClassVar[str] - UPLOADING: ClassVar[str] - CLEANUP: ClassVar[str] - - value: str - - -class JobProgress: - """Progress update for a job.""" - - job_id: str - status: JobStatus - progress: float - current_phase: Optional[JobPhase] - - -class JobStatusResponse: - """Full status response from get_job_status.""" - - job_id: str - status: JobStatus - progress: float - current_phase: Optional[JobPhase] - error: Optional[str] - created_at: str - updated_at: str - completed_at: Optional[str] - - -class ModelRef: - id: str - version: str - def __init__(self, id: str, version: str) -> None: ... - - -class QueryResultDocumentInfo: - id: str - text: str - metadata: Optional[Dict[str, str]] - score: float - def __init__( - self, - id: str, - text: str, - metadata: Optional[Dict[str, str]] = ..., - score: float = ..., - ) -> None: ... - - -class DocumentInfo: - id: str - text: str - metadata: Optional[Dict[str, str]] - embedding: Optional[Sequence[float]] - def __init__( - self, - id: str, - text: str, - metadata: Optional[Dict[str, str]] = ..., - embedding: Optional[Sequence[float]] = ..., - ) -> None: ... - - -class QueryOptions: - embedding: Optional[Sequence[float]] - top_k: Optional[int] - alpha: Optional[float] - filter: Optional[dict] - def __init__( - self, - embedding: Optional[Sequence[float]] = ..., - top_k: Optional[int] = ..., - alpha: Optional[float] = ..., - filter: Optional[dict] = ..., - ) -> None: ... - - -class IndexInfo: - id: str - name: str - version: str - status: str - doc_count: int - created_at: str - updated_at: str - model: ModelRef - def __init__( - self, - id: str, - name: str, - version: str, - status: str, - doc_count: int, - created_at: str, - updated_at: str, - model: ModelRef, - ) -> None: ... - - -class SearchResult: - docs: List[QueryResultDocumentInfo] - query: str - index_name: Optional[str] - time_taken_ms: Optional[int] - def __init__( - self, - docs: List[QueryResultDocumentInfo], - query: str, - index_name: Optional[str] = None, - time_taken_ms: Optional[int] = None, - ) -> None: ... - - -class IndexStatus: - NotStarted: ClassVar[str] - Building: ClassVar[str] - Ready: ClassVar[str] - Failed: ClassVar[str] - def __init__(self, value: str) -> None: ... - - -IndexStatusValues: Dict[str, str] - -__version__: str - -__all__ = [ - "MossClient", - "DocumentInfo", - "GetDocumentsOptions", - "IndexInfo", - "SearchResult", - "QueryResultDocumentInfo", - "ModelRef", - "IndexStatus", - "IndexStatusValues", - "QueryOptions", - "MutationResult", - "MutationOptions", - "JobStatus", - "JobPhase", - "JobProgress", - "JobStatusResponse", -] +from __future__ import annotations + +from typing import Any, ClassVar, Dict, List, Optional, Sequence + + +class RerankOptions: + """Configuration for reranking passed to query().""" + + provider: str + top_n: Optional[int] + init_kwargs: Dict[str, Any] + + def __init__( + self, + provider: str, + top_n: Optional[int] = None, + **kwargs: Any, + ) -> None: ... + + +class MossClient: + """Semantic search client for vector similarity operations.""" + + DEFAULT_MODEL_ID: ClassVar[str] + + def __init__(self, project_id: str, project_key: str) -> None: ... + + async def create_index( + self, + name: str, + docs: List[DocumentInfo], + model_id: Optional[str] = ..., + ) -> MutationResult: ... + + async def add_docs( + self, + name: str, + docs: List[DocumentInfo], + options: Optional[MutationOptions] = None, + ) -> MutationResult: ... + + async def delete_docs( + self, + name: str, + doc_ids: List[str], + ) -> MutationResult: ... + + async def get_job_status(self, job_id: str) -> JobStatusResponse: ... + + async def get_index(self, name: str) -> IndexInfo: ... + + async def list_indexes(self) -> List[IndexInfo]: ... + + async def delete_index(self, name: str) -> bool: ... + + async def get_docs( + self, + name: str, + options: Optional[GetDocumentsOptions] = None, + ) -> List[DocumentInfo]: ... + + async def load_index( + self, + name: str, + auto_refresh: bool = False, + polling_interval_in_seconds: int = 600, + ) -> str: ... + + async def unload_index(self, name: str) -> None: ... + + async def query( + self, + name: str, + query: str, + options: Optional[QueryOptions] = None, + ) -> SearchResult: ... + + +class MutationResult: + """Return value from create_index/add_docs/delete_docs.""" + + job_id: str + index_name: str + doc_count: int + + +class MutationOptions: + """Options for add_docs (e.g. upsert behavior).""" + + upsert: Optional[bool] + + def __init__(self, upsert: Optional[bool] = None) -> None: ... + + +class GetDocumentsOptions: + """Options for get_docs (e.g. filter by document IDs).""" + + doc_ids: Optional[List[str]] + + def __init__(self, doc_ids: Optional[List[str]] = None) -> None: ... + + +class JobStatus: + """Enum-like class for job status values.""" + + PENDING_UPLOAD: ClassVar[str] + UPLOADING: ClassVar[str] + BUILDING: ClassVar[str] + COMPLETED: ClassVar[str] + FAILED: ClassVar[str] + + value: str + + +class JobPhase: + """Enum-like class for job phase values.""" + + DOWNLOADING: ClassVar[str] + DESERIALIZING: ClassVar[str] + GENERATING_EMBEDDINGS: ClassVar[str] + BUILDING_INDEX: ClassVar[str] + UPLOADING: ClassVar[str] + CLEANUP: ClassVar[str] + + value: str + + +class JobProgress: + """Progress update for a job.""" + + job_id: str + status: JobStatus + progress: float + current_phase: Optional[JobPhase] + + +class JobStatusResponse: + """Full status response from get_job_status.""" + + job_id: str + status: JobStatus + progress: float + current_phase: Optional[JobPhase] + error: Optional[str] + created_at: str + updated_at: str + completed_at: Optional[str] + + +class ModelRef: + id: str + version: str + def __init__(self, id: str, version: str) -> None: ... + + +class QueryResultDocumentInfo: + id: str + text: str + metadata: Optional[Dict[str, str]] + score: float + def __init__( + self, + id: str, + text: str, + metadata: Optional[Dict[str, str]] = ..., + score: float = ..., + ) -> None: ... + + +class DocumentInfo: + id: str + text: str + metadata: Optional[Dict[str, str]] + embedding: Optional[Sequence[float]] + def __init__( + self, + id: str, + text: str, + metadata: Optional[Dict[str, str]] = ..., + embedding: Optional[Sequence[float]] = ..., + ) -> None: ... + + +class QueryOptions: + embedding: Optional[Sequence[float]] + top_k: Optional[int] + alpha: Optional[float] + filter: Optional[dict] + rerank: Optional[RerankOptions] + def __init__( + self, + embedding: Optional[Sequence[float]] = ..., + top_k: Optional[int] = ..., + alpha: Optional[float] = ..., + filter: Optional[dict] = ..., + rerank: Optional[RerankOptions] = ..., + ) -> None: ... + + +class IndexInfo: + id: str + name: str + version: str + status: str + doc_count: int + created_at: str + updated_at: str + model: ModelRef + def __init__( + self, + id: str, + name: str, + version: str, + status: str, + doc_count: int, + created_at: str, + updated_at: str, + model: ModelRef, + ) -> None: ... + + +class SearchResult: + docs: List[QueryResultDocumentInfo] + query: str + index_name: Optional[str] + time_taken_ms: Optional[int] + def __init__( + self, + docs: List[QueryResultDocumentInfo], + query: str, + index_name: Optional[str] = None, + time_taken_ms: Optional[int] = None, + ) -> None: ... + + +class IndexStatus: + NotStarted: ClassVar[str] + Building: ClassVar[str] + Ready: ClassVar[str] + Failed: ClassVar[str] + def __init__(self, value: str) -> None: ... + + +IndexStatusValues: Dict[str, str] + +__version__: str + +__all__ = [ + "MossClient", + "RerankOptions", + "DocumentInfo", + "GetDocumentsOptions", + "IndexInfo", + "SearchResult", + "QueryResultDocumentInfo", + "ModelRef", + "IndexStatus", + "IndexStatusValues", + "QueryOptions", + "MutationResult", + "MutationOptions", + "JobStatus", + "JobPhase", + "JobProgress", + "JobStatusResponse", +] diff --git a/sdks/python/sdk/src/moss/client/models.py b/sdks/python/sdk/src/moss/client/models.py new file mode 100644 index 00000000..a785a5fc --- /dev/null +++ b/sdks/python/sdk/src/moss/client/models.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + + +@dataclass(init=False) +class RerankOptions: + """Configuration for reranking, passed inside QueryOptions. + + Example: + QueryOptions( + top_k=20, + rerank=RerankOptions(provider="cohere", api_key="...", top_n=5), + ) + + The provider name maps to a reranker class via the registry in + moss.rerankers. Any additional kwargs are passed to the reranker + constructor (e.g. api_key, model). + """ + + provider: str + top_n: Optional[int] + init_kwargs: Dict[str, Any] + + def __init__( + self, + provider: str, + top_n: Optional[int] = None, + **kwargs: Any, + ) -> None: + self.provider = provider + self.top_n = top_n + self.init_kwargs = kwargs + self._instance: Optional[Any] = None + + +@dataclass +class QueryOptions: + """Query options for semantic search, including optional reranking. + + Example: + QueryOptions( + top_k=20, + alpha=0.8, + rerank=RerankOptions(provider="cohere", api_key="...", top_n=5), + ) + """ + + embedding: Optional[List[float]] = None + top_k: Optional[int] = None + alpha: Optional[float] = None + filter: Optional[dict] = None + rerank: Optional[RerankOptions] = None diff --git a/sdks/python/sdk/src/moss/client/moss_client.py b/sdks/python/sdk/src/moss/client/moss_client.py index 51bd2833..b9553d0c 100644 --- a/sdks/python/sdk/src/moss/client/moss_client.py +++ b/sdks/python/sdk/src/moss/client/moss_client.py @@ -17,11 +17,13 @@ MutationOptions, MutationResult, JobStatusResponse, - QueryOptions, QueryResultDocumentInfo, SearchResult, ) +from ..rerankers import get_reranker +from .models import QueryOptions, RerankOptions + logger = logging.getLogger(__name__) @@ -189,24 +191,35 @@ async def query( Otherwise, falls back to the cloud query API. Args: - options: Query options (top_k, alpha, embedding, filter). Example filter: - QueryOptions(filter={"$and": [ - {"field": "city", "condition": {"$eq": "NYC"}}, - {"field": "price", "condition": {"$lt": "50"}}, - ]}) + options: Query options (top_k, alpha, embedding, filter, rerank). + Reranking is applied client-side after retrieval and works on + both the local and cloud paths. Example filter: + QueryOptions(filter={"$and": [ + {"field": "city", "condition": {"$eq": "NYC"}}, + {"field": "price", "condition": {"$lt": "50"}}, + ]}) """ is_loaded = await asyncio.to_thread(self._manager.has_index, name) if is_loaded: - return await self._query_local(name, query, options) + result = await self._query_local(name, query, options) + else: + if getattr(options, "filter", None) is not None: + logger.warning( + "Metadata filter ignored: filtering is only supported for locally loaded indexes. " + "Call load_index('%s') first.", + name, + ) + result = await self._query_cloud(name, query, options) - if getattr(options, "filter", None) is not None: - logger.warning( - "Metadata filter ignored: filtering is only supported for locally loaded indexes. " - "Call load_index('%s') first.", - name, - ) - return await self._query_cloud(name, query, options) + rerank = getattr(options, "rerank", None) + if rerank: + top_k = getattr(options, "top_k", None) + if top_k is None: + top_k = 5 + result = await self._apply_rerank(query, result, rerank, top_k) + + return result # -- Internal --------------------------------------------------- @@ -224,33 +237,59 @@ async def _query_local( alpha = 0.8 query_embedding = getattr(options, "embedding", None) filter = getattr(options, "filter", None) + rerank = getattr(options, "rerank", None) - if query_embedding is None: - try: - return await asyncio.to_thread( - self._manager.query_text, - name, - query, - top_k, - alpha, - filter, - ) - except RuntimeError as e: - if "requires explicit query embeddings" in str(e): - raise ValueError( - "This index uses custom embeddings. " - "Query embeddings must be provided via QueryOptions.embedding." - ) from e - raise + fetch_k = top_k * 4 if rerank else top_k - return await asyncio.to_thread( - self._manager.query, - name, - query, - list(query_embedding), - top_k, - alpha, - filter, + if query_embedding is not None: + return await asyncio.to_thread( + self._manager.query, + name, + query, + list(query_embedding), + fetch_k, + alpha, + filter, + ) + + try: + return await asyncio.to_thread( + self._manager.query_text, + name, + query, + fetch_k, + alpha, + filter, + ) + except RuntimeError as e: + if "requires explicit query embeddings" in str(e): + raise ValueError( + "This index uses custom embeddings. " + "Query embeddings must be provided via QueryOptions.embedding." + ) from e + raise + + @staticmethod + async def _apply_rerank( + query: str, + result: SearchResult, + rerank_opts: RerankOptions, + default_top_k: Optional[int], + ) -> SearchResult: + """Rerank search results. Works on both local and cloud paths.""" + if rerank_opts._instance is None: + rerank_opts._instance = get_reranker( + rerank_opts.provider, **rerank_opts.init_kwargs + ) + final_n = rerank_opts.top_n or default_top_k + reranked_docs = await rerank_opts._instance.rerank( + query, result.docs, top_k=final_n + ) + return SearchResult( + docs=reranked_docs, + query=result.query, + index_name=result.index_name, + time_taken_ms=result.time_taken_ms, ) async def _query_cloud( @@ -260,7 +299,11 @@ async def _query_cloud( options: Optional[QueryOptions], ) -> SearchResult: """Fallback: query via the cloud API when the index is not loaded locally.""" - top_k = getattr(options, "top_k", None) or 10 + top_k = getattr(options, "top_k", None) + if top_k is None: + top_k = 5 + rerank = getattr(options, "rerank", None) + fetch_k = top_k * 4 if rerank else top_k query_embedding = getattr(options, "embedding", None) request_body: Dict[str, Any] = { @@ -268,7 +311,7 @@ async def _query_cloud( "indexName": name, "projectId": self._project_id, "projectKey": self._project_key, - "topK": top_k, + "topK": fetch_k, } if query_embedding is not None: request_body["queryEmbedding"] = list(query_embedding) diff --git a/sdks/python/sdk/src/moss/rerankers/__init__.py b/sdks/python/sdk/src/moss/rerankers/__init__.py new file mode 100644 index 00000000..e432a728 --- /dev/null +++ b/sdks/python/sdk/src/moss/rerankers/__init__.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from typing import Any, Dict, Type + +from .base import Reranker + +_REGISTRY: Dict[str, Type[Reranker]] = {} + +_MISSING_PROVIDERS: Dict[str, str] = {} + + +def register_reranker(name: str, cls: Type[Reranker]) -> None: + """Register a reranker class under a provider name. + + Example: + class MyReranker: + async def rerank(self, query, documents, top_k=None, **kwargs): + ... + + register_reranker("my-reranker", MyReranker) + """ + _REGISTRY[name] = cls + + +def get_reranker(name: str, **kwargs: Any) -> Reranker: + """Instantiate a reranker by provider name. + + Raises: + ImportError: If the provider is built-in but its optional dependency + isn't installed (e.g. "cohere" without `pip install cohere`). + ValueError: If the provider name is not registered and not a known + built-in. + """ + if name in _MISSING_PROVIDERS: + package = _MISSING_PROVIDERS[name] + raise ImportError( + f"The '{name}' reranker requires the '{package}' package. " + f"Install it with: pip install {package}" + ) + if name not in _REGISTRY: + available = list(_REGISTRY) or ["(none registered)"] + raise ValueError( + f"Unknown reranker provider: '{name}'. " + f"Available: {available}. " + f"Register custom rerankers with register_reranker(name, cls)." + ) + return _REGISTRY[name](**kwargs) + + +__all__ = ["Reranker", "register_reranker", "get_reranker"] + +try: + from .cohere import CohereReranker + + register_reranker("cohere", CohereReranker) +except ImportError: + # `pip install cohere` to enable the Cohere reranker. + _MISSING_PROVIDERS["cohere"] = "cohere" diff --git a/sdks/python/sdk/src/moss/rerankers/base.py b/sdks/python/sdk/src/moss/rerankers/base.py new file mode 100644 index 00000000..4f40f55a --- /dev/null +++ b/sdks/python/sdk/src/moss/rerankers/base.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from typing import Any, List, Optional, Protocol, runtime_checkable + +from moss_core import QueryResultDocumentInfo + + +@runtime_checkable +class Reranker(Protocol): + """Protocol for reranking search results using a cross-encoder or API. + + Implement this protocol to create custom rerankers. The SDK ships with + built-in implementations (e.g., CohereReranker). + + Example: + class MyReranker: + async def rerank(self, query, documents, top_k=None): + # Your reranking logic here + return sorted(documents, key=lambda d: d.score, reverse=True) + """ + + async def rerank( + self, + query: str, + documents: List[QueryResultDocumentInfo], + top_k: Optional[int] = None, + **kwargs: Any, + ) -> List[QueryResultDocumentInfo]: + """Rerank documents by relevance to the query. + + Args: + query: The original search query. + documents: Retrieved documents with initial scores. + top_k: If set, return only the top-k reranked results. + **kwargs: Additional provider-specific options. + + Returns: + Documents reordered by relevance, with updated scores. + """ + ... diff --git a/sdks/python/sdk/src/moss/rerankers/cohere.py b/sdks/python/sdk/src/moss/rerankers/cohere.py new file mode 100644 index 00000000..12aa316e --- /dev/null +++ b/sdks/python/sdk/src/moss/rerankers/cohere.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import os +from typing import Any, List, Optional + +import cohere +from moss_core import QueryResultDocumentInfo + +from .base import Reranker + + +class CohereReranker(Reranker): + """Reranker using the Cohere Python SDK. + + Requires the `cohere` package: pip install cohere + + Example: + results = await client.query("index", "query", + rerank=RerankOptions(provider="cohere", api_key="your-cohere-key", top_n=5) + ) + """ + + def __init__( + self, + api_key: Optional[str] = None, + model: str = "rerank-v3.5", + **kwargs: Any, + ): + """Initialize the Cohere reranker. + + Args: + api_key: Cohere API key. Falls back to COHERE_API_KEY env var. + model: Cohere rerank model name. + **kwargs: Reserved for future provider-specific options. + """ + self.api_key = api_key or os.getenv("COHERE_API_KEY") + if not self.api_key: + raise ValueError( + "Cohere API key is required. Pass api_key or set COHERE_API_KEY env var." + ) + self.model = model + self.extra_options = kwargs + self._client = cohere.AsyncClientV2(api_key=self.api_key) + + async def rerank( + self, + query: str, + documents: List[QueryResultDocumentInfo], + top_k: Optional[int] = None, + **kwargs: Any, + ) -> List[QueryResultDocumentInfo]: + """Rerank documents using the Cohere Rerank API. + + Args: + query: The search query. + documents: Documents to rerank. + top_k: Number of top results to return. Defaults to all documents. + + Returns: + Documents reordered by Cohere relevance score. + """ + if not documents: + return [] + + doc_texts = [doc.text for doc in documents] + resolved_top_k = top_k or len(documents) + + response = await self._client.rerank( + model=self.model, + query=query, + documents=doc_texts, + top_n=resolved_top_k, + ) + + reranked = [] + for result in response.results: + original_doc = documents[result.index] + reranked.append( + QueryResultDocumentInfo( + id=original_doc.id, + text=original_doc.text, + metadata=original_doc.metadata, + score=result.relevance_score, + ) + ) + + return reranked diff --git a/sdks/python/sdk/tests/test_client.py b/sdks/python/sdk/tests/test_client.py index 86c1f2fc..4f099a50 100644 --- a/sdks/python/sdk/tests/test_client.py +++ b/sdks/python/sdk/tests/test_client.py @@ -354,6 +354,7 @@ async def test_query_with_custom_embedding(self, client): opts.alpha = 0.9 opts.embedding = [0.1, 0.2, 0.3] opts.filter = None + opts.rerank = None result = await client.query("idx", "search text", opts) @@ -377,6 +378,7 @@ async def test_query_defaults_top_k_and_alpha(self, client): opts.alpha = None opts.embedding = [0.5] opts.filter = None + opts.rerank = None await client.query("idx", "q", opts) @@ -394,6 +396,8 @@ async def test_query_raises_for_custom_model_without_embedding(self, client): opts.embedding = None opts.top_k = 5 opts.alpha = 0.8 + opts.filter = None + opts.rerank = None with pytest.raises(ValueError, match="custom embeddings"): await client.query("idx", "q", opts) @@ -418,6 +422,7 @@ async def test_query_passes_filter_to_manager(self, client): opts.top_k = 5 opts.alpha = 0.8 opts.embedding = [0.1] + opts.rerank = None metadata_filter = {"field": "city", "condition": {"$eq": "NYC"}} opts.filter = metadata_filter @@ -443,6 +448,7 @@ async def test_query_passes_none_filter_when_omitted(self, client): opts.alpha = 0.8 opts.embedding = [0.1] opts.filter = None + opts.rerank = None await client.query("idx", "q", opts) @@ -464,6 +470,7 @@ async def test_query_passes_complex_and_filter(self, client): opts.top_k = 10 opts.alpha = 0.8 opts.embedding = [0.5] + opts.rerank = None metadata_filter = { "$and": [ @@ -561,6 +568,8 @@ async def test_uses_local_when_index_loaded(self, client): opts.top_k = 5 opts.alpha = 0.8 opts.embedding = [0.1] + opts.filter = None + opts.rerank = None result = await client.query("idx", "q", opts) diff --git a/sdks/python/sdk/tests/test_client_extended.py b/sdks/python/sdk/tests/test_client_extended.py index 2b53f916..286435e5 100644 --- a/sdks/python/sdk/tests/test_client_extended.py +++ b/sdks/python/sdk/tests/test_client_extended.py @@ -100,6 +100,7 @@ async def test_cloud_fallback_with_custom_embedding(self, unloaded_client): opts.top_k = 5 opts.embedding = [0.1, 0.2, 0.3] opts.filter = None + opts.rerank = None result = await unloaded_client.query("idx", "test query", opts) @@ -263,6 +264,7 @@ async def test_query_with_only_top_k(self, client): opts.alpha = None opts.embedding = None opts.filter = None + opts.rerank = None await client.query("idx", "test", opts) @@ -278,6 +280,7 @@ async def test_query_with_only_alpha(self, client): opts.alpha = 0.5 opts.embedding = None opts.filter = None + opts.rerank = None await client.query("idx", "test", opts) @@ -293,6 +296,7 @@ async def test_query_alpha_zero_keyword_only(self, client): opts.alpha = 0 opts.embedding = None opts.filter = None + opts.rerank = None await client.query("idx", "test", opts) @@ -308,6 +312,7 @@ async def test_query_alpha_one_semantic_only(self, client): opts.alpha = 1 opts.embedding = None opts.filter = None + opts.rerank = None await client.query("idx", "test", opts) @@ -333,6 +338,7 @@ async def test_filter_warning_logged_when_unloaded(self, unloaded_client, caplog opts = MagicMock() opts.filter = {"field": "city", "condition": {"$eq": "NYC"}} + opts.rerank = None with caplog.at_level("WARNING"): await unloaded_client.query("idx", "test", opts) diff --git a/sdks/python/sdk/tests/test_rerankers.py b/sdks/python/sdk/tests/test_rerankers.py new file mode 100644 index 00000000..30ded1bc --- /dev/null +++ b/sdks/python/sdk/tests/test_rerankers.py @@ -0,0 +1,190 @@ +import unittest +from unittest.mock import AsyncMock, patch + +from moss_core import QueryResultDocumentInfo +from moss import RerankOptions +from moss.rerankers import ( + CohereReranker, + _REGISTRY, + get_reranker, + register_reranker, +) +from moss.rerankers.base import Reranker as RerankerProtocol + + +class TestRerankerRegistry(unittest.TestCase): + """Tests for the reranker registry.""" + + def test_cohere_registered_by_default(self): + self.assertIn("cohere", _REGISTRY) + self.assertIs(_REGISTRY["cohere"], CohereReranker) + + def test_register_custom_reranker(self): + class MyReranker: + async def rerank(self, query, documents, top_k=None, **kwargs): + return documents + + register_reranker("my-custom", MyReranker) + try: + self.assertIn("my-custom", _REGISTRY) + finally: + _REGISTRY.pop("my-custom", None) + + def test_get_reranker_instantiates_with_kwargs(self): + class DummyReranker: + def __init__(self, **kwargs): + self.kwargs = kwargs + + async def rerank(self, query, documents, top_k=None, **kwargs): + return documents + + register_reranker("dummy", DummyReranker) + try: + instance = get_reranker("dummy", api_key="abc", model="test-model") + self.assertIsInstance(instance, DummyReranker) + self.assertEqual(instance.kwargs, {"api_key": "abc", "model": "test-model"}) + finally: + _REGISTRY.pop("dummy", None) + + def test_get_unknown_reranker_raises(self): + with self.assertRaises(ValueError) as ctx: + get_reranker("nonexistent") + self.assertIn("Unknown reranker provider", str(ctx.exception)) + + +class TestRerankOptions(unittest.TestCase): + """Tests for RerankOptions.""" + + def test_stores_provider_and_kwargs(self): + opts = RerankOptions(provider="cohere", api_key="key", top_n=5) + self.assertEqual(opts.provider, "cohere") + self.assertEqual(opts.top_n, 5) + self.assertEqual(opts.init_kwargs, {"api_key": "key"}) + + def test_default_top_n(self): + opts = RerankOptions(provider="cohere", api_key="key") + self.assertIsNone(opts.top_n) + + def test_instance_cache_starts_empty(self): + opts = RerankOptions(provider="cohere", api_key="key") + self.assertIsNone(opts._instance) + + def test_multiple_kwargs_forwarded(self): + opts = RerankOptions( + provider="cohere", api_key="k", model="rerank-v3.5", top_n=3 + ) + self.assertEqual(opts.init_kwargs, {"api_key": "k", "model": "rerank-v3.5"}) + + +class TestCohereRerankerProtocol(unittest.TestCase): + """CohereReranker satisfies the Reranker protocol.""" + + def test_protocol_satisfied(self): + with patch.dict("os.environ", {"COHERE_API_KEY": "test-key"}): + reranker = CohereReranker() + self.assertIsInstance(reranker, RerankerProtocol) + + def test_custom_reranker_satisfies_protocol(self): + class MyReranker: + async def rerank(self, query, documents, top_k=None, **kwargs): + return documents + + self.assertIsInstance(MyReranker(), RerankerProtocol) + + +class TestCohereReranker(unittest.IsolatedAsyncioTestCase): + """Tests for CohereReranker instantiation and behavior.""" + + def test_init_with_api_key(self): + reranker = CohereReranker(api_key="test-key") + self.assertEqual(reranker.api_key, "test-key") + self.assertEqual(reranker.model, "rerank-v3.5") + + def test_init_from_env(self): + with patch.dict("os.environ", {"COHERE_API_KEY": "env-key"}): + reranker = CohereReranker() + self.assertEqual(reranker.api_key, "env-key") + + def test_init_no_key_raises(self): + with patch.dict("os.environ", {}, clear=True): + with self.assertRaises(ValueError) as ctx: + CohereReranker() + self.assertIn("Cohere API key is required", str(ctx.exception)) + + def test_init_custom_model(self): + reranker = CohereReranker(api_key="key", model="rerank-english-v2.0") + self.assertEqual(reranker.model, "rerank-english-v2.0") + + def test_init_accepts_kwargs(self): + reranker = CohereReranker(api_key="key", custom_option="value") + self.assertEqual(reranker.extra_options, {"custom_option": "value"}) + + async def test_rerank_empty_documents(self): + reranker = CohereReranker(api_key="test-key") + result = await reranker.rerank("query", []) + self.assertEqual(result, []) + + async def test_rerank_calls_cohere_sdk(self): + reranker = CohereReranker(api_key="test-key") + reranker._client = AsyncMock() + + mock_result_1 = type("Result", (), {"index": 1, "relevance_score": 0.95})() + mock_result_2 = type("Result", (), {"index": 0, "relevance_score": 0.72})() + mock_response = type( + "Response", (), {"results": [mock_result_1, mock_result_2]} + )() + reranker._client.rerank = AsyncMock(return_value=mock_response) + + docs = [ + QueryResultDocumentInfo(id="d1", text="first doc", score=0.8), + QueryResultDocumentInfo(id="d2", text="second doc", score=0.6), + ] + + result = await reranker.rerank("test query", docs) + + self.assertEqual(len(result), 2) + self.assertEqual(result[0].id, "d2") + self.assertAlmostEqual(result[0].score, 0.95, places=2) + self.assertEqual(result[1].id, "d1") + self.assertAlmostEqual(result[1].score, 0.72, places=2) + + reranker._client.rerank.assert_awaited_once_with( + model="rerank-v3.5", + query="test query", + documents=["first doc", "second doc"], + top_n=2, + ) + + async def test_rerank_with_top_k(self): + reranker = CohereReranker(api_key="test-key") + reranker._client = AsyncMock() + + mock_result = type("Result", (), {"index": 0, "relevance_score": 0.9})() + mock_response = type("Response", (), {"results": [mock_result]})() + reranker._client.rerank = AsyncMock(return_value=mock_response) + + docs = [ + QueryResultDocumentInfo(id="d1", text="doc1", score=0.5), + QueryResultDocumentInfo(id="d2", text="doc2", score=0.4), + QueryResultDocumentInfo(id="d3", text="doc3", score=0.3), + ] + + await reranker.rerank("query", docs, top_k=1) + + call_kwargs = reranker._client.rerank.call_args.kwargs + self.assertEqual(call_kwargs["top_n"], 1) + + async def test_rerank_sdk_error(self): + reranker = CohereReranker(api_key="bad-key") + reranker._client = AsyncMock() + reranker._client.rerank = AsyncMock(side_effect=Exception("Unauthorized")) + + docs = [QueryResultDocumentInfo(id="d1", text="doc", score=0.5)] + + with self.assertRaises(Exception) as ctx: + await reranker.rerank("query", docs) + self.assertIn("Unauthorized", str(ctx.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/sdks/python/sdk/tests/test_types.py b/sdks/python/sdk/tests/test_types.py index f5694c28..b9411937 100644 --- a/sdks/python/sdk/tests/test_types.py +++ b/sdks/python/sdk/tests/test_types.py @@ -20,6 +20,7 @@ async def test_query_with_filter_option(self, client): opts.alpha = 0.8 opts.embedding = None opts.filter = {"field": "city", "condition": {"$eq": "NYC"}} + opts.rerank = None await client.query("idx", "test", opts) @@ -38,6 +39,7 @@ async def test_query_with_custom_embedding_option(self, client): opts.alpha = 0.5 opts.embedding = embedding opts.filter = None + opts.rerank = None await client.query("idx", "test", opts)