Skip to content

Commit

Permalink
[checkpointing] import async checkpoint with pinned memory only when …
Browse files Browse the repository at this point in the history
…needed

ghstack-source-id: e460a8d6458f191f7f589fc908974f896b514690
Pull Request resolved: pytorch#333
  • Loading branch information
tianyu-l committed May 15, 2024
1 parent ac94484 commit 41d69d2
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
Expand Down Expand Up @@ -267,6 +266,15 @@ def _async_wait(self) -> None:
self.async_future.result()

def _async_with_pinned_memory(self, checkpoint_id: str) -> None:
try:
from torch.distributed._state_dict_utils import (
_copy_state_dict,
_create_cpu_state_dict,
)
except ImportError as e:
raise ImportError(
"Please install the latest PyTorch nightly to use async checkpointing with pinned memory."
) from e
state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states)
if self.cpu_offload_state_dict is None:
logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
Expand Down

0 comments on commit 41d69d2

Please sign in to comment.