@@ -322,14 +322,16 @@ def _make_graphed_callables(
322322
323323 fwd_graphs = [torch .cuda .CUDAGraph () for _ in range (len (flatten_sample_args ))]
324324 bwd_graphs = [torch .cuda .CUDAGraph () for _ in range (len (flatten_sample_args ))]
325+ bwd_dw_graphs = [torch .cuda .CUDAGraph () for _ in range (len (flatten_sample_args ))]
325326 graph_callables = [None for _ in range (len (flatten_sample_args ))]
326327
327328 # For cases with multiple active RNG states, e.g. TP.
328329 if graph_safe_rng_available ():
329330 for _ , state in get_all_rng_states ().items ():
330- for fwd_graph , bwd_graph in zip (fwd_graphs , bwd_graphs ):
331+ for fwd_graph , bwd_graph , bwd_dw_graph in zip (fwd_graphs , bwd_graphs , bwd_dw_graphs ):
331332 fwd_graph .register_generator_state (state )
332333 bwd_graph .register_generator_state (state )
334+ bwd_dw_graph .register_generator_state (state )
333335
334336 mempool = graph_pool_handle () if pool is None else pool
335337
@@ -366,28 +368,40 @@ def _make_graphed_callables(
366368 ), f"Warmup runs { len (warmup_func )} but only { len (set (warmup_func_idx ))} are unique."
367369
368370 # Filter the TE modules that cudagraph can access.
369- visited_te_modules = set ()
370-
371- def hook_fn (module , inputs , outputs ): # pylint: disable=unused-argument
372- if isinstance (module , TransformerEngineBaseModule ):
373- visited_te_modules .add (module )
374- # If forward is called on a BasicOperation directly the hook will run
375- elif isinstance (module , BasicOperation ):
376- visited_te_modules .add (module )
377- # If forward is called on a te.ops.Sequential it is not called on its constituent ops
378- elif isinstance (module , Sequential ):
379- assert module ._module_groups is not None , "Should have been initialized by warmup"
380- for module_group in module ._module_groups :
381- if isinstance (module_group , OperationFuser ):
382- for basic_op in module_group ._basic_ops :
383- visited_te_modules .add (basic_op )
371+ visited_te_modules = {}
372+ need_bwd_dw_graph = {}
384373
385374 # Run warmup and do the above filtering.
386375 with torch .cuda .stream (torch .cuda .Stream ()):
387376 for func_idx , func in zip (warmup_func_idx , warmup_func ):
388377 args = sample_args [func_idx ]
389378 kwargs = sample_kwargs [func_idx ]
390379 static_input_surface = per_callable_static_input_surfaces [func_idx ]
380+
381+ def hook_fn (
382+ module , inputs , outputs , func_idx = func_idx
383+ ): # pylint: disable=unused-argument
384+ modules = set ()
385+ if isinstance (module , TransformerEngineBaseModule ):
386+ modules .add (module )
387+ # If forward is called on a BasicOperation directly the hook will run
388+ elif isinstance (module , BasicOperation ):
389+ modules .add (module )
390+ # If forward is called on a te.ops.Sequential it is not called on its constituent ops
391+ elif isinstance (module , Sequential ):
392+ assert (
393+ module ._module_groups is not None
394+ ), "Should have been initialized by warmup"
395+ for module_group in module ._module_groups :
396+ if isinstance (module_group , OperationFuser ):
397+ for basic_op in module_group ._basic_ops :
398+ modules .add (basic_op )
399+ if modules :
400+ if func_idx not in visited_te_modules :
401+ visited_te_modules [func_idx ] = modules
402+ else :
403+ visited_te_modules [func_idx ].update (modules )
404+
391405 for warmup_iter in range (num_warmup_iters ):
392406 hooks = []
393407 for module in func .modules ():
@@ -432,6 +446,15 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
432446 module_params_with_grad
433447 )
434448 per_callable_static_input_surfaces [func_idx ] = static_input_surface
449+
450+ # Run wgrad. This is essential for some TE modules when they have
451+ # delay_wgrad_compute enabled.
452+ need_backward_dw = False
453+ for module in visited_te_modules .get (func_idx , set ()):
454+ if hasattr (module , "need_backward_dw" ) and module .need_backward_dw ():
455+ need_backward_dw = True
456+ module .backward_dw ()
457+ need_bwd_dw_graph [func_idx ] = need_backward_dw
435458 else :
436459 grad_inputs = None
437460 del outputs , grad_inputs
@@ -514,6 +537,17 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
514537 allow_unused = allow_unused_input ,
515538 retain_graph = retain_graph_in_backward ,
516539 )
540+ # If no one module needs the backward_dw, the bwd_dw_graph will be empty.
541+ # So skip capturing it.
542+ if need_bwd_dw_graph [per_callable_bwd_idx ]:
543+ bwd_dw_graph = bwd_dw_graphs [per_callable_bwd_idx ]
544+ with _graph_context_wrapper (bwd_dw_graph , pool = mempool ):
545+ for module in visited_te_modules [per_callable_bwd_idx ]:
546+ if (
547+ hasattr (module , "need_backward_dw" )
548+ and module .need_backward_dw ()
549+ ):
550+ module .backward_dw ()
517551 # Constructs a tuple suitable for returning from Graphed.backward:
518552 # Pads out the actually-needed grads with Nones in gradient slots for inputs
519553 # that don't require grad. I couldn't think of a one-liner for this pattern.
@@ -582,10 +616,12 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
582616 # Capture backward graphs in reverse order
583617 per_callable_static_grad_outputs = []
584618 per_callable_static_grad_inputs = []
585- for static_input_surface , static_outputs , bwd_graph in zip (
619+ for static_input_surface , static_outputs , bwd_graph , bwd_dw_graph , bwd_idx in zip (
586620 reversed (per_callable_static_input_surfaces ),
587621 reversed (per_callable_static_outputs ),
588622 reversed (bwd_graphs ),
623+ reversed (bwd_dw_graphs ),
624+ reversed (range (len (per_callable_static_input_surfaces ))),
589625 ):
590626 # For now, assumes all static_outputs require grad
591627 static_grad_outputs = tuple (
@@ -601,6 +637,11 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
601637 allow_unused = allow_unused_input ,
602638 retain_graph = retain_graph_in_backward ,
603639 )
640+ if need_bwd_dw_graph [bwd_idx ]:
641+ with _graph_context_wrapper (bwd_dw_graph , pool = mempool ):
642+ for module in visited_te_modules [bwd_idx ]:
643+ if hasattr (module , "need_backward_dw" ) and module .need_backward_dw ():
644+ module .backward_dw ()
604645 # Constructs a tuple suitable for returning from Graphed.backward:
605646 # Pads out the actually-needed grads with Nones in gradient slots for inputs that
606647 # don't require grad. I couldn't think of a slick one-liner for this pattern.
@@ -732,9 +773,10 @@ def functionalized(*user_args, **user_kwargs):
732773 )
733774
734775 func = graph_callables [i ]
776+ te_modules = visited_te_modules .get (i , set ())
735777 if isinstance (func , torch .nn .Module ):
736778
737- def make_graphed_forward (func , graph_training_state , graphed , orig_fwd ):
779+ def make_graphed_forward (func , graph_training_state , graphed , orig_fwd , te_modules ):
738780 def new_fwd (* user_args , ** user_kwargs ):
739781 # If the module's training-or-eval state matches what we graphed,
740782 # run the graph, otherwise run the original forward method
@@ -743,7 +785,7 @@ def new_fwd(*user_args, **user_kwargs):
743785 if FP8GlobalStateManager .is_fp8_enabled ():
744786 fp8_recipe = FP8GlobalStateManager .get_fp8_recipe ()
745787 for m in func .modules ():
746- if m not in visited_te_modules :
788+ if m not in te_modules :
747789 # Only Set the FP8 meta for the modules included by forward
748790 continue
749791 if isinstance (m , TransformerEngineBaseModule ):
@@ -780,7 +822,7 @@ def new_fwd(*user_args, **user_kwargs):
780822
781823 return new_fwd
782824
783- forward = make_graphed_forward (func , func .training , graphed , func .forward )
825+ forward = make_graphed_forward (func , func .training , graphed , func .forward , te_modules )
784826 if _order is None :
785827 func .forward = forward
786828 ret .append (func )
@@ -789,6 +831,16 @@ def new_fwd(*user_args, **user_kwargs):
789831 else :
790832 ret .append (graphed )
791833
834+ # Attach backward_dw as an attribute to the graphed callable.
835+ def backward_dw (
836+ need_backward_dw = need_bwd_dw_graph .get (i , False ),
837+ bwd_dw_graph = bwd_dw_graphs [i ],
838+ ):
839+ if need_backward_dw :
840+ bwd_dw_graph .replay ()
841+
842+ setattr (ret [- 1 ], "backward_dw" , backward_dw )
843+
792844 if just_one_callable :
793845 return ret [0 ]
794846
0 commit comments