Skip to content

Commit

Permalink
[PyTorch] Normalization ops (#1033)
Browse files Browse the repository at this point in the history
* Add layer norm op

Signed-off-by: Tim Moon <[email protected]>

* Add FP8 cast op

Signed-off-by: Tim Moon <[email protected]>

* Add tests for linear and layernorm with FP8 output

Signed-off-by: Tim Moon <[email protected]>

* RMSNorm op

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

* Replace LayerNorm module with LayerNorm op

Signed-off-by: Tim Moon <[email protected]>

* Replace RMSNorm module with RMSNorm op

Signed-off-by: Tim Moon <[email protected]>

* Add AMP support

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Do not save autograd context if grad mode is disabled

Debugging ONNX export tests.

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Forward args in pre_forward func to base op class

Signed-off-by: Tim Moon <[email protected]>

* Update to use QuantizedTensor class

Signed-off-by: Tim Moon <[email protected]>

* Apply suggestions from code review

Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Review suggestions from @ptrendx

Rename "CastFloat8" op to "Quantize". Add more fine-grained control for SM margin. Add docs for legacy sequence_parallel kwarg.

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

* Use weight dtype as default compute dtype

Signed-off-by: Tim Moon <[email protected]>

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Przemyslaw Tredak <[email protected]>
3 people authored Nov 5, 2024

Verified

This commit was signed with the committer’s verified signature.
renovate-bot Mend Renovate
1 parent f20d3dd commit 77c37d4
Showing 12 changed files with 1,416 additions and 511 deletions.
515 changes: 411 additions & 104 deletions tests/pytorch/test_fusible_ops.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -82,6 +82,7 @@ def _load_library():
from transformer_engine.pytorch.distributed import checkpoint
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context
from transformer_engine.pytorch import ops
from transformer_engine.pytorch import optimizers

# Register custom op symbolic ONNX functions
294 changes: 112 additions & 182 deletions transformer_engine/pytorch/module/layernorm.py
Original file line number Diff line number Diff line change
@@ -3,158 +3,90 @@
# See LICENSE for license information.

"""LayerNorm API"""
import os
import warnings
from typing import Union, Tuple, Optional
from typing import Iterable, Optional, Union

import torch
from torch.nn.parameter import Parameter
from torch.nn import init

import transformer_engine_torch as tex
from ..cpp_extensions import (
layernorm_fwd_inf,
)
from ..jit import no_torch_dynamo
from ..utils import cast_if_needed
from transformer_engine.pytorch.ops import LayerNorm as _LayerNormOp

__all__ = ["LayerNorm"]


class _LayerNorm(torch.autograd.Function):
"""functional LayerNorm"""

@staticmethod
def forward(
ctx,
inp: torch.Tensor,
ln_weight: torch.Tensor,
ln_bias: torch.Tensor,
eps: float,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
inf_ln_sm_margin: int,
zero_centered_gamma: bool,
is_grad_enabled: bool,
activation_dtype: torch.dtype,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert inp.shape[-1] == in_features, "LayerNorm not possible"
inputmat = inp.view((-1, in_features))

# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype)
ln_bias = cast_if_needed(ln_bias, activation_dtype)

if is_grad_enabled:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
ctx.inp_shape = inp.shape
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
else:
ln_out, mu, rsigma = (
layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps, inf_ln_sm_margin, zero_centered_gamma
),
None,
None,
)
return ln_out.view_as(inp)

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
inputmat, ln_weight, mu, rsigma = ctx.saved_tensors
grad_output = grad_output.contiguous()
d_ln_out = grad_output.view(inputmat.shape)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None, None, None

class LayerNorm(_LayerNormOp):
r"""Layer Normalization
class LayerNorm(torch.nn.Module):
r"""
Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
size :attr:`hidden_size`
:math:`\gamma` and :math:`\beta` are learnable affine transform
parameters that match the inner-most dimensions of the input
tensor.
Parameters
----------
hidden_size : int
size of each input sample.
normalized_shape: int or iterable of int
Inner dimensions of input tensor
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
A value added to the denominator of layer normalization for
numerical stability
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
If `True`, the :math:`\gamma` parameter is initialized to zero
and the calculation changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta
sm_margin: int or dict, default = 0
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
margin at each compute stage ("forward", "backward",
"inference").
Legacy
------
sequence_parallel: bool
Set a bool attr named `sequence_parallel` in the parameters.
This is custom logic for Megatron-LM integration.
"""

def __init__(
self,
hidden_size: int,
normalized_shape: Union[Iterable[int], int],
eps: float = 1e-5,
sequence_parallel: bool = False,
params_dtype: Optional[torch.dtype] = None,
sequence_parallel: Optional[bool] = None, # legacy
params_dtype: Optional[torch.dtype] = None, # deprecated
zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
**kwargs,
) -> None:
super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.eps = eps
self.zero_centered_gamma = zero_centered_gamma
self.weight = Parameter(
torch.empty(
hidden_size,
device=device,
dtype=params_dtype,
)
)
self.bias = Parameter(
torch.empty(
hidden_size,
device=device,
dtype=params_dtype,
)
)
self.sequence_parallel = sequence_parallel
self.activation_dtype: Optional[torch.dtype] = None

self.reset_parameters(defer_init=device == "meta")
# Handle deprecated options
if params_dtype is not None:
if "dtype" in kwargs:
raise RuntimeError(
"Both `dtype` and `params_dtype` (deprecated) kwargs are provided"
)
kwargs["dtype"] = params_dtype

# Initialize layer norm operation
super().__init__(
normalized_shape,
eps=eps,
zero_centered_gamma=zero_centered_gamma,
**kwargs,
)

# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel

def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
@@ -164,64 +96,62 @@ def reset_layer_norm_parameters(self) -> None:
DeprecationWarning,
stacklevel=2,
)
if not self.zero_centered_gamma:
init.ones_(self.weight)
else:
init.zeros_(self.weight)
init.zeros_(self.bias)
self.reset_parameters()

def reset_parameters(self, defer_init=False) -> None:
def reset_parameters(self, defer_init: Optional[bool] = None) -> None:
"""Init LayerNorm parameters"""
if defer_init:
return

if self.weight.device == torch.device("meta"):
self.weight = torch.nn.Parameter(torch.empty_like(self.weight, device="cuda"))
setattr(self.weight, "sequence_parallel", self.sequence_parallel)
init.constant_(self.weight, float(not self.zero_centered_gamma))

if self.bias.device == torch.device("meta"):
self.bias = torch.nn.Parameter(torch.empty_like(self.bias, device="cuda"))
setattr(self.bias, "sequence_parallel", self.sequence_parallel)
init.zeros_(self.bias)

@no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor:
# pylint: disable=missing-function-docstring

# Set the activation type for AMP.
# Note: This will soon be deprecated with
# https://github.com/NVIDIA/TransformerEngine/pull/1033
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
elif self.activation_dtype != inp.dtype:
dtype = inp.dtype
for name, param in self.named_parameters():
if param is not None:
assert dtype == param.dtype, (
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
self.activation_dtype = dtype

if torch.is_grad_enabled():
fwd_fn = _LayerNorm.apply
args = []
else:
fwd_fn = _LayerNorm.forward
args = [None]

args += (
inp,
self.weight,
self.bias,
self.eps,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.inf_ln_sm_margin,
self.zero_centered_gamma,
torch.is_grad_enabled(),
self.activation_dtype,
)

return fwd_fn(*args)
# Check whether to defer init (deprecated)
if defer_init is not None:
warnings.warn(
"defer_init argument to reset_parameters function is deprecated. Set device to"
' "meta" instead.',
DeprecationWarning,
stacklevel=2,
)
if defer_init:
return

# Reset parameters
super().reset_parameters()

# Set flag for sequence parallelism (custom Megatron-LM integration)
if getattr(self, "sequence_parallel", None) is not None:
self.weight.sequence_parallel = self.sequence_parallel
self.bias.sequence_parallel = self.sequence_parallel

@property
def fwd_ln_sm_margin(self) -> int:
"""Shim for backward compatibility"""
warnings.warn("fwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
return self._sm_margins["forward"]

@fwd_ln_sm_margin.setter
def fwd_ln_sm_margin(self, val: int) -> None:
"""Shim for backward compatibility"""
warnings.warn("fwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
self._sm_margins["forward"] = val

@property
def bwd_ln_sm_margin(self) -> int:
"""Shim for backward compatibility"""
warnings.warn("bwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
return self._sm_margins["backward"]

@bwd_ln_sm_margin.setter
def bwd_ln_sm_margin(self, val: int) -> None:
"""Shim for backward compatibility"""
warnings.warn("bwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
self._sm_margins["backward"] = val

@property
def inf_ln_sm_margin(self) -> int:
"""Shim for backward compatibility"""
warnings.warn("inf_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
return self._sm_margins["inference"]

@inf_ln_sm_margin.setter
def inf_ln_sm_margin(self, val: int) -> None:
"""Shim for backward compatibility"""
warnings.warn("inf_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
self._sm_margins["inference"] = val
297 changes: 117 additions & 180 deletions transformer_engine/pytorch/module/rmsnorm.py
Original file line number Diff line number Diff line change
@@ -3,221 +3,158 @@
# See LICENSE for license information.

"""RMSNorm API"""
import os
import warnings
from typing import Union, Tuple, Optional
from typing import Iterable, Optional, Union

import torch
from torch.nn.parameter import Parameter
from torch.nn import init

from .. import cpp_extensions as tex
from ..jit import no_torch_dynamo
from ..utils import cast_if_needed

from transformer_engine.pytorch.ops import RMSNorm as _RMSNormOp

__all__ = ["RMSNorm"]


class _RMSNorm(torch.autograd.Function):
"""functional RMSNorm"""

@staticmethod
def forward(
ctx,
inp: torch.Tensor,
rmsnorm_weight: torch.Tensor,
eps: float,
fwd_rmsnorm_sm_margin: int,
bwd_rmsnorm_sm_margin: int,
inf_rmsnorm_sm_margin: int,
zero_centered_gamma: bool,
is_grad_enabled: bool,
activation_dtype: torch.dtype,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Make sure input dimensions are compatible
in_features = rmsnorm_weight.numel()
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert inp.shape[-1] == in_features, "RMSNorm not possible"
inputmat = inp.view((-1, in_features))

# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
rmsnorm_weight = cast_if_needed(rmsnorm_weight, activation_dtype)

if is_grad_enabled:
rmsnorm_out, rsigma = tex.rmsnorm_fwd(
inputmat, rmsnorm_weight, eps, fwd_rmsnorm_sm_margin, zero_centered_gamma
)
ctx.save_for_backward(inputmat, rmsnorm_weight, rsigma)
ctx.inp_shape = inp.shape
ctx.bwd_rmsnorm_sm_margin = bwd_rmsnorm_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
else:
rmsnorm_out = tex.rmsnorm_fwd_inf(
inputmat, rmsnorm_weight, eps, inf_rmsnorm_sm_margin, zero_centered_gamma
)
return rmsnorm_out.view_as(inp)

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
inputmat, rmsnorm_weight, rsigma = ctx.saved_tensors
grad_output = grad_output.contiguous()
d_rmsnorm_out = grad_output.view(inputmat.shape)
dxmat, dgamma = tex.rmsnorm_bwd(
d_rmsnorm_out,
inputmat,
rsigma,
rmsnorm_weight,
ctx.bwd_rmsnorm_sm_margin,
ctx.zero_centered_gamma,
)
return (
dxmat.view(ctx.inp_shape),
dgamma,
None,
None,
None,
None,
None,
None,
None,
)

class RMSNorm(_RMSNormOp):
r"""Root Mean Square Layer Normalization
class RMSNorm(torch.nn.Module):
r"""
Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in
the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__
Applies Root Mean Square Layer Normalization over a mini-batch of
inputs as described in the paper
`Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__
.. math::
y = \frac{x}{RMS_\varepsilon(x)} * \gamma
y = \frac{x}{\text{RMS}_\varepsilon(x)} * \gamma
where
.. math::
RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^nx_i^2 + \varepsilon}
\text{RMS}_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^n x_i^2 + \varepsilon}
:math:`\gamma` is a learnable affine transform parameter of size :attr:`hidden_size`
:math:`\gamma` is a learnable affine transform parameter that
matches the inner-most dimensions of the input tensor.
Parameters
----------
hidden_size : int
size of each input sample.
normalized_shape: int or iterable of int
Inner dimensions of input tensor
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
A value added to the denominator for numerical stability
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in RMSNorm is initialized to 0 and
the RMSNorm formula changes to
.. math::
y = \frac{x}{RMS_\varepsilon(x)} * (1 + \gamma)
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
If `True`, the :math:`\gamma` parameter is initialized to zero
and the calculation changes to
.. math::
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma)
sm_margin: int, default = 0
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
margin at each compute stage ("forward", "backward",
"inference").
Legacy
------
sequence_parallel: bool
Set a bool attr named `sequence_parallel` in the parameters.
This is custom logic for Megatron-LM integration.
"""

def __init__(
self,
hidden_size: int,
normalized_shape: Union[Iterable[int], int],
eps: float = 1e-5,
sequence_parallel: bool = False,
params_dtype: Optional[torch.dtype] = None,
sequence_parallel: Optional[bool] = None, # legacy
params_dtype: Optional[torch.dtype] = None, # deprecated
zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
**kwargs,
) -> None:
super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.eps = eps
self.zero_centered_gamma = zero_centered_gamma
self.weight = Parameter(
torch.empty(
hidden_size,
device=device,
dtype=params_dtype,
)
)
self.sequence_parallel = sequence_parallel
self.activation_dtype: Optional[torch.dtype] = None

self.reset_parameters(defer_init=device == "meta")
# Handle deprecated options
if params_dtype is not None:
if "dtype" in kwargs:
raise RuntimeError(
"Both `dtype` and `params_dtype` (deprecated) kwargs are provided"
)
kwargs["dtype"] = params_dtype

# Initialize RMSNorm operation
super().__init__(
normalized_shape,
eps=eps,
zero_centered_gamma=zero_centered_gamma,
**kwargs,
)

# These many SMs are subtracted from the total SM count when calling forward
# and backward RMSNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with RMSNorm.
self.fwd_rmsnorm_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_rmsnorm_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_rmsnorm_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel

def reset_rms_norm_parameters(self) -> None:
"""Init RMSNorm params"""
"""Deprecated"""
warnings.warn(
"This method is deprecated and will be removed in an upcoming release. "
"Update your code to use RMSNorm.reset_parameters() instead.",
DeprecationWarning,
stacklevel=2,
)
if not self.zero_centered_gamma:
init.ones_(self.weight)
else:
init.zeros_(self.weight)

def reset_parameters(self, defer_init=False) -> None:
"""Reset RMSNorm parameters"""
if defer_init:
return

if self.weight.device == torch.device("meta"):
self.weight = torch.nn.Parameter(torch.empty_like(self.weight, device="cuda"))
init.constant_(self.weight, float(not self.zero_centered_gamma))
setattr(self.weight, "sequence_parallel", self.sequence_parallel)

@no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor:
# pylint: disable=missing-function-docstring

# Set the activation type for AMP.
# Note: This will soon be deprecated with
# https://github.com/NVIDIA/TransformerEngine/pull/1033
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
elif self.activation_dtype != inp.dtype:
dtype = inp.dtype
for name, param in self.named_parameters():
if param is not None:
assert dtype == param.dtype, (
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
self.activation_dtype = dtype

if torch.is_grad_enabled():
fwd_fn = _RMSNorm.apply
args = []
else:
fwd_fn = _RMSNorm.forward
args = [None]

args += (
inp,
self.weight,
self.eps,
self.fwd_rmsnorm_sm_margin,
self.bwd_rmsnorm_sm_margin,
self.inf_rmsnorm_sm_margin,
self.zero_centered_gamma,
torch.is_grad_enabled(),
self.activation_dtype,
)

return fwd_fn(*args)
self.reset_parameters()

def reset_parameters(self, defer_init: Optional[bool] = None) -> None:
"""Init RMSNorm parameters"""

# Check whether to defer init (deprecated)
if defer_init is not None:
warnings.warn(
"defer_init argument to reset_parameters function is deprecated. Set device to"
' "meta" instead.',
DeprecationWarning,
stacklevel=2,
)
if defer_init:
return

# Reset parameters
super().reset_parameters()

# Flag for sequence parallelism (custom Megatron-LM integration)
if getattr(self, "sequence_parallel", None) is not None:
self.weight.sequence_parallel = self.sequence_parallel

@property
def fwd_rmsnorm_sm_margin(self) -> int:
"""Shim for backward compatibility"""
warnings.warn("fwd_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
return self._sm_margins["forward"]

@fwd_rmsnorm_sm_margin.setter
def fwd_rmsnorm_sm_margin(self, val: int) -> None:
"""Shim for backward compatibility"""
warnings.warn("fwd_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
self._sm_margins["forward"] = val

@property
def bwd_rmsnorm_sm_margin(self) -> int:
"""Shim for backward compatibility"""
warnings.warn("bwd_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
return self._sm_margins["backward"]

@bwd_rmsnorm_sm_margin.setter
def bwd_rmsnorm_sm_margin(self, val: int) -> None:
"""Shim for backward compatibility"""
warnings.warn("bwd_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
self._sm_margins["backward"] = val

@property
def inf_rmsnorm_sm_margin(self) -> int:
"""Shim for backward compatibility"""
warnings.warn("inf_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
return self._sm_margins["inference"]

@inf_rmsnorm_sm_margin.setter
def inf_rmsnorm_sm_margin(self, val: int) -> None:
"""Shim for backward compatibility"""
warnings.warn("inf_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
self._sm_margins["inference"] = val
12 changes: 1 addition & 11 deletions transformer_engine/pytorch/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -8,17 +8,7 @@
"""

from transformer_engine.pytorch.ops.basic import (
AddInPlace,
AllGather,
AllReduce,
BasicLinear,
Bias,
Identity,
MakeExtraOutput,
ReduceScatter,
Reshape,
)
from transformer_engine.pytorch.ops.basic import *
from transformer_engine.pytorch.ops.linear import Linear
from transformer_engine.pytorch.ops.op import FusibleOperation
from transformer_engine.pytorch.ops.sequential import Sequential
22 changes: 21 additions & 1 deletion transformer_engine/pytorch/ops/_common.py
Original file line number Diff line number Diff line change
@@ -56,6 +56,8 @@ def convert_tensor(
if memory_format != torch.preserve_format and not data.is_contiguous(
memory_format=memory_format
):
# Note: torch.Tensor.to ignores memory_format kwarg (see
# https://github.com/pytorch/pytorch/issues/132020).
data = data.contiguous(memory_format=memory_format)
return Float8Tensor.make_like(
tensor,
@@ -65,7 +67,14 @@ def convert_tensor(
)

# Convert standard PyTorch tensor
return tensor.to(device=device, dtype=dtype, memory_format=memory_format)
tensor = tensor.to(device=device, dtype=dtype)
if memory_format != torch.preserve_format and not tensor.is_contiguous(
memory_format=memory_format
):
# Note: torch.Tensor.to ignores memory_format kwarg (see
# https://github.com/pytorch/pytorch/issues/132020).
tensor = tensor.contiguous(memory_format=memory_format)
return tensor


def reshape(
@@ -114,3 +123,14 @@ def reshape(

# Reshape standard PyTorch tensor
return tensor.view(shape)


def maybe_autocast_dtype(
*,
device_type: str = "cuda",
default_dtype: Optional[torch.dtype] = None,
) -> torch.dtype:
"""Get autocast dtype if enabled"""
if torch.is_autocast_enabled(device_type):
return torch.get_autocast_dtype(device_type)
return canonicalize_dtype(default_dtype)
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/ops/basic/__init__.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,9 @@
from .basic_linear import BasicLinear
from .bias import Bias
from .identity import Identity
from .layer_norm import LayerNorm
from .make_extra_output import MakeExtraOutput
from .quantize import Quantize
from .reduce_scatter import ReduceScatter
from .reshape import Reshape
from .rmsnorm import RMSNorm
317 changes: 317 additions & 0 deletions transformer_engine/pytorch/ops/basic/layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Fusable operation for Layer Normalization."""

from __future__ import annotations
from collections.abc import Iterable
import math
import os
from typing import Optional

import torch

from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from ...cpp_extensions import (
layernorm_fwd_fp8,
layernorm_fwd_fp8_inf,
layernorm_fwd_inf,
)
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...tensor import Float8Tensor, QuantizedTensor
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, reshape


class LayerNorm(BasicOperation):
r"""Layer Normalization
Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
:math:`\gamma` and :math:`\beta` are learnable affine transform
parameters that match the inner-most dimensions of the input
tensor.
Parameters
----------
normalized_shape: int or iterable of int
Inner dimensions of input tensor
eps : float, default = 1e-5
A value added to the denominator of layer normalization for
numerical stability
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
zero_centered_gamma : bool, default = 'False'
If `True`, the :math:`\gamma` parameter is initialized to zero
and the calculation changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta
sm_margin: int or dict, default = 0
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
margin at each compute stage ("forward", "backward",
"inference").
"""

def __init__(
self,
normalized_shape: Iterable[int] | int,
*,
eps: float = 1e-5,
device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype] = None,
zero_centered_gamma: bool = False,
sm_margin: int | dict[str, int] = 0,
) -> None:
super().__init__()
self.eps: float = eps
self.zero_centered_gamma: bool = zero_centered_gamma

# Parameter shape
if not isinstance(normalized_shape, Iterable):
normalized_shape = (normalized_shape,)
else:
normalized_shape = tuple(normalized_shape)
self._shape: tuple[int, ...] = normalized_shape

# Parameter device
defer_param_init = False
device = canonicalize_device(device)
if device.type == "meta":
defer_param_init = True
device = canonicalize_device(None)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
self.device: torch.device = device

# Initialize parameters if needed
dtype = canonicalize_dtype(dtype)
weight = torch.empty(
self._shape,
device="meta",
dtype=dtype,
)
bias = torch.empty(
self._shape,
device="meta",
dtype=dtype,
)
weight = torch.nn.Parameter(weight)
bias = torch.nn.Parameter(bias)
self.weight: torch.nn.Parameter
self.bias: torch.nn.Parameter
self.register_parameter("weight", weight)
self.register_parameter("bias", bias)
if not defer_param_init:
self.reset_parameters()

# Number of SMs to exclude when launching CUDA kernels
self._sm_margins: dict[str, int]
if isinstance(sm_margin, dict):

def getenv(name: str) -> int:
return int(os.getenv(name, "0"))

self._sm_margins = {
"forward": sm_margin.get("forward", getenv("NVTE_FWD_LAYERNORM_SM_MARGIN")),
"backward": sm_margin.get("backward", getenv("NVTE_BWD_LAYERNORM_SM_MARGIN")),
"inference": sm_margin.get("inference", getenv("NVTE_INF_LAYERNORM_SM_MARGIN")),
}
else:

def getenv(name: str) -> int:
return int(os.getenv(name, str(sm_margin)))

self._sm_margins = {
"forward": getenv("NVTE_FWD_LAYERNORM_SM_MARGIN"),
"backward": getenv("NVTE_BWD_LAYERNORM_SM_MARGIN"),
"inference": getenv("NVTE_INF_LAYERNORM_SM_MARGIN"),
}

def reset_parameters(self) -> None:
"""Initialize parameter buffers and values"""

# Make sure parameter is initialized
weight = self.weight
bias = self.bias
if weight.device.type != "cuda":
weight = torch.empty_like(weight, device=self.device)
else:
weight = weight.to(device=self.device)
if bias.device.type != "cuda":
bias = torch.empty_like(bias, device=self.device)
else:
bias = bias.to(device=self.device)

# Initialize values
if self.zero_centered_gamma:
torch.nn.init.zeros_(weight)
else:
torch.nn.init.ones_(weight)
torch.nn.init.zeros_(bias)

# Save updated parameter
if not isinstance(weight, torch.nn.Parameter):
weight = torch.nn.Parameter(weight)
if not isinstance(bias, torch.nn.Parameter):
bias = torch.nn.Parameter(bias)
self.weight = weight
self.bias = bias

def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward(*args, **kwargs)
if self.weight.device.type == "meta" or self.bias.device.type == "meta":
self.reset_parameters()

def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:

# Check tensor dims
input_dims = tuple(input_.size())
if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={self._shape}) are not compatible"
)

# Check input tensors
inner_dim = math.prod(self._shape)
device = self.device
dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype)
x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
b = reshape(self.bias, (inner_dim,), device=device, dtype=dtype)
if isinstance(x, QuantizedTensor):
x = x.dequantize()
if isinstance(w, QuantizedTensor):
w = w.dequantize()
if isinstance(b, QuantizedTensor):
b = b.dequantize()

# Check if backward pass is needed
requires_grad = ctx.requires_grad

# Check if FP8 is enabled
with_fp8_output = (
FP8GlobalStateManager.is_fp8_enabled()
and next_op is not None
and next_op.num_fp8_scales("input") > 0
)
output_fp8_meta = None
if with_fp8_output:
output_fp8_meta = next_op.get_fp8_meta("input")

# Compute layer norm
y = None
means = None
rstdevs = None
sm_margin = self._sm_margins["forward" if requires_grad else "inference"]
if with_fp8_output:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
fp8_dtype = get_fp8_te_dtype(output_fp8_meta["recipe"], fprop_tensor=True)
args = (
x,
w,
b,
self.eps,
output_fp8_meta[fp8_meta_key],
0, # fp8_meta_index
fp8_dtype,
sm_margin,
self.zero_centered_gamma,
)
if requires_grad:
data, means, rstdevs = layernorm_fwd_fp8(*args)
else:
data = layernorm_fwd_fp8_inf(*args)
y = Float8Tensor(
data=data,
fp8_meta=output_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
dtype=dtype,
)
else:
args = (
x,
w,
b,
self.eps,
sm_margin,
self.zero_centered_gamma,
)
if requires_grad:
y, means, rstdevs = layernorm_fwd(*args)
else:
y = layernorm_fwd_inf(*args)

# Save state for backward pass
if requires_grad:
ctx.save_for_backward(x, means, rstdevs)
ctx.dtype = dtype
ctx.has_prev_op = prev_op is not None

# Reshape output tensor
out = reshape(y, input_dims)
return out

def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:

# Saved tensors from forward pass
x, means, rstdevs = ctx.saved_tensors

# Check input tensors
inner_dim = x.size(-1)
device = self.device
dtype = ctx.dtype
dy = reshape(grad_output, x.size(), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
if isinstance(w, QuantizedTensor):
w = w.dequantize()
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()

# Compute layer norm backward pass
dx, dw, db = layernorm_bwd(
dy,
x,
means,
rstdevs,
w,
self._sm_margins["backward"],
self.zero_centered_gamma,
)

# Clear saved tensors if possible
if ctx.has_prev_op:
clear_tensor_data(x)
clear_tensor_data(means)
clear_tensor_data(rstdevs)

# Reshape results
grad_input = reshape(dx, grad_output.size())
grad_weight = reshape(dw, self._shape)
grad_bias = reshape(db, self._shape)
return grad_input, (grad_weight, grad_bias)
93 changes: 93 additions & 0 deletions transformer_engine/pytorch/ops/basic/quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Fusible operation for quantization."""

from __future__ import annotations
from typing import Optional

import torch

from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...tensor import Float8Tensor, QuantizedTensor
from ..op import BasicOperation, OperationContext


class Quantize(BasicOperation):
"""Quantize tensor data
Uses FP8 recipe from `fp8_autocast` context. When called outside
of an `fp8_autocast` context, this is an identity operation.
Parameters
----------
forward: bool, default = `True`
Perform quantization in forward pass
backward: bool, default = `False`
Perform quantization in backward pass
"""

def __init__(
self,
forward: bool = True,
backward: bool = False,
) -> None:
super().__init__()
self._quantize_forward = forward
self._quantize_backward = backward

def num_fp8_scales(self, mode: str) -> int:
if mode == "input" and self._quantize_forward:
return 1
if mode == "grad_output" and self._quantize_backward:
return 1
return 0

def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:

# Check if FP8 is enabled
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
quantize_forward = fp8_enabled and self._quantize_forward
quantize_backward = fp8_enabled and self._quantize_backward

# Quantize if needed
out = input_
if quantize_forward and not isinstance(out, QuantizedTensor):
fp8_meta = self.get_fp8_meta("input")
fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
out = Float8Tensor.to_float8(
out,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
)

ctx.quantize_backward = quantize_backward
return out

def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
grad_input = grad_output
if ctx.quantize_backward and not isinstance(grad_input, QuantizedTensor):
fp8_meta = self.get_fp8_meta("grad_output")
fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
grad_input = Float8Tensor.to_float8(
grad_input,
fp8_meta=fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
)
return grad_input, ()
289 changes: 289 additions & 0 deletions transformer_engine/pytorch/ops/basic/rmsnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Fusable operation for RMSNorm."""

from __future__ import annotations
from collections.abc import Iterable
import math
import os
from typing import Optional

import torch

from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from ...cpp_extensions import (
rmsnorm_fwd_fp8,
rmsnorm_fwd_fp8_inf,
rmsnorm_fwd_inf,
)
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...tensor import Float8Tensor, QuantizedTensor
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, reshape


class RMSNorm(BasicOperation):
r"""Root Mean Square Layer Normalization
Applies Root Mean Square Layer Normalization over a mini-batch of
inputs as described in the paper
`Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__
.. math::
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma
:math:`\gamma` is a learnable affine transform parameter that
matches the inner-most dimensions of the input tensor.
Parameters
----------
normalized_shape: int or iterable of int
Inner dimensions of input tensor
eps : float, default = 1e-5
A value added to the denominator for numerical stability
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
zero_centered_gamma : bool, default = 'False'
If `True`, the :math:`\gamma` parameter is initialized to zero
and the calculation changes to
.. math::
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma)
sm_margin: int, default = 0
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
margin at each compute stage ("forward", "backward",
"inference").
"""

def __init__(
self,
normalized_shape: Iterable[int] | int,
*,
eps: float = 1e-5,
device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype] = None,
zero_centered_gamma: bool = False,
sm_margin: int = 0,
) -> None:
super().__init__()
self.eps: float = eps
self.zero_centered_gamma: bool = zero_centered_gamma

# Parameter shape
if not isinstance(normalized_shape, Iterable):
normalized_shape = (normalized_shape,)
else:
normalized_shape = tuple(normalized_shape)
self._shape: tuple[int, ...] = normalized_shape

# Parameter device
defer_param_init = False
device = canonicalize_device(device)
if device.type == "meta":
defer_param_init = True
device = canonicalize_device(None)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
self.device: torch.device = device

# Initialize parameters if needed
weight = torch.empty(
self._shape,
device="meta",
dtype=canonicalize_dtype(dtype),
)
weight = torch.nn.Parameter(weight)
self.weight: torch.nn.Parameter
self.register_parameter("weight", weight)
if not defer_param_init:
self.reset_parameters()

# Number of SMs to exclude when launching CUDA kernels
self._sm_margins: dict[str, int]
if isinstance(sm_margin, dict):

def getenv(name: str) -> int:
return int(os.getenv(name, "0"))

self._sm_margins = {
"forward": sm_margin.get("forward", getenv("NVTE_FWD_LAYERNORM_SM_MARGIN")),
"backward": sm_margin.get("backward", getenv("NVTE_BWD_LAYERNORM_SM_MARGIN")),
"inference": sm_margin.get("inference", getenv("NVTE_INF_LAYERNORM_SM_MARGIN")),
}
else:

def getenv(name: str) -> int:
return int(os.getenv(name, str(sm_margin)))

self._sm_margins = {
"forward": getenv("NVTE_FWD_LAYERNORM_SM_MARGIN"),
"backward": getenv("NVTE_BWD_LAYERNORM_SM_MARGIN"),
"inference": getenv("NVTE_INF_LAYERNORM_SM_MARGIN"),
}

def reset_parameters(self) -> None:
"""Initialize parameter buffers and values"""

# Make sure parameter is initialized
weight = self.weight
if weight.device.type != "cuda":
weight = torch.empty_like(weight, device=self.device)
else:
weight = weight.to(device=self.device)

# Initialize values
if self.zero_centered_gamma:
torch.nn.init.zeros_(weight)
else:
torch.nn.init.ones_(weight)

# Save updated parameter
if not isinstance(weight, torch.nn.Parameter):
weight = torch.nn.Parameter(weight)
self.weight = weight

def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward(*args, **kwargs)
if self.weight.device.type == "meta":
self.reset_parameters()

def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:

# Check tensor dims
input_dims = tuple(input_.size())
if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={self._shape}) are not compatible"
)

# Check input tensors
inner_dim = math.prod(self._shape)
device = self.device
dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype)
x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
if isinstance(x, QuantizedTensor):
x = x.dequantize()
if isinstance(w, QuantizedTensor):
w = w.dequantize()

# Check if backward pass is needed
requires_grad = ctx.requires_grad

# Check if FP8 is enabled
with_fp8_output = (
FP8GlobalStateManager.is_fp8_enabled()
and next_op is not None
and next_op.num_fp8_scales("input") > 0
)
output_fp8_meta = None
if with_fp8_output:
output_fp8_meta = next_op.get_fp8_meta("input")

# Compute RMSNorm
y = None
rstdevs = None
sm_margin = self._sm_margins["forward" if requires_grad else "inference"]
if with_fp8_output:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
fp8_dtype = get_fp8_te_dtype(output_fp8_meta["recipe"], fprop_tensor=True)
args = (
x,
w,
self.eps,
output_fp8_meta[fp8_meta_key],
0, # fp8_meta_index
fp8_dtype,
sm_margin,
self.zero_centered_gamma,
)
if requires_grad:
data, rstdevs = rmsnorm_fwd_fp8(*args)
else:
data = rmsnorm_fwd_fp8_inf(*args)
y = Float8Tensor(
data=data,
fp8_meta=output_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
dtype=dtype,
)
else:
args = (
x,
w,
self.eps,
sm_margin,
self.zero_centered_gamma,
)
if requires_grad:
y, rstdevs = rmsnorm_fwd(*args)
else:
y = rmsnorm_fwd_inf(*args)

# Save state for backward pass
if requires_grad:
ctx.save_for_backward(x, rstdevs)
ctx.dtype = dtype
ctx.has_prev_op = prev_op is not None

# Reshape output tensor
out = reshape(y, input_dims)
return out

def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:

# Saved tensors from forward pass
x, rstdevs = ctx.saved_tensors

# Check input tensors
inner_dim = x.size(-1)
device = self.device
dtype = ctx.dtype
dy = reshape(grad_output, x.size(), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
if isinstance(w, QuantizedTensor):
w = w.dequantize()
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()

# Compute RMSNorm backward pass
dx, dw = rmsnorm_bwd(
dy,
x,
rstdevs,
w,
self._sm_margins["backward"],
self.zero_centered_gamma,
)

# Clear saved tensors if possible
if ctx.has_prev_op:
clear_tensor_data(x)
clear_tensor_data(rstdevs)

# Reshape results
grad_input = reshape(dx, grad_output.size())
grad_weight = reshape(dw, self._shape)
return grad_input, (grad_weight,)
82 changes: 50 additions & 32 deletions transformer_engine/pytorch/ops/fuser.py
Original file line number Diff line number Diff line change
@@ -57,12 +57,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# pylint: disable=unused-argument
@staticmethod
def forward(
func_ctx: torch.autograd.function.FunctionCtx,
func_ctx: Optional[torch.autograd.function.FunctionCtx],
input_: torch.Tensor,
forward_ops: list[tuple[FusibleOperation, list[int]]],
backward_ops: list[tuple[FusibleOperation, list[int]]],
basic_ops: list[BasicOperation],
basic_op_kwargs: list[dict[str, Any]],
is_grad_enabled: bool,
num_params: int,
num_extra_inputs: int,
*params_and_extra_inputs: torch.nn.Parameter,
@@ -120,10 +121,20 @@ def forward(

# Apply forward ops
x = input_
requires_grad = x.requires_grad
requires_grad = is_grad_enabled and x.requires_grad
extra_outputs = [None for _ in range(len(basic_ops))]
for op, basic_op_idxs in forward_ops:

# Check if backward op is required
if is_grad_enabled:
if not requires_grad:
requires_grad = any(param.requires_grad for param in op.parameters())
if not requires_grad:
requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs)
for idx in basic_op_idxs:
basic_op_ctxs[idx].requires_grad = requires_grad
x.requires_grad_(requires_grad=requires_grad)

# Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs]
@@ -138,18 +149,12 @@ def forward(
basic_op_next_ops=next_ops,
basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs],
)
x.requires_grad_(requires_grad=requires_grad)
for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
for y in ys:
y.requires_grad_(requires_grad=requires_grad)
extra_outputs[idx] = ys

# Check if backward op is required
if not requires_grad:
requires_grad = any(param.requires_grad for param in op.parameters())
if not requires_grad:
requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs)
for idx in basic_op_idxs:
basic_op_ctxs[idx]._requires_grad = requires_grad
x.requires_grad_(requires_grad=requires_grad)

# Flatten list of extra outputs
extra_outputs_flat = []
for idx, ys in enumerate(extra_outputs):
@@ -163,25 +168,28 @@ def forward(
)
extra_outputs_flat.extend(ys)

# Flatten list of saved tensors
to_save = []
for ctx in basic_op_ctxs:
range_start = len(to_save)
if ctx.to_save is not None:
to_save.extend(ctx.to_save)
range_end = len(to_save)
ctx.to_save = None
ctx._saved_tensors_range = (range_start, range_end)
func_ctx.save_for_backward(*to_save)

# Other context for backward pass
func_ctx.backward_ops = backward_ops
func_ctx.basic_ops = basic_ops
func_ctx.basic_op_ctxs = basic_op_ctxs
func_ctx.num_params = num_params
func_ctx.num_extra_inputs = num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
# Save context for backward pass
if is_grad_enabled:

# Flatten list of saved tensors
to_save = []
for ctx in basic_op_ctxs:
range_start = len(to_save)
if ctx.to_save is not None:
to_save.extend(ctx.to_save)
range_end = len(to_save)
ctx.to_save = None
ctx._saved_tensors_range = (range_start, range_end)
func_ctx.save_for_backward(*to_save)

# Other context
func_ctx.backward_ops = backward_ops
func_ctx.basic_ops = basic_ops
func_ctx.basic_op_ctxs = basic_op_ctxs
func_ctx.num_params = num_params
func_ctx.num_extra_inputs = num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()

if extra_outputs_flat:
return x, *extra_outputs_flat
@@ -224,7 +232,7 @@ def backward(
for op, basic_op_idxs in backward_ops:

# Stop if no more gradients are required
if all(not basic_op_ctxs[idx]._requires_grad for idx in basic_op_idxs):
if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs):
dx = None
break

@@ -282,6 +290,7 @@ def backward(
None, # backward_ops
None, # basic_ops
None, # basic_op_kwargs
None, # is_grad_enabled
None, # num_params
None, # num_extra_inputs
*grad_params_flat,
@@ -373,14 +382,23 @@ def __call__(
params = [param for op in self._basic_ops for param in op.parameters()]

# Fuser forward pass
return _OperationFuserAutogradFunction.apply(
is_grad_enabled = torch.is_grad_enabled()
if is_grad_enabled:
forward_func = _OperationFuserAutogradFunction.apply
args = []
else:
forward_func = _OperationFuserAutogradFunction.forward
args = [None]
args += (
input,
self._forward_ops,
self._backward_ops,
self._basic_ops,
basic_op_kwargs,
is_grad_enabled,
len(params),
self._num_extra_inputs,
*params,
*extra_inputs,
)
return forward_func(*args)
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/ops/op.py
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ class OperationContext:
_saved_tensors_range: Optional[tuple[int, int]] = None

# Whether backward pass is required
_requires_grad: bool = False
requires_grad: bool = True

def save_for_backward(self, *tensors: Optional[torch.Tensor]) -> None:
"""Register tensors to be saved for the backward function

0 comments on commit 77c37d4

Please sign in to comment.