-
Notifications
You must be signed in to change notification settings - Fork 270
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
**Summary** This is a WIP TorchFT integration PR. **Current Issues** This doesn't work at this moment as there are hanged groups when a new group joins. **Issue 1:** ~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~ Fixed with: pytorch/torchft#83 **Issue 2:** ~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~ Fixed with: pytorch/torchft#83 **Issue 3:** ~The byproduct of issue 1 and issue 2: group 1 will continue to print out~ ``` [rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618> ``` Fixed with pytorch/torchft#91 and several other fixes. **Issue 4:** When there are 3 groups, everyone requests the state dict every step. ***How to reproduce?*** Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. Seems to be fixed, will need more tests. **Issue 5:** Hang will happen if using functional collective. ***How to reproduce?*** Pull the latest version of this PR and comment out line 41 and uncomment line 42 in `torchtitan/utils.py` **Reproduce steps:** 1. Patch TorchFT with pytorch/torchft#82 2. Execute lighthouse 3. Execute the following command in one terminal: ``` TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0 ``` 4. Wait 10 seconds, execute following command in another terminal: ``` TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1 ``` ghstack-source-id: 088581cceee2c523f2a4ea358f334a0b1cce3927 Pull Request resolved: #834
- Loading branch information
Showing
8 changed files
with
269 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import importlib | ||
from typing import Any, Callable, Optional | ||
|
||
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict | ||
|
||
from torchtitan.config_manager import JobConfig | ||
|
||
if importlib.util.find_spec("torchft") is not None: | ||
import torchft as ft | ||
|
||
has_torchft = True | ||
else: | ||
has_torchft = False | ||
|
||
|
||
def init_ft_manager(job: JobConfig) -> Optional["ft.Manager"]: | ||
""" | ||
Initialize the FT manager for the given job. | ||
""" | ||
if not job.experimental.enable_torchft: | ||
return None | ||
|
||
if not has_torchft: | ||
raise ImportError("torchft is not installed. Please install it.") | ||
|
||
pg = ft.ProcessGroupBabyNCCL() | ||
manager = ft.Manager( | ||
pg=pg, | ||
min_replica_size=1, | ||
load_state_dict=None, | ||
state_dict=None, | ||
use_async_quorum=True, | ||
replica_id=f"torchtitan_ft_{job.experimental.ft_replica_group_id}", | ||
) | ||
|
||
return manager | ||
|
||
|
||
def set_ft_state_dict_fns(manager: Optional["ft.Manager"], ckpt_manager) -> None: | ||
""" | ||
Set the state dict for the given manager. | ||
""" | ||
if manager is None: | ||
return | ||
|
||
def state_dict(): | ||
ret = {} | ||
for k, v in ckpt_manager.staging_results().items(): | ||
if k in {"model", "optimizer", "lr_schedulers"}: | ||
ret[k] = v | ||
return ret | ||
|
||
def load_state_dict(state_dict): | ||
assert state_dict is not None | ||
for k, v in state_dict.items(): | ||
ckpt_manager.states[k].load_state_dict(v) | ||
|
||
manager.set_state_dict_fns(load_state_dict, state_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.