diff --git a/training_pipeline/predict.py b/training_pipeline/predict.py index 27886f7..ed3e23d 100644 --- a/training_pipeline/predict.py +++ b/training_pipeline/predict.py @@ -313,7 +313,7 @@ def main(): checkpoint_path, resume) # Save predictions - for i, (image_path, pred_mask) in enumerate(zip(valid_images, predictions)): + for image_path, pred_mask in zip(processed_images, predictions): base_name = os.path.splitext(os.path.basename(image_path))[0] mask_save_path = f"{base_name}_predicted_mask.png" Image.fromarray(pred_mask * 255).save(mask_save_path)