diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 43810c38..d19f0653 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -2,7 +2,8 @@ from typing import Annotated, Optional import jwt -from fastapi import Depends, HTTPException, status, Request, Header, Security +import logging +from fastapi import Depends, HTTPException, status, Request, Path, Query from fastapi.responses import JSONResponse from fastapi.security import OAuth2PasswordBearer, APIKeyHeader from jwt.exceptions import InvalidTokenError @@ -23,8 +24,12 @@ ProjectUser, Project, Organization, + APIKey, ) + +logger = logging.getLogger(__name__) + reusable_oauth2 = OAuth2PasswordBearer( tokenUrl=f"{settings.API_V1_STR}/login/access-token", auto_error=False ) @@ -147,6 +152,18 @@ def get_current_active_superuser_org(current_user: CurrentUserOrg) -> User: return current_user +def check_org_access(_current_user: CurrentUserOrg, org_id: int = Path): + """Helper function to check organization access.""" + if not _current_user.is_superuser and org_id != _current_user.organization_id: + logger.warning( + f"[check_org_access] Access violation | user_id={_current_user.id}, attempted_org_id={org_id}, " + f"current_org_id={_current_user.organization_id}" + ) + raise HTTPException( + status_code=403, detail="Access to this organization is forbidden" + ) + + def verify_user_project_organization( db: SessionDep, current_user: CurrentUserOrg, diff --git a/backend/app/api/routes/organization.py b/backend/app/api/routes/organization.py index 099494ee..b71ad7a4 100644 --- a/backend/app/api/routes/organization.py +++ b/backend/app/api/routes/organization.py @@ -11,11 +11,16 @@ OrganizationPublic, ) from app.api.deps import ( - CurrentUser, + CurrentUserOrg, SessionDep, get_current_active_superuser, + check_org_access, +) +from app.crud.organization import ( + create_organization, + get_organization_by_id, + validate_organization, ) -from app.crud.organization import create_organization, get_organization_by_id from app.utils import APIResponse router = APIRouter(prefix="/organizations", tags=["organizations"]) @@ -24,14 +29,28 @@ # Retrieve organizations @router.get( "/", - dependencies=[Depends(get_current_active_superuser)], response_model=APIResponse[List[OrganizationPublic]], ) -def read_organizations(session: SessionDep, skip: int = 0, limit: int = 100): - count_statement = select(func.count()).select_from(Organization) - count = session.exec(count_statement).one() +def read_organizations( + session: SessionDep, current_user: CurrentUserOrg, skip: int = 0, limit: int = 100 +): + """ + Return all organizations for superuser, + or only the one associated with the current user. + """ + + if current_user.is_superuser: + statement = select(Organization).where(Organization.is_active == True) + else: + if not current_user.organization_id: + return APIResponse.success_response([]) - statement = select(Organization).offset(skip).limit(limit) + statement = select(Organization).where( + Organization.id == current_user.organization_id, + Organization.is_deleted == False, + ) + + statement = statement.offset(skip).limit(limit) organizations = session.exec(statement).all() return APIResponse.success_response(organizations) @@ -44,37 +63,44 @@ def read_organizations(session: SessionDep, skip: int = 0, limit: int = 100): response_model=APIResponse[OrganizationPublic], ) def create_new_organization(*, session: SessionDep, org_in: OrganizationCreate): + """ + Creates a new organization, note that only a superuser can create an organization. + """ new_org = create_organization(session=session, org_create=org_in) return APIResponse.success_response(new_org) +# Retrieve an organization by ID @router.get( "/{org_id}", - dependencies=[Depends(get_current_active_superuser)], response_model=APIResponse[OrganizationPublic], ) -def read_organization(*, session: SessionDep, org_id: int): +def read_organization(*, session: SessionDep, org_id: int, _=Depends(check_org_access)): """ - Retrieve an organization by ID. + Retrieves an organization by ID, + a normal user can only retrieve the organization(s) associated with their user ID. """ - org = get_organization_by_id(session=session, org_id=org_id) - if org is None: - raise HTTPException(status_code=404, detail="Organization not found") + org = validate_organization(session=session, org_id=org_id) return APIResponse.success_response(org) # Update an organization @router.patch( "/{org_id}", - dependencies=[Depends(get_current_active_superuser)], response_model=APIResponse[OrganizationPublic], ) def update_organization( - *, session: SessionDep, org_id: int, org_in: OrganizationUpdate + *, + session: SessionDep, + org_id: int, + org_in: OrganizationUpdate, + _=Depends(check_org_access), ): - org = get_organization_by_id(session=session, org_id=org_id) - if org is None: - raise HTTPException(status_code=404, detail="Organization not found") + """ + Updates an organization by ID, + a normal user can only update the organization(s) associated with their user ID. + """ + org = validate_organization(session=session, org_id=org_id) org_data = org_in.model_dump(exclude_unset=True) org = org.model_copy(update=org_data) @@ -94,10 +120,10 @@ def update_organization( include_in_schema=False, ) def delete_organization(session: SessionDep, org_id: int): - org = get_organization_by_id(session=session, org_id=org_id) - if org is None: - raise HTTPException(status_code=404, detail="Organization not found") - + """ + Deletes an existing organization, note that only a superuser can delete an organization. + """ + org = validate_organization(session=session, org_id=org_id) session.delete(org) session.commit() diff --git a/backend/app/tests/api/routes/test_org.py b/backend/app/tests/api/routes/test_org.py index 709bb7f5..c43b6588 100644 --- a/backend/app/tests/api/routes/test_org.py +++ b/backend/app/tests/api/routes/test_org.py @@ -24,8 +24,20 @@ def test_organization(db: Session, superuser_token_headers: dict[str, str]): return organization -# Test retrieving organizations -def test_read_organizations(db: Session, superuser_token_headers: dict[str, str]): +@pytest.fixture +def other_organization(db: Session, superuser_token_headers: dict[str, str]): + unique_name = f"OtherOrg-{random_lower_string()}" + org_data = OrganizationCreate(name=unique_name, is_active=True) + organization = create_organization(session=db, org_create=org_data) + db.commit() + return organization + + +def test_read_organizations_as_superuser( + db: Session, + superuser_token_headers: dict[str, str], + test_organization: Organization, +): response = client.get( f"{settings.API_V1_STR}/organizations/", headers=superuser_token_headers ) @@ -33,10 +45,12 @@ def test_read_organizations(db: Session, superuser_token_headers: dict[str, str] response_data = response.json() assert "data" in response_data assert isinstance(response_data["data"], list) + assert len(response_data["data"]) > 0 -# Test creating an organization -def test_create_organization(db: Session, superuser_token_headers: dict[str, str]): +def test_create_organization_as_superuser( + db: Session, superuser_token_headers: dict[str, str] +): unique_name = f"Org-{random_lower_string()}" org_data = {"name": unique_name, "is_active": True} response = client.post( @@ -44,23 +58,22 @@ def test_create_organization(db: Session, superuser_token_headers: dict[str, str json=org_data, headers=superuser_token_headers, ) - assert 200 <= response.status_code < 300 created_org = response.json() - assert "data" in created_org # Make sure there's a 'data' field + assert "data" in created_org created_org_data = created_org["data"] org = get_organization_by_id(session=db, org_id=created_org_data["id"]) - assert org is not None # The organization should be found in the DB + assert org is not None assert org.name == created_org_data["name"] assert org.is_active == created_org_data["is_active"] -def test_update_organization( +def test_update_organization_as_superuser( db: Session, test_organization: Organization, superuser_token_headers: dict[str, str], ): - unique_name = f"UpdatedOrg-{random_lower_string()}" # Ensure a unique name + unique_name = f"UpdatedOrg-{random_lower_string()}" update_data = {"name": unique_name, "is_active": False} response = client.patch( @@ -77,8 +90,23 @@ def test_update_organization( assert updated_org["is_active"] == update_data["is_active"] -# Test deleting an organization -def test_delete_organization( +def test_update_organization_as_regular_user( + db: Session, + test_organization: Organization, + normal_user_token_headers: dict[str, str], +): + unique_name = f"UpdatedOrg-{random_lower_string()}" + update_data = {"name": unique_name, "is_active": False} + + response = client.patch( + f"{settings.API_V1_STR}/organizations/{test_organization.id}", + json=update_data, + headers=normal_user_token_headers, + ) + assert response.status_code == 403 + + +def test_delete_organization_as_superuser( db: Session, test_organization: Organization, superuser_token_headers: dict[str, str], @@ -88,8 +116,33 @@ def test_delete_organization( headers=superuser_token_headers, ) assert response.status_code == 200 + response = client.get( f"{settings.API_V1_STR}/organizations/{test_organization.id}", headers=superuser_token_headers, ) assert response.status_code == 404 + + +def test_delete_organization_as_regular_user( + db: Session, + test_organization: Organization, + normal_user_token_headers: dict[str, str], +): + response = client.delete( + f"{settings.API_V1_STR}/organizations/{test_organization.id}", + headers=normal_user_token_headers, + ) + assert response.status_code == 403 + + +def test_read_organization_as_regular_user_without_access( + db: Session, + normal_user_token_headers: dict[str, str], + other_organization: Organization, +): + response = client.get( + f"{settings.API_V1_STR}/organizations/{other_organization.id}", + headers=normal_user_token_headers, + ) + assert response.status_code == 403 # Forbidden, as the user doesn't have access