Skip to content

Improve and tests #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
676a2eb
add pytest dependency
john0isaac Jun 28, 2024
296db5c
restrict workflow run and cache dependencies for faster reruns
john0isaac Jun 28, 2024
d41a6df
mypy fixes
john0isaac Jun 28, 2024
44446bb
add dataclasses for response for better typing and easier usage
john0isaac Jun 28, 2024
8c58aac
setup test client for database and without database with test coverag…
john0isaac Jun 28, 2024
a0522e6
fix tests for windows and macos
john0isaac Jun 28, 2024
06c424b
use basemodel instead of dataclass to match the other models and for …
john0isaac Jun 29, 2024
cf3bc78
fix scopes and add app fixture
john0isaac Jun 29, 2024
4d9a536
fix tests to use existing setup
john0isaac Jul 1, 2024
7573249
add mocks and use monkey patch for setting env vars
john0isaac Jul 5, 2024
9fb9e54
create database and seed data
john0isaac Jul 5, 2024
61f9b8f
add tests for items handler and similar
john0isaac Jul 5, 2024
86b7986
add search_handler tests
john0isaac Jul 5, 2024
610a418
add chat tests
john0isaac Jul 5, 2024
eac85f0
remove content length assertion to allow for flexibility in response …
john0isaac Jul 6, 2024
b36a260
add azure openai env vars and use session scoped fixutres
john0isaac Jul 6, 2024
ea722d7
Typing improvements
pamelafox Jul 11, 2024
addd3da
fix mypy module error
john0isaac Jul 12, 2024
5261ab2
typing improvements
john0isaac Jul 12, 2024
39abb0f
ignore pgvector from mypy checks
john0isaac Jul 12, 2024
121b341
follow sqlalchmey example for using columns()
john0isaac Jul 12, 2024
d0a9a02
reimplement abstract functions
john0isaac Jul 12, 2024
c5d99f5
fix typo in env var name
john0isaac Jul 12, 2024
5c17e9a
use fastapi dependency instead of global storage
john0isaac Jul 12, 2024
3a6076c
fix type and add mypy to tests
john0isaac Jul 12, 2024
da5220c
remove multiple env loading, use single azure_credentials
john0isaac Jul 12, 2024
79a8a2d
use app state to store global vars
john0isaac Jul 13, 2024
3f0286b
add pydatnic types for Item table
john0isaac Jul 13, 2024
02d0b1b
add more tests and fix azure credentials mocking
john0isaac Jul 15, 2024
17fc97f
apply feedback from pr review
john0isaac Jul 17, 2024
b2bb121
add postgres searcher tests
john0isaac Jul 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions .github/workflows/app-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@ on:
branches: [ main ]
pull_request:
branches: [ main ]
workflow_dispatch:

permissions:
contents: read

jobs:
test_package:
test-package:
name: Test ${{ matrix.os }} Python ${{ matrix.python_version }}
runs-on: ${{ matrix.os }}
strategy:
Expand Down Expand Up @@ -65,4 +69,6 @@ jobs:
run: |
cd ./src/frontend
npm install
npm run build
npm run build
- name: Run Pytest
run: python3 -m pytest
19 changes: 15 additions & 4 deletions .github/workflows/python-code-quality.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,34 @@ name: Python code quality
on:
push:
branches: [ main ]
paths:
- '**.py'

pull_request:
branches: [ main ]
paths:
- '**.py'

workflow_dispatch:

permissions:
contents: read

jobs:
build:
checks-format-and-lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python 3
uses: actions/setup-python@v5
with:
python-version: "3.12"
cache: 'pip'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements-dev.txt
python3 -m pip install --upgrade pip
python3 -m pip install ruff
- name: Lint with ruff
run: ruff check .
- name: Check formatting with ruff
run: ruff format --check .
run: ruff format . --check
22 changes: 16 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
[tool.ruff]
line-length = 120
target-version = "py311"
target-version = "py312"
lint.select = ["E", "F", "I", "UP"]
lint.ignore = ["D203"]
lint.isort.known-first-party = ["fastapi_app"]

[tool.ruff.lint]
select = ["E", "F", "I", "UP"]
ignore = ["D203"]
[tool.mypy]
check_untyped_defs = true
python_version = 3.12
exclude = [".venv/*"]

[tool.ruff.lint.isort]
known-first-party = ["fastapi_app"]
[tool.pytest.ini_options]
addopts = "-ra --cov"
testpaths = ["tests"]
pythonpath = ['src']
filterwarnings = ["ignore::DeprecationWarning"]

[tool.coverage.report]
show_missing = true
6 changes: 5 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,8 @@
ruff
pre-commit
pip-tools
pip-compile-cross-platform
pip-compile-cross-platform
pytest
pytest-cov
pytest-asyncio
psycopg2-binary
12 changes: 12 additions & 0 deletions src/fastapi_app/api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,15 @@ class ThoughtStep(BaseModel):
title: str
description: Any
props: dict = {}


class RAGContext(BaseModel):
data_points: dict[int, dict[str, Any]]
thoughts: list[ThoughtStep]
followup_questions: list[str] | None = None


class RetrievalResponse(BaseModel):
message: Message
context: RAGContext
session_state: Any | None = None
6 changes: 4 additions & 2 deletions src/fastapi_app/api_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from fastapi_app.rag_advanced import AdvancedRAGChat
from fastapi_app.rag_simple import SimpleRAGChat

from .api_models import RetrievalResponse

router = fastapi.APIRouter()


Expand Down Expand Up @@ -52,7 +54,7 @@ async def search_handler(query: str, top: int = 5, enable_vector_search: bool =
return [item.to_dict() for item in results]


@router.post("/chat")
@router.post("/chat", response_model=RetrievalResponse)
async def chat_handler(chat_request: ChatRequest):
messages = [message.model_dump() for message in chat_request.messages]
overrides = chat_request.context.get("overrides", {})
Expand All @@ -79,5 +81,5 @@ async def chat_handler(chat_request: ChatRequest):
chat_deployment=global_storage.openai_chat_deployment,
)

response = await ragchat.run(messages, overrides=overrides)
response: RetrievalResponse = await ragchat.run(messages, overrides=overrides)
return response
2 changes: 1 addition & 1 deletion src/fastapi_app/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


async def compute_text_embedding(
q: str, openai_client, embed_model: str, embed_deployment: str = None, embedding_dimensions: int = 1536
q: str, openai_client, embed_model: str, embed_deployment: str | None = None, embedding_dimensions: int = 1536
):
SUPPORTED_DIMENSIONS_MODEL = {
"text-embedding-ada-002": False,
Expand Down
2 changes: 1 addition & 1 deletion src/fastapi_app/postgres_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def search(

async def search_and_embed(
self,
query_text: str,
query_text: str | None = None,
top: int = 5,
enable_vector_search: bool = False,
enable_text_search: bool = False,
Expand Down
36 changes: 17 additions & 19 deletions src/fastapi_app/rag_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
)

from openai import AsyncOpenAI
from openai.types.chat import (
ChatCompletion,
)
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
from openai_messages_token_helper import build_messages, get_token_limit

from .api_models import ThoughtStep
from .api_models import Message, RAGContext, RetrievalResponse, ThoughtStep
from .postgres_searcher import PostgresSearcher
from .query_rewriter import build_search_function, extract_search_arguments

Expand All @@ -35,7 +33,7 @@ def __init__(

async def run(
self, messages: list[dict], overrides: dict[str, Any] = {}
) -> dict[str, Any] | AsyncGenerator[dict[str, Any], None]:
) -> RetrievalResponse | AsyncGenerator[dict[str, Any], None]:
text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
top = overrides.get("top", 3)
Expand All @@ -45,7 +43,7 @@ async def run(

# Generate an optimized keyword search query based on the chat history and the last question
query_response_token_limit = 500
query_messages = build_messages(
query_messages: list[ChatCompletionMessageParam] = build_messages(
model=self.chat_model,
system_prompt=self.query_prompt_template,
new_user_content=original_user_query,
Expand All @@ -55,7 +53,7 @@ async def run(
)

chat_completion: ChatCompletion = await self.openai_chat_client.chat.completions.create(
messages=query_messages, # type: ignore
messages=query_messages,
# Azure OpenAI takes the deployment name as the model name
model=self.chat_deployment if self.chat_deployment else self.chat_model,
temperature=0.0, # Minimize creativity for search query generation
Expand All @@ -81,7 +79,7 @@ async def run(

# Generate a contextual and content specific answer using the search results and chat history
response_token_limit = 1024
messages = build_messages(
contextual_messages: list[ChatCompletionMessageParam] = build_messages(
model=self.chat_model,
system_prompt=overrides.get("prompt_template") or self.answer_prompt_template,
new_user_content=original_user_query + "\n\nSources:\n" + content,
Expand All @@ -90,21 +88,21 @@ async def run(
fallback_to_default=True,
)

chat_completion_response = await self.openai_chat_client.chat.completions.create(
chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create(
# Azure OpenAI takes the deployment name as the model name
model=self.chat_deployment if self.chat_deployment else self.chat_model,
messages=messages,
messages=contextual_messages,
temperature=overrides.get("temperature", 0.3),
max_tokens=response_token_limit,
n=1,
stream=False,
)
first_choice = chat_completion_response.model_dump()["choices"][0]
return {
"message": first_choice["message"],
"context": {
"data_points": {item.id: item.to_dict() for item in results},
"thoughts": [
first_choice = chat_completion_response.choices[0]
return RetrievalResponse(
message=Message(content=first_choice.message.content, role=first_choice.message.role),
context=RAGContext(
data_points={item.id: item.to_dict() for item in results},
thoughts=[
ThoughtStep(
title="Prompt to generate search arguments",
description=[str(message) for message in query_messages],
Expand All @@ -130,13 +128,13 @@ async def run(
),
ThoughtStep(
title="Prompt to generate answer",
description=[str(message) for message in messages],
description=[str(message) for message in contextual_messages],
props=(
{"model": self.chat_model, "deployment": self.chat_deployment}
if self.chat_deployment
else {"model": self.chat_model}
),
),
],
},
}
),
)
29 changes: 15 additions & 14 deletions src/fastapi_app/rag_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
)

from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
from openai_messages_token_helper import build_messages, get_token_limit

from .api_models import ThoughtStep
from .api_models import Message, RAGContext, RetrievalResponse, ThoughtStep
from .postgres_searcher import PostgresSearcher


Expand All @@ -30,7 +31,7 @@ def __init__(

async def run(
self, messages: list[dict], overrides: dict[str, Any] = {}
) -> dict[str, Any] | AsyncGenerator[dict[str, Any], None]:
) -> RetrievalResponse | AsyncGenerator[dict[str, Any], None]:
text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
top = overrides.get("top", 3)
Expand All @@ -48,7 +49,7 @@ async def run(

# Generate a contextual and content specific answer using the search results and chat history
response_token_limit = 1024
messages = build_messages(
contextual_messages: list[ChatCompletionMessageParam] = build_messages(
model=self.chat_model,
system_prompt=overrides.get("prompt_template") or self.answer_prompt_template,
new_user_content=original_user_query + "\n\nSources:\n" + content,
Expand All @@ -57,21 +58,21 @@ async def run(
fallback_to_default=True,
)

chat_completion_response = await self.openai_chat_client.chat.completions.create(
chat_completion_response: ChatCompletion = await self.openai_chat_client.chat.completions.create(
# Azure OpenAI takes the deployment name as the model name
model=self.chat_deployment if self.chat_deployment else self.chat_model,
messages=messages,
messages=contextual_messages,
temperature=overrides.get("temperature", 0.3),
max_tokens=response_token_limit,
n=1,
stream=False,
)
first_choice = chat_completion_response.model_dump()["choices"][0]
return {
"message": first_choice["message"],
"context": {
"data_points": {item.id: item.to_dict() for item in results},
"thoughts": [
first_choice = chat_completion_response.choices[0]
return RetrievalResponse(
message=Message(content=first_choice.message.content, role=first_choice.message.role),
context=RAGContext(
data_points={item.id: item.to_dict() for item in results},
thoughts=[
ThoughtStep(
title="Search query for database",
description=original_user_query if text_search else None,
Expand All @@ -87,13 +88,13 @@ async def run(
),
ThoughtStep(
title="Prompt to generate answer",
description=[str(message) for message in messages],
description=[str(message) for message in contextual_messages],
props=(
{"model": self.chat_model, "deployment": self.chat_deployment}
if self.chat_deployment
else {"model": self.chat_model}
),
),
],
},
}
),
)
Loading