File tree Expand file tree Collapse file tree 1 file changed +8
-10
lines changed
transformer_engine/pytorch/module Expand file tree Collapse file tree 1 file changed +8
-10
lines changed Original file line number Diff line number Diff line change 1313
1414from transformer_engine .common .recipe import Recipe
1515from .base import (
16+ get_dummy_wgrad ,
1617 get_multi_stream_cublas_workspace ,
1718 TransformerEngineBaseModule ,
1819 _2X_ACC_FPROP ,
@@ -447,18 +448,15 @@ def handle_custom_ddp_from_mcore(weight, wgrad):
447448 ):
448449 weight .grad_added_to_main_grad = True
449450 if getattr (weight , "zero_out_wgrad" , False ):
450- wgrad = torch .zeros (
451- weight .main_grad .shape ,
452- dtype = weight .dtype ,
453- device = torch .cuda .current_device (),
454- requires_grad = False ,
451+ wgrad = get_dummy_wgrad (
452+ list (weight .main_grad .shape ),
453+ weight .dtype ,
454+ zero = True ,
455455 )
456456 else :
457- wgrad = torch .empty (
458- weight .main_grad .shape ,
459- dtype = weight .dtype ,
460- device = torch .cuda .current_device (),
461- requires_grad = False ,
457+ wgrad = get_dummy_wgrad (
458+ list (weight .main_grad .shape ),
459+ weight .dtype ,
462460 )
463461 elif ctx .fuse_wgrad_accumulation :
464462 wgrad = None
You can’t perform that action at this time.
0 commit comments