diff --git a/src/client/plantdb/client/api_endpoints.py b/src/client/plantdb/client/api_endpoints.py index 4624c31b..fc8c5071 100644 --- a/src/client/plantdb/client/api_endpoints.py +++ b/src/client/plantdb/client/api_endpoints.py @@ -21,9 +21,10 @@ >>> api_endpoints.scan('plant1', prefix='/api/v1') '/api/v1/scans/plant1' """ +from urllib import parse -def sanitize_name(name): +def sanitize_name(name) -> str: """Sanitizes and validates the provided name. The function ensures that the input string adheres to predefined naming rules by: @@ -58,9 +59,8 @@ def sanitize_name(name): return sanitized_name -def url_prefix(endpoint_path): - """ - Wrap an endpoint path generator with an optional URL prefix. +def url_prefix(endpoint_path) : + """Wrap an endpoint path generator with an optional URL prefix. Examples -------- @@ -282,6 +282,29 @@ def scans(**kwargs) -> str: return "/scans" +@url_prefix +def scans_info(**kwargs) -> str: + """Return the URL path to the list of scan dataset information endpoint. + + Other Parameters + ---------------- + prefix : str + An optional prefix to prepend to the URL path. + + Returns + ------- + str + The URL path to the list of scan dataset information endpoint. + + Examples + -------- + >>> from plantdb.client import api_endpoints + >>> api_endpoints.scans_info(prefix='/api/v1') + '/api/v1/scans_info' + """ + return "/scans_info" + + @url_prefix def scan(scan_id: str, **kwargs) -> str: """Return the URL path to the scan endpoint. @@ -313,7 +336,7 @@ def scan(scan_id: str, **kwargs) -> str: @url_prefix -def image(scan_id: str, fileset_id: str, file_id: str, size: str, **kwargs) -> str: +def image(scan_id: str, fileset_id: str, file_id: str, size: str, as_base64:bool, **kwargs) -> str: """Return the URL path to the image endpoint. Parameters @@ -326,6 +349,8 @@ def image(scan_id: str, fileset_id: str, file_id: str, size: str, **kwargs) -> s The name of the image. size : str or int The size parameter of the image request. + as_base64 : bool + A boolean flag indicating whether to return an image as a base64 string. Returns ------- @@ -335,13 +360,25 @@ def image(scan_id: str, fileset_id: str, file_id: str, size: str, **kwargs) -> s Examples -------- >>> from plantdb.client import api_endpoints - >>> api_endpoints.image('real_plant','images','00000_rgb', 'orig') + >>> api_endpoints.image('real_plant','images','00000_rgb', 'orig', False) '/image/real_plant/images/00000_rgb?size=orig' + >>> api_endpoints.image('real_plant','images','00000_rgb', 'thumb', True) + '/image/real_plant/images/00000_rgb?size=thumb&as_base64=true' """ scan_id = sanitize_name(scan_id) fileset_id = sanitize_name(fileset_id) file_id = sanitize_name(file_id) - return f"/image/{scan_id}/{fileset_id}/{file_id}?size={size}" + + # Assemble optional query parameters + query: dict[str, str] = {} + if size is not None: + query["size"] = str(size) + if as_base64 is not None: + # Use lower‑case JSON‑style booleans for consistency + query["as_base64"] = str(as_base64).lower() + + query_str = f"?{parse.urlencode(query)}" if query else "" + return f"/image/{scan_id}/{fileset_id}/{file_id}{query_str}" @url_prefix @@ -370,7 +407,14 @@ def sequence(scan_id: str, type: str, **kwargs) -> str: valid_types = ['all', 'angles', 'internodes', 'fruit_points', 'manual_angles', 'manual_internodes'] type = 'all' if type not in valid_types else type scan_id = sanitize_name(scan_id) - return f"/sequence/{scan_id}?type={type}" + + # Assemble optional query parameters + query: dict[str, str] = {} + if type is not None: + query["type"] = str(type) + + query_str = f"?{parse.urlencode(query)}" if query else "" + return f"/sequence/{scan_id}{query_str}" @url_prefix @@ -453,9 +497,8 @@ def scan_file(scan_id: str, file_path: str, **kwargs) -> str: return f"/files/{scan_id}/{file_path.lstrip('/')}" - @url_prefix -def create_user(**kwargs): +def create_user(**kwargs) -> str: """Create the user registration URL. Returns @@ -470,3 +513,108 @@ def create_user(**kwargs): '/register' """ return f"/register" + + +@url_prefix +def create_scan(**kwargs) -> str: + """URL to create a scan. + + Returns + ------- + str + The URL path to scan creation. + """ + return f"/api/scan" + + +@url_prefix +def create_fileset(**kwargs) -> str: + """URL to create a fileset. + + Returns + ------- + str + The URL path to fileset creation. + """ + return f"/api/fileset" + + +@url_prefix +def create_file(**kwargs) -> str: + """URL to create a file. + + Returns + ------- + str + The URL path to file creation. + """ + return f"/api/file" + + +@url_prefix +def list_scan_filesets(scan: str, **kwargs) -> str: + """URL to list the filesets associated with the given scan name. + + Returns + ------- + str + The URL path to filesets. + """ + scan_id = sanitize_name(scan) + return f"/api/scan/{scan_id}/filesets" + + +@url_prefix +def list_fileset_files(scan: str, fileset: str, **kwargs) -> str: + """URL to list the file associated with the given scan and filesets names. + + Returns + ------- + str + The URL path to filesets. + """ + scan_id = sanitize_name(scan) + fileset = sanitize_name(fileset) + return f"/api/scan/{scan_id}/{fileset}/files" + + +@url_prefix +def metadata_scan(scan: str, **kwargs) -> str: + """URL to access the metadata associated with the given scan name. + + Returns + ------- + str + The URL path to scan metadata. + """ + scan_id = sanitize_name(scan) + return f"/api/scan/{scan_id}/metadata" + + +@url_prefix +def metadata_fileset(scan: str, fileset: str, **kwargs) -> str: + """URL to access the fileset metadata associated with the given scan and fileset name. + + Returns + ------- + str + The URL path to fileset metadata. + """ + scan_id = sanitize_name(scan) + fileset = sanitize_name(fileset) + return f"/api/scan/{scan_id}/{fileset}/metadata" + + +@url_prefix +def metadata_files(scan: str, fileset: str, file: str, **kwargs) -> str: + """URL to access the file metadata associated with the given scan and fileset name. + + Returns + ------- + str + The URL path to file metadata. + """ + scan_id = sanitize_name(scan) + fileset = sanitize_name(fileset) + file = sanitize_name(file) + return f"/api/scan/{scan_id}/{fileset}/{file}/metadata" diff --git a/src/client/plantdb/client/cli/fsdb_rest_api_sync.py b/src/client/plantdb/client/cli/fsdb_rest_api_sync.py index 526af89c..2c1a65a6 100644 --- a/src/client/plantdb/client/cli/fsdb_rest_api_sync.py +++ b/src/client/plantdb/client/cli/fsdb_rest_api_sync.py @@ -197,7 +197,7 @@ def sync_scan_archives(origin_url, target_url, filter_pattern=None, log_level=DE Path(f_path).unlink() # Refresh the scan in the target to load its infos: try: - msg = request_refresh(scan_id, host=target_host, port=target_port) + success, msg = request_refresh(scan_id, host=target_host, port=target_port) except HTTPError as e: logger.error(f"Error refreshing target database for scan '{scan_id}': {e}") continue diff --git a/src/client/plantdb/client/plantdb_client.py b/src/client/plantdb/client/plantdb_client.py index ad9affbe..6e70f8e9 100644 --- a/src/client/plantdb/client/plantdb_client.py +++ b/src/client/plantdb/client/plantdb_client.py @@ -21,7 +21,7 @@ >>> # $ fsdb_rest_api --test >>> from plantdb.client.plantdb_client import PlantDBClient >>> from plantdb.client.rest_api import plantdb_url ->>> client = PlantDBClient(plantdb_url()) +>>> client = PlantDBClient(plantdb_url('localhost', port=5000)) >>> # Create a new scan >>> scan_id = client.create_scan( ... name="Plant Sample 001", @@ -35,6 +35,7 @@ ... ) ``` """ +import json import mimetypes import os @@ -92,9 +93,11 @@ class PlantDBClient: The base URL of the PlantDB REST API. session : requests.Session HTTP session that maintains cookies and connection pooling. - jwt_token : str - The JWT token to authenticate with the PlantDB REST API. - username :str + _access_token : str + The JSON Web Token to authenticate with the PlantDB REST API. + _refresh_token : str + The refresh token to obtain new access tokens. + _username :str The login username. logger : logging.Logger The logger to use. @@ -109,16 +112,16 @@ class PlantDBClient: -------- >>> from plantdb.server.test_rest_api import TestRestApiServer >>> # Start a test PlantDB REST API server first: - >>> server = TestRestApiServer(test=True, port=5555) + >>> server = TestRestApiServer(test=True, port=5000) >>> server.start() >>> # Use the client against the server >>> from plantdb.client.plantdb_client import PlantDBClient >>> from plantdb.client.rest_api import plantdb_url - >>> client = PlantDBClient(plantdb_url('localhost', port=5555)) + >>> client = PlantDBClient(plantdb_url('localhost', port=5000)) >>> client.login('admin', 'admin') - >>> print(client.jwt_token) - >>> client2 = PlantDBClient(plantdb_url('localhost', port=5555)) - >>> client2.validate_session_token(client.jwt_token) + >>> print(client._access_token) + >>> client2 = PlantDBClient(plantdb_url('localhost', port=5000)) + >>> client2.validate_token(client._access_token) >>> print(client.plantdb_url) >>> scans = client.list_scans() >>> print(scans) @@ -133,30 +136,12 @@ def __init__(self, base_url, prefix=None): prefix = api_prefix() self.base_url = f"{base_url}{prefix}" self.session = requests.Session() - self.jwt_token = None - self.username = None - self.logger = get_logger(__class__.__name__) - def validate_session_token(self, token): - """ - Sets the JWT token for the HTTP session and updates the Authorization header. + self._access_token = None + self._refresh_token = None + self._username = None - Parameters - ---------- - token : str - The JWT token to be used for authentication. - """ - url = join_url(self.base_url, api_endpoints.token_validation()) - response = self.session.post(url, headers={"Authorization": f"Bearer {token}"}) - if response.ok: - self.jwt_token = token - self.username = response.json().get('username') - # Add the JWT to the header - self.session.headers.update({'Authorization': f'Bearer {self.jwt_token}'}) - else: - self.logger.error(f"Token validation failed!") - self.logger.error(response.json()) - return + self.logger = get_logger(__class__.__name__) def login(self, username: str, password: str) -> bool: """ @@ -181,13 +166,15 @@ def login(self, username: str, password: str) -> bool: } try: - response = self.session.post(url, json=data) + # Use session.request directly for login to avoid using expired tokens in headers + response = self.session.request("POST", url, json=data) if response.ok: result = response.json() - self.jwt_token = result.get('access_token') - self.username = username + self._access_token = result.get('access_token') + self._refresh_token = result.get('refresh_token') + self._username = username # Add the JWT to the header - self.session.headers.update({'Authorization': f'Bearer {self.jwt_token}'}) + self.session.headers.update({'Authorization': f'Bearer {self._access_token}'}) return True else: error_msg = response.json().get('message', 'Login failed') @@ -198,22 +185,45 @@ def login(self, username: str, password: str) -> bool: self.logger.error(f"Login request failed: {e}") return False + def _request_with_refresh(self, method, url, **kwargs): + """Perform an HTTP request with automatic token refresh on 401.""" + response = self.session.request(method, url, **kwargs) + + if response.status_code == 401 and self._refresh_token: + self.logger.info("Access token expired, attempting to refresh...") + if self.refresh_token(): + self.logger.info("Token refresh successful, retrying request...") + # Update headers for the retry + if 'headers' in kwargs: + kwargs['headers'].update({'Authorization': f'Bearer {self._access_token}'}) + else: + # session already has the updated Authorization header + pass + return self.session.request(method, url, **kwargs) + else: + self.logger.error("Token refresh failed, user needs to re-authenticate.") + + return response + def logout(self) -> bool: - """ - Logout user from the PlantDB API. + """Logout user from the PlantDB API. Returns ------- bool - True if logout successful + ``True`` if logout successful, ``False`` otherwise. """ url = join_url(self.base_url, api_endpoints.logout()) try: - response = self.session.post(url) + # Use _request_with_refresh for logout as it requires authentication + response = self._request_with_refresh("POST", url) if response.ok: - self.username = None + self._username = None + self._access_token = None + self._refresh_token = None # Remove the Authorization with the JWT from the header - self.session.headers.pop('Authorization') + if 'Authorization' in self.session.headers: + self.session.headers.pop('Authorization') return True return False except Exception: @@ -229,7 +239,8 @@ def create_user(self, username: str, password: str, fullname: str) -> bool: } try: - response = self.session.post(url, json=data) + # create_user usually requires admin, use _request_with_refresh + response = self._request_with_refresh("POST", url, json=data) if response.ok: return True else: @@ -245,53 +256,117 @@ def refresh(self) -> bool: """Refresh the database.""" url = join_url(self.base_url, api_endpoints.refresh()) try: - response = self.session.get(url) + response = self._request_with_refresh("GET", url) if response.ok: return True return False except Exception: return False + def validate_token(self, token) -> bool: + """Validate an authentication token against the remote service. + + This method sends a ``POST`` request to the token‑validation endpoint + using the supplied ``token`` in the ``Authorization`` header. The + request is performed via :meth:`_request_with_refresh`, which will + transparently refresh the session if necessary. The response's + ``ok`` attribute determines the boolean result. + + Parameters + ---------- + token : str + The bearer token to be validated. + + Returns + ------- + True if the token is accepted by the server, otherwise ``False``. + + Examples + -------- + >>> client = MyApiClient(base_url='https://api.example.com') + >>> token = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...' + >>> client.validate_token(token) + True + >>> client.validate_token('invalid') + False + + See Also + -------- + _request_with_refresh : Internal helper that handles token refresh. + api_endpoints.token_validation : Returns the relative URL for token validation. + """ + url = join_url(self.base_url, api_endpoints.token_validation()) + response = self._request_with_refresh("POST", url, headers={"Authorization": f"Bearer {token}"}) + if response.ok: + resp_username = response.json()['user']['username'] + if not self._username: + self._username = resp_username + self.refresh_token() + if self._username and resp_username != self._username: + self.logger.warning(f"Given token correspond to a different username") + return True + else: + return False + def refresh_token(self) -> bool: - """Refresh the JWT token.""" + """Refresh the JSON Web Token. + + Uses the stored refresh token to obtain a new access/refresh token pair. + """ + if not self._refresh_token: + self.logger.error("No refresh token available") + return False + url = join_url(self.base_url, api_endpoints.token_refresh()) + data = {'refresh_token': self._refresh_token} try: - response = self.session.post(url) + # Use session.request directly to avoid infinite recursion with _request_with_refresh + response = self.session.request("POST", url, json=data) if response.ok: result = response.json() - self.jwt_token = result.get('access_token') - self.username = result.get('username') + self._access_token = result.get('access_token') + self._refresh_token = result.get('refresh_token') + # Update the header with the new access token + self.session.headers.update({'Authorization': f'Bearer {self._access_token}'}) return True - return False - except Exception: + else: + error_msg = response.json().get('message', 'Token refresh failed') + self.logger.error(f"Token refresh failed: {error_msg}") + self._access_token = None + self._refresh_token = None + self._username = None + return False + except Exception as e: + self.logger.error(f"Token refresh request failed: {e}") return False def _handle_http_errors(self, response): - """ - Handles HTTP errors by raising a custom exception with an error message obtained - from the HTTP response. This function intercepts the original exception, extracts - the error message from the response JSON, and raises a new exception to the same - type with the extracted message. - - Parameters - ---------- - response : requests.Response - The HTTP response object from which the status and error message will be - assessed. The response object is expected to have a JSON body containing - a key "message" for error details. + """Handles HTTP errors by logging a message appropriate to the severity of the HTTP status code.""" + # If the response is successful, nothing to do + if response.ok: + return + + # Determine severity and log accordingly + if response.status_code >= 500: + # Server error – treat as serious + self.logger.error( + f"Server error {response.status_code}: {response.reason}" + ) + else: + # Client error – treat as a warning + self.logger.warning( + f"Client error {response.status_code}: {response.reason}" + ) - Raises - ------ - RequestException - If the HTTP response status code indicates an error. The raised exception is - of the same type as the original exception, with the message replaced by the - value of the "message" key from the response JSON body. - """ + # Try to pull a helpful message from the JSON payload try: - response.raise_for_status() - except RequestException as e: - response_data = response.json()["message"] - raise type(e)(response_data) from e + response_data = response.json().get("message", response.text) + except ValueError: + # Fallback to raw text if JSON cannot be decoded + response_data = response.text + + # Re‑raise a generic RequestException with the extracted message + raise RequestException(response_data) def list_scans(self, query=None, fuzzy=False): """List all scans in the database. @@ -319,7 +394,7 @@ def list_scans(self, query=None, fuzzy=False): >>> # $ fsdb_rest_api --test >>> from plantdb.client.plantdb_client import PlantDBClient >>> from plantdb.client.rest_api import plantdb_url - >>> client = PlantDBClient(plantdb_url()) + >>> client = PlantDBClient(plantdb_url('localhost', port=5000)) >>> response = client.list_scans() >>> print(response) ['virtual_plant', 'real_plant_analyzed', 'real_plant', 'virtual_plant_analyzed', 'arabidopsis000'] @@ -330,12 +405,55 @@ def list_scans(self, query=None, fuzzy=False): params['query'] = query if fuzzy: params['fuzzy'] = fuzzy - response = self.session.get(url, params=params) + response = self._request_with_refresh('GET', url, params=params) # Handle HTTP errors with explicit messages self._handle_http_errors(response) return response.json() + def list_scans_info(self, query=None, fuzzy=False): + """Retrieve detailed scan information dictionaries from the ScansTable resource. + + Parameters + ---------- + query : dict, optional + A dictionary that will be JSON‑encoded and sent as the ``filterQuery`` URL + parameter. Use the same structure accepted by the server, _e.g._ + ``{"object": {"species": "Arabidopsis.*"}}``. + fuzzy : bool, optional + When ``True`` the server performs fuzzy matching (default ``False``). + + Returns + ------- + list[dict] + A list where each entry is a dictionary containing the scan’s + ``metadata``, ``tasks``, ``files`` and other information as defined by `ScansTable`. + + Raises + ------ + requests.exceptions.RequestException + If the request fails or the server returns an error status. + """ + # Build the URL for the “scans info” endpoint – the server side class is ScansTable + url = join_url(self.base_url, api_endpoints.scans_info()) + + # Prepare query parameters exactly as the REST API expects + params = {} + if query is not None: + # The API expects a JSON string in the ``filterQuery`` parameter + params["filterQuery"] = json.dumps(query) + if fuzzy: + params["fuzzy"] = fuzzy + + # Perform the request; token refresh is handled automatically + response = self._request_with_refresh("GET", url, params=params) + + # Turn HTTP errors into readable exceptions + self._handle_http_errors(response) + + # Return the parsed JSON payload (list of dicts) + return response.json() + def create_scan(self, name, metadata=None): """Create a new scan in the database. @@ -362,17 +480,23 @@ def create_scan(self, name, metadata=None): >>> # $ fsdb_rest_api --test >>> from plantdb.client.plantdb_client import PlantDBClient >>> from plantdb.client.rest_api import plantdb_url - >>> client = PlantDBClient(plantdb_url()) + >>> client = PlantDBClient(plantdb_url('localhost', port=5000)) >>> metadata = {'description': 'Test plant scan'} + >>> # Scan creation requires authentication >>> response = client.create_scan('test_plant', metadata=metadata) - >>> print(response) + ERROR [PlantDBClient] Server error 500: INTERNAL SERVER ERROR + requests.exceptions.RequestException: Error creating scan: Insufficient permissions to create a scan as 'guest' user! + >>> # Log in as admin to get sufficient rights + >>> client.login('admin', 'admin') + >>> response = client.create_scan('test_plant', metadata=metadata) + >>> print(response['message']) {'message': "Scan 'test_plant' created successfully."} """ - url = f"{self.base_url}/api/scan" + url = join_url(self.base_url, api_endpoints.create_scan()) data = {'name': name} if metadata: data['metadata'] = metadata - response = self.session.post(url, json=data) + response = self._request_with_refresh("POST", url, json=data) # Handle HTTP errors with explicit messages self._handle_http_errors(response) @@ -404,19 +528,19 @@ def get_scan_metadata(self, scan_id, key=None): >>> # $ fsdb_rest_api --test >>> from plantdb.client.plantdb_client import PlantDBClient >>> from plantdb.client.rest_api import plantdb_url - >>> client = PlantDBClient(plantdb_url()) + >>> client = PlantDBClient(plantdb_url('localhost', port=5000)) >>> # Get all metadata >>> metadata = client.get_scan_metadata('test_plant') >>> print(metadata) - {'metadata': {'owner': 'anonymous', 'description': 'Test plant scan'}} - >>> # Get specific metadata key + {'metadata': {'owner': 'admin', 'created': '2026-02-04T00:25:13.869891', 'last_modified': '2026-02-04T00:25:13.871581', 'created_by': 'PlantDB Admin', 'description': 'Test plant scan'}} + >>> # Get a specific metadata key >>> value = client.get_scan_metadata('test_plant', key='description') >>> print(value) {'metadata': 'Test plant scan'} """ url = f"{self.base_url}/api/scan/{scan_id}/metadata" params = {'key': key} if key else None - response = self.session.get(url, params=params) + response = self._request_with_refresh("GET", url, params=params) # Handle HTTP errors with explicit messages self._handle_http_errors(response) @@ -451,18 +575,20 @@ def update_scan_metadata(self, scan_id, metadata, replace=False): >>> # $ fsdb_rest_api --test >>> from plantdb.client.plantdb_client import PlantDBClient >>> from plantdb.client.rest_api import plantdb_url - >>> client = PlantDBClient(plantdb_url()) + >>> client = PlantDBClient(plantdb_url('localhost', port=5000)) + >>> # Log in as admin to get sufficient rights + >>> client.login('admin', 'admin') >>> new_metadata = {'description': 'Updated scan description'} >>> response = client.update_scan_metadata('test_plant', new_metadata) - >>> print(response) - {'metadata': {'owner': 'anonymous', 'description': 'Updated scan description'}} + >>> print(response['metadata']['description']) + Updated scan description """ url = f"{self.base_url}/api/scan/{scan_id}/metadata" data = { 'metadata': metadata, 'replace': replace } - response = self.session.post(url, json=data) + response = self._request_with_refresh("POST", url, json=data) # Handle HTTP errors with explicit messages self._handle_http_errors(response) @@ -496,7 +622,7 @@ def list_scan_filesets(self, scan_id, query=None, fuzzy=False): >>> # $ fsdb_rest_api --test >>> from plantdb.client.plantdb_client import PlantDBClient >>> from plantdb.client.rest_api import plantdb_url - >>> client = PlantDBClient(plantdb_url()) + >>> client = PlantDBClient(plantdb_url('localhost', port=5000)) >>> response = client.list_scan_filesets('real_plant') >>> print(response) {'filesets': ['images']} @@ -507,7 +633,7 @@ def list_scan_filesets(self, scan_id, query=None, fuzzy=False): params['query'] = query if fuzzy: params['fuzzy'] = fuzzy - response = self.session.get(url, params=params) + response = self._request_with_refresh("GET", url, params=params) # Handle HTTP errors with explicit messages self._handle_http_errors(response) @@ -528,7 +654,7 @@ def create_fileset(self, fileset_id, scan_id, metadata=None): Returns ------- dict - Server response containing creation confirmation message + Server response containing a creation confirmation message Raises ------ @@ -541,7 +667,9 @@ def create_fileset(self, fileset_id, scan_id, metadata=None): >>> # $ fsdb_rest_api --test >>> from plantdb.client.plantdb_client import PlantDBClient >>> from plantdb.client.rest_api import plantdb_url - >>> client = PlantDBClient(plantdb_url()) + >>> client = PlantDBClient(plantdb_url('localhost', port=5000)) + >>> # Log in as admin to get sufficient rights + >>> client.login('admin', 'admin') >>> metadata = {'description': 'This is a test fileset'} >>> response = client.create_fileset('my_fileset', 'real_plant', metadata=metadata) >>> print(response) @@ -554,7 +682,7 @@ def create_fileset(self, fileset_id, scan_id, metadata=None): } if metadata: data['metadata'] = metadata - response = self.session.post(url, json=data) + response = self._request_with_refresh("POST", url, json=data) # Handle HTTP errors with explicit messages self._handle_http_errors(response) @@ -588,7 +716,7 @@ def get_fileset_metadata(self, scan_id, fileset_id, key=None): >>> # $ fsdb_rest_api --test >>> from plantdb.client.plantdb_client import PlantDBClient >>> from plantdb.client.rest_api import plantdb_url - >>> client = PlantDBClient(plantdb_url()) + >>> client = PlantDBClient(plantdb_url('localhost', port=5000)) >>> # Get all metadata >>> metadata = client.get_fileset_metadata('real_plant', 'my_fileset') >>> print(metadata) @@ -600,7 +728,7 @@ def get_fileset_metadata(self, scan_id, fileset_id, key=None): """ url = f"{self.base_url}/api/fileset/{scan_id}/{fileset_id}/metadata" params = {'key': key} if key else None - response = self.session.get(url, params=params) + response = self._request_with_refresh("GET", url, params=params) # Handle HTTP errors with explicit messages self._handle_http_errors(response) @@ -637,7 +765,9 @@ def update_fileset_metadata(self, scan_id, fileset_id, metadata, replace=False): >>> # $ fsdb_rest_api --test >>> from plantdb.client.plantdb_client import PlantDBClient >>> from plantdb.client.rest_api import plantdb_url - >>> client = PlantDBClient(plantdb_url()) + >>> client = PlantDBClient(plantdb_url('localhost', port=5000)) + >>> # Log in as admin to get sufficient rights + >>> client.login('admin', 'admin') >>> # Update metadata >>> new_metadata = {'description': 'Updated fileset description', 'author': 'John Doe'} >>> response = client.update_fileset_metadata('real_plant', 'my_fileset', new_metadata) @@ -649,7 +779,7 @@ def update_fileset_metadata(self, scan_id, fileset_id, metadata, replace=False): 'metadata': metadata, 'replace': replace } - response = self.session.post(url, json=data) + response = self._request_with_refresh("POST", url, json=data) # Handle HTTP errors with explicit messages self._handle_http_errors(response) @@ -685,7 +815,7 @@ def list_fileset_files(self, scan_id, fileset_id, query=None, fuzzy=False): >>> # $ fsdb_rest_api --test >>> from plantdb.client.plantdb_client import PlantDBClient >>> from plantdb.client.rest_api import plantdb_url - >>> client = PlantDBClient(plantdb_url()) + >>> client = PlantDBClient(plantdb_url('localhost', port=5000)) >>> response = client.list_fileset_files('real_plant', 'images') >>> print(response) {'files': ['00000_rgb', '00001_rgb', '00002_rgb', ...]} @@ -696,7 +826,7 @@ def list_fileset_files(self, scan_id, fileset_id, query=None, fuzzy=False): params['query'] = query if fuzzy: params['fuzzy'] = fuzzy - response = self.session.get(url, params=params) + response = self._request_with_refresh("GET", url, params=params) # Handle HTTP errors with explicit messages self._handle_http_errors(response) @@ -740,7 +870,9 @@ def create_file(self, file_data, file_id, ext, scan_id, fileset_id, metadata=Non >>> import yaml >>> from plantdb.client.plantdb_client import PlantDBClient >>> from plantdb.client.rest_api import plantdb_url - >>> client = PlantDBClient(plantdb_url()) + >>> client = PlantDBClient(plantdb_url('localhost', port=5000)) + >>> # Log in as admin to get sufficient rights + >>> client.login('admin', 'admin') >>> # Example 1 - Existing YAML file path as string >>> metadata = {'description': 'Test document', 'author': 'John Doe'} >>> dummy_data = {'name': 'Test Plant', 'species': 'Arabidopsis thaliana'} @@ -796,7 +928,7 @@ def create_file(self, file_data, file_id, ext, scan_id, fileset_id, metadata=Non files = { 'file': (filename, file_data, get_mime_type(ext)) } - response = self.session.post(url, files=files, data=data) + response = self._request_with_refresh("POST", url, files=files, data=data) else: # Convert to a Path object if it's a string file_path = Path(file_data) if isinstance(file_data, str) else file_data @@ -807,7 +939,7 @@ def create_file(self, file_data, file_id, ext, scan_id, fileset_id, metadata=Non files = { 'file': (filename, file_handle, 'application/octet-stream') } - response = self.session.post(url, files=files, data=data) + response = self._request_with_refresh("POST", url, files=files, data=data) # Handle HTTP errors with explicit messages self._handle_http_errors(response) @@ -843,7 +975,7 @@ def get_file_metadata(self, scan_id, fileset_id, file_id, key=None): >>> # $ fsdb_rest_api --test >>> from plantdb.client.plantdb_client import PlantDBClient >>> from plantdb.client.rest_api import plantdb_url - >>> client = PlantDBClient(plantdb_url()) + >>> client = PlantDBClient(plantdb_url('localhost', port=5000)) >>> # Get all metadata >>> metadata = client.get_file_metadata('test_plant', 'images', 'image_001') >>> print(metadata) @@ -855,7 +987,7 @@ def get_file_metadata(self, scan_id, fileset_id, file_id, key=None): """ url = f"{self.base_url}/api/file/{scan_id}/{fileset_id}/{file_id}/metadata" params = {'key': key} if key else None - response = self.session.get(url, params=params) + response = self._request_with_refresh("GET", url, params=params) # Handle HTTP errors with explicit messages self._handle_http_errors(response) @@ -894,7 +1026,9 @@ def update_file_metadata(self, scan_id, fileset_id, file_id, metadata, replace=F >>> # $ fsdb_rest_api --test >>> from plantdb.client.plantdb_client import PlantDBClient >>> from plantdb.client.rest_api import plantdb_url - >>> client = PlantDBClient(plantdb_url()) + >>> client = PlantDBClient(plantdb_url('localhost', port=5000)) + >>> # Log in as admin to get sufficient rights + >>> client.login('admin', 'admin') >>> # Update metadata >>> new_metadata = {'description': 'Updated description'} >>> response = client.update_file_metadata( @@ -911,7 +1045,7 @@ def update_file_metadata(self, scan_id, fileset_id, file_id, metadata, replace=F 'metadata': metadata, 'replace': replace } - response = self.session.post(url, json=data) + response = self._request_with_refresh("POST", url, json=data) # Handle HTTP errors with explicit messages self._handle_http_errors(response) diff --git a/src/client/plantdb/client/rest_api.py b/src/client/plantdb/client/rest_api.py index ba26e960..12771b74 100644 --- a/src/client/plantdb/client/rest_api.py +++ b/src/client/plantdb/client/rest_api.py @@ -53,6 +53,10 @@ import os from io import BytesIO from pathlib import Path +from typing import Union +from urllib.parse import splitport +from urllib.parse import urlparse +from urllib.parse import urlunparse import requests from PIL import Image @@ -61,6 +65,7 @@ from plantdb.client import api_endpoints from plantdb.client.api_endpoints import sanitize_name +from plantdb.commons.log import get_logger #: Default hostname to PlantDB REST API is 'localhost': PLANTDB_HOST = os.getenv('PLANTDB_HOST', "localhost") @@ -77,31 +82,100 @@ #: Default URL prefix for the plantdb REST API PLANTDB_PREFIX = os.getenv('PLANTDB_PREFIX', None) +logger = get_logger(__name__) + # ----------------------------------------------------------------------------- # URL construction methods # ----------------------------------------------------------------------------- def origin_url(host, port=None, ssl=False, **kwargs) -> str: - # Attempt to split the host to check for an existing scheme (http/https) - try: - scheme, host = host.split('://') - except ValueError: - pass # If no scheme is found, proceed with the default - else: - # If 's' is in the scheme, it indicates HTTPS - if 's' in scheme: - ssl = True - - # Ensure port is converted to string and has no leading colon - if port: - if isinstance(port, int): - port = str(port) - port = ':' + port.lstrip(':') + """Construct a URL string from host, optional port, and SSL flag. + + Parameters + ---------- + host : str + Hostname or URL. May optionally include a scheme (e.g., ``http://`` or + ``https://``). If a scheme is present and contains the character ``s``, + the function treats it as HTTPS and forces ``ssl`` to ``True``. + port : int or str, optional + Port number to append to the host. If an ``int`` is supplied, it is + converted to a string; a leading colon is stripped before it is added. + The default is ``None`` which results in no port being added. + ssl : bool, optional + When ``True`` the URL will use the ``https`` scheme. The value is + overridden to ``True`` if the supplied ``host`` already contains a scheme + with an ``s`` character. + + Returns + ------- + url + The fully‑qualified URL string constructed from the supplied parts. + + Raises + ------ + TypeError + If ``host`` is not a string or does not support ``split`` (e.g., ``None``). + + Notes + ----- + The function does **not** validate that the resulting URL points to a + reachable endpoint; it only assembles the string. Supplying both a scheme + in ``host`` and ``ssl=True`` will result in the scheme dictated by the + original ``host`` (HTTPS if the original scheme contains ``s``). + + Examples + -------- + >>> from plantdb.client.rest_api import origin_url + >>> origin_url('example.com') + 'http://example.com' + >>> origin_url('example.com', 8080) + 'http://example.com:8080' + >>> origin_url('https://example.com') + 'https://example.com' + >>> origin_url('https://example.com/api/v1') + 'https://example.com' + >>> origin_url('http://example.com', ssl=True) + 'https://example.com' + >>> origin_url('example.com', port='443', ssl=True) + 'https://example.com:443' + """ + if not isinstance(host, str): + raise TypeError("host must be a string") + + # Parse the incoming host value + parsed = urlparse(host) + + # If no scheme was supplied, ``urlparse`` treats the whole string as a + # path. In that case we split the first “/” to obtain the netloc. + if not parsed.scheme: + # e.g. "example.com/api/v1" -> netloc="example.com", path="/api/v1" + first_slash = parsed.path.find("/") + if first_slash == -1: + netloc, path = parsed.path, "" + else: + netloc = parsed.path[:first_slash] + path = parsed.path[first_slash:] + scheme = "" else: - port = '' + scheme = parsed.scheme + netloc = parsed.netloc + path = parsed.path + + # If the original string already contains a scheme that contains an “s” (i.e. https) it forces ``ssl=True``. + if scheme and "s" in scheme.lower(): + ssl = True + final_scheme = "https" if ssl else "http" - # Construct the final URL - return f"http{'s' if ssl else ''}://{host}{port}" + # Apply an explicit ``port`` argument (overwrites any existing one) + if port is not None: + # ``splitport`` safely separates host from any existing port. + hostname, _ = splitport(netloc) + # Ensure ``port`` is a clean string without a leading colon. + clean_port = str(port).lstrip(":") + netloc = f"{hostname}:{clean_port}" + + # Re‑assemble the URL, excluding the original path (if any) + return urlunparse((final_scheme, netloc, "", "", "", "")) def plantdb_url(host, port=PLANTDB_PORT, prefix=PLANTDB_PREFIX, ssl=False) -> str: @@ -267,6 +341,73 @@ def register_url(host, **kwargs): return join_url(url, api_endpoints.register(**kwargs)) +def token_validation_url(host, **kwargs): + """Generate the full URL for the PlantDB API token validation endpoint. + + Parameters + ---------- + host : str + The hostname or IP address of the PlantDB REST API server. + + Other Parameters + ---------------- + port : int + The PlantDB API port number, defaults to ``None``. + prefix : str + A path prefix for the PlantDB API, defaults to ``None``. + ssl : bool + A boolean flag indicating whether to use HTTPS (``True``) or HTTP (``False``). Defaults to ``False``. + + Returns + ------- + str + The fully qualified register URL as a string. + + Examples + -------- + >>> from plantdb.client.rest_api import token_validation_url + >>> # Basic usage with default configuration + >>> url = token_validation_url('localhost') + >>> print(url) + http://localhost/token-validation + """ + url = origin_url(host, **kwargs) + return join_url(url, api_endpoints.token_validation()) + + +def token_refresh_url(host, **kwargs): + """Generate the full URL for the PlantDB API token refresh endpoint. + + Parameters + ---------- + host : str + The hostname or IP address of the PlantDB REST API server. + + Other Parameters + ---------------- + port : int + The PlantDB API port number, defaults to ``None``. + prefix : str + A path prefix for the PlantDB API, defaults to ``None``. + ssl : bool + A boolean flag indicating whether to use HTTPS (``True``) or HTTP (``False``). Defaults to ``False``. + + Returns + ------- + str + The fully qualified register URL as a string. + + Examples + -------- + >>> from plantdb.client.rest_api import token_refresh_url + >>> # Basic usage with default configuration + >>> url = token_refresh_url('localhost') + >>> print(url) + http://localhost/token-refresh + """ + url = origin_url(host, **kwargs) + return join_url(url, api_endpoints.token_refresh()) + def scans_url(host, **kwargs): """Generates the URL listing the scans from the PlantDB REST API. @@ -401,7 +542,7 @@ def scan_preview_image_url(host, scan_id, size="thumb", **kwargs): return join_url(url, thumb_uri) -def scan_image_url(host, scan_id, fileset_id, file_id, size='orig', **kwargs): +def scan_image_url(host, scan_id, fileset_id, file_id, size='orig', as_base64=False, **kwargs): """Get the URL to the image for a scan dataset and task fileset served by the PlantDB REST API. Parameters @@ -415,11 +556,13 @@ def scan_image_url(host, scan_id, fileset_id, file_id, size='orig', **kwargs): file_id : str The name of the image file to be retrieved. size : {'orig', 'large', 'thumb'} or int, optional - If an integer, use it as the size of the cached image to create and return. + If an integer, use it as the size of the cached image to create and return. Else, should be a string, defaulting to ``'orig'``, and it works as follows: * ``'thumb'``: image max width and height to `150`. * ``'large'``: image max width and height to `1500`; * ``'orig'``: original image, no cache; + as_base64 : bool + A boolean flag indicating whether to return an image as a base64 string. Other Parameters ---------------- @@ -441,11 +584,13 @@ def scan_image_url(host, scan_id, fileset_id, file_id, size='orig', **kwargs): >>> from plantdb.client.rest_api import scan_image_url >>> scan_image_url('localhost', "real_plant", "images", "00000_rgb") 'http://localhost/image/real_plant/images/00000_rgb?size=orig' + >>> scan_image_url('localhost', "real_plant", "images", "00000_rgb", as_base64=True) + 'http://localhost/image/real_plant/images/00000_rgb?size=orig&as_base64=true' >>> scan_image_url('localhost', "real_plant", "images", "00000_rgb", prefix='/plantdb') 'http://localhost/plantdb/image/real_plant/images/00000_rgb?size=orig' """ url = origin_url(host, **kwargs) - return join_url(url, api_endpoints.image(scan_id, fileset_id, file_id, size, **kwargs)) + return join_url(url, api_endpoints.image(scan_id, fileset_id, file_id, size, as_base64, **kwargs)) def refresh_url(host, scan_id=None, **kwargs): @@ -629,7 +774,7 @@ def scan_reconstruction_url(host, scan_id, cfg_fname='pipeline.toml', **kwargs): return scan_file_url(host, scan_id, cfg_fname, **kwargs) -def list_task_images_uri(host, scan_id, task_name='images', size='orig', **kwargs): +def list_task_images_uri(host, scan_id, task_name='images', size='orig', as_base64=True, **kwargs): """Get the list of images URI for a given dataset and task name. Parameters @@ -646,6 +791,8 @@ def list_task_images_uri(host, scan_id, task_name='images', size='orig', **kwarg * `'thumb'`: image max width and height to `150`. * `'large'`: image max width and height to `1500`; * `'orig'`: original image, no cache; + as_base64 : bool + A boolean flag indicating whether to return an image as a base64 string. Other Parameters ---------------- @@ -672,12 +819,11 @@ def list_task_images_uri(host, scan_id, task_name='images', size='orig', **kwarg http://localhost/image/real_plant/images/00002_rgb?size=100 """ scan_info = request_scan_data(host, scan_id, **kwargs) - tasks_fileset = scan_info["tasks_fileset"] + tasks_id = scan_info["tasks_fileset"][task_name] images = scan_info["images"] url = origin_url(host, **kwargs) - return [join_url(url, api_endpoints.image(scan_id, tasks_fileset[task_name], Path(img).stem, size, **kwargs)) for - img in - images] + return [join_url(url, api_endpoints.image(scan_id, tasks_id, Path(img).stem, size, as_base64, **kwargs)) + for img in images] # ----------------------------------------------------------------------------- @@ -715,6 +861,7 @@ def make_api_request(url, method="GET", params=None, json_data=None, Flag indicating whether to stream the request. Default is False. session_token : str The PlantDB REST API session token of the user. + It should be supplied for every request that requires authentication on the server-side. Returns ------- @@ -734,6 +881,14 @@ def make_api_request(url, method="GET", params=None, json_data=None, ----- This function is designed to handle various HTTP methods (GET, POST, PUT, DELETE) and provides a unified interface for making API requests. It supports SSL verification and allows for custom parameters and JSON data to be sent with the request. It passes keyword arguments to the underlying `requests` library. + + Examples + -------- + >>> from plantdb.client.rest_api import make_api_request + >>> from plantdb.client.rest_api import login_url + >>> response = make_api_request(login_url('localhost', port=5000), "POST", json_data={'username': 'admin', 'password': 'admin'}) + >>> access_token, refresh_token = response.json()['access_token'], response.json()['refresh_token'] + >>> user = response.json()['user'] """ requests_kwargs = {} requests_kwargs['params'] = params @@ -746,12 +901,12 @@ def make_api_request(url, method="GET", params=None, json_data=None, # otherwise default to requests' built‑in verification requests_kwargs['verify'] = os.getenv('CERT_PATH', True) - # If a session token is supplied, add it to the Authorization header requests_kwargs['headers'] = kwargs.get('headers', {}) + # If a session token is supplied, add it to the Authorization header if 'session_token' in kwargs: requests_kwargs['headers'].update({'Authorization': f"Bearer {kwargs.get('session_token')}"}) - # Normalise the HTTP method name to uppercase for comparison + # Normalize the HTTP method name to uppercase for comparison method = method.upper() try: @@ -774,14 +929,14 @@ def make_api_request(url, method="GET", params=None, json_data=None, response.raise_for_status() # Raise exception for 4XX/5XX responses return response except requests.exceptions.SSLError as e: - print(f"SSL Error: {e}") - raise + logger.error(f"SSL Error: {e}") + raise e from e except requests.exceptions.RequestException as e: - print(f"Request Error: {e}") - raise + logger.error(f"Request Error: {e}") + raise e from e -def request_login(host, username, password, **kwargs): +def request_login(host, username, password, **kwargs) -> dict: """Send a login request to the authentication service. This helper function constructs a POST request to the login endpoint @@ -809,8 +964,8 @@ def request_login(host, username, password, **kwargs): Returns ------- - requests.Response - The response from the API. + dict + The login data from the response if successful. Notes ----- @@ -824,18 +979,19 @@ def request_login(host, username, password, **kwargs): >>> # Start a test PlantDB REST API server first, in a terminal: >>> # $ fsdb_rest_api --test >>> from plantdb.client.rest_api import request_login - >>> login_data = request_login('localhost', 'admin', 'admin', port=5000).json() - >>> print(login_data) + >>> login_data = request_login('localhost', 'admin', 'admin', port=5000) + >>> print(list(login_data)) + ['access_token', 'message', 'refresh_token', 'user'] """ url = login_url(host, **kwargs) data = { 'username': username, 'password': password } - return make_api_request(url, method="POST", json_data=data) + return make_api_request(url, method="POST", json_data=data).json() -def request_check_username(host, username, **kwargs): +def request_check_username(host, username, **kwargs) -> bool: """Send a username availability request to the authentication service. This helper function constructs a GET request to the login endpoint @@ -860,8 +1016,8 @@ def request_check_username(host, username, **kwargs): Returns ------- - requests.Response - The response from the API. + bool + A boolean flag indicating whether the username is valid (``True``) or not (``False``). Notes ----- @@ -876,14 +1032,14 @@ def request_check_username(host, username, **kwargs): >>> # $ fsdb_rest_api --test >>> from plantdb.client.rest_api import request_check_username >>> username_exists = request_check_username('localhost', 'admin', port=5000) - >>> print(username_exists.json()['exists']) + >>> print(username_exists) True """ url = login_url(host, **kwargs) - return make_api_request(url, method="GET", params={'username': username}) + return make_api_request(url, method="GET", params={'username': username}).json()['exists'] -def request_logout(host, **kwargs): +def request_logout(host, **kwargs) -> tuple[bool, str]: """Send a logout request to the authentication service. This helper function constructs a POST request to the logout endpoint @@ -908,8 +1064,9 @@ def request_logout(host, **kwargs): Returns ------- - requests.Response - The response from the API. + tuple[bool, str] + A boolean flag indicating whether the logout request was successful (``True``) or not (``False``). + A string with the log out message. Notes ----- @@ -922,16 +1079,95 @@ def request_logout(host, **kwargs): >>> # $ fsdb_rest_api --test >>> from plantdb.client.rest_api import request_login >>> from plantdb.client.rest_api import request_logout - >>> login_data = request_login('localhost', 'admin', 'admin', port=5000).json() - >>> logout = request_logout('localhost', port=5000, session_token=login_data['access_token']) - >>> print(logout.ok) + >>> login_data = request_login('localhost', 'admin', 'admin', port=5000) + >>> success, msg = request_logout('localhost', port=5000, session_token=login_data['access_token']) + >>> print(success) True """ url = logout_url(host, **kwargs) - return make_api_request(url, method="POST", session_token=kwargs.get('session_token', None)) + response = make_api_request(url, method="POST", session_token=kwargs.get('session_token', None)) + return response.ok, response.json()['message'] + + +def request_token_validation(host, **kwargs) -> dict: + """Validate a token by making a POST request to the token validation endpoint. + + Parameters + ---------- + host : str + The hostname or base URL used to construct the validation endpoint. + + Other Parameters + ---------------- + port : int + The PlantDB API port number, defaults to ``None``. + prefix : str + A path prefix for the PlantDB API, defaults to ``None``. + ssl : bool + A boolean flag indicating whether to use HTTPS (``True``) or HTTP (``False``). Defaults to ``False``. + session_token : str + The PlantDB REST API session token of the user. + Returns + ------- + dict + The token validation data from the response, if successful. + + Examples + -------- + >>> # Start a test PlantDB REST API server first, in a terminal: + >>> # $ fsdb_rest_api --test + >>> from plantdb.client.rest_api import request_login + >>> from plantdb.client.rest_api import request_token_validation + >>> login_data = request_login('localhost', 'admin', 'admin', port=5000) + >>> token_data = request_token_validation('localhost', port=5000, session_token=login_data['access_token']) + >>> print(token_data['user']) + {'username': 'admin', 'fullname': 'PlantDB Admin'} + """ + url = token_validation_url(host, **kwargs) + return make_api_request(url, method="POST", session_token=kwargs.get('session_token', None)).json() -def request_new_user(host, username, password, fullname, **kwargs): + +def request_token_refresh(host, **kwargs) -> dict: + """Refresh a token by making a POST request to the token refresh endpoint. + + Parameters + ---------- + host : str + The hostname or base URL used to construct the refresh endpoint. + + Other Parameters + ---------------- + port : int + The PlantDB API port number, defaults to ``None``. + prefix : str + A path prefix for the PlantDB API, defaults to ``None``. + ssl : bool + A boolean flag indicating whether to use HTTPS (``True``) or HTTP (``False``). Defaults to ``False``. + session_token : str + The PlantDB REST API session token of the user. + + Returns + ------- + dict + The token refresh data from the response, if successful. + + Examples + -------- + >>> # Start a test PlantDB REST API server first, in a terminal: + >>> # $ fsdb_rest_api --test + >>> from plantdb.client.rest_api import request_login + >>> from plantdb.client.rest_api import request_token_refresh + >>> login_data = request_login('localhost', 'admin', 'admin', port=5000) + >>> token_refresh = request_token_refresh('localhost', port=5000, refresh_token=login_data['refresh_token']) + >>> print([key for key in token_refresh.json() if 'token' in key]) + ['access_token', 'refresh_token'] + """ + url = token_refresh_url(host, **kwargs) + return make_api_request(url, method="POST", json_data={'refresh_token': kwargs.get('refresh_token', None)}).json() + + +def request_new_user(host, username, password, fullname, **kwargs) -> bool: """Send a registration request to the authentication service. This helper function constructs a POST request to the register endpoint @@ -963,8 +1199,8 @@ def request_new_user(host, username, password, fullname, **kwargs): Returns ------- - requests.Response - The response from the API. + bool + A boolean indicating whether the request was successful (``True``) or not (``False``). Notes ----- @@ -978,18 +1214,21 @@ def request_new_user(host, username, password, fullname, **kwargs): >>> from plantdb.client.rest_api import request_login >>> from plantdb.client.rest_api import request_logout >>> from plantdb.client.rest_api import request_new_user - >>> login_data = request_login('localhost', 'admin', 'admin', port=5000).json() + >>> login_data = request_login('localhost', 'admin', 'admin', port=5000) >>> user_added = request_new_user('localhost', 'testuser', 'fake_password', 'Test User', port=5000, session_token=login_data['access_token']) - >>> print(user_added.ok) + >>> print(user_added) True >>> logout = request_logout('localhost', port=5000, session_token=login_data['access_token']) + >>> login_data = request_login('localhost', 'testuser', 'fake_password', port=5000) + >>> print(login_data['user']['username']) + testuser """ url = register_url(host, **kwargs) data = {'username': username, 'fullname': fullname, 'password': password} - return make_api_request(url, method="POST", json_data=data, session_token=kwargs.get('session_token', None)) + return make_api_request(url, method="POST", json_data=data, session_token=kwargs.get('session_token', None)).ok -def request_scan_names_list(host, **kwargs): +def request_scan_names_list(host, **kwargs) -> list[str]: """Get the list of the scan datasets names served by the PlantDB REST API. Parameters @@ -1010,22 +1249,22 @@ def request_scan_names_list(host, **kwargs): Returns ------- - requests.Response - The response from the API. The list of scan dataset names should be in the JSON dictionary. + list[str] + The list of the scan datasets names from the response, if successful. Examples -------- >>> # Start a test PlantDB REST API server first, in a terminal: >>> # $ fsdb_rest_api --test >>> from plantdb.client.rest_api import request_scan_names_list - >>> print(request_scan_names_list('localhost', port=5000).json()) + >>> print(request_scan_names_list('localhost', port=5000) ['arabidopsis000', 'real_plant', 'real_plant_analyzed', 'virtual_plant', 'virtual_plant_analyzed'] """ url = scans_url(host, **kwargs) - return make_api_request(url=url, method="GET", session_token=kwargs.get('session_token', None)) + return make_api_request(url=url, method="GET", session_token=kwargs.get('session_token', None)).json() -def request_scans_info(host, **kwargs): +def request_scans_info(host, **kwargs) -> list[dict]: """Retrieve the information dictionary for all scans from the PlantDB REST API. Other Parameters @@ -1041,8 +1280,8 @@ def request_scans_info(host, **kwargs): Returns ------- - list - The list of scan information dictionaries. + list[dict] + The list of scan information dictionaries obtained from the response, if successful. Examples -------- @@ -1050,15 +1289,17 @@ def request_scans_info(host, **kwargs): >>> # $ fsdb_rest_api --test >>> from plantdb.client.rest_api import request_scans_info >>> from plantdb.client.rest_api import request_login - >>> login_data = request_login('localhost', 'admin', 'admin', port=5000).json() - >>> scans_info = request_scans_info('localhost', port=5000, session_token=login_data['access_token']).json() + >>> login_data = request_login('localhost', 'admin', 'admin', port=5000) + >>> scans_info = request_scans_info('localhost', port=5000, session_token=login_data['access_token']) + >>> print(sorted([scan['id'] for scan in scans_info])) + ['arabidopsis000', 'real_plant', 'real_plant_analyzed', 'virtual_plant', 'virtual_plant_analyzed'] """ - scan_list = request_scan_names_list(host, **kwargs).json() + scan_list = request_scan_names_list(host, **kwargs) return [make_api_request(url=scan_url(host, scan, **kwargs), session_token=kwargs.get('session_token', None)).json() for scan in scan_list] -def request_scan_data(host, scan_id, **kwargs): +def request_scan_data(host, scan_id, **kwargs) -> dict: """Retrieve the data dictionary for a given scan dataset from the PlantDB REST API. Parameters @@ -1082,14 +1323,16 @@ def request_scan_data(host, scan_id, **kwargs): Returns ------- dict - The data dictionary for the given scan dataset. + The data dictionary for the given scan dataset obtained from the response, if successful. Examples -------- >>> # Start a test PlantDB REST API server first, in a terminal: >>> # $ fsdb_rest_api --test >>> from plantdb.client.rest_api import request_scan_data - >>> scan_data = request_scan_data('localhost', 'real_plant', port=5000) + >>> from plantdb.client.rest_api import request_login + >>> login_data = request_login('localhost', 'admin', 'admin', port=5000) + >>> scan_data = request_scan_data('localhost', 'real_plant', port=5000, session_token=login_data['access_token']) >>> print(scan_data['id']) real_plant >>> print(scan_data['hasColmap']) @@ -1108,7 +1351,8 @@ def request_scan_data(host, scan_id, **kwargs): return {} -def request_scan_image(host, scan_id, fileset_id, file_id, size='orig', **kwargs): +def request_scan_image(host, scan_id, fileset_id, file_id, + size='orig', as_base64=False, **kwargs) -> tuple[str, str, Union[str, bytes]]: """Get the image for a scan dataset and task fileset served by the PlantDB REST API. Parameters @@ -1127,6 +1371,8 @@ def request_scan_image(host, scan_id, fileset_id, file_id, size='orig', **kwargs * ``'thumb'``: image max width and height to `150`. * ``'large'``: image max width and height to `1500`; * ``'orig'``: original image, no cache; + as_base64 : bool + A boolean flag indicating whether to return an image as a base64 string. Other Parameters ---------------- @@ -1141,28 +1387,46 @@ def request_scan_image(host, scan_id, fileset_id, file_id, size='orig', **kwargs Returns ------- - requests.Response - The URL to an image of a scan dataset and task fileset. + tuple[str, str, Union[str, bytes]] + If ``as_base64==True``, a dictionary with the 'image' encoded as base64 and the mimetype in 'content-type'. + Else the image data as bytes. Examples -------- >>> # Start a test PlantDB REST API server first, in a terminal: >>> # $ fsdb_rest_api --test >>> from plantdb.client.rest_api import request_scan_image - >>> response = request_scan_image('real_plant', 'images', '00000_rgb', port=5000) # download the image - >>> print(response.status_code) - 200 - >>> # Display the image + >>> import pybase64 >>> from PIL import Image >>> from io import BytesIO - >>> image = Image.open(BytesIO(response.content)) # Open the image from the bytes data + >>> # Example #1 - Get an image as binary data: + >>> db_img = ['real_plant', 'images', '00000_rgb'] + >>> _, _, img_bytes = request_scan_image('localhost', *db_img, port=5000) # download the image + >>> print(img_bytes[:10]) + b'\xff\xd8\xff\xe0\x00\x10JFIF' + >>> image = Image.open(BytesIO(img_bytes)) # Open the image from the bytes data >>> image.show() # Display the image + >>> # Example #2 - Get an image as base64 data: + >>> _, _, b64_string = request_scan_image('localhost', *db_img, port=5000, as_base64=True) + >>> print(b64_string[:50]) + /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAUEBAQEAwUEBAQGBQ + >>> image_data = pybase64.b64decode(b64_string) + >>> image = Image.open(BytesIO(image_data)) # Open the image from the base64 data + >>> image.show() """ - url = scan_image_url(host, scan_id, fileset_id, file_id, size, **kwargs) - return make_api_request(url=url, session_token=kwargs.get('session_token', None)) + url = scan_image_url(host, scan_id, fileset_id, file_id, size, as_base64, **kwargs) + response = make_api_request(url=url, session_token=kwargs.get('session_token', None)) + content_type = response.headers.get('Content-Type') + encoding = response.headers.get("X-Content-Encoding") + if as_base64: + content_type = response.json()['content-type'] + img_str = response.json()['image'] + return content_type, encoding, img_str + else: + return content_type, encoding, response.content -def request_scan_tasks_fileset(host, scan_id, **kwargs): +def request_scan_tasks_fileset(host, scan_id, **kwargs) -> dict: """Get the task name to fileset name mapping dictionary from the REST API. Parameters @@ -1210,7 +1474,7 @@ def request_scan_tasks_fileset(host, scan_id, **kwargs): return request_scan_data(host, scan_id, **kwargs).get('tasks_fileset', dict()) -def request_refresh(host, scan_id=None, **kwargs): +def request_refresh(host, scan_id=None, **kwargs) -> tuple[bool, str]: """Refreshes the database, potentialy only for a specified dataset. Parameters @@ -1237,8 +1501,8 @@ def request_refresh(host, scan_id=None, **kwargs): Returns ------- - response : requests.Response - The response object from the refresh request. + tuple[bool, str] + A boolean indicating whether the refresh request succeeded. Raises ------ @@ -1250,12 +1514,13 @@ def request_refresh(host, scan_id=None, **kwargs): >>> # Start a test PlantDB REST API server first, in a terminal: >>> # $ fsdb_rest_api --test >>> from plantdb.client.rest_api import request_refresh - >>> res = request_refresh('localhost', "arabidopsis000", port = 5000) - >>> print(res.json()["message"]) + >>> success, message = request_refresh('localhost', "arabidopsis000", port = 5000) + >>> print(message) Successfully reloaded scan 'arabidopsis000' """ url = refresh_url(host, scan_id, **kwargs) - return make_api_request(url, session_token=kwargs.get('session_token', None)) + response = make_api_request(url, session_token=kwargs.get('session_token', None)) + return response.ok, response.json()["message"] def request_archive_download(host, scan_id, out_dir=None, **kwargs): @@ -1558,7 +1823,7 @@ def parse_scans_info(host, **kwargs): return scan_dict -def parse_task_images(host, scan_id, task_name='images', size='orig', **kwargs): +def parse_task_images(host, scan_id, task_name='images', size='orig', as_base64=False, **kwargs): """Get the list of images data for a given dataset and task name. Parameters @@ -1575,6 +1840,8 @@ def parse_task_images(host, scan_id, task_name='images', size='orig', **kwargs): * `'thumb'`: image max width and height to `150`. * `'large'`: image max width and height to `1500`; * `'orig'`: original image, no chache; + as_base64 : bool + A boolean flag indicating whether to return an image as a base64 string. Other Parameters ---------------- @@ -1603,7 +1870,7 @@ def parse_task_images(host, scan_id, task_name='images', size='orig', **kwargs): (1440, 1080) """ images = [] - for img_uri in list_task_images_uri(host, scan_id, task_name, size, **kwargs): + for img_uri in list_task_images_uri(host, scan_id, task_name, size, as_base64, **kwargs): images.append( Image.open(BytesIO(make_api_request(url=img_uri, session_token=kwargs.get('session_token', None)).content))) return images diff --git a/src/client/plantdb/client/sync.py b/src/client/plantdb/client/sync.py index ab3b1ed3..cb89ed9d 100644 --- a/src/client/plantdb/client/sync.py +++ b/src/client/plantdb/client/sync.py @@ -125,7 +125,7 @@ ... print("Sync completed successfully") >>> # Use REST API endpoint to refresh scans >>> from plantdb.client.rest_api import request_refresh ->>> request_refresh(**server_cfg) +>>> success, msg = request_refresh(**server_cfg) >>> # Use REST API to list scans and verify target DB contains the new scans >>> scans_list = request_scan_names_list(**server_cfg) >>> print(scans_list) @@ -994,9 +994,8 @@ def config_from_url(url): return config -def _parse_database_spec(spec): - """ - Parse and validate a database specification, determining the appropriate synchronization strategy. +def _parse_database_spec(spec) : + """Parse and validate a database specification, determining the appropriate synchronization strategy. This function analyzes database specifications and returns a structured representation suitable for synchronization operations. It supports multiple database types including diff --git a/src/client/plantdb/client/url.py b/src/client/plantdb/client/url.py index 914df0d5..64b7e6e6 100644 --- a/src/client/plantdb/client/url.py +++ b/src/client/plantdb/client/url.py @@ -51,9 +51,8 @@ ] -def _load_whitelist_from_file() -> Optional[set[str]]: - """ - Load allowed URLs from a specified environment variable file. +def _load_whitelist_from_file() -> Optional[set[str]] : + """Load allowed URLs from a specified environment variable file. This function reads a file path from an environment variable, parses it line by line, and extracts hostnames to create a set of allowed URLs. @@ -108,9 +107,8 @@ def _load_whitelist_from_file() -> Optional[set[str]]: WHITELIST: Optional[set[str]] = _load_whitelist_from_file() -def _download_and_cache_blacklist(force: bool = False) -> Optional[Path]: - """ - Ensure the blacklist file is present locally and up‑to‑date. +def _download_and_cache_blacklist(force: bool = False) -> Optional[Path] : + """Ensure the blacklist file is present locally and up‑to‑date. Returns the path to the cached file or ``None`` on failure. Parameters @@ -210,9 +208,8 @@ def _download_and_cache_blacklist(force: bool = False) -> Optional[Path]: return cache_path if cache_path.is_file() else None -def _is_host_blacklisted(hostname: str) -> bool: - """ - Check if a given hostname is listed in the blacklist. +def _is_host_blacklisted(hostname: str) -> bool : + """Check if a given hostname is listed in the blacklist. This function reads a locally cached blacklist file and checks if the provided hostname (with or without "www." prefix) is present in it. @@ -282,9 +279,8 @@ def _is_host_blacklisted(hostname: str) -> bool: # Helper utilities # --------------------------------------------------------------------------- # -def _is_private_ip(ip: str) -> bool: - """ - Return True if *ip* is in a private / non‑routable network. +def _is_private_ip(ip: str) -> bool : + """Return True if *ip* is in a private / non‑routable network. Examples -------- @@ -304,9 +300,8 @@ def _is_private_ip(ip: str) -> bool: return any(addr in net for net in PRIVATE_NETWORKS) -def _resolve_public_ips(hostname: str) -> List[str]: - """ - Resolve the given hostname to a list of its public IPv4/IPv6 addresses. +def _resolve_public_ips(hostname: str) -> List[str] : + """Resolve the given hostname to a list of its public IPv4/IPv6 addresses. This function queries DNS records for the specified hostname and filters out any private IP addresses. @@ -400,9 +395,8 @@ def _resolve_public_ips(hostname: str) -> List[str]: return list(public_ips) -def _validate_hostname(hostname: str, allow_private_ip: bool = False) -> bool: - """ - Validate the hostname based on given criteria. +def _validate_hostname(hostname: str, allow_private_ip: bool = False) -> bool : + """Validate the hostname based on given criteria. This function checks if a hostname is valid by verifying it against a whitelist, blacklist, and ensuring it has at least one public IP unless private IPs are allowed. @@ -585,9 +579,8 @@ def _handle_redirects( max_redirects: int, allow_private_ip: bool = False, validate_host: bool = True, -) -> tuple[urllib3.HTTPResponse, str, int]: - """ - Manually follow HTTP redirects up to ``max_redirects`` and return the final +) -> tuple[urllib3.HTTPResponse, str, int] : + """Manually follow HTTP redirects up to ``max_redirects`` and return the final response, the URL that was finally fetched and the number of redirects that were performed. """ diff --git a/src/client/pyproject.toml b/src/client/pyproject.toml index 49430cb9..d8528c63 100644 --- a/src/client/pyproject.toml +++ b/src/client/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ ] description = "Client-side database library for the ROMI plant database ecosystem." readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.8, <3.13" # 3.13 upper bound for open3d dependency license = { 'text' = "LGPL-3.0-or-later" } # to replace by "LGPL-3.0-or-later" only in next release, pending deprecation authors = [ { name = "Peter Hanappe", email = "peter.hanappe@sony.com" }, @@ -63,7 +63,6 @@ sync = [ ] test = [ 'plantdb.server', - 'open3d', ] [tool.setuptools.package-data] diff --git a/src/client/tests/test_client_against_server.py b/src/client/tests/test_client_against_server.py index e430e250..4fa136c7 100644 --- a/src/client/tests/test_client_against_server.py +++ b/src/client/tests/test_client_against_server.py @@ -41,15 +41,16 @@ def test_basic_listing_and_scan_data(self): scans. It ensures that the returned data is of the expected types and contains valid information. """ - names = client.request_scan_names_list(self.server.host, **self.kw) + login_data = client.request_login(self.server.host, 'guest', 'guest', **self.kw) + names = client.request_scan_names_list(self.server.host, session_token=login_data['access_token'], **self.kw) self.assertIsInstance(names, list) self.assertGreater(len(names), 0) scan_id = names[0] - info = client.request_scan_data(self.server.host, scan_id, **self.kw) + info = client.request_scan_data(self.server.host, scan_id, session_token=login_data['access_token'], **self.kw) self.assertEqual(info.get("id"), scan_id) - scans_info = client.request_scans_info(self.server.host, **self.kw) + scans_info = client.request_scans_info(self.server.host, session_token=login_data['access_token'], **self.kw) self.assertIsInstance(scans_info, list) def test_preview_and_images_helpers(self): @@ -61,19 +62,20 @@ def test_preview_and_images_helpers(self): It ensures that URL building is correct and requests complete successfully with status codes 200 or 404 depending on the dataset. """ - names = client.request_scan_names_list(self.server.host, **self.kw) + login_data = client.request_login(self.server.host, 'guest', 'guest', **self.kw) + names = client.request_scan_names_list(self.server.host, session_token=login_data['access_token'], **self.kw) scan_id = names[0] print(f"Selected scan ID: {scan_id}") # just ensure URL builds and request completes (200 or 404 acceptable depending on dataset) - url = client.scan_preview_image_url(self.server.host, scan_id, size="thumb", **self.kw) + url = client.scan_preview_image_url(self.server.host, scan_id, session_token=login_data['access_token'], size="thumb", **self.kw) print(f"URL: {url}") self.assertIn("/image/", url) # list task images - uris = client.list_task_images_uri(self.server.host, scan_id, task_name='images', size='orig', **self.kw) + uris = client.list_task_images_uri(self.server.host, scan_id, task_name='images', size='orig', session_token=login_data['access_token'], **self.kw) self.assertIsInstance(uris, list) # download images if any - imgs = client.parse_task_images(self.server.host, scan_id, task_name='images', size='orig', **self.kw) + imgs = client.parse_task_images(self.server.host, scan_id, task_name='images', size='orig', session_token=login_data['access_token'], **self.kw) self.assertIsInstance(imgs, list) def test_refresh_and_archive(self): @@ -85,9 +87,9 @@ def test_refresh_and_archive(self): Then, it retrieves a list of scan names and selects the first one as the target for archiving. The `download_scan_archive` method is called with this scan ID and no specified output directory, expecting to receive a tuple consisting of a BytesIO object and a message string. """ - res_data = client.request_refresh(self.server.host, **self.kw).json() - self.assertIsInstance(res_data, dict) - self.assertIn("message", res_data) + login_data = client.request_login(self.server.host, 'guest', 'guest', **self.kw) + success, message = client.request_refresh(self.server.host, session_token=login_data['access_token'], **self.kw) + self.assertTrue(success) names = client.request_scan_names_list(self.server.host, **self.kw) scan_id = names[0] @@ -149,10 +151,11 @@ def test_get_task_data(self): task_names = ["PointCloud", "TriangleMesh", "CurveSkeleton", "TreeGraph"] expected_data_types = [list, dict, dict, nx.Graph] + login_data = client.request_login(self.server.host, 'guest', 'guest', **self.kw) for task_name, expected_type in zip(task_names, expected_data_types): # We're testing the API calls succeed, not necessarily that data exists try: - data = client.get_task_data(self.server.host, scan_id, task_name, **self.kw) + data = client.get_task_data(self.server.host, scan_id, task_name, session_token=login_data['access_token'], **self.kw) # If data is returned, validate its structure if data is not None: print(f"Data for task {task_name} is: {type(data)}") diff --git a/src/commons/plantdb/commons/auth/rbac.py b/src/commons/plantdb/commons/auth/rbac.py index 5b631b61..f67fe65e 100644 --- a/src/commons/plantdb/commons/auth/rbac.py +++ b/src/commons/plantdb/commons/auth/rbac.py @@ -90,9 +90,8 @@ def wrapper(self, requesting_user: User, *args, **kwargs) -> Any: return decorator -class RBACManager: - """ - Manage Role-Based Access Control (RBAC) for users and permissions. +class RBACManager : + """Manage Role-Based Access Control (RBAC) for users and permissions. This class provides methods to determine which permissions a user has, check if a user has a specific permission, and verify if a user can access diff --git a/src/commons/plantdb/commons/auth/session.py b/src/commons/plantdb/commons/auth/session.py index 04a81822..09967c65 100644 --- a/src/commons/plantdb/commons/auth/session.py +++ b/src/commons/plantdb/commons/auth/session.py @@ -31,6 +31,8 @@ import secrets from datetime import datetime from datetime import timedelta +from datetime import timezone +from threading import RLock from typing import Any from typing import Dict from typing import Optional @@ -38,9 +40,43 @@ from typing import Union import jwt +from argon2 import Type +from argon2.low_level import hash_secret_raw + from plantdb.commons.log import get_logger +# ---------------------------------------------------------------------- +# Custom exception hierarchy for session validation +# ---------------------------------------------------------------------- +class SessionValidationError(Exception): + """Base class for all session‑validation‑related errors.""" + pass + + +class AccessTokenNotFoundError(SessionValidationError): + """Raised when an access token isn’t present in the active‑session store.""" + pass + + +class RefreshTokenNotFoundError(SessionValidationError): + """Raised when a refresh token isn’t present in the active‑refresh‑store.""" + pass + + +class InvalidTokenProcessingError(SessionValidationError): + """Raised for unexpected errors while processing a token (e.g. decoding issues).""" + pass + + +class RefreshTokenReuseError(SessionValidationError): + pass + + +class WrongTokenType(SessionValidationError): + pass + + class SessionManager: """Manages user sessions with expiration and validation. @@ -66,7 +102,7 @@ class SessionManager: The logger to use for this session manager. """ - def __init__(self, session_timeout: int = 3600, max_concurrent_sessions: int = 10): # 1 hour default + def __init__(self, session_timeout: int = 3600, max_concurrent_sessions: int = 10): """Manage user sessions with timeout. Parameters @@ -159,7 +195,7 @@ def create_session(self, username: str) -> Union[str, None]: f"Reached max concurrent sessions limit ({self.max_concurrent_sessions})") return None - now = datetime.now() + now = datetime.now(timezone.utc) exp_time = now + timedelta(seconds=self.session_timeout) # Create a session token session_token = secrets.token_urlsafe(32) @@ -172,7 +208,7 @@ def create_session(self, username: str) -> Union[str, None]: } return session_token - def validate_session(self, session_id: str) -> Optional[dict]: + def validate_session(self, session_id: str) -> Union[dict, None]: """Validate a given session by checking its existence and expiration status. Parameters @@ -182,7 +218,7 @@ def validate_session(self, session_id: str) -> Optional[dict]: Returns ------- - dict or None + Union[dict, None] A dictionary with user information if valid, ``None`` if invalid/expired. Returns dictionary with: - username: The authenticated user @@ -199,7 +235,7 @@ def validate_session(self, session_id: str) -> Optional[dict]: return None session = self.sessions[session_id] - now = datetime.now() + now = datetime.now(timezone.utc) if now > session['expires_at']: username = session['username'] self.logger.warning(f"The session for user '{username}' has expired. Please log back in!") @@ -221,8 +257,8 @@ def invalidate_session(self, session_id: str) -> Tuple[bool, str | None]: Returns ------- bool - `True` if the specified session was found and removed, `False` otherwise. - str + ``True`` if the specified session was found and removed, ``False`` otherwise. + Union[str, None] The username corresponding to the invalidated session Notes @@ -247,7 +283,7 @@ def cleanup_expired_sessions(self) -> None: ----- This function modifies the `self.sessions` dictionary in-place. """ - current_time = datetime.now() + current_time = datetime.now(timezone.utc) expired_sessions = [ sid for sid, session in self.sessions.items() if current_time > session['expires_at'] @@ -311,7 +347,7 @@ def refresh_session(self, session_id: str) -> Optional[str]: Returns ------- str or None - New session token if refresh successful + New session token if refresh is successful """ session_data = self.validate_session(session_id) if not session_data: @@ -320,7 +356,7 @@ def refresh_session(self, session_id: str) -> Optional[str]: # Invalidate old session self.invalidate_session(session_id) - # Create new session + # Create a new session username = session_data['username'] return self.create_session(username) @@ -389,13 +425,107 @@ def __init__(self, session_timeout: int = 3600, **kwargs) -> None: super().__init__(session_timeout=session_timeout, max_concurrent_sessions=1) +def _derive_key_argon2(password: str) -> bytes: + """Derive a 64‑byte (512‑bit) secret using Argon2. + + Parameters + ---------- + password: str + Human‑readable pass‑phrase supplied by the caller. + + Returns + ------- + bytes + 64‑byte key suitable for HS512. + """ + # Argon2 parameters – adjust if you need stronger/higher‑memory settings + time_cost = 2 # number of iterations + memory_cost = 102_400 # KiB (≈100 MiB) + parallelism = 8 # CPU lanes + salt = secrets.token_bytes(16) # 128‑bit random salt; stored only in‑memory here + # hash_secret_raw returns raw bytes (no encoding) + return hash_secret_raw( + secret=password.encode('utf‑8'), + salt=salt, + time_cost=time_cost, + memory_cost=memory_cost, + parallelism=parallelism, + hash_len=64, # 64 bytes = 512 bits + type=Type.ID, + ) + + +def _init_secret_key(secret_key: str = None) -> bytes: + """Generate or derive a 64‑byte secret key for HS512 signing. + + Parameters + ---------- + secret_key: Union[str, bytes, None] + Optional secret key. + + Returns + ------- + bytes + A 64‑byte key suitable for HS512 HMAC operations. + + Raises + ------ + ValueError + When ``secret_key`` is a ``bytes`` object shorter than 64 bytes. + + Notes + ----- + The function always returns exactly 64 bytes. + If ``secret_key`` is ``None``, a fresh random 64‑byte key is generated using ``secrets.token_bytes``. + If a ``bytes`` object is supplied, it is returned unchanged after verifying that its length is at + least 64 bytes; otherwise a ``ValueError`` is raised. + If a ``str`` is supplied, it is interpreted as a pass‑phrase and stretched to 64 bytes with Argon2 + via `_derive_key_argon2`. + + Examples + -------- + >>> from plantdb.commons.auth.session import _init_secret_key + >>> # Generate a new random key + >>> key = _init_secret_key() + >>> len(key) + 64 + >>> # Use an existing 64‑byte key + >>> raw = b'A' * 64 + >>> _init_secret_key(raw) is raw + True + >>> # Derive a key from a pass‑phrase + >>> key2 = _init_secret_key('my secret') + >>> len(key2) + 64 + >>> # Passing a short byte string raises an error + >>> _init_secret_key(b'short') + ValueError: Binary secret_key must be at least 64 bytes for HS512 + """ + if secret_key is None: + # No key supplied → generate a fresh random 64‑byte key + secret_key = secrets.token_bytes(64) + elif isinstance(secret_key, bytes): + # Caller supplied raw bytes – just verify length + if len(secret_key) < 64: + raise ValueError("Binary `secret_key` must be at least 64 bytes for HS512") + secret_key = secret_key + else: + # Caller supplied a pass‑phrase string → stretch it with Argon2 + secret_key = _derive_key_argon2(secret_key) + return secret_key + + class JWTSessionManager(SessionManager): - """Manages user sessions with expiration and validation. + """Manage JWT-based user sessions with configurable timeouts and concurrency limits. - This class provides methods to create, validate, invalidate, - and cleanup expired sessions. Each session is associated with a - unique identifier (session_id) and has an expiry time based on the - session timeout duration specified during initialization. + This session manager extends `SessionManager` by issuing JSON Web Tokens (JWT) for authentication. + An *access* token is short‑lived and is used for authorizing API calls, while a *refresh* token + is long‑lived and can be exchanged for a new access token when the original expires. + The manager keeps track of active access tokens to enforce a maximum number of concurrent + sessions per application instance. Tokens are signed with a secret key that is either supplied by + the caller or generated automatically. All tokens conform to RFC7519 and contain the standard + registered claims (`iss`, `sub`, `aud`, `exp`, `iat`, `jti`) plus a custom ``type`` claim that + identifies the token as ``'access'`` or ``'refresh'``. Attributes ---------- @@ -408,33 +538,80 @@ class JWTSessionManager(SessionManager): - 'expires_at': datetime - Expiry time of the session. session_timeout : int Duration in seconds after which a session expires. + The default value (``900``) corresponds to 15 minutes. max_concurrent_sessions : int The maximum number of concurrent sessions to allow. - secret_key : str - The session manager secret key to use for authentication. logger : logging.Logger The logger to use for this session manager. + refresh_timeout : int + Lifetime of a refresh token in seconds. + The default value (``86400``) corresponds to 24 hours. + secret_key : str + Secret used for HS512 signing of JWTs. + If ``None`` is passed to the constructor, a cryptographically‑secure random key is generated. + refresh_tokens : dict + Mapping from refresh token identifier (``jti``) to a dictionary containing ``username``, + creation and expiration timestamps, token type and the associated access token identifier. + Used to validate and rotate refresh tokens. + _lock : threading.Lock + A locking mechanism to lock `self.session` dict for thread‑safe changes """ - def __init__(self, session_timeout: int = 3600, max_concurrent_sessions: int = 10, secret_key: str = None): + def __init__(self, session_timeout: int = 900, refresh_timeout: int = 86400, max_concurrent_sessions: int = 10, + secret_key: str = None, leeway: int = 2): """Manage user sessions with timeout. Parameters ---------- session_timeout : int, optional - The duration for which the session should be valid in seconds. - A session that exceeds this duration will be considered expired and removed. - Defaults to ``3600`` seconds. + The duration for which the access token should be valid in seconds. + Defaults to ``900`` seconds (15 minutes). + refresh_timeout : int, optional + The duration for which the refresh token should be valid in seconds. + Defaults to ``86400`` seconds (24 hours). max_concurrent_sessions : int, optional The maximum number of concurrent sessions to allow. Defaults to ``10``. - secret_key : str, optional - Secret key for JWT signing. If None, generates a random key. + secret_key : Union[bytes, str] + Secret used for HS512 signing of JWTs. + - If a ``bytes`` object is supplied, it must be ≥ 64 bytes. + - If a ``str`` (pass‑phrase) is supplied, it will be stretched with Argon2 + to produce a 64‑byte key. + - If ``None`` a fresh random 64‑byte key is generated. + leeway : int, optional + Allowed leeway, in seconds, after tokens expiration date, to accommodate for clock-skew. + Set it to `0` so that the token is considered expired immediately after its exp claim passes. + Defaults to ``2``. """ super().__init__(session_timeout, max_concurrent_sessions) - self.secret_key = secret_key or secrets.token_urlsafe(32) + self.refresh_timeout = refresh_timeout + self.leeway = leeway + self.refresh_tokens = {} # Track valid refresh tokens (jti -> session_info) + self._lock = RLock() # to lock `self.session` dict for thread‑safe changes + self.secret_key = self._init_secret_key(secret_key) - def _create_token(self, username, jti, exp_time, now): + def _init_secret_key(self, secret_key: str = None) -> bytes: + """Initialize or validate the secret key used for cryptographic operations. + + Parameters + ---------- + secret_key : Union[str, None] + Optional user‑provided secret key as a string. + When ``None`` a fresh random key is created. + + Returns + ------- + bytes + The secret key encoded as UTF‑8 bytes, or a newly generated random + key when no input is given. + + See Also + -------- + plantdb.commons.auth.session._init_secret_key + """ + return _init_secret_key(secret_key) + + def _create_token(self, username, jti, exp_time, now, token_type='access'): """Create a JSON Web Token (JWT) with registered claims. Generates and encodes a JWT using the provided username, unique identifier @@ -447,10 +624,13 @@ def _create_token(self, username, jti, exp_time, now): The subject of the JWT (user identifier). jti : str Unique identifier for the JWT. - exp_time : datetime.datetime + exp_time : datetime Expiration time of the JWT. - now : datetime.datetime + now : datetime Current time when the JWT is issued. + token_type : str, optional + The type of token to create ('access' or 'refresh'). + Defaults to 'access'. Returns ------- @@ -469,16 +649,18 @@ def _create_token(self, username, jti, exp_time, now): 'aud': 'plantdb-client', # audience 'exp': int(exp_time.timestamp()), # expiration time (Unix timestamp) 'iat': int(now.timestamp()), # issued at (Unix timestamp) - 'jti': jti # JWT ID (unique identifier) + 'jti': jti, # JWT ID (unique identifier) + 'type': token_type # Custom claim for token type } - return jwt.encode( + token_bytes = jwt.encode( payload, self.secret_key, algorithm='HS512', headers={'typ': 'JWT', 'alg': 'HS512'} ) + return token_bytes if isinstance(token_bytes, str) else token_bytes.decode('utf-8') - def create_session(self, username: str) -> Union[str, None]: + def create_session(self, username: str) -> Union[Tuple[str, str], None]: """Create a new session for a user. If the user already has an active session, it returns the existing session ID. @@ -491,12 +673,12 @@ def create_session(self, username: str) -> Union[str, None]: Returns ------- - session_id : Union[str, None] - The ID of the created or existing session. + Tuple[str, str] or None + A tuple containing (access_token, refresh_token) if successful, ``None`` otherwise. Notes ----- - Creates a JSON Web Token following RFC 7519 standards with registered claims: + Creates JSON Web Tokens following RFC 7519 standards with registered claims: - iss (issuer): Identifies the token issuer - sub (subject): The username of the authenticated user - aud (audience): Intended audience for the token @@ -506,30 +688,47 @@ def create_session(self, username: str) -> Union[str, None]: """ if self.n_active_sessions() >= self.max_concurrent_sessions: self.logger.warning( - f"Too any users currently active, reached max concurrent sessions limit ({self.max_concurrent_sessions})") + f"Too many users currently active, reached max concurrent sessions limit ({self.max_concurrent_sessions})") return None - # Create a JWT payload with registered claims - now = datetime.now() - exp_time = now + timedelta(seconds=self.session_timeout) - jti = secrets.token_urlsafe(16) # unique token ID for tracking + # Create an access token + now = datetime.now(timezone.utc) + access_exp = now + timedelta(seconds=self.session_timeout) + access_jti = secrets.token_urlsafe(16) + + # Create a refresh token + refresh_exp = now + timedelta(seconds=self.refresh_timeout) + refresh_jti = secrets.token_urlsafe(16) try: - # Generate JSON Web Token - jwt_token = self._create_token(username, jti, exp_time, now) + # Generate JSON Web Tokens + access_token = self._create_token(username, access_jti, access_exp, now, token_type='access') + refresh_token = self._create_token(username, refresh_jti, refresh_exp, now, token_type='refresh') except Exception as e: - self.logger.error(f"Failed to create JSON Web Token for {username}: {e}") + self.logger.error(f"Failed to create JSON Web Tokens for {username}: {e}") return None - # Track session for concurrent limit enforcement - self.sessions[jti] = { - 'username': username, - 'created_at': now, - 'last_accessed': now, - 'expires_at': exp_time - } - self.logger.debug(f"Created JSON Web Token for '{username}'") - return jwt_token + with self._lock: + # Track access session for concurrent‑limit enforcement + self.sessions[access_jti] = { + 'username': username, + 'created_at': now, + 'last_accessed': now, + 'expires_at': access_exp, + 'type': 'access' + } + + # Track refresh token + self.refresh_tokens[refresh_jti] = { + 'username': username, + 'created_at': now, + 'expires_at': refresh_exp, + 'type': 'refresh', + 'access_jti': access_jti + } + + self.logger.debug(f"Created session for '{username}'") + return access_token, refresh_token def _payload_from_token(self, token: str) -> dict: """Decode the payload from a JSON Web Token. @@ -562,16 +761,21 @@ def _payload_from_token(self, token: str) -> dict: self.secret_key, algorithms=['HS512'], audience='plantdb-client', # Verify audience - issuer='plantdb-api' # Verify issuer + issuer='plantdb-api', # Verify issuer, + options={"require": ["exp", "iat", "iss", "aud"]}, # force the presence of these claims + leeway=self.leeway # allowed clock skew, in seconds ) - def validate_session(self, token: str) -> Optional[Dict[str, Any]]: + def validate_session(self, token: str, token_type: str = 'access') -> Optional[Dict[str, Any]]: """Validate a JSON Web Token and return user information. Parameters ---------- token : str The JSON Web Token to validate. + token_type : str, optional + The expected token type ('access' or 'refresh'). + Defaults to 'access'. Returns ------- @@ -584,31 +788,46 @@ def validate_session(self, token: str) -> Optional[Dict[str, Any]]: - jti: Unique token identifier - issuer: Token issuer - audience: Token audience + - type: Token type """ + # Decode and verify JSON Web Token with proper validation try: - # Decode and verify JSON Web Token with proper validation payload = self._payload_from_token(token) - - except jwt.ExpiredSignatureError: - self.logger.error("JSON Web Token expired") - return None - except jwt.InvalidAudienceError: + except jwt.ExpiredSignatureError as e: + self.logger.error(f"JSON Web Token ({token_type}) expired") + raise SessionValidationError(e) from e + except jwt.InvalidAudienceError as e: self.logger.error("JSON Web Token has invalid audience") - return None - except jwt.InvalidIssuerError: + raise SessionValidationError(e) from e + except jwt.InvalidIssuerError as e: self.logger.error("JSON Web Token has invalid issuer") - return None + raise SessionValidationError(e) from e except jwt.InvalidTokenError as e: self.logger.error(f"Invalid JSON Web Token: {e}") - return None + raise SessionValidationError(e) from e except Exception as e: self.logger.error(f"Error validating JSON Web Token: {e}") - return None + raise InvalidTokenProcessingError(e) from e + + # Check token type + if payload.get('type') != token_type: + self.logger.error(f"Invalid token type: expected {token_type}, got {payload.get('type')}") + raise WrongTokenType(f"Invalid token type: {token_type}") - # Update last accessed time in session tracking jti = payload.get('jti') - if jti and jti in self.sessions: - self.sessions[jti]['last_accessed'] = datetime.now() + + # Verify it's in our tracking list + if token_type == 'access': + if jti not in self.sessions: + self.logger.error("Access token not found in active sessions") + raise AccessTokenNotFoundError(f"Access token jti={jti} not found") + # Update last accessed time + with self._lock: + self.sessions[jti]['last_accessed'] = datetime.now(timezone.utc) + elif token_type == 'refresh': + if jti not in self.refresh_tokens: + self.logger.error("Refresh token not found in active refresh tokens") + raise RefreshTokenNotFoundError(f"Refresh token jti={jti} not found") return { 'username': payload['sub'], # subject is the username @@ -616,7 +835,8 @@ def validate_session(self, token: str) -> Optional[Dict[str, Any]]: 'expires_at': payload['exp'], # expiration timestamp 'jti': jti, # JWT ID 'issuer': payload['iss'], # issuer - 'audience': payload['aud'] # audience + 'audience': payload['aud'], # audience + 'type': payload.get('type') # type of token, 'access' or 'refresh' } def invalidate_session(self, token: str = None, jti: str = None) -> Tuple[bool, str | None]: @@ -640,25 +860,57 @@ def invalidate_session(self, token: str = None, jti: str = None) -> Tuple[bool, try: payload = self._payload_from_token(token) jti = payload.get('jti') - except: + token_type = payload.get('type', 'access') + except jwt.PyJWTError as e: + self.logger.error(f"Failed to decode token for invalidation: {e}") return False, None - - if jti and jti in self.sessions: - username = self.sessions[jti]['username'] - del self.sessions[jti] - return True, username + except KeyError as e: + self.logger.error(f"Failed to access payload key: {e}") + return False, None + else: + # If jti is provided, we need to know its type or check both + token_type = None + + with self._lock: + if token_type == 'access' or token_type is None: + if jti and jti in self.sessions: + username = self.sessions[jti]['username'] + del self.sessions[jti] + # Also invalidate the linked refresh token if any + refresh_jtis = [rj for rj, rs in self.refresh_tokens.items() if rs.get('access_jti') == jti] + for rj in refresh_jtis: + del self.refresh_tokens[rj] + return True, username + + if token_type == 'refresh' or token_type is None: + if jti and jti in self.refresh_tokens: + username = self.refresh_tokens[jti]['username'] + # Optionally invalidate the linked access token? + # Usually we just invalidate the refresh token. + del self.refresh_tokens[jti] + return True, username return False, None def cleanup_expired_sessions(self) -> None: """Remove expired sessions from tracking.""" - current_time = datetime.now() - expired_sessions = [ - jti for jti, session in self.sessions.items() - if current_time > session['expires_at'] - ] - for jti in expired_sessions: - del self.sessions[jti] + current_time = datetime.now(timezone.utc) + + with self._lock: + expired_access = [ + jti for jti, session in self.sessions.items() + if current_time > session['expires_at'] + ] + for jti in expired_access: + del self.sessions[jti] + + expired_refresh = [ + jti for jti, session in self.refresh_tokens.items() + if current_time > session['expires_at'] + ] + for jti in expired_refresh: + del self.refresh_tokens[jti] + return def session_username(self, token: str) -> Optional[str]: @@ -674,30 +926,36 @@ def session_username(self, token: str) -> Optional[str]: str or None Username if token is valid. """ - session_data = self.validate_session(token) - return session_data['username'] if session_data else None + try: + session_data = self.validate_session(token) + except SessionValidationError as e: + self.logger.warning(f"Provided session does not exist: {e}") + return None + return session_data['username'] - def refresh_session(self, token: str) -> Optional[str]: - """Refresh a JSON Web Token if it's still valid. + def refresh_session(self, refresh_token: str) -> Tuple[str, str]: + """Refresh a session using a valid refresh token. Parameters ---------- - token : str - Current JSON Web Token. + refresh_token : str + The refresh token to use. Returns ------- - str or None - New JSON Web Token if refresh is successful. + Tuple[str, str] + A tuple containing (new_access_token, new_refresh_token) if successful. """ - session_data = self.validate_session(token) - if not session_data: - return None + # Validate the refresh token – will raise if the refresh token is revoked or malformed + session_data = self.validate_session(refresh_token, token_type='refresh') - # Invalidate old session - old_jti = session_data['jti'] - self.invalidate_session(jti=old_jti) - - # Create a new session username = session_data['username'] + old_refresh_jti = session_data['jti'] + old_access_jti = self.refresh_tokens[old_refresh_jti].get('access_jti') + # Invalidate old tokens (Rotation) + self.invalidate_session(jti=old_refresh_jti) + if old_access_jti: + self.invalidate_session(jti=old_access_jti) + + # Create a new session (new access + new refresh) return self.create_session(username) diff --git a/src/commons/plantdb/commons/fsdb/core.py b/src/commons/plantdb/commons/fsdb/core.py index b116c037..5fb19250 100644 --- a/src/commons/plantdb/commons/fsdb/core.py +++ b/src/commons/plantdb/commons/fsdb/core.py @@ -595,7 +595,8 @@ def get_scans(self, query=None, **kwargs) -> List: for scan_id, scan in self.scans.items(): try: - metadata = scan.get_metadata() + # Access metadata directly to avoid nested lock acquisition + metadata = _get_metadata(scan.metadata, None, {}) if self.rbac_manager.can_access_scan(current_user, metadata, Permission.READ): accessible_scans[scan_id] = scan except Exception as e: @@ -658,7 +659,8 @@ def get_scan(self, scan_id, **kwargs): raise Exception("No valid user!") scan = self.scans[scan_id] - metadata = scan.get_metadata() + # Access metadata directly to avoid nested lock acquisition + metadata = _get_metadata(scan.metadata, None, {}) if self.rbac_manager.can_access_scan(current_user, metadata, Permission.READ): # Use shared lock for read operations with self.lock_manager.acquire_lock(scan_id, LockType.SHARED, current_user.username): @@ -821,9 +823,11 @@ def delete_scan(self, scan_id, **kwargs) -> bool: if not self.scan_exists(scan_id): raise ValueError(f"Scan '{scan_id}' does not exist!") - # Check DELETE permission for this specific scan + # Check DELETE permission for this specific scan using its metadata scan = self.scans[scan_id] - if not self.rbac_manager.can_access_scan(current_user, scan.get_metadata(), Permission.DELETE): + # Access metadata directly to avoid nested lock acquisition + metadata = _get_metadata(scan.metadata, None, {}) + if not self.rbac_manager.can_access_scan(current_user, metadata, Permission.DELETE): raise PermissionError( f"Insufficient permissions to delete '{scan_id}' scan as '{current_user.username}' user!") @@ -1047,7 +1051,7 @@ def login(self, username: str, password: str, **kwargs) -> Optional[str]: return None @require_token - def logout(self, **kwargs) -> bool: + def logout(self, **kwargs) -> tuple[bool, str]: """Log out a user by invalidating its session. Examples @@ -1057,16 +1061,16 @@ def logout(self, **kwargs) -> bool: INFO [FSDB] Successfully logged in as 'admin'. >>> db.logout() INFO [FSDB] Successfully logged out from 'admin'. - True + (True, 'admin') >>> db.disconnect() """ success, username = self.session_manager.invalidate_session(kwargs.get('token', None)) if success: self.logger.info(f"Successfully logged out from '{username}'.") - return True + return success, username else: self.logger.warning(f"Failed to logout!") - return False + return success, username @require_authentication def create_user(self, new_username, fullname, password, roles=None, **kwargs) -> None: @@ -1921,7 +1925,8 @@ def set_metadata(self, data, value=None, **kwargs): raise PermissionError("No authenticated user!") # Get current metadata for validation - old_metadata = self.get_metadata() + # Access scan metadata directly to avoid nested lock acquisition + old_metadata = _get_metadata(self.metadata, None, {}) if isinstance(data, str): if value is None: @@ -1935,7 +1940,7 @@ def set_metadata(self, data, value=None, **kwargs): else: new_metadata = data - # Validate metadata changes + # Validate scan metadata accessibility to current user if not self.db.rbac_manager.validate_scan_metadata_access(current_user, old_metadata, new_metadata): raise PermissionError(f"Insufficient permissions to modify scan '{self.id}' metadata!") @@ -2014,8 +2019,10 @@ def create_fileset(self, fs_id, metadata=None, **kwargs): if not current_user: raise PermissionError("No authenticated user!") - # Check WRITE permission for this fileset - if not self.db.rbac_manager.can_access_scan(current_user, self.get_metadata(), Permission.WRITE): + # Access scan metadata directly to avoid nested lock acquisition + metadata = _get_metadata(self.metadata, None, {}) + # Check WRITE permission for this fileset using scan metadata + if not self.db.rbac_manager.can_access_scan(current_user, metadata, Permission.WRITE): raise PermissionError(f"Insufficient permissions to create a fileset in the '{self.id}' scan!") # Verify if the given `fs_id` is valid @@ -2086,8 +2093,10 @@ def delete_fileset(self, fs_id, **kwargs) -> None: if not current_user: raise PermissionError("No authenticated user!") - # Check DELETE permission for this fileset - if not self.db.rbac_manager.can_access_scan(current_user, self.get_metadata(), Permission.DELETE): + # Access scan metadata directly to avoid nested lock acquisition + metadata = _get_metadata(self.metadata, None, {}) + # Check DELETE permission for this fileset using scan metadata + if not self.db.rbac_manager.can_access_scan(current_user, metadata, Permission.DELETE): raise PermissionError( f"Insufficient permissions to delete filesets from the '{self.id}' scan as '{current_user.username}' user!") @@ -2385,8 +2394,10 @@ def set_metadata(self, data, value=None, **kwargs): if not current_user: raise PermissionError("No authenticated user!") + # Access scan metadata directly to avoid nested lock acquisition + metadata = _get_metadata(self.scan.metadata, None, {}) # Check WRITE permission for this fileset - if not self.db.rbac_manager.can_access_scan(current_user, self.scan.get_metadata(), Permission.WRITE): + if not self.db.rbac_manager.can_access_scan(current_user, metadata, Permission.WRITE): raise PermissionError(f"Insufficient permissions to edit the '{self.scan.id}/{self.id}' fileset metadata!") # Use exclusive lock for this operation @@ -2452,8 +2463,10 @@ def create_file(self, f_id, metadata=None, **kwargs): if not current_user: raise PermissionError("No authenticated user!") + # Access scan metadata directly to avoid nested lock acquisition + metadata = _get_metadata(self.scan.metadata, None, {}) # Check WRITE permission for this file - if not self.db.rbac_manager.can_access_scan(current_user, self.scan.get_metadata(), Permission.WRITE): + if not self.db.rbac_manager.can_access_scan(current_user, metadata, Permission.WRITE): raise PermissionError( f"Insufficient permissions to create a file in the '{self.scan.id}' scan as '{current_user.username}' user!") @@ -2529,14 +2542,16 @@ def delete_file(self, f_id, **kwargs): if not current_user: raise PermissionError("No authenticated user!") + # Access scan metadata directly to avoid nested lock acquisition + metadata = _get_metadata(self.scan.metadata, None, {}) # Check DELETE permission for this fileset - if not self.db.rbac_manager.can_access_scan(current_user, self.scan.get_metadata(), Permission.DELETE): + if not self.db.rbac_manager.can_access_scan(current_user, metadata, Permission.DELETE): raise PermissionError( f"Insufficient permissions to delete the files from the '{self.scan.id}' scan as '{current_user.username}' user!") # Verify if the given `fs_id` exists in the local database if not self.file_exists(f_id): - raise ValueError(f"File '{f_id}' does not exist in scan '{self.id}'") + raise ValueError(f"File '{f_id}' does not exist in '{self.scan.id}/{self.id}'") # Use exclusive lock for fileset creation self.logger.info(f"Deleting file '{f_id}' from '{self.scan.id}/{self.id}' as '{current_user.username}' user...") diff --git a/src/commons/plantdb/commons/fsdb/lock.py b/src/commons/plantdb/commons/fsdb/lock.py index ee05b819..60c73ce9 100644 --- a/src/commons/plantdb/commons/fsdb/lock.py +++ b/src/commons/plantdb/commons/fsdb/lock.py @@ -75,9 +75,8 @@ def __str__(self) -> str: return self.message -class ScanLockManager: - """ - Acquires and releases file-based locks for thread-safe resource management. +class ScanLockManager : + """Acquires and releases file-based locks for thread-safe resource management. This class provides functionality for acquiring and releasing file-based locks, ensuring thread-safe operations across multiple threads or processes. Locks are @@ -140,6 +139,11 @@ def __init__(self, base_path: str, default_timeout: float = 30.0, **kwargs): self._lock_files: Dict[str, int] = {} # File descriptors for locks self._thread_lock = threading.RLock() # Thread-safe operations + # Store the last time we emitted a warning per scan_id + self._warning_timestamps: Dict[str, float] = {} + # How long we wait before emitting the same warning again (seconds) + self._warning_debounce_interval: float = kwargs.get("warning_debounce_interval", 5.0) + # Test write capability in base_path try: test_file = os.path.join(base_path, '.write_test.tmp') @@ -324,7 +328,9 @@ def acquire_lock(self, scan_id: str, lock_type: LockType, user: str, timeout: Op try: lock_file_path = self._get_lock_file_path(scan_id) + attempt = 0 while time.time() - start_time < timeout: + attempt += 1 try: # Open lock file lock_fd = os.open(lock_file_path, os.O_CREAT | os.O_WRONLY | os.O_TRUNC) @@ -337,6 +343,8 @@ def acquire_lock(self, scan_id: str, lock_type: LockType, user: str, timeout: Op # Try to acquire the lock (non-blocking) fcntl.flock(lock_fd, lock_flags) + # SUCCESS – clear any stale warning timestamp for this scan_id + self._warning_timestamps.pop(scan_id, None) # Lock acquired successfully with self._thread_lock: @@ -350,7 +358,10 @@ def acquire_lock(self, scan_id: str, lock_type: LockType, user: str, timeout: Op self._write_lock_info(scan_id, lock_type, user) acquired = True - self.logger.debug(f"Successfully acquired {lock_type.value} lock for scan {scan_id}") + self.logger.debug( + f"Successfully acquired {lock_type.value} lock for scan {scan_id} " + f"(attempt {attempt})" + ) break except (OSError, IOError) as e: @@ -359,7 +370,20 @@ def acquire_lock(self, scan_id: str, lock_type: LockType, user: str, timeout: Op os.close(lock_fd) except: pass - self.logger.warning(f"Lock acquisition attempt failed for {scan_id}, retrying...") + + now = time.time() + last_warn = self._warning_timestamps.get(scan_id, 0.0) + if now - last_warn >= self._warning_debounce_interval: + self.logger.warning( + f"Lock acquisition attempt failed for {scan_id} (attempt {attempt}) – " + f"retrying... [pid={os.getpid()}, tid={threading.get_ident()}]" + ) + self._warning_timestamps[scan_id] = now + else: + # We’re within the debounce window – log at DEBUG instead of spamming WARN + self.logger.debug( + f"Retrying lock for {scan_id} (attempt {attempt}) – still waiting" + ) time.sleep(0.1) # Brief pause before retry if not acquired: diff --git a/src/commons/plantdb/commons/test_database.py b/src/commons/plantdb/commons/test_database.py index 5c0b56bc..96ff818c 100644 --- a/src/commons/plantdb/commons/test_database.py +++ b/src/commons/plantdb/commons/test_database.py @@ -231,9 +231,8 @@ def _test_hash(tmp_fname, hash_value, hash_method="md5"): return -def _get_archive(archive, force=False): - """ - Download and verify an archive file from a given URL. +def _get_archive(archive, force=False) : + """Download and verify an archive file from a given URL. This function retrieves an archive file from a specified URL. If the file already exists locally and `force` is not set to ``True``, it skips downloading again, diff --git a/src/commons/pyproject.toml b/src/commons/pyproject.toml index 64ade28e..7ddd7eba 100644 --- a/src/commons/pyproject.toml +++ b/src/commons/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ ] description = "Core shared library for the ROMI plant database ecosystem." readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.8, <3.13" # 3.13 upper bound for open3d dependency license = { 'text' = "LGPL-3.0-or-later" } # to replace by "LGPL-3.0-or-later" only in next release, pending deprecation authors = [ { name = "Peter Hanappe", email = "peter.hanappe@sony.com" }, @@ -77,7 +77,7 @@ doc = [ io = [ "networkx", "tifffile", - # "open3d >=0.9.0.0", # get it from conda + "open3d >=0.9.0.0", ] test = [ "nose2[coverage]", diff --git a/src/commons/tests/test_auth.py b/src/commons/tests/test_auth.py index 731bb2d9..05bf3a0f 100644 --- a/src/commons/tests/test_auth.py +++ b/src/commons/tests/test_auth.py @@ -4,7 +4,6 @@ import logging import os import tempfile -import time import unittest from datetime import datetime from datetime import timedelta @@ -17,7 +16,6 @@ from plantdb.commons.auth.models import Role from plantdb.commons.auth.models import User from plantdb.commons.auth.rbac import RBACManager -from plantdb.commons.auth.session import SessionManager class TestUserManager(unittest.TestCase): @@ -356,116 +354,6 @@ def test_create_group_without_permission(self): self.assertIsNone(group) -class TestSessionManager(unittest.TestCase): - """Test cases for SessionManager class""" - - def setUp(self): - """Set up test fixtures before each test method.""" - self.session_manager = SessionManager(session_timeout=3600) # 1 hour timeout - - def test_init_creates_empty_sessions_dict(self): - """Test that SessionManager initializes with empty sessions dictionary.""" - # Verify initial state is correct - self.assertEqual(len(self.session_manager.sessions), 0) - self.assertEqual(self.session_manager.session_timeout, 3600) - self.assertIsNotNone(self.session_manager.logger) - - def test_user_has_session_returns_consistent_id(self): - """Test that _user_has_session returns consistent session IDs for same user.""" - # Test session ID generation - session_id1 = self.session_manager._user_has_session("testuser") - session_id2 = self.session_manager._user_has_session("testuser") - - # Session IDs should be consistent for same user at same time - self.assertEqual(session_id1, session_id2) - - def test_create_session_stores_session_data(self): - """Test that create_session properly stores session with expiry time.""" - # Create session - session_id = self.session_manager.create_session("testuser") - - # Verify session was created - self.assertIn(session_id, self.session_manager.sessions) - session_data = self.session_manager.sessions[session_id] - self.assertEqual(session_data['username'], "testuser") - - # Should have created_at timestamp - self.assertIn('created_at', session_data) - self.assertIsInstance(session_data['created_at'], datetime) - - def test_validate_session_returns_true_for_valid_session(self): - """Test that validate_session returns True for non-expired sessions.""" - # Create session first - session_id = self.session_manager.create_session("testuser") - - # Validate immediately (should be valid) - is_valid = self.session_manager.validate_session(session_id) - self.assertTrue(is_valid) - - def test_validate_session_returns_false_for_expired_session(self): - """Test that validate_session returns False for expired sessions.""" - # Create a session with a very short timeout - temp_session_manager = SessionManager(session_timeout=0.1) # 0.1 second timeout - session_id = temp_session_manager.create_session("testuser") - - # Wait for session to expire - time.sleep(0.2) - - # Validate after timeout period - is_valid = temp_session_manager.validate_session(session_id) - self.assertFalse(is_valid) - - def test_validate_session_returns_false_for_invalid_session_id(self): - """Test that validate_session returns False for non-existent session IDs.""" - is_valid = self.session_manager.validate_session("invalid_session_id") - self.assertFalse(is_valid) - - def test_invalidate_session_removes_session(self): - """Test that invalidate_session removes session from storage.""" - # Create session first - session_id = self.session_manager.create_session("testuser") - self.assertIn(session_id, self.session_manager.sessions) - - # Invalidate session - self.session_manager.invalidate_session(session_id) - self.assertNotIn(session_id, self.session_manager.sessions) - - def test_cleanup_expired_sessions_removes_old_sessions(self): - """Test that cleanup_expired_sessions removes only expired sessions.""" - # Create a session manager with a very short timeout for testing - temp_session_manager = SessionManager(session_timeout=0.5) # 0.5 second timeout - - # Create first session - session1_id = temp_session_manager.create_session("user1") - - # Wait for a moment - time.sleep(0.6) # Wait for first session to expire - - # Create second session - session2_id = temp_session_manager.create_session("user2") - - # Cleanup expired sessions - temp_session_manager.cleanup_expired_sessions() - - # Only expired session should be removed - self.assertNotIn(session1_id, temp_session_manager.sessions) - self.assertIn(session2_id, temp_session_manager.sessions) - - def test_session_username_returns_correct_username(self): - """Test that session_username returns the correct username for valid session.""" - # Create session - session_id = self.session_manager.create_session("testuser") - - # Get username from session - username = self.session_manager.session_username(session_id) - self.assertEqual(username, "testuser") - - def test_session_username_returns_none_for_invalid_session(self): - """Test that session_username returns None for invalid session ID.""" - username = self.session_manager.session_username("invalid_session") - self.assertIsNone(username) - - if __name__ == '__main__': # Configure logging to suppress logs during testing logging.getLogger().setLevel(logging.CRITICAL) diff --git a/src/commons/tests/test_auth_session.py b/src/commons/tests/test_auth_session.py new file mode 100644 index 00000000..97d9fee0 --- /dev/null +++ b/src/commons/tests/test_auth_session.py @@ -0,0 +1,541 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import time +import unittest +from datetime import datetime +from datetime import timezone + +import jwt + +from plantdb.commons.auth.session import AccessTokenNotFoundError +from plantdb.commons.auth.session import JWTSessionManager +from plantdb.commons.auth.session import RefreshTokenNotFoundError +from plantdb.commons.auth.session import SessionManager +from plantdb.commons.auth.session import SessionValidationError +from plantdb.commons.auth.session import WrongTokenType + + +class TestSessionManager(unittest.TestCase): + """Test cases for SessionManager class""" + + def setUp(self): + """Set up test fixtures before each test method.""" + self.session_manager = SessionManager(session_timeout=3600) # 1 hour timeout + + def test_init_creates_empty_sessions_dict(self): + """Test that SessionManager initializes with empty sessions dictionary.""" + # Verify initial state is correct + self.assertEqual(len(self.session_manager.sessions), 0) + self.assertEqual(self.session_manager.session_timeout, 3600) + self.assertIsNotNone(self.session_manager.logger) + + def test_user_has_session_returns_consistent_id(self): + """Test that _user_has_session returns consistent session IDs for same user.""" + # Test session ID generation + session_id1 = self.session_manager._user_has_session("testuser") + session_id2 = self.session_manager._user_has_session("testuser") + + # Session IDs should be consistent for same user at same time + self.assertEqual(session_id1, session_id2) + + def test_create_session_stores_session_data(self): + """Test that create_session properly stores session with expiry time.""" + # Create session + session_id = self.session_manager.create_session("testuser") + + # Verify session was created + self.assertIn(session_id, self.session_manager.sessions) + session_data = self.session_manager.sessions[session_id] + self.assertEqual(session_data['username'], "testuser") + + # Should have created_at timestamp + self.assertIn('created_at', session_data) + self.assertIsInstance(session_data['created_at'], datetime) + + def test_validate_session_returns_true_for_valid_session(self): + """Test that validate_session returns True for non-expired sessions.""" + # Create session first + session_id = self.session_manager.create_session("testuser") + + # Validate immediately (should be valid) + is_valid = self.session_manager.validate_session(session_id) + self.assertTrue(is_valid) + + def test_validate_session_returns_false_for_expired_session(self): + """Test that validate_session returns False for expired sessions.""" + # Create a session with a very short timeout + temp_session_manager = SessionManager(session_timeout=0.1) # 0.1 second timeout + session_id = temp_session_manager.create_session("testuser") + + # Wait for session to expire + time.sleep(0.2) + + # Validate after timeout period + is_valid = temp_session_manager.validate_session(session_id) + self.assertFalse(is_valid) + + def test_validate_session_returns_false_for_invalid_session_id(self): + """Test that validate_session returns False for non-existent session IDs.""" + is_valid = self.session_manager.validate_session("invalid_session_id") + self.assertFalse(is_valid) + + def test_invalidate_session_removes_session(self): + """Test that invalidate_session removes session from storage.""" + # Create session first + session_id = self.session_manager.create_session("testuser") + self.assertIn(session_id, self.session_manager.sessions) + + # Invalidate session + self.session_manager.invalidate_session(session_id) + self.assertNotIn(session_id, self.session_manager.sessions) + + def test_cleanup_expired_sessions_removes_old_sessions(self): + """Test that cleanup_expired_sessions removes only expired sessions.""" + # Create a session manager with a very short timeout for testing + temp_session_manager = SessionManager(session_timeout=0.5) # 0.5 second timeout + + # Create first session + session1_id = temp_session_manager.create_session("user1") + + # Wait for a moment + time.sleep(0.6) # Wait for first session to expire + + # Create second session + session2_id = temp_session_manager.create_session("user2") + + # Cleanup expired sessions + temp_session_manager.cleanup_expired_sessions() + + # Only expired session should be removed + self.assertNotIn(session1_id, temp_session_manager.sessions) + self.assertIn(session2_id, temp_session_manager.sessions) + + def test_session_username_returns_correct_username(self): + """Test that session_username returns the correct username for valid session.""" + # Create session + session_id = self.session_manager.create_session("testuser") + + # Get username from session + username = self.session_manager.session_username(session_id) + self.assertEqual(username, "testuser") + + def test_session_username_returns_none_for_invalid_session(self): + """Test that session_username returns None for invalid session ID.""" + username = self.session_manager.session_username("invalid_session") + self.assertIsNone(username) + + +class TestJWTSessionManager(unittest.TestCase): + """Comprehensive unit tests for :class:`JWTSessionManager`.""" + + # ------------------------------------------------------------------ + # Helper methods + # ------------------------------------------------------------------ + def _create_manager(self, session_timeout=2, refresh_timeout=5, max_sessions=2): + """ + Create a ``JWTSessionManager`` with short lifetimes so tests can + trigger expiration without long sleeps. + """ + return JWTSessionManager( + session_timeout=session_timeout, + refresh_timeout=refresh_timeout, + max_concurrent_sessions=max_sessions, + secret_key="test-secret-key", # deterministic for reproducibility + leeway=0, # no leeway for token validation during tests + ) + + # ------------------------------------------------------------------ + # Construction / secret‑key handling + # ------------------------------------------------------------------ + def test_init_derives_key_from_passphrase(self): + """A string secret is stretched with Argon2 to a 64‑byte key.""" + mgr = JWTSessionManager(secret_key="my‑passphrase") + # The derived key must be exactly 64 bytes for HS512. + self.assertIsInstance(mgr.secret_key, bytes) + self.assertEqual(len(mgr.secret_key), 64) + + def test_init_fails_on_short_binary_key(self): + """Providing a binary key < 64 bytes raises ``ValueError``.""" + short_key = b"12345" + with self.assertRaises(ValueError): + JWTSessionManager(secret_key=short_key) + + # ------------------------------------------------------------------ + # Basic session creation + # ------------------------------------------------------------------ + def test_create_session_returns_two_tokens_and_registers_them(self): + """`create_session` should return an (access, refresh) tuple and store both.""" + mgr = self._create_manager() + result = mgr.create_session("alice") + + # Verify we got a 2- of strings + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + access_token, refresh_token = result + self.assertIsInstance(access_token, str) + self.assertIsInstance(refresh_token, str) + + # Decode payloads just to verify they contain the expected fields + access_payload = jwt.decode( + access_token, + mgr.secret_key, + algorithms=["HS512"], + audience="plantdb-client", + issuer="plantdb-api", + ) + refresh_payload = jwt.decode( + refresh_token, + mgr.secret_key, + algorithms=["HS512"], + audience="plantdb-client", + issuer="plantdb-api", + ) + self.assertEqual(access_payload["sub"], "alice") + self.assertEqual(refresh_payload["sub"], "alice") + self.assertEqual(access_payload["type"], "access") + self.assertEqual(refresh_payload["type"], "refresh") + + # The manager must have stored the JTI of both tokens + self.assertIn(access_payload["jti"], mgr.sessions) + self.assertIn(refresh_payload["jti"], mgr.refresh_tokens) + + # ------------------------------------------------------------------ + # Validation of access / refresh tokens + # ------------------------------------------------------------------ + def test_validate_access_token_returns_payload_and_updates_last_accessed(self): + """`validate_session` for an access token returns a dict and updates its timestamp.""" + mgr = self._create_manager() + access, _ = mgr.create_session("bob") + before = mgr.sessions[jwt.decode( + access, + mgr.secret_key, + algorithms=["HS512"], + audience="plantdb-client", + issuer="plantdb-api", + )["jti"]]["last_accessed"] + + info = mgr.validate_session(access) # defaults to token_type='access' + self.assertEqual(info["username"], "bob") + self.assertEqual(info["type"], "access") + # ``last_accessed`` must be later than the previous value + after = mgr.sessions[info["jti"]]["last_accessed"] + self.assertGreater(after, before) + + def test_validate_refresh_token_returns_payload_without_touching_access(self): + """`validate_session` for a refresh token returns a dict and does not affect access dict.""" + mgr = self._create_manager() + _, refresh = mgr.create_session("carol") + payload = jwt.decode( + refresh, + mgr.secret_key, + algorithms=["HS512"], + audience="plantdb-client", + issuer="plantdb-api", + ) + # Record the state of the related access token before validation + access_jti = mgr.refresh_tokens[payload["jti"]]["access_jti"] + access_last = mgr.sessions[access_jti]["last_accessed"] + + info = mgr.validate_session(refresh, token_type="refresh") + self.assertEqual(info["username"], "carol") + self.assertEqual(info["type"], "refresh") + # Access token's ``last_accessed`` must stay unchanged + self.assertEqual(mgr.sessions[access_jti]["last_accessed"], access_last) + + # ------------------------------------------------------------------ + # Token‑type mismatch handling + # ------------------------------------------------------------------ + def test_token_type_mismatch_raises_invalid_token_processing_error(self): + """Supplying a refresh token when ``token_type='access'`` raises ``WrongTokenType``.""" + mgr = self._create_manager() + _, refresh = mgr.create_session("dave") + with self.assertRaises(WrongTokenType): + mgr.validate_session(refresh, token_type="access") + + # ------------------------------------------------------------------ + # Expiration handling + # ------------------------------------------------------------------ + def test_access_token_expiration_raises_session_validation_error(self): + """An expired access token triggers ``SessionValidationError``.""" + mgr = self._create_manager(session_timeout=1, refresh_timeout=5) + access, _ = mgr.create_session("eve") + time.sleep(2) # exceed access token life + with self.assertRaises(SessionValidationError): + mgr.validate_session(access) + + def test_refresh_token_expiration_raises_session_validation_error(self): + """An expired refresh token triggers ``SessionValidationError``.""" + mgr = self._create_manager(session_timeout=5, refresh_timeout=1) + _, refresh = mgr.create_session("frank") + time.sleep(2) # exceed refresh token life + with self.assertRaises(SessionValidationError): + mgr.validate_session(refresh, token_type="refresh") + + # ------------------------------------------------------------------ + # Malformed / audience / issuer errors + # ------------------------------------------------------------------ + def test_malformed_token_raises_session_validation_error(self): + """A syntactically invalid JWT raises ``SessionValidationError``.""" + mgr = self._create_manager() + malformed = "not.a.valid.jwt" + with self.assertRaises(SessionValidationError): + mgr.validate_session(malformed) + + def test_invalid_audience_raises_session_validation_error(self): + """A token with a wrong audience raises ``SessionValidationError``.""" + mgr = self._create_manager() + # Build a token with a different audience + now = datetime.now(timezone.utc) + payload = { + "iss": "plantdb-api", + "sub": "george", + "aud": "wrong-audience", + "exp": int((now.timestamp() + 10)), + "iat": int(now.timestamp()), + "jti": "dummy-jti", + "type": "access", + } + bad_token = jwt.encode(payload, mgr.secret_key, algorithm="HS512") + with self.assertRaises(SessionValidationError): + mgr.validate_session(bad_token) + + def test_invalid_issuer_raises_session_validation_error(self): + """A token with a wrong issuer raises ``SessionValidationError``.""" + mgr = self._create_manager() + now = datetime.now(timezone.utc) + payload = { + "iss": "wrong-issuer", + "sub": "harry", + "aud": "plantdb-client", + "exp": int((now.timestamp() + 10)), + "iat": int(now.timestamp()), + "jti": "dummy-jti-2", + "type": "access", + } + bad_token = jwt.encode(payload, mgr.secret_key, algorithm="HS512") + with self.assertRaises(SessionValidationError): + mgr.validate_session(bad_token) + + # ------------------------------------------------------------------ + # Missing token look‑ups + # ------------------------------------------------------------------ + def test_missing_access_token_raises_access_token_not_found_error(self): + """If the access JTI is absent from ``sessions`` an ``AccessTokenNotFoundError`` is raised.""" + mgr = self._create_manager() + access, _ = mgr.create_session("gina") + payload = jwt.decode( + access, + mgr.secret_key, + algorithms=["HS512"], + audience="plantdb-client", + issuer="plantdb-api", + ) + del mgr.sessions[payload["jti"]] # simulate loss + with self.assertRaises(AccessTokenNotFoundError): + mgr.validate_session(access) + + def test_missing_refresh_token_raises_refresh_token_not_found_error(self): + """If the refresh JTI is absent from ``refresh_tokens`` an ``RefreshTokenNotFoundError`` is raised.""" + mgr = self._create_manager() + _, refresh = mgr.create_session("hank") + payload = jwt.decode( + refresh, + mgr.secret_key, + algorithms=["HS512"], + audience="plantdb-client", + issuer="plantdb-api", + ) + del mgr.refresh_tokens[payload["jti"]] # simulate loss + with self.assertRaises(RefreshTokenNotFoundError): + mgr.validate_session(refresh, token_type="refresh") + + # ------------------------------------------------------------------ + # Invalidation paths + # ------------------------------------------------------------------ + def test_invalidate_access_by_jti_removes_linked_refresh(self): + """Invalidating an access token via JTI also removes its paired refresh token.""" + mgr = self._create_manager() + access, refresh = mgr.create_session("ivy") + access_jti = jwt.decode( + access, + mgr.secret_key, + algorithms=["HS512"], + audience="plantdb-client", + issuer="plantdb-api", + )["jti"] + refresh_jti = jwt.decode( + refresh, + mgr.secret_key, + algorithms=["HS512"], + audience="plantdb-client", + issuer="plantdb-api", + )["jti"] + + success, username = mgr.invalidate_session(jti=access_jti) + self.assertTrue(success) + self.assertEqual(username, "ivy") + self.assertNotIn(access_jti, mgr.sessions) + self.assertNotIn(refresh_jti, mgr.refresh_tokens) + + def test_invalidate_refresh_by_jti_keeps_access_intact(self): + """Invalidating a refresh token via JTI does not delete the associated access token.""" + mgr = self._create_manager() + access, refresh = mgr.create_session("jack") + access_jti = jwt.decode( + access, + mgr.secret_key, + algorithms=["HS512"], + audience="plantdb-client", + issuer="plantdb-api", + )["jti"] + refresh_jti = jwt.decode( + refresh, + mgr.secret_key, + algorithms=["HS512"], + audience="plantdb-client", + issuer="plantdb-api", + )["jti"] + + success, username = mgr.invalidate_session(jti=refresh_jti) + self.assertTrue(success) + self.assertEqual(username, "jack") + self.assertIn(access_jti, mgr.sessions) + self.assertNotIn(refresh_jti, mgr.refresh_tokens) + + def test_invalidate_unknown_jti_returns_false(self): + """Calling ``invalidate_session`` with a non‑existent JTI yields ``(False, None)``.""" + mgr = self._create_manager() + result = mgr.invalidate_session(jti="non‑existent-jti") + self.assertEqual(result, (False, None)) + + # ------------------------------------------------------------------ + # Refresh flow (rotation) and reuse detection + # ------------------------------------------------------------------ + def test_refresh_session_rotates_tokens_and_invalidates_old_ones(self): + """`refresh_session` returns fresh tokens and removes the previous pair.""" + mgr = self._create_manager() + old_access, old_refresh = mgr.create_session("kate") + old_access_jti = jwt.decode( + old_access, + mgr.secret_key, + algorithms=["HS512"], + audience="plantdb-client", + issuer="plantdb-api", + )["jti"] + old_refresh_jti = jwt.decode( + old_refresh, + mgr.secret_key, + algorithms=["HS512"], + audience="plantdb-client", + issuer="plantdb-api", + )["jti"] + + new_access, new_refresh = mgr.refresh_session(old_refresh) + + # Old tokens must be gone + self.assertNotIn(old_access_jti, mgr.sessions) + self.assertNotIn(old_refresh_jti, mgr.refresh_tokens) + + # New tokens must be present + new_access_jti = jwt.decode( + new_access, + mgr.secret_key, + algorithms=["HS512"], + audience="plantdb-client", + issuer="plantdb-api", + )["jti"] + new_refresh_jti = jwt.decode( + new_refresh, + mgr.secret_key, + algorithms=["HS512"], + audience="plantdb-client", + issuer="plantdb-api", + )["jti"] + self.assertIn(new_access_jti, mgr.sessions) + self.assertIn(new_refresh_jti, mgr.refresh_tokens) + + def test_refresh_token_reuse_raises_refresh_token_reuse_error(self): + """Attempting to reuse a refresh token after rotation raises ``RefreshTokenReuseError``.""" + mgr = self._create_manager() + _, refresh = mgr.create_session("linda") + # First rotation – should succeed + new_access, new_refresh = mgr.refresh_session(refresh) + self.assertIsNotNone(new_access) + self.assertIsNotNone(new_refresh) + + # Second attempt with the *old* refresh token must fail + with self.assertRaises(RefreshTokenNotFoundError): + mgr.refresh_session(refresh) + + # ------------------------------------------------------------------ + # Automatic cleanup of expired entries + # ------------------------------------------------------------------ + def test_cleanup_expired_sessions_purges_only_expired_tokens(self): + """`cleanup_expired_sessions` removes expired items but keeps still‑valid ones.""" + mgr = self._create_manager(session_timeout=1, refresh_timeout=5) + # Create two sessions: one that will expire, one that stays alive + short_access, short_refresh = mgr.create_session("mia") + short_jti = jwt.decode( + short_access, + mgr.secret_key, + algorithms=["HS512"], + audience="plantdb-client", + issuer="plantdb-api", + )["jti"] + + time.sleep(1.5) # let the access token expire, refresh still alive + long_access, long_refresh = mgr.create_session("nora") + long_jti = jwt.decode( + long_access, + mgr.secret_key, + algorithms=["HS512"], + audience="plantdb-client", + issuer="plantdb-api", + )["jti"] + + # Run cleanup – only the first access token should disappear + mgr.cleanup_expired_sessions() + self.assertNotIn(short_jti, mgr.sessions) + # The second access token must still be present + self.assertIn(long_jti, mgr.sessions) + + # Refresh token of the first session is still valid (refresh_timeout = 3 s) + short_refresh_jti = jwt.decode( + short_refresh, + mgr.secret_key, + algorithms=["HS512"], + audience="plantdb-client", + issuer="plantdb-api", + )["jti"] + self.assertIn(short_refresh_jti, mgr.refresh_tokens) + + # ------------------------------------------------------------------ + # Concurrency limit enforcement + # ------------------------------------------------------------------ + def test_exceeding_max_concurrent_sessions_returns_none(self): + """When the max‑session limit is reached, ``create_session`` returns ``None``.""" + mgr = self._create_manager(max_sessions=2) + self.assertIsNotNone(mgr.create_session("oliver")) + self.assertIsNotNone(mgr.create_session("peter")) + # Third attempt should be rejected + self.assertIsNone(mgr.create_session("quinn")) + + # ------------------------------------------------------------------ + # session_username helper + # ------------------------------------------------------------------ + def test_session_username_returns_username_for_valid_token(self): + """`session_username` extracts the username from a valid access token.""" + mgr = self._create_manager() + access, _ = mgr.create_session("rachel") + self.assertEqual(mgr.session_username(access), "rachel") + + def test_session_username_returns_none_for_invalid_token(self): + """`session_username` returns ``None`` when token validation fails.""" + mgr = self._create_manager() + malformed = "invalid.token.parts" + self.assertIsNone(mgr.session_username(malformed)) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/commons/tests/test_fsdb_lock.py b/src/commons/tests/test_fsdb_lock.py index 59509f1c..eec273c6 100644 --- a/src/commons/tests/test_fsdb_lock.py +++ b/src/commons/tests/test_fsdb_lock.py @@ -4,16 +4,19 @@ import fcntl import json import os +import shutil import tempfile import threading import time import unittest from unittest.mock import patch +from plantdb.commons.fsdb.core import FSDB from plantdb.commons.fsdb.lock import LockError from plantdb.commons.fsdb.lock import LockTimeoutError from plantdb.commons.fsdb.lock import LockType from plantdb.commons.fsdb.lock import ScanLockManager +from plantdb.commons.test_database import setup_test_database class TestLockType(unittest.TestCase): @@ -455,5 +458,79 @@ def acquire_exclusive_lock(user): self.assertEqual(len(errors), 4) # The other 4 should have timed out +class TestFSDBLock(unittest.TestCase): + """Test suite for verifying permission handling in FSDB operations. + + The tests exercise deletion of files, filesets, and scans under different user roles. + A temporary FSDB directory is created for each test case. + + Attributes + ---------- + fsdb_dir : str + Path to the temporary FSDB directory created for the test suite. + """ + def setUp(self): + """Set up a temporary directory for testing.""" + self.fsdb_dir = setup_test_database(['real_plant', 'real_plant_analyzed'], db_path=None) + + def tearDown(self): + """Clean up the temporary directory after tests.""" + shutil.rmtree(self.fsdb_dir) + + def test_delete_file_without_permission(self): + """Test that attempting to delete a file without sufficient permissions raises ``PermissionError``.""" + db = FSDB(self.fsdb_dir) + db.connect() + db.login("guest", "guest") + scan = db.get_scan('real_plant_analyzed') + fs = scan.get_fileset('Masks_1__0__1__0____channel____rgb_5619aa428d') + + with self.assertRaises(PermissionError): + fs.delete_file("00000_rgb.png") + + def test_delete_file_with_permission(self): + """Test deletion of a file when the user has sufficient permissions.""" + db = FSDB(self.fsdb_dir) + db.connect() + db.login("admin", "admin") + scan = db.get_scan('real_plant_analyzed') + fs = scan.get_fileset('Masks_1__0__1__0____channel____rgb_5619aa428d') + fs.delete_file("00000_rgb") + + def test_delete_fileset_without_permission(self): + """Test deletion of a fileset without sufficient permissions.""" + db = FSDB(self.fsdb_dir) + db.connect() + db.login("guest", "guest") + scan = db.get_scan('real_plant_analyzed') + + with self.assertRaises(PermissionError): + scan.delete_fileset('Masks_1__0__1__0____channel____rgb_5619aa428d') + + def test_delete_fileset_with_permission(self): + """Delete a fileset from a scan when the user has sufficient permissions.""" + db = FSDB(self.fsdb_dir) + db.connect() + db.login("admin", "admin") + scan = db.get_scan('real_plant_analyzed') + scan.delete_fileset('Masks_1__0__1__0____channel____rgb_5619aa428d') + + def test_delete_scan_without_permission(self): + """Test that attempting to delete a scan without sufficient permissions raises ``PermissionError``.""" + db = FSDB(self.fsdb_dir) + db.connect() + db.login("guest", "guest") + + with self.assertRaises(PermissionError): + db.delete_scan('real_plant_analyzed') + + def test_delete_scan_with_permission(self): + """Test that an administrator can delete an existing scan.""" + db = FSDB(self.fsdb_dir) + db.connect() + db.login("admin", "admin") + db.delete_scan('real_plant_analyzed') + + if __name__ == '__main__': unittest.main() diff --git a/src/server/plantdb/server/cli/fsdb_rest_api.py b/src/server/plantdb/server/cli/fsdb_rest_api.py index 35bd8261..8b9b0eb7 100644 --- a/src/server/plantdb/server/cli/fsdb_rest_api.py +++ b/src/server/plantdb/server/cli/fsdb_rest_api.py @@ -43,8 +43,11 @@ - ``ROMI_DB``: Path to the directory containing the FSDB. Default: '/myapp/db' (container) - ``PLANTDB_API_PREFIX``: Prefix for the REST API URL. Default is empty. - ``PLANTDB_API_SSL``: Enable SSL to use an HTTPS scheme. Default is `False`. -- ``FLASK_SECRET_KEY``: The secret key to use with flask. Default to random (32 bits secret). -- ``JWT_SECRET_KEY``: The secret key to use with JSON Web Token generator. Default to random (32 bits secret). +- ``FLASK_SECRET_KEY``: The secret key to use with flask. Default to random (64 bits secret). +- ``JWT_SECRET_KEY``: The secret key to use with JSON Web Token generator. Default to random (64 bits secret). +- ``SESSION_TIMEOUT``: Session JWT validity duration in seconds. Default `900` seconds (15 min). +- ``REFRESH_TIMEOUT``: Refresh JWT validity duration in seconds. Default `86400` seconds (1 day). +- ``MAX_SESSION``: The maximum number of concurrent sessions to allow. Default `10`. Usage Examples -------------- @@ -60,13 +63,6 @@ python fsdb_rest_api.py --test --debug ``` -RESTful endpoints include: -- `/scans`: List all scans available in the database. -- `/files/`: Retrieve files from the database. -- `/image///`: Access specific images. -- `/pointcloud///`: Access specific point clouds. -- `/mesh///`: Retrieve related meshes. - For detailed command-line parameters, use the `--help` flag: ```shell python fsdb_rest_api.py --help @@ -77,7 +73,6 @@ import atexit import logging import os -import secrets import shutil import sys from pathlib import Path @@ -91,6 +86,7 @@ from werkzeug.middleware.proxy_fix import ProxyFix from plantdb.commons.auth.session import JWTSessionManager +from plantdb.commons.auth.session import _init_secret_key from plantdb.commons.fsdb.core import FSDB from plantdb.commons.log import DEFAULT_LOG_LEVEL from plantdb.commons.log import LOG_LEVELS @@ -128,8 +124,7 @@ def parsing() -> argparse.ArgumentParser: - """ - Create and configure an argument parser for a REST API server. + """Create and configure an argument parser for a REST API server. Returns ------- @@ -166,8 +161,7 @@ def parsing() -> argparse.ArgumentParser: def _get_env_secret(var_name: str, logger: logging.Logger) -> str: - """ - Retrieve a secret from the environment or generate a new one if missing. + """Retrieve a secret from the environment or generate a new one if missing. Parameters ---------- @@ -185,13 +179,12 @@ def _get_env_secret(var_name: str, logger: logging.Logger) -> str: if secret is None: logger.warning(f"No secret key was provided for {var_name}.") logger.info(f"Set one with the '{var_name}' environment variable or let the server generate a random one.") - secret = secrets.token_urlsafe(32) + secret = _init_secret_key(secret) return secret def _configure_app(secret_key: str, ssl: bool = False) -> Flask: - """ - Create and configure a Flask application instance. + """Create and configure a Flask application instance. Parameters ---------- @@ -202,7 +195,7 @@ def _configure_app(secret_key: str, ssl: bool = False) -> Flask: Returns ------- - Flask + flask.Flask The configured Flask application. """ app = Flask(__name__) @@ -221,7 +214,7 @@ def _configure_api(app: Flask, proxy: bool, url_prefix: str, logger: logging.Log Parameters ---------- - app : Flask + app : flask.Flask The Flask application to extend. proxy : bool Whether the server is behind a reverse proxy. @@ -267,11 +260,14 @@ def _setup_test_database(empty: bool, models: bool, db_path: Optional[Union[str, The path to the created test database. """ jwt_key = _get_env_secret("JWT_SECRET_KEY", logger) + session_timeout = int(os.getenv("SESSION_TIMEOUT", 3600)) + max_sessions = int(os.getenv("MAX_SESSION", 10)) if empty: logger.info("Setting up a temporary test database without any datasets or configurations...") db_path = test_database( None, db_path=db_path, - session_manager=JWTSessionManager(secret_key=jwt_key) + session_manager=JWTSessionManager(secret_key=jwt_key, session_timeout=session_timeout, + max_concurrent_sessions=max_sessions) ).path() else: logger.info("Setting up a temporary test database with sample datasets and configurations...") @@ -280,7 +276,8 @@ def _setup_test_database(empty: bool, models: bool, db_path: Optional[Union[str, db_path=db_path, with_configs=True, with_models=models, - session_manager=JWTSessionManager(secret_key=jwt_key) + session_manager=JWTSessionManager(secret_key=jwt_key, session_timeout=session_timeout, + max_concurrent_sessions=max_sessions) ).path() return Path(db_path) @@ -511,9 +508,17 @@ def _cleanup() -> None: # 4 - Database connection jwt_key = _get_env_secret("JWT_SECRET_KEY", logger) + session_timeout = int(os.getenv("SESSION_TIMEOUT", 3600)) + refresh_timeout = int(os.getenv("REFRESH_TIMEOUT", 86400)) + max_sessions = int(os.getenv("MAX_SESSION", 10)) db = FSDB( db_path, - session_manager=JWTSessionManager(secret_key=jwt_key), + session_manager=JWTSessionManager( + secret_key=jwt_key, + session_timeout=session_timeout, + refresh_timeout=refresh_timeout, + max_concurrent_sessions=max_sessions, + ), ) logger.info(f"Connecting to local plant database at '{db.path()}'.") db.connect() diff --git a/src/server/plantdb/server/cli/wsgi.py b/src/server/plantdb/server/cli/wsgi.py index 0831d78a..0d20189c 100644 --- a/src/server/plantdb/server/cli/wsgi.py +++ b/src/server/plantdb/server/cli/wsgi.py @@ -18,8 +18,11 @@ - ``ROMI_DB``: Path to the directory containing the FSDB. Default: '/myapp/db' (container) - ``PLANTDB_API_PREFIX``: Prefix for the REST API URL. Default is empty. - ``PLANTDB_API_SSL``: Enable SSL to use an HTTPS scheme. Default is `False`. -- ``FLASK_SECRET_KEY``: The secret key to use with flask. Default to random (32 bits secret). -- ``JWT_SECRET_KEY``: The secret key to use with JSON Web Token generator. Default to random (32 bits secret). +- ``FLASK_SECRET_KEY``: The secret key to use with flask. Default to random (64 bits secret). +- ``JWT_SECRET_KEY``: The secret key to use with JSON Web Token generator. Default to random (64 bits secret). +- ``SESSION_TIMEOUT``: Session JWT validity duration in seconds. Default `900` seconds (15 min). +- ``REFRESH_TIMEOUT``: Refresh JWT validity duration in seconds. Default `86400` seconds (1 day). +- ``MAX_SESSION``: The maximum number of concurrent sessions to allow. Default `10`. Usage Examples -------------- diff --git a/src/server/plantdb/server/rest_api.py b/src/server/plantdb/server/rest_api.py index dfd185ee..2c7dba79 100644 --- a/src/server/plantdb/server/rest_api.py +++ b/src/server/plantdb/server/rest_api.py @@ -24,8 +24,24 @@ # ------------------------------------------------------------------------------ """ -This module regroup the classes and methods used to serve a REST API using ``fsdb_rest_api`` CLI. +REST API for PlantDB + +This module implements a collection of Flask Resource classes that expose endpoints for handling scans, +datasets, authentication, and related resources. +It centralizes request handling, JWT validation, rate‑limiting, and file‑URI resolution, providing a +robust backend for 3‑D scan data services. + +Key Features +------------ +- **Authentication** – JWT‑based login, logout, token validation and refresh. +- **Health monitoring** – Simple health‑check endpoint exposing database status. +- **Scan lifecycle** – Create, retrieve, update, and list scans and their associated filesets. +- **File handling** – Endpoints for uploading, retrieving, and managing individual files, datasets, + point clouds, meshes, and related assets. +- **Metadata services** – Accessors for scan metadata, fileset metadata, and file‑specific metadata. +- **Utility helpers** – Functions for URI generation, name sanitization, rate limiting, and archive validation. """ + import datetime import hashlib import json @@ -39,7 +55,7 @@ from math import radians from pathlib import Path from tempfile import mkstemp -from typing import Optional +from urllib import parse from zipfile import ZipFile import pybase64 @@ -52,6 +68,7 @@ from flask import send_from_directory from flask_restful import Resource +from plantdb.commons.auth.session import SessionValidationError from plantdb.commons.fsdb.exceptions import FileNotFoundError from plantdb.commons.fsdb.exceptions import FilesetNotFoundError from plantdb.commons.fsdb.exceptions import ScanNotFoundError @@ -408,7 +425,7 @@ def get_file_uri(scan, fileset, file): return f"/files/{scan_id}/{fileset_id}/{file_id}" -def get_image_uri(scan, fileset, file, size="orig"): +def get_image_uri(scan, fileset, file, size="orig", as_base64=False): """Return the URI for the corresponding `scan/fileset/file` tree. Parameters @@ -426,6 +443,8 @@ def get_image_uri(scan, fileset, file, size="orig"): - `'thumb'`: image max width and height to `150`. - `'large'`: image max width and height to `1500`; - `'orig'`: original image, no chache; + as_base64 : bool, optional + A boolean flag indicating whether to return an image as a base64 string. Returns ------- @@ -442,8 +461,8 @@ def get_image_uri(scan, fileset, file, size="orig"): >>> scan = db.get_scan('real_plant_analyzed') >>> get_image_uri(scan, 'images', '00000_rgb.jpg', size='orig') '/image/real_plant_analyzed/images/00000_rgb.jpg?size=orig' - >>> get_image_uri(scan, 'images', '00011_rgb.jpg', size='thumb') - '/image/real_plant_analyzed/images/00011_rgb.jpg?size=thumb' + >>> get_image_uri(scan, 'images', '00011_rgb.jpg', size='thumb', as_base64=True) + '/image/real_plant_analyzed/images/00011_rgb.jpg?size=thumb&as_base64=true' """ from plantdb.commons.fsdb.core import Scan from plantdb.commons.fsdb.core import Fileset @@ -451,7 +470,17 @@ def get_image_uri(scan, fileset, file, size="orig"): scan_id = scan.id if isinstance(scan, Scan) else scan fileset_id = fileset.id if isinstance(fileset, Fileset) else fileset file_id = file.path().name if isinstance(file, File) else file - return f"/image/{scan_id}/{fileset_id}/{file_id}?size={size}" + + # Assemble optional query parameters + query: dict[str, str] = {} + if size is not None: + query["size"] = str(size) + if as_base64: + # Use lower‑case JSON‑style booleans for consistency + query["as_base64"] = str(as_base64).lower() + + query_str = f"?{parse.urlencode(query)}" if query else "" + return f"/image/{scan_id}/{fileset_id}/{file_id}{query_str}" task_filesUri_mapping = { @@ -700,7 +729,7 @@ def wrapped(*args, **kwargs): def jwt_from_header(request) -> str: - """Extracts the JWT token from the Authorization header of an HTTP request. + """Extracts the JSON Web Token from the Authorization header of an HTTP request. Parameters ---------- @@ -710,7 +739,7 @@ def jwt_from_header(request) -> str: Returns ------- str - The JWT token extracted from the ``Authorization`` header, or an + The JSON Web Token extracted from the ``Authorization`` header, or an empty string if the header is missing or empty. Notes @@ -926,6 +955,10 @@ def post(self, **kwargs): 'success': True, 'message': 'User successfully created' }, 201 + + except SessionValidationError as e: + return {'message': 'Invalid credentials'}, 401 + except Exception as e: # Return error response if user creation fails (e.g., duplicate username) return { @@ -1063,19 +1096,21 @@ def post(self): password = data['password'] # Attempt to authenticate user with provided credentials - jwt_token = self.db.login(username, password) + tokens = self.db.login(username, password) # Prepare response based on authentication result - if jwt_token: - user = self.db.get_user_data(token=jwt_token) - # Create response with user info & access token + if tokens: + access_token, refresh_token = tokens + user = self.db.get_user_data(token=access_token) + # Create response with user info, access token, and refresh token response_data = { 'message': 'Login successful', 'user': { 'username': user.username, 'fullname': user.fullname, }, - 'access_token': jwt_token + 'access_token': access_token, + 'refresh_token': refresh_token } response = make_response(jsonify(response_data), 200) return response @@ -1117,17 +1152,23 @@ def post(self, **kwargs): try: if 'token' in kwargs: # Invalidate session - self.db.logout(**kwargs) - response = {'message': 'Logout successful'}, 200 + success, username = self.db.logout(**kwargs) + if success: + response = {'message': f'Logout successful from {username}'}, 200 + else: + response = {'message': 'Logout failed!'}, 401 else: self.logger.error(f"Logout error: no active session!") - response = {'message': 'Logout failed'}, 401 + response = {'message': 'Logout failed, no active session!'}, 401 return response + except SessionValidationError as e: + return {'message': 'Invalid credentials!'}, 401 + except Exception as e: self.logger.error(f"Logout error: {str(e)}") - return {'message': 'Logout failed'}, 500 + return {'message': 'Logout failed!'}, 500 class TokenValidation(Resource): @@ -1216,39 +1257,31 @@ def __init__(self, db): """Initialize the TokenRefresh resource.""" self.db = db - @add_jwt_from_header def post(self, **kwargs): """Refresh JSON Web Token. - Examples - -------- - >>> # Start a test REST API server first: - >>> # $ fsdb_rest_api --test - >>> import requests - >>> # Start by login as admin - >>> response = requests.post('http://127.0.0.1:5000/login', json={'username': 'admin', 'password': 'admin'}) - >>> token = response.json()['access_token'] - >>> # Now refresht the token for the admin user: - >>> response = requests.post("http://127.0.0.1:5000/token-refresh", headers={'Authorization': 'Bearer ' + token}) - >>> print(response.json()['message']) - Token refreshed successfully - >>> new_token = response.json()['access_token'] - >>> # Validate this new token: - >>> response = requests.post("http://127.0.0.1:5000/token-validation", headers={'Authorization': 'Bearer ' + new_token}) - >>> print(response.json()['user']['username']) - admin + This method expects a JSON payload containing a 'refresh_token'. + It validates the refresh token and issues a new access/refresh token pair. """ - # Get token from keyword arguments (from decorator) - jwt_token = kwargs.get('token', None) + data = request.get_json() + if not data or 'refresh_token' not in data: + return {'message': 'Missing refresh_token'}, 400 - try: - new_token = self.db.session_manager.refresh_session(jwt_token) + refresh_token = data['refresh_token'] - if new_token: - response = {'message': 'Token refreshed successfully', 'access_token': new_token}, 200 + try: + tokens = self.db.session_manager.refresh_session(refresh_token) + + if tokens: + access_token, new_refresh_token = tokens + response = { + 'message': 'Token refreshed successfully', + 'access_token': access_token, + 'refresh_token': new_refresh_token + }, 200 return response else: - return {'message': 'Token refresh failed'}, 401 + return {'message': 'Invalid or expired refresh token'}, 401 except Exception as e: return {'message': f'Token refresh failed: {e}'}, 500 @@ -1417,21 +1450,24 @@ def get(self, **kwargs): ----- The method can take direct parameters in the request body with the following fields: - filter_query: JSON string representing the filter query, example: ``{"object":{"species":"Arabidopsis.*"}}``. - - fuzzy: Boolean indicating whether to perform fuzzy filtering, ``false`` by default. + - fuzzy: Boolean indicating whether to perform fuzzy filtering, ``False`` by default. Examples -------- >>> # Start a test REST API server first: >>> # $ fsdb_rest_api --test >>> import requests - >>> # Get an info dict about all dataset: + >>> # Get a list of information dictionaries about all datasets: >>> response = requests.get("http://127.0.0.1:5000/scans_info") - >>> scans_list = response.json() + >>> scans_info = response.json() >>> # List the known dataset id: - >>> print(scans_list) - ['arabidopsis000', 'virtual_plant_analyzed', 'real_plant_analyzed', 'real_plant', 'virtual_plant', 'models'] + >>> print(sorted(scan['id'] for scan in scans_info)) + ['arabidopsis000', 'real_plant', 'real_plant_analyzed', 'virtual_plant', 'virtual_plant_analyzed'] + >>> # Add a metadata filter to the query: >>> response = requests.get('http://127.0.0.1:5000/scans_info?filterQuery={"object":{"species":"Arabidopsis.*"}}&fuzzy="true"') - >>> response.content.decode() + >>> scans_info = response.json() + >>> print(sorted(scan['id'] for scan in scans_info)) + ['virtual_plant', 'virtual_plant_analyzed'] """ query = request.args.get('filterQuery', None) fuzzy = request.args.get('fuzzy', False, type=bool) @@ -1809,7 +1845,6 @@ def post(self, scan_id): def write_stream(file_path, content_length, chunk_size): bytes_received = 0 with open(file_path, 'wb') as file: - print(f"Received: {bytes_received}") while bytes_received < content_length: chunk = request.stream.read(min(chunk_size, content_length - bytes_received)) if not chunk: @@ -1995,6 +2030,14 @@ def __init__(self, db): """ self.db = db + @staticmethod + def wants_base64(request) -> bool: + """ + Return ``True`` when the query string contains ``as_base64`` with a truthy value. + """ + flag = request.args.get('as_base64', default='false', type=str).lower() + return flag in ('true', '1', 'yes') + @rate_limit(max_requests=3000, window_seconds=60) def get(self, scan_id, fileset_id, file_id): """Retrieve and serve an image from the database. @@ -2021,7 +2064,7 @@ def get(self, scan_id, fileset_id, file_id): * `'large'`: image max width and height to `1500`; * `'orig'`: original image, no chache; If an invalid string is supplied, the default 'thumb' is used. - base64 : str + as_base64 : str Query parameter indicating whether to return the image encoded in base64. Accepts 'true', '1', 'yes' (case‑insensitive) to enable. Defaults to 'false', which streams the image file. @@ -2045,7 +2088,7 @@ def get(self, scan_id, fileset_id, file_id): See Also -------- - plantdb.server.rest_api.sanitize_name : Input sanitization & validation function. + plantdb.server.rest_api.sanitize_name : Input sanitization and validation function. plantdb.server.webcache.image_path : Image path resolution function with caching and resizing options. Examples @@ -2053,18 +2096,31 @@ def get(self, scan_id, fileset_id, file_id): >>> # In a terminal, start a (test) REST API with `fsdb_rest_api --test`, then: >>> import numpy as np >>> import requests + >>> import pybase64 >>> from io import BytesIO >>> from PIL import Image - >>> # Get the first image as a thumbnail (default): + >>> # Example #1 - Get the first image as a thumbnail (default): >>> response = requests.get("http://127.0.0.1:5000/image/real_plant_analyzed/images/00000_rgb", stream=True) >>> img = Image.open(BytesIO(response.content)) + >>> image.show() >>> np.asarray(img).shape (113, 150, 3) - >>> # Get the first image in original size: + >>> # Example #2 - Get the first image in original size: >>> response = requests.get("http://127.0.0.1:5000/image/real_plant_analyzed/images/00000_rgb", stream=True, params={"size": "orig"}) >>> img = Image.open(BytesIO(response.content)) + >>> image.show() >>> np.asarray(img).shape (1080, 1440, 3) + >>> # Example #3 - Get a base64 encoded image: + >>> response = requests.get("http://127.0.0.1:5000/image/real_plant_analyzed/images/00000_rgb", stream=True, params={"size": "orig", "as_base64": 'true'}) + >>> print(response.json()['content-type']) + 'image/jpeg' + >>> b64_string = response.json()['image'] + >>> print(b64_string[:30]) # print the first 30 characters + '/9j/4AAQSkZJRgABAQAAAQABAAD/2w' + >>> image_data = pybase64.b64decode(b64_string) + >>> image = Image.open(BytesIO(image_data)) + >>> image.show() """ # Sanitize identifiers scan_id = sanitize_name(scan_id) @@ -2073,21 +2129,29 @@ def get(self, scan_id, fileset_id, file_id): # Parse the `size` flag size = request.args.get('size', default='thumb', type=str) - # Parse the base64 flag (accepting true/1/yes in any case) - base64_flag = request.args.get('base64', default='false', type=str).lower() in ('true', '1', 'yes') - # Get the path to the image resource: path = webcache.image_path(self.db, scan_id, fileset_id, file_id, size) mime_type, _ = mimetypes.guess_type(path) - # If base64_flag is set, read the file, encode it, and return JSON - if base64_flag: + if self.wants_base64(request): + # ---------- JSON (base64) ---------- with open(path, 'rb') as f: - encoded = pybase64.b64encode(f.read()).decode('ascii') - return jsonify({'image': encoded, 'content-type': mime_type}) - # Otherwise, return the file directly - return send_file(path, mimetype=mime_type) + b64_str = pybase64.b64encode(f.read()).decode('ascii') + # ``decode('ascii')`` gives us a plain string that can be JSON‑encoded. + payload = { + "image": b64_str, + "content-type": mime_type + } + # Wrap ``jsonify`` with ``make_response`` to add custom headers + resp = make_response(jsonify(payload)) + resp.headers["Content-Type"] = "application/json" + resp.headers["X-Content-Encoding"] = "base64" + else: + # ---------- Binary (streaming) ---------- + resp = make_response(send_file(path, mimetype=mime_type)) + resp.headers["X-Content-Encoding"] = "binary" + return resp class PointCloud(Resource): """RESTful resource for serving and optionally downsampling point cloud data. @@ -3107,7 +3171,10 @@ def post(self, scan_id, **kwargs): # Create the new scan dataset that will receive the files from the archive self.logger.debug(f"REST API path to fsdb is '{self.db.path()}'...") - scan_path = Path(self.db.create_scan(scan_id, **kwargs).path()) + try: + scan_path = Path(self.db.create_scan(scan_id, **kwargs).path()) + except PermissionError as e: + return {'message': f'Invalid credentials: {str(e)}'}, 401 self.logger.debug(f"Exporting archive contents to '{scan_path}'...") # Detect a lone top level dir to remove from later file extraction @@ -3271,6 +3338,9 @@ def post(self, **kwargs): scan.set_metadata(metadata, **kwargs) return {'message': f"Scan '{scan_id}' created successfully."}, 201 + except SessionValidationError as e: + return {'message': f'Invalid credentials: {str(e)}'}, 401 + except Exception as e: return {'message': f'Error creating scan: {str(e)}'}, 500 @@ -3433,6 +3503,9 @@ def post(self, scan_id, **kwargs): updated_metadata = scan.get_metadata() return {'metadata': updated_metadata}, 200 + except SessionValidationError as e: + return {'message': 'Invalid credentials'}, 401 + except Exception as e: self.logger.error(f'Error updating metadata: {str(e)}') return {'message': f'Error updating metadata: {str(e)}'}, 500 @@ -3606,6 +3679,9 @@ def post(self, **kwargs): "id": fs_id }, 201 + except SessionValidationError as e: + return {'message': 'Invalid credentials'}, 401 + except Exception as e: return {'message': f'Error creating fileset: {str(e)}'}, 500 @@ -3801,6 +3877,9 @@ def post(self, scan_id, fileset_id, **kwargs): updated_metadata = fileset.get_metadata() return {'metadata': updated_metadata}, 200 + except SessionValidationError as e: + return {'message': 'Invalid credentials'}, 401 + except Exception as e: self.logger.error(f'Error updating metadata: {str(e)}') return {'message': f'Error updating metadata: {str(e)}'}, 500 @@ -4013,6 +4092,9 @@ def post(self, **kwargs): 'id': f"{file_id}", }, 201 + except SessionValidationError as e: + return {'message': 'Invalid credentials'}, 401 + except Exception as e: self.logger.error(f"Error creating file: {str(e)}") return {'message': f'Error creating file: {str(e)}'}, 500 @@ -4060,7 +4142,7 @@ def get(self, scan_id, fileset_id, file_id): Notes ----- - In the URL, uou can use the `key` parameter to retrieve specific metadata keys. + In the URL, you can use the `key` parameter to retrieve specific metadata keys. Examples -------- @@ -4073,7 +4155,7 @@ def get(self, scan_id, fileset_id, file_id): >>> response = requests.get(url) >>> print(response.json()) {'metadata': {'description': 'Test file'}} - >>> # Get specific metadata key: + >>> # Get a specific metadata key: >>> response = requests.get(url+"?key=description") >>> print(response.json()) {'metadata': 'Test file'} @@ -4185,6 +4267,9 @@ def post(self, scan_id, fileset_id, file_id, **kwargs): updated_metadata = file.get_metadata() return {'metadata': updated_metadata}, 200 + except SessionValidationError as e: + return {'message': 'Invalid credentials'}, 401 + except Exception as e: self.logger.error(f'Error processing request: {str(e)}') return {'message': f'Error processing request: {str(e)}'}, 500 diff --git a/src/server/pyproject.toml b/src/server/pyproject.toml index 401edeb5..6785c28e 100644 --- a/src/server/pyproject.toml +++ b/src/server/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ ] description = "Server-side component of the ROMI plant database system." readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.8, <3.13" # 3.13 upper bound for open3d dependency license = { 'text' = "LGPL-3.0-or-later" } # to replace by "LGPL-3.0-or-later" only in next release, pending deprecation authors = [ { name = "Peter Hanappe", email = "peter.hanappe@sony.com" }, @@ -75,4 +75,5 @@ test = [ "coverage[toml]", "networkx", "tifffile", + "open3d>=0.9.0.0" ] \ No newline at end of file diff --git a/src/server/tests/test_rest_api_server.py b/src/server/tests/test_rest_api_server.py index a62a470c..e1de9c38 100644 --- a/src/server/tests/test_rest_api_server.py +++ b/src/server/tests/test_rest_api_server.py @@ -44,7 +44,7 @@ def test_login_logout_endpoints(self): r = requests.post(self.server.get_base_url() + '/login', json={'username': 'anonymous', 'password': 'AlanMoore'}) self.assertEqual(r.status_code, 401) - # First attempt without login should fail + # The first attempt without a login should fail r = requests.post(self.server.get_base_url() + '/logout') self.assertEqual(r.status_code, 401) @@ -85,7 +85,6 @@ def test_scan_get(self): r = requests.get(self.server.get_base_url() + f"/api/scan/{scan_id}/metadata") self.assertEqual(r.status_code, 200) info = r.json() - print(info) self.assertIn("metadata", info) self.assertIn("owner", info['metadata']) @@ -173,7 +172,7 @@ def test_archive_endpoint(self): r = requests.post(self.server.get_base_url() + f'/archive/{new_dataset}', files=files) self.assertEqual(r.status_code, 401) - # Ensure file pointer is at the beginning before second request + # Ensure the file pointer is at the beginning before the second request zip_f.seek(0) # Test POST with proper authentication