Skip to content

Commit

Permalink
Merge pull request Chainlit#290 from Chainlit/gg/fix-llama-index
Browse files Browse the repository at this point in the history
Fix LlamaIndexCallbackHandler
  • Loading branch information
ramnes authored Aug 18, 2023
2 parents 211bc27 + e1a797f commit fb8853a
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 35 deletions.
4 changes: 3 additions & 1 deletion cypress/e2e/llama_index_cb/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.llms.base import ChatMessage, ChatResponse
from llama_index.schema import NodeWithScore, TextNode

import chainlit as cl
Expand All @@ -25,10 +26,11 @@ async def start():

cb.on_event_start(CBEventType.LLM)

response = ChatResponse(message=ChatMessage(content="This is the LLM response"))
cb.on_event_end(
CBEventType.LLM,
payload={
EventPayload.RESPONSE: "This is the LLM response",
EventPayload.RESPONSE: response,
EventPayload.PROMPT: "This is the LLM prompt",
},
)
2 changes: 1 addition & 1 deletion src/chainlit/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ChainlitContext:
def __init__(self, session: Session):
from chainlit.emitter import ChainlitEmitter

self.loop = asyncio.get_event_loop()
self.loop = asyncio.get_running_loop()
self.session = session
self.emitter = ChainlitEmitter(session)

Expand Down
4 changes: 2 additions & 2 deletions src/chainlit/llama_index/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
try:
import llama_index

if llama_index.__version__ < "0.6.27":
if llama_index.__version__ < "0.8.3":
raise ValueError(
"LlamaIndex version is too old, expected >= 0.6.27. Run `pip install llama_index --upgrade`"
"LlamaIndex version is too old, expected >= 0.8.3. Run `pip install llama_index --upgrade`"
)

LLAMA_INDEX_INSTALLED = True
Expand Down
66 changes: 37 additions & 29 deletions src/chainlit/llama_index/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from llama_index.callbacks.base import BaseCallbackHandler
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.llms.base import ChatResponse

from chainlit.context import context
from chainlit.context import context_var
from chainlit.element import Text
from chainlit.message import Message
from chainlit.sync import run_sync
Expand All @@ -21,34 +22,45 @@
class LlamaIndexCallbackHandler(BaseCallbackHandler):
"""Base callback handler that can be used to track event starts and ends."""

# Message at the root of the chat we should attach child messages to
root_message: Optional[Message] = None

def __init__(
self,
event_starts_to_ignore: List[CBEventType] = DEFAULT_IGNORE,
event_ends_to_ignore: List[CBEventType] = DEFAULT_IGNORE,
) -> None:
"""Initialize the base callback handler."""
self.context = context_var.get()
self.event_starts_to_ignore = tuple(event_starts_to_ignore)
self.event_ends_to_ignore = tuple(event_ends_to_ignore)

def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> str:
def _restore_context(self) -> None:
"""Restore Chainlit context in the current thread
Chainlit context is local to the main thread, and LlamaIndex
runs the callbacks in its own threads, so they don't have a
Chainlit context by default.
This method restores the context in which the callback handler
has been created (it's always created in the main thread), so
that we can actually send messages.
"""
context_var.set(self.context)

def _get_parent_id(self) -> Optional[str]:
"""Get the parent message id"""
if root_message := self.context.session.root_message:
return root_message.id
return None

def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an event starts and return id of event."""
self._restore_context()
run_sync(
Message(
author=event_type,
indent=1,
author=trace_id or "llama_index",
parent_id=self._get_parent_id(),
content="",
).send()
)
return ""

def on_event_end(
self,
Expand All @@ -61,7 +73,7 @@ def on_event_end(
if payload is None:
return

parent_id = self.root_message.id if self.root_message else None
self._restore_context()

if event_type == CBEventType.RETRIEVE:
sources = payload.get(EventPayload.NODES)
Expand All @@ -80,29 +92,25 @@ def on_event_end(
content=content,
author=event_type,
elements=elements,
parent_id=parent_id,
parent_id=self._get_parent_id(),
).send()
)

if event_type == CBEventType.LLM:
response = payload.get(EventPayload.RESPONSE)
content = response.message.content if response else ""

run_sync(
Message(
content=payload.get(EventPayload.RESPONSE, ""),
content=content,
author=event_type,
parent_id=parent_id,
parent_id=self._get_parent_id(),
prompt=payload.get(EventPayload.PROMPT),
).send()
)

def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""
self.root_message = context.session.root_message

def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""Run when an overall trace is exited."""
def _noop(self, *args, **kwargs):
pass

self.root_message = None
on_event_start = _noop
end_trace = _noop
4 changes: 2 additions & 2 deletions src/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ lazify = "^0.4.0"
optional = true

[tool.poetry.group.tests.dependencies]
langchain = "^0.0.229"
llama_index = "^0.7.4"
langchain = "^0.0.262"
llama_index = "^0.8.3"
transformers = "^4.30.1"
responses = "0.23.1"
aioresponses = "0.7.4"
Expand Down

0 comments on commit fb8853a

Please sign in to comment.