From 573bc0751371e727a679524c28d23c55daf51116 Mon Sep 17 00:00:00 2001 From: Kevin Cogan Date: Thu, 23 Jan 2025 15:36:31 +0000 Subject: [PATCH 1/8] Added streaming functionality to the LLM. --- pragmatic/main.py | 16 ++++++- pragmatic/pipelines/rag.py | 91 ++++++++++++++++++++++++++++++++++++-- pragmatic/settings.py | 8 ++-- test/sanity_test.py | 12 +++-- 4 files changed, 116 insertions(+), 11 deletions(-) diff --git a/pragmatic/main.py b/pragmatic/main.py index b22c53e..b8bf357 100644 --- a/pragmatic/main.py +++ b/pragmatic/main.py @@ -1,3 +1,8 @@ +# addressing the issue where the project structure causes pragmatic to not be on the path +import sys +import os +sys.path.insert(0, os.path.abspath(os.getcwd())) + import argparse from api import index_path_for_rag, execute_rag_query, evaluate_rag_pipeline @@ -26,7 +31,7 @@ def main(): args = parser.parse_args() - if sum([args.indexing, args.rag, args.evaluation, args.server]) != 1: + if sum([args.indexing, args.rag, args.evaluation]) != 1: print("Wrong usage: exactly one of the supported operation modes (indexing, query) must be specified.") return @@ -57,7 +62,14 @@ def main(): if args.query is None: print("Please specify the query.") return - print(execute_rag_query(args.query, **custom_settings)) + + # Execute query and get a generator + result_generator = execute_rag_query(args.query, **custom_settings) + + # Process each chunk as it arrives + for chunk in result_generator: + decoded_chunk = chunk.decode("utf-8") + print(decoded_chunk, end="", flush=True) if args.evaluation: print(evaluate_rag_pipeline(**custom_settings)) diff --git a/pragmatic/pipelines/rag.py b/pragmatic/pipelines/rag.py index 50b7abc..f47698c 100644 --- a/pragmatic/pipelines/rag.py +++ b/pragmatic/pipelines/rag.py @@ -11,6 +11,12 @@ from pragmatic.pipelines.pipeline import CommonPipelineWrapper +from threading import Thread, Event +from queue import Queue + + +STREAM_TIMEOUT = 10 # seconds + BASE_RAG_PROMPT = """You are an assistant for question-answering tasks. Here is the context to use to answer the question: @@ -31,13 +37,73 @@ """ +class RagStreamHandler: + def __init__(self): + self._stream_queue = None + self._stream_thread = None + self._stop_event = Event() + + def __del__(self): + """ + Finalizer for cleanup. Ensures threads and resources are released. + """ + self.stop_stream() + + def _streaming_callback(self, chunk): + """ + Callback to be passed to the LLM, which places streamed tokens into the queue. + """ + if self._stream_queue: + self._stream_queue.put(chunk.content) + + def _run_streaming_in_thread(self, run_callable): + """ + Invokes the provided callable (the pipeline's run method), + then signals the end of streaming by putting None in the queue. + """ + run_callable() + if self._stream_queue: + self._stream_queue.put(None) + + def start_stream(self, run_callable): + """ + Initializes the queue and the background thread that executes `run_callable`. + """ + self._stream_queue = Queue() + self._stream_thread = Thread(target=self._run_streaming_in_thread, args=(run_callable,)) + self._stream_thread.start() + + def stop_stream(self): + """ + Gracefully stops the streaming thread if it is running. + """ + if self._stream_thread and self._stream_thread.is_alive(): + self._stop_event.set() # Signal the thread to stop + self._stream_thread.join() # Wait for the thread to finish + self._stream_queue = None # Clean up the queue + self._stop_event.clear() # Reset the flag for future use + + def stream_chunks(self): + """ + Yields streamed chunks from the queue. + Terminates when `None` is retrieved from the queue. + """ + try: + for chunk in iter(lambda: self._stream_queue.get(timeout=STREAM_TIMEOUT), None): + yield chunk.encode("utf-8") + finally: + self.stop_stream() + + class RagPipelineWrapper(CommonPipelineWrapper): def __init__(self, settings, query=None, evaluation_mode=False): super().__init__(settings) self._query = query - self._evaluation_mode = evaluation_mode + self.streaming_handler = RagStreamHandler() + self.llm = None + def _add_embedder(self, query): embedder = SentenceTransformersTextEmbedder(model=self._settings["embedding_model_path"]) self._add_component("embedder", embedder, component_args={"text": query}) @@ -104,11 +170,21 @@ def _add_prompt_builder(self): self._add_component("prompt_builder", prompt_builder, component_args={"query": self._query}, component_to_connect_point="prompt_builder.documents") + + def stream_query(self): + self.streaming_handler.start_stream(lambda: super(RagPipelineWrapper, self).run()) + return self.streaming_handler.stream_chunks() + def _add_llm(self): if "generator_object" in self._settings and self._settings["generator_object"] is not None: # an object to use for communicating with the model was explicitly specified and we should use it self._add_component("llm", self._settings["generator_object"]) return + + custom_callback = ( + self.streaming_handler._streaming_callback + if self._settings.get("stream") else None + ) llm = OpenAIGenerator( api_key=self._settings["llm_api_key"], @@ -118,6 +194,7 @@ def _add_llm(self): max_retries=self._settings["llm_connection_max_retries"], system_prompt=self._settings["llm_system_prompt"], organization=self._settings["llm_organization"], + streaming_callback=custom_callback, generation_kwargs={ "max_tokens": self._settings["llm_response_max_tokens"], "temperature": self._settings["llm_temperature"], @@ -170,10 +247,18 @@ def run(self, query=None): for key in ["text", "query"]: if key in config_dict: config_dict[key] = query + + # If streaming is enabled + if self._settings.get("stream", False) and not self._evaluation_mode: + return self.stream_query() + # Otherwise, execute a normal pipeline run result = super().run() + # Return answer builder output in evaluation mode if self._evaluation_mode: - return result["answer_builder"]["answers"][0] + return result.get("answer_builder").get("answers")[0] - return result["llm"]["replies"][0] + # Return the final LLM reply as utf-8 encoded bytes wrapped in a generator + llm_reply = result.get("llm", {}).get("replies", [None])[0].strip() + return (chunk.encode("utf-8") for chunk in [llm_reply]) # Wrap in a generator diff --git a/pragmatic/settings.py b/pragmatic/settings.py index fe7f7da..77d6f8a 100644 --- a/pragmatic/settings.py +++ b/pragmatic/settings.py @@ -6,7 +6,7 @@ # basic settings "vector_db_type": "milvus", "retriever_type": "dense", - "embedding_model_path": "./cache/finetuned_embedding-final", + "embedding_model_path": "sentence-transformers/all-MiniLM-L12-v2", #"./cache/finetuned_embedding-final", # document conversion-related settings "apply_docling": False, @@ -38,8 +38,8 @@ "converted_docling_document_format": "json", # LLM-related settings - "llm": "mistralai/Mistral-7B-Instruct-v0.2", - "llm_base_url": "http://127.0.0.1:8000/v1", + "llm": "instructlab/granite-7b-lab-GGUF", + "llm_base_url": "http://vllm-service:8000/v1", "llm_api_key": Secret.from_token("VLLM-PLACEHOLDER-API-KEY"), # use Secret.from_env_var("API_KEY_ENV_VAR_NAME") to enable authentication "llm_connection_timeout": 30, "llm_connection_max_retries": 3, @@ -56,6 +56,8 @@ "llm_http_client": None, "generator_object": None, # an instance of a Haystack-compatible generator object + "stream": True, + # advanced RAG options "top_k": 1, "cleaner_enabled": False, diff --git a/test/sanity_test.py b/test/sanity_test.py index 6b88dd0..939fa39 100644 --- a/test/sanity_test.py +++ b/test/sanity_test.py @@ -42,6 +42,12 @@ def docling_convert(docs): result.document.save_as_json(Path(base_output_path + ".json")) print("Successfully converted the documents and saved as JSON in local directory specified \n") +def print_stream_data(result_generator): + # Process each chunk as it arrives + for chunk in result_generator: + decoded_chunk = chunk.decode("utf-8") + print(decoded_chunk, end="", flush=True) + def main(): # ensure the documents directory exists before the script is executed @@ -68,7 +74,7 @@ def main(): llm_base_url="http://vllm-service:8000/v1", top_k=3) print("Response generated:") - print(f"\n{result1}") + print_stream_data(result1) print("\n") print("Question: What are the two deployment options in OpenShift AI?") result2 = execute_rag_query("What are the two deployment options in OpenShift AI?", @@ -77,7 +83,7 @@ def main(): llm_base_url="http://vllm-service:8000/v1", top_k=3) print("Response generated:") - print(f"\n{result2}") + print_stream_data(result2) print("\n") print("Question: What is OpenShift AI?") result3 = execute_rag_query("What is OpenShift AI?", @@ -86,7 +92,7 @@ def main(): llm_base_url="http://vllm-service:8000/v1", top_k=3) print("Response generated:") - print(f"\n{result3}") + print_stream_data(result3) if __name__ == '__main__': From 2cf69148ca141e098f5fabe7914bfb656ca168af Mon Sep 17 00:00:00 2001 From: Kevin Cogan Date: Wed, 29 Jan 2025 07:02:27 +0100 Subject: [PATCH 2/8] Resolved change request updates --- pragmatic/main.py | 3 +- pragmatic/pipelines/rag.py | 64 ++++++++++++++++++++++---------------- pragmatic/settings.py | 7 +++-- test/sanity_test.py | 9 +++--- 4 files changed, 47 insertions(+), 36 deletions(-) diff --git a/pragmatic/main.py b/pragmatic/main.py index b8bf357..0e06692 100644 --- a/pragmatic/main.py +++ b/pragmatic/main.py @@ -68,8 +68,7 @@ def main(): # Process each chunk as it arrives for chunk in result_generator: - decoded_chunk = chunk.decode("utf-8") - print(decoded_chunk, end="", flush=True) + print(chunk, end="", flush=True) if args.evaluation: print(evaluate_rag_pipeline(**custom_settings)) diff --git a/pragmatic/pipelines/rag.py b/pragmatic/pipelines/rag.py index f47698c..8f9a967 100644 --- a/pragmatic/pipelines/rag.py +++ b/pragmatic/pipelines/rag.py @@ -15,8 +15,6 @@ from queue import Queue -STREAM_TIMEOUT = 10 # seconds - BASE_RAG_PROMPT = """You are an assistant for question-answering tasks. Here is the context to use to answer the question: @@ -38,7 +36,8 @@ class RagStreamHandler: - def __init__(self): + def __init__(self, settings): + self._timeout = settings.get("streaming_timeout", 60) # seconds self._stream_queue = None self._stream_thread = None self._stop_event = Event() @@ -89,8 +88,8 @@ def stream_chunks(self): Terminates when `None` is retrieved from the queue. """ try: - for chunk in iter(lambda: self._stream_queue.get(timeout=STREAM_TIMEOUT), None): - yield chunk.encode("utf-8") + for chunk in iter(lambda: self._stream_queue.get(timeout=self._timeout), None): + yield chunk finally: self.stop_stream() @@ -101,8 +100,7 @@ def __init__(self, settings, query=None, evaluation_mode=False): self._query = query self._evaluation_mode = evaluation_mode - self.streaming_handler = RagStreamHandler() - self.llm = None + self._streaming_handler = RagStreamHandler(settings) def _add_embedder(self, query): embedder = SentenceTransformersTextEmbedder(model=self._settings["embedding_model_path"]) @@ -169,11 +167,6 @@ def _add_prompt_builder(self): prompt_builder = PromptBuilder(template=BASE_RAG_PROMPT) self._add_component("prompt_builder", prompt_builder, component_args={"query": self._query}, component_to_connect_point="prompt_builder.documents") - - - def stream_query(self): - self.streaming_handler.start_stream(lambda: super(RagPipelineWrapper, self).run()) - return self.streaming_handler.stream_chunks() def _add_llm(self): if "generator_object" in self._settings and self._settings["generator_object"] is not None: @@ -182,8 +175,8 @@ def _add_llm(self): return custom_callback = ( - self.streaming_handler._streaming_callback - if self._settings.get("stream") else None + self._streaming_handler._streaming_callback + if self._settings.get("enable_response_streaming") else None ) llm = OpenAIGenerator( @@ -241,24 +234,43 @@ def set_evaluation_mode(self, new_evaluation_mode): self._rebuild_pipeline() def run(self, query=None): - if query is not None: - # replace the query in the pipeline args + """ + Executes a query against the pipeline. + + If response streaming is enabled, returns a generator that yields chunks of the response. + Otherwise, returns a string with the final response. + + If evaluation mode is enabled, the response is retrieved from the answer builder instead of the standard pipeline output. + + Parameters: + query (str, optional): The input query string to be processed. + + Returns: + str | generator: A string for non-streaming mode or a generator for streaming mode. + """ + # If a query is provided, update the relevant keys in the pipeline arguments + if query: for config_dict in self._args.values(): for key in ["text", "query"]: if key in config_dict: config_dict[key] = query - - # If streaming is enabled - if self._settings.get("stream", False) and not self._evaluation_mode: - return self.stream_query() - # Otherwise, execute a normal pipeline run + # Handle incompatible settings + if self._settings.get("stream", False) and self._evaluation_mode: + raise ValueError("Evaluation mode does not support streaming replies.") + + # Handle streaming mode if enabled + if self._settings.get("enable_response_streaming", False): + self._streaming_handler.start_stream(lambda: super(RagPipelineWrapper, self).run()) + return self._streaming_handler.stream_chunks() + + # Otherwise, execute a normal pipeline run result = super().run() - # Return answer builder output in evaluation mode + # # In evaluation mode, return the answer from the answer builder in string format if self._evaluation_mode: - return result.get("answer_builder").get("answers")[0] + return result.get("answer_builder", {}).get("answers", [""])[0] + + # Return the final LLM reply in string format + return result.get("llm", {}).get("replies", [""])[0] - # Return the final LLM reply as utf-8 encoded bytes wrapped in a generator - llm_reply = result.get("llm", {}).get("replies", [None])[0].strip() - return (chunk.encode("utf-8") for chunk in [llm_reply]) # Wrap in a generator diff --git a/pragmatic/settings.py b/pragmatic/settings.py index 77d6f8a..f20ae22 100644 --- a/pragmatic/settings.py +++ b/pragmatic/settings.py @@ -38,8 +38,8 @@ "converted_docling_document_format": "json", # LLM-related settings - "llm": "instructlab/granite-7b-lab-GGUF", - "llm_base_url": "http://vllm-service:8000/v1", + "llm": "mistralai/Mistral-7B-Instruct-v0.2", + "llm_base_url": "http://127.0.0.1:8000/v1", "llm_api_key": Secret.from_token("VLLM-PLACEHOLDER-API-KEY"), # use Secret.from_env_var("API_KEY_ENV_VAR_NAME") to enable authentication "llm_connection_timeout": 30, "llm_connection_max_retries": 3, @@ -56,7 +56,8 @@ "llm_http_client": None, "generator_object": None, # an instance of a Haystack-compatible generator object - "stream": True, + "enable_response_streaming": True, + 'streaming_timeout': 30, # Default timeout is 60 seconds if not specified # advanced RAG options "top_k": 1, diff --git a/test/sanity_test.py b/test/sanity_test.py index 939fa39..d27535d 100644 --- a/test/sanity_test.py +++ b/test/sanity_test.py @@ -45,8 +45,7 @@ def docling_convert(docs): def print_stream_data(result_generator): # Process each chunk as it arrives for chunk in result_generator: - decoded_chunk = chunk.decode("utf-8") - print(decoded_chunk, end="", flush=True) + print(chunk, end="", flush=True) def main(): @@ -71,7 +70,7 @@ def main(): result1 = execute_rag_query("How to install OpenShift CLI on macOS?", milvus_file_path="./milvus.db", embedding_model_path="sentence-transformers/all-MiniLM-L12-v2", - llm_base_url="http://vllm-service:8000/v1", + llm_base_url="http://127.0.0.1:8000/v1", top_k=3) print("Response generated:") print_stream_data(result1) @@ -80,7 +79,7 @@ def main(): result2 = execute_rag_query("What are the two deployment options in OpenShift AI?", milvus_file_path="./milvus.db", embedding_model_path="sentence-transformers/all-MiniLM-L12-v2", - llm_base_url="http://vllm-service:8000/v1", + llm_base_url="http://127.0.0.1:8000/v1", top_k=3) print("Response generated:") print_stream_data(result2) @@ -89,7 +88,7 @@ def main(): result3 = execute_rag_query("What is OpenShift AI?", milvus_file_path="./milvus.db", embedding_model_path="sentence-transformers/all-MiniLM-L12-v2", - llm_base_url="http://vllm-service:8000/v1", + llm_base_url="http://127.0.0.1:8000/v1", top_k=3) print("Response generated:") print_stream_data(result3) From 4b3bda5d62647abc4ea1f8901270da5df7d6ad3e Mon Sep 17 00:00:00 2001 From: Kevin Cogan Date: Thu, 30 Jan 2025 18:40:39 +0100 Subject: [PATCH 3/8] refactor: Extracted RagStreamHandler Class into its own file (streaming.py) and updated references --- pragmatic/pipelines/rag.py | 65 +------------------------------- pragmatic/pipelines/streaming.py | 60 +++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 64 deletions(-) create mode 100644 pragmatic/pipelines/streaming.py diff --git a/pragmatic/pipelines/rag.py b/pragmatic/pipelines/rag.py index 8f9a967..bab76fd 100644 --- a/pragmatic/pipelines/rag.py +++ b/pragmatic/pipelines/rag.py @@ -10,10 +10,7 @@ from openai import OpenAI from pragmatic.pipelines.pipeline import CommonPipelineWrapper - -from threading import Thread, Event -from queue import Queue - +from pragmatic.pipelines.streaming import RagStreamHandler BASE_RAG_PROMPT = """You are an assistant for question-answering tasks. @@ -34,66 +31,6 @@ Answer: """ - -class RagStreamHandler: - def __init__(self, settings): - self._timeout = settings.get("streaming_timeout", 60) # seconds - self._stream_queue = None - self._stream_thread = None - self._stop_event = Event() - - def __del__(self): - """ - Finalizer for cleanup. Ensures threads and resources are released. - """ - self.stop_stream() - - def _streaming_callback(self, chunk): - """ - Callback to be passed to the LLM, which places streamed tokens into the queue. - """ - if self._stream_queue: - self._stream_queue.put(chunk.content) - - def _run_streaming_in_thread(self, run_callable): - """ - Invokes the provided callable (the pipeline's run method), - then signals the end of streaming by putting None in the queue. - """ - run_callable() - if self._stream_queue: - self._stream_queue.put(None) - - def start_stream(self, run_callable): - """ - Initializes the queue and the background thread that executes `run_callable`. - """ - self._stream_queue = Queue() - self._stream_thread = Thread(target=self._run_streaming_in_thread, args=(run_callable,)) - self._stream_thread.start() - - def stop_stream(self): - """ - Gracefully stops the streaming thread if it is running. - """ - if self._stream_thread and self._stream_thread.is_alive(): - self._stop_event.set() # Signal the thread to stop - self._stream_thread.join() # Wait for the thread to finish - self._stream_queue = None # Clean up the queue - self._stop_event.clear() # Reset the flag for future use - - def stream_chunks(self): - """ - Yields streamed chunks from the queue. - Terminates when `None` is retrieved from the queue. - """ - try: - for chunk in iter(lambda: self._stream_queue.get(timeout=self._timeout), None): - yield chunk - finally: - self.stop_stream() - - class RagPipelineWrapper(CommonPipelineWrapper): def __init__(self, settings, query=None, evaluation_mode=False): super().__init__(settings) diff --git a/pragmatic/pipelines/streaming.py b/pragmatic/pipelines/streaming.py new file mode 100644 index 0000000..58f3a5a --- /dev/null +++ b/pragmatic/pipelines/streaming.py @@ -0,0 +1,60 @@ +from threading import Thread, Event +from queue import Queue + +class RagStreamHandler: + def __init__(self, settings): + self._timeout = settings.get("streaming_timeout", 60) # seconds + self._stream_queue = None + self._stream_thread = None + self._stop_event = Event() + + def __del__(self): + """ + Finalizer for cleanup. Ensures threads and resources are released. + """ + self.stop_stream() + + def _streaming_callback(self, chunk): + """ + Callback to be passed to the LLM, which places streamed tokens into the queue. + """ + if self._stream_queue: + self._stream_queue.put(chunk.content) + + def _run_streaming_in_thread(self, run_callable): + """ + Invokes the provided callable (the pipeline's run method), + then signals the end of streaming by putting None in the queue. + """ + run_callable() + if self._stream_queue: + self._stream_queue.put(None) + + def start_stream(self, run_callable): + """ + Initializes the queue and the background thread that executes `run_callable`. + """ + self._stream_queue = Queue() + self._stream_thread = Thread(target=self._run_streaming_in_thread, args=(run_callable,)) + self._stream_thread.start() + + def stop_stream(self): + """ + Gracefully stops the streaming thread if it is running. + """ + if self._stream_thread and self._stream_thread.is_alive(): + self._stop_event.set() # Signal the thread to stop + self._stream_thread.join() # Wait for the thread to finish + self._stream_queue = None # Clean up the queue + self._stop_event.clear() # Reset the flag for future use + + def stream_chunks(self): + """ + Yields streamed chunks from the queue. + Terminates when `None` is retrieved from the queue. + """ + try: + for chunk in iter(lambda: self._stream_queue.get(timeout=self._timeout), None): + yield chunk + finally: + self.stop_stream() From e6881419cd6559ca02b2fa380d52b7b32bede6b5 Mon Sep 17 00:00:00 2001 From: Kevin Cogan Date: Thu, 30 Jan 2025 18:46:26 +0100 Subject: [PATCH 4/8] feat: Support both string and generator responses in query execution. --- pragmatic/main.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pragmatic/main.py b/pragmatic/main.py index 0e06692..0883b5d 100644 --- a/pragmatic/main.py +++ b/pragmatic/main.py @@ -63,12 +63,14 @@ def main(): print("Please specify the query.") return - # Execute query and get a generator - result_generator = execute_rag_query(args.query, **custom_settings) + # Execute query and get response + result = execute_rag_query(args.query, **custom_settings) - # Process each chunk as it arrives - for chunk in result_generator: - print(chunk, end="", flush=True) + if isinstance(result, str): # If a string, print directly + print(result) + else: # If a generator, process each chunk + for chunk in result: + print(chunk, end="", flush=True) if args.evaluation: print(evaluate_rag_pipeline(**custom_settings)) From 3bc876a266247a44ed905345d82319361a7133b3 Mon Sep 17 00:00:00 2001 From: Kevin Cogan Date: Thu, 30 Jan 2025 18:49:29 +0100 Subject: [PATCH 5/8] feat: Support both string and generator responses in query execution. --- test/sanity_test.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/sanity_test.py b/test/sanity_test.py index d27535d..64c50fb 100644 --- a/test/sanity_test.py +++ b/test/sanity_test.py @@ -42,10 +42,12 @@ def docling_convert(docs): result.document.save_as_json(Path(base_output_path + ".json")) print("Successfully converted the documents and saved as JSON in local directory specified \n") -def print_stream_data(result_generator): - # Process each chunk as it arrives - for chunk in result_generator: - print(chunk, end="", flush=True) +def print_stream_data(result): + if isinstance(result, str): # If a string, print directly + print(result) + else: # If a generator, process each chunk + for chunk in result: + print(chunk, end="", flush=True) def main(): From 17c7d11567740f45735d55e7d09d0902b751cb6e Mon Sep 17 00:00:00 2001 From: Kevin Cogan Date: Thu, 30 Jan 2025 19:26:27 +0100 Subject: [PATCH 6/8] refactor: changed docstring to contain the new streaming flag name (enable_response_streaming). --- pragmatic/pipelines/rag.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pragmatic/pipelines/rag.py b/pragmatic/pipelines/rag.py index bab76fd..8eef9ca 100644 --- a/pragmatic/pipelines/rag.py +++ b/pragmatic/pipelines/rag.py @@ -174,7 +174,7 @@ def run(self, query=None): """ Executes a query against the pipeline. - If response streaming is enabled, returns a generator that yields chunks of the response. + If response enable_response_streaming is enabled, returns a generator that yields chunks of the response. Otherwise, returns a string with the final response. If evaluation mode is enabled, the response is retrieved from the answer builder instead of the standard pipeline output. @@ -183,7 +183,7 @@ def run(self, query=None): query (str, optional): The input query string to be processed. Returns: - str | generator: A string for non-streaming mode or a generator for streaming mode. + str | generator: A string when enable_response_streaming is False or a generator when enable_response_streaming is True. """ # If a query is provided, update the relevant keys in the pipeline arguments if query: From 785933a162f799982f77b1351781298272453a7f Mon Sep 17 00:00:00 2001 From: Kevin Cogan Date: Thu, 30 Jan 2025 19:32:02 +0100 Subject: [PATCH 7/8] fix: change old streaming flag (stream) to enable_response_streaming. --- pragmatic/pipelines/rag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pragmatic/pipelines/rag.py b/pragmatic/pipelines/rag.py index 8eef9ca..92e8348 100644 --- a/pragmatic/pipelines/rag.py +++ b/pragmatic/pipelines/rag.py @@ -193,7 +193,7 @@ def run(self, query=None): config_dict[key] = query # Handle incompatible settings - if self._settings.get("stream", False) and self._evaluation_mode: + if self._settings.get("enable_response_streaming", False) and self._evaluation_mode: raise ValueError("Evaluation mode does not support streaming replies.") # Handle streaming mode if enabled From 9618a9f53061c169d7ef101212e272a93d00107d Mon Sep 17 00:00:00 2001 From: Kevin Cogan Date: Thu, 30 Jan 2025 22:49:06 +0100 Subject: [PATCH 8/8] refactor: Implemented lazy loading of the RagStreamHandler class. --- pragmatic/pipelines/rag.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/pragmatic/pipelines/rag.py b/pragmatic/pipelines/rag.py index 92e8348..54e3555 100644 --- a/pragmatic/pipelines/rag.py +++ b/pragmatic/pipelines/rag.py @@ -12,6 +12,9 @@ from pragmatic.pipelines.pipeline import CommonPipelineWrapper from pragmatic.pipelines.streaming import RagStreamHandler +from functools import cached_property + + BASE_RAG_PROMPT = """You are an assistant for question-answering tasks. Here is the context to use to answer the question: @@ -37,7 +40,10 @@ def __init__(self, settings, query=None, evaluation_mode=False): self._query = query self._evaluation_mode = evaluation_mode - self._streaming_handler = RagStreamHandler(settings) + @cached_property + def _streaming_handler(self): + """Lazily creates and caches a RagStreamHandler on first access.""" + return RagStreamHandler(self._settings) def _add_embedder(self, query): embedder = SentenceTransformersTextEmbedder(model=self._settings["embedding_model_path"]) @@ -111,10 +117,9 @@ def _add_llm(self): self._add_component("llm", self._settings["generator_object"]) return - custom_callback = ( - self._streaming_handler._streaming_callback - if self._settings.get("enable_response_streaming") else None - ) + custom_callback = None + if self._settings.get("enable_response_streaming"): + custom_callback = self._streaming_handler._streaming_callback llm = OpenAIGenerator( api_key=self._settings["llm_api_key"], @@ -196,11 +201,12 @@ def run(self, query=None): if self._settings.get("enable_response_streaming", False) and self._evaluation_mode: raise ValueError("Evaluation mode does not support streaming replies.") - # Handle streaming mode if enabled + # If streaming is enabled, use the streaming_handler property if self._settings.get("enable_response_streaming", False): self._streaming_handler.start_stream(lambda: super(RagPipelineWrapper, self).run()) return self._streaming_handler.stream_chunks() + # Otherwise, execute a normal pipeline run result = super().run()