diff --git a/docs/api_reference/document_search/retrieval/rephrasers.md b/docs/api_reference/document_search/retrieval/rephrasers.md index 2a6acb6d2..bd9cf8a0d 100644 --- a/docs/api_reference/document_search/retrieval/rephrasers.md +++ b/docs/api_reference/document_search/retrieval/rephrasers.md @@ -1,6 +1,11 @@ # Query Rephrasers -::: ragbits.document_search.retrieval.rephrasers.QueryRephraser -::: ragbits.document_search.retrieval.rephrasers.LLMQueryRephraser -::: ragbits.document_search.retrieval.rephrasers.MultiQueryRephraser -::: ragbits.document_search.retrieval.rephrasers.NoopQueryRephraser \ No newline at end of file +::: ragbits.document_search.retrieval.rephrasers.base.QueryRephraserOptions + +::: ragbits.document_search.retrieval.rephrasers.llm.LLMQueryRephraserOptions + +::: ragbits.document_search.retrieval.rephrasers.base.QueryRephraser + +::: ragbits.document_search.retrieval.rephrasers.llm.LLMQueryRephraser + +::: ragbits.document_search.retrieval.rephrasers.noop.NoopQueryRephraser diff --git a/docs/how-to/document_search/search-documents.md b/docs/how-to/document_search/search-documents.md index 4690ba612..94c7bc2d6 100644 --- a/docs/how-to/document_search/search-documents.md +++ b/docs/how-to/document_search/search-documents.md @@ -94,10 +94,10 @@ By default, the input query is provided directly to the embedding model. However === "Multi query" ```python - from ragbits.document_search.retrieval.rephrasers import MultiQueryRephraser + from ragbits.document_search.retrieval.rephrasers import LLMQueryRephraser, LLMQueryRephraserOptions from ragbits.document_search import DocumentSearch - query_rephraser = MultiQueryRephraser(LiteLLM(model_name="gpt-3.5-turbo"), n=3) + query_rephraser = LLMQueryRephraser(LiteLLM(model_name="gpt-3.5-turbo"), default_options=LLMQueryRephraserOptions(n=3)) document_search = DocumentSearch(query_rephraser=query_rephraser, ...) elements = await document_search.search("What is the capital of Poland?") @@ -108,25 +108,28 @@ By default, the input query is provided directly to the embedding model. However To define a new rephraser, extend the the [`QueryRephraser`][ragbits.document_search.retrieval.rephrasers.base.QueryRephraser] class. ```python -from ragbits.document_search.retrieval.rephrasers import QueryRephraser +from ragbits.document_search.retrieval.rephrasers import QueryRephraser, QueryRephraserOptions -class CustomRephraser(QueryRephraser): +class CustomRephraser(QueryRephraser[QueryRephraserOptions]): """ Rephraser that uses a LLM to rephrase queries. """ - async def rephrase(self, query: str) -> list[str]: + options_cls: type[QueryRephraserOptions] = QueryRephraserOptions + + async def rephrase(self, query: str, options: QueryRephraserOptions | None = None) -> Iterable[str]: """ Rephrase a query using the LLM. Args: query: The query to be rephrased. + options: The options for rephrasing. Returns: List containing the rephrased query. """ - responses = await llm.generate(QueryRephraserPrompt(...)) + responses = await llm.generate(CustomRephraserPrompt(...)) ... return [...] ``` @@ -175,6 +178,8 @@ class CustomReranker(Reranker[RerankerOptions]): Reranker that uses a LLM to rerank elements. """ + options_cls: type[RerankerOptions] = RerankerOptions + async def rerank( self, elements: Sequence[Sequence[Element]], diff --git a/examples/document-search/configurable.py b/examples/document-search/configurable.py index 7335c5920..f686127e0 100644 --- a/examples/document-search/configurable.py +++ b/examples/document-search/configurable.py @@ -30,7 +30,7 @@ class to rephrase the query. import asyncio -from ragbits.core.audit import set_trace_handlers +from ragbits.core.audit.traces import set_trace_handlers from ragbits.document_search import DocumentSearch from ragbits.document_search.documents.document import DocumentMeta @@ -86,13 +86,12 @@ class to rephrase the query. "model": "cohere/rerank-english-v3.0", "default_options": { "top_n": 3, - "max_chunks_per_doc": None, }, }, }, "parser_router": {"txt": {"type": "TextDocumentParser"}}, "rephraser": { - "type": "LLMQueryRephraser", + "type": "ragbits.document_search.retrieval.rephrasers:LLMQueryRephraser", "config": { "llm": { "type": "ragbits.core.llms.litellm:LiteLLM", @@ -101,7 +100,13 @@ class to rephrase the query. }, }, "prompt": { - "type": "QueryRephraserPrompt", + "type": "ragbits.document_search.retrieval.rephrasers:LLMQueryRephraserPrompt", + }, + "default_options": { + "n": 2, + "llm_options": { + "temperature": 0.0, + }, }, }, }, diff --git a/packages/ragbits-document-search/CHANGELOG.md b/packages/ragbits-document-search/CHANGELOG.md index d0d80dbac..9e8ed8ddf 100644 --- a/packages/ragbits-document-search/CHANGELOG.md +++ b/packages/ragbits-document-search/CHANGELOG.md @@ -2,6 +2,7 @@ ## Unreleased +- Add query rephraser options (#560) - Rename DocumentMeta create_text_document_from_literal to from_literal (#561) - Update audit imports (#427) - BREAKING CHANGE: Adjust document search configurable interface (#554) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index 91a21ac8f..bc853f4b7 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -25,21 +25,23 @@ from ragbits.document_search.ingestion.parsers.router import DocumentParserRouter from ragbits.document_search.ingestion.strategies import IngestStrategy, SequentialIngestStrategy from ragbits.document_search.ingestion.strategies.base import IngestExecutionError, IngestExecutionResult -from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser +from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser, QueryRephraserOptionsT from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser from ragbits.document_search.retrieval.rerankers.base import Reranker, RerankerOptionsT from ragbits.document_search.retrieval.rerankers.noop import NoopReranker -class DocumentSearchOptions(Options, Generic[VectorStoreOptionsT, RerankerOptionsT]): +class DocumentSearchOptions(Options, Generic[QueryRephraserOptionsT, VectorStoreOptionsT, RerankerOptionsT]): """ Object representing the options for the document search. Attributes: + query_rephraser_options: The options for the query rephraser. vector_store_options: The options for the vector store. reranker_options: The options for the reranker. """ + query_rephraser_options: QueryRephraserOptionsT | None | NotGiven = NOT_GIVEN vector_store_options: VectorStoreOptionsT | None | NotGiven = NOT_GIVEN reranker_options: RerankerOptionsT | None | NotGiven = NOT_GIVEN @@ -57,7 +59,9 @@ class DocumentSearchConfig(BaseModel): enricher_router: dict[str, ObjectConstructionConfig] = {} -class DocumentSearch(ConfigurableComponent[DocumentSearchOptions[VectorStoreOptionsT, RerankerOptionsT]]): +class DocumentSearch( + ConfigurableComponent[DocumentSearchOptions[QueryRephraserOptionsT, VectorStoreOptionsT, RerankerOptionsT]] +): """ Main entrypoint to the document search functionality. It provides methods for document retrieval and ingestion. @@ -80,9 +84,14 @@ def __init__( self, vector_store: VectorStore[VectorStoreOptionsT], *, - query_rephraser: QueryRephraser | None = None, + query_rephraser: QueryRephraser[QueryRephraserOptionsT] | None = None, reranker: Reranker[RerankerOptionsT] | None = None, - default_options: DocumentSearchOptions[VectorStoreOptionsT, RerankerOptionsT] | None = None, + default_options: DocumentSearchOptions[ + QueryRephraserOptionsT, + VectorStoreOptionsT, + RerankerOptionsT, + ] + | None = None, ingest_strategy: IngestStrategy | None = None, parser_router: DocumentParserRouter | None = None, enricher_router: ElementEnricherRouter | None = None, @@ -124,9 +133,9 @@ def from_config(cls, config: dict) -> Self: """ model = DocumentSearchConfig.model_validate(config) - query_rephraser = QueryRephraser.subclass_from_config(model.rephraser) - reranker: Reranker = Reranker.subclass_from_config(model.reranker) + query_rephraser: QueryRephraser = QueryRephraser.subclass_from_config(model.rephraser) vector_store: VectorStore = VectorStore.subclass_from_config(model.vector_store) + reranker: Reranker = Reranker.subclass_from_config(model.reranker) ingest_strategy = IngestStrategy.subclass_from_config(model.ingest_strategy) parser_router = DocumentParserRouter.from_config(model.parser_router) @@ -192,7 +201,7 @@ def preferred_subclass( async def search( self, query: str, - options: DocumentSearchOptions[VectorStoreOptionsT, RerankerOptionsT] | None = None, + options: DocumentSearchOptions[QueryRephraserOptionsT, VectorStoreOptionsT, RerankerOptionsT] | None = None, ) -> Sequence[Element]: """ Search for the most relevant chunks for a query. @@ -205,11 +214,12 @@ async def search( A list of chunks. """ merged_options = (self.default_options | options) if options else self.default_options + query_rephraser_options = merged_options.query_rephraser_options or None vector_store_options = merged_options.vector_store_options or None reranker_options = merged_options.reranker_options or None with trace(query=query, options=merged_options) as outputs: - queries = await self.query_rephraser.rephrase(query) + queries = await self.query_rephraser.rephrase(query, query_rephraser_options) elements = [ [ Element.from_vector_db_entry(result.entry, result.score) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py index 46527e591..c9532b488 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py @@ -1,21 +1,18 @@ -from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser -from ragbits.document_search.retrieval.rephrasers.llm import LLMQueryRephraser -from ragbits.document_search.retrieval.rephrasers.multi import MultiQueryRephraser -from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser -from ragbits.document_search.retrieval.rephrasers.prompts import ( - MultiQueryRephraserInput, - MultiQueryRephraserPrompt, - QueryRephraserInput, - QueryRephraserPrompt, +from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser, QueryRephraserOptions +from ragbits.document_search.retrieval.rephrasers.llm import ( + LLMQueryRephraser, + LLMQueryRephraserOptions, + LLMQueryRephraserPrompt, + LLMQueryRephraserPromptInput, ) +from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser __all__ = [ "LLMQueryRephraser", - "MultiQueryRephraser", - "MultiQueryRephraserInput", - "MultiQueryRephraserPrompt", + "LLMQueryRephraserOptions", + "LLMQueryRephraserPrompt", + "LLMQueryRephraserPromptInput", "NoopQueryRephraser", "QueryRephraser", - "QueryRephraserInput", - "QueryRephraserPrompt", + "QueryRephraserOptions", ] diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py index 0b68b54ab..95b98f40e 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py @@ -1,25 +1,38 @@ from abc import ABC, abstractmethod -from typing import ClassVar +from collections.abc import Iterable +from typing import ClassVar, TypeVar -from ragbits.core.utils.config_handling import WithConstructionConfig +from ragbits.core.options import Options +from ragbits.core.utils.config_handling import ConfigurableComponent from ragbits.document_search.retrieval import rephrasers -class QueryRephraser(WithConstructionConfig, ABC): +class QueryRephraserOptions(Options): + """ + Object representing the options for the rephraser. + """ + + +QueryRephraserOptionsT = TypeVar("QueryRephraserOptionsT", bound=QueryRephraserOptions) + + +class QueryRephraser(ConfigurableComponent[QueryRephraserOptionsT], ABC): """ Rephrases a query. Can provide multiple rephrased queries from one sentence / question. """ + options_cls: type[QueryRephraserOptionsT] default_module: ClassVar = rephrasers configuration_key: ClassVar = "rephraser" @abstractmethod - async def rephrase(self, query: str) -> list[str]: + async def rephrase(self, query: str, options: QueryRephraserOptionsT | None = None) -> Iterable[str]: """ Rephrase a query. Args: query: The query to rephrase. + options: The options for the rephraser. Returns: The rephrased queries. diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py index f806e4223..f60dc1c84 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py @@ -1,40 +1,108 @@ -from typing import Any +from collections.abc import Iterable +from typing import Generic + +from pydantic import BaseModel +from typing_extensions import Self from ragbits.core.audit.traces import traceable -from ragbits.core.llms.base import LLM +from ragbits.core.llms.base import LLM, LLMClientOptionsT from ragbits.core.prompt import Prompt -from ragbits.core.utils.config_handling import ObjectConstructionConfig -from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser -from ragbits.document_search.retrieval.rephrasers.prompts import ( - QueryRephraserInput, - QueryRephraserPrompt, - get_rephraser_prompt, -) +from ragbits.core.types import NOT_GIVEN, NotGiven +from ragbits.core.utils.config_handling import ObjectConstructionConfig, import_by_path +from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser, QueryRephraserOptions + + +class LLMQueryRephraserPromptInput(BaseModel): + """ + Input data for the query rephraser prompt. + """ + + query: str + n: int | None = None + + +class LLMQueryRephraserPrompt(Prompt[LLMQueryRephraserPromptInput, list]): + """ + Prompt for generating a rephrased user query. + """ + + system_prompt = """ + You are an expert in query rephrasing and clarity improvement. + {%- if n and n > 1 %} + Your task is to generate {{ n }} different versions of the given user query to retrieve relevant documents + from a vector database. They can be phrased as statements, as they will be used as a search query. + By generating multiple perspectives on the user query, your goal is to help the user overcome some of the + limitations of the distance-based similarity search. + Alternative queries should only contain information present in the original query. Do not include anything + in the alternative query, you have not seen in the original version. + It is VERY important you DO NOT ADD any comments or notes. Return ONLY alternative queries. + Provide these alternative queries separated by newlines. DO NOT ADD any enumeration. + {%- else %} + Your task is to return a single paraphrased version of a user's query, + correcting any typos, handling abbreviations and improving clarity. + Focus on making the query more precise and readable while keeping its original intent. + Just return the rephrased query. No additional explanations are needed. + {%- endif %} + """ + user_prompt = "Query: {{ query }}" + + @staticmethod + def _response_parser(value: str) -> list[str]: + return [stripped_line for line in value.strip().split("\n") if (stripped_line := line.strip())] + + response_parser = _response_parser -class LLMQueryRephraser(QueryRephraser): +class LLMQueryRephraserOptions(QueryRephraserOptions, Generic[LLMClientOptionsT]): + """ + Object representing the options for the LLM query rephraser. + + Attributes: + n: The number of rephrasings to generate. Any number below 2 will generate only one rephrasing. + llm_options: The options for the LLM. + """ + + n: int | None | NotGiven = NOT_GIVEN + llm_options: LLMClientOptionsT | None | NotGiven = NOT_GIVEN + + +class LLMQueryRephraser(QueryRephraser[LLMQueryRephraserOptions[LLMClientOptionsT]]): """ A rephraser class that uses a LLM to rephrase queries. """ - def __init__(self, llm: LLM, prompt: type[Prompt[QueryRephraserInput, Any]] | None = None): + options_cls: type[LLMQueryRephraserOptions] = LLMQueryRephraserOptions + + def __init__( + self, + llm: LLM[LLMClientOptionsT], + prompt: type[Prompt[LLMQueryRephraserPromptInput, list[str]]] | None = None, + default_options: LLMQueryRephraserOptions[LLMClientOptionsT] | None = None, + ) -> None: """ Initialize the LLMQueryRephraser with a LLM. Args: llm: A LLM instance to handle query rephrasing. prompt: The prompt to use for rephrasing queries. + default_options: The default options for the rephraser. """ + super().__init__(default_options=default_options) self._llm = llm - self._prompt = prompt or QueryRephraserPrompt + self._prompt = prompt or LLMQueryRephraserPrompt @traceable - async def rephrase(self, query: str) -> list[str]: + async def rephrase( + self, + query: str, + options: LLMQueryRephraserOptions[LLMClientOptionsT] | None = None, + ) -> Iterable[str]: """ Rephrase a given query using the LLM. Args: query: The query to be rephrased. If not provided, a custom prompt must be given. + options: The options for the rephraser. Returns: A list containing the rephrased query. @@ -44,13 +112,13 @@ async def rephrase(self, query: str) -> list[str]: LLMStatusError: If the LLM API returns an error status code. LLMResponseError: If the LLM API response is invalid. """ - input_data = self._prompt.input_type(query=query) # type: ignore - prompt = self._prompt(input_data) - response = await self._llm.generate(prompt) - return response if isinstance(response, list) else [response] + merged_options = (self.default_options | options) if options else self.default_options + llm_options = merged_options.llm_options or None + prompt = self._prompt(LLMQueryRephraserPromptInput(query=query, n=merged_options.n or None)) + return await self._llm.generate(prompt, options=llm_options) @classmethod - def from_config(cls, config: dict) -> "LLMQueryRephraser": + def from_config(cls, config: dict) -> Self: """ Create an instance of `LLMQueryRephraser` from a configuration dictionary. @@ -63,12 +131,11 @@ def from_config(cls, config: dict) -> "LLMQueryRephraser": Raises: ValidationError: If the LLM or prompt configuration doesn't follow the expected format. InvalidConfigError: If an LLM or prompt class can't be found or is not the correct type. - ValueError: If the prompt class is not a subclass of `Prompt`. - """ - llm: LLM = LLM.subclass_from_config(ObjectConstructionConfig.model_validate(config["llm"])) - prompt_cls = None - if "prompt" in config: - prompt_config = ObjectConstructionConfig.model_validate(config["prompt"]) - prompt_cls = get_rephraser_prompt(prompt_config.type) - return cls(llm=llm, prompt=prompt_cls) + config["llm"] = LLM.subclass_from_config(ObjectConstructionConfig.model_validate(config["llm"])) + config["prompt"] = ( + import_by_path(ObjectConstructionConfig.model_validate(config["prompt"]).type) + if "prompt" in config + else None + ) + return super().from_config(config) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py deleted file mode 100644 index 71f0cda7c..000000000 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import Any - -from ragbits.core.audit.traces import traceable -from ragbits.core.llms.base import LLM -from ragbits.core.prompt import Prompt -from ragbits.core.utils.config_handling import ObjectConstructionConfig -from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser -from ragbits.document_search.retrieval.rephrasers.prompts import ( - MultiQueryRephraserInput, - MultiQueryRephraserPrompt, - get_rephraser_prompt, -) - - -class MultiQueryRephraser(QueryRephraser): - """ - A rephraser class that uses a LLM to generate reworded versions of input query. - """ - - def __init__( - self, llm: LLM, n: int | None = None, prompt: type[Prompt[MultiQueryRephraserInput, Any]] | None = None - ): - """ - Initialize the MultiQueryRephraser with a LLM. - - Args: - llm: A LLM instance to handle query rephrasing. - n: The number of rephrasings to generate. - prompt: The prompt to use for rephrasing queries. - """ - self._llm = llm - self._n = n if n else 5 - self._prompt = prompt or MultiQueryRephraserPrompt - - @traceable - async def rephrase(self, query: str) -> list[str]: - """ - Rephrase a given query using the LLM. - - Args: - query: The query to be rephrased. If not provided, a custom prompt must be given. - - Returns: - A list containing the reworded versions of input query. - - Raises: - LLMConnectionError: If there is a connection error with the LLM API. - LLMStatusError: If the LLM API returns an error status code. - LLMResponseError: If the LLM API response is invalid. - """ - input_data = self._prompt.input_type(query=query, n=self._n) # type: ignore - prompt = self._prompt(input_data) - response = await self._llm.generate(prompt) - return [query] + response - - @classmethod - def from_config(cls, config: dict) -> "MultiQueryRephraser": - """ - Create an instance of `MultiQueryRephraser` from a configuration dictionary. - - Args: - config: A dictionary containing configuration settings for the rephraser. - - Returns: - An instance of the rephraser class initialized with the provided configuration. - - Raises: - ValidationError: If the LLM or prompt configuration doesn't follow the expected format. - InvalidConfigError: If an LLM or prompt class can't be found or is not the correct type. - ValueError: If the prompt class is not a subclass of `Prompt`. - - """ - llm: LLM = LLM.subclass_from_config(ObjectConstructionConfig.model_validate(config["llm"])) - prompt_cls = None - if "prompt" in config: - prompt_config = ObjectConstructionConfig.model_validate(config["prompt"]) - prompt_cls = get_rephraser_prompt(prompt_config.type) - n = config.get("n", 5) - return cls(llm=llm, n=n, prompt=prompt_cls) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py index c8cccfa34..576b9beaf 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py @@ -1,19 +1,24 @@ +from collections.abc import Iterable + from ragbits.core.audit.traces import traceable -from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser +from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser, QueryRephraserOptions -class NoopQueryRephraser(QueryRephraser): +class NoopQueryRephraser(QueryRephraser[QueryRephraserOptions]): """ A no-op query paraphraser that does not change the query. """ + options_cls: type[QueryRephraserOptions] = QueryRephraserOptions + @traceable - async def rephrase(self, query: str) -> list[str]: # noqa: PLR6301 + async def rephrase(self, query: str, options: QueryRephraserOptions | None = None) -> Iterable[str]: # noqa: PLR6301 """ Mock implementation which outputs the same query as in input. Args: query: The query to rephrase. + options: The options for the rephraser. Returns: The list with non-transformed query. diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py deleted file mode 100644 index ba8332c1a..000000000 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py +++ /dev/null @@ -1,89 +0,0 @@ -import sys -from typing import Any - -from pydantic import BaseModel - -from ragbits.core.prompt.prompt import Prompt -from ragbits.core.utils.config_handling import import_by_path - -module = sys.modules[__name__] - - -class QueryRephraserInput(BaseModel): - """ - Input data for the query rephraser prompt. - """ - - query: str - - -class QueryRephraserPrompt(Prompt[QueryRephraserInput, str]): - """ - A prompt class for generating a rephrased version of a user's query using a LLM. - """ - - user_prompt = "{{ query }}" - system_prompt = ( - "You are an expert in query rephrasing and clarity improvement. " - "Your task is to return a single paraphrased version of a user's query, " - "correcting any typos, handling abbreviations and improving clarity. " - "Focus on making the query more precise and readable while keeping its original intent.\n\n" - "Just return the rephrased query. No additional explanations are needed." - ) - - -class MultiQueryRephraserInput(BaseModel): - """ - Represents the input data for the multi query rephraser prompt. - """ - - query: str - n: int - - -class MultiQueryRephraserPrompt(Prompt[MultiQueryRephraserInput, list]): - """ - A prompt template for generating multiple query rephrasings. - """ - - user_prompt = "{{ query }}" - system_prompt = ( - "You are a helpful assistant that creates short rephrased versions of a given query. " - "Your task is to generate {{ n }} different versions of the given user query to retrieve relevant documents" - " from a vector database. They can be phrased as statements, as they will be used as a search query. " - "By generating multiple perspectives on the user query, " - "your goal is to help the user overcome some of the limitations of the distance-based similarity search." - "Alternative queries should only contain information present in the original query. Do not include anything" - " in the alternative query, you have not seen in the original version.\n\n" - "It is VERY important you DO NOT ADD any comments or notes. Return ONLY alternative queries. " - "Provide these alternative queries separated by newlines. " - "DO NOT ADD any enumeration." - ) - - @staticmethod - def _list_parser(value: str) -> list[str]: - return value.split("\n") - - response_parser = _list_parser - - -def get_rephraser_prompt(prompt: str) -> type[Prompt[Any, Any]]: - """ - Initializes and returns a QueryRephraser object based on the provided configuration. - - Args: - prompt: The prompt class to use for rephrasing queries. - - Returns: - An instance of the specified QueryRephraser class, initialized with the provided config - (if any) or default arguments. - - Raises: - ValueError: If the prompt class is not a subclass of `Prompt`. - """ - prompt_cls = import_by_path(prompt, module) - - if not issubclass(prompt_cls, Prompt): - raise ValueError(f"Invalid rephraser prompt class: {prompt_cls}") - - return prompt_cls diff --git a/packages/ragbits-document-search/tests/unit/test_rephrasers.py b/packages/ragbits-document-search/tests/unit/test_rephrasers.py index 8b31464dd..de11d02d7 100644 --- a/packages/ragbits-document-search/tests/unit/test_rephrasers.py +++ b/packages/ragbits-document-search/tests/unit/test_rephrasers.py @@ -1,23 +1,21 @@ from ragbits.core.llms.litellm import LiteLLM from ragbits.core.utils.config_handling import ObjectConstructionConfig from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser -from ragbits.document_search.retrieval.rephrasers.llm import LLMQueryRephraser -from ragbits.document_search.retrieval.rephrasers.multi import MultiQueryRephraser +from ragbits.document_search.retrieval.rephrasers.llm import LLMQueryRephraser, LLMQueryRephraserPrompt from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser -from ragbits.document_search.retrieval.rephrasers.prompts import MultiQueryRephraserPrompt, QueryRephraserPrompt def test_subclass_from_config(): config = ObjectConstructionConfig.model_validate( {"type": "ragbits.document_search.retrieval.rephrasers:NoopQueryRephraser"} ) - rephraser = QueryRephraser.subclass_from_config(config) + rephraser: QueryRephraser = QueryRephraser.subclass_from_config(config) assert isinstance(rephraser, NoopQueryRephraser) def test_subclass_from_config_default_path(): config = ObjectConstructionConfig.model_validate({"type": "NoopQueryRephraser"}) - rephraser = QueryRephraser.subclass_from_config(config) + rephraser: QueryRephraser = QueryRephraser.subclass_from_config(config) assert isinstance(rephraser, NoopQueryRephraser) @@ -33,7 +31,7 @@ def test_subclass_from_config_llm(): }, } ) - rephraser = QueryRephraser.subclass_from_config(config) + rephraser: QueryRephraser = QueryRephraser.subclass_from_config(config) assert isinstance(rephraser, LLMQueryRephraser) assert isinstance(rephraser._llm, LiteLLM) assert rephraser._llm.model_name == "some_model" @@ -48,50 +46,15 @@ def test_subclass_from_config_llm_prompt(): "type": "ragbits.core.llms.litellm:LiteLLM", "config": {"model_name": "some_model"}, }, - "prompt": {"type": "QueryRephraserPrompt"}, - }, - } - ) - rephraser = QueryRephraser.subclass_from_config(config) - assert isinstance(rephraser, LLMQueryRephraser) - assert isinstance(rephraser._llm, LiteLLM) - assert issubclass(rephraser._prompt, QueryRephraserPrompt) - - -def test_subclass_from_config_multi(): - config = ObjectConstructionConfig.model_validate( - { - "type": "ragbits.document_search.retrieval.rephrasers.multi:MultiQueryRephraser", - "config": { - "llm": { - "type": "ragbits.core.llms.litellm:LiteLLM", - "config": {"model_name": "some_model"}, - }, - }, - } - ) - rephraser = QueryRephraser.subclass_from_config(config) - assert isinstance(rephraser, MultiQueryRephraser) - assert isinstance(rephraser._llm, LiteLLM) - assert rephraser._llm.model_name == "some_model" - - -def test_subclass_from_config_multiquery_llm_prompt(): - config = ObjectConstructionConfig.model_validate( - { - "type": "ragbits.document_search.retrieval.rephrasers.multi:MultiQueryRephraser", - "config": { - "llm": { - "type": "ragbits.core.llms.litellm:LiteLLM", - "config": {"model_name": "some_model"}, + "prompt": {"type": "ragbits.document_search.retrieval.rephrasers.llm:LLMQueryRephraserPrompt"}, + "default_options": { + "n": 4, }, - "n": 4, - "prompt": {"type": "MultiQueryRephraserPrompt"}, }, } ) - rephraser = QueryRephraser.subclass_from_config(config) - assert isinstance(rephraser, MultiQueryRephraser) + rephraser: QueryRephraser = QueryRephraser.subclass_from_config(config) + assert isinstance(rephraser, LLMQueryRephraser) assert isinstance(rephraser._llm, LiteLLM) - assert rephraser._n == 4 - assert issubclass(rephraser._prompt, MultiQueryRephraserPrompt) + assert issubclass(rephraser._prompt, LLMQueryRephraserPrompt) + assert rephraser.default_options.n == 4 diff --git a/packages/ragbits-document-search/tests/unit/testprojects/project_with_instance_factory/factories.py b/packages/ragbits-document-search/tests/unit/testprojects/project_with_instance_factory/factories.py index 900d7ea8e..2913c0096 100644 --- a/packages/ragbits-document-search/tests/unit/testprojects/project_with_instance_factory/factories.py +++ b/packages/ragbits-document-search/tests/unit/testprojects/project_with_instance_factory/factories.py @@ -10,7 +10,7 @@ def create_document_search_instance_223(): vector_store_options = VectorStoreOptions(k=223) - document_search = DocumentSearch( + document_search: DocumentSearch = DocumentSearch( reranker=NoopReranker(default_options=RerankerOptions(top_n=223)), vector_store=InMemoryVectorStore(embedder=NoopEmbedder(), default_options=vector_store_options), ) @@ -19,7 +19,7 @@ def create_document_search_instance_223(): def create_document_search_instance_825(): vector_store_options = VectorStoreOptions(k=825) - document_search = DocumentSearch( + document_search: DocumentSearch = DocumentSearch( reranker=NoopReranker(default_options=RerankerOptions(top_n=825)), vector_store=InMemoryVectorStore(embedder=NoopEmbedder(), default_options=vector_store_options), )