Skip to content

Commit

Permalink
Merge pull request #35 from redhat-et/feature/add-streaming-to-llm
Browse files Browse the repository at this point in the history
Added streaming functionality to the LLM.
  • Loading branch information
ilya-kolchinsky authored Jan 31, 2025
2 parents a6133e9 + 9618a9f commit fdfc83d
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 16 deletions.
17 changes: 15 additions & 2 deletions pragmatic/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# addressing the issue where the project structure causes pragmatic to not be on the path
import sys
import os
sys.path.insert(0, os.path.abspath(os.getcwd()))

import argparse

from api import index_path_for_rag, execute_rag_query, evaluate_rag_pipeline
Expand Down Expand Up @@ -26,7 +31,7 @@ def main():

args = parser.parse_args()

if sum([args.indexing, args.rag, args.evaluation, args.server]) != 1:
if sum([args.indexing, args.rag, args.evaluation]) != 1:
print("Wrong usage: exactly one of the supported operation modes (indexing, query) must be specified.")
return

Expand Down Expand Up @@ -57,7 +62,15 @@ def main():
if args.query is None:
print("Please specify the query.")
return
print(execute_rag_query(args.query, **custom_settings))

# Execute query and get response
result = execute_rag_query(args.query, **custom_settings)

if isinstance(result, str): # If a string, print directly
print(result)
else: # If a generator, process each chunk
for chunk in result:
print(chunk, end="", flush=True)

if args.evaluation:
print(evaluate_rag_pipeline(**custom_settings))
Expand Down
54 changes: 47 additions & 7 deletions pragmatic/pipelines/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from openai import OpenAI

from pragmatic.pipelines.pipeline import CommonPipelineWrapper
from pragmatic.pipelines.streaming import RagStreamHandler

from functools import cached_property


BASE_RAG_PROMPT = """You are an assistant for question-answering tasks.
Expand All @@ -30,14 +34,17 @@
Answer:
"""


class RagPipelineWrapper(CommonPipelineWrapper):
def __init__(self, settings, query=None, evaluation_mode=False):
super().__init__(settings)
self._query = query

self._evaluation_mode = evaluation_mode

@cached_property
def _streaming_handler(self):
"""Lazily creates and caches a RagStreamHandler on first access."""
return RagStreamHandler(self._settings)

def _add_embedder(self, query):
embedder = SentenceTransformersTextEmbedder(model=self._settings["embedding_model_path"])
self._add_component("embedder", embedder, component_args={"text": query})
Expand Down Expand Up @@ -103,12 +110,16 @@ def _add_prompt_builder(self):
prompt_builder = PromptBuilder(template=BASE_RAG_PROMPT)
self._add_component("prompt_builder", prompt_builder, component_args={"query": self._query},
component_to_connect_point="prompt_builder.documents")

def _add_llm(self):
if "generator_object" in self._settings and self._settings["generator_object"] is not None:
# an object to use for communicating with the model was explicitly specified and we should use it
self._add_component("llm", self._settings["generator_object"])
return

custom_callback = None
if self._settings.get("enable_response_streaming"):
custom_callback = self._streaming_handler._streaming_callback

llm = OpenAIGenerator(
api_key=self._settings["llm_api_key"],
Expand All @@ -118,6 +129,7 @@ def _add_llm(self):
max_retries=self._settings["llm_connection_max_retries"],
system_prompt=self._settings["llm_system_prompt"],
organization=self._settings["llm_organization"],
streaming_callback=custom_callback,
generation_kwargs={
"max_tokens": self._settings["llm_response_max_tokens"],
"temperature": self._settings["llm_temperature"],
Expand Down Expand Up @@ -164,16 +176,44 @@ def set_evaluation_mode(self, new_evaluation_mode):
self._rebuild_pipeline()

def run(self, query=None):
if query is not None:
# replace the query in the pipeline args
"""
Executes a query against the pipeline.
If response enable_response_streaming is enabled, returns a generator that yields chunks of the response.
Otherwise, returns a string with the final response.
If evaluation mode is enabled, the response is retrieved from the answer builder instead of the standard pipeline output.
Parameters:
query (str, optional): The input query string to be processed.
Returns:
str | generator: A string when enable_response_streaming is False or a generator when enable_response_streaming is True.
"""
# If a query is provided, update the relevant keys in the pipeline arguments
if query:
for config_dict in self._args.values():
for key in ["text", "query"]:
if key in config_dict:
config_dict[key] = query

# Handle incompatible settings
if self._settings.get("enable_response_streaming", False) and self._evaluation_mode:
raise ValueError("Evaluation mode does not support streaming replies.")

# If streaming is enabled, use the streaming_handler property
if self._settings.get("enable_response_streaming", False):
self._streaming_handler.start_stream(lambda: super(RagPipelineWrapper, self).run())
return self._streaming_handler.stream_chunks()


# Otherwise, execute a normal pipeline run
result = super().run()

# # In evaluation mode, return the answer from the answer builder in string format
if self._evaluation_mode:
return result["answer_builder"]["answers"][0]
return result.get("answer_builder", {}).get("answers", [""])[0]

# Return the final LLM reply in string format
return result.get("llm", {}).get("replies", [""])[0]

return result["llm"]["replies"][0]
60 changes: 60 additions & 0 deletions pragmatic/pipelines/streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from threading import Thread, Event
from queue import Queue

class RagStreamHandler:
def __init__(self, settings):
self._timeout = settings.get("streaming_timeout", 60) # seconds
self._stream_queue = None
self._stream_thread = None
self._stop_event = Event()

def __del__(self):
"""
Finalizer for cleanup. Ensures threads and resources are released.
"""
self.stop_stream()

def _streaming_callback(self, chunk):
"""
Callback to be passed to the LLM, which places streamed tokens into the queue.
"""
if self._stream_queue:
self._stream_queue.put(chunk.content)

def _run_streaming_in_thread(self, run_callable):
"""
Invokes the provided callable (the pipeline's run method),
then signals the end of streaming by putting None in the queue.
"""
run_callable()
if self._stream_queue:
self._stream_queue.put(None)

def start_stream(self, run_callable):
"""
Initializes the queue and the background thread that executes `run_callable`.
"""
self._stream_queue = Queue()
self._stream_thread = Thread(target=self._run_streaming_in_thread, args=(run_callable,))
self._stream_thread.start()

def stop_stream(self):
"""
Gracefully stops the streaming thread if it is running.
"""
if self._stream_thread and self._stream_thread.is_alive():
self._stop_event.set() # Signal the thread to stop
self._stream_thread.join() # Wait for the thread to finish
self._stream_queue = None # Clean up the queue
self._stop_event.clear() # Reset the flag for future use

def stream_chunks(self):
"""
Yields streamed chunks from the queue.
Terminates when `None` is retrieved from the queue.
"""
try:
for chunk in iter(lambda: self._stream_queue.get(timeout=self._timeout), None):
yield chunk
finally:
self.stop_stream()
5 changes: 4 additions & 1 deletion pragmatic/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# basic settings
"vector_db_type": "milvus",
"retriever_type": "dense",
"embedding_model_path": "./cache/finetuned_embedding-final",
"embedding_model_path": "sentence-transformers/all-MiniLM-L12-v2", #"./cache/finetuned_embedding-final",

# document conversion-related settings
"apply_docling": False,
Expand Down Expand Up @@ -56,6 +56,9 @@
"llm_http_client": None,
"generator_object": None, # an instance of a Haystack-compatible generator object

"enable_response_streaming": True,
'streaming_timeout': 30, # Default timeout is 60 seconds if not specified

# advanced RAG options
"top_k": 1,
"cleaner_enabled": False,
Expand Down
19 changes: 13 additions & 6 deletions test/sanity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def docling_convert(docs):
result.document.save_as_json(Path(base_output_path + ".json"))
print("Successfully converted the documents and saved as JSON in local directory specified \n")

def print_stream_data(result):
if isinstance(result, str): # If a string, print directly
print(result)
else: # If a generator, process each chunk
for chunk in result:
print(chunk, end="", flush=True)


def main():
# ensure the documents directory exists before the script is executed
Expand All @@ -65,28 +72,28 @@ def main():
result1 = execute_rag_query("How to install OpenShift CLI on macOS?",
milvus_file_path="./milvus.db",
embedding_model_path="sentence-transformers/all-MiniLM-L12-v2",
llm_base_url="http://vllm-service:8000/v1",
llm_base_url="http://127.0.0.1:8000/v1",
top_k=3)
print("Response generated:")
print(f"\n{result1}")
print_stream_data(result1)
print("\n")
print("Question: What are the two deployment options in OpenShift AI?")
result2 = execute_rag_query("What are the two deployment options in OpenShift AI?",
milvus_file_path="./milvus.db",
embedding_model_path="sentence-transformers/all-MiniLM-L12-v2",
llm_base_url="http://vllm-service:8000/v1",
llm_base_url="http://127.0.0.1:8000/v1",
top_k=3)
print("Response generated:")
print(f"\n{result2}")
print_stream_data(result2)
print("\n")
print("Question: What is OpenShift AI?")
result3 = execute_rag_query("What is OpenShift AI?",
milvus_file_path="./milvus.db",
embedding_model_path="sentence-transformers/all-MiniLM-L12-v2",
llm_base_url="http://vllm-service:8000/v1",
llm_base_url="http://127.0.0.1:8000/v1",
top_k=3)
print("Response generated:")
print(f"\n{result3}")
print_stream_data(result3)


if __name__ == '__main__':
Expand Down

0 comments on commit fdfc83d

Please sign in to comment.