Skip to content

Commit d2945c6

Browse files
Autumn1998yaox12
andauthored
[PyTorch] Use dummy wgrad in GroupedLinear (#2305)
dummy wgrad Signed-off-by: tongliu <[email protected]> Signed-off-by: Xin Yao <[email protected]> Co-authored-by: Xin Yao <[email protected]>
1 parent 87cb26c commit d2945c6

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from transformer_engine.common.recipe import Recipe
1515
from .base import (
16+
get_dummy_wgrad,
1617
get_multi_stream_cublas_workspace,
1718
TransformerEngineBaseModule,
1819
_2X_ACC_FPROP,
@@ -447,18 +448,15 @@ def handle_custom_ddp_from_mcore(weight, wgrad):
447448
):
448449
weight.grad_added_to_main_grad = True
449450
if getattr(weight, "zero_out_wgrad", False):
450-
wgrad = torch.zeros(
451-
weight.main_grad.shape,
452-
dtype=weight.dtype,
453-
device=torch.cuda.current_device(),
454-
requires_grad=False,
451+
wgrad = get_dummy_wgrad(
452+
list(weight.main_grad.shape),
453+
weight.dtype,
454+
zero=True,
455455
)
456456
else:
457-
wgrad = torch.empty(
458-
weight.main_grad.shape,
459-
dtype=weight.dtype,
460-
device=torch.cuda.current_device(),
461-
requires_grad=False,
457+
wgrad = get_dummy_wgrad(
458+
list(weight.main_grad.shape),
459+
weight.dtype,
462460
)
463461
elif ctx.fuse_wgrad_accumulation:
464462
wgrad = None

0 commit comments

Comments
 (0)