diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py index 9b04919..bb28329 100644 --- a/finetrainers/trainer.py +++ b/finetrainers/trainer.py @@ -54,7 +54,13 @@ ) from .utils.file_utils import string_to_filename from .utils.hub_utils import save_model_card -from .utils.memory_utils import free_memory, get_memory_statistics, make_contiguous, reset_memory_stats +from .utils.memory_utils import ( + free_memory, + get_memory_statistics, + make_contiguous, + reset_memory_stats, + synchornize_device, +) from .utils.model_utils import resolve_vae_cls_from_ckpt_path from .utils.optimizer_utils import get_optimizer from .utils.torch_utils import align_device_and_dtype, expand_tensor_dims, unwrap_model @@ -1107,8 +1113,7 @@ def _delete_components(self) -> None: self.vae = None self.scheduler = None free_memory() - if torch.cuda.is_available(): - torch.cuda.synchronize(self.state.accelerator.device) + synchornize_device(self.state.accelerator.device) def _get_and_prepare_pipeline_for_validation(self, final_validation: bool = False) -> DiffusionPipeline: accelerator = self.state.accelerator diff --git a/finetrainers/utils/memory_utils.py b/finetrainers/utils/memory_utils.py index 7bffdb4..b057931 100644 --- a/finetrainers/utils/memory_utils.py +++ b/finetrainers/utils/memory_utils.py @@ -58,7 +58,14 @@ def reset_memory_stats(device: torch.device): if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats(device) else: - logger.warning("No CUDA, device found. Memory statistics are not available.") + logger.warning("No CUDA, device found. Nothing to reset memory of.") + + +def synchornize_device(device: torch.device): + if torch.cuda.is_available(): + torch.cuda.synchronize(device) + else: + logger.warning("No CUDA, device found. Nothing to synchronize.") def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: