Skip to content

Commit

Permalink
Make the context directly accessible from chainlit.context.context (C…
Browse files Browse the repository at this point in the history
  • Loading branch information
ramnes authored Aug 18, 2023
1 parent b078e80 commit 211bc27
Show file tree
Hide file tree
Showing 14 changed files with 118 additions and 120 deletions.
13 changes: 5 additions & 8 deletions cypress/e2e/sdk_availability/main.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,26 @@
import chainlit as cl
from chainlit.context import get_emitter
from chainlit.context import context
from chainlit.sync import make_async, run_sync


async def async_function_from_sync():
await cl.sleep(2)
emitter = get_emitter()
return emitter
return context.emitter


def sync_function():
emitter_from_make_async = get_emitter()
emitter_from_make_async = context.emitter
emitter_from_async_from_sync = run_sync(async_function_from_sync())
return (emitter_from_make_async, emitter_from_async_from_sync)


async def async_function():
emitter = await another_async_function()
return emitter
return await another_async_function()


async def another_async_function():
await cl.sleep(2)
emitter = get_emitter()
return emitter
return context.emitter


@cl.on_chat_start
Expand Down
7 changes: 3 additions & 4 deletions src/chainlit/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses_json import DataClassJsonMixin
from pydantic.dataclasses import Field, dataclass

from chainlit.context import get_emitter
from chainlit.context import context
from chainlit.telemetry import trace_event


Expand All @@ -25,13 +25,12 @@ class Action(DataClassJsonMixin):

def __post_init__(self) -> None:
trace_event(f"init {self.__class__.__name__}")
self.emit = get_emitter().emit

async def send(self, for_id: str):
trace_event(f"send {self.__class__.__name__}")
self.forId = for_id
await self.emit("action", self.to_dict())
await context.emitter.emit("action", self.to_dict())

async def remove(self):
trace_event(f"remove {self.__class__.__name__}")
await self.emit("remove_action", self.to_dict())
await context.emitter.emit("remove_action", self.to_dict())
7 changes: 3 additions & 4 deletions src/chainlit/chat_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic.dataclasses import Field, dataclass

from chainlit.context import get_emitter
from chainlit.context import context
from chainlit.input_widget import InputWidget


Expand All @@ -17,7 +17,6 @@ def __init__(
inputs: List[InputWidget],
) -> None:
self.inputs = inputs
self.emitter = get_emitter()

def settings(self):
return dict(
Expand All @@ -26,9 +25,9 @@ def settings(self):

async def send(self):
settings = self.settings()
self.emitter.set_chat_settings(settings)
context.emitter.set_chat_settings(settings)

inputs_content = [input_widget.to_dict() for input_widget in self.inputs]
await self.emitter.emit("chat_settings", inputs_content)
await context.emitter.emit("chat_settings", inputs_content)

return settings
41 changes: 31 additions & 10 deletions src/chainlit/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from asyncio import AbstractEventLoop
import asyncio
from contextvars import ContextVar
from typing import TYPE_CHECKING

from lazify import LazyProxy

from chainlit.session import Session

if TYPE_CHECKING:
from chainlit.emitter import ChainlitEmitter

Expand All @@ -11,19 +15,36 @@ def __init__(self, msg="Chainlit context not found", *args, **kwargs):
super().__init__(msg, *args, **kwargs)


emitter_var: ContextVar["ChainlitEmitter"] = ContextVar("emitter")
loop_var: ContextVar[AbstractEventLoop] = ContextVar("loop")
class ChainlitContext:
loop: asyncio.AbstractEventLoop
emitter: "ChainlitEmitter"
session: Session

def __init__(self, session: Session):
from chainlit.emitter import ChainlitEmitter

self.loop = asyncio.get_event_loop()
self.session = session
self.emitter = ChainlitEmitter(session)

def get_emitter() -> "ChainlitEmitter":
try:
return emitter_var.get()
except LookupError:
raise ChainlitContextException()

context_var: ContextVar[ChainlitContext] = ContextVar("chainlit")

def get_loop() -> AbstractEventLoop:

def init_context(session_or_sid) -> ChainlitContext:
if not isinstance(session_or_sid, Session):
session_or_sid = Session.require(session_or_sid)

context = ChainlitContext(session_or_sid)
context_var.set(context)
return context


def get_context() -> ChainlitContext:
try:
return loop_var.get()
return context_var.get()
except LookupError:
raise ChainlitContextException()


context = LazyProxy(get_context, enable_cache=False)
19 changes: 9 additions & 10 deletions src/chainlit/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pydantic.dataclasses import Field, dataclass

from chainlit.client.base import BaseDBClient, ElementDict
from chainlit.context import get_emitter
from chainlit.context import context
from chainlit.telemetry import trace_event
from chainlit.types import ElementDisplay, ElementSize, ElementType

Expand Down Expand Up @@ -45,7 +45,6 @@ class Element:

def __post_init__(self) -> None:
trace_event(f"init {self.__class__.__name__}")
self.emitter = get_emitter()
self.persisted = False

if not self.url and not self.path and not self.content:
Expand Down Expand Up @@ -99,7 +98,7 @@ async def before_emit(self, element: Dict) -> Dict:

async def remove(self):
trace_event(f"remove {self.__class__.__name__}")
await self.emitter.emit("remove_element", {"id": self.id})
await context.emitter.emit("remove_element", {"id": self.id})

async def send(self, for_id: Optional[str] = None):
if not self.content and not self.url and self.path:
Expand All @@ -111,8 +110,8 @@ async def send(self, for_id: Optional[str] = None):
self.for_ids.append(for_id)

# We have a client, persist the element
if self.emitter.db_client:
element_dict = await self.persist(self.emitter.db_client)
if context.emitter.db_client:
element_dict = await self.persist(context.emitter.db_client)
self.id = element_dict["id"]

elif not self.url and not self.content:
Expand All @@ -123,18 +122,18 @@ async def send(self, for_id: Optional[str] = None):
# Adding this out of to_dict since the dict will be persisted in the DB
emit_dict["content"] = self.content

if self.emitter.emit:
if context.emitter.emit:
# Element was already sent
if len(self.for_ids) > 1:
trace_event(f"update {self.__class__.__name__}")
await self.emitter.emit(
await context.emitter.emit(
"update_element",
{"id": self.id, "forIds": self.for_ids},
)
else:
trace_event(f"send {self.__class__.__name__}")
emit_dict = await self.before_emit(emit_dict)
await self.emitter.emit("element", emit_dict)
await context.emitter.emit("element", emit_dict)


ElementBased = TypeVar("ElementBased", bound=Element)
Expand Down Expand Up @@ -165,10 +164,10 @@ async def send(self):
# Adding this out of to_dict since the dict will be persisted in the DB
element["content"] = self.content

if self.emitter.emit and element:
if context.emitter.emit and element:
trace_event(f"send {self.__class__.__name__}")
element = await self.before_emit(element)
await self.emitter.emit("element", element)
await context.emitter.emit("element", element)


@dataclass
Expand Down
10 changes: 3 additions & 7 deletions src/chainlit/haystack/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

import chainlit as cl
from chainlit.config import config
from chainlit.context import get_emitter
from chainlit.emitter import ChainlitEmitter
from chainlit.context import context

T = TypeVar("T")

Expand All @@ -33,7 +32,6 @@ def clear(self) -> None:

class HaystackAgentCallbackHandler:
stack: Stack[cl.Message]
emitter: ChainlitEmitter
latest_agent_message: Optional[cl.Message]

def __init__(self, agent: Agent):
Expand All @@ -47,13 +45,11 @@ def __init__(self, agent: Agent):
agent.tm.callback_manager.on_tool_error += self.on_tool_error

def get_root_message(self):
self.emitter = self.emitter if hasattr(self, "emitter") else get_emitter()

if not self.emitter.session.root_message:
if not context.session.root_message:
root_message = cl.Message(author=config.ui.name, content="")
cl.run_sync(root_message.send())

return self.emitter.session.root_message
return context.session.root_message

def on_agent_start(self, **kwargs: Any) -> None:
# Prepare agent step message for streaming
Expand Down
13 changes: 6 additions & 7 deletions src/chainlit/langchain/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult

from chainlit.config import config
from chainlit.context import get_emitter
from chainlit.emitter import ChainlitEmitter
from chainlit.context import context
from chainlit.message import ErrorMessage, Message
from chainlit.sync import run_sync
from chainlit.types import LLMSettings
Expand Down Expand Up @@ -37,7 +36,6 @@ def get_llm_settings(invocation_params: Union[Dict, None]):


class BaseLangchainCallbackHandler(BaseCallbackHandler):
emitter: ChainlitEmitter
# Keep track of the formatted prompts to display them in the prompt playground.
prompts: List[str]
# Keep track of the LLM settings for the last prompt
Expand Down Expand Up @@ -75,7 +73,6 @@ def __init__(
stream_final_answer: bool = False,
root_message: Optional[Message] = None,
) -> None:
self.emitter = get_emitter()
self.prompts = []
self.llm_settings = None
self.sequence = []
Expand All @@ -84,7 +81,7 @@ def __init__(

if root_message:
self.root_message = root_message
elif root_message := self.emitter.session.root_message:
elif root_message := context.session.root_message:
self.root_message = root_message
else:
self.root_message = Message(author=config.ui.name, content="")
Expand Down Expand Up @@ -287,7 +284,7 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
token_usage = response.llm_output["token_usage"]
if "total_tokens" in token_usage:
run_sync(
self.emitter.update_token_count(token_usage["total_tokens"])
context.emitter.update_token_count(token_usage["total_tokens"])
)
if self.final_stream:
run_sync(self.final_stream.send())
Expand Down Expand Up @@ -411,7 +408,9 @@ async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
if "token_usage" in response.llm_output:
token_usage = response.llm_output["token_usage"]
if "total_tokens" in token_usage:
await self.emitter.update_token_count(token_usage["total_tokens"])
await context.emitter.update_token_count(
token_usage["total_tokens"]
)
if self.final_stream:
await self.final_stream.send()

Expand Down
5 changes: 2 additions & 3 deletions src/chainlit/llama_index/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from llama_index.callbacks.base import BaseCallbackHandler
from llama_index.callbacks.schema import CBEventType, EventPayload

from chainlit.context import get_emitter
from chainlit.context import context
from chainlit.element import Text
from chainlit.message import Message
from chainlit.sync import run_sync
Expand Down Expand Up @@ -96,8 +96,7 @@ def on_event_end(

def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""
emitter = get_emitter()
self.root_message = emitter.session.root_message
self.root_message = context.session.root_message

def end_trace(
self,
Expand Down
Loading

0 comments on commit 211bc27

Please sign in to comment.