-
Couldn't load subscription status.
- Fork 532
[JAX] Use TE quantization when TE fused norm is disable #2303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Phuong Nguyen <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR refactors the JAX normalization layer (layernorm and rmsnorm) to decouple quantization from fused normalization kernels by switching from the internal _quantize_dbias_impl to the public quantize function. When TE's fused normalization primitives are disabled (fallback path), the code now performs normalization first using _jax_layernorm / _jax_rmsnorm and then applies quantization via the standalone quantize helper. The quantizer parameter is made optional (defaults to None) to support cases where quantization is not needed. This change unifies the quantization API surface across the codebase and delivers 0.06%-0.52% E2E training speedup on GB200 hardware for LLama 3.18B across four quantization recipes (FP8 delayed/current scaling, MXFP8, NVFP4) by enabling optimized quantization (e.g., flatten_axis=-1) in the fallback path that was previously missing.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/cpp_extensions/normalization.py | 3/5 | Refactored layernorm_fwd and rmsnorm_fwd to use public quantize function instead of _quantize_dbias_impl; made quantizer optional; decoupled normalization from quantization when fused ops are disabled |
Confidence score: 3/5
- This PR introduces performance optimization with a clear refactoring goal but contains subtle API inconsistencies that need verification before merge.
- Score reduced primarily due to inconsistent handling of
quantizereturn values (some call sites unpack a tuple on line 1039 while others treat it as a single value on line 1060) and the removal ofis_dbiasparameter without clear documentation of whether this affects backward compatibility with existing quantizer implementations. - Pay close attention to lines 1039 and 1060 in
transformer_engine/jax/cpp_extensions/normalization.pyto verify thequantizefunction's return signature is handled correctly in all code paths, and confirm that removingis_dbiasanddq_dtypeparameters doesn't break any existing quantizer implementations.
1 file reviewed, no comments
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR refactors the JAX normalization layer to use the public quantize() API instead of the internal _quantize_dbias_impl() function when TransformerEngine's fused normalization is disabled. The change affects both LayerNorm and RMSNorm code paths, routing all quantization operations through a unified entry point. When NormFwdPrimitive.enabled() returns False, the code now performs normalization using JAX's native implementation, then applies TE's optimized quantization kernel separately. This decouples normalization from quantization, enabling TE's performant quantization even when the fused norm kernel is unavailable (e.g., due to cuDNN version constraints or environment flags). The refactor also simplifies the API by removing is_dbias and dq_dtype parameters from quantize calls. The change maintains all existing fallback paths for MXFP8, current tensor scaling, and NVFP4 formats, demonstrating measurable performance improvements (0.06-0.52%) for end-to-end LLaMA 3.1 8B training on GB200 hardware.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/cpp_extensions/normalization.py | 4/5 | Replaces internal _quantize_dbias_impl() calls with public quantize() API across LayerNorm and RMSNorm fallback paths, simplifying quantization interface while routing all quantization through unified TE kernel |
Confidence score: 4/5
- This PR is safe to merge with minimal risk, as it's primarily a refactoring that consolidates quantization logic through a well-tested public API
- Score reflects that the change is well-structured and maintains backward compatibility across all fallback paths (MXFP8, current scaling, NVFP4), though the PR checklist indicates tests were not added to specifically validate the refactored quantization paths
- Pay close attention to transformer_engine/jax/cpp_extensions/normalization.py to ensure the
quantize()API behaves identically to the previous_quantize_dbias_impl()implementation across all quantization formats, especially for edge cases involving cuDNN version constraints and transpose_batch_sequence flags
1 file reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! LGTM
|
/te-ci JAX L0 |
Description
Perf improvement over the current main branch for E2E training of LLama 3.18B on a GB200.
Type of change
Checklist: