Skip to content

Commit

Permalink
keep the message id intact if persistence fails (Chainlit#200)
Browse files Browse the repository at this point in the history
* keep the message id intact if persistence fails

* fix test

* fix typo
  • Loading branch information
willydouhard authored Jul 20, 2023
1 parent 9f0c542 commit 4ac6a11
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
4 changes: 3 additions & 1 deletion cypress/e2e/client_factory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,6 @@ async def client_factory(user_infos):

@cl.on_chat_start
async def on_chat_start():
await cl.Message("Hello").send()
msg = cl.Message(content="Hello")
msg.fail_on_persist_error = True
await msg.send()
8 changes: 5 additions & 3 deletions src/chainlit/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ async def process_user_message(self, message: Dict):
if self.db_client:
# We have to update the UI with the actual DB ID
ui_message_update = message.copy()
message["id"] = await self.db_client.create_message(message)
ui_message_update["newId"] = message["id"]
await self.update_message(ui_message_update)
persisted_id = await self.db_client.create_message(message)
if persisted_id:
message["id"] = persisted_id
ui_message_update["newId"] = message["id"]
await self.update_message(ui_message_update)

self.session.root_message = Message.from_dict(message)

Expand Down
10 changes: 6 additions & 4 deletions src/chainlit/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class MessageBase(ABC):
id: str = None
streaming = False
created_at: int = None
fail_on_persist_error: bool = True
fail_on_persist_error: bool = False
persisted = False

def __post_init__(self) -> None:
Expand All @@ -42,9 +42,11 @@ async def _create(self):
msg_dict = self.to_dict()
if self.emitter.db_client and not self.persisted:
try:
self.id = await self.emitter.db_client.create_message(msg_dict)
msg_dict["id"] = self.id
self.persisted = True
persisted_id = await self.emitter.db_client.create_message(msg_dict)
if persisted_id:
msg_dict["id"] = persisted_id
self.id = persisted_id
self.persisted = True
except Exception as e:
if self.fail_on_persist_error:
raise e
Expand Down

0 comments on commit 4ac6a11

Please sign in to comment.