From b2a7cd924fb5eee07466c90140bf6ab7c4c50422 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= <47692610+micpst@users.noreply.github.com> Date: Thu, 15 May 2025 03:37:24 +0200 Subject: [PATCH 01/10] add query rephraser options --- .../src/ragbits/document_search/_main.py | 28 ++++++---- .../retrieval/rephrasers/base.py | 21 ++++++-- .../retrieval/rephrasers/llm.py | 15 ++++-- .../retrieval/rephrasers/multi.py | 52 +++++++++++++------ .../retrieval/rephrasers/noop.py | 11 ++-- .../tests/unit/test_rephrasers.py | 16 +++--- .../factories.py | 4 +- 7 files changed, 100 insertions(+), 47 deletions(-) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index 91a21ac8f..bc853f4b7 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -25,21 +25,23 @@ from ragbits.document_search.ingestion.parsers.router import DocumentParserRouter from ragbits.document_search.ingestion.strategies import IngestStrategy, SequentialIngestStrategy from ragbits.document_search.ingestion.strategies.base import IngestExecutionError, IngestExecutionResult -from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser +from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser, QueryRephraserOptionsT from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser from ragbits.document_search.retrieval.rerankers.base import Reranker, RerankerOptionsT from ragbits.document_search.retrieval.rerankers.noop import NoopReranker -class DocumentSearchOptions(Options, Generic[VectorStoreOptionsT, RerankerOptionsT]): +class DocumentSearchOptions(Options, Generic[QueryRephraserOptionsT, VectorStoreOptionsT, RerankerOptionsT]): """ Object representing the options for the document search. Attributes: + query_rephraser_options: The options for the query rephraser. vector_store_options: The options for the vector store. reranker_options: The options for the reranker. """ + query_rephraser_options: QueryRephraserOptionsT | None | NotGiven = NOT_GIVEN vector_store_options: VectorStoreOptionsT | None | NotGiven = NOT_GIVEN reranker_options: RerankerOptionsT | None | NotGiven = NOT_GIVEN @@ -57,7 +59,9 @@ class DocumentSearchConfig(BaseModel): enricher_router: dict[str, ObjectConstructionConfig] = {} -class DocumentSearch(ConfigurableComponent[DocumentSearchOptions[VectorStoreOptionsT, RerankerOptionsT]]): +class DocumentSearch( + ConfigurableComponent[DocumentSearchOptions[QueryRephraserOptionsT, VectorStoreOptionsT, RerankerOptionsT]] +): """ Main entrypoint to the document search functionality. It provides methods for document retrieval and ingestion. @@ -80,9 +84,14 @@ def __init__( self, vector_store: VectorStore[VectorStoreOptionsT], *, - query_rephraser: QueryRephraser | None = None, + query_rephraser: QueryRephraser[QueryRephraserOptionsT] | None = None, reranker: Reranker[RerankerOptionsT] | None = None, - default_options: DocumentSearchOptions[VectorStoreOptionsT, RerankerOptionsT] | None = None, + default_options: DocumentSearchOptions[ + QueryRephraserOptionsT, + VectorStoreOptionsT, + RerankerOptionsT, + ] + | None = None, ingest_strategy: IngestStrategy | None = None, parser_router: DocumentParserRouter | None = None, enricher_router: ElementEnricherRouter | None = None, @@ -124,9 +133,9 @@ def from_config(cls, config: dict) -> Self: """ model = DocumentSearchConfig.model_validate(config) - query_rephraser = QueryRephraser.subclass_from_config(model.rephraser) - reranker: Reranker = Reranker.subclass_from_config(model.reranker) + query_rephraser: QueryRephraser = QueryRephraser.subclass_from_config(model.rephraser) vector_store: VectorStore = VectorStore.subclass_from_config(model.vector_store) + reranker: Reranker = Reranker.subclass_from_config(model.reranker) ingest_strategy = IngestStrategy.subclass_from_config(model.ingest_strategy) parser_router = DocumentParserRouter.from_config(model.parser_router) @@ -192,7 +201,7 @@ def preferred_subclass( async def search( self, query: str, - options: DocumentSearchOptions[VectorStoreOptionsT, RerankerOptionsT] | None = None, + options: DocumentSearchOptions[QueryRephraserOptionsT, VectorStoreOptionsT, RerankerOptionsT] | None = None, ) -> Sequence[Element]: """ Search for the most relevant chunks for a query. @@ -205,11 +214,12 @@ async def search( A list of chunks. """ merged_options = (self.default_options | options) if options else self.default_options + query_rephraser_options = merged_options.query_rephraser_options or None vector_store_options = merged_options.vector_store_options or None reranker_options = merged_options.reranker_options or None with trace(query=query, options=merged_options) as outputs: - queries = await self.query_rephraser.rephrase(query) + queries = await self.query_rephraser.rephrase(query, query_rephraser_options) elements = [ [ Element.from_vector_db_entry(result.entry, result.score) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py index 0b68b54ab..95b98f40e 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py @@ -1,25 +1,38 @@ from abc import ABC, abstractmethod -from typing import ClassVar +from collections.abc import Iterable +from typing import ClassVar, TypeVar -from ragbits.core.utils.config_handling import WithConstructionConfig +from ragbits.core.options import Options +from ragbits.core.utils.config_handling import ConfigurableComponent from ragbits.document_search.retrieval import rephrasers -class QueryRephraser(WithConstructionConfig, ABC): +class QueryRephraserOptions(Options): + """ + Object representing the options for the rephraser. + """ + + +QueryRephraserOptionsT = TypeVar("QueryRephraserOptionsT", bound=QueryRephraserOptions) + + +class QueryRephraser(ConfigurableComponent[QueryRephraserOptionsT], ABC): """ Rephrases a query. Can provide multiple rephrased queries from one sentence / question. """ + options_cls: type[QueryRephraserOptionsT] default_module: ClassVar = rephrasers configuration_key: ClassVar = "rephraser" @abstractmethod - async def rephrase(self, query: str) -> list[str]: + async def rephrase(self, query: str, options: QueryRephraserOptionsT | None = None) -> Iterable[str]: """ Rephrase a query. Args: query: The query to rephrase. + options: The options for the rephraser. Returns: The rephrased queries. diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py index f806e4223..a5c5eaec2 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py @@ -1,10 +1,13 @@ +from collections.abc import Iterable from typing import Any +from typing_extensions import Self + from ragbits.core.audit.traces import traceable from ragbits.core.llms.base import LLM from ragbits.core.prompt import Prompt from ragbits.core.utils.config_handling import ObjectConstructionConfig -from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser +from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser, QueryRephraserOptions from ragbits.document_search.retrieval.rephrasers.prompts import ( QueryRephraserInput, QueryRephraserPrompt, @@ -12,11 +15,13 @@ ) -class LLMQueryRephraser(QueryRephraser): +class LLMQueryRephraser(QueryRephraser[QueryRephraserOptions]): """ A rephraser class that uses a LLM to rephrase queries. """ + options_cls = QueryRephraserOptions + def __init__(self, llm: LLM, prompt: type[Prompt[QueryRephraserInput, Any]] | None = None): """ Initialize the LLMQueryRephraser with a LLM. @@ -29,12 +34,13 @@ def __init__(self, llm: LLM, prompt: type[Prompt[QueryRephraserInput, Any]] | No self._prompt = prompt or QueryRephraserPrompt @traceable - async def rephrase(self, query: str) -> list[str]: + async def rephrase(self, query: str, options: QueryRephraserOptions | None = None) -> Iterable[str]: """ Rephrase a given query using the LLM. Args: query: The query to be rephrased. If not provided, a custom prompt must be given. + options: The options for the rephraser. Returns: A list containing the rephrased query. @@ -50,7 +56,7 @@ async def rephrase(self, query: str) -> list[str]: return response if isinstance(response, list) else [response] @classmethod - def from_config(cls, config: dict) -> "LLMQueryRephraser": + def from_config(cls, config: dict) -> Self: """ Create an instance of `LLMQueryRephraser` from a configuration dictionary. @@ -64,7 +70,6 @@ def from_config(cls, config: dict) -> "LLMQueryRephraser": ValidationError: If the LLM or prompt configuration doesn't follow the expected format. InvalidConfigError: If an LLM or prompt class can't be found or is not the correct type. ValueError: If the prompt class is not a subclass of `Prompt`. - """ llm: LLM = LLM.subclass_from_config(ObjectConstructionConfig.model_validate(config["llm"])) prompt_cls = None diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py index 71f0cda7c..ddf0bd131 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py @@ -1,10 +1,13 @@ +from collections.abc import Iterable from typing import Any +from typing_extensions import Self + from ragbits.core.audit.traces import traceable from ragbits.core.llms.base import LLM from ragbits.core.prompt import Prompt from ragbits.core.utils.config_handling import ObjectConstructionConfig -from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser +from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser, QueryRephraserOptions from ragbits.document_search.retrieval.rephrasers.prompts import ( MultiQueryRephraserInput, MultiQueryRephraserPrompt, @@ -12,33 +15,50 @@ ) -class MultiQueryRephraser(QueryRephraser): +class MultiQueryRephraserOptions(QueryRephraserOptions): + """ + Object representing the options for the multi query rephraser. + + Attributes: + n: The number of rephrasings to generate. + """ + + n: int = 5 + + +class MultiQueryRephraser(QueryRephraser[MultiQueryRephraserOptions]): """ A rephraser class that uses a LLM to generate reworded versions of input query. """ + options_cls = MultiQueryRephraserOptions + def __init__( - self, llm: LLM, n: int | None = None, prompt: type[Prompt[MultiQueryRephraserInput, Any]] | None = None + self, + llm: LLM, + prompt: type[Prompt[MultiQueryRephraserInput, Any]] | None = None, + default_options: MultiQueryRephraserOptions | None = None, ): """ Initialize the MultiQueryRephraser with a LLM. Args: llm: A LLM instance to handle query rephrasing. - n: The number of rephrasings to generate. prompt: The prompt to use for rephrasing queries. + default_options: The default options for the rephraser. """ + super().__init__(default_options=default_options) self._llm = llm - self._n = n if n else 5 self._prompt = prompt or MultiQueryRephraserPrompt @traceable - async def rephrase(self, query: str) -> list[str]: + async def rephrase(self, query: str, options: QueryRephraserOptions | None = None) -> Iterable[str]: """ Rephrase a given query using the LLM. Args: query: The query to be rephrased. If not provided, a custom prompt must be given. + options: The options for the rephraser. Returns: A list containing the reworded versions of input query. @@ -48,13 +68,14 @@ async def rephrase(self, query: str) -> list[str]: LLMStatusError: If the LLM API returns an error status code. LLMResponseError: If the LLM API response is invalid. """ - input_data = self._prompt.input_type(query=query, n=self._n) # type: ignore + merged_options = (self.default_options | options) if options else self.default_options + input_data = self._prompt.input_type(query=query, n=merged_options.n) # type: ignore prompt = self._prompt(input_data) response = await self._llm.generate(prompt) return [query] + response @classmethod - def from_config(cls, config: dict) -> "MultiQueryRephraser": + def from_config(cls, config: dict) -> Self: """ Create an instance of `MultiQueryRephraser` from a configuration dictionary. @@ -68,12 +89,11 @@ def from_config(cls, config: dict) -> "MultiQueryRephraser": ValidationError: If the LLM or prompt configuration doesn't follow the expected format. InvalidConfigError: If an LLM or prompt class can't be found or is not the correct type. ValueError: If the prompt class is not a subclass of `Prompt`. - """ - llm: LLM = LLM.subclass_from_config(ObjectConstructionConfig.model_validate(config["llm"])) - prompt_cls = None - if "prompt" in config: - prompt_config = ObjectConstructionConfig.model_validate(config["prompt"]) - prompt_cls = get_rephraser_prompt(prompt_config.type) - n = config.get("n", 5) - return cls(llm=llm, n=n, prompt=prompt_cls) + config["llm"] = LLM.subclass_from_config(ObjectConstructionConfig.model_validate(config["llm"])) + config["prompt"] = ( + get_rephraser_prompt(ObjectConstructionConfig.model_validate(config["prompt"]).type) + if "prompt" in config + else None + ) + return super().from_config(config) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py index c8cccfa34..62782e13d 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py @@ -1,19 +1,24 @@ +from collections.abc import Iterable + from ragbits.core.audit.traces import traceable -from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser +from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser, QueryRephraserOptions -class NoopQueryRephraser(QueryRephraser): +class NoopQueryRephraser(QueryRephraser[QueryRephraserOptions]): """ A no-op query paraphraser that does not change the query. """ + options_cls = QueryRephraserOptions + @traceable - async def rephrase(self, query: str) -> list[str]: # noqa: PLR6301 + async def rephrase(self, query: str, options: QueryRephraserOptions | None = None) -> Iterable[str]: # noqa: PLR6301 """ Mock implementation which outputs the same query as in input. Args: query: The query to rephrase. + options: The options for the rephraser. Returns: The list with non-transformed query. diff --git a/packages/ragbits-document-search/tests/unit/test_rephrasers.py b/packages/ragbits-document-search/tests/unit/test_rephrasers.py index 8b31464dd..e27406613 100644 --- a/packages/ragbits-document-search/tests/unit/test_rephrasers.py +++ b/packages/ragbits-document-search/tests/unit/test_rephrasers.py @@ -11,13 +11,13 @@ def test_subclass_from_config(): config = ObjectConstructionConfig.model_validate( {"type": "ragbits.document_search.retrieval.rephrasers:NoopQueryRephraser"} ) - rephraser = QueryRephraser.subclass_from_config(config) + rephraser: QueryRephraser = QueryRephraser.subclass_from_config(config) assert isinstance(rephraser, NoopQueryRephraser) def test_subclass_from_config_default_path(): config = ObjectConstructionConfig.model_validate({"type": "NoopQueryRephraser"}) - rephraser = QueryRephraser.subclass_from_config(config) + rephraser: QueryRephraser = QueryRephraser.subclass_from_config(config) assert isinstance(rephraser, NoopQueryRephraser) @@ -33,7 +33,7 @@ def test_subclass_from_config_llm(): }, } ) - rephraser = QueryRephraser.subclass_from_config(config) + rephraser: QueryRephraser = QueryRephraser.subclass_from_config(config) assert isinstance(rephraser, LLMQueryRephraser) assert isinstance(rephraser._llm, LiteLLM) assert rephraser._llm.model_name == "some_model" @@ -52,7 +52,7 @@ def test_subclass_from_config_llm_prompt(): }, } ) - rephraser = QueryRephraser.subclass_from_config(config) + rephraser: QueryRephraser = QueryRephraser.subclass_from_config(config) assert isinstance(rephraser, LLMQueryRephraser) assert isinstance(rephraser._llm, LiteLLM) assert issubclass(rephraser._prompt, QueryRephraserPrompt) @@ -70,7 +70,7 @@ def test_subclass_from_config_multi(): }, } ) - rephraser = QueryRephraser.subclass_from_config(config) + rephraser: QueryRephraser = QueryRephraser.subclass_from_config(config) assert isinstance(rephraser, MultiQueryRephraser) assert isinstance(rephraser._llm, LiteLLM) assert rephraser._llm.model_name == "some_model" @@ -85,13 +85,13 @@ def test_subclass_from_config_multiquery_llm_prompt(): "type": "ragbits.core.llms.litellm:LiteLLM", "config": {"model_name": "some_model"}, }, - "n": 4, "prompt": {"type": "MultiQueryRephraserPrompt"}, + "default_options": {"n": 4}, }, } ) - rephraser = QueryRephraser.subclass_from_config(config) + rephraser: QueryRephraser = QueryRephraser.subclass_from_config(config) assert isinstance(rephraser, MultiQueryRephraser) assert isinstance(rephraser._llm, LiteLLM) - assert rephraser._n == 4 + assert rephraser.default_options.n == 4 assert issubclass(rephraser._prompt, MultiQueryRephraserPrompt) diff --git a/packages/ragbits-document-search/tests/unit/testprojects/project_with_instance_factory/factories.py b/packages/ragbits-document-search/tests/unit/testprojects/project_with_instance_factory/factories.py index 5ad1c8c47..df54cfde9 100644 --- a/packages/ragbits-document-search/tests/unit/testprojects/project_with_instance_factory/factories.py +++ b/packages/ragbits-document-search/tests/unit/testprojects/project_with_instance_factory/factories.py @@ -10,7 +10,7 @@ def create_document_search_instance_223(): vector_store_options = VectorStoreOptions(k=223) - document_search = DocumentSearch( + document_search: DocumentSearch = DocumentSearch( reranker=NoopReranker(default_options=RerankerOptions(top_n=223)), vector_store=InMemoryVectorStore(embedder=NoopEmbedder(), default_options=vector_store_options), ) @@ -19,7 +19,7 @@ def create_document_search_instance_223(): def create_document_search_instance_825(): vector_store_options = VectorStoreOptions(k=825) - document_search = DocumentSearch( + document_search: DocumentSearch = DocumentSearch( reranker=NoopReranker(default_options=RerankerOptions(top_n=825)), vector_store=InMemoryVectorStore(embedder=NoopEmbedder(), default_options=vector_store_options), ) From 7e614a96ada32cb88ffda09e8aaa10dc2717288f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= <47692610+micpst@users.noreply.github.com> Date: Thu, 15 May 2025 10:42:53 +0200 Subject: [PATCH 02/10] improve typing --- .../retrieval/rephrasers/llm.py | 21 +++++++++---------- .../retrieval/rephrasers/multi.py | 8 +++---- .../retrieval/rephrasers/prompts.py | 2 +- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py index a5c5eaec2..863329f60 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py @@ -1,5 +1,4 @@ from collections.abc import Iterable -from typing import Any from typing_extensions import Self @@ -22,7 +21,7 @@ class LLMQueryRephraser(QueryRephraser[QueryRephraserOptions]): options_cls = QueryRephraserOptions - def __init__(self, llm: LLM, prompt: type[Prompt[QueryRephraserInput, Any]] | None = None): + def __init__(self, llm: LLM, prompt: type[Prompt[QueryRephraserInput, str]] | None = None) -> None: """ Initialize the LLMQueryRephraser with a LLM. @@ -50,10 +49,9 @@ async def rephrase(self, query: str, options: QueryRephraserOptions | None = Non LLMStatusError: If the LLM API returns an error status code. LLMResponseError: If the LLM API response is invalid. """ - input_data = self._prompt.input_type(query=query) # type: ignore - prompt = self._prompt(input_data) + prompt = self._prompt(QueryRephraserInput(query=query)) response = await self._llm.generate(prompt) - return response if isinstance(response, list) else [response] + return [response] @classmethod def from_config(cls, config: dict) -> Self: @@ -71,9 +69,10 @@ def from_config(cls, config: dict) -> Self: InvalidConfigError: If an LLM or prompt class can't be found or is not the correct type. ValueError: If the prompt class is not a subclass of `Prompt`. """ - llm: LLM = LLM.subclass_from_config(ObjectConstructionConfig.model_validate(config["llm"])) - prompt_cls = None - if "prompt" in config: - prompt_config = ObjectConstructionConfig.model_validate(config["prompt"]) - prompt_cls = get_rephraser_prompt(prompt_config.type) - return cls(llm=llm, prompt=prompt_cls) + config["llm"] = LLM.subclass_from_config(ObjectConstructionConfig.model_validate(config["llm"])) + config["prompt"] = ( + get_rephraser_prompt(ObjectConstructionConfig.model_validate(config["prompt"]).type) + if "prompt" in config + else None + ) + return super().from_config(config) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py index ddf0bd131..d5c96d5a3 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py @@ -1,5 +1,4 @@ from collections.abc import Iterable -from typing import Any from typing_extensions import Self @@ -36,9 +35,9 @@ class MultiQueryRephraser(QueryRephraser[MultiQueryRephraserOptions]): def __init__( self, llm: LLM, - prompt: type[Prompt[MultiQueryRephraserInput, Any]] | None = None, + prompt: type[Prompt[MultiQueryRephraserInput, list[str]]] | None = None, default_options: MultiQueryRephraserOptions | None = None, - ): + ) -> None: """ Initialize the MultiQueryRephraser with a LLM. @@ -69,8 +68,7 @@ async def rephrase(self, query: str, options: QueryRephraserOptions | None = Non LLMResponseError: If the LLM API response is invalid. """ merged_options = (self.default_options | options) if options else self.default_options - input_data = self._prompt.input_type(query=query, n=merged_options.n) # type: ignore - prompt = self._prompt(input_data) + prompt = self._prompt(MultiQueryRephraserInput(query=query, n=merged_options.n)) response = await self._llm.generate(prompt) return [query] + response diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py index ba8332c1a..d95319dc1 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py @@ -41,7 +41,7 @@ class MultiQueryRephraserInput(BaseModel): n: int -class MultiQueryRephraserPrompt(Prompt[MultiQueryRephraserInput, list]): +class MultiQueryRephraserPrompt(Prompt[MultiQueryRephraserInput, list[str]]): """ A prompt template for generating multiple query rephrasings. """ From 06816a09ef21ec2981b3a94b107cb24c8654666f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= <47692610+micpst@users.noreply.github.com> Date: Thu, 15 May 2025 10:44:05 +0200 Subject: [PATCH 03/10] update changelog --- packages/ragbits-document-search/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/ragbits-document-search/CHANGELOG.md b/packages/ragbits-document-search/CHANGELOG.md index 193040c0d..ba56718e6 100644 --- a/packages/ragbits-document-search/CHANGELOG.md +++ b/packages/ragbits-document-search/CHANGELOG.md @@ -2,6 +2,7 @@ ## Unreleased +- Add query rephraser options (#560) - Update audit imports (#427) - BREAKING CHANGE: Adjust document search configurable interface (#554) - BREAKING CHANGE: Rename SearchConfig to DocumentSearchOptions (#554) From c2e3162e1499d7402ba4ad22ecccd11349c740ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= <47692610+micpst@users.noreply.github.com> Date: Thu, 15 May 2025 11:42:05 +0200 Subject: [PATCH 04/10] add llm options --- .../retrieval/rephrasers/llm.py | 36 +++++++++++++++---- .../retrieval/rephrasers/multi.py | 23 +++++++----- 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py index 863329f60..785e9c1cf 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py @@ -1,10 +1,12 @@ from collections.abc import Iterable +from typing import Generic from typing_extensions import Self from ragbits.core.audit.traces import traceable -from ragbits.core.llms.base import LLM +from ragbits.core.llms.base import LLM, LLMClientOptionsT from ragbits.core.prompt import Prompt +from ragbits.core.types import NOT_GIVEN, NotGiven from ragbits.core.utils.config_handling import ObjectConstructionConfig from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser, QueryRephraserOptions from ragbits.document_search.retrieval.rephrasers.prompts import ( @@ -14,26 +16,46 @@ ) -class LLMQueryRephraser(QueryRephraser[QueryRephraserOptions]): +class LLMQueryRephraserOptions(QueryRephraserOptions, Generic[LLMClientOptionsT]): + """ + Object representing the options for the LLM query rephraser. + + Attributes: + llm_options: The options for the LLM. + """ + + llm_options: LLMClientOptionsT | None | NotGiven = NOT_GIVEN + + +class LLMQueryRephraser(QueryRephraser[LLMQueryRephraserOptions[LLMClientOptionsT]]): """ A rephraser class that uses a LLM to rephrase queries. """ - options_cls = QueryRephraserOptions + options_cls: type[LLMQueryRephraserOptions] = LLMQueryRephraserOptions - def __init__(self, llm: LLM, prompt: type[Prompt[QueryRephraserInput, str]] | None = None) -> None: + def __init__( + self, + llm: LLM[LLMClientOptionsT], + prompt: type[Prompt[QueryRephraserInput, str]] | None = None, + default_options: LLMQueryRephraserOptions[LLMClientOptionsT] | None = None, + ) -> None: """ Initialize the LLMQueryRephraser with a LLM. Args: llm: A LLM instance to handle query rephrasing. prompt: The prompt to use for rephrasing queries. + default_options: The default options for the rephraser. """ + super().__init__(default_options=default_options) self._llm = llm self._prompt = prompt or QueryRephraserPrompt @traceable - async def rephrase(self, query: str, options: QueryRephraserOptions | None = None) -> Iterable[str]: + async def rephrase( + self, query: str, options: LLMQueryRephraserOptions[LLMClientOptionsT] | None = None + ) -> Iterable[str]: """ Rephrase a given query using the LLM. @@ -49,8 +71,10 @@ async def rephrase(self, query: str, options: QueryRephraserOptions | None = Non LLMStatusError: If the LLM API returns an error status code. LLMResponseError: If the LLM API response is invalid. """ + merged_options = (self.default_options | options) if options else self.default_options + llm_options = merged_options.llm_options or None prompt = self._prompt(QueryRephraserInput(query=query)) - response = await self._llm.generate(prompt) + response = await self._llm.generate(prompt, options=llm_options) return [response] @classmethod diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py index d5c96d5a3..5a802dde5 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py @@ -3,10 +3,11 @@ from typing_extensions import Self from ragbits.core.audit.traces import traceable -from ragbits.core.llms.base import LLM +from ragbits.core.llms.base import LLM, LLMClientOptionsT from ragbits.core.prompt import Prompt from ragbits.core.utils.config_handling import ObjectConstructionConfig -from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser, QueryRephraserOptions +from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser +from ragbits.document_search.retrieval.rephrasers.llm import LLMQueryRephraserOptions from ragbits.document_search.retrieval.rephrasers.prompts import ( MultiQueryRephraserInput, MultiQueryRephraserPrompt, @@ -14,29 +15,30 @@ ) -class MultiQueryRephraserOptions(QueryRephraserOptions): +class MultiQueryRephraserOptions(LLMQueryRephraserOptions[LLMClientOptionsT]): """ Object representing the options for the multi query rephraser. Attributes: + llm_options: The options for the LLM. n: The number of rephrasings to generate. """ n: int = 5 -class MultiQueryRephraser(QueryRephraser[MultiQueryRephraserOptions]): +class MultiQueryRephraser(QueryRephraser[MultiQueryRephraserOptions[LLMClientOptionsT]]): """ A rephraser class that uses a LLM to generate reworded versions of input query. """ - options_cls = MultiQueryRephraserOptions + options_cls: type[MultiQueryRephraserOptions] = MultiQueryRephraserOptions def __init__( self, - llm: LLM, + llm: LLM[LLMClientOptionsT], prompt: type[Prompt[MultiQueryRephraserInput, list[str]]] | None = None, - default_options: MultiQueryRephraserOptions | None = None, + default_options: MultiQueryRephraserOptions[LLMClientOptionsT] | None = None, ) -> None: """ Initialize the MultiQueryRephraser with a LLM. @@ -51,7 +53,9 @@ def __init__( self._prompt = prompt or MultiQueryRephraserPrompt @traceable - async def rephrase(self, query: str, options: QueryRephraserOptions | None = None) -> Iterable[str]: + async def rephrase( + self, query: str, options: MultiQueryRephraserOptions[LLMClientOptionsT] | None = None + ) -> Iterable[str]: """ Rephrase a given query using the LLM. @@ -68,8 +72,9 @@ async def rephrase(self, query: str, options: QueryRephraserOptions | None = Non LLMResponseError: If the LLM API response is invalid. """ merged_options = (self.default_options | options) if options else self.default_options + llm_options = merged_options.llm_options or None prompt = self._prompt(MultiQueryRephraserInput(query=query, n=merged_options.n)) - response = await self._llm.generate(prompt) + response = await self._llm.generate(prompt, options=llm_options) return [query] + response @classmethod From 92188a49ac5f35daeb2168972c8ae00b270f9148 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= <47692610+micpst@users.noreply.github.com> Date: Thu, 15 May 2025 11:52:00 +0200 Subject: [PATCH 05/10] fix format --- .../src/ragbits/document_search/retrieval/rephrasers/llm.py | 4 +++- .../src/ragbits/document_search/retrieval/rephrasers/multi.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py index 785e9c1cf..fe1884227 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py @@ -54,7 +54,9 @@ def __init__( @traceable async def rephrase( - self, query: str, options: LLMQueryRephraserOptions[LLMClientOptionsT] | None = None + self, + query: str, + options: LLMQueryRephraserOptions[LLMClientOptionsT] | None = None, ) -> Iterable[str]: """ Rephrase a given query using the LLM. diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py index 5a802dde5..91da0e94e 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py @@ -54,7 +54,9 @@ def __init__( @traceable async def rephrase( - self, query: str, options: MultiQueryRephraserOptions[LLMClientOptionsT] | None = None + self, + query: str, + options: MultiQueryRephraserOptions[LLMClientOptionsT] | None = None, ) -> Iterable[str]: """ Rephrase a given query using the LLM. From eb2c032c6aad853677b13b8aa45ce515f3acfc9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= <47692610+micpst@users.noreply.github.com> Date: Thu, 15 May 2025 15:06:35 +0200 Subject: [PATCH 06/10] update docs --- .../document_search/retrieval/rephrasers.md | 18 ++++++++++++++---- .../how-to/document_search/search-documents.md | 15 ++++++++++----- .../retrieval/rephrasers/__init__.py | 9 ++++++--- .../retrieval/rephrasers/noop.py | 2 +- 4 files changed, 31 insertions(+), 13 deletions(-) diff --git a/docs/api_reference/document_search/retrieval/rephrasers.md b/docs/api_reference/document_search/retrieval/rephrasers.md index 2a6acb6d2..c46fe5593 100644 --- a/docs/api_reference/document_search/retrieval/rephrasers.md +++ b/docs/api_reference/document_search/retrieval/rephrasers.md @@ -1,6 +1,16 @@ # Query Rephrasers -::: ragbits.document_search.retrieval.rephrasers.QueryRephraser -::: ragbits.document_search.retrieval.rephrasers.LLMQueryRephraser -::: ragbits.document_search.retrieval.rephrasers.MultiQueryRephraser -::: ragbits.document_search.retrieval.rephrasers.NoopQueryRephraser \ No newline at end of file + +::: ragbits.document_search.retrieval.rephrasers.base.QueryRephraserOptions + +::: ragbits.document_search.retrieval.rephrasers.llm.LLMQueryRephraserOptions + +::: ragbits.document_search.retrieval.rephrasers.multi.MultiQueryRephraserOptions + +::: ragbits.document_search.retrieval.rephrasers.base.QueryRephraser + +::: ragbits.document_search.retrieval.rephrasers.llm.LLMQueryRephraser + +::: ragbits.document_search.retrieval.rephrasers.multi.MultiQueryRephraser + +::: ragbits.document_search.retrieval.rephrasers.noop.NoopQueryRephraser diff --git a/docs/how-to/document_search/search-documents.md b/docs/how-to/document_search/search-documents.md index 4690ba612..bddb32a4d 100644 --- a/docs/how-to/document_search/search-documents.md +++ b/docs/how-to/document_search/search-documents.md @@ -94,10 +94,10 @@ By default, the input query is provided directly to the embedding model. However === "Multi query" ```python - from ragbits.document_search.retrieval.rephrasers import MultiQueryRephraser + from ragbits.document_search.retrieval.rephrasers import MultiQueryRephraser, MultiQueryRephraserOptions from ragbits.document_search import DocumentSearch - query_rephraser = MultiQueryRephraser(LiteLLM(model_name="gpt-3.5-turbo"), n=3) + query_rephraser = MultiQueryRephraser(LiteLLM(model_name="gpt-3.5-turbo"), default_options=MultiQueryRephraserOptions(n=3)) document_search = DocumentSearch(query_rephraser=query_rephraser, ...) elements = await document_search.search("What is the capital of Poland?") @@ -108,20 +108,23 @@ By default, the input query is provided directly to the embedding model. However To define a new rephraser, extend the the [`QueryRephraser`][ragbits.document_search.retrieval.rephrasers.base.QueryRephraser] class. ```python -from ragbits.document_search.retrieval.rephrasers import QueryRephraser +from ragbits.document_search.retrieval.rephrasers import QueryRephraser, QueryRephraserOptions -class CustomRephraser(QueryRephraser): +class CustomRephraser(QueryRephraser[QueryRephraserOptions]): """ Rephraser that uses a LLM to rephrase queries. """ - async def rephrase(self, query: str) -> list[str]: + options_cls: type[QueryRephraserOptions] = QueryRephraserOptions + + async def rephrase(self, query: str, options: QueryRephraserOptions | None = None) -> Iterable[str]: """ Rephrase a query using the LLM. Args: query: The query to be rephrased. + options: The options for rephrasing. Returns: List containing the rephrased query. @@ -175,6 +178,8 @@ class CustomReranker(Reranker[RerankerOptions]): Reranker that uses a LLM to rerank elements. """ + options_cls: type[RerankerOptions] = RerankerOptions + async def rerank( self, elements: Sequence[Sequence[Element]], diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py index 46527e591..c09923f18 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py @@ -1,6 +1,6 @@ -from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser -from ragbits.document_search.retrieval.rephrasers.llm import LLMQueryRephraser -from ragbits.document_search.retrieval.rephrasers.multi import MultiQueryRephraser +from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser, QueryRephraserOptions +from ragbits.document_search.retrieval.rephrasers.llm import LLMQueryRephraser, LLMQueryRephraserOptions +from ragbits.document_search.retrieval.rephrasers.multi import MultiQueryRephraser, MultiQueryRephraserOptions from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser from ragbits.document_search.retrieval.rephrasers.prompts import ( MultiQueryRephraserInput, @@ -11,11 +11,14 @@ __all__ = [ "LLMQueryRephraser", + "LLMQueryRephraserOptions", "MultiQueryRephraser", "MultiQueryRephraserInput", + "MultiQueryRephraserOptions", "MultiQueryRephraserPrompt", "NoopQueryRephraser", "QueryRephraser", "QueryRephraserInput", + "QueryRephraserOptions", "QueryRephraserPrompt", ] diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py index 62782e13d..576b9beaf 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py @@ -9,7 +9,7 @@ class NoopQueryRephraser(QueryRephraser[QueryRephraserOptions]): A no-op query paraphraser that does not change the query. """ - options_cls = QueryRephraserOptions + options_cls: type[QueryRephraserOptions] = QueryRephraserOptions @traceable async def rephrase(self, query: str, options: QueryRephraserOptions | None = None) -> Iterable[str]: # noqa: PLR6301 From d5188c791cf08f002d003e82aefb55ca9e891880 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= <47692610+micpst@users.noreply.github.com> Date: Mon, 19 May 2025 15:40:42 +0200 Subject: [PATCH 07/10] mv prompts --- .../document_search/search-documents.md | 2 +- examples/document-search/configurable.py | 12 ++- .../retrieval/rephrasers/__init__.py | 16 ++-- .../retrieval/rephrasers/llm.py | 34 +++++-- .../retrieval/rephrasers/multi.py | 46 ++++++++-- .../retrieval/rephrasers/prompts.py | 89 ------------------- .../tests/unit/test_rephrasers.py | 9 +- 7 files changed, 87 insertions(+), 121 deletions(-) delete mode 100644 packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py diff --git a/docs/how-to/document_search/search-documents.md b/docs/how-to/document_search/search-documents.md index bddb32a4d..03ac1e9d7 100644 --- a/docs/how-to/document_search/search-documents.md +++ b/docs/how-to/document_search/search-documents.md @@ -129,7 +129,7 @@ class CustomRephraser(QueryRephraser[QueryRephraserOptions]): Returns: List containing the rephrased query. """ - responses = await llm.generate(QueryRephraserPrompt(...)) + responses = await llm.generate(CustomRephraserPrompt(...)) ... return [...] ``` diff --git a/examples/document-search/configurable.py b/examples/document-search/configurable.py index 7335c5920..d265a96ad 100644 --- a/examples/document-search/configurable.py +++ b/examples/document-search/configurable.py @@ -30,7 +30,7 @@ class to rephrase the query. import asyncio -from ragbits.core.audit import set_trace_handlers +from ragbits.core.audit.traces import set_trace_handlers from ragbits.document_search import DocumentSearch from ragbits.document_search.documents.document import DocumentMeta @@ -86,13 +86,12 @@ class to rephrase the query. "model": "cohere/rerank-english-v3.0", "default_options": { "top_n": 3, - "max_chunks_per_doc": None, }, }, }, "parser_router": {"txt": {"type": "TextDocumentParser"}}, "rephraser": { - "type": "LLMQueryRephraser", + "type": "ragbits.document_search.retrieval.rephrasers:LLMQueryRephraser", "config": { "llm": { "type": "ragbits.core.llms.litellm:LiteLLM", @@ -101,7 +100,12 @@ class to rephrase the query. }, }, "prompt": { - "type": "QueryRephraserPrompt", + "type": "ragbits.document_search.retrieval.rephrasers:QueryRephraserPrompt", + }, + "default_options": { + "llm_options": { + "temperature": 0.0, + }, }, }, }, diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py index c09923f18..8bcbcbcac 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py @@ -1,13 +1,17 @@ from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser, QueryRephraserOptions -from ragbits.document_search.retrieval.rephrasers.llm import LLMQueryRephraser, LLMQueryRephraserOptions -from ragbits.document_search.retrieval.rephrasers.multi import MultiQueryRephraser, MultiQueryRephraserOptions -from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser -from ragbits.document_search.retrieval.rephrasers.prompts import ( - MultiQueryRephraserInput, - MultiQueryRephraserPrompt, +from ragbits.document_search.retrieval.rephrasers.llm import ( + LLMQueryRephraser, + LLMQueryRephraserOptions, QueryRephraserInput, QueryRephraserPrompt, ) +from ragbits.document_search.retrieval.rephrasers.multi import ( + MultiQueryRephraser, + MultiQueryRephraserInput, + MultiQueryRephraserOptions, + MultiQueryRephraserPrompt, +) +from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser __all__ = [ "LLMQueryRephraser", diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py index fe1884227..352cc02de 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py @@ -1,19 +1,38 @@ from collections.abc import Iterable from typing import Generic +from pydantic import BaseModel from typing_extensions import Self from ragbits.core.audit.traces import traceable from ragbits.core.llms.base import LLM, LLMClientOptionsT from ragbits.core.prompt import Prompt from ragbits.core.types import NOT_GIVEN, NotGiven -from ragbits.core.utils.config_handling import ObjectConstructionConfig +from ragbits.core.utils.config_handling import ObjectConstructionConfig, import_by_path from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser, QueryRephraserOptions -from ragbits.document_search.retrieval.rephrasers.prompts import ( - QueryRephraserInput, - QueryRephraserPrompt, - get_rephraser_prompt, -) + + +class QueryRephraserInput(BaseModel): + """ + Input data for the query rephraser prompt. + """ + + query: str + + +class QueryRephraserPrompt(Prompt[QueryRephraserInput, str]): + """ + Prompt for generating a rephrased user query. + """ + + system_prompt = """ + You are an expert in query rephrasing and clarity improvement. + Your task is to return a single paraphrased version of a user's query, + correcting any typos, handling abbreviations and improving clarity. + Focus on making the query more precise and readable while keeping its original intent. + Just return the rephrased query. No additional explanations are needed. + """ + user_prompt = "Query:{{ query }}" class LLMQueryRephraserOptions(QueryRephraserOptions, Generic[LLMClientOptionsT]): @@ -93,11 +112,10 @@ def from_config(cls, config: dict) -> Self: Raises: ValidationError: If the LLM or prompt configuration doesn't follow the expected format. InvalidConfigError: If an LLM or prompt class can't be found or is not the correct type. - ValueError: If the prompt class is not a subclass of `Prompt`. """ config["llm"] = LLM.subclass_from_config(ObjectConstructionConfig.model_validate(config["llm"])) config["prompt"] = ( - get_rephraser_prompt(ObjectConstructionConfig.model_validate(config["prompt"]).type) + import_by_path(ObjectConstructionConfig.model_validate(config["prompt"]).type) if "prompt" in config else None ) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py index 91da0e94e..6d18691e5 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py @@ -1,18 +1,49 @@ from collections.abc import Iterable +from pydantic import BaseModel from typing_extensions import Self from ragbits.core.audit.traces import traceable from ragbits.core.llms.base import LLM, LLMClientOptionsT from ragbits.core.prompt import Prompt -from ragbits.core.utils.config_handling import ObjectConstructionConfig +from ragbits.core.utils.config_handling import ObjectConstructionConfig, import_by_path from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser from ragbits.document_search.retrieval.rephrasers.llm import LLMQueryRephraserOptions -from ragbits.document_search.retrieval.rephrasers.prompts import ( - MultiQueryRephraserInput, - MultiQueryRephraserPrompt, - get_rephraser_prompt, -) + + +class MultiQueryRephraserInput(BaseModel): + """ + Input data for the multi query rephraser prompt. + """ + + query: str + n: int + + +class MultiQueryRephraserPrompt(Prompt[MultiQueryRephraserInput, list]): + """ + Prompt for generating multiple query rephrasings. + """ + + system_prompt = """ + You are a helpful assistant that creates short rephrased versions of a given query. + Your task is to generate {{ n }} different versions of the given user query to retrieve relevant documents + from a vector database. They can be phrased as statements, as they will be used as a search query. + By generating multiple perspectives on the user query, your goal is to help the user overcome some of the + limitations of the distance-based similarity search. + Alternative queries should only contain information present in the original query. Do not include anything + in the alternative query, you have not seen in the original version. + It is VERY important you DO NOT ADD any comments or notes. Return ONLY alternative queries. + Provide these alternative queries separated by newlines. + DO NOT ADD any enumeration. + """ + user_prompt = "Query: {{ query }}" + + @staticmethod + def _list_parser(value: str) -> list[str]: + return value.split("\n") + + response_parser = _list_parser class MultiQueryRephraserOptions(LLMQueryRephraserOptions[LLMClientOptionsT]): @@ -93,11 +124,10 @@ def from_config(cls, config: dict) -> Self: Raises: ValidationError: If the LLM or prompt configuration doesn't follow the expected format. InvalidConfigError: If an LLM or prompt class can't be found or is not the correct type. - ValueError: If the prompt class is not a subclass of `Prompt`. """ config["llm"] = LLM.subclass_from_config(ObjectConstructionConfig.model_validate(config["llm"])) config["prompt"] = ( - get_rephraser_prompt(ObjectConstructionConfig.model_validate(config["prompt"]).type) + import_by_path(ObjectConstructionConfig.model_validate(config["prompt"]).type) if "prompt" in config else None ) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py deleted file mode 100644 index d95319dc1..000000000 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py +++ /dev/null @@ -1,89 +0,0 @@ -import sys -from typing import Any - -from pydantic import BaseModel - -from ragbits.core.prompt.prompt import Prompt -from ragbits.core.utils.config_handling import import_by_path - -module = sys.modules[__name__] - - -class QueryRephraserInput(BaseModel): - """ - Input data for the query rephraser prompt. - """ - - query: str - - -class QueryRephraserPrompt(Prompt[QueryRephraserInput, str]): - """ - A prompt class for generating a rephrased version of a user's query using a LLM. - """ - - user_prompt = "{{ query }}" - system_prompt = ( - "You are an expert in query rephrasing and clarity improvement. " - "Your task is to return a single paraphrased version of a user's query, " - "correcting any typos, handling abbreviations and improving clarity. " - "Focus on making the query more precise and readable while keeping its original intent.\n\n" - "Just return the rephrased query. No additional explanations are needed." - ) - - -class MultiQueryRephraserInput(BaseModel): - """ - Represents the input data for the multi query rephraser prompt. - """ - - query: str - n: int - - -class MultiQueryRephraserPrompt(Prompt[MultiQueryRephraserInput, list[str]]): - """ - A prompt template for generating multiple query rephrasings. - """ - - user_prompt = "{{ query }}" - system_prompt = ( - "You are a helpful assistant that creates short rephrased versions of a given query. " - "Your task is to generate {{ n }} different versions of the given user query to retrieve relevant documents" - " from a vector database. They can be phrased as statements, as they will be used as a search query. " - "By generating multiple perspectives on the user query, " - "your goal is to help the user overcome some of the limitations of the distance-based similarity search." - "Alternative queries should only contain information present in the original query. Do not include anything" - " in the alternative query, you have not seen in the original version.\n\n" - "It is VERY important you DO NOT ADD any comments or notes. Return ONLY alternative queries. " - "Provide these alternative queries separated by newlines. " - "DO NOT ADD any enumeration." - ) - - @staticmethod - def _list_parser(value: str) -> list[str]: - return value.split("\n") - - response_parser = _list_parser - - -def get_rephraser_prompt(prompt: str) -> type[Prompt[Any, Any]]: - """ - Initializes and returns a QueryRephraser object based on the provided configuration. - - Args: - prompt: The prompt class to use for rephrasing queries. - - Returns: - An instance of the specified QueryRephraser class, initialized with the provided config - (if any) or default arguments. - - Raises: - ValueError: If the prompt class is not a subclass of `Prompt`. - """ - prompt_cls = import_by_path(prompt, module) - - if not issubclass(prompt_cls, Prompt): - raise ValueError(f"Invalid rephraser prompt class: {prompt_cls}") - - return prompt_cls diff --git a/packages/ragbits-document-search/tests/unit/test_rephrasers.py b/packages/ragbits-document-search/tests/unit/test_rephrasers.py index e27406613..856b4d4f5 100644 --- a/packages/ragbits-document-search/tests/unit/test_rephrasers.py +++ b/packages/ragbits-document-search/tests/unit/test_rephrasers.py @@ -1,10 +1,9 @@ from ragbits.core.llms.litellm import LiteLLM from ragbits.core.utils.config_handling import ObjectConstructionConfig from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser -from ragbits.document_search.retrieval.rephrasers.llm import LLMQueryRephraser -from ragbits.document_search.retrieval.rephrasers.multi import MultiQueryRephraser +from ragbits.document_search.retrieval.rephrasers.llm import LLMQueryRephraser, QueryRephraserPrompt +from ragbits.document_search.retrieval.rephrasers.multi import MultiQueryRephraser, MultiQueryRephraserPrompt from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser -from ragbits.document_search.retrieval.rephrasers.prompts import MultiQueryRephraserPrompt, QueryRephraserPrompt def test_subclass_from_config(): @@ -48,7 +47,7 @@ def test_subclass_from_config_llm_prompt(): "type": "ragbits.core.llms.litellm:LiteLLM", "config": {"model_name": "some_model"}, }, - "prompt": {"type": "QueryRephraserPrompt"}, + "prompt": {"type": "ragbits.document_search.retrieval.rephrasers.llm:QueryRephraserPrompt"}, }, } ) @@ -85,7 +84,7 @@ def test_subclass_from_config_multiquery_llm_prompt(): "type": "ragbits.core.llms.litellm:LiteLLM", "config": {"model_name": "some_model"}, }, - "prompt": {"type": "MultiQueryRephraserPrompt"}, + "prompt": {"type": "ragbits.document_search.retrieval.rephrasers.multi:MultiQueryRephraserPrompt"}, "default_options": {"n": 4}, }, } From e687e82761638b222434874ca12992f9a2ebaf7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= <47692610+micpst@users.noreply.github.com> Date: Mon, 19 May 2025 16:18:08 +0200 Subject: [PATCH 08/10] merge multi and llm rephrasers --- .../document_search/retrieval/rephrasers.md | 5 - .../document_search/search-documents.md | 4 +- examples/document-search/configurable.py | 3 +- .../retrieval/rephrasers/__init__.py | 18 +-- .../retrieval/rephrasers/llm.py | 35 +++-- .../retrieval/rephrasers/multi.py | 134 ------------------ .../tests/unit/test_rephrasers.py | 48 +------ 7 files changed, 41 insertions(+), 206 deletions(-) delete mode 100644 packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py diff --git a/docs/api_reference/document_search/retrieval/rephrasers.md b/docs/api_reference/document_search/retrieval/rephrasers.md index c46fe5593..bd9cf8a0d 100644 --- a/docs/api_reference/document_search/retrieval/rephrasers.md +++ b/docs/api_reference/document_search/retrieval/rephrasers.md @@ -1,16 +1,11 @@ # Query Rephrasers - ::: ragbits.document_search.retrieval.rephrasers.base.QueryRephraserOptions ::: ragbits.document_search.retrieval.rephrasers.llm.LLMQueryRephraserOptions -::: ragbits.document_search.retrieval.rephrasers.multi.MultiQueryRephraserOptions - ::: ragbits.document_search.retrieval.rephrasers.base.QueryRephraser ::: ragbits.document_search.retrieval.rephrasers.llm.LLMQueryRephraser -::: ragbits.document_search.retrieval.rephrasers.multi.MultiQueryRephraser - ::: ragbits.document_search.retrieval.rephrasers.noop.NoopQueryRephraser diff --git a/docs/how-to/document_search/search-documents.md b/docs/how-to/document_search/search-documents.md index 03ac1e9d7..94c7bc2d6 100644 --- a/docs/how-to/document_search/search-documents.md +++ b/docs/how-to/document_search/search-documents.md @@ -94,10 +94,10 @@ By default, the input query is provided directly to the embedding model. However === "Multi query" ```python - from ragbits.document_search.retrieval.rephrasers import MultiQueryRephraser, MultiQueryRephraserOptions + from ragbits.document_search.retrieval.rephrasers import LLMQueryRephraser, LLMQueryRephraserOptions from ragbits.document_search import DocumentSearch - query_rephraser = MultiQueryRephraser(LiteLLM(model_name="gpt-3.5-turbo"), default_options=MultiQueryRephraserOptions(n=3)) + query_rephraser = LLMQueryRephraser(LiteLLM(model_name="gpt-3.5-turbo"), default_options=LLMQueryRephraserOptions(n=3)) document_search = DocumentSearch(query_rephraser=query_rephraser, ...) elements = await document_search.search("What is the capital of Poland?") diff --git a/examples/document-search/configurable.py b/examples/document-search/configurable.py index d265a96ad..f686127e0 100644 --- a/examples/document-search/configurable.py +++ b/examples/document-search/configurable.py @@ -100,9 +100,10 @@ class to rephrase the query. }, }, "prompt": { - "type": "ragbits.document_search.retrieval.rephrasers:QueryRephraserPrompt", + "type": "ragbits.document_search.retrieval.rephrasers:LLMQueryRephraserPrompt", }, "default_options": { + "n": 2, "llm_options": { "temperature": 0.0, }, diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py index 8bcbcbcac..6f9c22e71 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py @@ -1,28 +1,18 @@ from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser, QueryRephraserOptions from ragbits.document_search.retrieval.rephrasers.llm import ( LLMQueryRephraser, + LLMQueryRephraserInput, LLMQueryRephraserOptions, - QueryRephraserInput, - QueryRephraserPrompt, -) -from ragbits.document_search.retrieval.rephrasers.multi import ( - MultiQueryRephraser, - MultiQueryRephraserInput, - MultiQueryRephraserOptions, - MultiQueryRephraserPrompt, + LLMQueryRephraserPrompt, ) from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser __all__ = [ "LLMQueryRephraser", + "LLMQueryRephraserInput", "LLMQueryRephraserOptions", - "MultiQueryRephraser", - "MultiQueryRephraserInput", - "MultiQueryRephraserOptions", - "MultiQueryRephraserPrompt", + "LLMQueryRephraserPrompt", "NoopQueryRephraser", "QueryRephraser", - "QueryRephraserInput", "QueryRephraserOptions", - "QueryRephraserPrompt", ] diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py index 352cc02de..3c57709df 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py @@ -12,27 +12,45 @@ from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser, QueryRephraserOptions -class QueryRephraserInput(BaseModel): +class LLMQueryRephraserInput(BaseModel): """ Input data for the query rephraser prompt. """ query: str + n: int | None = None -class QueryRephraserPrompt(Prompt[QueryRephraserInput, str]): +class LLMQueryRephraserPrompt(Prompt[LLMQueryRephraserInput, list]): """ Prompt for generating a rephrased user query. """ system_prompt = """ You are an expert in query rephrasing and clarity improvement. + {%- if n and n > 1 %} + Your task is to generate {{ n }} different versions of the given user query to retrieve relevant documents + from a vector database. They can be phrased as statements, as they will be used as a search query. + By generating multiple perspectives on the user query, your goal is to help the user overcome some of the + limitations of the distance-based similarity search. + Alternative queries should only contain information present in the original query. Do not include anything + in the alternative query, you have not seen in the original version. + It is VERY important you DO NOT ADD any comments or notes. Return ONLY alternative queries. + Provide these alternative queries separated by newlines. DO NOT ADD any enumeration. + {%- else %} Your task is to return a single paraphrased version of a user's query, correcting any typos, handling abbreviations and improving clarity. Focus on making the query more precise and readable while keeping its original intent. Just return the rephrased query. No additional explanations are needed. + {%- endif %} """ - user_prompt = "Query:{{ query }}" + user_prompt = "Query: {{ query }}" + + @staticmethod + def _response_parser(value: str) -> list[str]: + return [line.strip() for line in value.strip().split("\n") if line.strip()] + + response_parser = _response_parser class LLMQueryRephraserOptions(QueryRephraserOptions, Generic[LLMClientOptionsT]): @@ -40,9 +58,11 @@ class LLMQueryRephraserOptions(QueryRephraserOptions, Generic[LLMClientOptionsT] Object representing the options for the LLM query rephraser. Attributes: + n: The number of rephrasings to generate. Any number below 2 will generate only one rephrasing. llm_options: The options for the LLM. """ + n: int | None | NotGiven = NOT_GIVEN llm_options: LLMClientOptionsT | None | NotGiven = NOT_GIVEN @@ -56,7 +76,7 @@ class LLMQueryRephraser(QueryRephraser[LLMQueryRephraserOptions[LLMClientOptions def __init__( self, llm: LLM[LLMClientOptionsT], - prompt: type[Prompt[QueryRephraserInput, str]] | None = None, + prompt: type[Prompt[LLMQueryRephraserInput, list[str]]] | None = None, default_options: LLMQueryRephraserOptions[LLMClientOptionsT] | None = None, ) -> None: """ @@ -69,7 +89,7 @@ def __init__( """ super().__init__(default_options=default_options) self._llm = llm - self._prompt = prompt or QueryRephraserPrompt + self._prompt = prompt or LLMQueryRephraserPrompt @traceable async def rephrase( @@ -94,9 +114,8 @@ async def rephrase( """ merged_options = (self.default_options | options) if options else self.default_options llm_options = merged_options.llm_options or None - prompt = self._prompt(QueryRephraserInput(query=query)) - response = await self._llm.generate(prompt, options=llm_options) - return [response] + prompt = self._prompt(LLMQueryRephraserInput(query=query, n=merged_options.n or None)) + return await self._llm.generate(prompt, options=llm_options) @classmethod def from_config(cls, config: dict) -> Self: diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py deleted file mode 100644 index 6d18691e5..000000000 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/multi.py +++ /dev/null @@ -1,134 +0,0 @@ -from collections.abc import Iterable - -from pydantic import BaseModel -from typing_extensions import Self - -from ragbits.core.audit.traces import traceable -from ragbits.core.llms.base import LLM, LLMClientOptionsT -from ragbits.core.prompt import Prompt -from ragbits.core.utils.config_handling import ObjectConstructionConfig, import_by_path -from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser -from ragbits.document_search.retrieval.rephrasers.llm import LLMQueryRephraserOptions - - -class MultiQueryRephraserInput(BaseModel): - """ - Input data for the multi query rephraser prompt. - """ - - query: str - n: int - - -class MultiQueryRephraserPrompt(Prompt[MultiQueryRephraserInput, list]): - """ - Prompt for generating multiple query rephrasings. - """ - - system_prompt = """ - You are a helpful assistant that creates short rephrased versions of a given query. - Your task is to generate {{ n }} different versions of the given user query to retrieve relevant documents - from a vector database. They can be phrased as statements, as they will be used as a search query. - By generating multiple perspectives on the user query, your goal is to help the user overcome some of the - limitations of the distance-based similarity search. - Alternative queries should only contain information present in the original query. Do not include anything - in the alternative query, you have not seen in the original version. - It is VERY important you DO NOT ADD any comments or notes. Return ONLY alternative queries. - Provide these alternative queries separated by newlines. - DO NOT ADD any enumeration. - """ - user_prompt = "Query: {{ query }}" - - @staticmethod - def _list_parser(value: str) -> list[str]: - return value.split("\n") - - response_parser = _list_parser - - -class MultiQueryRephraserOptions(LLMQueryRephraserOptions[LLMClientOptionsT]): - """ - Object representing the options for the multi query rephraser. - - Attributes: - llm_options: The options for the LLM. - n: The number of rephrasings to generate. - """ - - n: int = 5 - - -class MultiQueryRephraser(QueryRephraser[MultiQueryRephraserOptions[LLMClientOptionsT]]): - """ - A rephraser class that uses a LLM to generate reworded versions of input query. - """ - - options_cls: type[MultiQueryRephraserOptions] = MultiQueryRephraserOptions - - def __init__( - self, - llm: LLM[LLMClientOptionsT], - prompt: type[Prompt[MultiQueryRephraserInput, list[str]]] | None = None, - default_options: MultiQueryRephraserOptions[LLMClientOptionsT] | None = None, - ) -> None: - """ - Initialize the MultiQueryRephraser with a LLM. - - Args: - llm: A LLM instance to handle query rephrasing. - prompt: The prompt to use for rephrasing queries. - default_options: The default options for the rephraser. - """ - super().__init__(default_options=default_options) - self._llm = llm - self._prompt = prompt or MultiQueryRephraserPrompt - - @traceable - async def rephrase( - self, - query: str, - options: MultiQueryRephraserOptions[LLMClientOptionsT] | None = None, - ) -> Iterable[str]: - """ - Rephrase a given query using the LLM. - - Args: - query: The query to be rephrased. If not provided, a custom prompt must be given. - options: The options for the rephraser. - - Returns: - A list containing the reworded versions of input query. - - Raises: - LLMConnectionError: If there is a connection error with the LLM API. - LLMStatusError: If the LLM API returns an error status code. - LLMResponseError: If the LLM API response is invalid. - """ - merged_options = (self.default_options | options) if options else self.default_options - llm_options = merged_options.llm_options or None - prompt = self._prompt(MultiQueryRephraserInput(query=query, n=merged_options.n)) - response = await self._llm.generate(prompt, options=llm_options) - return [query] + response - - @classmethod - def from_config(cls, config: dict) -> Self: - """ - Create an instance of `MultiQueryRephraser` from a configuration dictionary. - - Args: - config: A dictionary containing configuration settings for the rephraser. - - Returns: - An instance of the rephraser class initialized with the provided configuration. - - Raises: - ValidationError: If the LLM or prompt configuration doesn't follow the expected format. - InvalidConfigError: If an LLM or prompt class can't be found or is not the correct type. - """ - config["llm"] = LLM.subclass_from_config(ObjectConstructionConfig.model_validate(config["llm"])) - config["prompt"] = ( - import_by_path(ObjectConstructionConfig.model_validate(config["prompt"]).type) - if "prompt" in config - else None - ) - return super().from_config(config) diff --git a/packages/ragbits-document-search/tests/unit/test_rephrasers.py b/packages/ragbits-document-search/tests/unit/test_rephrasers.py index 856b4d4f5..de11d02d7 100644 --- a/packages/ragbits-document-search/tests/unit/test_rephrasers.py +++ b/packages/ragbits-document-search/tests/unit/test_rephrasers.py @@ -1,8 +1,7 @@ from ragbits.core.llms.litellm import LiteLLM from ragbits.core.utils.config_handling import ObjectConstructionConfig from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser -from ragbits.document_search.retrieval.rephrasers.llm import LLMQueryRephraser, QueryRephraserPrompt -from ragbits.document_search.retrieval.rephrasers.multi import MultiQueryRephraser, MultiQueryRephraserPrompt +from ragbits.document_search.retrieval.rephrasers.llm import LLMQueryRephraser, LLMQueryRephraserPrompt from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser @@ -47,50 +46,15 @@ def test_subclass_from_config_llm_prompt(): "type": "ragbits.core.llms.litellm:LiteLLM", "config": {"model_name": "some_model"}, }, - "prompt": {"type": "ragbits.document_search.retrieval.rephrasers.llm:QueryRephraserPrompt"}, - }, - } - ) - rephraser: QueryRephraser = QueryRephraser.subclass_from_config(config) - assert isinstance(rephraser, LLMQueryRephraser) - assert isinstance(rephraser._llm, LiteLLM) - assert issubclass(rephraser._prompt, QueryRephraserPrompt) - - -def test_subclass_from_config_multi(): - config = ObjectConstructionConfig.model_validate( - { - "type": "ragbits.document_search.retrieval.rephrasers.multi:MultiQueryRephraser", - "config": { - "llm": { - "type": "ragbits.core.llms.litellm:LiteLLM", - "config": {"model_name": "some_model"}, - }, - }, - } - ) - rephraser: QueryRephraser = QueryRephraser.subclass_from_config(config) - assert isinstance(rephraser, MultiQueryRephraser) - assert isinstance(rephraser._llm, LiteLLM) - assert rephraser._llm.model_name == "some_model" - - -def test_subclass_from_config_multiquery_llm_prompt(): - config = ObjectConstructionConfig.model_validate( - { - "type": "ragbits.document_search.retrieval.rephrasers.multi:MultiQueryRephraser", - "config": { - "llm": { - "type": "ragbits.core.llms.litellm:LiteLLM", - "config": {"model_name": "some_model"}, + "prompt": {"type": "ragbits.document_search.retrieval.rephrasers.llm:LLMQueryRephraserPrompt"}, + "default_options": { + "n": 4, }, - "prompt": {"type": "ragbits.document_search.retrieval.rephrasers.multi:MultiQueryRephraserPrompt"}, - "default_options": {"n": 4}, }, } ) rephraser: QueryRephraser = QueryRephraser.subclass_from_config(config) - assert isinstance(rephraser, MultiQueryRephraser) + assert isinstance(rephraser, LLMQueryRephraser) assert isinstance(rephraser._llm, LiteLLM) + assert issubclass(rephraser._prompt, LLMQueryRephraserPrompt) assert rephraser.default_options.n == 4 - assert issubclass(rephraser._prompt, MultiQueryRephraserPrompt) From bb186251833f7abe07bda5ff88951f201de9bec2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= <47692610+micpst@users.noreply.github.com> Date: Mon, 19 May 2025 16:25:27 +0200 Subject: [PATCH 09/10] rename prompt input --- .../document_search/retrieval/rephrasers/__init__.py | 4 ++-- .../ragbits/document_search/retrieval/rephrasers/llm.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py index 6f9c22e71..c9532b488 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py @@ -1,17 +1,17 @@ from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser, QueryRephraserOptions from ragbits.document_search.retrieval.rephrasers.llm import ( LLMQueryRephraser, - LLMQueryRephraserInput, LLMQueryRephraserOptions, LLMQueryRephraserPrompt, + LLMQueryRephraserPromptInput, ) from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser __all__ = [ "LLMQueryRephraser", - "LLMQueryRephraserInput", "LLMQueryRephraserOptions", "LLMQueryRephraserPrompt", + "LLMQueryRephraserPromptInput", "NoopQueryRephraser", "QueryRephraser", "QueryRephraserOptions", diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py index 3c57709df..01474c34d 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py @@ -12,7 +12,7 @@ from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser, QueryRephraserOptions -class LLMQueryRephraserInput(BaseModel): +class LLMQueryRephraserPromptInput(BaseModel): """ Input data for the query rephraser prompt. """ @@ -21,7 +21,7 @@ class LLMQueryRephraserInput(BaseModel): n: int | None = None -class LLMQueryRephraserPrompt(Prompt[LLMQueryRephraserInput, list]): +class LLMQueryRephraserPrompt(Prompt[LLMQueryRephraserPromptInput, list]): """ Prompt for generating a rephrased user query. """ @@ -76,7 +76,7 @@ class LLMQueryRephraser(QueryRephraser[LLMQueryRephraserOptions[LLMClientOptions def __init__( self, llm: LLM[LLMClientOptionsT], - prompt: type[Prompt[LLMQueryRephraserInput, list[str]]] | None = None, + prompt: type[Prompt[LLMQueryRephraserPromptInput, list[str]]] | None = None, default_options: LLMQueryRephraserOptions[LLMClientOptionsT] | None = None, ) -> None: """ @@ -114,7 +114,7 @@ async def rephrase( """ merged_options = (self.default_options | options) if options else self.default_options llm_options = merged_options.llm_options or None - prompt = self._prompt(LLMQueryRephraserInput(query=query, n=merged_options.n or None)) + prompt = self._prompt(LLMQueryRephraserPromptInput(query=query, n=merged_options.n or None)) return await self._llm.generate(prompt, options=llm_options) @classmethod From 53020b034a2c394c611e84e1a30ba78e26d2c49c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= <47692610+micpst@users.noreply.github.com> Date: Mon, 19 May 2025 16:31:00 +0200 Subject: [PATCH 10/10] small fix --- .../src/ragbits/document_search/retrieval/rephrasers/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py index 01474c34d..f60dc1c84 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py @@ -48,7 +48,7 @@ class LLMQueryRephraserPrompt(Prompt[LLMQueryRephraserPromptInput, list]): @staticmethod def _response_parser(value: str) -> list[str]: - return [line.strip() for line in value.strip().split("\n") if line.strip()] + return [stripped_line for line in value.strip().split("\n") if (stripped_line := line.strip())] response_parser = _response_parser