Skip to content

CI: 04/17/25 upstream sync #370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 821 commits into from

Conversation

rocm-repo-management-api-2[bot]
Copy link

Daily sync with upstream

bythew3i and others added 30 commits April 8, 2025 10:52
So duplicated load/store ops can be removed.

PiperOrigin-RevId: 745209849
PiperOrigin-RevId: 745212009
Pass pytype_srcs as data to the pybind_extension rule.

PiperOrigin-RevId: 745238783
PiperOrigin-RevId: 745247778
These should be used directly from ml_dtypes.

PiperOrigin-RevId: 745256523
Now that jax-ml@db11efa has landed, we're free to split up xla_extension without creating binary size problems or having to be quite so careful about cross-module dependencies. Here weakref_lru_cache has absolutely nothing to do with XLA.

There's no reason weakref_lru_cache is in the same Python extension as everything else.

PiperOrigin-RevId: 745271825
…es` to True.

The main changes here are:

* Don't take the `_efficient_transpose_rewrite` transformation path anymore. In other words, `RewriteTrace` and all the rewriting machinery is dead.

* Wherever internally we were setting `check_rep=False` explicitly like `_prim_applier`, `_match`, `_unmatch`, `_shard_map_partial_eval`, `_shard_map_partial_eval_custom` (for remat), don't do that anymore. Instead set `check_rep` to the `check_rep` value so that it can be True if the user hasn't passed `check_rep=False`.

* Introduce an internal `_check_rep` context manager and set it wherever `extend_axis_env_nd` is used so that if `check_rep=False` on `shard_map`, JAX will set `vma` in `ShapedArray` to empty `frozenset`.

* Because of point (2), if `check_rep=True`, we can't set `in_specs` and `out_specs` of shmap internally to all manual axes of the mesh on the 0th dim. It needs to be whatever the argument was varying on.

Co-authored-by: Matthew Johnson <[email protected]>
PiperOrigin-RevId: 745276474
…JAX_SKIP_SLOW_TESTS=true

Description:
- Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true to reduce the test execution time
  - especially for TSAN CI job where this test takes ~700 seconds to pass with the recent 3.13 cpython
- Removed optional deps for 3.14
I expected Mosaic can canonicalize 2 same strided loads to one but it did not. (We will fix this in Mosaic). For now, manually converting to one strided load boosts 20~35% speedup in both v6e and v5e single chip for Meta-Llama-3-8B.

PiperOrigin-RevId: 745294058
This (private) API will shortly be deleted, and hlo_to_stablehlo is its replacement.

PiperOrigin-RevId: 745333506
…houldn't expose this to public API and have users use `psum` instead which will dispatch to `psum_invariant` when `check_rep=True`.

PiperOrigin-RevId: 745352875
PiperOrigin-RevId: 745375892
…manager

This allows unreserving the barrier once it is no longer needed and is consistent
with how resource estimation works, e.g. for `cond`.

PiperOrigin-RevId: 745483567
When run under an optimized build and Python 3.13.2t, I saw the
following high probability crash in lax_control_flow_test:

```
                Stack trace of thread 3526917:
                #0  0x00007f0898c4bf91 dump_frame (libpython3.13t.so.1.0 + 0x24bf91)
                #1  0x00007f0898c4b73f dump_traceback (libpython3.13t.so.1.0 + 0x24b73f)
                #2  0x00007f0898c4b86f _Py_DumpTracebackThreads (libpython3.13t.so.1.0 + 0x24b86f)
                #3  0x00007f0898cd4fe0 faulthandler_dump_traceback (libpython3.13t.so.1.0 + 0x2d4fe0)
                #4  0x00007f0898cd4f44 faulthandler_fatal_error (libpython3.13t.so.1.0 + 0x2d4f44)
                #5  0x00007f0898849e20 __restore_rt (libc.so.6 + 0x3fe20)
                #6  0x00007f07eb80e493 _ZNSt8__detail16_Hashtable_allocISaINS_10_Hash_nodeISt4pairIKN3jax15WeakrefLRUCache15WeakrefCacheKeyENS4_17WeakrefCacheValueEELb1EEEEE18_M_deallocate_nodeEPS9_ (libjax_common.so + 0x2c0e493)
                #7  0x00007f07eb80e13e _ZN3jax15WeakrefLRUCache5ClearEv (libjax_common.so + 0x2c0e13e)
                #8  0x00007f07eb812e37 _ZZN8nanobind6detail11func_createILb0ELb1EZNS_16cpp_function_defIN3jax15WeakrefLRUCacheEvS4_JEJNS_5scopeENS_4nameENS_9is_methodENS_9lock_selfEEEEvMT1_FT0_DpT2_EDpRKT3_EUlPS4_E_vJSJ_EJLm0EEJS5_S6_S7_S8_EEEP>
                #9  0x00007f07eb7fff70 _ZN8nanobind6detailL25nb_func_vectorcall_simpleEP7_objectPKS2_mS2_ (libjax_common.so + 0x2bfff70)
                #10 0x00007f0898dbbdee _PyObject_VectorcallTstate (libpython3.13t.so.1.0 + 0x3bbdee)
                #11 0x00007f0898d1d4db _PyEval_EvalFrame (libpython3.13t.so.1.0 + 0x31d4db)
                #12 0x00007f0898d1ee78 _PyObject_VectorcallTstate (libpython3.13t.so.1.0 + 0x31ee78)
                #13 0x00007f0898dc0054 _PyVectorcall_Call (libpython3.13t.so.1.0 + 0x3c0054)
                #14 0x00007f0898d1d4db _PyEval_EvalFrame (libpython3.13t.so.1.0 + 0x31d4db)
                #15 0x00007f0898d1e02c _PyObject_VectorcallDictTstate (libpython3.13t.so.1.0 + 0x31e02c)
                #16 0x00007f0898ed8e35 slot_tp_call (libpython3.13t.so.1.0 + 0x4d8e35)
                #17 0x00007f0898dbc312 _PyObject_MakeTpCall (libpython3.13t.so.1.0 + 0x3bc312)
                #18 0x00007f0898d1d4db _PyEval_EvalFrame (libpython3.13t.so.1.0 + 0x31d4db)
                #19 0x00007f0898d1ef54 _PyObject_VectorcallTstate (libpython3.13t.so.1.0 + 0x31ef54)
                #20 0x00007f0899094c1f thread_run (libpython3.13t.so.1.0 + 0x694c1f)
                #21 0x00007f0898fa0c58 pythread_wrapper (libpython3.13t.so.1.0 + 0x5a0c58)
                #22 0x00007f089889c103 start_thread (libc.so.6 + 0x92103)
                #23 0x00007f089891a7b8 __clone3 (libc.so.6 + 0x1107b8)
```

It appears that this is due to freeing Python objects during
unordered_map::clear(), which may release the enclosing critical section
(`nb::lock_self()` on the method). Fix this by deferring destruction of
the both the keys and the values to after the map's destruction.
The CUDA-specific primitives need to be explicitly skipped.

PiperOrigin-RevId: 745504040
It looks like that release does not support Python 3.10 that is still within our support window.

PiperOrigin-RevId: 745508105
PiperOrigin-RevId: 745508763
Google-ML-Automation and others added 26 commits April 16, 2025 07:02
Since we do not have a Linux Arm64 RBE pool, we do not run the tests on Arm64. Instead, we cross-compile the test targets on the Linux x86 RBE pool. The job name on Linux Arm64 runs will now show "build only" to avoid any confusion. Also, run only a single Python version for Linux Arm64

PiperOrigin-RevId: 748374021
… we were checking for a scalar instead. Fixes jax-ml#28070

PiperOrigin-RevId: 748418451
… we were checking for a scalar instead. Fixes jax-ml#28070

PiperOrigin-RevId: 748418451
…oduleNotFoundError

PiperOrigin-RevId: 748440123
…oduleNotFoundError

PiperOrigin-RevId: 748440123
…ual axes on the mesh.

PiperOrigin-RevId: 748534841
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot requested a review from a team as a code owner April 17, 2025 06:02
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot enabled auto-merge (rebase) April 17, 2025 06:02
auto-merge was automatically disabled May 1, 2025 15:10

Pull request was closed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.