Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added streaming functionality to the LLM. #35

Merged
merged 8 commits into from
Jan 31, 2025
15 changes: 13 additions & 2 deletions pragmatic/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# addressing the issue where the project structure causes pragmatic to not be on the path
import sys
import os
sys.path.insert(0, os.path.abspath(os.getcwd()))

import argparse

from api import index_path_for_rag, execute_rag_query, evaluate_rag_pipeline
Expand Down Expand Up @@ -26,7 +31,7 @@ def main():

args = parser.parse_args()

if sum([args.indexing, args.rag, args.evaluation, args.server]) != 1:
if sum([args.indexing, args.rag, args.evaluation]) != 1:
print("Wrong usage: exactly one of the supported operation modes (indexing, query) must be specified.")
return

Expand Down Expand Up @@ -57,7 +62,13 @@ def main():
if args.query is None:
print("Please specify the query.")
return
print(execute_rag_query(args.query, **custom_settings))

# Execute query and get a generator
result_generator = execute_rag_query(args.query, **custom_settings)

# Process each chunk as it arrives
for chunk in result_generator:
kevincogan marked this conversation as resolved.
Show resolved Hide resolved
print(chunk, end="", flush=True)

if args.evaluation:
print(evaluate_rag_pipeline(**custom_settings))
Expand Down
109 changes: 103 additions & 6 deletions pragmatic/pipelines/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

from pragmatic.pipelines.pipeline import CommonPipelineWrapper

from threading import Thread, Event
from queue import Queue


BASE_RAG_PROMPT = """You are an assistant for question-answering tasks.

Here is the context to use to answer the question:
Expand All @@ -31,13 +35,73 @@
"""


class RagStreamHandler:
kevincogan marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, settings):
self._timeout = settings.get("streaming_timeout", 60) # seconds
self._stream_queue = None
self._stream_thread = None
self._stop_event = Event()

def __del__(self):
"""
Finalizer for cleanup. Ensures threads and resources are released.
"""
self.stop_stream()

def _streaming_callback(self, chunk):
"""
Callback to be passed to the LLM, which places streamed tokens into the queue.
"""
if self._stream_queue:
self._stream_queue.put(chunk.content)

def _run_streaming_in_thread(self, run_callable):
"""
Invokes the provided callable (the pipeline's run method),
then signals the end of streaming by putting None in the queue.
"""
run_callable()
if self._stream_queue:
self._stream_queue.put(None)

def start_stream(self, run_callable):
"""
Initializes the queue and the background thread that executes `run_callable`.
"""
self._stream_queue = Queue()
self._stream_thread = Thread(target=self._run_streaming_in_thread, args=(run_callable,))
self._stream_thread.start()

def stop_stream(self):
"""
Gracefully stops the streaming thread if it is running.
"""
if self._stream_thread and self._stream_thread.is_alive():
self._stop_event.set() # Signal the thread to stop
self._stream_thread.join() # Wait for the thread to finish
self._stream_queue = None # Clean up the queue
self._stop_event.clear() # Reset the flag for future use

def stream_chunks(self):
"""
Yields streamed chunks from the queue.
Terminates when `None` is retrieved from the queue.
"""
try:
for chunk in iter(lambda: self._stream_queue.get(timeout=self._timeout), None):
yield chunk
finally:
self.stop_stream()


class RagPipelineWrapper(CommonPipelineWrapper):
def __init__(self, settings, query=None, evaluation_mode=False):
super().__init__(settings)
self._query = query

self._evaluation_mode = evaluation_mode

self._streaming_handler = RagStreamHandler(settings)
kevincogan marked this conversation as resolved.
Show resolved Hide resolved

def _add_embedder(self, query):
embedder = SentenceTransformersTextEmbedder(model=self._settings["embedding_model_path"])
self._add_component("embedder", embedder, component_args={"text": query})
Expand Down Expand Up @@ -103,12 +167,17 @@ def _add_prompt_builder(self):
prompt_builder = PromptBuilder(template=BASE_RAG_PROMPT)
self._add_component("prompt_builder", prompt_builder, component_args={"query": self._query},
component_to_connect_point="prompt_builder.documents")

def _add_llm(self):
if "generator_object" in self._settings and self._settings["generator_object"] is not None:
# an object to use for communicating with the model was explicitly specified and we should use it
self._add_component("llm", self._settings["generator_object"])
return

custom_callback = (
self._streaming_handler._streaming_callback
if self._settings.get("enable_response_streaming") else None
)

llm = OpenAIGenerator(
api_key=self._settings["llm_api_key"],
Expand All @@ -118,6 +187,7 @@ def _add_llm(self):
max_retries=self._settings["llm_connection_max_retries"],
system_prompt=self._settings["llm_system_prompt"],
organization=self._settings["llm_organization"],
streaming_callback=custom_callback,
generation_kwargs={
"max_tokens": self._settings["llm_response_max_tokens"],
"temperature": self._settings["llm_temperature"],
Expand Down Expand Up @@ -164,16 +234,43 @@ def set_evaluation_mode(self, new_evaluation_mode):
self._rebuild_pipeline()

def run(self, query=None):
if query is not None:
# replace the query in the pipeline args
"""
Executes a query against the pipeline.

If response streaming is enabled, returns a generator that yields chunks of the response.
Otherwise, returns a string with the final response.

If evaluation mode is enabled, the response is retrieved from the answer builder instead of the standard pipeline output.

Parameters:
query (str, optional): The input query string to be processed.

Returns:
str | generator: A string for non-streaming mode or a generator for streaming mode.
"""
# If a query is provided, update the relevant keys in the pipeline arguments
if query:
for config_dict in self._args.values():
for key in ["text", "query"]:
if key in config_dict:
config_dict[key] = query

# Handle incompatible settings
if self._settings.get("stream", False) and self._evaluation_mode:
kevincogan marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Evaluation mode does not support streaming replies.")

# Handle streaming mode if enabled
if self._settings.get("enable_response_streaming", False):
self._streaming_handler.start_stream(lambda: super(RagPipelineWrapper, self).run())
return self._streaming_handler.stream_chunks()

# Otherwise, execute a normal pipeline run
result = super().run()

# # In evaluation mode, return the answer from the answer builder in string format
if self._evaluation_mode:
return result["answer_builder"]["answers"][0]
return result.get("answer_builder", {}).get("answers", [""])[0]

# Return the final LLM reply in string format
return result.get("llm", {}).get("replies", [""])[0]

return result["llm"]["replies"][0]
5 changes: 4 additions & 1 deletion pragmatic/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# basic settings
"vector_db_type": "milvus",
"retriever_type": "dense",
"embedding_model_path": "./cache/finetuned_embedding-final",
"embedding_model_path": "sentence-transformers/all-MiniLM-L12-v2", #"./cache/finetuned_embedding-final",

# document conversion-related settings
"apply_docling": False,
Expand Down Expand Up @@ -56,6 +56,9 @@
"llm_http_client": None,
"generator_object": None, # an instance of a Haystack-compatible generator object

"enable_response_streaming": True,
'streaming_timeout': 30, # Default timeout is 60 seconds if not specified

# advanced RAG options
"top_k": 1,
"cleaner_enabled": False,
Expand Down
17 changes: 11 additions & 6 deletions test/sanity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def docling_convert(docs):
result.document.save_as_json(Path(base_output_path + ".json"))
print("Successfully converted the documents and saved as JSON in local directory specified \n")

def print_stream_data(result_generator):
kevincogan marked this conversation as resolved.
Show resolved Hide resolved
# Process each chunk as it arrives
for chunk in result_generator:
print(chunk, end="", flush=True)


def main():
# ensure the documents directory exists before the script is executed
Expand All @@ -65,28 +70,28 @@ def main():
result1 = execute_rag_query("How to install OpenShift CLI on macOS?",
milvus_file_path="./milvus.db",
embedding_model_path="sentence-transformers/all-MiniLM-L12-v2",
llm_base_url="http://vllm-service:8000/v1",
llm_base_url="http://127.0.0.1:8000/v1",
top_k=3)
print("Response generated:")
print(f"\n{result1}")
print_stream_data(result1)
print("\n")
print("Question: What are the two deployment options in OpenShift AI?")
result2 = execute_rag_query("What are the two deployment options in OpenShift AI?",
milvus_file_path="./milvus.db",
embedding_model_path="sentence-transformers/all-MiniLM-L12-v2",
llm_base_url="http://vllm-service:8000/v1",
llm_base_url="http://127.0.0.1:8000/v1",
top_k=3)
print("Response generated:")
print(f"\n{result2}")
print_stream_data(result2)
print("\n")
print("Question: What is OpenShift AI?")
result3 = execute_rag_query("What is OpenShift AI?",
milvus_file_path="./milvus.db",
embedding_model_path="sentence-transformers/all-MiniLM-L12-v2",
llm_base_url="http://vllm-service:8000/v1",
llm_base_url="http://127.0.0.1:8000/v1",
top_k=3)
print("Response generated:")
print(f"\n{result3}")
print_stream_data(result3)


if __name__ == '__main__':
Expand Down