Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions src/jabs/classifier/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,31 +447,63 @@ def sort_features_to_classify(self, features):
features_sorted = features[classifier_columns]
return features_sorted

def predict(self, features):
"""predict classes for a given set of features"""
def predict(self, features: dict, frame_indexes: np.ndarray | None = None) -> np.ndarray:
Copy link

Copilot AI Dec 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type hint indicates features should be a dict, but the method body treats it as a pandas DataFrame (using .replace() and .fillna() methods). The type hint should be corrected to match the actual expected type, likely pd.DataFrame.

Copilot uses AI. Check for mistakes.
"""predict classes for a given set of features

Args:
features: dictionary of feature data to classify
Copy link

Copilot AI Dec 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring describes features as a 'dictionary' but the actual parameter should be a pandas DataFrame based on the method implementation. Update the docstring to accurately reflect the expected type.

Suggested change
features: dictionary of feature data to classify
features: pandas DataFrame of feature data to classify

Copilot uses AI. Check for mistakes.
frame_indexes: frame indexes to classify (default all)

Returns:
predicted class vector
"""
if self._classifier_type == ClassifierType.XGBOOST:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=FutureWarning)
result = self._classifier.predict(
self.sort_features_to_classify(features.replace([np.inf, -np.inf], 0))
)
return result
# Random forests and gradient boost can't handle NAs & infs, so fill them with 0s
return self._classifier.predict(
self.sort_features_to_classify(features.replace([np.inf, -np.inf], 0).fillna(0))
)
else:
# Random forests and gradient boost can't handle NAs & infs, so fill them with 0s
result = self._classifier.predict(
self.sort_features_to_classify(features.replace([np.inf, -np.inf], 0).fillna(0))
)

# Insert -1s into class prediction when no prediction is made
if frame_indexes is not None:
result_adjusted = np.full(result.shape, -1, dtype=np.int8)
result_adjusted[frame_indexes] = result[frame_indexes]
result = result_adjusted

return result

def predict_proba(self, features):
"""predict probabilities for a given set of features"""
def predict_proba(self, features: dict, frame_indexes: np.ndarray | None = None) -> np.ndarray:
Copy link

Copilot AI Dec 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type hint indicates features should be a dict, but the method body treats it as a pandas DataFrame (using .replace() and .fillna() methods). The type hint should be corrected to match the actual expected type, likely pd.DataFrame.

Copilot uses AI. Check for mistakes.
"""predict probabilities for a given set of features.

Args:
features: dictionary of feature data to classify
Copy link

Copilot AI Dec 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring describes features as a 'dictionary' but the actual parameter should be a pandas DataFrame based on the method implementation. Update the docstring to accurately reflect the expected type.

Copilot uses AI. Check for mistakes.
frame_indexes: frame indexes to classify (default all)

Returns:
prediction probability matrix
"""
if self._classifier_type == ClassifierType.XGBOOST:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=FutureWarning)
result = self._classifier.predict_proba(self.sort_features_to_classify(features))
return result
# Random forests and gradient boost can't handle NAs & infs, so fill them with 0s
return self._classifier.predict_proba(
self.sort_features_to_classify(features.replace([np.inf, -np.inf], 0).fillna(0))
)
else:
# Random forests and gradient boost can't handle NAs & infs, so fill them with 0s
result = self._classifier.predict_proba(
self.sort_features_to_classify(features.replace([np.inf, -np.inf], 0).fillna(0))
)

# Insert 0 probabilities when no prediction is made
if frame_indexes is not None:
result_adjusted = np.full(result.shape, 0, dtype=np.float32)
result_adjusted[frame_indexes] = result[frame_indexes]
result = result_adjusted

return result

def save(self, path: Path):
"""save the classifier to a file
Expand Down
31 changes: 3 additions & 28 deletions src/jabs/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,6 @@ def save_predictions(
video_name: str,
predictions: dict[int, np.ndarray],
probabilities: dict[int, np.ndarray],
frame_indexes: dict[int, np.ndarray],
behavior: str,
classifier: object,
) -> None:
Expand All @@ -443,25 +442,8 @@ def save_predictions(
video_name: name of the video these predictions correspond to.
predictions: dict mapping identity to a 1D numpy array of predicted labels.
probabilities: same structure as `predictions` but with floating-point values.
frame_indexes: dict mapping identity to 1D numpy array of absolute frame indices
listing the frames where the identity has a valid pose (i.e., frames with a meaningful prediction).
behavior: string behavior name.
classifier: Classifier object used to generate the predictions.

Note:
Currently, the classifier runs on every frame for every identity -- even when pose is invalid
and features are NaN. We copy values for *only* the frames with a valid pose. This is why we
index *both* the source and destination with `indexes` (an array with the absolute frame indices
of frames with a valid pose), e.g.:

prediction_labels[identity, indexes] = predictions[video][identity][indexes]
prediction_prob[identity, indexes] = probabilities[video][identity][indexes]

This leaves the output arrays with default values (-1 for labels, 0.0 for probabilities) for frames
without pose.

In the future, if the upstream caller were to provide compact arrays of length `len(indexes)`
instead of full-length arrays, the copy logic would need to drop the indexing on the source side.
"""
# set up an output filename based on the video names
file_base = Path(video_name).with_suffix("").name + ".h5"
Expand All @@ -473,17 +455,10 @@ def save_predictions(
)
prediction_prob = np.zeros_like(prediction_labels, dtype=np.float32)

# populate numpy arrays
# stack the numpy arrays
for identity in predictions:
indexes = frame_indexes[identity]

# 'indexes' are absolute frame indices where this identity has a valid pose.
# predictions[identity] and probabilities[identity] are full-length arrays
# (len == num_frames); however, only elements at 'indexes' contain meaningful values.
# We index both source and destination with 'indexes' to copy only those valid-pose frames.
# If upstream ever provides compact arrays instead, drop the source-side indexing.
prediction_labels[identity, indexes] = predictions[identity][indexes]
prediction_prob[identity, indexes] = probabilities[identity][indexes]
prediction_labels[identity] = predictions[identity]
prediction_prob[identity] = probabilities[identity]

# write to h5 file
self._prediction_manager.write_predictions(
Expand Down
14 changes: 5 additions & 9 deletions src/jabs/scripts/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,18 @@ def classify_pose(
data = Classifier.combine_data(per_frame_features, window_features)

if data.shape[0] > 0:
pred = classifier.predict(data)
pred_prob = classifier.predict_proba(data)
pred = classifier.predict(data, features["frame_indexes"])
pred_prob = classifier.predict_proba(data, features["frame_indexes"])

# Keep the probability for the predicted class only.
# The following code uses some
# numpy magic to use the pred array as column indexes
# for each row of the pred_prob array we just computed.
pred_prob = pred_prob[np.arange(len(pred_prob)), pred]

# Only copy out predictions where there was a valid pose
prediction_labels[curr_id, features["frame_indexes"]] = pred[
features["frame_indexes"]
]
prediction_prob[curr_id, features["frame_indexes"]] = pred_prob[
features["frame_indexes"]
]
# Copy results into results matrix
prediction_labels[curr_id] = pred
prediction_prob[curr_id] = pred_prob
progress.update(task, advance=1)

print(f"Writing predictions to {out_dir}")
Expand Down
1 change: 0 additions & 1 deletion src/jabs/ui/central_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,6 @@ def _classify_thread_complete(self, output: dict) -> None:
# display the new predictions
self._predictions = output["predictions"]
self._probabilities = output["probabilities"]
self._frame_indexes = output["frame_indexes"]
self._cleanup_progress_dialog()
self._cleanup_classify_thread()
self.status_message.emit("Classification Complete", 3000)
Expand Down
20 changes: 6 additions & 14 deletions src/jabs/ui/classification_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def run(self) -> None:
self._tasks_complete = 0
current_video_predictions = {}
current_video_probabilities = {}
current_video_frame_indexes = {}

def check_termination_requested() -> None:
if self._should_terminate:
Expand All @@ -104,7 +103,6 @@ def check_termination_requested() -> None:
# collect predictions, probabilities, and frame indexes for each identity in the video
predictions = {}
probabilities = {}
frame_indexes = {}

for identity in pose_est.identities:
check_termination_requested()
Expand Down Expand Up @@ -136,31 +134,27 @@ def check_termination_requested() -> None:
check_termination_requested()
if data.shape[0] > 0:
# make predictions
# Note: this makes predictions for all frames in the video, even those without valid pose
# We will later filter these out when saving the predictions to disk
# consider changing this to only predict on frames with valid pose
predictions[identity] = self._classifier.predict(data)
predictions[identity] = self._classifier.predict(
data, feature_values["frame_indexes"]
)

# also get the probabilities
prob = self._classifier.predict_proba(data)
prob = self._classifier.predict_proba(
data, feature_values["frame_indexes"]
)
# Save the probability for the predicted class only.
# The following code uses some
# numpy magic to use the _predictions array as column indexes
# for each row of the 'prob' array we just computed.
probabilities[identity] = prob[np.arange(len(prob)), predictions[identity]]

# save the indexes for the predicted frames
frame_indexes[identity] = feature_values["frame_indexes"]
else:
predictions[identity] = np.array(0)
probabilities[identity] = np.array(0)
frame_indexes[identity] = np.array(0)

if video == self._current_video:
# keep predictions for the video currently loaded in the video player
current_video_predictions = predictions.copy()
current_video_probabilities = probabilities.copy()
current_video_frame_indexes = frame_indexes.copy()

# save predictions to disk
self.current_status.emit("Saving Predictions")
Expand All @@ -169,7 +163,6 @@ def check_termination_requested() -> None:
video,
predictions,
probabilities,
frame_indexes,
self._behavior,
self._classifier,
)
Expand All @@ -183,7 +176,6 @@ def check_termination_requested() -> None:
{
"predictions": current_video_predictions,
"probabilities": current_video_probabilities,
"frame_indexes": current_video_frame_indexes,
}
)
except Exception as e:
Expand Down