Commit f378eaf
[JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135)
* Fix failing tests for dropout=0.1 and bias for fused attn for blackwell
Signed-off-by: Kshitij Lakhani <[email protected]>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Fix the skip message
Signed-off-by: Kshitij Lakhani <[email protected]>
* Assert in fused attn bwd pass for sm100
Signed-off-by: Kshitij Lakhani <[email protected]>
Add check for sm100
Signed-off-by: Kshitij Lakhani <[email protected]>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Add support to get all devs in the process for jax
Signed-off-by: Kshitij Lakhani <[email protected]>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Code clean up
Signed-off-by: Kshitij Lakhani <[email protected]>
* Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion
Signed-off-by: Kshitij Lakhani <[email protected]>
* Represent attn bias using enum instead of string
Signed-off-by: Kshitij Lakhani <[email protected]>
---------
Signed-off-by: Kshitij Lakhani <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>1 parent 3b4366b commit f378eaf
File tree
3 files changed
+25
-0
lines changed- tests/jax
- transformer_engine/jax/cpp_extensions
3 files changed
+25
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
41 | 41 | | |
42 | 42 | | |
43 | 43 | | |
| 44 | + | |
44 | 45 | | |
45 | 46 | | |
46 | 47 | | |
| |||
348 | 349 | | |
349 | 350 | | |
350 | 351 | | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
351 | 360 | | |
352 | 361 | | |
353 | 362 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
34 | 34 | | |
35 | 35 | | |
36 | 36 | | |
| 37 | + | |
37 | 38 | | |
38 | 39 | | |
39 | 40 | | |
| |||
2745 | 2746 | | |
2746 | 2747 | | |
2747 | 2748 | | |
| 2749 | + | |
| 2750 | + | |
| 2751 | + | |
| 2752 | + | |
| 2753 | + | |
2748 | 2754 | | |
2749 | 2755 | | |
2750 | 2756 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
193 | 193 | | |
194 | 194 | | |
195 | 195 | | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
196 | 206 | | |
197 | 207 | | |
198 | 208 | | |
| |||
0 commit comments