From 8124d3d0e5669c80c6149fa67a2c4960d34574bc Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 24 Oct 2025 12:19:13 +0200 Subject: [PATCH 01/14] init Signed-off-by: Pawel Gadzinski --- .github/workflows/docs.yml | 2 +- docs/api/pytorch.rst | 30 ++++++++-- docs/conf.py | 24 +++++++- docs/debug.rst | 1 + docs/debug/1_getting_started.rst | 15 +++-- docs/debug/2_config_file_structure.rst | 15 +++-- docs/debug/3_api_debug_setup.rst | 7 ++- docs/debug/3_api_features.rst | 2 +- docs/debug/4_distributed.rst | 13 +++-- docs/debug/api.rst | 1 + docs/examples/attention/attention.ipynb | 6 +- .../tutorial_generation_gemma_with_te.ipynb | 2 +- .../jax/cpp_extensions/activation.py | 3 +- transformer_engine/jax/cpp_extensions/gemm.py | 2 +- transformer_engine/jax/cpp_extensions/misc.py | 3 +- .../jax/cpp_extensions/normalization.py | 3 +- .../jax/cpp_extensions/quantization.py | 2 +- transformer_engine/jax/dense.py | 2 +- transformer_engine/jax/flax/transformer.py | 19 ++++--- transformer_engine/jax/quantize/hadamard.py | 4 +- transformer_engine/jax/quantize/quantizer.py | 1 - .../dot_product_attention.py | 42 +++++++------- .../pytorch/attention/multi_head_attention.py | 21 +++---- transformer_engine/pytorch/cross_entropy.py | 55 ++++++++++++++++++- transformer_engine/pytorch/distributed.py | 2 +- transformer_engine/pytorch/jit.py | 2 +- transformer_engine/pytorch/module/base.py | 35 ++++++------ .../pytorch/module/layernorm_linear.py | 2 +- .../pytorch/module/layernorm_mlp.py | 7 ++- transformer_engine/pytorch/module/linear.py | 2 +- transformer_engine/pytorch/ops/_common.py | 2 +- .../pytorch/ops/basic/l2normalization.py | 2 +- transformer_engine/pytorch/quantization.py | 2 +- transformer_engine/pytorch/transformer.py | 28 +++++----- transformer_engine/pytorch/utils.py | 7 ++- 35 files changed, 241 insertions(+), 125 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 3c4229a888..f4a1d4a2e6 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -25,7 +25,7 @@ jobs: run: | doxygen docs/Doxyfile cd docs - make html + make html SPHINXOPTS="-W" - name: 'Upload docs' uses: actions/upload-artifact@v4 with: diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index c456f1a6ad..c934e89653 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -37,9 +37,6 @@ pyTorch .. autoapiclass:: transformer_engine.pytorch.CudaRNGStatesTracker() :members: reset, get_states, set_states, add, fork -.. autoapifunction:: transformer_engine.pytorch.fp8_autocast - -.. autoapifunction:: transformer_engine.pytorch.fp8_model_init .. autoapifunction:: transformer_engine.pytorch.autocast @@ -47,6 +44,15 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.checkpoint + +.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables + +.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context + + +Recipe availability +------------------------ + .. autoapifunction:: transformer_engine.pytorch.is_fp8_available .. autoapifunction:: transformer_engine.pytorch.is_mxfp8_available @@ -63,9 +69,8 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.get_default_recipe -.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables - -.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context +Mixture of Experts (MoE) functions +------------------------------------------ .. autoapifunction:: transformer_engine.pytorch.moe_permute @@ -79,9 +84,22 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index_with_probs + +GEMM Comm overlap +--------------------- + .. autoapifunction:: transformer_engine.pytorch.initialize_ub .. autoapifunction:: transformer_engine.pytorch.destroy_ub .. autoapiclass:: transformer_engine.pytorch.UserBufferQuantizationMode :members: FP8, NONE + + +Deprecated functions +--------------------- + + +.. autoapifunction:: transformer_engine.pytorch.fp8_autocast + +.. autoapifunction:: transformer_engine.pytorch.fp8_model_init \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 4083bfd242..1f5679ca1b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -61,7 +61,10 @@ ] templates_path = ["_templates"] -exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] +exclude_patterns = [ + "_build", + "sphinx_rtd_theme", +] source_suffix = ".rst" @@ -101,3 +104,22 @@ autoapi_generate_api_docs = False autoapi_dirs = [root_path / "transformer_engine"] + + +# There are 2 warnings about the same namespace (transformer_engine) in two different c++ api +# docs pages. This seems to be the only way to suppress these warnings. +def setup(app): + """Custom Sphinx setup to filter warnings.""" + import logging + + # Filter out duplicate C++ declaration warnings + class DuplicateDeclarationFilter(logging.Filter): + def filter(self, record): + message = record.getMessage() + if "Duplicate C++ declaration" in message and "transformer_engine" in message: + return False + return True + + # Apply filter to Sphinx logger + logger = logging.getLogger("sphinx") + logger.addFilter(DuplicateDeclarationFilter()) diff --git a/docs/debug.rst b/docs/debug.rst index d33568ea3b..20ab69d00c 100644 --- a/docs/debug.rst +++ b/docs/debug.rst @@ -2,6 +2,7 @@ Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. + Precision debug tools ============================================== diff --git a/docs/debug/1_getting_started.rst b/docs/debug/1_getting_started.rst index 906c625567..9950915427 100644 --- a/docs/debug/1_getting_started.rst +++ b/docs/debug/1_getting_started.rst @@ -4,7 +4,7 @@ See LICENSE for license information. Getting started -============== +=============================== .. note:: @@ -38,7 +38,7 @@ To start debugging, one needs to create a configuration YAML file. This file lis one - ``UserProvidedPrecision`` - is a custom feature implemented by the user. Nvidia-DL-Framework-Inspect inserts features into the layers according to the config. Example training script ----------------------- +------------------------------ Let's look at a simple example of training a Transformer layer using Transformer Engine with FP8 precision. This example demonstrates how to set up the layer, define an optimizer, and perform a few training iterations using synthetic data. @@ -81,7 +81,7 @@ We will demonstrate two debug features on the code above: 2. Logging statistics for other GEMM operations, such as gradient statistics for data gradient GEMM within the LayerNormLinear sub-layer of the TransformerLayer. Config file ----------- +------------------------------ We need to prepare the configuration YAML file, as below @@ -114,7 +114,8 @@ We need to prepare the configuration YAML file, as below Further explanation on how to create config files is in the :doc:`next part of the documentation <2_config_file_structure>`. Adjusting Python file --------------------- +---------------------------- + .. code-block:: python @@ -145,7 +146,8 @@ In the modified code above, the following changes were made: 3. Added ``debug_api.step()`` after each of the forward-backward pass. Inspecting the logs ------------------- +---------------------------- + Let's look at the files with the logs. Two files will be created: @@ -213,7 +215,8 @@ The second log file (``nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank- INFO - transformer_layer.self_attention.layernorm_qkv_activation_l1_norm iteration=000004 value=130776.7969 Logging using TensorBoard ------------------------- +---------------------------- + Precision debug tools support logging using `TensorBoard `_. To enable it, one needs to pass the argument ``tb_writer`` to the ``debug_api.initialize()``. Let's modify ``train.py`` file. diff --git a/docs/debug/2_config_file_structure.rst b/docs/debug/2_config_file_structure.rst index f1069b0c80..2d9334de48 100644 --- a/docs/debug/2_config_file_structure.rst +++ b/docs/debug/2_config_file_structure.rst @@ -4,13 +4,14 @@ See LICENSE for license information. Config File Structure -==================== +=========================== To enable debug features, create a configuration YAML file to specify the desired behavior, such as determining which GEMMs (General Matrix Multiply operations) should run in higher precision rather than FP8 and defining which statistics to log. Below, we outline how to structure the configuration YAML file. General Format -------------- +---------------------------- + A config file can have one or more sections, each containing settings for specific layers and features: @@ -55,7 +56,8 @@ Sections may have any name and must contain: 3. Additional fields describing features for those layers. Layer Specification ------------------- +---------------------------- + Debug layers can be identified by a ``name`` parameter: @@ -89,7 +91,8 @@ Examples: (...) Names in Transformer Layers --------------------------- +-------------------------------- + There are three ways to assign a name to a layer in the Transformer Engine: @@ -154,7 +157,7 @@ Below is an example ``TransformerLayer`` with four linear layers that can be inf Structured Configuration for GEMMs and Tensors ---------------------------------------------- +----------------------------------------------------- Sometimes a feature is parameterized by a list of tensors or by a list of GEMMs. There are multiple ways of describing this parameterization. @@ -216,7 +219,7 @@ We can use both structs for tensors and GEMMs. The tensors_struct should be nest gemm_feature_param1: value Enabling or Disabling Sections and Features ------------------------------------------- +------------------------------------------------- Debug features can be enabled or disabled with the ``enabled`` keyword: diff --git a/docs/debug/3_api_debug_setup.rst b/docs/debug/3_api_debug_setup.rst index bda8f096d6..ccda556342 100644 --- a/docs/debug/3_api_debug_setup.rst +++ b/docs/debug/3_api_debug_setup.rst @@ -11,7 +11,8 @@ Please refer to the Nvidia-DL-Framework-Inspect `documentation `_ for more details. @@ -61,7 +62,7 @@ If the tensor reduction group is not specified, then statistics are reduced acro # activation/gradient tensor statistics are reduced along pipeline_parallel_group set_weight_tensor_tp_group_reduce() ---------------------------------- +----------------------------------------- By default, weight tensor statistics are reduced within the tensor parallel group. This function allows you to disable that behavior; for more details, see `reduction group section <./4_distributed.rst#reduction-groups>`_. diff --git a/docs/debug/3_api_features.rst b/docs/debug/3_api_features.rst index b31c437b2d..ffb07c6ced 100644 --- a/docs/debug/3_api_features.rst +++ b/docs/debug/3_api_features.rst @@ -4,7 +4,7 @@ See LICENSE for license information. Debug features -========== +=========================== .. autoapiclass:: transformer_engine.debug.features.log_tensor_stats.LogTensorStats .. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats diff --git a/docs/debug/4_distributed.rst b/docs/debug/4_distributed.rst index 6f69f2712c..e9e3ade3d4 100644 --- a/docs/debug/4_distributed.rst +++ b/docs/debug/4_distributed.rst @@ -4,7 +4,7 @@ See LICENSE for license information. Distributed training -=================== +==================================== Nvidia-Pytorch-Inspect with Transformer Engine supports multi-GPU training. This guide describes how to run it and how the supported features work in the distributed setting. @@ -14,7 +14,8 @@ To use precision debug tools in multi-GPU training, one needs to: 2. If one wants to log stats, one may want to invoke ``debug_api.set_tensor_reduction_group`` with a proper reduction group. Behavior of the features ------------------------ +---------------------------- + In a distributed setting, **DisableFP8GEMM** and **DisableFP8Layer** function similarly to the single-GPU case, with no notable differences. @@ -28,7 +29,8 @@ In a distributed setting, **DisableFP8GEMM** and **DisableFP8Layer** function si Logging-related features are more complex and will be discussed further in the next sections. Reduction groups --------------- +---------------------------- + In setups with tensor, data, or pipeline parallelism, some tensors are distributed across multiple GPUs, requiring a reduction operation to compute statistics for these tensors. @@ -65,7 +67,8 @@ Below, we illustrate configurations for a 4-node setup with tensor parallelism s Microbatching ------------ +---------------------------- + Let's dive into how statistics collection works with microbatching. By microbatching, we mean invoking multiple ``forward()`` calls for each ``debug_api.step()``. The behavior is as follows: @@ -73,7 +76,7 @@ Let's dive into how statistics collection works with microbatching. By microbatc - For other tensors, the stats are accumulated. Logging to files and TensorBoard ------------------------------- +------------------------------------------- In a single-node setup with ``default_logging_enabled=True``, all logs are saved by default to ``log_dir/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log``. In multi-GPU training, each node writes its reduced statistics to its unique file, named ``log_dir/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-i.log`` for rank i. Because these logs contain reduced statistics, the logged values are identical for all nodes within a reduction group. diff --git a/docs/debug/api.rst b/docs/debug/api.rst index ac593d353a..4e2cf99c67 100644 --- a/docs/debug/api.rst +++ b/docs/debug/api.rst @@ -2,6 +2,7 @@ Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. + API ============ diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb index 61a6ad949f..8591ce218f 100644 --- a/docs/examples/attention/attention.ipynb +++ b/docs/examples/attention/attention.ipynb @@ -16,8 +16,8 @@ "\n", "[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in two frameworks, [PyTorch](https://github.com/pytorch/pytorch) and [JAX](https://github.com/google/jax). The API for each framework is\n", "\n", - "- [transformer_engine.pytorch.DotProductAttention](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention)\n", - "- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)" + "- [transformer_engine.pytorch.DotProductAttention](../../api/pytorch.rst#transformer_engine.pytorch.dotproductattention)\n", + "- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.dotproductattention)" ] }, { @@ -606,7 +606,7 @@ "\n", "A unique feature of Transformer Engine is its FP8 support, not only for the `Linear` layers but also for dot product attention. Transformer Engine's FP8 attention support is through its cuDNN attention sub-backend 2. Recall Figure 1: the two `MatMul` operations are performed in FP8 for computational efficiency, and the `SoftMax` operation is performed in FP32 for numerical accuracy.\n", "\n", - "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v2.0. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n", + "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.dotproductattention), as of v2.0. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n", "\n", "- `DelayedScaling.fp8_dpa=True (default=False)`: This enables the use of cuDNN attention sub-backend 2, when it does support the provided user inputs. The `FusedAttention` module for cuDNN attention takes FP16 or BF16 tensors as inputs, performs dot product attention in FP8, and returns attention logits in FP16 or BF16 (same as the input type). Casting operations are required to cast tensors to FP8 at the beginning, and back to FP16/BF16 at the end of the module.\n", "\n", diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index c31e272b25..1ce60840b6 100755 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -38,7 +38,7 @@ "\n", "For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).\n", "\n", - "In a previous tutorial on [Llama](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb), it was demonstrated how finetuning of an open-source Llama model can be accelerated using Transformer Engine's `TransformerLayer`. Building on that foundation, this tutorial showcases how to accelerate the token generation from the open-source Hugging Face Gemma 7B model.\n", + "In a previous tutorial on [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), it was demonstrated how finetuning of an open-source Llama model can be accelerated using Transformer Engine's `TransformerLayer`. Building on that foundation, this tutorial showcases how to accelerate the token generation from the open-source Hugging Face Gemma 7B model.\n", "\n", "This tutorial introduces several features of the Transformer Engine library that contribute towards this goal. A brief explanation is as follows:\n", "\n", diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index bb3c56bcf1..330db28172 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -15,7 +15,7 @@ import numpy as np import transformer_engine_jax -from transformer_engine_jax import NVTE_Activation_Type +from transformer_engine_jax import NVTE_Activation_Type, QuantizeLayout from .base import BasePrimitive, register_primitive from .misc import ( jax_dtype_to_te_dtype, @@ -32,7 +32,6 @@ from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( Quantizer, - QuantizeLayout, DelayedScaleQuantizer, ScalingMode, ) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index b37c4bd848..7314130ddf 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -24,6 +24,7 @@ get_device_compute_capability, initialize_cgemm_communicator, get_cgemm_num_max_streams, + QuantizeLayout, ) from .base import BasePrimitive, register_primitive @@ -40,7 +41,6 @@ GroupedQuantizer, get_quantize_config, QuantizerSet, - QuantizeLayout, noop_quantizer_set, is_fp8_gemm_with_all_layouts_supported, apply_padding_to_scale_inv, diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 572d82f18d..93ec1d00c3 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -15,9 +15,10 @@ from jax.interpreters.mlir import dtype_to_ir_type import transformer_engine_jax +from transformer_engine_jax import QuantizeLayout from ..sharding import get_padded_spec as te_get_padded_spec -from ..quantize import ScaledTensorFactory, QuantizeLayout +from ..quantize import ScaledTensorFactory TEDType = transformer_engine_jax.DType diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 90ab5fb7fe..2ed6d55164 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -16,7 +16,7 @@ from jax.sharding import PartitionSpec import transformer_engine_jax -from transformer_engine_jax import NVTE_Norm_Type +from transformer_engine_jax import NVTE_Norm_Type, QuantizeLayout from .base import BasePrimitive, register_primitive from .misc import ( @@ -35,7 +35,6 @@ from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( Quantizer, - QuantizeLayout, DelayedScaleQuantizer, ScalingMode, ) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index b3f1e60f9a..fb22998fd7 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -15,6 +15,7 @@ from jax.sharding import PartitionSpec import transformer_engine_jax +from transformer_engine_jax import QuantizeLayout from .amax import AmaxScope, calculate_amax, calculate_post_rht_amax from .base import BasePrimitive, register_primitive @@ -40,7 +41,6 @@ GroupedScaledTensor1x, Quantizer, GroupedQuantizer, - QuantizeLayout, ScalingMode, compute_scale_from_amax, NoScaleTensor, diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 44c73a5b1e..a28252bbc3 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -15,12 +15,12 @@ import jax import jax.numpy as jnp +from transformer_engine_jax import QuantizeLayout from . import cpp_extensions as tex from .cpp_extensions.amax import AmaxScope from .quantize import ( ScaledTensorFactory, ScalingMode, - QuantizeLayout, QuantizerSet, noop_quantizer_set, with_sharding_constraint_by_logical_axes, diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 1eafed4131..e11a625249 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -457,13 +457,18 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods .. note:: THD format only supports 'padding' or 'causal_padding' mask type. - attn_mask_type mask/sequence_descriptor SWA softmax type - -------------------------------------------------------------------------------------------- - no_mask None None SCALED - causal None None SCALED_UPPER_TRIANG_MASKED - causal None Yes SCALED_MASKED - padding Required Yes/No SCALED_MASKED - padding_causal Required Yes/No SCALED_MASKED + .. table:: + :widths: auto + + ===================== ============================ ========== ================================= + attn_mask_type mask/sequence_descriptor SWA softmax type + ===================== ============================ ========== ================================= + no_mask None None SCALED + causal None None SCALED_UPPER_TRIANG_MASKED + causal None Yes SCALED_MASKED + padding Required Yes/No SCALED_MASKED + padding_causal Required Yes/No SCALED_MASKED + ===================== ============================ ========== ================================= attn_bias_type: Optional[str], default = None Type of the attention bias passed in the attention. diff --git a/transformer_engine/jax/quantize/hadamard.py b/transformer_engine/jax/quantize/hadamard.py index c0b74ef75e..a43b4da111 100644 --- a/transformer_engine/jax/quantize/hadamard.py +++ b/transformer_engine/jax/quantize/hadamard.py @@ -4,6 +4,7 @@ """Randomized Hadamard Transform (RHT) utilities for JAX.""" import jax.numpy as jnp +from transformer_engine_jax import QuantizeLayout from .scaling_modes import ScalingMode @@ -18,9 +19,6 @@ def should_use_rht(scaling_mode, is_colwise=None, q_layout=None) -> bool: Returns: bool: True if RHT should be used, False otherwise. """ - # Delayed import to avoid circular dependencies - from .quantizer import QuantizeLayout - assert (is_colwise is None) != ( q_layout is None ), "Exactly one of is_colwise or q_layout must be provided." diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 7bc08f834f..2dc946664d 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -36,7 +36,6 @@ from .device_utils import is_fp8_gemm_with_all_layouts_supported __all__ = [ - "QuantizeLayout", "Quantizer", "QuantizerSet", "CurrentScaleQuantizer", diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 6d9ce9a522..fa4d607197 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -278,16 +278,17 @@ class DotProductAttention(TransformerEngineBaseModule): can overlap two flash attention kernels. cp_comm_type : str, default = `p2p` inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p". - "p2p": Exchange KV chunks with P2P communications in ring topology. - P2P is async and can be overlapped with attention compute. - "all_gather": All-gather to get full sequence of KV before attention. - The all-gather is not async, and cannot be overlapped. - "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP - group, and gather to get full sequence of QKV. - "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV - across each CP sub-group (e.g., via NVLink), then exchanging KV with - p2p between sub-groups (e.g., via IBLink). + Can be ``"p2p"`` or ``"all_gather"`` or ``"a2a"`` or ``"a2a+p2p"``. + + - ``"p2p"``: Exchange KV chunks with P2P communications in ring topology. + P2P is async and can be overlapped with attention compute. + - ``"all_gather"``: All-gather to get full sequence of KV before attention. + The all-gather is not async, and cannot be overlapped. + - ``"a2a"``: Like DeepSpeed Ulysses, scatter attention heads across the CP + group, and gather to get full sequence of QKV. + - ``"a2a+p2p"``: hierarchical CP implementation. First applying a2a to QKV + across each CP sub-group (e.g., via NVLink), then exchanging KV with + p2p between sub-groups (e.g., via IBLink). """ def __init__( @@ -521,16 +522,17 @@ def set_context_parallel_group( cuda stream for context parallel execution. cp_comm_type : str, default = `p2p` inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather" or "a2a" or "a2a+p2p". - "p2p": Exchange KV chunks with P2P communications in ring topology. - P2P is async and can be overlapped with attention compute. - "all_gather": All-gather to get full sequence of KV before attention. - The all-gather is not async, and cannot be overlapped. - "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP - group, and gather to get full sequence of QKV. - "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV - across each CP sub-group (e.g., via NVLink), then exchanging KV with - p2p between sub-groups (e.g., via IBLink). + Can be ``"p2p"`` or ``"all_gather"`` or ``"a2a"`` or ``"a2a+p2p"``. + + - ``"p2p"``: Exchange KV chunks with P2P communications in ring topology. + P2P is async and can be overlapped with attention compute. + - ``"all_gather"``: All-gather to get full sequence of KV before attention. + The all-gather is not async, and cannot be overlapped. + - ``"a2a"``: Like DeepSpeed Ulysses, scatter attention heads across the CP + group, and gather to get full sequence of QKV. + - ``"a2a+p2p"``: hierarchical CP implementation. First applying a2a to QKV + across each CP sub-group (e.g., via NVLink), then exchanging KV with + p2p between sub-groups (e.g., via IBLink). """ self.cp_group = cp_group self.cp_global_ranks = cp_global_ranks diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index b3bda677bb..793a0f56ad 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -562,16 +562,17 @@ def set_context_parallel_group( cuda stream for context parallel execution. cp_comm_type : str, default = `p2p` inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather" or "a2a", "a2a+p2p". - "p2p": Exchange KV chunks with P2P communications in ring topology. - P2P is async and can be overlapped with attention compute. - "all_gather": All-gather to get full sequence of KV before attention. - The all-gather is not async, and cannot be overlapped. - "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP - group, and gather to get full sequence of QKV. - "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV - across each CP sub-group (e.g., via NVLink), then exchanging KV with - p2p between sub-groups (e.g., via IBLink). + Can be ``"p2p"`` or ``"all_gather"`` or ``"a2a"`` or ``"a2a+p2p"``. + + - ``"p2p"``: Exchange KV chunks with P2P communications in ring topology. + P2P is async and can be overlapped with attention compute. + - ``"all_gather"``: All-gather to get full sequence of KV before attention. + The all-gather is not async, and cannot be overlapped. + - ``"a2a"``: Like DeepSpeed Ulysses, scatter attention heads across the CP + group, and gather to get full sequence of QKV. + - ``"a2a+p2p"``: hierarchical CP implementation. First applying a2a to QKV + across each CP sub-group (e.g., via NVLink), then exchanging KV with + p2p between sub-groups (e.g., via IBLink). """ if isinstance(cp_group, dist_group_type): self.cp_size = get_distributed_world_size(cp_group) diff --git a/transformer_engine/pytorch/cross_entropy.py b/transformer_engine/pytorch/cross_entropy.py index 076dbec0dc..2de063ba47 100644 --- a/transformer_engine/pytorch/cross_entropy.py +++ b/transformer_engine/pytorch/cross_entropy.py @@ -4,6 +4,8 @@ """Cross Entropy Loss API""" +from typing import Optional + import torch import transformer_engine.pytorch.triton.cross_entropy as triton_cross_entropy @@ -87,4 +89,55 @@ def backward(ctx, grad_output): ) -parallel_cross_entropy = CrossEntropyFunction.apply +def parallel_cross_entropy( + _input: torch.Tensor, + target: torch.Tensor, + label_smoothing: float = 0.0, + reduce_loss: bool = False, + dist_process_group: Optional[torch.distributed.ProcessGroup] = None, + ignore_idx: int = -100, + is_cg_capturable: bool = False, +) -> torch.Tensor: + """ + Cross Entropy loss with optional distributed reduction. + + The input tensor can be in BF16/FP32, the loss and gradient calculation happens in + FP32 only. The returned loss is always in FP32, the input gradients are upcasted + to the datatype of the input. + + If ``dist_process_group`` is passed for distributed loss calculation, the input to each + distributed rank should be ``(*, V/world_size)``. Note that each of the ranks should + get equal shards along the V dimension. + + Parameters + ---------- + _input : torch.Tensor + The input tensor of shape ``(B, SQ, V)`` or ``(SQ, B, V)`` where B is batch size, + SQ is sequence length, V is vocab size. + target : torch.Tensor + The target tensor of shape ``(B, SQ)`` or ``(SQ, B)`` where each value is in ``[0, V-1]``. + label_smoothing : float, default = 0.0 + The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduce_loss : bool, default = False + If True, returns the averaged loss across the B*SQ dimension. + dist_process_group : torch.distributed.ProcessGroup, default = None + The distributed process group the loss computation is split across, None if on 1 device. + ignore_idx : int, default = -100 + The index for which loss and gradients are made to zero. + is_cg_capturable : bool, default = False + Whether the operation is CUDA graph capturable. + + Returns + ------- + torch.Tensor + The computed loss. + """ + return CrossEntropyFunction.apply( + _input, + target, + label_smoothing, + reduce_loss, + dist_process_group, + ignore_idx, + is_cg_capturable, + ) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 5ed73f6783..864dc194a9 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -29,7 +29,7 @@ import transformer_engine_torch as tex -from . import torch_version +from .utils import torch_version from .utils import ( is_non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data, diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index f0f77621e5..ad0029623c 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -8,7 +8,7 @@ from typing import Callable, Optional, Tuple import torch -from . import torch_version +from .utils import torch_version from .export import is_in_onnx_export_mode from .utils import gpu_autocast_ctx diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index d16455b5b4..e853879e53 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -19,7 +19,6 @@ import torch.nn.functional as F import transformer_engine_torch as tex -from transformer_engine.common.recipe import Recipe from ._common import _ParameterInitMeta, noop_cat from ..quantization import ( @@ -149,23 +148,23 @@ def initialize_ub( dtype : torch.dtype = torch.bfloat16 non-FP8 data type of the communication buffer when `use_fp8 = False` ub_cfgs: dict = None - Configuration dictionary with the structure - ``` - { - : { - "method": <"ring_exchange" or "pipeline">, - "is_reduce_scatter": bool, - "num_sm": int, - "cga_size": int, - "set_sm_margin": bool, - "num_splits": int, - "aggregate": bool, - "atomic_gemm": bool, - "use_ce": bool, - "fp8_buf": bool, - } - } - ``` + Configuration dictionary with the structure:: + + { + : { + "method": <"ring_exchange" or "pipeline">, + "is_reduce_scatter": bool, + "num_sm": int, + "cga_size": int, + "set_sm_margin": bool, + "num_splits": int, + "aggregate": bool, + "atomic_gemm": bool, + "use_ce": bool, + "fp8_buf": bool, + } + } + for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad", "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", "fc2_fprop", "fc2_wgrad"]`. diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 05f2e9cde4..691e1be1ca 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -15,7 +15,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe -from transformer_engine.pytorch import torch_version +from transformer_engine.pytorch.utils import torch_version from transformer_engine.pytorch.tensor.utils import is_experimental from .base import ( fill_userbuffers_buffer_for_all_gather, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a2ddb970af..bc76516f3d 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -16,7 +16,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe -from transformer_engine.pytorch import torch_version +from transformer_engine.pytorch.utils import torch_version from transformer_engine.pytorch.tensor.utils import is_experimental from .base import ( fill_userbuffers_buffer_for_all_gather, @@ -1435,8 +1435,9 @@ class LayerNormMLP(TransformerEngineBaseModule): type of normalization applied. activation : str, default = 'gelu' activation function used. - Options: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu', - 'silu', and 'swiglu'. + + Options: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, + ``'srelu'``, ``'sreglu'``, ``'silu'``, and ``'swiglu'``. init_method : Callable, default = `None` used for initializing FC1 weights in the following way: `init_method(weight)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3069c21d9f..dbd0b19666 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -13,7 +13,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe -from transformer_engine.pytorch import torch_version +from transformer_engine.pytorch.utils import torch_version from .base import ( fill_userbuffers_buffer_for_all_gather, diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 52ca84b5df..be9ffbd41a 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -10,7 +10,7 @@ import torch from transformer_engine_torch import FP8TensorMeta -from .. import torch_version +from ..utils import torch_version from ..quantization import FP8GlobalStateManager from ..tensor.float8_tensor import Float8Tensor from ..tensor.quantized_tensor import QuantizedTensorStorage diff --git a/transformer_engine/pytorch/ops/basic/l2normalization.py b/transformer_engine/pytorch/ops/basic/l2normalization.py index 440fee34d1..e299e75759 100644 --- a/transformer_engine/pytorch/ops/basic/l2normalization.py +++ b/transformer_engine/pytorch/ops/basic/l2normalization.py @@ -10,7 +10,7 @@ import torch -from ... import torch_version +from ...utils import torch_version from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...jit import ( l2normalization_fused, diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 030370b9db..27b6983f4b 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -26,8 +26,8 @@ NVFP4BlockScaling, CustomRecipe, ) - from .constants import dist_group_type + from .utils import get_device_compute_capability from .jit import jit_fuser diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 8a032b2f55..cf5d50cab9 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -10,7 +10,7 @@ import torch -from transformer_engine.pytorch import torch_version +from transformer_engine.pytorch.utils import torch_version from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention @@ -175,8 +175,9 @@ class TransformerLayer(torch.nn.Module): if set to `False`, the transformer layer will not learn any additive biases. activation : str, default = 'gelu' Type of activation used in MLP block. - Options are: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu', - 'silu', and 'swiglu'. + + Options: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, + ``'srelu'``, ``'sreglu'``, ``'silu'``, and ``'swiglu'``. device : Union[torch.device, str], default = "cuda" The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the @@ -559,16 +560,17 @@ def set_context_parallel_group( cuda stream for context parallel execution. cp_comm_type : str, default = `p2p` inter-gpu communication type for context parallelism. - Can be "p2p" or "all_gather" or "a2a", or "a2a+p2p". - "p2p": Exchange KV chunks with P2P communications in ring topology. - P2P is async and can be overlapped with attention compute. - "all_gather": All-gather to get full sequence of KV before attention. - The all-gather is not async, and cannot be overlapped. - "a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP - group, and gather to get full sequence of QKV. - "a2a+p2p": hierarchical CP implementation. First applying a2a to QKV - across each CP sub-group (e.g., via NVLink), then exchanging KV with - p2p between sub-groups (e.g., via IBLink). + Can be ``"p2p"`` or ``"all_gather"`` or ``"a2a"`` or ``"a2a+p2p"``. + + - ``"p2p"``: Exchange KV chunks with P2P communications in ring topology. + P2P is async and can be overlapped with attention compute. + - ``"all_gather"``: All-gather to get full sequence of KV before attention. + The all-gather is not async, and cannot be overlapped. + - ``"a2a"``: Like DeepSpeed Ulysses, scatter attention heads across the CP + group, and gather to get full sequence of QKV. + - ``"a2a+p2p"``: hierarchical CP implementation. First applying a2a to QKV + across each CP sub-group (e.g., via NVLink), then exchanging KV with + p2p between sub-groups (e.g., via IBLink). """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(self.modules()): diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 2be0aed4a8..e612820559 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -10,14 +10,19 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, Union import numpy as np import torch +from packaging.version import Version as PkgVersion -from . import torch_version from .tensor.quantized_tensor import Quantizer from ..debug.pytorch.debug_quantization import DebugQuantizedTensor __all__ = ["get_device_compute_capability", "get_cudnn_version", "is_bf16_available"] +@functools.lru_cache(maxsize=None) +def torch_version() -> tuple[int, ...]: + """Get PyTorch version""" + return PkgVersion(str(torch.__version__)).release + def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: """Check if any of the given tensors require gradient.""" From 2413bcbedb2b957fda09130c8d73fe9d9f188608 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 24 Oct 2025 11:12:52 +0000 Subject: [PATCH 02/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/conf.py | 6 +++--- transformer_engine/pytorch/utils.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 1f5679ca1b..da736811ea 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -106,12 +106,12 @@ autoapi_dirs = [root_path / "transformer_engine"] -# There are 2 warnings about the same namespace (transformer_engine) in two different c++ api +# There are 2 warnings about the same namespace (transformer_engine) in two different c++ api # docs pages. This seems to be the only way to suppress these warnings. def setup(app): """Custom Sphinx setup to filter warnings.""" import logging - + # Filter out duplicate C++ declaration warnings class DuplicateDeclarationFilter(logging.Filter): def filter(self, record): @@ -119,7 +119,7 @@ def filter(self, record): if "Duplicate C++ declaration" in message and "transformer_engine" in message: return False return True - + # Apply filter to Sphinx logger logger = logging.getLogger("sphinx") logger.addFilter(DuplicateDeclarationFilter()) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 98e1a314bc..0abc10e658 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -18,6 +18,7 @@ __all__ = ["get_device_compute_capability", "get_cudnn_version", "is_bf16_available"] + @functools.lru_cache(maxsize=None) def torch_version() -> tuple[int, ...]: """Get PyTorch version""" From 8dec7180ca3a446b780d808c2553851e2b5e68fe Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 24 Oct 2025 14:19:37 +0200 Subject: [PATCH 03/14] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/layernorm_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 933c7cde53..6540eeb6f9 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -15,7 +15,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe -from transformer_engine.pytorch import torch_version +from transformer_engine.pytorch.utils import torch_version from transformer_engine.pytorch.tensor.utils import is_custom from .base import ( fill_userbuffers_buffer_for_all_gather, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 0a32449f0c..3a114483ea 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -16,7 +16,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe -from transformer_engine.pytorch import torch_version +from transformer_engine.pytorch.utils import torch_version from transformer_engine.pytorch.tensor.utils import is_custom from .base import ( fill_userbuffers_buffer_for_all_gather, From 3ff3ca04f771317382e95a492ffed468da730eb5 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 24 Oct 2025 14:30:50 +0200 Subject: [PATCH 04/14] lines lenght Signed-off-by: Pawel Gadzinski --- docs/debug/1_getting_started.rst | 12 ++++++------ docs/debug/2_config_file_structure.rst | 12 ++++++------ docs/debug/3_api_debug_setup.rst | 6 +++--- docs/debug/3_api_features.rst | 2 +- docs/debug/4_distributed.rst | 10 +++++----- docs/debug/api.rst | 2 +- 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/docs/debug/1_getting_started.rst b/docs/debug/1_getting_started.rst index 9950915427..a5cdc1a6b1 100644 --- a/docs/debug/1_getting_started.rst +++ b/docs/debug/1_getting_started.rst @@ -4,7 +4,7 @@ See LICENSE for license information. Getting started -=============================== +=============== .. note:: @@ -38,7 +38,7 @@ To start debugging, one needs to create a configuration YAML file. This file lis one - ``UserProvidedPrecision`` - is a custom feature implemented by the user. Nvidia-DL-Framework-Inspect inserts features into the layers according to the config. Example training script ------------------------------- +----------------------- Let's look at a simple example of training a Transformer layer using Transformer Engine with FP8 precision. This example demonstrates how to set up the layer, define an optimizer, and perform a few training iterations using synthetic data. @@ -81,7 +81,7 @@ We will demonstrate two debug features on the code above: 2. Logging statistics for other GEMM operations, such as gradient statistics for data gradient GEMM within the LayerNormLinear sub-layer of the TransformerLayer. Config file ------------------------------- +----------- We need to prepare the configuration YAML file, as below @@ -114,7 +114,7 @@ We need to prepare the configuration YAML file, as below Further explanation on how to create config files is in the :doc:`next part of the documentation <2_config_file_structure>`. Adjusting Python file ----------------------------- +--------------------- .. code-block:: python @@ -146,7 +146,7 @@ In the modified code above, the following changes were made: 3. Added ``debug_api.step()`` after each of the forward-backward pass. Inspecting the logs ----------------------------- +------------------- Let's look at the files with the logs. Two files will be created: @@ -215,7 +215,7 @@ The second log file (``nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank- INFO - transformer_layer.self_attention.layernorm_qkv_activation_l1_norm iteration=000004 value=130776.7969 Logging using TensorBoard ----------------------------- +------------------------- Precision debug tools support logging using `TensorBoard `_. To enable it, one needs to pass the argument ``tb_writer`` to the ``debug_api.initialize()``. Let's modify ``train.py`` file. diff --git a/docs/debug/2_config_file_structure.rst b/docs/debug/2_config_file_structure.rst index 2d9334de48..d795d08be5 100644 --- a/docs/debug/2_config_file_structure.rst +++ b/docs/debug/2_config_file_structure.rst @@ -4,13 +4,13 @@ See LICENSE for license information. Config File Structure -=========================== +===================== To enable debug features, create a configuration YAML file to specify the desired behavior, such as determining which GEMMs (General Matrix Multiply operations) should run in higher precision rather than FP8 and defining which statistics to log. Below, we outline how to structure the configuration YAML file. General Format ----------------------------- +-------------- A config file can have one or more sections, each containing settings for specific layers and features: @@ -56,7 +56,7 @@ Sections may have any name and must contain: 3. Additional fields describing features for those layers. Layer Specification ----------------------------- +------------------- Debug layers can be identified by a ``name`` parameter: @@ -91,7 +91,7 @@ Examples: (...) Names in Transformer Layers --------------------------------- +--------------------------- There are three ways to assign a name to a layer in the Transformer Engine: @@ -157,7 +157,7 @@ Below is an example ``TransformerLayer`` with four linear layers that can be inf Structured Configuration for GEMMs and Tensors ------------------------------------------------------ +---------------------------------------------- Sometimes a feature is parameterized by a list of tensors or by a list of GEMMs. There are multiple ways of describing this parameterization. @@ -219,7 +219,7 @@ We can use both structs for tensors and GEMMs. The tensors_struct should be nest gemm_feature_param1: value Enabling or Disabling Sections and Features -------------------------------------------------- +------------------------------------------- Debug features can be enabled or disabled with the ``enabled`` keyword: diff --git a/docs/debug/3_api_debug_setup.rst b/docs/debug/3_api_debug_setup.rst index ccda556342..176bc13d32 100644 --- a/docs/debug/3_api_debug_setup.rst +++ b/docs/debug/3_api_debug_setup.rst @@ -11,7 +11,7 @@ Please refer to the Nvidia-DL-Framework-Inspect `documentation `_ for more details. @@ -62,7 +62,7 @@ If the tensor reduction group is not specified, then statistics are reduced acro # activation/gradient tensor statistics are reduced along pipeline_parallel_group set_weight_tensor_tp_group_reduce() ------------------------------------------ +----------------------------------- By default, weight tensor statistics are reduced within the tensor parallel group. This function allows you to disable that behavior; for more details, see `reduction group section <./4_distributed.rst#reduction-groups>`_. diff --git a/docs/debug/3_api_features.rst b/docs/debug/3_api_features.rst index ffb07c6ced..8cdbde8edd 100644 --- a/docs/debug/3_api_features.rst +++ b/docs/debug/3_api_features.rst @@ -4,7 +4,7 @@ See LICENSE for license information. Debug features -=========================== +============== .. autoapiclass:: transformer_engine.debug.features.log_tensor_stats.LogTensorStats .. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats diff --git a/docs/debug/4_distributed.rst b/docs/debug/4_distributed.rst index e9e3ade3d4..764fee6541 100644 --- a/docs/debug/4_distributed.rst +++ b/docs/debug/4_distributed.rst @@ -4,7 +4,7 @@ See LICENSE for license information. Distributed training -==================================== +==================== Nvidia-Pytorch-Inspect with Transformer Engine supports multi-GPU training. This guide describes how to run it and how the supported features work in the distributed setting. @@ -14,7 +14,7 @@ To use precision debug tools in multi-GPU training, one needs to: 2. If one wants to log stats, one may want to invoke ``debug_api.set_tensor_reduction_group`` with a proper reduction group. Behavior of the features ----------------------------- +------------------------ In a distributed setting, **DisableFP8GEMM** and **DisableFP8Layer** function similarly to the single-GPU case, with no notable differences. @@ -29,7 +29,7 @@ In a distributed setting, **DisableFP8GEMM** and **DisableFP8Layer** function si Logging-related features are more complex and will be discussed further in the next sections. Reduction groups ----------------------------- +---------------- In setups with tensor, data, or pipeline parallelism, some tensors are distributed across multiple GPUs, requiring a reduction operation to compute statistics for these tensors. @@ -67,7 +67,7 @@ Below, we illustrate configurations for a 4-node setup with tensor parallelism s Microbatching ----------------------------- +------------- Let's dive into how statistics collection works with microbatching. By microbatching, we mean invoking multiple ``forward()`` calls for each ``debug_api.step()``. The behavior is as follows: @@ -76,7 +76,7 @@ Let's dive into how statistics collection works with microbatching. By microbatc - For other tensors, the stats are accumulated. Logging to files and TensorBoard -------------------------------------------- +-------------------------------- In a single-node setup with ``default_logging_enabled=True``, all logs are saved by default to ``log_dir/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log``. In multi-GPU training, each node writes its reduced statistics to its unique file, named ``log_dir/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-i.log`` for rank i. Because these logs contain reduced statistics, the logged values are identical for all nodes within a reduction group. diff --git a/docs/debug/api.rst b/docs/debug/api.rst index 4e2cf99c67..6ccb32cc8b 100644 --- a/docs/debug/api.rst +++ b/docs/debug/api.rst @@ -4,7 +4,7 @@ See LICENSE for license information. API -============ +=== .. toctree:: :caption: Precision debug tools API From 81180058eae1926d759f8b0f64f4f5d5cbaedfea Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 24 Oct 2025 15:29:18 +0200 Subject: [PATCH 05/14] fix Signed-off-by: Pawel Gadzinski --- .github/workflows/docs.yml | 4 +- docs/conf.py | 1 + transformer_engine/jax/flax/transformer.py | 41 ++++++++++++------- transformer_engine/pytorch/__init__.py | 9 +--- .../dot_product_attention.py | 15 ++++--- .../pytorch/attention/multi_head_attention.py | 15 ++++--- transformer_engine/pytorch/transformer.py | 15 ++++--- 7 files changed, 59 insertions(+), 41 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index f4a1d4a2e6..f5c20e2eee 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -21,8 +21,8 @@ jobs: pip install breathe==4.35.0 sphinx-autoapi==3.3.2 sudo apt-get install -y pandoc graphviz doxygen export GIT_SHA=$(git show-ref --hash HEAD) - - name: 'Build docs' - run: | + - name: 'Build docs' + run: | # SPHINXOPTS="-W" errors out on warnings doxygen docs/Doxyfile cd docs make html SPHINXOPTS="-W" diff --git a/docs/conf.py b/docs/conf.py index da736811ea..dc791e43d3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -63,6 +63,7 @@ templates_path = ["_templates"] exclude_patterns = [ "_build", + "Thumbs.db", "sphinx_rtd_theme", ] diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index e11a625249..fa269a05a3 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -453,22 +453,32 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods * causal_padding / padding_causal: A combination of both causal and padding masks. Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect. + | + .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. + | + .. note:: THD format only supports 'padding' or 'causal_padding' mask type. + | + .. table:: :widths: auto - ===================== ============================ ========== ================================= - attn_mask_type mask/sequence_descriptor SWA softmax type - ===================== ============================ ========== ================================= - no_mask None None SCALED - causal None None SCALED_UPPER_TRIANG_MASKED - causal None Yes SCALED_MASKED - padding Required Yes/No SCALED_MASKED - padding_causal Required Yes/No SCALED_MASKED - ===================== ============================ ========== ================================= + ================== ============ ========== ============================== + attn_mask_type mask/sd SWA softmax type + ================== ============ ========== ============================== + no_mask None None SCALED + causal None None SCALED_UPPER_TRIANG_MASKED + causal None Yes SCALED_MASKED + padding Required Yes/No SCALED_MASKED + padding_causal Required Yes/No SCALED_MASKED + ================== ============ ========== ============================== + + where sd stands for sequence_descriptor. + + attn_bias_type: Optional[str], default = None Type of the attention bias passed in the attention. @@ -513,11 +523,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods Sliding window size. The default value is no sliding window. max_segments_per_seq: Optional[int], default = 1 The maximum number of segments per sequence, also used for THD format (sequence packing). - context_parallel_causal_load_balanced (bool): - Indicates the sequences are ordered for causal mask load balancing when running context parallelism. - context_parallel_axis (str): The name of the context parallel axis. - context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING. - context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention. + context_parallel_causal_load_balanced: bool + Indicates the sequences are ordered for causal mask load balancing when running context parallelism. + context_parallel_axis: str + The name of the context parallel axis. + context_parallel_strategy: CPStrategy + The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING. + context_checkpoint_name: str + The name of the context checkpoint in the forward pass of fused attention. Optimization parameters ----------------------- diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 9d894a389b..f5c9e3f157 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -14,13 +14,8 @@ from transformer_engine.common import load_framework_extension -@functools.lru_cache(maxsize=None) -def torch_version() -> tuple[int, ...]: - """Get PyTorch version""" - return PkgVersion(str(torch.__version__)).release - - -assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}." +torch_version = PkgVersion(str(torch.__version__)).release +assert torch_version >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version}." load_framework_extension("torch") diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index fa4d607197..4d70b14246 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -245,14 +245,17 @@ class DotProductAttention(TransformerEngineBaseModule): softmax scale for the attention scores. If `None`, defaults to `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`. softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' - softmax type as described in this paper: + Softmax type as described in the paper `Efficient Streaming Language Models with Attention Sinks `_. - For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv], - 'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), - 'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and - 'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), - where alpha is a learnable parameter in shape [h]. + + For a given attention score ``S = Q x K^T``, of shape ``[b, h, s_q, s_kv]``: + + * 'vanilla': ``S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1)`` + * 'off-by-one': ``S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1))`` + * 'learnable': ``S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1))`` + + where ``alpha`` is a learnable parameter in shape ``[h]``. 'off-by-one' and 'learnable' softmax types are also called sink attention ('zero sink' and 'learnable sink'). diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 793a0f56ad..b61ed35871 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -143,14 +143,17 @@ class MultiheadAttention(torch.nn.Module): name: str, default = `None` name of the module, currently used for debugging purposes. softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' - softmax type as described in this paper: + Softmax type as described in the paper `Efficient Streaming Language Models with Attention Sinks `_. - For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv], - 'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), - 'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and - 'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), - where alpha is a learnable parameter in shape [h]. + + For a given attention score ``S = Q x K^T``, of shape ``[b, h, s_q, s_kv]``: + + * 'vanilla': ``S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1)`` + * 'off-by-one': ``S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1))`` + * 'learnable': ``S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1))`` + + where ``alpha`` is a learnable parameter in shape ``[h]``. 'off-by-one' and 'learnable' softmax types are also called sink attention ('zero sink' and 'learnable sink'). diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index cf5d50cab9..26429f1470 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -193,14 +193,17 @@ class TransformerLayer(torch.nn.Module): name: str, default = `None` name of the module, currently used for debugging purposes. softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' - softmax type as described in this paper: + Softmax type as described in the paper `Efficient Streaming Language Models with Attention Sinks `_. - For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv], - 'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), - 'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and - 'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), - where alpha is a learnable parameter in shape [h]. + + For a given attention score ``S = Q x K^T``, of shape ``[b, h, s_q, s_kv]``: + + * 'vanilla': ``S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1)`` + * 'off-by-one': ``S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1))`` + * 'learnable': ``S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1))`` + + where ``alpha`` is a learnable parameter in shape ``[h]``. 'off-by-one' and 'learnable' softmax types are also called sink attention ('zero sink' and 'learnable sink'). From 1a9b9937afd2f0f6865f61dcc3c5f8673e93cb5f Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 30 Oct 2025 16:02:42 +0100 Subject: [PATCH 06/14] fix Signed-off-by: Pawel Gadzinski --- docs/api/pytorch.rst | 2 +- docs/examples/attention/attention.ipynb | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index c934e89653..0efa682504 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -85,7 +85,7 @@ Mixture of Experts (MoE) functions .. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index_with_probs -GEMM Comm overlap +Communication-computation overlap --------------------- .. autoapifunction:: transformer_engine.pytorch.initialize_ub diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb index 8591ce218f..3eed9f2f7d 100644 --- a/docs/examples/attention/attention.ipynb +++ b/docs/examples/attention/attention.ipynb @@ -16,8 +16,8 @@ "\n", "[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in two frameworks, [PyTorch](https://github.com/pytorch/pytorch) and [JAX](https://github.com/google/jax). The API for each framework is\n", "\n", - "- [transformer_engine.pytorch.DotProductAttention](../../api/pytorch.rst#transformer_engine.pytorch.dotproductattention)\n", - "- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.dotproductattention)" + "- [transformer_engine.pytorch.DotProductAttention](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention)\n", + "- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)" ] }, { @@ -174,7 +174,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "50852cb5", "metadata": {}, "outputs": [ @@ -266,7 +266,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "id": "906b8cf1", "metadata": {}, "outputs": [ @@ -299,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "id": "d3637094", "metadata": {}, "outputs": [ @@ -521,7 +521,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": null, "id": "a1f25a9b", "metadata": {}, "outputs": [ @@ -606,7 +606,7 @@ "\n", "A unique feature of Transformer Engine is its FP8 support, not only for the `Linear` layers but also for dot product attention. Transformer Engine's FP8 attention support is through its cuDNN attention sub-backend 2. Recall Figure 1: the two `MatMul` operations are performed in FP8 for computational efficiency, and the `SoftMax` operation is performed in FP32 for numerical accuracy.\n", "\n", - "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.dotproductattention), as of v2.0. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n", + "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v2.0. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n", "\n", "- `DelayedScaling.fp8_dpa=True (default=False)`: This enables the use of cuDNN attention sub-backend 2, when it does support the provided user inputs. The `FusedAttention` module for cuDNN attention takes FP16 or BF16 tensors as inputs, performs dot product attention in FP8, and returns attention logits in FP16 or BF16 (same as the input type). Casting operations are required to cast tensors to FP8 at the beginning, and back to FP16/BF16 at the end of the module.\n", "\n", From 0116d34f44037757301833e16ed3fb374962ac5f Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 30 Oct 2025 16:16:28 +0100 Subject: [PATCH 07/14] fix Signed-off-by: Pawel Gadzinski --- docs/api/pytorch.rst | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/transformer.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 0efa682504..fd6b2665e9 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -86,7 +86,7 @@ Mixture of Experts (MoE) functions Communication-computation overlap ---------------------- +--------------------------------- .. autoapifunction:: transformer_engine.pytorch.initialize_ub diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index b8a0de5808..a4e8929b6e 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1447,7 +1447,7 @@ class LayerNormMLP(TransformerEngineBaseModule): activation : str, default = 'gelu' activation function used. Options: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu', - 'silu', 'swiglu', and 'clamped_swiglu'. + 'silu', 'swiglu', and 'clamped_swiglu'. activation_params : dict, default = `None` Additional parameters for the activation function. At the moment, only used for 'clamped_swiglu' activation which diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 9f49a7fc46..60fd024d71 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -176,7 +176,7 @@ class TransformerLayer(torch.nn.Module): activation : str, default = 'gelu' Type of activation used in MLP block. Options are: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu', - 'silu', 'swiglu', and 'clamped_swiglu'. + 'silu', 'swiglu', and 'clamped_swiglu'. activation_params : Optional[dict], default = `None` Additional parameters for the activation function. At the moment, only used for 'clamped_swiglu' activation which From 20ba71918b1bead0b3acaf3f750a34954a65ca47 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 4 Nov 2025 15:45:34 +0100 Subject: [PATCH 08/14] subtitle --- fix in many files: Signed-off-by: Pawel Gadzinski --- docs/api/jax.rst | 8 ++++---- docs/api/pytorch.rst | 6 +++--- docs/debug.rst | 2 +- docs/index.rst | 2 +- docs/installation.rst | 4 ++-- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/api/jax.rst b/docs/api/jax.rst index 789b27e59c..3be15b154d 100644 --- a/docs/api/jax.rst +++ b/docs/api/jax.rst @@ -4,10 +4,10 @@ See LICENSE for license information. Jax -======= +=== Pre-defined Variable of Logical Axes ------------------------------------- +------------------------------------- Variables are available in `transformer_engine.jax.sharding`. * BATCH_AXES: The logical axis of batch dimension. It is usually sharded along DP + FSDP on Mesh. @@ -20,11 +20,11 @@ Variables are available in `transformer_engine.jax.sharding`. Checkpointing ------------------------------------- +------------- When using checkpointing with Transformer Engine JAX, please be aware of the checkpointing policy being applied to your model. Any JAX checkpointing policy using `dot`, such as `jax.checkpoint_policies.dots_with_no_batch_dims`, may not work with GEMMs provided by Transformer Engine as they do not always use the `jax.lax.dot_general` primitive. Instead, you can use `transformer_engine.jax.checkpoint_policies.dots_and_te_gemms_with_no_batch_dims` or similar policies that are designed to work with Transformer Engine's GEMMs and `jax.lax.dot_general` GEMMs. You may also use any JAX policies that do not filter by primitive, such as `jax.checkpoint_policies.save_only_these_names` or `jax.checkpoint_policies.everything_saveable`. Modules ------------------------------------- +------- .. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType .. autoapiclass:: transformer_engine.jax.MeshResource() diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index fd6b2665e9..7385af8a98 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -51,7 +51,7 @@ pyTorch Recipe availability ------------------------- +------------------- .. autoapifunction:: transformer_engine.pytorch.is_fp8_available @@ -70,7 +70,7 @@ Recipe availability .. autoapifunction:: transformer_engine.pytorch.get_default_recipe Mixture of Experts (MoE) functions ------------------------------------------- +---------------------------------- .. autoapifunction:: transformer_engine.pytorch.moe_permute @@ -97,7 +97,7 @@ Communication-computation overlap Deprecated functions ---------------------- +-------------------- .. autoapifunction:: transformer_engine.pytorch.fp8_autocast diff --git a/docs/debug.rst b/docs/debug.rst index 20ab69d00c..527f30ed02 100644 --- a/docs/debug.rst +++ b/docs/debug.rst @@ -4,7 +4,7 @@ See LICENSE for license information. Precision debug tools -============================================== +===================== .. toctree:: :caption: Precision debug tools diff --git a/docs/index.rst b/docs/index.rst index 2c04810f4d..c14a21e2e7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -4,7 +4,7 @@ See LICENSE for license information. Transformer Engine documentation -============================================== +================================= .. ifconfig:: "dev" in release diff --git a/docs/installation.rst b/docs/installation.rst index a8bb74fd1a..24563c456e 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -28,7 +28,7 @@ on `NVIDIA GPU Cloud `_. pip - from PyPI ------------------------ +--------------- Transformer Engine can be directly installed from `our PyPI `_, e.g. @@ -47,7 +47,7 @@ The core package from Transformer Engine (without any framework extensions) can By default, this will install the core library compiled for CUDA 12. The cuda major version can be specified by modified the extra dependency to `core_cu12` or `core_cu13`. pip - from GitHub ------------------------ +----------------- Additional Prerequisites ^^^^^^^^^^^^^^^^^^^^^^^^ From 19da61bcdfeb1b1e715b467362f855a10f8c75af Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 4 Nov 2025 15:54:04 +0100 Subject: [PATCH 09/14] cross entropy _input -> input rename Signed-off-by: Pawel Gadzinski --- docs/api/pytorch.rst | 3 +- transformer_engine/pytorch/cross_entropy.py | 37 ++++++++++++++------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 7385af8a98..1e0ec0fe01 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -49,6 +49,7 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context +.. autoapifunction:: transformer_engine.pytorch.parallel_cross_entropy Recipe availability ------------------- @@ -80,8 +81,6 @@ Mixture of Experts (MoE) functions .. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index -.. autoapifunction:: transformer_engine.pytorch.parallel_cross_entropy - .. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index_with_probs diff --git a/transformer_engine/pytorch/cross_entropy.py b/transformer_engine/pytorch/cross_entropy.py index 2de063ba47..caedc0d2b3 100644 --- a/transformer_engine/pytorch/cross_entropy.py +++ b/transformer_engine/pytorch/cross_entropy.py @@ -5,6 +5,7 @@ """Cross Entropy Loss API""" from typing import Optional +import warnings import torch @@ -25,7 +26,7 @@ class CrossEntropyFunction(torch.autograd.Function): @staticmethod def forward( ctx, - _input, + input, target, label_smoothing=0.0, reduce_loss=False, @@ -39,7 +40,7 @@ def forward( Parameters: ctx : The context object. - _input (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size. + input (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size. target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1]. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension. @@ -49,8 +50,8 @@ def forward( Returns: tensor: The computed loss. """ - loss, _input = triton_cross_entropy.cross_entropy_forward( - _input, + loss, input = triton_cross_entropy.cross_entropy_forward( + input, target, label_smoothing, reduce_loss, @@ -58,7 +59,7 @@ def forward( ignore_idx, ) - ctx.save_for_backward(_input.detach()) + ctx.save_for_backward(input.detach()) ctx.is_cg_capturable = is_cg_capturable return loss @@ -74,12 +75,12 @@ def backward(ctx, grad_output): Returns: tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. """ - (_input,) = ctx.saved_tensors - _input = triton_cross_entropy.cross_entropy_backward( - _input, grad_output, ctx.is_cg_capturable + (input,) = ctx.saved_tensors + input = triton_cross_entropy.cross_entropy_backward( + input, grad_output, ctx.is_cg_capturable ) return ( - _input, + input, None, None, None, @@ -90,13 +91,15 @@ def backward(ctx, grad_output): def parallel_cross_entropy( - _input: torch.Tensor, + input: torch.Tensor, target: torch.Tensor, label_smoothing: float = 0.0, reduce_loss: bool = False, dist_process_group: Optional[torch.distributed.ProcessGroup] = None, ignore_idx: int = -100, is_cg_capturable: bool = False, + *, + _input: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Cross Entropy loss with optional distributed reduction. @@ -111,7 +114,7 @@ def parallel_cross_entropy( Parameters ---------- - _input : torch.Tensor + input : torch.Tensor The input tensor of shape ``(B, SQ, V)`` or ``(SQ, B, V)`` where B is batch size, SQ is sequence length, V is vocab size. target : torch.Tensor @@ -132,8 +135,18 @@ def parallel_cross_entropy( torch.Tensor The computed loss. """ + # Handle backward compatibility with _input parameter + if _input is not None: + warnings.warn( + "The '_input' parameter is deprecated and will be removed in a future version. " + "Please use 'input' instead.", + FutureWarning, + stacklevel=2, + ) + input = _input + return CrossEntropyFunction.apply( - _input, + input, target, label_smoothing, reduce_loss, From 15c6741d26f53be134cfd457e276e5445965113b Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 4 Nov 2025 15:55:33 +0100 Subject: [PATCH 10/14] cross entropy _input -> input rename Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/cross_entropy.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/cross_entropy.py b/transformer_engine/pytorch/cross_entropy.py index caedc0d2b3..fac9730aaf 100644 --- a/transformer_engine/pytorch/cross_entropy.py +++ b/transformer_engine/pytorch/cross_entropy.py @@ -138,10 +138,9 @@ def parallel_cross_entropy( # Handle backward compatibility with _input parameter if _input is not None: warnings.warn( - "The '_input' parameter is deprecated and will be removed in a future version. " + "The '_input' parameter is deprecated. " "Please use 'input' instead.", FutureWarning, - stacklevel=2, ) input = _input From 556ab28e407bdfc9f340f06da7e43574657181c5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 4 Nov 2025 14:57:02 +0000 Subject: [PATCH 11/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/cross_entropy.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/cross_entropy.py b/transformer_engine/pytorch/cross_entropy.py index fac9730aaf..5bfec4608b 100644 --- a/transformer_engine/pytorch/cross_entropy.py +++ b/transformer_engine/pytorch/cross_entropy.py @@ -138,12 +138,11 @@ def parallel_cross_entropy( # Handle backward compatibility with _input parameter if _input is not None: warnings.warn( - "The '_input' parameter is deprecated. " - "Please use 'input' instead.", + "The '_input' parameter is deprecated. Please use 'input' instead.", FutureWarning, ) input = _input - + return CrossEntropyFunction.apply( input, target, From 35e75de9705d13ad3e13d6636fda06ee57e277b2 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 4 Nov 2025 16:29:39 +0100 Subject: [PATCH 12/14] fix Signed-off-by: Pawel Gadzinski --- docs/api/jax.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api/jax.rst b/docs/api/jax.rst index 3be15b154d..99782f99c7 100644 --- a/docs/api/jax.rst +++ b/docs/api/jax.rst @@ -7,7 +7,7 @@ Jax === Pre-defined Variable of Logical Axes -------------------------------------- +------------------------------------ Variables are available in `transformer_engine.jax.sharding`. * BATCH_AXES: The logical axis of batch dimension. It is usually sharded along DP + FSDP on Mesh. From 0a5eafdfaa275aeee19c8d5a991fae583b4770f2 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 4 Nov 2025 22:41:13 +0100 Subject: [PATCH 13/14] a lot of small fixes Signed-off-by: Pawel Gadzinski --- docs/api/pytorch.rst | 2 +- docs/conf.py | 1 + docs/examples/advanced_optimizations.ipynb | 4 +- docs/examples/attention/attention.ipynb | 4 +- .../common/fused_attn/kv_cache.cu | 2 +- .../include/transformer_engine/fused_attn.h | 2 +- transformer_engine/common/recipe/__init__.py | 14 +- transformer_engine/jax/cpp_extensions/misc.py | 2 +- transformer_engine/jax/flax/module.py | 94 +++--- transformer_engine/jax/flax/transformer.py | 152 ++++----- .../dot_product_attention.py | 277 +++++++-------- .../attention/dot_product_attention/utils.py | 78 ++--- .../pytorch/attention/inference.py | 10 +- .../pytorch/attention/multi_head_attention.py | 267 ++++++++------- transformer_engine/pytorch/attention/rope.py | 2 +- .../pytorch/cpp_extensions/fused_attn.py | 4 +- transformer_engine/pytorch/cpu_offload.py | 8 +- transformer_engine/pytorch/cross_entropy.py | 28 +- transformer_engine/pytorch/distributed.py | 32 +- transformer_engine/pytorch/export.py | 2 +- transformer_engine/pytorch/graph.py | 22 +- transformer_engine/pytorch/module/base.py | 28 +- .../pytorch/module/grouped_linear.py | 44 +-- .../pytorch/module/layernorm.py | 11 +- .../pytorch/module/layernorm_linear.py | 58 ++-- .../pytorch/module/layernorm_mlp.py | 82 ++--- transformer_engine/pytorch/module/linear.py | 58 ++-- transformer_engine/pytorch/module/rmsnorm.py | 13 +- transformer_engine/pytorch/quantization.py | 20 +- .../pytorch/quantized_tensor.py | 4 +- transformer_engine/pytorch/transformer.py | 316 +++++++++--------- 31 files changed, 841 insertions(+), 800 deletions(-) diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 1e0ec0fe01..eb6a4eefd1 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -3,7 +3,7 @@ See LICENSE for license information. -pyTorch +PyTorch ======= .. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs) diff --git a/docs/conf.py b/docs/conf.py index dc791e43d3..beb16bf2ec 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -98,6 +98,7 @@ ("Values", "params_style"), ("Graphing parameters", "params_style"), ("FP8-related parameters", "params_style"), + ("Quantization parameters", "params_style"), ] breathe_projects = {"TransformerEngine": root_path / "docs" / "doxygen" / "xml"} diff --git a/docs/examples/advanced_optimizations.ipynb b/docs/examples/advanced_optimizations.ipynb index 5dc9cb92f9..7c08bb6586 100644 --- a/docs/examples/advanced_optimizations.ipynb +++ b/docs/examples/advanced_optimizations.ipynb @@ -100,7 +100,7 @@ "\n", "\n", "\n", - "A variety of parallelism strategies can be used to enable multi-GPU training of Transformer models, often based on different approaches to distribute their $\\text{sequence_length} \\times \\text{batch_size} \\times \\text{hidden_size}$ activation tensors. The most common approach is data parallelism, which distributes along the $\\text{batch_size}$ dimension. By storing duplicate copies of the model on each GPU, the forward and backward passes of the training step can be done independently, followed by a gradient synchronization. A more advanced strategy is tensor parallelism, a type of model parallelism that distributes along the $\\text{hidden_size}$ dimension. This allows us to scale past the limits of data parallelism (typically $\\text{hidden_size} > \\text{batch_size}$) and to reduce the per-GPU memory usage (since model parameters are also distributed), but it also incurs the overhead of communicating activation tensors between GPUs at every step. For a more detailed explanation, please see the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Finally, sequence parallelism distributes along the $\\text{sequence_length}$ dimension. This can be used when tensor parallelism is enabled in order to parallelize operations that run outside the tensor-parallel region (e.g. layer norm). For more details, please see [this paper](https://arxiv.org/pdf/2205.05198.pdf).\n", + "A variety of parallelism strategies can be used to enable multi-GPU training of Transformer models, often based on different approaches to distribute their $\\text{sequence_length} \\cdot \\text{batch_size} \\cdot \\text{hidden_size}$ activation tensors. The most common approach is data parallelism, which distributes along the $\\text{batch_size}$ dimension. By storing duplicate copies of the model on each GPU, the forward and backward passes of the training step can be done independently, followed by a gradient synchronization. A more advanced strategy is tensor parallelism, a type of model parallelism that distributes along the $\\text{hidden_size}$ dimension. This allows us to scale past the limits of data parallelism (typically $\\text{hidden_size} > \\text{batch_size}$) and to reduce the per-GPU memory usage (since model parameters are also distributed), but it also incurs the overhead of communicating activation tensors between GPUs at every step. For a more detailed explanation, please see the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Finally, sequence parallelism distributes along the $\\text{sequence_length}$ dimension. This can be used when tensor parallelism is enabled in order to parallelize operations that run outside the tensor-parallel region (e.g. layer norm). For more details, please see [this paper](https://arxiv.org/pdf/2205.05198.pdf).\n", "\n", "To show this in action, let's first initialize NCCL with a trivial process group:" ] @@ -131,7 +131,7 @@ "id": "1f2b80d0", "metadata": {}, "source": [ - "We only initialize with one GPU to keep this example simple. Please consult the documentation [torch.distributed](https://pytorch.org/docs/stable/distributed.html) for guidance on running with multiple GPUs. Note that we require that each distributed process corresponds to exactly one GPU, so we treat them interchangeably. In practice, there are multiple factors that can affect the optimal parallel layout: the system hardware, the network topology, usage of other parallelism schemes like pipeline parallelism. A rough rule-of-thumb is to interpret the GPUs as a 2D grid with dimensions of $\\text{num_nodes} \\times \\text{gpus_per_node}$. The rows are tensor-parallel groups and the columns are data-parallel groups.\n", + "We only initialize with one GPU to keep this example simple. Please consult the documentation [torch.distributed](https://pytorch.org/docs/stable/distributed.html) for guidance on running with multiple GPUs. Note that we require that each distributed process corresponds to exactly one GPU, so we treat them interchangeably. In practice, there are multiple factors that can affect the optimal parallel layout: the system hardware, the network topology, usage of other parallelism schemes like pipeline parallelism. A rough rule-of-thumb is to interpret the GPUs as a 2D grid with dimensions of $\\text{num_nodes} \\cdot \\text{gpus_per_node}$. The rows are tensor-parallel groups and the columns are data-parallel groups.\n", "\n", "Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group.\n", "\n", diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb index 3eed9f2f7d..4b2ed80497 100644 --- a/docs/examples/attention/attention.ipynb +++ b/docs/examples/attention/attention.ipynb @@ -509,10 +509,10 @@ "\n", "* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n", " - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n", - " - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor in shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors in shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n", + " - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor of shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors of shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n", "\n", "\n", - "* JAX: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n", + "* JAX: Users should provide the `attention_mask` tensor of shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n", "\n", "**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n", "\n", diff --git a/transformer_engine/common/fused_attn/kv_cache.cu b/transformer_engine/common/fused_attn/kv_cache.cu index 67119c323b..3b78cab239 100644 --- a/transformer_engine/common/fused_attn/kv_cache.cu +++ b/transformer_engine/common/fused_attn/kv_cache.cu @@ -278,7 +278,7 @@ void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, in /*************************************************************************************************** * KV Cache: Copy new KV tokens to the KV cache * 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format - * 2. cu_new_lens and cu_cached_lens are in shape [b + 1]; cu_cached_lens include the added lens + * 2. cu_new_lens and cu_cached_lens are of shape [b + 1]; cu_cached_lens include the added lens * in current step * 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and * max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged. diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 518fad20de..cefbcb7ce8 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -131,7 +131,7 @@ enum NVTE_Mask_Type { * NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), * NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and * NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), - * where alpha is a learnable parameter in shape [H]. + * where alpha is a learnable parameter of shape [H]. */ enum NVTE_Softmax_Type { /*! Vanilla softmax */ diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 7bc39f0745..204dfa1829 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -50,7 +50,7 @@ class MMParams: Parameters ---------- - use_split_accumulator : bool, default = `True` + use_split_accumulator : bool, default = True Use FP8 fast accumulation on Hopper or Ada. For more details, see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul. """ @@ -159,7 +159,7 @@ def scaling_factor_compute(amax: Tensor, recipe: DelayedScaling) -> Tensor where `Tensor` is a framework tensor type. - reduce_amax: bool, default = `True` + reduce_amax: bool, default = True By default, if `torch.distributed` is initialized, the `amax` value for FP8 tensors is reduced across the `amax_reduction_group` (specified in the `autocast` call). This keeps the amaxes and scaling factors synced across the given @@ -167,13 +167,13 @@ def scaling_factor_compute(amax: Tensor, GPU maintains local amaxes and scaling factors. To ensure results are numerically identical across checkpointing boundaries in this case, all ranks must checkpoint in order to store the local tensors. - fp8_dpa: bool, default = `False` + fp8_dpa: bool, default = False Whether to enable FP8 dot product attention (DPA). When the model is placed in an `autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the inputs from higher precision to FP8, performs attention in FP8, and casts tensors back to higher precision as outputs. FP8 DPA currently is only supported in the `FusedAttention` backend. - fp8_mha: bool, default = `False` + fp8_mha: bool, default = False Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting operations mentioned above at the DPA boundaries. Currently only standard MHA modules i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When @@ -422,11 +422,11 @@ class NVFP4BlockScaling(Recipe): ---------- fp4_format : {Format.E2M1}, default = Format.E2M1 FP4 data type. - disable_rht : bool, default = `False` + disable_rht : bool, default = False If set to `True`, random Hadamard transforms are not applied to any tensor. - disable_stochastic_rounding : bool, default = `False` + disable_stochastic_rounding : bool, default = False If set to `True`, stochastic rounding is disabled during quantization for all tensors. - disable_2d_quantization : bool, default = `False` + disable_2d_quantization : bool, default = False If set to `True`, 1D block scaling with block size 16 is used for all tensors. """ diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 93ec1d00c3..89225b1c2a 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -117,7 +117,7 @@ def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1): transpose. Note, transpose_axis should be greater than static_axis_boundary examples: - X in shape (dim0, dim1, dim2, dim3, dim4) + X of shape (dim0, dim1, dim2, dim3, dim4) static_axis_boundary == -1, transpose_axis == 2 Xt = (dim2, dim3, dim4, dim0, dim1) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index c54ecb236f..15864549c8 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -252,26 +252,26 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm' Indicate the type of layer normalization. zero_centered_gamma : bool, default = False - If set to `True`, the LayerNorm formula changes to + If set to ``True``, the LayerNorm formula changes to .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot (1 + \gamma) + \beta - This parameter is only applicable for 'layernorm'. - The default of `scale_init` will also be changed. See `scale_init`. + This parameter is only applicable for ``'layernorm'``. + The default of ``scale_init`` will also be changed. See ``scale_init``. scale_init : Initializer, default = None Used for initializing scale factors :math:`\gamma`. - If `None` is provided, scale_init is set according to the value of zero_centered_gamma. - If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`. - Otherwise, scale_init is `flax.linen.initializers.ones`. - It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). + If ``None`` is provided, scale_init is set according to the value of zero_centered_gamma. + If zero_centered_gamma is set to ``True``, then scale_init is ``flax.linen.initializers.zeros``. + Otherwise, scale_init is ``flax.linen.initializers.ones``. + It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``. scale_axes : Tuple[str, ...], default = ('embed', ) The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh. bias_init : Initializer, default = flax.linen.initializers.zeros Used for initializing shift factors :math:`\beta`, only used when :attr:`layernorm_type='layernorm'`. - It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). + It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``. bias_axes : Tuple[str, ...], default = ('embed', ) The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh. only used when :attr:`layernorm_type='layernorm'`. @@ -391,15 +391,15 @@ class DenseGeneral(TransformerEngineBase): kernel_init : Initializer, default = flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') Used for initializing weights. - It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). + It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``. kernel_axes : Tuple[str, ...], default = () The name of axes used to shard the weights with a corresponding mesh. use_bias: bool, default = False Indicate whether to enable bias shifting. - If set to False, the layer will not learn an additive bias. + If set to ``False``, the layer will not learn an additive bias. bias_init: Initializer, default = flax.linen.initializers.zeros Used for initializing bias, only used when :attr:`use_bias=True`. - It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). + It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``. bias_axes: Tuple[str, ...], default = () The name of axes used to shard bias with a corresponding mesh, only used when :attr:`use_bias=True`. @@ -410,12 +410,12 @@ class DenseGeneral(TransformerEngineBase): :attr:`enable_low_rank_adaptation=True` low_rank_adaptation_alpha: float, default = None The alpha for computing the scaling factor of LoRA output. - :math:`\frac{alpha}{rank} * lora_output`. None means no scaling. + :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling. axis: Union[Iterable[int], int], default = -1 An integer tuple with axes to apply the transformation on. input_axes: Tuple[str, ...], default = None Indicate the logical axes of sharding constraint to the input, like - (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert + ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert sharding constraint. Optimization parameters @@ -558,48 +558,48 @@ class LayerNormDenseGeneral(TransformerEngineBase): epsilon : float, default = 1e-6 A value added to the denominator of layer normalization for numerical stability. zero_centered_gamma : bool, default = False - If set to `True`, the LayerNorm formula changes to + If set to ``True``, the LayerNorm formula changes to .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot (1 + \gamma) + \beta - This parameter is only applicable for 'layernorm'. - The default of `scale_init` will also be changed. See `scale_init` + This parameter is only applicable for ``'layernorm'``. + The default of ``scale_init`` will also be changed. See ``scale_init`` scale_init : Initializer, default = None Used for initializing scale factors :math:`\gamma`. - If `None` is provided, scale_init is set according to the value of zero_centered_gamma. - If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`. - Otherwise, scale_init is `flax.linen.initializers.ones`. - It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). + If ``None`` is provided, scale_init is set according to the value of zero_centered_gamma. + If zero_centered_gamma is set to ``True``, then scale_init is ``flax.linen.initializers.zeros``. + Otherwise, scale_init is ``flax.linen.initializers.ones``. + It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``. scale_axes : Tuple[str, ...], default = ('embed', ) The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh, only used when :attr:`enable_layernorm=True`. ln_bias_init: Initializer, default = flax.linen.initializers.zeros Used for initializing shift factors :math:`\beta`, only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. - It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). + It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``. ln_bias_axes: Tuple[str, ...], default = ('embed', ) The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh. It is only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. kernel_init : Initializer, default = flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') Used for initializing weights. - It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). + It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``. kernel_axes : Tuple[str, ...], default = () The name of axes used to shard the weights with a corresponding mesh. use_bias: bool, default = False Indicate whether to enable bias shifting. - If set to False, the layer will not learn an additive bias. + If set to ``False``, the layer will not learn an additive bias. bias_init: Initializer, default = flax.linen.initializers.zeros Used for initializing bias, only used when :attr:`use_bias=True`. - It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). + It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``. bias_axes: Tuple[str, ...], default = () The name of axes used to shard bias with a corresponding mesh, only used when :attr:`use_bias=True`. return_layernorm_output: bool, default = True Indicate whether to return the output of layer normalization. - If set False, return None as the second tensor in outputs. + If set ``False``, return ``None`` as the second tensor in outputs. enable_low_rank_adaptation: bool, default = False Indicate whether to enable low rank adaptation for each dense layer. low_rank_adaptation_dim: int, default = 32 @@ -607,16 +607,16 @@ class LayerNormDenseGeneral(TransformerEngineBase): :attr:`enable_low_rank_adaptation=True` low_rank_adaptation_alpha: float, default = None The alpha for computing the scaling factor of LoRA output. - :math:`\frac{alpha}{rank} * lora_output`. None means no scaling. + :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling. axis: Union[Iterable[int], int], default = -1 An integer tuple with axes to apply the transformation on. layernorm_input_axes: Tuple[str, ...], default = None Indicate the logical axes of sharding constraint to the input of layernorm, like - (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert + ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert sharding constraint. dot_input_axes: Tuple[str, ...], default = None Indicate the logical axes of sharding constraint to the input of dot, like - (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert + ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert sharding constraint. Optimization parameters @@ -843,34 +843,34 @@ class LayerNormMLP(TransformerEngineBase): epsilon : float, default = 1e-6 A value added to the denominator of layer normalization for numerical stability. zero_centered_gamma : bool, default = False - If set to `True`, the LayerNorm formula changes to + If set to ``True``, the LayerNorm formula changes to .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot (1 + \gamma) + \beta - This parameter is only applicable for 'layernorm'. - The default of `scale_init` will also be changed. See `scale_init`. + This parameter is only applicable for ``'layernorm'``. + The default of ``scale_init`` will also be changed. See ``scale_init``. scale_init : Initializer, default = None Used for initializing scale factors :math:`\gamma`. - If `None` is provided, scale_init is set according to the value of zero_centered_gamma. - If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`. - Otherwise, scale_init is `flax.linen.initializers.ones`. - It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). + If ``None`` is provided, scale_init is set according to the value of zero_centered_gamma. + If zero_centered_gamma is set to ``True``, then scale_init is ``flax.linen.initializers.zeros``. + Otherwise, scale_init is ``flax.linen.initializers.ones``. + It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``. scale_axes : Tuple[str, ...], default = ('embed', ) The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh, only used when :attr:`enable_layernorm=True`. ln_bias_init: Initializer, default = flax.linen.initializers.zeros Used for initializing shift factors :math:`\beta`, only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. - It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). + It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``. ln_bias_axes: Tuple[str, ...], default = ('embed', ) The name of axes used to shard the shift factors :math:`\beta` with a corresponding mesh. Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`. kernel_init : Initializer, default = flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') Used for initializing the weights of both dense layer transformations. - It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). + It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``. kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp') The name of axes used to shard the weights with a corresponding mesh for the weight of the first dense layer transformation. @@ -879,10 +879,10 @@ class LayerNormMLP(TransformerEngineBase): the weight of the second dense layer transformation. use_bias: bool, default = False Indicate whether to enable bias shifting. - If set to False, the layer will not learn an additive bias. + If set to ``False``, the layer will not learn an additive bias. bias_init: Initializer, default = flax.linen.initializers.zeros Used for initializing bias, only used when :attr:`use_bias=True`. - It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). + It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``. bias_axes_1: Tuple[str, ...], default = ('mlp',) The name of axes used to shard bias with a corresponding mesh for the weight of the first dense layer transformation. @@ -893,7 +893,7 @@ class LayerNormMLP(TransformerEngineBase): Only used when :attr:`use_bias=True`. return_layernorm_output: bool, default = True Indicate whether to return the output of layer normalization. - If set False, return None as the second tensor in outputs. + If set ``False``, return ``None`` as the second tensor in outputs. activations: Sequence[Union[str, Callable]], default = ('relu',) The sequence of activation functions to apply after the first dense layer transformation. Each activation has its own transformation layer. @@ -914,20 +914,20 @@ class LayerNormMLP(TransformerEngineBase): :attr:`enable_low_rank_adaptation=True`. low_rank_adaptation_alpha: float, default = None The alpha for computing the scaling factor of LoRA output. - :math:`\frac{alpha}{rank} * lora_output`. None means no scaling. + :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling. axis: Union[Iterable[int], int], default = -1 An integer tuple with axes to apply the transformation on. layernorm_input_axes: Tuple[str, ...], default = None Indicate the logical axes of sharding constraint to the input of layernorm, like - (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert + ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert sharding constraint. dot_1_input_axes: Tuple[str, ...], default = None Indicate the logical axes of sharding constraint to the input of 1st dot, like - (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert + ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert sharding constraint. dot_2_input_axes: Tuple[str, ...], default = None Indicate the logical axes of sharding constraint to the input of 2nd dot, like - (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert + ``(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)``. Default is ``None``, which means not to insert sharding constraint. ffn1_ckpt_name: str = "ffn1" Checkpoint name for the output of the first fully-connected layer in the MLP block. diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index fa269a05a3..d1a1821b67 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -425,7 +425,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods The hidden dimension of each attention head. num_attention_heads: int The number of attention heads. - num_gqa_groups: int, default = `None` + num_gqa_groups: int, default = None Number of GQA groups. When `None` is present, it is equal to num_attention_heads. Grouped Query Attention is described in `this paper `_. @@ -442,24 +442,24 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods Each described below: - * no_mask: No attention mask is applied. This means the attention will consider the + * ``no_mask``: No attention mask is applied. This means the attention will consider the full sequence without any restrictions. - * padding: Indicates the presence of padding at the end of each sequence. - Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the + * ``padding``: Indicates the presence of padding at the end of each sequence. + Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the :attr:`__call__` method to specify the padding positions. - * causal: An upper triangular mask is applied to the softmax inputs, + * ``causal``: An upper triangular mask is applied to the softmax inputs, ensuring that the prediction for a certain position is only dependent on known outputs from positions before it. - * causal_padding / padding_causal: A combination of both causal and padding masks. - Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect. + * ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks. + Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect. | - .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. + .. note:: :attr:`mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``. | - .. note:: THD format only supports 'padding' or 'causal_padding' mask type. + .. note:: THD format only supports ``'padding'`` or ``'causal_padding'`` mask type. | @@ -581,7 +581,7 @@ def __call__( mask: jax.numpy.ndarray, default = None Boolean tensor used to mask out the attention softmax input. :attr:`True` means to mask out the corresponding values. - Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'. + Ignored when :attr:`self.attn_mask_type` is either ``'no_mask'`` or ``'causal'``. bias: jax.numpy.ndarray, default = None A tensor used to shift attention softmax input. *: @@ -763,7 +763,7 @@ def rotary_pos_emb( ): """ Rotary Positional Embedding - x should be in shape of + x should be of shape [Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or [Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True. """ @@ -901,7 +901,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods The hidden dimension of each attention head. num_attention_heads: int The number of attention heads. - num_gqa_groups: int, default = `None` + num_gqa_groups: int, default = None Number of GQA groups. When `None` is present, it is equal to num_attention_heads. Grouped Query Attention is described in `this paper `_. @@ -918,24 +918,24 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods Each described below: - * no_mask: No attention mask is applied. This means the attention will consider the + * ``no_mask``: No attention mask is applied. This means the attention will consider the full sequence without any restrictions. - * padding: Indicates the presence of padding at the end of each sequence. - Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the + * ``padding``: Indicates the presence of padding at the end of each sequence. + Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the :attr:`__call__` method to specify the padding positions. - * causal: An upper triangular mask is applied to the softmax inputs, + * ``causal``: An upper triangular mask is applied to the softmax inputs, ensuring that the prediction for a certain position is only dependent on known outputs from positions before it. - * causal_padding / padding_causal: A combination of both causal and padding masks. - Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect. + * ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks. + Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect. - .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. + .. note:: :attr:`mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``. attn_bias_type: Optional[str], default = None Type of the attention bias passed in the attention. - Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. + Available options: ``{'no_bias', 'pre_scale_bias', 'post_scale_bias'}``. When default is present, the type is automatically decided by the MHA's bias parameter. - Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used. + Where it is ``'post_scale_bias'`` if there is bias. Otherwise ``'no_bias'`` is used. dropout_rng_name: str, default = 'dropout' The key in given RNGs via flax.linen.Module.apply that is used to generate Dropout masks in the core attention. @@ -944,27 +944,27 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods layernorm_epsilon: float, default = 1e-6 A value added to the denominator of layer normalization for numerical stability. zero_centered_gamma: bool, default = False - If set to `True`, the LayerNorm formula changes to + If set to ``True``, the LayerNorm formula changes to .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} \cdot (1 + \gamma) + \beta - This parameter is only applicable for 'layernorm'. + This parameter is only applicable for ``'layernorm'``. kernel_init: Initializer, default = - flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal') + ``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')`` Used for initializing the QKV and output projection weights. - It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). + It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``. use_bias: bool, default = False Indicate whether or not to enable bias shifting for QKV and output projections. - If set to False, the layer will not learn additive biases. - bias_init: Initializer, default = flax.linen.initializers.zeros + If set to ``False``, the layer will not learn additive biases. + bias_init: Initializer, default = ``flax.linen.initializers.zeros`` Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`. - It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). + It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``. input_layernorm: bool, default = True - If set to False, layer normalization to the input is not applied. + If set to ``False``, layer normalization to the input is not applied. return_layernorm_output: bool, default = False - If set to True, output of layernorm is returned from the forward together with the output + If set to ``True``, output of layernorm is returned from the forward together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm. enable_rotary_pos_emb: bool, default = False @@ -974,17 +974,17 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods only used when :attr:`enable_rotary_pos_emb=True` rotary_pos_emb_group_method: str, default = 'consecutive' Indicate the method to coupled the coordinates. It should be one of - ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2` - , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`. + ``['consecutive', 'alternate']``. ``'alternate'`` is to pair index :math:`i` with :math:`i + d/2` + , d is the hidden dimension. ``'consecutive'`` pairs index :math:`i` with :math:`i + 1`. low_rank_adaptation_scope: str, default = 'none' Indicate the scope to apply low rank adaptation. It should be one of - ['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj'] + ``['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']`` low_rank_adaptation_dim: int, default = 32 The dimension for low rank adaptation, only used when :attr:`enable_low_rank_adaptation=True` low_rank_adaptation_alpha: float, default = None The alpha for computing the scaling factor of LoRA output. - :math:`\frac{alpha}{rank} * lora_output`. None means no scaling. + :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling. enable_sequence_parallel: bool, default = False Whether to enable sequence parallelism to operations except dot. num_heads: int, default = None @@ -1010,8 +1010,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). scale_attn_logits: bool, default = False Indicate whether to scale attention logits. - If set to True, :math:`\frac{Q}{\sqrt{head\_dim}*K}`, - else :math:`Q*K` + If set to True, :math:`\frac{Q \cdot K^T}{\sqrt{head\_dim}}`, + else :math:`Q \cdot K^T` scaled_query_init: bool, default = True Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}` float32_logits: bool, default = False @@ -1125,7 +1125,7 @@ def __call__( mask: jax.numpy.ndarray, default = None Boolean tensor used to mask out the attention softmax input. :attr:`True` means mask out the corresponding values. - Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'. + Ignored when :attr:`self.attn_mask_type` is either ``'no_mask'`` or ``'causal'``. bias: jax.numpy.ndarray, default = None A tensor used to shift the attention softmax input. * @@ -1610,7 +1610,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods Intermediate size to which input samples are projected. num_attention_heads: int, default = 8 Number of attention heads in the transformer layer. - num_gqa_groups: int, default = `None` + num_gqa_groups: int, default = None Number of GQA groups. When `None` is present, it is equal to num_attention_heads. Grouped Query Attention is described in `this paper `_. @@ -1644,31 +1644,31 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks in the Multi-Head Attention. mha_kernel_init: Initializer, default = - flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal') + ``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')`` Used for initializing weights of QKV and Output projection weights. - It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). + It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``. mlp_kernel_init: Initializer, default = - flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') + ``flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')`` Used for initializing weights of FC1 and FC2 layers. - It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). + It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``. mlp_activations: Sequence[str], default = ('relu', ) The sequence of activation functions to apply after the first linear transformation. Each activation has its own transformation layer. mlp_activation_params: dict = None - This is only used when ('clamped_silu', 'clamped_linear') is in :attr:`mlp_activations`. At the moment - ClampedSwiglu is the only activation that requires parameters. + This is only used when ``('clamped_silu', 'clamped_linear')`` is in :attr:`mlp_activations`. At the moment + ``ClampedSwiglu`` is the only activation that requires parameters. use_bias: bool, default = False Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2. - If set to False, the layer will not learn additive biases. - bias_init: Initializer, default = flax.linen.initializers.zeros + If set to ``False``, the layer will not learn additive biases. + bias_init: Initializer, default = ``flax.linen.initializers.zeros`` Used for initializing bias of QKVO projections, FC1 and FC2. It is only used when :attr:`use_bias=True`. - It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). + It should be a callable object with three arguments ``(jax.random.PRNGKey, shape, dtype)``. apply_residual_connection_post_layernorm: bool, default = False - If set to True, residual connections are taken from the output + If set to ``True``, residual connections are taken from the output of layer norm (default is taken from input of layer norm) output_layernorm: bool, default = False - If set to True, layer normalization is applied on the output side, + If set to ``True``, layer normalization is applied on the output side, after the final dropout-add. default behavior is to apply layer normalization on the input side, before the QKV transformation. float32_attention_logits: bool, default = False @@ -1676,7 +1676,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods For fused attention backend, the accumulation is always float32 without the perf overhead. layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER If set to TransformerLayerType.DECODER, an additional cross-attention block - is added after self-attention.this can be used for structures like `T5` + is added after self-attention.this can be used for structures like T5 Transformer in conjunction with the TransformerLayerType.ENCODER option. self_attn_mask_type: str, default = 'causal' This parameter specifies the type of attention mask to be applied during the softmax @@ -1685,34 +1685,34 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods Each described below: - * no_mask: No attention mask is applied. This means the self attention will consider the + * ``no_mask``: No attention mask is applied. This means the self attention will consider the full sequence without any restrictions. - * padding: Indicates the presence of padding at the end of each sequence. - Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the + * ``padding``: Indicates the presence of padding at the end of each sequence. + Users must provide a mask with the shape ``[batch, 1, max_seqlen_q, max_seqlen_kv]`` in the :attr:`__call__` method to specify the padding positions. - * causal: An upper triangular mask is applied to the softmax inputs, + * ``causal``: An upper triangular mask is applied to the softmax inputs, ensuring that the prediction for a certain position is only dependent on known outputs from positions before it. - * causal_padding / padding_causal: A combination of both causal and padding masks. - Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect. + * ``causal_padding`` / ``padding_causal``: A combination of both causal and padding masks. + Both ``'causal_padding'`` and ``'padding_causal'`` are acceptable and have the same effect. - .. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. + .. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for ``'no_mask'`` and ``'causal'``. self_attn_bias_type: Optional[str], default = None Type of the attention bias passed into the self attention. - Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. + Available options: ``{'no_bias', 'pre_scale_bias', 'post_scale_bias'}``. When default is present, the type is automatically decided by the MHA's bias parameter. - Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used. + Where it is ``'post_scale_bias'`` if there is bias. Otherwise ``'no_bias'`` is used. enable_relative_embedding: bool, default = True Whether to enable relative embedding as shifting of attention logits. relative_embedding: flax.linen.Module, default = None The module for relative embedding execution, only used when - :attr:`enable_relative_embedding=True`. Default is None, which will create + :attr:`enable_relative_embedding=True`. Default is ``None``, which will create an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`. - Default: RelativePositionBiases( num_buckets=32, max_distance=128, + Default: ``RelativePositionBiases( num_buckets=32, max_distance=128, num_attention_heads=self.num_attention_heads, dtype=self.dtype, embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'), - name='relpos_bias') + name='relpos_bias')`` enable_rotary_pos_emb: bool, default = False Whether to enable rotary position embedding to projected query and key in MHA. rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000) @@ -1720,19 +1720,19 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods only used when :attr:`enable_rotary_pos_emb=True` rotary_pos_emb_group_method: str, default = 'consecutive' Indicate the method to couple the coordinates. It should be one of - ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`, - where :math:`d` is the hidden dimension. 'consecutive' pairs index :math:`i` with + ``['consecutive', 'alternate']``. ``'alternate'`` is to pair index :math:`i` with :math:`i + d/2`, + where :math:`d` is the hidden dimension. ``'consecutive'`` pairs index :math:`i` with :math:`i + 1`. low_rank_adaptation_scope: str, default = 'none' Indicate the scope to apply low rank adaptation. It should be one of - ['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj', - 'exclude_output_proj', 'exclude_mlp'] + ``['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj', + 'exclude_output_proj', 'exclude_mlp']`` low_rank_adaptation_dim: int, default = 32 The dimension for low rank adaptation, only used when :attr:`enable_low_rank_adaptation=True` low_rank_adaptation_alpha: float, default = None The alpha for computing the scaling factor of LoRA output. - :math:`\frac{alpha}{rank} * lora\_output`. None means no scaling. + :math:`\frac{alpha}{rank} \cdot lora\_output`. ``None`` means no scaling. enable_sequence_parallel: bool, default = False Whether to enable sequence parallelism to operations except dot. window_size: Optional[Tuple[int, int]], default = None @@ -1746,19 +1746,19 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods When > 0.0, applies stochastic depth per sample in the main path of the residual block. fuse_qkv_params: bool, default = True - If set to True, `TransformerLayer` module exposes a single fused + If set to ``True``, ``TransformerLayer`` module exposes a single fused parameter for query-key-value for self-attention and key-value for cross-attention. transpose_batch_sequence: bool, default = False Indicate whether the input tensors were switched axis of batch - and sequence length dimension. if set to True, the input tensors - should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). + and sequence length dimension. if set to ``True``, the input tensors + should be in ``(seqlen, batch, hidden)``, otherwise ``(batch, seqlen, hidden)``. scale_attn_logits: bool, default = False Indicate whether to scale attention logits. - if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`, - else :math:`Q*K` - scaled_query_init: bool, default = `True` - Whether to scale WQ on initialization by :math:`\sqrt{head_dim}` + if set to ``True``, :math:`\frac{Q \cdot K^T}{\sqrt{head\_dim}}`, + else :math:`Q \cdot K^T` + scaled_query_init: bool, default = True + Whether to scale WQ on initialization by :math:`\sqrt{head\_dim}` """ hidden_size: int = 512 @@ -1840,7 +1840,7 @@ def __call__( attention_mask : jax.numpy.ndarray, default = None Boolean tensor used to mask out self-attention softmax input. :attr:`True` means mask out the corresponding values. - Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'. + Ignored when :attr:`self.self_attn_mask_type` is either ``'no_mask'`` or ``'causal'``. encoder_decoder_mask: jax.numpy.ndarray, default = None Boolean tensor used to mask out cross-attention softmax input when :attr:`layer_type=TransformerLayerType.DECODER`. diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 15fb56db17..8a951c9bc6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -152,25 +152,25 @@ class DotProductAttention(TransformerEngineBaseModule): - """Allows the model to jointly attend to information from different + r"""Allows the model to jointly attend to information from different representation subspaces as described in the paper: `Attention Is All You Need `_. .. note:: - Argument :attr:`attention_mask` in the `forward` call is only used when - :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`. + Argument :attr:`attention_mask` in the ``forward`` call is only used when + :attr:`attn_mask_type` includes '"padding"' or ``"arbitrary"``. .. warning:: FlashAttention uses a non-deterministic algorithm for optimal performance. To observe - deterministic behavior at the cost of performance, use FlashAttention version >= `2.4.1` + deterministic behavior at the cost of performance, use FlashAttention version >= ``2.4.1`` and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order - to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`. + to disable ``flash-attn`` entirely, set :attr:`NVTE_FLASH_ATTN=0`. .. note:: - Transformer Engine stores the FP8 metadata under a `._extra_state` key when checkpointing. + Transformer Engine stores the FP8 metadata under a ``._extra_state`` key when checkpointing. As the FP8 attention support expands from one backend to multiple backends, the location of that key has also shifted (see `FP8 checkpoint compatibility `_). @@ -182,110 +182,125 @@ class DotProductAttention(TransformerEngineBaseModule): kv_channels : Union[int, Tuple[int, int]] the head size in key and value tensors. If the same, :attr:`kv_channels` can be an integer; if not, :attr:`kv_channels` should be a tuple of two integers. - num_gqa_groups : Optional[int] = None + num_gqa_groups : Optional[int], default = None number of GQA groups in the transformer layer. Grouped Query Attention is described in `this paper `_. This only affects the keys and values, not the queries. GQA-1 is equivalent to Multi-Query Attention (`MQA `_), while GQA-H - is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. + is equivalent to MHA, i.e. ``num_gqa_groups = num_attention_heads``. attention_dropout: float, default = 0.0 dropout probability for the dropout op during multi-head attention. - attn_mask_type: str, default = `causal` - type of attention mask passed into softmax operation, options are "`no_mask`", - "`padding`", "`causal`", "`padding,causal`", "`causal,padding`", - "`padding_causal`", "`causal_bottom_right`", "`padding_causal_bottom_right`", and - "`arbitrary`", where "`padding,causal`", "`causal,padding`" and "`padding_causal`" + attn_mask_type: str, default = "causal" + type of attention mask passed into softmax operation, options are ``"no_mask"``, + ``"padding"``, ``"causal"``, ``"padding,causal"``, ``"causal,padding"``, + ``"padding_causal"``, ``"causal_bottom_right"``, ``"padding_causal_bottom_right"``, and + ``"arbitrary"``, where ``"padding,causal"``, ``"causal,padding"`` and ``"padding_causal"`` are equivalent. This arg can be overridden by :attr:`attn_mask_type` in the - `forward` method. It is useful for cases involving compilation/tracing, e.g. + :meth:`forward` method. It is useful for cases involving compilation/tracing, e.g. ONNX export, and the forward arg is useful for dynamically changing mask types, e.g. a different mask for training and inference. - 1. For "`no_mask`", no attention mask is applied. - 2. For "`causal`", "`causal_bottom_right`", or the causal mask in - "`padding_causal`" and "`padding_causal_bottom_right`", Transformer Engine - calculates and applies an upper triangular mask to the softmax input. - No user input is needed. Causal masks without the "`bottom_right`" appendix align - the diagonal line to the top left corner of the softmax matrix. With - "`bottom_right`", the causal mask is aligned to the bottom right corner, which is - often used in inference/KV caching. - 3. For "`padding`", or the padding mask in "`padding_causal`" and - "`padding_causal_bottom_right`", users need to provide the locations of padded - tokens, either via :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv` (both in shape - [batch_size + 1]), or via :attr:`attention_mask` (one tensor for self-attention - in shape [batch_size, 1, 1, max_seqlen_q], or two tensors in a tuple for - cross-attention in shapes [batch_size, 1, 1, max_seqlen_q] and - [batch_size, 1, 1, max_seqlen_kv]). - 4. For "`arbitrary`", users need to provide a mask that is broadcastable to - the shape of softmax input [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. - window_size: Optional[Tuple[int, int]], default = `None` + + 1. For ``"no_mask"``, no attention mask is applied. + 2. For ``"causal"``, ``"causal_bottom_right"``, or the causal mask in + ``"padding_causal"`` and ``"padding_causal_bottom_right"``, Transformer Engine + calculates and applies an upper triangular mask to the softmax input. + No user input is needed. Causal masks without the ``"bottom_right"`` appendix align + the diagonal line to the top left corner of the softmax matrix. With + ``"bottom_right"``, the causal mask is aligned to the bottom right corner, which is + often used in inference/KV caching. + 3. For ``"padding"``, or the padding mask in ``"padding_causal"`` and + ``"padding_causal_bottom_right"``, users need to provide the locations of padded + tokens, either via :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv` (both of shape + ``[batch_size + 1]``), or via :attr:`attention_mask` (one tensor for self-attention + of shape ``[batch_size, 1, 1, max_seqlen_q]``, or two tensors in a tuple for + cross-attention of shapes ``[batch_size, 1, 1, max_seqlen_q]`` and + ``[batch_size, 1, 1, max_seqlen_kv]``). + 4. For ``"arbitrary"``, users need to provide a mask that is broadcastable to + the shape of softmax input ``[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]``. + + window_size: Optional[Tuple[int, int]], default = None sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. Both `causal` and `causal_bottom_right` masks - map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on - `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can - be overridden by :attr:`window_size` in `forward` as well. - attention_type: str, default = `self` - type of attention, either "`self`" and "`cross`". - layer_number: int, default = `None` - layer number of the current `DotProductAttention` when multiple such modules + in ``[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + + window_size[1]] inclusive. Special cases ``(-1, -1)`` and ``(-1, 0)`` mean no sliding + window and causal mask specifically. Both ``causal`` and ``causal_bottom_right`` masks + map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on + ``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can + be overridden by :attr:`window_size` in ``forward`` as well. + attention_type: str, default = "self" + type of attention, either ``"self"`` and ``"cross"``. + layer_number: int, default = None + layer number of the current ``DotProductAttention`` when multiple such modules are concatenated, for instance in consecutive transformer blocks. - qkv_format: str, default = `sbhd` - dimension format for `query_layer`, `key_layer` and `value_layer`, - {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length, `b` batch size, - `h` the number of heads, `d` head size, and `t` the total number of tokens - in a batch, with `t = sum(s_i), for i = 0...b-1`. `sbhd` and `bshd` formats + qkv_format: str, default = "sbhd" + dimension format for ``query_layer``, ``key_layer`` and ``value_layer``, + {``"sbhd"``, ``"bshd"``, ``"thd"``}. ``s`` stands for the sequence length, ``b`` batch size, + ``h`` the number of heads, ``d`` head size, and ``t`` the total number of tokens + in a batch, with ``t = sum(s_i), for i = 0...b-1``. ``"sbhd"`` and ``"bshd"`` formats are used for when sequences in a batch are of equal length or padded to - equal length, and the `thd` format is used for when sequences in a batch + equal length, and the ``"thd"`` format is used for when sequences in a batch have different lengths. Please note that these formats do not reflect how - tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory. - For that, please use `get_qkv_layout` to gain the layout information. - softmax_scale: Optional[float], default = `None` - softmax scale for the attention scores. If `None`, defaults to - `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`. + tensors ``query_layer``, ``key_layer``, ``value_layer`` are laid out in memory. + For that, please use ``get_qkv_layout`` to gain the layout information. + softmax_scale: Optional[float], default = None + softmax scale for the attention scores. If ``None``, defaults to + ``1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])``. softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' Softmax type as described in the paper `Efficient Streaming Language Models with Attention Sinks `_. - For a given attention score ``S = Q x K^T``, of shape ``[b, h, s_q, s_kv]``: + For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``: + + * ``'vanilla'``: + + .. math:: + Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})} + + * ``'off-by-one'``: + + .. math:: + Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})} - * 'vanilla': ``S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1)`` - * 'off-by-one': ``S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1))`` - * 'learnable': ``S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1))`` + * ``'learnable'``: - where ``alpha`` is a learnable parameter in shape ``[h]``. - 'off-by-one' and 'learnable' softmax types are also called sink attention - ('zero sink' and 'learnable sink'). - return_max_logit: Optional[bool], default = `False` + .. math:: + Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})} + + where :math:`\alpha` is a learnable parameter of shape ``[h]``. + + ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention + (``'zero sink'`` and ``'learnable sink'``). + + return_max_logit: Optional[bool], default = False If true, returns the maximum attention score that can be used in a Muon optimizer to rescale the Q and K projection weights (see `Muon is Scalable for LLM Training `_). - max_logit = max(S), where S = mask(Q*K^T*softmax_scale + bias) in shape [b, h, s_q, s_kv], - and max_logit is in shape [h]. + :math:`\text{max_logit} = \max(S)`, where :math:`S = \text{mask}(Q \cdot K^T \cdot \text{softmax_scale} + \text{bias})` of shape ``[b, h, s_q, s_kv]``, + and :math:`\text{max_logit}` is of shape ``[h]``. Parallelism parameters ---------------------- - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. + sequence_parallel : bool, default = False + if set to ``True``, uses sequence parallelism. tp_size : int, default = 1 tensor parallel world size. - tp_group : ProcessGroup, default = `None` + tp_group : ProcessGroup, default = None tensor parallel process group. - cp_group : Union[ProcessGroup, List[ProcessGroup]], default = `None` + cp_group : Union[ProcessGroup, List[ProcessGroup]], default = None context parallel process group. - ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a". - List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0] - and cp_group[1] are for a2a and p2p communications respectively. - cp_global_ranks : list of global rank IDs, default = `None` - global rank IDs of GPUs that are in cp_group. - cp_stream : CUDA stream, default = `None` + ``ProcessGroup`` is for :attr:`cp_comm_type` of ``"p2p"``, ``"all_gather"``, and ``"a2a"``. + ``List[ProcessGroup]`` is for :attr:`cp_comm_type` of ``"a2a+p2p"``, where :attr:`cp_group[0]` + and :attr:`cp_group[1]` are for ``"a2a"`` and ``"p2p"`` communications respectively. + cp_global_ranks : list of global rank IDs, default = None + global rank IDs of GPUs that are in ``cp_group``. + cp_stream : CUDA stream, default = None context parallelism splits flash attention into multiple steps for compute and communication overlapping. To address the wave quantization issue of each split step, we add an additional CUDA stream so that we can overlap two flash attention kernels. - cp_comm_type : str, default = `p2p` + cp_comm_type : str, default = "p2p" inter-gpu communication type for context parallelism. Can be ``"p2p"`` or ``"all_gather"`` or ``"a2a"`` or ``"a2a+p2p"``. @@ -472,8 +487,8 @@ def _load_from_state_dict( ): """ This function helps to load Transformer Engine 1.6 and 1.7 checkpoints, where FP8 attention - metadata is stored under the `core_attention.fused_attention._extra_state` key and not the - `core_attention._extra_state` key. Please see `FP8 checkpoint compatibility + metadata is stored under the ``core_attention.fused_attention._extra_state`` key and not the + ``core_attention._extra_state`` key. Please see `FP8 checkpoint compatibility `_ for more details. """ fused_attn_key = False @@ -526,14 +541,14 @@ def set_context_parallel_group( ---------- cp_group : Union[ProcessGroup, List[ProcessGroup]] context parallel process group. - ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a". - List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0] - and cp_group[1] are for a2a and p2p communications respectively. + ``ProcessGroup`` is for :attr:`cp_comm_type` of ``"p2p"``, ``"all_gather"``, and ``"a2a"``. + ``List[ProcessGroup]`` is for :attr:`cp_comm_type` of ``"a2a+p2p"``, where :attr:`cp_group[0]` + and :attr:`cp_group[1]` are for ``"a2a"`` and ``"p2p"`` communications respectively. cp_global_ranks : List[int] list of global ranks in the context group. cp_stream : torch.cuda.Stream cuda stream for context parallel execution. - cp_comm_type : str, default = `p2p` + cp_comm_type : str, default = "p2p" inter-gpu communication type for context parallelism. Can be ``"p2p"`` or ``"all_gather"`` or ``"a2a"`` or ``"a2a+p2p"``. @@ -805,13 +820,13 @@ def forward( pad_between_seqs: Optional[bool] = None, fp8_output: Optional[bool] = False, ) -> torch.Tensor: - """ + r""" Dot Product Attention Layer. .. note:: Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type` - includes '"padding"' or `"arbitrary"`. + includes ``"padding"`` or ``"arbitrary"``. .. note:: @@ -850,24 +865,24 @@ def forward( Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask` (which will be converted to :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`), to provide the real sequence length information. For example, a batch of 3 sequences - [a a a b b c c c c] can be padded to [a a a PAD b b PAD PAD c c c c], and the cumulative + ``[a a a b b c c c c]`` can be padded to ``[a a a PAD b b PAD PAD c c c c]``, and the cumulative sequence length tensors would be - :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention. + :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = ``[0, 3, 5, 9]`` for self-attention. 2. Do not perform padding on training data. Use :attr:`qkv_format` = "thd" and :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}. Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`, - as in option 1. For example, a batch of 3 sequences [a a a b b c c c c] can be processed + as in option 1. For example, a batch of 3 sequences ``[a a a b b c c c c]`` can be processed without any padding, and the sequence length tensors would be - :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention. + :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = ``[0, 3, 5, 9]`` for self-attention. In certain use cases, a varying number of identifier tokens are inserted between sequences. These tokens do not participate in the attention calculation. :attr:`cu_seqlens_q_padded` and :attr:`cu_seqlens_kv_padded` must be specified in such cases to correctly identify the start and end of each sequence in a batch. - For example, a batch of 3 sequences [a a a 1 b b 2 2 c c c c 3] would have - :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9], and - :attr:`cu_seqlens_q_padded` = :attr:`cu_seqlens_kv_padded` = [0, 4, 8, 13] + For example, a batch of 3 sequences ``[a a a 1 b b 2 2 c c c c 3]`` would have + :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = ``[0, 3, 5, 9]``, and + :attr:`cu_seqlens_q_padded` = :attr:`cu_seqlens_kv_padded` = ``[0, 4, 8, 13]`` for self-attention. .. note:: @@ -902,81 +917,81 @@ def forward( value_layer : torch.Tensor Value tensor. attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], - default = `None`. Boolean tensor(s) used to mask out attention softmax input. - It should be `None` for causal masks and "`no_mask`". For padding masks, it should be - a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of - two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] - for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable - to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means - the corresponding position is masked out and a `False` means that position + default = None. Boolean tensor(s) used to mask out attention softmax input. + It should be ``None`` for causal masks and ``"no_mask"``. For padding masks, it should be + a single tensor of ``[batch_size, 1, 1, seqlen_q]`` for self-attention, and a tuple of + two tensors of shapes ``[batch_size, 1, 1, seqlen_q]`` and ``[batch_size, 1, 1, seqlen_kv]`` + for cross-attention. For ``"arbitrary"`` mask, it should be of a shape broadcastable + to ``[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]``. A ``True`` value means + the corresponding position is masked out and a ``False`` means that position is allowed to participate in attention. - qkv_format: str, default = `None` + qkv_format: str, default = None If provided, overrides :attr:`qkv_format` from initialization. - cu_seqlens_q: Optional[torch.Tensor], default = `None` - Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, + cu_seqlens_q: Optional[torch.Tensor], default = None + Cumulative sum of sequence lengths (without offset) in a batch for ``query_layer``, with shape [batch_size + 1] and dtype torch.int32. See :ref:`note` for more details. - cu_seqlens_kv: Optional[torch.Tensor], default = `None` - Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` - and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. + cu_seqlens_kv: Optional[torch.Tensor], default = None + Cumulative sum of sequence lengths (without offset) in a batch for ``key_layer`` + and ``value_layer``, with shape [batch_size + 1] and dtype torch.int32. See :ref:`note` for more details. - cu_seqlens_q_padded: Optional[torch.Tensor], default = `None` + cu_seqlens_q_padded: Optional[torch.Tensor], default = None Cumulative sum of sequence lengths (with offset) in a batch for - `query_layer`, with shape [batch_size + 1] and dtype torch.int32. + ``query_layer``, with shape ``[batch_size + 1]`` and dtype torch.int32. When there is no padding between sequences in a batch, - `cu_seqlens_q_padded = cu_seqlens_q`. + :attr:`cu_seqlens_q_padded` = :attr:`cu_seqlens_q`. See :ref:`note` for more details. - cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None` - Cumulative sum of sequence lengths (with offset) in a batch for `key_layer` - and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. + cu_seqlens_kv_padded: Optional[torch.Tensor], default = None + Cumulative sum of sequence lengths (with offset) in a batch for ``key_layer`` + and ``value_layer``, with shape ``[batch_size + 1]`` and dtype torch.int32. When there is no padding between sequences in a batch, - `cu_seqlens_kv_padded = cu_seqlens_kv`. + :attr:`cu_seqlens_kv_padded` = :attr:`cu_seqlens_kv`. See :ref:`note` for more details. - max_seqlen_q: Optional[int], default = `None` - Maximum sequence length in `query_layer`. + max_seqlen_q: Optional[int], default = None + Maximum sequence length in ``query_layer``. See :ref:`note` for more details. - max_seqlen_kv: Optional[int], default = `None` - Maximum sequence length in `key_layer` and `value_layer`. + max_seqlen_kv: Optional[int], default = None + Maximum sequence length in ``key_layer`` and ``value_layer``. See :ref:`note` for more details. attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding', 'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right', - 'arbitrary'}, default = `None`. Type of attention mask passed into + 'arbitrary'}, default = None. Type of attention mask passed into softmax operation. 'padding,causal', 'causal,padding' and 'padding_causal' are equivalent. By default, causal masks are aligned to the top left corner - of the softmax matrix. When "`bottom_right`" is specified in the mask type, + of the softmax matrix. When ``"bottom_right"`` is specified in the mask type, causal masks are aligned to the bottom right corner. - window_size: Optional[Tuple[int, int]], default = `None` + window_size: Optional[Tuple[int, int]], default = None Sliding window size for local attention. - checkpoint_core_attention : bool, default = `False` + checkpoint_core_attention : bool, default = False If true, forward activations for attention are recomputed during the backward pass in order to save memory that would otherwise be occupied to store the forward activations until backprop. - core_attention_bias_type: str, default = `no_bias` - Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`} - core_attention_bias: Optional[torch.Tensor], default = `None` - Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv]. - It should be 'None' for 'no_bias' and 'alibi' bias types. - alibi_slopes: Optional[torch.Tensor], default = `None` - ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. + core_attention_bias_type: str, default = "no_bias" + Bias type, {``"no_bias"``, ``"pre_scale_bias"``, ``"post_scale_bias"``, ``"alibi"``} + core_attention_bias: Optional[torch.Tensor], default = None + Bias tensor for :math:`Q \cdot K^T`, shape ``[1, num_head, max_seqlen_q, max_seqlen_kv]``. + It should be ``None`` for ``"no_bias"`` and ``"alibi"`` bias types. + alibi_slopes: Optional[torch.Tensor], default = None + ALiBi slopes in FP32 and shape ``[nheads]`` or ``[batch_size, nheads]``. It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) to the attention score of query i and key j. - fast_zero_fill: bool, default = `True` + fast_zero_fill: bool, default = True Whether to use the fast path to set output tensors to 0 or not. - inference_params: Optional[InferenceParams], default = `None` + inference_params: Optional[InferenceParams], default = None Optimizes execution performance during inference by caching Keys and Values of the current decoding iteration. These cached values are appended to the K and V values computed in previous iterations, eliminating the need to recalculate them for the entire sequence. - Initialization of `inference_params` is required prior to use to ensure sufficient + Initialization of ``inference_params`` is required prior to use to ensure sufficient memory allocation. Adjustments of the sequence_len_offset should be done after a complete forward pass. If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand. Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient. - pad_between_seqs: Optional[bool], default = `None` - If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. - If true, there are padding tokens between individual sequences in a packed batch. - fp8_output: Optional[bool], default = `False` + pad_between_seqs: Optional[bool], default = None + If ``None``, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. + If ``True``, there are padding tokens between individual sequences in a packed batch. + fp8_output: Optional[bool], default = False Whether to enforce output to be in FP8 or not. """ diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 7d4a4f86d9..6f26ddc31d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -175,9 +175,9 @@ class AttentionParams: Parameters ---------- - qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor` + qkv_type: Union[torch.Tensor, Float8Tensor], default = torch.Tensor Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}. - qkv_dtype: torch.dtype, default = `torch.bfloat16` + qkv_dtype: torch.dtype, default = torch.bfloat16 Data type of query/key/value tensors. qkv_layout: str, default = "sbh3d" Query/key/value tensor memory layout. @@ -195,41 +195,41 @@ class AttentionParams: The size of each attention head in query and key tensors. head_dim_v: int, default = 64 The size of each attention head in the value tensor. - attn_mask_type: str, default = `no_mask` + attn_mask_type: str, default = no_mask Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`, `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} window_size: Tuple[int, int], default = None Sliding window attention size. - alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None` + alibi_slopes_shape: Optional[Union[torch.Size, List]], default = None Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`. - core_attention_bias_type: str, default = `no_bias` + core_attention_bias_type: str, default = no_bias Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}. - core_attention_bias_shape: str, default = `1hss` + core_attention_bias_shape: str, default = 1hss Attention bias shape, {`1hss`, `b1ss`, `bhss`}. - core_attention_bias_requires_grad: bool, default = `True` + core_attention_bias_requires_grad: bool, default = True Whether attention bias requires gradient. - pad_between_seqs: bool, default = `False` + pad_between_seqs: bool, default = False Whether there is padding between sequences in a batch. This only applies to `qkv_format=thd`. attention_dropout: float, default = 0.0 Attention dropout. - context_parallel: bool, default = `False` + context_parallel: bool, default = False Whether context parallelism is used or not. cp_comm_type: str, default = "p2p" The communication type of context parallelism. - deterministic: bool, default = `False` + deterministic: bool, default = False Whether to run `DotProductAttention` with determinism or not. - is_training: bool, default = `True` + is_training: bool, default = True Whether in training mode (`True`) or inference mode (`False`) - fp8: bool, default = `False` + fp8: bool, default = False Whether `DotProductAttention` is in an `autocast` region. - fp8_meta: Optional[Dict[str Any]], default = `None` + fp8_meta: Optional[Dict[str Any]], default = None The FP8 metadata tensor of `DotProductAttention`. - inference_params: Optional[InferenceParams], default = `None` + inference_params: Optional[InferenceParams], default = None Inference-related parameters. See InferenceParams for details. softmax_type: str, default = "vanilla" The type of softmax operation. See DotProductAttention for details. - return_max_logit: bool, default = `False` + return_max_logit: bool, default = False Whether to output max_logit. """ @@ -815,8 +815,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt # ---------------------------------------------------------------------------------------- # no_mask | None | All # padding | | All - # self-attention | One tensor in shape [b, 1, 1, sq] | - # cross-attention | Tuple of two tensors in shapes | + # self-attention | One tensor of shape [b, 1, 1, sq] | + # cross-attention | Tuple of two tensors of shapes | # | [b, 1, 1, sq] and [b, 1, 1, skv] | # causal | None | # self-attention | | All @@ -826,7 +826,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt # cross-attention | | FusedAttention, UnfusedDotProductAttention # causal_bottom_right | None | All # padding_causal_bottom_right | Same as "padding" | All - # arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention + # arbitrary | One tensor of shape broadcastable to | UnfusedDotProductAttention # | [b, h, sq, skv] | if attn_mask_type == "arbitrary": if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or ( @@ -1254,14 +1254,14 @@ def get_full_mask( Maximum sequence length for queries. max_seqlen_kv: int Maximum sequence length for keys and values. - attn_mask_type: str, default = `no_mask` - Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`", - "`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"} + attn_mask_type: str, default = no_mask + Attention mask type, {``"no_mask"``, ``"padding"``, ``"causal"``, ``"padding_causal"``, + ``"causal_bottom_right"``, ``"padding_causal_bottom_right"``, ``"arbitrary"``} attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], - default = `None` + default = None Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention for the requirements of `attention_mask` for different `attn_mask_type`s. - window_size: Tuple[int, int], default = `None` + window_size: Tuple[int, int], default = None Sliding window size for local attention, where query at position i attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding @@ -1270,7 +1270,7 @@ def get_full_mask( `attn_mask_type`. attention_type: str, default = "self" Attention type, {"self", "cross"} - bottom_right_alignment: bool, default = `True` + bottom_right_alignment: bool, default = True Whether to align the diagonal of the sliding window attention to the bottom right (`True`) or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly specifies "causal" or "causal_bottom_right". @@ -1282,10 +1282,10 @@ def get_full_mask( attention_mask: torch.Tensor The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size` actual_seqlens_q: torch.Tensor - For padding masks, the actual sequence lengths for queries, in shape [batch_size]. + For padding masks, the actual sequence lengths for queries, of shape [batch_size]. For other masks, `None`. - actual_seqlens_kv: Optional[torch.Tensor], default = `None` - For padding masks, the actual sequence lengths for keys and values, in shape [batch_size]. + actual_seqlens_kv: Optional[torch.Tensor], default = None + For padding masks, the actual sequence lengths for keys and values, of shape [batch_size]. For other masks, `None`. """ # perform basic checks @@ -1377,15 +1377,15 @@ def get_alibi( Maximum sequence length for queries. max_seqlen_kv: int Maximum sequence length for keys and values. - actual_seqlens_q: Optional[torch.Tensor], default = `None` - Actual sequence lengths for queries, in shape [batch_size]. - actual_seqlens_kv: Optional[torch.Tensor], default = `None` - Actual sequence lengths for keys and values, in shape [batch_size]. - alibi_slopes: Optional[torch.Tensor], default = `None` - Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads]. - bias_dtype: Optional[torch.dtype], default = `None` + actual_seqlens_q: Optional[torch.Tensor], default = None + Actual sequence lengths for queries, of shape [batch_size]. + actual_seqlens_kv: Optional[torch.Tensor], default = None + Actual sequence lengths for keys and values, of shape [batch_size]. + alibi_slopes: Optional[torch.Tensor], default = None + Custom ALiBi slopes, FP32, CUDA tensor, of shape [num_heads] or [batch_size, num_heads]. + bias_dtype: Optional[torch.dtype], default = None Dtype of the generated ALiBi bias. If None, use torch.float32. - bottom_right_alignment: bool, default = `True` + bottom_right_alignment: bool, default = True Whether to align the diagonal of the ALiBi bias to the bottom right corner of the matrix (`True`) or top left (`False`). @@ -1797,12 +1797,12 @@ def get_qkv_format( ---------- qkv_layout: str Memory layout of `q`, `k` and `v`. See get_qkv_layout() for more details. - inference_params: InferenceParams, default = `None` + inference_params: InferenceParams, default = None InferenceParams related to KV caching. Returns ---------- - qkv_format: str, default = `sbhd` + qkv_format: str, default = sbhd Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. q_format: str Format of the `q` tensor, {`bshd`, `sbhd`, `thd`}. @@ -1838,12 +1838,12 @@ def get_qkv_layout( Key tensor. v: torch.Tensor Value tensor. - qkv_format: str, default = `sbhd` + qkv_format: str, default = sbhd Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for the sequence length dimension, `b` batch size, `h` the number of attention heads, `d` head size, and `t` the total number of tokens in a batch, i.e. `t = sum(s_i) for i = 0...b-1`. - inference_params: InferenceParams, default = `None` + inference_params: InferenceParams, default = None InferenceParams related to KV caching. Returns diff --git a/transformer_engine/pytorch/attention/inference.py b/transformer_engine/pytorch/attention/inference.py index f0ef8d0bd5..1d6b30e556 100644 --- a/transformer_engine/pytorch/attention/inference.py +++ b/transformer_engine/pytorch/attention/inference.py @@ -525,9 +525,9 @@ def step( new_v: torch.Tensor New value tokens for layer_number in current inference iteration cu_new_seqlens: torch.Tensor - Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1] + Cumulative sequence lengths for new_k and new_v, of shape [batch_size + 1] cu_cached_seqlens: torch.Tensor - Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1] + Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), of shape [batch_size + 1] qkv_format: str Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} @@ -701,7 +701,7 @@ def get_page_list(self, seq: int): return [x.page_id for x in self.allocated_pages[seq]] def get_page_table(self, sequences: List[int]): - """Get the page table, in shape [batch_size, max_pages_per_seq]""" + """Get the page table, of shape [batch_size, max_pages_per_seq]""" page_table = torch.Tensor( [ self.get_page_list(seq) + [0] * (self.max_pages_per_seq - self.get_page_count(seq)) @@ -783,9 +783,9 @@ def step( new_v: torch.Tensor New value tokens for layer_number in current inference iteration cu_new_seqlens: torch.Tensor - Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1] + Cumulative sequence lengths for new_k and new_v, of shape [batch_size + 1] cu_cached_seqlens: torch.Tensor - Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1] + Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), of shape [batch_size + 1] qkv_format: str Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index b61ed35871..d82dc3519c 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -48,8 +48,8 @@ class MultiheadAttention(torch.nn.Module): .. note:: - Argument :attr:`attention_mask` in the `forward` call is only used when - :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`. + Argument :attr:`attention_mask` in the :meth:`forward() ` method is only used when + :attr:`attn_mask_type` includes ``"padding"`` or ``"arbitrary"``. Parameters ---------- @@ -57,56 +57,55 @@ class MultiheadAttention(torch.nn.Module): size of each input sample. num_attention_heads : int number of attention heads in the transformer layer. - kv_channels: int, default = `None` + kv_channels: int, default = None number of key-value channels. defaults to - :attr:`hidden_size` / :attr:`num_attention_heads` if `None`. + :attr:`hidden_size` / :attr:`num_attention_heads` if ``None``. attention_dropout: float, default = 0.1 dropout probability for the dropout op during multi-head attention. layernorm_epsilon : float, default = 1e-5 a value added to the denominator of layer normalization for numerical stability. - init_method : Callable, default = `None` + init_method : Callable, default = None used for initializing weights of QKV and FC1 weights in the following way: - `init_method(weight)`. When set to `None`, defaults to - `torch.nn.init.normal_(mean=0.0, std=0.023)`. - output_layer_init_method : Callable, default = `None` + ``init_method(weight)``. When set to ``None``, defaults to + ``torch.nn.init.normal_(mean=0.0, std=0.023)``. + output_layer_init_method : Callable, default = None used for initializing weights of PROJ and FC2 in the following way: - `output_layer_init_method(weight)`. When set to `None`, defaults to - `torch.nn.init.normal_(mean=0.0, std=0.023)`. - layer_number: int, default = `None` - layer number of the current `TransformerLayer` when multiple such modules are + ``output_layer_init_method(weight)``. When set to ``None``, defaults to + ``torch.nn.init.normal_(mean=0.0, std=0.023)``. + layer_number: int, default = None + layer number of the current ``TransformerLayer`` when multiple such modules are concatenated to form a transformer block. attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'}, - default = `causal` + default = "causal" type of attention mask passed into softmax operation. Overridden by - :attr:`attn_mask_type` in the `forward` method. The forward + :attr:`attn_mask_type` in the :meth:`forward` method. The :meth:`forward` arg is useful for dynamically changing mask types, e.g. a different - mask for training and inference. The init arg is useful for cases + mask for training and inference. The :meth:`__init__` arg is useful for cases involving compilation/tracing, e.g. ONNX export. - window_size: Optional[Tuple[int, int]], default = `None` + window_size: Optional[Tuple[int, int]], default = None sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. Both `causal` and `causal_bottom_right` masks - map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on - `attn_mask_type`. Similar to :attr:`attn_mask_type`, `window_size` can - be overridden by :attr:`window_size` in `forward` as well. - num_gqa_groups : int, default = `None` + in ``[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]]`` inclusive. Special cases ``(-1, -1)`` and ``(-1, 0)`` mean no sliding + window and causal mask specifically. Both ``"causal"`` and ``"causal_bottom_right"`` masks + map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on + ``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can + be overridden by :attr:`window_size` in :meth:`forward` as well. + num_gqa_groups : int, default = None number of GQA groups in the transformer layer. Grouped Query Attention is described in `this paper `_. This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention (`MQA `_), while GQA-H - is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. - return_layernorm_output : bool, default = `False` - if set to `True`, output of layernorm is returned from the forward + is equivalent to MHA, i.e. ``num_gqa_groups = num_attention_heads``. + return_layernorm_output : bool, default = False + if set to ``True``, output of layernorm is returned from the :meth:`forward` method together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm. - input_layernorm: bool, default = `False` - if set to `True`, layer normalization to the input is applied. + input_layernorm: bool, default = False + if set to ``True``, layer normalization to the input is applied. attention_type: { 'self', 'cross' }, default = 'self' type of attention applied. zero_centered_gamma : bool, default = 'False' @@ -118,106 +117,118 @@ class MultiheadAttention(torch.nn.Module): (1 + \gamma) + \beta normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' type of normalization applied. - qkv_weight_interleaved : bool, default = `True` - if set to `False`, the QKV weight is interpreted as a concatenation of - query, key, and value weights along the `0th` dimension. The default - interpretation is that the individual `q`, `k`, and `v` weights for each - attention head are interleaved. This parameter is set to `False` when + qkv_weight_interleaved : bool, default = True + if set to ``False``, the QKV weight is interpreted as a concatenation of + query, key, and value weights along the ``0th`` dimension. The default + interpretation is that the individual ``q``, ``k``, and ``v`` weights for each + attention head are interleaved. This parameter is set to ``False`` when using :attr:`fuse_qkv_params=False`. - rotary_pos_interleaved : bool, default = `False` + rotary_pos_interleaved : bool, default = False whether to use interleaved rotary position embeddings. - bias : bool, default = `True` - if set to `False`, the transformer layer will not learn any additive biases. + bias : bool, default = True + if set to ``False``, the transformer layer will not learn any additive biases. device : Union[torch.device, str], default = "cuda" The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. - qkv_format: str, default = `sbhd` - dimension format for `query_layer`, `key_layer` and `value_layer`, - {`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size, - `h` the number of heads and `d` head size. `sbhd` and `bshd` formats + qkv_format: str, default = "sbhd" + dimension format for ``query_layer``, ``key_layer`` and ``value_layer``, + {``"sbhd"``, ``"bshd"``}. ``s`` stands for the sequence length, ``b`` batch size, + ``h`` the number of heads and ``d`` head size. ``"sbhd"`` and ``"bshd"`` formats are used for when sequences in a batch are of equal length or padded to equal length. Please note that these formats do not reflect how - tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory. - For that, please use `get_qkv_layout` to gain the layout information. - name: str, default = `None` + tensors ``query_layer``, ``key_layer``, ``value_layer`` are laid out in memory. + For that, please use ``get_qkv_layout`` to gain the layout information. + name: str, default = None name of the module, currently used for debugging purposes. softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' Softmax type as described in the paper `Efficient Streaming Language Models with Attention Sinks `_. - For a given attention score ``S = Q x K^T``, of shape ``[b, h, s_q, s_kv]``: + For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``: - * 'vanilla': ``S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1)`` - * 'off-by-one': ``S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1))`` - * 'learnable': ``S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1))`` + * ``'vanilla'``: - where ``alpha`` is a learnable parameter in shape ``[h]``. - 'off-by-one' and 'learnable' softmax types are also called sink attention - ('zero sink' and 'learnable sink'). + .. math:: + S_{:,:,:,i} = = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})} + + * ``'off-by-one'``: + + .. math:: + S_{:,:,:,i} = = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})} + + * ``'learnable'``: + + .. math:: + S_{:,:,:,i} = = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})} + + where :math:`\alpha` is a learnable parameter of shape ``[h]``. + + ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention + (``'zero sink'`` and ``'learnable sink'``). Parallelism parameters ---------------------- - set_parallel_mode : bool, default = `False` - if set to `True`, QKV and FC1 layers are used as Column Parallel + set_parallel_mode : bool, default = False + if set to ``True``, QKV and FC1 layers are used as Column Parallel whereas PROJ and FC2 is used as Row Parallel as described `here `_. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - tp_group : ProcessGroup, default = `None` + sequence_parallel : bool, default = False + if set to ``True``, uses sequence parallelism. + tp_group : ProcessGroup, default = None tensor parallel process group. tp_size : int, default = 1 used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the - `set_tensor_parallel_group(tp_group)` method on the initialized module before the + ``set_tensor_parallel_group(tp_group)`` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. Optimization parameters ----------------------- fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of + if set to ``True``, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct + have an additional ``main_grad`` attribute (used instead of the + regular ``grad``) which is a pre-allocated buffer of the correct size to accumulate gradients in. - params_dtype : torch.dtype, default = `torch.get_default_dtype()` + params_dtype : torch.dtype, default = torch.get_default_dtype() it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. - return_bias : bool, default = `False` - when set to `True`, this module will not apply the additive bias itself, but - instead return the bias value during the forward pass together with the + return_bias : bool, default = False + when set to ``True``, this module will not apply the additive bias itself, but + instead return the bias value during the :meth:`forward` method together with the output of the linear transformation :math:`y = xA^T`. This is useful when the bias addition can be fused to subsequent operations. fuse_qkv_params: bool, default = 'False' - if set to `True`, `TransformerLayer` module exposes a single fused + if set to ``True``, ``TransformerLayer`` module exposes a single fused parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument - `fuse_wgrad_accumulation`. + ``fuse_wgrad_accumulation``. qk_norm_type: Optional[str], default = None type of normalization to apply to query and key tensors. - Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'. When None, no normalization is applied. - When 'L2Normalization', L2 normalization is applied to query and key tensors. - When 'RMSNorm', RMS normalization is applied to query and key tensors. - When 'LayerNorm', layer normalization is applied to query and key tensors. + Options: ``None``, ``'L2Normalization'``, ``'RMSNorm'``, ``'LayerNorm'``. When ``None``, no normalization is applied. + When ``'L2Normalization'``, L2 normalization is applied to query and key tensors. + When ``'RMSNorm'``, RMS normalization is applied to query and key tensors. + When ``'LayerNorm'``, layer normalization is applied to query and key tensors. Normalization is applied after RoPE (if applicable) but before attention computation - when `qk_norm_before_rope` is False. This follows the e.g. Llama4 approach + when ``qk_norm_before_rope`` is ``False``. This follows the e.g. Llama4 approach for QK normalization to improve training stability and model performance. qk_norm_eps: float, default = 1e-6 epsilon value for normalization of query and key tensors. - Only used when `qk_norm_type` is not None. - qk_norm_before_rope: bool, default = `False` - if set to `True`, query and key normalization is applied before rotary position - embedding. When `False` (default), normalization is applied after RoPE. + Only used when ``qk_norm_type`` is not ``None``. + qk_norm_before_rope: bool, default = False + if set to ``True``, query and key normalization is applied before rotary position + embedding. When ``False`` (default), normalization is applied after RoPE. This parameter allows supporting different architectural variants that apply QK normalization at different points. - seq_length: Optional[int], default = `None` + seq_length: Optional[int], default = None sequence length of input samples. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propagation and activation recompute phase. - micro_batch_size: Optional[int], default = `None` + micro_batch_size: Optional[int], default = None batch size per training step. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propagation and activation recompute phase. @@ -536,7 +547,7 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N Parameters ---------- - tp_group : ProcessGroup, default = `None` + tp_group : ProcessGroup, default = None tensor parallel process group. """ self.tp_group = tp_group @@ -556,14 +567,14 @@ def set_context_parallel_group( ---------- cp_group : Union[ProcessGroup, List[ProcessGroup]] context parallel process group. - ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a". - List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0] - and cp_group[1] are for a2a and p2p communications respectively. + ``ProcessGroup`` is for :attr:`cp_comm_type` of ``"p2p"``, ``"all_gather"``, and ``"a2a"``. + ``List[ProcessGroup]`` is for :attr:`cp_comm_type` of ``"a2a+p2p"``, where :attr:`cp_group[0]` + and :attr:`cp_group[1]` are for ``"a2a"`` and ``"p2p"`` communications respectively. cp_global_ranks : List[int] list of global ranks in the context group. cp_stream : torch.cuda.Stream cuda stream for context parallel execution. - cp_comm_type : str, default = `p2p` + cp_comm_type : str, default = "p2p" inter-gpu communication type for context parallelism. Can be ``"p2p"`` or ``"all_gather"`` or ``"a2a"`` or ``"a2a+p2p"``. @@ -624,39 +635,39 @@ def forward( fast_zero_fill: bool = True, pad_between_seqs: Optional[bool] = None, ) -> Tuple[Union[torch.Tensor, None], ...]: - """ + r""" Forward propagation for MultiheadAttention layer. .. note:: Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type` - includes `"padding"` or `"arbitrary"`. + includes ``"padding"`` or ``"arbitrary"``. Parameters ---------- hidden_states : torch.Tensor Input tensor. attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], - default = `None`. Boolean tensor(s) used to mask out attention softmax input. - It should be `None` for causal masks and "`no_mask`". For padding masks, it should be - a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of - two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] - for cross-attention. For "`arbitrary`" mask, it should be in a shape broadcastable to - [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value means - the corresponding position is masked out and a `False` means that position + default = None. Boolean tensor(s) used to mask out attention softmax input. + It should be ``None`` for causal masks and ``"no_mask"``. For padding masks, it should be + a single tensor of ``[batch_size, 1, 1, seqlen_q]`` for self-attention, and a tuple of + two tensors of shapes ``[batch_size, 1, 1, seqlen_q]`` and ``[batch_size, 1, 1, seqlen_kv]`` + for cross-attention. For ``"arbitrary"`` mask, it should be of a shape broadcastable to + ``[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]``. A ``True`` value means + the corresponding position is masked out and a ``False`` means that position is allowed to participate in attention. attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'}, - default = `None` + default = None type of attention mask passed into softmax operation. By default, causal masks are aligned to the top left corner of the softmax matrix. - When "`bottom_right`" is specified in the mask type, causal masks are + When ``"bottom_right"`` is specified in the mask type, causal masks are aligned to the bottom right corner. - window_size: Optional[Tuple[int, int]], default = `None` + window_size: Optional[Tuple[int, int]], default = None sliding window size for local attention. - encoder_output : Optional[torch.Tensor], default = `None` + encoder_output : Optional[torch.Tensor], default = None Output of the encoder block to be fed into the decoder block if using - `layer_type="decoder"`. + ``layer_type="decoder"``. is_first_microbatch : {True, False, None}, default = None During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split @@ -670,46 +681,46 @@ def forward( * it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced) - checkpoint_core_attention: bool, default = `False` - If true, forward activations for core attention are recomputed + checkpoint_core_attention: bool, default = False + If ``True``, forward activations for core attention are recomputed during the backward pass in order to save memory that would otherwise be occupied to store the forward activations until backprop. - rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None` + rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = None Embeddings for query and key tensors for applying rotary position embedding. By default no input embedding is applied. - core_attention_bias_type: str, default = `no_bias` - Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`, `alibi`} - core_attention_bias: Optional[torch.Tensor], default = `None` - Bias tensor for Q * K.T, shape [1, num_head, max_seqlen_q, max_seqlen_kv]. - It should be 'None' for 'no_bias' and 'alibi' bias types. - alibi_slopes: Optional[torch.Tensor], default = `None` - ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. - It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) + core_attention_bias_type: str, default = "no_bias" + Bias type, {``"no_bias"``, ``"pre_scale_bias"``, ``"post_scale_bias"``, ``"alibi"``} + core_attention_bias: Optional[torch.Tensor], default = None + Bias tensor for :math:`Q \cdot K^T`, shape ``[1, num_head, max_seqlen_q, max_seqlen_kv]``. + It should be ``None`` for ``"no_bias"`` and ``"alibi"`` bias types. + alibi_slopes: Optional[torch.Tensor], default = None + ALiBi slopes in FP32 and shape ``[nheads]`` or ``[batch_size, nheads]``. + It adds a bias of ``(-alibi_slope * (i + seqlen_k - seqlen_q - j))`` to the attention score of query i and key j. - cu_seqlens_q: Optional[torch.Tensor], default = `None` - Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, - with shape [batch_size + 1] and dtype torch.int32. - cu_seqlens_kv: Optional[torch.Tensor], default = `None` - Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` - and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. - cu_seqlens_q_padded: Optional[torch.Tensor], default = `None` - Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`, - with shape [batch_size + 1] and dtype torch.int32. - cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None` - Cumulative sum of sequence lengths (with offset) in a batch for `key_layer` - and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. - max_seqlen_q: Optional[int], default = `None` - Maximum sequence length in `query_layer`. - Calculated from `cu_seqlens_q` if not provided. - max_seqlen_kv: Optional[int], default = `None` - Maximum sequence length in `key_layer` and `value_layer`. - Calculated from `cu_seqlens_kv` if not provided. - fast_zero_fill: bool, default = `True` + cu_seqlens_q: Optional[torch.Tensor], default = None + Cumulative sum of sequence lengths (without offset) in a batch for ``query_layer``, + with shape ``[batch_size + 1]`` and dtype torch.int32. + cu_seqlens_kv: Optional[torch.Tensor], default = None + Cumulative sum of sequence lengths (without offset) in a batch for ``key_layer`` + and ``value_layer``, with shape ``[batch_size + 1]`` and dtype torch.int32. + cu_seqlens_q_padded: Optional[torch.Tensor], default = None + Cumulative sum of sequence lengths (with offset) in a batch for ``query_layer``, + with shape ``[batch_size + 1]`` and dtype torch.int32. + cu_seqlens_kv_padded: Optional[torch.Tensor], default = None + Cumulative sum of sequence lengths (with offset) in a batch for ``key_layer`` + and ``value_layer``, with shape ``[batch_size + 1]`` and dtype torch.int32. + max_seqlen_q: Optional[int], default = None + Maximum sequence length in ``query_layer``. + Calculated from ``cu_seqlens_q`` if not provided. + max_seqlen_kv: Optional[int], default = None + Maximum sequence length in ``key_layer`` and ``value_layer``. + Calculated from ``cu_seqlens_kv`` if not provided. + fast_zero_fill: bool, default = True Whether to set output tensors to 0 or not before use. - pad_between_seqs: Optional[bool], default = `None` - If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. - If true, there are padding tokens between individual sequences in a packed batch. + pad_between_seqs: Optional[bool], default = None + If ``None``, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. + If ``True``, there are padding tokens between individual sequences in a packed batch. """ # hidden_states: [sq, b, h] diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index cc23d65a3e..4c6e6ebb4b 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -350,7 +350,7 @@ def _get_freqs_on_this_cp_rank( """Get the position embedding on the current context parallel rank. Args: - freqs: torch.Tensor. Positional embedding tensor in shape `[s2, 1, 1, d2]`. + freqs: torch.Tensor. Positional embedding tensor of shape `[s2, 1, 1, d2]`. seqlen: int. Length of the current sequence. cp_size: int. Context parallel world size. cp_rank: int. Context parallel rank. diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index eb43c75f6b..c4d912b719 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -215,7 +215,7 @@ def fused_attn_fwd( random number generator; if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen softmax_offset: torch.Tensor, default = None - softmax offset tensor in shape [1, h_q, 1, 1]. + softmax offset tensor of shape [1, h_q, 1, 1]. See softmax_type in DotProductAttention for details. return_max_logit: bool, default = False whether to return the maximum attention score @@ -452,7 +452,7 @@ def fused_attn_bwd( gradient tensor of Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; same data type and shape as Bias d_softmax_offset: torch.Tensor, optional - gradient tensor of softmax offset in shape [1, h_q, 1, 1]. + gradient tensor of softmax offset of shape [1, h_q, 1, 1]. See softmax_type in DotProductAttention for details. """ if attn_scale is None: diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 6edc126200..2b1def16ec 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -680,18 +680,18 @@ def get_cpu_offload_context( Parameters ---------- - enabled: bool, default = `False` + enabled: bool, default = False When set to True, CPU Offloading functionality is enabled. num_layers: int, default = 1 Determines the number of transformer layers you want to offload activations/weights for. model_layers: int, default = 1 Number of layers in the model that will be used under this context. - offload_activations: bool, default = `True` + offload_activations: bool, default = True When set to `True`, offloads the activations for the TE layer. - offload_weights: bool, default = `True` + offload_weights: bool, default = True When set to `True`, offloads the weights for the TE layer. - double_buffering: bool, default = `False` + double_buffering: bool, default = False When set to `True`, uses double buffering for offloading. """ diff --git a/transformer_engine/pytorch/cross_entropy.py b/transformer_engine/pytorch/cross_entropy.py index 5bfec4608b..a71b108956 100644 --- a/transformer_engine/pytorch/cross_entropy.py +++ b/transformer_engine/pytorch/cross_entropy.py @@ -26,7 +26,7 @@ class CrossEntropyFunction(torch.autograd.Function): @staticmethod def forward( ctx, - input, + inp, target, label_smoothing=0.0, reduce_loss=False, @@ -40,7 +40,7 @@ def forward( Parameters: ctx : The context object. - input (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size. + inp (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size. target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1]. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension. @@ -50,8 +50,8 @@ def forward( Returns: tensor: The computed loss. """ - loss, input = triton_cross_entropy.cross_entropy_forward( - input, + loss, inp = triton_cross_entropy.cross_entropy_forward( + inp, target, label_smoothing, reduce_loss, @@ -59,7 +59,7 @@ def forward( ignore_idx, ) - ctx.save_for_backward(input.detach()) + ctx.save_for_backward(inp.detach()) ctx.is_cg_capturable = is_cg_capturable return loss @@ -75,12 +75,12 @@ def backward(ctx, grad_output): Returns: tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. """ - (input,) = ctx.saved_tensors - input = triton_cross_entropy.cross_entropy_backward( - input, grad_output, ctx.is_cg_capturable + (inp,) = ctx.saved_tensors + inp = triton_cross_entropy.cross_entropy_backward( + inp, grad_output, ctx.is_cg_capturable ) return ( - input, + inp, None, None, None, @@ -91,7 +91,7 @@ def backward(ctx, grad_output): def parallel_cross_entropy( - input: torch.Tensor, + inp: torch.Tensor, target: torch.Tensor, label_smoothing: float = 0.0, reduce_loss: bool = False, @@ -114,7 +114,7 @@ def parallel_cross_entropy( Parameters ---------- - input : torch.Tensor + inp : torch.Tensor The input tensor of shape ``(B, SQ, V)`` or ``(SQ, B, V)`` where B is batch size, SQ is sequence length, V is vocab size. target : torch.Tensor @@ -138,13 +138,13 @@ def parallel_cross_entropy( # Handle backward compatibility with _input parameter if _input is not None: warnings.warn( - "The '_input' parameter is deprecated. Please use 'input' instead.", + "The '_input' parameter is deprecated. Please use 'inp' instead.", FutureWarning, ) - input = _input + inp = _input return CrossEntropyFunction.apply( - input, + inp, target, label_smoothing, reduce_loss, diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 8992bbc291..3bb80a936b 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -645,14 +645,14 @@ def checkpoint( pytorch module used to run the forward and backward passes using the specified :attr:`args` and :attr:`kwargs`. distribute_saved_activations: bool, default = False - if set to `True` and `use_reentrant=True`, first tensor argument is distributed - across the specified tensor parallel group (`tp_group`) before saving it for the - backward pass. This has no effect when `use_reentrant=False`. - get_rng_state_tracker: `Callable`, default = None - python callable which returns an instance of :func:`CudaRNGStatesTracker`. + if set to ``True`` and ``use_reentrant=True``, first tensor argument is distributed + across the specified tensor parallel group (``tp_group``) before saving it for the + backward pass. This has no effect when ``use_reentrant=False``. + get_rng_state_tracker: Callable, default = None + python callable which returns an instance of :class:`CudaRNGStatesTracker`. tp_group : ProcessGroup, default = None - tensor parallel process group. Used only when `distribute_saved_activations=True` - and `use_reentrant=True`. If `None`, it falls back to the default group. + tensor parallel process group. Used only when ``distribute_saved_activations=True`` + and ``use_reentrant=True``. If ``None``, it falls back to the default group. use_reentrant : bool, default = True perform checkpointing in reentrant mode. args : tuple @@ -777,8 +777,8 @@ class CudaRNGStatesTracker: For model parallelism, multiple RNG states need to simultaneously exist in order to execute operations in or out of the model parallel region. This class keeps track of the various RNG states and provides utility methods to maintain them and - execute parts of the model under a given RNG setting. Using the `add` method, a - cuda rng state is initialized based on the input `seed` and is assigned to `name`. + execute parts of the model under a given RNG setting. Using the :meth:`add` method, a + cuda rng state is initialized based on the input ``seed`` and is assigned to ``name``. Later, by forking the rng state, we can perform operations and return to our starting cuda state. """ @@ -811,7 +811,9 @@ def set_states(self, states: Dict[str, torch.Tensor]) -> None: Set the rng states. For efficiency purposes, we do not check the size of seed for compatibility. - states: Dict[str, torch.Tensor] + Parameters + ---------- + states : Dict[str, torch.Tensor] A mapping from string names to RNG states. """ self.states_ = states @@ -820,9 +822,11 @@ def add(self, name: str, seed: int) -> None: """ Adds a new RNG state. - name: str + Parameters + ---------- + name : str string identifier for the RNG state. - seed: int + seed : int PyTorch seed for the RNG state. """ # Check seed is not already used. @@ -856,7 +860,9 @@ def fork(self, name: str = "model-parallel-rng"): Fork the cuda rng state, perform operations, and exit with the original state. - name: str + Parameters + ---------- + name : str string identifier for the RNG state. """ # Check if we have added the state diff --git a/transformer_engine/pytorch/export.py b/transformer_engine/pytorch/export.py index f75271e2cc..9652889bbc 100644 --- a/transformer_engine/pytorch/export.py +++ b/transformer_engine/pytorch/export.py @@ -28,7 +28,7 @@ def onnx_export(enabled: bool = False) -> Generator[None, None, None]: Parameters ---------- - enabled: bool, default = `False` + enabled: bool, default = False whether or not to enable export """ diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 9af9fb8870..e194f75aa4 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -941,38 +941,38 @@ def make_graphed_callables( Positional arguments to callable(s). num_warmup_iters: int, default = 3 Number of warmup iterations. - allow_unused_input: bool, default = `False` + allow_unused_input: bool, default = False Whether to handle case where callable inputs and outputs are disconnected in compute graph. sample_kwargs: (tuple of) dict, optional Keyword arguments to callable(s) - pool: (tuple of) int, default = `None`, optional + pool: (tuple of) int, default = None, optional An instance returned from function `torch.cuda.graph_pool_handle` that hints this graph may share memory with the indicated pool. - retain_graph_in_backward: bool, default = `False` + retain_graph_in_backward: bool, default = False Whether to set retain_graph=True in backward graph capture. - _reuse_graph_input_output_buffers: bool, default = `False` + _reuse_graph_input_output_buffers: bool, default = False Reduce memory usage by reusing input/output data buffers between graphs. Only supported with Mcore interleaved pipeline parallelism, i.e. when `_order` is provided. All callables in `modules` are assumed to have inputs and outputs with the same dtype and shape. - Quantization related parameters - ---------------------- - enabled: (tuple of) bool, default = `False` + Quantization parameters + ----------------------- + enabled: (tuple of) bool, default = False whether or not to enable low precision quantization (FP8/FP4). If tuple, the length must match the number of modules. - calibrating: bool, default = `False` + calibrating: bool, default = False calibration mode allows collecting statistics such as amax and scale data of quantized tensors even when executing without quantization enabled. This is useful for saving an inference ready checkpoint while training using a higher precision. - recipe: recipe.Recipe, default = `None` + recipe: recipe.Recipe, default = None recipe used for low precision quantization. - amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = `None` + amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = None distributed group over which amaxes for the quantized tensors are reduced at the end of each training step. - cache_quantized_params: bool, default = `False` + cache_quantized_params: bool, default = False Whether or not to cache quantized weights across microbatches. if set to `True`, the `is_first_microbatch` boolean argument must be passed into the forward method for TransformerEngine modules. When storing primary weights in low precision diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index d2febde3c7..bcc26f43b3 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -129,24 +129,24 @@ def initialize_ub( ) -> None: r""" Initialize the Userbuffers communicator for overlapping tensor-parallel communications with - GEMM compute in te.Linear, te.LayerNormLinear and te.LayerNormMLP modules. + GEMM compute in ``te.Linear``, ``te.LayerNormLinear`` and ``te.LayerNormMLP`` modules. Parameters ---------- shape : list shape of the communication buffer, typically set to be the same as the global shape of - the input tensor to a te.TransformerLayer forward pass, with the sequence and batch - dimensions collapsed together -- i.e.: `(sequence_length * batch_size, hidden_size)` + the input tensor to a ``te.TransformerLayer`` forward pass, with the sequence and batch + dimensions collapsed together -- i.e.: ``(sequence_length * batch_size, hidden_size)`` tp_size : int number of GPUs in the tensor-parallel process group use_fp8 : bool = False allocate the communication buffer for FP8 GEMM inputs/outputs. - DEPRECATED: Please use `quantization_modes` instead. + DEPRECATED: Please use ``quantization_modes`` instead. quantization_modes : List[UserBufferQuantizationMode] = None if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list. - falls back to the legacy `use_fp8` parameter if `None` is provided. + falls back to the legacy ``use_fp8`` parameter if ``None`` is provided. dtype : torch.dtype = torch.bfloat16 - non-FP8 data type of the communication buffer when `use_fp8 = False` + non-FP8 data type of the communication buffer when ``use_fp8 = False`` ub_cfgs: dict = None Configuration dictionary with the structure:: @@ -165,19 +165,19 @@ def initialize_ub( } } - for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad", + for ``te.TransformerLayer`` GEMM layers in ``["qkv_fprop", "qkv_dgrad", "qkv_wgrad", "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", - "fc2_fprop", "fc2_wgrad"]`. - a list may be provided to specify different overlap configurations for different the quantization settings in `quantization_modes` + "fc2_fprop", "fc2_wgrad"]``. + a list may be provided to specify different overlap configurations for different the quantization settings in ``quantization_modes`` bootstrap_backend : str = None - `torch.distributed` communication backend for the all-gather, broadcast and + ``torch.distributed`` communication backend for the all-gather, broadcast and barrier collectives during Userbuffers initialization. Not all backends are valid for every cluster configuration and distributed launch method even if they are available in PyTorch. When left unset, the initialization prefers to use the MPI backend, falling back first on Gloo and then NCCL if MPI is - not available. Setting `NVTE_UB_WITH_MPI=1` when building TE overrides this + not available. Setting ``NVTE_UB_WITH_MPI=1`` when building TE overrides this option and always initializes Userbuffers with direct MPI calls in C++, - which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time. + which also requires ``MPI_HOME=/path/to/mpi/root`` to be set at compile time. """ if not tex.device_supports_multicast(): assert bool(int(os.getenv("UB_SKIPMC", "0"))), ( @@ -986,7 +986,7 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N Parameters ---------- - tp_group : ProcessGroup, default = `None` + tp_group : ProcessGroup, default = None tensor parallel process group. """ self.tp_group = tp_group @@ -1355,7 +1355,7 @@ def get_weight_workspace( workspace is being constructed or updated. cache_name: str, optional Key for caching. - update_workspace: bool, default = `True` + update_workspace: bool, default = True Update workspace with values from `tensor`. skip_update_flag: torch.Tensor, optional GPU flag to skip updating the workspace. Take precedence diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 4d6b2f23b9..1e18b4d77e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -517,14 +517,14 @@ class GroupedLinear(TransformerEngineBaseModule): size of each input sample. out_features : int size of each output sample. - bias : bool, default = `True` - if set to `False`, the layer will not learn an additive bias. - init_method : Callable, default = `None` - used for initializing weights in the following way: `init_method(weight)`. - When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. - get_rng_state_tracker : Callable, default = `None` + bias : bool, default = True + if set to ``False``, the layer will not learn an additive bias. + init_method : Callable, default = None + used for initializing weights in the following way: ``init_method(weight)``. + When set to ``None``, defaults to ``torch.nn.init.normal_(mean=0.0, std=0.023)``. + get_rng_state_tracker : Callable, default = None used to get the random number generator state tracker for initializing weights. - rng_tracker_name : str, default = `None` + rng_tracker_name : str, default = None the param passed to get_rng_state_tracker to get the specific rng tracker. device : Union[torch.device, str], default = "cuda" The device on which the parameters of the model will be allocated. It is the user's @@ -533,34 +533,36 @@ class GroupedLinear(TransformerEngineBaseModule): Optimization parameters ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of + fuse_wgrad_accumulation : bool, default = False + if set to ``True``, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct + have an additional ``main_grad`` attribute (used instead of the + regular ``grad``) which is a pre-allocated buffer of the correct size to accumulate gradients in. This argument along with weight tensor having attribute 'overwrite_main_grad' set to True - will overwrite `main_grad` instead of accumulating. - return_bias : bool, default = `False` - when set to `True`, this module will not apply the additive bias itself, but + will overwrite ``main_grad`` instead of accumulating. + return_bias : bool, default = False + when set to ``True``, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the output of the linear transformation :math:`y = xA^T`. This is useful when the bias addition can be fused to subsequent operations. - params_dtype : torch.dtype, default = `torch.get_default_dtype()` + params_dtype : torch.dtype, default = torch.get_default_dtype() it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. - delay_wgrad_compute : bool, default = `False` + delay_wgrad_compute : bool, default = False Whether to delay weight gradient computation - save_original_input : bool, default = `False` - If set to `True`, always saves the original input tensor rather than the + save_original_input : bool, default = False + If set to ``True``, always saves the original input tensor rather than the cast tensor. In some scenarios, the input tensor is used by multiple modules, and saving the original input tensor may reduce the memory usage. Cannot work with FP8 DelayedScaling recipe. - Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and - `parallel_mode` are used to determine the shapes of weights and biases. - The TP communication should be handled in the dispatch and combine stages of MoE models. + Notes + ----- + GroupedLinear doesn't really handle the TP communications inside. The ``tp_size`` and + ``parallel_mode`` are used to determine the shapes of weights and biases. + The TP communication should be handled in the dispatch and combine stages of MoE models. """ def __init__( diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index 6d13544e4f..4ce087ceb8 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -38,7 +38,7 @@ class LayerNorm(_LayerNormOp): dtype: torch.dtype, default = default dtype Tensor datatype zero_centered_gamma : bool, default = 'False' - If `True`, the :math:`\gamma` parameter is initialized to zero + If ``True``, the :math:`\gamma` parameter is initialized to zero and the calculation changes to .. math:: @@ -48,13 +48,10 @@ class LayerNorm(_LayerNormOp): Number of SMs to exclude when launching CUDA kernels. This helps overlap with other kernels, e.g. communication kernels. For more fine-grained control, provide a dict with the SM - margin at each compute stage ("forward", "backward", - "inference"). - - Legacy - ------ + margin at each compute stage (``"forward"``, ``"backward"``, + ``"inference"``). sequence_parallel: bool - Set a bool attr named `sequence_parallel` in the parameters. + **Legacy parameter.** Set a bool attr named ``sequence_parallel`` in the parameters. This is custom logic for Megatron-LM integration. """ diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 6540eeb6f9..c17e809f29 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1065,20 +1065,20 @@ class LayerNormLinear(TransformerEngineBaseModule): size of each output sample. eps : float, default = 1e-5 a value added to the denominator of layer normalization for numerical stability. - bias : bool, default = `True` - if set to `False`, the layer will not learn an additive bias. + bias : bool, default = True + if set to ``False``, the layer will not learn an additive bias. normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' type of normalization applied. - init_method : Callable, default = `None` - used for initializing weights in the following way: `init_method(weight)`. - When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. - return_layernorm_output : bool, default = `False` - if set to `True`, output of layernorm is returned from the forward + init_method : Callable, default = None + used for initializing weights in the following way: ``init_method(weight)``. + When set to ``None``, defaults to ``torch.nn.init.normal_(mean=0.0, std=0.023)``. + return_layernorm_output : bool, default = False + if set to ``True``, output of layernorm is returned from the forward together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm. - return_layernorm_output_gathered : bool, default = `False` - if set to `True`, output of layernorm is returned after the all + return_layernorm_output_gathered : bool, default = False + if set to ``True``, output of layernorm is returned after the all gather operation. Ignored if return_layernorm_output is False. Example use case: with sequence parallel, input to residual connection for transformer module (e.g. LoRA) will need to be gathered. @@ -1089,10 +1089,10 @@ class LayerNormLinear(TransformerEngineBaseModule): they are used to make the names of equally-sized parameters. If a dict (preferably an OrderedDict) is provided, the keys are used as names and values as split sizes along dim 0. The resulting parameters will have - names that end in `_weight` or `_bias`, so trailing underscores are + names that end in ``_weight`` or ``_bias``, so trailing underscores are stripped from any provided names. zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in LayerNorm is initialized to 0 and + if set to ``'True'``, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to .. math:: @@ -1102,53 +1102,53 @@ class LayerNormLinear(TransformerEngineBaseModule): The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. - name: str, default = `None` + name: str, default = None name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - tp_group : ProcessGroup, default = `None` + sequence_parallel : bool, default = False + if set to ``True``, uses sequence parallelism. + tp_group : ProcessGroup, default = None tensor parallel process group. tp_size : int, default = 1 used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the - `set_tensor_parallel_group(tp_group)` method on the initialized module before the + ``set_tensor_parallel_group(tp_group)`` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. - parallel_mode : {None, 'column', 'row'}, default = `None` + parallel_mode : {None, 'column', 'row'}, default = None used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described `here `_. - When set to `None`, no communication is performed. + When set to ``None``, no communication is performed. Optimization parameters ----------------------- fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of + if set to ``True``, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct + have an additional ``main_grad`` attribute (used instead of the + regular ``grad``) which is a pre-allocated buffer of the correct size to accumulate gradients in. This argument along with weight tensor having attribute 'overwrite_main_grad' set to True - will overwrite `main_grad` instead of accumulating. - return_bias : bool, default = `False` - when set to `True`, this module will not apply the additive bias itself, but + will overwrite ``main_grad`` instead of accumulating. + return_bias : bool, default = False + when set to ``True``, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the output of the linear transformation :math:`y = xA^T`. This is useful when the bias addition can be fused to subsequent operations. - params_dtype : torch.dtype, default = `torch.get_default_dtype()` + params_dtype : torch.dtype, default = torch.get_default_dtype() it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. - delay_wgrad_compute : bool, default = `False` - Whether or not to delay weight gradient computation. If set to `True`, - it's the user's responsibility to call `module.backward_dw` to compute + delay_wgrad_compute : bool, default = False + Whether or not to delay weight gradient computation. If set to ``True``, + it's the user's responsibility to call ``module.backward_dw`` to compute weight gradients. symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None Type of symmetric memory all-reduce to use during the forward pass. This can help in latency bound communication situations. - Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce + Requires PyTorch version 2.7.0 or higher. When set to ``None``, standard all-reduce is used. """ diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a4e8929b6e..52979599e7 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1440,38 +1440,38 @@ class LayerNormMLP(TransformerEngineBaseModule): intermediate size to which input samples are projected. eps : float, default = 1e-5 a value added to the denominator of layer normalization for numerical stability. - bias : bool, default = `True` - if set to `False`, the FC1 and FC2 layers will not learn an additive bias. + bias : bool, default = True + if set to ``False``, the FC1 and FC2 layers will not learn an additive bias. normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' type of normalization applied. activation : str, default = 'gelu' activation function used. - Options: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu', - 'silu', 'swiglu', and 'clamped_swiglu'. - activation_params : dict, default = `None` + Options: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``, + ``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``. + activation_params : dict, default = None Additional parameters for the activation function. - At the moment, only used for 'clamped_swiglu' activation which - supports 'limit' and 'alpha' parameters. - init_method : Callable, default = `None` - used for initializing FC1 weights in the following way: `init_method(weight)`. - When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. - output_layer_init_method : Callable, default = `None` + At the moment, only used for ``'clamped_swiglu'`` activation which + supports ``'limit'`` and ``'alpha'`` parameters. + init_method : Callable, default = None + used for initializing FC1 weights in the following way: ``init_method(weight)``. + When set to ``None``, defaults to ``torch.nn.init.normal_(mean=0.0, std=0.023)``. + output_layer_init_method : Callable, default = None used for initializing FC2 weights in the following way: - `output_layer_init_method(weight)`. When set to `None`, defaults to - `torch.nn.init.normal_(mean=0.0, std=0.023)`. - return_layernorm_output : bool, default = `False` - if set to `True`, output of layernorm is returned from the forward + ``output_layer_init_method(weight)``. When set to ``None``, defaults to + ``torch.nn.init.normal_(mean=0.0, std=0.023)``. + return_layernorm_output : bool, default = False + if set to ``True``, output of layernorm is returned from the :meth:`forward` method together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm. - return_layernorm_output_gathered : bool, default = `False` - if set to `True`, output of layernorm is returned after the all - gather operation. Ignored if return_layernorm_output is False. + return_layernorm_output_gathered : bool, default = False + if set to ``True``, output of layernorm is returned after the all + gather operation. Ignored if ``return_layernorm_output`` is False. Example use case: with sequence parallel, input to residual connection for transformer module (e.g. LoRA) will need to be gathered. Returning layernorm output gathered will prevent a redundant gather. - zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in LayerNorm is initialized to 0 and + zero_centered_gamma : bool, default = False + if set to ``True``, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to .. math:: @@ -1481,41 +1481,41 @@ class LayerNormMLP(TransformerEngineBaseModule): The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. - name: str, default = `None` + name: str, default = None name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- - set_parallel_mode : bool, default = `False` - if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row + set_parallel_mode : bool, default = False + if set to ``True``, FC1 is used as Column Parallel and FC2 is used as Row Parallel as described `here `_. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - tp_group : ProcessGroup, default = `None` + sequence_parallel : bool, default = False + if set to ``True``, uses sequence parallelism. + tp_group : ProcessGroup, default = None tensor parallel process group. tp_size : int, default = 1 used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the - `set_tensor_parallel_group(tp_group)` method on the initialized module before the + ``set_tensor_parallel_group(tp_group)`` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. Optimization parameters ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of + fuse_wgrad_accumulation : bool, default = False + if set to ``True``, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct + have an additional ``main_grad`` attribute (used instead of the + regular ``grad``) which is a pre-allocated buffer of the correct size to accumulate gradients in. This argument along with - weight tensor having attribute 'overwrite_main_grad' set to True - will overwrite `main_grad` instead of accumulating. - return_bias : bool, default = `False` - when set to `True`, this module will not apply the additive bias for FC2, but + weight tensor having attribute ``'overwrite_main_grad'`` set to True + will overwrite ``main_grad`` instead of accumulating. + return_bias : bool, default = False + when set to ``True``, this module will not apply the additive bias for FC2, but instead return the bias value during the forward pass together with the output of the linear transformation :math:`y = xA^T`. This is useful when the bias addition can be fused to subsequent operations. - params_dtype : torch.dtype, default = `torch.get_default_dtype()` + params_dtype : torch.dtype, default = torch.get_default_dtype() it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. @@ -1527,14 +1527,14 @@ class LayerNormMLP(TransformerEngineBaseModule): batch size per training step. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase. - delay_wgrad_compute : bool, default = `False` - Whether or not to delay weight gradient computation. If set to `True`, - it's the user's responsibility to call `module.backward_dw` to compute + delay_wgrad_compute : bool, default = False + Whether or not to delay weight gradient computation. If set to ``True``, + it's the user's responsibility to call :meth:`backward_dw` to compute weight gradients. symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None Type of symmetric memory all-reduce to use during the forward pass. This can help in latency bound communication situations. - Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce + Requires PyTorch version 2.7.0 or higher. When set to ``None``, standard all-reduce is used. """ @@ -2001,7 +2001,7 @@ def _get_quantizers(self, fp8_output): def onnx_forward(self, inp: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ - ONNX-compatible version of the forward function that provides numerical equivalence + ONNX-compatible version of the :meth:`forward` method that provides numerical equivalence while only using operations that have defined ONNX symbolic translations. This simplified implementation is designed specifically for inference scenarios. """ diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 61886950a8..48ebe9c327 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1009,7 +1009,7 @@ def wgrad_gemm( class Linear(TransformerEngineBaseModule): """Applies a linear transformation to the incoming data :math:`y = xA^T + b` - On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`. + On NVIDIA GPUs it is a drop-in replacement for ``torch.nn.Linear``. Parameters ---------- @@ -1017,14 +1017,14 @@ class Linear(TransformerEngineBaseModule): size of each input sample. out_features : int size of each output sample. - bias : bool, default = `True` - if set to `False`, the layer will not learn an additive bias. - init_method : Callable, default = `None` - used for initializing weights in the following way: `init_method(weight)`. - When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. - get_rng_state_tracker : Callable, default = `None` + bias : bool, default = True + if set to ``False``, the layer will not learn an additive bias. + init_method : Callable, default = None + used for initializing weights in the following way: ``init_method(weight)``. + When set to ``None``, defaults to ``torch.nn.init.normal_(mean=0.0, std=0.023)``. + get_rng_state_tracker : Callable, default = None used to get the random number generator state tracker for initializing weights. - rng_tracker_name : str, default = `None` + rng_tracker_name : str, default = None the param passed to get_rng_state_tracker to get the specific rng tracker. parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None Configuration for splitting the weight and bias tensors along dim 0 into @@ -1032,62 +1032,62 @@ class Linear(TransformerEngineBaseModule): they are used to make the names of equally-sized parameters. If a dict (preferably an OrderedDict) is provided, the keys are used as names and values as split sizes along dim 0. The resulting parameters will have - names that end in `_weight` or `_bias`, so trailing underscores are + names that end in ``_weight`` or ``_bias``, so trailing underscores are stripped from any provided names. device : Union[torch.device, str], default = "cuda" The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. - name: str, default = `None` + name: str, default = None name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - tp_group : ProcessGroup, default = `None` + sequence_parallel : bool, default = False + if set to ``True``, uses sequence parallelism. + tp_group : ProcessGroup, default = None tensor parallel process group. tp_size : int, default = 1 used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the - `set_tensor_parallel_group(tp_group)` method on the initialized module before the + ``set_tensor_parallel_group(tp_group)`` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. - parallel_mode : {None, 'column', 'row'}, default = `None` + parallel_mode : {None, 'column', 'row'}, default = None used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described `here `_. - When set to `None`, no communication is performed. + When set to ``None``, no communication is performed. Optimization parameters ----------------------- fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of + if set to ``True``, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct + have an additional ``main_grad`` attribute (used instead of the + regular ``grad``) which is a pre-allocated buffer of the correct size to accumulate gradients in. This argument along with weight tensor having attribute 'overwrite_main_grad' set to True - will overwrite `main_grad` instead of accumulating. - return_bias : bool, default = `False` - when set to `True`, this module will not apply the additive bias itself, but + will overwrite ``main_grad`` instead of accumulating. + return_bias : bool, default = False + when set to ``True``, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the output of the linear transformation :math:`y = xA^T`. This is useful when the bias addition can be fused to subsequent operations. - params_dtype : torch.dtype, default = `torch.get_default_dtype()` + params_dtype : torch.dtype, default = torch.get_default_dtype() it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. - delay_wgrad_compute : bool, default = `False` - Whether or not to delay weight gradient computation. If set to `True`, - it's the user's responsibility to call `module.backward_dw` to compute + delay_wgrad_compute : bool, default = False + Whether or not to delay weight gradient computation. If set to ``True``, + it's the user's responsibility to call ``module.backward_dw`` to compute weight gradients. symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None Type of symmetric memory all-reduce to use during the forward pass. This can help in latency bound communication situations. - Requires PyTorch version 2.7.0 or higher. When set to None, standard all-reduce + Requires PyTorch version 2.7.0 or higher. When set to ``None``, standard all-reduce is used. - save_original_input : bool, default = `False` - If set to `True`, always saves the original input tensor rather than the + save_original_input : bool, default = False + If set to ``True``, always saves the original input tensor rather than the cast tensor. In some scenarios, the input tensor is used by multiple modules, and saving the original input tensor may reduce the memory usage. Cannot work with FP8 DelayedScaling recipe. diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index fb267d8a9b..9ebff0c81d 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -41,8 +41,8 @@ class RMSNorm(_RMSNormOp): Tensor device dtype: torch.dtype, default = default dtype Tensor datatype - zero_centered_gamma : bool, default = 'False' - If `True`, the :math:`\gamma` parameter is initialized to zero + zero_centered_gamma : bool, default = False + If ``True``, the :math:`\gamma` parameter is initialized to zero and the calculation changes to .. math:: @@ -52,13 +52,10 @@ class RMSNorm(_RMSNormOp): Number of SMs to exclude when launching CUDA kernels. This helps overlap with other kernels, e.g. communication kernels. For more fine-grained control, provide a dict with the SM - margin at each compute stage ("forward", "backward", - "inference"). - - Legacy - ------ + margin at each compute stage (``"forward"``, ``"backward"``, + ``"inference"``). sequence_parallel: bool - Set a bool attr named `sequence_parallel` in the parameters. + **Legacy parameter.** Set a bool attr named ``sequence_parallel`` in the parameters. This is custom logic for Megatron-LM integration. """ diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 27b6983f4b..39a5aa6ab7 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -668,7 +668,7 @@ def fp8_model_init( .. warning:: fp8_model_init is deprecated and will be removed in a future release. Use - quantized_model_init(enabled=..., recipe=..., preserve_high_precision_init_val=...) instead. + ``quantized_model_init(enabled=..., recipe=..., preserve_high_precision_init_val=...)`` instead. """ @@ -713,7 +713,7 @@ def quantized_model_init( Parameters ---------- - enabled: bool, default = `True` + enabled: bool, default = True when enabled, Transformer Engine modules created inside this `quantized_model_init` region will hold only quantized copies of its parameters, as opposed to the default behavior where both higher precision and quantized copies are present. Setting this @@ -724,9 +724,9 @@ def quantized_model_init( precision copies of weights are already present in the optimizer. * inference, where only the quantized copies of the parameters are used. * LoRA-like fine-tuning, where the main parameters of the model do not change. - recipe: transformer_engine.common.recipe.Recipe, default = `None` + recipe: transformer_engine.common.recipe.Recipe, default = None Recipe used to create the parameters. If left to None, it uses the default recipe. - preserve_high_precision_init_val: bool, default = `False` + preserve_high_precision_init_val: bool, default = False when enabled, store the high precision tensor used to initialize quantized parameters in CPU memory, and add two function attributes named `get_high_precision_init_val()` and `clear_high_precision_init_val()` to quantized parameters to get/clear this high @@ -763,8 +763,8 @@ def fp8_autocast( """ .. warning:: - fp8_autocast is deprecated and will be removed in a future release. - Use autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...) instead. + ``fp8_autocast`` is deprecated and will be removed in a future release. + Use ``autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...)`` instead. """ @@ -818,16 +818,16 @@ def autocast( Parameters ---------- - enabled: bool, default = `True` + enabled: bool, default = True whether or not to enable low precision quantization (FP8/FP4). - calibrating: bool, default = `False` + calibrating: bool, default = False calibration mode allows collecting statistics such as amax and scale data of quantized tensors even when executing without quantization enabled. This is useful for saving an inference ready checkpoint while training using a higher precision. - recipe: recipe.Recipe, default = `None` + recipe: recipe.Recipe, default = None recipe used for low precision quantization. - amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = `None` + amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = None distributed group over which amaxes for the quantized tensors are reduced at the end of each training step. """ diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 15f5b6bd5e..a73ca38a97 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -49,11 +49,11 @@ def update_usage( Parameters ---------- - rowwise_usage : Optional[bool[, default = `None` + rowwise_usage : Optional[bool[, default = None Whether to create or keep the data needed for using the tensor in rowwise fashion (e.g. as B argument in TN GEMM). Leaving it as `None` preserves the original value in the tensor. - columnwise_usage : Optional[bool], default = `None` + columnwise_usage : Optional[bool], default = None Whether to create or keep the data needed for using the tensor in columnwise fashion (e.g. as A argument in TN GEMM). Leaving it as `None` preserves the original value in the tensor. diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 60fd024d71..6f5a209da8 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -75,8 +75,8 @@ class TransformerLayer(torch.nn.Module): .. note:: - Argument :attr:`attention_mask` in the `forward` call is only used when - :attr:`self_attn_mask_type` includes `"padding"` or `"arbitrary"`. + Argument :attr:`attention_mask` in the :meth:`forward` call is only used when + :attr:`self_attn_mask_type` includes ``"padding"`` or ``"arbitrary"``. Parameters ---------- @@ -86,14 +86,14 @@ class TransformerLayer(torch.nn.Module): intermediate size to which input samples are projected. num_attention_heads : int number of attention heads in the transformer layer. - num_gqa_groups : int, default = `None` + num_gqa_groups : int, default = None number of GQA groups in the transformer layer. Grouped Query Attention is described in `this paper `_. This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention (`MQA `_), while GQA-H - is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. + is equivalent to MHA, i.e. ``num_gqa_groups = num_attention_heads``. layernorm_epsilon : float, default = 1e-5 a value added to the denominator of layer normalization for numerical stability. @@ -101,61 +101,61 @@ class TransformerLayer(torch.nn.Module): dropout probability for the dropout op after FC2 layer. attention_dropout: float, default = 0.1 dropout probability for the dropout op during multi-head attention. - init_method : Callable, default = `None` + init_method : Callable, default = None used for initializing weights of QKV and FC1 weights in the following way: - `init_method(weight)`. When set to `None`, defaults to - `torch.nn.init.normal_(mean=0.0, std=0.023)`. - output_layer_init_method : Callable, default = `None` + ``init_method(weight)``. When set to ``None``, defaults to + ``torch.nn.init.normal_(mean=0.0, std=0.023)``. + output_layer_init_method : Callable, default = None used for initializing weights of PROJ and FC2 in the following way: - `output_layer_init_method(weight)`. When set to `None`, defaults to - `torch.nn.init.normal_(mean=0.0, std=0.023)`. - apply_residual_connection_post_layernorm : bool, default = `False` - if set to `True`, residual connections are taken + ``output_layer_init_method(weight)``. When set to ``None``, defaults to + ``torch.nn.init.normal_(mean=0.0, std=0.023)``. + apply_residual_connection_post_layernorm : bool, default = False + if set to ``True``, residual connections are taken from the output of layer norm (default is taken from input of layer norm) - layer_number: int, default = `None` - layer number of the current `TransformerLayer` when multiple such modules are + layer_number: int, default = None + layer number of the current :class:`TransformerLayer` when multiple such modules are concatenated to form a transformer block. - output_layernorm: bool, default = `False` - if set to `True`, layer normalization is applied on the output side, + output_layernorm: bool, default = False + if set to ``True``, layer normalization is applied on the output side, after the final dropout-add. default behavior is to apply layer normalization on the input side, before the QKV transformation. - parallel_attention_mlp: bool, default = `False` - if set to `True`, self-attention and feedforward network are computed + parallel_attention_mlp: bool, default = False + if set to ``True``, self-attention and feedforward network are computed based on the same input (in parallel) instead of sequentially. Both blocks have an independent normalization. This architecture is used in `Falcon` models. - layer_type: {'encoder', 'decoder'}, default = `encoder` - if set to `decoder`, an additional cross-attn block is added after self-attn. + layer_type: {'encoder', 'decoder'}, default = "encoder" + if set to ``"decoder"``, an additional cross-attn block is added after self-attn. This can be used for structures like `T5` Transformer in conjunction with the - `encoder` option. - kv_channels: int, default = `None` + ``"encoder"`` option. + kv_channels: int, default = None number of query-key-value channels per attention head. defaults to - :attr:`hidden_size` / :attr:`num_attention_heads` if `None`. + :attr:`hidden_size` / :attr:`num_attention_heads` if ``None``. self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right', 'arbitrary'}, - default = `causal` + default = "causal" type of attention mask passed into softmax operation for encoder. - Overridden by :attr:`self_attn_mask_type` in the `forward` method. - The forward arg is useful for dynamically changing mask types, e.g. - a different mask for training and inference. The init arg is useful + Overridden by :attr:`self_attn_mask_type` in the :meth:`forward` method. + The :meth:`forward` arg is useful for dynamically changing mask types, e.g. + a different mask for training and inference. The :meth:`__init__` arg is useful for cases involving compilation/tracing, e.g. ONNX export. - window_size: Optional[Tuple[int, int]], default = `None` + window_size: Optional[Tuple[int, int]], default = None sliding window size for local attention in encoder, where query at position i - attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - - seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean - no sliding window and causal mask specifically. Both `causal` and - `causal_bottom_right` masks map to `window_size = (-1, 0)` and Transformer Engine - distinguishes them based on `self_attn_mask_type` or `enc_dec_attn_mask_type`. - Similar to :attr:`self_attn_mask_type`, `window_size` can be overridden by - :attr:`window_size` in `forward` as well. + attends to keys in ``[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k + - seqlen_q + window_size[1]]`` inclusive. Special cases ``(-1, -1)`` and ``(-1, 0)`` mean + no sliding window and causal mask specifically. Both ``"causal"`` and + ``"causal_bottom_right"`` masks map to :attr:`window_size` = ``(-1, 0)`` and Transformer Engine + distinguishes them based on :attr:`self_attn_mask_type` or :attr:`enc_dec_attn_mask_type`. + Similar to :attr:`self_attn_mask_type`, :attr:`window_size` can be overridden by + :attr:`window_size` in :meth:`forward` as well. enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'}, - default = `no_mask` + default = "no_mask" type of attention mask passed into softmax operation for decoder. - enc_dec_window_size: Optional[Tuple[int, int]], default = `None` + enc_dec_window_size: Optional[Tuple[int, int]], default = None sliding window size for local attention in decoder. - zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in LayerNorm is initialized to 0 and + zero_centered_gamma : bool, default = False + if set to ``True``, gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to .. math:: @@ -163,80 +163,92 @@ class TransformerLayer(torch.nn.Module): (1 + \gamma) + \beta normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' type of normalization applied. - qkv_weight_interleaved : bool, default = `True` - if set to `False`, the QKV weight is interpreted as a concatenation of - query, key, and value weights along the `0th` dimension. The default - interpretation is that the individual `q`, `k`, and `v` weights for each - attention head are interleaved. This parameter is set to `False` when + qkv_weight_interleaved : bool, default = True + if set to ``False``, the QKV weight is interpreted as a concatenation of + query, key, and value weights along the ``0th`` dimension. The default + interpretation is that the individual ``q``, ``k``, and ``v`` weights for each + attention head are interleaved. This parameter is set to ``False`` when using :attr:`fuse_qkv_params=False`. - rotary_pos_interleaved : bool, default = `False` + rotary_pos_interleaved : bool, default = False whether to use interleaved rotary position embeddings. - bias : bool, default = `True` - if set to `False`, the transformer layer will not learn any additive biases. + bias : bool, default = True + if set to ``False``, the transformer layer will not learn any additive biases. activation : str, default = 'gelu' Type of activation used in MLP block. - Options are: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu', - 'silu', 'swiglu', and 'clamped_swiglu'. - activation_params : Optional[dict], default = `None` + Options are: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``, + ``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``. + activation_params : Optional[dict], default = None Additional parameters for the activation function. - At the moment, only used for 'clamped_swiglu' activation which - supports 'limit' and 'alpha' parameters. You can set these as - `activation_params={'limit': 7.0, 'alpha': 1.702}`. + At the moment, only used for ``'clamped_swiglu'`` activation which + supports ``'limit'`` and ``'alpha'`` parameters. You can set these as + ``activation_params={'limit': 7.0, 'alpha': 1.702}``. device : Union[torch.device, str], default = "cuda" The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. attn_input_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd' - This controls whether the dimensions of the - intermediate hidden states is 'sequence first' ('sbhd'), 'batch first' ('bshd'), - or 'token first' ('thd'). `s` stands for the sequence length, `b` batch size, - `t` the total number of tokens, `h` the number of heads, `d` head size. - Note that these formats are very closely - related to the `qkv_format` in the `MultiHeadAttention` - and `DotProductAttention` modules. - name: str, default = `None` + This controls whether the dimensions of the + intermediate hidden states is 'sequence first' (``'sbhd'``), 'batch first' (``'bshd'``), + or 'token first' (``'thd'``). ``s`` stands for the sequence length, ``b`` batch size, + ``t`` the total number of tokens, ``h`` the number of heads, ``d`` head size. + Note that these formats are very closely + related to the :attr:`qkv_format` parameter in the :class:`MultiHeadAttention` + and :class:`DotProductAttention` modules. + name: str, default = None name of the module, currently used for debugging purposes. softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' Softmax type as described in the paper `Efficient Streaming Language Models with Attention Sinks `_. - For a given attention score ``S = Q x K^T``, of shape ``[b, h, s_q, s_kv]``: + For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``: - * 'vanilla': ``S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1)`` - * 'off-by-one': ``S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1))`` - * 'learnable': ``S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1))`` + * ``'vanilla'``: - where ``alpha`` is a learnable parameter in shape ``[h]``. - 'off-by-one' and 'learnable' softmax types are also called sink attention - ('zero sink' and 'learnable sink'). + .. math:: + Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})} + + * ``'off-by-one'``: + + .. math:: + Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})} + + * ``'learnable'``: + + .. math:: + Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})} + + where :math:`\\alpha` is a learnable parameter of shape ``[h]``. + + ``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention + (``'zero sink'`` and ``'learnable sink'``). Parallelism parameters ---------------------- - set_parallel_mode : bool, default = `False` - if set to `True`, QKV and FC1 layers are used as Column Parallel + set_parallel_mode : bool, default = False + if set to ``True``, QKV and FC1 layers are used as Column Parallel whereas PROJ and FC2 is used as Row Parallel as described `here `_. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - tp_group : ProcessGroup, default = `None` + sequence_parallel : bool, default = False + if set to ``True``, uses sequence parallelism. + tp_group : ProcessGroup, default = None tensor parallel process group. tp_size : int, default = 1 used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the - `set_tensor_parallel_group(tp_group)` method on the initialized module before the + :meth:`set_tensor_parallel_group` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. Optimization parameters ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of + fuse_wgrad_accumulation : bool, default = False + if set to ``True``, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct + have an additional :attr:`main_grad` attribute (used instead of the + regular :attr:`grad`) which is a pre-allocated buffer of the correct size to accumulate gradients in. - params_dtype : torch.dtype, default = `torch.get_default_dtype()` + params_dtype : torch.dtype, default = torch.get_default_dtype() it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. @@ -251,26 +263,26 @@ class TransformerLayer(torch.nn.Module): drop_path_rate: float, default = 0.0 when > 0.0, applies stochastic depth per sample in the main path of the residual block. - fuse_qkv_params: bool, default = 'False' - if set to `True`, `TransformerLayer` module exposes a single fused + fuse_qkv_params: bool, default = False + if set to ``True``, :class:`TransformerLayer` module exposes a single fused parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument - `fuse_wgrad_accumulation`. + :attr:`fuse_wgrad_accumulation`. qk_norm_type: Optional[str], default = None type of normalization to apply to query and key tensors. - Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'. When None, no normalization is applied. - When 'L2Normalization', L2 normalization is applied to query and key tensors. - When 'RMSNorm', RMS normalization is applied to query and key tensors. - When 'LayerNorm', layer normalization is applied to query and key tensors. + Options: ``None``, ``'L2Normalization'``, ``'RMSNorm'``, ``'LayerNorm'``. When ``None``, no normalization is applied. + When ``'L2Normalization'``, L2 normalization is applied to query and key tensors. + When ``'RMSNorm'``, RMS normalization is applied to query and key tensors. + When ``'LayerNorm'``, layer normalization is applied to query and key tensors. Normalization is applied after RoPE (if applicable) but before attention computation - when `qk_norm_before_rope` is False. This follows the e.g. Llama4 approach for + when ``qk_norm_before_rope`` is ``False``. This follows the e.g. Llama4 approach for QK normalization to improve training stability and model performance. qk_norm_eps: float, default = 1e-6 epsilon value for normalization of query and key tensors. - Only used when `qk_norm_type` is not None. - qk_norm_before_rope: bool, default = `False` - if set to `True`, query and key normalization is applied before rotary position - embedding. When `False` (default), normalization is applied after RoPE. + Only used when ``qk_norm_type`` is not ``None``. + qk_norm_before_rope: bool, default = False + if set to ``True``, query and key normalization is applied before rotary position + embedding. When ``False`` (default), normalization is applied after RoPE. This parameter allows supporting different architectural variants that apply QK normalization at different points. """ @@ -526,7 +538,7 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N Parameters ---------- - tp_group : ProcessGroup, default = `None` + tp_group : ProcessGroup, default = None tensor parallel process group. """ # Deep iterate but skip self to avoid infinite recursion. @@ -552,7 +564,7 @@ def set_context_parallel_group( cp_stream: torch.cuda.Stream, cp_comm_type: str = "p2p", ) -> None: - """ + r""" Set the context parallel attributes for the given module before executing the forward pass. @@ -560,14 +572,14 @@ def set_context_parallel_group( ---------- cp_group : Union[ProcessGroup, List[ProcessGroup]] context parallel process group. - ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a". - List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0] - and cp_group[1] are for a2a and p2p communications respectively. + ProcessGroup is for cp_comm_type of ``"p2p"``, ``"all_gather"``, and ``"a2a"``. + List[ProcessGroup] is for cp_comm_type of ``"a2a+p2p"``, where ``cp_group[0]`` + and ``cp_group[1]`` are for a2a and p2p communications respectively. cp_global_ranks : List[int] list of global ranks in the context group. cp_stream : torch.cuda.Stream cuda stream for context parallel execution. - cp_comm_type : str, default = `p2p` + cp_comm_type : str, default = "p2p" inter-gpu communication type for context parallelism. Can be ``"p2p"`` or ``"all_gather"`` or ``"a2a"`` or ``"a2a+p2p"``. @@ -614,49 +626,49 @@ def forward( fast_zero_fill: bool = True, pad_between_seqs: Optional[bool] = None, ) -> torch.Tensor: - """ + r""" Transformer Layer: attention block and a feedforward network (MLP) .. note:: Argument :attr:`attention_mask` is only used when :attr:`self_attn_mask_type` - includes `"padding"` or `"arbitrary"`. + includes ``"padding"`` or ``"arbitrary"``. Parameters ---------- hidden_states : torch.Tensor Input tensor. - attention_mask : Optional[torch.Tensor], default = `None` + attention_mask : Optional[torch.Tensor], default = None Boolean tensor used to mask out self-attention softmax input. It should be - in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable - to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for "`arbitrary`" - mask. It should be `None` for causal masks and "`no_mask`" type. - A `True` value means the corresponding position is masked out and - a `False` means that position is allowed to participate in attention. + in ``[batch_size, 1, 1, seqlen_q]`` for padding masks, and broadcastable + to ``[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]`` for ``"arbitrary"`` + mask. It should be ``None`` for causal masks and ``"no_mask"`` type. + A ``True`` value means the corresponding position is masked out and + a ``False`` means that position is allowed to participate in attention. self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'}, - default = `causal` + default = "causal" Type of attention mask passed into softmax operation for encoder. By default, causal masks are aligned to the top left corner of - the softmax matrix. When "`bottom_right`" is specified in the mask type, + the softmax matrix. When ``"bottom_right"`` is specified in the mask type, causal masks are aligned to the bottom right corner. - window_size: Optional[Tuple[int, int]], default = `None` + window_size: Optional[Tuple[int, int]], default = None Sliding window size for local attention in encoder. - encoder_output : Optional[torch.Tensor], default = `None` + encoder_output : Optional[torch.Tensor], default = None Output of the encoder block to be fed into the decoder block if using - `layer_type="decoder"`. + :attr:`layer_type` = ``"decoder"``. enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], - default = `None`. Boolean tensors used to mask out inter-attention softmax input if - using `layer_type="decoder"`. It should be a tuple of two masks in - [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks. - It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] - for "`arbitrary`" mask. It should be `None` for causal masks and "`no_mask`". - A `True` value means the corresponding position is masked out and a `False` + default = None. Boolean tensors used to mask out inter-attention softmax input if + using :attr:`layer_type` = ``"decoder"``. It should be a tuple of two masks in + ``[batch_size, 1, 1, seqlen_q]`` and ``[batch_size, 1, 1, seqlen_kv]`` for padding masks. + It should be broadcastable to ``[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]`` + for ``"arbitrary"`` mask. It should be ``None`` for causal masks and ``"no_mask"``. + A ``True`` value means the corresponding position is masked out and a ``False`` means that position is allowed to participate in attention. enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'}, - default = `None` + default = None Type of attention mask passed into softmax operation for decoder. - enc_dec_window_size: Optional[Tuple[int, int]], default = `None` + enc_dec_window_size: Optional[Tuple[int, int]], default = None Sliding window size for local attention in decoder. is_first_microbatch : {True, False, None}, default = None During training using either gradient accumulation or @@ -671,53 +683,53 @@ def forward( * it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced) - checkpoint_core_attention: bool, default = `False` - If true, forward activations for core attention are recomputed + checkpoint_core_attention: bool, default = False + If ``True``, forward activations for core attention are recomputed during the backward pass in order to save memory that would otherwise be occupied to store the forward activations until backprop. - rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None` + rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = None Embeddings for query and key tensors for applying rotary position embedding. By default no input embedding is applied. - core_attention_bias_type: str, default = `no_bias` - Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`} - core_attention_bias: Optional[torch.Tensor], default = `None` - Bias tensor for Q * K.T - alibi_slopes: Optional[torch.Tensor], default = `None` - ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. - It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) + core_attention_bias_type: str, default = "no_bias" + Bias type, {``"no_bias"``, ``"pre_scale_bias"``, ``"post_scale_bias"``, ``"alibi"``} + core_attention_bias: Optional[torch.Tensor], default = None + Bias tensor for :math:`Q \cdot K^T` + alibi_slopes: Optional[torch.Tensor], default = None + ALiBi slopes in FP32 and shape ``[nheads]`` or ``[batch_size, nheads]``. + It adds a bias of :math:`(-\text{alibi_slope} \cdot (i + \text{seqlen_k} - \text{seqlen_q} - j))` to the attention score of query i and key j. - cu_seqlens_q: Optional[torch.Tensor], default = `None` - Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, - with shape [batch_size + 1] and dtype torch.int32. + cu_seqlens_q: Optional[torch.Tensor], default = None + Cumulative sum of sequence lengths (without offset) in a batch for query layer, + with shape ``[batch_size + 1]`` and dtype torch.int32. Used by encoders, or decoders' self-attention. - cu_seqlens_kv: Optional[torch.Tensor], default = `None` - Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` - and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. + cu_seqlens_kv: Optional[torch.Tensor], default = None + Cumulative sum of sequence lengths (without offset) in a batch for key layer + and value layer, with shape ``[batch_size + 1]`` and dtype torch.int32. Used by decoders' cross-attention. - cu_seqlens_q_padded: Optional[torch.Tensor], default = `None` - Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`, - with shape [batch_size + 1] and dtype torch.int32. Set to `cu_seqlens_q` if None. + cu_seqlens_q_padded: Optional[torch.Tensor], default = None + Cumulative sum of sequence lengths (with offset) in a batch for query layer, + with shape ``[batch_size + 1]`` and dtype torch.int32. Set to :attr:`cu_seqlens_q` if ``None``. Used by encoders, or decoders' self-attention. - cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None` - Cumulative sum of sequence lengths (with offset) in a batch for `key_layer` - and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. - Set to `cu_seqlens_kv` if None. Used by decoders' cross-attention. - max_seqlen_q: Optional[int], default = `None` - Maximum sequence length in `query_layer`. - Calculated from `cu_seqlens_q_padded` if not provided. - max_seqlen_kv: Optional[int], default = `None` - Maximum sequence length in `key_layer` and `value_layer`. - Calculated from `cu_seqlens_kv_padded` if not provided. - fast_zero_fill: bool, default = `True` + cu_seqlens_kv_padded: Optional[torch.Tensor], default = None + Cumulative sum of sequence lengths (with offset) in a batch for key layer + and value layer, with shape ``[batch_size + 1]`` and dtype torch.int32. + Set to :attr:`cu_seqlens_kv` if ``None``. Used by decoders' cross-attention. + max_seqlen_q: Optional[int], default = None + Maximum sequence length in query layer. + Calculated from :attr:`cu_seqlens_q_padded` if not provided. + max_seqlen_kv: Optional[int], default = None + Maximum sequence length in key layer and value layer. + Calculated from :attr:`cu_seqlens_kv_padded` if not provided. + fast_zero_fill: bool, default = True Whether to set output tensors to 0 or not before use. inference_params: InferenceParams, default = None Inference parameters that are passed to the main model in order to efficiently calculate and store the context during inference. - pad_between_seqs: Optional[bool], default = `None` - If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. - If true, there are padding tokens between individual sequences in a packed batch, - i.e. qkv_format = 'thd'. + pad_between_seqs: Optional[bool], default = None + If ``None``, inferred from :attr:`qkv_format`, cu_seqlens and cu_seqlens_padded. + If ``True``, there are padding tokens between individual sequences in a packed batch, + i.e. :attr:`qkv_format` = ``'thd'``. """ if self_attn_mask_type is None: From b9506aa6079938f92c1c24b602cc914f66f0a4bd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 4 Nov 2025 21:42:48 +0000 Subject: [PATCH 14/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/cross_entropy.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/cross_entropy.py b/transformer_engine/pytorch/cross_entropy.py index a71b108956..30002cdbfd 100644 --- a/transformer_engine/pytorch/cross_entropy.py +++ b/transformer_engine/pytorch/cross_entropy.py @@ -76,9 +76,7 @@ def backward(ctx, grad_output): tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. """ (inp,) = ctx.saved_tensors - inp = triton_cross_entropy.cross_entropy_backward( - inp, grad_output, ctx.is_cg_capturable - ) + inp = triton_cross_entropy.cross_entropy_backward(inp, grad_output, ctx.is_cg_capturable) return ( inp, None,