Commit e2f2a0b
authored
[JAX] Make SR rng state always 2D (num_devices, 4) to fix partitioning issue (#2294)
* Make SR rng state always 2D (num_devices, 4)
Signed-off-by: Jeremy Berchtold <[email protected]>
* fix pure-jax impl
Signed-off-by: Jeremy Berchtold <[email protected]>
* fix test shape
Signed-off-by: Jeremy Berchtold <[email protected]>
---------
Signed-off-by: Jeremy Berchtold <[email protected]>1 parent eb34783 commit e2f2a0b
File tree
3 files changed
+14
-12
lines changed- tests/jax
- transformer_engine/jax/quantize
3 files changed
+14
-12
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
876 | 876 | | |
877 | 877 | | |
878 | 878 | | |
879 | | - | |
| 879 | + | |
880 | 880 | | |
881 | 881 | | |
882 | 882 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
631 | 631 | | |
632 | 632 | | |
633 | 633 | | |
634 | | - | |
635 | | - | |
636 | | - | |
637 | | - | |
| 634 | + | |
| 635 | + | |
638 | 636 | | |
639 | 637 | | |
640 | 638 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
34 | 34 | | |
35 | 35 | | |
36 | 36 | | |
| 37 | + | |
37 | 38 | | |
38 | 39 | | |
39 | 40 | | |
| |||
633 | 634 | | |
634 | 635 | | |
635 | 636 | | |
636 | | - | |
637 | | - | |
638 | | - | |
| 637 | + | |
| 638 | + | |
| 639 | + | |
| 640 | + | |
| 641 | + | |
639 | 642 | | |
640 | 643 | | |
641 | 644 | | |
642 | 645 | | |
643 | 646 | | |
644 | 647 | | |
645 | 648 | | |
646 | | - | |
647 | | - | |
| 649 | + | |
| 650 | + | |
| 651 | + | |
648 | 652 | | |
649 | 653 | | |
650 | 654 | | |
651 | 655 | | |
652 | | - | |
653 | | - | |
| 656 | + | |
| 657 | + | |
654 | 658 | | |
655 | 659 | | |
656 | 660 | | |
| |||
0 commit comments