Skip to content

Commit

Permalink
Update sql_alchemy.py (Chainlit#981)
Browse files Browse the repository at this point in the history
  • Loading branch information
hayescode authored May 13, 2024
1 parent d47045f commit e5d2573
Showing 1 changed file with 22 additions and 21 deletions.
43 changes: 22 additions & 21 deletions backend/chainlit/data/sql_alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ def __init__(
ssl_require: bool = False,
storage_provider: Optional[BaseStorageClient] = None,
user_thread_limit: Optional[int] = 1000,
show_logger: Optional[bool] = False,
):
self._conninfo = conninfo
self.user_thread_limit = user_thread_limit
self.show_logger = show_logger
ssl_args = {}
if ssl_require:
# Create an SSL context to require an SSL connection
Expand All @@ -55,7 +57,7 @@ def __init__(
self.async_session = sessionmaker(bind=self.engine, expire_on_commit=False, class_=AsyncSession) # type: ignore
if storage_provider:
self.storage_provider: Optional[BaseStorageClient] = storage_provider
logger.info("SQLAlchemyDataLayer storage client initialized")
if self.show_logger: logger.info("SQLAlchemyDataLayer storage client initialized")
else:
self.storage_provider = None
logger.warn(
Expand Down Expand Up @@ -102,7 +104,7 @@ def clean_result(self, obj):

###### User ######
async def get_user(self, identifier: str) -> Optional[PersistedUser]:
logger.info(f"SQLAlchemy: get_user, identifier={identifier}")
if self.show_logger: logger.info(f"SQLAlchemy: get_user, identifier={identifier}")
query = "SELECT * FROM users WHERE identifier = :identifier"
parameters = {"identifier": identifier}
result = await self.execute_sql(query=query, parameters=parameters)
Expand All @@ -112,20 +114,20 @@ async def get_user(self, identifier: str) -> Optional[PersistedUser]:
return None

async def create_user(self, user: User) -> Optional[PersistedUser]:
logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}")
if self.show_logger: logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}")
existing_user: Optional["PersistedUser"] = await self.get_user(user.identifier)
user_dict: Dict[str, Any] = {
"identifier": str(user.identifier),
"metadata": json.dumps(user.metadata) or {},
}
if not existing_user: # create the user
logger.info("SQLAlchemy: create_user, creating the user")
if self.show_logger: logger.info("SQLAlchemy: create_user, creating the user")
user_dict["id"] = str(uuid.uuid4())
user_dict["createdAt"] = await self.get_current_timestamp()
query = """INSERT INTO users ("id", "identifier", "createdAt", "metadata") VALUES (:id, :identifier, :createdAt, :metadata)"""
await self.execute_sql(query=query, parameters=user_dict)
else: # update the user
logger.info("SQLAlchemy: update user metadata")
if self.show_logger: logger.info("SQLAlchemy: update user metadata")
query = """UPDATE users SET "metadata" = :metadata WHERE "identifier" = :identifier"""
await self.execute_sql(
query=query, parameters=user_dict
Expand All @@ -134,19 +136,18 @@ async def create_user(self, user: User) -> Optional[PersistedUser]:

###### Threads ######
async def get_thread_author(self, thread_id: str) -> str:
logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}")
if self.show_logger: logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}")
query = """SELECT "userIdentifier" FROM threads WHERE "id" = :id"""
parameters = {"id": thread_id}
result = await self.execute_sql(query=query, parameters=parameters)
if isinstance(result, list) and result[0]:
author_identifier = result[0].get("userIdentifier")
if author_identifier is not None:
print(f"Author found: {author_identifier}")
return author_identifier
raise ValueError(f"Author not found for thread_id {thread_id}")

async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}")
if self.show_logger: logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}")
user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(
thread_id=thread_id
)
Expand All @@ -163,7 +164,7 @@ async def update_thread(
metadata: Optional[Dict] = None,
tags: Optional[List[str]] = None,
):
logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
if self.show_logger: logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
if context.session.user is not None:
user_identifier = context.session.user.identifier
else:
Expand Down Expand Up @@ -200,7 +201,7 @@ async def update_thread(
await self.execute_sql(query=query, parameters=parameters)

async def delete_thread(self, thread_id: str):
logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}")
if self.show_logger: logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}")
# Delete feedbacks/elements/steps/thread
feedbacks_query = """DELETE FROM feedbacks WHERE "forId" IN (SELECT "id" FROM steps WHERE "threadId" = :id)"""
elements_query = """DELETE FROM elements WHERE "threadId" = :id"""
Expand All @@ -215,7 +216,7 @@ async def delete_thread(self, thread_id: str):
async def list_threads(
self, pagination: Pagination, filters: ThreadFilter
) -> PaginatedResponse:
logger.info(
if self.show_logger: logger.info(
f"SQLAlchemy: list_threads, pagination={pagination}, filters={filters}"
)
if not filters.userId:
Expand Down Expand Up @@ -275,7 +276,7 @@ async def list_threads(
###### Steps ######
@queue_until_user_message()
async def create_step(self, step_dict: "StepDict"):
logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}")
if self.show_logger: logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}")
if not getattr(context.session.user, "id", None):
raise ValueError("No authenticated user in context")
step_dict["showInput"] = (
Expand Down Expand Up @@ -305,12 +306,12 @@ async def create_step(self, step_dict: "StepDict"):

@queue_until_user_message()
async def update_step(self, step_dict: "StepDict"):
logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}")
if self.show_logger: logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}")
await self.create_step(step_dict)

@queue_until_user_message()
async def delete_step(self, step_id: str):
logger.info(f"SQLAlchemy: delete_step, step_id={step_id}")
if self.show_logger: logger.info(f"SQLAlchemy: delete_step, step_id={step_id}")
# Delete feedbacks/elements/steps
feedbacks_query = """DELETE FROM feedbacks WHERE "forId" = :id"""
elements_query = """DELETE FROM elements WHERE "forId" = :id"""
Expand All @@ -322,7 +323,7 @@ async def delete_step(self, step_id: str):

###### Feedback ######
async def upsert_feedback(self, feedback: Feedback) -> str:
logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}")
if self.show_logger: logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}")
feedback.id = feedback.id or str(uuid.uuid4())
feedback_dict = asdict(feedback)
parameters = {
Expand All @@ -344,7 +345,7 @@ async def upsert_feedback(self, feedback: Feedback) -> str:
return feedback.id

async def delete_feedback(self, feedback_id: str) -> bool:
logger.info(f"SQLAlchemy: delete_feedback, feedback_id={feedback_id}")
if self.show_logger: logger.info(f"SQLAlchemy: delete_feedback, feedback_id={feedback_id}")
query = """DELETE FROM feedbacks WHERE "id" = :feedback_id"""
parameters = {"feedback_id": feedback_id}
await self.execute_sql(query=query, parameters=parameters)
Expand All @@ -353,7 +354,7 @@ async def delete_feedback(self, feedback_id: str) -> bool:
###### Elements ######
@queue_until_user_message()
async def create_element(self, element: "Element"):
logger.info(f"SQLAlchemy: create_element, element_id = {element.id}")
if self.show_logger: logger.info(f"SQLAlchemy: create_element, element_id = {element.id}")
if not getattr(context.session.user, "id", None):
raise ValueError("No authenticated user in context")
if isinstance(element, Avatar): # Skip creating elements of type avatar
Expand Down Expand Up @@ -416,7 +417,7 @@ async def create_element(self, element: "Element"):

@queue_until_user_message()
async def delete_element(self, element_id: str):
logger.info(f"SQLAlchemy: delete_element, element_id={element_id}")
if self.show_logger: logger.info(f"SQLAlchemy: delete_element, element_id={element_id}")
query = """DELETE FROM elements WHERE "id" = :id"""
parameters = {"id": element_id}
await self.execute_sql(query=query, parameters=parameters)
Expand All @@ -428,7 +429,7 @@ async def get_all_user_threads(
self, user_id: Optional[str] = None, thread_id: Optional[str] = None
) -> Optional[List[ThreadDict]]:
"""Fetch all user threads up to self.user_thread_limit, or one thread by id if thread_id is provided."""
logger.info(f"SQLAlchemy: get_all_user_threads")
if self.show_logger: logger.info(f"SQLAlchemy: get_all_user_threads")
user_threads_query = """
SELECT
"id" AS thread_id,
Expand Down Expand Up @@ -562,8 +563,8 @@ async def get_all_user_threads(
tags=step_feedback.get("step_tags"),
input=(
step_feedback.get("step_input", "")
if step_feedback["step_showinput"]
else ""
if step_feedback["step_showinput"] == "true"
else None
),
output=step_feedback.get("step_output", ""),
createdAt=step_feedback.get("step_createdat"),
Expand Down

0 comments on commit e5d2573

Please sign in to comment.