Skip to content

Commit 1f5854d

Browse files
committed
make checkpointing thread safe and deterministic
Summary: - the regression tests fail (on future changes) because it expects no recovery to happen, or it happens at the first step - because we validate the parameters at each step, if recovery happens non deterministically, we can't really validate the parameters - to fix this, copy the state dict before transferring it - the checkpointing also wasn't thread safe for http transport so use lock the model in the pre step hook and when we want to transfer the checkpoint
1 parent d358fb4 commit 1f5854d

File tree

4 files changed

+68
-6
lines changed

4 files changed

+68
-6
lines changed

torchft/checkpointing/http_transport.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Generator, List, Optional, TypeVar, cast
1717

1818
import torch
19+
from torch.distributed.tensor import DTensor, distribute_tensor
1920
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten
2021

2122
from torchft.checkpointing._rwlock import RWLock
@@ -266,6 +267,15 @@ def recv_checkpoint(
266267
return tree_unflatten(values, spec)
267268

268269

270+
def _clone_cpu_tensor(tensor: torch.Tensor) -> torch.Tensor:
271+
if isinstance(tensor, DTensor):
272+
return distribute_tensor(
273+
tensor.to_local().clone(), tensor.device_mesh, tensor.placements
274+
)
275+
else:
276+
return tensor.clone()
277+
278+
269279
def _to_cpu(values: List[T], pin_memory: bool) -> List[T]:
270280
out = []
271281
for v in values:
@@ -278,7 +288,7 @@ def _to_cpu(values: List[T], pin_memory: bool) -> List[T]:
278288
else:
279289
out.append(v.cpu())
280290
else:
281-
out.append(v)
291+
out.append(_clone_cpu_tensor(v))
282292
else:
283293
out.append(v)
284294
return out

torchft/local_sgd.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ def __init__(
8686
self._hooks: List[RemovableHandle] = []
8787

8888
def __enter__(self) -> "LocalSGD":
89+
self._hooks.append(
90+
self._local_optimizer.register_step_pre_hook(self._step_pre_hook)
91+
)
8992
# Add optimizer hook which increments the local step counter and syncs if necessary
9093
self._hooks.append(
9194
self._local_optimizer.register_step_post_hook(self._step_post_hook)
@@ -106,12 +109,20 @@ def __exit__(
106109

107110
return False # Propagate exceptions
108111

112+
def _step_pre_hook(
113+
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
114+
) -> None:
115+
# The checkpoint may transfer model parameters, so we need to make access to it thread safe
116+
self._manager.disallow_state_dict_read()
117+
109118
def _step_post_hook(
110119
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
111120
) -> None:
112121
"""
113122
This hook is registered on the optimizer and is called after the optimizer step.
114123
"""
124+
self._manager.allow_state_dict_read()
125+
115126
self._local_step += 1
116127
if self._local_step >= self._sync_every:
117128
self.sync()
@@ -677,12 +688,21 @@ def _restore_parameters(self) -> None:
677688
fragment.restore_parameters()
678689

679690
def __enter__(self) -> "DiLoCo":
691+
self._hooks.append(
692+
self._local_optimizer.register_step_pre_hook(self._step_pre_hook)
693+
)
680694
# Add optimizer hook which increments the local step counter and syncs if necessary
681695
self._hooks.append(
682696
self._local_optimizer.register_step_post_hook(self._step_post_hook)
683697
)
684698
return self
685699

700+
def _step_pre_hook(
701+
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
702+
) -> None:
703+
# The checkpoint may transfer model parameters, so we need to make access to it thread safe
704+
self._manager.disallow_state_dict_read()
705+
686706
def __exit__(
687707
self,
688708
exc_type: Optional[Type[BaseException]],
@@ -717,6 +737,8 @@ def _step_post_hook(
717737
"""
718738
This hook is registered on the optimizer and is called after the optimizer step.
719739
"""
740+
self._manager.allow_state_dict_read()
741+
720742
# We need to make sure all nodes send the same fragments in order.
721743
# This is to avoid deadlocking e.g.
722744
#

torchft/local_sgd_integ_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
)
3737

3838
logger: logging.Logger = logging.getLogger(__name__)
39+
logging.basicConfig(level=logging.INFO)
3940

4041

4142
def local_sgd_train_loop(
@@ -143,6 +144,7 @@ def assert_equal_global_state(
143144
rep1[step]["user"][f"StreamingDiLoCoFragment_{i}"],
144145
rep0[step]["user"][f"StreamingDiLoCoFragment_{i}"],
145146
check_device=False,
147+
msg=f"{step=} {i=}",
146148
)
147149
# Check all outer optimizers
148150
for i in range(
@@ -574,3 +576,9 @@ def test_streaming_diloco_commit_failure(
574576
self.assertEqual(
575577
event_injector.count[EventInjectorEvent.AllreduceFailure], 1
576578
)
579+
580+
581+
if __name__ == "__main__":
582+
import unittest
583+
584+
unittest.main()

torchft/manager.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555

5656
from torchft._torchft import ManagerClient, ManagerServer
5757
from torchft.checkpointing import CheckpointTransport, HTTPTransport
58+
from torchft.checkpointing._rwlock import RWLock
5859
from torchft.futures import future_timeout
5960
from torchft.work import _DummyWork
6061

@@ -216,6 +217,9 @@ def __init__(
216217
self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {}
217218
self._user_state_dicts: Dict[str, Callable[[], object]] = {}
218219

220+
# Protects state dict
221+
self._state_dict_lock = RWLock(timeout=timeout.total_seconds())
222+
219223
if load_state_dict and state_dict:
220224
self.register_state_dict_fn("default", load_state_dict, state_dict)
221225

@@ -324,6 +328,21 @@ def __init__(
324328
# first step is 1
325329
self._participating_replica_rank: Optional[int] = None
326330
self._participating_replica_world_size: int = 0
331+
self._is_state_dict_read_allowed = True
332+
333+
def allow_state_dict_read(self) -> None:
334+
if self._is_state_dict_read_allowed:
335+
return
336+
337+
self._is_state_dict_read_allowed = True
338+
self._state_dict_lock.w_release()
339+
340+
def disallow_state_dict_read(self) -> None:
341+
if not self._is_state_dict_read_allowed:
342+
return
343+
344+
self._is_state_dict_read_allowed = False
345+
self._state_dict_lock.w_acquire()
327346

328347
def register_state_dict_fn(
329348
self,
@@ -806,11 +825,14 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None:
806825
self._batches_committed = state_dict["batches_committed"]
807826

808827
def _manager_state_dict(self) -> Dict[str, object]:
809-
assert len(self._user_state_dicts) > 0, "user state_dict is not initialized."
810-
return {
811-
"user": {key: value() for key, value in self._user_state_dicts.items()},
812-
"torchft": self.state_dict(),
813-
}
828+
with self._state_dict_lock.r_lock():
829+
assert (
830+
len(self._user_state_dicts) > 0
831+
), "user state_dict is not initialized."
832+
return {
833+
"user": {key: value() for key, value in self._user_state_dicts.items()},
834+
"torchft": self.state_dict(),
835+
}
814836

815837
def state_dict(self) -> Dict[str, int]:
816838
"""

0 commit comments

Comments
 (0)