diff --git a/src/jabs/pose_estimation/pose_est.py b/src/jabs/pose_estimation/pose_est.py index 47ca9d89..e3560155 100644 --- a/src/jabs/pose_estimation/pose_est.py +++ b/src/jabs/pose_estimation/pose_est.py @@ -1,4 +1,5 @@ import enum +import logging from abc import ABC, abstractmethod from pathlib import Path @@ -112,6 +113,8 @@ class KeypointIndex(enum.IntEnum): ), ) + _CACHE_FILE_VERSION = 1 + def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30): """initialize new object from h5 file @@ -137,6 +140,16 @@ def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30 self._static_objects = {} + # check cache version, if it doesn't match, clear the cache file for this pose file + if self._cache_dir is not None and not self.check_cache_version(): + cache_file = self._cache_file_path() + if cache_file and cache_file.exists(): + try: + cache_file.unlink() + except Exception: + logging.warning("Unable to delete old cache file %s", cache_file) + pass + @property def num_frames(self) -> int: """return the number of frames in the pose_est file""" @@ -377,3 +390,27 @@ def identity_index_to_display(self, identity_index: int) -> str: if self.external_identities and 0 <= identity_index < len(self.external_identities): return self.external_identities[identity_index] return str(identity_index) + + def check_cache_version(self) -> bool: + """Check if the cache version matches the expected version. + + Returns: + bool: True if the cache version matches, False otherwise. + """ + try: + with h5py.File(self._cache_file_path(), "r") as cache_h5: + cache_version = cache_h5.attrs.get("cache_file_version", None) + return cache_version == self._CACHE_FILE_VERSION + except Exception: + return False + + def _cache_file_path(self) -> Path | None: + """Get the path to the cache file for this pose file. + + Returns: + Path | None: The path to the cache file, or None if no cache directory is set. + """ + if self._cache_dir is None: + return None + filename = self._path.name.replace(".h5", "_cache.h5") + return self._cache_dir / filename diff --git a/src/jabs/pose_estimation/pose_est_v3.py b/src/jabs/pose_estimation/pose_est_v3.py index 1155c2da..2fecd8cd 100644 --- a/src/jabs/pose_estimation/pose_est_v3.py +++ b/src/jabs/pose_estimation/pose_est_v3.py @@ -34,7 +34,10 @@ class PoseEstimationV3(PoseEstimation): get_identity_point_mask(identity): Get the point mask array for a given identity. """ - __CACHE_FILE_VERSION = 2 + # super class handles validating cache file version and will delete + # if it doesn't match expected version so it will get regenerated + # bumping this version to force v3 pose file cache regeneration only + _CACHE_FILE_VERSION = 2 def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30): super().__init__(file_path, cache_dir, fps) @@ -47,18 +50,11 @@ def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30 # to speedup reopening the pose file later, we'll cache the transformed # pose file in the project dir if cache_dir is not None: - filename = self._path.name.replace(".h5", "_cache.h5") - cache_file_path = self._cache_dir / filename + cache_file_path = self._cache_file_path() use_cache = True try: with h5py.File(cache_file_path, "r") as cache_h5: - if cache_h5.attrs["version"] != self.__CACHE_FILE_VERSION: - # cache file version is not what we expect, raise - # exception so we will revert to reading source pose - # file - raise _CacheFileVersion - if cache_h5.attrs["source_pose_hash"] != self._hash: raise PoseHashException @@ -149,7 +145,7 @@ def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30 if self._cache_dir is not None: with h5py.File(cache_file_path, "w") as cache_h5: - cache_h5.attrs["version"] = self.__CACHE_FILE_VERSION + cache_h5.attrs["cache_file_version"] = self._CACHE_FILE_VERSION cache_h5.attrs["source_pose_hash"] = self.hash group = cache_h5.create_group("poseest") if self._cm_per_pixel is not None: diff --git a/src/jabs/pose_estimation/pose_est_v4.py b/src/jabs/pose_estimation/pose_est_v4.py index e5eca252..9a3e91ec 100644 --- a/src/jabs/pose_estimation/pose_est_v4.py +++ b/src/jabs/pose_estimation/pose_est_v4.py @@ -37,7 +37,8 @@ class PoseEstimationV4(PoseEstimation): get_identity_point_mask(identity): Get the point mask array for a given identity. """ - __CACHE_FILE_VERSION = 4 + # bump to force regeneration of pose cache files for v4 or any subclass + _CACHE_FILE_VERSION = 4 def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30): super().__init__(file_path, cache_dir, fps) @@ -51,7 +52,7 @@ def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30 try: self._load_from_cache() use_cache = True - except (OSError, KeyError, _CacheFileVersion, PoseHashException): + except (OSError, KeyError, PoseHashException): # if load_from_cache() raises an exception, we'll read from # the source pose file below because use_cache will still be # set to false, just ignore the exceptions here @@ -250,16 +251,9 @@ def _load_from_cache(self): Returns: None """ - filename = self._path.name.replace(".h5", "_cache.h5") - cache_file_path = self._cache_dir / filename + cache_file_path = self._cache_file_path() with h5py.File(cache_file_path, "r") as cache_h5: - if cache_h5.attrs["version"] != self.__CACHE_FILE_VERSION: - # cache file version is not what we expect, raise - # exception so we will revert to reading source pose - # file - raise _CacheFileVersion - if cache_h5.attrs["source_pose_hash"] != self._hash: raise PoseHashException @@ -297,7 +291,7 @@ def _cache_poses(self): cache_file_path = self._cache_dir / filename with h5py.File(cache_file_path, "w") as cache_h5: - cache_h5.attrs["version"] = self.__CACHE_FILE_VERSION + cache_h5.attrs["cache_file_version"] = self._CACHE_FILE_VERSION cache_h5.attrs["source_pose_hash"] = self.hash cache_h5.attrs["num_identities"] = self._num_identities cache_h5.attrs["num_frames"] = self._num_frames diff --git a/src/jabs/pose_estimation/pose_est_v6.py b/src/jabs/pose_estimation/pose_est_v6.py index e1dbbc14..61f4e780 100644 --- a/src/jabs/pose_estimation/pose_est_v6.py +++ b/src/jabs/pose_estimation/pose_est_v6.py @@ -40,7 +40,7 @@ def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30 "seg_data": None, } - # open the hdf5 pose file and extract segmentation data. + # open the hdf5 pose file and extract segmentation data, this is not cached with h5py.File(self._path, "r") as pose_h5: for seg_key in set(pose_h5["poseest"].keys()) & set(self._segmentation_dict.keys()): self._segmentation_dict[seg_key] = pose_h5[f"poseest/{seg_key}"][:] diff --git a/src/jabs/pose_estimation/pose_est_v8.py b/src/jabs/pose_estimation/pose_est_v8.py index 5b8d138c..5dcdc4ee 100644 --- a/src/jabs/pose_estimation/pose_est_v8.py +++ b/src/jabs/pose_estimation/pose_est_v8.py @@ -14,6 +14,9 @@ class PoseEstimationV8(PoseEstimationV7): Adds bounding box support. """ + # force a bump in cache file version if either parent class or this class changes + _CACHE_FILE_VERSION = PoseEstimationV7._CACHE_FILE_VERSION + 1 + def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30) -> None: super().__init__(file_path, cache_dir, fps) self._has_bounding_boxes = False @@ -112,6 +115,7 @@ def _load_from_h5(self, cache_dir: Path | None) -> None: filename = self._path.name.replace(".h5", "_cache.h5") cache_file_path = self._cache_dir / filename with h5py.File(cache_file_path, "a") as cache_h5: + cache_h5.attrs["cache_file_version"] = self._CACHE_FILE_VERSION grp = cache_h5.require_group("poseest") if "bboxes" in grp: del grp["bboxes"] diff --git a/src/jabs/project/project.py b/src/jabs/project/project.py index 9300617e..dfaf164d 100644 --- a/src/jabs/project/project.py +++ b/src/jabs/project/project.py @@ -617,6 +617,18 @@ def get_labeled_features( "groups": np.concatenate(all_groups), }, group_mapping + def clear_cache(self): + """clear the cache directory for this project""" + if self._paths.cache_dir is not None: + for f in self._paths.cache_dir.glob("*"): + try: + if f.is_dir(): + shutil.rmtree(f) + else: + f.unlink() + except OSError: + pass + def __has_pose(self, vid: str): """check to see if a video has a corresponding pose file""" path = self._paths.project_dir / vid diff --git a/src/jabs/ui/central_widget.py b/src/jabs/ui/central_widget.py index 920bd66a..f1f0fc64 100644 --- a/src/jabs/ui/central_widget.py +++ b/src/jabs/ui/central_widget.py @@ -151,6 +151,11 @@ def update_behavior_search_query(self, search_query) -> None: """Update the search query for the search bar widget""" self._search_bar_widget.update_search(search_query) + @property + def loaded_video(self) -> Path | None: + """get the currently loaded video path""" + return self._loaded_video + @property def overlay_annotations_enabled(self) -> bool: """get the annotation overlay enabled status from player widget.""" diff --git a/src/jabs/ui/main_window.py b/src/jabs/ui/main_window.py index e013cfb1..c081a581 100644 --- a/src/jabs/ui/main_window.py +++ b/src/jabs/ui/main_window.py @@ -106,6 +106,13 @@ def __init__(self, app_name: str, app_name_long: str, *args, **kwargs) -> None: ) app_menu.addAction(session_tracking_action) + # clear cache action + self._clear_cache = QtGui.QAction("Clear Project Cache", self) + self._clear_cache.setStatusTip("Clear Project Cache") + self._clear_cache.setEnabled(False) + self._clear_cache.triggered.connect(self._clear_cache_action) + app_menu.addAction(self._clear_cache) + # exit action exit_action = QtGui.QAction(f" &Quit {self._app_name}", self) exit_action.setShortcut(QtGui.QKeySequence("Ctrl+Q")) @@ -149,6 +156,7 @@ def __init__(self, app_name: str, app_name_long: str, *args, **kwargs) -> None: file_menu.addAction(self._prune_action) # Setup View Menu + # video playlist menu item self.view_playlist = QtGui.QAction("View Playlist", self) self.view_playlist.setCheckable(True) @@ -710,6 +718,7 @@ def _project_loaded_callback(self) -> None: self.enable_segmentation_features.setEnabled( self._project.feature_manager.can_use_segmentation_features ) + self._clear_cache.setEnabled(True) available_objects = self._project.feature_manager.static_objects for static_object, menu_item in self.enable_landmark_features.items(): if static_object in available_objects: @@ -751,6 +760,38 @@ def show_license_dialog(self) -> QtWidgets.QDialog.DialogCode: return QtWidgets.QDialog.DialogCode(result) + # show dialog + + def _clear_cache_action(self): + """Clear the cache for the current project. Opens a dialog to get user confirmation first.""" + app = QtWidgets.QApplication.instance() + dont_use_native_dialogs = QtWidgets.QApplication.instance().testAttribute( + Qt.ApplicationAttribute.AA_DontUseNativeDialogs + ) + + # if app is currently set to use native dialogs, we will temporarily set it to use Qt dialogs + # the native style, at least on macOS, is not ideal so we'll force the Qt dialog instead + if not dont_use_native_dialogs: + app.setAttribute(Qt.ApplicationAttribute.AA_DontUseNativeDialogs, True) + + result = QtWidgets.QMessageBox.warning( + self, + "Clear Cache", + "Are you sure you want to clear the project cache?", + QtWidgets.QMessageBox.StandardButton.Yes | QtWidgets.QMessageBox.StandardButton.No, + QtWidgets.QMessageBox.StandardButton.No, + ) + + # restore the original setting + app.setAttribute(Qt.ApplicationAttribute.AA_DontUseNativeDialogs, dont_use_native_dialogs) + + if result == QtWidgets.QMessageBox.StandardButton.Yes: + self._project.clear_cache() + # need to reload the current video to force the pose file to reload + if self._central_widget.loaded_video: + self._central_widget.load_video(self._central_widget.loaded_video) + self.display_status_message("Cache cleared", 3000) + def _update_recent_projects(self) -> None: """update the contents of the Recent Projects menu""" self._open_recent_menu.clear()