diff --git a/backend/chainlit/data/sql_alchemy.py b/backend/chainlit/data/sql_alchemy.py index 2902795821..8207498fa5 100644 --- a/backend/chainlit/data/sql_alchemy.py +++ b/backend/chainlit/data/sql_alchemy.py @@ -39,9 +39,11 @@ def __init__( ssl_require: bool = False, storage_provider: Optional[BaseStorageClient] = None, user_thread_limit: Optional[int] = 1000, + show_logger: Optional[bool] = False, ): self._conninfo = conninfo self.user_thread_limit = user_thread_limit + self.show_logger = show_logger ssl_args = {} if ssl_require: # Create an SSL context to require an SSL connection @@ -55,7 +57,7 @@ def __init__( self.async_session = sessionmaker(bind=self.engine, expire_on_commit=False, class_=AsyncSession) # type: ignore if storage_provider: self.storage_provider: Optional[BaseStorageClient] = storage_provider - logger.info("SQLAlchemyDataLayer storage client initialized") + if self.show_logger: logger.info("SQLAlchemyDataLayer storage client initialized") else: self.storage_provider = None logger.warn( @@ -102,7 +104,7 @@ def clean_result(self, obj): ###### User ###### async def get_user(self, identifier: str) -> Optional[PersistedUser]: - logger.info(f"SQLAlchemy: get_user, identifier={identifier}") + if self.show_logger: logger.info(f"SQLAlchemy: get_user, identifier={identifier}") query = "SELECT * FROM users WHERE identifier = :identifier" parameters = {"identifier": identifier} result = await self.execute_sql(query=query, parameters=parameters) @@ -112,20 +114,20 @@ async def get_user(self, identifier: str) -> Optional[PersistedUser]: return None async def create_user(self, user: User) -> Optional[PersistedUser]: - logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}") + if self.show_logger: logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}") existing_user: Optional["PersistedUser"] = await self.get_user(user.identifier) user_dict: Dict[str, Any] = { "identifier": str(user.identifier), "metadata": json.dumps(user.metadata) or {}, } if not existing_user: # create the user - logger.info("SQLAlchemy: create_user, creating the user") + if self.show_logger: logger.info("SQLAlchemy: create_user, creating the user") user_dict["id"] = str(uuid.uuid4()) user_dict["createdAt"] = await self.get_current_timestamp() query = """INSERT INTO users ("id", "identifier", "createdAt", "metadata") VALUES (:id, :identifier, :createdAt, :metadata)""" await self.execute_sql(query=query, parameters=user_dict) else: # update the user - logger.info("SQLAlchemy: update user metadata") + if self.show_logger: logger.info("SQLAlchemy: update user metadata") query = """UPDATE users SET "metadata" = :metadata WHERE "identifier" = :identifier""" await self.execute_sql( query=query, parameters=user_dict @@ -134,19 +136,18 @@ async def create_user(self, user: User) -> Optional[PersistedUser]: ###### Threads ###### async def get_thread_author(self, thread_id: str) -> str: - logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}") + if self.show_logger: logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}") query = """SELECT "userIdentifier" FROM threads WHERE "id" = :id""" parameters = {"id": thread_id} result = await self.execute_sql(query=query, parameters=parameters) if isinstance(result, list) and result[0]: author_identifier = result[0].get("userIdentifier") if author_identifier is not None: - print(f"Author found: {author_identifier}") return author_identifier raise ValueError(f"Author not found for thread_id {thread_id}") async def get_thread(self, thread_id: str) -> Optional[ThreadDict]: - logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}") + if self.show_logger: logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}") user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads( thread_id=thread_id ) @@ -163,7 +164,7 @@ async def update_thread( metadata: Optional[Dict] = None, tags: Optional[List[str]] = None, ): - logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}") + if self.show_logger: logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}") if context.session.user is not None: user_identifier = context.session.user.identifier else: @@ -200,7 +201,7 @@ async def update_thread( await self.execute_sql(query=query, parameters=parameters) async def delete_thread(self, thread_id: str): - logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}") + if self.show_logger: logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}") # Delete feedbacks/elements/steps/thread feedbacks_query = """DELETE FROM feedbacks WHERE "forId" IN (SELECT "id" FROM steps WHERE "threadId" = :id)""" elements_query = """DELETE FROM elements WHERE "threadId" = :id""" @@ -215,7 +216,7 @@ async def delete_thread(self, thread_id: str): async def list_threads( self, pagination: Pagination, filters: ThreadFilter ) -> PaginatedResponse: - logger.info( + if self.show_logger: logger.info( f"SQLAlchemy: list_threads, pagination={pagination}, filters={filters}" ) if not filters.userId: @@ -275,7 +276,7 @@ async def list_threads( ###### Steps ###### @queue_until_user_message() async def create_step(self, step_dict: "StepDict"): - logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}") + if self.show_logger: logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}") if not getattr(context.session.user, "id", None): raise ValueError("No authenticated user in context") step_dict["showInput"] = ( @@ -305,12 +306,12 @@ async def create_step(self, step_dict: "StepDict"): @queue_until_user_message() async def update_step(self, step_dict: "StepDict"): - logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}") + if self.show_logger: logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}") await self.create_step(step_dict) @queue_until_user_message() async def delete_step(self, step_id: str): - logger.info(f"SQLAlchemy: delete_step, step_id={step_id}") + if self.show_logger: logger.info(f"SQLAlchemy: delete_step, step_id={step_id}") # Delete feedbacks/elements/steps feedbacks_query = """DELETE FROM feedbacks WHERE "forId" = :id""" elements_query = """DELETE FROM elements WHERE "forId" = :id""" @@ -322,7 +323,7 @@ async def delete_step(self, step_id: str): ###### Feedback ###### async def upsert_feedback(self, feedback: Feedback) -> str: - logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}") + if self.show_logger: logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}") feedback.id = feedback.id or str(uuid.uuid4()) feedback_dict = asdict(feedback) parameters = { @@ -344,7 +345,7 @@ async def upsert_feedback(self, feedback: Feedback) -> str: return feedback.id async def delete_feedback(self, feedback_id: str) -> bool: - logger.info(f"SQLAlchemy: delete_feedback, feedback_id={feedback_id}") + if self.show_logger: logger.info(f"SQLAlchemy: delete_feedback, feedback_id={feedback_id}") query = """DELETE FROM feedbacks WHERE "id" = :feedback_id""" parameters = {"feedback_id": feedback_id} await self.execute_sql(query=query, parameters=parameters) @@ -353,7 +354,7 @@ async def delete_feedback(self, feedback_id: str) -> bool: ###### Elements ###### @queue_until_user_message() async def create_element(self, element: "Element"): - logger.info(f"SQLAlchemy: create_element, element_id = {element.id}") + if self.show_logger: logger.info(f"SQLAlchemy: create_element, element_id = {element.id}") if not getattr(context.session.user, "id", None): raise ValueError("No authenticated user in context") if isinstance(element, Avatar): # Skip creating elements of type avatar @@ -416,7 +417,7 @@ async def create_element(self, element: "Element"): @queue_until_user_message() async def delete_element(self, element_id: str): - logger.info(f"SQLAlchemy: delete_element, element_id={element_id}") + if self.show_logger: logger.info(f"SQLAlchemy: delete_element, element_id={element_id}") query = """DELETE FROM elements WHERE "id" = :id""" parameters = {"id": element_id} await self.execute_sql(query=query, parameters=parameters) @@ -428,7 +429,7 @@ async def get_all_user_threads( self, user_id: Optional[str] = None, thread_id: Optional[str] = None ) -> Optional[List[ThreadDict]]: """Fetch all user threads up to self.user_thread_limit, or one thread by id if thread_id is provided.""" - logger.info(f"SQLAlchemy: get_all_user_threads") + if self.show_logger: logger.info(f"SQLAlchemy: get_all_user_threads") user_threads_query = """ SELECT "id" AS thread_id, @@ -562,8 +563,8 @@ async def get_all_user_threads( tags=step_feedback.get("step_tags"), input=( step_feedback.get("step_input", "") - if step_feedback["step_showinput"] - else "" + if step_feedback["step_showinput"] == "true" + else None ), output=step_feedback.get("step_output", ""), createdAt=step_feedback.get("step_createdat"),