Skip to content

Auth: Implementation of org access control #250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
19 changes: 18 additions & 1 deletion backend/app/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Annotated, Optional

import jwt
from fastapi import Depends, HTTPException, status, Request, Header, Security
import logging
from fastapi import Depends, HTTPException, status, Request, Path, Query
from fastapi.responses import JSONResponse
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader
from jwt.exceptions import InvalidTokenError
Expand All @@ -23,8 +24,12 @@
ProjectUser,
Project,
Organization,
APIKey,
)


logger = logging.getLogger(__name__)

reusable_oauth2 = OAuth2PasswordBearer(
tokenUrl=f"{settings.API_V1_STR}/login/access-token", auto_error=False
)
Expand Down Expand Up @@ -147,6 +152,18 @@ def get_current_active_superuser_org(current_user: CurrentUserOrg) -> User:
return current_user


def check_org_access(_current_user: CurrentUserOrg, org_id: int = Path):
"""Helper function to check organization access."""
if not _current_user.is_superuser and org_id != _current_user.organization_id:
logger.warning(
f"[check_org_access] Access violation | user_id={_current_user.id}, attempted_org_id={org_id}, "
f"current_org_id={_current_user.organization_id}"
)
raise HTTPException(
status_code=403, detail="Access to this organization is forbidden"
)


def verify_user_project_organization(
db: SessionDep,
current_user: CurrentUserOrg,
Expand Down
70 changes: 48 additions & 22 deletions backend/app/api/routes/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@
OrganizationPublic,
)
from app.api.deps import (
CurrentUser,
CurrentUserOrg,
SessionDep,
get_current_active_superuser,
check_org_access,
)
from app.crud.organization import (
create_organization,
get_organization_by_id,
validate_organization,
)
from app.crud.organization import create_organization, get_organization_by_id
from app.utils import APIResponse

router = APIRouter(prefix="/organizations", tags=["organizations"])
Expand All @@ -24,14 +29,28 @@
# Retrieve organizations
@router.get(
"/",
dependencies=[Depends(get_current_active_superuser)],
response_model=APIResponse[List[OrganizationPublic]],
)
def read_organizations(session: SessionDep, skip: int = 0, limit: int = 100):
count_statement = select(func.count()).select_from(Organization)
count = session.exec(count_statement).one()
def read_organizations(
session: SessionDep, current_user: CurrentUserOrg, skip: int = 0, limit: int = 100
):
"""
Return all organizations for superuser,
or only the one associated with the current user.
"""

if current_user.is_superuser:
statement = select(Organization).where(Organization.is_active == True)
else:
if not current_user.organization_id:
return APIResponse.success_response([])

Check warning on line 46 in backend/app/api/routes/organization.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/organization.py#L45-L46

Added lines #L45 - L46 were not covered by tests

statement = select(Organization).offset(skip).limit(limit)
statement = select(Organization).where(

Check warning on line 48 in backend/app/api/routes/organization.py

View check run for this annotation

Codecov / codecov/patch

backend/app/api/routes/organization.py#L48

Added line #L48 was not covered by tests
Organization.id == current_user.organization_id,
Organization.is_deleted == False,
)

statement = statement.offset(skip).limit(limit)
organizations = session.exec(statement).all()

return APIResponse.success_response(organizations)
Expand All @@ -44,37 +63,44 @@
response_model=APIResponse[OrganizationPublic],
)
def create_new_organization(*, session: SessionDep, org_in: OrganizationCreate):
"""
Creates a new organization, note that only a superuser can create an organization.
"""
new_org = create_organization(session=session, org_create=org_in)
return APIResponse.success_response(new_org)


# Retrieve an organization by ID
@router.get(
"/{org_id}",
dependencies=[Depends(get_current_active_superuser)],
response_model=APIResponse[OrganizationPublic],
)
def read_organization(*, session: SessionDep, org_id: int):
def read_organization(*, session: SessionDep, org_id: int, _=Depends(check_org_access)):
"""
Retrieve an organization by ID.
Retrieves an organization by ID,
a normal user can only retrieve the organization(s) associated with their user ID.
"""
org = get_organization_by_id(session=session, org_id=org_id)
if org is None:
raise HTTPException(status_code=404, detail="Organization not found")
org = validate_organization(session=session, org_id=org_id)
return APIResponse.success_response(org)


# Update an organization
@router.patch(
"/{org_id}",
dependencies=[Depends(get_current_active_superuser)],
response_model=APIResponse[OrganizationPublic],
)
def update_organization(
*, session: SessionDep, org_id: int, org_in: OrganizationUpdate
*,
session: SessionDep,
org_id: int,
org_in: OrganizationUpdate,
_=Depends(check_org_access),
):
org = get_organization_by_id(session=session, org_id=org_id)
if org is None:
raise HTTPException(status_code=404, detail="Organization not found")
"""
Updates an organization by ID,
a normal user can only update the organization(s) associated with their user ID.
"""
org = validate_organization(session=session, org_id=org_id)

org_data = org_in.model_dump(exclude_unset=True)
org = org.model_copy(update=org_data)
Expand All @@ -94,10 +120,10 @@
include_in_schema=False,
)
def delete_organization(session: SessionDep, org_id: int):
org = get_organization_by_id(session=session, org_id=org_id)
if org is None:
raise HTTPException(status_code=404, detail="Organization not found")

"""
Deletes an existing organization, note that only a superuser can delete an organization.
"""
org = validate_organization(session=session, org_id=org_id)
session.delete(org)
session.commit()

Expand Down
75 changes: 64 additions & 11 deletions backend/app/tests/api/routes/test_org.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,43 +24,56 @@ def test_organization(db: Session, superuser_token_headers: dict[str, str]):
return organization


# Test retrieving organizations
def test_read_organizations(db: Session, superuser_token_headers: dict[str, str]):
@pytest.fixture
def other_organization(db: Session, superuser_token_headers: dict[str, str]):
unique_name = f"OtherOrg-{random_lower_string()}"
org_data = OrganizationCreate(name=unique_name, is_active=True)
organization = create_organization(session=db, org_create=org_data)
db.commit()
return organization


def test_read_organizations_as_superuser(
db: Session,
superuser_token_headers: dict[str, str],
test_organization: Organization,
):
response = client.get(
f"{settings.API_V1_STR}/organizations/", headers=superuser_token_headers
)
assert response.status_code == 200
response_data = response.json()
assert "data" in response_data
assert isinstance(response_data["data"], list)
assert len(response_data["data"]) > 0


# Test creating an organization
def test_create_organization(db: Session, superuser_token_headers: dict[str, str]):
def test_create_organization_as_superuser(
db: Session, superuser_token_headers: dict[str, str]
):
unique_name = f"Org-{random_lower_string()}"
org_data = {"name": unique_name, "is_active": True}
response = client.post(
f"{settings.API_V1_STR}/organizations/",
json=org_data,
headers=superuser_token_headers,
)

assert 200 <= response.status_code < 300
created_org = response.json()
assert "data" in created_org # Make sure there's a 'data' field
assert "data" in created_org
created_org_data = created_org["data"]
org = get_organization_by_id(session=db, org_id=created_org_data["id"])
assert org is not None # The organization should be found in the DB
assert org is not None
assert org.name == created_org_data["name"]
assert org.is_active == created_org_data["is_active"]


def test_update_organization(
def test_update_organization_as_superuser(
db: Session,
test_organization: Organization,
superuser_token_headers: dict[str, str],
):
unique_name = f"UpdatedOrg-{random_lower_string()}" # Ensure a unique name
unique_name = f"UpdatedOrg-{random_lower_string()}"
update_data = {"name": unique_name, "is_active": False}

response = client.patch(
Expand All @@ -77,8 +90,23 @@ def test_update_organization(
assert updated_org["is_active"] == update_data["is_active"]


# Test deleting an organization
def test_delete_organization(
def test_update_organization_as_regular_user(
db: Session,
test_organization: Organization,
normal_user_token_headers: dict[str, str],
):
unique_name = f"UpdatedOrg-{random_lower_string()}"
update_data = {"name": unique_name, "is_active": False}

response = client.patch(
f"{settings.API_V1_STR}/organizations/{test_organization.id}",
json=update_data,
headers=normal_user_token_headers,
)
assert response.status_code == 403


def test_delete_organization_as_superuser(
db: Session,
test_organization: Organization,
superuser_token_headers: dict[str, str],
Expand All @@ -88,8 +116,33 @@ def test_delete_organization(
headers=superuser_token_headers,
)
assert response.status_code == 200

response = client.get(
f"{settings.API_V1_STR}/organizations/{test_organization.id}",
headers=superuser_token_headers,
)
assert response.status_code == 404


def test_delete_organization_as_regular_user(
db: Session,
test_organization: Organization,
normal_user_token_headers: dict[str, str],
):
response = client.delete(
f"{settings.API_V1_STR}/organizations/{test_organization.id}",
headers=normal_user_token_headers,
)
assert response.status_code == 403


def test_read_organization_as_regular_user_without_access(
db: Session,
normal_user_token_headers: dict[str, str],
other_organization: Organization,
):
response = client.get(
f"{settings.API_V1_STR}/organizations/{other_organization.id}",
headers=normal_user_token_headers,
)
assert response.status_code == 403 # Forbidden, as the user doesn't have access