-
Notifications
You must be signed in to change notification settings - Fork 547
[JAX] Fix bug with pre scale bias #2300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
|
/te-ci jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR fixes a critical bug in the JAX attention implementation where pre-scale bias was being added twice to attention weights. The fix modifies transformer.py to add the bias once in the unfused attention path and then nullify it (bias = None) to prevent the downstream DPA module from adding it again. Additionally, the test suite was strengthened by replacing forward-pass mean-only comparison with full tensor comparison for non-FP8 modes, addressing inadequate test coverage that allowed the bug to persist (backward tests used rtol/atol ~0.039 when actual gradient differences were ~0.003).
The changes integrate with the existing attention bias architecture where AttnBiasType.PRE_SCALE_BIAS requires special handling—bias must be added before scaling, unlike POST_SCALE_BIAS. The fix coordinates between the transformer layer and the Softmax/DPA module to ensure bias is applied exactly once in the attention computation pipeline.
Confidence score: 4/5
- This PR fixes a correctness bug and strengthens test coverage but has minor concerns about FP8 testing gaps
- Score reflects that while the bias fix is correct, the FP8 fallback in tests (lines 345-348 of test_layer.py) still uses mean-only comparison, perpetuating weak test coverage for that code path; also, parameter order inconsistency between
_output_fnand_loss_fncould cause maintenance issues - Pay close attention to tests/jax/test_layer.py lines 345-348where FP8 mode still uses mean-only comparison instead of full tensor comparison
2 files reviewed, 1 comment
tests/jax/test_layer.py
Outdated
| if not get_quantize_config().is_fp8_enabled(): | ||
| assert_allclose(ref_out, test_out, **tols) | ||
| else: | ||
| assert_allclose(ref_out.mean(), test_out.mean(), **tols) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Only comparing means in FP8 mode loses test coverage. This doesn't validate output shape, distribution, or element-wise correctness. The PR description mentions test_layer only tests means, which was insufficient to catch the double-addition bug. What specific tolerance values were found to work for full tensor comparison in FP8mode without false positives?
Prompt To Fix With AI
This is a comment left during a code review.
Path: tests/jax/test_layer.py
Line: 345:348
Comment:
**logic:** Only comparing means in FP8 mode loses test coverage. This doesn't validate output shape, distribution, or element-wise correctness. The PR description mentions test_layer only tests means, which was insufficient to catch the double-addition bug. What specific tolerance values were found to work for full tensor comparison in FP8mode without false positives?
How can I resolve this? If you propose a fix, please make it concise.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The difference for fp8 is around 0.4, which is huge. Maybe this is the bug and needs investigation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I decided to skip this tests to get this merged sooner.
|
Thanks for finding and raising the tolerance issue Pawel! I'll take a look today and see how we can make these tolerances stricter |
Signed-off-by: Pawel Gadzinski <[email protected]>
1e91f44 to
c0a0947
Compare
Signed-off-by: Pawel Gadzinski <[email protected]>
|
@jberchtold-nvidia I removed the tolerance change from this PR, since the pipeline was failing and I saw that you are working on it in the other PR. I will leave only the fix to the bug. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Fixed critical bug where pre-scale bias was incorrectly added twice in attention computation. When attn_bias_type == PRE_SCALE_BIAS, the bias was added to attn_weights at transformer.py:199 and then added again inside the Softmax module at module.py:194. The fix sets bias = None after the first addition to prevent double-addition.
Key changes:
- Added
bias = Noneafter pre-scale bias addition in_UnfusedDotProductAttention - Ensures bias is only applied once for PRE_SCALE_BIAS case
- POST_SCALE_BIAS path remains unchanged (bias still passed to Softmax)
Note from PR author: The original tests were insufficient to catch this bug because they only compared means in forward pass and used overly loose tolerances (rtol/atol ~0.039 for bf16 when actual gradient magnitudes were ~0.003) for backward pass comparisons.
Confidence Score: 5/5
- This PR is safe to merge - it's a minimal, surgical fix for a clear bug
- The fix is a simple one-line addition that correctly prevents double-addition of bias. The logic is clear: after manually adding bias for PRE_SCALE_BIAS case, setting it to None prevents it from being added again in the Softmax module. No edge cases or side effects identified.
- No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/flax/transformer.py | 5/5 | Fixed double-addition bug for PRE_SCALE_BIAS by setting bias to None after first addition |
Sequence Diagram
sequenceDiagram
participant DPA as _UnfusedDotProductAttention
participant SM as Softmax (module.py)
Note over DPA: attn_bias_type = PRE_SCALE_BIAS
rect rgb(255, 200, 200)
Note over DPA: BEFORE FIX (Bug)
DPA->>DPA: Line 199: attn_weights += bias
DPA->>SM: Line 240: Softmax(attn_weights, mask, bias)
SM->>SM: Line 194: logits = logits + bias
Note over SM: ❌ Bias added TWICE!
end
rect rgb(200, 255, 200)
Note over DPA: AFTER FIX (Correct)
DPA->>DPA: Line 199: attn_weights += bias
DPA->>DPA: Line 200: bias = None
DPA->>SM: Line 240: Softmax(attn_weights, mask, None)
SM->>SM: Line 194: Skip (bias is None)
Note over SM: ✓ Bias added once
end
1 file reviewed, no comments
|
/te-ci jax |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
KshitijLakhani
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Pawel Gadzinski <[email protected]>
* fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Pawel Gadzinski <[email protected]> Signed-off-by: Peter Dykas <[email protected]>
* [Common] Deleted unused header (#2324) Deleted unused header Signed-off-by: Oleg Goncharov <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * [JAX] L1_jax_distributed_test suit with individual executions (#2321) * L1 rework Signed-off-by: Phuong Nguyen <[email protected]> * comment out test_multi_process_grouped_gemm for now Signed-off-by: Phuong Nguyen <[email protected]> * rm e5m2 from test norm + MXFP8 Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * for branch Signed-off-by: Peter Dykas <[email protected]> * clean up and tests Signed-off-by: Peter Dykas <[email protected]> * change tests Signed-off-by: Peter Dykas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Peter Dykas <[email protected]> * [PyTorch debug] Fixes to debug tests failures (#2268) * code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix: Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Peter Dykas <[email protected]> * [PyTorch Debug] Add max_blockwise_dynamic_range stats (#2137) * code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Peter Dykas <[email protected]> * [JAX] Fix bug with pre scale bias (#2300) * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Pawel Gadzinski <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * [JAX] Try to use pre-downloaded dataset artifacts first (#2345) * Try to use pre-downloaded dataset artifacts first Signed-off-by: Jeremy Berchtold <[email protected]> * Set HF_HUB_OFFLINE to disable any network calls to HF when the pre-downloaded dataset is available Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * Fix out of bounds access in the FP4 dequantize kernel (#2346) Signed-off-by: Przemek Tredak <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * Make FP8 weights compatible with older MCore version (#2342) * Make cast_master_weights_to_fp8 compatible with older MCore version Signed-off-by: kunlunl <[email protected]> * Rename keep_columnwise to manual_post_all_gather_processing & Optimize unit test Signed-off-by: kunlunl <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove redundant _test_mini_optimizer() Signed-off-by: kunlunl <[email protected]> --------- Signed-off-by: kunlunl <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * [JAX] Add test to check jaxpr that amax is reused for nvfp4 recipe (#2348) * Add test to check jaxpr that amax is reused for nvfp4 recipe Signed-off-by: Jeremy Berchtold <[email protected]> * Move test to test_helper.py and rename file Signed-off-by: Jeremy Berchtold <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jeremy Berchtold <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Peter Dykas <[email protected]> * Fix sharding of segment position to match id in ring attention. (#2349) Signed-off-by: Peter Dykas <[email protected]> * Disable cuDNN attention for known IMA and NaNs (#2344) * Fix cuDNN backend selection for more case. Add CG as a option as well Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * fix logic Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix cuDNN checks Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Add more checks Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix cuddn version Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix error message Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Add check for window size Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * [JAX] Default to fused attention in JAX DPA (#2363) * Default to fused attention in JAX DPA Signed-off-by: Kshitij Lakhani <[email protected]> * Consolidate documentation for DPA in JAX Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kshitij Lakhani <[email protected]> * Correctly update the documentation for defaults in JAX DPA Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kshitij Lakhani <[email protected]> --------- Signed-off-by: Kshitij Lakhani <[email protected]> Signed-off-by: Kshitij Lakhani <[email protected]> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Peter Dykas <[email protected]> * Update cudnn frontend to v1.16.0 (#2362) Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * [common] Remove kvpacked and qkvpacked attention functions for every kernel type. (#2287) * code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * depracted compile time warning + \warning -> \deprecated Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski <[email protected]> Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charlene Yang <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * Move Triton to common (#2359) * move triton to common and change paths Signed-off-by: tdophung <[email protected]> * Formatting Signed-off-by: tdophung <[email protected]> --------- Signed-off-by: tdophung <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * [JAX] Fused layers argument default values changed (#2347) * Changing default activations in MLP, TransformerLayer, dropout rate after FC1 to 0, and return_layernorm_output to False Signed-off-by: tdophung <[email protected]> * Fixing the failing tests by hard coding arguments to the previous values instead of relying on newer default values Signed-off-by: tdophung <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: tdophung <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Peter Dykas <[email protected]> * remove comment from gpt Signed-off-by: Peter Dykas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor changes for num_splits logic Signed-off-by: Charlene Yang <[email protected]> * replace None with 1 as default Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix last commit Signed-off-by: Charlene Yang <[email protected]> * fix docstring Signed-off-by: Charlene Yang <[email protected]> * fix dtype in pack/unpack when FP8 Signed-off-by: Charlene Yang <[email protected]> * add fused_attn_supported constraint for some tests Signed-off-by: Charlene Yang <[email protected]> * update FA3 installation commands Signed-off-by: Charlene Yang <[email protected]> * update FA3 installation commands in DPA Signed-off-by: Charlene Yang <[email protected]> * separate fused fp8 and f16 flags in tests Signed-off-by: Charlene Yang <[email protected]> * initialize fused_attn_supported_f16 Signed-off-by: Charlene Yang <[email protected]> * fix FA installation in L3 tests Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Oleg Goncharov <[email protected]> Signed-off-by: Peter Dykas <[email protected]> Signed-off-by: Phuong Nguyen <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]> Signed-off-by: Jeremy Berchtold <[email protected]> Signed-off-by: Przemek Tredak <[email protected]> Signed-off-by: kunlunl <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Kshitij Lakhani <[email protected]> Signed-off-by: Kshitij Lakhani <[email protected]> Signed-off-by: Charlene Yang <[email protected]> Signed-off-by: tdophung <[email protected]> Co-authored-by: Oleg Goncharov <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> Co-authored-by: root <[email protected]> Co-authored-by: Peter Dykas <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Paweł Gadziński <[email protected]> Co-authored-by: jberchtold-nvidia <[email protected]> Co-authored-by: Przemyslaw Tredak <[email protected]> Co-authored-by: Kunlun Li <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Michael Goldfarb <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Kshitij Lakhani <[email protected]> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Co-authored-by: Teddy Do <[email protected]> Co-authored-by: wdykas <[email protected]>
Description
Pre-scale bias is added twice:
here for the first time:
TransformerEngine/transformer_engine/jax/flax/transformer.py
Line 199 in e2f2a0b
here for the second:
TransformerEngine/transformer_engine/jax/flax/module.py
Line 194 in e2f2a0b
But what's more important there is problem with test_layer. They test only mean of forward, which is not enough to catch most bugs. For backward, all tensors are compared, but with some very big precision computed here
TransformerEngine/tests/jax/utils.py
Line 1486 in e2f2a0b
For bf16 it is
{'rtol': 0.039372532809214794, 'atol': 0.039372532809214794}, but max of all wgrads and dgrads is around0.003, so basically it tests nothing.I did some temporary fix to catch this issue, we will see if anything more will be caught. But these tols need to be lowered.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: