Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete extra tensor objects after restoring float8 tensors #1500

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,9 @@ def backward(
mu,
rsigma,
) = restore_from_saved(ctx.tensor_objects, saved_tensors)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None

# Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = (
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,10 @@ def backward(
mu,
rsigma,
) = restore_from_saved(ctx.tensor_objects, saved_tensors)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None

# Since main_grad can be modified inplace, it should not be a part of saved_tensors
fc1_weight_main_grad = (
ctx.fc1_main_grad
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking
restore_from_saved(ctx.tensor_objects, saved_tensors)
)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None

# Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8Tensor
"""
tensors = [self._data, self._transpose]
# self._data = None
# self._transpose = None
self._data = None
self._transpose = None
Comment on lines +108 to +109
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pggPL IIRC you removed these during a numerics debugging effort, do you remember why?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If weight is in fp8 I want to have it in save_for_backward() - for offloading. If there is forward, but backward is not invoked, it will result in removing the weight. I discussed it with @ptrendx and he proposed some solution with flag internal in prepare_for_saving - to set it True if tensor is not owned and remove tensors iff they are internal. It seems that we forgot about this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so that would be solved by overriding this function in Float8Tensor and MXFP8Tensor to just return self and None instead.

Also, in https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/tensor/quantized_tensor.py#L30 why do we check for exactly Tensor or Param and not just isinstance(torch.Tensor)? This should solve this as well, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now logic of restoring tensor is inside the tensor object. If tensor object is None, we assume that this was standard torch.tensor. If it is QuantizedTensor, then it object is responsible for restoring itself, so we need to somehow save it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, QuantizedTensor is in a way a standard tensor - at least it can be passed whole through save_for_backward, so there is nothing to restore afterwards.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, it makes sense

return tensors, self

def restore_from_saved(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorB
"""
tensors = [self._rowwise_data, self._columnwise_data]
# self._rowwise_data = None
# self._columnwise_data = None
self._rowwise_data = None
self._columnwise_data = None
return tensors, self

def restore_from_saved(
Expand Down
9 changes: 9 additions & 0 deletions transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,15 @@ def clear(self):
self._transpose = torch.Tensor() if self._transpose is not None else None
self._transpose_invalid = True

# def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]:
# """Prepare the tensor base for saving for backward

# After calling this, the tensor instance does not hold any
# data.

# """
# return [self], None

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):

Expand Down
9 changes: 9 additions & 0 deletions transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,15 @@ def clear(self):
self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None
self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None

# def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]:
# """Prepare the tensor base for saving for backward

# After calling this, the tensor instance does not hold any
# data.

# """
# return [self], None

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):

Expand Down