Skip to content

Conversation

@phu0ngng
Copy link
Collaborator

Description

Perf improvement over the current main branch for E2E training of LLama 3.18B on a GB200.

Recipe Speedup
te_fp8_delayedscaling 0.23%
te_fp8_currentscaling 0.06%
te_mxfp8 0.45%
te_nvfp4 0.52%

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring
  • Performance improvement

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Phuong Nguyen <[email protected]>
@phu0ngng phu0ngng marked this pull request as ready for review October 24, 2025 17:09
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR refactors the JAX normalization layer (layernorm and rmsnorm) to decouple quantization from fused normalization kernels by switching from the internal _quantize_dbias_impl to the public quantize function. When TE's fused normalization primitives are disabled (fallback path), the code now performs normalization first using _jax_layernorm / _jax_rmsnorm and then applies quantization via the standalone quantize helper. The quantizer parameter is made optional (defaults to None) to support cases where quantization is not needed. This change unifies the quantization API surface across the codebase and delivers 0.06%-0.52% E2E training speedup on GB200 hardware for LLama 3.18B across four quantization recipes (FP8 delayed/current scaling, MXFP8, NVFP4) by enabling optimized quantization (e.g., flatten_axis=-1) in the fallback path that was previously missing.

Important Files Changed

Filename Score Overview
transformer_engine/jax/cpp_extensions/normalization.py 3/5 Refactored layernorm_fwd and rmsnorm_fwd to use public quantize function instead of _quantize_dbias_impl; made quantizer optional; decoupled normalization from quantization when fused ops are disabled

Confidence score: 3/5

  • This PR introduces performance optimization with a clear refactoring goal but contains subtle API inconsistencies that need verification before merge.
  • Score reduced primarily due to inconsistent handling of quantize return values (some call sites unpack a tuple on line 1039 while others treat it as a single value on line 1060) and the removal of is_dbias parameter without clear documentation of whether this affects backward compatibility with existing quantizer implementations.
  • Pay close attention to lines 1039 and 1060 in transformer_engine/jax/cpp_extensions/normalization.py to verify the quantize function's return signature is handled correctly in all code paths, and confirm that removing is_dbias and dq_dtype parameters doesn't break any existing quantizer implementations.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR refactors the JAX normalization layer to use the public quantize() API instead of the internal _quantize_dbias_impl() function when TransformerEngine's fused normalization is disabled. The change affects both LayerNorm and RMSNorm code paths, routing all quantization operations through a unified entry point. When NormFwdPrimitive.enabled() returns False, the code now performs normalization using JAX's native implementation, then applies TE's optimized quantization kernel separately. This decouples normalization from quantization, enabling TE's performant quantization even when the fused norm kernel is unavailable (e.g., due to cuDNN version constraints or environment flags). The refactor also simplifies the API by removing is_dbias and dq_dtype parameters from quantize calls. The change maintains all existing fallback paths for MXFP8, current tensor scaling, and NVFP4 formats, demonstrating measurable performance improvements (0.06-0.52%) for end-to-end LLaMA 3.1 8B training on GB200 hardware.

Important Files Changed

Filename Score Overview
transformer_engine/jax/cpp_extensions/normalization.py 4/5 Replaces internal _quantize_dbias_impl() calls with public quantize() API across LayerNorm and RMSNorm fallback paths, simplifying quantization interface while routing all quantization through unified TE kernel

Confidence score: 4/5

  • This PR is safe to merge with minimal risk, as it's primarily a refactoring that consolidates quantization logic through a well-tested public API
  • Score reflects that the change is well-structured and maintains backward compatibility across all fallback paths (MXFP8, current scaling, NVFP4), though the PR checklist indicates tests were not added to specifically validate the refactored quantization paths
  • Pay close attention to transformer_engine/jax/cpp_extensions/normalization.py to ensure the quantize() API behaves identically to the previous _quantize_dbias_impl() implementation across all quantization formats, especially for edge cases involving cuDNN version constraints and transpose_batch_sequence flags

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! LGTM

@phu0ngng
Copy link
Collaborator Author

/te-ci JAX L0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants