This repository has been archived by the owner on Oct 21, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
63dad8e
commit 9413ace
Showing
11 changed files
with
436 additions
and
3 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
from typing import List | ||
|
||
import streamlit as st | ||
from phi.assistant import Assistant | ||
from phi.document import Document | ||
from phi.document.reader.pdf import PDFReader | ||
from phi.tools.streamlit.components import ( | ||
get_openai_key_sidebar, | ||
get_username_sidebar, | ||
reload_button_sidebar, | ||
) | ||
|
||
from arxiv_ai.assistant import get_pdf_assistant | ||
from utils.log import logger | ||
|
||
|
||
st.set_page_config( | ||
page_title="Arxiv AI", | ||
page_icon=":orange_heart:", | ||
) | ||
st.title("Chat with Arxiv Papers") | ||
st.markdown("##### :orange_heart: built using [phidata](https://github.com/phidatahq/phidata)") | ||
|
||
|
||
def restart_assistant(): | ||
st.session_state["pdf_assistant"] = None | ||
st.session_state["pdf_assistant_run_id"] = None | ||
st.session_state["file_uploader_key"] += 1 | ||
st.rerun() | ||
|
||
|
||
def main() -> None: | ||
# Get OpenAI key from environment variable or user input | ||
get_openai_key_sidebar() | ||
|
||
# Get username | ||
username = get_username_sidebar() | ||
if username: | ||
st.sidebar.info(f":technologist: User: {username}") | ||
else: | ||
st.markdown("---") | ||
st.markdown("#### :technologist: Enter a username, upload a PDF and start chatting") | ||
return | ||
|
||
# Get the assistant | ||
pdf_assistant: Assistant | ||
if "pdf_assistant" not in st.session_state or st.session_state["pdf_assistant"] is None: | ||
logger.info("---*--- Creating PDF Assistant ---*---") | ||
pdf_assistant = get_pdf_assistant( | ||
user_id=username, | ||
debug_mode=True, | ||
) | ||
st.session_state["pdf_assistant"] = pdf_assistant | ||
else: | ||
pdf_assistant = st.session_state["pdf_assistant"] | ||
|
||
# Create assistant run (i.e. log to database) and save run_id in session state | ||
st.session_state["pdf_assistant_run_id"] = pdf_assistant.create_run() | ||
|
||
# Load messages for existing assistant | ||
assistant_chat_history = pdf_assistant.memory.get_chat_history() | ||
if len(assistant_chat_history) > 0: | ||
logger.debug("Loading chat history") | ||
st.session_state["messages"] = assistant_chat_history | ||
else: | ||
logger.debug("No chat history found") | ||
st.session_state["messages"] = [{"role": "assistant", "content": "Ask me questions from the PDF"}] | ||
|
||
# Prompt for user input | ||
if prompt := st.chat_input(): | ||
st.session_state["messages"].append({"role": "user", "content": prompt}) | ||
|
||
# Display existing chat messages | ||
for message in st.session_state["messages"]: | ||
if message["role"] == "system": | ||
continue | ||
with st.chat_message(message["role"]): | ||
st.write(message["content"]) | ||
|
||
# If last message is from a user, generate a new response | ||
last_message = st.session_state["messages"][-1] | ||
if last_message.get("role") == "user": | ||
question = last_message["content"] | ||
with st.chat_message("assistant"): | ||
with st.spinner("Working..."): | ||
response = "" | ||
resp_container = st.empty() | ||
for delta in pdf_assistant.run(question): | ||
response += delta # type: ignore | ||
resp_container.markdown(response) | ||
|
||
st.session_state["messages"].append({"role": "assistant", "content": response}) | ||
|
||
# Upload PDF | ||
if pdf_assistant.knowledge_base: | ||
if "file_uploader_key" not in st.session_state: | ||
st.session_state["file_uploader_key"] = 0 | ||
|
||
uploaded_file = st.sidebar.file_uploader( | ||
"Upload a PDF :page_facing_up:", | ||
type="pdf", | ||
key=st.session_state["file_uploader_key"], | ||
) | ||
if uploaded_file is not None: | ||
alert = st.sidebar.info("Processing PDF...", icon="ℹ️") | ||
pdf_name = uploaded_file.name.split(".")[0] | ||
if f"{pdf_name}_uploaded" not in st.session_state: | ||
reader = PDFReader() | ||
pdf_documents: List[Document] = reader.read(uploaded_file) | ||
if pdf_documents: | ||
pdf_assistant.knowledge_base.load_documents(documents=pdf_documents, upsert=True) | ||
# Refresh the assistant to update the instructions and document names | ||
pdf_assistant = get_pdf_assistant( | ||
user_id=username, | ||
run_id=st.session_state["pdf_assistant_run_id"], | ||
debug_mode=True, | ||
) | ||
st.session_state["pdf_assistant"] = pdf_assistant | ||
else: | ||
st.sidebar.error("Could not read PDF") | ||
st.session_state[f"{pdf_name}_uploaded"] = True | ||
alert.empty() | ||
st.sidebar.success(":information_source: If the PDF throws an error, try uploading it again") | ||
|
||
st.sidebar.markdown("---") | ||
|
||
if st.sidebar.button("New Run"): | ||
restart_assistant() | ||
|
||
if st.sidebar.button("Auto Rename"): | ||
pdf_assistant.auto_rename_run() | ||
|
||
if pdf_assistant.storage: | ||
pdf_assistant_run_ids: List[str] = pdf_assistant.storage.get_all_run_ids(user_id=username) | ||
new_pdf_assistant_run_id = st.sidebar.selectbox("Run ID", options=pdf_assistant_run_ids) | ||
if st.session_state["pdf_assistant_run_id"] != new_pdf_assistant_run_id: | ||
logger.debug(f"Loading run {new_pdf_assistant_run_id}") | ||
logger.info("---*--- Loading PDF Assistant ---*---") | ||
st.session_state["pdf_assistant"] = get_pdf_assistant( | ||
user_id=username, | ||
run_id=new_pdf_assistant_run_id, | ||
debug_mode=True, | ||
) | ||
st.rerun() | ||
|
||
pdf_assistant_run_name = pdf_assistant.run_name | ||
if pdf_assistant_run_name: | ||
st.sidebar.write(f":thread: {pdf_assistant_run_name}") | ||
|
||
# Show reload button | ||
reload_button_sidebar() | ||
|
||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import json | ||
from typing import Optional, List | ||
|
||
from phi.assistant import Assistant | ||
from phi.llm.openai import OpenAIChat | ||
|
||
from ai.settings import ai_settings | ||
from hn_ai.search import search_web | ||
from arxiv_ai.storage import arxiv_assistant_storage | ||
from arxiv_ai.tools import ArxivTools | ||
from arxiv_ai.knowledge import get_arxiv_knowledge_base_for_user | ||
from utils.log import logger | ||
|
||
|
||
def get_arxiv_assistant( | ||
user_id: str, | ||
run_id: Optional[str] = None, | ||
debug_mode: bool = False, | ||
) -> Assistant: | ||
pdf_tools = ArxivTools(user_id=user_id) | ||
document_names_json: Optional[str] = pdf_tools.get_document_names() | ||
document_names: Optional[List] = json.loads(document_names_json) if document_names_json else None | ||
logger.info(f"Documents available: {document_names}") | ||
|
||
introduction = "Hi, I am Arxiv AI, built by [phidata](https://github.com/phidatahq/phidata)." | ||
|
||
instructions = [ | ||
"You are made by phidata: https://github.com/phidatahq/phidata", | ||
f"You are interacting with the user: {user_id}", | ||
"You have a knowledge base of ArXiv papers that you can use to answer questions.", | ||
"When the user asks a question, first determine if you can answer the question from the documents in the knowledge base.", | ||
] | ||
if document_names is None or len(document_names) == 0: | ||
introduction += " Please upload a document to get started." | ||
instructions.append( | ||
"You do not have any documents in your knowledge base. Ask the user politely to upload a document and share a nice joke with them." | ||
) | ||
elif len(document_names) == 1: | ||
introduction += "\n\nAsk me about: {}".format(", ".join(document_names)) | ||
instructions.extend( | ||
[ | ||
f"You have the following documents in your knowledge base: {document_names}", | ||
"If the user asks a specific question, use the `search_latest_document` tool to search the latest document for context.", | ||
"If the user asks a summary, use the `get_latest_document_contents` tool to get the contents of the latest document.", | ||
"You can also search the entire knowledge base using the `search_knowledge_base` tool.", | ||
] | ||
) | ||
else: | ||
introduction += "\n\nAsk me about: {}".format(", ".join(document_names)) | ||
instructions.extend( | ||
[ | ||
f"You have the following documents in your knowledge base: {document_names}", | ||
"When the user asks a question, first determine if you should search a specific document or the latest document uploaded by the user.", | ||
"If the user asks a specific question, use the `search_document` tool if you know the document to search OR `search_latest_document` tool to search the latest document for context.", | ||
"If the user asks to summarize a document, use the `get_document_contents` if you know the document to search OR `get_latest_document_contents` tool to get the contents of the latest document.", | ||
] | ||
) | ||
instructions.extend( | ||
[ | ||
"You can also search the entire knowledge base using the `search_knowledge_base` tool.", | ||
"Keep your conversation light hearted and fun.", | ||
"Using information from the document, provide the user with a concise and relevant answer.", | ||
"If the user asks what is this? they are asking about the latest document", | ||
"If you cannot find the information in the knowledge base, think if you can find it on the web. If you can find the information on the web, use the `search_web` tool", | ||
"When searching the knowledge base, search for at least 3 documents.", | ||
"When getting document contents, get atleast 3000 words so you get the first few pages.", | ||
"Most documents have a table of contents in the beginning so if you need those, use the `get_document_contents` tool.", | ||
"If the user compliments you, ask them to star phidata on GitHub: https://github.com/phidatahq/phidata", | ||
] | ||
) | ||
|
||
return Assistant( | ||
name=f"pdf_assistant_{user_id}" if user_id else "hn_assistant", | ||
run_id=run_id, | ||
user_id=user_id, | ||
llm=OpenAIChat( | ||
model=ai_settings.gpt_4, | ||
max_tokens=ai_settings.default_max_tokens, | ||
temperature=ai_settings.default_temperature, | ||
), | ||
storage=arxiv_assistant_storage, | ||
monitoring=True, | ||
use_tools=True, | ||
introduction=introduction, | ||
tools=[search_web, pdf_tools], | ||
knowledge_base=get_arxiv_knowledge_base_for_user(user_id), | ||
show_tool_calls=True, | ||
debug_mode=debug_mode, | ||
description="Your name is PDF AI and you are a chatbot that answers questions from a knowledge base of PDFs.", | ||
add_datetime_to_instructions=True, | ||
instructions=instructions, | ||
user_data={"documents": document_names}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from typing import Optional | ||
|
||
from phi.knowledge import AssistantKnowledge | ||
from phi.embedder.openai import OpenAIEmbedder | ||
from phi.vectordb.pgvector import PgVector2 | ||
|
||
from db.session import db_url | ||
|
||
|
||
def get_arxiv_knowledge_base_for_user(user_id: Optional[str] = None) -> AssistantKnowledge: | ||
table_name = f"arxiv_documents_{user_id}" if user_id else "arxiv_documents" | ||
return AssistantKnowledge( | ||
vector_db=PgVector2( | ||
schema="ai", | ||
db_url=db_url, | ||
collection=table_name, | ||
embedder=OpenAIEmbedder(model="text-embedding-3-small"), | ||
), | ||
num_documents=5, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from phi.storage.assistant.postgres import PgAssistantStorage | ||
|
||
from db.session import db_url | ||
|
||
arxiv_assistant_storage = PgAssistantStorage( | ||
schema="ai", | ||
db_url=db_url, | ||
table_name="pdf_assistant", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from arxiv_ai.assistant import get_arxiv_assistant | ||
|
||
arxiv_assistant = get_arxiv_assistant(user_id="ab", debug_mode=True) | ||
|
||
arxiv_assistant.print_response("Who are you?") | ||
# arxiv_assistant.print_response("summarize this") | ||
# arxiv_assistant.print_response("What the capital of India?") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from phi.utils.log import set_log_level_to_debug | ||
|
||
from arxiv_ai.tools import ArxivTools | ||
|
||
set_log_level_to_debug() | ||
arxiv_tools = ArxivTools(user_id="ab") | ||
|
||
search_results = arxiv_tools.search_arxiv(query="The FlashAttention Paper") | ||
print(search_results) |
Oops, something went wrong.