diff --git a/training_pipeline/predict.py b/training_pipeline/predict.py index 27886f7..6961d5b 100644 --- a/training_pipeline/predict.py +++ b/training_pipeline/predict.py @@ -117,7 +117,7 @@ def predict_batch(model, image_paths, device="cuda", threshold=0.5, checkpoint_p save_checkpoint(checkpoint_path, processed_images, predictions, metadata) print(f"Final checkpoint saved: {len(processed_images)} images processed") - return predictions + return processed_images, predictions def visualize_prediction(image_path, pred_mask, save_path=None): """