Skip to content

Commit 0b91490

Browse files
committed
deep copy state dict for checkpoint
Summary: deep copy the state dict for sending checkpoint because if the replica moves to the next step, the state dict can change before the checkpoint is sent
1 parent fef4abc commit 0b91490

File tree

4 files changed

+27
-1
lines changed

4 files changed

+27
-1
lines changed

torchft/checkpointing/http_transport.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Generator, List, Optional, TypeVar, cast
1717

1818
import torch
19+
from torch.distributed.tensor import DTensor, distribute_tensor
1920
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten
2021

2122
from torchft.checkpointing._rwlock import RWLock
@@ -265,6 +266,13 @@ def recv_checkpoint(
265266

266267
return tree_unflatten(values, spec)
267268

269+
def _clone_cpu_tensor(tensor: torch.Tensor) -> torch.Tensor:
270+
if isinstance(tensor, DTensor):
271+
return distribute_tensor(
272+
tensor.to_local().clone(), tensor.device_mesh, tensor.placements
273+
)
274+
else:
275+
return tensor.clone()
268276

269277
def _to_cpu(values: List[T], pin_memory: bool) -> List[T]:
270278
out = []
@@ -278,7 +286,7 @@ def _to_cpu(values: List[T], pin_memory: bool) -> List[T]:
278286
else:
279287
out.append(v.cpu())
280288
else:
281-
out.append(v)
289+
out.append(_clone_cpu_tensor(v))
282290
else:
283291
out.append(v)
284292
return out

torchft/checkpointing/pg_transport.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ def metadata(self) -> str:
194194
def disallow_checkpoint(self) -> None:
195195
pass
196196

197+
def allow_checkpoint(self) -> None:
198+
pass
199+
197200
def send_checkpoint(
198201
self, dst_ranks: list[int], step: int, state_dict: T, timeout: timedelta
199202
) -> None:

torchft/checkpointing/transport.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ def disallow_checkpoint(self) -> None:
4444
"""
4545
...
4646

47+
def allow_checkpoint(self) -> None:
48+
"""
49+
Called when checkpoint is allowed to be sent to make sure access to the state_dict is safe.
50+
"""
51+
...
52+
4753
@abstractmethod
4854
def recv_checkpoint(
4955
self, src_rank: int, metadata: str, step: int, timeout: timedelta

torchft/local_sgd.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ def __init__(
8585
self._hooks: List[RemovableHandle] = []
8686

8787
def __enter__(self) -> "LocalSGD":
88+
self._hooks.append(
89+
self._local_optimizer.register_step_pre_hook(self._step_pre_hook)
90+
)
8891
# Add optimizer hook which increments the local step counter and syncs if necessary
8992
self._hooks.append(
9093
self._local_optimizer.register_step_post_hook(self._step_post_hook)
@@ -105,12 +108,18 @@ def __exit__(
105108

106109
return False # Propagate exceptions
107110

111+
def _step_pre_hook(self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]) -> None:
112+
# The checkpoint may transfer model parameters, so we need to make access to it thread safe
113+
self._manager._checkpoint_transport.disallow_checkpoint()
114+
108115
def _step_post_hook(
109116
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
110117
) -> None:
111118
"""
112119
This hook is registered on the optimizer and is called after the optimizer step.
113120
"""
121+
self._manager._checkpoint_transport.allow_checkpoint()
122+
114123
self._local_step += 1
115124
if self._local_step >= self._sync_every:
116125
self.sync()

0 commit comments

Comments
 (0)