diff --git a/.env.example b/.env.example index 8df49676..bb33262e 100644 --- a/.env.example +++ b/.env.example @@ -30,6 +30,7 @@ POSTGRES_PORT=5432 POSTGRES_DB=ai_platform POSTGRES_USER=postgres POSTGRES_PASSWORD=postgres +POSTGRES_DB_TEST=ai_platform_test SENTRY_DSN= diff --git a/.github/workflows/continuous_integration.yml b/.github/workflows/continuous_integration.yml index 5b40cdd3..e6373ac4 100644 --- a/.github/workflows/continuous_integration.yml +++ b/.github/workflows/continuous_integration.yml @@ -18,7 +18,11 @@ jobs: POSTGRES_DB: ai_platform ports: - 5432:5432 - options: --health-cmd "pg_isready -U postgres" --health-interval 10s --health-timeout 5s --health-retries 5 + options: >- + --health-cmd "pg_isready -U postgres" + --health-interval 10s + --health-timeout 5s + --health-retries 5 strategy: matrix: diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 24779bf3..32009e0d 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -47,6 +47,7 @@ class Settings(BaseSettings): POSTGRES_USER: str POSTGRES_PASSWORD: str = "" POSTGRES_DB: str = "" + POSTGRES_DB_TEST: str | None = None @computed_field # type: ignore[prop-decorator] @property @@ -60,6 +61,22 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn: path=self.POSTGRES_DB, ) + @computed_field # type: ignore[prop-decorator] + @property + def SQLALCHEMY_TEST_DATABASE_URI(self) -> PostgresDsn: + if not self.POSTGRES_DB_TEST: + raise ValueError( + "POSTGRES_DB_TEST is not set but is required for test configuration." + ) + return MultiHostUrl.build( + scheme="postgresql+psycopg", + username=self.POSTGRES_USER, + password=self.POSTGRES_PASSWORD, + host=self.POSTGRES_SERVER, + port=self.POSTGRES_PORT, + path=self.POSTGRES_DB_TEST, + ) + EMAIL_RESET_TOKEN_EXPIRE_HOURS: int = 48 EMAIL_TEST_USER: EmailStr diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index e2a6464d..f33ae916 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -2,42 +2,59 @@ import pytest from fastapi.testclient import TestClient -from sqlmodel import Session, delete +from sqlmodel import Session, SQLModel, create_engine, text from app.core.config import settings from app.core.db import engine, init_db from app.main import app -from app.models import ( - APIKey, - Assistant, - Organization, - Project, - ProjectUser, - User, - OpenAI_Thread, - Credential, - Collection, -) +from app.api.deps import get_db from app.tests.utils.user import authentication_token_from_email from app.tests.utils.utils import get_superuser_token_headers +def recreate_test_db(): + test_db_name = settings.POSTGRES_DB_TEST + if test_db_name is None: + raise ValueError( + "POSTGRES_DB_TEST is not set but is required for test configuration." + ) + with engine.connect() as conn: + conn.execution_options(isolation_level="AUTOCOMMIT") + # Disconnect other connections to the test DB + conn.execute( + text( + f""" + SELECT pg_terminate_backend(pid) + FROM pg_stat_activity + WHERE datname = '{test_db_name}' AND pid <> pg_backend_pid() + """ + ) + ) + conn.execute(text(f"DROP DATABASE IF EXISTS {test_db_name}")) + conn.execute(text(f"CREATE DATABASE {test_db_name}")) + + +recreate_test_db() +test_engine = create_engine(str(settings.SQLALCHEMY_TEST_DATABASE_URI)) + + @pytest.fixture(scope="session", autouse=True) def db() -> Generator[Session, None, None]: - with Session(engine) as session: + with Session(test_engine) as session: + SQLModel.metadata.create_all(test_engine) init_db(session) yield session - # Delete data in reverse dependency order - session.execute(delete(ProjectUser)) # Many-to-many relationship - session.execute(delete(Assistant)) - session.execute(delete(Credential)) - session.execute(delete(Project)) - session.execute(delete(Organization)) - session.execute(delete(APIKey)) - session.execute(delete(User)) - session.execute(delete(OpenAI_Thread)) - session.execute(delete(Collection)) - session.commit() + + +# Override the get_db dependency to use test session +@pytest.fixture(scope="session", autouse=True) +def override_get_db(db: Session): + def _get_test_db(): + yield db + + app.dependency_overrides[get_db] = _get_test_db + yield + app.dependency_overrides.clear() @pytest.fixture(scope="module") diff --git a/docker-compose.yml b/docker-compose.yml index ec6e9deb..abe7e21d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -69,6 +69,7 @@ services: - POSTGRES_SERVER=db - POSTGRES_PORT=${POSTGRES_PORT} - POSTGRES_DB=${POSTGRES_DB} + - POSTGRES_DB_TEST=${POSTGRES_DB_TEST} - POSTGRES_USER=${POSTGRES_USER?Variable not set} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD?Variable not set} - SENTRY_DSN=${SENTRY_DSN} @@ -108,6 +109,7 @@ services: - POSTGRES_SERVER=db - POSTGRES_PORT=${POSTGRES_PORT} - POSTGRES_DB=${POSTGRES_DB} + - POSTGRES_DB_TEST=${POSTGRES_DB_TEST} - POSTGRES_USER=${POSTGRES_USER?Variable not set} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD?Variable not set} - SENTRY_DSN=${SENTRY_DSN}