Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added streaming functionality to the LLM. #35

Merged
merged 8 commits into from
Jan 31, 2025
16 changes: 14 additions & 2 deletions pragmatic/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# addressing the issue where the project structure causes pragmatic to not be on the path
import sys
import os
sys.path.insert(0, os.path.abspath(os.getcwd()))

import argparse

from api import index_path_for_rag, execute_rag_query, evaluate_rag_pipeline
Expand Down Expand Up @@ -26,7 +31,7 @@ def main():

args = parser.parse_args()

if sum([args.indexing, args.rag, args.evaluation, args.server]) != 1:
if sum([args.indexing, args.rag, args.evaluation]) != 1:
print("Wrong usage: exactly one of the supported operation modes (indexing, query) must be specified.")
return

Expand Down Expand Up @@ -57,7 +62,14 @@ def main():
if args.query is None:
print("Please specify the query.")
return
print(execute_rag_query(args.query, **custom_settings))

# Execute query and get a generator
result_generator = execute_rag_query(args.query, **custom_settings)

# Process each chunk as it arrives
for chunk in result_generator:
kevincogan marked this conversation as resolved.
Show resolved Hide resolved
decoded_chunk = chunk.decode("utf-8")
kevincogan marked this conversation as resolved.
Show resolved Hide resolved
print(decoded_chunk, end="", flush=True)

if args.evaluation:
print(evaluate_rag_pipeline(**custom_settings))
Expand Down
91 changes: 88 additions & 3 deletions pragmatic/pipelines/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@

from pragmatic.pipelines.pipeline import CommonPipelineWrapper

from threading import Thread, Event
from queue import Queue


STREAM_TIMEOUT = 10 # seconds
kevincogan marked this conversation as resolved.
Show resolved Hide resolved

BASE_RAG_PROMPT = """You are an assistant for question-answering tasks.

Here is the context to use to answer the question:
Expand All @@ -31,13 +37,73 @@
"""


class RagStreamHandler:
kevincogan marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self):
self._stream_queue = None
self._stream_thread = None
self._stop_event = Event()

def __del__(self):
"""
Finalizer for cleanup. Ensures threads and resources are released.
"""
self.stop_stream()

def _streaming_callback(self, chunk):
"""
Callback to be passed to the LLM, which places streamed tokens into the queue.
"""
if self._stream_queue:
self._stream_queue.put(chunk.content)

def _run_streaming_in_thread(self, run_callable):
"""
Invokes the provided callable (the pipeline's run method),
then signals the end of streaming by putting None in the queue.
"""
run_callable()
if self._stream_queue:
self._stream_queue.put(None)

def start_stream(self, run_callable):
"""
Initializes the queue and the background thread that executes `run_callable`.
"""
self._stream_queue = Queue()
self._stream_thread = Thread(target=self._run_streaming_in_thread, args=(run_callable,))
self._stream_thread.start()

def stop_stream(self):
"""
Gracefully stops the streaming thread if it is running.
"""
if self._stream_thread and self._stream_thread.is_alive():
self._stop_event.set() # Signal the thread to stop
self._stream_thread.join() # Wait for the thread to finish
self._stream_queue = None # Clean up the queue
self._stop_event.clear() # Reset the flag for future use

def stream_chunks(self):
"""
Yields streamed chunks from the queue.
Terminates when `None` is retrieved from the queue.
"""
try:
for chunk in iter(lambda: self._stream_queue.get(timeout=STREAM_TIMEOUT), None):
yield chunk.encode("utf-8")
finally:
self.stop_stream()


class RagPipelineWrapper(CommonPipelineWrapper):
def __init__(self, settings, query=None, evaluation_mode=False):
super().__init__(settings)
self._query = query

self._evaluation_mode = evaluation_mode

self.streaming_handler = RagStreamHandler()
kevincogan marked this conversation as resolved.
Show resolved Hide resolved
self.llm = None
kevincogan marked this conversation as resolved.
Show resolved Hide resolved

def _add_embedder(self, query):
embedder = SentenceTransformersTextEmbedder(model=self._settings["embedding_model_path"])
self._add_component("embedder", embedder, component_args={"text": query})
Expand Down Expand Up @@ -104,11 +170,21 @@ def _add_prompt_builder(self):
self._add_component("prompt_builder", prompt_builder, component_args={"query": self._query},
component_to_connect_point="prompt_builder.documents")


def stream_query(self):
kevincogan marked this conversation as resolved.
Show resolved Hide resolved
self.streaming_handler.start_stream(lambda: super(RagPipelineWrapper, self).run())
return self.streaming_handler.stream_chunks()

def _add_llm(self):
if "generator_object" in self._settings and self._settings["generator_object"] is not None:
# an object to use for communicating with the model was explicitly specified and we should use it
self._add_component("llm", self._settings["generator_object"])
return

custom_callback = (
self.streaming_handler._streaming_callback
if self._settings.get("stream") else None
)

llm = OpenAIGenerator(
api_key=self._settings["llm_api_key"],
Expand All @@ -118,6 +194,7 @@ def _add_llm(self):
max_retries=self._settings["llm_connection_max_retries"],
system_prompt=self._settings["llm_system_prompt"],
organization=self._settings["llm_organization"],
streaming_callback=custom_callback,
generation_kwargs={
"max_tokens": self._settings["llm_response_max_tokens"],
"temperature": self._settings["llm_temperature"],
Expand Down Expand Up @@ -170,10 +247,18 @@ def run(self, query=None):
for key in ["text", "query"]:
if key in config_dict:
config_dict[key] = query

# If streaming is enabled
if self._settings.get("stream", False) and not self._evaluation_mode:
kevincogan marked this conversation as resolved.
Show resolved Hide resolved
return self.stream_query()

# Otherwise, execute a normal pipeline run
result = super().run()

# Return answer builder output in evaluation mode
if self._evaluation_mode:
return result["answer_builder"]["answers"][0]
return result.get("answer_builder").get("answers")[0]

return result["llm"]["replies"][0]
# Return the final LLM reply as utf-8 encoded bytes wrapped in a generator
llm_reply = result.get("llm", {}).get("replies", [None])[0].strip()
kevincogan marked this conversation as resolved.
Show resolved Hide resolved
return (chunk.encode("utf-8") for chunk in [llm_reply]) # Wrap in a generator
kevincogan marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 5 additions & 3 deletions pragmatic/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# basic settings
"vector_db_type": "milvus",
"retriever_type": "dense",
"embedding_model_path": "./cache/finetuned_embedding-final",
"embedding_model_path": "sentence-transformers/all-MiniLM-L12-v2", #"./cache/finetuned_embedding-final",

# document conversion-related settings
"apply_docling": False,
Expand Down Expand Up @@ -38,8 +38,8 @@
"converted_docling_document_format": "json",

# LLM-related settings
"llm": "mistralai/Mistral-7B-Instruct-v0.2",
"llm_base_url": "http://127.0.0.1:8000/v1",
"llm": "instructlab/granite-7b-lab-GGUF",
"llm_base_url": "http://vllm-service:8000/v1",
"llm_api_key": Secret.from_token("VLLM-PLACEHOLDER-API-KEY"), # use Secret.from_env_var("API_KEY_ENV_VAR_NAME") to enable authentication
"llm_connection_timeout": 30,
"llm_connection_max_retries": 3,
Expand All @@ -56,6 +56,8 @@
"llm_http_client": None,
"generator_object": None, # an instance of a Haystack-compatible generator object

"stream": True,
kevincogan marked this conversation as resolved.
Show resolved Hide resolved

# advanced RAG options
"top_k": 1,
"cleaner_enabled": False,
Expand Down
12 changes: 9 additions & 3 deletions test/sanity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ def docling_convert(docs):
result.document.save_as_json(Path(base_output_path + ".json"))
print("Successfully converted the documents and saved as JSON in local directory specified \n")

def print_stream_data(result_generator):
kevincogan marked this conversation as resolved.
Show resolved Hide resolved
# Process each chunk as it arrives
for chunk in result_generator:
decoded_chunk = chunk.decode("utf-8")
print(decoded_chunk, end="", flush=True)


def main():
# ensure the documents directory exists before the script is executed
Expand All @@ -68,7 +74,7 @@ def main():
llm_base_url="http://vllm-service:8000/v1",
top_k=3)
print("Response generated:")
print(f"\n{result1}")
print_stream_data(result1)
print("\n")
print("Question: What are the two deployment options in OpenShift AI?")
result2 = execute_rag_query("What are the two deployment options in OpenShift AI?",
Expand All @@ -77,7 +83,7 @@ def main():
llm_base_url="http://vllm-service:8000/v1",
top_k=3)
print("Response generated:")
print(f"\n{result2}")
print_stream_data(result2)
print("\n")
print("Question: What is OpenShift AI?")
result3 = execute_rag_query("What is OpenShift AI?",
Expand All @@ -86,7 +92,7 @@ def main():
llm_base_url="http://vllm-service:8000/v1",
top_k=3)
print("Response generated:")
print(f"\n{result3}")
print_stream_data(result3)


if __name__ == '__main__':
Expand Down