Skip to content

Commit 3ec5cc2

Browse files
author
florian
committed
more improvements
1 parent f36ea6e commit 3ec5cc2

File tree

6 files changed

+160
-48
lines changed

6 files changed

+160
-48
lines changed

frontend/app/layout.tsx

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ import "./globals.css";
55
const inter = Inter({ subsets: ["latin"] });
66

77
export const metadata: Metadata = {
8-
title: "Create Next App",
9-
description: "Generated by create next app",
8+
title: "PyPi LLM Search",
9+
description: "Find PyPi packages with natural language using LLM's",
1010
};
1111

1212
export default function RootLayout({

frontend/app/page.tsx

+36-13
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ export default function Home() {
1212
const [sortDirection, setSortDirection] = useState("desc");
1313
const [loading, setLoading] = useState(false);
1414
const [error, setError] = useState("");
15+
const [infoBoxVisible, setInfoBoxVisible] = useState(false);
1516

1617
const handleSearch = async () => {
1718
setLoading(true);
@@ -28,10 +29,8 @@ export default function Home() {
2829
},
2930
},
3031
);
31-
const sortedResults = response.data.matches.sort(
32-
(a, b) => b.weekly_downloads - a.weekly_downloads,
33-
);
34-
setResults(sortedResults);
32+
const fetchedResults = response.data.matches;
33+
setResults(sortResults(fetchedResults, sortField, sortDirection));
3534
} catch (error) {
3635
setError("Error fetching search results.");
3736
console.error("Error fetching search results:", error);
@@ -40,17 +39,20 @@ export default function Home() {
4039
}
4140
};
4241

43-
const sortResults = (field) => {
44-
const direction =
45-
sortField === field && sortDirection === "asc" ? "desc" : "asc";
46-
const sorted = [...results].sort((a, b) => {
42+
const sortResults = (data, field, direction) => {
43+
return [...data].sort((a, b) => {
4744
if (a[field] < b[field]) return direction === "asc" ? -1 : 1;
4845
if (a[field] > b[field]) return direction === "asc" ? 1 : -1;
4946
return 0;
5047
});
51-
setResults(sorted);
48+
};
49+
50+
const handleSort = (field) => {
51+
const direction =
52+
sortField === field && sortDirection === "asc" ? "desc" : "asc";
5253
setSortField(field);
5354
setSortDirection(direction);
55+
setResults(sortResults(results, field, direction));
5456
};
5557

5658
return (
@@ -80,17 +82,38 @@ export default function Home() {
8082
{error && <p className="text-red-500">{error}</p>}
8183
</div>
8284

85+
<div className="w-full flex justify-center mt-6">
86+
<button
87+
className="w-[250px] p-2 border rounded bg-gray-300 hover:bg-gray-400 focus:outline-none focus:ring-2 focus:ring-gray-500"
88+
onClick={() => setInfoBoxVisible(!infoBoxVisible)}
89+
>
90+
{infoBoxVisible ? "Hide Info" : "How does this work?"}
91+
</button>
92+
</div>
93+
94+
{infoBoxVisible && (
95+
<div className="w-3/5 bg-white p-6 rounded-lg shadow-lg mt-4">
96+
<h2 className="text-2xl font-bold mb-2">How does this work?</h2>
97+
<p className="text-gray-700">
98+
This application allows you to search for Python packages on PyPi
99+
using natural language. So an example query would be "a package that
100+
creates plots and beautiful visualizations". Once you click search,
101+
your query will be matched against the summary and the first part of
102+
the description of all PyPi packages with more than 50 weekly
103+
downloads, and the 50 most similar results will be displayed in a
104+
table below.
105+
</p>
106+
</div>
107+
)}
108+
83109
{results.length > 0 && (
84110
<div className="w-full flex justify-center mt-6">
85111
<div className="w-11/12 bg-white p-6 rounded-lg shadow-lg flex flex-col items-center">
86-
<p className="mb-4 text-gray-700">
87-
Displaying the {results.length} most similar results:
88-
</p>
89112
<SearchResultsTable
90113
results={results}
91114
sortField={sortField}
92115
sortDirection={sortDirection}
93-
onSort={sortResults}
116+
onSort={handleSort}
94117
/>
95118
</div>
96119
</div>

notebooks/main.ipynb

+78-23
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,23 @@
3434
"from pypi_llm.config import Config\n",
3535
"from pypi_llm.data.description_cleaner import DescriptionCleaner, CLEANING_FAILED\n",
3636
"from pypi_llm.data.reader import DataReader\n",
37+
"from sentence_transformers import SentenceTransformer\n",
38+
"from pypi_llm.vector_database import VectorDatabaseInterface\n",
3739
"\n",
3840
"load_dotenv()\n",
3941
"config = Config()\n",
4042
"\n",
41-
"df = DataReader(config.DATA_DIR).read()\n",
42-
"df = DescriptionCleaner().clean(df, \"description\", \"description_cleaned\")\n",
43-
"df = df.filter(~pl.col(\"description_cleaned\").is_null())\n",
44-
"df = df.filter(pl.col(\"description_cleaned\")!=CLEANING_FAILED)"
43+
"# Load dataset and model\n",
44+
"df = pl.read_csv(config.DATA_DIR / config.PROCESSED_DATASET_CSV_NAME)\n",
45+
"model = SentenceTransformer(config.EMBEDDINGS_MODEL_NAME)\n",
46+
"\n",
47+
"# Initialize vector database interface\n",
48+
"vector_database_interface = VectorDatabaseInterface(\n",
49+
" pinecone_token=config.PINECONE_TOKEN,\n",
50+
" pinecone_index_name=config.PINECONE_INDEX_NAME,\n",
51+
" embeddings_model=model,\n",
52+
" pinecone_namespace=config.PINECONE_NAMESPACE,\n",
53+
")"
4554
]
4655
},
4756
{
@@ -51,39 +60,85 @@
5160
"metadata": {},
5261
"outputs": [],
5362
"source": [
54-
"with pl.Config(fmt_str_lengths=1000):\n",
63+
"with pl.Config(fmt_str_lengths=100):\n",
5564
" display(df.head(10))"
5665
]
5766
},
5867
{
5968
"cell_type": "code",
6069
"execution_count": null,
61-
"id": "053c9cf1-9f79-4b98-bcc9-85b6b676da84",
70+
"id": "bf393f0c-92c6-4d4a-bd97-d3ea7ebf2b80",
6271
"metadata": {},
6372
"outputs": [],
6473
"source": [
65-
"from sentence_transformers import SentenceTransformer\n",
66-
"model = SentenceTransformer(config.EMBEDDINGS_MODEL)\n",
67-
"embeddings = model.encode(query)\n",
68-
"\n",
69-
"from pinecone import Pinecone, Index\n",
70-
"pc = Pinecone(api_key=config.PINECONE_TOKEN)\n",
71-
"index = pc.Index(config.PINECONE_INDEX_NAME)\n",
72-
"\n",
73-
"matches = index.query(\n",
74-
" namespace=\"ns1\",\n",
75-
" vector=embeddings.tolist(),\n",
76-
" top_k=50,\n",
77-
" include_values=False\n",
74+
"query = \"find unused packages\""
75+
]
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": null,
80+
"id": "07ebd4fd-a0b9-4958-8325-bdff4be45a66",
81+
"metadata": {},
82+
"outputs": [],
83+
"source": [
84+
"df_matches = vector_database_interface.find_similar(query, top_k=100)\n",
85+
"df_matches = df_matches.join(df, how=\"left\", on=\"name\")\n",
86+
"df_matches = df_matches.sort(\"similarity\", descending=True)"
87+
]
88+
},
89+
{
90+
"cell_type": "code",
91+
"execution_count": null,
92+
"id": "fa071203-a3cd-4e80-a7b7-0ac7562bef8d",
93+
"metadata": {},
94+
"outputs": [],
95+
"source": [
96+
"df_matches"
97+
]
98+
},
99+
{
100+
"cell_type": "code",
101+
"execution_count": null,
102+
"id": "8b7b28e7-495c-44db-a939-dfa3e2c45159",
103+
"metadata": {},
104+
"outputs": [],
105+
"source": [
106+
"# Rank the columns\n",
107+
"df_matches = df_matches.with_columns(\n",
108+
" rank_similarity=pl.col(\"similarity\").rank(\"dense\", descending=False),\n",
109+
" rank_weekly_downloads=pl.col(\"weekly_downloads\").rank(\"dense\", descending=False)\n",
78110
")\n",
79111
"\n",
80-
"df_matches = pl.from_dicts([{'name' : x['id'], 'similarity': x['score']} for x in matches['matches']])\n",
112+
"df_matches = df_matches.with_columns(\n",
113+
" normalized_similarity=(pl.col(\"rank_similarity\") - 1) / (df_matches['rank_similarity'].max() - 1),\n",
114+
" normalized_weekly_downloads=(pl.col(\"rank_weekly_downloads\") - 1) / (df_matches['rank_weekly_downloads'].max() - 1)\n",
115+
")\n",
81116
"\n",
82-
"df_matches = df_matches.join(df, how = 'left', on = 'name')\n",
117+
"df_matches = df_matches.with_columns(\n",
118+
" score=0.5 * pl.col(\"normalized_similarity\") + 0.5 * pl.col(\"normalized_weekly_downloads\")\n",
119+
")\n",
83120
"\n",
84-
"df_matches.sort('weekly_downloads', descending=True)\n",
85-
"\n"
121+
"# Sort the DataFrame by the combined score in descending order\n",
122+
"df_matches = df_matches.sort(\"score\", descending=True)"
123+
]
124+
},
125+
{
126+
"cell_type": "code",
127+
"execution_count": null,
128+
"id": "d5465cec-c717-4fc5-aa55-c4c7dc9e79cf",
129+
"metadata": {},
130+
"outputs": [],
131+
"source": [
132+
"df_matches.sort(\"score\", descending=True)"
86133
]
134+
},
135+
{
136+
"cell_type": "code",
137+
"execution_count": null,
138+
"id": "4384f057-8eaf-431d-a31a-f4f7e203ed35",
139+
"metadata": {},
140+
"outputs": [],
141+
"source": []
87142
}
88143
],
89144
"metadata": {

pypi_llm/api/main.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,16 @@
66
from sentence_transformers import SentenceTransformer
77

88
from pypi_llm.config import Config
9+
from pypi_llm.utils.score_calculator import calculate_score
910
from pypi_llm.vector_database import VectorDatabaseInterface
1011

1112
app = FastAPI()
1213

13-
# Load environment variables
1414
load_dotenv()
1515
config = Config()
1616

17-
# Setup CORS
1817
origins = [
1918
"http://localhost:3000",
20-
# Add other origins if needed
2119
]
2220

2321
app.add_middleware(
@@ -28,11 +26,9 @@
2826
allow_headers=["*"],
2927
)
3028

31-
# Load dataset and model
3229
df = pl.read_csv(config.DATA_DIR / config.PROCESSED_DATASET_CSV_NAME)
3330
model = SentenceTransformer(config.EMBEDDINGS_MODEL_NAME)
3431

35-
# Initialize vector database interface
3632
vector_database_interface = VectorDatabaseInterface(
3733
pinecone_token=config.PINECONE_TOKEN,
3834
pinecone_index_name=config.PINECONE_INDEX_NAME,
@@ -41,9 +37,9 @@
4137
)
4238

4339

44-
# Define request and response models
4540
class QueryModel(BaseModel):
4641
query: str
42+
top_k: int = 30
4743

4844

4945
class Match(BaseModel):
@@ -57,10 +53,14 @@ class SearchResponse(BaseModel):
5753
matches: list[Match]
5854

5955

60-
# Define search endpoint
6156
@app.post("/search/", response_model=SearchResponse)
6257
async def search(query: QueryModel):
63-
df_matches = vector_database_interface.find_similar(query.query, top_k=50)
58+
df_matches = vector_database_interface.find_similar(query.query, top_k=query.top_k * 2)
6459
df_matches = df_matches.join(df, how="left", on="name")
65-
df_matches = df_matches.sort("similarity", descending=True)
60+
61+
df_matches = calculate_score(df_matches)
62+
df_matches = df_matches.sort("score", descending=True)
63+
df_matches = df_matches.head(query.top_k)
64+
65+
print("sending")
6666
return SearchResponse(matches=df_matches.to_dicts())

pypi_llm/scripts/upsert_data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@
2727
)
2828

2929
df = df.with_columns(
30-
summary_and_description_cleaned=pl.concat_str(pl.col("summary"), pl.lit(" "), pl.col("description_cleaned"))
30+
summary_and_description_cleaned=pl.concat_str(pl.col("summary"), pl.lit(" - "), pl.col("description_cleaned"))
3131
)
3232
vector_database_interface.upsert_polars(df, key_column="name", text_column="summary_and_description_cleaned")

pypi_llm/utils/score_calculator.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import polars as pl
2+
3+
4+
def calculate_score(df: pl.DataFrame, weight_similarity=0.5, weight_weekly_downloads=0.5) -> pl.DataFrame:
5+
"""
6+
Calculate a combined score based on similarity and weekly downloads.
7+
8+
The function ranks the similarity and weekly downloads, normalizes these ranks to a [0, 1] scale,
9+
and then computes a combined score based on the provided weights for similarity and weekly downloads.
10+
The DataFrame is sorted by the combined score in descending order.
11+
12+
Args:
13+
df (pl.DataFrame): DataFrame containing 'similarity' and 'weekly_downloads' columns.
14+
weight_similarity (float): Weight for the similarity score in the combined score calculation. Default is 0.5.
15+
weight_weekly_downloads (float): Weight for the weekly downloads score in the combined score calculation. Default is 0.5.
16+
17+
"""
18+
df = df.with_columns(
19+
rank_similarity=pl.col("similarity").rank("dense", descending=False),
20+
rank_weekly_downloads=pl.col("weekly_downloads").rank("dense", descending=False),
21+
)
22+
23+
df = df.with_columns(
24+
normalized_similarity=(pl.col("rank_similarity") - 1) / (df["rank_similarity"].max() - 1),
25+
normalized_weekly_downloads=(pl.col("rank_weekly_downloads") - 1) / (df["rank_weekly_downloads"].max() - 1),
26+
)
27+
28+
df = df.with_columns(
29+
score=weight_similarity * pl.col("normalized_similarity")
30+
+ weight_weekly_downloads * pl.col("normalized_weekly_downloads")
31+
)
32+
33+
df = df.sort("score", descending=True)
34+
return df

0 commit comments

Comments
 (0)