diff --git a/training_pipeline/predict.py b/training_pipeline/predict.py index 27886f7..37e34cd 100644 --- a/training_pipeline/predict.py +++ b/training_pipeline/predict.py @@ -168,7 +168,7 @@ def load_trained_model(model_path, device="cuda"): Loaded U-Net model """ model = UNet(n_channels=3, n_classes=1) - model.load_state_dict(torch.load(model_path, map_location=device)) + model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True)) model.to(device) return model