Skip to content

Commit 8895875

Browse files
authored
Merge branch 'main' into dependabot/pip/src/backend/h11-0.16.0
2 parents 2ad7b4a + 2735ce2 commit 8895875

File tree

24 files changed

+576
-412
lines changed

24 files changed

+576
-412
lines changed

.devcontainer/devcontainer.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,15 @@
2929
"extensions": [
3030
"ms-python.python",
3131
"ms-python.vscode-pylance",
32+
"ms-python.vscode-python-envs",
3233
"charliermarsh.ruff",
3334
"mtxr.sqltools",
3435
"mtxr.sqltools-driver-pg",
36+
"esbenp.prettier-vscode",
37+
"mechatroner.rainbow-csv",
3538
"ms-vscode.vscode-node-azure-pack",
36-
"esbenp.prettier-vscode"
39+
"teamsdevapp.vscode-ai-foundry",
40+
"ms-windows-ai-studio.windows-ai-studio"
3741
],
3842
// Set *default* container specific settings.json values on container create.
3943
"settings": {

.github/workflows/app-tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ jobs:
123123
key: mypy${{ matrix.os }}-${{ matrix.python_version }}-${{ hashFiles('requirements-dev.txt', 'src/backend/requirements.txt', 'src/backend/pyproject.toml') }}
124124

125125
- name: Run MyPy
126-
run: python3 -m mypy .
126+
run: python3 -m mypy . --python-version ${{ matrix.python_version }}
127127

128128
- name: Run Pytest
129129
run: python3 -m pytest -s -vv --cov --cov-fail-under=85

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ lint.isort.known-first-party = ["fastapi_app"]
77

88
[tool.mypy]
99
check_untyped_defs = true
10-
python_version = 3.9
1110
exclude = [".venv/*"]
1211

1312
[tool.pytest.ini_options]

src/backend/fastapi_app/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@ class State(TypedDict):
3434
@asynccontextmanager
3535
async def lifespan(app: fastapi.FastAPI) -> AsyncIterator[State]:
3636
context = await common_parameters()
37-
azure_credential = await get_azure_credential()
37+
azure_credential = None
38+
if (
39+
os.getenv("OPENAI_CHAT_HOST") == "azure"
40+
or os.getenv("OPENAI_EMBED_HOST") == "azure"
41+
or os.getenv("POSTGRES_HOST", "").endswith(".database.azure.com")
42+
):
43+
azure_credential = await get_azure_credential()
3844
engine = await create_postgres_engine_from_env(azure_credential)
3945
sessionmaker = await create_async_sessionmaker(engine)
4046
chat_client = await create_openai_chat_client(azure_credential)
@@ -53,6 +59,7 @@ def create_app(testing: bool = False):
5359
if not testing:
5460
load_dotenv(override=True)
5561
logging.basicConfig(level=logging.INFO)
62+
5663
# Turn off particularly noisy INFO level logs from Azure Core SDK:
5764
logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING)
5865
logging.getLogger("azure.identity").setLevel(logging.WARNING)

src/backend/fastapi_app/api_models.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from enum import Enum
22
from typing import Any, Optional
33

4-
from openai.types.chat import ChatCompletionMessageParam
5-
from pydantic import BaseModel
4+
from openai.types.responses import ResponseInputItemParam
5+
from pydantic import BaseModel, Field
66

77

88
class AIChatRoles(str, Enum):
@@ -36,19 +36,39 @@ class ChatRequestContext(BaseModel):
3636

3737

3838
class ChatRequest(BaseModel):
39-
messages: list[ChatCompletionMessageParam]
39+
messages: list[ResponseInputItemParam]
4040
context: ChatRequestContext
4141
sessionState: Optional[Any] = None
4242

4343

44+
class ItemPublic(BaseModel):
45+
id: int
46+
type: str
47+
brand: str
48+
name: str
49+
description: str
50+
price: float
51+
52+
def to_str_for_rag(self):
53+
return f"Name:{self.name} Description:{self.description} Price:{self.price} Brand:{self.brand} Type:{self.type}"
54+
55+
56+
class ItemWithDistance(ItemPublic):
57+
distance: float
58+
59+
def __init__(self, **data):
60+
super().__init__(**data)
61+
self.distance = round(self.distance, 2)
62+
63+
4464
class ThoughtStep(BaseModel):
4565
title: str
4666
description: Any
4767
props: dict = {}
4868

4969

5070
class RAGContext(BaseModel):
51-
data_points: dict[int, dict[str, Any]]
71+
data_points: dict[int, ItemPublic]
5272
thoughts: list[ThoughtStep]
5373
followup_questions: Optional[list[str]] = None
5474

@@ -69,27 +89,39 @@ class RetrievalResponseDelta(BaseModel):
6989
sessionState: Optional[Any] = None
7090

7191

72-
class ItemPublic(BaseModel):
73-
id: int
74-
type: str
75-
brand: str
76-
name: str
77-
description: str
78-
price: float
79-
80-
81-
class ItemWithDistance(ItemPublic):
82-
distance: float
83-
84-
def __init__(self, **data):
85-
super().__init__(**data)
86-
self.distance = round(self.distance, 2)
87-
88-
8992
class ChatParams(ChatRequestOverrides):
9093
prompt_template: str
9194
response_token_limit: int = 1024
9295
enable_text_search: bool
9396
enable_vector_search: bool
9497
original_user_query: str
95-
past_messages: list[ChatCompletionMessageParam]
98+
past_messages: list[ResponseInputItemParam]
99+
100+
101+
class Filter(BaseModel):
102+
column: str
103+
comparison_operator: str
104+
value: Any
105+
106+
107+
class PriceFilter(Filter):
108+
column: str = Field(default="price", description="The column to filter on (always 'price' for this filter)")
109+
comparison_operator: str = Field(description="The operator for price comparison ('>', '<', '>=', '<=', '=')")
110+
value: float = Field(description="The price value to compare against (e.g., 30.00)")
111+
112+
113+
class BrandFilter(Filter):
114+
column: str = Field(default="brand", description="The column to filter on (always 'brand' for this filter)")
115+
comparison_operator: str = Field(description="The operator for brand comparison ('=' or '!=')")
116+
value: str = Field(description="The brand name to compare against (e.g., 'AirStrider')")
117+
118+
119+
class SearchResults(BaseModel):
120+
query: str
121+
"""The original search query"""
122+
123+
items: list[ItemPublic]
124+
"""List of items that match the search query and filters"""
125+
126+
filters: list[Filter]
127+
"""List of filters applied to the search results"""

src/backend/fastapi_app/openai_clients.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99

1010

1111
async def create_openai_chat_client(
12-
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential],
12+
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential, None],
1313
) -> Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]:
1414
openai_chat_client: Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]
1515
OPENAI_CHAT_HOST = os.getenv("OPENAI_CHAT_HOST")
1616
if OPENAI_CHAT_HOST == "azure":
17-
api_version = os.environ["AZURE_OPENAI_VERSION"] or "2024-03-01-preview"
17+
api_version = os.environ["AZURE_OPENAI_VERSION"] or "2024-10-21"
1818
azure_endpoint = os.environ["AZURE_OPENAI_ENDPOINT"]
1919
azure_deployment = os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT"]
2020
if api_key := os.getenv("AZURE_OPENAI_KEY"):
@@ -29,7 +29,7 @@ async def create_openai_chat_client(
2929
azure_deployment=azure_deployment,
3030
api_key=api_key,
3131
)
32-
else:
32+
elif azure_credential:
3333
logger.info(
3434
"Setting up Azure OpenAI client for chat completions using Azure Identity, endpoint %s, deployment %s",
3535
azure_endpoint,
@@ -44,6 +44,8 @@ async def create_openai_chat_client(
4444
azure_deployment=azure_deployment,
4545
azure_ad_token_provider=token_provider,
4646
)
47+
else:
48+
raise ValueError("Azure OpenAI client requires either an API key or Azure Identity credential.")
4749
elif OPENAI_CHAT_HOST == "ollama":
4850
logger.info("Setting up OpenAI client for chat completions using Ollama")
4951
openai_chat_client = openai.AsyncOpenAI(
@@ -67,7 +69,7 @@ async def create_openai_chat_client(
6769

6870

6971
async def create_openai_embed_client(
70-
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential],
72+
azure_credential: Union[azure.identity.AzureDeveloperCliCredential, azure.identity.ManagedIdentityCredential, None],
7173
) -> Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]:
7274
openai_embed_client: Union[openai.AsyncAzureOpenAI, openai.AsyncOpenAI]
7375
OPENAI_EMBED_HOST = os.getenv("OPENAI_EMBED_HOST")
@@ -87,7 +89,7 @@ async def create_openai_embed_client(
8789
azure_deployment=azure_deployment,
8890
api_key=api_key,
8991
)
90-
else:
92+
elif azure_credential:
9193
logger.info(
9294
"Setting up Azure OpenAI client for embeddings using Azure Identity, endpoint %s, deployment %s",
9395
azure_endpoint,
@@ -102,6 +104,8 @@ async def create_openai_embed_client(
102104
azure_deployment=azure_deployment,
103105
azure_ad_token_provider=token_provider,
104106
)
107+
else:
108+
raise ValueError("Azure OpenAI client requires either an API key or Azure Identity credential.")
105109
elif OPENAI_EMBED_HOST == "ollama":
106110
logger.info("Setting up OpenAI client for embeddings using Ollama")
107111
openai_embed_client = openai.AsyncOpenAI(

src/backend/fastapi_app/postgres_searcher.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sqlalchemy import Float, Integer, column, select, text
66
from sqlalchemy.ext.asyncio import AsyncSession
77

8+
from fastapi_app.api_models import Filter
89
from fastapi_app.embeddings import compute_text_embedding
910
from fastapi_app.postgres_models import Item
1011

@@ -26,21 +27,24 @@ def __init__(
2627
self.embed_dimensions = embed_dimensions
2728
self.embedding_column = embedding_column
2829

29-
def build_filter_clause(self, filters) -> tuple[str, str]:
30+
def build_filter_clause(self, filters: Optional[list[Filter]]) -> tuple[str, str]:
3031
if filters is None:
3132
return "", ""
3233
filter_clauses = []
3334
for filter in filters:
34-
if isinstance(filter["value"], str):
35-
filter["value"] = f"'{filter['value']}'"
36-
filter_clauses.append(f"{filter['column']} {filter['comparison_operator']} {filter['value']}")
35+
filter_value = f"'{filter.value}'" if isinstance(filter.value, str) else filter.value
36+
filter_clauses.append(f"{filter.column} {filter.comparison_operator} {filter_value}")
3737
filter_clause = " AND ".join(filter_clauses)
3838
if len(filter_clause) > 0:
3939
return f"WHERE {filter_clause}", f"AND {filter_clause}"
4040
return "", ""
4141

4242
async def search(
43-
self, query_text: Optional[str], query_vector: list[float], top: int = 5, filters: Optional[list[dict]] = None
43+
self,
44+
query_text: Optional[str],
45+
query_vector: list[float],
46+
top: int = 5,
47+
filters: Optional[list[Filter]] = None,
4448
):
4549
filter_clause_where, filter_clause_and = self.build_filter_clause(filters)
4650
table_name = Item.__tablename__
@@ -106,7 +110,7 @@ async def search_and_embed(
106110
top: int = 5,
107111
enable_vector_search: bool = False,
108112
enable_text_search: bool = False,
109-
filters: Optional[list[dict]] = None,
113+
filters: Optional[list[Filter]] = None,
110114
) -> list[Item]:
111115
"""
112116
Search rows by query text. Optionally converts the query text to a vector if enable_vector_search is True.
Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
Below is a history of the conversation so far, and a new question asked by the user that needs to be answered by searching database rows.
2-
You have access to an Azure PostgreSQL database with an items table that has columns for title, description, brand, price, and type.
3-
Generate a search query based on the conversation and the new question.
4-
If the question is not in English, translate the question to English before generating the search query.
5-
If you cannot generate a search query, return the original user question.
6-
DO NOT return anything besides the query.
1+
Your job is to find search results based off the user's question and past messages.
2+
You have access to only these tools:
3+
1. **search_database**: This tool allows you to search a table for items based on a query.
4+
You can pass in a search query and optional filters.
5+
Once you get the search results, you're done.
Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,36 @@
11
[
2-
{"role": "user", "content": "good options for climbing gear that can be used outside?"},
3-
{"role": "assistant", "tool_calls": [
4-
{
5-
"id": "call_abc123",
6-
"type": "function",
7-
"function": {
8-
"arguments": "{\"search_query\":\"climbing gear outside\"}",
9-
"name": "search_database"
10-
}
11-
}
12-
]},
13-
{
14-
"role": "tool",
15-
"tool_call_id": "call_abc123",
16-
"content": "Search results for climbing gear that can be used outside: ..."
17-
},
18-
{"role": "user", "content": "are there any shoes less than $50?"},
19-
{"role": "assistant", "tool_calls": [
20-
{
21-
"id": "call_abc456",
22-
"type": "function",
23-
"function": {
24-
"arguments": "{\"search_query\":\"shoes\",\"price_filter\":{\"comparison_operator\":\"<\",\"value\":50}}",
25-
"name": "search_database"
26-
}
27-
}
28-
]},
29-
{
30-
"role": "tool",
31-
"tool_call_id": "call_abc456",
32-
"content": "Search results for shoes cheaper than 50: ..."
33-
}
2+
{
3+
"role": "user",
4+
"content": "good options for climbing gear that can be used outside?"
5+
},
6+
{
7+
"id": "madeup",
8+
"call_id": "call_abc123",
9+
"name": "search_database",
10+
"arguments": "{\"search_query\":\"climbing gear outside\"}",
11+
"type": "function_call"
12+
},
13+
{
14+
"id": "madeupoutput",
15+
"call_id": "call_abc123",
16+
"output": "Search results for climbing gear that can be used outside: ...",
17+
"type": "function_call_output"
18+
},
19+
{
20+
"role": "user",
21+
"content": "are there any shoes less than $50?"
22+
},
23+
{
24+
"id": "madeup",
25+
"call_id": "call_abc456",
26+
"name": "search_database",
27+
"arguments": "{\"search_query\":\"shoes\",\"price_filter\":{\"comparison_operator\":\"<\",\"value\":50}}",
28+
"type": "function_call"
29+
},
30+
{
31+
"id": "madeupoutput",
32+
"call_id": "call_abc456",
33+
"output": "Search results for shoes cheaper than 50: ...",
34+
"type": "function_call_output"
35+
}
3436
]

0 commit comments

Comments
 (0)