Skip to content

Commit 6273ced

Browse files
buptzybksivaman
andauthored
[PyTorch] Support delay_wgrad_compute cudagraph (#1948)
* support cudagraph dw Signed-off-by: Robin Zhang <[email protected]> * fix lint Signed-off-by: Robin Zhang <[email protected]> * fix ci Signed-off-by: Robin Zhang <[email protected]> --------- Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
1 parent 021e1e6 commit 6273ced

File tree

4 files changed

+85
-23
lines changed

4 files changed

+85
-23
lines changed

transformer_engine/pytorch/graph.py

Lines changed: 72 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -322,14 +322,16 @@ def _make_graphed_callables(
322322

323323
fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
324324
bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
325+
bwd_dw_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
325326
graph_callables = [None for _ in range(len(flatten_sample_args))]
326327

327328
# For cases with multiple active RNG states, e.g. TP.
328329
if graph_safe_rng_available():
329330
for _, state in get_all_rng_states().items():
330-
for fwd_graph, bwd_graph in zip(fwd_graphs, bwd_graphs):
331+
for fwd_graph, bwd_graph, bwd_dw_graph in zip(fwd_graphs, bwd_graphs, bwd_dw_graphs):
331332
fwd_graph.register_generator_state(state)
332333
bwd_graph.register_generator_state(state)
334+
bwd_dw_graph.register_generator_state(state)
333335

334336
mempool = graph_pool_handle() if pool is None else pool
335337

@@ -366,28 +368,40 @@ def _make_graphed_callables(
366368
), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique."
367369

368370
# Filter the TE modules that cudagraph can access.
369-
visited_te_modules = set()
370-
371-
def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
372-
if isinstance(module, TransformerEngineBaseModule):
373-
visited_te_modules.add(module)
374-
# If forward is called on a BasicOperation directly the hook will run
375-
elif isinstance(module, BasicOperation):
376-
visited_te_modules.add(module)
377-
# If forward is called on a te.ops.Sequential it is not called on its constituent ops
378-
elif isinstance(module, Sequential):
379-
assert module._module_groups is not None, "Should have been initialized by warmup"
380-
for module_group in module._module_groups:
381-
if isinstance(module_group, OperationFuser):
382-
for basic_op in module_group._basic_ops:
383-
visited_te_modules.add(basic_op)
371+
visited_te_modules = {}
372+
need_bwd_dw_graph = {}
384373

385374
# Run warmup and do the above filtering.
386375
with torch.cuda.stream(torch.cuda.Stream()):
387376
for func_idx, func in zip(warmup_func_idx, warmup_func):
388377
args = sample_args[func_idx]
389378
kwargs = sample_kwargs[func_idx]
390379
static_input_surface = per_callable_static_input_surfaces[func_idx]
380+
381+
def hook_fn(
382+
module, inputs, outputs, func_idx=func_idx
383+
): # pylint: disable=unused-argument
384+
modules = set()
385+
if isinstance(module, TransformerEngineBaseModule):
386+
modules.add(module)
387+
# If forward is called on a BasicOperation directly the hook will run
388+
elif isinstance(module, BasicOperation):
389+
modules.add(module)
390+
# If forward is called on a te.ops.Sequential it is not called on its constituent ops
391+
elif isinstance(module, Sequential):
392+
assert (
393+
module._module_groups is not None
394+
), "Should have been initialized by warmup"
395+
for module_group in module._module_groups:
396+
if isinstance(module_group, OperationFuser):
397+
for basic_op in module_group._basic_ops:
398+
modules.add(basic_op)
399+
if modules:
400+
if func_idx not in visited_te_modules:
401+
visited_te_modules[func_idx] = modules
402+
else:
403+
visited_te_modules[func_idx].update(modules)
404+
391405
for warmup_iter in range(num_warmup_iters):
392406
hooks = []
393407
for module in func.modules():
@@ -432,6 +446,15 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
432446
module_params_with_grad
433447
)
434448
per_callable_static_input_surfaces[func_idx] = static_input_surface
449+
450+
# Run wgrad. This is essential for some TE modules when they have
451+
# delay_wgrad_compute enabled.
452+
need_backward_dw = False
453+
for module in visited_te_modules.get(func_idx, set()):
454+
if hasattr(module, "need_backward_dw") and module.need_backward_dw():
455+
need_backward_dw = True
456+
module.backward_dw()
457+
need_bwd_dw_graph[func_idx] = need_backward_dw
435458
else:
436459
grad_inputs = None
437460
del outputs, grad_inputs
@@ -514,6 +537,17 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
514537
allow_unused=allow_unused_input,
515538
retain_graph=retain_graph_in_backward,
516539
)
540+
# If no one module needs the backward_dw, the bwd_dw_graph will be empty.
541+
# So skip capturing it.
542+
if need_bwd_dw_graph[per_callable_bwd_idx]:
543+
bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx]
544+
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
545+
for module in visited_te_modules[per_callable_bwd_idx]:
546+
if (
547+
hasattr(module, "need_backward_dw")
548+
and module.need_backward_dw()
549+
):
550+
module.backward_dw()
517551
# Constructs a tuple suitable for returning from Graphed.backward:
518552
# Pads out the actually-needed grads with Nones in gradient slots for inputs
519553
# that don't require grad. I couldn't think of a one-liner for this pattern.
@@ -582,10 +616,12 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
582616
# Capture backward graphs in reverse order
583617
per_callable_static_grad_outputs = []
584618
per_callable_static_grad_inputs = []
585-
for static_input_surface, static_outputs, bwd_graph in zip(
619+
for static_input_surface, static_outputs, bwd_graph, bwd_dw_graph, bwd_idx in zip(
586620
reversed(per_callable_static_input_surfaces),
587621
reversed(per_callable_static_outputs),
588622
reversed(bwd_graphs),
623+
reversed(bwd_dw_graphs),
624+
reversed(range(len(per_callable_static_input_surfaces))),
589625
):
590626
# For now, assumes all static_outputs require grad
591627
static_grad_outputs = tuple(
@@ -601,6 +637,11 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
601637
allow_unused=allow_unused_input,
602638
retain_graph=retain_graph_in_backward,
603639
)
640+
if need_bwd_dw_graph[bwd_idx]:
641+
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
642+
for module in visited_te_modules[bwd_idx]:
643+
if hasattr(module, "need_backward_dw") and module.need_backward_dw():
644+
module.backward_dw()
604645
# Constructs a tuple suitable for returning from Graphed.backward:
605646
# Pads out the actually-needed grads with Nones in gradient slots for inputs that
606647
# don't require grad. I couldn't think of a slick one-liner for this pattern.
@@ -732,9 +773,10 @@ def functionalized(*user_args, **user_kwargs):
732773
)
733774

734775
func = graph_callables[i]
776+
te_modules = visited_te_modules.get(i, set())
735777
if isinstance(func, torch.nn.Module):
736778

737-
def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
779+
def make_graphed_forward(func, graph_training_state, graphed, orig_fwd, te_modules):
738780
def new_fwd(*user_args, **user_kwargs):
739781
# If the module's training-or-eval state matches what we graphed,
740782
# run the graph, otherwise run the original forward method
@@ -743,7 +785,7 @@ def new_fwd(*user_args, **user_kwargs):
743785
if FP8GlobalStateManager.is_fp8_enabled():
744786
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
745787
for m in func.modules():
746-
if m not in visited_te_modules:
788+
if m not in te_modules:
747789
# Only Set the FP8 meta for the modules included by forward
748790
continue
749791
if isinstance(m, TransformerEngineBaseModule):
@@ -780,7 +822,7 @@ def new_fwd(*user_args, **user_kwargs):
780822

781823
return new_fwd
782824

783-
forward = make_graphed_forward(func, func.training, graphed, func.forward)
825+
forward = make_graphed_forward(func, func.training, graphed, func.forward, te_modules)
784826
if _order is None:
785827
func.forward = forward
786828
ret.append(func)
@@ -789,6 +831,16 @@ def new_fwd(*user_args, **user_kwargs):
789831
else:
790832
ret.append(graphed)
791833

834+
# Attach backward_dw as an attribute to the graphed callable.
835+
def backward_dw(
836+
need_backward_dw=need_bwd_dw_graph.get(i, False),
837+
bwd_dw_graph=bwd_dw_graphs[i],
838+
):
839+
if need_backward_dw:
840+
bwd_dw_graph.replay()
841+
842+
setattr(ret[-1], "backward_dw", backward_dw)
843+
792844
if just_one_callable:
793845
return ret[0]
794846

transformer_engine/pytorch/module/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ def __init__(self) -> None:
662662
self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
663663
self.activation_dtype: Optional[torch.dtype] = None
664664
self.wgrad_accumulation_and_reduce_hooks = []
665+
self.wgrad_store = None
665666

666667
if not TEDebugState.debug_enabled:
667668
TEDebugState.initialize()
@@ -1481,12 +1482,21 @@ def register_wgrad_accumulation_and_reduce_hooks(self, wgrad_accumulation_and_re
14811482
"""
14821483
self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook)
14831484

1485+
def need_backward_dw(self):
1486+
"""
1487+
Check if this module needs to execute the delayed weight gradient computation.
1488+
This method should be used at the beginning of self.backward_dw() to determine if it
1489+
should actually be executed or just return without doing anything.
1490+
User can also manually call this method to check that before calling into backward_dw().
1491+
"""
1492+
return self.wgrad_store is not None and self.wgrad_store.delay_wgrad_compute()
1493+
14841494
def backward_dw(self):
14851495
"""
14861496
Execute the delayed weight gradient computation.
14871497
This method is called after the main backward pass to compute weight gradients.
14881498
"""
1489-
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
1499+
if not self.need_backward_dw():
14901500
return
14911501
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
14921502
(wgrad, bgrad), _ = self.wgrad_store.pop()

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@ def backward_dw(self):
840840
Execute the delayed weight gradient computation.
841841
This method is called after the main backward pass to compute weight gradients.
842842
"""
843-
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
843+
if not self.need_backward_dw():
844844
return
845845
with torch.cuda.nvtx.range("_GroupedLinear_wgrad"):
846846
(_, grad_biases_, _), tensor_list = self.wgrad_store.pop()

transformer_engine/pytorch/module/layernorm_mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2211,7 +2211,7 @@ def backward_dw(self):
22112211
Execute the delayed weight gradient computation.
22122212
This method is called after the main backward pass to compute weight gradients.
22132213
"""
2214-
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
2214+
if not self.need_backward_dw():
22152215
return
22162216
with torch.cuda.nvtx.range("_LayerNormMLP_wgrad"):
22172217
(fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop()

0 commit comments

Comments
 (0)