Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added streaming functionality to the LLM. #35

Merged
merged 8 commits into from
Jan 31, 2025
17 changes: 15 additions & 2 deletions pragmatic/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# addressing the issue where the project structure causes pragmatic to not be on the path
import sys
import os
sys.path.insert(0, os.path.abspath(os.getcwd()))

import argparse

from api import index_path_for_rag, execute_rag_query, evaluate_rag_pipeline
Expand Down Expand Up @@ -26,7 +31,7 @@ def main():

args = parser.parse_args()

if sum([args.indexing, args.rag, args.evaluation, args.server]) != 1:
if sum([args.indexing, args.rag, args.evaluation]) != 1:
print("Wrong usage: exactly one of the supported operation modes (indexing, query) must be specified.")
return

Expand Down Expand Up @@ -57,7 +62,15 @@ def main():
if args.query is None:
print("Please specify the query.")
return
print(execute_rag_query(args.query, **custom_settings))

# Execute query and get response
result = execute_rag_query(args.query, **custom_settings)

if isinstance(result, str): # If a string, print directly
print(result)
else: # If a generator, process each chunk
for chunk in result:
print(chunk, end="", flush=True)

if args.evaluation:
print(evaluate_rag_pipeline(**custom_settings))
Expand Down
54 changes: 47 additions & 7 deletions pragmatic/pipelines/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from openai import OpenAI

from pragmatic.pipelines.pipeline import CommonPipelineWrapper
from pragmatic.pipelines.streaming import RagStreamHandler

from functools import cached_property


BASE_RAG_PROMPT = """You are an assistant for question-answering tasks.

Expand All @@ -30,14 +34,17 @@
Answer:
"""


class RagPipelineWrapper(CommonPipelineWrapper):
def __init__(self, settings, query=None, evaluation_mode=False):
super().__init__(settings)
self._query = query

self._evaluation_mode = evaluation_mode

@cached_property
def _streaming_handler(self):
"""Lazily creates and caches a RagStreamHandler on first access."""
return RagStreamHandler(self._settings)

def _add_embedder(self, query):
embedder = SentenceTransformersTextEmbedder(model=self._settings["embedding_model_path"])
self._add_component("embedder", embedder, component_args={"text": query})
Expand Down Expand Up @@ -103,12 +110,16 @@ def _add_prompt_builder(self):
prompt_builder = PromptBuilder(template=BASE_RAG_PROMPT)
self._add_component("prompt_builder", prompt_builder, component_args={"query": self._query},
component_to_connect_point="prompt_builder.documents")

def _add_llm(self):
if "generator_object" in self._settings and self._settings["generator_object"] is not None:
# an object to use for communicating with the model was explicitly specified and we should use it
self._add_component("llm", self._settings["generator_object"])
return

custom_callback = None
if self._settings.get("enable_response_streaming"):
custom_callback = self._streaming_handler._streaming_callback

llm = OpenAIGenerator(
api_key=self._settings["llm_api_key"],
Expand All @@ -118,6 +129,7 @@ def _add_llm(self):
max_retries=self._settings["llm_connection_max_retries"],
system_prompt=self._settings["llm_system_prompt"],
organization=self._settings["llm_organization"],
streaming_callback=custom_callback,
generation_kwargs={
"max_tokens": self._settings["llm_response_max_tokens"],
"temperature": self._settings["llm_temperature"],
Expand Down Expand Up @@ -164,16 +176,44 @@ def set_evaluation_mode(self, new_evaluation_mode):
self._rebuild_pipeline()

def run(self, query=None):
if query is not None:
# replace the query in the pipeline args
"""
Executes a query against the pipeline.

If response enable_response_streaming is enabled, returns a generator that yields chunks of the response.
Otherwise, returns a string with the final response.

If evaluation mode is enabled, the response is retrieved from the answer builder instead of the standard pipeline output.

Parameters:
query (str, optional): The input query string to be processed.

Returns:
str | generator: A string when enable_response_streaming is False or a generator when enable_response_streaming is True.
"""
# If a query is provided, update the relevant keys in the pipeline arguments
if query:
for config_dict in self._args.values():
for key in ["text", "query"]:
if key in config_dict:
config_dict[key] = query

# Handle incompatible settings
if self._settings.get("enable_response_streaming", False) and self._evaluation_mode:
raise ValueError("Evaluation mode does not support streaming replies.")

# If streaming is enabled, use the streaming_handler property
if self._settings.get("enable_response_streaming", False):
self._streaming_handler.start_stream(lambda: super(RagPipelineWrapper, self).run())
return self._streaming_handler.stream_chunks()


# Otherwise, execute a normal pipeline run
result = super().run()

# # In evaluation mode, return the answer from the answer builder in string format
if self._evaluation_mode:
return result["answer_builder"]["answers"][0]
return result.get("answer_builder", {}).get("answers", [""])[0]

# Return the final LLM reply in string format
return result.get("llm", {}).get("replies", [""])[0]

return result["llm"]["replies"][0]
60 changes: 60 additions & 0 deletions pragmatic/pipelines/streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from threading import Thread, Event
from queue import Queue

class RagStreamHandler:
def __init__(self, settings):
self._timeout = settings.get("streaming_timeout", 60) # seconds
self._stream_queue = None
self._stream_thread = None
self._stop_event = Event()

def __del__(self):
"""
Finalizer for cleanup. Ensures threads and resources are released.
"""
self.stop_stream()

def _streaming_callback(self, chunk):
"""
Callback to be passed to the LLM, which places streamed tokens into the queue.
"""
if self._stream_queue:
self._stream_queue.put(chunk.content)

def _run_streaming_in_thread(self, run_callable):
"""
Invokes the provided callable (the pipeline's run method),
then signals the end of streaming by putting None in the queue.
"""
run_callable()
if self._stream_queue:
self._stream_queue.put(None)

def start_stream(self, run_callable):
"""
Initializes the queue and the background thread that executes `run_callable`.
"""
self._stream_queue = Queue()
self._stream_thread = Thread(target=self._run_streaming_in_thread, args=(run_callable,))
self._stream_thread.start()

def stop_stream(self):
"""
Gracefully stops the streaming thread if it is running.
"""
if self._stream_thread and self._stream_thread.is_alive():
self._stop_event.set() # Signal the thread to stop
self._stream_thread.join() # Wait for the thread to finish
self._stream_queue = None # Clean up the queue
self._stop_event.clear() # Reset the flag for future use

def stream_chunks(self):
"""
Yields streamed chunks from the queue.
Terminates when `None` is retrieved from the queue.
"""
try:
for chunk in iter(lambda: self._stream_queue.get(timeout=self._timeout), None):
yield chunk
finally:
self.stop_stream()
5 changes: 4 additions & 1 deletion pragmatic/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# basic settings
"vector_db_type": "milvus",
"retriever_type": "dense",
"embedding_model_path": "./cache/finetuned_embedding-final",
"embedding_model_path": "sentence-transformers/all-MiniLM-L12-v2", #"./cache/finetuned_embedding-final",

# document conversion-related settings
"apply_docling": False,
Expand Down Expand Up @@ -56,6 +56,9 @@
"llm_http_client": None,
"generator_object": None, # an instance of a Haystack-compatible generator object

"enable_response_streaming": True,
'streaming_timeout': 30, # Default timeout is 60 seconds if not specified

# advanced RAG options
"top_k": 1,
"cleaner_enabled": False,
Expand Down
19 changes: 13 additions & 6 deletions test/sanity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def docling_convert(docs):
result.document.save_as_json(Path(base_output_path + ".json"))
print("Successfully converted the documents and saved as JSON in local directory specified \n")

def print_stream_data(result):
if isinstance(result, str): # If a string, print directly
print(result)
else: # If a generator, process each chunk
for chunk in result:
print(chunk, end="", flush=True)


def main():
# ensure the documents directory exists before the script is executed
Expand All @@ -65,28 +72,28 @@ def main():
result1 = execute_rag_query("How to install OpenShift CLI on macOS?",
milvus_file_path="./milvus.db",
embedding_model_path="sentence-transformers/all-MiniLM-L12-v2",
llm_base_url="http://vllm-service:8000/v1",
llm_base_url="http://127.0.0.1:8000/v1",
top_k=3)
print("Response generated:")
print(f"\n{result1}")
print_stream_data(result1)
print("\n")
print("Question: What are the two deployment options in OpenShift AI?")
result2 = execute_rag_query("What are the two deployment options in OpenShift AI?",
milvus_file_path="./milvus.db",
embedding_model_path="sentence-transformers/all-MiniLM-L12-v2",
llm_base_url="http://vllm-service:8000/v1",
llm_base_url="http://127.0.0.1:8000/v1",
top_k=3)
print("Response generated:")
print(f"\n{result2}")
print_stream_data(result2)
print("\n")
print("Question: What is OpenShift AI?")
result3 = execute_rag_query("What is OpenShift AI?",
milvus_file_path="./milvus.db",
embedding_model_path="sentence-transformers/all-MiniLM-L12-v2",
llm_base_url="http://vllm-service:8000/v1",
llm_base_url="http://127.0.0.1:8000/v1",
top_k=3)
print("Response generated:")
print(f"\n{result3}")
print_stream_data(result3)


if __name__ == '__main__':
Expand Down