5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
8
- from typing import Callable , Literal
8
+ from typing import Callable , Literal , Dict
9
9
10
10
import torch
11
11
import torch .nn as nn
22
22
)
23
23
from torch .distributed .tensor .parallel import ParallelStyle
24
24
25
+ import threading
26
+ import torch
27
+ from typing import Optional
28
+ import time
29
+
30
+ class HookSequenceCoordinator :
31
+ """Coordinates hooks based on a predefined sequence"""
32
+
33
+ def __init__ (self ):
34
+ self ._lock = threading .Lock ()
35
+ self ._condition = threading .Condition (self ._lock )
36
+
37
+ # Define your desired execution sequence matching:
38
+ # stageB.combine() -> stageA.forward_attention() -> stageB.backward_moe() ->
39
+ # stageA.dispatch() -> stageB.dispatch() -> stageA.forward_moe() ->
40
+ # stageB.backward_attention() -> stageA.combine()
41
+ self ._hook_sequence = [
42
+ "combine_D_bwd" ,
43
+ "dispatch_A_fwd" ,
44
+ "combine_C_bwd" ,
45
+ "dispatch_B_fwd" ,
46
+ "dispatch_B_bwd" ,
47
+ "combine_C_fwd" ,
48
+ "dispatch_A_bwd" ,
49
+ "combine_D_fwd" ,
50
+ ]
51
+ # Create a semaphore for each hook in the sequence
52
+ self ._semaphores : Dict [str , threading .Semaphore ] = {}
53
+ self ._reset_semaphores ()
54
+
55
+ # Coordination control - disabled by default
56
+ self ._coordination_enabled = False
57
+ self ._cycle_count = 0
58
+
59
+ def _reset_semaphores (self ):
60
+ """Reset all semaphores - first one gets 1 permit, others get 0"""
61
+ self ._semaphores .clear ()
62
+ for i , hook_name in enumerate (self ._hook_sequence ):
63
+ # First semaphore starts with 1 permit, others start with 0
64
+ initial_permits = 1 if i == 0 else 0
65
+ self ._semaphores [hook_name ] = threading .Semaphore (initial_permits )
66
+
67
+ def enable_coordination (self ):
68
+ """Enable hook coordination"""
69
+ self ._coordination_enabled = True
70
+ self ._reset_semaphores () # Reset semaphores when enabling
71
+ print ("[COORDINATION] Hook coordination ENABLED" )
72
+
73
+ def disable_coordination (self ):
74
+ """Disable hook coordination"""
75
+ self ._coordination_enabled = False
76
+ # Release all semaphores so no threads get stuck
77
+ for semaphore in self ._semaphores .values ():
78
+ try :
79
+ semaphore .release ()
80
+ except ValueError :
81
+ pass # Semaphore was already at max value
82
+ print ("[COORDINATION] Hook coordination DISABLED" )
83
+
84
+ def is_coordination_enabled (self ) -> bool :
85
+ """Check if coordination is currently enabled"""
86
+ return self ._coordination_enabled
87
+
88
+ def reset_coordination (self ):
89
+ """Reset coordination state (useful between training runs)"""
90
+ self ._cycle_count = 0
91
+ self ._reset_semaphores ()
92
+ print ("[COORDINATION] Hook coordination state RESET" )
93
+
94
+ def acquire_execution (self , hook_name : str ):
95
+ """Acquire execution permission using semaphores"""
96
+ # If coordination is disabled, just pass through
97
+ if not self ._coordination_enabled :
98
+ print (f"[PASSTHROUGH] { hook_name } executing (coordination disabled)" )
99
+ return
100
+
101
+ # Check if hook is in our sequence
102
+ if hook_name not in self ._semaphores :
103
+ print (f"[WARNING] { hook_name } not in sequence, executing without coordination" )
104
+ return
105
+
106
+ # Acquire the semaphore for this hook (blocks until available)
107
+ print (f"[WAITING] { hook_name } waiting for semaphore" )
108
+ self ._semaphores [hook_name ].acquire ()
109
+ print (f"[EXECUTING] { hook_name } acquired semaphore" )
110
+
111
+ def release_execution (self , hook_name : str ):
112
+ """Release execution and signal next hook"""
113
+ # If coordination is disabled, just pass through
114
+ if not self ._coordination_enabled :
115
+ return
116
+
117
+ # Check if hook is in our sequence
118
+ if hook_name not in self ._semaphores :
119
+ return
120
+
121
+ # Find the next hook in the sequence and release its semaphore
122
+ try :
123
+ current_index = self ._hook_sequence .index (hook_name )
124
+ next_index = (current_index + 1 ) % len (self ._hook_sequence )
125
+ next_hook = self ._hook_sequence [next_index ]
126
+
127
+ print (f"[COMPLETED] { hook_name } completed, signaling { next_hook } " )
128
+ self ._semaphores [next_hook ].release ()
129
+
130
+ # Check if we completed a full cycle
131
+ if next_index == 0 :
132
+ self ._cycle_count += 1
133
+ print (f"[CYCLE] Completed cycle { self ._cycle_count } " )
134
+
135
+ except ValueError :
136
+ print (f"[ERROR] { hook_name } not found in sequence" )
137
+
138
+ # Global coordinator
139
+ _hook_coordinator = HookSequenceCoordinator ()
140
+
141
+ class SyncHook (torch .autograd .Function ):
142
+ """Sync hook that follows a predefined execution sequence"""
143
+
144
+ @staticmethod
145
+ def forward (ctx , x , hook_name ):
146
+ ctx .hook_name = hook_name
147
+
148
+ # Use forward-specific hook name
149
+ forward_hook_name = f"{ hook_name } _fwd"
150
+ _hook_coordinator .acquire_execution (forward_hook_name )
151
+
152
+ try :
153
+ if _hook_coordinator .is_coordination_enabled ():
154
+ print (f"[FORWARD HOOK] { forward_hook_name } (coordinated)" )
155
+ else :
156
+ print (f"[FORWARD HOOK] { forward_hook_name } (uncoordinated)" )
157
+ return x
158
+ finally :
159
+ _hook_coordinator .release_execution (forward_hook_name )
160
+
161
+ @staticmethod
162
+ def backward (ctx , grad_output ):
163
+ hook_name = ctx .hook_name
164
+
165
+ # Use backward-specific hook name
166
+ backward_hook_name = f"{ hook_name } _bwd"
167
+ _hook_coordinator .acquire_execution (backward_hook_name )
168
+
169
+ try :
170
+ if _hook_coordinator .is_coordination_enabled ():
171
+ print (f"[BACKWARD HOOK] { backward_hook_name } (coordinated)" )
172
+ else :
173
+ print (f"[BACKWARD HOOK] { backward_hook_name } (uncoordinated)" )
174
+ return grad_output , None
175
+ finally :
176
+ _hook_coordinator .release_execution (backward_hook_name )
177
+
178
+
25
179
26
180
TOKEN_GROUP_ALIGN_SIZE_M = 8
27
181
ValidTokenGroupAlignmentSize = Literal [8 , 16 , 32 ]
@@ -77,7 +231,6 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
77
231
self ._partition_fn ,
78
232
)
79
233
80
-
81
234
class ExpertParallel (ParallelStyle ):
82
235
def __init__ (self ):
83
236
super ().__init__ ()
@@ -90,6 +243,9 @@ def _token_dispatch(self, mod, inputs, device_mesh):
90
243
routed_input , num_tokens_per_expert = inputs
91
244
ep_size = device_mesh .shape [0 ]
92
245
246
+ # HOOK: signal ready for sync
247
+ routed_input = SyncHook .apply (routed_input , "dispatch_A" )
248
+
93
249
# generate the input splits and output splits for all-to-all
94
250
with torch .no_grad ():
95
251
num_tokens_per_expert_group = all_to_all_single (
@@ -135,6 +291,9 @@ def _token_dispatch(self, mod, inputs, device_mesh):
135
291
# generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
136
292
# each expert gets locally is a multiple of ALIGN_SIZE_M.
137
293
294
+ # HOOK: signal ready for sync
295
+ routed_input = SyncHook .apply (routed_input , "dispatch_B" )
296
+
138
297
return routed_input , num_tokens_per_expert_group
139
298
140
299
@staticmethod
@@ -146,12 +305,16 @@ def _partition_fn(name, mod, device_mesh):
146
305
147
306
# performing all-to-all combine on the output
148
307
def _token_combine (self , mod , routed_output , device_mesh ):
308
+ # HOOK: signal ready for sync
309
+ routed_output = SyncHook .apply (routed_output , "combine_C" )
149
310
routed_output = all_to_all_single_autograd (
150
311
routed_output ,
151
312
self .input_splits ,
152
313
self .output_splits ,
153
314
device_mesh .get_group (),
154
315
)
316
+ # HOOK: signal ready for sync
317
+ routed_output = SyncHook .apply (routed_output , "combine_D" )
155
318
return routed_output
156
319
157
320
def _apply (self , module : nn .Module , device_mesh : DeviceMesh ) -> nn .Module :
0 commit comments