diff --git a/src/jabs_postprocess/compare_gt.py b/src/jabs_postprocess/compare_gt.py index b2322e1..1dfef70 100644 --- a/src/jabs_postprocess/compare_gt.py +++ b/src/jabs_postprocess/compare_gt.py @@ -41,13 +41,16 @@ def evaluate_ground_truth( filter_scan: List of filter (minimum duration in frames to consider real) values to test iou_thresholds: List of intersection over union thresholds to scan filter_ground_truth: Apply filters to ground truth data (default is only to filter predictions) + trim_time: Limit the duration in frames of videos for performance + Returns: + None, but saves the following files to results_folder: + framewise_output: Output file to save the frame-level performance plot scan_output: Output file to save the filter scan performance plot bout_output: Output file to save the resulting bout performance plot - trim_time: Limit the duration in frames of videos for performance ethogram_output: Output file to save the ethogram plot comparing GT and predictions scan_csv_output: Output file to save the scan performance data as CSV """ - ouput_paths = generate_output_paths(results_folder) + output_paths = generate_output_paths(results_folder) # Set default values if not provided stitch_scan = stitch_scan or np.arange(5, 46, 5).tolist() @@ -122,6 +125,14 @@ def evaluate_ground_truth( pred_df["is_gt"] = False all_annotations = pd.concat([gt_df, pred_df]) + # Generate frame-level performance plot + framewise_plot = generate_framewise_performance_plot(gt_df, pred_df) + if output_paths["framewise_plot"] is not None: + framewise_plot.save(output_paths["framewise_plot"], height=6, width=12, dpi=300) + logging.info( + f"Frame-level performance plot saved to {output_paths['framewise_plot']}" + ) + # We only want the positive examples for performance evaluation # (but for ethogram plotting later, we'll use the full all_annotations) performance_annotations = all_annotations[ @@ -150,9 +161,9 @@ def evaluate_ground_truth( "No performance data to analyze. Ensure that the ground truth and predictions are correctly formatted and contain valid bouts." ) - if ouput_paths["scan_csv"] is not None: - performance_df.to_csv(ouput_paths["scan_csv"], index=False) - logging.info(f"Scan performance data saved to {ouput_paths['scan_csv']}") + if output_paths["scan_csv"] is not None: + performance_df.to_csv(output_paths["scan_csv"], index=False) + logging.info(f"Scan performance data saved to {output_paths['scan_csv']}") _melted_df = pd.melt(performance_df, id_vars=["threshold", "stitch", "filter"]) @@ -172,8 +183,8 @@ def evaluate_ground_truth( + p9.theme_bw() + p9.labs(title=f"No performance data for {middle_threshold} IoU") ) - if ouput_paths["scan_plot"]: - plot.save(ouput_paths["scan_plot"], height=6, width=12, dpi=300) + if output_paths["scan_plot"]: + plot.save(output_paths["scan_plot"], height=6, width=12, dpi=300) # Create default winning filters with first values from scan parameters winning_filters = pd.DataFrame( { @@ -231,8 +242,8 @@ def evaluate_ground_truth( + p9.scale_fill_continuous(na_value=0) ) - if ouput_paths["scan_plot"]: - plot.save(ouput_paths["scan_plot"], height=6, width=12, dpi=300) + if output_paths["scan_plot"]: + plot.save(output_paths["scan_plot"], height=6, width=12, dpi=300) # Handle case where all f1_plot values are NaN or empty if subset_df["f1_plot"].isna().all() or len(subset_df) == 0: @@ -254,9 +265,9 @@ def evaluate_ground_truth( ).T.reset_index(drop=True)[["stitch", "filter"]] winning_bout_df = pd.merge(performance_df, winning_filters, on=["stitch", "filter"]) - if ouput_paths["bout_csv"] is not None: - winning_bout_df.to_csv(ouput_paths["bout_csv"], index=False) - logging.info(f"Bout performance data saved to {ouput_paths['bout_csv']}") + if output_paths["bout_csv"] is not None: + winning_bout_df.to_csv(output_paths["bout_csv"], index=False) + logging.info(f"Bout performance data saved to {output_paths['bout_csv']}") melted_winning = pd.melt(winning_bout_df, id_vars=["threshold", "stitch", "filter"]) @@ -268,9 +279,9 @@ def evaluate_ground_truth( + p9.geom_line() + p9.theme_bw() + p9.scale_y_continuous(limits=(0, 1)) - ).save(ouput_paths["bout_plot"], height=6, width=12, dpi=300) + ).save(output_paths["bout_plot"], height=6, width=12, dpi=300) - if ouput_paths["ethogram"] is not None: + if output_paths["ethogram"] is not None: # Prepare data for ethogram plot # Use all_annotations to include both behavior (1) and not-behavior (0) states plot_df = all_annotations.copy() @@ -327,14 +338,14 @@ def evaluate_ground_truth( ) # Adjust height based on the number of unique animal-video combinations ethogram_plot.save( - ouput_paths["ethogram"], + output_paths["ethogram"], height=1.5 * num_unique_combos + 2, width=12, dpi=300, limitsize=False, verbose=False, ) - logging.info(f"Ethogram plot saved to {ouput_paths['ethogram']}") + logging.info(f"Ethogram plot saved to {output_paths['ethogram']}") else: logger.warning( f"No behavior instances found for behavior {behavior} after filtering for ethogram." @@ -480,4 +491,201 @@ def generate_output_paths(results_folder: Path): "ethogram": results_folder / "ethogram.png", "scan_plot": results_folder / "scan_performance.png", "bout_plot": results_folder / "bout_performance.png", + "framewise_plot": results_folder / "framewise_performance.png", } + + +def _expand_intervals_to_frames(df): + """Expand behavior intervals into per-frame rows.""" + expanded = df.copy() + expanded["frame"] = expanded.apply( + lambda row: range(row["start"], row["start"] + row["duration"]), axis=1 + ) + expanded = expanded.explode("frame") + expanded = expanded.sort_values(by=["animal_idx", "frame"]) + return expanded + + +def _compute_framewise_confusion(gt_df, pred_df): + """Compute frame-level confusion counts (TP, TN, FP, FN) per video. + + Args: + gt_df (pd.DataFrame): Ground truth intervals with columns + ['video_name', 'animal_idx', 'start', 'duration', 'is_behavior']. + pred_df (pd.DataFrame): Prediction intervals with the same structure. + + Returns: + pd.DataFrame: Confusion matrix counts per video with columns + ['video_name', 'TP', 'TN', 'FP', 'FN']. + """ + + # Expand ground truth and predictions into frame-level data + gt_frames = _expand_intervals_to_frames(gt_df) + pred_frames = _expand_intervals_to_frames(pred_df) + + # Merge to align predictions and ground truth per frame + framewise = pd.merge( + gt_frames, + pred_frames, + on=["video_name", "animal_idx", "frame"], + how="left", + suffixes=("_gt", "_pred"), + ) + + # Compute confusion counts per video + confusion_counts = ( + framewise.groupby("video_name") + .apply( + lambda x: pd.Series( + { + "TP": ( + (x["is_behavior_gt"] == 1) & (x["is_behavior_pred"] == 1) + ).sum(), + "TN": ( + (x["is_behavior_gt"] == 0) & (x["is_behavior_pred"] == 0) + ).sum(), + "FP": ( + (x["is_behavior_gt"] == 0) & (x["is_behavior_pred"] == 1) + ).sum(), + "FN": ( + (x["is_behavior_gt"] == 1) & (x["is_behavior_pred"] == 0) + ).sum(), + } + ), + include_groups=False, + ) + .reset_index() + ) + + return confusion_counts + + +def _find_outliers(melted_df: pd.DataFrame): + """ + Return rows flagged as outliers per metric using the IQR rule. + + Args: + melted_df: long-form DataFrame with at least 'metric' and 'value' columns. + + Returns: + DataFrame containing the outliers rows from the input DataFrame. + Returns an empty DataFrame with the same columns if no outliers found. + """ + outliers = [] + for metric in melted_df["metric"].unique(): + values = melted_df.loc[melted_df["metric"] == metric, "value"] + q1 = values.quantile(0.25) + q3 = values.quantile(0.75) + iqr = q3 - q1 + lower_bound = q1 - 1.5 * iqr + upper_bound = q3 + 1.5 * iqr + outliers_df = melted_df[ + (melted_df["metric"] == metric) + & ((melted_df["value"] < lower_bound) | (melted_df["value"] > upper_bound)) + ] + outliers.append(outliers_df) + + outliers = ( + pd.concat(outliers) if outliers else pd.DataFrame(columns=melted_df.columns) + ) + + return outliers + + +def generate_framewise_performance_plot(gt_df: pd.DataFrame, pred_df: pd.DataFrame): + """ + Generate and save a frame-level performance plot comparing ground truth and predicted behavior intervals. + + This function: + 1. Expands each interval in `gt_df` and `pred_df` to per-frame annotations. + 2. Computes per-video confusion counts (TP, TN, FP, FN). + 3. Calculates precision, recall, F1 score, and accuracy for each video. + 4. Produces a boxplot with jitter showing the distribution of these metrics. + 5. Adds an overall summary in the plot subtitle. + + Args: + gt_df (pd.DataFrame): Ground truth intervals with columns + ['video_name', 'animal_idx', 'start', 'duration', 'is_behavior']. + pred_df (pd.DataFrame): Prediction intervals with the same structure. + + Returns: + plotnine.ggplot: A ggplot object containing the frame-level performance visualization. + """ + # Compute framewise confusion counts + confusion_counts = _compute_framewise_confusion(gt_df, pred_df) + confusion_counts["frame_total"] = ( + confusion_counts["TP"] + + confusion_counts["TN"] + + confusion_counts["FP"] + + confusion_counts["FN"] + ) + + # Compute per-video metrics + confusion_counts["precision"] = confusion_counts["TP"] / ( + confusion_counts["TP"] + confusion_counts["FP"] + ) + confusion_counts["recall"] = confusion_counts["TP"] / ( + confusion_counts["TP"] + confusion_counts["FN"] + ) + confusion_counts["f1_score"] = ( + 2 + * (confusion_counts["precision"] * confusion_counts["recall"]) + / (confusion_counts["precision"] + confusion_counts["recall"]) + ) + confusion_counts["accuracy"] = ( + confusion_counts["TP"] + confusion_counts["TN"] + ) / confusion_counts["frame_total"] + + # Compute overall (global) metrics + totals = confusion_counts[["TP", "TN", "FP", "FN"]].sum() + overall_metrics = { + "precision": totals["TP"] / (totals["TP"] + totals["FP"]), + "recall": totals["TP"] / (totals["TP"] + totals["FN"]), + "accuracy": (totals["TP"] + totals["TN"]) + / (totals["TP"] + totals["TN"] + totals["FP"] + totals["FN"]), + } + overall_metrics["f1_score"] = ( + 2 + * (overall_metrics["precision"] * overall_metrics["recall"]) + / (overall_metrics["precision"] + overall_metrics["recall"]) + ) + + # Melt into long format for plotting + melted_df = pd.melt( + confusion_counts, + id_vars=["video_name", "frame_total"], + value_vars=["precision", "recall", "f1_score", "accuracy"], + var_name="metric", + value_name="value", + ) + + outliers = _find_outliers(melted_df) + # Generate plot + subtitle_text = ( + f"Precision: {overall_metrics['precision']:.2f}, " + f"Recall: {overall_metrics['recall']:.2f}, " + f"F1: {overall_metrics['f1_score']:.2f}, " + f"Accuracy: {overall_metrics['accuracy']:.2f}" + ) + + plot = ( + p9.ggplot(melted_df, p9.aes(x="metric", y="value")) + + p9.geom_boxplot(outlier_shape=None, fill="lightblue", alpha=0.7) + + p9.geom_jitter(p9.aes(color="frame_total"), width=0.05, height=0) + + p9.geom_text( + p9.aes(label="video_name"), data=outliers, ha="left", nudge_x=0.1 + ) + + p9.labs( + title="Frame-level Performance Metrics", + y="Score", + x="Metric", + subtitle=subtitle_text, + ) + + p9.theme_bw() + + p9.theme( + plot_title=p9.element_text(ha="center"), # Center the main title + plot_subtitle=p9.element_text(ha="center"), # Center the subtitle too + ) + ) + + return plot diff --git a/tests/test_compare_gt.py b/tests/test_compare_gt.py index 56fdbfc..f92a3ac 100644 --- a/tests/test_compare_gt.py +++ b/tests/test_compare_gt.py @@ -32,9 +32,14 @@ import numpy as np import pandas as pd +import plotnine as p9 import pytest -from jabs_postprocess.compare_gt import evaluate_ground_truth, generate_iou_scan +from jabs_postprocess.compare_gt import ( + evaluate_ground_truth, + generate_iou_scan, + generate_framewise_performance_plot, +) from jabs_postprocess.utils.project_utils import ( Bouts, ) @@ -672,3 +677,88 @@ def test_generate_iou_scan_metrics_calculation(mock_metrics, expected_result): assert np.isnan(row[metric]) else: assert round(row[metric], 3) == round(expected, 3) + + +@pytest.fixture +def sample_data(): + """Create small sample GT and prediction DataFrames for testing.""" + gt_df = pd.DataFrame( + { + "video_name": ["video1", "video1", "video2"], + "animal_idx": [0, 1, 0], + "start": [0, 5, 0], + "duration": [5, 5, 10], + "is_behavior": [1, 0, 1], + } + ) + + pred_df = pd.DataFrame( + { + "video_name": ["video1", "video1", "video2"], + "animal_idx": [0, 1, 0], + "start": [0, 5, 0], + "duration": [5, 5, 10], + "is_behavior": [1, 0, 0], + } + ) + + return gt_df, pred_df + + +def test_generate_plot_runs(sample_data): + """Test that the plot function runs and returns a ggplot object.""" + gt_df, pred_df = sample_data + plot = generate_framewise_performance_plot(gt_df, pred_df) + # Check that the returned object is a ggplot + assert isinstance(plot, p9.ggplot) + + +def test_plot_metrics(sample_data): + """Test that generate_framewise_performance_plot correctly handles NaNs.""" + gt_df, pred_df = sample_data + + plot = generate_framewise_performance_plot(gt_df, pred_df) + df = plot.data.sort_values(["video_name", "metric"]).reset_index(drop=True) + + # Manually compute expected metrics + expected = [] + # Video 1: Perfect prediction + expected.append( + { + "video_name": "video1", + "precision": 1.0, + "recall": 1.0, + "f1_score": 1.0, + "accuracy": 1.0, + } + ) + # Video 2: All wrong + expected.append( + { + "video_name": "video2", + "precision": float("nan"), + "recall": 0.0, + "f1_score": float("nan"), + "accuracy": 0.0, + } + ) + + expected_df = pd.DataFrame(expected) + expected_melted = ( + pd.melt( + expected_df, + id_vars=["video_name"], + value_vars=["precision", "recall", "f1_score", "accuracy"], + var_name="metric", + value_name="value", + ) + .sort_values(["video_name", "metric"]) + .reset_index(drop=True) + ) + + # Compare numeric values, treating NaNs as equal + for a, b in zip(df["value"], expected_melted["value"]): + if pd.isna(a) and pd.isna(b): + continue + else: + assert abs(a - b) < 1e-6