Skip to content

Commit

Permalink
Fix issues for MCore DDP. (#1474)
Browse files Browse the repository at this point in the history
* Fix issues for MCore DDP.

Signed-off-by: Dennis Liu <[email protected]>

* Remove force data release for CPU offloading.

Signed-off-by: Dennis Liu <[email protected]>

* Add preserved attributeds.

Signed-off-by: Dennis Liu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add main_grad to prevserved attributes.

Signed-off-by: Dennis Liu <[email protected]>

* Change prepare_for_saving to original tensor and add .data to CPU hook.

Signed-off-by: Dennis Liu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update.

Signed-off-by: Dennis Liu <[email protected]>

* Fix for LayernormLinear in FP8.

Signed-off-by: Dennis Liu <[email protected]>

---------

Signed-off-by: Dennis Liu <[email protected]>
Co-authored-by: Xin Yao <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 19, 2025
1 parent 6673f16 commit 978f1d7
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 25 deletions.
4 changes: 3 additions & 1 deletion transformer_engine/pytorch/cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def __init__(
super().__init__()

def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs)
retrieve_identifier = self.offload_handler.tensor_push(
tensor.data, **self.handler_extra_kwargs
)
return retrieve_identifier

def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
Expand Down
19 changes: 12 additions & 7 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def backward(
( # pylint: disable=unbalanced-tuple-unpacking
inputmat,
weight,
_,
origin_weight,
bias,
ln_weight,
ln_out,
Expand Down Expand Up @@ -722,17 +722,22 @@ def backward(

if ctx.requires_wgrad:
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(weight, "grad_added_to_main_grad"):
weight.grad_added_to_main_grad = True
if getattr(weight, "zero_out_wgrad", False):
if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"):
origin_weight.grad_added_to_main_grad = True
if getattr(origin_weight, "zero_out_wgrad", False):
wgrad = torch.zeros(
weight.main_grad.shape,
dtype=weight.dtype,
origin_weight.main_grad.shape,
dtype=origin_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
wgrad = None
wgrad = torch.empty(
origin_weight.main_grad.shape,
dtype=origin_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else:
Expand Down
7 changes: 6 additions & 1 deletion transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
requires_grad=False,
)
else:
wgrad = None
wgrad = torch.empty(
weight.main_grad.shape,
dtype=weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else:
Expand Down
22 changes: 6 additions & 16 deletions transformer_engine/pytorch/tensor/quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def prepare_for_saving(
tensor_list.append(None)
tensor_objects_list.append(None)
elif type(tensor) in (torch.Tensor, torch.nn.Parameter):
tensor_list.append(tensor.data)
tensor_list.append(tensor)
tensor_objects_list.append(None)
else:
t, t_obj = tensor.prepare_for_saving()
Expand Down Expand Up @@ -116,10 +116,7 @@ def update_quantized(
"""Quantize tensor in-place"""

def quantize(
self,
tensor: torch.Tensor,
*,
out: Optional[QuantizedTensor] = None,
self, tensor: torch.Tensor, *, out: Optional[QuantizedTensor] = None
) -> QuantizedTensor:
"""Quantize tensor"""
if out is not None:
Expand Down Expand Up @@ -159,10 +156,7 @@ def calibrate(self, tensor: torch.Tensor) -> None:
"""

def set_usage(
self,
*,
rowwise: Optional[bool] = None,
columnwise: Optional[bool] = None,
self, *, rowwise: Optional[bool] = None, columnwise: Optional[bool] = None
) -> None:
"""Set how the quantized tensor is expected to be used
Expand Down Expand Up @@ -194,8 +188,7 @@ def forward(

@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
_ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor # unused
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
Expand All @@ -212,9 +205,7 @@ class _IdentityFunc(torch.autograd.Function):

@staticmethod
def forward(
ctx,
tensor: QuantizedTensor,
init_kwargs: Optional[Dict[str, Any]] = None,
ctx, tensor: QuantizedTensor, init_kwargs: Optional[Dict[str, Any]] = None
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring

Expand Down Expand Up @@ -408,8 +399,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
return torch._C._disabled_torch_function_impl(func, types, args, kwargs)

def contiguous(
self,
memory_format: torch.memory_format = torch.contiguous_format,
self, memory_format: torch.memory_format = torch.contiguous_format
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
raise NotImplementedError(
Expand Down

0 comments on commit 978f1d7

Please sign in to comment.