Skip to content

Commit 3a61b86

Browse files
committed
WIP
1 parent 71dea16 commit 3a61b86

File tree

7 files changed

+444
-30
lines changed

7 files changed

+444
-30
lines changed

run_train.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@ set -ex
1010
# use envs as local overwrites for convenience
1111
# e.g.
1212
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
13-
NGPU=${NGPU:-"8"}
14-
export LOG_RANK=${LOG_RANK:-0}
13+
# NGPU=${NGPU:-"8"}
14+
NGPU=${NGPU:-"4"}
15+
# export LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7}
16+
# export LOG_RANK=${LOG_RANK:-0,1,2,3}
17+
export LOG_RANK=${LOG_RANK:-1}
1518
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
1619
TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"}
1720

torchtitan/distributed/expert_parallel.py

Lines changed: 165 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8-
from typing import Callable, Literal
8+
from typing import Callable, Literal, Dict
99

1010
import torch
1111
import torch.nn as nn
@@ -22,6 +22,160 @@
2222
)
2323
from torch.distributed.tensor.parallel import ParallelStyle
2424

25+
import threading
26+
import torch
27+
from typing import Optional
28+
import time
29+
30+
class HookSequenceCoordinator:
31+
"""Coordinates hooks based on a predefined sequence"""
32+
33+
def __init__(self):
34+
self._lock = threading.Lock()
35+
self._condition = threading.Condition(self._lock)
36+
37+
# Define your desired execution sequence matching:
38+
# stageB.combine() -> stageA.forward_attention() -> stageB.backward_moe() ->
39+
# stageA.dispatch() -> stageB.dispatch() -> stageA.forward_moe() ->
40+
# stageB.backward_attention() -> stageA.combine()
41+
self._hook_sequence = [
42+
"combine_D_bwd",
43+
"dispatch_A_fwd",
44+
"combine_C_bwd",
45+
"dispatch_B_fwd",
46+
"dispatch_B_bwd",
47+
"combine_C_fwd",
48+
"dispatch_A_bwd",
49+
"combine_D_fwd",
50+
]
51+
# Create a semaphore for each hook in the sequence
52+
self._semaphores: Dict[str, threading.Semaphore] = {}
53+
self._reset_semaphores()
54+
55+
# Coordination control - disabled by default
56+
self._coordination_enabled = False
57+
self._cycle_count = 0
58+
59+
def _reset_semaphores(self):
60+
"""Reset all semaphores - first one gets 1 permit, others get 0"""
61+
self._semaphores.clear()
62+
for i, hook_name in enumerate(self._hook_sequence):
63+
# First semaphore starts with 1 permit, others start with 0
64+
initial_permits = 1 if i == 0 else 0
65+
self._semaphores[hook_name] = threading.Semaphore(initial_permits)
66+
67+
def enable_coordination(self):
68+
"""Enable hook coordination"""
69+
self._coordination_enabled = True
70+
self._reset_semaphores() # Reset semaphores when enabling
71+
print("[COORDINATION] Hook coordination ENABLED")
72+
73+
def disable_coordination(self):
74+
"""Disable hook coordination"""
75+
self._coordination_enabled = False
76+
# Release all semaphores so no threads get stuck
77+
for semaphore in self._semaphores.values():
78+
try:
79+
semaphore.release()
80+
except ValueError:
81+
pass # Semaphore was already at max value
82+
print("[COORDINATION] Hook coordination DISABLED")
83+
84+
def is_coordination_enabled(self) -> bool:
85+
"""Check if coordination is currently enabled"""
86+
return self._coordination_enabled
87+
88+
def reset_coordination(self):
89+
"""Reset coordination state (useful between training runs)"""
90+
self._cycle_count = 0
91+
self._reset_semaphores()
92+
print("[COORDINATION] Hook coordination state RESET")
93+
94+
def acquire_execution(self, hook_name: str):
95+
"""Acquire execution permission using semaphores"""
96+
# If coordination is disabled, just pass through
97+
if not self._coordination_enabled:
98+
print(f"[PASSTHROUGH] {hook_name} executing (coordination disabled)")
99+
return
100+
101+
# Check if hook is in our sequence
102+
if hook_name not in self._semaphores:
103+
print(f"[WARNING] {hook_name} not in sequence, executing without coordination")
104+
return
105+
106+
# Acquire the semaphore for this hook (blocks until available)
107+
print(f"[WAITING] {hook_name} waiting for semaphore")
108+
self._semaphores[hook_name].acquire()
109+
print(f"[EXECUTING] {hook_name} acquired semaphore")
110+
111+
def release_execution(self, hook_name: str):
112+
"""Release execution and signal next hook"""
113+
# If coordination is disabled, just pass through
114+
if not self._coordination_enabled:
115+
return
116+
117+
# Check if hook is in our sequence
118+
if hook_name not in self._semaphores:
119+
return
120+
121+
# Find the next hook in the sequence and release its semaphore
122+
try:
123+
current_index = self._hook_sequence.index(hook_name)
124+
next_index = (current_index + 1) % len(self._hook_sequence)
125+
next_hook = self._hook_sequence[next_index]
126+
127+
print(f"[COMPLETED] {hook_name} completed, signaling {next_hook}")
128+
self._semaphores[next_hook].release()
129+
130+
# Check if we completed a full cycle
131+
if next_index == 0:
132+
self._cycle_count += 1
133+
print(f"[CYCLE] Completed cycle {self._cycle_count}")
134+
135+
except ValueError:
136+
print(f"[ERROR] {hook_name} not found in sequence")
137+
138+
# Global coordinator
139+
_hook_coordinator = HookSequenceCoordinator()
140+
141+
class SyncHook(torch.autograd.Function):
142+
"""Sync hook that follows a predefined execution sequence"""
143+
144+
@staticmethod
145+
def forward(ctx, x, hook_name):
146+
ctx.hook_name = hook_name
147+
148+
# Use forward-specific hook name
149+
forward_hook_name = f"{hook_name}_fwd"
150+
_hook_coordinator.acquire_execution(forward_hook_name)
151+
152+
try:
153+
if _hook_coordinator.is_coordination_enabled():
154+
print(f"[FORWARD HOOK] {forward_hook_name} (coordinated)")
155+
else:
156+
print(f"[FORWARD HOOK] {forward_hook_name} (uncoordinated)")
157+
return x
158+
finally:
159+
_hook_coordinator.release_execution(forward_hook_name)
160+
161+
@staticmethod
162+
def backward(ctx, grad_output):
163+
hook_name = ctx.hook_name
164+
165+
# Use backward-specific hook name
166+
backward_hook_name = f"{hook_name}_bwd"
167+
_hook_coordinator.acquire_execution(backward_hook_name)
168+
169+
try:
170+
if _hook_coordinator.is_coordination_enabled():
171+
print(f"[BACKWARD HOOK] {backward_hook_name} (coordinated)")
172+
else:
173+
print(f"[BACKWARD HOOK] {backward_hook_name} (uncoordinated)")
174+
return grad_output, None
175+
finally:
176+
_hook_coordinator.release_execution(backward_hook_name)
177+
178+
25179

26180
TOKEN_GROUP_ALIGN_SIZE_M = 8
27181
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]
@@ -77,7 +231,6 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
77231
self._partition_fn,
78232
)
79233

80-
81234
class ExpertParallel(ParallelStyle):
82235
def __init__(self):
83236
super().__init__()
@@ -90,6 +243,9 @@ def _token_dispatch(self, mod, inputs, device_mesh):
90243
routed_input, num_tokens_per_expert = inputs
91244
ep_size = device_mesh.shape[0]
92245

246+
# HOOK: signal ready for sync
247+
routed_input = SyncHook.apply(routed_input, "dispatch_A")
248+
93249
# generate the input splits and output splits for all-to-all
94250
with torch.no_grad():
95251
num_tokens_per_expert_group = all_to_all_single(
@@ -135,6 +291,9 @@ def _token_dispatch(self, mod, inputs, device_mesh):
135291
# generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
136292
# each expert gets locally is a multiple of ALIGN_SIZE_M.
137293

294+
# HOOK: signal ready for sync
295+
routed_input = SyncHook.apply(routed_input, "dispatch_B")
296+
138297
return routed_input, num_tokens_per_expert_group
139298

140299
@staticmethod
@@ -146,12 +305,16 @@ def _partition_fn(name, mod, device_mesh):
146305

147306
# performing all-to-all combine on the output
148307
def _token_combine(self, mod, routed_output, device_mesh):
308+
# HOOK: signal ready for sync
309+
routed_output = SyncHook.apply(routed_output, "combine_C")
149310
routed_output = all_to_all_single_autograd(
150311
routed_output,
151312
self.input_splits,
152313
self.output_splits,
153314
device_mesh.get_group(),
154315
)
316+
# HOOK: signal ready for sync
317+
routed_output = SyncHook.apply(routed_output, "combine_D")
155318
return routed_output
156319

157320
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
1212
from torchtitan.components.tokenizer import build_hf_tokenizer
1313
from torchtitan.datasets.hf_datasets import build_hf_dataloader
14-
from torchtitan.models.llama3.infra.pipeline import pipeline_llama
14+
from torchtitan.models.llama3.infra.pipeline import pipeline_llama, pipeline_llama_tracer
1515
from torchtitan.models.moe import MoEArgs
1616

1717
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
@@ -32,10 +32,11 @@
3232
deepseekv3_configs = {
3333
"debugmodel": DeepSeekV3ModelArgs(
3434
vocab_size=2000,
35-
dim=256,
35+
# needs at least dim 8?
36+
dim=8,
3637
inter_dim=1024,
3738
moe_inter_dim=256,
38-
n_layers=6,
39+
n_layers=16,
3940
n_dense_layers=1,
4041
n_heads=16,
4142
moe_args=MoEArgs(

torchtitan/models/deepseek_v3/train_configs/debug_model.toml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ description = "DeepSeek-V3 debug training"
44
print_args = false
55

66
[profiling]
7-
enable_profiling = false
7+
enable_profiling = true
88
save_traces_folder = "profile_trace"
9-
profile_freq = 10
9+
profile_freq = 5
1010
enable_memory_snapshot = false
1111
save_memory_snapshot_folder = "memory_snapshot"
1212

@@ -36,22 +36,23 @@ decay_type = "linear"
3636
min_lr_factor = 0.0
3737

3838
[training]
39-
local_batch_size = 8
40-
seq_len = 2048
39+
local_batch_size = 4
40+
seq_len = 4
4141
max_norm = 1.0 # grad norm clipping
42-
steps = 10
42+
steps = 6
4343
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
44+
# dataset = "c4"
4445

4546
[parallelism]
4647
data_parallel_replicate_degree = 1
4748
data_parallel_shard_degree = -1
4849
fsdp_reshard_after_forward = "default" # default / never / always
4950
tensor_parallel_degree = 1
5051
enable_async_tensor_parallel = false
51-
pipeline_parallel_degree = 1
52-
pipeline_parallel_schedule = "1F1B"
52+
pipeline_parallel_degree = 2
53+
expert_parallel_degree = 2
5354
context_parallel_degree = 1
54-
expert_parallel_degree = 1
55+
pipeline_parallel_schedule = "DualPipeV"
5556
expert_tensor_parallel_degree = 1
5657

5758
[checkpoint]
@@ -63,7 +64,7 @@ export_dtype = "float32"
6364
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
6465

6566
[activation_checkpoint]
66-
mode = "selective" # ["none", "selective", "full"]
67+
mode = "none" # ["none", "selective", "full"]
6768
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy
6869

6970
[compile]

torchtitan/models/llama3/infra/pipeline.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
pipeline_module_split,
2626
)
2727

28+
from torch.distributed.pipelining import SplitPoint, pipeline
29+
from torch.distributed.pipelining.stage import _PipelineStage
30+
2831
from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction
2932
from torchtitan.tools.logging import logger
3033

@@ -148,3 +151,75 @@ def pipeline_llama(
148151
has_last_stage = True
149152

150153
return pp_schedule, model_parts, has_first_stage, has_last_stage
154+
155+
156+
def pipeline_llama_tracer(
157+
model: nn.Module,
158+
parallel_dims: ParallelDims,
159+
job_config: JobConfig,
160+
device: torch.device,
161+
model_args: BaseModelArgs,
162+
parallelize_fn: ParallelizeFunction,
163+
loss_fn: LossFunction,
164+
):
165+
assert (
166+
parallel_dims.pp_enabled
167+
), "can't apply pipeline parallelism if it is not enabled"
168+
169+
# if job_config.model.norm_type == "fused_rmsnorm":
170+
# # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
171+
# # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm
172+
# raise NotImplementedError(
173+
# "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
174+
# )
175+
pp_mesh = parallel_dims.world_mesh["pp"]
176+
pp_rank = pp_mesh.get_local_rank()
177+
stage_idx = pp_mesh.get_local_rank()
178+
layers_per_rank = model_args.n_layers // parallel_dims.pp
179+
split_spec = {
180+
f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING
181+
for i in range(1, parallel_dims.pp)
182+
}
183+
# Get example input
184+
input_shape = (job_config.training.local_batch_size, job_config.training.seq_len)
185+
assert hasattr(model_args, "vocab_size")
186+
input_ids = torch.randint(
187+
model_args.vocab_size, input_shape, dtype=torch.int64, device="meta"
188+
)
189+
190+
# Create a pipeline representation from the model
191+
pipe = pipeline(
192+
model, mb_args=(input_ids,), split_spec=split_spec
193+
)
194+
model = pipe.get_stage_module(stage_idx)
195+
stage = _PipelineStage(
196+
stage_module=model,
197+
stage_index=pp_rank,
198+
pipe_info=pipe.pipe_info,
199+
device=device,
200+
group=pp_mesh.get_group(),
201+
)
202+
203+
# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
204+
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
205+
# optimizer, and checkpointing
206+
for i, m in enumerate(model_parts):
207+
# apply SPMD-style PT-D techniques
208+
m = parallelize_fn(m, parallel_dims, job_config)
209+
model_parts[i] = m
210+
# NOTE: this is to update the model in the stage
211+
# in case the model is modified e.g. by torch.compile
212+
stages[i].submod = m
213+
214+
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
215+
216+
# This is used in the train loop to determine whether to pass in the input_ids and labels
217+
has_first_stage = False
218+
has_last_stage = False
219+
for stage in stages:
220+
if stage.is_first:
221+
has_first_stage = True
222+
if stage.is_last:
223+
has_last_stage = True
224+
225+
return pp_schedule, model_parts, has_first_stage, has_last_stage

0 commit comments

Comments
 (0)