diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index 419e2b84f72e..db188c43a7cd 100755 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -195,7 +195,7 @@ class MegatronStrategy(DDPStrategy, io.IOMixin): save_ckpt_format (str): Distributed checkpoint format to use for checkpoint saving. Should be one of 'torch_dist' or 'zarr'. Defaults to 'torch_dist'. ckpt_async_save (bool): Whether to save checkpoints asynchronously to reduce checkpointing overhead. - Defaults to True. + Defaults to False. ckpt_torch_dist_multiproc (int): Number of extra processes per rank used during ckpt save with PyTorch distributed format. Defaults to None. ckpt_assume_constant_structure (bool): Allows caching some computation across checkpoint saves. @@ -292,7 +292,7 @@ def __init__( use_te_rng_tracker: bool = False, use_sharp: bool = False, save_ckpt_format: str = "torch_dist", - ckpt_async_save: bool = True, + ckpt_async_save: bool = False, ckpt_torch_dist_multiproc: int = None, ## TODO(ashors): put elsewhere? ckpt_assume_constant_structure: bool = False, ckpt_parallel_save: bool = True, diff --git a/nemo/lightning/pytorch/strategies/utils.py b/nemo/lightning/pytorch/strategies/utils.py index 3a3297414bbb..50b4f805433e 100755 --- a/nemo/lightning/pytorch/strategies/utils.py +++ b/nemo/lightning/pytorch/strategies/utils.py @@ -210,7 +210,9 @@ def create_checkpoint_io(wrapping_ckpt_io=None, **kwargs): if wrapping_ckpt_io: checkpoint_io = wrapping_ckpt_io(checkpoint_io) - if kwargs.get("async_save", False): + + async_save = kwargs.get("async_save", False) + if async_save: checkpoint_io = AsyncFinalizableCheckpointIO(checkpoint_io) return checkpoint_io diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py index c7c74e98a655..411a45603c37 100644 --- a/nemo/utils/callbacks/dist_ckpt_io.py +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -92,12 +92,16 @@ class AsyncFinalizableCheckpointIO(_WrappingCheckpointIO): AsyncCheckpointIO does). Allows to perform a (synchronous) finalization function after all ranks finish checkpoint saving. + This wrapper always creates the AsyncCallsQueue with persistent workers. This is known + to increase memory usage, and sometimes leads to out of memory errors. + NOTE: for correctness, this plugin must be used together with the AsyncFinalizerCallback callback which performs the finalization checks. Args: checkpoint_io (CheckpointIO): wrapped checkpoint_io object. Must be of type AsyncCompatibleCheckpointIO. + persistent_workers (bool): whether to use persistent workers for checkpoint writing. Defaults to False. Requires the underlying checkpoint_io.save_checkpoint to return save_fn, save_args, finalize_fn. """ @@ -108,7 +112,7 @@ def __init__(self, checkpoint_io: AsyncCompatibleCheckpointIO) -> None: raise ValueError(f'Incompatible wrapped checkpoint_io type: {type(checkpoint_io)}') super().__init__(checkpoint_io) - self.async_calls_queue = AsyncCallsQueue() + self.async_calls_queue = AsyncCallsQueue(persistent=True) def save_checkpoint( self, @@ -166,6 +170,7 @@ def teardown(self) -> None: if self.async_calls_queue.get_num_unfinalized_calls() > 0: # Can't do finalization now because some ranks might be lost logging.warning('Some async checkpoint saves might be not finalized properly.') + self.async_calls_queue.close() class AsyncFinalizerCallback(Callback):