-
Notifications
You must be signed in to change notification settings - Fork 366
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
[PyTorch] Normalization ops (#1033)
* Add layer norm op Signed-off-by: Tim Moon <[email protected]> * Add FP8 cast op Signed-off-by: Tim Moon <[email protected]> * Add tests for linear and layernorm with FP8 output Signed-off-by: Tim Moon <[email protected]> * RMSNorm op Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warnings Signed-off-by: Tim Moon <[email protected]> * Replace LayerNorm module with LayerNorm op Signed-off-by: Tim Moon <[email protected]> * Replace RMSNorm module with RMSNorm op Signed-off-by: Tim Moon <[email protected]> * Add AMP support Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Do not save autograd context if grad mode is disabled Debugging ONNX export tests. Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Forward args in pre_forward func to base op class Signed-off-by: Tim Moon <[email protected]> * Update to use QuantizedTensor class Signed-off-by: Tim Moon <[email protected]> * Apply suggestions from code review Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Review suggestions from @ptrendx Rename "CastFloat8" op to "Quantize". Add more fine-grained control for SM margin. Add docs for legacy sequence_parallel kwarg. Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warnings Signed-off-by: Tim Moon <[email protected]> * Use weight dtype as default compute dtype Signed-off-by: Tim Moon <[email protected]> * Fix linter warnings Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak <[email protected]>
1 parent
f20d3dd
commit 77c37d4
Showing
12 changed files
with
1,416 additions
and
511 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,317 @@ | ||
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# See LICENSE for license information. | ||
|
||
"""Fusable operation for Layer Normalization.""" | ||
|
||
from __future__ import annotations | ||
from collections.abc import Iterable | ||
import math | ||
import os | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from transformer_engine_torch import layernorm_bwd, layernorm_fwd | ||
from ...cpp_extensions import ( | ||
layernorm_fwd_fp8, | ||
layernorm_fwd_fp8_inf, | ||
layernorm_fwd_inf, | ||
) | ||
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype | ||
from ...tensor import Float8Tensor, QuantizedTensor | ||
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data | ||
from ..op import BasicOperation, OperationContext | ||
from .._common import maybe_autocast_dtype, reshape | ||
|
||
|
||
class LayerNorm(BasicOperation): | ||
r"""Layer Normalization | ||
Applies Layer Normalization over a mini-batch of inputs as described in | ||
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__ | ||
.. math:: | ||
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta | ||
:math:`\gamma` and :math:`\beta` are learnable affine transform | ||
parameters that match the inner-most dimensions of the input | ||
tensor. | ||
Parameters | ||
---------- | ||
normalized_shape: int or iterable of int | ||
Inner dimensions of input tensor | ||
eps : float, default = 1e-5 | ||
A value added to the denominator of layer normalization for | ||
numerical stability | ||
device: torch.device, default = default CUDA device | ||
Tensor device | ||
dtype: torch.dtype, default = default dtype | ||
Tensor datatype | ||
zero_centered_gamma : bool, default = 'False' | ||
If `True`, the :math:`\gamma` parameter is initialized to zero | ||
and the calculation changes to | ||
.. math:: | ||
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta | ||
sm_margin: int or dict, default = 0 | ||
Number of SMs to exclude when launching CUDA kernels. This | ||
helps overlap with other kernels, e.g. communication kernels. | ||
For more fine-grained control, provide a dict with the SM | ||
margin at each compute stage ("forward", "backward", | ||
"inference"). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
normalized_shape: Iterable[int] | int, | ||
*, | ||
eps: float = 1e-5, | ||
device: Optional[torch.device | str] = None, | ||
dtype: Optional[torch.dtype] = None, | ||
zero_centered_gamma: bool = False, | ||
sm_margin: int | dict[str, int] = 0, | ||
) -> None: | ||
super().__init__() | ||
self.eps: float = eps | ||
self.zero_centered_gamma: bool = zero_centered_gamma | ||
|
||
# Parameter shape | ||
if not isinstance(normalized_shape, Iterable): | ||
normalized_shape = (normalized_shape,) | ||
else: | ||
normalized_shape = tuple(normalized_shape) | ||
self._shape: tuple[int, ...] = normalized_shape | ||
|
||
# Parameter device | ||
defer_param_init = False | ||
device = canonicalize_device(device) | ||
if device.type == "meta": | ||
defer_param_init = True | ||
device = canonicalize_device(None) | ||
if device.type != "cuda": | ||
raise ValueError(f"Only CUDA devices are supported (got {device})") | ||
self.device: torch.device = device | ||
|
||
# Initialize parameters if needed | ||
dtype = canonicalize_dtype(dtype) | ||
weight = torch.empty( | ||
self._shape, | ||
device="meta", | ||
dtype=dtype, | ||
) | ||
bias = torch.empty( | ||
self._shape, | ||
device="meta", | ||
dtype=dtype, | ||
) | ||
weight = torch.nn.Parameter(weight) | ||
bias = torch.nn.Parameter(bias) | ||
self.weight: torch.nn.Parameter | ||
self.bias: torch.nn.Parameter | ||
self.register_parameter("weight", weight) | ||
self.register_parameter("bias", bias) | ||
if not defer_param_init: | ||
self.reset_parameters() | ||
|
||
# Number of SMs to exclude when launching CUDA kernels | ||
self._sm_margins: dict[str, int] | ||
if isinstance(sm_margin, dict): | ||
|
||
def getenv(name: str) -> int: | ||
return int(os.getenv(name, "0")) | ||
|
||
self._sm_margins = { | ||
"forward": sm_margin.get("forward", getenv("NVTE_FWD_LAYERNORM_SM_MARGIN")), | ||
"backward": sm_margin.get("backward", getenv("NVTE_BWD_LAYERNORM_SM_MARGIN")), | ||
"inference": sm_margin.get("inference", getenv("NVTE_INF_LAYERNORM_SM_MARGIN")), | ||
} | ||
else: | ||
|
||
def getenv(name: str) -> int: | ||
return int(os.getenv(name, str(sm_margin))) | ||
|
||
self._sm_margins = { | ||
"forward": getenv("NVTE_FWD_LAYERNORM_SM_MARGIN"), | ||
"backward": getenv("NVTE_BWD_LAYERNORM_SM_MARGIN"), | ||
"inference": getenv("NVTE_INF_LAYERNORM_SM_MARGIN"), | ||
} | ||
|
||
def reset_parameters(self) -> None: | ||
"""Initialize parameter buffers and values""" | ||
|
||
# Make sure parameter is initialized | ||
weight = self.weight | ||
bias = self.bias | ||
if weight.device.type != "cuda": | ||
weight = torch.empty_like(weight, device=self.device) | ||
else: | ||
weight = weight.to(device=self.device) | ||
if bias.device.type != "cuda": | ||
bias = torch.empty_like(bias, device=self.device) | ||
else: | ||
bias = bias.to(device=self.device) | ||
|
||
# Initialize values | ||
if self.zero_centered_gamma: | ||
torch.nn.init.zeros_(weight) | ||
else: | ||
torch.nn.init.ones_(weight) | ||
torch.nn.init.zeros_(bias) | ||
|
||
# Save updated parameter | ||
if not isinstance(weight, torch.nn.Parameter): | ||
weight = torch.nn.Parameter(weight) | ||
if not isinstance(bias, torch.nn.Parameter): | ||
bias = torch.nn.Parameter(bias) | ||
self.weight = weight | ||
self.bias = bias | ||
|
||
def pre_forward(self, *args, **kwargs) -> None: | ||
super().pre_forward(*args, **kwargs) | ||
if self.weight.device.type == "meta" or self.bias.device.type == "meta": | ||
self.reset_parameters() | ||
|
||
def op_forward( | ||
self, | ||
ctx: OperationContext, | ||
input_: torch.Tensor, | ||
prev_op: Optional[BasicOperation] = None, | ||
next_op: Optional[BasicOperation] = None, | ||
) -> torch.Tensor: | ||
|
||
# Check tensor dims | ||
input_dims = tuple(input_.size()) | ||
if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape: | ||
raise ValueError( | ||
f"Input tensor (shape={input_dims}) " | ||
f"and weight tensor (shape={self._shape}) are not compatible" | ||
) | ||
|
||
# Check input tensors | ||
inner_dim = math.prod(self._shape) | ||
device = self.device | ||
dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype) | ||
x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) | ||
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) | ||
b = reshape(self.bias, (inner_dim,), device=device, dtype=dtype) | ||
if isinstance(x, QuantizedTensor): | ||
x = x.dequantize() | ||
if isinstance(w, QuantizedTensor): | ||
w = w.dequantize() | ||
if isinstance(b, QuantizedTensor): | ||
b = b.dequantize() | ||
|
||
# Check if backward pass is needed | ||
requires_grad = ctx.requires_grad | ||
|
||
# Check if FP8 is enabled | ||
with_fp8_output = ( | ||
FP8GlobalStateManager.is_fp8_enabled() | ||
and next_op is not None | ||
and next_op.num_fp8_scales("input") > 0 | ||
) | ||
output_fp8_meta = None | ||
if with_fp8_output: | ||
output_fp8_meta = next_op.get_fp8_meta("input") | ||
|
||
# Compute layer norm | ||
y = None | ||
means = None | ||
rstdevs = None | ||
sm_margin = self._sm_margins["forward" if requires_grad else "inference"] | ||
if with_fp8_output: | ||
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) | ||
fp8_dtype = get_fp8_te_dtype(output_fp8_meta["recipe"], fprop_tensor=True) | ||
args = ( | ||
x, | ||
w, | ||
b, | ||
self.eps, | ||
output_fp8_meta[fp8_meta_key], | ||
0, # fp8_meta_index | ||
fp8_dtype, | ||
sm_margin, | ||
self.zero_centered_gamma, | ||
) | ||
if requires_grad: | ||
data, means, rstdevs = layernorm_fwd_fp8(*args) | ||
else: | ||
data = layernorm_fwd_fp8_inf(*args) | ||
y = Float8Tensor( | ||
data=data, | ||
fp8_meta=output_fp8_meta, | ||
fp8_meta_forward=True, | ||
fp8_meta_index=0, | ||
fp8_dtype=fp8_dtype, | ||
dtype=dtype, | ||
) | ||
else: | ||
args = ( | ||
x, | ||
w, | ||
b, | ||
self.eps, | ||
sm_margin, | ||
self.zero_centered_gamma, | ||
) | ||
if requires_grad: | ||
y, means, rstdevs = layernorm_fwd(*args) | ||
else: | ||
y = layernorm_fwd_inf(*args) | ||
|
||
# Save state for backward pass | ||
if requires_grad: | ||
ctx.save_for_backward(x, means, rstdevs) | ||
ctx.dtype = dtype | ||
ctx.has_prev_op = prev_op is not None | ||
|
||
# Reshape output tensor | ||
out = reshape(y, input_dims) | ||
return out | ||
|
||
def op_backward( | ||
self, | ||
ctx: OperationContext, | ||
grad_output: torch.Tensor, | ||
) -> tuple[torch.Tensor, tuple[()]]: | ||
|
||
# Saved tensors from forward pass | ||
x, means, rstdevs = ctx.saved_tensors | ||
|
||
# Check input tensors | ||
inner_dim = x.size(-1) | ||
device = self.device | ||
dtype = ctx.dtype | ||
dy = reshape(grad_output, x.size(), device=device, dtype=dtype) | ||
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) | ||
if isinstance(w, QuantizedTensor): | ||
w = w.dequantize() | ||
if isinstance(dy, QuantizedTensor): | ||
dy = dy.dequantize() | ||
|
||
# Compute layer norm backward pass | ||
dx, dw, db = layernorm_bwd( | ||
dy, | ||
x, | ||
means, | ||
rstdevs, | ||
w, | ||
self._sm_margins["backward"], | ||
self.zero_centered_gamma, | ||
) | ||
|
||
# Clear saved tensors if possible | ||
if ctx.has_prev_op: | ||
clear_tensor_data(x) | ||
clear_tensor_data(means) | ||
clear_tensor_data(rstdevs) | ||
|
||
# Reshape results | ||
grad_input = reshape(dx, grad_output.size()) | ||
grad_weight = reshape(dw, self._shape) | ||
grad_bias = reshape(db, self._shape) | ||
return grad_input, (grad_weight, grad_bias) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# See LICENSE for license information. | ||
|
||
"""Fusible operation for quantization.""" | ||
|
||
from __future__ import annotations | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype | ||
from ...tensor import Float8Tensor, QuantizedTensor | ||
from ..op import BasicOperation, OperationContext | ||
|
||
|
||
class Quantize(BasicOperation): | ||
"""Quantize tensor data | ||
Uses FP8 recipe from `fp8_autocast` context. When called outside | ||
of an `fp8_autocast` context, this is an identity operation. | ||
Parameters | ||
---------- | ||
forward: bool, default = `True` | ||
Perform quantization in forward pass | ||
backward: bool, default = `False` | ||
Perform quantization in backward pass | ||
""" | ||
|
||
def __init__( | ||
self, | ||
forward: bool = True, | ||
backward: bool = False, | ||
) -> None: | ||
super().__init__() | ||
self._quantize_forward = forward | ||
self._quantize_backward = backward | ||
|
||
def num_fp8_scales(self, mode: str) -> int: | ||
if mode == "input" and self._quantize_forward: | ||
return 1 | ||
if mode == "grad_output" and self._quantize_backward: | ||
return 1 | ||
return 0 | ||
|
||
def op_forward( | ||
self, | ||
ctx: OperationContext, | ||
input_: torch.Tensor, | ||
prev_op: Optional[BasicOperation] = None, | ||
next_op: Optional[BasicOperation] = None, | ||
) -> torch.Tensor: | ||
|
||
# Check if FP8 is enabled | ||
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() | ||
quantize_forward = fp8_enabled and self._quantize_forward | ||
quantize_backward = fp8_enabled and self._quantize_backward | ||
|
||
# Quantize if needed | ||
out = input_ | ||
if quantize_forward and not isinstance(out, QuantizedTensor): | ||
fp8_meta = self.get_fp8_meta("input") | ||
fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) | ||
out = Float8Tensor.to_float8( | ||
out, | ||
fp8_meta=fp8_meta, | ||
fp8_meta_forward=True, | ||
fp8_meta_index=0, | ||
fp8_dtype=fp8_dtype, | ||
) | ||
|
||
ctx.quantize_backward = quantize_backward | ||
return out | ||
|
||
def op_backward( | ||
self, | ||
ctx: OperationContext, | ||
grad_output: torch.Tensor, | ||
) -> tuple[torch.Tensor, tuple[()]]: | ||
grad_input = grad_output | ||
if ctx.quantize_backward and not isinstance(grad_input, QuantizedTensor): | ||
fp8_meta = self.get_fp8_meta("grad_output") | ||
fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) | ||
grad_input = Float8Tensor.to_float8( | ||
grad_input, | ||
fp8_meta=fp8_meta, | ||
fp8_meta_forward=False, | ||
fp8_meta_index=0, | ||
fp8_dtype=fp8_dtype, | ||
) | ||
return grad_input, () |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,289 @@ | ||
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# See LICENSE for license information. | ||
|
||
"""Fusable operation for RMSNorm.""" | ||
|
||
from __future__ import annotations | ||
from collections.abc import Iterable | ||
import math | ||
import os | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd | ||
from ...cpp_extensions import ( | ||
rmsnorm_fwd_fp8, | ||
rmsnorm_fwd_fp8_inf, | ||
rmsnorm_fwd_inf, | ||
) | ||
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype | ||
from ...tensor import Float8Tensor, QuantizedTensor | ||
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data | ||
from ..op import BasicOperation, OperationContext | ||
from .._common import maybe_autocast_dtype, reshape | ||
|
||
|
||
class RMSNorm(BasicOperation): | ||
r"""Root Mean Square Layer Normalization | ||
Applies Root Mean Square Layer Normalization over a mini-batch of | ||
inputs as described in the paper | ||
`Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__ | ||
.. math:: | ||
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma | ||
:math:`\gamma` is a learnable affine transform parameter that | ||
matches the inner-most dimensions of the input tensor. | ||
Parameters | ||
---------- | ||
normalized_shape: int or iterable of int | ||
Inner dimensions of input tensor | ||
eps : float, default = 1e-5 | ||
A value added to the denominator for numerical stability | ||
device: torch.device, default = default CUDA device | ||
Tensor device | ||
dtype: torch.dtype, default = default dtype | ||
Tensor datatype | ||
zero_centered_gamma : bool, default = 'False' | ||
If `True`, the :math:`\gamma` parameter is initialized to zero | ||
and the calculation changes to | ||
.. math:: | ||
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) | ||
sm_margin: int, default = 0 | ||
Number of SMs to exclude when launching CUDA kernels. This | ||
helps overlap with other kernels, e.g. communication kernels. | ||
For more fine-grained control, provide a dict with the SM | ||
margin at each compute stage ("forward", "backward", | ||
"inference"). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
normalized_shape: Iterable[int] | int, | ||
*, | ||
eps: float = 1e-5, | ||
device: Optional[torch.device | str] = None, | ||
dtype: Optional[torch.dtype] = None, | ||
zero_centered_gamma: bool = False, | ||
sm_margin: int = 0, | ||
) -> None: | ||
super().__init__() | ||
self.eps: float = eps | ||
self.zero_centered_gamma: bool = zero_centered_gamma | ||
|
||
# Parameter shape | ||
if not isinstance(normalized_shape, Iterable): | ||
normalized_shape = (normalized_shape,) | ||
else: | ||
normalized_shape = tuple(normalized_shape) | ||
self._shape: tuple[int, ...] = normalized_shape | ||
|
||
# Parameter device | ||
defer_param_init = False | ||
device = canonicalize_device(device) | ||
if device.type == "meta": | ||
defer_param_init = True | ||
device = canonicalize_device(None) | ||
if device.type != "cuda": | ||
raise ValueError(f"Only CUDA devices are supported (got {device})") | ||
self.device: torch.device = device | ||
|
||
# Initialize parameters if needed | ||
weight = torch.empty( | ||
self._shape, | ||
device="meta", | ||
dtype=canonicalize_dtype(dtype), | ||
) | ||
weight = torch.nn.Parameter(weight) | ||
self.weight: torch.nn.Parameter | ||
self.register_parameter("weight", weight) | ||
if not defer_param_init: | ||
self.reset_parameters() | ||
|
||
# Number of SMs to exclude when launching CUDA kernels | ||
self._sm_margins: dict[str, int] | ||
if isinstance(sm_margin, dict): | ||
|
||
def getenv(name: str) -> int: | ||
return int(os.getenv(name, "0")) | ||
|
||
self._sm_margins = { | ||
"forward": sm_margin.get("forward", getenv("NVTE_FWD_LAYERNORM_SM_MARGIN")), | ||
"backward": sm_margin.get("backward", getenv("NVTE_BWD_LAYERNORM_SM_MARGIN")), | ||
"inference": sm_margin.get("inference", getenv("NVTE_INF_LAYERNORM_SM_MARGIN")), | ||
} | ||
else: | ||
|
||
def getenv(name: str) -> int: | ||
return int(os.getenv(name, str(sm_margin))) | ||
|
||
self._sm_margins = { | ||
"forward": getenv("NVTE_FWD_LAYERNORM_SM_MARGIN"), | ||
"backward": getenv("NVTE_BWD_LAYERNORM_SM_MARGIN"), | ||
"inference": getenv("NVTE_INF_LAYERNORM_SM_MARGIN"), | ||
} | ||
|
||
def reset_parameters(self) -> None: | ||
"""Initialize parameter buffers and values""" | ||
|
||
# Make sure parameter is initialized | ||
weight = self.weight | ||
if weight.device.type != "cuda": | ||
weight = torch.empty_like(weight, device=self.device) | ||
else: | ||
weight = weight.to(device=self.device) | ||
|
||
# Initialize values | ||
if self.zero_centered_gamma: | ||
torch.nn.init.zeros_(weight) | ||
else: | ||
torch.nn.init.ones_(weight) | ||
|
||
# Save updated parameter | ||
if not isinstance(weight, torch.nn.Parameter): | ||
weight = torch.nn.Parameter(weight) | ||
self.weight = weight | ||
|
||
def pre_forward(self, *args, **kwargs) -> None: | ||
super().pre_forward(*args, **kwargs) | ||
if self.weight.device.type == "meta": | ||
self.reset_parameters() | ||
|
||
def op_forward( | ||
self, | ||
ctx: OperationContext, | ||
input_: torch.Tensor, | ||
prev_op: Optional[BasicOperation] = None, | ||
next_op: Optional[BasicOperation] = None, | ||
) -> torch.Tensor: | ||
|
||
# Check tensor dims | ||
input_dims = tuple(input_.size()) | ||
if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape: | ||
raise ValueError( | ||
f"Input tensor (shape={input_dims}) " | ||
f"and weight tensor (shape={self._shape}) are not compatible" | ||
) | ||
|
||
# Check input tensors | ||
inner_dim = math.prod(self._shape) | ||
device = self.device | ||
dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype) | ||
x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) | ||
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) | ||
if isinstance(x, QuantizedTensor): | ||
x = x.dequantize() | ||
if isinstance(w, QuantizedTensor): | ||
w = w.dequantize() | ||
|
||
# Check if backward pass is needed | ||
requires_grad = ctx.requires_grad | ||
|
||
# Check if FP8 is enabled | ||
with_fp8_output = ( | ||
FP8GlobalStateManager.is_fp8_enabled() | ||
and next_op is not None | ||
and next_op.num_fp8_scales("input") > 0 | ||
) | ||
output_fp8_meta = None | ||
if with_fp8_output: | ||
output_fp8_meta = next_op.get_fp8_meta("input") | ||
|
||
# Compute RMSNorm | ||
y = None | ||
rstdevs = None | ||
sm_margin = self._sm_margins["forward" if requires_grad else "inference"] | ||
if with_fp8_output: | ||
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) | ||
fp8_dtype = get_fp8_te_dtype(output_fp8_meta["recipe"], fprop_tensor=True) | ||
args = ( | ||
x, | ||
w, | ||
self.eps, | ||
output_fp8_meta[fp8_meta_key], | ||
0, # fp8_meta_index | ||
fp8_dtype, | ||
sm_margin, | ||
self.zero_centered_gamma, | ||
) | ||
if requires_grad: | ||
data, rstdevs = rmsnorm_fwd_fp8(*args) | ||
else: | ||
data = rmsnorm_fwd_fp8_inf(*args) | ||
y = Float8Tensor( | ||
data=data, | ||
fp8_meta=output_fp8_meta, | ||
fp8_meta_forward=True, | ||
fp8_meta_index=0, | ||
fp8_dtype=fp8_dtype, | ||
dtype=dtype, | ||
) | ||
else: | ||
args = ( | ||
x, | ||
w, | ||
self.eps, | ||
sm_margin, | ||
self.zero_centered_gamma, | ||
) | ||
if requires_grad: | ||
y, rstdevs = rmsnorm_fwd(*args) | ||
else: | ||
y = rmsnorm_fwd_inf(*args) | ||
|
||
# Save state for backward pass | ||
if requires_grad: | ||
ctx.save_for_backward(x, rstdevs) | ||
ctx.dtype = dtype | ||
ctx.has_prev_op = prev_op is not None | ||
|
||
# Reshape output tensor | ||
out = reshape(y, input_dims) | ||
return out | ||
|
||
def op_backward( | ||
self, | ||
ctx: OperationContext, | ||
grad_output: torch.Tensor, | ||
) -> tuple[torch.Tensor, tuple[()]]: | ||
|
||
# Saved tensors from forward pass | ||
x, rstdevs = ctx.saved_tensors | ||
|
||
# Check input tensors | ||
inner_dim = x.size(-1) | ||
device = self.device | ||
dtype = ctx.dtype | ||
dy = reshape(grad_output, x.size(), device=device, dtype=dtype) | ||
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) | ||
if isinstance(w, QuantizedTensor): | ||
w = w.dequantize() | ||
if isinstance(dy, QuantizedTensor): | ||
dy = dy.dequantize() | ||
|
||
# Compute RMSNorm backward pass | ||
dx, dw = rmsnorm_bwd( | ||
dy, | ||
x, | ||
rstdevs, | ||
w, | ||
self._sm_margins["backward"], | ||
self.zero_centered_gamma, | ||
) | ||
|
||
# Clear saved tensors if possible | ||
if ctx.has_prev_op: | ||
clear_tensor_data(x) | ||
clear_tensor_data(rstdevs) | ||
|
||
# Reshape results | ||
grad_input = reshape(dx, grad_output.size()) | ||
grad_weight = reshape(dw, self._shape) | ||
return grad_input, (grad_weight,) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters