Skip to content

Commit

Permalink
sync util.
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul committed Jan 24, 2025
1 parent d63fbcf commit 5529264
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
11 changes: 8 additions & 3 deletions finetrainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@
)
from .utils.file_utils import string_to_filename
from .utils.hub_utils import save_model_card
from .utils.memory_utils import free_memory, get_memory_statistics, make_contiguous, reset_memory_stats
from .utils.memory_utils import (
free_memory,
get_memory_statistics,
make_contiguous,
reset_memory_stats,
synchornize_device,
)
from .utils.model_utils import resolve_vae_cls_from_ckpt_path
from .utils.optimizer_utils import get_optimizer
from .utils.torch_utils import align_device_and_dtype, expand_tensor_dims, unwrap_model
Expand Down Expand Up @@ -1107,8 +1113,7 @@ def _delete_components(self) -> None:
self.vae = None
self.scheduler = None
free_memory()
if torch.cuda.is_available():
torch.cuda.synchronize(self.state.accelerator.device)
synchornize_device(self.state.accelerator.device)

def _get_and_prepare_pipeline_for_validation(self, final_validation: bool = False) -> DiffusionPipeline:
accelerator = self.state.accelerator
Expand Down
9 changes: 8 additions & 1 deletion finetrainers/utils/memory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,14 @@ def reset_memory_stats(device: torch.device):
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats(device)
else:
logger.warning("No CUDA, device found. Memory statistics are not available.")
logger.warning("No CUDA, device found. Nothing to reset memory of.")


def synchornize_device(device: torch.device):
if torch.cuda.is_available():
torch.cuda.synchronize(device)
else:
logger.warning("No CUDA, device found. Nothing to synchronize.")


def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
Expand Down

0 comments on commit 5529264

Please sign in to comment.