Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 224 additions & 16 deletions src/jabs_postprocess/compare_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,16 @@ def evaluate_ground_truth(
filter_scan: List of filter (minimum duration in frames to consider real) values to test
iou_thresholds: List of intersection over union thresholds to scan
filter_ground_truth: Apply filters to ground truth data (default is only to filter predictions)
trim_time: Limit the duration in frames of videos for performance
Returns:
None, but saves the following files to results_folder:
framewise_output: Output file to save the frame-level performance plot
scan_output: Output file to save the filter scan performance plot
bout_output: Output file to save the resulting bout performance plot
trim_time: Limit the duration in frames of videos for performance
ethogram_output: Output file to save the ethogram plot comparing GT and predictions
scan_csv_output: Output file to save the scan performance data as CSV
"""
ouput_paths = generate_output_paths(results_folder)
output_paths = generate_output_paths(results_folder)

# Set default values if not provided
stitch_scan = stitch_scan or np.arange(5, 46, 5).tolist()
Expand Down Expand Up @@ -122,6 +125,14 @@ def evaluate_ground_truth(
pred_df["is_gt"] = False
all_annotations = pd.concat([gt_df, pred_df])

# Generate frame-level performance plot
framewise_plot = generate_framewise_performance_plot(gt_df, pred_df)
if output_paths["framewise_plot"] is not None:
framewise_plot.save(output_paths["framewise_plot"], height=6, width=12, dpi=300)
logging.info(
f"Frame-level performance plot saved to {output_paths['framewise_plot']}"
)

# We only want the positive examples for performance evaluation
# (but for ethogram plotting later, we'll use the full all_annotations)
performance_annotations = all_annotations[
Expand Down Expand Up @@ -150,9 +161,9 @@ def evaluate_ground_truth(
"No performance data to analyze. Ensure that the ground truth and predictions are correctly formatted and contain valid bouts."
)

if ouput_paths["scan_csv"] is not None:
performance_df.to_csv(ouput_paths["scan_csv"], index=False)
logging.info(f"Scan performance data saved to {ouput_paths['scan_csv']}")
if output_paths["scan_csv"] is not None:
performance_df.to_csv(output_paths["scan_csv"], index=False)
logging.info(f"Scan performance data saved to {output_paths['scan_csv']}")

_melted_df = pd.melt(performance_df, id_vars=["threshold", "stitch", "filter"])

Expand All @@ -172,8 +183,8 @@ def evaluate_ground_truth(
+ p9.theme_bw()
+ p9.labs(title=f"No performance data for {middle_threshold} IoU")
)
if ouput_paths["scan_plot"]:
plot.save(ouput_paths["scan_plot"], height=6, width=12, dpi=300)
if output_paths["scan_plot"]:
plot.save(output_paths["scan_plot"], height=6, width=12, dpi=300)
# Create default winning filters with first values from scan parameters
winning_filters = pd.DataFrame(
{
Expand Down Expand Up @@ -231,8 +242,8 @@ def evaluate_ground_truth(
+ p9.scale_fill_continuous(na_value=0)
)

if ouput_paths["scan_plot"]:
plot.save(ouput_paths["scan_plot"], height=6, width=12, dpi=300)
if output_paths["scan_plot"]:
plot.save(output_paths["scan_plot"], height=6, width=12, dpi=300)

# Handle case where all f1_plot values are NaN or empty
if subset_df["f1_plot"].isna().all() or len(subset_df) == 0:
Expand All @@ -254,9 +265,9 @@ def evaluate_ground_truth(
).T.reset_index(drop=True)[["stitch", "filter"]]

winning_bout_df = pd.merge(performance_df, winning_filters, on=["stitch", "filter"])
if ouput_paths["bout_csv"] is not None:
winning_bout_df.to_csv(ouput_paths["bout_csv"], index=False)
logging.info(f"Bout performance data saved to {ouput_paths['bout_csv']}")
if output_paths["bout_csv"] is not None:
winning_bout_df.to_csv(output_paths["bout_csv"], index=False)
logging.info(f"Bout performance data saved to {output_paths['bout_csv']}")

melted_winning = pd.melt(winning_bout_df, id_vars=["threshold", "stitch", "filter"])

Expand All @@ -268,9 +279,9 @@ def evaluate_ground_truth(
+ p9.geom_line()
+ p9.theme_bw()
+ p9.scale_y_continuous(limits=(0, 1))
).save(ouput_paths["bout_plot"], height=6, width=12, dpi=300)
).save(output_paths["bout_plot"], height=6, width=12, dpi=300)

if ouput_paths["ethogram"] is not None:
if output_paths["ethogram"] is not None:
# Prepare data for ethogram plot
# Use all_annotations to include both behavior (1) and not-behavior (0) states
plot_df = all_annotations.copy()
Expand Down Expand Up @@ -327,14 +338,14 @@ def evaluate_ground_truth(
)
# Adjust height based on the number of unique animal-video combinations
ethogram_plot.save(
ouput_paths["ethogram"],
output_paths["ethogram"],
height=1.5 * num_unique_combos + 2,
width=12,
dpi=300,
limitsize=False,
verbose=False,
)
logging.info(f"Ethogram plot saved to {ouput_paths['ethogram']}")
logging.info(f"Ethogram plot saved to {output_paths['ethogram']}")
else:
logger.warning(
f"No behavior instances found for behavior {behavior} after filtering for ethogram."
Expand Down Expand Up @@ -480,4 +491,201 @@ def generate_output_paths(results_folder: Path):
"ethogram": results_folder / "ethogram.png",
"scan_plot": results_folder / "scan_performance.png",
"bout_plot": results_folder / "bout_performance.png",
"framewise_plot": results_folder / "framewise_performance.png",
}


def _expand_intervals_to_frames(df):
"""Expand behavior intervals into per-frame rows."""
expanded = df.copy()
expanded["frame"] = expanded.apply(
lambda row: range(row["start"], row["start"] + row["duration"]), axis=1
)
expanded = expanded.explode("frame")
expanded = expanded.sort_values(by=["animal_idx", "frame"])
return expanded


def _compute_framewise_confusion(gt_df, pred_df):
"""Compute frame-level confusion counts (TP, TN, FP, FN) per video.

Args:
gt_df (pd.DataFrame): Ground truth intervals with columns
['video_name', 'animal_idx', 'start', 'duration', 'is_behavior'].
pred_df (pd.DataFrame): Prediction intervals with the same structure.

Returns:
pd.DataFrame: Confusion matrix counts per video with columns
['video_name', 'TP', 'TN', 'FP', 'FN'].
"""

# Expand ground truth and predictions into frame-level data
gt_frames = _expand_intervals_to_frames(gt_df)
pred_frames = _expand_intervals_to_frames(pred_df)

# Merge to align predictions and ground truth per frame
framewise = pd.merge(
gt_frames,
pred_frames,
on=["video_name", "animal_idx", "frame"],
how="left",
suffixes=("_gt", "_pred"),
)

# Compute confusion counts per video
confusion_counts = (
framewise.groupby("video_name")
.apply(
lambda x: pd.Series(
{
"TP": (
(x["is_behavior_gt"] == 1) & (x["is_behavior_pred"] == 1)
).sum(),
"TN": (
(x["is_behavior_gt"] == 0) & (x["is_behavior_pred"] == 0)
).sum(),
"FP": (
(x["is_behavior_gt"] == 0) & (x["is_behavior_pred"] == 1)
).sum(),
"FN": (
(x["is_behavior_gt"] == 1) & (x["is_behavior_pred"] == 0)
).sum(),
}
),
include_groups=False,
)
.reset_index()
)

return confusion_counts


def _find_outliers(melted_df: pd.DataFrame):
"""
Return rows flagged as outliers per metric using the IQR rule.

Args:
melted_df: long-form DataFrame with at least 'metric' and 'value' columns.

Returns:
DataFrame containing the outliers rows from the input DataFrame.
Returns an empty DataFrame with the same columns if no outliers found.
"""
outliers = []
for metric in melted_df["metric"].unique():
values = melted_df.loc[melted_df["metric"] == metric, "value"]
q1 = values.quantile(0.25)
q3 = values.quantile(0.75)
iqr = q3 - q1
lower_bound = q1 - 1.5 * iqr
upper_bound = q3 + 1.5 * iqr
outliers_df = melted_df[
(melted_df["metric"] == metric)
& ((melted_df["value"] < lower_bound) | (melted_df["value"] > upper_bound))
]
outliers.append(outliers_df)

outliers = (
pd.concat(outliers) if outliers else pd.DataFrame(columns=melted_df.columns)
)

return outliers


def generate_framewise_performance_plot(gt_df: pd.DataFrame, pred_df: pd.DataFrame):
"""
Generate and save a frame-level performance plot comparing ground truth and predicted behavior intervals.

This function:
1. Expands each interval in `gt_df` and `pred_df` to per-frame annotations.
2. Computes per-video confusion counts (TP, TN, FP, FN).
3. Calculates precision, recall, F1 score, and accuracy for each video.
4. Produces a boxplot with jitter showing the distribution of these metrics.
5. Adds an overall summary in the plot subtitle.

Args:
gt_df (pd.DataFrame): Ground truth intervals with columns
['video_name', 'animal_idx', 'start', 'duration', 'is_behavior'].
pred_df (pd.DataFrame): Prediction intervals with the same structure.

Returns:
plotnine.ggplot: A ggplot object containing the frame-level performance visualization.
"""
# Compute framewise confusion counts
confusion_counts = _compute_framewise_confusion(gt_df, pred_df)
confusion_counts["frame_total"] = (
confusion_counts["TP"]
+ confusion_counts["TN"]
+ confusion_counts["FP"]
+ confusion_counts["FN"]
)

# Compute per-video metrics
confusion_counts["precision"] = confusion_counts["TP"] / (
confusion_counts["TP"] + confusion_counts["FP"]
)
confusion_counts["recall"] = confusion_counts["TP"] / (
confusion_counts["TP"] + confusion_counts["FN"]
)
confusion_counts["f1_score"] = (
2
* (confusion_counts["precision"] * confusion_counts["recall"])
/ (confusion_counts["precision"] + confusion_counts["recall"])
)
confusion_counts["accuracy"] = (
confusion_counts["TP"] + confusion_counts["TN"]
) / confusion_counts["frame_total"]

# Compute overall (global) metrics
totals = confusion_counts[["TP", "TN", "FP", "FN"]].sum()
overall_metrics = {
"precision": totals["TP"] / (totals["TP"] + totals["FP"]),
"recall": totals["TP"] / (totals["TP"] + totals["FN"]),
"accuracy": (totals["TP"] + totals["TN"])
/ (totals["TP"] + totals["TN"] + totals["FP"] + totals["FN"]),
}
overall_metrics["f1_score"] = (
2
* (overall_metrics["precision"] * overall_metrics["recall"])
/ (overall_metrics["precision"] + overall_metrics["recall"])
)

# Melt into long format for plotting
melted_df = pd.melt(
confusion_counts,
id_vars=["video_name", "frame_total"],
value_vars=["precision", "recall", "f1_score", "accuracy"],
var_name="metric",
value_name="value",
)

outliers = _find_outliers(melted_df)
# Generate plot
subtitle_text = (
f"Precision: {overall_metrics['precision']:.2f}, "
f"Recall: {overall_metrics['recall']:.2f}, "
f"F1: {overall_metrics['f1_score']:.2f}, "
f"Accuracy: {overall_metrics['accuracy']:.2f}"
)

plot = (
p9.ggplot(melted_df, p9.aes(x="metric", y="value"))
+ p9.geom_boxplot(outlier_shape=None, fill="lightblue", alpha=0.7)
+ p9.geom_jitter(p9.aes(color="frame_total"), width=0.05, height=0)
+ p9.geom_text(
p9.aes(label="video_name"), data=outliers, ha="left", nudge_x=0.1
)
+ p9.labs(
title="Frame-level Performance Metrics",
y="Score",
x="Metric",
subtitle=subtitle_text,
)
+ p9.theme_bw()
+ p9.theme(
plot_title=p9.element_text(ha="center"), # Center the main title
plot_subtitle=p9.element_text(ha="center"), # Center the subtitle too
)
)

return plot
Loading