Skip to content
37 changes: 37 additions & 0 deletions src/jabs/pose_estimation/pose_est.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import enum
import logging
from abc import ABC, abstractmethod
from pathlib import Path

Expand Down Expand Up @@ -112,6 +113,8 @@ class KeypointIndex(enum.IntEnum):
),
)

_CACHE_FILE_VERSION = 1

def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30):
"""initialize new object from h5 file

Expand All @@ -137,6 +140,16 @@ def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30

self._static_objects = {}

# check cache version, if it doesn't match, clear the cache file for this pose file
if self._cache_dir is not None and not self.check_cache_version():
cache_file = self._cache_file_path()
if cache_file and cache_file.exists():
try:
cache_file.unlink()
except Exception:
logging.warning("Unable to delete old cache file %s", cache_file)
pass

@property
def num_frames(self) -> int:
"""return the number of frames in the pose_est file"""
Expand Down Expand Up @@ -377,3 +390,27 @@ def identity_index_to_display(self, identity_index: int) -> str:
if self.external_identities and 0 <= identity_index < len(self.external_identities):
return self.external_identities[identity_index]
return str(identity_index)

def check_cache_version(self) -> bool:
"""Check if the cache version matches the expected version.

Returns:
bool: True if the cache version matches, False otherwise.
"""
try:
with h5py.File(self._cache_file_path(), "r") as cache_h5:
cache_version = cache_h5.attrs.get("cache_file_version", None)
return cache_version == self._CACHE_FILE_VERSION
except Exception:
return False

def _cache_file_path(self) -> Path | None:
"""Get the path to the cache file for this pose file.

Returns:
Path | None: The path to the cache file, or None if no cache directory is set.
"""
if self._cache_dir is None:
return None
filename = self._path.name.replace(".h5", "_cache.h5")
return self._cache_dir / filename
16 changes: 6 additions & 10 deletions src/jabs/pose_estimation/pose_est_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ class PoseEstimationV3(PoseEstimation):
get_identity_point_mask(identity): Get the point mask array for a given identity.
"""

__CACHE_FILE_VERSION = 2
# super class handles validating cache file version and will delete
# if it doesn't match expected version so it will get regenerated
# bumping this version to force v3 pose file cache regeneration only
_CACHE_FILE_VERSION = 2

def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30):
super().__init__(file_path, cache_dir, fps)
Expand All @@ -47,18 +50,11 @@ def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30
# to speedup reopening the pose file later, we'll cache the transformed
# pose file in the project dir
if cache_dir is not None:
filename = self._path.name.replace(".h5", "_cache.h5")
cache_file_path = self._cache_dir / filename
cache_file_path = self._cache_file_path()
use_cache = True

try:
with h5py.File(cache_file_path, "r") as cache_h5:
if cache_h5.attrs["version"] != self.__CACHE_FILE_VERSION:
# cache file version is not what we expect, raise
# exception so we will revert to reading source pose
# file
raise _CacheFileVersion

if cache_h5.attrs["source_pose_hash"] != self._hash:
raise PoseHashException

Expand Down Expand Up @@ -149,7 +145,7 @@ def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30

if self._cache_dir is not None:
with h5py.File(cache_file_path, "w") as cache_h5:
cache_h5.attrs["version"] = self.__CACHE_FILE_VERSION
cache_h5.attrs["cache_file_version"] = self._CACHE_FILE_VERSION
cache_h5.attrs["source_pose_hash"] = self.hash
group = cache_h5.create_group("poseest")
if self._cm_per_pixel is not None:
Expand Down
16 changes: 5 additions & 11 deletions src/jabs/pose_estimation/pose_est_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class PoseEstimationV4(PoseEstimation):
get_identity_point_mask(identity): Get the point mask array for a given identity.
"""

__CACHE_FILE_VERSION = 4
# bump to force regeneration of pose cache files for v4 or any subclass
_CACHE_FILE_VERSION = 4

def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30):
super().__init__(file_path, cache_dir, fps)
Expand All @@ -51,7 +52,7 @@ def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30
try:
self._load_from_cache()
use_cache = True
except (OSError, KeyError, _CacheFileVersion, PoseHashException):
except (OSError, KeyError, PoseHashException):
# if load_from_cache() raises an exception, we'll read from
# the source pose file below because use_cache will still be
# set to false, just ignore the exceptions here
Expand Down Expand Up @@ -250,16 +251,9 @@ def _load_from_cache(self):
Returns:
None
"""
filename = self._path.name.replace(".h5", "_cache.h5")
cache_file_path = self._cache_dir / filename
cache_file_path = self._cache_file_path()

with h5py.File(cache_file_path, "r") as cache_h5:
if cache_h5.attrs["version"] != self.__CACHE_FILE_VERSION:
# cache file version is not what we expect, raise
# exception so we will revert to reading source pose
# file
raise _CacheFileVersion

if cache_h5.attrs["source_pose_hash"] != self._hash:
raise PoseHashException

Expand Down Expand Up @@ -297,7 +291,7 @@ def _cache_poses(self):
cache_file_path = self._cache_dir / filename

with h5py.File(cache_file_path, "w") as cache_h5:
cache_h5.attrs["version"] = self.__CACHE_FILE_VERSION
cache_h5.attrs["cache_file_version"] = self._CACHE_FILE_VERSION
cache_h5.attrs["source_pose_hash"] = self.hash
cache_h5.attrs["num_identities"] = self._num_identities
cache_h5.attrs["num_frames"] = self._num_frames
Expand Down
2 changes: 1 addition & 1 deletion src/jabs/pose_estimation/pose_est_v6.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30
"seg_data": None,
}

# open the hdf5 pose file and extract segmentation data.
# open the hdf5 pose file and extract segmentation data, this is not cached
with h5py.File(self._path, "r") as pose_h5:
for seg_key in set(pose_h5["poseest"].keys()) & set(self._segmentation_dict.keys()):
self._segmentation_dict[seg_key] = pose_h5[f"poseest/{seg_key}"][:]
Expand Down
4 changes: 4 additions & 0 deletions src/jabs/pose_estimation/pose_est_v8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class PoseEstimationV8(PoseEstimationV7):
Adds bounding box support.
"""

# force a bump in cache file version if either parent class or this class changes
_CACHE_FILE_VERSION = PoseEstimationV7._CACHE_FILE_VERSION + 1

def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30) -> None:
super().__init__(file_path, cache_dir, fps)
self._has_bounding_boxes = False
Expand Down Expand Up @@ -112,6 +115,7 @@ def _load_from_h5(self, cache_dir: Path | None) -> None:
filename = self._path.name.replace(".h5", "_cache.h5")
cache_file_path = self._cache_dir / filename
with h5py.File(cache_file_path, "a") as cache_h5:
cache_h5.attrs["cache_file_version"] = self._CACHE_FILE_VERSION
grp = cache_h5.require_group("poseest")
if "bboxes" in grp:
del grp["bboxes"]
Expand Down
12 changes: 12 additions & 0 deletions src/jabs/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,18 @@ def get_labeled_features(
"groups": np.concatenate(all_groups),
}, group_mapping

def clear_cache(self):
"""clear the cache directory for this project"""
if self._paths.cache_dir is not None:
for f in self._paths.cache_dir.glob("*"):
try:
if f.is_dir():
shutil.rmtree(f)
else:
f.unlink()
except OSError:
pass

def __has_pose(self, vid: str):
"""check to see if a video has a corresponding pose file"""
path = self._paths.project_dir / vid
Expand Down
5 changes: 5 additions & 0 deletions src/jabs/ui/central_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ def update_behavior_search_query(self, search_query) -> None:
"""Update the search query for the search bar widget"""
self._search_bar_widget.update_search(search_query)

@property
def loaded_video(self) -> Path | None:
"""get the currently loaded video path"""
return self._loaded_video

@property
def overlay_annotations_enabled(self) -> bool:
"""get the annotation overlay enabled status from player widget."""
Expand Down
41 changes: 41 additions & 0 deletions src/jabs/ui/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ def __init__(self, app_name: str, app_name_long: str, *args, **kwargs) -> None:
)
app_menu.addAction(session_tracking_action)

# clear cache action
self._clear_cache = QtGui.QAction("Clear Project Cache", self)
self._clear_cache.setStatusTip("Clear Project Cache")
self._clear_cache.setEnabled(False)
self._clear_cache.triggered.connect(self._clear_cache_action)
app_menu.addAction(self._clear_cache)

# exit action
exit_action = QtGui.QAction(f" &Quit {self._app_name}", self)
exit_action.setShortcut(QtGui.QKeySequence("Ctrl+Q"))
Expand Down Expand Up @@ -149,6 +156,7 @@ def __init__(self, app_name: str, app_name_long: str, *args, **kwargs) -> None:
file_menu.addAction(self._prune_action)

# Setup View Menu

# video playlist menu item
self.view_playlist = QtGui.QAction("View Playlist", self)
self.view_playlist.setCheckable(True)
Expand Down Expand Up @@ -710,6 +718,7 @@ def _project_loaded_callback(self) -> None:
self.enable_segmentation_features.setEnabled(
self._project.feature_manager.can_use_segmentation_features
)
self._clear_cache.setEnabled(True)
available_objects = self._project.feature_manager.static_objects
for static_object, menu_item in self.enable_landmark_features.items():
if static_object in available_objects:
Expand Down Expand Up @@ -751,6 +760,38 @@ def show_license_dialog(self) -> QtWidgets.QDialog.DialogCode:

return QtWidgets.QDialog.DialogCode(result)

# show dialog

def _clear_cache_action(self):
"""Clear the cache for the current project. Opens a dialog to get user confirmation first."""
app = QtWidgets.QApplication.instance()
dont_use_native_dialogs = QtWidgets.QApplication.instance().testAttribute(
Qt.ApplicationAttribute.AA_DontUseNativeDialogs
)

# if app is currently set to use native dialogs, we will temporarily set it to use Qt dialogs
# the native style, at least on macOS, is not ideal so we'll force the Qt dialog instead
if not dont_use_native_dialogs:
app.setAttribute(Qt.ApplicationAttribute.AA_DontUseNativeDialogs, True)

result = QtWidgets.QMessageBox.warning(
self,
"Clear Cache",
"Are you sure you want to clear the project cache?",
QtWidgets.QMessageBox.StandardButton.Yes | QtWidgets.QMessageBox.StandardButton.No,
QtWidgets.QMessageBox.StandardButton.No,
)

# restore the original setting
app.setAttribute(Qt.ApplicationAttribute.AA_DontUseNativeDialogs, dont_use_native_dialogs)

if result == QtWidgets.QMessageBox.StandardButton.Yes:
self._project.clear_cache()
# need to reload the current video to force the pose file to reload
if self._central_widget.loaded_video:
self._central_widget.load_video(self._central_widget.loaded_video)
self.display_status_message("Cache cleared", 3000)

def _update_recent_projects(self) -> None:
"""update the contents of the Recent Projects menu"""
self._open_recent_menu.clear()
Expand Down