From ff1ff7c78b34204087928d2109a87df6a8622b35 Mon Sep 17 00:00:00 2001 From: ramnes Date: Thu, 17 Aug 2023 15:12:42 +0200 Subject: [PATCH 1/6] Make sure we always use a running loop --- src/chainlit/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chainlit/context.py b/src/chainlit/context.py index 91d8e5daa1..0c40990226 100644 --- a/src/chainlit/context.py +++ b/src/chainlit/context.py @@ -23,7 +23,7 @@ class ChainlitContext: def __init__(self, session: Session): from chainlit.emitter import ChainlitEmitter - self.loop = asyncio.get_event_loop() + self.loop = asyncio.get_running_loop() self.session = session self.emitter = ChainlitEmitter(session) From 38fb0f794eb166af71de93b43f195d01262656cb Mon Sep 17 00:00:00 2001 From: ramnes Date: Thu, 17 Aug 2023 16:17:26 +0200 Subject: [PATCH 2/6] llama_index v0.8.3 now returns a ChatResponse --- cypress/e2e/llama_index_cb/main.py | 4 +++- src/chainlit/llama_index/__init__.py | 4 ++-- src/chainlit/llama_index/callbacks.py | 6 +++++- src/pyproject.toml | 4 ++-- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/cypress/e2e/llama_index_cb/main.py b/cypress/e2e/llama_index_cb/main.py index 23af1ccca8..f6bd520d04 100644 --- a/cypress/e2e/llama_index_cb/main.py +++ b/cypress/e2e/llama_index_cb/main.py @@ -1,4 +1,5 @@ from llama_index.callbacks.schema import CBEventType, EventPayload +from llama_index.llms.base import ChatMessage, ChatResponse from llama_index.schema import NodeWithScore, TextNode import chainlit as cl @@ -25,10 +26,11 @@ async def start(): cb.on_event_start(CBEventType.LLM) + response = ChatResponse(message=ChatMessage(content="This is the LLM response")) cb.on_event_end( CBEventType.LLM, payload={ - EventPayload.RESPONSE: "This is the LLM response", + EventPayload.RESPONSE: response, EventPayload.PROMPT: "This is the LLM prompt", }, ) diff --git a/src/chainlit/llama_index/__init__.py b/src/chainlit/llama_index/__init__.py index d140c242e0..11210c4f6f 100644 --- a/src/chainlit/llama_index/__init__.py +++ b/src/chainlit/llama_index/__init__.py @@ -1,9 +1,9 @@ try: import llama_index - if llama_index.__version__ < "0.6.27": + if llama_index.__version__ < "0.8.3": raise ValueError( - "LlamaIndex version is too old, expected >= 0.6.27. Run `pip install llama_index --upgrade`" + "LlamaIndex version is too old, expected >= 0.8.3. Run `pip install llama_index --upgrade`" ) LLAMA_INDEX_INSTALLED = True diff --git a/src/chainlit/llama_index/callbacks.py b/src/chainlit/llama_index/callbacks.py index dd11447699..8c2365823e 100644 --- a/src/chainlit/llama_index/callbacks.py +++ b/src/chainlit/llama_index/callbacks.py @@ -2,6 +2,7 @@ from llama_index.callbacks.base import BaseCallbackHandler from llama_index.callbacks.schema import CBEventType, EventPayload +from llama_index.llms.base import ChatResponse from chainlit.context import context from chainlit.element import Text @@ -85,9 +86,12 @@ def on_event_end( ) if event_type == CBEventType.LLM: + response = payload.get(EventPayload.RESPONSE) + content = response.message.content if response else "" + run_sync( Message( - content=payload.get(EventPayload.RESPONSE, ""), + content=content, author=event_type, parent_id=parent_id, prompt=payload.get(EventPayload.PROMPT), diff --git a/src/pyproject.toml b/src/pyproject.toml index 55a35d2bc7..224b2e815e 100644 --- a/src/pyproject.toml +++ b/src/pyproject.toml @@ -46,8 +46,8 @@ lazify = "^0.4.0" optional = true [tool.poetry.group.tests.dependencies] -langchain = "^0.0.229" -llama_index = "^0.7.4" +langchain = "^0.0.262" +llama_index = "^0.8.3" transformers = "^4.30.1" responses = "0.23.1" aioresponses = "0.7.4" From 113e51c50fa1a7d1e59e6739d130947ede804a8e Mon Sep 17 00:00:00 2001 From: ramnes Date: Thu, 17 Aug 2023 15:14:27 +0200 Subject: [PATCH 3/6] Restore the main thread context to the current llama_index thread --- src/chainlit/llama_index/callbacks.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/chainlit/llama_index/callbacks.py b/src/chainlit/llama_index/callbacks.py index 8c2365823e..1eac9bba8e 100644 --- a/src/chainlit/llama_index/callbacks.py +++ b/src/chainlit/llama_index/callbacks.py @@ -4,7 +4,7 @@ from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.llms.base import ChatResponse -from chainlit.context import context +from chainlit.context import context_var from chainlit.element import Text from chainlit.message import Message from chainlit.sync import run_sync @@ -31,6 +31,7 @@ def __init__( event_ends_to_ignore: List[CBEventType] = DEFAULT_IGNORE, ) -> None: """Initialize the base callback handler.""" + self.context = context_var.get() self.event_starts_to_ignore = tuple(event_starts_to_ignore) self.event_ends_to_ignore = tuple(event_ends_to_ignore) @@ -42,6 +43,8 @@ def on_event_start( **kwargs: Any, ) -> str: """Run when an event starts and return id of event.""" + context_var.set(self.context) + run_sync( Message( author=event_type, @@ -62,6 +65,16 @@ def on_event_end( if payload is None: return + # Chainlit context is local to the main thread, and LlamaIndex + # runs the callbacks in its own threads, so they don't have a + # Chainlit context by default. + # + # This line restores the context in which the callback handler + # has been created (it's always created in the main thread) + # before running the rest of the method, so that we can + # actually send messages. + context_var.set(self.context) + parent_id = self.root_message.id if self.root_message else None if event_type == CBEventType.RETRIEVE: @@ -100,7 +113,7 @@ def on_event_end( def start_trace(self, trace_id: Optional[str] = None) -> None: """Run when an overall trace is launched.""" - self.root_message = context.session.root_message + self.root_message = self.context.session.root_message def end_trace( self, From a1dd7734acd276728606d2ec25679af9a66e5cf5 Mon Sep 17 00:00:00 2001 From: ramnes Date: Thu, 17 Aug 2023 18:16:10 +0200 Subject: [PATCH 4/6] Don't use methods we don't need --- src/chainlit/llama_index/callbacks.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/src/chainlit/llama_index/callbacks.py b/src/chainlit/llama_index/callbacks.py index 1eac9bba8e..cec25d6cea 100644 --- a/src/chainlit/llama_index/callbacks.py +++ b/src/chainlit/llama_index/callbacks.py @@ -22,9 +22,6 @@ class LlamaIndexCallbackHandler(BaseCallbackHandler): """Base callback handler that can be used to track event starts and ends.""" - # Message at the root of the chat we should attach child messages to - root_message: Optional[Message] = None - def __init__( self, event_starts_to_ignore: List[CBEventType] = DEFAULT_IGNORE, @@ -75,7 +72,8 @@ def on_event_end( # actually send messages. context_var.set(self.context) - parent_id = self.root_message.id if self.root_message else None + root_message = self.context.session.root_message + parent_id = root_message.id if root_message else None if event_type == CBEventType.RETRIEVE: sources = payload.get(EventPayload.NODES) @@ -111,15 +109,8 @@ def on_event_end( ).send() ) - def start_trace(self, trace_id: Optional[str] = None) -> None: - """Run when an overall trace is launched.""" - self.root_message = self.context.session.root_message - - def end_trace( - self, - trace_id: Optional[str] = None, - trace_map: Optional[Dict[str, List[str]]] = None, - ) -> None: - """Run when an overall trace is exited.""" + def _noop(self, *args, **kwargs): + pass - self.root_message = None + start_trace = _noop + end_trace = _noop From 8c6f3d4730adab0b583bda1bb7a06f44afe22e37 Mon Sep 17 00:00:00 2001 From: ramnes Date: Fri, 18 Aug 2023 10:15:54 +0200 Subject: [PATCH 5/6] Sending a Message in on_event_start creates a deadlock --- src/chainlit/llama_index/callbacks.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/chainlit/llama_index/callbacks.py b/src/chainlit/llama_index/callbacks.py index cec25d6cea..f56802b4d2 100644 --- a/src/chainlit/llama_index/callbacks.py +++ b/src/chainlit/llama_index/callbacks.py @@ -32,24 +32,19 @@ def __init__( self.event_starts_to_ignore = tuple(event_starts_to_ignore) self.event_ends_to_ignore = tuple(event_ends_to_ignore) - def on_event_start( - self, - event_type: CBEventType, - payload: Optional[Dict[str, Any]] = None, - event_id: str = "", - **kwargs: Any, - ) -> str: + def start_trace(self, trace_id: Optional[str] = None) -> None: """Run when an event starts and return id of event.""" context_var.set(self.context) + root_message = self.context.session.root_message + parent_id = root_message.id if root_message else None run_sync( Message( - author=event_type, - indent=1, + author=trace_id or "llama_index", + parent_id=parent_id, content="", ).send() ) - return "" def on_event_end( self, @@ -112,5 +107,5 @@ def on_event_end( def _noop(self, *args, **kwargs): pass - start_trace = _noop + on_event_start = _noop end_trace = _noop From e1a797fd889160ef6a428165425c9ffe7f62929f Mon Sep 17 00:00:00 2001 From: ramnes Date: Fri, 18 Aug 2023 10:21:23 +0200 Subject: [PATCH 6/6] Don't repeat yourself --- src/chainlit/llama_index/callbacks.py | 43 +++++++++++++++------------ 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/src/chainlit/llama_index/callbacks.py b/src/chainlit/llama_index/callbacks.py index f56802b4d2..8256ba8170 100644 --- a/src/chainlit/llama_index/callbacks.py +++ b/src/chainlit/llama_index/callbacks.py @@ -32,16 +32,32 @@ def __init__( self.event_starts_to_ignore = tuple(event_starts_to_ignore) self.event_ends_to_ignore = tuple(event_ends_to_ignore) - def start_trace(self, trace_id: Optional[str] = None) -> None: - """Run when an event starts and return id of event.""" + def _restore_context(self) -> None: + """Restore Chainlit context in the current thread + + Chainlit context is local to the main thread, and LlamaIndex + runs the callbacks in its own threads, so they don't have a + Chainlit context by default. + + This method restores the context in which the callback handler + has been created (it's always created in the main thread), so + that we can actually send messages. + """ context_var.set(self.context) - root_message = self.context.session.root_message - parent_id = root_message.id if root_message else None + def _get_parent_id(self) -> Optional[str]: + """Get the parent message id""" + if root_message := self.context.session.root_message: + return root_message.id + return None + + def start_trace(self, trace_id: Optional[str] = None) -> None: + """Run when an event starts and return id of event.""" + self._restore_context() run_sync( Message( author=trace_id or "llama_index", - parent_id=parent_id, + parent_id=self._get_parent_id(), content="", ).send() ) @@ -57,18 +73,7 @@ def on_event_end( if payload is None: return - # Chainlit context is local to the main thread, and LlamaIndex - # runs the callbacks in its own threads, so they don't have a - # Chainlit context by default. - # - # This line restores the context in which the callback handler - # has been created (it's always created in the main thread) - # before running the rest of the method, so that we can - # actually send messages. - context_var.set(self.context) - - root_message = self.context.session.root_message - parent_id = root_message.id if root_message else None + self._restore_context() if event_type == CBEventType.RETRIEVE: sources = payload.get(EventPayload.NODES) @@ -87,7 +92,7 @@ def on_event_end( content=content, author=event_type, elements=elements, - parent_id=parent_id, + parent_id=self._get_parent_id(), ).send() ) @@ -99,7 +104,7 @@ def on_event_end( Message( content=content, author=event_type, - parent_id=parent_id, + parent_id=self._get_parent_id(), prompt=payload.get(EventPayload.PROMPT), ).send() )