Skip to content

Commit

Permalink
Integrate TorchFT
Browse files Browse the repository at this point in the history
**Summary**
This is a WIP TorchFT integration PR.

**Current Issues**

This doesn't work at this moment as there are hanged groups when a new
group joins.

**Issue 1:**
~Group 0 and group 1 will hang during the first `should_commit` after
group 1 applying the pending state_dict from group 0.~

Fixed with: pytorch/torchft#83

**Issue 2:**
~Group 0 and group 1 will pass the `should_commit` but group 0 needs
healing which is wrong and the healing process will cause another hang.~

Fixed with: pytorch/torchft#83

**Issue 3:**
~The byproduct of issue 1 and issue 2: group 1 will continue to print
out~
```
[rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN
socketProgress: Connection closed by remote peer
devgpu051.cln3.svc.fbinfra.net<33618>
```

Fixed with pytorch/torchft#91 and several other
fixes.

**Issue 4:**
When there are 3 groups, everyone requests the state dict every step.
***How to reproduce?***
Using the `Reproduce steps` to run 2 groups, then add another group by
modifying the command.

Seems to be fixed, will need more tests.

**Issue 5:**
Hang will happen if using functional collective.
***How to reproduce?***
Pull the latest version of this PR and comment out line 41 and uncomment
line 42 in `torchtitan/utils.py`

**Reproduce steps:**

1. Patch TorchFT with pytorch/torchft#82
2. Execute lighthouse
3. Execute the following command in one terminal:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1
NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2
--experimental.enable_torchft --experimental.ft_replica_group_id=0
```
4. Wait 10 seconds, execute following command in another terminal:
```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3
NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2
--experimental.enable_torchft --experimental.ft_replica_group_id=1
```

ghstack-source-id: 088581cceee2c523f2a4ea358f334a0b1cce3927
Pull Request resolved: #834
  • Loading branch information
fegin committed Feb 12, 2025
1 parent 58ab4d1 commit b5e09d0
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 63 deletions.
4 changes: 4 additions & 0 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ if [ $# -ne 0 ]; then
overrides="$*"
fi

TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT:-"29512"}

PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
TORCHFT_LIGHTHOUSE=http://localhost:29510 \
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT} \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE} $overrides
167 changes: 117 additions & 50 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
set_model_state_dict,
Expand Down Expand Up @@ -144,13 +145,18 @@ def __init__(
lr_schedulers: LRSchedulersContainer,
states: Dict[str, Any],
job_config: JobConfig,
ft_manager: Optional[Any] = None,
) -> None:
ckpt_config = job_config.checkpoint
self.enable_checkpoint = ckpt_config.enable_checkpoint
self.keep_latest_k = ckpt_config.keep_latest_k
self.ft_manager = ft_manager
self.enable_staging = (
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
) or self.ft_manager

if not self.enable_checkpoint:
if not self.enable_checkpoint and self.ft_manager is None:
return

"""
Note: Pipeline Parallelism and Virtual Stages
Expand Down Expand Up @@ -185,6 +191,13 @@ def __init__(
}
)

async_mode = ckpt_config.async_mode.lower()
self.staging = False
self.sending_to_checkpoint_mp = False
self.staging_id = None
self.cpu_offload_state_dict = None
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None

self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.interval_type = (
IntervalType.SECONDS
Expand All @@ -199,6 +212,7 @@ def __init__(
if async_mode == AsyncMode.ASYNC or self.interval_type == IntervalType.SECONDS:
self.pg = dist.new_group(backend="gloo")

self.keep_latest_k = ckpt_config.keep_latest_k
self.model_weights_only = ckpt_config.model_weights_only
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
self.exclude_from_loading = ckpt_config.exclude_from_loading
Expand All @@ -223,10 +237,6 @@ def __init__(
daemon=True,
)
self.mp.start()
self.cpu_offload_state_dict = None
self.staging = False
self.staging_id = None
self.staging_stream = torch.cuda.Stream()
else:
raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}")

Expand All @@ -240,8 +250,61 @@ def __del__(self):
self.mp.join()

def reset(self) -> None:
# We need to stage the local state if another replicate joins during the
# first step.
if self.ft_manager:
self.cpu_staging(None)
self.begin_time = time.monotonic()

def _initialize_states(
self,
states: Dict[str, Any],
dataloader: DataLoader,
model_parts: List[nn.Module],
optimizers: OptimizersContainer,
lr_schedulers: LRSchedulersContainer,
) -> None:
"""
Note: Pipeline Parallelism and Virtual Stages
1. Even for simple PP schedules, there is a separate optimizer each PP rank.
rank0's optimizer would have a param_group[0] which refers to layers.0 in the
original model. rank1's would _also_ have a param_group[0], since it's index based,
but referring to layers.1.
When saving, these collide and one of them is lost. Then when reloading, only one
stage can restore its optimizer states, others will error.
The solution to this problem is optimizer flattening: it landed in #127071
and is enabled in TorchTitan by passing the 'flatten_optimizer_state_dict'
kwarg to DCP functions called in the OptimizerContainer.
2. With complex PP schedules, we have multiple model chunks per pp rank. This
compounds challenge (1) by also requiring us to reason about multiple 'optim'
objects locally.
We solve this in the Model and Optimizer wrapper classes by flattening the
state dicts from each object into one state dict before saving/loading.
We rely on the individual state_dicts to not collide, which is gauranteed for
the model by correct pipeline splitting and for the optimizer by the flattening
support described in (1).
3. LR schedulers also index model states like optimizers and would need to be
flattened properly to support resharding. Unfortunately, the implementations of
different lr_schedulers do not follow a clear pattern like optimizers do, so it's
hard to write a generic 'flattener' utility.
TODO: This is currently unsolved and needs a fix.
"""
self.states = states
self.states.update(
{
"model": ModelWrapper(model_parts),
"optimizer": optimizers,
"dataloader": dataloader,
"lr_scheduler": lr_schedulers,
}
)

def _create_checkpoint_id(self, step: int) -> str:
return os.path.join(self.folder, f"step-{step}")

Expand Down Expand Up @@ -324,31 +387,8 @@ def _async_wait(self) -> None:
self.async_future.result()

def _async_with_pinned_memory(self, checkpoint_id: str) -> None:
try:
from torch.distributed._state_dict_utils import (
_copy_state_dict,
_create_cpu_state_dict,
)
except ImportError as e:
raise ImportError(
"Please install the latest PyTorch nightly to use async checkpointing with pinned memory."
) from e
state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states)
if self.cpu_offload_state_dict is None:
logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
self.cpu_offload_state_dict = _create_cpu_state_dict(
state_dict, pin_memory=True, share_memory=True
)

logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
with torch.cuda.stream(self.staging_stream):
self.cpu_offload_state_dict = _copy_state_dict(
state_dict,
self.cpu_offload_state_dict,
non_blocking=True,
)
self.staging = True
self.staging_id = checkpoint_id
self.cpu_staging(checkpoint_id)
self.sending_to_checkpoint_mp = True

def save(self, curr_step: int, force: bool = False) -> None:
"""
Expand All @@ -358,6 +398,8 @@ def save(self, curr_step: int, force: bool = False) -> None:
for initial seed checkpoint.
"""
if not self._should_save(curr_step, force):
if self.ft_manager:
self.cpu_staging(None)
return

begin = time.monotonic()
Expand All @@ -381,26 +423,51 @@ def save(self, curr_step: int, force: bool = False) -> None:
f"in {time.monotonic() - begin:.2f} seconds."
)

def cpu_staging(self, checkpoint_id: Optional[str]) -> None:
"""Offload state_dict to CPU memory"""
state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states)
if self.cpu_offload_state_dict is None:
logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
self.cpu_offload_state_dict = _create_cpu_state_dict(
state_dict, pin_memory=True, share_memory=True
)

logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
with torch.cuda.stream(self.staging_stream):
self.cpu_offload_state_dict = _copy_state_dict(
state_dict,
self.cpu_offload_state_dict,
non_blocking=True,
)
self.staging = True
self.staging_id = checkpoint_id

def wait_for_staging(self) -> None:
if not self.staging_stream.query():
self.staging_stream.synchronize()
self.staging = False

def staging_results(self) -> Dict[str, Any]:
self.maybe_wait_for_staging()
return self.cpu_offload_state_dict

def maybe_wait_for_staging(self) -> None:
if (
self.enable_checkpoint
and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
and self.staging
):
if not self.staging_stream.query():
self.staging_stream.synchronize()

def sync_func():
self.mp_queue_send.put_nowait(
(self.cpu_offload_state_dict, self.staging_id)
)

# This may be a faster way to do zero-overhead checkpointing staging
# checkpointing but we need more thorough investigation before
# swithing to this method.
# self.my_thread = threading.Thread(target=func).start()
sync_func()
self.staging = False
if self.enable_staging and self.staging:
self.wait_for_staging()

if self.sending_to_checkpoint_mp:
# Copy the sync staging result to another process.
def sync_func():
self.mp_queue_send.put_nowait(
(self.cpu_offload_state_dict, self.staging_id)
)

# This may be a faster way to do zero-overhead checkpointing staging
# checkpointing but we need more thorough investigation before
# swithing to this method.
# self.my_thread = threading.Thread(target=func).start()
sync_func()
self.sending_to_checkpoint_mp = False

def load(self, step: int = -1) -> bool:
if not self.enable_checkpoint:
Expand Down
13 changes: 13 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,19 @@ def __init__(self):
action="store_true",
)

self.parser.add_argument(
"--experimental.enable_torchft",
action="store_true",
help="Enable TorchFT integration.",
)

self.parser.add_argument(
"--experimental.ft_replica_group_id",
type=int,
default=-1,
help="The FT replicate group of this run.",
)

def to_dict(self):
return self.args_dict

Expand Down
58 changes: 58 additions & 0 deletions torchtitan/ft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import importlib
from typing import Any, Callable, Optional

from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict

from torchtitan.config_manager import JobConfig

if importlib.util.find_spec("torchft") is not None:
import torchft as ft

has_torchft = True
else:
has_torchft = False


def init_ft_manager(job: JobConfig) -> Optional["ft.Manager"]:
"""
Initialize the FT manager for the given job.
"""
if not job.experimental.enable_torchft:
return None

if not has_torchft:
raise ImportError("torchft is not installed. Please install it.")

pg = ft.ProcessGroupBabyNCCL()
manager = ft.Manager(
pg=pg,
min_replica_size=1,
load_state_dict=None,
state_dict=None,
use_async_quorum=True,
replica_id=f"torchtitan_ft_{job.experimental.ft_replica_group_id}",
)

return manager


def set_ft_state_dict_fns(manager: Optional["ft.Manager"], ckpt_manager) -> None:
"""
Set the state dict for the given manager.
"""
if manager is None:
return

def state_dict():
ret = {}
for k, v in ckpt_manager.staging_results().items():
if k in {"model", "optimizer", "lr_schedulers"}:
ret[k] = v
return ret

def load_state_dict(state_dict):
assert state_dict is not None
for k, v in state_dict.items():
ckpt_manager.states[k].load_state_dict(v)

manager.set_state_dict_fns(load_state_dict, state_dict)
39 changes: 33 additions & 6 deletions torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,32 @@ def zero_grad(self) -> None:
pass


class FTOptimizersContainer(Optimizer):
def __init__(
self,
model_parts: List[nn.Module],
optimizer_kwargs: Dict[str, Any],
name: str,
ft_manager: Any,
) -> None:
import torchft as ft

super().__init__()

# Force to initialize the optimizer state so that `optim.step()`
# won't be called by state_dict() and load_state_dict().
_ = {
k: v
for sd in map(get_optimizer_state_dict, model_parts, self.optimizers)
for k, v in sd.items()
}
self.optimizers = [ft.Optimizer(ft_manager, optim) for optim in self.optimizers]


def build_optimizers(
model_parts: List[nn.Module], job_config: JobConfig
model_parts: List[nn.Module],
job_config: JobConfig,
ft_manager: Optional[Any] = None,
) -> OptimizersContainer:
"""Create a OptimizersContainer for the given model parts and job config.
Expand Down Expand Up @@ -213,11 +237,14 @@ def build_optimizers(
"foreach": not fused,
}

return (
OptimizersContainer(model_parts, optimizer_kwargs, name)
if not optim_in_bwd
else OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
)
if optim_in_bwd and ft_manager:
raise ValueError("TorchFT is not supported with optimizers in backward.")
elif optim_in_bwd:
return OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
elif ft_manager:
return FTOptimizersContainer(model_parts, optimizer_kwargs, name, ft_manager)
else:
return OptimizersContainer(model_parts, optimizer_kwargs, name)


class LRSchedulersContainer(Stateful):
Expand Down
Loading

0 comments on commit b5e09d0

Please sign in to comment.