-
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathauth.py
More file actions
185 lines (151 loc) · 6.22 KB
/
auth.py
File metadata and controls
185 lines (151 loc) · 6.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import os
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from jose import JWTError, jwt
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from schemas import TokenData, UserInDB
from database import get_user_by_id, get_user_organizations, check_user_org_access
import logging
logger = logging.getLogger(__name__)
# JWT settings
SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key-change-this-in-production")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/login")
def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
"""Create JWT access token"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire, "type": "access"})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def create_refresh_token(data: Dict[str, Any]) -> str:
"""Create JWT refresh token"""
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
to_encode.update({"exp": expire, "type": "refresh"})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def verify_token(token: str, token_type: str = "access") -> Optional[TokenData]:
"""Verify JWT token"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
# Check token type
if payload.get("type") != token_type:
return None
username: str = payload.get("sub")
user_id: int = payload.get("user_id")
if username is None or user_id is None:
return None
token_data = TokenData(
username=username,
user_id=user_id,
org_uuid=payload.get("org_uuid"),
role=payload.get("role")
)
return token_data
except JWTError as e:
logger.error(f"JWT verification failed: {e}")
return None
async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInDB:
"""Get current authenticated user"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
token_data = verify_token(token)
if token_data is None:
raise credentials_exception
user = get_user_by_id(token_data.user_id)
if user is None:
raise credentials_exception
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Inactive user"
)
return user
async def get_current_active_user(current_user: UserInDB = Depends(get_current_user)) -> UserInDB:
"""Get current active user"""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Inactive user"
)
return current_user
def require_role(required_role: str):
"""Dependency to require specific role"""
async def role_dependency(token: str = Depends(oauth2_scheme)) -> UserInDB:
user = await get_current_user(token)
# Super users can access everything
if user.is_superuser:
return user
token_data = verify_token(token)
if token_data and token_data.role == required_role:
return user
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied. Required role: {required_role}"
)
return role_dependency
def require_org_access(org_uuid: str):
"""Dependency to require access to specific organization"""
async def org_access_dependency(current_user: UserInDB = Depends(get_current_user)) -> UserInDB:
# Super users can access everything
if current_user.is_superuser:
return current_user
# Check if user has access to the organization
if not check_user_org_access(current_user.id, org_uuid):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to organization"
)
return current_user
return org_access_dependency
class OrganizationAccess:
"""Helper class to check organization access dynamically"""
@staticmethod
async def verify_access(org_uuid: str, current_user: UserInDB = Depends(get_current_user)) -> bool:
"""Verify user has access to organization"""
if current_user.is_superuser:
return True
return check_user_org_access(current_user.id, org_uuid)
@staticmethod
async def get_user_orgs(current_user: UserInDB = Depends(get_current_user)) -> list:
"""Get user's organizations"""
return get_user_organizations(current_user.id)
def create_user_token(user: UserInDB, org_uuid: Optional[str] = None) -> Dict[str, str]:
"""Create access and refresh tokens for user"""
# Get user's primary organization if not specified
user_orgs = get_user_organizations(user.id)
if not org_uuid and user_orgs:
org_uuid = user_orgs[0]['uuid'] # Use first organization as default
# Determine user role for the organization
role = "user" # default
if user.is_superuser:
role = "admin"
elif user_orgs:
# Find role for specific org
for org in user_orgs:
if org['uuid'] == org_uuid:
role = org['role_name']
break
token_data = {
"sub": user.username or user.email,
"user_id": user.id,
"org_uuid": org_uuid,
"role": role
}
access_token = create_access_token(token_data)
refresh_token = create_refresh_token({"sub": user.username or user.email, "user_id": user.id})
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer"
}