-
Notifications
You must be signed in to change notification settings - Fork 502
[JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 #2135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 #2135
Conversation
a85ad4f
to
330ffd5
Compare
cd1b8c9
to
4a2a993
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks!
#34107254 Launched a pipeline as I needed to alter the JAX image for testing. |
942d624
to
319807a
Compare
/te-ci jax |
Signed-off-by: Kshitij Lakhani <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Lakhani <[email protected]>
Signed-off-by: Kshitij Lakhani <[email protected]> Add check for sm100 Signed-off-by: Kshitij Lakhani <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Lakhani <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Lakhani <[email protected]>
…g unnecessary type conversion Signed-off-by: Kshitij Lakhani <[email protected]>
Signed-off-by: Kshitij Lakhani <[email protected]>
1ac4bd5
to
b76131c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
Description
The cuDNN kernel path to allow for dropout along with determinism (for bias) is unsupported as of now on sm100 resulting in failures. This PR helps fix this.
I confirmed that 78 test failures were seen in the JAX nightly are now being skipped in the pipeline run for this PR and that there are no new additional skips due to the logic in the PR.
Type of change
Changes
i) Skips the fused attn tests so as to not try and exercise the unsupported path.
ii) Puts an assert in place in fused attn bwd to guard against inadvertent usage of the unsupported path
Checklist: