Skip to content

Commit 34271d4

Browse files
committed
Adjust filters after merge
1 parent aa4f40e commit 34271d4

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

src/backend/fastapi_app/api_models.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class ChatRequest(BaseModel):
4141
context: ChatRequestContext
4242
sessionState: Optional[Any] = None
4343

44-
44+
4545
class ItemPublic(BaseModel):
4646
id: int
4747
name: str
@@ -50,13 +50,16 @@ class ItemPublic(BaseModel):
5050
rating: int
5151
price_level: int
5252
review_count: int
53-
hours: int
54-
tags: str
53+
hours: str
54+
tags: list[str]
5555
description: str
5656
menu_summary: str
5757
top_reviews: str
5858
vibe: str
5959

60+
def to_str_for_rag(self):
61+
return f"Name:{self.name} Description:{self.description} Location:{self.location} Cuisine:{self.cuisine} Rating:{self.rating} Price Level:{self.price_level} Review Count:{self.review_count} Hours:{self.hours} Tags:{self.tags} Menu Summary:{self.menu_summary} Top Reviews:{self.top_reviews} Vibe:{self.vibe}" # noqa: E501
62+
6063

6164
class ItemWithDistance(ItemPublic):
6265
distance: float
@@ -110,7 +113,9 @@ class Filter(BaseModel):
110113

111114

112115
class PriceLevelFilter(Filter):
113-
column: str = Field(default="price_level", description="The column to filter on (always 'price_level' for this filter)")
116+
column: str = Field(
117+
default="price_level", description="The column to filter on (always 'price_level' for this filter)"
118+
)
114119
comparison_operator: str = Field(description="The operator for price level comparison ('>', '<', '>=', '<=', '=')")
115120
value: float = Field(description="Value to compare against, either 1, 2, 3, 4")
116121

src/backend/fastapi_app/rag_advanced.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111

1212
from fastapi_app.api_models import (
1313
AIChatRoles,
14-
BrandFilter,
1514
ChatRequestOverrides,
1615
Filter,
1716
ItemPublic,
1817
Message,
19-
PriceFilter,
18+
PriceLevelFilter,
2019
RAGContext,
20+
RatingFilter,
2121
RetrievalResponse,
2222
RetrievalResponseDelta,
2323
SearchResults,
@@ -75,8 +75,8 @@ async def search_database(
7575
self,
7676
ctx: RunContext[ChatParams],
7777
search_query: str,
78-
price_filter: Optional[PriceFilter] = None,
79-
brand_filter: Optional[BrandFilter] = None,
78+
price_filter: Optional[PriceLevelFilter] = None,
79+
brand_filter: Optional[RatingFilter] = None,
8080
) -> SearchResults:
8181
"""
8282
Search PostgreSQL database for relevant products based on user query

0 commit comments

Comments
 (0)