Skip to content

Commit 495ab9a

Browse files
committed
make checkpointing thread safe
Summary: - the checkpointing wasn't thread safe for http transport so use lock the model in the pre step hook and when we want to transfer the checkpoint
1 parent 22b8fa1 commit 495ab9a

File tree

4 files changed

+172
-5
lines changed

4 files changed

+172
-5
lines changed

torchft/local_sgd.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ def __init__(
8686
self._hooks: List[RemovableHandle] = []
8787

8888
def __enter__(self) -> "LocalSGD":
89+
self._hooks.append(
90+
self._local_optimizer.register_step_pre_hook(self._step_pre_hook)
91+
)
8992
# Add optimizer hook which increments the local step counter and syncs if necessary
9093
self._hooks.append(
9194
self._local_optimizer.register_step_post_hook(self._step_post_hook)
@@ -106,12 +109,20 @@ def __exit__(
106109

107110
return False # Propagate exceptions
108111

112+
def _step_pre_hook(
113+
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
114+
) -> None:
115+
# The checkpoint may transfer model parameters, so we need to make access to it thread safe
116+
self._manager.disallow_state_dict_read()
117+
109118
def _step_post_hook(
110119
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
111120
) -> None:
112121
"""
113122
This hook is registered on the optimizer and is called after the optimizer step.
114123
"""
124+
self._manager.allow_state_dict_read()
125+
115126
self._local_step += 1
116127
if self._local_step >= self._sync_every:
117128
self.sync()
@@ -682,12 +693,21 @@ def _restore_parameters(self) -> None:
682693
fragment.restore_parameters()
683694

684695
def __enter__(self) -> "DiLoCo":
696+
self._hooks.append(
697+
self._local_optimizer.register_step_pre_hook(self._step_pre_hook)
698+
)
685699
# Add optimizer hook which increments the local step counter and syncs if necessary
686700
self._hooks.append(
687701
self._local_optimizer.register_step_post_hook(self._step_post_hook)
688702
)
689703
return self
690704

705+
def _step_pre_hook(
706+
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
707+
) -> None:
708+
# The checkpoint may transfer model parameters, so we need to make access to it thread safe
709+
self._manager.disallow_state_dict_read()
710+
691711
def __exit__(
692712
self,
693713
exc_type: Optional[Type[BaseException]],
@@ -722,6 +742,8 @@ def _step_post_hook(
722742
"""
723743
This hook is registered on the optimizer and is called after the optimizer step.
724744
"""
745+
self._manager.allow_state_dict_read()
746+
725747
# We need to make sure all nodes send the same fragments in order.
726748
# This is to avoid deadlocking e.g.
727749
#

torchft/local_sgd_integ_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
)
3737

3838
logger: logging.Logger = logging.getLogger(__name__)
39+
logging.basicConfig(level=logging.INFO)
3940

4041

4142
def local_sgd_train_loop(
@@ -143,6 +144,7 @@ def assert_equal_global_state(
143144
rep1[step]["user"][f"StreamingDiLoCoFragment_{i}"],
144145
rep0[step]["user"][f"StreamingDiLoCoFragment_{i}"],
145146
check_device=False,
147+
msg=f"{step=} {i=}",
146148
)
147149
# Check all outer optimizers
148150
for i in range(
@@ -574,3 +576,9 @@ def test_streaming_diloco_commit_failure(
574576
self.assertEqual(
575577
event_injector.count[EventInjectorEvent.AllreduceFailure], 1
576578
)
579+
580+
581+
if __name__ == "__main__":
582+
import unittest
583+
584+
unittest.main()

torchft/manager.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555

5656
from torchft._torchft import ManagerClient, ManagerServer
5757
from torchft.checkpointing import CheckpointTransport, HTTPTransport
58+
from torchft.checkpointing._rwlock import RWLock
5859
from torchft.futures import future_timeout
5960
from torchft.work import _DummyWork
6061

@@ -216,6 +217,9 @@ def __init__(
216217
self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {}
217218
self._user_state_dicts: Dict[str, Callable[[], object]] = {}
218219

220+
# Protects state dict
221+
self._state_dict_lock = RWLock(timeout=timeout.total_seconds())
222+
219223
if load_state_dict and state_dict:
220224
self.register_state_dict_fn("default", load_state_dict, state_dict)
221225

@@ -324,6 +328,21 @@ def __init__(
324328
# first step is 1
325329
self._participating_replica_rank: Optional[int] = None
326330
self._participating_replica_world_size: int = 0
331+
self._is_state_dict_read_allowed = True
332+
333+
def allow_state_dict_read(self) -> None:
334+
if self._is_state_dict_read_allowed:
335+
return
336+
337+
self._is_state_dict_read_allowed = True
338+
self._state_dict_lock.w_release()
339+
340+
def disallow_state_dict_read(self) -> None:
341+
if not self._is_state_dict_read_allowed:
342+
return
343+
344+
self._is_state_dict_read_allowed = False
345+
self._state_dict_lock.w_acquire()
327346

328347
def register_state_dict_fn(
329348
self,
@@ -806,11 +825,14 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None:
806825
self._batches_committed = state_dict["batches_committed"]
807826

808827
def _manager_state_dict(self) -> Dict[str, object]:
809-
assert len(self._user_state_dicts) > 0, "user state_dict is not initialized."
810-
return {
811-
"user": {key: value() for key, value in self._user_state_dicts.items()},
812-
"torchft": self.state_dict(),
813-
}
828+
with self._state_dict_lock.r_lock():
829+
assert (
830+
len(self._user_state_dicts) > 0
831+
), "user state_dict is not initialized."
832+
return {
833+
"user": {key: value() for key, value in self._user_state_dicts.items()},
834+
"torchft": self.state_dict(),
835+
}
814836

815837
def state_dict(self) -> Dict[str, int]:
816838
"""

torchft/manager_test.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import concurrent
8+
import threading
9+
import time
810
from datetime import timedelta
911
from typing import Optional
1012
from unittest import TestCase
@@ -14,6 +16,7 @@
1416
from torch.distributed import TCPStore
1517

1618
from torchft._torchft import QuorumResult
19+
from torchft.checkpointing._rwlock import RWLock
1720
from torchft.checkpointing.transport import CheckpointTransport
1821
from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode
1922
from torchft.process_group import ProcessGroup
@@ -778,3 +781,115 @@ def test_max_retries(self, client_mock: MagicMock) -> None:
778781
# This should succeed and reset the counter
779782
self.assertTrue(manager.should_commit())
780783
self.assertEqual(manager._commit_failures, 0)
784+
785+
@patch("torchft.manager.ManagerClient", autospec=True)
786+
def test_state_dict_lock_allow_disallow(self, client_mock: MagicMock) -> None:
787+
"""Test that allow_state_dict_read and disallow_state_dict_read methods work correctly."""
788+
manager = self._create_manager()
789+
790+
# Initially, state dict read should be allowed
791+
self.assertTrue(manager._is_state_dict_read_allowed)
792+
793+
# Test disallow_state_dict_read
794+
manager.disallow_state_dict_read()
795+
self.assertFalse(manager._is_state_dict_read_allowed)
796+
self.assertTrue(manager._state_dict_lock.w_locked())
797+
798+
# Calling disallow_state_dict_read again should be a no-op
799+
manager.disallow_state_dict_read()
800+
self.assertFalse(manager._is_state_dict_read_allowed)
801+
self.assertTrue(manager._state_dict_lock.w_locked())
802+
803+
# Test allow_state_dict_read
804+
manager.allow_state_dict_read()
805+
self.assertTrue(manager._is_state_dict_read_allowed)
806+
self.assertFalse(manager._state_dict_lock.w_locked())
807+
808+
# Calling allow_state_dict_read again should be a no-op
809+
manager.allow_state_dict_read()
810+
self.assertTrue(manager._is_state_dict_read_allowed)
811+
self.assertFalse(manager._state_dict_lock.w_locked())
812+
813+
@patch("torchft.manager.ManagerClient", autospec=True)
814+
def test_state_dict_lock_concurrent_access(self, client_mock: MagicMock) -> None:
815+
"""Test that _state_dict_lock properly protects concurrent access to the state dictionary."""
816+
manager: Manager = self._create_manager()
817+
818+
# Create flags for thread synchronization
819+
access_attempted: threading.Event = threading.Event()
820+
can_proceed: threading.Event = threading.Event()
821+
access_result: dict[str, bool] = {"succeeded": False}
822+
823+
def try_access_state_dict() -> None:
824+
# Wait until the main thread signals it's ready
825+
nonlocal access_attempted, can_proceed, access_result, manager
826+
access_attempted.set()
827+
can_proceed.wait(timeout=1.0)
828+
829+
# Try to access the state dict
830+
if manager._is_state_dict_read_allowed:
831+
access_result["succeeded"] = True
832+
833+
# Start a thread that will try to access the state dict
834+
thread = threading.Thread(target=try_access_state_dict)
835+
thread.daemon = True
836+
thread.start()
837+
838+
# Disallow state dict read
839+
manager.disallow_state_dict_read()
840+
self.assertFalse(manager._is_state_dict_read_allowed)
841+
842+
# Wait for the thread to be ready
843+
access_attempted.wait(timeout=1.0)
844+
845+
# Signal the thread to proceed while state dict read is disallowed
846+
can_proceed.set()
847+
thread.join(timeout=1.0)
848+
849+
# The thread should not have been able to access the state dict
850+
self.assertFalse(access_result["succeeded"])
851+
852+
# Reset for the second part of the test
853+
access_attempted.clear()
854+
can_proceed.clear()
855+
856+
# Start another thread
857+
thread = threading.Thread(target=try_access_state_dict)
858+
thread.daemon = True
859+
thread.start()
860+
861+
# Allow state dict read
862+
manager.allow_state_dict_read()
863+
self.assertTrue(manager._is_state_dict_read_allowed)
864+
865+
# Wait for the thread to be ready
866+
access_attempted.wait(timeout=1.0)
867+
868+
# Signal the thread to proceed while state dict read is allowed
869+
can_proceed.set()
870+
thread.join(timeout=1.0)
871+
872+
# The thread should now have been able to access the state dict
873+
self.assertTrue(access_result["succeeded"])
874+
875+
@patch("torchft.manager.ManagerClient", autospec=True)
876+
def test_manager_state_dict_with_lock(self, client_mock: MagicMock) -> None:
877+
"""Test that _manager_state_dict properly uses the read lock."""
878+
manager = self._create_manager()
879+
880+
# Replace the real RWLock with a mock to track lock acquisition
881+
original_lock = manager._state_dict_lock
882+
mock_lock = create_autospec(RWLock)
883+
mock_context = MagicMock()
884+
mock_lock.r_lock.return_value.__enter__ = lambda _: mock_context
885+
mock_lock.r_lock.return_value.__exit__ = lambda *args: None
886+
manager._state_dict_lock = mock_lock
887+
888+
# Call _manager_state_dict
889+
result = manager._manager_state_dict()
890+
891+
# Verify that r_lock was called
892+
mock_lock.r_lock.assert_called_once()
893+
894+
# Restore the original lock
895+
manager._state_dict_lock = original_lock

0 commit comments

Comments
 (0)