diff --git a/pragmatic/main.py b/pragmatic/main.py index b22c53e..0883b5d 100644 --- a/pragmatic/main.py +++ b/pragmatic/main.py @@ -1,3 +1,8 @@ +# addressing the issue where the project structure causes pragmatic to not be on the path +import sys +import os +sys.path.insert(0, os.path.abspath(os.getcwd())) + import argparse from api import index_path_for_rag, execute_rag_query, evaluate_rag_pipeline @@ -26,7 +31,7 @@ def main(): args = parser.parse_args() - if sum([args.indexing, args.rag, args.evaluation, args.server]) != 1: + if sum([args.indexing, args.rag, args.evaluation]) != 1: print("Wrong usage: exactly one of the supported operation modes (indexing, query) must be specified.") return @@ -57,7 +62,15 @@ def main(): if args.query is None: print("Please specify the query.") return - print(execute_rag_query(args.query, **custom_settings)) + + # Execute query and get response + result = execute_rag_query(args.query, **custom_settings) + + if isinstance(result, str): # If a string, print directly + print(result) + else: # If a generator, process each chunk + for chunk in result: + print(chunk, end="", flush=True) if args.evaluation: print(evaluate_rag_pipeline(**custom_settings)) diff --git a/pragmatic/pipelines/rag.py b/pragmatic/pipelines/rag.py index 50b7abc..54e3555 100644 --- a/pragmatic/pipelines/rag.py +++ b/pragmatic/pipelines/rag.py @@ -10,6 +10,10 @@ from openai import OpenAI from pragmatic.pipelines.pipeline import CommonPipelineWrapper +from pragmatic.pipelines.streaming import RagStreamHandler + +from functools import cached_property + BASE_RAG_PROMPT = """You are an assistant for question-answering tasks. @@ -30,14 +34,17 @@ Answer: """ - class RagPipelineWrapper(CommonPipelineWrapper): def __init__(self, settings, query=None, evaluation_mode=False): super().__init__(settings) self._query = query - self._evaluation_mode = evaluation_mode + @cached_property + def _streaming_handler(self): + """Lazily creates and caches a RagStreamHandler on first access.""" + return RagStreamHandler(self._settings) + def _add_embedder(self, query): embedder = SentenceTransformersTextEmbedder(model=self._settings["embedding_model_path"]) self._add_component("embedder", embedder, component_args={"text": query}) @@ -103,12 +110,16 @@ def _add_prompt_builder(self): prompt_builder = PromptBuilder(template=BASE_RAG_PROMPT) self._add_component("prompt_builder", prompt_builder, component_args={"query": self._query}, component_to_connect_point="prompt_builder.documents") - + def _add_llm(self): if "generator_object" in self._settings and self._settings["generator_object"] is not None: # an object to use for communicating with the model was explicitly specified and we should use it self._add_component("llm", self._settings["generator_object"]) return + + custom_callback = None + if self._settings.get("enable_response_streaming"): + custom_callback = self._streaming_handler._streaming_callback llm = OpenAIGenerator( api_key=self._settings["llm_api_key"], @@ -118,6 +129,7 @@ def _add_llm(self): max_retries=self._settings["llm_connection_max_retries"], system_prompt=self._settings["llm_system_prompt"], organization=self._settings["llm_organization"], + streaming_callback=custom_callback, generation_kwargs={ "max_tokens": self._settings["llm_response_max_tokens"], "temperature": self._settings["llm_temperature"], @@ -164,16 +176,44 @@ def set_evaluation_mode(self, new_evaluation_mode): self._rebuild_pipeline() def run(self, query=None): - if query is not None: - # replace the query in the pipeline args + """ + Executes a query against the pipeline. + + If response enable_response_streaming is enabled, returns a generator that yields chunks of the response. + Otherwise, returns a string with the final response. + + If evaluation mode is enabled, the response is retrieved from the answer builder instead of the standard pipeline output. + + Parameters: + query (str, optional): The input query string to be processed. + + Returns: + str | generator: A string when enable_response_streaming is False or a generator when enable_response_streaming is True. + """ + # If a query is provided, update the relevant keys in the pipeline arguments + if query: for config_dict in self._args.values(): for key in ["text", "query"]: if key in config_dict: config_dict[key] = query + # Handle incompatible settings + if self._settings.get("enable_response_streaming", False) and self._evaluation_mode: + raise ValueError("Evaluation mode does not support streaming replies.") + + # If streaming is enabled, use the streaming_handler property + if self._settings.get("enable_response_streaming", False): + self._streaming_handler.start_stream(lambda: super(RagPipelineWrapper, self).run()) + return self._streaming_handler.stream_chunks() + + + # Otherwise, execute a normal pipeline run result = super().run() + # # In evaluation mode, return the answer from the answer builder in string format if self._evaluation_mode: - return result["answer_builder"]["answers"][0] + return result.get("answer_builder", {}).get("answers", [""])[0] + + # Return the final LLM reply in string format + return result.get("llm", {}).get("replies", [""])[0] - return result["llm"]["replies"][0] diff --git a/pragmatic/pipelines/streaming.py b/pragmatic/pipelines/streaming.py new file mode 100644 index 0000000..58f3a5a --- /dev/null +++ b/pragmatic/pipelines/streaming.py @@ -0,0 +1,60 @@ +from threading import Thread, Event +from queue import Queue + +class RagStreamHandler: + def __init__(self, settings): + self._timeout = settings.get("streaming_timeout", 60) # seconds + self._stream_queue = None + self._stream_thread = None + self._stop_event = Event() + + def __del__(self): + """ + Finalizer for cleanup. Ensures threads and resources are released. + """ + self.stop_stream() + + def _streaming_callback(self, chunk): + """ + Callback to be passed to the LLM, which places streamed tokens into the queue. + """ + if self._stream_queue: + self._stream_queue.put(chunk.content) + + def _run_streaming_in_thread(self, run_callable): + """ + Invokes the provided callable (the pipeline's run method), + then signals the end of streaming by putting None in the queue. + """ + run_callable() + if self._stream_queue: + self._stream_queue.put(None) + + def start_stream(self, run_callable): + """ + Initializes the queue and the background thread that executes `run_callable`. + """ + self._stream_queue = Queue() + self._stream_thread = Thread(target=self._run_streaming_in_thread, args=(run_callable,)) + self._stream_thread.start() + + def stop_stream(self): + """ + Gracefully stops the streaming thread if it is running. + """ + if self._stream_thread and self._stream_thread.is_alive(): + self._stop_event.set() # Signal the thread to stop + self._stream_thread.join() # Wait for the thread to finish + self._stream_queue = None # Clean up the queue + self._stop_event.clear() # Reset the flag for future use + + def stream_chunks(self): + """ + Yields streamed chunks from the queue. + Terminates when `None` is retrieved from the queue. + """ + try: + for chunk in iter(lambda: self._stream_queue.get(timeout=self._timeout), None): + yield chunk + finally: + self.stop_stream() diff --git a/pragmatic/settings.py b/pragmatic/settings.py index fe7f7da..f20ae22 100644 --- a/pragmatic/settings.py +++ b/pragmatic/settings.py @@ -6,7 +6,7 @@ # basic settings "vector_db_type": "milvus", "retriever_type": "dense", - "embedding_model_path": "./cache/finetuned_embedding-final", + "embedding_model_path": "sentence-transformers/all-MiniLM-L12-v2", #"./cache/finetuned_embedding-final", # document conversion-related settings "apply_docling": False, @@ -56,6 +56,9 @@ "llm_http_client": None, "generator_object": None, # an instance of a Haystack-compatible generator object + "enable_response_streaming": True, + 'streaming_timeout': 30, # Default timeout is 60 seconds if not specified + # advanced RAG options "top_k": 1, "cleaner_enabled": False, diff --git a/test/sanity_test.py b/test/sanity_test.py index 6b88dd0..64c50fb 100644 --- a/test/sanity_test.py +++ b/test/sanity_test.py @@ -42,6 +42,13 @@ def docling_convert(docs): result.document.save_as_json(Path(base_output_path + ".json")) print("Successfully converted the documents and saved as JSON in local directory specified \n") +def print_stream_data(result): + if isinstance(result, str): # If a string, print directly + print(result) + else: # If a generator, process each chunk + for chunk in result: + print(chunk, end="", flush=True) + def main(): # ensure the documents directory exists before the script is executed @@ -65,28 +72,28 @@ def main(): result1 = execute_rag_query("How to install OpenShift CLI on macOS?", milvus_file_path="./milvus.db", embedding_model_path="sentence-transformers/all-MiniLM-L12-v2", - llm_base_url="http://vllm-service:8000/v1", + llm_base_url="http://127.0.0.1:8000/v1", top_k=3) print("Response generated:") - print(f"\n{result1}") + print_stream_data(result1) print("\n") print("Question: What are the two deployment options in OpenShift AI?") result2 = execute_rag_query("What are the two deployment options in OpenShift AI?", milvus_file_path="./milvus.db", embedding_model_path="sentence-transformers/all-MiniLM-L12-v2", - llm_base_url="http://vllm-service:8000/v1", + llm_base_url="http://127.0.0.1:8000/v1", top_k=3) print("Response generated:") - print(f"\n{result2}") + print_stream_data(result2) print("\n") print("Question: What is OpenShift AI?") result3 = execute_rag_query("What is OpenShift AI?", milvus_file_path="./milvus.db", embedding_model_path="sentence-transformers/all-MiniLM-L12-v2", - llm_base_url="http://vllm-service:8000/v1", + llm_base_url="http://127.0.0.1:8000/v1", top_k=3) print("Response generated:") - print(f"\n{result3}") + print_stream_data(result3) if __name__ == '__main__':