Skip to content

Conversation

KshitijLakhani
Copy link
Collaborator

@KshitijLakhani KshitijLakhani commented Aug 28, 2025

Description

The cuDNN kernel path to allow for dropout along with determinism (for bias) is unsupported as of now on sm100 resulting in failures. This PR helps fix this.

I confirmed that 78 test failures were seen in the JAX nightly are now being skipped in the pipeline run for this PR and that there are no new additional skips due to the logic in the PR.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

i) Skips the fused attn tests so as to not try and exercise the unsupported path.
ii) Puts an assert in place in fused attn bwd to guard against inadvertent usage of the unsupported path

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@KshitijLakhani KshitijLakhani self-assigned this Aug 28, 2025
@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/jax-dropout-bias-sm100-test-failures branch from a85ad4f to 330ffd5 Compare August 28, 2025 18:49
@KshitijLakhani KshitijLakhani marked this pull request as ready for review August 29, 2025 17:10
@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/jax-dropout-bias-sm100-test-failures branch from cd1b8c9 to 4a2a993 Compare August 29, 2025 17:10
@KshitijLakhani KshitijLakhani changed the title Fix failing tests for dropout=0.1 and bias for fused attn for blackwell [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 Aug 29, 2025
@KshitijLakhani KshitijLakhani requested review from phu0ngng and cyanguwa and removed request for phu0ngng August 29, 2025 17:17
phu0ngng
phu0ngng previously approved these changes Aug 29, 2025
Copy link
Collaborator

@phu0ngng phu0ngng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks!

@KshitijLakhani KshitijLakhani marked this pull request as draft August 29, 2025 17:52
@KshitijLakhani KshitijLakhani marked this pull request as ready for review August 29, 2025 22:24
@KshitijLakhani
Copy link
Collaborator Author

KshitijLakhani commented Aug 29, 2025

#34107254 Launched a pipeline as I needed to alter the JAX image for testing.
Using the internal nightly JAX image dated 08/26 as suggested by Brian(Jinxin Yag)
NO Failures

@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/jax-dropout-bias-sm100-test-failures branch from 942d624 to 319807a Compare August 29, 2025 22:26
@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax

@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/jax-dropout-bias-sm100-test-failures branch from 1ac4bd5 to b76131c Compare August 31, 2025 23:58
Copy link
Collaborator

@phu0ngng phu0ngng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@KshitijLakhani KshitijLakhani merged commit f378eaf into NVIDIA:main Sep 3, 2025
12 checks passed
@KshitijLakhani KshitijLakhani deleted the klakhani/fix/jax-dropout-bias-sm100-test-failures branch September 3, 2025 21:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants