Skip to content

Commit

Permalink
finished simplest UI
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaohanzai committed Sep 19, 2023
1 parent 8df6c8d commit 678ebee
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 7 deletions.
125 changes: 125 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import pandas as pd
import streamlit as st
import sys
sys.path.append('code/')
from retriever_functions import *

@st.cache_resource
def load_data():
citation_data = pd.read_csv('data/FRB_citations_23.csv')
abstract_data = pd.read_csv('data/FRB_abstracts.csv')
return citation_data, abstract_data

@st.cache_resource
def load_retrievers(key):
retriever_abstract = gen_retriever('VectorStoreIndex', 'abstract', openai_api_key=key, path_to_db_folder='./data/')
retriever_citation = gen_retriever('VectorStoreIndex', 'citation', openai_api_key=key, path_to_db_folder='./data/')
return retriever_citation, retriever_abstract

def get_results_above_threshold(results_citation, results_abstract, score_threshold_percentile):
if query and retriever_citation is not None and citation_data is not None:
results_abstract = rearrange_query_results(results_abstract, score_threshold=0.8)
scores = get_all_scores(results_citation, show_hist=False)
score_threshold = np.percentile(scores, score_threshold_percentile)
results_citation = rearrange_query_results(results_citation, score_threshold=score_threshold, sort=True)
return results_citation, results_abstract

@st.cache_data
def query_retreivers(query, score_threshold_percentile):
if query and retriever_abstract is not None and abstract_data is not None:
results_citation = retriever_citation.retrieve(query)
results_abstract = retriever_abstract.retrieve(query)
results_citation, results_abstract = get_results_above_threshold(results_citation, results_abstract, score_threshold_percentile)
return results_citation, results_abstract

st.title("LLM Citation Tool")
st.write(
"Have you ever felt it quite a headache to find out which papers to cite when writing\
a paper? This tool is designed to help you with that.")
st.write("This app is a prototype built from arXiv papers on **fast radio bursts (FRBs)**, which means\
it can only do FRB literature search for now.\
For each FRB paper from July 2022 to Aug 2023, we extracted the citations in the introduction\
section of the paper and the reasons for citing these papers using OpenAI's API.\
The motivation of this project is that we find that humans are highly biased towards which papers\
they want to cite, and that finding references from embedding abstracts often gives interestingly\
unuseful results (e.g. those papers are indeed very related to a topic but just not what people\
normally would like to cite).\
Therefore, why not search for the reasons why other people cite papers and follow them?")
st.write("On the left column of this page, you will see query results from performing a similarity search\
on the reasons for citation. On the right column, you will see query results from embedding the\
title and abstracts of the papers. You can adjust the similarity score threshold and the number\
of results to display using the sliders below.")
st.write("This app uses OpenAI Embedding so **please input your own API key** below.\
It's not going to cost you more than a few cents, but I'm poor so I don't want to share my own key.")

st.write("This is just a prototype and the citation query is very likely to return unsatisfying results.\
For example, some papers (e.g. Lorimer et al.) that are identified by gpt3.5 as\
'cited to provide background knowleges of FRBs'\
seem to be matched to whatever query you make. Try raising the similarity score threshold and see if\
anything changes.\
Including abbreviations in the query, e.g. fast radio burst (frb), circumgalactic medium (cgm),\
may also help getting better query results.\
There are, however, failure cases like searching for FRB scintillation. Perhaps OpenAI Embedding doesn't\
process scintillation as an astrophysical term?")

st.write("There may be cases where the abstract search gives better results, so don't rely too much on the reasons for citation\
for now.")

api_key = st.text_input("Enter your OpenAI API key:")
query = st.text_input("Enter your query:")
st.write("Similarity score threshold for the citation search is determined by the percentile of the all scores.\
Default is set to 99.7, which roughly returns the top 20 most similar reasons of citation to the query.\
We then group these reasons with the same arXiv id and count the number of times each one appears.\
The final results are sorted by the number of times each arXiv id appears.\
This does not affect the abstract search because that one simply returns the top results above a\
similarity score threshold of 0.8.")
score_threshold = st.slider("Similarity score threshold (percentage):", min_value=99.0, max_value=99.8, value=99.7)
search_clicked = st.button("Search")

results_citation = results_abstract = None
if search_clicked and query and api_key:
citation_data, abstract_data = load_data()
retriever_citation, retriever_abstract = load_retrievers(api_key)

results_citation, results_abstract = query_retreivers(query, score_threshold)

col_l, col_r = st.columns(2)
n_results = 5

with col_l:
st.header("**Results**: from embedding reasons for citation")
if results_citation is not None:
for i in range(min(n_results, len(results_citation))):
row = results_citation.iloc[i]
doc_id = row['doc_id']
reasons = row['reasons']
if len(reasons) > 500:
reasons = reasons[:500] + '...'
arxiv_id = citation_data.iloc[doc_id]['arxiv_id']
txt_ref = citation_data.iloc[doc_id]['txt_ref']
st.subheader(f"{i+1}: {txt_ref} (arXiv id: {arxiv_id})")
st.write("***Reasons that people cited it:***")
for reason in reasons.split(';'):
st.write(reason)
st.divider()
# else:
# st.write('nothing mached your search')

with col_r:
st.header("**Results**: from embedding title and abstracts")
if results_abstract is not None:
for i in range(min(n_results, len(results_abstract))):
row = results_abstract.iloc[i]
doc_id = row['doc_id']
arxiv_id = abstract_data.iloc[doc_id]['arxiv_id']
authors = abstract_data.iloc[doc_id]['authors']
title = abstract_data.iloc[doc_id]['title']
abstract = abstract_data.iloc[doc_id]['abstract']
if len(abstract) > 800:
abstract = abstract[:800] + '...'
st.subheader(f"{i+1}: {authors} (arXiv id: {arxiv_id})")
st.write(f"***Title:*** {title}")
st.write(f"***Abstract:*** {abstract}")
st.divider()
# else:
# st.write('nothing mached your search')
18 changes: 11 additions & 7 deletions code/retriever_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,27 @@
import matplotlib.pyplot as plt
import numpy as np

def get_all_score(query_results, show_hist=True):
def get_all_scores(query_results, show_hist=True):
'''
query_results is the list of nodes returned by retriever.retrieve()
'''
scores = [node.score for node in query_results]
if show_hist:
vals = np.percentile(scores, [50, 99, 99.5])
vals = np.percentile(scores, [50, 99, 99.7])
plt.hist(scores, bins=50)
for val in vals:
plt.axvline(val, color='r', linestyle='--')
plt.show()
return scores

def gen_retriever(index_name, data_name, openai_api_key=None):
def gen_retriever(index_name, data_name, openai_api_key=None, path_to_db_folder='../data/'):
'''
index_name: 'VectorStoreIndex', 'SimpleKeywordTableIndex', 'RAKEKeywordTableIndex'
data_name: 'citation', 'abstract'
Input your own openai_api_key or it's going to detect environment variable OPENAI_API_KEY.
'''
index = load_index_from_storage(
StorageContext.from_defaults(persist_dir=f'../data/llamaindex{index_name}_openaiEmbed_{data_name}_db/'),
StorageContext.from_defaults(persist_dir=path_to_db_folder+f'/llamaindex{index_name}_openaiEmbed_{data_name}_db/'),
service_context=set_service_context(openai_api_key)
)
retriever = index.as_retriever()
Expand All @@ -32,15 +35,15 @@ def gen_retriever(index_name, data_name, openai_api_key=None):
retriever.num_chunks_per_query = n # there won't be a score in the case of keyword table index though
return retriever

def rearrange_query_results(query_results, score_threshold=0.9):
def rearrange_query_results(query_results, score_threshold=0.9, sort=True):
'''
If using VectorStoreIndex, keep the query results with similarity score > score_threshold.
Then count the number of times each doc appears in the query results and combine the reasons of citation.
Sort the query results by the number of times each doc appears.
Return the sorted docs.
If using keyword table index, just rearrange the query results into a pd.DataFrame.
'''
if query_results[0].score is None:
if query_results[0].score is None: # keyword table index
rst = pd.DataFrame(columns=['doc_id', 'reasons'])
for node in query_results:
doc_id = int(node.metadata['doc_id'])
Expand All @@ -60,7 +63,8 @@ def rearrange_query_results(query_results, score_threshold=0.9):
rst.loc[len(rst)] = {'doc_id': doc_id, 'n': 1, 'reasons': node.text}

# sort results
rst = rst.sort_values(by='n', ascending=False)
if sort:
rst = rst.sort_values(by='n', ascending=False, kind='mergesort')
return rst

def print_citation_query_results(data_citation, query_results, i_start=0, i_end=5, data_abstract=None):
Expand Down

0 comments on commit 678ebee

Please sign in to comment.