diff --git a/cypress/e2e/llama_index_cb/main.py b/cypress/e2e/llama_index_cb/main.py index 23af1ccca8..f6bd520d04 100644 --- a/cypress/e2e/llama_index_cb/main.py +++ b/cypress/e2e/llama_index_cb/main.py @@ -1,4 +1,5 @@ from llama_index.callbacks.schema import CBEventType, EventPayload +from llama_index.llms.base import ChatMessage, ChatResponse from llama_index.schema import NodeWithScore, TextNode import chainlit as cl @@ -25,10 +26,11 @@ async def start(): cb.on_event_start(CBEventType.LLM) + response = ChatResponse(message=ChatMessage(content="This is the LLM response")) cb.on_event_end( CBEventType.LLM, payload={ - EventPayload.RESPONSE: "This is the LLM response", + EventPayload.RESPONSE: response, EventPayload.PROMPT: "This is the LLM prompt", }, ) diff --git a/src/chainlit/context.py b/src/chainlit/context.py index 91d8e5daa1..0c40990226 100644 --- a/src/chainlit/context.py +++ b/src/chainlit/context.py @@ -23,7 +23,7 @@ class ChainlitContext: def __init__(self, session: Session): from chainlit.emitter import ChainlitEmitter - self.loop = asyncio.get_event_loop() + self.loop = asyncio.get_running_loop() self.session = session self.emitter = ChainlitEmitter(session) diff --git a/src/chainlit/llama_index/__init__.py b/src/chainlit/llama_index/__init__.py index d140c242e0..11210c4f6f 100644 --- a/src/chainlit/llama_index/__init__.py +++ b/src/chainlit/llama_index/__init__.py @@ -1,9 +1,9 @@ try: import llama_index - if llama_index.__version__ < "0.6.27": + if llama_index.__version__ < "0.8.3": raise ValueError( - "LlamaIndex version is too old, expected >= 0.6.27. Run `pip install llama_index --upgrade`" + "LlamaIndex version is too old, expected >= 0.8.3. Run `pip install llama_index --upgrade`" ) LLAMA_INDEX_INSTALLED = True diff --git a/src/chainlit/llama_index/callbacks.py b/src/chainlit/llama_index/callbacks.py index dd11447699..8256ba8170 100644 --- a/src/chainlit/llama_index/callbacks.py +++ b/src/chainlit/llama_index/callbacks.py @@ -2,8 +2,9 @@ from llama_index.callbacks.base import BaseCallbackHandler from llama_index.callbacks.schema import CBEventType, EventPayload +from llama_index.llms.base import ChatResponse -from chainlit.context import context +from chainlit.context import context_var from chainlit.element import Text from chainlit.message import Message from chainlit.sync import run_sync @@ -21,34 +22,45 @@ class LlamaIndexCallbackHandler(BaseCallbackHandler): """Base callback handler that can be used to track event starts and ends.""" - # Message at the root of the chat we should attach child messages to - root_message: Optional[Message] = None - def __init__( self, event_starts_to_ignore: List[CBEventType] = DEFAULT_IGNORE, event_ends_to_ignore: List[CBEventType] = DEFAULT_IGNORE, ) -> None: """Initialize the base callback handler.""" + self.context = context_var.get() self.event_starts_to_ignore = tuple(event_starts_to_ignore) self.event_ends_to_ignore = tuple(event_ends_to_ignore) - def on_event_start( - self, - event_type: CBEventType, - payload: Optional[Dict[str, Any]] = None, - event_id: str = "", - **kwargs: Any, - ) -> str: + def _restore_context(self) -> None: + """Restore Chainlit context in the current thread + + Chainlit context is local to the main thread, and LlamaIndex + runs the callbacks in its own threads, so they don't have a + Chainlit context by default. + + This method restores the context in which the callback handler + has been created (it's always created in the main thread), so + that we can actually send messages. + """ + context_var.set(self.context) + + def _get_parent_id(self) -> Optional[str]: + """Get the parent message id""" + if root_message := self.context.session.root_message: + return root_message.id + return None + + def start_trace(self, trace_id: Optional[str] = None) -> None: """Run when an event starts and return id of event.""" + self._restore_context() run_sync( Message( - author=event_type, - indent=1, + author=trace_id or "llama_index", + parent_id=self._get_parent_id(), content="", ).send() ) - return "" def on_event_end( self, @@ -61,7 +73,7 @@ def on_event_end( if payload is None: return - parent_id = self.root_message.id if self.root_message else None + self._restore_context() if event_type == CBEventType.RETRIEVE: sources = payload.get(EventPayload.NODES) @@ -80,29 +92,25 @@ def on_event_end( content=content, author=event_type, elements=elements, - parent_id=parent_id, + parent_id=self._get_parent_id(), ).send() ) if event_type == CBEventType.LLM: + response = payload.get(EventPayload.RESPONSE) + content = response.message.content if response else "" + run_sync( Message( - content=payload.get(EventPayload.RESPONSE, ""), + content=content, author=event_type, - parent_id=parent_id, + parent_id=self._get_parent_id(), prompt=payload.get(EventPayload.PROMPT), ).send() ) - def start_trace(self, trace_id: Optional[str] = None) -> None: - """Run when an overall trace is launched.""" - self.root_message = context.session.root_message - - def end_trace( - self, - trace_id: Optional[str] = None, - trace_map: Optional[Dict[str, List[str]]] = None, - ) -> None: - """Run when an overall trace is exited.""" + def _noop(self, *args, **kwargs): + pass - self.root_message = None + on_event_start = _noop + end_trace = _noop diff --git a/src/pyproject.toml b/src/pyproject.toml index 55a35d2bc7..224b2e815e 100644 --- a/src/pyproject.toml +++ b/src/pyproject.toml @@ -46,8 +46,8 @@ lazify = "^0.4.0" optional = true [tool.poetry.group.tests.dependencies] -langchain = "^0.0.229" -llama_index = "^0.7.4" +langchain = "^0.0.262" +llama_index = "^0.8.3" transformers = "^4.30.1" responses = "0.23.1" aioresponses = "0.7.4"