This repository has been archived by the owner on Oct 21, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0229c91
commit 4fe3cdf
Showing
12 changed files
with
352 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
from typing import List | ||
|
||
import streamlit as st | ||
from phi.assistant import Assistant | ||
from phi.document import Document | ||
from phi.document.reader.pdf import PDFReader | ||
from phi.tools.streamlit.components import ( | ||
get_openai_key_sidebar, | ||
get_username_sidebar, | ||
) | ||
|
||
from pdf_ai.assistant import get_pdf_assistant | ||
from utils.log import logger | ||
|
||
|
||
st.set_page_config( | ||
page_title="PDF AI", | ||
page_icon=":blue_heart:", | ||
) | ||
st.title("PDF AI") | ||
st.markdown("##### :blue_heart: built using [phidata](https://github.com/phidatahq/phidata)") | ||
|
||
|
||
def restart_assistant(): | ||
st.session_state["pdf_assistant"] = None | ||
st.session_state["pdf_assistant_run_id"] = None | ||
st.session_state["file_uploader_key"] += 1 | ||
st.rerun() | ||
|
||
|
||
def main() -> None: | ||
# Get OpenAI key from environment variable or user input | ||
get_openai_key_sidebar() | ||
|
||
# Get username | ||
username = get_username_sidebar() | ||
if username: | ||
st.sidebar.info(f":technologist: User: {username}") | ||
else: | ||
st.markdown("---") | ||
st.markdown("#### :technologist: Enter a username, upload a PDF and start chatting") | ||
return | ||
|
||
# Get the assistant | ||
pdf_assistant: Assistant | ||
if "pdf_assistant" not in st.session_state or st.session_state["pdf_assistant"] is None: | ||
logger.info("---*--- Creating PDF Assistant ---*---") | ||
pdf_assistant = get_pdf_assistant( | ||
user_id=username, | ||
debug_mode=True, | ||
) | ||
st.session_state["pdf_assistant"] = pdf_assistant | ||
else: | ||
pdf_assistant = st.session_state["pdf_assistant"] | ||
|
||
# Create assistant run (i.e. log to database) and save run_id in session state | ||
st.session_state["pdf_assistant_run_id"] = pdf_assistant.create_run() | ||
|
||
# Load messages for existing assistant | ||
assistant_chat_history = pdf_assistant.memory.get_chat_history() | ||
if len(assistant_chat_history) > 0: | ||
logger.debug("Loading chat history") | ||
st.session_state["messages"] = assistant_chat_history | ||
else: | ||
logger.debug("No chat history found") | ||
st.session_state["messages"] = [{"role": "assistant", "content": "Ask me questions from the PDF"}] | ||
|
||
# Prompt for user input | ||
if prompt := st.chat_input(): | ||
st.session_state["messages"].append({"role": "user", "content": prompt}) | ||
|
||
# Display existing chat messages | ||
for message in st.session_state["messages"]: | ||
if message["role"] == "system": | ||
continue | ||
with st.chat_message(message["role"]): | ||
st.write(message["content"]) | ||
|
||
# If last message is from a user, generate a new response | ||
last_message = st.session_state["messages"][-1] | ||
if last_message.get("role") == "user": | ||
question = last_message["content"] | ||
with st.chat_message("assistant"): | ||
with st.spinner("Working..."): | ||
response = "" | ||
resp_container = st.empty() | ||
for delta in pdf_assistant.run(question): | ||
response += delta # type: ignore | ||
resp_container.markdown(response) | ||
|
||
st.session_state["messages"].append({"role": "assistant", "content": response}) | ||
|
||
# Upload PDF | ||
if pdf_assistant.knowledge_base: | ||
if "file_uploader_key" not in st.session_state: | ||
st.session_state["file_uploader_key"] = 0 | ||
|
||
uploaded_file = st.sidebar.file_uploader( | ||
"Upload a PDF :page_facing_up:", | ||
type="pdf", | ||
key=st.session_state["file_uploader_key"], | ||
) | ||
if uploaded_file is not None: | ||
alert = st.sidebar.info("Processing PDF...", icon="ℹ️") | ||
pdf_name = uploaded_file.name.split(".")[0] | ||
if f"{pdf_name}_uploaded" not in st.session_state: | ||
reader = PDFReader() | ||
pdf_documents: List[Document] = reader.read(uploaded_file) | ||
if pdf_documents: | ||
pdf_assistant.knowledge_base.load_documents(documents=pdf_documents, upsert=True) | ||
else: | ||
st.sidebar.error("Could not read PDF") | ||
st.session_state[f"{pdf_name}_uploaded"] = True | ||
alert.empty() | ||
|
||
st.sidebar.markdown("---") | ||
|
||
if st.sidebar.button("New Run"): | ||
restart_assistant() | ||
|
||
if st.sidebar.button("Auto Rename"): | ||
pdf_assistant.auto_rename_run() | ||
|
||
if pdf_assistant.storage: | ||
pdf_assistant_run_ids: List[str] = pdf_assistant.storage.get_all_run_ids(user_id=username) | ||
new_pdf_assistant_run_id = st.sidebar.selectbox("Run ID", options=pdf_assistant_run_ids) | ||
if st.session_state["pdf_assistant_run_id"] != new_pdf_assistant_run_id: | ||
logger.debug(f"Loading run {new_pdf_assistant_run_id}") | ||
logger.info("---*--- Loading PDF Assistant ---*---") | ||
st.session_state["pdf_assistant"] = get_pdf_assistant( | ||
user_id=username, | ||
run_id=new_pdf_assistant_run_id, | ||
debug_mode=True, | ||
) | ||
st.rerun() | ||
|
||
pdf_assistant_run_name = pdf_assistant.run_name | ||
if pdf_assistant_run_name: | ||
st.sidebar.write(f":thread: {pdf_assistant_run_name}") | ||
|
||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from typing import Optional | ||
|
||
from phi.assistant import Assistant | ||
from phi.llm.openai import OpenAIChat | ||
|
||
from ai.settings import ai_settings | ||
from hn_ai.search import search_web | ||
from pdf_ai.storage import pdf_assistant_storage | ||
from pdf_ai.tools import PDFTools | ||
from pdf_ai.knowledge import get_pdf_knowledge_base_for_user | ||
|
||
|
||
def get_pdf_assistant( | ||
user_id: str, | ||
run_id: Optional[str] = None, | ||
debug_mode: bool = False, | ||
) -> Assistant: | ||
pdf_tools = PDFTools(user_id=user_id) | ||
|
||
return Assistant( | ||
name=f"pdf_assistant_{user_id}" if user_id else "hn_assistant", | ||
run_id=run_id, | ||
user_id=user_id, | ||
llm=OpenAIChat( | ||
model=ai_settings.gpt_4, | ||
max_tokens=ai_settings.default_max_tokens, | ||
temperature=ai_settings.default_temperature, | ||
), | ||
storage=pdf_assistant_storage, | ||
monitoring=True, | ||
use_tools=True, | ||
# tools=[search_web, pdf_tools], | ||
knowledge_base=get_pdf_knowledge_base_for_user(user_id), | ||
show_tool_calls=True, | ||
debug_mode=debug_mode, | ||
description="Your name is PDF AI and you are a chatbot that answers questions from a knowledge base of PDFs.", | ||
add_datetime_to_instructions=True, | ||
instructions=[ | ||
"You are made by phidata: https://github.com/phidatahq/phidata", | ||
f"You are interacting with the user: {user_id}", | ||
# "If the user asks a question, first determine if you should search the web or your knowledge base for the answer.", | ||
# "If you need to search the web, use the `search_web` tool to search the web for the answer.", | ||
# "If the user asks a question but the document is not clear, use the `search_latest_document` tool to search the latest document for the answer.", | ||
# "If the user asks to summarize a document, use the `get_latest_document_contents` tool to get the contents of the latest document.", | ||
# "When the user asks a question, first determine if you should search the web or HackerNews for the answer.", | ||
# "If you need to search HackerNews, use the `search_hackernews_stories` tool. Search for atleast 10 stories." | ||
# + " Then use the `get_story_details` tool to get the details of the most popular 3 stories.", | ||
# "If the user asks what's trending, use the `get_top_stories` tool to get the top 5 stories.", | ||
# f"If the user asks about their posts, use the `get_user_details` tool with the username {user_id}.", | ||
# "If you need to search the web, use the `search_web` tool to search the web for the answer.", | ||
# "If you need to search the web, use the `search_web` tool to search the web for any query. ", | ||
# "Remember, you can first user the `search_web` tool to get context on the question and then use `search_hackernews_stories` to get information from HackerNews.", | ||
# "Using this information, provide a reasoned summary for the user. Talk about the general sentiment in the comments and the popularity of the story.", | ||
# "Always share the story score, number of comments and a link to the story if available.", | ||
# "If the user provides a URL, use the `get_item_details_by_url` tool to get the details of the item.", | ||
# "Prefer stories with high scores and comments", | ||
# "Always try to delight the user with an interesting fact about the story.", | ||
"If the user compliments you, ask them to star phidata on GitHub: https://github.com/phidatahq/phidata", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from typing import Optional | ||
|
||
from phi.knowledge import AssistantKnowledge | ||
from phi.embedder.openai import OpenAIEmbedder | ||
from phi.vectordb.pgvector import PgVector2 | ||
|
||
from db.session import db_url | ||
from utils.log import logger | ||
|
||
|
||
def get_pdf_knowledge_base_for_user(user_id: Optional[str] = None) -> AssistantKnowledge: | ||
table_name = f"pdf_documents_{user_id}" if user_id else "pdf_documents" | ||
return AssistantKnowledge( | ||
vector_db=PgVector2( | ||
schema="ai", | ||
db_url=db_url, | ||
collection=table_name, | ||
embedder=OpenAIEmbedder(model="text-embedding-3-small"), | ||
), | ||
num_documents=5, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from phi.storage.assistant.postgres import PgAssistantStorage | ||
|
||
from db.session import db_url | ||
|
||
pdf_assistant_storage = PgAssistantStorage( | ||
schema="ai", | ||
db_url=db_url, | ||
table_name="pdf_assistant", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from pdf_ai.assistant import get_pdf_assistant | ||
|
||
pdf_assistant = get_pdf_assistant(user_id="ab", debug_mode=True) | ||
|
||
# pdf_assistant.print_response("Who is the agreement between?") | ||
pdf_assistant.print_response("hello?") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from pdf_ai.tools import PDFTools | ||
|
||
pdf_tools = PDFTools(user_id="ab") | ||
|
||
latest_document = pdf_tools.get_latest_document() | ||
print(latest_document) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import json | ||
from typing import List, Optional | ||
|
||
from phi.document import Document | ||
from phi.tools import ToolRegistry | ||
from phi.knowledge import AssistantKnowledge | ||
from phi.vectordb import VectorDb | ||
from phi.vectordb.pgvector import PgVector2 | ||
|
||
from pdf_ai.knowledge import get_pdf_knowledge_base_for_user | ||
from utils.log import logger | ||
|
||
|
||
class PDFTools(ToolRegistry): | ||
def __init__(self, user_id: str): | ||
super().__init__(name="pdf_tools") | ||
|
||
self.user_id = user_id | ||
self.knowledge_base: AssistantKnowledge = get_pdf_knowledge_base_for_user(user_id=user_id) | ||
self.register(self.get_latest_document_contents) | ||
self.register(self.search_latest_document) | ||
|
||
def get_latest_document_contents(self, limit: int = 5000) -> Optional[str]: | ||
"""Use this function to get the content of the latest document uploaded by the user. | ||
Args: | ||
limit (int, optional): Maximum number of characters to return. Defaults to 5000. | ||
Returns: | ||
str: JSON string of the latest document | ||
""" | ||
|
||
logger.debug(f"Getting latest document for user {self.user_id}") | ||
if self.knowledge_base.vector_db is None or not isinstance(self.knowledge_base.vector_db, PgVector2): | ||
return "Sorry could not find latest document" | ||
|
||
vector_db: PgVector2 = self.knowledge_base.vector_db | ||
table = vector_db.table | ||
with vector_db.Session() as session, session.begin(): | ||
query = session.query(table).order_by(table.c.created_at.desc()).limit(1) | ||
result = session.execute(query) | ||
row = result.fetchone() | ||
|
||
if row is None: | ||
return "Sorry could not find latest document" | ||
|
||
latest_document_name = row.name | ||
logger.debug(f"Latest document name: {latest_document_name}") | ||
|
||
document_query = session.query(table).filter(table.c.name == latest_document_name) | ||
document_result = session.execute(document_query) | ||
document_rows = document_result.fetchall() | ||
latest_document_content = "" | ||
for document_row in document_rows: | ||
document_content = document_row.content | ||
latest_document_content += document_content | ||
|
||
return latest_document_content[:limit] | ||
|
||
return "Sorry could not find latest document" | ||
|
||
def search_latest_document(self, query: str, num_documents: Optional[int] = None) -> Optional[str]: | ||
"""Use this function to search the latest document uploaded by the user for a query. | ||
Args: | ||
query (str): Query to search for | ||
num_documents (Optional[int], optional): Number of documents to return. Defaults to None. | ||
Returns: | ||
str: JSON string of the search results | ||
""" | ||
|
||
logger.debug(f"Searching latest document for query: {query}") | ||
if self.knowledge_base.vector_db is None or not isinstance(self.knowledge_base.vector_db, PgVector2): | ||
return "Sorry could not search latest document" | ||
|
||
vector_db: PgVector2 = self.knowledge_base.vector_db | ||
table = vector_db.table | ||
latest_document_name = None | ||
with vector_db.Session() as session, session.begin(): | ||
query = session.query(table).order_by(table.c.created_at.desc()).limit(1) | ||
result = session.execute(query) | ||
row = result.fetchone() | ||
|
||
if row is None: | ||
return "Sorry could not find latest document" | ||
|
||
latest_document_name = row.name | ||
logger.debug(f"Latest document name: {latest_document_name}") | ||
|
||
if latest_document_name is None: | ||
return "Sorry could not find latest document" | ||
|
||
search_results: List[Document] = vector_db.search( | ||
query=query, limit=num_documents, filters={"name": latest_document_name} | ||
) | ||
logger.debug(f"Search result: {search_results}") | ||
|
||
if len(search_results) == 0: | ||
return "Sorry could not find any results from latest document" | ||
|
||
return json.dumps([doc.to_dict() for doc in search_results]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.